跳转至

ztxexp

ztxexp

ztxexp 包级导出。

该模块定义用户最常用的导入入口,例如:

from ztxexp import ExperimentPipeline, ResultAnalyzer

__all__ module-attribute

__all__ = ['ExpManager', 'ExpRunner', 'ResultAnalyzer', 'ExperimentPipeline', 'RunContext', 'RunMetadata', 'MetricEvent', 'RunSummary', 'SkipRun', 'Tracker', 'JsonlTracker', 'MlflowTracker', 'WandbTracker', 'init_torch_env', 'set_process_priority']

__version__ module-attribute

__version__ = '1.0.3'

ExpManager

实验配置构建器。

该类维护一条配置流水线: grid -> variants -> modify -> where -> exclude_completed -> shuffle

参数:

名称 类型 描述 默认
base_config Namespace | Mapping[str, Any] | None

基础配置。可传 NamespacedictNone

None

示例:

>>> manager = (
...     ExpManager({"seed": 42})
...     .grid({"lr": [1e-3, 1e-2]})
...     .variants([{"model": "tiny"}, {"model": "base"}])
...     .where(lambda c: c["lr"] < 0.02)
... )
>>> len(manager.build())
4
源代码位于: ztxexp/manager.py
class ExpManager:
    """实验配置构建器。

    该类维护一条配置流水线:
    ``grid -> variants -> modify -> where -> exclude_completed -> shuffle``。

    Args:
        base_config: 基础配置。可传 ``Namespace``、``dict`` 或 ``None``。

    Examples:
        >>> manager = (
        ...     ExpManager({"seed": 42})
        ...     .grid({"lr": [1e-3, 1e-2]})
        ...     .variants([{"model": "tiny"}, {"model": "base"}])
        ...     .where(lambda c: c["lr"] < 0.02)
        ... )
        >>> len(manager.build())
        4
    """

    def __init__(self, base_config: argparse.Namespace | Mapping[str, Any] | None = None):
        base = {} if base_config is None else _namespace_to_dict(base_config)
        self._configs: list[ConfigDict] = [base]
        self._modifiers: list[Modifier] = []
        self._predicates: list[Predicate] = []
        self._exclude_completed_root: Path | None = None
        self._exclude_ignore_keys: set[str] = set()
        self._should_shuffle = False

    def grid(self, param_grid: Mapping[str, Sequence[Any]]) -> "ExpManager":
        """按笛卡尔积扩展参数网格。

        Args:
            param_grid: 网格字典,例如
                ``{"lr": [1e-3, 1e-2], "batch_size": [16, 32]}``。

        Returns:
            ExpManager: 返回自身,支持链式调用。

        Notes:
            若 ``param_grid`` 为空,本方法为 no-op。
        """
        if not param_grid:
            return self

        keys = list(param_grid.keys())
        value_lists = [list(param_grid[key]) for key in keys]
        combos = list(itertools.product(*value_lists))

        expanded: list[ConfigDict] = []
        for base_config in self._configs:
            for combo in combos:
                next_config = copy.deepcopy(base_config)
                for key, value in zip(keys, combo):
                    next_config[key] = value
                expanded.append(next_config)

        self._configs = expanded
        return self

    def variants(
        self,
        variants: Sequence[Mapping[str, Any]] | Mapping[str, Sequence[Any]],
    ) -> "ExpManager":
        """按“独立变体”方式扩展配置。

        Args:
            variants: 推荐传 ``list[dict]``,例如
                ``[{"model": "tiny"}, {"model": "base", "layers": 12}]``。
                同时兼容旧格式 ``dict[str, list]``。

        Returns:
            ExpManager: 返回自身,支持链式调用。

        Notes:
            - ``list[dict]`` 语义更清晰,推荐优先使用。
            - ``dict[str, list]`` 会被转为单键变体集合。
        """
        if not variants:
            return self

        variant_dicts: list[ConfigDict] = []
        if isinstance(variants, Mapping):
            for key, values in variants.items():
                for value in values:
                    variant_dicts.append({key: value})
        else:
            variant_dicts = [dict(item) for item in variants]

        expanded: list[ConfigDict] = []
        for base_config in self._configs:
            for variant in variant_dicts:
                merged = copy.deepcopy(base_config)
                merged.update(copy.deepcopy(variant))
                expanded.append(merged)

        self._configs = expanded
        return self

    def random_search(
        self,
        space: Mapping[str, Sequence[Any]],
        n_trials: int,
        seed: int = 42,
    ) -> "ExpManager":
        """按随机采样方式扩展配置。

        Args:
            space: 参数搜索空间,值为候选列表。
            n_trials: 采样次数。
            seed: 随机种子。

        Returns:
            ExpManager: 返回自身,支持链式调用。

        Notes:
            - 采样采用“有放回”策略;
            - 不替代 ``grid/variants``,可与其组合使用。
        """
        if not space or n_trials <= 0:
            return self

        rng = random.Random(seed)
        keys = list(space.keys())
        value_lists = [list(space[key]) for key in keys]

        expanded: list[ConfigDict] = []
        for base_config in self._configs:
            for _ in range(n_trials):
                next_config = copy.deepcopy(base_config)
                for key, values in zip(keys, value_lists):
                    if not values:
                        continue
                    next_config[key] = rng.choice(values)
                expanded.append(next_config)

        self._configs = expanded
        return self

    def modify(self, modifier: Modifier) -> "ExpManager":
        """注册配置修改器。

        Args:
            modifier: 修改函数。支持两种风格:
                1) 原地修改并返回 ``None``;
                2) 返回修改后的新字典。

        Returns:
            ExpManager: 返回自身。
        """
        self._modifiers.append(modifier)
        return self

    def where(self, predicate: Predicate) -> "ExpManager":
        """注册配置过滤器。

        Args:
            predicate: 谓词函数。返回 ``True`` 表示保留该配置。

        Returns:
            ExpManager: 返回自身。
        """
        self._predicates.append(predicate)
        return self

    def exclude_completed(
        self,
        results_root: str | Path,
        ignore_keys: Sequence[str] | None = None,
    ) -> "ExpManager":
        """排除已成功完成的配置。

        Args:
            results_root: 历史 run 根目录。
            ignore_keys: 配置对比时忽略的键(可选)。

        Returns:
            ExpManager: 返回自身。

        Notes:
            仅将满足以下条件的 run 视为“已完成”:
            1) ``run.json.schema_version == RUN_SCHEMA_VERSION``;
            2) ``run.json.status == succeeded``。
        """
        self._exclude_completed_root = Path(results_root)
        self._exclude_ignore_keys = set(ignore_keys or [])
        return self

    def shuffle(self) -> "ExpManager":
        """在最终构建结果上随机打乱顺序。

        Returns:
            ExpManager: 返回自身。
        """
        self._should_shuffle = True
        return self

    def build(self) -> list[ConfigDict]:
        """执行所有阶段并返回最终配置列表。

        Returns:
            list[dict[str, Any]]: 最终配置列表。

        Raises:
            TypeError: 当某个修改器返回值不是 ``dict`` 或 ``None``。

        Examples:
            >>> manager = ExpManager({"a": 1}).modify(lambda c: {**c, "b": 2})
            >>> manager.build()[0]["b"]
            2
        """
        configs = [copy.deepcopy(config) for config in self._configs]

        if self._modifiers:
            modified_configs: list[ConfigDict] = []
            for config in configs:
                next_config = config
                for modifier in self._modifiers:
                    result = modifier(next_config)
                    if result is None:
                        result = next_config
                    if not isinstance(result, dict):
                        raise TypeError("Modifier must return dict or None.")
                    next_config = result
                modified_configs.append(next_config)
            configs = modified_configs

        if self._predicates:
            filtered = []
            for config in configs:
                if all(predicate(config) for predicate in self._predicates):
                    filtered.append(config)
            configs = filtered

        if self._exclude_completed_root:
            completed_configs = self._load_completed_configs(self._exclude_completed_root)
            configs = [
                config
                for config in configs
                if not any(
                    self._configs_equal(config, completed, self._exclude_ignore_keys)
                    for completed in completed_configs
                )
            ]

        if self._should_shuffle:
            random.shuffle(configs)

        return configs

    # ---- v0.1 兼容别名 ----

    def add_grid_search(self, param_grid: Mapping[str, Sequence[Any]]) -> "ExpManager":
        """``grid`` 的兼容别名。"""
        return self.grid(param_grid)

    def add_variants(
        self,
        variant_space: Sequence[Mapping[str, Any]] | Mapping[str, Sequence[Any]],
    ) -> "ExpManager":
        """``variants`` 的兼容别名。"""
        return self.variants(variant_space)

    def add_modifier(self, modifier_func: Modifier) -> "ExpManager":
        """``modify`` 的兼容别名。"""
        return self.modify(modifier_func)

    def add_filter(self, filter_func: Predicate) -> "ExpManager":
        """``where`` 的兼容别名。"""
        return self.where(filter_func)

    def filter_completed(
        self,
        results_path: str | Path,
        ignore_keys: Sequence[str] | None = None,
    ) -> "ExpManager":
        """``exclude_completed`` 的兼容别名。"""
        return self.exclude_completed(results_path, ignore_keys=ignore_keys)

    def get_configs(self) -> list[ConfigDict]:
        """``build`` 的兼容别名。"""
        return self.build()

    def add_random_search(
        self,
        space: Mapping[str, Sequence[Any]],
        n_trials: int,
        seed: int = 42,
    ) -> "ExpManager":
        """``random_search`` 的兼容别名。"""
        return self.random_search(space=space, n_trials=n_trials, seed=seed)

    # ---- 内部辅助函数 ----

    def _load_completed_configs(self, results_root: Path) -> list[ConfigDict]:
        """加载“已成功完成”实验的配置列表。

        Args:
            results_root: 历史 run 根目录。

        Returns:
            list[dict[str, Any]]: 已完成实验对应的配置字典列表。
        """
        if not results_root.exists():
            return []

        completed: list[ConfigDict] = []
        for run_dir in utils.get_subdirectories(results_root):
            run_meta = utils.load_json(run_dir / "run.json")
            if not run_meta:
                continue
            if run_meta.get("schema_version") != RUN_SCHEMA_VERSION:
                continue
            if run_meta.get("status") != RUN_STATUS_SUCCEEDED:
                continue

            config = utils.load_json(run_dir / "config.json")
            if isinstance(config, dict):
                completed.append(config)

        return completed

    def _configs_equal(self, a: ConfigDict, b: ConfigDict, ignore_keys: Iterable[str]) -> bool:
        """判断两个配置是否等价。

        Args:
            a: 配置 A。
            b: 配置 B。
            ignore_keys: 忽略键集合。

        Returns:
            bool: ``True`` 表示等价,``False`` 表示不等价。

        Notes:
            采用严格键集合比较,避免“缺字段却误判相等”的问题。
        """
        ignore = set(ignore_keys)
        left_keys = set(a.keys()) - ignore
        right_keys = set(b.keys()) - ignore
        if left_keys != right_keys:
            return False

        for key in left_keys:
            if self._normalize_value(a[key]) != self._normalize_value(b[key]):
                return False
        return True

    def _normalize_value(self, value: Any) -> Any:
        """归一化配置值,便于稳定比较。

        Args:
            value: 任意配置值。

        Returns:
            Any: 归一化后的值。

        Notes:
            - dict: 按 key 排序后递归归一化;
            - list/tuple: 统一为 list 并递归归一化。
        """
        if isinstance(value, dict):
            return {k: self._normalize_value(v) for k, v in sorted(value.items())}
        if isinstance(value, (list, tuple)):
            return [self._normalize_value(item) for item in value]
        return value

__init__

__init__(base_config: Namespace | Mapping[str, Any] | None = None)
源代码位于: ztxexp/manager.py
def __init__(self, base_config: argparse.Namespace | Mapping[str, Any] | None = None):
    base = {} if base_config is None else _namespace_to_dict(base_config)
    self._configs: list[ConfigDict] = [base]
    self._modifiers: list[Modifier] = []
    self._predicates: list[Predicate] = []
    self._exclude_completed_root: Path | None = None
    self._exclude_ignore_keys: set[str] = set()
    self._should_shuffle = False

add_filter

add_filter(filter_func: Predicate) -> 'ExpManager'

where 的兼容别名。

源代码位于: ztxexp/manager.py
def add_filter(self, filter_func: Predicate) -> "ExpManager":
    """``where`` 的兼容别名。"""
    return self.where(filter_func)
add_grid_search(param_grid: Mapping[str, Sequence[Any]]) -> 'ExpManager'

grid 的兼容别名。

源代码位于: ztxexp/manager.py
def add_grid_search(self, param_grid: Mapping[str, Sequence[Any]]) -> "ExpManager":
    """``grid`` 的兼容别名。"""
    return self.grid(param_grid)

add_modifier

add_modifier(modifier_func: Modifier) -> 'ExpManager'

modify 的兼容别名。

源代码位于: ztxexp/manager.py
def add_modifier(self, modifier_func: Modifier) -> "ExpManager":
    """``modify`` 的兼容别名。"""
    return self.modify(modifier_func)
add_random_search(space: Mapping[str, Sequence[Any]], n_trials: int, seed: int = 42) -> 'ExpManager'

random_search 的兼容别名。

源代码位于: ztxexp/manager.py
def add_random_search(
    self,
    space: Mapping[str, Sequence[Any]],
    n_trials: int,
    seed: int = 42,
) -> "ExpManager":
    """``random_search`` 的兼容别名。"""
    return self.random_search(space=space, n_trials=n_trials, seed=seed)

add_variants

add_variants(variant_space: Sequence[Mapping[str, Any]] | Mapping[str, Sequence[Any]]) -> 'ExpManager'

variants 的兼容别名。

源代码位于: ztxexp/manager.py
def add_variants(
    self,
    variant_space: Sequence[Mapping[str, Any]] | Mapping[str, Sequence[Any]],
) -> "ExpManager":
    """``variants`` 的兼容别名。"""
    return self.variants(variant_space)

build

build() -> list[ConfigDict]

执行所有阶段并返回最终配置列表。

返回:

类型 描述
list[ConfigDict]

list[dict[str, Any]]: 最终配置列表。

引发:

类型 描述
TypeError

当某个修改器返回值不是 dictNone

示例:

>>> manager = ExpManager({"a": 1}).modify(lambda c: {**c, "b": 2})
>>> manager.build()[0]["b"]
2
源代码位于: ztxexp/manager.py
def build(self) -> list[ConfigDict]:
    """执行所有阶段并返回最终配置列表。

    Returns:
        list[dict[str, Any]]: 最终配置列表。

    Raises:
        TypeError: 当某个修改器返回值不是 ``dict`` 或 ``None``。

    Examples:
        >>> manager = ExpManager({"a": 1}).modify(lambda c: {**c, "b": 2})
        >>> manager.build()[0]["b"]
        2
    """
    configs = [copy.deepcopy(config) for config in self._configs]

    if self._modifiers:
        modified_configs: list[ConfigDict] = []
        for config in configs:
            next_config = config
            for modifier in self._modifiers:
                result = modifier(next_config)
                if result is None:
                    result = next_config
                if not isinstance(result, dict):
                    raise TypeError("Modifier must return dict or None.")
                next_config = result
            modified_configs.append(next_config)
        configs = modified_configs

    if self._predicates:
        filtered = []
        for config in configs:
            if all(predicate(config) for predicate in self._predicates):
                filtered.append(config)
        configs = filtered

    if self._exclude_completed_root:
        completed_configs = self._load_completed_configs(self._exclude_completed_root)
        configs = [
            config
            for config in configs
            if not any(
                self._configs_equal(config, completed, self._exclude_ignore_keys)
                for completed in completed_configs
            )
        ]

    if self._should_shuffle:
        random.shuffle(configs)

    return configs

exclude_completed

exclude_completed(results_root: str | Path, ignore_keys: Sequence[str] | None = None) -> 'ExpManager'

排除已成功完成的配置。

参数:

名称 类型 描述 默认
results_root str | Path

历史 run 根目录。

