openadapt-ml 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -115
- openadapt_ml/benchmarks/agent.py +265 -421
- openadapt_ml/benchmarks/azure.py +28 -19
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1722 -4847
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +22 -5
- openadapt_ml/benchmarks/vm_monitor.py +530 -29
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +2038 -487
- openadapt_ml/cloud/ssh_tunnel.py +68 -26
- openadapt_ml/datasets/next_action.py +40 -30
- openadapt_ml/evals/grounding.py +8 -3
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +41 -26
- openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
- openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/runner.py +29 -14
- openadapt_ml/export/parquet.py +36 -24
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +8 -6
- openadapt_ml/ingest/capture.py +25 -22
- openadapt_ml/ingest/loader.py +7 -4
- openadapt_ml/ingest/synthetic.py +189 -100
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/demo_retriever.py +50 -24
- openadapt_ml/retrieval/embeddings.py +9 -8
- openadapt_ml/retrieval/retriever.py +3 -1
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +18 -5
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +9 -0
- openadapt_ml/schema/converters.py +74 -27
- openadapt_ml/schema/episode.py +31 -18
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +85 -54
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +15 -9
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +3 -1
- openadapt_ml/scripts/train.py +21 -9
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +52 -41
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +143 -86
- openadapt_ml/training/trl_trainer.py +70 -21
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +215 -14
- openadapt_ml-0.2.1.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/live_tracker.py +0 -180
- openadapt_ml/benchmarks/runner.py +0 -418
- openadapt_ml/benchmarks/waa.py +0 -761
- openadapt_ml/benchmarks/waa_live.py +0 -619
- openadapt_ml-0.2.0.dist-info/RECORD +0 -86
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
"""OpenAI (GPT) API provider.
|
|
2
|
+
|
|
3
|
+
Supports GPT-5.2, GPT-5.1, GPT-4o, and other OpenAI models with vision.
|
|
4
|
+
Implements the BaseAPIProvider interface for the Chat Completions API.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
from typing import TYPE_CHECKING, Any
|
|
11
|
+
|
|
12
|
+
from openadapt_ml.models.providers.base import (
|
|
13
|
+
BaseAPIProvider,
|
|
14
|
+
AuthenticationError,
|
|
15
|
+
ModelNotFoundError,
|
|
16
|
+
ProviderError,
|
|
17
|
+
RateLimitError,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from PIL import Image
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
# Default models
|
|
26
|
+
DEFAULT_MODEL = "gpt-4o"
|
|
27
|
+
|
|
28
|
+
# Supported models with their properties
|
|
29
|
+
SUPPORTED_MODELS = {
|
|
30
|
+
"gpt-5.2": {"context": 128_000, "description": "Latest GPT model"},
|
|
31
|
+
"gpt-5.1": {"context": 128_000, "description": "Previous GPT-5"},
|
|
32
|
+
"gpt-4o": {"context": 128_000, "description": "Vision-capable, fast"},
|
|
33
|
+
"gpt-4o-mini": {"context": 128_000, "description": "Cheaper, fast"},
|
|
34
|
+
"gpt-4-turbo": {"context": 128_000, "description": "Previous gen turbo"},
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class OpenAIProvider(BaseAPIProvider):
|
|
39
|
+
"""Provider for OpenAI's GPT models.
|
|
40
|
+
|
|
41
|
+
Implements vision support via data URL encoded images in the Chat Completions API.
|
|
42
|
+
Supports both standard chat and vision-enabled models.
|
|
43
|
+
|
|
44
|
+
Supported models:
|
|
45
|
+
- gpt-5.2: Latest and most capable
|
|
46
|
+
- gpt-5.1: Previous generation GPT-5
|
|
47
|
+
- gpt-4o: Fast, vision-capable
|
|
48
|
+
- gpt-4o-mini: Cost-effective, vision-capable
|
|
49
|
+
|
|
50
|
+
Example:
|
|
51
|
+
>>> provider = OpenAIProvider()
|
|
52
|
+
>>> client = provider.create_client(api_key)
|
|
53
|
+
>>> response = provider.send_message(
|
|
54
|
+
... client,
|
|
55
|
+
... model="gpt-5.2",
|
|
56
|
+
... system="You are a GUI agent.",
|
|
57
|
+
... content=[
|
|
58
|
+
... {"type": "text", "text": "Click the submit button"},
|
|
59
|
+
... provider.encode_image(screenshot),
|
|
60
|
+
... ],
|
|
61
|
+
... )
|
|
62
|
+
|
|
63
|
+
Note:
|
|
64
|
+
OpenAI uses data URLs for images (data:image/png;base64,...).
|
|
65
|
+
This differs from Anthropic's explicit source object format.
|
|
66
|
+
|
|
67
|
+
Attributes:
|
|
68
|
+
name: Returns 'openai'.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def name(self) -> str:
|
|
73
|
+
"""Provider name."""
|
|
74
|
+
return "openai"
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def env_key_name(self) -> str:
|
|
78
|
+
"""Environment variable name for API key."""
|
|
79
|
+
return "OPENAI_API_KEY"
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def default_model(self) -> str:
|
|
83
|
+
"""Default model to use."""
|
|
84
|
+
return DEFAULT_MODEL
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def supported_models(self) -> dict[str, dict[str, Any]]:
|
|
88
|
+
"""Dictionary of supported models and their properties."""
|
|
89
|
+
return SUPPORTED_MODELS
|
|
90
|
+
|
|
91
|
+
def create_client(self, api_key: str) -> Any:
|
|
92
|
+
"""Create OpenAI client.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
api_key: OpenAI API key.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
OpenAI client instance.
|
|
99
|
+
|
|
100
|
+
Raises:
|
|
101
|
+
ImportError: If openai package not installed.
|
|
102
|
+
AuthenticationError: If API key is empty.
|
|
103
|
+
"""
|
|
104
|
+
try:
|
|
105
|
+
from openai import OpenAI
|
|
106
|
+
except ImportError as e:
|
|
107
|
+
raise ImportError(
|
|
108
|
+
"openai package is required for provider='openai'. "
|
|
109
|
+
"Install with: uv add openai"
|
|
110
|
+
) from e
|
|
111
|
+
|
|
112
|
+
if not api_key or not api_key.strip():
|
|
113
|
+
raise AuthenticationError(
|
|
114
|
+
"OpenAI API key cannot be empty. "
|
|
115
|
+
"Get a key from https://platform.openai.com/api-keys"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
logger.debug("Creating OpenAI client")
|
|
119
|
+
return OpenAI(api_key=api_key)
|
|
120
|
+
|
|
121
|
+
def send_message(
|
|
122
|
+
self,
|
|
123
|
+
client: Any,
|
|
124
|
+
model: str,
|
|
125
|
+
system: str,
|
|
126
|
+
content: list[dict[str, Any]],
|
|
127
|
+
max_tokens: int = 1024,
|
|
128
|
+
temperature: float = 0.1,
|
|
129
|
+
) -> str:
|
|
130
|
+
"""Send message using OpenAI Chat Completions API.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
client: OpenAI client from create_client().
|
|
134
|
+
model: Model ID (e.g., 'gpt-5.2', 'gpt-4o').
|
|
135
|
+
system: System prompt.
|
|
136
|
+
content: List of content blocks (text and images).
|
|
137
|
+
max_tokens: Max response tokens.
|
|
138
|
+
temperature: Sampling temperature (0.0-2.0 for OpenAI).
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Model response text.
|
|
142
|
+
|
|
143
|
+
Raises:
|
|
144
|
+
AuthenticationError: If API key is invalid.
|
|
145
|
+
RateLimitError: If rate limit exceeded.
|
|
146
|
+
ModelNotFoundError: If model doesn't exist.
|
|
147
|
+
ProviderError: For other API errors.
|
|
148
|
+
"""
|
|
149
|
+
logger.debug(f"Sending message to {model} with {len(content)} content blocks")
|
|
150
|
+
|
|
151
|
+
messages = []
|
|
152
|
+
|
|
153
|
+
if system:
|
|
154
|
+
messages.append({"role": "system", "content": system})
|
|
155
|
+
|
|
156
|
+
messages.append({"role": "user", "content": content})
|
|
157
|
+
|
|
158
|
+
try:
|
|
159
|
+
response = client.chat.completions.create(
|
|
160
|
+
model=model,
|
|
161
|
+
messages=messages,
|
|
162
|
+
max_completion_tokens=max_tokens,
|
|
163
|
+
temperature=temperature,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
result = response.choices[0].message.content or ""
|
|
167
|
+
logger.debug(f"Received response: {len(result)} chars")
|
|
168
|
+
return result
|
|
169
|
+
|
|
170
|
+
except Exception as e:
|
|
171
|
+
error_str = str(e).lower()
|
|
172
|
+
|
|
173
|
+
# Map common errors to specific exceptions
|
|
174
|
+
if (
|
|
175
|
+
"authentication" in error_str
|
|
176
|
+
or "api_key" in error_str
|
|
177
|
+
or "invalid_api_key" in error_str
|
|
178
|
+
):
|
|
179
|
+
raise AuthenticationError(f"OpenAI authentication failed: {e}") from e
|
|
180
|
+
elif "rate_limit" in error_str or "429" in error_str:
|
|
181
|
+
raise RateLimitError(f"OpenAI rate limit exceeded: {e}") from e
|
|
182
|
+
elif "model_not_found" in error_str or "does not exist" in error_str:
|
|
183
|
+
raise ModelNotFoundError(f"Model '{model}' not found: {e}") from e
|
|
184
|
+
else:
|
|
185
|
+
raise ProviderError(f"OpenAI API error: {e}") from e
|
|
186
|
+
|
|
187
|
+
def encode_image(self, image: "Image") -> dict[str, Any]:
|
|
188
|
+
"""Encode image for OpenAI API.
|
|
189
|
+
|
|
190
|
+
OpenAI uses data URLs for images in the format:
|
|
191
|
+
data:image/<type>;base64,<data>
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
image: PIL Image.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
Image content block for OpenAI API in format:
|
|
198
|
+
{
|
|
199
|
+
"type": "image_url",
|
|
200
|
+
"image_url": {
|
|
201
|
+
"url": "data:image/png;base64,..."
|
|
202
|
+
}
|
|
203
|
+
}
|
|
204
|
+
"""
|
|
205
|
+
base64_data = self.image_to_base64(image, "PNG")
|
|
206
|
+
return {
|
|
207
|
+
"type": "image_url",
|
|
208
|
+
"image_url": {
|
|
209
|
+
"url": f"data:image/png;base64,{base64_data}",
|
|
210
|
+
},
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
def encode_image_with_detail(
|
|
214
|
+
self,
|
|
215
|
+
image: "Image",
|
|
216
|
+
detail: str = "auto",
|
|
217
|
+
) -> dict[str, Any]:
|
|
218
|
+
"""Encode image with detail level specification.
|
|
219
|
+
|
|
220
|
+
OpenAI supports different detail levels for vision processing:
|
|
221
|
+
- "low": Fixed 512x512, 85 tokens, fast
|
|
222
|
+
- "high": Scaled up to 2048x2048, more tokens, detailed
|
|
223
|
+
- "auto": Let the model decide based on image size
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
image: PIL Image.
|
|
227
|
+
detail: Detail level ("low", "high", "auto").
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
Image content block with detail specification.
|
|
231
|
+
"""
|
|
232
|
+
base64_data = self.image_to_base64(image, "PNG")
|
|
233
|
+
return {
|
|
234
|
+
"type": "image_url",
|
|
235
|
+
"image_url": {
|
|
236
|
+
"url": f"data:image/png;base64,{base64_data}",
|
|
237
|
+
"detail": detail,
|
|
238
|
+
},
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
def encode_image_from_url(
|
|
242
|
+
self,
|
|
243
|
+
url: str,
|
|
244
|
+
detail: str = "auto",
|
|
245
|
+
) -> dict[str, Any]:
|
|
246
|
+
"""Create image content block from URL.
|
|
247
|
+
|
|
248
|
+
OpenAI natively supports URL-based images, so no fetching needed.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
url: Image URL.
|
|
252
|
+
detail: Detail level ("low", "high", "auto").
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
Image content block for OpenAI API.
|
|
256
|
+
"""
|
|
257
|
+
return {
|
|
258
|
+
"type": "image_url",
|
|
259
|
+
"image_url": {
|
|
260
|
+
"url": url,
|
|
261
|
+
"detail": detail,
|
|
262
|
+
},
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
def encode_image_from_bytes(
|
|
266
|
+
self,
|
|
267
|
+
image_bytes: bytes,
|
|
268
|
+
media_type: str = "image/png",
|
|
269
|
+
) -> dict[str, Any]:
|
|
270
|
+
"""Encode raw image bytes for OpenAI API.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
image_bytes: Raw image bytes.
|
|
274
|
+
media_type: MIME type of the image.
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
Image content block for OpenAI API.
|
|
278
|
+
"""
|
|
279
|
+
import base64
|
|
280
|
+
|
|
281
|
+
base64_data = base64.b64encode(image_bytes).decode("utf-8")
|
|
282
|
+
return {
|
|
283
|
+
"type": "image_url",
|
|
284
|
+
"image_url": {
|
|
285
|
+
"url": f"data:{media_type};base64,{base64_data}",
|
|
286
|
+
},
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
def send_with_tools(
|
|
290
|
+
self,
|
|
291
|
+
client: Any,
|
|
292
|
+
model: str,
|
|
293
|
+
messages: list[dict[str, Any]],
|
|
294
|
+
tools: list[dict[str, Any]],
|
|
295
|
+
tool_choice: str | dict[str, Any] = "auto",
|
|
296
|
+
max_tokens: int = 1024,
|
|
297
|
+
temperature: float = 0.1,
|
|
298
|
+
) -> Any:
|
|
299
|
+
"""Send message with function calling/tools support.
|
|
300
|
+
|
|
301
|
+
OpenAI supports function calling which can be useful for structured
|
|
302
|
+
action extraction in GUI automation.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
client: OpenAI client.
|
|
306
|
+
model: Model ID.
|
|
307
|
+
messages: Chat messages.
|
|
308
|
+
tools: Tool definitions.
|
|
309
|
+
tool_choice: Tool choice strategy.
|
|
310
|
+
max_tokens: Max response tokens.
|
|
311
|
+
temperature: Sampling temperature.
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
Raw API response (for tool call handling).
|
|
315
|
+
|
|
316
|
+
Example:
|
|
317
|
+
>>> tools = [{
|
|
318
|
+
... "type": "function",
|
|
319
|
+
... "function": {
|
|
320
|
+
... "name": "click",
|
|
321
|
+
... "parameters": {
|
|
322
|
+
... "type": "object",
|
|
323
|
+
... "properties": {
|
|
324
|
+
... "x": {"type": "number"},
|
|
325
|
+
... "y": {"type": "number"}
|
|
326
|
+
... }
|
|
327
|
+
... }
|
|
328
|
+
... }
|
|
329
|
+
... }]
|
|
330
|
+
>>> response = provider.send_with_tools(client, model, messages, tools)
|
|
331
|
+
"""
|
|
332
|
+
try:
|
|
333
|
+
return client.chat.completions.create(
|
|
334
|
+
model=model,
|
|
335
|
+
messages=messages,
|
|
336
|
+
tools=tools,
|
|
337
|
+
tool_choice=tool_choice,
|
|
338
|
+
max_completion_tokens=max_tokens,
|
|
339
|
+
temperature=temperature,
|
|
340
|
+
)
|
|
341
|
+
except Exception as e:
|
|
342
|
+
raise ProviderError(f"OpenAI tools API error: {e}") from e
|
openadapt_ml/models/qwen_vl.py
CHANGED
|
@@ -2,15 +2,20 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from typing import Any, Dict, List, Optional
|
|
4
4
|
|
|
5
|
-
from PIL import Image
|
|
6
5
|
import torch
|
|
7
6
|
from peft import LoraConfig, PeftModel, get_peft_model
|
|
8
|
-
from transformers import
|
|
7
|
+
from transformers import (
|
|
8
|
+
AutoProcessor,
|
|
9
|
+
Qwen3VLForConditionalGeneration,
|
|
10
|
+
Qwen2_5_VLForConditionalGeneration,
|
|
11
|
+
)
|
|
9
12
|
|
|
10
13
|
from openadapt_ml.models.base_adapter import BaseVLMAdapter, get_default_device
|
|
11
14
|
|
|
12
15
|
|
|
13
|
-
def _process_vision_info(
|
|
16
|
+
def _process_vision_info(
|
|
17
|
+
messages: List[Dict[str, Any]],
|
|
18
|
+
) -> tuple[list[list[Any]], list[list[Any]]]:
|
|
14
19
|
"""Minimal stand-in for qwen_vl_utils.process_vision_info.
|
|
15
20
|
|
|
16
21
|
For our use case we only need to extract image/video entries from the
|
|
@@ -100,10 +105,12 @@ class QwenVLAdapter(BaseVLMAdapter):
|
|
|
100
105
|
processor = AutoProcessor.from_pretrained(model_name)
|
|
101
106
|
|
|
102
107
|
# Configure image resolution for faster training
|
|
103
|
-
if max_pixels is not None and hasattr(processor,
|
|
108
|
+
if max_pixels is not None and hasattr(processor, "image_processor"):
|
|
104
109
|
processor.image_processor.max_pixels = max_pixels
|
|
105
|
-
print(
|
|
106
|
-
|
|
110
|
+
print(
|
|
111
|
+
f"Set max_pixels to {max_pixels} ({int(max_pixels**0.5)}x{int(max_pixels**0.5)} approx)"
|
|
112
|
+
)
|
|
113
|
+
if min_pixels is not None and hasattr(processor, "image_processor"):
|
|
107
114
|
processor.image_processor.min_pixels = min_pixels
|
|
108
115
|
|
|
109
116
|
model_kwargs: Dict[str, Any] = {}
|
|
@@ -121,7 +128,9 @@ class QwenVLAdapter(BaseVLMAdapter):
|
|
|
121
128
|
if lora_config is not None:
|
|
122
129
|
if isinstance(lora_config, dict):
|
|
123
130
|
lora_weights_path = lora_config.get("weights_path")
|
|
124
|
-
lora_cfg_clean = {
|
|
131
|
+
lora_cfg_clean = {
|
|
132
|
+
k: v for k, v in lora_config.items() if k != "weights_path"
|
|
133
|
+
}
|
|
125
134
|
else:
|
|
126
135
|
lora_cfg_clean = lora_config
|
|
127
136
|
|
|
@@ -184,10 +193,12 @@ class QwenVLAdapter(BaseVLMAdapter):
|
|
|
184
193
|
},
|
|
185
194
|
]
|
|
186
195
|
if assistant_text:
|
|
187
|
-
qwen_messages_full.append(
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
196
|
+
qwen_messages_full.append(
|
|
197
|
+
{
|
|
198
|
+
"role": "assistant",
|
|
199
|
+
"content": [{"type": "text", "text": assistant_text}],
|
|
200
|
+
}
|
|
201
|
+
)
|
|
191
202
|
batch_messages_full.append(qwen_messages_full)
|
|
192
203
|
|
|
193
204
|
# User-only messages (for label masking)
|
|
@@ -250,7 +261,11 @@ class QwenVLAdapter(BaseVLMAdapter):
|
|
|
250
261
|
# Padding token is typically 0 or a special value
|
|
251
262
|
# For Qwen, we look for the first occurrence of pad token
|
|
252
263
|
pad_token_id = self.processor.tokenizer.pad_token_id
|
|
253
|
-
user_ids_no_pad =
|
|
264
|
+
user_ids_no_pad = (
|
|
265
|
+
user_ids[user_ids != pad_token_id]
|
|
266
|
+
if pad_token_id is not None
|
|
267
|
+
else user_ids
|
|
268
|
+
)
|
|
254
269
|
user_len = len(user_ids_no_pad)
|
|
255
270
|
|
|
256
271
|
# Check if user sequence is a prefix of full sequence
|
|
@@ -261,7 +276,10 @@ class QwenVLAdapter(BaseVLMAdapter):
|
|
|
261
276
|
labels[i, user_len:] = full_ids[user_len:]
|
|
262
277
|
|
|
263
278
|
# Ensure padding tokens are masked in labels
|
|
264
|
-
if
|
|
279
|
+
if (
|
|
280
|
+
hasattr(self.processor.tokenizer, "pad_token_id")
|
|
281
|
+
and self.processor.tokenizer.pad_token_id is not None
|
|
282
|
+
):
|
|
265
283
|
labels[input_ids_full == self.processor.tokenizer.pad_token_id] = -100
|
|
266
284
|
|
|
267
285
|
inputs_full["labels"] = labels
|
|
@@ -300,10 +318,12 @@ class QwenVLAdapter(BaseVLMAdapter):
|
|
|
300
318
|
}
|
|
301
319
|
]
|
|
302
320
|
if assistant_text:
|
|
303
|
-
qwen_messages.append(
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
321
|
+
qwen_messages.append(
|
|
322
|
+
{
|
|
323
|
+
"role": "assistant",
|
|
324
|
+
"content": [{"type": "text", "text": assistant_text}],
|
|
325
|
+
}
|
|
326
|
+
)
|
|
307
327
|
|
|
308
328
|
batch_messages.append(qwen_messages)
|
|
309
329
|
|
|
@@ -339,14 +359,20 @@ class QwenVLAdapter(BaseVLMAdapter):
|
|
|
339
359
|
labels = input_ids.clone()
|
|
340
360
|
|
|
341
361
|
# Mask padding tokens
|
|
342
|
-
if
|
|
362
|
+
if (
|
|
363
|
+
hasattr(self.processor.tokenizer, "pad_token_id")
|
|
364
|
+
and self.processor.tokenizer.pad_token_id is not None
|
|
365
|
+
):
|
|
343
366
|
labels[input_ids == self.processor.tokenizer.pad_token_id] = -100
|
|
344
367
|
|
|
345
368
|
inputs["labels"] = labels
|
|
346
369
|
return inputs
|
|
347
370
|
|
|
348
371
|
def compute_loss(self, inputs: Dict[str, Any]) -> torch.Tensor: # type: ignore[override]
|
|
349
|
-
inputs = {
|
|
372
|
+
inputs = {
|
|
373
|
+
k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
|
374
|
+
for k, v in inputs.items()
|
|
375
|
+
}
|
|
350
376
|
outputs = self.model(**inputs)
|
|
351
377
|
# Hugging Face causal LM models return `loss` when `labels` are provided.
|
|
352
378
|
return outputs.loss # type: ignore[no-any-return]
|
|
@@ -420,6 +446,7 @@ class QwenVLAdapter(BaseVLMAdapter):
|
|
|
420
446
|
def save_checkpoint(self, path: str) -> None:
|
|
421
447
|
"""Save the LoRA adapter weights to a directory."""
|
|
422
448
|
from pathlib import Path
|
|
449
|
+
|
|
423
450
|
save_path = Path(path)
|
|
424
451
|
save_path.mkdir(parents=True, exist_ok=True)
|
|
425
452
|
# Save the PEFT adapter (LoRA weights only, not base model)
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Perception Integration Module
|
|
3
|
+
|
|
4
|
+
Bridges openadapt-grounding (UI element detection) with openadapt-ml (ML schema).
|
|
5
|
+
|
|
6
|
+
This module provides:
|
|
7
|
+
- UIElementGraph: A wrapper class for parsed UI elements
|
|
8
|
+
- Conversion utilities between grounding types and ML schema types
|
|
9
|
+
|
|
10
|
+
Usage:
|
|
11
|
+
from openadapt_ml.perception import UIElementGraph
|
|
12
|
+
|
|
13
|
+
# From parser output
|
|
14
|
+
graph = UIElementGraph.from_parser_output(elements, source="omniparser")
|
|
15
|
+
|
|
16
|
+
# Access elements
|
|
17
|
+
for element in graph.elements:
|
|
18
|
+
print(f"{element.role}: {element.name}")
|
|
19
|
+
|
|
20
|
+
Requires:
|
|
21
|
+
pip install openadapt-grounding
|
|
22
|
+
# or: uv add openadapt-grounding
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from openadapt_ml.perception.integration import (
|
|
26
|
+
UIElementGraph,
|
|
27
|
+
element_to_ui_element,
|
|
28
|
+
ui_element_to_element,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
__all__ = [
|
|
32
|
+
"UIElementGraph",
|
|
33
|
+
"element_to_ui_element",
|
|
34
|
+
"ui_element_to_element",
|
|
35
|
+
]
|