콘텐츠로 이동

Weights & Biases (wandb)

Installation

# Core install
pip install wandb

# With integrations
pip install wandb[media]         # rich media logging
pip install wandb[importers]     # import from MLflow, Comet
pip install "wandb[sweeps]"      # sweep agent

# Verify
wandb --version

Authenticate:

# Interactive (opens browser)
wandb login

# Non-interactive (CI/CD)
wandb login $WANDB_API_KEY

# Or set env var (no command needed)
export WANDB_API_KEY=your_api_key_here

# Verify auth
wandb whoami

# Logout
wandb logout

Configuration

Key Environment Variables

# Authentication
export WANDB_API_KEY=your_api_key

# Project and entity defaults
export WANDB_PROJECT=my-project
export WANDB_ENTITY=my-team

# Modes
export WANDB_MODE=offline       # queue logs locally, sync later
export WANDB_MODE=disabled      # turn off wandb entirely
export WANDB_MODE=online        # default, sync in real-time

# Disable prompts in CI
export WANDB_SILENT=true

# Directory for local run data
export WANDB_DIR=/tmp/wandb

# Artifact caching
export WANDB_CACHE_DIR=~/.cache/wandb

# Anonymous runs (no login required)
export WANDB_ANONYMOUS=allow

wandb/settings (Project-Level)

# .wandb/settings (auto-generated on first run)
[default]
entity = my-team
project = my-project
mode = online

Core Commands

CLI

CommandDescription
wandb loginAuthenticate
wandb whoamiShow current user
wandb initInitialize a project
wandb sync ./wandb/offline-run-*Sync offline runs
wandb artifact get entity/project/name:versionDownload artifact
wandb artifact put -n name -t type fileUpload artifact
wandb sweep config.yamlCreate sweep
wandb agent sweep-idStart sweep agent
wandb server startStart local wandb server
wandb disabledDisable wandb globally
wandb enabledRe-enable wandb
wandb statusCheck connection status
wandb gcClean up old local runs
wandb restore run-idRestore a run’s code/config

Python API

FunctionDescription
wandb.init(project, entity, name, config)Start a run
wandb.log({"loss": 0.5, "acc": 0.9})Log metrics
wandb.log({"loss": 0.5}, step=100)Log at specific step
wandb.finish()End a run
wandb.config.update({"lr": 0.01})Update config
wandb.summary["best_acc"] = 0.95Set summary metric
wandb.save("model.pkl")Save file to run
wandb.restore("model.pkl")Restore saved file
wandb.watch(model)Track model gradients
wandb.watch(model, log="all", log_freq=100)Watch with frequency
wandb.alert(title, text, level)Send alert notification
wandb.mark_preempting()Signal preemption
wandb.define_metric("val/*", step_metric="epoch")Custom x-axis

Advanced Usage

Full Training Run

import wandb
import torch
import torch.nn as nn
from torch.optim import AdamW

# Initialize run
run = wandb.init(
    project="image-classification",
    entity="my-team",
    name="resnet50-baseline",
    tags=["baseline", "resnet", "imagenet"],
    notes="First run with default augmentation",
    config={
        "architecture": "resnet50",
        "learning_rate": 1e-4,
        "weight_decay": 1e-2,
        "epochs": 50,
        "batch_size": 64,
        "optimizer": "adamw",
        "dataset": "imagenet",
    }
)

# Access config (works with sweeps)
cfg = wandb.config

model = build_model(cfg.architecture)
optimizer = AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)

# Track model gradients and weights
wandb.watch(model, log="gradients", log_freq=100)

# Training loop
for epoch in range(cfg.epochs):
    train_loss, train_acc = train_epoch(model, optimizer, train_loader)
    val_loss, val_acc = evaluate(model, val_loader)

    wandb.log({
        "epoch": epoch,
        "train/loss": train_loss,
        "train/acc": train_acc,
        "val/loss": val_loss,
        "val/acc": val_acc,
        "lr": optimizer.param_groups[0]["lr"],
    })

    # Log confusion matrix
    if epoch % 10 == 0:
        cm = compute_confusion_matrix(model, val_loader)
        wandb.log({"confusion_matrix": wandb.plot.confusion_matrix(
            probs=None, y_true=labels, preds=preds,
            class_names=class_names
        )})