必需
ignore_keys Sequence[str] | None

配置对比时忽略的键(可选)。

None

返回:

名称 类型 描述
ExpManager 'ExpManager'

返回自身。

Notes

仅将满足以下条件的 run 视为“已完成”: 1) run.json.schema_version == RUN_SCHEMA_VERSION; 2) run.json.status == succeeded

源代码位于: ztxexp/manager.py
def exclude_completed(
    self,
    results_root: str | Path,
    ignore_keys: Sequence[str] | None = None,
) -> "ExpManager":
    """排除已成功完成的配置。

    Args:
        results_root: 历史 run 根目录。
        ignore_keys: 配置对比时忽略的键(可选)。

    Returns:
        ExpManager: 返回自身。

    Notes:
        仅将满足以下条件的 run 视为“已完成”:
        1) ``run.json.schema_version == RUN_SCHEMA_VERSION``;
        2) ``run.json.status == succeeded``。
    """
    self._exclude_completed_root = Path(results_root)
    self._exclude_ignore_keys = set(ignore_keys or [])
    return self

filter_completed

filter_completed(results_path: str | Path, ignore_keys: Sequence[str] | None = None) -> 'ExpManager'

exclude_completed 的兼容别名。

源代码位于: ztxexp/manager.py
def filter_completed(
    self,
    results_path: str | Path,
    ignore_keys: Sequence[str] | None = None,
) -> "ExpManager":
    """``exclude_completed`` 的兼容别名。"""
    return self.exclude_completed(results_path, ignore_keys=ignore_keys)

get_configs

get_configs() -> list[ConfigDict]

build 的兼容别名。

源代码位于: ztxexp/manager.py
def get_configs(self) -> list[ConfigDict]:
    """``build`` 的兼容别名。"""
    return self.build()

grid

grid(param_grid: Mapping[str, Sequence[Any]]) -> 'ExpManager'

按笛卡尔积扩展参数网格。

参数:

名称 类型 描述 默认
param_grid Mapping[str, Sequence[Any]]

网格字典,例如 {"lr": [1e-3, 1e-2], "batch_size": [16, 32]}

必需

返回:

名称 类型 描述
ExpManager 'ExpManager'

返回自身,支持链式调用。

Notes

param_grid 为空,本方法为 no-op。

源代码位于: ztxexp/manager.py
def grid(self, param_grid: Mapping[str, Sequence[Any]]) -> "ExpManager":
    """按笛卡尔积扩展参数网格。

    Args:
        param_grid: 网格字典,例如
            ``{"lr": [1e-3, 1e-2], "batch_size": [16, 32]}``。

    Returns:
        ExpManager: 返回自身,支持链式调用。

    Notes:
        若 ``param_grid`` 为空,本方法为 no-op。
    """
    if not param_grid:
        return self

    keys = list(param_grid.keys())
    value_lists = [list(param_grid[key]) for key in keys]
    combos = list(itertools.product(*value_lists))

    expanded: list[ConfigDict] = []
    for base_config in self._configs:
        for combo in combos:
            next_config = copy.deepcopy(base_config)
            for key, value in zip(keys, combo):
                next_config[key] = value
            expanded.append(next_config)

    self._configs = expanded
    return self

modify

modify(modifier: Modifier) -> 'ExpManager'

注册配置修改器。

参数:

名称 类型 描述 默认
modifier Modifier

修改函数。支持两种风格: 1) 原地修改并返回 None; 2) 返回修改后的新字典。

必需

返回:

名称 类型 描述
ExpManager 'ExpManager'

返回自身。

源代码位于: ztxexp/manager.py
def modify(self, modifier: Modifier) -> "ExpManager":
    """注册配置修改器。

    Args:
        modifier: 修改函数。支持两种风格:
            1) 原地修改并返回 ``None``;
            2) 返回修改后的新字典。

    Returns:
        ExpManager: 返回自身。
    """
    self._modifiers.append(modifier)
    return self
random_search(space: Mapping[str, Sequence[Any]], n_trials: int, seed: int = 42) -> 'ExpManager'

按随机采样方式扩展配置。

参数:

名称 类型 描述 默认
space Mapping[str, Sequence[Any]]

参数搜索空间,值为候选列表。

必需
n_trials int

采样次数。

必需
seed int

随机种子。

42

返回:

名称 类型 描述
ExpManager 'ExpManager'

返回自身,支持链式调用。

Notes
  • 采样采用“有放回”策略;
  • 不替代 grid/variants,可与其组合使用。
源代码位于: ztxexp/manager.py
def random_search(
    self,
    space: Mapping[str, Sequence[Any]],
    n_trials: int,
    seed: int = 42,
) -> "ExpManager":
    """按随机采样方式扩展配置。

    Args:
        space: 参数搜索空间,值为候选列表。
        n_trials: 采样次数。
        seed: 随机种子。

    Returns:
        ExpManager: 返回自身,支持链式调用。

    Notes:
        - 采样采用“有放回”策略;
        - 不替代 ``grid/variants``,可与其组合使用。
    """
    if not space or n_trials <= 0:
        return self

    rng = random.Random(seed)
    keys = list(space.keys())
    value_lists = [list(space[key]) for key in keys]

    expanded: list[ConfigDict] = []
    for base_config in self._configs:
        for _ in range(n_trials):
            next_config = copy.deepcopy(base_config)
            for key, values in zip(keys, value_lists):
                if not values:
                    continue
                next_config[key] = rng.choice(values)
            expanded.append(next_config)

    self._configs = expanded
    return self

shuffle

shuffle() -> 'ExpManager'

在最终构建结果上随机打乱顺序。

返回:

名称 类型 描述
ExpManager 'ExpManager'

返回自身。

源代码位于: ztxexp/manager.py
def shuffle(self) -> "ExpManager":
    """在最终构建结果上随机打乱顺序。

    Returns:
        ExpManager: 返回自身。
    """
    self._should_shuffle = True
    return self

variants

variants(variants: Sequence[Mapping[str, Any]] | Mapping[str, Sequence[Any]]) -> 'ExpManager'

按“独立变体”方式扩展配置。

参数:

名称 类型 描述 默认
variants Sequence[Mapping[str, Any]] | Mapping[str, Sequence[Any]]

推荐传 list[dict],例如 [{"model": "tiny"}, {"model": "base", "layers": 12}]。 同时兼容旧格式 dict[str, list]

必需

返回:

名称 类型 描述
ExpManager 'ExpManager'

返回自身,支持链式调用。

Notes
  • list[dict] 语义更清晰,推荐优先使用。
  • dict[str, list] 会被转为单键变体集合。
源代码位于: ztxexp/manager.py
def variants(
    self,
    variants: Sequence[Mapping[str, Any]] | Mapping[str, Sequence[Any]],
) -> "ExpManager":
    """按“独立变体”方式扩展配置。

    Args:
        variants: 推荐传 ``list[dict]``,例如
            ``[{"model": "tiny"}, {"model": "base", "layers": 12}]``。
            同时兼容旧格式 ``dict[str, list]``。

    Returns:
        ExpManager: 返回自身,支持链式调用。

    Notes:
        - ``list[dict]`` 语义更清晰,推荐优先使用。
        - ``dict[str, list]`` 会被转为单键变体集合。
    """
    if not variants:
        return self

    variant_dicts: list[ConfigDict] = []
    if isinstance(variants, Mapping):
        for key, values in variants.items():
            for value in values:
                variant_dicts.append({key: value})
    else:
        variant_dicts = [dict(item) for item in variants]

    expanded: list[ConfigDict] = []
    for base_config in self._configs:
        for variant in variant_dicts:
            merged = copy.deepcopy(base_config)
            merged.update(copy.deepcopy(variant))
            expanded.append(merged)

    self._configs = expanded
    return self

where

where(predicate: Predicate) -> 'ExpManager'

注册配置过滤器。

参数:

名称 类型 描述 默认
predicate Predicate

谓词函数。返回 True 表示保留该配置。

必需

返回:

名称 类型 描述
ExpManager 'ExpManager'

返回自身。

源代码位于: ztxexp/manager.py
def where(self, predicate: Predicate) -> "ExpManager":
    """注册配置过滤器。

    Args:
        predicate: 谓词函数。返回 ``True`` 表示保留该配置。

    Returns:
        ExpManager: 返回自身。
    """
    self._predicates.append(predicate)
    return self

ExpRunner

实验执行器。

源代码位于: ztxexp/runner.py
class ExpRunner:
    """实验执行器。"""

    def __init__(
        self,
        configs: list[dict[str, Any]],
        results_root: str | Path,
        exp_function: ExperimentFn | None = None,
    ):
        self.configs = [dict(config) for config in configs]
        self.results_root = Path(results_root)
        self.exp_function = exp_function
        utils.create_dir(self.results_root)

    def run(
        self,
        exp_function: ExperimentFn | None = None,
        mode: str = "sequential",
        workers: int = 1,
        cpu_threshold: int = 80,
        execution_mode: str | None = None,
        num_workers: int | None = None,
        dynamic_cpu_threshold: int | None = None,
        metadata: RunMetadata | None = None,
        max_attempts: int = 1,
        retry_on: tuple[str, ...] = ("Exception",),
        tracker_specs: list[dict[str, Any]] | None = None,
        trackers: list[Tracker] | None = None,
    ) -> RunSummary:
        """执行全部配置并返回汇总。

        Args:
            exp_function: 单次实验函数,签名应为
                ``exp_fn(ctx: RunContext) -> dict | None``。
            mode: 执行模式。可选
                ``sequential`` / ``process_pool`` / ``joblib`` / ``dynamic``。
            workers: 并行 worker 数。
            cpu_threshold: ``dynamic`` 模式提交新任务时的 CPU 阈值。
            execution_mode: 兼容参数,等价于 ``mode``。
            num_workers: 兼容参数,等价于 ``workers``。
            dynamic_cpu_threshold: 兼容参数,等价于 ``cpu_threshold``。
            metadata: 运行元数据模板。框架会补全可采集字段。
            max_attempts: 每个配置最大尝试次数(失败重试上限)。
            retry_on: 可重试异常名集合(支持父类名,如 ``Exception``)。
            tracker_specs: 追踪器规格列表(字符串模式构造 tracker)。
            trackers: 追踪器实例列表(当前进程内对象)。

        Returns:
            RunSummary: 本次批量执行汇总(成功/失败/跳过计数与耗时)。

        Raises:
            ValueError: 未提供 ``exp_function`` 或 ``mode`` 不合法时抛出。

        Notes:
            - ``exp_fn`` 返回 ``dict`` 时自动写入 ``metrics.json``;
            - ``exp_fn`` 返回 ``None`` 时不写 ``metrics.json``;
            - 返回非 ``dict|None`` 会判定为失败并写 ``error.log``;
            - 抛出 ``SkipRun`` 会标记为 ``skipped``;
            - 成功判定以 ``run.json.status == succeeded`` 为准。

        Examples:
            >>> def exp_fn(ctx: RunContext):
            ...     return {"score": 0.9}
            >>> summary = ExpRunner([{"lr": 0.001}], "./results").run(exp_fn)
            >>> summary.total
            1
        """
        if execution_mode is not None:
            mode = execution_mode
        if num_workers is not None:
            workers = num_workers
        if dynamic_cpu_threshold is not None:
            cpu_threshold = dynamic_cpu_threshold

        experiment = exp_function or self.exp_function
        if experiment is None:
            raise ValueError("exp_function is required.")

        total = len(self.configs)
        started = time.time()

        if total == 0:
            return RunSummary(
                total=0,
                succeeded=0,
                failed=0,
                skipped=0,
                duration_sec=0.0,
                failed_run_ids=[],
            )

        resolved_specs = list(tracker_specs or [])
        has_jsonl_spec = any(
            str(spec.get("type", "")).lower() == "jsonl" for spec in resolved_specs
        )
        has_jsonl_instance = any(isinstance(tracker, JsonlTracker) for tracker in (trackers or []))
        if not has_jsonl_spec and not has_jsonl_instance:
            resolved_specs.append({"type": "jsonl", "kwargs": {}})

        if mode == "sequential":
            records = [
                _execute_single_run(
                    config=config,
                    exp_function=experiment,
                    results_root=self.results_root,
                    metadata=metadata,
                    max_attempts=max_attempts,
                    retry_on=retry_on,
                    tracker_specs=resolved_specs,
                    trackers=trackers,
                )
                for config in self.configs
            ]
        elif mode == "process_pool":
            if trackers:
                print("Live tracker instances are ignored in process_pool mode.")
            records = self._run_process_pool(
                exp_function=experiment,
                workers=workers,
                metadata=metadata,
                max_attempts=max_attempts,
                retry_on=retry_on,
                tracker_specs=resolved_specs,
            )
        elif mode == "joblib":
            if trackers:
                print("Live tracker instances are ignored in joblib mode.")
            records = self._run_joblib(
                exp_function=experiment,
                workers=workers,
                metadata=metadata,
                max_attempts=max_attempts,
                retry_on=retry_on,
                tracker_specs=resolved_specs,
            )
        elif mode == "dynamic":
            if trackers:
                print("Live tracker instances are ignored in dynamic mode.")
            records = self._run_dynamic(
                exp_function=experiment,
                workers=workers,
                cpu_threshold=cpu_threshold,
                metadata=metadata,
                max_attempts=max_attempts,
                retry_on=retry_on,
                tracker_specs=resolved_specs,
            )
        else:
            raise ValueError(
                f"Invalid mode '{mode}'. Choose from sequential/process_pool/joblib/dynamic."
            )

        duration = round(time.time() - started, 6)
        return self._summarize(records, total, duration)

    def _run_process_pool(
        self,
        exp_function: ExperimentFn,
        workers: int,
        metadata: RunMetadata | None,
        max_attempts: int,
        retry_on: tuple[str, ...],
        tracker_specs: list[dict[str, Any]],
    ) -> list[dict[str, Any]]:
        """使用 ProcessPoolExecutor 并行执行。"""
        records: list[dict[str, Any]] = []
        with ProcessPoolExecutor(max_workers=workers) as executor:
            future_map = {
                executor.submit(
                    _execute_single_run,
                    config,
                    exp_function,
                    self.results_root,
                    metadata,
                    max_attempts,
                    retry_on,
                    tracker_specs,
                    None,
                ): config
                for config in self.configs
            }
            for future in as_completed(future_map):
                try:
                    records.append(future.result())
                except Exception as exc:  # pragma: no cover
                    records.append(_failure_record_from_exception(exc))
        return records

    def _run_joblib(
        self,
        exp_function: ExperimentFn,
        workers: int,
        metadata: RunMetadata | None,
        max_attempts: int,
        retry_on: tuple[str, ...],
        tracker_specs: list[dict[str, Any]],
    ) -> list[dict[str, Any]]:
        """使用 joblib 并行执行。"""
        try:
            return Parallel(n_jobs=workers, prefer="processes")(
                delayed(_execute_single_run)(
                    config,
                    exp_function,
                    self.results_root,
                    metadata,
                    max_attempts,
                    retry_on,
                    tracker_specs,
                    None,
                )
                for config in self.configs
            )
        except Exception as exc:  # pragma: no cover
            return [_failure_record_from_exception(exc) for _ in self.configs]

    def _run_dynamic(
        self,
        exp_function: ExperimentFn,
        workers: int,
        cpu_threshold: int,
        metadata: RunMetadata | None,
        max_attempts: int,
        retry_on: tuple[str, ...],
        tracker_specs: list[dict[str, Any]],
    ) -> list[dict[str, Any]]:
        """动态调度执行(实验特性)。"""
        pending = deque(self.configs)
        in_flight: dict[Any, dict[str, Any]] = {}
        records: list[dict[str, Any]] = []

        with ProcessPoolExecutor(max_workers=workers) as executor:
            while pending or in_flight:
                cpu_usage = psutil.cpu_percent(interval=0.2)

                while pending and len(in_flight) < workers and cpu_usage < cpu_threshold:
                    config = pending.popleft()
                    future = executor.submit(
                        _execute_single_run,
                        config,
                        exp_function,
                        self.results_root,
                        metadata,
                        max_attempts,
                        retry_on,
                        tracker_specs,
                        None,
                    )
                    in_flight[future] = config
                    cpu_usage = psutil.cpu_percent(interval=0.0)

                if not in_flight:
                    time.sleep(0.2)
                    continue

                done, _ = wait(
                    in_flight.keys(),
                    timeout=0.5,
                    return_when=FIRST_COMPLETED,
                )

                for future in done:
                    in_flight.pop(future, None)
                    try:
                        records.append(future.result())
                    except Exception as exc:  # pragma: no cover
                        records.append(_failure_record_from_exception(exc))

        return records

    def _summarize(
        self,
        records: list[dict[str, Any]],
        total: int,
        duration_sec: float,
    ) -> RunSummary:
        """将执行记录聚合为 ``RunSummary``。"""
        succeeded = sum(1 for record in records if record.get("status") == RUN_STATUS_SUCCEEDED)
        failed = sum(1 for record in records if record.get("status") == RUN_STATUS_FAILED)
        skipped = sum(1 for record in records if record.get("status") == RUN_STATUS_SKIPPED)
        failed_run_ids = [
            str(record.get("run_id"))
            for record in records
            if record.get("status") == RUN_STATUS_FAILED
        ]

        return RunSummary(
            total=total,
            succeeded=succeeded,
            failed=failed,
            skipped=skipped,
            duration_sec=duration_sec,
            failed_run_ids=failed_run_ids,
        )

