sentienceapi 0.90.16__py3-none-any.whl → 0.98.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sentienceapi might be problematic. Click here for more details.
- sentience/__init__.py +120 -6
- sentience/_extension_loader.py +156 -1
- sentience/action_executor.py +217 -0
- sentience/actions.py +758 -30
- sentience/agent.py +806 -293
- sentience/agent_config.py +3 -0
- sentience/agent_runtime.py +840 -0
- sentience/asserts/__init__.py +70 -0
- sentience/asserts/expect.py +621 -0
- sentience/asserts/query.py +383 -0
- sentience/async_api.py +89 -1141
- sentience/backends/__init__.py +137 -0
- sentience/backends/actions.py +372 -0
- sentience/backends/browser_use_adapter.py +241 -0
- sentience/backends/cdp_backend.py +393 -0
- sentience/backends/exceptions.py +211 -0
- sentience/backends/playwright_backend.py +194 -0
- sentience/backends/protocol.py +216 -0
- sentience/backends/sentience_context.py +469 -0
- sentience/backends/snapshot.py +483 -0
- sentience/base_agent.py +95 -0
- sentience/browser.py +678 -39
- sentience/browser_evaluator.py +299 -0
- sentience/canonicalization.py +207 -0
- sentience/cloud_tracing.py +507 -42
- sentience/constants.py +6 -0
- sentience/conversational_agent.py +77 -43
- sentience/cursor_policy.py +142 -0
- sentience/element_filter.py +136 -0
- sentience/expect.py +98 -2
- sentience/extension/background.js +56 -185
- sentience/extension/content.js +150 -287
- sentience/extension/injected_api.js +1088 -1368
- sentience/extension/manifest.json +1 -1
- sentience/extension/pkg/sentience_core.d.ts +22 -22
- sentience/extension/pkg/sentience_core.js +275 -433
- sentience/extension/pkg/sentience_core_bg.wasm +0 -0
- sentience/extension/release.json +47 -47
- sentience/failure_artifacts.py +241 -0
- sentience/formatting.py +9 -53
- sentience/inspector.py +183 -1
- sentience/integrations/__init__.py +6 -0
- sentience/integrations/langchain/__init__.py +12 -0
- sentience/integrations/langchain/context.py +18 -0
- sentience/integrations/langchain/core.py +326 -0
- sentience/integrations/langchain/tools.py +180 -0
- sentience/integrations/models.py +46 -0
- sentience/integrations/pydanticai/__init__.py +15 -0
- sentience/integrations/pydanticai/deps.py +20 -0
- sentience/integrations/pydanticai/toolset.py +468 -0
- sentience/llm_interaction_handler.py +191 -0
- sentience/llm_provider.py +765 -66
- sentience/llm_provider_utils.py +120 -0
- sentience/llm_response_builder.py +153 -0
- sentience/models.py +595 -3
- sentience/ordinal.py +280 -0
- sentience/overlay.py +109 -2
- sentience/protocols.py +228 -0
- sentience/query.py +67 -5
- sentience/read.py +95 -3
- sentience/recorder.py +223 -3
- sentience/schemas/trace_v1.json +128 -9
- sentience/screenshot.py +48 -2
- sentience/sentience_methods.py +86 -0
- sentience/snapshot.py +599 -55
- sentience/snapshot_diff.py +126 -0
- sentience/text_search.py +120 -5
- sentience/trace_event_builder.py +148 -0
- sentience/trace_file_manager.py +197 -0
- sentience/trace_indexing/index_schema.py +95 -7
- sentience/trace_indexing/indexer.py +105 -48
- sentience/tracer_factory.py +120 -9
- sentience/tracing.py +172 -8
- sentience/utils/__init__.py +40 -0
- sentience/utils/browser.py +46 -0
- sentience/{utils.py → utils/element.py} +3 -42
- sentience/utils/formatting.py +59 -0
- sentience/verification.py +618 -0
- sentience/visual_agent.py +2058 -0
- sentience/wait.py +68 -2
- {sentienceapi-0.90.16.dist-info → sentienceapi-0.98.0.dist-info}/METADATA +199 -40
- sentienceapi-0.98.0.dist-info/RECORD +92 -0
- sentience/extension/test-content.js +0 -4
- sentienceapi-0.90.16.dist-info/RECORD +0 -50
- {sentienceapi-0.90.16.dist-info → sentienceapi-0.98.0.dist-info}/WHEEL +0 -0
- {sentienceapi-0.90.16.dist-info → sentienceapi-0.98.0.dist-info}/entry_points.txt +0 -0
- {sentienceapi-0.90.16.dist-info → sentienceapi-0.98.0.dist-info}/licenses/LICENSE +0 -0
- {sentienceapi-0.90.16.dist-info → sentienceapi-0.98.0.dist-info}/licenses/LICENSE-APACHE +0 -0
- {sentienceapi-0.90.16.dist-info → sentienceapi-0.98.0.dist-info}/licenses/LICENSE-MIT +0 -0
- {sentienceapi-0.90.16.dist-info → sentienceapi-0.98.0.dist-info}/top_level.txt +0 -0
sentience/llm_provider.py
CHANGED
|
@@ -5,6 +5,10 @@ Enables "Bring Your Own Brain" (BYOB) pattern - plug in any LLM provider
|
|
|
5
5
|
|
|
6
6
|
from abc import ABC, abstractmethod
|
|
7
7
|
from dataclasses import dataclass
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from .llm_provider_utils import get_api_key_from_env, handle_provider_error, require_package
|
|
11
|
+
from .llm_response_builder import LLMResponseBuilder
|
|
8
12
|
|
|
9
13
|
|
|
10
14
|
@dataclass
|
|
@@ -31,6 +35,15 @@ class LLMProvider(ABC):
|
|
|
31
35
|
- Any other completion API
|
|
32
36
|
"""
|
|
33
37
|
|
|
38
|
+
def __init__(self, model: str):
|
|
39
|
+
"""
|
|
40
|
+
Initialize LLM provider with model name.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
model: Model identifier (e.g., "gpt-4o", "claude-3-sonnet")
|
|
44
|
+
"""
|
|
45
|
+
self._model_name = model
|
|
46
|
+
|
|
34
47
|
@abstractmethod
|
|
35
48
|
def generate(self, system_prompt: str, user_prompt: str, **kwargs) -> LLMResponse:
|
|
36
49
|
"""
|
|
@@ -67,6 +80,48 @@ class LLMProvider(ABC):
|
|
|
67
80
|
"""
|
|
68
81
|
pass
|
|
69
82
|
|
|
83
|
+
def supports_vision(self) -> bool:
|
|
84
|
+
"""
|
|
85
|
+
Whether this provider supports image input for vision tasks.
|
|
86
|
+
|
|
87
|
+
Override in subclasses that support vision-capable models.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
True if provider supports vision, False otherwise
|
|
91
|
+
"""
|
|
92
|
+
return False
|
|
93
|
+
|
|
94
|
+
def generate_with_image(
|
|
95
|
+
self,
|
|
96
|
+
system_prompt: str,
|
|
97
|
+
user_prompt: str,
|
|
98
|
+
image_base64: str,
|
|
99
|
+
**kwargs,
|
|
100
|
+
) -> LLMResponse:
|
|
101
|
+
"""
|
|
102
|
+
Generate a response with image input (for vision-capable models).
|
|
103
|
+
|
|
104
|
+
This method is used for vision fallback in assertions and visual agents.
|
|
105
|
+
Override in subclasses that support vision-capable models.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
system_prompt: System instruction/context
|
|
109
|
+
user_prompt: User query/request
|
|
110
|
+
image_base64: Base64-encoded image (PNG or JPEG)
|
|
111
|
+
**kwargs: Provider-specific parameters (temperature, max_tokens, etc.)
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
LLMResponse with content and token usage
|
|
115
|
+
|
|
116
|
+
Raises:
|
|
117
|
+
NotImplementedError: If provider doesn't support vision
|
|
118
|
+
"""
|
|
119
|
+
raise NotImplementedError(
|
|
120
|
+
f"{type(self).__name__} does not support vision. "
|
|
121
|
+
"Use a vision-capable provider like OpenAIProvider with GPT-4o "
|
|
122
|
+
"or AnthropicProvider with Claude 3."
|
|
123
|
+
)
|
|
124
|
+
|
|
70
125
|
|
|
71
126
|
class OpenAIProvider(LLMProvider):
|
|
72
127
|
"""
|
|
@@ -95,13 +150,16 @@ class OpenAIProvider(LLMProvider):
|
|
|
95
150
|
base_url: Custom API base URL (for compatible APIs)
|
|
96
151
|
organization: OpenAI organization ID
|
|
97
152
|
"""
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
153
|
+
super().__init__(model) # Initialize base class with model name
|
|
154
|
+
|
|
155
|
+
OpenAI = require_package(
|
|
156
|
+
"openai",
|
|
157
|
+
"openai",
|
|
158
|
+
"OpenAI",
|
|
159
|
+
"pip install openai",
|
|
160
|
+
)
|
|
102
161
|
|
|
103
162
|
self.client = OpenAI(api_key=api_key, base_url=base_url, organization=organization)
|
|
104
|
-
self._model_name = model
|
|
105
163
|
|
|
106
164
|
def generate(
|
|
107
165
|
self,
|
|
@@ -148,12 +206,15 @@ class OpenAIProvider(LLMProvider):
|
|
|
148
206
|
api_params.update(kwargs)
|
|
149
207
|
|
|
150
208
|
# Call OpenAI API
|
|
151
|
-
|
|
209
|
+
try:
|
|
210
|
+
response = self.client.chat.completions.create(**api_params)
|
|
211
|
+
except Exception as e:
|
|
212
|
+
handle_provider_error(e, "OpenAI", "generate response")
|
|
152
213
|
|
|
153
214
|
choice = response.choices[0]
|
|
154
215
|
usage = response.usage
|
|
155
216
|
|
|
156
|
-
return
|
|
217
|
+
return LLMResponseBuilder.from_openai_format(
|
|
157
218
|
content=choice.message.content,
|
|
158
219
|
prompt_tokens=usage.prompt_tokens if usage else None,
|
|
159
220
|
completion_tokens=usage.completion_tokens if usage else None,
|
|
@@ -167,6 +228,92 @@ class OpenAIProvider(LLMProvider):
|
|
|
167
228
|
model_lower = self._model_name.lower()
|
|
168
229
|
return any(x in model_lower for x in ["gpt-4", "gpt-3.5"])
|
|
169
230
|
|
|
231
|
+
def supports_vision(self) -> bool:
|
|
232
|
+
"""GPT-4o, GPT-4-turbo, and GPT-4-vision support vision."""
|
|
233
|
+
model_lower = self._model_name.lower()
|
|
234
|
+
return any(x in model_lower for x in ["gpt-4o", "gpt-4-turbo", "gpt-4-vision"])
|
|
235
|
+
|
|
236
|
+
def generate_with_image(
|
|
237
|
+
self,
|
|
238
|
+
system_prompt: str,
|
|
239
|
+
user_prompt: str,
|
|
240
|
+
image_base64: str,
|
|
241
|
+
temperature: float = 0.0,
|
|
242
|
+
max_tokens: int | None = None,
|
|
243
|
+
**kwargs,
|
|
244
|
+
) -> LLMResponse:
|
|
245
|
+
"""
|
|
246
|
+
Generate response with image input using OpenAI Vision API.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
system_prompt: System instruction
|
|
250
|
+
user_prompt: User query
|
|
251
|
+
image_base64: Base64-encoded image (PNG or JPEG)
|
|
252
|
+
temperature: Sampling temperature (0.0 = deterministic)
|
|
253
|
+
max_tokens: Maximum tokens to generate
|
|
254
|
+
**kwargs: Additional OpenAI API parameters
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
LLMResponse object
|
|
258
|
+
|
|
259
|
+
Raises:
|
|
260
|
+
NotImplementedError: If model doesn't support vision
|
|
261
|
+
"""
|
|
262
|
+
if not self.supports_vision():
|
|
263
|
+
raise NotImplementedError(
|
|
264
|
+
f"Model {self._model_name} does not support vision. "
|
|
265
|
+
"Use gpt-4o, gpt-4-turbo, or gpt-4-vision-preview."
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
messages = []
|
|
269
|
+
if system_prompt:
|
|
270
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
271
|
+
|
|
272
|
+
# Vision message format with image_url
|
|
273
|
+
messages.append(
|
|
274
|
+
{
|
|
275
|
+
"role": "user",
|
|
276
|
+
"content": [
|
|
277
|
+
{"type": "text", "text": user_prompt},
|
|
278
|
+
{
|
|
279
|
+
"type": "image_url",
|
|
280
|
+
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
|
281
|
+
},
|
|
282
|
+
],
|
|
283
|
+
}
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
# Build API parameters
|
|
287
|
+
api_params = {
|
|
288
|
+
"model": self._model_name,
|
|
289
|
+
"messages": messages,
|
|
290
|
+
"temperature": temperature,
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
if max_tokens:
|
|
294
|
+
api_params["max_tokens"] = max_tokens
|
|
295
|
+
|
|
296
|
+
# Merge additional parameters
|
|
297
|
+
api_params.update(kwargs)
|
|
298
|
+
|
|
299
|
+
# Call OpenAI API
|
|
300
|
+
try:
|
|
301
|
+
response = self.client.chat.completions.create(**api_params)
|
|
302
|
+
except Exception as e:
|
|
303
|
+
handle_provider_error(e, "OpenAI", "generate response with image")
|
|
304
|
+
|
|
305
|
+
choice = response.choices[0]
|
|
306
|
+
usage = response.usage
|
|
307
|
+
|
|
308
|
+
return LLMResponseBuilder.from_openai_format(
|
|
309
|
+
content=choice.message.content,
|
|
310
|
+
prompt_tokens=usage.prompt_tokens if usage else None,
|
|
311
|
+
completion_tokens=usage.completion_tokens if usage else None,
|
|
312
|
+
total_tokens=usage.total_tokens if usage else None,
|
|
313
|
+
model_name=response.model,
|
|
314
|
+
finish_reason=choice.finish_reason,
|
|
315
|
+
)
|
|
316
|
+
|
|
170
317
|
@property
|
|
171
318
|
def model_name(self) -> str:
|
|
172
319
|
return self._model_name
|
|
@@ -191,15 +338,16 @@ class AnthropicProvider(LLMProvider):
|
|
|
191
338
|
api_key: Anthropic API key (or set ANTHROPIC_API_KEY env var)
|
|
192
339
|
model: Model name (claude-3-opus, claude-3-sonnet, claude-3-haiku, etc.)
|
|
193
340
|
"""
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
341
|
+
super().__init__(model) # Initialize base class with model name
|
|
342
|
+
|
|
343
|
+
Anthropic = require_package(
|
|
344
|
+
"anthropic",
|
|
345
|
+
"anthropic",
|
|
346
|
+
"Anthropic",
|
|
347
|
+
"pip install anthropic",
|
|
348
|
+
)
|
|
200
349
|
|
|
201
350
|
self.client = Anthropic(api_key=api_key)
|
|
202
|
-
self._model_name = model
|
|
203
351
|
|
|
204
352
|
def generate(
|
|
205
353
|
self,
|
|
@@ -237,27 +385,113 @@ class AnthropicProvider(LLMProvider):
|
|
|
237
385
|
api_params.update(kwargs)
|
|
238
386
|
|
|
239
387
|
# Call Anthropic API
|
|
240
|
-
|
|
388
|
+
try:
|
|
389
|
+
response = self.client.messages.create(**api_params)
|
|
390
|
+
except Exception as e:
|
|
391
|
+
handle_provider_error(e, "Anthropic", "generate response")
|
|
241
392
|
|
|
242
393
|
content = response.content[0].text if response.content else ""
|
|
243
394
|
|
|
244
|
-
return
|
|
395
|
+
return LLMResponseBuilder.from_anthropic_format(
|
|
245
396
|
content=content,
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
total_tokens=(
|
|
249
|
-
(response.usage.input_tokens + response.usage.output_tokens)
|
|
250
|
-
if hasattr(response, "usage")
|
|
251
|
-
else None
|
|
252
|
-
),
|
|
397
|
+
input_tokens=response.usage.input_tokens if hasattr(response, "usage") else None,
|
|
398
|
+
output_tokens=response.usage.output_tokens if hasattr(response, "usage") else None,
|
|
253
399
|
model_name=response.model,
|
|
254
|
-
|
|
400
|
+
stop_reason=response.stop_reason,
|
|
255
401
|
)
|
|
256
402
|
|
|
257
403
|
def supports_json_mode(self) -> bool:
|
|
258
404
|
"""Anthropic doesn't have native JSON mode (requires prompt engineering)"""
|
|
259
405
|
return False
|
|
260
406
|
|
|
407
|
+
def supports_vision(self) -> bool:
|
|
408
|
+
"""Claude 3 models (Opus, Sonnet, Haiku) all support vision."""
|
|
409
|
+
model_lower = self._model_name.lower()
|
|
410
|
+
return any(x in model_lower for x in ["claude-3", "claude-3.5"])
|
|
411
|
+
|
|
412
|
+
def generate_with_image(
|
|
413
|
+
self,
|
|
414
|
+
system_prompt: str,
|
|
415
|
+
user_prompt: str,
|
|
416
|
+
image_base64: str,
|
|
417
|
+
temperature: float = 0.0,
|
|
418
|
+
max_tokens: int = 1024,
|
|
419
|
+
**kwargs,
|
|
420
|
+
) -> LLMResponse:
|
|
421
|
+
"""
|
|
422
|
+
Generate response with image input using Anthropic Vision API.
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
system_prompt: System instruction
|
|
426
|
+
user_prompt: User query
|
|
427
|
+
image_base64: Base64-encoded image (PNG or JPEG)
|
|
428
|
+
temperature: Sampling temperature
|
|
429
|
+
max_tokens: Maximum tokens to generate (required by Anthropic)
|
|
430
|
+
**kwargs: Additional Anthropic API parameters
|
|
431
|
+
|
|
432
|
+
Returns:
|
|
433
|
+
LLMResponse object
|
|
434
|
+
|
|
435
|
+
Raises:
|
|
436
|
+
NotImplementedError: If model doesn't support vision
|
|
437
|
+
"""
|
|
438
|
+
if not self.supports_vision():
|
|
439
|
+
raise NotImplementedError(
|
|
440
|
+
f"Model {self._model_name} does not support vision. "
|
|
441
|
+
"Use Claude 3 models (claude-3-opus, claude-3-sonnet, claude-3-haiku)."
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
# Anthropic vision message format
|
|
445
|
+
messages = [
|
|
446
|
+
{
|
|
447
|
+
"role": "user",
|
|
448
|
+
"content": [
|
|
449
|
+
{
|
|
450
|
+
"type": "image",
|
|
451
|
+
"source": {
|
|
452
|
+
"type": "base64",
|
|
453
|
+
"media_type": "image/png",
|
|
454
|
+
"data": image_base64,
|
|
455
|
+
},
|
|
456
|
+
},
|
|
457
|
+
{
|
|
458
|
+
"type": "text",
|
|
459
|
+
"text": user_prompt,
|
|
460
|
+
},
|
|
461
|
+
],
|
|
462
|
+
}
|
|
463
|
+
]
|
|
464
|
+
|
|
465
|
+
# Build API parameters
|
|
466
|
+
api_params = {
|
|
467
|
+
"model": self._model_name,
|
|
468
|
+
"max_tokens": max_tokens,
|
|
469
|
+
"temperature": temperature,
|
|
470
|
+
"messages": messages,
|
|
471
|
+
}
|
|
472
|
+
|
|
473
|
+
if system_prompt:
|
|
474
|
+
api_params["system"] = system_prompt
|
|
475
|
+
|
|
476
|
+
# Merge additional parameters
|
|
477
|
+
api_params.update(kwargs)
|
|
478
|
+
|
|
479
|
+
# Call Anthropic API
|
|
480
|
+
try:
|
|
481
|
+
response = self.client.messages.create(**api_params)
|
|
482
|
+
except Exception as e:
|
|
483
|
+
handle_provider_error(e, "Anthropic", "generate response with image")
|
|
484
|
+
|
|
485
|
+
content = response.content[0].text if response.content else ""
|
|
486
|
+
|
|
487
|
+
return LLMResponseBuilder.from_anthropic_format(
|
|
488
|
+
content=content,
|
|
489
|
+
input_tokens=response.usage.input_tokens if hasattr(response, "usage") else None,
|
|
490
|
+
output_tokens=response.usage.output_tokens if hasattr(response, "usage") else None,
|
|
491
|
+
model_name=response.model,
|
|
492
|
+
stop_reason=response.stop_reason,
|
|
493
|
+
)
|
|
494
|
+
|
|
261
495
|
@property
|
|
262
496
|
def model_name(self) -> str:
|
|
263
497
|
return self._model_name
|
|
@@ -285,13 +519,16 @@ class GLMProvider(LLMProvider):
|
|
|
285
519
|
api_key: Zhipu AI API key (or set GLM_API_KEY env var)
|
|
286
520
|
model: Model name (glm-4-plus, glm-4, glm-4-air, glm-4-flash, etc.)
|
|
287
521
|
"""
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
522
|
+
super().__init__(model) # Initialize base class with model name
|
|
523
|
+
|
|
524
|
+
ZhipuAI = require_package(
|
|
525
|
+
"zhipuai",
|
|
526
|
+
"zhipuai",
|
|
527
|
+
"ZhipuAI",
|
|
528
|
+
"pip install zhipuai",
|
|
529
|
+
)
|
|
292
530
|
|
|
293
531
|
self.client = ZhipuAI(api_key=api_key)
|
|
294
|
-
self._model_name = model
|
|
295
532
|
|
|
296
533
|
def generate(
|
|
297
534
|
self,
|
|
@@ -333,12 +570,15 @@ class GLMProvider(LLMProvider):
|
|
|
333
570
|
api_params.update(kwargs)
|
|
334
571
|
|
|
335
572
|
# Call GLM API
|
|
336
|
-
|
|
573
|
+
try:
|
|
574
|
+
response = self.client.chat.completions.create(**api_params)
|
|
575
|
+
except Exception as e:
|
|
576
|
+
handle_provider_error(e, "GLM", "generate response")
|
|
337
577
|
|
|
338
578
|
choice = response.choices[0]
|
|
339
579
|
usage = response.usage
|
|
340
580
|
|
|
341
|
-
return
|
|
581
|
+
return LLMResponseBuilder.from_openai_format(
|
|
342
582
|
content=choice.message.content,
|
|
343
583
|
prompt_tokens=usage.prompt_tokens if usage else None,
|
|
344
584
|
completion_tokens=usage.completion_tokens if usage else None,
|
|
@@ -378,25 +618,20 @@ class GeminiProvider(LLMProvider):
|
|
|
378
618
|
api_key: Google API key (or set GEMINI_API_KEY or GOOGLE_API_KEY env var)
|
|
379
619
|
model: Model name (gemini-2.0-flash-exp, gemini-1.5-pro, gemini-1.5-flash, etc.)
|
|
380
620
|
"""
|
|
381
|
-
|
|
382
|
-
import google.generativeai as genai
|
|
383
|
-
except ImportError:
|
|
384
|
-
raise ImportError(
|
|
385
|
-
"Google Generative AI package not installed. Install with: pip install google-generativeai"
|
|
386
|
-
)
|
|
621
|
+
super().__init__(model) # Initialize base class with model name
|
|
387
622
|
|
|
388
|
-
|
|
623
|
+
genai = require_package(
|
|
624
|
+
"google-generativeai",
|
|
625
|
+
"google.generativeai",
|
|
626
|
+
install_command="pip install google-generativeai",
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
# Configure API key (check parameter first, then environment variables)
|
|
630
|
+
api_key = get_api_key_from_env(["GEMINI_API_KEY", "GOOGLE_API_KEY"], api_key)
|
|
389
631
|
if api_key:
|
|
390
632
|
genai.configure(api_key=api_key)
|
|
391
|
-
else:
|
|
392
|
-
import os
|
|
393
|
-
|
|
394
|
-
api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
|
|
395
|
-
if api_key:
|
|
396
|
-
genai.configure(api_key=api_key)
|
|
397
633
|
|
|
398
634
|
self.genai = genai
|
|
399
|
-
self._model_name = model
|
|
400
635
|
self.model = genai.GenerativeModel(model)
|
|
401
636
|
|
|
402
637
|
def generate(
|
|
@@ -435,7 +670,10 @@ class GeminiProvider(LLMProvider):
|
|
|
435
670
|
generation_config.update(kwargs)
|
|
436
671
|
|
|
437
672
|
# Call Gemini API
|
|
438
|
-
|
|
673
|
+
try:
|
|
674
|
+
response = self.model.generate_content(full_prompt, generation_config=generation_config)
|
|
675
|
+
except Exception as e:
|
|
676
|
+
handle_provider_error(e, "Gemini", "generate response")
|
|
439
677
|
|
|
440
678
|
# Extract content
|
|
441
679
|
content = response.text if response.text else ""
|
|
@@ -450,13 +688,12 @@ class GeminiProvider(LLMProvider):
|
|
|
450
688
|
completion_tokens = response.usage_metadata.candidates_token_count
|
|
451
689
|
total_tokens = response.usage_metadata.total_token_count
|
|
452
690
|
|
|
453
|
-
return
|
|
691
|
+
return LLMResponseBuilder.from_gemini_format(
|
|
454
692
|
content=content,
|
|
455
693
|
prompt_tokens=prompt_tokens,
|
|
456
694
|
completion_tokens=completion_tokens,
|
|
457
695
|
total_tokens=total_tokens,
|
|
458
696
|
model_name=self._model_name,
|
|
459
|
-
finish_reason=None, # Gemini uses different finish reason format
|
|
460
697
|
)
|
|
461
698
|
|
|
462
699
|
def supports_json_mode(self) -> bool:
|
|
@@ -503,16 +740,24 @@ class LocalLLMProvider(LLMProvider):
|
|
|
503
740
|
load_in_8bit: Use 8-bit quantization (saves 50% memory)
|
|
504
741
|
torch_dtype: Data type ("auto", "float16", "bfloat16", "float32")
|
|
505
742
|
"""
|
|
743
|
+
super().__init__(model_name) # Initialize base class with model name
|
|
744
|
+
|
|
745
|
+
# Import required packages with consistent error handling.
|
|
746
|
+
# These are optional dependencies, so keep them out of module import-time.
|
|
506
747
|
try:
|
|
507
|
-
import torch
|
|
508
|
-
from transformers import
|
|
509
|
-
|
|
748
|
+
import torch # type: ignore[import-not-found]
|
|
749
|
+
from transformers import ( # type: ignore[import-not-found]
|
|
750
|
+
AutoModelForCausalLM,
|
|
751
|
+
AutoTokenizer,
|
|
752
|
+
BitsAndBytesConfig,
|
|
753
|
+
)
|
|
754
|
+
except ImportError as exc:
|
|
510
755
|
raise ImportError(
|
|
511
756
|
"transformers and torch required for local LLM. "
|
|
512
757
|
"Install with: pip install transformers torch"
|
|
513
|
-
)
|
|
758
|
+
) from exc
|
|
514
759
|
|
|
515
|
-
self.
|
|
760
|
+
self._torch = torch
|
|
516
761
|
|
|
517
762
|
# Load tokenizer
|
|
518
763
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
|
@@ -533,21 +778,44 @@ class LocalLLMProvider(LLMProvider):
|
|
|
533
778
|
elif load_in_8bit:
|
|
534
779
|
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
|
535
780
|
|
|
781
|
+
device = (device or "auto").strip().lower()
|
|
782
|
+
|
|
536
783
|
# Determine torch dtype
|
|
537
784
|
if torch_dtype == "auto":
|
|
538
|
-
dtype = torch.float16 if device
|
|
785
|
+
dtype = torch.float16 if device not in {"cpu"} else torch.float32
|
|
539
786
|
else:
|
|
540
787
|
dtype = getattr(torch, torch_dtype)
|
|
541
788
|
|
|
542
|
-
#
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
789
|
+
# device_map is a Transformers concept (not a literal "cpu/mps/cuda" device string).
|
|
790
|
+
# - "auto" enables Accelerate device mapping.
|
|
791
|
+
# - Otherwise, we load normally and then move the model to the requested device.
|
|
792
|
+
device_map: str | None = "auto" if device == "auto" else None
|
|
793
|
+
|
|
794
|
+
def _load(*, device_map_override: str | None) -> Any:
|
|
795
|
+
return AutoModelForCausalLM.from_pretrained(
|
|
796
|
+
model_name,
|
|
797
|
+
quantization_config=quantization_config,
|
|
798
|
+
torch_dtype=dtype if quantization_config is None else None,
|
|
799
|
+
device_map=device_map_override,
|
|
800
|
+
trust_remote_code=True,
|
|
801
|
+
low_cpu_mem_usage=True,
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
try:
|
|
805
|
+
self.model = _load(device_map_override=device_map)
|
|
806
|
+
except KeyError as e:
|
|
807
|
+
# Some envs / accelerate versions can crash on auto mapping (e.g. KeyError: 'cpu').
|
|
808
|
+
# Keep demo ergonomics: default stays "auto", but we gracefully fall back.
|
|
809
|
+
if device == "auto" and ("cpu" in str(e).lower()):
|
|
810
|
+
device = "cpu"
|
|
811
|
+
dtype = torch.float32
|
|
812
|
+
self.model = _load(device_map_override=None)
|
|
813
|
+
else:
|
|
814
|
+
raise
|
|
815
|
+
|
|
816
|
+
# If we didn't use device_map, move model explicitly (only safe for non-quantized loads).
|
|
817
|
+
if device_map is None and quantization_config is None and device in {"cpu", "cuda", "mps"}:
|
|
818
|
+
self.model = self.model.to(device)
|
|
551
819
|
self.model.eval()
|
|
552
820
|
|
|
553
821
|
def generate(
|
|
@@ -573,7 +841,7 @@ class LocalLLMProvider(LLMProvider):
|
|
|
573
841
|
Returns:
|
|
574
842
|
LLMResponse object
|
|
575
843
|
"""
|
|
576
|
-
|
|
844
|
+
torch = self._torch
|
|
577
845
|
|
|
578
846
|
# Auto-determine sampling based on temperature
|
|
579
847
|
do_sample = temperature > 0
|
|
@@ -620,11 +888,10 @@ class LocalLLMProvider(LLMProvider):
|
|
|
620
888
|
generated_tokens = outputs[0][input_length:]
|
|
621
889
|
response_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
|
|
622
890
|
|
|
623
|
-
return
|
|
891
|
+
return LLMResponseBuilder.from_local_format(
|
|
624
892
|
content=response_text,
|
|
625
893
|
prompt_tokens=input_length,
|
|
626
894
|
completion_tokens=len(generated_tokens),
|
|
627
|
-
total_tokens=input_length + len(generated_tokens),
|
|
628
895
|
model_name=self._model_name,
|
|
629
896
|
)
|
|
630
897
|
|
|
@@ -635,3 +902,435 @@ class LocalLLMProvider(LLMProvider):
|
|
|
635
902
|
@property
|
|
636
903
|
def model_name(self) -> str:
|
|
637
904
|
return self._model_name
|
|
905
|
+
|
|
906
|
+
|
|
907
|
+
class LocalVisionLLMProvider(LLMProvider):
|
|
908
|
+
"""
|
|
909
|
+
Local vision-language LLM provider using HuggingFace Transformers.
|
|
910
|
+
|
|
911
|
+
Intended for models like:
|
|
912
|
+
- Qwen/Qwen3-VL-8B-Instruct
|
|
913
|
+
|
|
914
|
+
Notes on Mac (MPS) + quantization:
|
|
915
|
+
- Transformers BitsAndBytes (4-bit/8-bit) typically requires CUDA and does NOT work on MPS.
|
|
916
|
+
- If you want quantized local vision on Apple Silicon, you may prefer MLX-based stacks
|
|
917
|
+
(e.g., mlx-vlm) or llama.cpp/gguf pipelines.
|
|
918
|
+
"""
|
|
919
|
+
|
|
920
|
+
def __init__(
|
|
921
|
+
self,
|
|
922
|
+
model_name: str = "Qwen/Qwen3-VL-8B-Instruct",
|
|
923
|
+
device: str = "auto",
|
|
924
|
+
torch_dtype: str = "auto",
|
|
925
|
+
load_in_4bit: bool = False,
|
|
926
|
+
load_in_8bit: bool = False,
|
|
927
|
+
trust_remote_code: bool = True,
|
|
928
|
+
):
|
|
929
|
+
super().__init__(model_name)
|
|
930
|
+
|
|
931
|
+
# Import required packages with consistent error handling
|
|
932
|
+
try:
|
|
933
|
+
import torch # type: ignore[import-not-found]
|
|
934
|
+
from transformers import AutoProcessor # type: ignore[import-not-found]
|
|
935
|
+
except ImportError as exc:
|
|
936
|
+
raise ImportError(
|
|
937
|
+
"transformers and torch are required for LocalVisionLLMProvider. "
|
|
938
|
+
"Install with: pip install transformers torch"
|
|
939
|
+
) from exc
|
|
940
|
+
|
|
941
|
+
self._torch = torch
|
|
942
|
+
|
|
943
|
+
# Resolve device
|
|
944
|
+
if device == "auto":
|
|
945
|
+
if (
|
|
946
|
+
getattr(torch.backends, "mps", None) is not None
|
|
947
|
+
and torch.backends.mps.is_available()
|
|
948
|
+
):
|
|
949
|
+
device = "mps"
|
|
950
|
+
elif torch.cuda.is_available():
|
|
951
|
+
device = "cuda"
|
|
952
|
+
else:
|
|
953
|
+
device = "cpu"
|
|
954
|
+
|
|
955
|
+
if device == "mps" and (load_in_4bit or load_in_8bit):
|
|
956
|
+
raise ValueError(
|
|
957
|
+
"Quantized (4-bit/8-bit) Transformers loading is typically not supported on Apple MPS. "
|
|
958
|
+
"Set load_in_4bit/load_in_8bit to False for MPS, or use a different local runtime "
|
|
959
|
+
"(e.g., MLX/llama.cpp) for quantized vision models."
|
|
960
|
+
)
|
|
961
|
+
|
|
962
|
+
# Determine torch dtype
|
|
963
|
+
if torch_dtype == "auto":
|
|
964
|
+
dtype = torch.float16 if device in ("cuda", "mps") else torch.float32
|
|
965
|
+
else:
|
|
966
|
+
dtype = getattr(torch, torch_dtype)
|
|
967
|
+
|
|
968
|
+
# Load processor
|
|
969
|
+
self.processor = AutoProcessor.from_pretrained(
|
|
970
|
+
model_name, trust_remote_code=trust_remote_code
|
|
971
|
+
)
|
|
972
|
+
|
|
973
|
+
# Load model (prefer vision2seq; fall back with guidance)
|
|
974
|
+
try:
|
|
975
|
+
import importlib
|
|
976
|
+
|
|
977
|
+
transformers = importlib.import_module("transformers")
|
|
978
|
+
AutoModelForVision2Seq = getattr(transformers, "AutoModelForVision2Seq", None)
|
|
979
|
+
if AutoModelForVision2Seq is None:
|
|
980
|
+
raise AttributeError("transformers.AutoModelForVision2Seq is not available")
|
|
981
|
+
|
|
982
|
+
self.model = AutoModelForVision2Seq.from_pretrained(
|
|
983
|
+
model_name,
|
|
984
|
+
torch_dtype=dtype,
|
|
985
|
+
trust_remote_code=trust_remote_code,
|
|
986
|
+
low_cpu_mem_usage=True,
|
|
987
|
+
)
|
|
988
|
+
except Exception as exc:
|
|
989
|
+
# Some transformers versions/models don't expose AutoModelForVision2Seq.
|
|
990
|
+
# We fail loudly with a helpful message rather than silently doing text-only.
|
|
991
|
+
raise ImportError(
|
|
992
|
+
"Failed to load a vision-capable Transformers model. "
|
|
993
|
+
"Try upgrading transformers (vision models often require newer versions), "
|
|
994
|
+
"or use a model class supported by your installed transformers build."
|
|
995
|
+
) from exc
|
|
996
|
+
|
|
997
|
+
# Move to device
|
|
998
|
+
self.device = device
|
|
999
|
+
self.model.to(device)
|
|
1000
|
+
|
|
1001
|
+
self.model.eval()
|
|
1002
|
+
|
|
1003
|
+
def supports_json_mode(self) -> bool:
|
|
1004
|
+
return False
|
|
1005
|
+
|
|
1006
|
+
def supports_vision(self) -> bool:
|
|
1007
|
+
return True
|
|
1008
|
+
|
|
1009
|
+
@property
|
|
1010
|
+
def model_name(self) -> str:
|
|
1011
|
+
return self._model_name
|
|
1012
|
+
|
|
1013
|
+
def generate(
|
|
1014
|
+
self,
|
|
1015
|
+
system_prompt: str,
|
|
1016
|
+
user_prompt: str,
|
|
1017
|
+
max_new_tokens: int = 512,
|
|
1018
|
+
temperature: float = 0.1,
|
|
1019
|
+
top_p: float = 0.9,
|
|
1020
|
+
**kwargs,
|
|
1021
|
+
) -> LLMResponse:
|
|
1022
|
+
"""
|
|
1023
|
+
Text-only generation (no image). Provided for interface completeness.
|
|
1024
|
+
"""
|
|
1025
|
+
torch = self._torch
|
|
1026
|
+
|
|
1027
|
+
messages = []
|
|
1028
|
+
if system_prompt:
|
|
1029
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
1030
|
+
messages.append({"role": "user", "content": user_prompt})
|
|
1031
|
+
|
|
1032
|
+
if hasattr(self.processor, "apply_chat_template"):
|
|
1033
|
+
prompt = self.processor.apply_chat_template(
|
|
1034
|
+
messages, tokenize=False, add_generation_prompt=True
|
|
1035
|
+
)
|
|
1036
|
+
else:
|
|
1037
|
+
prompt = (system_prompt + "\n\n" if system_prompt else "") + user_prompt
|
|
1038
|
+
|
|
1039
|
+
inputs = self.processor(text=[prompt], return_tensors="pt")
|
|
1040
|
+
inputs = {
|
|
1041
|
+
k: (v.to(self.model.device) if hasattr(v, "to") else v) for k, v in inputs.items()
|
|
1042
|
+
}
|
|
1043
|
+
|
|
1044
|
+
do_sample = temperature > 0
|
|
1045
|
+
with torch.no_grad():
|
|
1046
|
+
outputs = self.model.generate(
|
|
1047
|
+
**inputs,
|
|
1048
|
+
max_new_tokens=max_new_tokens,
|
|
1049
|
+
do_sample=do_sample,
|
|
1050
|
+
temperature=temperature if do_sample else 1.0,
|
|
1051
|
+
top_p=top_p,
|
|
1052
|
+
**kwargs,
|
|
1053
|
+
)
|
|
1054
|
+
|
|
1055
|
+
# Decode
|
|
1056
|
+
input_len = inputs["input_ids"].shape[1] if "input_ids" in inputs else 0
|
|
1057
|
+
generated = outputs[0][input_len:]
|
|
1058
|
+
if hasattr(self.processor, "batch_decode"):
|
|
1059
|
+
text = self.processor.batch_decode([generated], skip_special_tokens=True)[0].strip()
|
|
1060
|
+
else:
|
|
1061
|
+
text = str(generated)
|
|
1062
|
+
|
|
1063
|
+
return LLMResponseBuilder.from_local_format(
|
|
1064
|
+
content=text,
|
|
1065
|
+
prompt_tokens=int(input_len) if input_len else None,
|
|
1066
|
+
completion_tokens=int(generated.shape[0]) if hasattr(generated, "shape") else None,
|
|
1067
|
+
model_name=self._model_name,
|
|
1068
|
+
)
|
|
1069
|
+
|
|
1070
|
+
def generate_with_image(
|
|
1071
|
+
self,
|
|
1072
|
+
system_prompt: str,
|
|
1073
|
+
user_prompt: str,
|
|
1074
|
+
image_base64: str,
|
|
1075
|
+
max_new_tokens: int = 256,
|
|
1076
|
+
temperature: float = 0.0,
|
|
1077
|
+
top_p: float = 0.9,
|
|
1078
|
+
**kwargs,
|
|
1079
|
+
) -> LLMResponse:
|
|
1080
|
+
"""
|
|
1081
|
+
Vision generation using an image + prompt.
|
|
1082
|
+
|
|
1083
|
+
This is used by vision fallback in assertions and by visual agents.
|
|
1084
|
+
"""
|
|
1085
|
+
torch = self._torch
|
|
1086
|
+
|
|
1087
|
+
# Lazy import PIL to avoid adding a hard dependency for text-only users.
|
|
1088
|
+
try:
|
|
1089
|
+
from PIL import Image # type: ignore[import-not-found]
|
|
1090
|
+
except ImportError as exc:
|
|
1091
|
+
raise ImportError(
|
|
1092
|
+
"Pillow is required for LocalVisionLLMProvider image input. Install with: pip install pillow"
|
|
1093
|
+
) from exc
|
|
1094
|
+
|
|
1095
|
+
import base64
|
|
1096
|
+
import io
|
|
1097
|
+
|
|
1098
|
+
img_bytes = base64.b64decode(image_base64)
|
|
1099
|
+
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
|
1100
|
+
|
|
1101
|
+
# Prefer processor chat template if available (needed by many VL models).
|
|
1102
|
+
messages = []
|
|
1103
|
+
if system_prompt:
|
|
1104
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
1105
|
+
messages.append(
|
|
1106
|
+
{
|
|
1107
|
+
"role": "user",
|
|
1108
|
+
"content": [
|
|
1109
|
+
{"type": "image", "image": image},
|
|
1110
|
+
{"type": "text", "text": user_prompt},
|
|
1111
|
+
],
|
|
1112
|
+
}
|
|
1113
|
+
)
|
|
1114
|
+
|
|
1115
|
+
if hasattr(self.processor, "apply_chat_template"):
|
|
1116
|
+
prompt = self.processor.apply_chat_template(
|
|
1117
|
+
messages, tokenize=False, add_generation_prompt=True
|
|
1118
|
+
)
|
|
1119
|
+
else:
|
|
1120
|
+
raise NotImplementedError(
|
|
1121
|
+
"This local vision model/processor does not expose apply_chat_template(). "
|
|
1122
|
+
"Install/upgrade to a Transformers version that supports your model's chat template."
|
|
1123
|
+
)
|
|
1124
|
+
|
|
1125
|
+
inputs = self.processor(text=[prompt], images=[image], return_tensors="pt")
|
|
1126
|
+
inputs = {
|
|
1127
|
+
k: (v.to(self.model.device) if hasattr(v, "to") else v) for k, v in inputs.items()
|
|
1128
|
+
}
|
|
1129
|
+
|
|
1130
|
+
do_sample = temperature > 0
|
|
1131
|
+
with torch.no_grad():
|
|
1132
|
+
outputs = self.model.generate(
|
|
1133
|
+
**inputs,
|
|
1134
|
+
max_new_tokens=max_new_tokens,
|
|
1135
|
+
do_sample=do_sample,
|
|
1136
|
+
temperature=temperature if do_sample else 1.0,
|
|
1137
|
+
top_p=top_p,
|
|
1138
|
+
**kwargs,
|
|
1139
|
+
)
|
|
1140
|
+
|
|
1141
|
+
input_len = inputs["input_ids"].shape[1] if "input_ids" in inputs else 0
|
|
1142
|
+
generated = outputs[0][input_len:]
|
|
1143
|
+
|
|
1144
|
+
if hasattr(self.processor, "batch_decode"):
|
|
1145
|
+
text = self.processor.batch_decode([generated], skip_special_tokens=True)[0].strip()
|
|
1146
|
+
elif hasattr(self.processor, "tokenizer") and hasattr(self.processor.tokenizer, "decode"):
|
|
1147
|
+
text = self.processor.tokenizer.decode(generated, skip_special_tokens=True).strip()
|
|
1148
|
+
else:
|
|
1149
|
+
text = ""
|
|
1150
|
+
|
|
1151
|
+
return LLMResponseBuilder.from_local_format(
|
|
1152
|
+
content=text,
|
|
1153
|
+
prompt_tokens=int(input_len) if input_len else None,
|
|
1154
|
+
completion_tokens=int(generated.shape[0]) if hasattr(generated, "shape") else None,
|
|
1155
|
+
model_name=self._model_name,
|
|
1156
|
+
)
|
|
1157
|
+
|
|
1158
|
+
|
|
1159
|
+
class MLXVLMProvider(LLMProvider):
|
|
1160
|
+
"""
|
|
1161
|
+
Local vision-language provider using MLX-VLM (Apple Silicon optimized).
|
|
1162
|
+
|
|
1163
|
+
Recommended for running *quantized* vision models on Mac (M1/M2/M3/M4), e.g.:
|
|
1164
|
+
- mlx-community/Qwen3-VL-8B-Instruct-3bit
|
|
1165
|
+
|
|
1166
|
+
Optional dependencies:
|
|
1167
|
+
- mlx-vlm
|
|
1168
|
+
- pillow
|
|
1169
|
+
|
|
1170
|
+
Notes:
|
|
1171
|
+
- MLX-VLM APIs can vary across versions; this provider tries a couple common call shapes.
|
|
1172
|
+
- For best results, use an MLX-converted model repo under `mlx-community/`.
|
|
1173
|
+
"""
|
|
1174
|
+
|
|
1175
|
+
def __init__(
|
|
1176
|
+
self,
|
|
1177
|
+
model: str = "mlx-community/Qwen3-VL-8B-Instruct-3bit",
|
|
1178
|
+
*,
|
|
1179
|
+
default_max_tokens: int = 256,
|
|
1180
|
+
default_temperature: float = 0.0,
|
|
1181
|
+
**kwargs,
|
|
1182
|
+
):
|
|
1183
|
+
super().__init__(model)
|
|
1184
|
+
self._default_max_tokens = default_max_tokens
|
|
1185
|
+
self._default_temperature = default_temperature
|
|
1186
|
+
self._default_kwargs = dict(kwargs)
|
|
1187
|
+
|
|
1188
|
+
# Lazy imports to keep base SDK light.
|
|
1189
|
+
try:
|
|
1190
|
+
import importlib
|
|
1191
|
+
|
|
1192
|
+
self._mlx_vlm = importlib.import_module("mlx_vlm")
|
|
1193
|
+
except ImportError as exc:
|
|
1194
|
+
raise ImportError(
|
|
1195
|
+
"mlx-vlm is required for MLXVLMProvider. Install with: pip install mlx-vlm"
|
|
1196
|
+
) from exc
|
|
1197
|
+
|
|
1198
|
+
try:
|
|
1199
|
+
from PIL import Image # type: ignore[import-not-found]
|
|
1200
|
+
|
|
1201
|
+
self._PIL_Image = Image
|
|
1202
|
+
except ImportError as exc:
|
|
1203
|
+
raise ImportError(
|
|
1204
|
+
"Pillow is required for MLXVLMProvider. Install with: pip install pillow"
|
|
1205
|
+
) from exc
|
|
1206
|
+
|
|
1207
|
+
# Some mlx_vlm versions expose load(model_id) -> (model, processor)
|
|
1208
|
+
self._model = None
|
|
1209
|
+
self._processor = None
|
|
1210
|
+
load_fn = getattr(self._mlx_vlm, "load", None)
|
|
1211
|
+
if callable(load_fn):
|
|
1212
|
+
try:
|
|
1213
|
+
loaded = load_fn(model)
|
|
1214
|
+
if isinstance(loaded, tuple) and len(loaded) >= 2:
|
|
1215
|
+
self._model, self._processor = loaded[0], loaded[1]
|
|
1216
|
+
except Exception:
|
|
1217
|
+
# Keep it lazy; we'll try loading on demand during generate_with_image().
|
|
1218
|
+
self._model, self._processor = None, None
|
|
1219
|
+
|
|
1220
|
+
def supports_json_mode(self) -> bool:
|
|
1221
|
+
return False
|
|
1222
|
+
|
|
1223
|
+
def supports_vision(self) -> bool:
|
|
1224
|
+
return True
|
|
1225
|
+
|
|
1226
|
+
@property
|
|
1227
|
+
def model_name(self) -> str:
|
|
1228
|
+
return self._model_name
|
|
1229
|
+
|
|
1230
|
+
def generate(self, system_prompt: str, user_prompt: str, **kwargs) -> LLMResponse:
|
|
1231
|
+
"""
|
|
1232
|
+
Text-only generation is not a primary MLX-VLM use-case. We attempt it if the installed
|
|
1233
|
+
mlx_vlm exposes a compatible `generate()` signature; otherwise, raise a clear error.
|
|
1234
|
+
"""
|
|
1235
|
+
generate_fn = getattr(self._mlx_vlm, "generate", None)
|
|
1236
|
+
if not callable(generate_fn):
|
|
1237
|
+
raise NotImplementedError("mlx_vlm.generate is not available in your mlx-vlm install.")
|
|
1238
|
+
|
|
1239
|
+
prompt = (system_prompt + "\n\n" if system_prompt else "") + user_prompt
|
|
1240
|
+
max_tokens = kwargs.pop("max_tokens", self._default_max_tokens)
|
|
1241
|
+
temperature = kwargs.pop("temperature", self._default_temperature)
|
|
1242
|
+
merged_kwargs = {**self._default_kwargs, **kwargs}
|
|
1243
|
+
|
|
1244
|
+
try:
|
|
1245
|
+
out = generate_fn(
|
|
1246
|
+
self._model_name,
|
|
1247
|
+
prompt=prompt,
|
|
1248
|
+
max_tokens=max_tokens,
|
|
1249
|
+
temperature=temperature,
|
|
1250
|
+
**merged_kwargs,
|
|
1251
|
+
)
|
|
1252
|
+
except TypeError as exc:
|
|
1253
|
+
if self._model is None or self._processor is None:
|
|
1254
|
+
raise NotImplementedError(
|
|
1255
|
+
"Text-only generation is not supported by this mlx-vlm version without a loaded model."
|
|
1256
|
+
) from exc
|
|
1257
|
+
out = generate_fn(
|
|
1258
|
+
self._model,
|
|
1259
|
+
self._processor,
|
|
1260
|
+
prompt,
|
|
1261
|
+
max_tokens=max_tokens,
|
|
1262
|
+
temperature=temperature,
|
|
1263
|
+
**merged_kwargs,
|
|
1264
|
+
)
|
|
1265
|
+
|
|
1266
|
+
text = getattr(out, "text", None) or getattr(out, "output", None) or str(out)
|
|
1267
|
+
return LLMResponseBuilder.from_local_format(
|
|
1268
|
+
content=str(text).strip(),
|
|
1269
|
+
prompt_tokens=None,
|
|
1270
|
+
completion_tokens=None,
|
|
1271
|
+
model_name=self._model_name,
|
|
1272
|
+
)
|
|
1273
|
+
|
|
1274
|
+
def generate_with_image(
|
|
1275
|
+
self,
|
|
1276
|
+
system_prompt: str,
|
|
1277
|
+
user_prompt: str,
|
|
1278
|
+
image_base64: str,
|
|
1279
|
+
**kwargs,
|
|
1280
|
+
) -> LLMResponse:
|
|
1281
|
+
import base64
|
|
1282
|
+
import io
|
|
1283
|
+
|
|
1284
|
+
generate_fn = getattr(self._mlx_vlm, "generate", None)
|
|
1285
|
+
if not callable(generate_fn):
|
|
1286
|
+
raise NotImplementedError("mlx_vlm.generate is not available in your mlx-vlm install.")
|
|
1287
|
+
|
|
1288
|
+
img_bytes = base64.b64decode(image_base64)
|
|
1289
|
+
image = self._PIL_Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
|
1290
|
+
|
|
1291
|
+
prompt = (system_prompt + "\n\n" if system_prompt else "") + user_prompt
|
|
1292
|
+
max_tokens = kwargs.pop("max_tokens", self._default_max_tokens)
|
|
1293
|
+
temperature = kwargs.pop("temperature", self._default_temperature)
|
|
1294
|
+
merged_kwargs = {**self._default_kwargs, **kwargs}
|
|
1295
|
+
|
|
1296
|
+
# Try a couple common MLX-VLM call shapes.
|
|
1297
|
+
try:
|
|
1298
|
+
# 1) generate(model_id, image=..., prompt=...)
|
|
1299
|
+
out = generate_fn(
|
|
1300
|
+
self._model_name,
|
|
1301
|
+
image=image,
|
|
1302
|
+
prompt=prompt,
|
|
1303
|
+
max_tokens=max_tokens,
|
|
1304
|
+
temperature=temperature,
|
|
1305
|
+
**merged_kwargs,
|
|
1306
|
+
)
|
|
1307
|
+
except TypeError as exc:
|
|
1308
|
+
# 2) generate(model, processor, prompt, image, ...)
|
|
1309
|
+
if self._model is None or self._processor is None:
|
|
1310
|
+
load_fn = getattr(self._mlx_vlm, "load", None)
|
|
1311
|
+
if callable(load_fn):
|
|
1312
|
+
loaded = load_fn(self._model_name)
|
|
1313
|
+
if isinstance(loaded, tuple) and len(loaded) >= 2:
|
|
1314
|
+
self._model, self._processor = loaded[0], loaded[1]
|
|
1315
|
+
if self._model is None or self._processor is None:
|
|
1316
|
+
raise NotImplementedError(
|
|
1317
|
+
"Unable to call mlx_vlm.generate with your installed mlx-vlm version. "
|
|
1318
|
+
"Please upgrade mlx-vlm or use LocalVisionLLMProvider (Transformers backend)."
|
|
1319
|
+
) from exc
|
|
1320
|
+
out = generate_fn(
|
|
1321
|
+
self._model,
|
|
1322
|
+
self._processor,
|
|
1323
|
+
prompt,
|
|
1324
|
+
image,
|
|
1325
|
+
max_tokens=max_tokens,
|
|
1326
|
+
temperature=temperature,
|
|
1327
|
+
**merged_kwargs,
|
|
1328
|
+
)
|
|
1329
|
+
|
|
1330
|
+
text = getattr(out, "text", None) or getattr(out, "output", None) or str(out)
|
|
1331
|
+
return LLMResponseBuilder.from_local_format(
|
|
1332
|
+
content=str(text).strip(),
|
|
1333
|
+
prompt_tokens=None,
|
|
1334
|
+
completion_tokens=None,
|
|
1335
|
+
model_name=self._model_name,
|
|
1336
|
+
)
|