# Set summary metrics (shown in run table)
wandb.summary["best_val_acc"] = best_val_acc
wandb.summary["best_epoch"] = best_epoch

wandb.finish()

Hyperparameter Sweeps

# sweep.yaml
program: train.py
method: bayes          # grid | random | bayes

metric:
  goal: maximize
  name: val/acc

parameters:
  learning_rate:
    distribution: log_uniform_values
    min: 1e-5
    max: 1e-2
  batch_size:
    values: [32, 64, 128]
  optimizer:
    values: [adam, adamw, sgd]
  dropout:
    distribution: uniform
    min: 0.1
    max: 0.5
  n_layers:
    values: [2, 3, 4, 5]

# Optional: stop early if not improving
early_terminate:
  type: hyperband
  min_iter: 3
  eta: 3
# Create the sweep
wandb sweep sweep.yaml
# Returns: sweep ID like abc12345

# Run agents (can run on multiple machines)
wandb agent my-team/my-project/abc12345

# Limit agent runs
wandb agent --count 20 my-team/my-project/abc12345
# train.py — works with both direct runs and sweeps
import wandb

def train():
    with wandb.init() as run:
        cfg = run.config
        # cfg.learning_rate, cfg.batch_size etc. injected by sweep
        model = build_model()
        optimizer = build_optimizer(model, cfg.optimizer, cfg.learning_rate)
        for epoch in range(50):
            loss = train_epoch(model, optimizer)
            wandb.log({"val/acc": evaluate(model)})

wandb.agent("my-team/my-project/abc12345", function=train, count=50)

Artifacts

# Log an artifact (dataset, model, predictions)
artifact = wandb.Artifact(
    name="imagenet-preprocessed",
    type="dataset",
    description="ImageNet resized to 224x224",
    metadata={"num_examples": 1281167, "split": "train"}
)
artifact.add_dir("./data/imagenet/")              # add directory
artifact.add_file("./data/labels.json")           # add single file

# Log to the run
run.log_artifact(artifact)

# Log model artifact
model_artifact = wandb.Artifact("resnet50-v1", type="model")
model_artifact.add_file("model.pt")
run.log_artifact(model_artifact)

# Use an artifact in a run
artifact = run.use_artifact("my-team/project/imagenet-preprocessed:latest")
data_dir = artifact.download()           # returns local path

# Mark artifact as used without downloading
artifact = run.use_artifact("my-team/project/resnet50-v1:v3")
artifact.get_path("model.pt").download()  # download single file

# Link to model registry
run.link_artifact(model_artifact, target_path="my-team/model-registry/ResNet50Classifier")

Tables (Structured Logging)

# Log predictions as a table
columns = ["id", "image", "label", "prediction", "confidence", "correct"]
data = []

for idx, (img, label) in enumerate(val_loader):
    pred, conf = model.predict(img)
    data.append([
        idx,
        wandb.Image(img),
        class_names[label],
        class_names[pred],
        conf,
        label == pred
    ])

table = wandb.Table(columns=columns, data=data)
wandb.log({"predictions": table})

# Log table from pandas DataFrame
import pandas as pd
df = pd.read_csv("results.csv")
wandb.log({"results": wandb.Table(dataframe=df)})

Media Logging

# Images
wandb.log({"image": wandb.Image(pil_image)})
wandb.log({"segmentation": wandb.Image(img, masks={"pred": {"mask_data": mask}})})
wandb.log({"bboxes": wandb.Image(img, boxes={"pred": {"box_data": boxes}})})

# Audio
wandb.log({"audio": wandb.Audio(audio_array, sample_rate=44100)})

# Video
wandb.log({"video": wandb.Video(frames, fps=24, format="mp4")})

# 3D Point Cloud
wandb.log({"point_cloud": wandb.Object3D({"type": "lidar/beta", "points": points})})