configs instance-attribute

configs = [(dict(config)) for config in configs]

exp_function instance-attribute

exp_function = exp_function

results_root instance-attribute

results_root = Path(results_root)

__init__

__init__(configs: list[dict[str, Any]], results_root: str | Path, exp_function: ExperimentFn | None = None)
源代码位于: ztxexp/runner.py
def __init__(
    self,
    configs: list[dict[str, Any]],
    results_root: str | Path,
    exp_function: ExperimentFn | None = None,
):
    self.configs = [dict(config) for config in configs]
    self.results_root = Path(results_root)
    self.exp_function = exp_function
    utils.create_dir(self.results_root)

run

run(exp_function: ExperimentFn | None = None, mode: str = 'sequential', workers: int = 1, cpu_threshold: int = 80, execution_mode: str | None = None, num_workers: int | None = None, dynamic_cpu_threshold: int | None = None, metadata: RunMetadata | None = None, max_attempts: int = 1, retry_on: tuple[str, ...] = ('Exception',), tracker_specs: list[dict[str, Any]] | None = None, trackers: list[Tracker] | None = None) -> RunSummary

执行全部配置并返回汇总。

参数:

名称 类型 描述 默认
exp_function ExperimentFn | None

单次实验函数,签名应为 exp_fn(ctx: RunContext) -> dict | None

None
mode str

执行模式。可选 sequential / process_pool / joblib / dynamic

'sequential'
workers int

并行 worker 数。

1
cpu_threshold int

dynamic 模式提交新任务时的 CPU 阈值。

80
execution_mode str | None

兼容参数,等价于 mode

None
num_workers int | None

兼容参数,等价于 workers

None
dynamic_cpu_threshold int | None

兼容参数,等价于 cpu_threshold

None
metadata RunMetadata | None

运行元数据模板。框架会补全可采集字段。

None
max_attempts int

每个配置最大尝试次数(失败重试上限)。

1
retry_on tuple[str, ...]

可重试异常名集合(支持父类名,如 Exception)。

('Exception',)
tracker_specs list[dict[str, Any]] | None

追踪器规格列表(字符串模式构造 tracker)。

None
trackers list[Tracker] | None

追踪器实例列表(当前进程内对象)。

None

返回:

名称 类型 描述
RunSummary RunSummary

本次批量执行汇总(成功/失败/跳过计数与耗时)。

引发:

类型 描述
ValueError

未提供 exp_functionmode 不合法时抛出。

Notes
  • exp_fn 返回 dict 时自动写入 metrics.json
  • exp_fn 返回 None 时不写 metrics.json
  • 返回非 dict|None 会判定为失败并写 error.log
  • 抛出 SkipRun 会标记为 skipped
  • 成功判定以 run.json.status == succeeded 为准。

示例:

>>> def exp_fn(ctx: RunContext):
...     return {"score": 0.9}
>>> summary = ExpRunner([{"lr": 0.001}], "./results").run(exp_fn)
>>> summary.total
1
源代码位于: ztxexp/runner.py
def run(
    self,
    exp_function: ExperimentFn | None = None,
    mode: str = "sequential",
    workers: int = 1,
    cpu_threshold: int = 80,
    execution_mode: str | None = None,
    num_workers: int | None = None,
    dynamic_cpu_threshold: int | None = None,
    metadata: RunMetadata | None = None,
    max_attempts: int = 1,
    retry_on: tuple[str, ...] = ("Exception",),
    tracker_specs: list[dict[str, Any]] | None = None,
    trackers: list[Tracker] | None = None,
) -> RunSummary:
    """执行全部配置并返回汇总。

    Args:
        exp_function: 单次实验函数,签名应为
            ``exp_fn(ctx: RunContext) -> dict | None``。
        mode: 执行模式。可选
            ``sequential`` / ``process_pool`` / ``joblib`` / ``dynamic``。
        workers: 并行 worker 数。
        cpu_threshold: ``dynamic`` 模式提交新任务时的 CPU 阈值。
        execution_mode: 兼容参数,等价于 ``mode``。
        num_workers: 兼容参数,等价于 ``workers``。
        dynamic_cpu_threshold: 兼容参数,等价于 ``cpu_threshold``。
        metadata: 运行元数据模板。框架会补全可采集字段。
        max_attempts: 每个配置最大尝试次数(失败重试上限)。
        retry_on: 可重试异常名集合(支持父类名,如 ``Exception``)。
        tracker_specs: 追踪器规格列表(字符串模式构造 tracker)。
        trackers: 追踪器实例列表(当前进程内对象)。

    Returns:
        RunSummary: 本次批量执行汇总(成功/失败/跳过计数与耗时)。

    Raises:
        ValueError: 未提供 ``exp_function`` 或 ``mode`` 不合法时抛出。

    Notes:
        - ``exp_fn`` 返回 ``dict`` 时自动写入 ``metrics.json``;
        - ``exp_fn`` 返回 ``None`` 时不写 ``metrics.json``;
        - 返回非 ``dict|None`` 会判定为失败并写 ``error.log``;
        - 抛出 ``SkipRun`` 会标记为 ``skipped``;
        - 成功判定以 ``run.json.status == succeeded`` 为准。

    Examples:
        >>> def exp_fn(ctx: RunContext):
        ...     return {"score": 0.9}
        >>> summary = ExpRunner([{"lr": 0.001}], "./results").run(exp_fn)
        >>> summary.total
        1
    """
    if execution_mode is not None:
        mode = execution_mode
    if num_workers is not None:
        workers = num_workers
    if dynamic_cpu_threshold is not None:
        cpu_threshold = dynamic_cpu_threshold

    experiment = exp_function or self.exp_function
    if experiment is None:
        raise ValueError("exp_function is required.")

    total = len(self.configs)
    started = time.time()

    if total == 0:
        return RunSummary(
            total=0,
            succeeded=0,
            failed=0,
            skipped=0,
            duration_sec=0.0,
            failed_run_ids=[],
        )

    resolved_specs = list(tracker_specs or [])
    has_jsonl_spec = any(
        str(spec.get("type", "")).lower() == "jsonl" for spec in resolved_specs
    )
    has_jsonl_instance = any(isinstance(tracker, JsonlTracker) for tracker in (trackers or []))
    if not has_jsonl_spec and not has_jsonl_instance:
        resolved_specs.append({"type": "jsonl", "kwargs": {}})

    if mode == "sequential":
        records = [
            _execute_single_run(
                config=config,
                exp_function=experiment,
                results_root=self.results_root,
                metadata=metadata,
                max_attempts=max_attempts,
                retry_on=retry_on,
                tracker_specs=resolved_specs,
                trackers=trackers,
            )
            for config in self.configs
        ]
    elif mode == "process_pool":
        if trackers:
            print("Live tracker instances are ignored in process_pool mode.")
        records = self._run_process_pool(
            exp_function=experiment,
            workers=workers,
            metadata=metadata,
            max_attempts=max_attempts,
            retry_on=retry_on,
            tracker_specs=resolved_specs,
        )
    elif mode == "joblib":
        if trackers:
            print("Live tracker instances are ignored in joblib mode.")
        records = self._run_joblib(
            exp_function=experiment,
            workers=workers,
            metadata=metadata,
            max_attempts=max_attempts,
            retry_on=retry_on,
            tracker_specs=resolved_specs,
        )
    elif mode == "dynamic":
        if trackers:
            print("Live tracker instances are ignored in dynamic mode.")
        records = self._run_dynamic(
            exp_function=experiment,
            workers=workers,
            cpu_threshold=cpu_threshold,
            metadata=metadata,
            max_attempts=max_attempts,
            retry_on=retry_on,
            tracker_specs=resolved_specs,
        )
    else:
        raise ValueError(
            f"Invalid mode '{mode}'. Choose from sequential/process_pool/joblib/dynamic."
        )

    duration = round(time.time() - started, 6)
    return self._summarize(records, total, duration)

ExperimentPipeline

实验流水线 Facade。

设计目标: 1. 减少样板代码; 2. 将管理与执行组合为统一入口; 3. 保留底层 ExpManager / ExpRunner 的可控性。

参数:

名称 类型 描述 默认
results_root str | Path

运行产物根目录。

必需
base_config Mapping[str, Any] | None

基础配置字典,后续 grid/variants 将基于它扩展。

None

示例:

>>> pipeline = (
...     ExperimentPipeline("./results", base_config={"seed": 42})
...     .grid({"lr": [1e-3, 1e-2]})
...     .variants([{"model": "tiny"}, {"model": "base"}])
... )
>>> configs = pipeline.build()
>>> len(configs)
4
源代码位于: ztxexp/pipeline.py
class ExperimentPipeline:
    """实验流水线 Facade。

    设计目标:
    1. 减少样板代码;
    2. 将管理与执行组合为统一入口;
    3. 保留底层 ``ExpManager`` / ``ExpRunner`` 的可控性。

    Args:
        results_root: 运行产物根目录。
        base_config: 基础配置字典,后续 grid/variants 将基于它扩展。

    Examples:
        >>> pipeline = (
        ...     ExperimentPipeline("./results", base_config={"seed": 42})
        ...     .grid({"lr": [1e-3, 1e-2]})
        ...     .variants([{"model": "tiny"}, {"model": "base"}])
        ... )
        >>> configs = pipeline.build()
        >>> len(configs)
        4
    """

    def __init__(
        self,
        results_root: str | Path,
        base_config: Mapping[str, Any] | None = None,
    ):
        self.results_root = Path(results_root)
        self._manager = ExpManager(base_config)
        self._exclude_completed = False
        self._experiment_name: str | None = None
        self._group_name: str | None = None
        self._tags: dict[str, str] | list[str] | None = None
        self._parent_run_id: str | None = None
        self._retry_max_attempts = 1
        self._retry_on = ("Exception",)
        self._tracker_specs: list[dict[str, Any]] = []
        self._trackers: list[Tracker] = []

    def grid(self, param_grid: Mapping[str, Sequence[Any]]) -> "ExperimentPipeline":
        """添加网格参数空间。

        Args:
            param_grid: 参数网格,例如 ``{"lr": [1e-3, 1e-2]}``。

        Returns:
            ExperimentPipeline: 返回自身以支持链式调用。
        """
        self._manager.grid(param_grid)
        return self

    def variants(self, variants: Sequence[Mapping[str, Any]]) -> "ExperimentPipeline":
        """添加独立变体空间。

        Args:
            variants: 变体列表,每个元素是一个配置片段字典。

        Returns:
            ExperimentPipeline: 返回自身以支持链式调用。

        Examples:
            >>> pipeline.variants([{"model": "tiny"}, {"model": "base"}])
        """
        self._manager.variants(variants)
        return self

    def random_search(
        self,
        space: Mapping[str, Sequence[Any]],
        n_trials: int,
        seed: int = 42,
    ) -> "ExperimentPipeline":
        """添加随机搜索空间。"""
        self._manager.random_search(space=space, n_trials=n_trials, seed=seed)
        return self

    def modify(self, fn: Callable[[dict[str, Any]], dict[str, Any] | None]) -> "ExperimentPipeline":
        """注册配置修改函数。

        Args:
            fn: 配置修改器。可原地修改并返回 ``None``,也可返回新字典。

        Returns:
            ExperimentPipeline: 返回自身以支持链式调用。
        """
        self._manager.modify(fn)
        return self

    def where(self, fn: Callable[[dict[str, Any]], bool]) -> "ExperimentPipeline":
        """注册配置过滤函数。

        Args:
            fn: 谓词函数。返回 ``True`` 表示保留配置。

        Returns:
            ExperimentPipeline: 返回自身以支持链式调用。
        """
        self._manager.where(fn)
        return self

    def exclude_completed(self) -> "ExperimentPipeline":
        """启用“排除已完成实验”逻辑。

        Returns:
            ExperimentPipeline: 返回自身以支持链式调用。
        """
        self._exclude_completed = True
        return self

    def name(self, experiment_name: str) -> "ExperimentPipeline":
        """设置实验名称。"""
        self._experiment_name = experiment_name
        return self

    def group(self, group_name: str) -> "ExperimentPipeline":
        """设置实验分组。"""
        self._group_name = group_name
        return self

    def tags(self, tags: dict[str, str] | list[str]) -> "ExperimentPipeline":
        """设置实验标签。"""
        self._tags = tags
        return self

    def lineage(self, parent_run_id: str | None) -> "ExperimentPipeline":
        """设置父 run ID。"""
        self._parent_run_id = parent_run_id
        return self

    def retry(
        self,
        max_attempts: int = 1,
        retry_on: tuple[str, ...] = ("Exception",),
    ) -> "ExperimentPipeline":
        """设置失败重试策略。"""
        self._retry_max_attempts = max(1, int(max_attempts))
        self._retry_on = retry_on
        return self

    def track(self, tracker: Tracker | str, **kwargs: Any) -> "ExperimentPipeline":
        """注册追踪器。

        Args:
            tracker: 追踪器实例或内置追踪器名(``jsonl/mlflow/wandb``)。
            **kwargs: 追踪器初始化参数(字符串模式下使用)。
        """
        if isinstance(tracker, str):
            self._tracker_specs.append({"type": tracker.lower(), "kwargs": dict(kwargs)})
            return self

        self._trackers.append(tracker)
        return self

    def build(self) -> list[dict[str, Any]]:
        """构建最终配置列表。

        Returns:
            list[dict[str, Any]]: 构建完成的配置字典列表。
        """
        if self._exclude_completed:
            self._manager.exclude_completed(self.results_root)
        return self._manager.build()

    def run(
        self,
        exp_fn: Callable[[RunContext], dict[str, Any] | None],
        mode: str = "sequential",
        workers: int = 1,
        cpu_threshold: int = 80,
    ) -> RunSummary:
        """构建配置并执行实验。

        Args:
            exp_fn: 单次实验函数,签名为 ``exp_fn(ctx: RunContext) -> dict | None``。
                - 返回 ``dict``:框架写入 ``metrics.json``;
                - 返回 ``None``:不写 ``metrics.json``;
                - 返回其他类型:判定为失败并写 ``error.log``;
                - 抛出 ``SkipRun``:判定为 ``skipped``。
            mode: 执行模式,支持 ``sequential`` / ``process_pool`` /
                ``joblib`` / ``dynamic``。
            workers: 并行 worker 数量。
            cpu_threshold: ``dynamic`` 模式下的 CPU 提交阈值。

        Returns:
            RunSummary: 批量执行汇总信息。

        Notes:
            用户业务文件建议统一写入 ``ctx.run_dir / \"artifacts\"``;
            过程曲线建议使用 ``ctx.log_metric(...)`` 写入 ``metrics.jsonl``。

        Examples:
            >>> def exp_fn(ctx: RunContext):
            ...     return {"score": 1.0}
            >>> summary = ExperimentPipeline("./results").run(exp_fn)
            >>> summary.total >= 0
            True
        """
        configs = self.build()
        runner = ExpRunner(configs=configs, results_root=self.results_root)
        run_meta = RunMetadata(
            experiment_name=self._experiment_name,
            group=self._group_name,
            tags=self._tags,
            parent_run_id=self._parent_run_id,
        )

        has_jsonl_spec = any(spec.get("type") == "jsonl" for spec in self._tracker_specs)
        has_jsonl_instance = any(
            tracker.__class__.__name__ == "JsonlTracker" for tracker in self._trackers
        )
        if not has_jsonl_spec and not has_jsonl_instance:
            self._tracker_specs.append({"type": "jsonl", "kwargs": {}})

        return runner.run(
            exp_function=exp_fn,
            mode=mode,
            workers=workers,
            cpu_threshold=cpu_threshold,
            metadata=run_meta,
            max_attempts=self._retry_max_attempts,
            retry_on=self._retry_on,
            tracker_specs=self._tracker_specs,
            trackers=self._trackers,
        )

