sentienceapi 0.90.16__py3-none-any.whl → 0.98.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sentienceapi might be problematic. Click here for more details.

Files changed (90) hide show
  1. sentience/__init__.py +120 -6
  2. sentience/_extension_loader.py +156 -1
  3. sentience/action_executor.py +217 -0
  4. sentience/actions.py +758 -30
  5. sentience/agent.py +806 -293
  6. sentience/agent_config.py +3 -0
  7. sentience/agent_runtime.py +840 -0
  8. sentience/asserts/__init__.py +70 -0
  9. sentience/asserts/expect.py +621 -0
  10. sentience/asserts/query.py +383 -0
  11. sentience/async_api.py +89 -1141
  12. sentience/backends/__init__.py +137 -0
  13. sentience/backends/actions.py +372 -0
  14. sentience/backends/browser_use_adapter.py +241 -0
  15. sentience/backends/cdp_backend.py +393 -0
  16. sentience/backends/exceptions.py +211 -0
  17. sentience/backends/playwright_backend.py +194 -0
  18. sentience/backends/protocol.py +216 -0
  19. sentience/backends/sentience_context.py +469 -0
  20. sentience/backends/snapshot.py +483 -0
  21. sentience/base_agent.py +95 -0
  22. sentience/browser.py +678 -39
  23. sentience/browser_evaluator.py +299 -0
  24. sentience/canonicalization.py +207 -0
  25. sentience/cloud_tracing.py +507 -42
  26. sentience/constants.py +6 -0
  27. sentience/conversational_agent.py +77 -43
  28. sentience/cursor_policy.py +142 -0
  29. sentience/element_filter.py +136 -0
  30. sentience/expect.py +98 -2
  31. sentience/extension/background.js +56 -185
  32. sentience/extension/content.js +150 -287
  33. sentience/extension/injected_api.js +1088 -1368
  34. sentience/extension/manifest.json +1 -1
  35. sentience/extension/pkg/sentience_core.d.ts +22 -22
  36. sentience/extension/pkg/sentience_core.js +275 -433
  37. sentience/extension/pkg/sentience_core_bg.wasm +0 -0
  38. sentience/extension/release.json +47 -47
  39. sentience/failure_artifacts.py +241 -0
  40. sentience/formatting.py +9 -53
  41. sentience/inspector.py +183 -1
  42. sentience/integrations/__init__.py +6 -0
  43. sentience/integrations/langchain/__init__.py +12 -0
  44. sentience/integrations/langchain/context.py +18 -0
  45. sentience/integrations/langchain/core.py +326 -0
  46. sentience/integrations/langchain/tools.py +180 -0
  47. sentience/integrations/models.py +46 -0
  48. sentience/integrations/pydanticai/__init__.py +15 -0
  49. sentience/integrations/pydanticai/deps.py +20 -0
  50. sentience/integrations/pydanticai/toolset.py +468 -0
  51. sentience/llm_interaction_handler.py +191 -0
  52. sentience/llm_provider.py +765 -66
  53. sentience/llm_provider_utils.py +120 -0
  54. sentience/llm_response_builder.py +153 -0
  55. sentience/models.py +595 -3
  56. sentience/ordinal.py +280 -0
  57. sentience/overlay.py +109 -2
  58. sentience/protocols.py +228 -0
  59. sentience/query.py +67 -5
  60. sentience/read.py +95 -3
  61. sentience/recorder.py +223 -3
  62. sentience/schemas/trace_v1.json +128 -9
  63. sentience/screenshot.py +48 -2
  64. sentience/sentience_methods.py +86 -0
  65. sentience/snapshot.py +599 -55
  66. sentience/snapshot_diff.py +126 -0
  67. sentience/text_search.py +120 -5
  68. sentience/trace_event_builder.py +148 -0
  69. sentience/trace_file_manager.py +197 -0
  70. sentience/trace_indexing/index_schema.py +95 -7
  71. sentience/trace_indexing/indexer.py +105 -48
  72. sentience/tracer_factory.py +120 -9
  73. sentience/tracing.py +172 -8
  74. sentience/utils/__init__.py +40 -0
  75. sentience/utils/browser.py +46 -0
  76. sentience/{utils.py → utils/element.py} +3 -42
  77. sentience/utils/formatting.py +59 -0
  78. sentience/verification.py +618 -0
  79. sentience/visual_agent.py +2058 -0
  80. sentience/wait.py +68 -2
  81. {sentienceapi-0.90.16.dist-info → sentienceapi-0.98.0.dist-info}/METADATA +199 -40
  82. sentienceapi-0.98.0.dist-info/RECORD +92 -0
  83. sentience/extension/test-content.js +0 -4
  84. sentienceapi-0.90.16.dist-info/RECORD +0 -50
  85. {sentienceapi-0.90.16.dist-info → sentienceapi-0.98.0.dist-info}/WHEEL +0 -0
  86. {sentienceapi-0.90.16.dist-info → sentienceapi-0.98.0.dist-info}/entry_points.txt +0 -0
  87. {sentienceapi-0.90.16.dist-info → sentienceapi-0.98.0.dist-info}/licenses/LICENSE +0 -0
  88. {sentienceapi-0.90.16.dist-info → sentienceapi-0.98.0.dist-info}/licenses/LICENSE-APACHE +0 -0
  89. {sentienceapi-0.90.16.dist-info → sentienceapi-0.98.0.dist-info}/licenses/LICENSE-MIT +0 -0
  90. {sentienceapi-0.90.16.dist-info → sentienceapi-0.98.0.dist-info}/top_level.txt +0 -0
sentience/llm_provider.py CHANGED
@@ -5,6 +5,10 @@ Enables "Bring Your Own Brain" (BYOB) pattern - plug in any LLM provider
5
5
 
