跳转至

ztxexp.tracking

ztxexp.tracking

实验追踪器导出。

__all__ module-attribute

__all__ = ['Tracker', 'JsonlTracker', 'MlflowTracker', 'WandbTracker']

JsonlTracker

将生命周期事件写入 events.jsonl 的轻量追踪器。

参数:

名称 类型 描述 默认
events_filename str

事件文件名。

'events.jsonl'
源代码位于: ztxexp/tracking/jsonl.py
class JsonlTracker:
    """将生命周期事件写入 ``events.jsonl`` 的轻量追踪器。

    Args:
        events_filename: 事件文件名。
    """

    def __init__(self, events_filename: str = "events.jsonl"):
        self.events_filename = events_filename

    def _events_path(self, ctx: RunContext):
        return ctx.run_dir / self.events_filename

    def _append(self, ctx: RunContext, payload: dict[str, Any]) -> None:
        utils.append_jsonl(self._events_path(ctx), payload)

    def on_run_start(self, ctx: RunContext, meta: RunMetadata) -> None:
        """记录 run 启动事件。"""
        self._append(
            ctx,
            {
                "event": "run_start",
                "run_id": ctx.run_id,
                "timestamp": utils.utc_now_iso(),
                "meta": meta.to_dict(),
            },
        )

    def on_metric(self, ctx: RunContext, event: MetricEvent) -> None:
        """记录指标事件。"""
        self._append(
            ctx,
            {
                "event": "metric",
                "run_id": ctx.run_id,
                "payload": event.to_dict(),
            },
        )

    def on_run_end(self, ctx: RunContext, summary: dict[str, object]) -> None:
        """记录 run 结束事件。"""
        self._append(
            ctx,
            {
                "event": "run_end",
                "run_id": ctx.run_id,
                "timestamp": utils.utc_now_iso(),
                "summary": summary,
            },
        )

events_filename instance-attribute

events_filename = events_filename

__init__

__init__(events_filename: str = 'events.jsonl')
源代码位于: ztxexp/tracking/jsonl.py
def __init__(self, events_filename: str = "events.jsonl"):
    self.events_filename = events_filename

on_metric

on_metric(ctx: RunContext, event: MetricEvent) -> None

记录指标事件。

源代码位于: ztxexp/tracking/jsonl.py
def on_metric(self, ctx: RunContext, event: MetricEvent) -> None:
    """记录指标事件。"""
    self._append(
        ctx,
        {
            "event": "metric",
            "run_id": ctx.run_id,
            "payload": event.to_dict(),
        },
    )

on_run_end

on_run_end(ctx: RunContext, summary: dict[str, object]) -> None

记录 run 结束事件。

源代码位于: ztxexp/tracking/jsonl.py
def on_run_end(self, ctx: RunContext, summary: dict[str, object]) -> None:
    """记录 run 结束事件。"""
    self._append(
        ctx,
        {
            "event": "run_end",
            "run_id": ctx.run_id,
            "timestamp": utils.utc_now_iso(),
            "summary": summary,
        },
    )

on_run_start

on_run_start(ctx: RunContext, meta: RunMetadata) -> None

记录 run 启动事件。

源代码位于: ztxexp/tracking/jsonl.py
def on_run_start(self, ctx: RunContext, meta: RunMetadata) -> None:
    """记录 run 启动事件。"""
    self._append(
        ctx,
        {
            "event": "run_start",
            "run_id": ctx.run_id,
            "timestamp": utils.utc_now_iso(),
            "meta": meta.to_dict(),
        },
    )

MlflowTracker

MLflow 追踪器(可选依赖)。

