openadapt-ml 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. openadapt_ml/baselines/__init__.py +121 -0
  2. openadapt_ml/baselines/adapter.py +185 -0
  3. openadapt_ml/baselines/cli.py +314 -0
  4. openadapt_ml/baselines/config.py +448 -0
  5. openadapt_ml/baselines/parser.py +922 -0
  6. openadapt_ml/baselines/prompts.py +787 -0
  7. openadapt_ml/benchmarks/__init__.py +13 -115
  8. openadapt_ml/benchmarks/agent.py +265 -421
  9. openadapt_ml/benchmarks/azure.py +28 -19
  10. openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
  11. openadapt_ml/benchmarks/cli.py +1722 -4847
  12. openadapt_ml/benchmarks/trace_export.py +631 -0
  13. openadapt_ml/benchmarks/viewer.py +22 -5
  14. openadapt_ml/benchmarks/vm_monitor.py +530 -29
  15. openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
  16. openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
  17. openadapt_ml/cloud/azure_inference.py +3 -5
  18. openadapt_ml/cloud/lambda_labs.py +722 -307
  19. openadapt_ml/cloud/local.py +2038 -487
  20. openadapt_ml/cloud/ssh_tunnel.py +68 -26
  21. openadapt_ml/datasets/next_action.py +40 -30
  22. openadapt_ml/evals/grounding.py +8 -3
  23. openadapt_ml/evals/plot_eval_metrics.py +15 -13
  24. openadapt_ml/evals/trajectory_matching.py +41 -26
  25. openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
  26. openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
  27. openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
  28. openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
  29. openadapt_ml/experiments/representation_shootout/config.py +390 -0
  30. openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
  31. openadapt_ml/experiments/representation_shootout/runner.py +687 -0
  32. openadapt_ml/experiments/waa_demo/runner.py +29 -14
  33. openadapt_ml/export/parquet.py +36 -24
  34. openadapt_ml/grounding/detector.py +18 -14
  35. openadapt_ml/ingest/__init__.py +8 -6
  36. openadapt_ml/ingest/capture.py +25 -22
  37. openadapt_ml/ingest/loader.py +7 -4
  38. openadapt_ml/ingest/synthetic.py +189 -100
  39. openadapt_ml/models/api_adapter.py +14 -4
  40. openadapt_ml/models/base_adapter.py +10 -2
  41. openadapt_ml/models/providers/__init__.py +288 -0
  42. openadapt_ml/models/providers/anthropic.py +266 -0
  43. openadapt_ml/models/providers/base.py +299 -0
  44. openadapt_ml/models/providers/google.py +376 -0
  45. openadapt_ml/models/providers/openai.py +342 -0
  46. openadapt_ml/models/qwen_vl.py +46 -19
  47. openadapt_ml/perception/__init__.py +35 -0
  48. openadapt_ml/perception/integration.py +399 -0
  49. openadapt_ml/retrieval/demo_retriever.py +50 -24
  50. openadapt_ml/retrieval/embeddings.py +9 -8
  51. openadapt_ml/retrieval/retriever.py +3 -1
  52. openadapt_ml/runtime/__init__.py +50 -0
  53. openadapt_ml/runtime/policy.py +18 -5
  54. openadapt_ml/runtime/safety_gate.py +471 -0
  55. openadapt_ml/schema/__init__.py +9 -0
  56. openadapt_ml/schema/converters.py +74 -27
  57. openadapt_ml/schema/episode.py +31 -18
  58. openadapt_ml/scripts/capture_screenshots.py +530 -0
  59. openadapt_ml/scripts/compare.py +85 -54
  60. openadapt_ml/scripts/demo_policy.py +4 -1
  61. openadapt_ml/scripts/eval_policy.py +15 -9
  62. openadapt_ml/scripts/make_gif.py +1 -1
  63. openadapt_ml/scripts/prepare_synthetic.py +3 -1
  64. openadapt_ml/scripts/train.py +21 -9
  65. openadapt_ml/segmentation/README.md +920 -0
  66. openadapt_ml/segmentation/__init__.py +97 -0
  67. openadapt_ml/segmentation/adapters/__init__.py +5 -0
  68. openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
  69. openadapt_ml/segmentation/annotator.py +610 -0
  70. openadapt_ml/segmentation/cache.py +290 -0
  71. openadapt_ml/segmentation/cli.py +674 -0
  72. openadapt_ml/segmentation/deduplicator.py +656 -0
  73. openadapt_ml/segmentation/frame_describer.py +788 -0
  74. openadapt_ml/segmentation/pipeline.py +340 -0
  75. openadapt_ml/segmentation/schemas.py +622 -0
  76. openadapt_ml/segmentation/segment_extractor.py +634 -0
  77. openadapt_ml/training/azure_ops_viewer.py +1097 -0
  78. openadapt_ml/training/benchmark_viewer.py +52 -41
  79. openadapt_ml/training/shared_ui.py +7 -7
  80. openadapt_ml/training/stub_provider.py +57 -35
  81. openadapt_ml/training/trainer.py +143 -86
  82. openadapt_ml/training/trl_trainer.py +70 -21
  83. openadapt_ml/training/viewer.py +323 -108
  84. openadapt_ml/training/viewer_components.py +180 -0
  85. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +215 -14
  86. openadapt_ml-0.2.1.dist-info/RECORD +116 -0
  87. openadapt_ml/benchmarks/base.py +0 -366
  88. openadapt_ml/benchmarks/data_collection.py +0 -432
  89. openadapt_ml/benchmarks/live_tracker.py +0 -180
  90. openadapt_ml/benchmarks/runner.py +0 -418
  91. openadapt_ml/benchmarks/waa.py +0 -761
  92. openadapt_ml/benchmarks/waa_live.py +0 -619
  93. openadapt_ml-0.2.0.dist-info/RECORD +0 -86
  94. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
  95. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -30,16 +30,13 @@ from __future__ import annotations