6
6
  from abc import ABC, abstractmethod
7
7
  from dataclasses import dataclass
8
+ from typing import Any
9
+
10
+ from .llm_provider_utils import get_api_key_from_env, handle_provider_error, require_package
11
+ from .llm_response_builder import LLMResponseBuilder
8
12
 
9
13
 
10
14
  @dataclass
@@ -31,6 +35,15 @@ class LLMProvider(ABC):
31
35
  - Any other completion API
32
36
  """
33
37
 
38
+ def __init__(self, model: str):
39
+ """
40
+ Initialize LLM provider with model name.
41
+
42
+ Args:
43
+ model: Model identifier (e.g., "gpt-4o", "claude-3-sonnet")
44
+ """
45
+ self._model_name = model
46
+
34
47
  @abstractmethod
35
48
  def generate(self, system_prompt: str, user_prompt: str, **kwargs) -> LLMResponse:
36
49
  """
@@ -67,6 +80,48 @@ class LLMProvider(ABC):
67
80
  """
68
81
  pass
69
82
 
83
+ def supports_vision(self) -> bool:
84
+ """
85
+ Whether this provider supports image input for vision tasks.
86
+
87
+ Override in subclasses that support vision-capable models.
88
+
89
+ Returns:
90
+ True if provider supports vision, False otherwise
91
+ """
92
+ return False
93
+
94
+ def generate_with_image(
95
+ self,
96
+ system_prompt: str,
97
+ user_prompt: str,
98
+ image_base64: str,
99
+ **kwargs,
100
+ ) -> LLMResponse:
101
+ """
102
+ Generate a response with image input (for vision-capable models).
103
+
104
+ This method is used for vision fallback in assertions and visual agents.
105
+ Override in subclasses that support vision-capable models.
106
+
107
+ Args:
108
+ system_prompt: System instruction/context
109
+ user_prompt: User query/request
110
+ image_base64: Base64-encoded image (PNG or JPEG)
111
+ **kwargs: Provider-specific parameters (temperature, max_tokens, etc.)
112
+
113
+ Returns:
114
+ LLMResponse with content and token usage
115
+
116
+ Raises:
117
+ NotImplementedError: If provider doesn't support vision
118
+ """
119
+ raise NotImplementedError(
120
+ f"{type(self).__name__} does not support vision. "
121
+ "Use a vision-capable provider like OpenAIProvider with GPT-4o "
122
+ "or AnthropicProvider with Claude 3."
123
+ )
124
+
70
125
 
71
126
  class OpenAIProvider(LLMProvider):
72
127
  """
@@ -95,13 +150,16 @@ class OpenAIProvider(LLMProvider):
95
150
  base_url: Custom API base URL (for compatible APIs)
96
151
  organization: OpenAI organization ID
97
152
  """
98
- try:
99
- from openai import OpenAI
100
- except ImportError:
101
- raise ImportError("OpenAI package not installed. Install with: pip install openai")
153
+ super().__init__(model) # Initialize base class with model name
154
+
155
+ OpenAI = require_package(
156
+ "openai",
157
+ "openai",
158
+ "OpenAI",
159
+ "pip install openai",
160
+ )
102
161
 
103
162
  self.client = OpenAI(api_key=api_key, base_url=base_url, organization=organization)
104
- self._model_name = model
105
163
 
106
164
  def generate(
107
165
  self,
@@ -148,12 +206,15 @@ class OpenAIProvider(LLMProvider):
148
206
  api_params.update(kwargs)
149
207
 
150
208
  # Call OpenAI API
151
- response = self.client.chat.completions.create(**api_params)
209
+ try:
210
+ response = self.client.chat.completions.create(**api_params)
211
+ except Exception as e:
212
+ handle_provider_error(e, "OpenAI", "generate response")
152
213
 
153
214
  choice = response.choices[0]
154
215
  usage = response.usage
155
216
 
156
- return LLMResponse(
217
+ return LLMResponseBuilder.from_openai_format(
157
218
  content=choice.message.content,
158
219
  prompt_tokens=usage.prompt_tokens if usage else None,
159
220
  completion_tokens=usage.completion_tokens if usage else None,
@@ -167,6 +228,92 @@ class OpenAIProvider(LLMProvider):
167
228
  model_lower = self._model_name.lower()
168
229
  return any(x in model_lower for x in ["gpt-4", "gpt-3.5"])
169
230
 
231
+ def supports_vision(self) -> bool:
232
+ """GPT-4o, GPT-4-turbo, and GPT-4-vision support vision."""
233
+ model_lower = self._model_name.lower()
234
+ return any(x in model_lower for x in ["gpt-4o", "gpt-4-turbo", "gpt-4-vision"])
235
+
236
+ def generate_with_image(
237
+ self,
238
+ system_prompt: str,
239
+ user_prompt: str,
240
+ image_base64: str,
241
+ temperature: float = 0.0,
242
+ max_tokens: int | None = None,
243
+ **kwargs,
244
+ ) -> LLMResponse:
245
+ """
246
+ Generate response with image input using OpenAI Vision API.
247
+
248
+ Args:
249
+ system_prompt: System instruction
250
+ user_prompt: User query
251
+ image_base64: Base64-encoded image (PNG or JPEG)
252
+ temperature: Sampling temperature (0.0 = deterministic)
253
+ max_tokens: Maximum tokens to generate
254
+ **kwargs: Additional OpenAI API parameters
255
+
256
+ Returns:
257
+ LLMResponse object
258
+
259
+ Raises:
260
+ NotImplementedError: If model doesn't support vision
261
+ """
262
+ if not self.supports_vision():
263
+ raise NotImplementedError(
264
+ f"Model {self._model_name} does not support vision. "
265
+ "Use gpt-4o, gpt-4-turbo, or gpt-4-vision-preview."
266
+ )
267
+
268
+ messages = []
269
+ if system_prompt:
270
+ messages.append({"role": "system", "content": system_prompt})
271
+
272
+ # Vision message format with image_url
273
+ messages.append(
274
+ {
275
+ "role": "user",
276
+ "content": [
277
+ {"type": "text", "text": user_prompt},
278
+ {
279
+ "type": "image_url",
280
+ "image_url": {"url": f"data:image/png;base64,{image_base64}"},
281
+ },
282
+ ],
283
+ }
284
+ )
285
+
286
+ # Build API parameters
287
+ api_params = {
288
+ "model": self._model_name,
289
+ "messages": messages,
290
+ "temperature": temperature,
291
+ }
292
+
293
+ if max_tokens:
294
+ api_params["max_tokens"] = max_tokens
295
+
296
+ # Merge additional parameters
297
+ api_params.update(kwargs)
298
+
299
+ # Call OpenAI API
300
+ try:
301
+ response = self.client.chat.completions.create(**api_params)
302
+ except Exception as e:
303
+ handle_provider_error(e, "OpenAI", "generate response with image")
304
+
305
+ choice = response.choices[0]
306
+ usage = response.usage
307
+
308
+ return LLMResponseBuilder.from_openai_format(
309
+ content=choice.message.content,
310
+ prompt_tokens=usage.prompt_tokens if usage else None,
311
+ completion_tokens=usage.completion_tokens if usage else None,
312
+ total_tokens=usage.total_tokens if usage else None,
313
+ model_name=response.model,
314
+ finish_reason=choice.finish_reason,
315
+ )
316
+
170
317
  @property
171
318
  def model_name(self) -> str:
172
319
  return self._model_name
@@ -191,15 +338,16 @@ class AnthropicProvider(LLMProvider):
191
338
  api_key: Anthropic API key (or set ANTHROPIC_API_KEY env var)
192
339
  model: Model name (claude-3-opus, claude-3-sonnet, claude-3-haiku, etc.)
193
340
  """