源代码位于: ztxexp/tracking/adapters.py
class MlflowTracker:
    """MLflow 追踪器(可选依赖)。"""

    def __init__(
        self,
        tracking_uri: str | None = None,
        experiment_name: str | None = None,
        run_name: str | None = None,
    ):
        self.tracking_uri = tracking_uri
        self.experiment_name = experiment_name
        self.run_name = run_name
        self._started = False

    def _mlflow(self):
        try:
            import mlflow

            return mlflow
        except ImportError as exc:  # pragma: no cover
            raise ImportError(
                "mlflow is required for MlflowTracker. Install with: pip install ztxexp[mlflow]"
            ) from exc

    def on_run_start(self, ctx: RunContext, meta: RunMetadata) -> None:
        mlflow = self._mlflow()
        if self.tracking_uri:
            mlflow.set_tracking_uri(self.tracking_uri)
        if self.experiment_name or meta.experiment_name:
            mlflow.set_experiment(self.experiment_name or str(meta.experiment_name))
        mlflow.start_run(run_name=self.run_name or ctx.run_id)
        self._started = True
        mlflow.log_params(
            {
                k: v
                for k, v in ctx.config.items()
                if isinstance(v, (str, int, float, bool))
            }
        )
        mlflow.set_tags(
            {
                "run_id": ctx.run_id,
                "group": meta.group or "",
                "parent_run_id": meta.parent_run_id or "",
            }
        )

    def on_metric(self, ctx: RunContext, event: MetricEvent) -> None:
        if not self._started:
            return
        mlflow = self._mlflow()
        mlflow.log_metrics(event.metrics, step=event.step)

    def on_run_end(self, ctx: RunContext, summary: dict[str, object]) -> None:
        if not self._started:
            return
        mlflow = self._mlflow()
        status = summary.get("status")
        if status is not None:
            mlflow.set_tag("status", str(status))
        mlflow.end_run()
        self._started = False

experiment_name instance-attribute

experiment_name = experiment_name

run_name instance-attribute

run_name = run_name

tracking_uri instance-attribute

tracking_uri = tracking_uri

__init__

__init__(tracking_uri: str | None = None, experiment_name: str | None = None, run_name: str | None = None)
源代码位于: ztxexp/tracking/adapters.py
def __init__(
    self,
    tracking_uri: str | None = None,
    experiment_name: str | None = None,
    run_name: str | None = None,
):
    self.tracking_uri = tracking_uri
    self.experiment_name = experiment_name
    self.run_name = run_name
    self._started = False

on_metric

on_metric(ctx: RunContext, event: MetricEvent) -> None
源代码位于: ztxexp/tracking/adapters.py
def on_metric(self, ctx: RunContext, event: MetricEvent) -> None:
    if not self._started:
        return
    mlflow = self._mlflow()
    mlflow.log_metrics(event.metrics, step=event.step)

on_run_end

on_run_end(ctx: RunContext, summary: dict[str, object]) -> None
源代码位于: ztxexp/tracking/adapters.py
def on_run_end(self, ctx: RunContext, summary: dict[str, object]) -> None:
    if not self._started:
        return
    mlflow = self._mlflow()
    status = summary.get("status")
    if status is not None:
        mlflow.set_tag("status", str(status))
    mlflow.end_run()
    self._started = False

on_run_start

on_run_start(ctx: RunContext, meta: RunMetadata) -> None
源代码位于: ztxexp/tracking/adapters.py
def on_run_start(self, ctx: RunContext, meta: RunMetadata) -> None:
    mlflow = self._mlflow()
    if self.tracking_uri:
        mlflow.set_tracking_uri(self.tracking_uri)
    if self.experiment_name or meta.experiment_name:
        mlflow.set_experiment(self.experiment_name or str(meta.experiment_name))
    mlflow.start_run(run_name=self.run_name or ctx.run_id)
    self._started = True
    mlflow.log_params(
        {
            k: v
            for k, v in ctx.config.items()
            if isinstance(v, (str, int, float, bool))
        }
    )
    mlflow.set_tags(
        {
            "run_id": ctx.run_id,
            "group": meta.group or "",
            "parent_run_id": meta.parent_run_id or "",
        }
    )

Tracker

Bases: Protocol

实验追踪器协议。

所有追踪器都应实现生命周期三段回调: 1. run 开始; 2. 指标事件; 3. run 结束。

源代码位于: ztxexp/tracking/base.py
class Tracker(Protocol):
    """实验追踪器协议。

    所有追踪器都应实现生命周期三段回调:
    1. run 开始;
    2. 指标事件;
    3. run 结束。
    """

    def on_run_start(self, ctx: RunContext, meta: RunMetadata) -> None:
        """run 启动回调。"""

    def on_metric(self, ctx: RunContext, event: MetricEvent) -> None:
        """指标事件回调。"""

    def on_run_end(self, ctx: RunContext, summary: dict[str, object]) -> None:
        """run 结束回调。"""

