hindsight-api 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/__init__.py +1 -1
- hindsight_api/api/http.py +7 -19
- hindsight_api/api/mcp.py +45 -5
- hindsight_api/config.py +115 -11
- hindsight_api/daemon.py +4 -1
- hindsight_api/engine/consolidation/consolidator.py +39 -3
- hindsight_api/engine/cross_encoder.py +7 -99
- hindsight_api/engine/embeddings.py +3 -93
- hindsight_api/engine/interface.py +0 -43
- hindsight_api/engine/llm_wrapper.py +93 -22
- hindsight_api/engine/memory_engine.py +37 -138
- hindsight_api/engine/response_models.py +1 -21
- hindsight_api/engine/retain/fact_extraction.py +19 -23
- hindsight_api/engine/retain/orchestrator.py +1 -4
- hindsight_api/engine/utils.py +0 -3
- hindsight_api/main.py +27 -12
- hindsight_api/mcp_tools.py +31 -12
- hindsight_api/metrics.py +3 -3
- hindsight_api/pg0.py +1 -1
- hindsight_api/worker/main.py +11 -11
- hindsight_api/worker/poller.py +226 -97
- {hindsight_api-0.4.1.dist-info → hindsight_api-0.4.3.dist-info}/METADATA +2 -1
- {hindsight_api-0.4.1.dist-info → hindsight_api-0.4.3.dist-info}/RECORD +25 -25
- {hindsight_api-0.4.1.dist-info → hindsight_api-0.4.3.dist-info}/WHEEL +0 -0
- {hindsight_api-0.4.1.dist-info → hindsight_api-0.4.3.dist-info}/entry_points.txt +0 -0
|
@@ -178,108 +178,16 @@ class LocalSTCrossEncoder(CrossEncoderModel):
|
|
|
178
178
|
else:
|
|
179
179
|
logger.info("Reranker: local provider initialized (using existing executor)")
|
|
180
180
|
|
|
181
|
-
def
|
|
182
|
-
"""
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
On macOS, long-running daemons can lose XPC connections to system services
|
|
186
|
-
when the process is idle for extended periods.
|
|
187
|
-
"""
|
|
188
|
-
error_str = str(error).lower()
|
|
189
|
-
return "xpc_error_connection_invalid" in error_str or "xpc error" in error_str
|
|
190
|
-
|
|
191
|
-
def _reinitialize_model_sync(self) -> None:
|
|
192
|
-
"""
|
|
193
|
-
Clear and reinitialize the cross-encoder model synchronously.
|
|
194
|
-
|
|
195
|
-
This is used to recover from XPC errors on macOS where the
|
|
196
|
-
PyTorch/MPS backend loses its connection to system services.
|
|
197
|
-
"""
|
|
198
|
-
logger.warning(f"Reinitializing reranker model {self.model_name} due to backend error")
|
|
199
|
-
|
|
200
|
-
# Clear existing model
|
|
201
|
-
self._model = None
|
|
202
|
-
|
|
203
|
-
# Force garbage collection to free resources
|
|
204
|
-
import gc
|
|
205
|
-
|
|
206
|
-
import torch
|
|
207
|
-
|
|
208
|
-
gc.collect()
|
|
209
|
-
|
|
210
|
-
# If using CUDA/MPS, clear the cache
|
|
211
|
-
if torch.cuda.is_available():
|
|
212
|
-
torch.cuda.empty_cache()
|
|
213
|
-
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
214
|
-
try:
|
|
215
|
-
torch.mps.empty_cache()
|
|
216
|
-
except AttributeError:
|
|
217
|
-
pass # Method might not exist in all PyTorch versions
|
|
218
|
-
|
|
219
|
-
# Reinitialize the model
|
|
220
|
-
try:
|
|
221
|
-
from sentence_transformers import CrossEncoder
|
|
222
|
-
except ImportError:
|
|
223
|
-
raise ImportError(
|
|
224
|
-
"sentence-transformers is required for LocalSTCrossEncoder. "
|
|
225
|
-
"Install it with: pip install sentence-transformers"
|
|
226
|
-
)
|
|
227
|
-
|
|
228
|
-
# Determine device based on hardware availability
|
|
229
|
-
if self.force_cpu:
|
|
230
|
-
device = "cpu"
|
|
231
|
-
else:
|
|
232
|
-
# Wrap in try-except to gracefully handle any device detection issues
|
|
233
|
-
device = "cpu" # Default to CPU
|
|
234
|
-
try:
|
|
235
|
-
has_gpu = torch.cuda.is_available() or (
|
|
236
|
-
hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
|
237
|
-
)
|
|
238
|
-
if has_gpu:
|
|
239
|
-
device = None # Let sentence-transformers auto-detect GPU/MPS
|
|
240
|
-
except Exception as e:
|
|
241
|
-
logger.warning(f"Failed to detect GPU/MPS during reinit, falling back to CPU: {e}")
|
|
242
|
-
|
|
243
|
-
self._model = CrossEncoder(
|
|
244
|
-
self.model_name,
|
|
245
|
-
device=device,
|
|
246
|
-
model_kwargs={"low_cpu_mem_usage": False},
|
|
247
|
-
)
|
|
248
|
-
|
|
249
|
-
logger.info("Reranker: local provider reinitialized successfully")
|
|
250
|
-
|
|
251
|
-
def _predict_with_recovery(self, pairs: list[tuple[str, str]]) -> list[float]:
|
|
252
|
-
"""
|
|
253
|
-
Predict with automatic recovery from XPC errors.
|
|
254
|
-
|
|
255
|
-
This runs synchronously in the thread pool.
|
|
256
|
-
"""
|
|
257
|
-
max_retries = 1
|
|
258
|
-
for attempt in range(max_retries + 1):
|
|
259
|
-
try:
|
|
260
|
-
scores = self._model.predict(pairs, show_progress_bar=False)
|
|
261
|
-
return scores.tolist() if hasattr(scores, "tolist") else list(scores)
|
|
262
|
-
except Exception as e:
|
|
263
|
-
# Check if this is an XPC error (macOS daemon issue)
|
|
264
|
-
if self._is_xpc_error(e) and attempt < max_retries:
|
|
265
|
-
logger.warning(f"XPC error detected in reranker (attempt {attempt + 1}): {e}")
|
|
266
|
-
try:
|
|
267
|
-
self._reinitialize_model_sync()
|
|
268
|
-
logger.info("Reranker reinitialized successfully, retrying prediction")
|
|
269
|
-
continue
|
|
270
|
-
except Exception as reinit_error:
|
|
271
|
-
logger.error(f"Failed to reinitialize reranker: {reinit_error}")
|
|
272
|
-
raise Exception(f"Failed to recover from XPC error: {str(e)}")
|
|
273
|
-
else:
|
|
274
|
-
# Not an XPC error or out of retries
|
|
275
|
-
raise
|
|
181
|
+
def _predict_sync(self, pairs: list[tuple[str, str]]) -> list[float]:
|
|
182
|
+
"""Synchronous prediction wrapper for thread pool execution."""
|
|
183
|
+
scores = self._model.predict(pairs, show_progress_bar=False)
|
|
184
|
+
return scores.tolist() if hasattr(scores, "tolist") else list(scores)
|
|
276
185
|
|
|
277
186
|
async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
|
|
278
187
|
"""
|
|
279
188
|
Score query-document pairs for relevance.
|
|
280
189
|
|
|
281
190
|
Uses a dedicated thread pool with limited workers to prevent CPU thrashing.
|
|
282
|
-
Automatically recovers from XPC errors on macOS by reinitializing the model.
|
|
283
191
|
|
|
284
192
|
Args:
|
|
285
193
|
pairs: List of (query, document) tuples to score
|
|
@@ -294,7 +202,7 @@ class LocalSTCrossEncoder(CrossEncoderModel):
|
|
|
294
202
|
loop = asyncio.get_event_loop()
|
|
295
203
|
return await loop.run_in_executor(
|
|
296
204
|
LocalSTCrossEncoder._executor,
|
|
297
|
-
self.
|
|
205
|
+
self._predict_sync,
|
|
298
206
|
pairs,
|
|
299
207
|
)
|
|
300
208
|
|
|
@@ -706,7 +614,7 @@ class FlashRankCrossEncoder(CrossEncoderModel):
|
|
|
706
614
|
return
|
|
707
615
|
|
|
708
616
|
try:
|
|
709
|
-
from flashrank import Ranker
|
|
617
|
+
from flashrank import Ranker
|
|
710
618
|
except ImportError:
|
|
711
619
|
raise ImportError("flashrank is required for FlashRankCrossEncoder. Install it with: pip install flashrank")
|
|
712
620
|
|
|
@@ -733,7 +641,7 @@ class FlashRankCrossEncoder(CrossEncoderModel):
|
|
|
733
641
|
|
|
734
642
|
def _predict_sync(self, pairs: list[tuple[str, str]]) -> list[float]:
|
|
735
643
|
"""Synchronous predict - processes each query group."""
|
|
736
|
-
from flashrank import RerankRequest
|
|
644
|
+
from flashrank import RerankRequest
|
|
737
645
|
|
|
738
646
|
if not pairs:
|
|
739
647
|
return []
|
|
@@ -166,82 +166,10 @@ class LocalSTEmbeddings(Embeddings):
|
|
|
166
166
|
self._dimension = self._model.get_sentence_embedding_dimension()
|
|
167
167
|
logger.info(f"Embeddings: local provider initialized (dim: {self._dimension})")
|
|
168
168
|
|
|
169
|
-
def _is_xpc_error(self, error: Exception) -> bool:
|
|
170
|
-
"""
|
|
171
|
-
Check if an error is an XPC connection error (macOS daemon issue).
|
|
172
|
-
|
|
173
|
-
On macOS, long-running daemons can lose XPC connections to system services
|
|
174
|
-
when the process is idle for extended periods.
|
|
175
|
-
"""
|
|
176
|
-
error_str = str(error).lower()
|
|
177
|
-
return "xpc_error_connection_invalid" in error_str or "xpc error" in error_str
|
|
178
|
-
|
|
179
|
-
def _reinitialize_model_sync(self) -> None:
|
|
180
|
-
"""
|
|
181
|
-
Clear and reinitialize the embedding model synchronously.
|
|
182
|
-
|
|
183
|
-
This is used to recover from XPC errors on macOS where the
|
|
184
|
-
PyTorch/MPS backend loses its connection to system services.
|
|
185
|
-
"""
|
|
186
|
-
logger.warning(f"Reinitializing embedding model {self.model_name} due to backend error")
|
|
187
|
-
|
|
188
|
-
# Clear existing model
|
|
189
|
-
self._model = None
|
|
190
|
-
|
|
191
|
-
# Force garbage collection to free resources
|
|
192
|
-
import gc
|
|
193
|
-
|
|
194
|
-
import torch
|
|
195
|
-
|
|
196
|
-
gc.collect()
|
|
197
|
-
|
|
198
|
-
# If using CUDA/MPS, clear the cache
|
|
199
|
-
if torch.cuda.is_available():
|
|
200
|
-
torch.cuda.empty_cache()
|
|
201
|
-
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
202
|
-
try:
|
|
203
|
-
torch.mps.empty_cache()
|
|
204
|
-
except AttributeError:
|
|
205
|
-
pass # Method might not exist in all PyTorch versions
|
|
206
|
-
|
|
207
|
-
# Reinitialize the model (inline version of initialize() but synchronous)
|
|
208
|
-
try:
|
|
209
|
-
from sentence_transformers import SentenceTransformer
|
|
210
|
-
except ImportError:
|
|
211
|
-
raise ImportError(
|
|
212
|
-
"sentence-transformers is required for LocalSTEmbeddings. "
|
|
213
|
-
"Install it with: pip install sentence-transformers"
|
|
214
|
-
)
|
|
215
|
-
|
|
216
|
-
# Determine device based on hardware availability
|
|
217
|
-
if self.force_cpu:
|
|
218
|
-
device = "cpu"
|
|
219
|
-
else:
|
|
220
|
-
# Wrap in try-except to gracefully handle any device detection issues
|
|
221
|
-
device = "cpu" # Default to CPU
|
|
222
|
-
try:
|
|
223
|
-
has_gpu = torch.cuda.is_available() or (
|
|
224
|
-
hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
|
225
|
-
)
|
|
226
|
-
if has_gpu:
|
|
227
|
-
device = None # Let sentence-transformers auto-detect GPU/MPS
|
|
228
|
-
except Exception as e:
|
|
229
|
-
logger.warning(f"Failed to detect GPU/MPS during reinit, falling back to CPU: {e}")
|
|
230
|
-
|
|
231
|
-
self._model = SentenceTransformer(
|
|
232
|
-
self.model_name,
|
|
233
|
-
device=device,
|
|
234
|
-
model_kwargs={"low_cpu_mem_usage": False},
|
|
235
|
-
)
|
|
236
|
-
|
|
237
|
-
logger.info("Embeddings: local provider reinitialized successfully")
|
|
238
|
-
|
|
239
169
|
def encode(self, texts: list[str]) -> list[list[float]]:
|
|
240
170
|
"""
|
|
241
171
|
Generate embeddings for a list of texts.
|
|
242
172
|
|
|
243
|
-
Automatically recovers from XPC errors on macOS by reinitializing the model.
|
|
244
|
-
|
|
245
173
|
Args:
|
|
246
174
|
texts: List of text strings to encode
|
|
247
175
|
|
|
@@ -251,26 +179,8 @@ class LocalSTEmbeddings(Embeddings):
|
|
|
251
179
|
if self._model is None:
|
|
252
180
|
raise RuntimeError("Embeddings not initialized. Call initialize() first.")
|
|
253
181
|
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
for attempt in range(max_retries + 1):
|
|
257
|
-
try:
|
|
258
|
-
embeddings = self._model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
|
|
259
|
-
return [emb.tolist() for emb in embeddings]
|
|
260
|
-
except Exception as e:
|
|
261
|
-
# Check if this is an XPC error (macOS daemon issue)
|
|
262
|
-
if self._is_xpc_error(e) and attempt < max_retries:
|
|
263
|
-
logger.warning(f"XPC error detected in embedding generation (attempt {attempt + 1}): {e}")
|
|
264
|
-
try:
|
|
265
|
-
self._reinitialize_model_sync()
|
|
266
|
-
logger.info("Model reinitialized successfully, retrying embedding generation")
|
|
267
|
-
continue
|
|
268
|
-
except Exception as reinit_error:
|
|
269
|
-
logger.error(f"Failed to reinitialize model: {reinit_error}")
|
|
270
|
-
raise Exception(f"Failed to recover from XPC error: {str(e)}")
|
|
271
|
-
else:
|
|
272
|
-
# Not an XPC error or out of retries
|
|
273
|
-
raise
|
|
182
|
+
embeddings = self._model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
|
|
183
|
+
return [emb.tolist() for emb in embeddings]
|
|
274
184
|
|
|
275
185
|
|
|
276
186
|
class RemoteTEIEmbeddings(Embeddings):
|
|
@@ -635,7 +545,7 @@ class CohereEmbeddings(Embeddings):
|
|
|
635
545
|
model=self.model,
|
|
636
546
|
input_type=self.input_type,
|
|
637
547
|
)
|
|
638
|
-
if response.embeddings:
|
|
548
|
+
if response.embeddings and isinstance(response.embeddings, list):
|
|
639
549
|
self._dimension = len(response.embeddings[0])
|
|
640
550
|
|
|
641
551
|
logger.info(f"Embeddings: Cohere provider initialized (model: {self.model}, dim: {self._dimension})")
|
|
@@ -442,49 +442,6 @@ class MemoryEngineInterface(ABC):
|
|
|
442
442
|
"""
|
|
443
443
|
...
|
|
444
444
|
|
|
445
|
-
@abstractmethod
|
|
446
|
-
async def get_entity_observations(
|
|
447
|
-
self,
|
|
448
|
-
bank_id: str,
|
|
449
|
-
entity_id: str,
|
|
450
|
-
*,
|
|
451
|
-
limit: int = 10,
|
|
452
|
-
request_context: "RequestContext",
|
|
453
|
-
) -> list[Any]:
|
|
454
|
-
"""
|
|
455
|
-
Get observations for an entity.
|
|
456
|
-
|
|
457
|
-
Args:
|
|
458
|
-
bank_id: The memory bank ID.
|
|
459
|
-
entity_id: The entity ID.
|
|
460
|
-
limit: Maximum observations.
|
|
461
|
-
request_context: Request context for authentication.
|
|
462
|
-
|
|
463
|
-
Returns:
|
|
464
|
-
List of EntityObservation objects.
|
|
465
|
-
"""
|
|
466
|
-
...
|
|
467
|
-
|
|
468
|
-
@abstractmethod
|
|
469
|
-
async def regenerate_entity_observations(
|
|
470
|
-
self,
|
|
471
|
-
bank_id: str,
|
|
472
|
-
entity_id: str,
|
|
473
|
-
entity_name: str,
|
|
474
|
-
*,
|
|
475
|
-
request_context: "RequestContext",
|
|
476
|
-
) -> None:
|
|
477
|
-
"""
|
|
478
|
-
Regenerate observations for an entity.
|
|
479
|
-
|
|
480
|
-
Args:
|
|
481
|
-
bank_id: The memory bank ID.
|
|
482
|
-
entity_id: The entity ID.
|
|
483
|
-
entity_name: The entity's canonical name.
|
|
484
|
-
request_context: Request context for authentication.
|
|
485
|
-
"""
|
|
486
|
-
...
|
|
487
|
-
|
|
488
445
|
# =========================================================================
|
|
489
446
|
# Statistics & Operations
|
|
490
447
|
# =========================================================================
|
|
@@ -16,6 +16,15 @@ from google.genai import errors as genai_errors
|
|
|
16
16
|
from google.genai import types as genai_types
|
|
17
17
|
from openai import APIConnectionError, APIStatusError, AsyncOpenAI, LengthFinishReasonError
|
|
18
18
|
|
|
19
|
+
# Vertex AI imports (conditional)
|
|
20
|
+
try:
|
|
21
|
+
import google.auth
|
|
22
|
+
from google.oauth2 import service_account
|
|
23
|
+
|
|
24
|
+
VERTEXAI_AVAILABLE = True
|
|
25
|
+
except ImportError:
|
|
26
|
+
VERTEXAI_AVAILABLE = False
|
|
27
|
+
|
|
19
28
|
from ..config import (
|
|
20
29
|
DEFAULT_LLM_MAX_CONCURRENT,
|
|
21
30
|
DEFAULT_LLM_TIMEOUT,
|
|
@@ -88,7 +97,7 @@ class LLMProvider:
|
|
|
88
97
|
self.groq_service_tier = groq_service_tier or os.getenv(ENV_LLM_GROQ_SERVICE_TIER, "auto")
|
|
89
98
|
|
|
90
99
|
# Validate provider
|
|
91
|
-
valid_providers = ["openai", "groq", "ollama", "gemini", "anthropic", "lmstudio", "mock"]
|
|
100
|
+
valid_providers = ["openai", "groq", "ollama", "gemini", "anthropic", "lmstudio", "vertexai", "mock"]
|
|
92
101
|
if self.provider not in valid_providers:
|
|
93
102
|
raise ValueError(f"Invalid LLM provider: {self.provider}. Must be one of: {', '.join(valid_providers)}")
|
|
94
103
|
|
|
@@ -105,8 +114,51 @@ class LLMProvider:
|
|
|
105
114
|
elif self.provider == "lmstudio":
|
|
106
115
|
self.base_url = "http://localhost:1234/v1"
|
|
107
116
|
|
|
108
|
-
#
|
|
109
|
-
|
|
117
|
+
# Vertex AI config — stored for client creation below
|
|
118
|
+
self._vertexai_project_id: str | None = None
|
|
119
|
+
self._vertexai_region: str | None = None
|
|
120
|
+
self._vertexai_credentials: Any = None
|
|
121
|
+
|
|
122
|
+
if self.provider == "vertexai":
|
|
123
|
+
from ..config import get_config
|
|
124
|
+
|
|
125
|
+
config = get_config()
|
|
126
|
+
|
|
127
|
+
self._vertexai_project_id = config.llm_vertexai_project_id
|
|
128
|
+
if not self._vertexai_project_id:
|
|
129
|
+
raise ValueError(
|
|
130
|
+
"HINDSIGHT_API_LLM_VERTEXAI_PROJECT_ID is required for Vertex AI provider. "
|
|
131
|
+
"Set it to your GCP project ID."
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
self._vertexai_region = config.llm_vertexai_region or "us-central1"
|
|
135
|
+
service_account_key = config.llm_vertexai_service_account_key
|
|
136
|
+
|
|
137
|
+
# Load explicit service account credentials if provided
|
|
138
|
+
if service_account_key:
|
|
139
|
+
if not VERTEXAI_AVAILABLE:
|
|
140
|
+
raise ValueError(
|
|
141
|
+
"Vertex AI service account auth requires 'google-auth' package. "
|
|
142
|
+
"Install with: pip install google-auth"
|
|
143
|
+
)
|
|
144
|
+
self._vertexai_credentials = service_account.Credentials.from_service_account_file(
|
|
145
|
+
service_account_key,
|
|
146
|
+
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
147
|
+
)
|
|
148
|
+
logger.info(f"Vertex AI: Using service account key: {service_account_key}")
|
|
149
|
+
|
|
150
|
+
# Strip google/ prefix from model name — native SDK uses bare names
|
|
151
|
+
# e.g. "google/gemini-2.0-flash-lite-001" -> "gemini-2.0-flash-lite-001"
|
|
152
|
+
if self.model.startswith("google/"):
|
|
153
|
+
self.model = self.model[len("google/") :]
|
|
154
|
+
|
|
155
|
+
logger.info(
|
|
156
|
+
f"Vertex AI: project={self._vertexai_project_id}, region={self._vertexai_region}, "
|
|
157
|
+
f"model={self.model}, auth={'service_account' if service_account_key else 'ADC'}"
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# Validate API key (not needed for ollama, lmstudio, vertexai, or mock)
|
|
161
|
+
if self.provider not in ("ollama", "lmstudio", "vertexai", "mock") and not self.api_key:
|
|
110
162
|
raise ValueError(f"API key not found for {self.provider}")
|
|
111
163
|
|
|
112
164
|
# Get timeout config (set HINDSIGHT_API_LLM_TIMEOUT for local LLMs that need longer timeouts)
|
|
@@ -132,6 +184,17 @@ class LLMProvider:
|
|
|
132
184
|
if self.timeout:
|
|
133
185
|
anthropic_kwargs["timeout"] = self.timeout
|
|
134
186
|
self._anthropic_client = AsyncAnthropic(**anthropic_kwargs)
|
|
187
|
+
elif self.provider == "vertexai":
|
|
188
|
+
# Native genai SDK with Vertex AI — handles ADC automatically,
|
|
189
|
+
# or uses explicit service account credentials if provided
|
|
190
|
+
client_kwargs = {
|
|
191
|
+
"vertexai": True,
|
|
192
|
+
"project": self._vertexai_project_id,
|
|
193
|
+
"location": self._vertexai_region,
|
|
194
|
+
}
|
|
195
|
+
if self._vertexai_credentials is not None:
|
|
196
|
+
client_kwargs["credentials"] = self._vertexai_credentials
|
|
197
|
+
self._gemini_client = genai.Client(**client_kwargs)
|
|
135
198
|
elif self.provider in ("ollama", "lmstudio"):
|
|
136
199
|
# Use dummy key if not provided for local
|
|
137
200
|
api_key = self.api_key or "local"
|
|
@@ -223,8 +286,8 @@ class LLMProvider:
|
|
|
223
286
|
return_usage,
|
|
224
287
|
)
|
|
225
288
|
|
|
226
|
-
# Handle Gemini
|
|
227
|
-
if self.provider
|
|
289
|
+
# Handle Gemini and Vertex AI providers (both use native genai SDK)
|
|
290
|
+
if self.provider in ("gemini", "vertexai"):
|
|
228
291
|
return await self._call_gemini(
|
|
229
292
|
messages,
|
|
230
293
|
response_format,
|
|
@@ -342,11 +405,13 @@ class LLMProvider:
|
|
|
342
405
|
schema_msg = f"\n\nYou must respond with valid JSON matching this schema:\n{json.dumps(schema, indent=2)}"
|
|
343
406
|
|
|
344
407
|
if call_params["messages"] and call_params["messages"][0].get("role") == "system":
|
|
345
|
-
call_params["messages"][0]
|
|
408
|
+
first_msg = call_params["messages"][0]
|
|
409
|
+
if isinstance(first_msg, dict) and isinstance(first_msg.get("content"), str):
|
|
410
|
+
first_msg["content"] += schema_msg
|
|
346
411
|
elif call_params["messages"]:
|
|
347
|
-
call_params["messages"][0]
|
|
348
|
-
|
|
349
|
-
|
|
412
|
+
first_msg = call_params["messages"][0]
|
|
413
|
+
if isinstance(first_msg, dict) and isinstance(first_msg.get("content"), str):
|
|
414
|
+
first_msg["content"] = schema_msg + "\n\n" + first_msg["content"]
|
|
350
415
|
if self.provider not in ("lmstudio", "ollama"):
|
|
351
416
|
# LM Studio and Ollama don't support json_object response format reliably
|
|
352
417
|
# We rely on the schema in the system message instead
|
|
@@ -586,8 +651,8 @@ class LLMProvider:
|
|
|
586
651
|
messages, tools, max_completion_tokens, max_retries, initial_backoff, max_backoff, start_time, scope
|
|
587
652
|
)
|
|
588
653
|
|
|
589
|
-
# Handle Gemini (convert to Gemini tool format)
|
|
590
|
-
if self.provider
|
|
654
|
+
# Handle Gemini and Vertex AI (convert to Gemini tool format)
|
|
655
|
+
if self.provider in ("gemini", "vertexai"):
|
|
591
656
|
return await self._call_with_tools_gemini(
|
|
592
657
|
messages, tools, max_retries, initial_backoff, max_backoff, start_time, scope
|
|
593
658
|
)
|
|
@@ -917,18 +982,20 @@ class LLMProvider:
|
|
|
917
982
|
tool_calls: list[LLMToolCall] = []
|
|
918
983
|
|
|
919
984
|
if response.candidates and response.candidates[0].content:
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
985
|
+
parts = response.candidates[0].content.parts
|
|
986
|
+
if parts:
|
|
987
|
+
for part in parts:
|
|
988
|
+
if hasattr(part, "text") and part.text:
|
|
989
|
+
content = part.text
|
|
990
|
+
if hasattr(part, "function_call") and part.function_call:
|
|
991
|
+
fc = part.function_call
|
|
992
|
+
tool_calls.append(
|
|
993
|
+
LLMToolCall(
|
|
994
|
+
id=f"gemini_{len(tool_calls)}",
|
|
995
|
+
name=fc.name,
|
|
996
|
+
arguments=dict(fc.args) if fc.args else {},
|
|
997
|
+
)
|
|
930
998
|
)
|
|
931
|
-
)
|
|
932
999
|
|
|
933
1000
|
finish_reason = "tool_calls" if tool_calls else "stop"
|
|
934
1001
|
|
|
@@ -1504,6 +1571,10 @@ class LLMProvider:
|
|
|
1504
1571
|
"""Clear the recorded mock calls."""
|
|
1505
1572
|
self._mock_calls = []
|
|
1506
1573
|
|
|
1574
|
+
async def cleanup(self) -> None:
|
|
1575
|
+
"""Clean up resources."""
|
|
1576
|
+
pass
|
|
1577
|
+
|
|
1507
1578
|
@classmethod
|
|
1508
1579
|
def for_memory(cls) -> "LLMProvider":
|
|
1509
1580
|
"""Create provider for memory operations from environment variables."""
|