openadapt-ml 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. openadapt_ml/__init__.py +0 -0
  2. openadapt_ml/benchmarks/__init__.py +125 -0
  3. openadapt_ml/benchmarks/agent.py +825 -0
  4. openadapt_ml/benchmarks/azure.py +761 -0
  5. openadapt_ml/benchmarks/base.py +366 -0
  6. openadapt_ml/benchmarks/cli.py +884 -0
  7. openadapt_ml/benchmarks/data_collection.py +432 -0
  8. openadapt_ml/benchmarks/runner.py +381 -0
  9. openadapt_ml/benchmarks/waa.py +704 -0
  10. openadapt_ml/cloud/__init__.py +5 -0
  11. openadapt_ml/cloud/azure_inference.py +441 -0
  12. openadapt_ml/cloud/lambda_labs.py +2445 -0
  13. openadapt_ml/cloud/local.py +790 -0
  14. openadapt_ml/config.py +56 -0
  15. openadapt_ml/datasets/__init__.py +0 -0
  16. openadapt_ml/datasets/next_action.py +507 -0
  17. openadapt_ml/evals/__init__.py +23 -0
  18. openadapt_ml/evals/grounding.py +241 -0
  19. openadapt_ml/evals/plot_eval_metrics.py +174 -0
  20. openadapt_ml/evals/trajectory_matching.py +486 -0
  21. openadapt_ml/grounding/__init__.py +45 -0
  22. openadapt_ml/grounding/base.py +236 -0
  23. openadapt_ml/grounding/detector.py +570 -0
  24. openadapt_ml/ingest/__init__.py +43 -0
  25. openadapt_ml/ingest/capture.py +312 -0
  26. openadapt_ml/ingest/loader.py +232 -0
  27. openadapt_ml/ingest/synthetic.py +1102 -0
  28. openadapt_ml/models/__init__.py +0 -0
  29. openadapt_ml/models/api_adapter.py +171 -0
  30. openadapt_ml/models/base_adapter.py +59 -0
  31. openadapt_ml/models/dummy_adapter.py +42 -0
  32. openadapt_ml/models/qwen_vl.py +426 -0
  33. openadapt_ml/runtime/__init__.py +0 -0
  34. openadapt_ml/runtime/policy.py +182 -0
  35. openadapt_ml/schemas/__init__.py +53 -0
  36. openadapt_ml/schemas/sessions.py +122 -0
  37. openadapt_ml/schemas/validation.py +252 -0
  38. openadapt_ml/scripts/__init__.py +0 -0
  39. openadapt_ml/scripts/compare.py +1490 -0
  40. openadapt_ml/scripts/demo_policy.py +62 -0
  41. openadapt_ml/scripts/eval_policy.py +287 -0
  42. openadapt_ml/scripts/make_gif.py +153 -0
  43. openadapt_ml/scripts/prepare_synthetic.py +43 -0
  44. openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
  45. openadapt_ml/scripts/train.py +174 -0
  46. openadapt_ml/training/__init__.py +0 -0
  47. openadapt_ml/training/benchmark_viewer.py +1538 -0
  48. openadapt_ml/training/shared_ui.py +157 -0
  49. openadapt_ml/training/stub_provider.py +276 -0
  50. openadapt_ml/training/trainer.py +2446 -0
  51. openadapt_ml/training/viewer.py +2970 -0
  52. openadapt_ml-0.1.0.dist-info/METADATA +818 -0
  53. openadapt_ml-0.1.0.dist-info/RECORD +55 -0
  54. openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
  55. openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,62 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+
