openadapt-ml 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. openadapt_ml/__init__.py +0 -0
  2. openadapt_ml/benchmarks/__init__.py +125 -0
  3. openadapt_ml/benchmarks/agent.py +825 -0
  4. openadapt_ml/benchmarks/azure.py +761 -0
  5. openadapt_ml/benchmarks/base.py +366 -0
  6. openadapt_ml/benchmarks/cli.py +884 -0
  7. openadapt_ml/benchmarks/data_collection.py +432 -0
  8. openadapt_ml/benchmarks/runner.py +381 -0
  9. openadapt_ml/benchmarks/waa.py +704 -0
  10. openadapt_ml/cloud/__init__.py +5 -0
  11. openadapt_ml/cloud/azure_inference.py +441 -0
  12. openadapt_ml/cloud/lambda_labs.py +2445 -0
  13. openadapt_ml/cloud/local.py +790 -0
  14. openadapt_ml/config.py +56 -0
  15. openadapt_ml/datasets/__init__.py +0 -0
  16. openadapt_ml/datasets/next_action.py +507 -0
  17. openadapt_ml/evals/__init__.py +23 -0
  18. openadapt_ml/evals/grounding.py +241 -0
  19. openadapt_ml/evals/plot_eval_metrics.py +174 -0
  20. openadapt_ml/evals/trajectory_matching.py +486 -0
  21. openadapt_ml/grounding/__init__.py +45 -0
  22. openadapt_ml/grounding/base.py +236 -0
  23. openadapt_ml/grounding/detector.py +570 -0
  24. openadapt_ml/ingest/__init__.py +43 -0
  25. openadapt_ml/ingest/capture.py +312 -0
  26. openadapt_ml/ingest/loader.py +232 -0
  27. openadapt_ml/ingest/synthetic.py +1102 -0
  28. openadapt_ml/models/__init__.py +0 -0
  29. openadapt_ml/models/api_adapter.py +171 -0
  30. openadapt_ml/models/base_adapter.py +59 -0
  31. openadapt_ml/models/dummy_adapter.py +42 -0
  32. openadapt_ml/models/qwen_vl.py +426 -0
  33. openadapt_ml/runtime/__init__.py +0 -0
  34. openadapt_ml/runtime/policy.py +182 -0
  35. openadapt_ml/schemas/__init__.py +53 -0
  36. openadapt_ml/schemas/sessions.py +122 -0
  37. openadapt_ml/schemas/validation.py +252 -0
  38. openadapt_ml/scripts/__init__.py +0 -0
  39. openadapt_ml/scripts/compare.py +1490 -0
  40. openadapt_ml/scripts/demo_policy.py +62 -0
  41. openadapt_ml/scripts/eval_policy.py +287 -0
  42. openadapt_ml/scripts/make_gif.py +153 -0
  43. openadapt_ml/scripts/prepare_synthetic.py +43 -0
  44. openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
  45. openadapt_ml/scripts/train.py +174 -0
  46. openadapt_ml/training/__init__.py +0 -0
  47. openadapt_ml/training/benchmark_viewer.py +1538 -0
  48. openadapt_ml/training/shared_ui.py +157 -0
  49. openadapt_ml/training/stub_provider.py +276 -0
  50. openadapt_ml/training/trainer.py +2446 -0
  51. openadapt_ml/training/viewer.py +2970 -0
  52. openadapt_ml-0.1.0.dist-info/METADATA +818 -0
  53. openadapt_ml-0.1.0.dist-info/RECORD +55 -0
  54. openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
  55. openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,426 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ from PIL import Image