194
- try:
195
- from anthropic import Anthropic
196
- except ImportError:
197
- raise ImportError(
198
- "Anthropic package not installed. Install with: pip install anthropic"
199
- )
341
+ super().__init__(model) # Initialize base class with model name
342
+
343
+ Anthropic = require_package(
344
+ "anthropic",
345
+ "anthropic",
346
+ "Anthropic",
347
+ "pip install anthropic",
348
+ )
200
349
 
201
350
  self.client = Anthropic(api_key=api_key)
202
- self._model_name = model
203
351
 
204
352
  def generate(
205
353
  self,
@@ -237,27 +385,113 @@ class AnthropicProvider(LLMProvider):
237
385
  api_params.update(kwargs)
238
386
 
239
387
  # Call Anthropic API
240
- response = self.client.messages.create(**api_params)
388
+ try:
389
+ response = self.client.messages.create(**api_params)
390
+ except Exception as e:
391
+ handle_provider_error(e, "Anthropic", "generate response")
241
392
 
242
393
  content = response.content[0].text if response.content else ""
243
394
 
244
- return LLMResponse(
395
+ return LLMResponseBuilder.from_anthropic_format(
245
396
  content=content,
246
- prompt_tokens=response.usage.input_tokens if hasattr(response, "usage") else None,
247
- completion_tokens=response.usage.output_tokens if hasattr(response, "usage") else None,
248
- total_tokens=(
249
- (response.usage.input_tokens + response.usage.output_tokens)
250
- if hasattr(response, "usage")
251
- else None
252
- ),
397
+ input_tokens=response.usage.input_tokens if hasattr(response, "usage") else None,
398
+ output_tokens=response.usage.output_tokens if hasattr(response, "usage") else None,
253
399
  model_name=response.model,
254
- finish_reason=response.stop_reason,
400
+ stop_reason=response.stop_reason,
255
401
  )
256
402
 
257
403
  def supports_json_mode(self) -> bool:
258
404
  """Anthropic doesn't have native JSON mode (requires prompt engineering)"""
259
405
  return False
260
406
 
407
+ def supports_vision(self) -> bool:
408
+ """Claude 3 models (Opus, Sonnet, Haiku) all support vision."""
409
+ model_lower = self._model_name.lower()
410
+ return any(x in model_lower for x in ["claude-3", "claude-3.5"])
411
+
412
+ def generate_with_image(
413
+ self,
414
+ system_prompt: str,
415
+ user_prompt: str,
416
+ image_base64: str,
417
+ temperature: float = 0.0,
418
+ max_tokens: int = 1024,
419
+ **kwargs,
420
+ ) -> LLMResponse:
421
+ """
422
+ Generate response with image input using Anthropic Vision API.
423
+
424
+ Args:
425
+ system_prompt: System instruction
426
+ user_prompt: User query
427
+ image_base64: Base64-encoded image (PNG or JPEG)
428
+ temperature: Sampling temperature
429
+ max_tokens: Maximum tokens to generate (required by Anthropic)
430
+ **kwargs: Additional Anthropic API parameters
431
+
432
+ Returns:
433
+ LLMResponse object
434
+
435
+ Raises:
436
+ NotImplementedError: If model doesn't support vision
437
+ """
438
+ if not self.supports_vision():
439
+ raise NotImplementedError(
440
+ f"Model {self._model_name} does not support vision. "
441
+ "Use Claude 3 models (claude-3-opus, claude-3-sonnet, claude-3-haiku)."
442
+ )
443
+
444
+ # Anthropic vision message format
445
+ messages = [
446
+ {
447
+ "role": "user",
448
+ "content": [
449
+ {
450
+ "type": "image",
451
+ "source": {
452
+ "type": "base64",
453
+ "media_type": "image/png",
454
+ "data": image_base64,
455
+ },
456
+ },
457
+ {
458
+ "type": "text",
459
+ "text": user_prompt,
460
+ },
461
+ ],
462
+ }
463
+ ]
464
+
465
+ # Build API parameters
466
+ api_params = {
467
+ "model": self._model_name,
468
+ "max_tokens": max_tokens,
469
+ "temperature": temperature,
470
+ "messages": messages,
471
+ }
472
+
473
+ if system_prompt:
474
+ api_params["system"] = system_prompt
475
+
476
+ # Merge additional parameters
477
+ api_params.update(kwargs)
478
+
479
+ # Call Anthropic API
480
+ try:
481
+ response = self.client.messages.create(**api_params)
482
+ except Exception as e:
483
+ handle_provider_error(e, "Anthropic", "generate response with image")
484
+
485
+ content = response.content[0].text if response.content else ""
486
+
487
+ return LLMResponseBuilder.from_anthropic_format(
488
+ content=content,
489
+ input_tokens=response.usage.input_tokens if hasattr(response, "usage") else None,
490
+ output_tokens=response.usage.output_tokens if hasattr(response, "usage") else None,
491
+ model_name=response.model,
492
+ stop_reason=response.stop_reason,
493
+ )
494
+
261
495
  @property
262
496
  def model_name(self) -> str:
263
497
  return self._model_name
@@ -285,13 +519,16 @@ class GLMProvider(LLMProvider):
285
519
  api_key: Zhipu AI API key (or set GLM_API_KEY env var)
286
520
  model: Model name (glm-4-plus, glm-4, glm-4-air, glm-4-flash, etc.)
287
521
  """
288
- try:
289
- from zhipuai import ZhipuAI
290
- except ImportError:
291
- raise ImportError("ZhipuAI package not installed. Install with: pip install zhipuai")
522
+ super().__init__(model) # Initialize base class with model name
523
+
524
+ ZhipuAI = require_package(
525
+ "zhipuai",
526
+ "zhipuai",
527
+ "ZhipuAI",
528
+ "pip install zhipuai",
529
+ )
292
530
 
293
531
  self.client = ZhipuAI(api_key=api_key)
294
- self._model_name = model
295
532
 
296
533
  def generate(
297
534
  self,
@@ -333,12 +570,15 @@ class GLMProvider(LLMProvider):
333
570
  api_params.update(kwargs)
334
571
 
335
572
  # Call GLM API
336
- response = self.client.chat.completions.create(**api_params)
573
+ try:
574
+ response = self.client.chat.completions.create(**api_params)
575
+ except Exception as e:
576
+ handle_provider_error(e, "GLM", "generate response")
337
577
 
338
578
  choice = response.choices[0]
339
579
  usage = response.usage
340
580
 
341
- return LLMResponse(
581
+ return LLMResponseBuilder.from_openai_format(
342
582
  content=choice.message.content,
343
583
  prompt_tokens=usage.prompt_tokens if usage else None,
344
584
  completion_tokens=usage.completion_tokens if usage else None,
@@ -378,25 +618,20 @@ class GeminiProvider(LLMProvider):
378
618
  api_key: Google API key (or set GEMINI_API_KEY or GOOGLE_API_KEY env var)
379
619
  model: Model name (gemini-2.0-flash-exp, gemini-1.5-pro, gemini-1.5-flash, etc.)
380
620
  """
381
- try:
382
- import google.generativeai as genai
383
- except ImportError:
384
- raise ImportError(
385
- "Google Generative AI package not installed. Install with: pip install google-generativeai"
386
- )
621
+ super().__init__(model) # Initialize base class with model name
387
622
 
388
- # Configure API key
623
+ genai = require_package(
624
+ "google-generativeai",
625
+ "google.generativeai",
626
+ install_command="pip install google-generativeai",
627
+ )
628
+
629
+ # Configure API key (check parameter first, then environment variables)
630
+ api_key = get_api_key_from_env(["GEMINI_API_KEY", "GOOGLE_API_KEY"], api_key)
389
631
  if api_key:
390
632
  genai.configure(api_key=api_key)
391
- else:
392
- import os
393
-
394
- api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
395
- if api_key:
396
- genai.configure(api_key=api_key)
397
633
 
398
634
  self.genai = genai
399
- self._model_name = model
400
635
  self.model = genai.GenerativeModel(model)
401
636
 
402
637
  def generate(
@@ -435,7 +670,10 @@ class GeminiProvider(LLMProvider):
435
670
  generation_config.update(kwargs)
436
671
 
437
672
  # Call Gemini API
438
- response = self.model.generate_content(full_prompt, generation_config=generation_config)
673
+ try:
674
+ response = self.model.generate_content(full_prompt, generation_config=generation_config)
675
+ except Exception as e:
676
+ handle_provider_error(e, "Gemini", "generate response")
439
677
 
440
678
  # Extract content
441
679
  content = response.text if response.text else ""
@@ -450,13 +688,12 @@ class GeminiProvider(LLMProvider):
450
688
  completion_tokens = response.usage_metadata.candidates_token_count
451
689
  total_tokens = response.usage_metadata.total_token_count
452
690
 
453
- return LLMResponse(
691
+ return LLMResponseBuilder.from_gemini_format(
454
692
  content=content,
455
693
  prompt_tokens=prompt_tokens,
456
694
  completion_tokens=completion_tokens,
457
695
  total_tokens=total_tokens,
458
696
  model_name=self._model_name,
459
- finish_reason=None, # Gemini uses different finish reason format
460
697
  )
461
698
 
462
699
  def supports_json_mode(self) -> bool:
@@ -503,16 +740,24 @@ class LocalLLMProvider(LLMProvider):
503
740
  load_in_8bit: Use 8-bit quantization (saves 50% memory)
504
741
  torch_dtype: Data type ("auto", "float16", "bfloat16", "float32")
505
742
  """
743
+ super().__init__(model_name) # Initialize base class with model name
744
+
745
+ # Import required packages with consistent error handling.
746
+ # These are optional dependencies, so keep them out of module import-time.
506
747
  try:
507
- import torch
508
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
509
- except ImportError:
748
+ import torch # type: ignore[import-not-found]
749
+ from transformers import ( # type: ignore[import-not-found]
750
+ AutoModelForCausalLM,
751
+ AutoTokenizer,
752
+ BitsAndBytesConfig,
753
+ )
754
+ except ImportError as exc:
510
755
  raise ImportError(
511
756
  "transformers and torch required for local LLM. "
512
757
  "Install with: pip install transformers torch"
513
- )
758
+ ) from exc
514
759
 
515
- self._model_name = model_name
760
+ self._torch = torch
516
761
 
517
762
  # Load tokenizer
518
763
  self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
@@ -533,21 +778,44 @@ class LocalLLMProvider(LLMProvider):
533
778
  elif load_in_8bit:
534
779
  quantization_config = BitsAndBytesConfig(load_in_8bit=True)
535
780
 
781
+ device = (device or "auto").strip().lower()
782
+
536
783
  # Determine torch dtype
537
784
  if torch_dtype == "auto":
538
- dtype = torch.float16 if device != "cpu" else torch.float32
785
+ dtype = torch.float16 if device not in {"cpu"} else torch.float32
539
786
  else:
540
787
  dtype = getattr(torch, torch_dtype)
541
788
 
542
- # Load model
543
- self.model = AutoModelForCausalLM.from_pretrained(
544
- model_name,
545
- quantization_config=quantization_config,
546
- torch_dtype=dtype if quantization_config is None else None,
547
- device_map=device,
548
- trust_remote_code=True,
549
- low_cpu_mem_usage=True,
550
- )
789
+ # device_map is a Transformers concept (not a literal "cpu/mps/cuda" device string).
790
+ # - "auto" enables Accelerate device mapping.
791
+ # - Otherwise, we load normally and then move the model to the requested device.
792
+ device_map: str | None = "auto" if device == "auto" else None
793
+
794
+ def _load(*, device_map_override: str | None) -> Any:
795
+ return AutoModelForCausalLM.from_pretrained(
796
+ model_name,
797
+ quantization_config=quantization_config,
798
+ torch_dtype=dtype if quantization_config is None else None,
799
+ device_map=device_map_override,
800
+ trust_remote_code=True,
801
+ low_cpu_mem_usage=True,
802
+ )
803
+
804
+ try:
805
+ self.model = _load(device_map_override=device_map)
806
+ except KeyError as e:
807
+ # Some envs / accelerate versions can crash on auto mapping (e.g. KeyError: 'cpu').
808
+ # Keep demo ergonomics: default stays "auto", but we gracefully fall back.
809
+ if device == "auto" and ("cpu" in str(e).lower()):
810
+ device = "cpu"
811
+ dtype = torch.float32
812
+ self.model = _load(device_map_override=None)
813
+ else:
814
+ raise
815
+
816
+ # If we didn't use device_map, move model explicitly (only safe for non-quantized loads).
817
+ if device_map is None and quantization_config is None and device in {"cpu", "cuda", "mps"}:
818
+ self.model = self.model.to(device)
551
819
  self.model.eval()
552
820
 
553
821
  def generate(
@@ -573,7 +841,7 @@ class LocalLLMProvider(LLMProvider):
573
841
  Returns:
574
842
  LLMResponse object
575
843
  """
576
- import torch
844
+ torch = self._torch
577
845
 
578
846
  # Auto-determine sampling based on temperature
579
847
  do_sample = temperature > 0
@@ -620,11 +888,10 @@ class LocalLLMProvider(LLMProvider):
620
888
  generated_tokens = outputs[0][input_length:]
621
889
  response_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
622
890
 
623
- return LLMResponse(
891
+ return LLMResponseBuilder.from_local_format(
624
892
  content=response_text,
625
893
  prompt_tokens=input_length,
626
894
  completion_tokens=len(generated_tokens),
627
- total_tokens=input_length + len(generated_tokens),
628
895
  model_name=self._model_name,
629
896
  )
630
897
 
@@ -635,3 +902,435 @@ class LocalLLMProvider(LLMProvider):
635
902
  @property
636
903
  def model_name(self) -> str:
637
904
  return self._model_name
905
+
906
+
907
+ class LocalVisionLLMProvider(LLMProvider):
908
+ """
909
+ Local vision-language LLM provider using HuggingFace Transformers.
910
+
911
+ Intended for models like:
912
+ - Qwen/Qwen3-VL-8B-Instruct
913
+
914
+ Notes on Mac (MPS) + quantization:
915
+ - Transformers BitsAndBytes (4-bit/8-bit) typically requires CUDA and does NOT work on MPS.
916
+ - If you want quantized local vision on Apple Silicon, you may prefer MLX-based stacks
917
+ (e.g., mlx-vlm) or llama.cpp/gguf pipelines.
918
+ """
919
+
920
+ def __init__(
921
+ self,
922
+ model_name: str = "Qwen/Qwen3-VL-8B-Instruct",
923
+ device: str = "auto",
924
+ torch_dtype: str = "auto",
925
+ load_in_4bit: bool = False,
926
+ load_in_8bit: bool = False,
927
+ trust_remote_code: bool = True,
928
+ ):
929
+ super().__init__(model_name)
930
+
931
+ # Import required packages with consistent error handling
932
+ try:
933
+ import torch # type: ignore[import-not-found]
934
+ from transformers import AutoProcessor # type: ignore[import-not-found]
935
+ except ImportError as exc:
936
+ raise ImportError(
937
+ "transformers and torch are required for LocalVisionLLMProvider. "
938
+ "Install with: pip install transformers torch"
939
+ ) from exc
940
+
941
+ self._torch = torch
942
+
943
+ # Resolve device
944
+ if device == "auto":
945
+ if (
946
+ getattr(torch.backends, "mps", None) is not None
947
+ and torch.backends.mps.is_available()
948
+ ):
949
+ device = "mps"
950
+ elif torch.cuda.is_available():
951
+ device = "cuda"
952
+ else:
953
+ device = "cpu"
954
+
955
+ if device == "mps" and (load_in_4bit or load_in_8bit):
956
+ raise ValueError(
957
+ "Quantized (4-bit/8-bit) Transformers loading is typically not supported on Apple MPS. "
958
+ "Set load_in_4bit/load_in_8bit to False for MPS, or use a different local runtime "
959
+ "(e.g., MLX/llama.cpp) for quantized vision models."
960
+ )
961
+
962
+ # Determine torch dtype
963
+ if torch_dtype == "auto":
964
+ dtype = torch.float16 if device in ("cuda", "mps") else torch.float32
965
+ else:
966
+ dtype = getattr(torch, torch_dtype)
967
+
968
+ # Load processor
969
+ self.processor = AutoProcessor.from_pretrained(
970
+ model_name, trust_remote_code=trust_remote_code
971
+ )
972
+
973
+ # Load model (prefer vision2seq; fall back with guidance)
974
+ try:
975
+ import importlib
976
+
977
+ transformers = importlib.import_module("transformers")
978
+ AutoModelForVision2Seq = getattr(transformers, "AutoModelForVision2Seq", None)
979
+ if AutoModelForVision2Seq is None:
980
+ raise AttributeError("transformers.AutoModelForVision2Seq is not available")
981
+
982
+ self.model = AutoModelForVision2Seq.from_pretrained(
983
+ model_name,
984
+ torch_dtype=dtype,
985
+ trust_remote_code=trust_remote_code,
986
+ low_cpu_mem_usage=True,
987
+ )
988
+ except Exception as exc:
989
+ # Some transformers versions/models don't expose AutoModelForVision2Seq.
990
+ # We fail loudly with a helpful message rather than silently doing text-only.
991
+ raise ImportError(
992
+ "Failed to load a vision-capable Transformers model. "
993
+ "Try upgrading transformers (vision models often require newer versions), "
994
+ "or use a model class supported by your installed transformers build."
995
+ ) from exc
996
+
997
+ # Move to device
998
+ self.device = device
999
+ self.model.to(device)
1000
+
1001
+ self.model.eval()
1002
+
1003
+ def supports_json_mode(self) -> bool:
1004
+ return False
1005
+
1006
+ def supports_vision(self) -> bool:
1007
+ return True
1008
+
1009
+ @property
1010
+ def model_name(self) -> str:
1011
+ return self._model_name
1012
+
1013
+ def generate(
1014
+ self,
1015
+ system_prompt: str,
1016
+ user_prompt: str,
1017
+ max_new_tokens: int = 512,
1018
+ temperature: float = 0.1,
1019
+ top_p: float = 0.9,
1020
+ **kwargs,
1021
+ ) -> LLMResponse:
1022
+ """
1023
+ Text-only generation (no image). Provided for interface completeness.
1024
+ """
1025
+ torch = self._torch
1026
+
1027
+ messages = []
1028
+ if system_prompt:
1029
+ messages.append({"role": "system", "content": system_prompt})
1030
+ messages.append({"role": "user", "content": user_prompt})
1031
+
1032
+ if hasattr(self.processor, "apply_chat_template"):
1033
+ prompt = self.processor.apply_chat_template(
1034
+ messages, tokenize=False, add_generation_prompt=True
1035
+ )
1036
+ else:
1037
+ prompt = (system_prompt + "\n\n" if system_prompt else "") + user_prompt
1038
+
1039
+ inputs = self.processor(text=[prompt], return_tensors="pt")
1040
+ inputs = {
1041
+ k: (v.to(self.model.device) if hasattr(v, "to") else v) for k, v in inputs.items()
1042
+ }
1043
+
1044
+ do_sample = temperature > 0
1045
+ with torch.no_grad():
1046
+ outputs = self.model.generate(
1047
+ **inputs,
1048
+ max_new_tokens=max_new_tokens,
1049
+ do_sample=do_sample,
1050
+ temperature=temperature if do_sample else 1.0,
1051
+ top_p=top_p,
1052
+ **kwargs,
1053
+ )
1054
+
1055
+ # Decode
1056
+ input_len = inputs["input_ids"].shape[1] if "input_ids" in inputs else 0
1057
+ generated = outputs[0][input_len:]
1058
+ if hasattr(self.processor, "batch_decode"):
1059
+ text = self.processor.batch_decode([generated], skip_special_tokens=True)[0].strip()
1060
+ else:
1061
+ text = str(generated)
1062
+
1063
+ return LLMResponseBuilder.from_local_format(
1064
+ content=text,
1065
+ prompt_tokens=int(input_len) if input_len else None,
1066
+ completion_tokens=int(generated.shape[0]) if hasattr(generated, "shape") else None,
1067
+ model_name=self._model_name,
1068
+ )
1069
+
1070
+ def generate_with_image(
1071
+ self,
1072
+ system_prompt: str,
1073
+ user_prompt: str,
1074
+ image_base64: str,
1075
+ max_new_tokens: int = 256,
1076
+ temperature: float = 0.0,
1077
+ top_p: float = 0.9,
1078
+ **kwargs,
1079
+ ) -> LLMResponse:
1080
+ """
1081
+ Vision generation using an image + prompt.
1082
+
1083
+ This is used by vision fallback in assertions and by visual agents.
1084
+ """
1085
+ torch = self._torch
1086
+
1087
+ # Lazy import PIL to avoid adding a hard dependency for text-only users.
1088
+ try:
1089
+ from PIL import Image # type: ignore[import-not-found]
1090
+ except ImportError as exc:
1091
+ raise ImportError(
1092
+ "Pillow is required for LocalVisionLLMProvider image input. Install with: pip install pillow"
1093
+ ) from exc
1094
+
1095
+ import base64
1096
+ import io
1097
+
1098
+ img_bytes = base64.b64decode(image_base64)
1099
+ image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
1100
+
1101
+ # Prefer processor chat template if available (needed by many VL models).
1102
+ messages = []
1103
+ if system_prompt:
1104
+ messages.append({"role": "system", "content": system_prompt})
1105
+ messages.append(
1106
+ {
1107
+ "role": "user",
1108
+ "content": [
1109
+ {"type": "image", "image": image},
1110
+ {"type": "text", "text": user_prompt},
1111
+ ],
1112
+ }
1113
+ )
1114
+
1115
+ if hasattr(self.processor, "apply_chat_template"):
1116
+ prompt = self.processor.apply_chat_template(
1117
+ messages, tokenize=False, add_generation_prompt=True
1118
+ )
1119
+ else:
1120
+ raise NotImplementedError(
1121
+ "This local vision model/processor does not expose apply_chat_template(). "
1122
+ "Install/upgrade to a Transformers version that supports your model's chat template."
1123
+ )
1124
+
1125
+ inputs = self.processor(text=[prompt], images=[image], return_tensors="pt")
1126
+ inputs = {
1127
+ k: (v.to(self.model.device) if hasattr(v, "to") else v) for k, v in inputs.items()
1128
+ }
1129
+
1130
+ do_sample = temperature > 0
1131
+ with torch.no_grad():
1132
+ outputs = self.model.generate(
1133
+ **inputs,
1134
+ max_new_tokens=max_new_tokens,
1135
+ do_sample=do_sample,
1136
+ temperature=temperature if do_sample else 1.0,
1137
+ top_p=top_p,
1138
+ **kwargs,
1139
+ )
1140
+
1141
+ input_len = inputs["input_ids"].shape[1] if "input_ids" in inputs else 0
1142
+ generated = outputs[0][input_len:]
1143
+
1144
+ if hasattr(self.processor, "batch_decode"):
1145
+ text = self.processor.batch_decode([generated], skip_special_tokens=True)[0].strip()
1146
+ elif hasattr(self.processor, "tokenizer") and hasattr(self.processor.tokenizer, "decode"):
1147
+ text = self.processor.tokenizer.decode(generated, skip_special_tokens=True).strip()
1148
+ else:
1149
+ text = ""
1150
+
1151
+ return LLMResponseBuilder.from_local_format(
1152
+ content=text,
1153
+ prompt_tokens=int(input_len) if input_len else None,
1154
+ completion_tokens=int(generated.shape[0]) if hasattr(generated, "shape") else None,
1155
+ model_name=self._model_name,
1156
+ )
1157
+
1158
+
1159
+ class MLXVLMProvider(LLMProvider):
1160
+ """
1161
+ Local vision-language provider using MLX-VLM (Apple Silicon optimized).
1162
+
1163
+ Recommended for running *quantized* vision models on Mac (M1/M2/M3/M4), e.g.:
1164
+ - mlx-community/Qwen3-VL-8B-Instruct-3bit
1165
+
1166
+ Optional dependencies:
1167
+ - mlx-vlm
1168
+ - pillow
1169
+
1170
+ Notes:
1171
+ - MLX-VLM APIs can vary across versions; this provider tries a couple common call shapes.
1172
+ - For best results, use an MLX-converted model repo under `mlx-community/`.
1173
+ """
1174
+
1175
+ def __init__(
1176
+ self,
1177
+ model: str = "mlx-community/Qwen3-VL-8B-Instruct-3bit",
1178
+ *,
1179
+ default_max_tokens: int = 256,
1180
+ default_temperature: float = 0.0,
1181
+ **kwargs,
1182
+ ):
1183
+ super().__init__(model)
1184
+ self._default_max_tokens = default_max_tokens
1185
+ self._default_temperature = default_temperature
1186
+ self._default_kwargs = dict(kwargs)
1187
+
1188
+ # Lazy imports to keep base SDK light.
1189
+ try:
1190
+ import importlib
1191
+
1192
+ self._mlx_vlm = importlib.import_module("mlx_vlm")
1193
+ except ImportError as exc:
1194
+ raise ImportError(
1195
+ "mlx-vlm is required for MLXVLMProvider. Install with: pip install mlx-vlm"
1196
+ ) from exc
1197
+
1198
+ try:
1199
+ from PIL import Image # type: ignore[import-not-found]
1200
+
1201
+ self._PIL_Image = Image
1202
+ except ImportError as exc:
1203
+ raise ImportError(
1204
+ "Pillow is required for MLXVLMProvider. Install with: pip install pillow"
1205
+ ) from exc
1206
+
1207
+ # Some mlx_vlm versions expose load(model_id) -> (model, processor)
1208
+ self._model = None
1209
+ self._processor = None
1210
+ load_fn = getattr(self._mlx_vlm, "load", None)
1211
+ if callable(load_fn):
1212
+ try:
1213
+ loaded = load_fn(model)
1214
+ if isinstance(loaded, tuple) and len(loaded) >= 2:
1215
+ self._model, self._processor = loaded[0], loaded[1]
1216
+ except Exception:
1217
+ # Keep it lazy; we'll try loading on demand during generate_with_image().
1218
+ self._model, self._processor = None, None
1219
+
1220
+ def supports_json_mode(self) -> bool:
1221
+ return False
1222
+
1223
+ def supports_vision(self) -> bool:
1224
+ return True
1225
+
1226
+ @property
1227
+ def model_name(self) -> str:
1228
+ return self._model_name
1229
+
1230
+ def generate(self, system_prompt: str, user_prompt: str, **kwargs) -> LLMResponse:
1231
+ """
1232
+ Text-only generation is not a primary MLX-VLM use-case. We attempt it if the installed
1233
+ mlx_vlm exposes a compatible `generate()` signature; otherwise, raise a clear error.
1234
+ """
1235
+ generate_fn = getattr(self._mlx_vlm, "generate", None)
1236
+ if not callable(generate_fn):
1237
+ raise NotImplementedError("mlx_vlm.generate is not available in your mlx-vlm install.")
1238
+
1239
+ prompt = (system_prompt + "\n\n" if system_prompt else "") + user_prompt
1240
+ max_tokens = kwargs.pop("max_tokens", self._default_max_tokens)
1241
+ temperature = kwargs.pop("temperature", self._default_temperature)
1242
+ merged_kwargs = {**self._default_kwargs, **kwargs}
1243
+
1244
+ try:
1245
+ out = generate_fn(
1246
+ self._model_name,
1247
+ prompt=prompt,
1248
+ max_tokens=max_tokens,
1249
+ temperature=temperature,
1250
+ **merged_kwargs,
1251
+ )
1252
+ except TypeError as exc:
1253
+ if self._model is None or self._processor is None:
1254
+ raise NotImplementedError(
1255
+ "Text-only generation is not supported by this mlx-vlm version without a loaded model."
1256
+ ) from exc
1257
+ out = generate_fn(
1258
+ self._model,
1259
+ self._processor,
1260
+ prompt,
1261
+ max_tokens=max_tokens,
1262
+ temperature=temperature,
1263
+ **merged_kwargs,
1264
+ )
1265
+
1266
+ text = getattr(out, "text", None) or getattr(out, "output", None) or str(out)
1267
+ return LLMResponseBuilder.from_local_format(
1268
+ content=str(text).strip(),
1269
+ prompt_tokens=None,
1270
+ completion_tokens=None,
1271
+ model_name=self._model_name,
1272
+ )
1273
+
1274
+ def generate_with_image(
1275
+ self,
1276
+ system_prompt: str,
1277
+ user_prompt: str,
1278
+ image_base64: str,
1279
+ **kwargs,
1280
+ ) -> LLMResponse:
1281
+ import base64
1282
+ import io
1283
+
1284
+ generate_fn = getattr(self._mlx_vlm, "generate", None)
1285
+ if not callable(generate_fn):
1286
+ raise NotImplementedError("mlx_vlm.generate is not available in your mlx-vlm install.")
1287
+
1288
+ img_bytes = base64.b64decode(image_base64)
1289
+ image = self._PIL_Image.open(io.BytesIO(img_bytes)).convert("RGB")
1290
+
1291
+ prompt = (system_prompt + "\n\n" if system_prompt else "") + user_prompt
1292
+ max_tokens = kwargs.pop("max_tokens", self._default_max_tokens)
1293
+ temperature = kwargs.pop("temperature", self._default_temperature)
1294
+ merged_kwargs = {**self._default_kwargs, **kwargs}
1295
+
1296
+ # Try a couple common MLX-VLM call shapes.
1297
+ try:
1298
+ # 1) generate(model_id, image=..., prompt=...)
1299
+ out = generate_fn(
1300
+ self._model_name,
1301
+ image=image,
1302
+ prompt=prompt,
1303
+ max_tokens=max_tokens,
1304
+ temperature=temperature,
1305
+ **merged_kwargs,
1306
+ )
1307
+ except TypeError as exc:
1308
+ # 2) generate(model, processor, prompt, image, ...)
1309
+ if self._model is None or self._processor is None:
1310
+ load_fn = getattr(self._mlx_vlm, "load", None)
1311
+ if callable(load_fn):
1312
+ loaded = load_fn(self._model_name)
1313
+ if isinstance(loaded, tuple) and len(loaded) >= 2:
1314
+ self._model, self._processor = loaded[0], loaded[1]
1315
+ if self._model is None or self._processor is None:
1316
+ raise NotImplementedError(
1317
+ "Unable to call mlx_vlm.generate with your installed mlx-vlm version. "
1318
+ "Please upgrade mlx-vlm or use LocalVisionLLMProvider (Transformers backend)."
1319
+ ) from exc
1320
+ out = generate_fn(
1321
+ self._model,
1322
+ self._processor,
1323
+ prompt,
1324
+ image,
1325
+ max_tokens=max_tokens,
1326
+ temperature=temperature,
1327
+ **merged_kwargs,
1328
+ )
1329
+
1330
+ text = getattr(out, "text", None) or getattr(out, "output", None) or str(out)
1331
+ return LLMResponseBuilder.from_local_format(
1332
+ content=str(text).strip(),
1333
+ prompt_tokens=None,
1334
+ completion_tokens=None,
1335
+ model_name=self._model_name,
1336
+ )