results_root instance-attribute

results_root = Path(results_root)

__init__

__init__(results_root: str | Path, base_config: Mapping[str, Any] | None = None)
源代码位于: ztxexp/pipeline.py
def __init__(
    self,
    results_root: str | Path,
    base_config: Mapping[str, Any] | None = None,
):
    self.results_root = Path(results_root)
    self._manager = ExpManager(base_config)
    self._exclude_completed = False
    self._experiment_name: str | None = None
    self._group_name: str | None = None
    self._tags: dict[str, str] | list[str] | None = None
    self._parent_run_id: str | None = None
    self._retry_max_attempts = 1
    self._retry_on = ("Exception",)
    self._tracker_specs: list[dict[str, Any]] = []
    self._trackers: list[Tracker] = []

build

build() -> list[dict[str, Any]]

构建最终配置列表。

返回:

类型 描述
list[dict[str, Any]]

list[dict[str, Any]]: 构建完成的配置字典列表。

源代码位于: ztxexp/pipeline.py
def build(self) -> list[dict[str, Any]]:
    """构建最终配置列表。

    Returns:
        list[dict[str, Any]]: 构建完成的配置字典列表。
    """
    if self._exclude_completed:
        self._manager.exclude_completed(self.results_root)
    return self._manager.build()

exclude_completed

exclude_completed() -> 'ExperimentPipeline'

启用“排除已完成实验”逻辑。

返回:

名称 类型 描述
ExperimentPipeline 'ExperimentPipeline'

返回自身以支持链式调用。

源代码位于: ztxexp/pipeline.py
def exclude_completed(self) -> "ExperimentPipeline":
    """启用“排除已完成实验”逻辑。

    Returns:
        ExperimentPipeline: 返回自身以支持链式调用。
    """
    self._exclude_completed = True
    return self

grid

grid(param_grid: Mapping[str, Sequence[Any]]) -> 'ExperimentPipeline'

添加网格参数空间。

参数:

名称 类型 描述 默认
param_grid Mapping[str, Sequence[Any]]

参数网格,例如 {"lr": [1e-3, 1e-2]}

必需

返回:

名称 类型 描述
ExperimentPipeline 'ExperimentPipeline'

返回自身以支持链式调用。

源代码位于: ztxexp/pipeline.py
def grid(self, param_grid: Mapping[str, Sequence[Any]]) -> "ExperimentPipeline":
    """添加网格参数空间。

    Args:
        param_grid: 参数网格,例如 ``{"lr": [1e-3, 1e-2]}``。

    Returns:
        ExperimentPipeline: 返回自身以支持链式调用。
    """
    self._manager.grid(param_grid)
    return self

group

group(group_name: str) -> 'ExperimentPipeline'

设置实验分组。

源代码位于: ztxexp/pipeline.py
def group(self, group_name: str) -> "ExperimentPipeline":
    """设置实验分组。"""
    self._group_name = group_name
    return self

lineage

lineage(parent_run_id: str | None) -> 'ExperimentPipeline'

设置父 run ID。

源代码位于: ztxexp/pipeline.py
def lineage(self, parent_run_id: str | None) -> "ExperimentPipeline":
    """设置父 run ID。"""
    self._parent_run_id = parent_run_id
    return self

modify

modify(fn: Callable[[dict[str, Any]], dict[str, Any] | None]) -> 'ExperimentPipeline'

注册配置修改函数。

参数:

名称 类型 描述 默认
fn Callable[[dict[str, Any]], dict[str, Any] | None]

配置修改器。可原地修改并返回 None,也可返回新字典。

必需

返回:

名称 类型 描述
ExperimentPipeline 'ExperimentPipeline'

返回自身以支持链式调用。

源代码位于: ztxexp/pipeline.py
def modify(self, fn: Callable[[dict[str, Any]], dict[str, Any] | None]) -> "ExperimentPipeline":
    """注册配置修改函数。

    Args:
        fn: 配置修改器。可原地修改并返回 ``None``,也可返回新字典。

    Returns:
        ExperimentPipeline: 返回自身以支持链式调用。
    """
    self._manager.modify(fn)
    return self

name

name(experiment_name: str) -> 'ExperimentPipeline'

设置实验名称。

源代码位于: ztxexp/pipeline.py
def name(self, experiment_name: str) -> "ExperimentPipeline":
    """设置实验名称。"""
    self._experiment_name = experiment_name
    return self
random_search(space: Mapping[str, Sequence[Any]], n_trials: int, seed: int = 42) -> 'ExperimentPipeline'

添加随机搜索空间。

源代码位于: ztxexp/pipeline.py
def random_search(
    self,
    space: Mapping[str, Sequence[Any]],
    n_trials: int,
    seed: int = 42,
) -> "ExperimentPipeline":
    """添加随机搜索空间。"""
    self._manager.random_search(space=space, n_trials=n_trials, seed=seed)
    return self

retry

retry(max_attempts: int = 1, retry_on: tuple[str, ...] = ('Exception',)) -> 'ExperimentPipeline'

设置失败重试策略。

源代码位于: ztxexp/pipeline.py
def retry(
    self,
    max_attempts: int = 1,
    retry_on: tuple[str, ...] = ("Exception",),
) -> "ExperimentPipeline":
    """设置失败重试策略。"""
    self._retry_max_attempts = max(1, int(max_attempts))
    self._retry_on = retry_on
    return self

run

run(exp_fn: Callable[[RunContext], dict[str, Any] | None], mode: str = 'sequential', workers: int = 1, cpu_threshold: int = 80) -> RunSummary

构建配置并执行实验。

参数:

名称 类型 描述 默认
exp_fn Callable[[RunContext], dict[str, Any] | None]

单次实验函数,签名为 exp_fn(ctx: RunContext) -> dict | None。 - 返回 dict:框架写入 metrics.json; - 返回 None:不写 metrics.json; - 返回其他类型:判定为失败并写 error.log; - 抛出 SkipRun:判定为 skipped

必需
mode str

执行模式,支持 sequential / process_pool / joblib / dynamic

'sequential'
workers int

并行 worker 数量。

1
cpu_threshold int

dynamic 模式下的 CPU 提交阈值。

80

返回:

名称 类型 描述
RunSummary RunSummary

批量执行汇总信息。

Notes

用户业务文件建议统一写入 ctx.run_dir / "artifacts"; 过程曲线建议使用 ctx.log_metric(...) 写入 metrics.jsonl

示例:

>>> def exp_fn(ctx: RunContext):
...     return {"score": 1.0}
>>> summary = ExperimentPipeline("./results").run(exp_fn)
>>> summary.total >= 0
True
源代码位于: ztxexp/pipeline.py
def run(
    self,
    exp_fn: Callable[[RunContext], dict[str, Any] | None],
    mode: str = "sequential",
    workers: int = 1,
    cpu_threshold: int = 80,
) -> RunSummary:
    """构建配置并执行实验。

    Args:
        exp_fn: 单次实验函数,签名为 ``exp_fn(ctx: RunContext) -> dict | None``。
            - 返回 ``dict``:框架写入 ``metrics.json``;
            - 返回 ``None``:不写 ``metrics.json``;
            - 返回其他类型:判定为失败并写 ``error.log``;
            - 抛出 ``SkipRun``:判定为 ``skipped``。
        mode: 执行模式,支持 ``sequential`` / ``process_pool`` /
            ``joblib`` / ``dynamic``。
        workers: 并行 worker 数量。
        cpu_threshold: ``dynamic`` 模式下的 CPU 提交阈值。

    Returns:
        RunSummary: 批量执行汇总信息。

    Notes:
        用户业务文件建议统一写入 ``ctx.run_dir / \"artifacts\"``;
        过程曲线建议使用 ``ctx.log_metric(...)`` 写入 ``metrics.jsonl``。

    Examples:
        >>> def exp_fn(ctx: RunContext):
        ...     return {"score": 1.0}
        >>> summary = ExperimentPipeline("./results").run(exp_fn)
        >>> summary.total >= 0
        True
    """
    configs = self.build()
    runner = ExpRunner(configs=configs, results_root=self.results_root)
    run_meta = RunMetadata(
        experiment_name=self._experiment_name,
        group=self._group_name,
        tags=self._tags,
        parent_run_id=self._parent_run_id,
    )

    has_jsonl_spec = any(spec.get("type") == "jsonl" for spec in self._tracker_specs)
    has_jsonl_instance = any(
        tracker.__class__.__name__ == "JsonlTracker" for tracker in self._trackers
    )
    if not has_jsonl_spec and not has_jsonl_instance:
        self._tracker_specs.append({"type": "jsonl", "kwargs": {}})

    return runner.run(
        exp_function=exp_fn,
        mode=mode,
        workers=workers,
        cpu_threshold=cpu_threshold,
        metadata=run_meta,
        max_attempts=self._retry_max_attempts,
        retry_on=self._retry_on,
        tracker_specs=self._tracker_specs,
        trackers=self._trackers,
    )

tags

tags(tags: dict[str, str] | list[str]) -> 'ExperimentPipeline'

设置实验标签。

源代码位于: ztxexp/pipeline.py
def tags(self, tags: dict[str, str] | list[str]) -> "ExperimentPipeline":
    """设置实验标签。"""
    self._tags = tags
    return self

track

track(tracker: Tracker | str, **kwargs: Any) -> 'ExperimentPipeline'

注册追踪器。

参数:

名称 类型 描述 默认
tracker Tracker | str

追踪器实例或内置追踪器名(jsonl/mlflow/wandb)。

必需
**kwargs Any

追踪器初始化参数(字符串模式下使用)。

{}
源代码位于: ztxexp/pipeline.py
def track(self, tracker: Tracker | str, **kwargs: Any) -> "ExperimentPipeline":
    """注册追踪器。

    Args:
        tracker: 追踪器实例或内置追踪器名(``jsonl/mlflow/wandb``)。
        **kwargs: 追踪器初始化参数(字符串模式下使用)。
    """
    if isinstance(tracker, str):
        self._tracker_specs.append({"type": tracker.lower(), "kwargs": dict(kwargs)})
        return self

    self._trackers.append(tracker)
    return self

variants

variants(variants: Sequence[Mapping[str, Any]]) -> 'ExperimentPipeline'

添加独立变体空间。

参数:

名称 类型 描述 默认
variants Sequence[Mapping[str, Any]]

变体列表,每个元素是一个配置片段字典。

必需

返回:

名称 类型 描述
ExperimentPipeline 'ExperimentPipeline'

返回自身以支持链式调用。

示例:

>>> pipeline.variants([{"model": "tiny"}, {"model": "base"}])
源代码位于: ztxexp/pipeline.py
def variants(self, variants: Sequence[Mapping[str, Any]]) -> "ExperimentPipeline":
    """添加独立变体空间。

    Args:
        variants: 变体列表,每个元素是一个配置片段字典。

    Returns:
        ExperimentPipeline: 返回自身以支持链式调用。

    Examples:
        >>> pipeline.variants([{"model": "tiny"}, {"model": "base"}])
    """
    self._manager.variants(variants)
    return self

where

where(fn: Callable[[dict[str, Any]], bool]) -> 'ExperimentPipeline'

注册配置过滤函数。

参数:

名称 类型 描述 默认
fn Callable[[dict[str, Any]], bool]

谓词函数。返回 True 表示保留配置。

必需

返回:

名称 类型 描述
ExperimentPipeline 'ExperimentPipeline'

返回自身以支持链式调用。

源代码位于: ztxexp/pipeline.py
def where(self, fn: Callable[[dict[str, Any]], bool]) -> "ExperimentPipeline":
    """注册配置过滤函数。

    Args:
        fn: 谓词函数。返回 ``True`` 表示保留配置。

    Returns:
        ExperimentPipeline: 返回自身以支持链式调用。
    """
    self._manager.where(fn)
    return self

JsonlTracker

将生命周期事件写入 events.jsonl 的轻量追踪器。

参数:

名称 类型 描述 默认
events_filename str

事件文件名。

'events.jsonl'
源代码位于: ztxexp/tracking/jsonl.py
class JsonlTracker:
    """将生命周期事件写入 ``events.jsonl`` 的轻量追踪器。

    Args:
        events_filename: 事件文件名。
    """

    def __init__(self, events_filename: str = "events.jsonl"):
        self.events_filename = events_filename

    def _events_path(self, ctx: RunContext):
        return ctx.run_dir / self.events_filename

    def _append(self, ctx: RunContext, payload: dict[str, Any]) -> None:
        utils.append_jsonl(self._events_path(ctx), payload)

    def on_run_start(self, ctx: RunContext, meta: RunMetadata) -> None:
        """记录 run 启动事件。"""
        self._append(
            ctx,
            {
                "event": "run_start",
                "run_id": ctx.run_id,
                "timestamp": utils.utc_now_iso(),
                "meta": meta.to_dict(),
            },
        )

    def on_metric(self, ctx: RunContext, event: MetricEvent) -> None:
        """记录指标事件。"""
        self._append(
            ctx,
            {
                "event": "metric",
                "run_id": ctx.run_id,
                "payload": event.to_dict(),
            },
        )

    def on_run_end(self, ctx: RunContext, summary: dict[str, object]) -> None:
        """记录 run 结束事件。"""
        self._append(
            ctx,
            {
                "event": "run_end",
                "run_id": ctx.run_id,
                "timestamp": utils.utc_now_iso(),
                "summary": summary,
            },
        )

events_filename instance-attribute

events_filename = events_filename

__init__

__init__(events_filename: str = 'events.jsonl')
源代码位于: ztxexp/tracking/jsonl.py
def __init__(self, events_filename: str = "events.jsonl"):
    self.events_filename = events_filename

on_metric

on_metric(ctx: RunContext, event: MetricEvent) -> None

记录指标事件。

源代码位于: ztxexp/tracking/jsonl.py
def on_metric(self, ctx: RunContext, event: MetricEvent) -> None:
    """记录指标事件。"""
    self._append(
        ctx,
        {
            "event": "metric",
            "run_id": ctx.run_id,
            "payload": event.to_dict(),
        },
    )

on_run_end

on_run_end(ctx: RunContext, summary: dict[str, object]) -> None

记录 run 结束事件。

源代码位于: ztxexp/tracking/jsonl.py
def on_run_end(self, ctx: RunContext, summary: dict[str, object]) -> None:
    """记录 run 结束事件。"""
    self._append(
        ctx,
        {
            "event": "run_end",
            "run_id": ctx.run_id,
            "timestamp": utils.utc_now_iso(),
            "summary": summary,
        },
    )

on_run_start

on_run_start(ctx: RunContext, meta: RunMetadata) -> None

记录 run 启动事件。

源代码位于: ztxexp/tracking/jsonl.py
def on_run_start(self, ctx: RunContext, meta: RunMetadata) -> None:
    """记录 run 启动事件。"""
    self._append(
        ctx,
        {
            "event": "run_start",
            "run_id": ctx.run_id,
            "timestamp": utils.utc_now_iso(),
            "meta": meta.to_dict(),
        },
    )