30
30
 
31
31
  import json
32
32
  import logging
33
- import os
34
- import tempfile
35
33
  import time
36
34
  from concurrent.futures import ThreadPoolExecutor, as_completed
37
35
  from dataclasses import dataclass, field
38
36
  from pathlib import Path
39
- from typing import Any, Callable
37
+ from typing import Callable
40
38
 
41
- from openadapt_ml.benchmarks.agent import BenchmarkAgent
42
- from openadapt_ml.benchmarks.base import BenchmarkResult, BenchmarkTask
39
+ from openadapt_evals import BenchmarkAgent, BenchmarkResult, BenchmarkTask
43
40
 
44
41
  logger = logging.getLogger(__name__)
45
42
 
@@ -233,7 +230,9 @@ class AzureMLClient:
233
230
  resource_group_name=self.config.resource_group,
234
231
  workspace_name=self.config.workspace_name,
235
232
  )
236
- logger.info(f"Connected to Azure ML workspace: {self.config.workspace_name}")
233
+ logger.info(
234
+ f"Connected to Azure ML workspace: {self.config.workspace_name}"
235
+ )
237
236
  return self._client
238
237
 
239
238
  def _get_credential(self):
@@ -241,11 +240,13 @@ class AzureMLClient:
241
240
  from openadapt_ml.config import settings
242
241
 
243
242
  # Use service principal if credentials are configured
244
- if all([
245
- settings.azure_client_id,
246
- settings.azure_client_secret,
247
- settings.azure_tenant_id,
248
- ]):
243
+ if all(
244
+ [
245
+ settings.azure_client_id,
246
+ settings.azure_client_secret,
247
+ settings.azure_tenant_id,
248
+ ]
249
+ ):
249
250
  logger.info("Using service principal authentication")
250
251
  return self._ClientSecretCredential(
251
252
  tenant_id=settings.azure_tenant_id,
@@ -301,7 +302,10 @@ class AzureMLClient:
301
302
  f"/providers/Microsoft.ManagedIdentity"
302
303
  f"/userAssignedIdentities/{self.config.managed_identity_name}"
303
304
  )