5
+ from openadapt_ml.datasets.next_action import build_next_action_sft_samples
6
+ from openadapt_ml.ingest.synthetic import generate_synthetic_sessions
7
+ from openadapt_ml.models.dummy_adapter import DummyAdapter
8
+ from openadapt_ml.models.qwen_vl import QwenVLAdapter
9
+ from openadapt_ml.models.api_adapter import ApiVLMAdapter
10
+ from openadapt_ml.runtime.policy import AgentPolicy
11
+
12
+
13
+ def main() -> None:
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument(
16
+ "--backend",
17
+ choices=["dummy", "qwen3", "qwen2_5", "claude", "openai"],
18
+ default="dummy",
19
+ )
20
+ args = parser.parse_args()
21
+
22
+ # Use synthetic data to build one SFT-style sample
23
+ sessions = generate_synthetic_sessions(num_sessions=1, seed=99, output_dir="synthetic/demo")
24
+ episodes = [ep for sess in sessions for ep in sess.episodes]
25
+ samples = build_next_action_sft_samples(episodes)
26
+
27
+ # Load first sample and overwrite assistant content so the dummy adapter
28
+ # doesn't depend on any particular target.
29
+ sample = samples[0]
30
+
31
+ if args.backend == "dummy":
32
+ adapter = DummyAdapter()
33
+ elif args.backend == "qwen3":
34
+ adapter = QwenVLAdapter.from_pretrained(
35
+ "Qwen/Qwen3-VL-8B-Instruct",
36
+ lora_config=None,
37
+ load_in_4bit=False,
38
+ )
39
+ elif args.backend == "qwen2_5":
40
+ adapter = QwenVLAdapter.from_pretrained(
41
+ "Qwen/Qwen2.5-VL-7B-Instruct",
42
+ lora_config=None,
43
+ load_in_4bit=False,
44
+ )
45
+ elif args.backend == "claude":
46
+ adapter = ApiVLMAdapter(provider="anthropic")
47
+ else: # openai
48
+ adapter = ApiVLMAdapter(provider="openai")
49
+ policy = AgentPolicy(adapter)
50
+
51
+ action, thought, state, raw_text = policy.predict_action_from_sample(sample)
52
+ print("Raw sample messages:")
53
+ for m in sample["messages"]:
54
+ print(f"[{m['role']}] {m['content']}")
55
+
56
+ print("\nPredicted action:", action)
57
+ print("Thought:", thought)
58
+ print("State:", state)
59
+ print("Raw output:", raw_text)
60
+
61
+ if __name__ == "__main__":
62
+ main()
@@ -0,0 +1,287 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Optional
7
+
8
+ import yaml
9
+
10
+ from openadapt_ml.datasets.next_action import build_next_action_sft_samples, parse_action_som
11
+ from openadapt_ml.evals.trajectory_matching import evaluate_policy_on_episodes
12
+ from openadapt_ml.ingest.synthetic import generate_synthetic_sessions
13
+ from openadapt_ml.models.dummy_adapter import DummyAdapter
14
+ from openadapt_ml.models.qwen_vl import QwenVLAdapter
15
+ from openadapt_ml.models.api_adapter import ApiVLMAdapter
16
+ from openadapt_ml.runtime.policy import AgentPolicy
17
+
18
+
19
+ def _load_config(path: str | Path) -> dict:
20
+ with open(path, "r", encoding="utf-8") as f:
21
+ return yaml.safe_load(f)
22
+
23
+
24
+ def main(
25
+ config_path: str,
26
+ backend: str,
27
+ output_json: str | None,
28
+ ignore_lora: bool = False,
29
+ log_samples: Optional[str] = None,
30
+ log_limit: Optional[int] = None,
31
+ dsl_mode: str = "coord",
32
+ eval_on_training_data: bool = False,
33
+ no_jitter: bool = False,
34
+ scenario: Optional[str] = None,
35
+ ) -> None:
36
+ cfg = _load_config(config_path)
37
+
38
+ # Determine if using Set-of-Marks (SoM) mode
39
+ use_som = dsl_mode == "som"
40
+
41
+ # Synthetic data config
42
+ synth_cfg: Dict[str, Any] = cfg.get("synthetic_data", {})
43
+ num_sessions = synth_cfg.get("num_sessions", 4)
44
+ seed = synth_cfg.get("seed", 999)
45
+
46
+ # Determine output directory and jitter setting
47
+ if eval_on_training_data:
48
+ # Use the SAME data directory as training to test memorization
49
+ output_dir = synth_cfg.get("output_dir", "synthetic_train")
50
+ # When evaluating on training data, use same jitter setting as training
51
+ # (default True unless explicitly set)
52
+ jitter = synth_cfg.get("jitter", True) and not no_jitter
53
+ print(f"[INFO] Evaluating on TRAINING data from: {output_dir}")
54
+ else:
55
+ # Generate fresh data for generalization testing
56
+ output_dir = synth_cfg.get("output_dir", "synthetic_eval") + "_eval"
57
+ jitter = not no_jitter
58
+ print(f"[INFO] Evaluating on FRESH data in: {output_dir}")
59
+
60
+ if no_jitter:
61
+ print("[INFO] Jitter disabled - using deterministic layouts")
62
+
63
+ # Determine scenario: CLI arg takes precedence, then config, then default "login"
64
+ scenario_to_use = scenario if scenario else synth_cfg.get("scenario", "login")
65
+
66
+ # Generate sessions with SoM if requested
67
+ sessions = generate_synthetic_sessions(
68
+ num_sessions=num_sessions,
69
+ seed=seed,
70
+ output_dir=output_dir,
71
+ use_som=use_som,
72
+ jitter=jitter,
73
+ scenario=scenario_to_use,
74
+ )
75
+ print(f"[INFO] Scenario: {scenario_to_use}")
76
+ episodes = [ep for sess in sessions for ep in sess.episodes]
77
+
78
+ # Build samples with appropriate DSL mode
79
+ samples = build_next_action_sft_samples(episodes, use_som=use_som)
80
+
81
+ # Backend / adapter selection
82
+ if backend == "dummy":
83
+ adapter = DummyAdapter()
84
+ elif backend == "qwen3":
85
+ model_cfg = cfg.get("model", {})
86
+ model_name = model_cfg.get("name", "Qwen/Qwen3-VL-8B-Instruct")
87
+ load_in_4bit = model_cfg.get("load_in_4bit", False)
88
+
89
+ # Optionally ignore LoRA to evaluate the base model only.
90
+ if ignore_lora:
91
+ lora_cfg = None
92
+ else:
93
+ lora_cfg = cfg.get("lora")
94
+
95
+ adapter = QwenVLAdapter.from_pretrained(
96
+ model_name,
97
+ lora_config=lora_cfg,
98
+ load_in_4bit=load_in_4bit,
99
+ )
100
+ elif backend == "qwen2_5":
101
+ adapter = QwenVLAdapter.from_pretrained(
102
+ "Qwen/Qwen2.5-VL-7B-Instruct",
103
+ lora_config=None,
104
+ load_in_4bit=False,
105
+ )
106
+ elif backend == "claude":
107
+ adapter = ApiVLMAdapter(provider="anthropic")
108
+ elif backend == "openai":
109
+ adapter = ApiVLMAdapter(provider="openai")
110
+ else:
111
+ raise ValueError(f"Unsupported backend: {backend}")
112
+
113
+ policy = AgentPolicy(adapter)
114
+
115
+ log_fn: Optional[callable] = None
116
+ log_file_handle = None
117
+ if log_samples is not None:
118
+ log_path = Path(log_samples)
119
+ log_path.parent.mkdir(parents=True, exist_ok=True)
120
+ log_file_handle = open(log_path, "w", encoding="utf-8")
121
+
122
+ def _log(record: Dict[str, Any]) -> None:
123
+ assert log_file_handle is not None
124
+ log_file_handle.write(json.dumps(record) + "\n")
125
+
126
+ log_fn = _log
127
+
128
+ try:
129
+ metrics = evaluate_policy_on_episodes(
130
+ policy,
131
+ episodes,
132
+ samples,
133
+ log_fn=log_fn,
134
+ log_limit=log_limit,
135
+ use_som=use_som,
136
+ )
137
+ finally:
138
+ if log_file_handle is not None:
139
+ log_file_handle.close()
140
+
141
+ print(f"Evaluation results (DSL mode: {dsl_mode}):")
142
+ print(f" num_episodes: {metrics.num_episodes}")
143
+ print(f" num_steps: {metrics.num_steps}")
144
+ print(f" action_type_accuracy: {metrics.action_type_accuracy:.4f}")
145
+ if metrics.mean_coord_error is not None:
146
+ print(
147
+ " mean_coord_error (normalized): "
148
+ f"{metrics.mean_coord_error:.4f} (n={metrics.coord_error_count})"
149
+ )
150
+ else:
151
+ print(" mean_coord_error (normalized): N/A")
152
+ if metrics.episode_success_rate is not None:
153
+ print(f" episode_success_rate: {metrics.episode_success_rate:.4f}")
154
+ else:
155
+ print(" episode_success_rate: N/A")
156
+ if metrics.click_hit_rate is not None:
157
+ print(f" click_hit_rate: {metrics.click_hit_rate:.4f}")
158
+ else:
159
+ print(" click_hit_rate: N/A")
160
+ if metrics.mean_episode_progress is not None:
161
+ print(f" mean_episode_progress: {metrics.mean_episode_progress:.4f}")
162
+ else:
163
+ print(" mean_episode_progress: N/A")
164
+ if metrics.mean_episode_step_score is not None:
165
+ print(f" mean_episode_step_score: {metrics.mean_episode_step_score:.4f}")
166
+ else:
167
+ print(" mean_episode_step_score: N/A")
168
+ if metrics.weak_episode_success_rate is not None:
169
+ print(f" weak_episode_success_rate: {metrics.weak_episode_success_rate:.4f}")
170
+ else:
171
+ print(" weak_episode_success_rate: N/A")
172
+ if metrics.state_success_rate is not None:
173
+ print(f" state_success_rate: {metrics.state_success_rate:.4f}")
174
+ else:
175
+ print(" state_success_rate: N/A")
176
+ if metrics.bbox_hit_rate is not None:
177
+ print(f" bbox_hit_rate: {metrics.bbox_hit_rate:.4f}")
178
+ else:
179
+ print(" bbox_hit_rate: N/A")
180
+ if metrics.element_accuracy is not None:
181
+ print(f" element_accuracy: {metrics.element_accuracy:.4f}")
182
+ else:
183
+ print(" element_accuracy: N/A")
184
+
185
+ if output_json is not None:
186
+ payload = {
187
+ "config_path": str(config_path),
188
+ "backend": backend,
189
+ "dsl_mode": dsl_mode,
190
+ "metrics": {
191
+ "num_episodes": metrics.num_episodes,
192
+ "num_steps": metrics.num_steps,
193
+ "action_type_accuracy": metrics.action_type_accuracy,
194
+ "mean_coord_error": metrics.mean_coord_error,
195
+ "coord_error_count": metrics.coord_error_count,
196
+ "episode_success_rate": metrics.episode_success_rate,
197
+ "click_hit_rate": metrics.click_hit_rate,
198
+ "bbox_hit_rate": metrics.bbox_hit_rate,
199
+ "mean_episode_progress": metrics.mean_episode_progress,
200
+ "mean_episode_step_score": metrics.mean_episode_step_score,
201
+ "weak_episode_success_rate": metrics.weak_episode_success_rate,
202
+ "state_success_rate": metrics.state_success_rate,
203
+ "element_accuracy": metrics.element_accuracy if hasattr(metrics, 'element_accuracy') else None,
204
+ },
205
+ }
206
+ out_path = Path(output_json)
207
+ out_path.parent.mkdir(parents=True, exist_ok=True)
208
+ with open(out_path, "w", encoding="utf-8") as f:
209
+ json.dump(payload, f, indent=2)
210
+ print(f"Metrics written to {output_json}")
211
+
212
+
213
+ if __name__ == "__main__":
214
+ parser = argparse.ArgumentParser(description="Evaluate a policy on synthetic episodes.")
215
+ parser.add_argument("--config", type=str, required=True, help="Path to YAML config file.")
216
+ parser.add_argument(
217
+ "--backend",
218
+ type=str,
219
+ choices=["dummy", "qwen3", "qwen2_5", "claude", "openai"],
220
+ default="qwen2_5",
221
+ help="Backend adapter to use for evaluation.",
222
+ )
223
+ parser.add_argument(
224
+ "--output-json",
225
+ type=str,
226
+ default=None,
227
+ help="Optional path to write metrics as JSON.",
228
+ )
229
+ parser.add_argument(
230
+ "--ignore-lora",
231
+ action="store_true",
232
+ help="Ignore any LoRA config in the YAML and evaluate the base model only.",
233
+ )
234
+ parser.add_argument(
235
+ "--log-samples",
236
+ type=str,
237
+ default=None,
238
+ help="Optional path to write per-step eval logs as JSONL.",
239
+ )
240
+ parser.add_argument(
241
+ "--log-limit",
242
+ type=int,
243
+ default=None,
244
+ help="Maximum number of steps to log (default: no limit).",
245
+ )
246
+ parser.add_argument(
247
+ "--dsl-mode",
248
+ type=str,
249
+ choices=["coord", "som"],
250
+ default="coord",
251
+ help="DSL mode: 'coord' for coordinate-based (CLICK(x=..., y=...)), "
252
+ "'som' for Set-of-Marks index-based (CLICK([1])). Default: coord.",
253
+ )
254
+ parser.add_argument(
255
+ "--overfit",
256
+ action="store_true",
257
+ help="Evaluate on training data to check memorization/overfitting. "
258
+ "If not set, generates fresh data to test generalization.",
259
+ )
260
+ parser.add_argument(
261
+ "--no-jitter",
262
+ action="store_true",
263
+ help="Disable jitter for deterministic UI layouts. "
264
+ "Useful for testing memorization of fixed layouts.",
265
+ )
266
+ parser.add_argument(
267
+ "--scenario",
268
+ type=str,
269
+ choices=["login", "registration"],
270
+ default=None,
271
+ help="Scenario type: 'login' (6 steps, 3 elements) or 'registration' (12 steps, 6 elements). "
272
+ "Overrides config if provided.",
273
+ )
274
+ args = parser.parse_args()
275
+
276
+ main(
277
+ config_path=args.config,
278
+ backend=args.backend,
279
+ output_json=args.output_json,
280
+ ignore_lora=args.ignore_lora,
281
+ log_samples=args.log_samples,
282
+ log_limit=args.log_limit,
283
+ dsl_mode=args.dsl_mode,
284
+ eval_on_training_data=args.overfit,
285
+ no_jitter=args.no_jitter,
286
+ scenario=args.scenario,
287
+ )
@@ -0,0 +1,153 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import glob
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional
8
+
9
+ from PIL import Image, ImageDraw, ImageFont
10
+
11
+
12
+ def _load_font(size: int = 16) -> ImageFont.FreeTypeFont | ImageFont.ImageFont: # type: ignore[name-defined]
13
+ try:
14
+ return ImageFont.truetype("arial.ttf", size)
15
+ except OSError:
16
+ return ImageFont.load_default()
17
+
18
+
19
+ FONT = _load_font(16)
20
+
21
+
22
+ def _load_frames(frames_dir: Path, pattern: str) -> List[Path]:
23
+ paths = sorted(Path(p) for p in glob.glob(str(frames_dir / pattern)))
24
+ if not paths:
25
+ raise ValueError(f"No frames matched pattern '{pattern}' under {frames_dir}")
26
+ return paths
27
+
28
+
29
+ def _default_login_caption(filename: str, index: int) -> str:
30
+ # Heuristic captions for the synthetic login script based on step index
31
+ # and the conventional *_step_{k}.png naming.
32
+ name = os.path.basename(filename)
33
+ # Try to extract step index from name if present.
34
+ step_idx = index
35
+ for part in name.split("_"):
36
+ if part.startswith("step"):
37
+ try:
38
+ step_idx = int(part.replace("step", "").replace(".png", ""))
39
+ except ValueError:
40
+ pass
41
+ if step_idx == 0:
42
+ return "Step 0: Initial login screen (WAIT)"
43
+ if step_idx == 1:
44
+ return "Step 1: CLICK username field"
45
+ if step_idx == 2:
46
+ return "Step 2: TYPE username"
47
+ if step_idx == 3:
48
+ return "Step 3: CLICK password field"
49
+ if step_idx == 4:
50
+ return "Step 4: TYPE password (masked)"
51
+ if step_idx == 5:
52
+ return "Step 5: CLICK Login button"
53
+ if step_idx == 6:
54
+ return "Step 6: DONE (logged in)"
55
+ return f"Step {step_idx}: synthetic step"
56
+
57
+
58
+ def _draw_caption(image: Image.Image, text: str) -> Image.Image:
59
+ img = image.convert("RGB").copy()
60
+ draw = ImageDraw.Draw(img)
61
+ width, height = img.size
62
+
63
+ # Draw a semi-transparent rectangle at the bottom for text background
64
+ padding = 8
65
+ text_width, text_height = draw.textbbox((0, 0), text, font=FONT)[2:4] # type: ignore[assignment]
66
+ rect_height = text_height + 2 * padding
67
+ y0 = height - rect_height
68
+ draw.rectangle([(0, y0), (width, height)], fill=(0, 0, 0, 180))
69
+ x_text = max(padding, (width - text_width) // 2)
70
+ y_text = y0 + padding
71
+ draw.text((x_text, y_text), text, font=FONT, fill=(255, 255, 255))
72
+ return img
73
+
74
+
75
+ def make_gif(
76
+ frames_dir: Path,
77
+ pattern: str,
78
+ output: Path,
79
+ duration_ms: int = 1000,
80
+ scenario: Optional[str] = None,
81
+ ) -> None:
82
+ frame_paths = _load_frames(frames_dir, pattern)
83
+
84
+ frames: List[Image.Image] = []
85
+ for idx, frame_path in enumerate(frame_paths):
86
+ img = Image.open(frame_path)
87
+ if scenario == "login":
88
+ caption = _default_login_caption(frame_path.name, idx)
89
+ img = _draw_caption(img, caption)
90
+ frames.append(img)
91
+
92
+ output.parent.mkdir(parents=True, exist_ok=True)
93
+ frames[0].save(
94
+ output,
95
+ save_all=True,
96
+ append_images=frames[1:],
97
+ duration=duration_ms,
98
+ loop=0,
99
+ )
100
+
101
+
102
+ def main() -> None:
103
+ parser = argparse.ArgumentParser(
104
+ description="Generate an animated GIF from a sequence of PNG frames.",
105
+ )
106
+ parser.add_argument(
107
+ "--frames-dir",
108
+ type=str,
109
+ required=True,
110
+ help="Directory containing frame PNGs (e.g. synthetic_demo/session_0000)",
111
+ )
112
+ parser.add_argument(
113
+ "--pattern",
114
+ type=str,
115
+ default="*step_*.png",
116
+ help="Glob pattern for frame filenames inside frames-dir (default: *step_*.png)",
117
+ )
118
+ parser.add_argument(
119
+ "--output",
120
+ type=str,
121
+ required=True,
122
+ help="Output GIF path",
123
+ )
124
+ parser.add_argument(
125
+ "--duration-ms",
126
+ type=int,
127
+ default=1000,
128
+ help="Frame duration in milliseconds (default: 1000)",
129
+ )
130
+ parser.add_argument(
131
+ "--scenario",
132
+ type=str,
133
+ default=None,
134
+ choices=["login", None], # type: ignore[list-item]
135
+ help="Optional built-in captioning scenario (e.g. 'login')",
136
+ )
137
+
138
+ args = parser.parse_args()
139
+
140
+ frames_dir = Path(args.frames_dir)
141
+ output = Path(args.output)
142
+
143
+ make_gif(
144
+ frames_dir=frames_dir,
145
+ pattern=args.pattern,
146
+ output=output,
147
+ duration_ms=args.duration_ms,
148
+ scenario=args.scenario,
149
+ )
150
+
151
+
152
+ if __name__ == "__main__":
153
+ main()
@@ -0,0 +1,43 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from pathlib import Path
5
+
6
+ from openadapt_ml.ingest.synthetic import generate_synthetic_sessions
7
+
8
+
9
+ def main() -> None:
10
+ output_dir = Path("synthetic") / "debug"
11
+ sessions = generate_synthetic_sessions(num_sessions=2, seed=42, output_dir=output_dir)
12
+
13
+ print(f"Generated {len(sessions)} sessions into {output_dir.resolve()}")
14
+
15
+ total_episodes = 0
16
+ total_steps = 0
17
+ missing_images: list[str] = []
18
+
19
+ for session in sessions:
20
+ total_episodes += len(session.episodes)
21
+ for episode in session.episodes:
22
+ total_steps += len(episode.steps)
23
+ for step in episode.steps:
24
+ path = step.observation.image_path
25
+ if not path:
26
+ missing_images.append(f"[no path] in episode {episode.id}")
27
+ continue
28
+ if not os.path.exists(path):
29
+ missing_images.append(path)
30
+
31
+ print(f"Episodes: {total_episodes}, Steps: {total_steps}")
32
+
33
+ if missing_images:
34
+ print("Missing images:")
35
+ for p in missing_images:
36
+ print(" -", p)
37
+ raise SystemExit(1)
38
+
39
+ print("All observation image paths exist.")
40
+
41
+
42
+ if __name__ == "__main__":
43
+ main()