openadapt-ml 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. openadapt_ml/__init__.py +0 -0
  2. openadapt_ml/benchmarks/__init__.py +125 -0
  3. openadapt_ml/benchmarks/agent.py +825 -0
  4. openadapt_ml/benchmarks/azure.py +761 -0
  5. openadapt_ml/benchmarks/base.py +366 -0
  6. openadapt_ml/benchmarks/cli.py +884 -0
  7. openadapt_ml/benchmarks/data_collection.py +432 -0
  8. openadapt_ml/benchmarks/runner.py +381 -0
  9. openadapt_ml/benchmarks/waa.py +704 -0
  10. openadapt_ml/cloud/__init__.py +5 -0
  11. openadapt_ml/cloud/azure_inference.py +441 -0
  12. openadapt_ml/cloud/lambda_labs.py +2445 -0
  13. openadapt_ml/cloud/local.py +790 -0
  14. openadapt_ml/config.py +56 -0
  15. openadapt_ml/datasets/__init__.py +0 -0
  16. openadapt_ml/datasets/next_action.py +507 -0
  17. openadapt_ml/evals/__init__.py +23 -0
  18. openadapt_ml/evals/grounding.py +241 -0
  19. openadapt_ml/evals/plot_eval_metrics.py +174 -0
  20. openadapt_ml/evals/trajectory_matching.py +486 -0
  21. openadapt_ml/grounding/__init__.py +45 -0
  22. openadapt_ml/grounding/base.py +236 -0
  23. openadapt_ml/grounding/detector.py +570 -0
  24. openadapt_ml/ingest/__init__.py +43 -0
  25. openadapt_ml/ingest/capture.py +312 -0
  26. openadapt_ml/ingest/loader.py +232 -0
  27. openadapt_ml/ingest/synthetic.py +1102 -0
  28. openadapt_ml/models/__init__.py +0 -0
  29. openadapt_ml/models/api_adapter.py +171 -0
  30. openadapt_ml/models/base_adapter.py +59 -0
  31. openadapt_ml/models/dummy_adapter.py +42 -0
  32. openadapt_ml/models/qwen_vl.py +426 -0
  33. openadapt_ml/runtime/__init__.py +0 -0
  34. openadapt_ml/runtime/policy.py +182 -0
  35. openadapt_ml/schemas/__init__.py +53 -0
  36. openadapt_ml/schemas/sessions.py +122 -0
  37. openadapt_ml/schemas/validation.py +252 -0
  38. openadapt_ml/scripts/__init__.py +0 -0
  39. openadapt_ml/scripts/compare.py +1490 -0
  40. openadapt_ml/scripts/demo_policy.py +62 -0
  41. openadapt_ml/scripts/eval_policy.py +287 -0
  42. openadapt_ml/scripts/make_gif.py +153 -0
  43. openadapt_ml/scripts/prepare_synthetic.py +43 -0
  44. openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
  45. openadapt_ml/scripts/train.py +174 -0
  46. openadapt_ml/training/__init__.py +0 -0
  47. openadapt_ml/training/benchmark_viewer.py +1538 -0
  48. openadapt_ml/training/shared_ui.py +157 -0
  49. openadapt_ml/training/stub_provider.py +276 -0
  50. openadapt_ml/training/trainer.py +2446 -0
  51. openadapt_ml/training/viewer.py +2970 -0
  52. openadapt_ml-0.1.0.dist-info/METADATA +818 -0
  53. openadapt_ml-0.1.0.dist-info/RECORD +55 -0
  54. openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
  55. openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,5 @@