MetricEvent dataclass

单条指标事件。

属性:

名称 类型 描述
step int

指标对应的 step(epoch/global step)。

timestamp str

事件时间(ISO8601)。

metrics dict[str, float]

指标字典。

split str

数据划分(train/valid/test)。

phase str

阶段标识(fit/eval/infer)。

源代码位于: ztxexp/types.py
@dataclass(slots=True)
class MetricEvent:
    """单条指标事件。

    Attributes:
        step: 指标对应的 step(epoch/global step)。
        timestamp: 事件时间(ISO8601)。
        metrics: 指标字典。
        split: 数据划分(train/valid/test)。
        phase: 阶段标识(fit/eval/infer)。
    """

    step: int
    timestamp: str
    metrics: dict[str, float]
    split: str = "train"
    phase: str = "fit"

    def to_dict(self) -> dict[str, Any]:
        """转换为字典。"""
        return asdict(self)

metrics instance-attribute

metrics: dict[str, float]

phase class-attribute instance-attribute

phase: str = 'fit'

split class-attribute instance-attribute

split: str = 'train'

step instance-attribute

step: int

timestamp instance-attribute

timestamp: str

__init__

__init__(step: int, timestamp: str, metrics: dict[str, float], split: str = 'train', phase: str = 'fit') -> None

to_dict

to_dict() -> dict[str, Any]

转换为字典。

源代码位于: ztxexp/types.py
def to_dict(self) -> dict[str, Any]:
    """转换为字典。"""
    return asdict(self)

MlflowTracker

MLflow 追踪器(可选依赖)。

源代码位于: ztxexp/tracking/adapters.py
class MlflowTracker:
    """MLflow 追踪器(可选依赖)。"""

    def __init__(
        self,
        tracking_uri: str | None = None,
        experiment_name: str | None = None,
        run_name: str | None = None,
    ):
        self.tracking_uri = tracking_uri
        self.experiment_name = experiment_name
        self.run_name = run_name
        self._started = False

    def _mlflow(self):
        try:
            import mlflow

            return mlflow
        except ImportError as exc:  # pragma: no cover
            raise ImportError(
                "mlflow is required for MlflowTracker. Install with: pip install ztxexp[mlflow]"
            ) from exc

    def on_run_start(self, ctx: RunContext, meta: RunMetadata) -> None:
        mlflow = self._mlflow()
        if self.tracking_uri:
            mlflow.set_tracking_uri(self.tracking_uri)
        if self.experiment_name or meta.experiment_name:
            mlflow.set_experiment(self.experiment_name or str(meta.experiment_name))
        mlflow.start_run(run_name=self.run_name or ctx.run_id)
        self._started = True
        mlflow.log_params(
            {
                k: v
                for k, v in ctx.config.items()
                if isinstance(v, (str, int, float, bool))
            }
        )
        mlflow.set_tags(
            {
                "run_id": ctx.run_id,
                "group": meta.group or "",
                "parent_run_id": meta.parent_run_id or "",
            }
        )

    def on_metric(self, ctx: RunContext, event: MetricEvent) -> None:
        if not self._started:
            return
        mlflow = self._mlflow()
        mlflow.log_metrics(event.metrics, step=event.step)

    def on_run_end(self, ctx: RunContext, summary: dict[str, object]) -> None:
        if not self._started:
            return
        mlflow = self._mlflow()
        status = summary.get("status")
        if status is not None:
            mlflow.set_tag("status", str(status))
        mlflow.end_run()
        self._started = False

experiment_name instance-attribute

experiment_name = experiment_name

run_name instance-attribute

run_name = run_name

tracking_uri instance-attribute

tracking_uri = tracking_uri

__init__

__init__(tracking_uri: str | None = None, experiment_name: str | None = None, run_name: str | None = None)
源代码位于: ztxexp/tracking/adapters.py
def __init__(
    self,
    tracking_uri: str | None = None,
    experiment_name: str | None = None,
    run_name: str | None = None,
):
    self.tracking_uri = tracking_uri
    self.experiment_name = experiment_name
    self.run_name = run_name
    self._started = False

on_metric

on_metric(ctx: RunContext, event: MetricEvent) -> None
源代码位于: ztxexp/tracking/adapters.py
def on_metric(self, ctx: RunContext, event: MetricEvent) -> None:
    if not self._started:
        return
    mlflow = self._mlflow()
    mlflow.log_metrics(event.metrics, step=event.step)

on_run_end

on_run_end(ctx: RunContext, summary: dict[str, object]) -> None
源代码位于: ztxexp/tracking/adapters.py
def on_run_end(self, ctx: RunContext, summary: dict[str, object]) -> None:
    if not self._started:
        return
    mlflow = self._mlflow()
    status = summary.get("status")
    if status is not None:
        mlflow.set_tag("status", str(status))
    mlflow.end_run()
    self._started = False

on_run_start

on_run_start(ctx: RunContext, meta: RunMetadata) -> None
源代码位于: ztxexp/tracking/adapters.py
def on_run_start(self, ctx: RunContext, meta: RunMetadata) -> None:
    mlflow = self._mlflow()
    if self.tracking_uri:
        mlflow.set_tracking_uri(self.tracking_uri)
    if self.experiment_name or meta.experiment_name:
        mlflow.set_experiment(self.experiment_name or str(meta.experiment_name))
    mlflow.start_run(run_name=self.run_name or ctx.run_id)
    self._started = True
    mlflow.log_params(
        {
            k: v
            for k, v in ctx.config.items()
            if isinstance(v, (str, int, float, bool))
        }
    )
    mlflow.set_tags(
        {
            "run_id": ctx.run_id,
            "group": meta.group or "",
            "parent_run_id": meta.parent_run_id or "",
        }
    )

ResultAnalyzer

实验结果分析器(仅支持 schema v2)。

参数:

名称 类型 描述 默认
results_path str | Path

运行根目录。

必需

引发:

类型 描述
FileNotFoundError

结果目录不存在时抛出。

示例:

>>> analyzer = ResultAnalyzer("./results_demo")
>>> df = analyzer.to_dataframe(statuses=("succeeded",))
源代码位于: ztxexp/analyzer.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
class ResultAnalyzer:
    """实验结果分析器(仅支持 schema v2)。

    Args:
        results_path: 运行根目录。

    Raises:
        FileNotFoundError: 结果目录不存在时抛出。

    Examples:
        >>> analyzer = ResultAnalyzer("./results_demo")
        >>> df = analyzer.to_dataframe(statuses=("succeeded",))
    """

    def __init__(self, results_path: str | Path):
        self.results_path = Path(results_path)
        if not self.results_path.exists():
            raise FileNotFoundError(f"Results path does not exist: {self.results_path}")

    def to_records(
        self,
        statuses: Sequence[str] | None = (RUN_STATUS_SUCCEEDED,),
        metrics_filename: str = "metrics.json",
        experiment_name: str | None = None,
        group: str | None = None,
        tags: dict[str, str] | list[str] | None = None,
    ) -> list[dict[str, Any]]:
        """读取 run 目录并合并为记录列表。

        合并顺序:``config -> metrics -> run_meta``,后者覆盖前者同名键。
        仅处理 ``schema_version == 2`` 的 run 目录。

        Args:
            statuses: 允许状态集合;传 ``None`` 表示不过滤状态。
            metrics_filename: 指标文件名,默认 ``metrics.json``。
            experiment_name: 实验名称过滤条件。
            group: 分组过滤条件。
            tags: 标签过滤条件。

        Returns:
            list[dict[str, Any]]: 扁平化记录列表。每条记录至少包含:
                - 配置字段(来自 ``config.json``);
                - 指标字段(来自 ``metrics.json``,若文件存在);
                - 运行元字段(来自 ``run.json``,如 ``status/run_id``);
                - ``run_dir``(绝对路径字符串)。

        Examples:
            >>> records = ResultAnalyzer("./results_demo").to_records(statuses=None)
            >>> isinstance(records, list)
            True
        """
        records: list[dict[str, Any]] = []
        target_statuses = set(statuses) if statuses is not None else None

        for run_dir in utils.get_subdirectories(self.results_path):
            record = self._load_record(run_dir, metrics_filename)
            if record is None:
                continue

            status = record.get("status")
            if target_statuses is not None and status not in target_statuses:
                continue
            if experiment_name and record.get("experiment_name") != experiment_name:
                continue
            if group and record.get("group") != group:
                continue
            if tags and not self._tags_match(record.get("tags"), tags):
                continue

            records.append(record)

        return records

    def to_dataframe(
        self,
        statuses: Sequence[str] | None = (RUN_STATUS_SUCCEEDED,),
        metrics_filename: str = "metrics.json",
        experiment_name: str | None = None,
        group: str | None = None,
        tags: dict[str, str] | list[str] | None = None,
    ) -> pd.DataFrame:
        """将记录列表转为 DataFrame。

        Args:
            statuses: 状态过滤条件。
            metrics_filename: 指标文件名。

        Returns:
            pd.DataFrame: 聚合后的数据表;若无数据返回空 DataFrame。
        """
        records = self.to_records(
            statuses=statuses,
            metrics_filename=metrics_filename,
            experiment_name=experiment_name,
            group=group,
            tags=tags,
        )
        if not records:
            return pd.DataFrame()
        return pd.DataFrame.from_records(records)

    def to_csv(
        self,
        output_path: str | Path,
        sort_by: Sequence[str] | None = None,
        statuses: Sequence[str] | None = (RUN_STATUS_SUCCEEDED,),
        metrics_filename: str = "metrics.json",
        experiment_name: str | None = None,
        group: str | None = None,
        tags: dict[str, str] | list[str] | None = None,
    ) -> pd.DataFrame:
        """导出 CSV,并返回导出所用 DataFrame。

        Args:
            output_path: CSV 输出路径。
            sort_by: 排序字段列表(仅会使用存在于列中的字段)。
            statuses: 状态过滤条件。
            metrics_filename: 指标文件名。

        Returns:
            pd.DataFrame: 导出用 DataFrame(可能为空)。
        """
        df = self.to_dataframe(
            statuses=statuses,
            metrics_filename=metrics_filename,
            experiment_name=experiment_name,
            group=group,
            tags=tags,
        )
        if df.empty:
            print("No records found to export.")
            return df

        if sort_by:
            valid_keys = [key for key in sort_by if key in df.columns]
            if valid_keys:
                df = df.sort_values(by=valid_keys).reset_index(drop=True)

        df.to_csv(output_path, index=False)
        print(f"Saved CSV to {output_path}")
        return df

    def to_metric_events(
        self,
        statuses: Sequence[str] | None = (RUN_STATUS_SUCCEEDED,),
        metrics_stream_filename: str = "metrics.jsonl",
        experiment_name: str | None = None,
        group: str | None = None,
        tags: dict[str, str] | list[str] | None = None,
    ) -> list[MetricEvent]:
        """读取 step 级指标事件。

        Args:
            statuses: 状态过滤条件。
            metrics_stream_filename: 指标流文件名,默认 ``metrics.jsonl``。
            experiment_name: 实验名称过滤条件。
            group: 分组过滤条件。
            tags: 标签过滤条件。

        Returns:
            list[MetricEvent]: 事件列表。仅返回结构合法的事件:
                - ``step`` 必须是 ``int``;
                - ``timestamp`` 必须是 ``str``;
                - ``metrics`` 必须是 ``dict`` 且值可转为 ``float``。

        Notes:
            无效行会被跳过,不会抛出异常中断整个读取流程。
        """
        events: list[MetricEvent] = []
        target_statuses = set(statuses) if statuses is not None else None

        for run_dir in utils.get_subdirectories(self.results_path):
            run_meta = utils.load_json(run_dir / "run.json")
            if not run_meta:
                continue
            if run_meta.get("schema_version") != RUN_SCHEMA_VERSION:
                continue
            if target_statuses is not None and run_meta.get("status") not in target_statuses:
                continue
            if experiment_name and run_meta.get("experiment_name") != experiment_name:
                continue
            if group and run_meta.get("group") != group:
                continue
            if tags and not self._tags_match(run_meta.get("tags"), tags):
                continue

            rows = utils.load_jsonl(run_dir / metrics_stream_filename, skip_invalid=True)
            for row in rows:
                metrics = row.get("metrics")
                step = row.get("step")
                timestamp = row.get("timestamp")
                if not isinstance(metrics, dict):
                    continue
                if not isinstance(step, int):
                    continue
                if not isinstance(timestamp, str):
                    continue
                try:
                    events.append(
                        MetricEvent(
                            step=step,
                            timestamp=timestamp,
                            metrics={k: float(v) for k, v in metrics.items()},
                            split=str(row.get("split", "train")),
                            phase=str(row.get("phase", "fit")),
                        )
                    )
                except Exception:
                    continue

        return events

    def to_curve_dataframe(
        self,
        metric_key: str | None = None,
        statuses: Sequence[str] | None = (RUN_STATUS_SUCCEEDED,),
        metrics_stream_filename: str = "metrics.jsonl",
        experiment_name: str | None = None,
        group: str | None = None,
        tags: dict[str, str] | list[str] | None = None,
    ) -> pd.DataFrame:
        """将 step 级指标事件转为曲线 DataFrame。

        Args:
            metric_key: 指标键。若为空则展开全部指标。
            statuses: 状态过滤条件。
            metrics_stream_filename: 指标流文件名。
            experiment_name: 实验名称过滤条件。
            group: 分组过滤条件。
            tags: 标签过滤条件。

        Returns:
            pd.DataFrame: 曲线数据表。
                - 基础列始终包含:``run_id/timestamp/step/split/phase``;
                - 当 ``metric_key`` 非空时,只返回该指标列;
                - 当 ``metric_key`` 为空时,展开 ``metrics`` 的全部键。

        Examples:
            >>> analyzer = ResultAnalyzer("./results_demo")
            >>> df = analyzer.to_curve_dataframe(metric_key="loss")
            >>> set(["run_id", "step"]).issubset(df.columns) if not df.empty else True
            True
        """
        rows: list[dict[str, Any]] = []
        target_statuses = set(statuses) if statuses is not None else None

        for run_dir in utils.get_subdirectories(self.results_path):
            run_meta = utils.load_json(run_dir / "run.json")
            if not run_meta:
                continue
            if run_meta.get("schema_version") != RUN_SCHEMA_VERSION:
                continue
            if target_statuses is not None and run_meta.get("status") not in target_statuses:
                continue
            if experiment_name and run_meta.get("experiment_name") != experiment_name:
                continue
            if group and run_meta.get("group") != group:
                continue
            if tags and not self._tags_match(run_meta.get("tags"), tags):
                continue

            run_id = str(run_meta.get("run_id") or run_dir.name)
            records = utils.load_jsonl(run_dir / metrics_stream_filename, skip_invalid=True)
            for record in records:
                metrics = record.get("metrics")
                if not isinstance(metrics, dict):
                    continue

                base = {
                    "run_id": run_id,
                    "timestamp": record.get("timestamp"),
                    "step": record.get("step"),
                    "split": record.get("split", "train"),
                    "phase": record.get("phase", "fit"),
                }
                if metric_key:
                    if metric_key in metrics:
                        base[metric_key] = metrics[metric_key]
                        rows.append(base)
                else:
                    expanded = dict(base)
                    expanded.update(metrics)
                    rows.append(expanded)

        if not rows:
            return pd.DataFrame()
        return pd.DataFrame.from_records(rows)

    def to_pivot_excel(
        self,
        output_path: str | Path,
        df: pd.DataFrame,
        index_cols: Sequence[str],
        column_cols: Sequence[str],
        value_cols: Sequence[str],
        add_ranking: bool = True,
        ranking_ascending: bool = False,
    ) -> None:
        """生成透视表并导出 Excel。

        Args:
            output_path: Excel 输出路径。
            df: 输入数据表。
            index_cols: 透视表行索引字段。
            column_cols: 透视表列索引字段。
            value_cols: 值字段。
            add_ranking: 是否附加名次标签(1st/2nd/3rd)。
            ranking_ascending: 排名方向。``False`` 通常用于“值越大越好”。

        Returns:
            None

        Notes:
            该功能依赖 ``openpyxl``。未安装时会给出提示并安全返回。
        """
        if df.empty:
            print("DataFrame is empty, cannot generate pivot table.")
            return

        try:
            pivot_df = df.pivot_table(index=index_cols, columns=column_cols, values=value_cols)
        except Exception as exc:
            print(f"Failed to create pivot table: {exc}")
            return

        if not add_ranking:
            try:
                pivot_df.to_excel(output_path)
            except ImportError:
                print(
                    "openpyxl is required for Excel export. "
                    "Install with: pip install openpyxl"
                )
                return
            print(f"Saved pivot table to {output_path}")
            return

        rank_df = pivot_df.rank(method="min", ascending=ranking_ascending)
        final_pivot = pivot_df.astype(str)
        rank_labels = {1.0: " (1st)", 2.0: " (2nd)", 3.0: " (3rd)"}

        for col in final_pivot.columns:
            for idx in final_pivot.index:
                value = pivot_df.at[idx, col]
                if pd.notna(value):
                    rank = rank_df.at[idx, col]
                    final_pivot.at[idx, col] = f"{value:.4f}{rank_labels.get(rank, '')}"
                else:
                    final_pivot.at[idx, col] = ""

        try:
            final_pivot.to_excel(output_path)
        except ImportError:
            print(
                "openpyxl is required for Excel export. "
                "Install with: pip install openpyxl"
            )
            return
        print(f"Saved ranked pivot table to {output_path}")

    def clean_results(
        self,
        statuses: Sequence[str] | None = (
            RUN_STATUS_FAILED,
            RUN_STATUS_RUNNING,
            RUN_STATUS_SKIPPED,
        ),
        predicate: RecordPredicate | None = None,
        dry_run: bool = True,
        metrics_filename: str = "metrics.json",
        confirm: bool = True,
    ) -> list[Path]:
        """清理匹配条件的 run 目录。

        删除条件采用 OR 逻辑:
        1) ``status in statuses``(当 statuses 非 None);
        2) ``predicate(record) is True``(当 predicate 非空)。

        Args:
            statuses: 状态过滤集合;``None`` 表示不按状态筛选。
            predicate: 自定义删除规则。
            dry_run: 为 ``True`` 时只打印并返回候选,不执行删除。
            metrics_filename: 指标文件名。
            confirm: 非 dry-run 且为 ``True`` 时,删除前二次确认。

        Returns:
            list[Path]:
                - dry-run: 待删目录列表;
                - 非 dry-run: 实际删除成功的目录列表。

        Examples:
            >>> analyzer = ResultAnalyzer("./results_demo")
            >>> analyzer.clean_results(statuses=("failed",), dry_run=True)
        """
        target_statuses = set(statuses) if statuses is not None else None
        to_delete: list[Path] = []

        for run_dir in utils.get_subdirectories(self.results_path):
            record = self._load_record(run_dir, metrics_filename)
            if record is None:
                continue

            should_delete = False
            if target_statuses is not None and record.get("status") in target_statuses:
                should_delete = True
            if predicate and predicate(record):
                should_delete = True

            if should_delete:
                to_delete.append(run_dir)

        if not to_delete:
            print("No folders matched cleanup criteria.")
            return []

        print(f"Found {len(to_delete)} folders to delete.")
        for run_dir in to_delete:
            print(f"  - {run_dir.name}")

        if dry_run:
            print("Dry run enabled. Nothing was deleted.")
            return to_delete

        if confirm:
            answer = input(f"Delete these {len(to_delete)} folders permanently? (yes/no): ")
            if answer.strip().lower() != "yes":
                print("Deletion canceled.")
                return []

        deleted: list[Path] = []
        for run_dir in to_delete:
            try:
                shutil.rmtree(run_dir)
                deleted.append(run_dir)
            except Exception as exc:  # pragma: no cover
                print(f"Failed to delete {run_dir}: {exc}")

        print(f"Deleted {len(deleted)} folders.")
        return deleted

    def _load_record(self, run_dir: Path, metrics_filename: str) -> dict[str, Any] | None:
        """从单个 run 目录加载扁平记录。

        Args:
            run_dir: 单个 run 目录。
            metrics_filename: 指标文件名。

        Returns:
            dict[str, Any] | None: 合并记录;若目录不符合 v2 协议则返回 None。
        """
        run_meta = utils.load_json(run_dir / "run.json")
        if not run_meta:
            return None
        if run_meta.get("schema_version") != RUN_SCHEMA_VERSION:
            return None

        config = utils.load_json(run_dir / "config.json") or {}
        if not isinstance(config, dict):
            return None

        metrics = utils.load_json(run_dir / metrics_filename) or {}
        if not isinstance(metrics, dict):
            metrics = {}

        record: dict[str, Any] = {}
        record.update(config)
        record.update(metrics)
        record.update(run_meta)
        record["run_dir"] = str(run_dir.resolve())
        return record

    def _tags_match(
        self,
        record_tags: Any,
        target_tags: dict[str, str] | list[str],
    ) -> bool:
        """判断标签是否匹配。"""
        if isinstance(target_tags, dict):
            if not isinstance(record_tags, dict):
                return False
            for key, value in target_tags.items():
                if str(record_tags.get(key)) != str(value):
                    return False
            return True

        if isinstance(target_tags, list):
            if isinstance(record_tags, list):
                return all(tag in record_tags for tag in target_tags)
            if isinstance(record_tags, dict):
                values = set(record_tags.values()) | set(record_tags.keys())
                return all(tag in values for tag in target_tags)
            return False

        return False

