openadapt-ml 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/__init__.py +0 -0
- openadapt_ml/benchmarks/__init__.py +125 -0
- openadapt_ml/benchmarks/agent.py +825 -0
- openadapt_ml/benchmarks/azure.py +761 -0
- openadapt_ml/benchmarks/base.py +366 -0
- openadapt_ml/benchmarks/cli.py +884 -0
- openadapt_ml/benchmarks/data_collection.py +432 -0
- openadapt_ml/benchmarks/runner.py +381 -0
- openadapt_ml/benchmarks/waa.py +704 -0
- openadapt_ml/cloud/__init__.py +5 -0
- openadapt_ml/cloud/azure_inference.py +441 -0
- openadapt_ml/cloud/lambda_labs.py +2445 -0
- openadapt_ml/cloud/local.py +790 -0
- openadapt_ml/config.py +56 -0
- openadapt_ml/datasets/__init__.py +0 -0
- openadapt_ml/datasets/next_action.py +507 -0
- openadapt_ml/evals/__init__.py +23 -0
- openadapt_ml/evals/grounding.py +241 -0
- openadapt_ml/evals/plot_eval_metrics.py +174 -0
- openadapt_ml/evals/trajectory_matching.py +486 -0
- openadapt_ml/grounding/__init__.py +45 -0
- openadapt_ml/grounding/base.py +236 -0
- openadapt_ml/grounding/detector.py +570 -0
- openadapt_ml/ingest/__init__.py +43 -0
- openadapt_ml/ingest/capture.py +312 -0
- openadapt_ml/ingest/loader.py +232 -0
- openadapt_ml/ingest/synthetic.py +1102 -0
- openadapt_ml/models/__init__.py +0 -0
- openadapt_ml/models/api_adapter.py +171 -0
- openadapt_ml/models/base_adapter.py +59 -0
- openadapt_ml/models/dummy_adapter.py +42 -0
- openadapt_ml/models/qwen_vl.py +426 -0
- openadapt_ml/runtime/__init__.py +0 -0
- openadapt_ml/runtime/policy.py +182 -0
- openadapt_ml/schemas/__init__.py +53 -0
- openadapt_ml/schemas/sessions.py +122 -0
- openadapt_ml/schemas/validation.py +252 -0
- openadapt_ml/scripts/__init__.py +0 -0
- openadapt_ml/scripts/compare.py +1490 -0
- openadapt_ml/scripts/demo_policy.py +62 -0
- openadapt_ml/scripts/eval_policy.py +287 -0
- openadapt_ml/scripts/make_gif.py +153 -0
- openadapt_ml/scripts/prepare_synthetic.py +43 -0
- openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
- openadapt_ml/scripts/train.py +174 -0
- openadapt_ml/training/__init__.py +0 -0
- openadapt_ml/training/benchmark_viewer.py +1538 -0
- openadapt_ml/training/shared_ui.py +157 -0
- openadapt_ml/training/stub_provider.py +276 -0
- openadapt_ml/training/trainer.py +2446 -0
- openadapt_ml/training/viewer.py +2970 -0
- openadapt_ml-0.1.0.dist-info/METADATA +818 -0
- openadapt_ml-0.1.0.dist-info/RECORD +55 -0
- openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
- openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,426 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
from PIL import Image
|
|
6
|
+
import torch
|
|
7
|
+
from peft import LoraConfig, PeftModel, get_peft_model
|
|
8
|
+
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
|
|
9
|
+
|
|
10
|
+
from openadapt_ml.models.base_adapter import BaseVLMAdapter, get_default_device
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _process_vision_info(messages: List[Dict[str, Any]]) -> tuple[list[list[Any]], list[list[Any]]]:
|
|
14
|
+
"""Minimal stand-in for qwen_vl_utils.process_vision_info.
|
|
15
|
+
|
|
16
|
+
For our use case we only need to extract image/video entries from the
|
|
17
|
+
message content structure expected by Qwen-VL examples, where each
|
|
18
|
+
message has a `content` list of dicts with `type` in {"image", "video"}.
|
|
19
|
+
|
|
20
|
+
Returns (image_inputs, video_inputs), each a list-of-lists suitable for
|
|
21
|
+
passing to AutoProcessor.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
image_inputs: list[list[Any]] = []
|
|
25
|
+
video_inputs: list[list[Any]] = []
|
|
26
|
+
|
|
27
|
+
current_images: list[Any] = []
|
|
28
|
+
current_videos: list[Any] = []
|
|
29
|
+
|
|
30
|
+
for message in messages:
|
|
31
|
+
content = message.get("content", [])
|
|
32
|
+
if not isinstance(content, list):
|
|
33
|
+
continue
|
|
34
|
+
for item in content:
|
|
35
|
+
if not isinstance(item, dict):
|
|
36
|
+
continue
|
|
37
|
+
t = item.get("type")
|
|
38
|
+
if t == "image":
|
|
39
|
+
current_images.append(item.get("image"))
|
|
40
|
+
elif t == "video":
|
|
41
|
+
current_videos.append(item.get("video"))
|
|
42
|
+
|
|
43
|
+
if current_images:
|
|
44
|
+
image_inputs.append(current_images)
|
|
45
|
+
if current_videos:
|
|
46
|
+
video_inputs.append(current_videos)
|
|
47
|
+
|
|
48
|
+
return image_inputs, video_inputs
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class QwenVLAdapter(BaseVLMAdapter):
|
|
52
|
+
"""Adapter for Qwen-family VLMs using Hugging Face + PEFT.
|
|
53
|
+
|
|
54
|
+
This is a minimal skeleton that:
|
|
55
|
+
- loads a base model + processor
|
|
56
|
+
- optionally applies a LoRA adapter
|
|
57
|
+
- implements the BaseVLMAdapter interface
|
|
58
|
+
|
|
59
|
+
The exact chat/image encoding and loss masking will be filled in
|
|
60
|
+
once we wire a concrete training loop.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
model: torch.nn.Module,
|
|
66
|
+
processor: Any,
|
|
67
|
+
device: Optional[torch.device] = None,
|
|
68
|
+
version: str = "qwen3",
|
|
69
|
+
) -> None:
|
|
70
|
+
super().__init__(model=model, processor=processor, device=device)
|
|
71
|
+
self.version = version
|
|
72
|
+
|
|
73
|
+
@classmethod
|
|
74
|
+
def from_pretrained(
|
|
75
|
+
cls,
|
|
76
|
+
model_name: str,
|
|
77
|
+
lora_config: Optional[LoraConfig | Dict[str, Any]] = None,
|
|
78
|
+
load_in_4bit: bool = False,
|
|
79
|
+
device: Optional[torch.device] = None,
|
|
80
|
+
max_pixels: Optional[int] = None,
|
|
81
|
+
min_pixels: Optional[int] = None,
|
|
82
|
+
) -> "QwenVLAdapter":
|
|
83
|
+
"""Load base Qwen model + processor and attach optional LoRA adapter.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
max_pixels: Maximum image size in pixels (e.g., 512*512=262144 for faster training).
|
|
87
|
+
If None, uses model default (very large).
|
|
88
|
+
min_pixels: Minimum image size in pixels. If None, uses model default.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
if "Qwen3-VL" in model_name or "Qwen3VL" in model_name:
|
|
92
|
+
version = "qwen3"
|
|
93
|
+
model_cls = Qwen3VLForConditionalGeneration
|
|
94
|
+
elif "Qwen2.5-VL" in model_name or "Qwen2_5" in model_name:
|
|
95
|
+
version = "qwen2_5"
|
|
96
|
+
model_cls = Qwen2_5_VLForConditionalGeneration
|
|
97
|
+
else:
|
|
98
|
+
raise ValueError(f"Unrecognized Qwen-VL model name: {model_name}")
|
|
99
|
+
|
|
100
|
+
processor = AutoProcessor.from_pretrained(model_name)
|
|
101
|
+
|
|
102
|
+
# Configure image resolution for faster training
|
|
103
|
+
if max_pixels is not None and hasattr(processor, 'image_processor'):
|
|
104
|
+
processor.image_processor.max_pixels = max_pixels
|
|
105
|
+
print(f"Set max_pixels to {max_pixels} ({int(max_pixels**0.5)}x{int(max_pixels**0.5)} approx)")
|
|
106
|
+
if min_pixels is not None and hasattr(processor, 'image_processor'):
|
|
107
|
+
processor.image_processor.min_pixels = min_pixels
|
|
108
|
+
|
|
109
|
+
model_kwargs: Dict[str, Any] = {}
|
|
110
|
+
if load_in_4bit:
|
|
111
|
+
model_kwargs["load_in_4bit"] = True
|
|
112
|
+
|
|
113
|
+
model = model_cls.from_pretrained(model_name, **model_kwargs)
|
|
114
|
+
|
|
115
|
+
# Support two modes for LoRA:
|
|
116
|
+
# - config-based (no weights_path): create a fresh adapter.
|
|
117
|
+
# - weights-based (weights_path in dict): load an existing adapter.
|
|
118
|
+
lora_weights_path: Optional[str] = None
|
|
119
|
+
lora_cfg_clean: Optional[LoraConfig | Dict[str, Any]] = None
|
|
120
|
+
|
|
121
|
+
if lora_config is not None:
|
|
122
|
+
if isinstance(lora_config, dict):
|
|
123
|
+
lora_weights_path = lora_config.get("weights_path")
|
|
124
|
+
lora_cfg_clean = {k: v for k, v in lora_config.items() if k != "weights_path"}
|
|
125
|
+
else:
|
|
126
|
+
lora_cfg_clean = lora_config
|
|
127
|
+
|
|
128
|
+
if lora_weights_path:
|
|
129
|
+
# Load an existing adapter onto the base model.
|
|
130
|
+
model = PeftModel.from_pretrained(model, lora_weights_path)
|
|
131
|
+
elif lora_cfg_clean is not None:
|
|
132
|
+
if isinstance(lora_cfg_clean, dict):
|
|
133
|
+
lora_cfg_clean = LoraConfig(**lora_cfg_clean)
|
|
134
|
+
model = get_peft_model(model, lora_cfg_clean) # type: ignore[arg-type]
|
|
135
|
+
|
|
136
|
+
if device is None:
|
|
137
|
+
device = get_default_device()
|
|
138
|
+
|
|
139
|
+
return cls(model=model, processor=processor, device=device, version=version)
|
|
140
|
+
|
|
141
|
+
def prepare_inputs(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: # type: ignore[override]
|
|
142
|
+
"""Convert SFT samples into model inputs for Qwen-VL.
|
|
143
|
+
|
|
144
|
+
Supports true batching by processing multiple samples simultaneously.
|
|
145
|
+
Uses processor.apply_chat_template with padding=True and truncation=True
|
|
146
|
+
for multi-sample tokenization. Computes assistant-only labels per sample.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
batch: List of SFT-style samples, each with "images" and "messages" keys.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Dict with input_ids, attention_mask, pixel_values, image_grid_thw, and labels.
|
|
153
|
+
"""
|
|
154
|
+
|
|
155
|
+
if self.version == "qwen3":
|
|
156
|
+
# Build batch of messages for all samples
|
|
157
|
+
batch_messages_full: List[List[Dict[str, Any]]] = []
|
|
158
|
+
batch_messages_user_only: List[List[Dict[str, Any]]] = []
|
|
159
|
+
|
|
160
|
+
for sample in batch:
|
|
161
|
+
image_paths = sample["images"]
|
|
162
|
+
if not image_paths:
|
|
163
|
+
raise ValueError("Sample is missing image paths")
|
|
164
|
+
image_path = image_paths[0]
|
|
165
|
+
|
|
166
|
+
messages = sample["messages"]
|
|
167
|
+
user_text = ""
|
|
168
|
+
assistant_text = ""
|
|
169
|
+
for m in messages:
|
|
170
|
+
role = m.get("role")
|
|
171
|
+
if role == "user":
|
|
172
|
+
user_text = m.get("content", "")
|
|
173
|
+
elif role == "assistant":
|
|
174
|
+
assistant_text = m.get("content", "")
|
|
175
|
+
|
|
176
|
+
# Full messages (user + assistant)
|
|
177
|
+
qwen_messages_full: List[Dict[str, Any]] = [
|
|
178
|
+
{
|
|
179
|
+
"role": "user",
|
|
180
|
+
"content": [
|
|
181
|
+
{"type": "image", "image": image_path},
|
|
182
|
+
{"type": "text", "text": user_text},
|
|
183
|
+
],
|
|
184
|
+
},
|
|
185
|
+
]
|
|
186
|
+
if assistant_text:
|
|
187
|
+
qwen_messages_full.append({
|
|
188
|
+
"role": "assistant",
|
|
189
|
+
"content": [{"type": "text", "text": assistant_text}],
|
|
190
|
+
})
|
|
191
|
+
batch_messages_full.append(qwen_messages_full)
|
|
192
|
+
|
|
193
|
+
# User-only messages (for label masking)
|
|
194
|
+
qwen_messages_user_only: List[Dict[str, Any]] = [
|
|
195
|
+
{
|
|
196
|
+
"role": "user",
|
|
197
|
+
"content": [
|
|
198
|
+
{"type": "image", "image": image_path},
|
|
199
|
+
{"type": "text", "text": user_text},
|
|
200
|
+
],
|
|
201
|
+
}
|
|
202
|
+
]
|
|
203
|
+
batch_messages_user_only.append(qwen_messages_user_only)
|
|
204
|
+
|
|
205
|
+
# Tokenize full batch with padding and truncation
|
|
206
|
+
inputs_full = self.processor.apply_chat_template( # type: ignore[call-arg]
|
|
207
|
+
batch_messages_full,
|
|
208
|
+
tokenize=True,
|
|
209
|
+
add_generation_prompt=False,
|
|
210
|
+
return_dict=True,
|
|
211
|
+
return_tensors="pt",
|
|
212
|
+
padding=True,
|
|
213
|
+
truncation=True,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# Tokenize user-only batch for label masking
|
|
217
|
+
inputs_user = self.processor.apply_chat_template( # type: ignore[call-arg]
|
|
218
|
+
batch_messages_user_only,
|
|
219
|
+
tokenize=True,
|
|
220
|
+
add_generation_prompt=True,
|
|
221
|
+
return_dict=True,
|
|
222
|
+
return_tensors="pt",
|
|
223
|
+
padding=True,
|
|
224
|
+
truncation=True,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
input_ids_full = inputs_full["input_ids"] # [batch_size, seq_len]
|
|
228
|
+
input_ids_user = inputs_user["input_ids"] # [batch_size, seq_len_user]
|
|
229
|
+
|
|
230
|
+
# Initialize labels with full input_ids, then mask per sample
|
|
231
|
+
labels = input_ids_full.clone()
|
|
232
|
+
|
|
233
|
+
# Compute assistant-only labels per sample
|
|
234
|
+
batch_size = input_ids_full.size(0)
|
|
235
|
+
for i in range(batch_size):
|
|
236
|
+
sample = batch[i]
|
|
237
|
+
messages = sample["messages"]
|
|
238
|
+
assistant_text = ""
|
|
239
|
+
for m in messages:
|
|
240
|
+
if m.get("role") == "assistant":
|
|
241
|
+
assistant_text = m.get("content", "")
|
|
242
|
+
|
|
243
|
+
if assistant_text:
|
|
244
|
+
# Find where user prompt ends and assistant response begins
|
|
245
|
+
# The user-only sequence should be a prefix of the full sequence
|
|
246
|
+
full_ids = input_ids_full[i]
|
|
247
|
+
user_ids = input_ids_user[i]
|
|
248
|
+
|
|
249
|
+
# Remove padding from user_ids to find actual sequence length
|
|
250
|
+
# Padding token is typically 0 or a special value
|
|
251
|
+
# For Qwen, we look for the first occurrence of pad token
|
|
252
|
+
pad_token_id = self.processor.tokenizer.pad_token_id
|
|
253
|
+
user_ids_no_pad = user_ids[user_ids != pad_token_id] if pad_token_id is not None else user_ids
|
|
254
|
+
user_len = len(user_ids_no_pad)
|
|
255
|
+
|
|
256
|
+
# Check if user sequence is a prefix of full sequence
|
|
257
|
+
if user_len <= full_ids.size(0):
|
|
258
|
+
# Mask everything except assistant tokens
|
|
259
|
+
labels[i, :] = -100
|
|
260
|
+
# Only supervise on tokens after the user prompt
|
|
261
|
+
labels[i, user_len:] = full_ids[user_len:]
|
|
262
|
+
|
|
263
|
+
# Ensure padding tokens are masked in labels
|
|
264
|
+
if hasattr(self.processor.tokenizer, 'pad_token_id') and self.processor.tokenizer.pad_token_id is not None:
|
|
265
|
+
labels[input_ids_full == self.processor.tokenizer.pad_token_id] = -100
|
|
266
|
+
|
|
267
|
+
inputs_full["labels"] = labels
|
|
268
|
+
return inputs_full
|
|
269
|
+
|
|
270
|
+
else: # qwen2_5
|
|
271
|
+
# Build batch of messages
|
|
272
|
+
batch_messages: List[List[Dict[str, Any]]] = []
|
|
273
|
+
batch_texts: List[str] = []
|
|
274
|
+
all_image_inputs: List[List[Any]] = []
|
|
275
|
+
all_video_inputs: List[List[Any]] = []
|
|
276
|
+
|
|
277
|
+
for sample in batch:
|
|
278
|
+
image_paths = sample["images"]
|
|
279
|
+
if not image_paths:
|
|
280
|
+
raise ValueError("Sample is missing image paths")
|
|
281
|
+
image_path = image_paths[0]
|
|
282
|
+
|
|
283
|
+
messages = sample["messages"]
|
|
284
|
+
user_text = ""
|
|
285
|
+
assistant_text = ""
|
|
286
|
+
for m in messages:
|
|
287
|
+
role = m.get("role")
|
|
288
|
+
if role == "user":
|
|
289
|
+
user_text = m.get("content", "")
|
|
290
|
+
elif role == "assistant":
|
|
291
|
+
assistant_text = m.get("content", "")
|
|
292
|
+
|
|
293
|
+
qwen_messages: List[Dict[str, Any]] = [
|
|
294
|
+
{
|
|
295
|
+
"role": "user",
|
|
296
|
+
"content": [
|
|
297
|
+
{"type": "image", "image": image_path},
|
|
298
|
+
{"type": "text", "text": user_text},
|
|
299
|
+
],
|
|
300
|
+
}
|
|
301
|
+
]
|
|
302
|
+
if assistant_text:
|
|
303
|
+
qwen_messages.append({
|
|
304
|
+
"role": "assistant",
|
|
305
|
+
"content": [{"type": "text", "text": assistant_text}],
|
|
306
|
+
})
|
|
307
|
+
|
|
308
|
+
batch_messages.append(qwen_messages)
|
|
309
|
+
|
|
310
|
+
# Convert to text for this sample
|
|
311
|
+
text = self.processor.apply_chat_template( # type: ignore[call-arg]
|
|
312
|
+
qwen_messages,
|
|
313
|
+
tokenize=False,
|
|
314
|
+
add_generation_prompt=False,
|
|
315
|
+
)
|
|
316
|
+
batch_texts.append(text)
|
|
317
|
+
|
|
318
|
+
# Extract vision info
|
|
319
|
+
image_inputs, video_inputs = _process_vision_info(qwen_messages)
|
|
320
|
+
if image_inputs:
|
|
321
|
+
all_image_inputs.extend(image_inputs)
|
|
322
|
+
if video_inputs:
|
|
323
|
+
all_video_inputs.extend(video_inputs)
|
|
324
|
+
|
|
325
|
+
# Process batch with padding
|
|
326
|
+
videos_arg = all_video_inputs if all_video_inputs else None
|
|
327
|
+
inputs = self.processor( # type: ignore[call-arg]
|
|
328
|
+
text=batch_texts,
|
|
329
|
+
images=all_image_inputs if all_image_inputs else None,
|
|
330
|
+
videos=videos_arg,
|
|
331
|
+
padding=True,
|
|
332
|
+
truncation=True,
|
|
333
|
+
return_tensors="pt",
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
# For qwen2_5, use full sequence supervision for now
|
|
337
|
+
# (can be refined to assistant-only if needed)
|
|
338
|
+
input_ids = inputs["input_ids"]
|
|
339
|
+
labels = input_ids.clone()
|
|
340
|
+
|
|
341
|
+
# Mask padding tokens
|
|
342
|
+
if hasattr(self.processor.tokenizer, 'pad_token_id') and self.processor.tokenizer.pad_token_id is not None:
|
|
343
|
+
labels[input_ids == self.processor.tokenizer.pad_token_id] = -100
|
|
344
|
+
|
|
345
|
+
inputs["labels"] = labels
|
|
346
|
+
return inputs
|
|
347
|
+
|
|
348
|
+
def compute_loss(self, inputs: Dict[str, Any]) -> torch.Tensor: # type: ignore[override]
|
|
349
|
+
inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
|
|
350
|
+
outputs = self.model(**inputs)
|
|
351
|
+
# Hugging Face causal LM models return `loss` when `labels` are provided.
|
|
352
|
+
return outputs.loss # type: ignore[no-any-return]
|
|
353
|
+
|
|
354
|
+
def generate(self, sample: Dict[str, Any], max_new_tokens: int = 64) -> str: # type: ignore[override]
|
|
355
|
+
"""Generate assistant text for a single SFT-style sample.
|
|
356
|
+
|
|
357
|
+
We pass system + user messages to the chat template with
|
|
358
|
+
`add_generation_prompt=True` and let the model generate the
|
|
359
|
+
assistant continuation.
|
|
360
|
+
"""
|
|
361
|
+
|
|
362
|
+
image_paths = sample["images"]
|
|
363
|
+
if not image_paths:
|
|
364
|
+
raise ValueError("Sample is missing image paths")
|
|
365
|
+
image_path = image_paths[0]
|
|
366
|
+
|
|
367
|
+
messages = sample["messages"]
|
|
368
|
+
user_text = ""
|
|
369
|
+
for m in messages:
|
|
370
|
+
if m.get("role") == "user":
|
|
371
|
+
user_text = m.get("content", "")
|
|
372
|
+
|
|
373
|
+
qwen_messages: List[Dict[str, Any]] = [
|
|
374
|
+
{
|
|
375
|
+
"role": "user",
|
|
376
|
+
"content": [
|
|
377
|
+
{"type": "image", "image": image_path},
|
|
378
|
+
{"type": "text", "text": user_text},
|
|
379
|
+
],
|
|
380
|
+
}
|
|
381
|
+
]
|
|
382
|
+
|
|
383
|
+
if self.version == "qwen3":
|
|
384
|
+
inputs = self.processor.apply_chat_template( # type: ignore[call-arg]
|
|
385
|
+
qwen_messages,
|
|
386
|
+
tokenize=True,
|
|
387
|
+
add_generation_prompt=True,
|
|
388
|
+
return_dict=True,
|
|
389
|
+
return_tensors="pt",
|
|
390
|
+
).to(self.device)
|
|
391
|
+
else:
|
|
392
|
+
text = self.processor.apply_chat_template( # type: ignore[call-arg]
|
|
393
|
+
qwen_messages,
|
|
394
|
+
tokenize=False,
|
|
395
|
+
add_generation_prompt=True,
|
|
396
|
+
)
|
|
397
|
+
image_inputs, video_inputs = _process_vision_info(qwen_messages)
|
|
398
|
+
videos_arg = video_inputs if video_inputs else None
|
|
399
|
+
inputs = self.processor( # type: ignore[call-arg]
|
|
400
|
+
text=[text],
|
|
401
|
+
images=image_inputs,
|
|
402
|
+
videos=videos_arg,
|
|
403
|
+
padding=True,
|
|
404
|
+
return_tensors="pt",
|
|
405
|
+
).to(self.device)
|
|
406
|
+
|
|
407
|
+
with torch.no_grad():
|
|
408
|
+
generation = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
|
|
409
|
+
|
|
410
|
+
# Decode only the GENERATED tokens (not the input prompt)
|
|
411
|
+
# This is critical - otherwise we return "user Goal: ... assistant ..." instead of just the response
|
|
412
|
+
input_len = inputs["input_ids"].shape[1]
|
|
413
|
+
generated_ids = generation[:, input_len:]
|
|
414
|
+
text = self.processor.batch_decode( # type: ignore[call-arg]
|
|
415
|
+
generated_ids,
|
|
416
|
+
skip_special_tokens=True,
|
|
417
|
+
)[0]
|
|
418
|
+
return text
|
|
419
|
+
|
|
420
|
+
def save_checkpoint(self, path: str) -> None:
|
|
421
|
+
"""Save the LoRA adapter weights to a directory."""
|
|
422
|
+
from pathlib import Path
|
|
423
|
+
save_path = Path(path)
|
|
424
|
+
save_path.mkdir(parents=True, exist_ok=True)
|
|
425
|
+
# Save the PEFT adapter (LoRA weights only, not base model)
|
|
426
|
+
self.model.save_pretrained(str(save_path))
|
|
File without changes
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import re
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
7
|
+
|
|
8
|
+
from PIL import Image
|
|
9
|
+
|
|
10
|
+
from openadapt_ml.models.base_adapter import BaseVLMAdapter
|
|
11
|
+
from openadapt_ml.schemas.sessions import Action
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# Coordinate-based DSL patterns
|
|
15
|
+
_CLICK_RE = re.compile(r"CLICK\(x=([0-9]*\.?[0-9]+),\s*y=([0-9]*\.?[0-9]+)\)")
|
|
16
|
+
_TYPE_RE = re.compile(r'TYPE\(text="([^"\\]*(?:\\.[^"\\]*)*)"\)')
|
|
17
|
+
_WAIT_RE = re.compile(r"\bWAIT\s*\(\s*\)")
|
|
18
|
+
_DONE_RE = re.compile(r"\bDONE\s*\(\s*\)")
|
|
19
|
+
|
|
20
|
+
# SoM (Set-of-Marks) index-based DSL patterns
|
|
21
|
+
_CLICK_SOM_RE = re.compile(r"CLICK\(\[(\d+)\]\)")
|
|
22
|
+
_TYPE_SOM_RE = re.compile(r'TYPE\(\[(\d+)\],\s*["\']([^"\']*(?:\\.[^"\']*)*)["\']\)')
|
|
23
|
+
_TYPE_SOM_SIMPLE_RE = re.compile(r'TYPE\(["\']([^"\']*(?:\\.[^"\']*)*)["\']\)')
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class PolicyOutput:
|
|
28
|
+
"""Result of a single policy step."""
|
|
29
|
+
action: Action
|
|
30
|
+
thought: Optional[str] = None
|
|
31
|
+
state: Optional[Dict[str, Any]] = None
|
|
32
|
+
raw_text: str = ""
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def parse_thought_state_action(text: str) -> Tuple[Optional[str], Optional[Dict[str, Any]], str]:
|
|
36
|
+
"""Parse Thought / State / Action blocks from model output.
|
|
37
|
+
|
|
38
|
+
Expected format:
|
|
39
|
+
Thought: [reasoning]
|
|
40
|
+
State: {"success": false, "progress": 0.5, ...}
|
|
41
|
+
Action: CLICK(x=0.42, y=0.31)
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
(thought, state, action_str):
|
|
45
|
+
- thought: Content after 'Thought:' up to 'State:' or 'Action:'
|
|
46
|
+
- state: Parsed JSON dict from 'State:' line, or None
|
|
47
|
+
- action_str: Content after 'Action:', or whole text if missing
|
|
48
|
+
|
|
49
|
+
Note: We look for the LAST occurrence of 'Action:' to handle cases where
|
|
50
|
+
the user prompt template also contains 'Action:' placeholders.
|
|
51
|
+
"""
|
|
52
|
+
thought: Optional[str] = None
|
|
53
|
+
state: Optional[Dict[str, Any]] = None
|
|
54
|
+
action_str: str = text.strip()
|
|
55
|
+
|
|
56
|
+
# Extract Thought - find the LAST occurrence (model's response, not template)
|
|
57
|
+
thought_matches = list(re.finditer(r"Thought:\s*(.+?)(?=State:|Action:|$)", text, re.DOTALL | re.IGNORECASE))
|
|
58
|
+
if thought_matches:
|
|
59
|
+
thought = thought_matches[-1].group(1).strip()
|
|
60
|
+
|
|
61
|
+
# Extract State (JSON on same line or next line) - last occurrence
|
|
62
|
+
state_matches = list(re.finditer(r"State:\s*(\{.*?\})", text, re.DOTALL | re.IGNORECASE))
|
|
63
|
+
if state_matches:
|
|
64
|
+
try:
|
|
65
|
+
state = json.loads(state_matches[-1].group(1))
|
|
66
|
+
except json.JSONDecodeError:
|
|
67
|
+
state = None
|
|
68
|
+
|
|
69
|
+
# Extract Action - find the LAST occurrence to get the model's actual action
|
|
70
|
+
# not the placeholder in the prompt template
|
|
71
|
+
action_matches = list(re.finditer(r"Action:\s*(.+?)(?=\n|$)", text, re.IGNORECASE))
|
|
72
|
+
if action_matches:
|
|
73
|
+
action_str = action_matches[-1].group(1).strip()
|
|
74
|
+
|
|
75
|
+
return thought, state, action_str
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class AgentPolicy:
|
|
79
|
+
"""Runtime policy wrapper around a trained VLM adapter.
|
|
80
|
+
|
|
81
|
+
Formats goal-conditioned inputs and parses textual actions into
|
|
82
|
+
structured `Action` objects.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(self, adapter: BaseVLMAdapter) -> None:
|
|
86
|
+
self.adapter = adapter
|
|
87
|
+
|
|
88
|
+
def _build_sample(self, image: Image.Image, goal: str) -> Dict[str, Any]:
|
|
89
|
+
# For runtime we keep the same structure as SFT samples but use
|
|
90
|
+
# an in-memory image. The adapter's generate method currently expects
|
|
91
|
+
# paths, so we require the caller to supply a path-based sample. For
|
|
92
|
+
# now, we save responsibility for image loading to the caller; this
|
|
93
|
+
# method is kept for future extensibility.
|
|
94
|
+
raise NotImplementedError(
|
|
95
|
+
"AgentPolicy._build_sample is not used directly; pass a sample dict "
|
|
96
|
+
"compatible with the adapter's `generate` method."
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def _parse_action(self, text: str) -> Action:
|
|
100
|
+
"""Parse a DSL action string into an Action object.
|
|
101
|
+
|
|
102
|
+
Supported formats (coordinate-based):
|
|
103
|
+
- CLICK(x=<float>, y=<float>)
|
|
104
|
+
- TYPE(text="...")
|
|
105
|
+
|
|
106
|
+
Supported formats (SoM index-based):
|
|
107
|
+
- CLICK([N])
|
|
108
|
+
- TYPE([N], "text")
|
|
109
|
+
- TYPE("text")
|
|
110
|
+
|
|
111
|
+
Common formats:
|
|
112
|
+
- WAIT()
|
|
113
|
+
- DONE()
|
|
114
|
+
|
|
115
|
+
Returns Action(type="failed") if no valid action is found.
|
|
116
|
+
"""
|
|
117
|
+
# Try SoM patterns first (index-based)
|
|
118
|
+
# CLICK([N])
|
|
119
|
+
m = _CLICK_SOM_RE.search(text)
|
|
120
|
+
if m:
|
|
121
|
+
idx = int(m.group(1))
|
|
122
|
+
return Action(type="click", element_index=idx)
|
|
123
|
+
|
|
124
|
+
# TYPE([N], "text")
|
|
125
|
+
m = _TYPE_SOM_RE.search(text)
|
|
126
|
+
if m:
|
|
127
|
+
idx = int(m.group(1))
|
|
128
|
+
raw_text = m.group(2)
|
|
129
|
+
unescaped = raw_text.replace('\\"', '"').replace("\\\\", "\\")
|
|
130
|
+
return Action(type="type", text=unescaped, element_index=idx)
|
|
131
|
+
|
|
132
|
+
# TYPE("text") - SoM style without index
|
|
133
|
+
m = _TYPE_SOM_SIMPLE_RE.search(text)
|
|
134
|
+
if m:
|
|
135
|
+
raw_text = m.group(1)
|
|
136
|
+
unescaped = raw_text.replace('\\"', '"').replace("\\\\", "\\")
|
|
137
|
+
return Action(type="type", text=unescaped)
|
|
138
|
+
|
|
139
|
+
# Coordinate-based patterns
|
|
140
|
+
# CLICK(x=..., y=...)
|
|
141
|
+
m = _CLICK_RE.search(text)
|
|
142
|
+
if m:
|
|
143
|
+
x = float(m.group(1))
|
|
144
|
+
y = float(m.group(2))
|
|
145
|
+
# Clamp to [0, 1]
|
|
146
|
+
x = max(0.0, min(1.0, x))
|
|
147
|
+
y = max(0.0, min(1.0, y))
|
|
148
|
+
return Action(type="click", x=x, y=y)
|
|
149
|
+
|
|
150
|
+
# TYPE(text="...")
|
|
151
|
+
m = _TYPE_RE.search(text)
|
|
152
|
+
if m:
|
|
153
|
+
# Unescape the text content
|
|
154
|
+
raw_text = m.group(1)
|
|
155
|
+
unescaped = raw_text.replace('\\"', '"').replace("\\\\", "\\")
|
|
156
|
+
return Action(type="type", text=unescaped)
|
|
157
|
+
|
|
158
|
+
# WAIT()
|
|
159
|
+
if _WAIT_RE.search(text):
|
|
160
|
+
return Action(type="wait")
|
|
161
|
+
|
|
162
|
+
# DONE()
|
|
163
|
+
if _DONE_RE.search(text):
|
|
164
|
+
return Action(type="done")
|
|
165
|
+
|
|
166
|
+
# Fallback
|
|
167
|
+
return Action(type="failed", raw={"text": text})
|
|
168
|
+
|
|
169
|
+
def predict_action_from_sample(
|
|
170
|
+
self, sample: Dict[str, Any], max_new_tokens: int = 150
|
|
171
|
+
) -> Tuple[Action, Optional[str], Optional[Dict[str, Any]], str]:
|
|
172
|
+
"""Run the adapter on a pre-built SFT-style sample and parse the result.
|
|
173
|
+
|
|
174
|
+
Returns (Action, thought, state, raw_text) where:
|
|
175
|
+
- thought: Reasoning text from 'Thought:' block
|
|
176
|
+
- state: Parsed JSON dict from 'State:' block (may contain 'success' bool)
|
|
177
|
+
- raw_text: The raw model output text for debugging
|
|
178
|
+
"""
|
|
179
|
+
text = self.adapter.generate(sample, max_new_tokens=max_new_tokens)
|
|
180
|
+
thought, state, action_str = parse_thought_state_action(text)
|
|
181
|
+
action = self._parse_action(action_str)
|
|
182
|
+
return action, thought, state, text
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""Schema definitions and validation for openadapt-ml.
|
|
2
|
+
|
|
3
|
+
Core data structures:
|
|
4
|
+
- Action: A single GUI action (click, type, scroll, etc.)
|
|
5
|
+
- Observation: GUI state observation (screenshot, accessibility tree, etc.)
|
|
6
|
+
- Step: One timestep containing observation + action
|
|
7
|
+
- Episode: A single task attempt / workflow instance
|
|
8
|
+
- Session: Container for multiple episodes
|
|
9
|
+
|
|
10
|
+
Validation:
|
|
11
|
+
- validate_episode(): Validate an Episode object
|
|
12
|
+
- validate_session(): Validate a Session object
|
|
13
|
+
- validate_episodes(): Validate a list of Episodes
|
|
14
|
+
- ValidationError: Raised on schema violations
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from openadapt_ml.schemas.sessions import (
|
|
18
|
+
Action,
|
|
19
|
+
ActionType,
|
|
20
|
+
Episode,
|
|
21
|
+
Observation,
|
|
22
|
+
Session,
|
|
23
|
+
Step,
|
|
24
|
+
)
|
|
25
|
+
from openadapt_ml.schemas.validation import (
|
|
26
|
+
ValidationError,
|
|
27
|
+
summarize_episodes,
|
|
28
|
+
validate_action,
|
|
29
|
+
validate_episode,
|
|
30
|
+
validate_episodes,
|
|
31
|
+
validate_observation,
|
|
32
|
+
validate_session,
|
|
33
|
+
validate_step,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
__all__ = [
|
|
37
|
+
# Core types
|
|
38
|
+
"Action",
|
|
39
|
+
"ActionType",
|
|
40
|
+
"Episode",
|
|
41
|
+
"Observation",
|
|
42
|
+
"Session",
|
|
43
|
+
"Step",
|
|
44
|
+
# Validation
|
|
45
|
+
"ValidationError",
|
|
46
|
+
"validate_action",
|
|
47
|
+
"validate_episode",
|
|
48
|
+
"validate_episodes",
|
|
49
|
+
"validate_observation",
|
|
50
|
+
"validate_session",
|
|
51
|
+
"validate_step",
|
|
52
|
+
"summarize_episodes",
|
|
53
|
+
]
|