跳转至

ztxexp.template_wizard

ztxexp.template_wizard

Interactive template wizard for ztxexp init-template.

This module provides a command-line questionnaire and scaffold generator that produces a runnable experiment template project.

MetricMode module-attribute

MetricMode = Literal['final_only', 'final_plus_stream']

PRESET_MODULES module-attribute

PRESET_MODULES = ('data', 'model', 'train', 'eval', 'infer', 'deploy')

RunMode module-attribute

RunMode = Literal['sequential', 'process_pool', 'joblib', 'dynamic']

SUPPORTED_METRIC_MODES module-attribute

SUPPORTED_METRIC_MODES = ('final_only', 'final_plus_stream')

SUPPORTED_RUN_MODES module-attribute

SUPPORTED_RUN_MODES = ('sequential', 'process_pool', 'joblib', 'dynamic')

SUPPORTED_TRACKERS module-attribute

SUPPORTED_TRACKERS = ('none', 'jsonl', 'mlflow', 'wandb')

TEMPLATE_MARKER_FILENAME module-attribute

TEMPLATE_MARKER_FILENAME = '.ztxexp-managed-template.json'

TrackerMode module-attribute

TrackerMode = Literal['none', 'jsonl', 'mlflow', 'wandb']

TemplateAnswers dataclass

Collected answers for template generation.

源代码位于: ztxexp/template_wizard.py
@dataclass(slots=True)
class TemplateAnswers:
    """Collected answers for template generation."""

    experiment_name: str
    ablation: bool
    with_models_dir: bool
    modules: list[str]
    mode: RunMode
    metric_mode: MetricMode
    tracker: TrackerMode

ablation instance-attribute

ablation: bool

experiment_name instance-attribute

experiment_name: str

metric_mode instance-attribute

metric_mode: MetricMode

mode instance-attribute

mode: RunMode

modules instance-attribute

modules: list[str]

tracker instance-attribute

tracker: TrackerMode

with_models_dir instance-attribute

with_models_dir: bool

__init__

__init__(experiment_name: str, ablation: bool, with_models_dir: bool, modules: list[str], mode: RunMode, metric_mode: MetricMode, tracker: TrackerMode) -> None

TemplateRenderResult dataclass

Template rendering result.

源代码位于: ztxexp/template_wizard.py
@dataclass(slots=True)
class TemplateRenderResult:
    """Template rendering result."""

    template_dir: Path
    action: str
    changed: bool
    dry_run: bool
    warnings: list[str] = field(default_factory=list)
    files: list[Path] = field(default_factory=list)

    def summary_lines(self) -> list[str]:
        """Build summary lines for CLI output."""
        lines = [
            (
                f"template_dir={self.template_dir}, action={self.action}, "
                f"changed={self.changed}, dry_run={self.dry_run}"
            )
        ]
        if self.warnings:
            for warning in self.warnings:
                lines.append(f"warning={warning}")
        if self.files:
            lines.append(f"files={len(self.files)}")
            for file_path in self.files:
                lines.append(f"file={file_path}")
        return lines

action instance-attribute

action: str

changed instance-attribute

changed: bool

dry_run instance-attribute

dry_run: bool

files class-attribute instance-attribute

files: list[Path] = field(default_factory=list)

template_dir instance-attribute

template_dir: Path

warnings class-attribute instance-attribute

warnings: list[str] = field(default_factory=list)

__init__

__init__(template_dir: Path, action: str, changed: bool, dry_run: bool, warnings: list[str] = list(), files: list[Path] = list()) -> None

summary_lines

summary_lines() -> list[str]

Build summary lines for CLI output.

源代码位于: ztxexp/template_wizard.py
def summary_lines(self) -> list[str]:
    """Build summary lines for CLI output."""
    lines = [
        (
            f"template_dir={self.template_dir}, action={self.action}, "
            f"changed={self.changed}, dry_run={self.dry_run}"
        )
    ]
    if self.warnings:
        for warning in self.warnings:
            lines.append(f"warning={warning}")
    if self.files:
        lines.append(f"files={len(self.files)}")
        for file_path in self.files:
            lines.append(f"file={file_path}")
    return lines

ask_template_questions

ask_template_questions(name: str | None = None, yes: bool = False, input_fn: Callable[[str], str] | None = None) -> TemplateAnswers

Ask interactive questions and return normalized answers.

参数:

