openadapt-ml 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/__init__.py +0 -0
- openadapt_ml/benchmarks/__init__.py +125 -0
- openadapt_ml/benchmarks/agent.py +825 -0
- openadapt_ml/benchmarks/azure.py +761 -0
- openadapt_ml/benchmarks/base.py +366 -0
- openadapt_ml/benchmarks/cli.py +884 -0
- openadapt_ml/benchmarks/data_collection.py +432 -0
- openadapt_ml/benchmarks/runner.py +381 -0
- openadapt_ml/benchmarks/waa.py +704 -0
- openadapt_ml/cloud/__init__.py +5 -0
- openadapt_ml/cloud/azure_inference.py +441 -0
- openadapt_ml/cloud/lambda_labs.py +2445 -0
- openadapt_ml/cloud/local.py +790 -0
- openadapt_ml/config.py +56 -0
- openadapt_ml/datasets/__init__.py +0 -0
- openadapt_ml/datasets/next_action.py +507 -0
- openadapt_ml/evals/__init__.py +23 -0
- openadapt_ml/evals/grounding.py +241 -0
- openadapt_ml/evals/plot_eval_metrics.py +174 -0
- openadapt_ml/evals/trajectory_matching.py +486 -0
- openadapt_ml/grounding/__init__.py +45 -0
- openadapt_ml/grounding/base.py +236 -0
- openadapt_ml/grounding/detector.py +570 -0
- openadapt_ml/ingest/__init__.py +43 -0
- openadapt_ml/ingest/capture.py +312 -0
- openadapt_ml/ingest/loader.py +232 -0
- openadapt_ml/ingest/synthetic.py +1102 -0
- openadapt_ml/models/__init__.py +0 -0
- openadapt_ml/models/api_adapter.py +171 -0
- openadapt_ml/models/base_adapter.py +59 -0
- openadapt_ml/models/dummy_adapter.py +42 -0
- openadapt_ml/models/qwen_vl.py +426 -0
- openadapt_ml/runtime/__init__.py +0 -0
- openadapt_ml/runtime/policy.py +182 -0
- openadapt_ml/schemas/__init__.py +53 -0
- openadapt_ml/schemas/sessions.py +122 -0
- openadapt_ml/schemas/validation.py +252 -0
- openadapt_ml/scripts/__init__.py +0 -0
- openadapt_ml/scripts/compare.py +1490 -0
- openadapt_ml/scripts/demo_policy.py +62 -0
- openadapt_ml/scripts/eval_policy.py +287 -0
- openadapt_ml/scripts/make_gif.py +153 -0
- openadapt_ml/scripts/prepare_synthetic.py +43 -0
- openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
- openadapt_ml/scripts/train.py +174 -0
- openadapt_ml/training/__init__.py +0 -0
- openadapt_ml/training/benchmark_viewer.py +1538 -0
- openadapt_ml/training/shared_ui.py +157 -0
- openadapt_ml/training/stub_provider.py +276 -0
- openadapt_ml/training/trainer.py +2446 -0
- openadapt_ml/training/viewer.py +2970 -0
- openadapt_ml-0.1.0.dist-info/METADATA +818 -0
- openadapt_ml-0.1.0.dist-info/RECORD +55 -0
- openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
- openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,441 @@
|
|
|
1
|
+
"""Azure async inference queue for live training feedback.
|
|
2
|
+
|
|
3
|
+
This module implements Phase 2 of the live inference design:
|
|
4
|
+
- Training instance uploads checkpoints to Azure Blob Storage
|
|
5
|
+
- Triggers Azure Queue Storage message with checkpoint info
|
|
6
|
+
- Inference worker polls queue and runs inference
|
|
7
|
+
- Results uploaded to Blob Storage for dashboard to display
|
|
8
|
+
|
|
9
|
+
Architecture:
|
|
10
|
+
Training GPU → Blob (checkpoints) → Queue (jobs) → Inference GPU → Blob (comparisons)
|
|
11
|
+
|
|
12
|
+
Usage:
|
|
13
|
+
# Submit checkpoint for async inference (called by trainer)
|
|
14
|
+
uv run python -m openadapt_ml.cloud.azure_inference inference-submit \
|
|
15
|
+
--checkpoint checkpoints/epoch_1 \
|
|
16
|
+
--capture /path/to/capture
|
|
17
|
+
|
|
18
|
+
# Start inference worker (runs on separate instance)
|
|
19
|
+
uv run python -m openadapt_ml.cloud.azure_inference inference-worker \
|
|
20
|
+
--model Qwen/Qwen2.5-VL-3B
|
|
21
|
+
|
|
22
|
+
# Watch for new comparison results
|
|
23
|
+
uv run python -m openadapt_ml.cloud.azure_inference inference-watch
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from __future__ import annotations
|
|
27
|
+
|
|
28
|
+
import json
|
|
29
|
+
import logging
|
|
30
|
+
import time
|
|
31
|
+
from dataclasses import dataclass
|
|
32
|
+
from pathlib import Path
|
|
33
|
+
from typing import Any
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class InferenceJob:
|
|
40
|
+
"""Inference job metadata."""
|
|
41
|
+
|
|
42
|
+
job_id: str
|
|
43
|
+
checkpoint_blob: str # Path in blob storage
|
|
44
|
+
checkpoint_epoch: int
|
|
45
|
+
capture_path: str # Path to capture data
|
|
46
|
+
submitted_at: float
|
|
47
|
+
status: str = "pending" # pending, running, completed, failed
|
|
48
|
+
output_blob: str | None = None # Path to comparison HTML in blob
|
|
49
|
+
error: str | None = None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class AzureInferenceQueue:
|
|
53
|
+
"""Manages async inference jobs via Azure Queue Storage.
|
|
54
|
+
|
|
55
|
+
This class provides the core infrastructure for async inference during training:
|
|
56
|
+
1. submit_checkpoint() - Upload checkpoint to blob, queue inference job
|
|
57
|
+
2. poll_and_process() - Worker loop that processes jobs from queue
|
|
58
|
+
3. watch_comparisons() - Poll for new comparison results
|
|
59
|
+
|
|
60
|
+
Authentication uses existing AzureConfig pattern from benchmarks/azure.py.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
storage_connection_string: str | None = None,
|
|
66
|
+
queue_name: str = "inference-jobs",
|
|
67
|
+
checkpoints_container: str = "checkpoints",
|
|
68
|
+
comparisons_container: str = "comparisons",
|
|
69
|
+
):
|
|
70
|
+
"""Initialize Azure inference queue.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
storage_connection_string: Azure Storage connection string (from .env)
|
|
74
|
+
queue_name: Queue name for inference jobs
|
|
75
|
+
checkpoints_container: Blob container for checkpoints
|
|
76
|
+
comparisons_container: Blob container for comparison results
|
|
77
|
+
|
|
78
|
+
Raises:
|
|
79
|
+
ImportError: If azure-storage-blob or azure-storage-queue not installed
|
|
80
|
+
ValueError: If connection string not provided and not in settings
|
|
81
|
+
"""
|
|
82
|
+
# Lazy import Azure SDK
|
|
83
|
+
try:
|
|
84
|
+
from azure.storage.blob import BlobServiceClient
|
|
85
|
+
from azure.storage.queue import QueueClient
|
|
86
|
+
|
|
87
|
+
self._BlobServiceClient = BlobServiceClient
|
|
88
|
+
self._QueueClient = QueueClient
|
|
89
|
+
except ImportError as e:
|
|
90
|
+
raise ImportError(
|
|
91
|
+
"Azure Storage SDK not installed. Install with: "
|
|
92
|
+
"pip install azure-storage-blob azure-storage-queue"
|
|
93
|
+
) from e
|
|
94
|
+
|
|
95
|
+
# Get connection string from settings if not provided
|
|
96
|
+
if not storage_connection_string:
|
|
97
|
+
from openadapt_ml.config import settings
|
|
98
|
+
|
|
99
|
+
storage_connection_string = settings.azure_storage_connection_string
|
|
100
|
+
if not storage_connection_string:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
"AZURE_STORAGE_CONNECTION_STRING not set. "
|
|
103
|
+
"Run 'python scripts/setup_azure.py' to configure Azure storage."
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
self.connection_string = storage_connection_string
|
|
107
|
+
self.queue_name = queue_name
|
|
108
|
+
self.checkpoints_container = checkpoints_container
|
|
109
|
+
self.comparisons_container = comparisons_container
|
|
110
|
+
|
|
111
|
+
# Initialize clients
|
|
112
|
+
self.blob_service = self._BlobServiceClient.from_connection_string(
|
|
113
|
+
storage_connection_string
|
|
114
|
+
)
|
|
115
|
+
self.queue_client = self._QueueClient.from_connection_string(
|
|
116
|
+
storage_connection_string, queue_name
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
logger.info(f"Initialized Azure inference queue: {queue_name}")
|
|
120
|
+
|
|
121
|
+
def submit_checkpoint(
|
|
122
|
+
self, checkpoint_path: str | Path, capture_path: str | Path, epoch: int = 0
|
|
123
|
+
) -> InferenceJob:
|
|
124
|
+
"""Upload checkpoint and queue inference job.
|
|
125
|
+
|
|
126
|
+
This is called by the trainer after saving a checkpoint.
|
|
127
|
+
It uploads the checkpoint to blob storage and adds a message to the queue.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
checkpoint_path: Local path to checkpoint directory
|
|
131
|
+
capture_path: Path to capture data (for inference)
|
|
132
|
+
epoch: Epoch number for this checkpoint
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
InferenceJob with job metadata
|
|
136
|
+
"""
|
|
137
|
+
checkpoint_path = Path(checkpoint_path)
|
|
138
|
+
capture_path = Path(capture_path)
|
|
139
|
+
|
|
140
|
+
# Generate unique job ID
|
|
141
|
+
job_id = f"inference_{int(time.time())}_{epoch}"
|
|
142
|
+
|
|
143
|
+
# Upload checkpoint to blob storage
|
|
144
|
+
blob_name = f"checkpoints/epoch_{epoch}/{checkpoint_path.name}"
|
|
145
|
+
logger.info(f"Uploading checkpoint to {blob_name}...")
|
|
146
|
+
|
|
147
|
+
checkpoint_blob_client = self.blob_service.get_blob_client(
|
|
148
|
+
container=self.checkpoints_container, blob=blob_name
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Upload all files in checkpoint directory
|
|
152
|
+
for file_path in checkpoint_path.rglob("*"):
|
|
153
|
+
if file_path.is_file():
|
|
154
|
+
relative_path = file_path.relative_to(checkpoint_path)
|
|
155
|
+
file_blob_name = f"checkpoints/epoch_{epoch}/{relative_path}"
|
|
156
|
+
file_blob_client = self.blob_service.get_blob_client(
|
|
157
|
+
container=self.checkpoints_container, blob=file_blob_name
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
with open(file_path, "rb") as f:
|
|
161
|
+
file_blob_client.upload_blob(f, overwrite=True)
|
|
162
|
+
logger.debug(f" Uploaded {file_blob_name}")
|
|
163
|
+
|
|
164
|
+
# Create job metadata
|
|
165
|
+
job = InferenceJob(
|
|
166
|
+
job_id=job_id,
|
|
167
|
+
checkpoint_blob=blob_name,
|
|
168
|
+
checkpoint_epoch=epoch,
|
|
169
|
+
capture_path=str(capture_path),
|
|
170
|
+
submitted_at=time.time(),
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Queue inference job
|
|
174
|
+
job_message = {
|
|
175
|
+
"job_id": job.job_id,
|
|
176
|
+
"checkpoint_blob": job.checkpoint_blob,
|
|
177
|
+
"checkpoint_epoch": job.checkpoint_epoch,
|
|
178
|
+
"capture_path": job.capture_path,
|
|
179
|
+
"submitted_at": job.submitted_at,
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
self.queue_client.send_message(json.dumps(job_message))
|
|
183
|
+
logger.info(f"Queued inference job: {job_id}")
|
|
184
|
+
|
|
185
|
+
return job
|
|
186
|
+
|
|
187
|
+
def poll_and_process(
|
|
188
|
+
self,
|
|
189
|
+
adapter,
|
|
190
|
+
max_messages: int = 1,
|
|
191
|
+
visibility_timeout: int = 3600,
|
|
192
|
+
) -> None:
|
|
193
|
+
"""Worker: poll queue and run inference.
|
|
194
|
+
|
|
195
|
+
This is the main worker loop that runs on a separate GPU instance.
|
|
196
|
+
It continuously polls the queue for new jobs, downloads checkpoints,
|
|
197
|
+
runs inference, and uploads results.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
adapter: VLM adapter for inference (e.g., Qwen adapter)
|
|
201
|
+
max_messages: Maximum messages to process per iteration
|
|
202
|
+
visibility_timeout: How long to hide message while processing (seconds)
|
|
203
|
+
"""
|
|
204
|
+
logger.info("Starting inference worker...")
|
|
205
|
+
|
|
206
|
+
while True:
|
|
207
|
+
try:
|
|
208
|
+
# Poll for messages
|
|
209
|
+
messages = self.queue_client.receive_messages(
|
|
210
|
+
messages_per_page=max_messages,
|
|
211
|
+
visibility_timeout=visibility_timeout,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
for msg in messages:
|
|
215
|
+
try:
|
|
216
|
+
# Parse job metadata
|
|
217
|
+
job_data = json.loads(msg.content)
|
|
218
|
+
job = InferenceJob(**job_data)
|
|
219
|
+
|
|
220
|
+
logger.info(f"Processing job: {job.job_id}")
|
|
221
|
+
job.status = "running"
|
|
222
|
+
|
|
223
|
+
# Download checkpoint from blob
|
|
224
|
+
checkpoint_dir = Path(f"/tmp/checkpoints/{job.job_id}")
|
|
225
|
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
226
|
+
|
|
227
|
+
self._download_checkpoint(job.checkpoint_blob, checkpoint_dir)
|
|
228
|
+
|
|
229
|
+
# Run inference
|
|
230
|
+
output_path = Path(f"/tmp/comparisons/{job.job_id}.html")
|
|
231
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
232
|
+
|
|
233
|
+
self._run_inference(
|
|
234
|
+
adapter=adapter,
|
|
235
|
+
checkpoint_path=checkpoint_dir,
|
|
236
|
+
capture_path=job.capture_path,
|
|
237
|
+
output_path=output_path,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Upload result to blob
|
|
241
|
+
output_blob = (
|
|
242
|
+
f"comparisons/epoch_{job.checkpoint_epoch}_comparison.html"
|
|
243
|
+
)
|
|
244
|
+
self._upload_comparison(output_path, output_blob)
|
|
245
|
+
|
|
246
|
+
job.status = "completed"
|
|
247
|
+
job.output_blob = output_blob
|
|
248
|
+
|
|
249
|
+
logger.info(f"Job completed: {job.job_id}")
|
|
250
|
+
|
|
251
|
+
# Delete message from queue
|
|
252
|
+
self.queue_client.delete_message(msg)
|
|
253
|
+
|
|
254
|
+
except Exception as e:
|
|
255
|
+
logger.error(f"Job failed: {e}")
|
|
256
|
+
job.status = "failed"
|
|
257
|
+
job.error = str(e)
|
|
258
|
+
# Don't delete message - let it become visible again for retry
|
|
259
|
+
|
|
260
|
+
except Exception as e:
|
|
261
|
+
logger.error(f"Worker error: {e}")
|
|
262
|
+
time.sleep(10) # Back off on errors
|
|
263
|
+
|
|
264
|
+
# Poll interval
|
|
265
|
+
time.sleep(5)
|
|
266
|
+
|
|
267
|
+
def watch_comparisons(self, poll_interval: int = 10) -> None:
|
|
268
|
+
"""Poll for new comparison results.
|
|
269
|
+
|
|
270
|
+
This can be used by the dashboard to discover new comparison files.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
poll_interval: How often to check for new files (seconds)
|
|
274
|
+
"""
|
|
275
|
+
logger.info("Watching for new comparison results...")
|
|
276
|
+
seen_blobs = set()
|
|
277
|
+
|
|
278
|
+
while True:
|
|
279
|
+
try:
|
|
280
|
+
# List blobs in comparisons container
|
|
281
|
+
container_client = self.blob_service.get_container_client(
|
|
282
|
+
self.comparisons_container
|
|
283
|
+
)
|
|
284
|
+
blobs = container_client.list_blobs()
|
|
285
|
+
|
|
286
|
+
for blob in blobs:
|
|
287
|
+
if blob.name not in seen_blobs:
|
|
288
|
+
logger.info(f"New comparison: {blob.name}")
|
|
289
|
+
seen_blobs.add(blob.name)
|
|
290
|
+
|
|
291
|
+
# Optionally download and open in browser
|
|
292
|
+
# self._download_and_open(blob.name)
|
|
293
|
+
|
|
294
|
+
except Exception as e:
|
|
295
|
+
logger.error(f"Watch error: {e}")
|
|
296
|
+
|
|
297
|
+
time.sleep(poll_interval)
|
|
298
|
+
|
|
299
|
+
def _download_checkpoint(self, blob_name: str, local_dir: Path) -> None:
|
|
300
|
+
"""Download checkpoint from blob storage."""
|
|
301
|
+
# List all files under the checkpoint blob prefix
|
|
302
|
+
container_client = self.blob_service.get_container_client(
|
|
303
|
+
self.checkpoints_container
|
|
304
|
+
)
|
|
305
|
+
blob_prefix = "/".join(blob_name.split("/")[:-1]) # Get directory prefix
|
|
306
|
+
|
|
307
|
+
blobs = container_client.list_blobs(name_starts_with=blob_prefix)
|
|
308
|
+
|
|
309
|
+
for blob in blobs:
|
|
310
|
+
# Download each file
|
|
311
|
+
local_path = local_dir / Path(blob.name).relative_to(blob_prefix)
|
|
312
|
+
local_path.parent.mkdir(parents=True, exist_ok=True)
|
|
313
|
+
|
|
314
|
+
blob_client = self.blob_service.get_blob_client(
|
|
315
|
+
container=self.checkpoints_container, blob=blob.name
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
with open(local_path, "wb") as f:
|
|
319
|
+
download_stream = blob_client.download_blob()
|
|
320
|
+
f.write(download_stream.readall())
|
|
321
|
+
|
|
322
|
+
logger.debug(f"Downloaded {blob.name} to {local_path}")
|
|
323
|
+
|
|
324
|
+
def _run_inference(
|
|
325
|
+
self,
|
|
326
|
+
adapter: Any,
|
|
327
|
+
checkpoint_path: Path,
|
|
328
|
+
capture_path: str,
|
|
329
|
+
output_path: Path,
|
|
330
|
+
) -> None:
|
|
331
|
+
"""Run inference and generate comparison HTML.
|
|
332
|
+
|
|
333
|
+
This wraps the comparison generation logic from scripts/compare.py.
|
|
334
|
+
"""
|
|
335
|
+
# Import here to avoid circular dependencies
|
|
336
|
+
from openadapt_ml.scripts.compare import generate_comparison
|
|
337
|
+
|
|
338
|
+
logger.info(f"Running inference on {capture_path}...")
|
|
339
|
+
|
|
340
|
+
# Load checkpoint into adapter
|
|
341
|
+
adapter.load_lora_weights(str(checkpoint_path))
|
|
342
|
+
|
|
343
|
+
# Generate comparison
|
|
344
|
+
generate_comparison(
|
|
345
|
+
capture_path=capture_path,
|
|
346
|
+
adapter=adapter,
|
|
347
|
+
output_path=str(output_path),
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
logger.info(f"Comparison saved to {output_path}")
|
|
351
|
+
|
|
352
|
+
def _upload_comparison(self, local_path: Path, blob_name: str) -> None:
|
|
353
|
+
"""Upload comparison HTML to blob storage."""
|
|
354
|
+
blob_client = self.blob_service.get_blob_client(
|
|
355
|
+
container=self.comparisons_container, blob=blob_name
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
with open(local_path, "rb") as f:
|
|
359
|
+
blob_client.upload_blob(f, overwrite=True)
|
|
360
|
+
|
|
361
|
+
logger.info(f"Uploaded comparison to {blob_name}")
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def main():
|
|
365
|
+
"""CLI for Azure async inference."""
|
|
366
|
+
import argparse
|
|
367
|
+
|
|
368
|
+
parser = argparse.ArgumentParser(
|
|
369
|
+
description="Azure async inference queue",
|
|
370
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
371
|
+
)
|
|
372
|
+
subparsers = parser.add_subparsers(dest="command", help="Command")
|
|
373
|
+
|
|
374
|
+
# Submit checkpoint for inference
|
|
375
|
+
submit_parser = subparsers.add_parser(
|
|
376
|
+
"inference-submit", help="Submit checkpoint for async inference"
|
|
377
|
+
)
|
|
378
|
+
submit_parser.add_argument(
|
|
379
|
+
"--checkpoint", "-c", required=True, help="Path to checkpoint directory"
|
|
380
|
+
)
|
|
381
|
+
submit_parser.add_argument(
|
|
382
|
+
"--capture", required=True, help="Path to capture data"
|
|
383
|
+
)
|
|
384
|
+
submit_parser.add_argument(
|
|
385
|
+
"--epoch", "-e", type=int, default=0, help="Epoch number"
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
# Start inference worker
|
|
389
|
+
worker_parser = subparsers.add_parser(
|
|
390
|
+
"inference-worker", help="Start inference worker process"
|
|
391
|
+
)
|
|
392
|
+
worker_parser.add_argument(
|
|
393
|
+
"--model",
|
|
394
|
+
"-m",
|
|
395
|
+
default="Qwen/Qwen2.5-VL-3B",
|
|
396
|
+
help="VLM model to use for inference",
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
# Watch for new comparisons
|
|
400
|
+
watch_parser = subparsers.add_parser(
|
|
401
|
+
"inference-watch", help="Watch for new comparison results"
|
|
402
|
+
)
|
|
403
|
+
watch_parser.add_argument(
|
|
404
|
+
"--interval", "-i", type=int, default=10, help="Poll interval in seconds"
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
args = parser.parse_args()
|
|
408
|
+
|
|
409
|
+
if not args.command:
|
|
410
|
+
parser.print_help()
|
|
411
|
+
return
|
|
412
|
+
|
|
413
|
+
# Initialize queue
|
|
414
|
+
queue = AzureInferenceQueue()
|
|
415
|
+
|
|
416
|
+
if args.command == "inference-submit":
|
|
417
|
+
# Submit checkpoint for inference
|
|
418
|
+
print(f"Submitting checkpoint for inference...")
|
|
419
|
+
job = queue.submit_checkpoint(
|
|
420
|
+
checkpoint_path=args.checkpoint,
|
|
421
|
+
capture_path=args.capture,
|
|
422
|
+
epoch=args.epoch,
|
|
423
|
+
)
|
|
424
|
+
print(f"Job submitted: {job.job_id}")
|
|
425
|
+
print(f"Checkpoint uploaded to: {job.checkpoint_blob}")
|
|
426
|
+
|
|
427
|
+
elif args.command == "inference-worker":
|
|
428
|
+
# Start inference worker
|
|
429
|
+
from openadapt_ml.adapters.qwen import QwenVLAdapter
|
|
430
|
+
|
|
431
|
+
print(f"Starting inference worker with model: {args.model}")
|
|
432
|
+
adapter = QwenVLAdapter(model_name=args.model)
|
|
433
|
+
queue.poll_and_process(adapter)
|
|
434
|
+
|
|
435
|
+
elif args.command == "inference-watch":
|
|
436
|
+
# Watch for new comparisons
|
|
437
|
+
queue.watch_comparisons(poll_interval=args.interval)
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
if __name__ == "__main__":
|
|
441
|
+
main()
|