results_path instance-attribute

results_path = Path(results_path)

__init__

__init__(results_path: str | Path)
源代码位于: ztxexp/analyzer.py
def __init__(self, results_path: str | Path):
    self.results_path = Path(results_path)
    if not self.results_path.exists():
        raise FileNotFoundError(f"Results path does not exist: {self.results_path}")

clean_results

clean_results(statuses: Sequence[str] | None = (RUN_STATUS_FAILED, RUN_STATUS_RUNNING, RUN_STATUS_SKIPPED), predicate: RecordPredicate | None = None, dry_run: bool = True, metrics_filename: str = 'metrics.json', confirm: bool = True) -> list[Path]

清理匹配条件的 run 目录。

删除条件采用 OR 逻辑: 1) status in statuses(当 statuses 非 None); 2) predicate(record) is True(当 predicate 非空)。

参数:

名称 类型 描述 默认
statuses Sequence[str] | None

状态过滤集合;None 表示不按状态筛选。

(RUN_STATUS_FAILED, RUN_STATUS_RUNNING, RUN_STATUS_SKIPPED)
predicate RecordPredicate | None

自定义删除规则。

None
dry_run bool

True 时只打印并返回候选,不执行删除。

True
metrics_filename str

指标文件名。

'metrics.json'
confirm bool

非 dry-run 且为 True 时,删除前二次确认。

True

返回:

类型 描述
list[Path]

list[Path]: - dry-run: 待删目录列表; - 非 dry-run: 实际删除成功的目录列表。

示例:

>>> analyzer = ResultAnalyzer("./results_demo")
>>> analyzer.clean_results(statuses=("failed",), dry_run=True)
源代码位于: ztxexp/analyzer.py
def clean_results(
    self,
    statuses: Sequence[str] | None = (
        RUN_STATUS_FAILED,
        RUN_STATUS_RUNNING,
        RUN_STATUS_SKIPPED,
    ),
    predicate: RecordPredicate | None = None,
    dry_run: bool = True,
    metrics_filename: str = "metrics.json",
    confirm: bool = True,
) -> list[Path]:
    """清理匹配条件的 run 目录。

    删除条件采用 OR 逻辑:
    1) ``status in statuses``(当 statuses 非 None);
    2) ``predicate(record) is True``(当 predicate 非空)。

    Args:
        statuses: 状态过滤集合;``None`` 表示不按状态筛选。
        predicate: 自定义删除规则。
        dry_run: 为 ``True`` 时只打印并返回候选,不执行删除。
        metrics_filename: 指标文件名。
        confirm: 非 dry-run 且为 ``True`` 时,删除前二次确认。

    Returns:
        list[Path]:
            - dry-run: 待删目录列表;
            - 非 dry-run: 实际删除成功的目录列表。

    Examples:
        >>> analyzer = ResultAnalyzer("./results_demo")
        >>> analyzer.clean_results(statuses=("failed",), dry_run=True)
    """
    target_statuses = set(statuses) if statuses is not None else None
    to_delete: list[Path] = []

    for run_dir in utils.get_subdirectories(self.results_path):
        record = self._load_record(run_dir, metrics_filename)
        if record is None:
            continue

        should_delete = False
        if target_statuses is not None and record.get("status") in target_statuses:
            should_delete = True
        if predicate and predicate(record):
            should_delete = True

        if should_delete:
            to_delete.append(run_dir)

    if not to_delete:
        print("No folders matched cleanup criteria.")
        return []

    print(f"Found {len(to_delete)} folders to delete.")
    for run_dir in to_delete:
        print(f"  - {run_dir.name}")

    if dry_run:
        print("Dry run enabled. Nothing was deleted.")
        return to_delete

    if confirm:
        answer = input(f"Delete these {len(to_delete)} folders permanently? (yes/no): ")
        if answer.strip().lower() != "yes":
            print("Deletion canceled.")
            return []

    deleted: list[Path] = []
    for run_dir in to_delete:
        try:
            shutil.rmtree(run_dir)
            deleted.append(run_dir)
        except Exception as exc:  # pragma: no cover
            print(f"Failed to delete {run_dir}: {exc}")

    print(f"Deleted {len(deleted)} folders.")
    return deleted

to_csv

to_csv(output_path: str | Path, sort_by: Sequence[str] | None = None, statuses: Sequence[str] | None = (RUN_STATUS_SUCCEEDED,), metrics_filename: str = 'metrics.json', experiment_name: str | None = None, group: str | None = None, tags: dict[str, str] | list[str] | None = None) -> pd.DataFrame

导出 CSV,并返回导出所用 DataFrame。

参数:

名称 类型 描述 默认
output_path str | Path

CSV 输出路径。

必需
sort_by Sequence[str] | None

排序字段列表(仅会使用存在于列中的字段)。

None
statuses Sequence[str] | None

状态过滤条件。

(RUN_STATUS_SUCCEEDED,)
metrics_filename str

指标文件名。

'metrics.json'

返回:

类型 描述
DataFrame

pd.DataFrame: 导出用 DataFrame(可能为空)。

源代码位于: ztxexp/analyzer.py
def to_csv(
    self,
    output_path: str | Path,
    sort_by: Sequence[str] | None = None,
    statuses: Sequence[str] | None = (RUN_STATUS_SUCCEEDED,),
    metrics_filename: str = "metrics.json",
    experiment_name: str | None = None,
    group: str | None = None,
    tags: dict[str, str] | list[str] | None = None,
) -> pd.DataFrame:
    """导出 CSV,并返回导出所用 DataFrame。

    Args:
        output_path: CSV 输出路径。
        sort_by: 排序字段列表(仅会使用存在于列中的字段)。
        statuses: 状态过滤条件。
        metrics_filename: 指标文件名。

    Returns:
        pd.DataFrame: 导出用 DataFrame(可能为空)。
    """
    df = self.to_dataframe(
        statuses=statuses,
        metrics_filename=metrics_filename,
        experiment_name=experiment_name,
        group=group,
        tags=tags,
    )
    if df.empty:
        print("No records found to export.")
        return df

    if sort_by:
        valid_keys = [key for key in sort_by if key in df.columns]
        if valid_keys:
            df = df.sort_values(by=valid_keys).reset_index(drop=True)

    df.to_csv(output_path, index=False)
    print(f"Saved CSV to {output_path}")
    return df

to_curve_dataframe

to_curve_dataframe(metric_key: str | None = None, statuses: Sequence[str] | None = (RUN_STATUS_SUCCEEDED,), metrics_stream_filename: str = 'metrics.jsonl', experiment_name: str | None = None, group: str | None = None, tags: dict[str, str] | list[str] | None = None) -> pd.DataFrame

将 step 级指标事件转为曲线 DataFrame。

参数:

名称 类型 描述 默认
metric_key str | None

指标键。若为空则展开全部指标。

None
statuses Sequence[str] | None

状态过滤条件。

(RUN_STATUS_SUCCEEDED,)
metrics_stream_filename str

指标流文件名。

'metrics.jsonl'
experiment_name str | None

实验名称过滤条件。

None
group str | None

分组过滤条件。

None
tags dict[str, str] | list[str] | None

标签过滤条件。

None

返回:

类型 描述
DataFrame

pd.DataFrame: 曲线数据表。 - 基础列始终包含:run_id/timestamp/step/split/phase; - 当 metric_key 非空时,只返回该指标列; - 当 metric_key 为空时,展开 metrics 的全部键。

示例:

>>> analyzer = ResultAnalyzer("./results_demo")
>>> df = analyzer.to_curve_dataframe(metric_key="loss")
>>> set(["run_id", "step"]).issubset(df.columns) if not df.empty else True
True
源代码位于: ztxexp/analyzer.py
def to_curve_dataframe(
    self,
    metric_key: str | None = None,
    statuses: Sequence[str] | None = (RUN_STATUS_SUCCEEDED,),
    metrics_stream_filename: str = "metrics.jsonl",
    experiment_name: str | None = None,
    group: str | None = None,
    tags: dict[str, str] | list[str] | None = None,
) -> pd.DataFrame:
    """将 step 级指标事件转为曲线 DataFrame。

    Args:
        metric_key: 指标键。若为空则展开全部指标。
        statuses: 状态过滤条件。
        metrics_stream_filename: 指标流文件名。
        experiment_name: 实验名称过滤条件。
        group: 分组过滤条件。
        tags: 标签过滤条件。

    Returns:
        pd.DataFrame: 曲线数据表。
            - 基础列始终包含:``run_id/timestamp/step/split/phase``;
            - 当 ``metric_key`` 非空时,只返回该指标列;
            - 当 ``metric_key`` 为空时,展开 ``metrics`` 的全部键。

    Examples:
        >>> analyzer = ResultAnalyzer("./results_demo")
        >>> df = analyzer.to_curve_dataframe(metric_key="loss")
        >>> set(["run_id", "step"]).issubset(df.columns) if not df.empty else True
        True
    """
    rows: list[dict[str, Any]] = []
    target_statuses = set(statuses) if statuses is not None else None

    for run_dir in utils.get_subdirectories(self.results_path):
        run_meta = utils.load_json(run_dir / "run.json")
        if not run_meta:
            continue
        if run_meta.get("schema_version") != RUN_SCHEMA_VERSION:
            continue
        if target_statuses is not None and run_meta.get("status") not in target_statuses:
            continue
        if experiment_name and run_meta.get("experiment_name") != experiment_name:
            continue
        if group and run_meta.get("group") != group:
            continue
        if tags and not self._tags_match(run_meta.get("tags"), tags):
            continue

        run_id = str(run_meta.get("run_id") or run_dir.name)
        records = utils.load_jsonl(run_dir / metrics_stream_filename, skip_invalid=True)
        for record in records:
            metrics = record.get("metrics")
            if not isinstance(metrics, dict):
                continue

            base = {
                "run_id": run_id,
                "timestamp": record.get("timestamp"),
                "step": record.get("step"),
                "split": record.get("split", "train"),
                "phase": record.get("phase", "fit"),
            }
            if metric_key:
                if metric_key in metrics:
                    base[metric_key] = metrics[metric_key]
                    rows.append(base)
            else:
                expanded = dict(base)
                expanded.update(metrics)
                rows.append(expanded)

    if not rows:
        return pd.DataFrame()
    return pd.DataFrame.from_records(rows)

to_dataframe