名称 类型 描述 默认
name str | None

Optional preset experiment name.

None
yes bool

If True, use recommended defaults without prompting.

False
input_fn Callable[[str], str] | None

Input function for testing or custom IO.

None

返回:

名称 类型 描述
TemplateAnswers TemplateAnswers

Normalized questionnaire results.

源代码位于: ztxexp/template_wizard.py
def ask_template_questions(
    name: str | None = None,
    yes: bool = False,
    input_fn: Callable[[str], str] | None = None,
) -> TemplateAnswers:
    """Ask interactive questions and return normalized answers.

    Args:
        name: Optional preset experiment name.
        yes: If True, use recommended defaults without prompting.
        input_fn: Input function for testing or custom IO.

    Returns:
        TemplateAnswers: Normalized questionnaire results.
    """
    if input_fn is None:
        input_fn = input

    default_name = _sanitize_identifier(name or "my_experiment", default="my_experiment")
    if yes:
        return TemplateAnswers(
            experiment_name=default_name,
            ablation=False,
            with_models_dir=True,
            modules=["data", "model", "train", "eval"],
            mode="sequential",
            metric_mode="final_plus_stream",
            tracker="jsonl",
        )

    experiment_name = _sanitize_identifier(
        _ask_text("1) 实验名称", default_name, input_fn=input_fn),
        default="my_experiment",
    )
    ablation = _ask_yes_no("2) 是否进行消融实验?", default=False, input_fn=input_fn)
    with_models_dir = _ask_yes_no("3) 是否生成 models 目录?", default=True, input_fn=input_fn)
    modules = _collect_modules_interactive(input_fn=input_fn)
    mode = _ask_choice(
        "5) 执行模式:",
        options=SUPPORTED_RUN_MODES,
        default_index=1,
        input_fn=input_fn,
    )
    metric_mode = _ask_choice(
        "6) 指标记录方式:",
        options=SUPPORTED_METRIC_MODES,
        default_index=2,
        input_fn=input_fn,
    )
    tracker = _ask_choice(
        "7) 追踪器:",
        options=SUPPORTED_TRACKERS,
        default_index=2,
        input_fn=input_fn,
    )

    return TemplateAnswers(
        experiment_name=experiment_name,
        ablation=ablation,
        with_models_dir=with_models_dir,
        modules=modules,
        mode=mode,  # type: ignore[arg-type]
        metric_mode=metric_mode,  # type: ignore[arg-type]
        tracker=tracker,  # type: ignore[arg-type]
    )

build_template_plan

build_template_plan(answers: TemplateAnswers, project_root: Path, output_dir: str | Path | None = None) -> dict[str, object]

Build scaffold plan from answers.

源代码位于: ztxexp/template_wizard.py
def build_template_plan(
    answers: TemplateAnswers,
    project_root: Path,
    output_dir: str | Path | None = None,
) -> dict[str, object]:
    """Build scaffold plan from answers."""
    if output_dir is None:
        out_root = project_root / "experiments"
    else:
        candidate = Path(output_dir).expanduser()
        out_root = candidate if candidate.is_absolute() else project_root / candidate

    out_root = out_root.resolve()
    template_dir = (out_root / answers.experiment_name).resolve()

    return {
        "project_root": project_root,
        "output_root": out_root,
        "template_dir": template_dir,
    }

check_init_prerequisites

check_init_prerequisites(project_root: Path) -> list[str]

Check optional prerequisites and return warning messages.

The check is advisory and never blocks template generation.

源代码位于: ztxexp/template_wizard.py
def check_init_prerequisites(project_root: Path) -> list[str]:
    """Check optional prerequisites and return warning messages.

    The check is advisory and never blocks template generation.
    """
    warnings: list[str] = []

    vibe_ok = False
    for filename in ("AGENTS.md", "agents.md", "agents.MD"):
        path = project_root / filename
        if not path.exists():
            continue
        text = path.read_text(encoding="utf-8")
        if START_MARKER in text and END_MARKER in text:
            vibe_ok = True
            break

    if not vibe_ok:
        warnings.append(
            "未检测到 vibe 受管区块,建议先执行: ztxexp init-vibe"
        )

    skill_candidates = [
        project_root / "skills" / SKILL_NAME / "SKILL.md",
        project_root / ".codex" / "skills" / SKILL_NAME / "SKILL.md",
    ]
    if not any(path.exists() for path in skill_candidates):
        warnings.append(
            "未检测到 ztx-exp-manager skill,建议先执行: ztxexp init-skill"
        )

    return warnings