6
+ import torch
7
+ from peft import LoraConfig, PeftModel, get_peft_model
8
+ from transformers import AutoProcessor, Qwen3VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
9
+
10
+ from openadapt_ml.models.base_adapter import BaseVLMAdapter, get_default_device
11
+
12
+
13
+ def _process_vision_info(messages: List[Dict[str, Any]]) -> tuple[list[list[Any]], list[list[Any]]]:
14
+ """Minimal stand-in for qwen_vl_utils.process_vision_info.
15
+
16
+ For our use case we only need to extract image/video entries from the
17
+ message content structure expected by Qwen-VL examples, where each
18
+ message has a `content` list of dicts with `type` in {"image", "video"}.
19
+
20
+ Returns (image_inputs, video_inputs), each a list-of-lists suitable for
21
+ passing to AutoProcessor.
22
+ """
23
+
24
+ image_inputs: list[list[Any]] = []
25
+ video_inputs: list[list[Any]] = []
26
+
27
+ current_images: list[Any] = []
28
+ current_videos: list[Any] = []
29
+
30
+ for message in messages:
31
+ content = message.get("content", [])
32
+ if not isinstance(content, list):
33
+ continue
34
+ for item in content:
35
+ if not isinstance(item, dict):
36
+ continue
37
+ t = item.get("type")
38
+ if t == "image":
39
+ current_images.append(item.get("image"))
40
+ elif t == "video":
41
+ current_videos.append(item.get("video"))
42
+
43
+ if current_images:
44
+ image_inputs.append(current_images)
45
+ if current_videos:
46
+ video_inputs.append(current_videos)
47
+
48
+ return image_inputs, video_inputs
49
+
50
+
51
+ class QwenVLAdapter(BaseVLMAdapter):
52
+ """Adapter for Qwen-family VLMs using Hugging Face + PEFT.
53
+
54
+ This is a minimal skeleton that:
55
+ - loads a base model + processor
56
+ - optionally applies a LoRA adapter
57
+ - implements the BaseVLMAdapter interface
58
+
59
+ The exact chat/image encoding and loss masking will be filled in
60
+ once we wire a concrete training loop.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ model: torch.nn.Module,
66
+ processor: Any,
67
+ device: Optional[torch.device] = None,
68
+ version: str = "qwen3",
69
+ ) -> None:
70
+ super().__init__(model=model, processor=processor, device=device)
71
+ self.version = version
72
+
73
+ @classmethod
74
+ def from_pretrained(
75
+ cls,
76
+ model_name: str,
77
+ lora_config: Optional[LoraConfig | Dict[str, Any]] = None,
78
+ load_in_4bit: bool = False,
79
+ device: Optional[torch.device] = None,
80
+ max_pixels: Optional[int] = None,
81
+ min_pixels: Optional[int] = None,
82
+ ) -> "QwenVLAdapter":
83
+ """Load base Qwen model + processor and attach optional LoRA adapter.
84
+
85
+ Args:
86
+ max_pixels: Maximum image size in pixels (e.g., 512*512=262144 for faster training).
87
+ If None, uses model default (very large).
88
+ min_pixels: Minimum image size in pixels. If None, uses model default.
89
+ """
90
+
91
+ if "Qwen3-VL" in model_name or "Qwen3VL" in model_name:
92
+ version = "qwen3"
93
+ model_cls = Qwen3VLForConditionalGeneration
94
+ elif "Qwen2.5-VL" in model_name or "Qwen2_5" in model_name:
95
+ version = "qwen2_5"
96
+ model_cls = Qwen2_5_VLForConditionalGeneration
97
+ else:
98
+ raise ValueError(f"Unrecognized Qwen-VL model name: {model_name}")
99
+
100
+ processor = AutoProcessor.from_pretrained(model_name)
101
+
102
+ # Configure image resolution for faster training
103
+ if max_pixels is not None and hasattr(processor, 'image_processor'):
104
+ processor.image_processor.max_pixels = max_pixels
105
+ print(f"Set max_pixels to {max_pixels} ({int(max_pixels**0.5)}x{int(max_pixels**0.5)} approx)")
106
+ if min_pixels is not None and hasattr(processor, 'image_processor'):
107
+ processor.image_processor.min_pixels = min_pixels
108
+
109
+ model_kwargs: Dict[str, Any] = {}
110
+ if load_in_4bit:
111
+ model_kwargs["load_in_4bit"] = True
112
+
113
+ model = model_cls.from_pretrained(model_name, **model_kwargs)
114
+
115
+ # Support two modes for LoRA:
116
+ # - config-based (no weights_path): create a fresh adapter.
117
+ # - weights-based (weights_path in dict): load an existing adapter.
118
+ lora_weights_path: Optional[str] = None
119
+ lora_cfg_clean: Optional[LoraConfig | Dict[str, Any]] = None
120
+
121
+ if lora_config is not None:
122
+ if isinstance(lora_config, dict):
123
+ lora_weights_path = lora_config.get("weights_path")
124
+ lora_cfg_clean = {k: v for k, v in lora_config.items() if k != "weights_path"}
125
+ else:
126
+ lora_cfg_clean = lora_config
127
+
128
+ if lora_weights_path:
129
+ # Load an existing adapter onto the base model.
130
+ model = PeftModel.from_pretrained(model, lora_weights_path)
131
+ elif lora_cfg_clean is not None:
132
+ if isinstance(lora_cfg_clean, dict):
133
+ lora_cfg_clean = LoraConfig(**lora_cfg_clean)
134
+ model = get_peft_model(model, lora_cfg_clean) # type: ignore[arg-type]
135
+
136
+ if device is None:
137
+ device = get_default_device()
138
+
139
+ return cls(model=model, processor=processor, device=device, version=version)
140
+
141
+ def prepare_inputs(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: # type: ignore[override]
142
+ """Convert SFT samples into model inputs for Qwen-VL.
143
+
144
+ Supports true batching by processing multiple samples simultaneously.
145
+ Uses processor.apply_chat_template with padding=True and truncation=True
146
+ for multi-sample tokenization. Computes assistant-only labels per sample.
147
+
148
+ Args:
149
+ batch: List of SFT-style samples, each with "images" and "messages" keys.
150
+
151
+ Returns:
152
+ Dict with input_ids, attention_mask, pixel_values, image_grid_thw, and labels.
153
+ """
154
+
155
+ if self.version == "qwen3":
156
+ # Build batch of messages for all samples
157
+ batch_messages_full: List[List[Dict[str, Any]]] = []
158
+ batch_messages_user_only: List[List[Dict[str, Any]]] = []
159
+
160
+ for sample in batch:
161
+ image_paths = sample["images"]
162
+ if not image_paths:
163
+ raise ValueError("Sample is missing image paths")
164
+ image_path = image_paths[0]
165
+
166
+ messages = sample["messages"]
167
+ user_text = ""
168
+ assistant_text = ""
169
+ for m in messages:
170
+ role = m.get("role")
171
+ if role == "user":
172
+ user_text = m.get("content", "")
173
+ elif role == "assistant":
174
+ assistant_text = m.get("content", "")
175
+
176
+ # Full messages (user + assistant)
177
+ qwen_messages_full: List[Dict[str, Any]] = [
178
+ {
179
+ "role": "user",
180
+ "content": [
181
+ {"type": "image", "image": image_path},
182
+ {"type": "text", "text": user_text},
183
+ ],
184
+ },
185
+ ]
186
+ if assistant_text:
187
+ qwen_messages_full.append({
188
+ "role": "assistant",
189
+ "content": [{"type": "text", "text": assistant_text}],
190
+ })
191
+ batch_messages_full.append(qwen_messages_full)
192
+
193
+ # User-only messages (for label masking)
194
+ qwen_messages_user_only: List[Dict[str, Any]] = [
195
+ {
196
+ "role": "user",
197
+ "content": [
198
+ {"type": "image", "image": image_path},
199
+ {"type": "text", "text": user_text},
200
+ ],
201
+ }
202
+ ]
203
+ batch_messages_user_only.append(qwen_messages_user_only)
204
+
205
+ # Tokenize full batch with padding and truncation
206
+ inputs_full = self.processor.apply_chat_template( # type: ignore[call-arg]
207
+ batch_messages_full,
208
+ tokenize=True,
209
+ add_generation_prompt=False,
210
+ return_dict=True,
211
+ return_tensors="pt",
212
+ padding=True,
213
+ truncation=True,
214
+ )
215
+
216
+ # Tokenize user-only batch for label masking
217
+ inputs_user = self.processor.apply_chat_template( # type: ignore[call-arg]
218
+ batch_messages_user_only,
219
+ tokenize=True,
220
+ add_generation_prompt=True,
221
+ return_dict=True,
222
+ return_tensors="pt",
223
+ padding=True,
224
+ truncation=True,
225
+ )
226
+
227
+ input_ids_full = inputs_full["input_ids"] # [batch_size, seq_len]
228
+ input_ids_user = inputs_user["input_ids"] # [batch_size, seq_len_user]
229
+
230
+ # Initialize labels with full input_ids, then mask per sample
231
+ labels = input_ids_full.clone()
232
+
233
+ # Compute assistant-only labels per sample
234
+ batch_size = input_ids_full.size(0)
235
+ for i in range(batch_size):
236
+ sample = batch[i]
237
+ messages = sample["messages"]
238
+ assistant_text = ""
239
+ for m in messages:
240
+ if m.get("role") == "assistant":
241
+ assistant_text = m.get("content", "")
242
+
243
+ if assistant_text:
244
+ # Find where user prompt ends and assistant response begins
245
+ # The user-only sequence should be a prefix of the full sequence
246
+ full_ids = input_ids_full[i]
247
+ user_ids = input_ids_user[i]
248
+
249
+ # Remove padding from user_ids to find actual sequence length
250
+ # Padding token is typically 0 or a special value
251
+ # For Qwen, we look for the first occurrence of pad token
252
+ pad_token_id = self.processor.tokenizer.pad_token_id
253
+ user_ids_no_pad = user_ids[user_ids != pad_token_id] if pad_token_id is not None else user_ids
254
+ user_len = len(user_ids_no_pad)
255
+
256
+ # Check if user sequence is a prefix of full sequence
257
+ if user_len <= full_ids.size(0):
258
+ # Mask everything except assistant tokens
259
+ labels[i, :] = -100
260
+ # Only supervise on tokens after the user prompt
261
+ labels[i, user_len:] = full_ids[user_len:]
262
+
263
+ # Ensure padding tokens are masked in labels
264
+ if hasattr(self.processor.tokenizer, 'pad_token_id') and self.processor.tokenizer.pad_token_id is not None:
265
+ labels[input_ids_full == self.processor.tokenizer.pad_token_id] = -100
266
+
267
+ inputs_full["labels"] = labels
268
+ return inputs_full
269
+
270
+ else: # qwen2_5
271
+ # Build batch of messages
272
+ batch_messages: List[List[Dict[str, Any]]] = []
273
+ batch_texts: List[str] = []
274
+ all_image_inputs: List[List[Any]] = []
275
+ all_video_inputs: List[List[Any]] = []
276
+
277
+ for sample in batch:
278
+ image_paths = sample["images"]
279
+ if not image_paths:
280
+ raise ValueError("Sample is missing image paths")
281
+ image_path = image_paths[0]
282
+
283
+ messages = sample["messages"]
284
+ user_text = ""
285
+ assistant_text = ""
286
+ for m in messages:
287
+ role = m.get("role")
288
+ if role == "user":
289
+ user_text = m.get("content", "")
290
+ elif role == "assistant":
291
+ assistant_text = m.get("content", "")
292
+
293
+ qwen_messages: List[Dict[str, Any]] = [
294
+ {
295
+ "role": "user",
296
+ "content": [
297
+ {"type": "image", "image": image_path},
298
+ {"type": "text", "text": user_text},
299
+ ],
300
+ }
301
+ ]
302
+ if assistant_text:
303
+ qwen_messages.append({
304
+ "role": "assistant",
305
+ "content": [{"type": "text", "text": assistant_text}],
306
+ })
307
+
308
+ batch_messages.append(qwen_messages)
309
+
310
+ # Convert to text for this sample
311
+ text = self.processor.apply_chat_template( # type: ignore[call-arg]
312
+ qwen_messages,
313
+ tokenize=False,
314
+ add_generation_prompt=False,
315
+ )
316
+ batch_texts.append(text)
317
+
318
+ # Extract vision info
319
+ image_inputs, video_inputs = _process_vision_info(qwen_messages)
320
+ if image_inputs:
321
+ all_image_inputs.extend(image_inputs)
322
+ if video_inputs:
323
+ all_video_inputs.extend(video_inputs)
324
+
325
+ # Process batch with padding
326
+ videos_arg = all_video_inputs if all_video_inputs else None
327
+ inputs = self.processor( # type: ignore[call-arg]
328
+ text=batch_texts,
329
+ images=all_image_inputs if all_image_inputs else None,
330
+ videos=videos_arg,
331
+ padding=True,
332
+ truncation=True,
333
+ return_tensors="pt",
334
+ )
335
+
336
+ # For qwen2_5, use full sequence supervision for now
337
+ # (can be refined to assistant-only if needed)
338
+ input_ids = inputs["input_ids"]
339
+ labels = input_ids.clone()
340
+
341
+ # Mask padding tokens
342
+ if hasattr(self.processor.tokenizer, 'pad_token_id') and self.processor.tokenizer.pad_token_id is not None:
343
+ labels[input_ids == self.processor.tokenizer.pad_token_id] = -100
344
+
345
+ inputs["labels"] = labels
346
+ return inputs
347
+
348
+ def compute_loss(self, inputs: Dict[str, Any]) -> torch.Tensor: # type: ignore[override]
349
+ inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
350
+ outputs = self.model(**inputs)
351
+ # Hugging Face causal LM models return `loss` when `labels` are provided.
352
+ return outputs.loss # type: ignore[no-any-return]
353
+
354
+ def generate(self, sample: Dict[str, Any], max_new_tokens: int = 64) -> str: # type: ignore[override]
355
+ """Generate assistant text for a single SFT-style sample.
356
+
357
+ We pass system + user messages to the chat template with
358
+ `add_generation_prompt=True` and let the model generate the
359
+ assistant continuation.
360
+ """
361
+
362
+ image_paths = sample["images"]
363
+ if not image_paths:
364
+ raise ValueError("Sample is missing image paths")
365
+ image_path = image_paths[0]
366
+
367
+ messages = sample["messages"]
368
+ user_text = ""
369
+ for m in messages:
370
+ if m.get("role") == "user":
371
+ user_text = m.get("content", "")
372
+
373
+ qwen_messages: List[Dict[str, Any]] = [
374
+ {
375
+ "role": "user",
376
+ "content": [
377
+ {"type": "image", "image": image_path},
378
+ {"type": "text", "text": user_text},
379
+ ],
380
+ }
381
+ ]
382
+
383
+ if self.version == "qwen3":
384
+ inputs = self.processor.apply_chat_template( # type: ignore[call-arg]
385
+ qwen_messages,
386
+ tokenize=True,
387
+ add_generation_prompt=True,
388
+ return_dict=True,
389
+ return_tensors="pt",
390
+ ).to(self.device)
391
+ else:
392
+ text = self.processor.apply_chat_template( # type: ignore[call-arg]
393
+ qwen_messages,
394
+ tokenize=False,
395
+ add_generation_prompt=True,
396
+ )
397
+ image_inputs, video_inputs = _process_vision_info(qwen_messages)
398
+ videos_arg = video_inputs if video_inputs else None
399
+ inputs = self.processor( # type: ignore[call-arg]
400
+ text=[text],
401
+ images=image_inputs,
402
+ videos=videos_arg,
403
+ padding=True,
404
+ return_tensors="pt",
405
+ ).to(self.device)
406
+
407
+ with torch.no_grad():
408
+ generation = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
409
+
410
+ # Decode only the GENERATED tokens (not the input prompt)
411
+ # This is critical - otherwise we return "user Goal: ... assistant ..." instead of just the response
412
+ input_len = inputs["input_ids"].shape[1]
413
+ generated_ids = generation[:, input_len:]
414
+ text = self.processor.batch_decode( # type: ignore[call-arg]
415
+ generated_ids,
416
+ skip_special_tokens=True,
417
+ )[0]
418
+ return text
419
+
420
+ def save_checkpoint(self, path: str) -> None:
421
+ """Save the LoRA adapter weights to a directory."""
422
+ from pathlib import Path
423
+ save_path = Path(path)
424
+ save_path.mkdir(parents=True, exist_ok=True)
425
+ # Save the PEFT adapter (LoRA weights only, not base model)
426
+ self.model.save_pretrained(str(save_path))
File without changes
@@ -0,0 +1,182 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import re
5
+ from dataclasses import dataclass
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+
8
+ from PIL import Image
9
+
10
+ from openadapt_ml.models.base_adapter import BaseVLMAdapter
11
+ from openadapt_ml.schemas.sessions import Action
12
+
13
+
14
+ # Coordinate-based DSL patterns
15
+ _CLICK_RE = re.compile(r"CLICK\(x=([0-9]*\.?[0-9]+),\s*y=([0-9]*\.?[0-9]+)\)")
16
+ _TYPE_RE = re.compile(r'TYPE\(text="([^"\\]*(?:\\.[^"\\]*)*)"\)')
17
+ _WAIT_RE = re.compile(r"\bWAIT\s*\(\s*\)")
18
+ _DONE_RE = re.compile(r"\bDONE\s*\(\s*\)")
19
+
20
+ # SoM (Set-of-Marks) index-based DSL patterns
21
+ _CLICK_SOM_RE = re.compile(r"CLICK\(\[(\d+)\]\)")
22
+ _TYPE_SOM_RE = re.compile(r'TYPE\(\[(\d+)\],\s*["\']([^"\']*(?:\\.[^"\']*)*)["\']\)')
23
+ _TYPE_SOM_SIMPLE_RE = re.compile(r'TYPE\(["\']([^"\']*(?:\\.[^"\']*)*)["\']\)')
24
+
25
+
26
+ @dataclass
27
+ class PolicyOutput:
28
+ """Result of a single policy step."""
29
+ action: Action
30
+ thought: Optional[str] = None
31
+ state: Optional[Dict[str, Any]] = None
32
+ raw_text: str = ""
33
+
34
+
35
+ def parse_thought_state_action(text: str) -> Tuple[Optional[str], Optional[Dict[str, Any]], str]:
36
+ """Parse Thought / State / Action blocks from model output.
37
+
38
+ Expected format:
39
+ Thought: [reasoning]
40
+ State: {"success": false, "progress": 0.5, ...}
41
+ Action: CLICK(x=0.42, y=0.31)
42
+
43
+ Returns:
44
+ (thought, state, action_str):
45
+ - thought: Content after 'Thought:' up to 'State:' or 'Action:'
46
+ - state: Parsed JSON dict from 'State:' line, or None
47
+ - action_str: Content after 'Action:', or whole text if missing
48
+
49
+ Note: We look for the LAST occurrence of 'Action:' to handle cases where
50
+ the user prompt template also contains 'Action:' placeholders.
51
+ """
52
+ thought: Optional[str] = None
53
+ state: Optional[Dict[str, Any]] = None
54
+ action_str: str = text.strip()
55
+
56
+ # Extract Thought - find the LAST occurrence (model's response, not template)
57
+ thought_matches = list(re.finditer(r"Thought:\s*(.+?)(?=State:|Action:|$)", text, re.DOTALL | re.IGNORECASE))
58
+ if thought_matches:
59
+ thought = thought_matches[-1].group(1).strip()
60
+
61
+ # Extract State (JSON on same line or next line) - last occurrence
62
+ state_matches = list(re.finditer(r"State:\s*(\{.*?\})", text, re.DOTALL | re.IGNORECASE))
63
+ if state_matches:
64
+ try:
65
+ state = json.loads(state_matches[-1].group(1))
66
+ except json.JSONDecodeError:
67
+ state = None
68
+
69
+ # Extract Action - find the LAST occurrence to get the model's actual action
70
+ # not the placeholder in the prompt template
71
+ action_matches = list(re.finditer(r"Action:\s*(.+?)(?=\n|$)", text, re.IGNORECASE))
72
+ if action_matches:
73
+ action_str = action_matches[-1].group(1).strip()
74
+
75
+ return thought, state, action_str
76
+
77
+
78
+ class AgentPolicy:
79
+ """Runtime policy wrapper around a trained VLM adapter.
80
+
81
+ Formats goal-conditioned inputs and parses textual actions into
82
+ structured `Action` objects.
83
+ """
84
+
85
+ def __init__(self, adapter: BaseVLMAdapter) -> None:
86
+ self.adapter = adapter
87
+
88
+ def _build_sample(self, image: Image.Image, goal: str) -> Dict[str, Any]:
89
+ # For runtime we keep the same structure as SFT samples but use
90
+ # an in-memory image. The adapter's generate method currently expects
91
+ # paths, so we require the caller to supply a path-based sample. For
92
+ # now, we save responsibility for image loading to the caller; this
93
+ # method is kept for future extensibility.
94
+ raise NotImplementedError(
95
+ "AgentPolicy._build_sample is not used directly; pass a sample dict "
96
+ "compatible with the adapter's `generate` method."
97
+ )
98
+
99
+ def _parse_action(self, text: str) -> Action:
100
+ """Parse a DSL action string into an Action object.
101
+
102
+ Supported formats (coordinate-based):
103
+ - CLICK(x=<float>, y=<float>)
104
+ - TYPE(text="...")
105
+
106
+ Supported formats (SoM index-based):
107
+ - CLICK([N])
108
+ - TYPE([N], "text")
109
+ - TYPE("text")
110
+
111
+ Common formats:
112
+ - WAIT()
113
+ - DONE()
114
+
115
+ Returns Action(type="failed") if no valid action is found.
116
+ """
117
+ # Try SoM patterns first (index-based)
118
+ # CLICK([N])
119
+ m = _CLICK_SOM_RE.search(text)
120
+ if m:
121
+ idx = int(m.group(1))
122
+ return Action(type="click", element_index=idx)
123
+
124
+ # TYPE([N], "text")
125
+ m = _TYPE_SOM_RE.search(text)
126
+ if m:
127
+ idx = int(m.group(1))
128
+ raw_text = m.group(2)
129
+ unescaped = raw_text.replace('\\"', '"').replace("\\\\", "\\")
130
+ return Action(type="type", text=unescaped, element_index=idx)
131
+
132
+ # TYPE("text") - SoM style without index
133
+ m = _TYPE_SOM_SIMPLE_RE.search(text)
134
+ if m:
135
+ raw_text = m.group(1)
136
+ unescaped = raw_text.replace('\\"', '"').replace("\\\\", "\\")
137
+ return Action(type="type", text=unescaped)
138
+
139
+ # Coordinate-based patterns
140
+ # CLICK(x=..., y=...)
141
+ m = _CLICK_RE.search(text)
142
+ if m:
143
+ x = float(m.group(1))
144
+ y = float(m.group(2))
145
+ # Clamp to [0, 1]
146
+ x = max(0.0, min(1.0, x))
147
+ y = max(0.0, min(1.0, y))
148
+ return Action(type="click", x=x, y=y)
149
+
150
+ # TYPE(text="...")
151
+ m = _TYPE_RE.search(text)
152
+ if m:
153
+ # Unescape the text content
154
+ raw_text = m.group(1)
155
+ unescaped = raw_text.replace('\\"', '"').replace("\\\\", "\\")
156
+ return Action(type="type", text=unescaped)
157
+
158
+ # WAIT()
159
+ if _WAIT_RE.search(text):
160
+ return Action(type="wait")
161
+
162
+ # DONE()
163
+ if _DONE_RE.search(text):
164
+ return Action(type="done")
165
+
166
+ # Fallback
167
+ return Action(type="failed", raw={"text": text})
168
+
169
+ def predict_action_from_sample(
170
+ self, sample: Dict[str, Any], max_new_tokens: int = 150
171
+ ) -> Tuple[Action, Optional[str], Optional[Dict[str, Any]], str]:
172
+ """Run the adapter on a pre-built SFT-style sample and parse the result.
173
+
174
+ Returns (Action, thought, state, raw_text) where:
175
+ - thought: Reasoning text from 'Thought:' block
176
+ - state: Parsed JSON dict from 'State:' block (may contain 'success' bool)
177
+ - raw_text: The raw model output text for debugging
178
+ """
179
+ text = self.adapter.generate(sample, max_new_tokens=max_new_tokens)
180
+ thought, state, action_str = parse_thought_state_action(text)
181
+ action = self._parse_action(action_str)
182
+ return action, thought, state, text
@@ -0,0 +1,53 @@
1
+ """Schema definitions and validation for openadapt-ml.
2
+
3
+ Core data structures:
4
+ - Action: A single GUI action (click, type, scroll, etc.)
5
+ - Observation: GUI state observation (screenshot, accessibility tree, etc.)
6
+ - Step: One timestep containing observation + action
7
+ - Episode: A single task attempt / workflow instance
8
+ - Session: Container for multiple episodes
9
+
10
+ Validation:
11
+ - validate_episode(): Validate an Episode object
12
+ - validate_session(): Validate a Session object
13
+ - validate_episodes(): Validate a list of Episodes
14
+ - ValidationError: Raised on schema violations
15
+ """
16
+
17
+ from openadapt_ml.schemas.sessions import (
18
+ Action,
19
+ ActionType,
20
+ Episode,
21
+ Observation,
22
+ Session,
23
+ Step,
24
+ )
25
+ from openadapt_ml.schemas.validation import (
26
+ ValidationError,
27
+ summarize_episodes,
28
+ validate_action,
29
+ validate_episode,
30
+ validate_episodes,
31
+ validate_observation,
32
+ validate_session,
33
+ validate_step,
34
+ )
35
+
36
+ __all__ = [
37
+ # Core types
38
+ "Action",
39
+ "ActionType",
40
+ "Episode",
41
+ "Observation",
42
+ "Session",
43
+ "Step",
44
+ # Validation
45
+ "ValidationError",
46
+ "validate_action",
47
+ "validate_episode",
48
+ "validate_episodes",
49
+ "validate_observation",
50
+ "validate_session",
51
+ "validate_step",
52
+ "summarize_episodes",
53
+ ]