1
+ """Cloud GPU providers for training."""
2
+
3
+ from openadapt_ml.cloud.lambda_labs import LambdaLabsClient
4
+
5
+ __all__ = ["LambdaLabsClient"]
@@ -0,0 +1,441 @@
1
+ """Azure async inference queue for live training feedback.
2
+
3
+ This module implements Phase 2 of the live inference design:
4
+ - Training instance uploads checkpoints to Azure Blob Storage
5
+ - Triggers Azure Queue Storage message with checkpoint info
6
+ - Inference worker polls queue and runs inference
7
+ - Results uploaded to Blob Storage for dashboard to display
8
+
9
+ Architecture:
10
+ Training GPU → Blob (checkpoints) → Queue (jobs) → Inference GPU → Blob (comparisons)
11
+
12
+ Usage:
13
+ # Submit checkpoint for async inference (called by trainer)
14
+ uv run python -m openadapt_ml.cloud.azure_inference inference-submit \
15
+ --checkpoint checkpoints/epoch_1 \
16
+ --capture /path/to/capture
17
+
18
+ # Start inference worker (runs on separate instance)
19
+ uv run python -m openadapt_ml.cloud.azure_inference inference-worker \
20
+ --model Qwen/Qwen2.5-VL-3B
21
+
22
+ # Watch for new comparison results
23
+ uv run python -m openadapt_ml.cloud.azure_inference inference-watch
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import json
29
+ import logging
30
+ import time
31
+ from dataclasses import dataclass
32
+ from pathlib import Path
33
+ from typing import Any
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ @dataclass
39
+ class InferenceJob:
40
+ """Inference job metadata."""
41
+
42
+ job_id: str
43
+ checkpoint_blob: str # Path in blob storage
44
+ checkpoint_epoch: int
45
+ capture_path: str # Path to capture data
46
+ submitted_at: float
47
+ status: str = "pending" # pending, running, completed, failed
48
+ output_blob: str | None = None # Path to comparison HTML in blob
49
+ error: str | None = None
50
+
51
+
52
+ class AzureInferenceQueue:
53
+ """Manages async inference jobs via Azure Queue Storage.
54
+
55
+ This class provides the core infrastructure for async inference during training:
56
+ 1. submit_checkpoint() - Upload checkpoint to blob, queue inference job
57
+ 2. poll_and_process() - Worker loop that processes jobs from queue
58
+ 3. watch_comparisons() - Poll for new comparison results
59
+
60
+ Authentication uses existing AzureConfig pattern from benchmarks/azure.py.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ storage_connection_string: str | None = None,
66
+ queue_name: str = "inference-jobs",
67
+ checkpoints_container: str = "checkpoints",
68
+ comparisons_container: str = "comparisons",
69
+ ):
70
+ """Initialize Azure inference queue.
71
+
72
+ Args:
73
+ storage_connection_string: Azure Storage connection string (from .env)
74
+ queue_name: Queue name for inference jobs
75
+ checkpoints_container: Blob container for checkpoints
76
+ comparisons_container: Blob container for comparison results
77
+
78
+ Raises:
79
+ ImportError: If azure-storage-blob or azure-storage-queue not installed
80
+ ValueError: If connection string not provided and not in settings
81
+ """
82
+ # Lazy import Azure SDK
83
+ try:
84
+ from azure.storage.blob import BlobServiceClient
85
+ from azure.storage.queue import QueueClient
86
+
87
+ self._BlobServiceClient = BlobServiceClient
88
+ self._QueueClient = QueueClient
89
+ except ImportError as e:
90
+ raise ImportError(
91
+ "Azure Storage SDK not installed. Install with: "
92
+ "pip install azure-storage-blob azure-storage-queue"
93
+ ) from e
94
+
95
+ # Get connection string from settings if not provided
96
+ if not storage_connection_string:
97
+ from openadapt_ml.config import settings
98
+
99
+ storage_connection_string = settings.azure_storage_connection_string
100
+ if not storage_connection_string:
101
+ raise ValueError(
102
+ "AZURE_STORAGE_CONNECTION_STRING not set. "
103
+ "Run 'python scripts/setup_azure.py' to configure Azure storage."
104
+ )
105
+
106
+ self.connection_string = storage_connection_string
107
+ self.queue_name = queue_name
108
+ self.checkpoints_container = checkpoints_container
109
+ self.comparisons_container = comparisons_container
110
+
111
+ # Initialize clients
112
+ self.blob_service = self._BlobServiceClient.from_connection_string(
113
+ storage_connection_string
114
+ )
115
+ self.queue_client = self._QueueClient.from_connection_string(
116
+ storage_connection_string, queue_name
117
+ )
118
+
119
+ logger.info(f"Initialized Azure inference queue: {queue_name}")
120
+
121
+ def submit_checkpoint(
122
+ self, checkpoint_path: str | Path, capture_path: str | Path, epoch: int = 0
123
+ ) -> InferenceJob:
124
+ """Upload checkpoint and queue inference job.
125
+
126
+ This is called by the trainer after saving a checkpoint.
127
+ It uploads the checkpoint to blob storage and adds a message to the queue.
128
+
129
+ Args:
130
+ checkpoint_path: Local path to checkpoint directory
131
+ capture_path: Path to capture data (for inference)
132
+ epoch: Epoch number for this checkpoint
133
+
134
+ Returns:
135
+ InferenceJob with job metadata
136
+ """
137
+ checkpoint_path = Path(checkpoint_path)
138
+ capture_path = Path(capture_path)
139
+
140
+ # Generate unique job ID
141
+ job_id = f"inference_{int(time.time())}_{epoch}"
142
+
143
+ # Upload checkpoint to blob storage
144
+ blob_name = f"checkpoints/epoch_{epoch}/{checkpoint_path.name}"
145
+ logger.info(f"Uploading checkpoint to {blob_name}...")
146
+
147
+ checkpoint_blob_client = self.blob_service.get_blob_client(
148
+ container=self.checkpoints_container, blob=blob_name
149
+ )
150
+
151
+ # Upload all files in checkpoint directory
152
+ for file_path in checkpoint_path.rglob("*"):
153
+ if file_path.is_file():
154
+ relative_path = file_path.relative_to(checkpoint_path)
155
+ file_blob_name = f"checkpoints/epoch_{epoch}/{relative_path}"
156
+ file_blob_client = self.blob_service.get_blob_client(
157
+ container=self.checkpoints_container, blob=file_blob_name
158
+ )
159
+
160
+ with open(file_path, "rb") as f:
161
+ file_blob_client.upload_blob(f, overwrite=True)
162
+ logger.debug(f" Uploaded {file_blob_name}")
163
+
164
+ # Create job metadata
165
+ job = InferenceJob(
166
+ job_id=job_id,
167
+ checkpoint_blob=blob_name,
168
+ checkpoint_epoch=epoch,
169
+ capture_path=str(capture_path),
170
+ submitted_at=time.time(),
171
+ )
172
+
173
+ # Queue inference job
174
+ job_message = {
175
+ "job_id": job.job_id,
176
+ "checkpoint_blob": job.checkpoint_blob,
177
+ "checkpoint_epoch": job.checkpoint_epoch,
178
+ "capture_path": job.capture_path,
179
+ "submitted_at": job.submitted_at,
180
+ }
181
+
182
+ self.queue_client.send_message(json.dumps(job_message))
183
+ logger.info(f"Queued inference job: {job_id}")
184
+
185
+ return job
186
+
187
+ def poll_and_process(
188
+ self,
189
+ adapter,
190
+ max_messages: int = 1,
191
+ visibility_timeout: int = 3600,
192
+ ) -> None:
193
+ """Worker: poll queue and run inference.
194
+
195
+ This is the main worker loop that runs on a separate GPU instance.
196
+ It continuously polls the queue for new jobs, downloads checkpoints,
197
+ runs inference, and uploads results.
198
+
199
+ Args:
200
+ adapter: VLM adapter for inference (e.g., Qwen adapter)
201
+ max_messages: Maximum messages to process per iteration
202
+ visibility_timeout: How long to hide message while processing (seconds)
203
+ """
204
+ logger.info("Starting inference worker...")
205
+
206
+ while True:
207
+ try:
208
+ # Poll for messages
209
+ messages = self.queue_client.receive_messages(
210
+ messages_per_page=max_messages,
211
+ visibility_timeout=visibility_timeout,
212
+ )
213
+
214
+ for msg in messages:
215
+ try:
216
+ # Parse job metadata
217
+ job_data = json.loads(msg.content)
218
+ job = InferenceJob(**job_data)
219
+
220
+ logger.info(f"Processing job: {job.job_id}")
221
+ job.status = "running"
222
+
223
+ # Download checkpoint from blob
224
+ checkpoint_dir = Path(f"/tmp/checkpoints/{job.job_id}")
225
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
226
+
227
+ self._download_checkpoint(job.checkpoint_blob, checkpoint_dir)
228
+
229
+ # Run inference
230
+ output_path = Path(f"/tmp/comparisons/{job.job_id}.html")
231
+ output_path.parent.mkdir(parents=True, exist_ok=True)
232
+
233
+ self._run_inference(
234
+ adapter=adapter,
235
+ checkpoint_path=checkpoint_dir,
236
+ capture_path=job.capture_path,
237
+ output_path=output_path,
238
+ )
239
+
240
+ # Upload result to blob
241
+ output_blob = (
242
+ f"comparisons/epoch_{job.checkpoint_epoch}_comparison.html"
243
+ )
244
+ self._upload_comparison(output_path, output_blob)
245
+
246
+ job.status = "completed"
247
+ job.output_blob = output_blob
248
+
249
+ logger.info(f"Job completed: {job.job_id}")
250
+
251
+ # Delete message from queue
252
+ self.queue_client.delete_message(msg)
253
+
254
+ except Exception as e:
255
+ logger.error(f"Job failed: {e}")
256
+ job.status = "failed"
257
+ job.error = str(e)
258
+ # Don't delete message - let it become visible again for retry
259
+
260
+ except Exception as e:
261
+ logger.error(f"Worker error: {e}")
262
+ time.sleep(10) # Back off on errors
263
+
264
+ # Poll interval
265
+ time.sleep(5)
266
+
267
+ def watch_comparisons(self, poll_interval: int = 10) -> None:
268
+ """Poll for new comparison results.
269
+
270
+ This can be used by the dashboard to discover new comparison files.
271
+
272
+ Args:
273
+ poll_interval: How often to check for new files (seconds)
274
+ """
275
+ logger.info("Watching for new comparison results...")
276
+ seen_blobs = set()
277
+
278
+ while True:
279
+ try:
280
+ # List blobs in comparisons container
281
+ container_client = self.blob_service.get_container_client(
282
+ self.comparisons_container
283
+ )
284
+ blobs = container_client.list_blobs()
285
+
286
+ for blob in blobs:
287
+ if blob.name not in seen_blobs:
288
+ logger.info(f"New comparison: {blob.name}")
289
+ seen_blobs.add(blob.name)
290
+
291
+ # Optionally download and open in browser
292
+ # self._download_and_open(blob.name)
293
+
294
+ except Exception as e:
295
+ logger.error(f"Watch error: {e}")
296
+
297
+ time.sleep(poll_interval)
298
+
299
+ def _download_checkpoint(self, blob_name: str, local_dir: Path) -> None:
300
+ """Download checkpoint from blob storage."""
301
+ # List all files under the checkpoint blob prefix
302
+ container_client = self.blob_service.get_container_client(
303
+ self.checkpoints_container
304
+ )
305
+ blob_prefix = "/".join(blob_name.split("/")[:-1]) # Get directory prefix
306
+
307
+ blobs = container_client.list_blobs(name_starts_with=blob_prefix)
308
+
309
+ for blob in blobs:
310
+ # Download each file
311
+ local_path = local_dir / Path(blob.name).relative_to(blob_prefix)
312
+ local_path.parent.mkdir(parents=True, exist_ok=True)
313
+
314
+ blob_client = self.blob_service.get_blob_client(
315
+ container=self.checkpoints_container, blob=blob.name
316
+ )
317
+
318
+ with open(local_path, "wb") as f:
319
+ download_stream = blob_client.download_blob()
320
+ f.write(download_stream.readall())
321
+
322
+ logger.debug(f"Downloaded {blob.name} to {local_path}")
323
+
324
+ def _run_inference(
325
+ self,
326
+ adapter: Any,
327
+ checkpoint_path: Path,
328
+ capture_path: str,
329
+ output_path: Path,
330
+ ) -> None:
331
+ """Run inference and generate comparison HTML.
332
+
333
+ This wraps the comparison generation logic from scripts/compare.py.
334
+ """
335
+ # Import here to avoid circular dependencies
336
+ from openadapt_ml.scripts.compare import generate_comparison
337
+
338
+ logger.info(f"Running inference on {capture_path}...")
339
+
340
+ # Load checkpoint into adapter
341
+ adapter.load_lora_weights(str(checkpoint_path))
342
+
343
+ # Generate comparison
344
+ generate_comparison(
345
+ capture_path=capture_path,
346
+ adapter=adapter,
347
+ output_path=str(output_path),
348
+ )
349
+
350
+ logger.info(f"Comparison saved to {output_path}")
351
+
352
+ def _upload_comparison(self, local_path: Path, blob_name: str) -> None:
353
+ """Upload comparison HTML to blob storage."""
354
+ blob_client = self.blob_service.get_blob_client(
355
+ container=self.comparisons_container, blob=blob_name
356
+ )
357
+
358
+ with open(local_path, "rb") as f:
359
+ blob_client.upload_blob(f, overwrite=True)
360
+
361
+ logger.info(f"Uploaded comparison to {blob_name}")
362
+
363
+
364
+ def main():
365
+ """CLI for Azure async inference."""
366
+ import argparse
367
+
368
+ parser = argparse.ArgumentParser(
369
+ description="Azure async inference queue",
370
+ formatter_class=argparse.RawDescriptionHelpFormatter,
371
+ )
372
+ subparsers = parser.add_subparsers(dest="command", help="Command")
373
+
374
+ # Submit checkpoint for inference
375
+ submit_parser = subparsers.add_parser(
376
+ "inference-submit", help="Submit checkpoint for async inference"
377
+ )
378
+ submit_parser.add_argument(
379
+ "--checkpoint", "-c", required=True, help="Path to checkpoint directory"
380
+ )
381
+ submit_parser.add_argument(
382
+ "--capture", required=True, help="Path to capture data"
383
+ )
384
+ submit_parser.add_argument(
385
+ "--epoch", "-e", type=int, default=0, help="Epoch number"
386
+ )
387
+
388
+ # Start inference worker
389
+ worker_parser = subparsers.add_parser(
390
+ "inference-worker", help="Start inference worker process"
391
+ )
392
+ worker_parser.add_argument(
393
+ "--model",
394
+ "-m",
395
+ default="Qwen/Qwen2.5-VL-3B",
396
+ help="VLM model to use for inference",
397
+ )
398
+
399
+ # Watch for new comparisons
400
+ watch_parser = subparsers.add_parser(
401
+ "inference-watch", help="Watch for new comparison results"
402
+ )
403
+ watch_parser.add_argument(
404
+ "--interval", "-i", type=int, default=10, help="Poll interval in seconds"
405
+ )
406
+
407
+ args = parser.parse_args()
408
+
409
+ if not args.command:
410
+ parser.print_help()
411
+ return
412
+
413
+ # Initialize queue
414
+ queue = AzureInferenceQueue()
415
+
416
+ if args.command == "inference-submit":
417
+ # Submit checkpoint for inference
418
+ print(f"Submitting checkpoint for inference...")
419
+ job = queue.submit_checkpoint(
420
+ checkpoint_path=args.checkpoint,
421
+ capture_path=args.capture,
422
+ epoch=args.epoch,
423
+ )
424
+ print(f"Job submitted: {job.job_id}")
425
+ print(f"Checkpoint uploaded to: {job.checkpoint_blob}")
426
+
427
+ elif args.command == "inference-worker":
428
+ # Start inference worker
429
+ from openadapt_ml.adapters.qwen import QwenVLAdapter
430
+
431
+ print(f"Starting inference worker with model: {args.model}")
432
+ adapter = QwenVLAdapter(model_name=args.model)
433
+ queue.poll_and_process(adapter)
434
+
435
+ elif args.command == "inference-watch":
436
+ # Watch for new comparisons
437
+ queue.watch_comparisons(poll_interval=args.interval)
438
+
439
+
440
+ if __name__ == "__main__":
441
+ main()