create_template_scaffold

create_template_scaffold(plan: dict[str, object], answers: TemplateAnswers, dry_run: bool = False, force: bool = False, warnings: list[str] | None = None) -> TemplateRenderResult

Create template scaffold according to plan and answers.

源代码位于: ztxexp/template_wizard.py
def create_template_scaffold(
    plan: dict[str, object],
    answers: TemplateAnswers,
    dry_run: bool = False,
    force: bool = False,
    warnings: list[str] | None = None,
) -> TemplateRenderResult:
    """Create template scaffold according to plan and answers."""
    template_dir = Path(plan["template_dir"])  # type: ignore[arg-type]
    existed = template_dir.exists()

    if existed and not _is_managed_template(template_dir) and not force:
        raise ValueError(
            f"Template directory exists and is unmanaged: {template_dir}. "
            "Use --force to overwrite generated files."
        )

    base_config_payload = (
        json.dumps(render_base_config(answers), ensure_ascii=False, indent=2) + "\n"
    )
    main_script = render_template_script(answers)
    marker_payload = _render_marker(answers)

    files_to_write: dict[Path, str] = {
        template_dir / "main_experiment.py": main_script,
        template_dir / "configs" / "base.json": base_config_payload,
        template_dir / "artifacts" / ".gitkeep": "",
        template_dir / "modules" / "__init__.py": "",
        template_dir / TEMPLATE_MARKER_FILENAME: marker_payload,
    }

    if answers.with_models_dir:
        files_to_write[template_dir / "models" / ".gitkeep"] = ""

    for module in answers.modules:
        module_name = _sanitize_identifier(module, default="module")
        files_to_write[template_dir / "modules" / module_name / "__init__.py"] = (
            f'"""Module stub for `{module_name}`."""\n'
        )

    changed_files: list[Path] = []
    for path, content in files_to_write.items():
        if not path.exists() or path.read_text(encoding="utf-8") != content:
            changed_files.append(path)

    changed = bool(changed_files) or not existed
    if not changed:
        action = "unchanged"
    else:
        action = "create" if not existed else "update"
        if dry_run:
            action = f"would_{action}"

    if changed and not dry_run:
        for path, content in files_to_write.items():
            path.parent.mkdir(parents=True, exist_ok=True)
            path.write_text(content, encoding="utf-8")

    return TemplateRenderResult(
        template_dir=template_dir,
        action=action,
        changed=changed,
        dry_run=dry_run,
        warnings=list(warnings or []),
        files=sorted(files_to_write.keys()),
    )

is_interactive_terminal

is_interactive_terminal() -> bool

Return whether current stdin/stdout are interactive terminals.

源代码位于: ztxexp/template_wizard.py
def is_interactive_terminal() -> bool:
    """Return whether current stdin/stdout are interactive terminals."""
    return bool(sys.stdin.isatty() and sys.stdout.isatty())

render_base_config

render_base_config(answers: TemplateAnswers) -> dict[str, object]

Render base config payload for generated template.

源代码位于: ztxexp/template_wizard.py
def render_base_config(answers: TemplateAnswers) -> dict[str, object]:
    """Render base config payload for generated template."""
    return {
        "experiment_name": answers.experiment_name,
        "seed": 42,
        "ablation": answers.ablation,
        "with_models_dir": answers.with_models_dir,
        "modules": answers.modules,
        "default_mode": answers.mode,
        "metric_mode": answers.metric_mode,
        "tracker": answers.tracker,
    }

render_template_script

render_template_script(answers: TemplateAnswers) -> str

Render generated command-line experiment script.