to_dataframe(statuses: Sequence[str] | None = (RUN_STATUS_SUCCEEDED,), metrics_filename: str = 'metrics.json', experiment_name: str | None = None, group: str | None = None, tags: dict[str, str] | list[str] | None = None) -> pd.DataFrame

将记录列表转为 DataFrame。

参数:

名称 类型 描述 默认
statuses Sequence[str] | None

状态过滤条件。

(RUN_STATUS_SUCCEEDED,)
metrics_filename str

指标文件名。

'metrics.json'

返回:

类型 描述
DataFrame

pd.DataFrame: 聚合后的数据表;若无数据返回空 DataFrame。

源代码位于: ztxexp/analyzer.py
def to_dataframe(
    self,
    statuses: Sequence[str] | None = (RUN_STATUS_SUCCEEDED,),
    metrics_filename: str = "metrics.json",
    experiment_name: str | None = None,
    group: str | None = None,
    tags: dict[str, str] | list[str] | None = None,
) -> pd.DataFrame:
    """将记录列表转为 DataFrame。

    Args:
        statuses: 状态过滤条件。
        metrics_filename: 指标文件名。

    Returns:
        pd.DataFrame: 聚合后的数据表;若无数据返回空 DataFrame。
    """
    records = self.to_records(
        statuses=statuses,
        metrics_filename=metrics_filename,
        experiment_name=experiment_name,
        group=group,
        tags=tags,
    )
    if not records:
        return pd.DataFrame()
    return pd.DataFrame.from_records(records)

to_metric_events

to_metric_events(statuses: Sequence[str] | None = (RUN_STATUS_SUCCEEDED,), metrics_stream_filename: str = 'metrics.jsonl', experiment_name: str | None = None, group: str | None = None, tags: dict[str, str] | list[str] | None = None) -> list[MetricEvent]

读取 step 级指标事件。

参数:

名称 类型 描述 默认
statuses Sequence[str] | None

状态过滤条件。

(RUN_STATUS_SUCCEEDED,)
metrics_stream_filename str

指标流文件名,默认 metrics.jsonl

'metrics.jsonl'
experiment_name str | None

实验名称过滤条件。

None
group str | None

分组过滤条件。

None
tags dict[str, str] | list[str] | None

标签过滤条件。

None

返回:

类型 描述
list[MetricEvent]

list[MetricEvent]: 事件列表。仅返回结构合法的事件: - step 必须是 int; - timestamp 必须是 str; - metrics 必须是 dict 且值可转为 float

Notes

无效行会被跳过,不会抛出异常中断整个读取流程。

源代码位于: ztxexp/analyzer.py
def to_metric_events(
    self,
    statuses: Sequence[str] | None = (RUN_STATUS_SUCCEEDED,),
    metrics_stream_filename: str = "metrics.jsonl",
    experiment_name: str | None = None,
    group: str | None = None,
    tags: dict[str, str] | list[str] | None = None,
) -> list[MetricEvent]:
    """读取 step 级指标事件。

    Args:
        statuses: 状态过滤条件。
        metrics_stream_filename: 指标流文件名,默认 ``metrics.jsonl``。
        experiment_name: 实验名称过滤条件。
        group: 分组过滤条件。
        tags: 标签过滤条件。

    Returns:
        list[MetricEvent]: 事件列表。仅返回结构合法的事件:
            - ``step`` 必须是 ``int``;
            - ``timestamp`` 必须是 ``str``;
            - ``metrics`` 必须是 ``dict`` 且值可转为 ``float``。

    Notes:
        无效行会被跳过,不会抛出异常中断整个读取流程。
    """
    events: list[MetricEvent] = []
    target_statuses = set(statuses) if statuses is not None else None

    for run_dir in utils.get_subdirectories(self.results_path):
        run_meta = utils.load_json(run_dir / "run.json")
        if not run_meta:
            continue
        if run_meta.get("schema_version") != RUN_SCHEMA_VERSION:
            continue
        if target_statuses is not None and run_meta.get("status") not in target_statuses:
            continue
        if experiment_name and run_meta.get("experiment_name") != experiment_name:
            continue
        if group and run_meta.get("group") != group:
            continue
        if tags and not self._tags_match(run_meta.get("tags"), tags):
            continue

        rows = utils.load_jsonl(run_dir / metrics_stream_filename, skip_invalid=True)
        for row in rows:
            metrics = row.get("metrics")
            step = row.get("step")
            timestamp = row.get("timestamp")
            if not isinstance(metrics, dict):
                continue
            if not isinstance(step, int):
                continue
            if not isinstance(timestamp, str):
                continue
            try:
                events.append(
                    MetricEvent(
                        step=step,
                        timestamp=timestamp,
                        metrics={k: float(v) for k, v in metrics.items()},
                        split=str(row.get("split", "train")),
                        phase=str(row.get("phase", "fit")),
                    )
                )
            except Exception:
                continue

    return events

to_pivot_excel

to_pivot_excel(output_path: str | Path, df: DataFrame, index_cols: Sequence[str], column_cols: Sequence[str], value_cols: Sequence[str], add_ranking: bool = True, ranking_ascending: bool = False) -> None

生成透视表并导出 Excel。

参数:

名称 类型 描述 默认
output_path str | Path

Excel 输出路径。

必需
df DataFrame

输入数据表。

必需
index_cols Sequence[str]

透视表行索引字段。

必需
column_cols Sequence[str]

透视表列索引字段。

必需
value_cols Sequence[str]

值字段。

必需
add_ranking bool

是否附加名次标签(1st/2nd/3rd)。

True
ranking_ascending bool

排名方向。False 通常用于“值越大越好”。

False

返回:

类型 描述
None

None

Notes

该功能依赖 openpyxl。未安装时会给出提示并安全返回。

源代码位于: ztxexp/analyzer.py
def to_pivot_excel(
    self,
    output_path: str | Path,
    df: pd.DataFrame,
    index_cols: Sequence[str],
    column_cols: Sequence[str],
    value_cols: Sequence[str],
    add_ranking: bool = True,
    ranking_ascending: bool = False,
) -> None:
    """生成透视表并导出 Excel。

    Args:
        output_path: Excel 输出路径。
        df: 输入数据表。
        index_cols: 透视表行索引字段。
        column_cols: 透视表列索引字段。
        value_cols: 值字段。
        add_ranking: 是否附加名次标签(1st/2nd/3rd)。
        ranking_ascending: 排名方向。``False`` 通常用于“值越大越好”。

    Returns:
        None

    Notes:
        该功能依赖 ``openpyxl``。未安装时会给出提示并安全返回。
    """
    if df.empty:
        print("DataFrame is empty, cannot generate pivot table.")
        return

    try:
        pivot_df = df.pivot_table(index=index_cols, columns=column_cols, values=value_cols)
    except Exception as exc:
        print(f"Failed to create pivot table: {exc}")
        return

    if not add_ranking:
        try:
            pivot_df.to_excel(output_path)
        except ImportError:
            print(
                "openpyxl is required for Excel export. "
                "Install with: pip install openpyxl"
            )
            return
        print(f"Saved pivot table to {output_path}")
        return

    rank_df = pivot_df.rank(method="min", ascending=ranking_ascending)
    final_pivot = pivot_df.astype(str)
    rank_labels = {1.0: " (1st)", 2.0: " (2nd)", 3.0: " (3rd)"}

    for col in final_pivot.columns:
        for idx in final_pivot.index:
            value = pivot_df.at[idx, col]
            if pd.notna(value):
                rank = rank_df.at[idx, col]
                final_pivot.at[idx, col] = f"{value:.4f}{rank_labels.get(rank, '')}"
            else:
                final_pivot.at[idx, col] = ""

    try:
        final_pivot.to_excel(output_path)
    except ImportError:
        print(
            "openpyxl is required for Excel export. "
            "Install with: pip install openpyxl"
        )
        return
    print(f"Saved ranked pivot table to {output_path}")

to_records

to_records(statuses: Sequence[str] | None = (RUN_STATUS_SUCCEEDED,), metrics_filename: str = 'metrics.json', experiment_name: str | None = None, group: str | None = None, tags: dict[str, str] | list[str] | None = None) -> list[dict[str, Any]]

读取 run 目录并合并为记录列表。

合并顺序:config -> metrics -> run_meta,后者覆盖前者同名键。 仅处理 schema_version == 2 的 run 目录。

参数:

名称 类型 描述 默认
statuses Sequence[str] | None

允许状态集合;传 None 表示不过滤状态。

(RUN_STATUS_SUCCEEDED,)
metrics_filename str

指标文件名,默认 metrics.json

'metrics.json'
experiment_name str | None

实验名称过滤条件。

None
group str | None

分组过滤条件。

None
tags dict[str, str] | list[str] | None

标签过滤条件。

None

返回:

类型 描述
list[dict[str, Any]]

list[dict[str, Any]]: 扁平化记录列表。每条记录至少包含: - 配置字段(来自 config.json); - 指标字段(来自 metrics.json,若文件存在); - 运行元字段(来自 run.json,如 status/run_id); - run_dir(绝对路径字符串)。

示例:

>>> records = ResultAnalyzer("./results_demo").to_records(statuses=None)
>>> isinstance(records, list)
True
源代码位于: ztxexp/analyzer.py
def to_records(
    self,
    statuses: Sequence[str] | None = (RUN_STATUS_SUCCEEDED,),
    metrics_filename: str = "metrics.json",
    experiment_name: str | None = None,
    group: str | None = None,
    tags: dict[str, str] | list[str] | None = None,
) -> list[dict[str, Any]]:
    """读取 run 目录并合并为记录列表。

    合并顺序:``config -> metrics -> run_meta``,后者覆盖前者同名键。
    仅处理 ``schema_version == 2`` 的 run 目录。

    Args:
        statuses: 允许状态集合;传 ``None`` 表示不过滤状态。
        metrics_filename: 指标文件名,默认 ``metrics.json``。
        experiment_name: 实验名称过滤条件。
        group: 分组过滤条件。
        tags: 标签过滤条件。

    Returns:
        list[dict[str, Any]]: 扁平化记录列表。每条记录至少包含:
            - 配置字段(来自 ``config.json``);
            - 指标字段(来自 ``metrics.json``,若文件存在);
            - 运行元字段(来自 ``run.json``,如 ``status/run_id``);
            - ``run_dir``(绝对路径字符串)。

    Examples:
        >>> records = ResultAnalyzer("./results_demo").to_records(statuses=None)
        >>> isinstance(records, list)
        True
    """
    records: list[dict[str, Any]] = []
    target_statuses = set(statuses) if statuses is not None else None

    for run_dir in utils.get_subdirectories(self.results_path):
        record = self._load_record(run_dir, metrics_filename)
        if record is None:
            continue

        status = record.get("status")
        if target_statuses is not None and status not in target_statuses:
            continue
        if experiment_name and record.get("experiment_name") != experiment_name:
            continue
        if group and record.get("group") != group:
            continue
        if tags and not self._tags_match(record.get("tags"), tags):
            continue

        records.append(record)

    return records

RunContext dataclass

单次实验运行上下文。

该对象由 ExpRunner 在每个 run 开始时构造,并传入用户实验函数。 exp_fn 的推荐契约如下:

  1. 函数签名:exp_fn(ctx: RunContext) -> dict | None
  2. 返回 dict 时框架会写入 metrics.json
  3. 返回 None 时不写 metrics.json,但 run 仍可成功;
  4. 业务产物统一写入 ctx.run_dir / "artifacts"
  5. 过程指标用 ctx.log_metric(...) 写入 metrics.jsonl

属性:

名称 类型 描述
run_id str

当前运行唯一 ID(同时也是 run 目录名)。

run_dir Path

当前运行目录绝对路径。

config dict[str, Any]

当前运行最终配置字典。

logger Logger

当前运行专属日志对象(输出到 run.log)。

meta RunMetadata

当前 run 元数据对象。

示例:

>>> def exp_fn(ctx: RunContext):
...     lr = ctx.config["lr"]
...     ctx.logger.info("lr=%s", lr)
...     return {"score": 1.0 - lr}
源代码位于: ztxexp/types.py
@dataclass(slots=True)
class RunContext:
    """单次实验运行上下文。

    该对象由 ``ExpRunner`` 在每个 run 开始时构造,并传入用户实验函数。
    ``exp_fn`` 的推荐契约如下:

    1. 函数签名:``exp_fn(ctx: RunContext) -> dict | None``;
    2. 返回 ``dict`` 时框架会写入 ``metrics.json``;
    3. 返回 ``None`` 时不写 ``metrics.json``,但 run 仍可成功;
    4. 业务产物统一写入 ``ctx.run_dir / "artifacts"``;
    5. 过程指标用 ``ctx.log_metric(...)`` 写入 ``metrics.jsonl``。

    Attributes:
        run_id: 当前运行唯一 ID(同时也是 run 目录名)。
        run_dir: 当前运行目录绝对路径。
        config: 当前运行最终配置字典。
        logger: 当前运行专属日志对象(输出到 run.log)。
        meta: 当前 run 元数据对象。

    Examples:
        >>> def exp_fn(ctx: RunContext):
        ...     lr = ctx.config["lr"]
        ...     ctx.logger.info("lr=%s", lr)
        ...     return {"score": 1.0 - lr}
    """

    run_id: str
    run_dir: Path
    config: dict[str, Any]
    logger: logging.Logger
    meta: RunMetadata = field(default_factory=RunMetadata)
    _metrics_jsonl_path: Path | None = field(default=None, repr=False)
    _trackers: list["Tracker"] = field(default_factory=list, repr=False)

    def log_metric(
        self,
        step: int,
        metrics: dict[str, float],
        split: str = "train",
        phase: str = "fit",
    ) -> None:
        """记录 step 级指标并通知 tracker。

        该方法用于写入过程曲线数据,不替代 ``exp_fn`` 的最终 ``return dict``。
        典型分工是:
        1. ``ctx.log_metric`` 负责每步/每轮中间指标;
        2. ``return dict`` 负责最终汇总指标。

        Args:
            step: 当前 step(例如 epoch 或 global step)。
            metrics: 指标字典,值应可转为 JSON(建议 ``float``)。
            split: 数据划分,如 ``train/valid/test``。
            phase: 阶段标识,如 ``fit/eval/infer``。

        Returns:
            None

        Examples:
            >>> ctx.log_metric(step=1, metrics={"loss": 0.91}, split="train", phase="fit")
        """
        event = MetricEvent(
            step=step,
            timestamp=datetime.now(timezone.utc).isoformat(),
            metrics=metrics,
            split=split,
            phase=phase,
        )
        payload = event.to_dict()

        if self._metrics_jsonl_path is not None:
            utils.append_jsonl(self._metrics_jsonl_path, payload)

        for tracker in self._trackers:
            tracker.on_metric(self, event)

config instance-attribute

config: dict[str, Any]

logger instance-attribute

logger: Logger

meta class-attribute instance-attribute

meta: RunMetadata = field(default_factory=RunMetadata)

run_dir instance-attribute

run_dir: Path

run_id instance-attribute

run_id: str

__init__

__init__(run_id: str, run_dir: Path, config: dict[str, Any], logger: Logger, meta: RunMetadata = RunMetadata(), _metrics_jsonl_path: Path | None = None, _trackers: list['Tracker'] = list()) -> None

log_metric

log_metric(step: int, metrics: dict[str, float], split: str = 'train', phase: str = 'fit') -> None

记录 step 级指标并通知 tracker。

该方法用于写入过程曲线数据,不替代 exp_fn 的最终 return dict。 典型分工是: 1. ctx.log_metric 负责每步/每轮中间指标; 2. return dict 负责最终汇总指标。

参数:

名称 类型 描述 默认
step int

当前 step(例如 epoch 或 global step)。

必需
metrics dict[str, float]

指标字典,值应可转为 JSON(建议 float)。

必需
split str

数据划分,如 train/valid/test

'train'
phase str

阶段标识,如 fit/eval/infer

'fit'

返回:

类型 描述
None

None

示例:

