leagents 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
leagents/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ """leagents — agentic orchestration for the LeRobot pipeline."""
2
+
3
+ __version__ = "0.0.4"
@@ -0,0 +1,21 @@
1
+ from leagents.agents.base import RunResult, Runner, dry_runner, subprocess_runner
2
+ from leagents.agents.data_agent import DataAgent
3
+ from leagents.agents.eval_agent import EvalAgent, EvalError
4
+ from leagents.agents.improve_agent import ImproveAgent, ImproveError
5
+ from leagents.agents.knowledge_agent import KnowledgeAgent
6
+ from leagents.agents.train_agent import TrainAgent, TrainError
7
+
8
+ __all__ = [
9
+ "RunResult",
10
+ "Runner",
11
+ "dry_runner",
12
+ "subprocess_runner",
13
+ "DataAgent",
14
+ "EvalAgent",
15
+ "EvalError",
16
+ "ImproveAgent",
17
+ "ImproveError",
18
+ "KnowledgeAgent",
19
+ "TrainAgent",
20
+ "TrainError",
21
+ ]
@@ -0,0 +1,42 @@
1
+ """Shared subprocess runner for agents that wrap LeRobot CLIs.
2
+
3
+ Runners are injectable so tests (and --dry-run) exercise the full loop
4
+ without lerobot installed or a GPU present.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import subprocess
10
+ import time
11
+ from dataclasses import dataclass
12
+ from pathlib import Path
13
+ from typing import Callable, Sequence
14
+
15
+
16
+ @dataclass
17
+ class RunResult:
18
+ cmd: list[str]
19
+ exit_code: int
20
+ duration_s: float
21
+ log_path: Path | None = None
22
+
23
+
24
+ Runner = Callable[[Sequence[str], Path], RunResult]
25
+
26
+
27
+ def subprocess_runner(cmd: Sequence[str], log_path: Path) -> RunResult:
28
+ """Run a command, teeing stdout+stderr to a log file.
29
+
30
+ Training jobs stream output for hours — never buffer it in memory.
31
+ """
32
+ log_path.parent.mkdir(parents=True, exist_ok=True)
33
+ start = time.monotonic()
34
+ with log_path.open("w") as log:
35
+ proc = subprocess.run(list(cmd), stdout=log, stderr=subprocess.STDOUT)
36
+ return RunResult(list(cmd), proc.returncode, time.monotonic() - start, log_path)
37
+
38
+
39
+ def dry_runner(cmd: Sequence[str], log_path: Path) -> RunResult:
40
+ log_path.parent.mkdir(parents=True, exist_ok=True)
41
+ log_path.write_text("DRY RUN: " + " ".join(cmd) + "\n")
42
+ return RunResult(list(cmd), 0, 0.0, log_path)
@@ -0,0 +1,38 @@
1
+ """Data Agent — M0 scope (DESIGN.md §3.2, §8).
2
+
3
+ M0 collects no new data: the agent resolves the seed dataset from the
4
+ proposal and records provenance for the cycle. RoboGene-style curation,
5
+ MimicGen-style amplification, and GenAug-style augmentation land in M1.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from leagents.contracts import DatasetRef, Proposal
11
+ from leagents.events import Event, EventBus
12
+
13
+
14
+ class DataAgent:
15
+ def __init__(self, bus: EventBus):
16
+ self.bus = bus
17
+
18
+ def run(self, *, run_id: str, cycle: int, proposal: Proposal) -> DatasetRef:
19
+ ref = DatasetRef(
20
+ repo_id=proposal.dataset,
21
+ num_episodes=proposal.num_episodes,
22
+ notes=proposal.notes,
23
+ )
24
+ self.bus.emit(
25
+ Event(
26
+ run_id=run_id,
27
+ stage="collect",
28
+ kind="dataset_resolved",
29
+ cycle=cycle,
30
+ payload={
31
+ "action": proposal.action,
32
+ "repo_id": ref.repo_id,
33
+ "num_episodes": ref.num_episodes,
34
+ "notes": ref.notes,
35
+ },
36
+ )
37
+ )
38
+ return ref
@@ -0,0 +1,105 @@
1
+ """Eval Agent — `lerobot-eval` LIBERO gate (DESIGN.md §3.4)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ from leagents.agents.base import Runner, subprocess_runner
10
+ from leagents.config import EvalConfig
11
+ from leagents.contracts import CheckpointRecord, EvalReport
12
+ from leagents.events import Event, EventBus
13
+ from leagents.orchestrator.constitution import Constitution, ConstitutionError
14
+
15
+
16
+ class EvalError(Exception):
17
+ pass
18
+
19
+
20
+ class EvalAgent:
21
+ def __init__(
22
+ self,
23
+ cfg: EvalConfig,
24
+ constitution: Constitution,
25
+ bus: EventBus,
26
+ runner: Runner = subprocess_runner,
27
+ ):
28
+ self.cfg = cfg
29
+ self.constitution = constitution
30
+ self.bus = bus
31
+ self.runner = runner
32
+
33
+ def build_command(self, checkpoint: CheckpointRecord, output_dir: Path) -> list[str]:
34
+ return [
35
+ "lerobot-eval",
36
+ f"--policy.path={checkpoint.path}",
37
+ f"--env.type={self.cfg.env_type}",
38
+ f"--env.task={self.cfg.task_suite}",
39
+ f"--eval.n_episodes={self.cfg.n_episodes}",
40
+ # lerobot-eval rejects batch_size > n_episodes (default batch is 50)
41
+ f"--eval.batch_size={min(self.cfg.batch_size, self.cfg.n_episodes)}",
42
+ f"--output_dir={output_dir}",
43
+ *self.cfg.extra_args,
44
+ ]
45
+
46
+ def run(
47
+ self, *, run_id: str, cycle: int, checkpoint: CheckpointRecord, workdir: Path
48
+ ) -> EvalReport:
49
+ output_dir = workdir / f"cycle_{cycle}" / "eval"
50
+ cmd = self.build_command(checkpoint, output_dir)
51
+
52
+ for verdict in (self.constitution.check_eval(self.cfg.env_type, self.cfg.n_episodes),
53
+ self.constitution.check_command(cmd)):
54
+ if not verdict.allowed:
55
+ self.bus.emit(Event(run_id, "eval", "constitution_denied", cycle,
56
+ {"rule": verdict.rule, "reason": verdict.reason}))
57
+ raise ConstitutionError(verdict)
58
+
59
+ self.bus.emit(Event(run_id, "eval", "job_started", cycle, {"cmd": cmd}))
60
+ # log lives outside output_dir: lerobot CLIs refuse a pre-existing output_dir
61
+ result = self.runner(cmd, output_dir.parent / "eval.log")
62
+ self.bus.emit(Event(run_id, "eval", "job_finished", cycle,
63
+ {"exit_code": result.exit_code, "duration_s": result.duration_s,
64
+ "log": str(result.log_path)}))
65
+ if result.exit_code != 0:
66
+ raise EvalError(f"lerobot-eval exited {result.exit_code}, see {result.log_path}")
67
+
68
+ success_rate, per_task, raw = self._parse_eval_info(output_dir / "eval_info.json")
69
+ report = EvalReport(
70
+ checkpoint=checkpoint.path,
71
+ env_type=self.cfg.env_type,
72
+ task_suite=self.cfg.task_suite,
73
+ n_episodes=self.cfg.n_episodes,
74
+ success_rate=success_rate,
75
+ per_task=per_task,
76
+ raw=raw,
77
+ )
78
+ self.bus.emit(Event(run_id, "eval", "report", cycle,
79
+ {"success_rate": success_rate, "task_suite": self.cfg.task_suite}))
80
+ return report
81
+
82
+ @staticmethod
83
+ def _parse_eval_info(path: Path) -> tuple[float, dict[str, float], dict[str, Any]]:
84
+ """Parse lerobot-eval's eval_info.json.
85
+
86
+ lerobot 0.5 schema: {"overall": {"pc_success": <0-100>, ...},
87
+ "per_group": {<group>: {"pc_success": ...}}, "per_task": [...]}.
88
+ Returns (success fraction, per-group success fractions, raw data).
89
+ """
90
+ if not path.exists():
91
+ raise EvalError(f"eval output missing: {path}")
92
+ data: dict[str, Any] = json.loads(path.read_text())
93
+ aggregate = data.get("overall") or data.get("aggregated") or data
94
+ if "pc_success" in aggregate: # percentage 0-100 by definition
95
+ success = float(aggregate["pc_success"]) / 100.0
96
+ elif "success_rate" in aggregate: # fraction 0-1
97
+ success = float(aggregate["success_rate"])
98
+ else:
99
+ raise EvalError(f"no success metric found in {path}")
100
+ per_task = {
101
+ group: float(metrics["pc_success"]) / 100.0
102
+ for group, metrics in (data.get("per_group") or {}).items()
103
+ if isinstance(metrics, dict) and "pc_success" in metrics
104
+ }
105
+ return success, per_task, data
@@ -0,0 +1,140 @@
1
+ """Improvement Agent — DexFlyWheel step 3: success-filtered rollout collection
2
+ (M1, DESIGN.md §3.5).
3
+
4
+ Runs the blessed policy in sim via ``leagents.scripts.collect_rollouts`` (a
5
+ subprocess, like every other agent) and keeps only successful episodes as a
6
+ new LeRobotDataset. Residual-RL (flywheel step 2) and merging the rollout
7
+ dataset into the training mix are the remaining M1 work.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ import sys
14
+ from pathlib import Path
15
+
16
+ from leagents.agents.base import Runner, subprocess_runner
17
+ from leagents.config import EvalConfig, ImproveConfig
18
+ from leagents.contracts import CheckpointRecord, DatasetRef
19
+ from leagents.events import Event, EventBus
20
+ from leagents.orchestrator.constitution import Constitution, ConstitutionError
21
+
22
+
23
+ class ImproveError(Exception):
24
+ pass
25
+
26
+
27
+ class ImproveAgent:
28
+ def __init__(
29
+ self,
30
+ cfg: ImproveConfig,
31
+ eval_cfg: EvalConfig,
32
+ constitution: Constitution,
33
+ bus: EventBus,
34
+ runner: Runner = subprocess_runner,
35
+ ):
36
+ self.cfg = cfg
37
+ self.eval_cfg = eval_cfg
38
+ self.constitution = constitution
39
+ self.bus = bus
40
+ self.runner = runner
41
+
42
+ def build_command(self, checkpoint: CheckpointRecord, out_dir: Path, repo_id: str) -> list[str]:
43
+ cmd = [
44
+ sys.executable, "-m", "leagents.scripts.collect_rollouts",
45
+ f"--policy-path={checkpoint.path}",
46
+ f"--env-type={self.eval_cfg.env_type}",
47
+ f"--task={self.eval_cfg.task_suite}",
48
+ f"--episodes={self.cfg.episodes}",
49
+ f"--out={out_dir}",
50
+ f"--repo-id={repo_id}",
51
+ f"--device={self.cfg.device}",
52
+ *self.cfg.extra_args,
53
+ ]
54
+ if self.cfg.task_text:
55
+ cmd.append(f"--task-text={self.cfg.task_text}")
56
+ return cmd
57
+
58
+ def run(
59
+ self, *, run_id: str, cycle: int, checkpoint: CheckpointRecord, workdir: Path,
60
+ prev_mix: DatasetRef | None = None,
61
+ ) -> tuple[DatasetRef, dict]:
62
+ out_dir = workdir / f"cycle_{cycle}" / "rollouts"
63
+ repo_id = f"local/{run_id}-cycle{cycle}-rollouts"
64
+ cmd = self.build_command(checkpoint, out_dir, repo_id)
65
+
66
+ for verdict in (self.constitution.check_eval(self.eval_cfg.env_type, self.cfg.episodes),
67
+ self.constitution.check_command(cmd)):
68
+ if not verdict.allowed:
69
+ self.bus.emit(Event(run_id, "improve", "constitution_denied", cycle,
70
+ {"rule": verdict.rule, "reason": verdict.reason}))
71
+ raise ConstitutionError(verdict)
72
+
73
+ self.bus.emit(Event(run_id, "improve", "job_started", cycle, {"cmd": cmd}))
74
+ result = self.runner(cmd, out_dir.parent / "rollouts.log")
75
+ self.bus.emit(Event(run_id, "improve", "job_finished", cycle,
76
+ {"exit_code": result.exit_code, "duration_s": result.duration_s,
77
+ "log": str(result.log_path)}))
78
+ if result.exit_code != 0:
79
+ raise ImproveError(f"rollout collection exited {result.exit_code}, "
80
+ f"see {result.log_path}")
81
+
82
+ summary = self._parse_summary(result.log_path)
83
+ self.bus.emit(Event(run_id, "improve", "rollouts_collected", cycle, summary))
84
+ ref = DatasetRef(
85
+ repo_id=repo_id,
86
+ root=str(out_dir),
87
+ num_episodes=summary.get("kept"),
88
+ notes=f"success-filtered rollouts from cycle {cycle} blessed checkpoint",
89
+ )
90
+ if summary.get("kept", 0) == 0:
91
+ return (prev_mix if prev_mix else ref), summary
92
+ if prev_mix is None:
93
+ return ref, summary
94
+ return self._merge(run_id, cycle, workdir, prev_mix, ref), summary
95
+
96
+ def _merge(
97
+ self, run_id: str, cycle: int, workdir: Path, a: DatasetRef, b: DatasetRef
98
+ ) -> DatasetRef:
99
+ """Accumulate rollouts across cycles: mix_{n} = mix_{n-1} ∪ rollouts_n
100
+ (DexFlyWheel's growing dataset), via lerobot-edit-dataset merge."""
101
+ mix_repo_id = f"local/{run_id}-rollouts-mix-c{cycle}"
102
+ mix_root = workdir / f"cycle_{cycle}" / "rollouts_mix"
103
+ cmd = [
104
+ "lerobot-edit-dataset",
105
+ "--operation.type=merge",
106
+ f"--operation.repo_ids=[{a.repo_id}, {b.repo_id}]",
107
+ f"--operation.roots=[{a.root}, {b.root}]",
108
+ f"--new_repo_id={mix_repo_id}",
109
+ f"--new_root={mix_root}",
110
+ "--push_to_hub=false",
111
+ ]
112
+ self.bus.emit(Event(run_id, "improve", "job_started", cycle,
113
+ {"cmd": cmd, "stage": "merge"}))
114
+ result = self.runner(cmd, mix_root.parent / "rollouts_merge.log")
115
+ self.bus.emit(Event(run_id, "improve", "job_finished", cycle,
116
+ {"exit_code": result.exit_code, "duration_s": result.duration_s,
117
+ "log": str(result.log_path)}))
118
+ if result.exit_code != 0:
119
+ raise ImproveError(f"rollout merge exited {result.exit_code}, see {result.log_path}")
120
+ total = (a.num_episodes or 0) + (b.num_episodes or 0)
121
+ merged = DatasetRef(repo_id=mix_repo_id, root=str(mix_root),
122
+ num_episodes=total or None,
123
+ notes=f"rollout mix through cycle {cycle}")
124
+ self.bus.emit(Event(run_id, "improve", "rollouts_merged", cycle,
125
+ {"repo_id": mix_repo_id, "episodes": total}))
126
+ return merged
127
+
128
+ @staticmethod
129
+ def _parse_summary(log_path: Path | None) -> dict:
130
+ """The collector prints a JSON summary as its LAST stdout line."""
131
+ if log_path is None or not Path(log_path).exists():
132
+ raise ImproveError(f"rollout log missing: {log_path}")
133
+ for line in reversed(Path(log_path).read_text().splitlines()):
134
+ line = line.strip()
135
+ if line.startswith("{"):
136
+ try:
137
+ return json.loads(line)
138
+ except json.JSONDecodeError:
139
+ continue
140
+ raise ImproveError(f"no JSON summary found in {log_path}")
@@ -0,0 +1,156 @@
1
+ """Knowledge Agent — OKF knowledge layer (DESIGN.md §3.6, M1).
2
+
3
+ Two operations mirroring the Karpathy-LLM-Wiki workflow:
4
+ - ingest: after each cycle, distill the eval outcome into task/policy
5
+ pages (markdown + YAML frontmatter, per Google's OKF spec shape).
6
+ - lint: health-check pass over the wiki (schema, provenance).
7
+
8
+ Pages are advisory context only — loop control stays deterministic.
9
+ Layer-1 artifacts (events.jsonl, SQLite, eval reports) are never edited;
10
+ this agent writes only under the knowledge root.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import time
16
+ from pathlib import Path
17
+ from typing import Any
18
+
19
+ import yaml
20
+
21
+ from leagents.contracts import EvalReport
22
+ from leagents.events import Event, EventBus
23
+ from leagents.llm import LLMClient, NullLLM
24
+
25
+ _PAGE_TYPES = {"task", "policy", "experiment"}
26
+ _STATUSES = {"observed-once", "replicated", "human-confirmed"}
27
+
28
+ _LESSONS_HEADING = "## Lessons"
29
+
30
+ _LESSONS_SYSTEM = (
31
+ "You maintain one page of a robotics run-knowledge wiki. Given the page and a new "
32
+ "observation, rewrite ONLY the Lessons section as a short markdown bullet list of "
33
+ "durable, actionable lessons. No preamble, bullets only."
34
+ )
35
+
36
+
37
+ def _slug(name: str) -> str:
38
+ return name.lower().replace(" ", "-").replace("/", "-")
39
+
40
+
41
+ def _read_page(path: Path) -> tuple[dict[str, Any], str]:
42
+ text = path.read_text()
43
+ if text.startswith("---\n"):
44
+ _, frontmatter, body = text.split("---\n", 2)
45
+ return yaml.safe_load(frontmatter) or {}, body.lstrip("\n")
46
+ return {}, text
47
+
48
+
49
+ def _write_page(path: Path, meta: dict[str, Any], body: str) -> None:
50
+ path.parent.mkdir(parents=True, exist_ok=True)
51
+ frontmatter = yaml.safe_dump(meta, sort_keys=False, allow_unicode=True).strip()
52
+ path.write_text(f"---\n{frontmatter}\n---\n\n{body.strip()}\n")
53
+
54
+
55
+ class KnowledgeAgent:
56
+ def __init__(self, root: Path, bus: EventBus, llm: LLMClient | None = None):
57
+ self.root = Path(root)
58
+ self.bus = bus
59
+ self.llm = llm or NullLLM()
60
+
61
+ # -- ingest ----------------------------------------------------------
62
+ def ingest(
63
+ self,
64
+ *,
65
+ run_id: str,
66
+ cycle: int,
67
+ policy: str,
68
+ report: EvalReport,
69
+ decision: str,
70
+ ) -> list[Path]:
71
+ observation = (
72
+ f"- {time.strftime('%Y-%m-%d')} · run {run_id} · cycle {cycle} · "
73
+ f"{report.task_suite} · policy {policy} · "
74
+ f"success {report.success_rate:.1%} · decision {decision}"
75
+ )
76
+ touched = [
77
+ self._upsert_page(
78
+ self.root / "policies" / f"{_slug(policy)}.md",
79
+ page_type="policy",
80
+ title=f"Policy: {policy}",
81
+ name=policy,
82
+ observation=observation,
83
+ provenance={"run": run_id, "cycle": cycle},
84
+ ),
85
+ self._upsert_page(
86
+ self.root / "tasks" / f"{_slug(report.task_suite)}.md",
87
+ page_type="task",
88
+ title=f"Task: {report.task_suite}",
89
+ name=report.task_suite,
90
+ observation=observation,
91
+ provenance={"run": run_id, "cycle": cycle},
92
+ ),
93
+ ]
94
+ self.bus.emit(
95
+ Event(run_id, "knowledge", "knowledge_updated", cycle,
96
+ {"pages": [str(p) for p in touched]})
97
+ )
98
+ return touched
99
+
100
+ def _upsert_page(
101
+ self,
102
+ path: Path,
103
+ *,
104
+ page_type: str,
105
+ title: str,
106
+ name: str,
107
+ observation: str,
108
+ provenance: dict[str, Any],
109
+ ) -> Path:
110
+ if path.exists():
111
+ meta, body = _read_page(path)
112
+ else:
113
+ meta = {"name": name, "type": page_type, "status": "observed-once",
114
+ "provenance": []}
115
+ body = f"# {title}\n\n## Observations\n\n{_LESSONS_HEADING}\n"
116
+
117
+ meta.setdefault("provenance", []).append(provenance)
118
+ if meta.get("status") != "human-confirmed": # human verdicts are never downgraded
119
+ meta["status"] = "replicated" if len(meta["provenance"]) >= 2 else "observed-once"
120
+ meta["updated"] = time.strftime("%Y-%m-%d")
121
+
122
+ head, sep, lessons = body.partition(_LESSONS_HEADING)
123
+ body = head.rstrip("\n") + f"\n{observation}\n\n" + sep + lessons
124
+
125
+ lessons_update = self.llm.complete(
126
+ f"PAGE:\n{body}\n\nNEW OBSERVATION:\n{observation}", system=_LESSONS_SYSTEM
127
+ ).strip()
128
+ if lessons_update:
129
+ head, _, _ = body.partition(_LESSONS_HEADING)
130
+ body = head + f"{_LESSONS_HEADING}\n{lessons_update}\n"
131
+
132
+ _write_page(path, meta, body)
133
+ return path
134
+
135
+ # -- lint --------------------------------------------------------------
136
+ def lint(self) -> list[dict[str, str]]:
137
+ """Health check over all pages; findings feed the next experiments."""
138
+ findings: list[dict[str, str]] = []
139
+ for path in sorted(self.root.rglob("*.md")):
140
+ if path.name == "KNOWLEDGE.md":
141
+ continue
142
+ try:
143
+ meta, _ = _read_page(path)
144
+ except Exception as exc:
145
+ findings.append({"page": str(path), "problem": f"unparseable: {exc!r}"})
146
+ continue
147
+ if meta.get("type") not in _PAGE_TYPES:
148
+ findings.append({"page": str(path), "problem": f"bad type {meta.get('type')!r}"})
149
+ if meta.get("status") not in _STATUSES:
150
+ findings.append({"page": str(path),
151
+ "problem": f"bad status {meta.get('status')!r}"})
152
+ if not meta.get("provenance"):
153
+ findings.append({"page": str(path), "problem": "missing provenance"})
154
+ self.bus.emit(Event("knowledge-lint", "knowledge", "lint_report",
155
+ payload={"findings": findings}))
156
+ return findings
@@ -0,0 +1,98 @@
1
+ """Training Agent — subprocess wrapper over `lerobot-train` (DESIGN.md §3.3)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ from leagents.agents.base import Runner, subprocess_runner
8
+ from leagents.config import PolicyRung, TrainConfig
9
+ from leagents.contracts import CheckpointRecord, DatasetRef
10
+ from leagents.events import Event, EventBus
11
+ from leagents.orchestrator.constitution import Constitution, ConstitutionError
12
+
13
+
14
+ class TrainError(Exception):
15
+ pass
16
+
17
+
18
+ class TrainAgent:
19
+ def __init__(
20
+ self,
21
+ cfg: TrainConfig,
22
+ constitution: Constitution,
23
+ bus: EventBus,
24
+ runner: Runner = subprocess_runner,
25
+ ):
26
+ self.cfg = cfg
27
+ self.constitution = constitution
28
+ self.bus = bus
29
+ self.runner = runner
30
+
31
+ def build_command(
32
+ self,
33
+ rung: PolicyRung,
34
+ dataset: DatasetRef,
35
+ output_dir: Path,
36
+ init_override: str | None = None,
37
+ steps: int | None = None,
38
+ ) -> list[str]:
39
+ init = init_override or rung.init
40
+ policy_arg = f"--policy.path={init}" if init else f"--policy.type={rung.name}"
41
+ episodes_args = (
42
+ [f"--dataset.episodes={list(range(dataset.num_episodes))}"]
43
+ if dataset.num_episodes and dataset.root is None
44
+ else []
45
+ )
46
+ root_args = [f"--dataset.root={dataset.root}"] if dataset.root else []
47
+ return [
48
+ "lerobot-train",
49
+ policy_arg,
50
+ # lerobot >= 0.5 refuses to train without policy.repo_id unless hub
51
+ # push is explicitly off; leagents manages checkpoints locally.
52
+ "--policy.push_to_hub=false",
53
+ f"--dataset.repo_id={dataset.repo_id}",
54
+ *root_args,
55
+ *episodes_args,
56
+ f"--output_dir={output_dir}",
57
+ f"--steps={steps or self.cfg.steps}",
58
+ f"--batch_size={self.cfg.batch_size}",
59
+ *self.cfg.extra_args,
60
+ ]
61
+
62
+ def run(
63
+ self,
64
+ *,
65
+ run_id: str,
66
+ cycle: int,
67
+ rung: PolicyRung,
68
+ dataset: DatasetRef,
69
+ workdir: Path,
70
+ init_override: str | None = None,
71
+ steps: int | None = None,
72
+ stage: str = "train",
73
+ ) -> CheckpointRecord:
74
+ output_dir = workdir / f"cycle_{cycle}" / stage
75
+ cmd = self.build_command(rung, dataset, output_dir, init_override, steps)
76
+
77
+ for verdict in (self.constitution.check_train(steps or self.cfg.steps),
78
+ self.constitution.check_command(cmd)):
79
+ if not verdict.allowed:
80
+ self.bus.emit(Event(run_id, "train", "constitution_denied", cycle,
81
+ {"rule": verdict.rule, "reason": verdict.reason}))
82
+ raise ConstitutionError(verdict)
83
+
84
+ self.bus.emit(Event(run_id, "train", "job_started", cycle, {"cmd": cmd, "stage": stage}))
85
+ # log lives outside output_dir: lerobot-train refuses a pre-existing output_dir
86
+ result = self.runner(cmd, output_dir.parent / f"{stage}.log")
87
+ self.bus.emit(Event(run_id, "train", "job_finished", cycle,
88
+ {"exit_code": result.exit_code, "duration_s": result.duration_s,
89
+ "log": str(result.log_path)}))
90
+ if result.exit_code != 0:
91
+ raise TrainError(f"lerobot-train exited {result.exit_code}, see {result.log_path}")
92
+
93
+ # lerobot-train checkpoint layout: output_dir/checkpoints/last/pretrained_model
94
+ checkpoint = output_dir / "checkpoints" / "last" / "pretrained_model"
95
+ if not checkpoint.exists():
96
+ self.bus.emit(Event(run_id, "train", "warning", cycle,
97
+ {"msg": f"expected checkpoint missing: {checkpoint}"}))
98
+ return CheckpointRecord(policy_type=rung.name, path=str(checkpoint), cycle=cycle)