304
- compute.identity = {"type": "UserAssigned", "user_assigned_identities": [identity_id]}
305
+ compute.identity = {
306
+ "type": "UserAssigned",
307
+ "user_assigned_identities": [identity_id],
308
+ }
305
309
 
306
310
  print(f" Creating VM: {name}...", end="", flush=True)
307
311
  self.client.compute.begin_create_or_update(compute).result()
@@ -381,6 +385,7 @@ class AzureMLClient:
381
385
 
382
386
  import time
383
387
  import uuid
388
+
384
389
  timestamp = int(time.time())
385
390
  unique_id = str(uuid.uuid4())[:8]
386
391
  job_name = f"waa-{compute_name}-{timestamp}-{unique_id}"
@@ -490,7 +495,7 @@ class AzureWAAOrchestrator:
490
495
  List of BenchmarkResult for all tasks.
491
496
  """
492
497
  # Load tasks
493
- from openadapt_ml.benchmarks.waa import WAAAdapter
498
+ from openadapt_evals import WAAMockAdapter as WAAAdapter
494
499
 
495
500
  adapter = WAAAdapter(waa_repo_path=self.waa_repo_path)
496
501
  if task_ids:
@@ -530,17 +535,21 @@ class AzureWAAOrchestrator:
530
535
 
531
536
  try:
532
537
  # Provision VMs in parallel
533
- print(f"[2/4] Provisioning {num_workers} Azure VM(s)... (this takes 3-5 minutes)")
538
+ print(
539
+ f"[2/4] Provisioning {num_workers} Azure VM(s)... (this takes 3-5 minutes)"
540
+ )
534
541
  self._provision_workers(workers)
535
- print(f" VM(s) ready")
542
+ print(" VM(s) ready")
536
543
 
537
544
  # Submit jobs to workers
538
- print(f"[3/4] Submitting evaluation jobs...")
539
- self._submit_worker_jobs(workers, task_batches, agent, max_steps_per_task, timeout_hours)
540
- print(f" Jobs submitted")
545
+ print("[3/4] Submitting evaluation jobs...")
546
+ self._submit_worker_jobs(
547
+ workers, task_batches, agent, max_steps_per_task, timeout_hours
548
+ )
549
+ print(" Jobs submitted")
541
550
 
542
551
  # Wait for completion and collect results
543
- print(f"[4/4] Waiting for workers to complete...")
552
+ print("[4/4] Waiting for workers to complete...")
544
553
  results = self._wait_and_collect_results(workers, on_worker_complete)
545
554
 
546
555
  self._current_run.status = "completed"
@@ -0,0 +1,521 @@
1
+ """Azure operations status tracker.
2
+
3
+ Writes real-time status to azure_ops_status.json for dashboard consumption.
4
+ Used by CLI commands (setup-waa, run-waa, vm monitor) to provide visibility
5
+ into long-running Azure operations.
6
+
7
+ Usage:
8
+ from openadapt_ml.benchmarks.azure_ops_tracker import AzureOpsTracker
9
+
10
+ tracker = AzureOpsTracker()
11
+ tracker.start_operation("docker_build", total_steps=12)
12
+ tracker.update(phase="pulling_base_image", step=1, log_lines=["Pulling from ..."])
13
+ tracker.append_log("Step 1/12 : FROM dockurr/windows:latest")
14
+ tracker.finish_operation()
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import json
20
+ import re
21
+ from dataclasses import dataclass, asdict, field
22
+ from datetime import datetime
23
+ from pathlib import Path
24
+ from typing import Any
25
+
26
+ # VM pricing from vm_monitor.py
27
+ VM_HOURLY_RATES = {
28
+ "Standard_D2_v3": 0.096,
29
+ "Standard_D4_v3": 0.192,
30
+ "Standard_D8_v3": 0.384,
31
+ "Standard_D4s_v3": 0.192,
32
+ "Standard_D8s_v3": 0.384,
33
+ "Standard_D4ds_v5": 0.422, # Updated pricing as per spec
34
+ "Standard_D8ds_v5": 0.384,
35
+ "Standard_D16ds_v5": 0.768,
36
+ "Standard_D32ds_v5": 1.536,
37
+ }
38
+
39
+ # Typical operation durations in seconds (for ETA estimation)
40
+ TYPICAL_DURATIONS = {
41
+ "docker_build": 600, # ~10 minutes for waa-auto build
42
+ "docker_pull": 300, # ~5 minutes for large image pull
43
+ "windows_boot": 900, # ~15 minutes for first Windows boot
44
+ "benchmark": 1800, # ~30 minutes for 20 tasks
45
+ }
46
+
47
+ DEFAULT_OUTPUT_FILE = Path("benchmark_results/azure_ops_status.json")
48
+
49
+
50
+ @dataclass
51
+ class AzureOpsStatus:
52
+ """Status of current Azure operation.
53
+
54
+ Attributes:
55
+ operation: Current operation type (idle, vm_create, docker_install,
56
+ docker_build, windows_boot, benchmark, etc.)
57
+ phase: Specific phase within the operation.
58
+ step: Current step number.
59
+ total_steps: Total number of steps in the operation.
60
+ progress_pct: Progress percentage (0-100).
61
+ log_tail: Last N lines of log output.
62
+ started_at: ISO timestamp when operation started.
63
+ elapsed_seconds: Seconds since operation started.
64
+ eta_seconds: Estimated seconds remaining (None if unknown).
65
+ cost_usd: Running cost in USD.
66
+ hourly_rate_usd: Hourly VM rate in USD.
67
+ vm_ip: VM IP address if available.
68
+ vm_state: VM power state (running, starting, stopped, deallocated).
69
+ vm_size: Azure VM size.
70
+ vnc_url: VNC URL for accessing Windows desktop.
71
+ error: Error message if operation failed.
72
+ download_bytes: Bytes downloaded so far (for image pulls).
73
+ download_total_bytes: Total bytes to download.
74
+ build_id: Current Docker build run ID (to detect new builds).
75
+ """
76
+
77
+ operation: str = "idle"
78
+ phase: str = ""
79
+ step: int = 0
80
+ total_steps: int = 0
81
+ progress_pct: float = 0.0
82
+ log_tail: list[str] = field(default_factory=list)
83
+ started_at: str | None = None
84
+ elapsed_seconds: float = 0.0
85
+ eta_seconds: float | None = None
86
+ cost_usd: float = 0.0
87
+ hourly_rate_usd: float = 0.422 # Default for Standard_D4ds_v5
88
+ vm_ip: str | None = None
89
+ vm_state: str = "unknown"
90
+ vm_size: str = "Standard_D4ds_v5"
91
+ vnc_url: str | None = None
92
+ error: str | None = None
93
+ download_bytes: int = 0
94
+ download_total_bytes: int = 0
95
+ build_id: str | None = None
96
+
97
+ def to_dict(self) -> dict[str, Any]:
98
+ """Convert to dictionary for JSON serialization."""
99
+ return asdict(self)
100
+
101
+
102
+ class AzureOpsTracker:
103
+ """Tracks Azure operations and writes status to JSON file.
104
+
105
+ The tracker maintains a status file that the dashboard can poll to
106
+ display real-time progress of Azure operations.
107
+ """
108
+
109
+ MAX_LOG_LINES = 100
110
+
111
+ def __init__(
112
+ self,
113
+ output_file: str | Path = DEFAULT_OUTPUT_FILE,
114
+ vm_size: str = "Standard_D4ds_v5",
115
+ ):
116
+ """Initialize tracker.
117
+
118
+ Args:
119
+ output_file: Path to output JSON file.
120
+ vm_size: Azure VM size for cost calculation.
121
+ """
122
+ self.output_file = Path(output_file)
123
+ self.vm_size = vm_size
124
+ self.hourly_rate = VM_HOURLY_RATES.get(vm_size, 0.422)
125
+ self._status = AzureOpsStatus(
126
+ vm_size=vm_size,
127
+ hourly_rate_usd=self.hourly_rate,
128
+ )
129
+ self._start_time: datetime | None = None
130
+
131
+ def start_operation(
132
+ self,
133
+ operation: str,
134
+ total_steps: int = 0,
135
+ phase: str = "",
136
+ vm_ip: str | None = None,
137
+ vm_state: str = "running",
138
+ build_id: str | None = None,
139
+ started_at: datetime | None = None,
140
+ ) -> None:
141
+ """Start tracking a new operation.
142
+
143
+ Args:
144
+ operation: Operation type (vm_create, docker_install, docker_build,
145
+ windows_boot, benchmark, etc.)
146
+ total_steps: Total number of steps in the operation.
147
+ phase: Initial phase description.
148
+ vm_ip: VM IP address if known.
149
+ vm_state: VM power state.
150
+ build_id: Unique identifier for this build (to detect new builds).
151
+ started_at: When the operation actually started (uses now if not provided).
152
+ """
153
+ self._start_time = started_at or datetime.now()
154
+ self._status = AzureOpsStatus(
155
+ operation=operation,
156
+ phase=phase,
157
+ step=0,
158
+ total_steps=total_steps,
159
+ progress_pct=0.0,
160
+ log_tail=[], # Clear stale logs
161
+ started_at=self._start_time.isoformat(),
162
+ elapsed_seconds=0.0,
163
+ eta_seconds=TYPICAL_DURATIONS.get(
164
+ operation
165
+ ), # Use typical duration as initial ETA
166
+ cost_usd=0.0,
167
+ hourly_rate_usd=self.hourly_rate,
168
+ vm_ip=vm_ip,
169
+ vm_state=vm_state,
170
+ vm_size=self.vm_size,
171
+ vnc_url="http://localhost:8006" if vm_ip else None,
172
+ error=None,
173
+ download_bytes=0,
174
+ download_total_bytes=0,
175
+ build_id=build_id,
176
+ )
177
+ self._write_status()
178
+
179
+ def update(
180
+ self,
181
+ phase: str | None = None,
182
+ step: int | None = None,
183
+ total_steps: int | None = None,
184
+ log_lines: list[str] | None = None,
185
+ vm_ip: str | None = None,
186
+ vm_state: str | None = None,
187
+ error: str | None = None,
188
+ download_bytes: int | None = None,
189
+ download_total_bytes: int | None = None,
190
+ build_id: str | None = None,
191
+ ) -> None:
192
+ """Update operation status.
193
+
194
+ Args:
195
+ phase: Current phase description.
196
+ step: Current step number.
197
+ total_steps: Total steps (can be updated if discovered during operation).
198
+ log_lines: New log lines to append.
199
+ vm_ip: VM IP address.
200
+ vm_state: VM power state.
201
+ error: Error message if operation failed.
202
+ download_bytes: Bytes downloaded so far.
203
+ download_total_bytes: Total bytes to download.
204
+ build_id: Build identifier (clears log if different from current).
205
+ """
206
+ # If build_id changed, this is a new build - clear stale logs
207
+ if build_id is not None and build_id != self._status.build_id:
208
+ self._status.build_id = build_id
209
+ self._status.log_tail = []
210
+ self._status.error = None
211
+ self._start_time = datetime.now()
212
+ self._status.started_at = self._start_time.isoformat()
213
+
214
+ if phase is not None:
215
+ self._status.phase = phase
216
+ if step is not None:
217
+ self._status.step = step
218
+ if total_steps is not None:
219
+ self._status.total_steps = total_steps
220
+ if log_lines is not None:
221
+ for line in log_lines:
222
+ self.append_log(line)
223
+ if vm_ip is not None:
224
+ self._status.vm_ip = vm_ip
225
+ self._status.vnc_url = "http://localhost:8006"
226
+ if vm_state is not None:
227
+ self._status.vm_state = vm_state
228
+ if error is not None:
229
+ self._status.error = error
230
+ if download_bytes is not None:
231
+ self._status.download_bytes = download_bytes
232
+ if download_total_bytes is not None:
233
+ self._status.download_total_bytes = download_total_bytes
234
+
235
+ # Update derived fields
236
+ self._update_progress()
237
+ self._write_status()
238
+
239
+ def append_log(self, line: str) -> None:
240
+ """Append a log line (keeps last MAX_LOG_LINES).
241
+
242
+ Args:
243
+ line: Log line to append.
244
+ """
245
+ self._status.log_tail.append(line.rstrip())
246
+ if len(self._status.log_tail) > self.MAX_LOG_LINES:
247
+ self._status.log_tail = self._status.log_tail[-self.MAX_LOG_LINES :]
248
+ self._update_progress()
249
+ self._write_status()
250
+
251
+ def parse_docker_build_line(self, line: str) -> dict[str, Any]:
252
+ """Parse Docker build output for step progress and download info.
253
+
254
+ Handles both patterns:
255
+ - Old style: "Step X/Y : ..."
256
+ - Buildx style: "#N [stage X/Y] ..." or "#N sha256:... XXXMB / YGB ..."
257
+
258
+ Args:
259
+ line: Docker build output line.
260
+
261
+ Returns:
262
+ Dict with parsed info: {step, total_steps, download_bytes, download_total_bytes, phase}
263
+ """
264
+ result: dict[str, Any] = {}
265
+
266
+ # Old style: "Step X/Y : ..."
267
+ step_match = re.search(r"Step\s+(\d+)/(\d+)", line)
268
+ if step_match:
269
+ result["step"] = int(step_match.group(1))
270
+ result["total_steps"] = int(step_match.group(2))
271
+
272
+ # Buildx style: "#N [stage X/Y] ..."
273
+ buildx_stage = re.search(r"#\d+\s+\[.*?\s+(\d+)/(\d+)\]", line)
274
+ if buildx_stage:
275
+ result["step"] = int(buildx_stage.group(1))
276
+ result["total_steps"] = int(buildx_stage.group(2))
277
+
278
+ # Download progress: "sha256:... XXXMB / YGB ..." or "XXX.XXMB / YY.YYGB ..."
279
+ download_match = re.search(
280
+ r"(\d+(?:\.\d+)?)\s*(MB|GB|KB|B)\s*/\s*(\d+(?:\.\d+)?)\s*(MB|GB|KB|B)",
281
+ line,
282
+ )
283
+ if download_match:
284
+ size_multipliers = {"B": 1, "KB": 1024, "MB": 1024**2, "GB": 1024**3}
285
+ downloaded = float(download_match.group(1))
286
+ downloaded_unit = download_match.group(2)
287
+ total = float(download_match.group(3))
288
+ total_unit = download_match.group(4)
289
+ result["download_bytes"] = int(
290
+ downloaded * size_multipliers[downloaded_unit]
291
+ )
292
+ result["download_total_bytes"] = int(total * size_multipliers[total_unit])
293
+
294
+ # Extract phase from buildx output
295
+ if line.startswith("#"):
296
+ # #N DONE, #N CACHED, #N [stage]
297
+ phase_match = re.match(r"#\d+\s+(.*)", line)
298
+ if phase_match:
299
+ phase_text = phase_match.group(1)[:80]
300
+ # Clean up ANSI codes
301
+ phase_text = re.sub(r"\x1b\[[0-9;]*m", "", phase_text)
302
+ result["phase"] = phase_text.strip()
303
+
304
+ # Apply updates if we found anything
305
+ if "step" in result:
306
+ self._status.step = result["step"]
307
+ if "total_steps" in result:
308
+ self._status.total_steps = result["total_steps"]
309
+ if "download_bytes" in result:
310
+ self._status.download_bytes = result["download_bytes"]
311
+ if "download_total_bytes" in result:
312
+ self._status.download_total_bytes = result["download_total_bytes"]
313
+ if "phase" in result:
314
+ self._status.phase = result["phase"]
315
+
316
+ if result:
317
+ self._update_progress()
318
+
319
+ return result
320
+
321
+ def is_error_line(self, line: str) -> bool:
322
+ """Check if a line is an error message.
323
+
324
+ Args:
325
+ line: Log line to check.
326
+
327
+ Returns:
328
+ True if line contains an error.
329
+ """
330
+ error_patterns = [
331
+ r"ERROR:",
332
+ r"failed to build",
333
+ r"failed to solve",
334
+ r"error reading from server",
335
+ r"rpc error",
336
+ ]
337
+ return any(re.search(p, line, re.IGNORECASE) for p in error_patterns)
338
+
339
+ def finish_operation(self, success: bool = True, error: str | None = None) -> None:
340
+ """Mark operation as complete.
341
+
342
+ Args:
343
+ success: Whether the operation completed successfully.
344
+ error: Error message if operation failed.
345
+ """
346
+ if error:
347
+ self._status.error = error
348
+ self._status.operation = "complete" if success else "failed"
349
+ self._status.progress_pct = 100.0 if success else self._status.progress_pct
350
+ self._update_progress()
351
+ self._write_status()
352
+
353
+ def set_idle(self) -> None:
354
+ """Reset tracker to idle state."""
355
+ self._start_time = None
356
+ self._status = AzureOpsStatus(
357
+ vm_size=self.vm_size,
358
+ hourly_rate_usd=self.hourly_rate,
359
+ )
360
+ self._write_status()
361
+
362
+ def get_status(self) -> AzureOpsStatus:
363
+ """Get current status (with updated elapsed time and cost)."""
364
+ self._update_progress()
365
+ return self._status
366
+
367
+ def _update_progress(self) -> None:
368
+ """Update derived fields (elapsed time, cost, progress percentage, ETA)."""
369
+ # Update elapsed time
370
+ if self._start_time:
371
+ elapsed = datetime.now() - self._start_time
372
+ self._status.elapsed_seconds = elapsed.total_seconds()
373
+
374
+ # Update cost
375
+ elapsed_hours = self._status.elapsed_seconds / 3600
376
+ self._status.cost_usd = elapsed_hours * self.hourly_rate
377
+
378
+ # Calculate progress from multiple sources
379
+ progress_pct = 0.0
380
+ eta_seconds = None
381
+
382
+ # 1. Download progress (most accurate during image pulls)
383
+ if self._status.download_total_bytes > 0:
384
+ download_pct = (
385
+ self._status.download_bytes / self._status.download_total_bytes
386
+ ) * 100
387
+ progress_pct = max(progress_pct, download_pct)
388
+
389
+ # ETA from download speed
390
+ if self._status.download_bytes > 0 and self._status.elapsed_seconds > 1:
391
+ bytes_per_sec = (
392
+ self._status.download_bytes / self._status.elapsed_seconds
393
+ )
394
+ remaining_bytes = (
395
+ self._status.download_total_bytes - self._status.download_bytes
396
+ )
397
+ if bytes_per_sec > 0:
398
+ eta_seconds = remaining_bytes / bytes_per_sec
399
+
400
+ # 2. Step-based progress
401
+ if self._status.total_steps > 0:
402
+ step_pct = (self._status.step / self._status.total_steps) * 100
403
+ progress_pct = max(progress_pct, step_pct)
404
+
405
+ # ETA from step rate (only if we have meaningful progress)
406
+ if self._status.step > 0 and self._status.elapsed_seconds > 10:
407
+ time_per_step = self._status.elapsed_seconds / self._status.step
408
+ remaining_steps = self._status.total_steps - self._status.step
409
+ step_eta = time_per_step * remaining_steps
410
+ # Use step ETA if we don't have download ETA or if step progress > download
411
+ if (
412
+ eta_seconds is None
413
+ or step_pct
414
+ > (
415
+ self._status.download_bytes
416
+ / max(self._status.download_total_bytes, 1)
417
+ )
418
+ * 100
419
+ ):
420
+ eta_seconds = step_eta
421
+
422
+ # 3. Fallback: Use typical duration if no progress info
423
+ if eta_seconds is None and self._status.operation in TYPICAL_DURATIONS:
424
+ typical = TYPICAL_DURATIONS[self._status.operation]
425
+ remaining = max(0, typical - self._status.elapsed_seconds)
426
+ eta_seconds = remaining
427
+ # Estimate progress from elapsed vs typical
428
+ if progress_pct == 0 and self._status.elapsed_seconds > 0:
429
+ progress_pct = min(95, (self._status.elapsed_seconds / typical) * 100)
430
+
431
+ self._status.progress_pct = min(100.0, progress_pct)
432
+ self._status.eta_seconds = eta_seconds
433
+
434
+ def _write_status(self) -> None:
435
+ """Write current status to JSON file."""
436
+ self.output_file.parent.mkdir(parents=True, exist_ok=True)
437
+ with open(self.output_file, "w") as f:
438
+ json.dump(self._status.to_dict(), f, indent=2)
439
+
440
+
441
+ # Global tracker instance for convenience
442
+ _tracker: AzureOpsTracker | None = None
443
+
444
+
445
+ def get_tracker(
446
+ output_file: str | Path = DEFAULT_OUTPUT_FILE,
447
+ vm_size: str = "Standard_D4ds_v5",
448
+ ) -> AzureOpsTracker:
449
+ """Get or create global tracker instance.
450
+
451
+ Args:
452
+ output_file: Path to output JSON file.
453
+ vm_size: Azure VM size for cost calculation.
454
+
455
+ Returns:
456
+ AzureOpsTracker instance.
457
+ """
458
+ global _tracker
459
+ if _tracker is None:
460
+ _tracker = AzureOpsTracker(output_file=output_file, vm_size=vm_size)
461
+ return _tracker
462
+
463
+
464
+ def read_status(
465
+ status_file: str | Path = DEFAULT_OUTPUT_FILE,
466
+ ) -> dict[str, Any]:
467
+ """Read status from JSON file with fresh computed values.
468
+
469
+ This function reads the persisted status and recomputes time-dependent
470
+ fields (elapsed_seconds, cost_usd) based on the current time. This ensures
471
+ the API always returns accurate values without relying on client-side
472
+ computation.
473
+
474
+ Args:
475
+ status_file: Path to status JSON file.
476
+
477
+ Returns:
478
+ Status dictionary with fresh elapsed_seconds and cost_usd, or idle status
479
+ if file doesn't exist.
480
+ """
481
+ status_path = Path(status_file)
482
+ if status_path.exists():
483
+ try:
484
+ with open(status_path) as f:
485
+ status = json.load(f)
486
+
487
+ # Recompute time-dependent fields if operation is active
488
+ if status.get("started_at") and status.get("operation") not in (
489
+ "idle",
490
+ "complete",
491
+ "failed",
492
+ ):
493
+ started_at = datetime.fromisoformat(status["started_at"])
494
+ elapsed = datetime.now() - started_at
495
+ elapsed_seconds = max(0, elapsed.total_seconds())
496
+
497
+ # Update elapsed time
498
+ status["elapsed_seconds"] = elapsed_seconds
499
+
500
+ # Update cost based on elapsed time
501
+ hourly_rate = status.get("hourly_rate_usd", 0.422)
502
+ status["cost_usd"] = (elapsed_seconds / 3600) * hourly_rate
503
+
504
+ # Update ETA if we have progress info
505
+ progress_pct = status.get("progress_pct", 0)
506
+ if progress_pct > 0 and elapsed_seconds > 10:
507
+ # Estimate remaining time from progress rate
508
+ time_per_pct = elapsed_seconds / progress_pct
509
+ remaining_pct = 100 - progress_pct
510
+ status["eta_seconds"] = time_per_pct * remaining_pct
511
+ elif status.get("operation") in TYPICAL_DURATIONS:
512
+ # Use typical duration minus elapsed
513
+ typical = TYPICAL_DURATIONS[status["operation"]]
514
+ status["eta_seconds"] = max(0, typical - elapsed_seconds)
515
+
516
+ return status
517
+ except (json.JSONDecodeError, IOError, ValueError):
518
+ pass
519
+
520
+ # Return default idle status
521
+ return AzureOpsStatus().to_dict()