>>> ctx.log_metric(step=1, metrics={"loss": 0.91}, split="train", phase="fit")
源代码位于: ztxexp/types.py
def log_metric(
    self,
    step: int,
    metrics: dict[str, float],
    split: str = "train",
    phase: str = "fit",
) -> None:
    """记录 step 级指标并通知 tracker。

    该方法用于写入过程曲线数据,不替代 ``exp_fn`` 的最终 ``return dict``。
    典型分工是:
    1. ``ctx.log_metric`` 负责每步/每轮中间指标;
    2. ``return dict`` 负责最终汇总指标。

    Args:
        step: 当前 step(例如 epoch 或 global step)。
        metrics: 指标字典,值应可转为 JSON(建议 ``float``)。
        split: 数据划分,如 ``train/valid/test``。
        phase: 阶段标识,如 ``fit/eval/infer``。

    Returns:
        None

    Examples:
        >>> ctx.log_metric(step=1, metrics={"loss": 0.91}, split="train", phase="fit")
    """
    event = MetricEvent(
        step=step,
        timestamp=datetime.now(timezone.utc).isoformat(),
        metrics=metrics,
        split=split,
        phase=phase,
    )
    payload = event.to_dict()

    if self._metrics_jsonl_path is not None:
        utils.append_jsonl(self._metrics_jsonl_path, payload)

    for tracker in self._trackers:
        tracker.on_metric(self, event)

RunMetadata dataclass

运行元数据。

用于描述一次 run 的治理与复现上下文。字段均为可选,框架会在运行时 自动填充可采集部分(如 python 版本、平台、命令行等)。

属性:

名称 类型 描述
experiment_name str | None

实验名称。

group str | None

实验分组。

tags dict[str, str] | list[str] | None

标签(可为字典或字符串列表)。

parent_run_id str | None

父 run ID(用于 lineage)。

attempt int | None

当前尝试次数(重试时递增)。

git_commit str | None

当前代码 commit。

python_version str | None

Python 版本。

platform str | None

运行平台描述。

hostname str | None

主机名。

started_cmd str | None

启动命令。

dataset_version str | None

数据版本标识。

seed int | None

随机种子。

extras dict[str, Any] | None

其它扩展元数据。

源代码位于: ztxexp/types.py
@dataclass(slots=True)
class RunMetadata:
    """运行元数据。

    用于描述一次 run 的治理与复现上下文。字段均为可选,框架会在运行时
    自动填充可采集部分(如 python 版本、平台、命令行等)。

    Attributes:
        experiment_name: 实验名称。
        group: 实验分组。
        tags: 标签(可为字典或字符串列表)。
        parent_run_id: 父 run ID(用于 lineage)。
        attempt: 当前尝试次数(重试时递增)。
        git_commit: 当前代码 commit。
        python_version: Python 版本。
        platform: 运行平台描述。
        hostname: 主机名。
        started_cmd: 启动命令。
        dataset_version: 数据版本标识。
        seed: 随机种子。
        extras: 其它扩展元数据。
    """

    experiment_name: str | None = None
    group: str | None = None
    tags: dict[str, str] | list[str] | None = None
    parent_run_id: str | None = None
    attempt: int | None = None
    git_commit: str | None = None
    python_version: str | None = None
    platform: str | None = None
    hostname: str | None = None
    started_cmd: str | None = None
    dataset_version: str | None = None
    seed: int | None = None
    extras: dict[str, Any] | None = None

    def to_dict(self) -> dict[str, Any]:
        """转换为字典。"""
        return asdict(self)

attempt class-attribute instance-attribute

attempt: int | None = None

dataset_version class-attribute instance-attribute

dataset_version: str | None = None

experiment_name class-attribute instance-attribute

experiment_name: str | None = None

extras class-attribute instance-attribute

extras: dict[str, Any] | None = None

git_commit class-attribute instance-attribute

git_commit: str | None = None

group class-attribute instance-attribute

group: str | None = None

hostname class-attribute instance-attribute

hostname: str | None = None

parent_run_id class-attribute instance-attribute

parent_run_id: str | None = None

platform class-attribute instance-attribute

platform: str | None = None

python_version class-attribute instance-attribute

python_version: str | None = None

seed class-attribute instance-attribute

seed: int | None = None

started_cmd class-attribute instance-attribute

started_cmd: str | None = None

tags class-attribute instance-attribute

tags: dict[str, str] | list[str] | None = None

__init__

__init__(experiment_name: str | None = None, group: str | None = None, tags: dict[str, str] | list[str] | None = None, parent_run_id: str | None = None, attempt: int | None = None, git_commit: str | None = None, python_version: str | None = None, platform: str | None = None, hostname: str | None = None, started_cmd: str | None = None, dataset_version: str | None = None, seed: int | None = None, extras: dict[str, Any] | None = None) -> None

to_dict

to_dict() -> dict[str, Any]

转换为字典。

源代码位于: ztxexp/types.py
def to_dict(self) -> dict[str, Any]:
    """转换为字典。"""
    return asdict(self)

RunSummary dataclass

一次批量执行的汇总结果。

属性:

名称 类型 描述
total int

本次执行计划中的配置总数。

succeeded int

成功运行数量。

failed int

失败运行数量。

skipped int

跳过运行数量。

duration_sec float

本次批量执行总耗时(秒)。

failed_run_ids list[str]

失败 run 的 ID 列表。

示例:

>>> summary = RunSummary(4, 3, 1, 0, 2.35, ["20260301_xxx"])
>>> summary.failed
1
源代码位于: ztxexp/types.py
@dataclass(slots=True)
class RunSummary:
    """一次批量执行的汇总结果。

    Attributes:
        total: 本次执行计划中的配置总数。
        succeeded: 成功运行数量。
        failed: 失败运行数量。
        skipped: 跳过运行数量。
        duration_sec: 本次批量执行总耗时(秒)。
        failed_run_ids: 失败 run 的 ID 列表。

    Examples:
        >>> summary = RunSummary(4, 3, 1, 0, 2.35, ["20260301_xxx"])
        >>> summary.failed
        1
    """

    total: int
    succeeded: int
    failed: int
    skipped: int
    duration_sec: float
    failed_run_ids: list[str]

duration_sec instance-attribute

duration_sec: float

failed instance-attribute

failed: int

failed_run_ids instance-attribute

failed_run_ids: list[str]

skipped instance-attribute

skipped: int

succeeded instance-attribute

succeeded: int

total instance-attribute

total: int

__init__

__init__(total: int, succeeded: int, failed: int, skipped: int, duration_sec: float, failed_run_ids: list[str]) -> None

SkipRun

Bases: Exception

主动跳过当前运行。

exp_fn 中抛出该异常时,当前 run 会被标记为 skipped, 而不是 failed。适用于“业务上不合法、无需重试”的配置分支。

示例:

>>> from ztxexp import SkipRun
>>> def exp_fn(ctx):
...     if ctx.config.get("batch_size", 0) <= 0:
...         raise SkipRun("batch_size must be positive")
...     return {"score": 0.9}
源代码位于: ztxexp/runner.py
class SkipRun(Exception):
    """主动跳过当前运行。

    在 ``exp_fn`` 中抛出该异常时,当前 run 会被标记为 ``skipped``,
    而不是 ``failed``。适用于“业务上不合法、无需重试”的配置分支。

    Examples:
        >>> from ztxexp import SkipRun
        >>> def exp_fn(ctx):
        ...     if ctx.config.get("batch_size", 0) <= 0:
        ...         raise SkipRun("batch_size must be positive")
        ...     return {"score": 0.9}
    """

Tracker

Bases: Protocol

实验追踪器协议。

所有追踪器都应实现生命周期三段回调: 1. run 开始; 2. 指标事件; 3. run 结束。

源代码位于: ztxexp/tracking/base.py
class Tracker(Protocol):
    """实验追踪器协议。

    所有追踪器都应实现生命周期三段回调:
    1. run 开始;
    2. 指标事件;
    3. run 结束。
    """

    def on_run_start(self, ctx: RunContext, meta: RunMetadata) -> None:
        """run 启动回调。"""

    def on_metric(self, ctx: RunContext, event: MetricEvent) -> None:
        """指标事件回调。"""

    def on_run_end(self, ctx: RunContext, summary: dict[str, object]) -> None:
        """run 结束回调。"""

on_metric

on_metric(ctx: RunContext, event: MetricEvent) -> None

指标事件回调。

源代码位于: ztxexp/tracking/base.py
def on_metric(self, ctx: RunContext, event: MetricEvent) -> None:
    """指标事件回调。"""

on_run_end

on_run_end(ctx: RunContext, summary: dict[str, object]) -> None

run 结束回调。

源代码位于: ztxexp/tracking/base.py
def on_run_end(self, ctx: RunContext, summary: dict[str, object]) -> None:
    """run 结束回调。"""

on_run_start

on_run_start(ctx: RunContext, meta: RunMetadata) -> None

run 启动回调。

源代码位于: ztxexp/tracking/base.py
def on_run_start(self, ctx: RunContext, meta: RunMetadata) -> None:
    """run 启动回调。"""

WandbTracker

Weights & Biases 追踪器(可选依赖)。

源代码位于: ztxexp/tracking/adapters.py
class WandbTracker:
    """Weights & Biases 追踪器(可选依赖)。"""

    def __init__(self, project: str | None = None, entity: str | None = None, **kwargs: Any):
        self.project = project
        self.entity = entity
        self.kwargs = kwargs
        self._run = None

    def _wandb(self):
        try:
            import wandb

            return wandb
        except ImportError as exc:  # pragma: no cover
            raise ImportError(
                "wandb is required for WandbTracker. Install with: pip install ztxexp[wandb]"
            ) from exc

    def on_run_start(self, ctx: RunContext, meta: RunMetadata) -> None:
        wandb = self._wandb()
        self._run = wandb.init(
            project=self.project,
            entity=self.entity,
            config=ctx.config,
            name=ctx.run_id,
            reinit=True,
            **self.kwargs,
        )
        if self._run is not None:
            if meta.experiment_name:
                self._run.summary["experiment_name"] = meta.experiment_name
            if meta.group:
                self._run.summary["group"] = meta.group

    def on_metric(self, ctx: RunContext, event: MetricEvent) -> None:
        if self._run is None:
            return
        wandb = self._wandb()
        wandb.log(dict(event.metrics), step=event.step)

    def on_run_end(self, ctx: RunContext, summary: dict[str, object]) -> None:
        if self._run is None:
            return
        self._run.summary.update(summary)
        self._run.finish()
        self._run = None

entity instance-attribute

entity = entity

kwargs instance-attribute

kwargs = kwargs

project instance-attribute

project = project

__init__

__init__(project: str | None = None, entity: str | None = None, **kwargs: Any)
源代码位于: ztxexp/tracking/adapters.py
def __init__(self, project: str | None = None, entity: str | None = None, **kwargs: Any):
    self.project = project
    self.entity = entity
    self.kwargs = kwargs
    self._run = None

on_metric

on_metric(ctx: RunContext, event: MetricEvent) -> None
源代码位于: ztxexp/tracking/adapters.py
def on_metric(self, ctx: RunContext, event: MetricEvent) -> None:
    if self._run is None:
        return
    wandb = self._wandb()
    wandb.log(dict(event.metrics), step=event.step)

on_run_end

on_run_end(ctx: RunContext, summary: dict[str, object]) -> None
源代码位于: ztxexp/tracking/adapters.py
def on_run_end(self, ctx: RunContext, summary: dict[str, object]) -> None:
    if self._run is None:
        return
    self._run.summary.update(summary)
    self._run.finish()
    self._run = None

on_run_start

on_run_start(ctx: RunContext, meta: RunMetadata) -> None
源代码位于: ztxexp/tracking/adapters.py
def on_run_start(self, ctx: RunContext, meta: RunMetadata) -> None:
    wandb = self._wandb()
    self._run = wandb.init(
        project=self.project,
        entity=self.entity,
        config=ctx.config,
        name=ctx.run_id,
        reinit=True,
        **self.kwargs,
    )
    if self._run is not None:
        if meta.experiment_name:
            self._run.summary["experiment_name"] = meta.experiment_name
        if meta.group:
            self._run.summary["group"] = meta.group

init_torch_env

init_torch_env(seed: int = 3407, use_gpu: bool = True, gpu_id: int = 0, deterministic: bool = False, benchmark: bool = False) -> Any

初始化 PyTorch 实验环境。

参数:

名称 类型 描述 默认
seed int

全局随机种子(Python/NumPy/Torch 同步设置)。

3407
use_gpu bool

是否优先使用 GPU。

True
gpu_id int

当启用 GPU 时使用的设备 ID。

0
deterministic bool

是否启用 cuDNN 确定性模式。

False
benchmark bool

是否启用 cuDNN benchmark 自动搜索最优算法。

False

返回:

类型 描述
Any

torch.device: 最终使用的设备对象(CPU 或 CUDA)。

引发:

类型 描述
ImportError

未安装 torch 时抛出。

示例:

>>> device = init_torch_env(seed=42, use_gpu=False)
>>> str(device)
'cpu'
源代码位于: ztxexp/environment.py
def init_torch_env(
    seed: int = 3407,
    use_gpu: bool = True,
    gpu_id: int = 0,
    deterministic: bool = False,
    benchmark: bool = False,
) -> Any:
    """初始化 PyTorch 实验环境。

    Args:
        seed: 全局随机种子(Python/NumPy/Torch 同步设置)。
        use_gpu: 是否优先使用 GPU。
        gpu_id: 当启用 GPU 时使用的设备 ID。
        deterministic: 是否启用 cuDNN 确定性模式。
        benchmark: 是否启用 cuDNN benchmark 自动搜索最优算法。

    Returns:
        torch.device: 最终使用的设备对象(CPU 或 CUDA)。

    Raises:
        ImportError: 未安装 torch 时抛出。

    Examples:
        >>> device = init_torch_env(seed=42, use_gpu=False)
        >>> str(device)
        'cpu'
    """
    # 延迟导入 torch,避免非 torch 用户在导入包时失败。
    torch = _require_torch()

    # 设置三套随机种子,提升实验复现稳定性。
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    # GPU 路径:设置设备与 cuDNN 行为。
    if torch.cuda.is_available() and use_gpu:
        torch.cuda.manual_seed_all(seed)
        torch.cuda.set_device(gpu_id)
        torch.backends.cudnn.deterministic = deterministic
        torch.backends.cudnn.benchmark = benchmark
        print(f"Using GPU: {gpu_id}")
        return torch.device(f"cuda:{gpu_id}")

    # CPU 回退路径。
    print("Using CPU")
    return torch.device("cpu")

set_process_priority

set_process_priority(priority: str = 'high') -> None

设置当前进程优先级。

参数:

名称 类型 描述 默认
priority str

优先级级别,可选值为 high / normal / low

'high'

返回:

类型 描述
None

None

引发:

类型 描述
PermissionError

当前权限不足,无法修改进程优先级。

示例:

>>> set_process_priority("normal")
源代码位于: ztxexp/environment.py
def set_process_priority(priority: str = "high") -> None:
    """设置当前进程优先级。

    Args:
        priority: 优先级级别,可选值为 ``high`` / ``normal`` / ``low``。

    Returns:
        None

    Raises:
        PermissionError: 当前权限不足,无法修改进程优先级。

    Examples:
        >>> set_process_priority("normal")
    """
    process = psutil.Process(os.getpid())
    selected = priority.strip().lower()

    try:
        # Windows 使用优先级类;类 Unix 使用 nice 值。
        if os.name == "nt":
            if selected == "high":
                process.nice(psutil.HIGH_PRIORITY_CLASS)
            elif selected == "low":
                process.nice(psutil.BELOW_NORMAL_PRIORITY_CLASS)
            else:
                process.nice(psutil.NORMAL_PRIORITY_CLASS)
        else:
            if selected == "high":
                process.nice(-10)
            elif selected == "low":
                process.nice(10)
            else:
                process.nice(0)

    except (psutil.AccessDenied, PermissionError) as exc:  # pragma: no cover
        raise PermissionError(
            "Insufficient permission to change process priority. "
            "Try running with elevated privileges or use 'normal'."
        ) from exc