# Plotly / matplotlib
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot(history["loss"])
wandb.log({"loss_curve": wandb.Image(fig)})
# or directly:
wandb.log({"plotly_chart": wandb.Plotly(plotly_fig)})

# HTML
wandb.log({"report": wandb.Html(html_string)})

Integrations

# PyTorch Lightning
from lightning.pytorch.loggers import WandbLogger

logger = WandbLogger(project="my-project", log_model=True)
trainer = pl.Trainer(logger=logger, max_epochs=50)

# HuggingFace Transformers (auto-integration via env var)
import os
os.environ["WANDB_PROJECT"] = "my-project"
# Training with Trainer logs automatically

from transformers import Trainer, TrainingArguments
args = TrainingArguments(
    output_dir="./output",
    report_to="wandb",
    run_name="bert-finetune"
)

# Keras
wandb_callback = wandb.keras.WandbCallback(
    monitor="val_accuracy",
    save_model=True,
    log_weights=True
)
model.fit(X_train, y_train, callbacks=[wandb_callback])

# XGBoost
from wandb.integration.xgboost import WandbCallback
model = xgb.train(params, dtrain, callbacks=[WandbCallback()])

# scikit-learn
from wandb.sklearn import plot_classifier
plot_classifier(clf, X_train, X_test, y_train, y_test, labels=classes)

Resuming Runs

# Resume a crashed run
run = wandb.init(
    project="my-project",
    id="existing-run-id",
    resume="allow"          # "must" | "allow" | "never" | "auto"
)

Common Workflows

Offline Mode and Sync

# Run offline
WANDB_MODE=offline python train.py

# Sync all offline runs later
wandb sync ./wandb/

# Sync specific run
wandb sync ./wandb/offline-run-2026-05-16_12-00-00-abc123/

Compare Runs Programmatically

import wandb

api = wandb.Api()

# Fetch runs from a project
runs = api.runs("my-team/my-project", filters={"config.architecture": "resnet50"})

# Build comparison DataFrame
import pandas as pd
rows = []
for run in runs:
    rows.append({
        "id": run.id,
        "name": run.name,
        "val_acc": run.summary.get("best_val_acc"),
        "lr": run.config.get("learning_rate"),
        "state": run.state
    })

df = pd.DataFrame(rows).sort_values("val_acc", ascending=False)
print(df.head(10))

Sweep with Parallel Agents

# Terminal 1
wandb agent my-team/project/sweep-id

# Terminal 2 (same machine or different)
wandb agent my-team/project/sweep-id

# On a GPU cluster — submit one agent per GPU
for GPU in 0 1 2 3; do
  CUDA_VISIBLE_DEVICES=$GPU wandb agent --count 10 my-team/project/sweep-id &
done

Tips and Best Practices

  • Always use wandb.config for hyperparameters — it integrates directly with sweeps and makes configs searchable in the UI.
  • wandb.define_metric lets you set a custom x-axis per metric — use it to plot validation metrics against epoch rather than step.
  • Name your runs with name= in wandb.init — auto-generated names are forgettable; use arch-dataset-timestamp patterns.
  • Use tags to group related runs across experiments — filter and compare by tag in the UI without renaming projects.
  • Bayesian sweeps outperform random after ~20 runs — start with method: random for quick exploration, switch to method: bayes once you have a search space.
  • wandb.watch(model, log="all") logs both gradients and weights — use log="gradients" in production to halve the overhead.
  • Artifacts version your data and models — always reference artifacts by alias (:latest, :best) rather than version number in downstream jobs.
  • Offline mode is essential for HPC clusters without internet — run offline, sync when you have connectivity.
  • wandb.summary is the single-row summary shown in the runs table — set your best epoch metrics here, not just the last step’s values.
  • Link models to the model registry for production promotion workflows — it decouples model versioning from experiment runs.
  • early_terminate: hyperband in sweeps cuts unpromising trials early — essential when each trial is expensive (GPU hours).
  • Reports turn experiment comparisons into shareable documents — use them for team reviews and experiment write-ups instead of copying screenshots.