on_metric

on_metric(ctx: RunContext, event: MetricEvent) -> None

指标事件回调。

源代码位于: ztxexp/tracking/base.py
def on_metric(self, ctx: RunContext, event: MetricEvent) -> None:
    """指标事件回调。"""

on_run_end

on_run_end(ctx: RunContext, summary: dict[str, object]) -> None

run 结束回调。

源代码位于: ztxexp/tracking/base.py
def on_run_end(self, ctx: RunContext, summary: dict[str, object]) -> None:
    """run 结束回调。"""

on_run_start

on_run_start(ctx: RunContext, meta: RunMetadata) -> None

run 启动回调。

源代码位于: ztxexp/tracking/base.py
def on_run_start(self, ctx: RunContext, meta: RunMetadata) -> None:
    """run 启动回调。"""

WandbTracker

Weights & Biases 追踪器(可选依赖)。

源代码位于: ztxexp/tracking/adapters.py
class WandbTracker:
    """Weights & Biases 追踪器(可选依赖)。"""

    def __init__(self, project: str | None = None, entity: str | None = None, **kwargs: Any):
        self.project = project
        self.entity = entity
        self.kwargs = kwargs
        self._run = None

    def _wandb(self):
        try:
            import wandb

            return wandb
        except ImportError as exc:  # pragma: no cover
            raise ImportError(
                "wandb is required for WandbTracker. Install with: pip install ztxexp[wandb]"
            ) from exc

    def on_run_start(self, ctx: RunContext, meta: RunMetadata) -> None:
        wandb = self._wandb()
        self._run = wandb.init(
            project=self.project,
            entity=self.entity,
            config=ctx.config,
            name=ctx.run_id,
            reinit=True,
            **self.kwargs,
        )
        if self._run is not None:
            if meta.experiment_name:
                self._run.summary["experiment_name"] = meta.experiment_name
            if meta.group:
                self._run.summary["group"] = meta.group

    def on_metric(self, ctx: RunContext, event: MetricEvent) -> None:
        if self._run is None:
            return
        wandb = self._wandb()
        wandb.log(dict(event.metrics), step=event.step)

    def on_run_end(self, ctx: RunContext, summary: dict[str, object]) -> None:
        if self._run is None:
            return
        self._run.summary.update(summary)
        self._run.finish()
        self._run = None

entity instance-attribute

entity = entity

kwargs instance-attribute

kwargs = kwargs

project instance-attribute

project = project

__init__

__init__(project: str | None = None, entity: str | None = None, **kwargs: Any)
源代码位于: ztxexp/tracking/adapters.py
def __init__(self, project: str | None = None, entity: str | None = None, **kwargs: Any):
    self.project = project
    self.entity = entity
    self.kwargs = kwargs
    self._run = None

on_metric

on_metric(ctx: RunContext, event: MetricEvent) -> None
源代码位于: ztxexp/tracking/adapters.py
def on_metric(self, ctx: RunContext, event: MetricEvent) -> None:
    if self._run is None:
        return
    wandb = self._wandb()
    wandb.log(dict(event.metrics), step=event.step)

on_run_end

on_run_end(ctx: RunContext, summary: dict[str, object]) -> None
源代码位于: ztxexp/tracking/adapters.py
def on_run_end(self, ctx: RunContext, summary: dict[str, object]) -> None:
    if self._run is None:
        return
    self._run.summary.update(summary)
    self._run.finish()
    self._run = None

on_run_start

on_run_start(ctx: RunContext, meta: RunMetadata) -> None
源代码位于: ztxexp/tracking/adapters.py
def on_run_start(self, ctx: RunContext, meta: RunMetadata) -> None:
    wandb = self._wandb()
    self._run = wandb.init(
        project=self.project,
        entity=self.entity,
        config=ctx.config,
        name=ctx.run_id,
        reinit=True,
        **self.kwargs,
    )
    if self._run is not None:
        if meta.experiment_name:
            self._run.summary["experiment_name"] = meta.experiment_name
        if meta.group:
            self._run.summary["group"] = meta.group