源代码位于: ztxexp/template_wizard.py
def render_template_script(answers: TemplateAnswers) -> str:
    """Render generated command-line experiment script."""
    module_stubs, module_calls = _render_module_stubs(answers.modules)

    ablation_snippet = ""
    if answers.ablation:
        ablation_snippet = (
            "    pipeline = pipeline.variants([\n"
            '        {"ablation_branch": "baseline", "use_optional_block": True},\n'
            '        {"ablation_branch": "no_optional_block", "use_optional_block": False},\n'
            "    ])\n"
        )

    stream_metric_snippet = ""
    if answers.metric_mode == "final_plus_stream":
        stream_metric_snippet = (
            "    ctx.log_metric(step=1, metrics={\"loss\": 0.88}, split=\"train\", phase=\"fit\")\n"
            "    ctx.log_metric(step=2, metrics={\"loss\": 0.73}, split=\"train\", phase=\"fit\")\n"
        )

    tracker_snippet = _render_tracker_snippet(answers.tracker, answers.experiment_name)
    default_results_root = f"./results_{answers.experiment_name}"

    return f'''"""Generated by `ztxexp init-template`."""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Any

from ztxexp import ExperimentPipeline, ResultAnalyzer, RunContext


{module_stubs}


def exp_fn(ctx: RunContext) -> dict[str, float] | None:
    """Single experiment function."""
    cfg = ctx.config
    lr = float(cfg.get("lr", 0.001))
{stream_metric_snippet}    artifact_path = Path(ctx.run_dir) / "artifacts" / "summary.txt"
    artifact_path.write_text(
        f"run={{ctx.run_id}}, lr={{lr}}\\n",
        encoding="utf-8",
    )

{module_calls}

    score = max(0.0, 1.0 - lr)
    return {{"score": round(score, 4), "latency_ms": round(120 + lr * 1000, 3)}}


def _load_base_config() -> dict[str, Any]:
    config_path = Path(__file__).resolve().parent / "configs" / "base.json"
    return json.loads(config_path.read_text(encoding="utf-8"))


def build_pipeline(results_root: str) -> ExperimentPipeline:
    base_config = _load_base_config()
    pipeline = (
        ExperimentPipeline(results_root=results_root, base_config=base_config)
        .grid({{"lr": [0.001, 0.01]}})
    )
{ablation_snippet}{tracker_snippet}    return pipeline


def run_experiment(results_root: str, mode: str) -> None:
    summary = build_pipeline(results_root=results_root).run(exp_fn, mode=mode)
    print(summary)


def analyze_results(results_root: str, output_csv: str | None) -> None:
    analyzer = ResultAnalyzer(results_root)
    df = analyzer.to_dataframe(statuses=("succeeded",))
    print(df.head(20))
    if output_csv:
        analyzer.to_csv(output_csv, statuses=("succeeded",))


def clean_results(results_root: str, apply_delete: bool) -> None:
    analyzer = ResultAnalyzer(results_root)
    analyzer.clean_results(
        statuses=("failed", "running", "skipped"),
        dry_run=not apply_delete,
        confirm=False,
    )


def main() -> int:
    parser = argparse.ArgumentParser(description="Generated experiment CLI")
    subparsers = parser.add_subparsers(dest="command", required=True)

    run_parser = subparsers.add_parser("run", help="Run experiments")
    run_parser.add_argument("--results-root", default="{default_results_root}")
    run_parser.add_argument(
        "--mode",
        choices={list(SUPPORTED_RUN_MODES)},
        default="{answers.mode}",
    )

    analyze_parser = subparsers.add_parser("analyze", help="Analyze experiment results")
    analyze_parser.add_argument("--results-root", default="{default_results_root}")
    analyze_parser.add_argument("--output-csv", default=None)

    clean_parser = subparsers.add_parser("clean", help="Clean failed/skipped/running runs")
    clean_parser.add_argument("--results-root", default="{default_results_root}")
    clean_parser.add_argument("--apply", action="store_true")

    args = parser.parse_args()
    if args.command == "run":
        run_experiment(results_root=args.results_root, mode=args.mode)
        return 0
    if args.command == "analyze":
        analyze_results(results_root=args.results_root, output_csv=args.output_csv)
        return 0
    if args.command == "clean":
        clean_results(results_root=args.results_root, apply_delete=args.apply)
        return 0
    parser.error(f"Unknown command: {{args.command}}")
    return 2


if __name__ == "__main__":
    raise SystemExit(main())
'''

resolve_project_root

resolve_project_root(project_root: str | Path | None) -> Path

Resolve and validate project root directory.

源代码位于: ztxexp/template_wizard.py
def resolve_project_root(project_root: str | Path | None) -> Path:
    """Resolve and validate project root directory."""
    root = Path.cwd() if project_root is None else Path(project_root)
    root = root.expanduser().resolve()
    if root.exists() and not root.is_dir():
        raise ValueError(f"project root is not a directory: {root}")
    if not root.exists():
        raise ValueError(f"project root does not exist: {root}")
    return root