hindsight-api 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,745 @@
1
+ """
2
+ OpenAI-compatible LLM provider supporting OpenAI, Groq, Ollama, and LMStudio.
3
+
4
+ This provider handles all OpenAI API-compatible models including:
5
+ - OpenAI: GPT-4, GPT-4o, GPT-5, o1, o3 (reasoning models)
6
+ - Groq: Fast inference with seed control and service tiers
7
+ - Ollama: Local models with native streaming API support
8
+ - LMStudio: Local models with OpenAI-compatible API
9
+
10
+ Features:
11
+ - Reasoning models with extended thinking (o1, o3, GPT-5 families)
12
+ - Strict JSON schema enforcement (OpenAI)
13
+ - Provider-specific parameters (Groq seed, service tier)
14
+ - Native Ollama streaming for better structured output
15
+ - Automatic token limit handling per model family
16
+ """
17
+
18
+ import asyncio
19
+ import json
20
+ import logging
21
+ import os
22
+ import re
23
+ import time
24
+ from typing import Any
25
+
26
+ import httpx
27
+ from openai import APIConnectionError, APIStatusError, AsyncOpenAI, LengthFinishReasonError
28
+
29
+ from hindsight_api.config import DEFAULT_LLM_TIMEOUT, ENV_LLM_TIMEOUT
30
+ from hindsight_api.engine.llm_interface import LLMInterface, OutputTooLongError
31
+ from hindsight_api.engine.response_models import LLMToolCall, LLMToolCallResult, TokenUsage
32
+ from hindsight_api.metrics import get_metrics_collector
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ # Seed applied to every Groq request for deterministic behavior
37
+ DEFAULT_LLM_SEED = 4242
38
+
39
+
40
+ class OpenAICompatibleLLM(LLMInterface):
41
+ """
42
+ LLM provider for OpenAI-compatible APIs.
43
+
44
+ Supports:
45
+ - OpenAI: Standard models (GPT-4, GPT-4o) and reasoning models (o1, o3, GPT-5)
46
+ - Groq: Fast inference with seed control and service tiers
47
+ - Ollama: Local models with native streaming API for better structured output
48
+ - LMStudio: Local models with OpenAI-compatible API
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ provider: str,
54
+ api_key: str,
55
+ base_url: str,
56
+ model: str,
57
+ reasoning_effort: str = "low",
58
+ timeout: float | None = None,
59
+ groq_service_tier: str | None = None,
60
+ **kwargs: Any,
61
+ ):
62
+ """
63
+ Initialize OpenAI-compatible LLM provider.
64
+
65
+ Args:
66
+ provider: Provider name ("openai", "groq", "ollama", "lmstudio").
67
+ api_key: API key (optional for ollama/lmstudio).
68
+ base_url: Base URL for the API (uses defaults for groq/ollama/lmstudio if empty).
69
+ model: Model name.
70
+ reasoning_effort: Reasoning effort level for supported models ("low", "medium", "high").
71
+ timeout: Request timeout in seconds (uses env var or 300s default).
72
+ groq_service_tier: Groq service tier ("on_demand", "flex", "auto").
73
+ **kwargs: Additional provider-specific parameters.
74
+ """
75
+ super().__init__(provider, api_key, base_url, model, reasoning_effort, **kwargs)
76
+
77
+ # Validate provider
78
+ valid_providers = ["openai", "groq", "ollama", "lmstudio"]
79
+ if self.provider not in valid_providers:
80
+ raise ValueError(f"OpenAICompatibleLLM only supports: {', '.join(valid_providers)}. Got: {self.provider}")
81
+
82
+ # Set default base URLs
83
+ if not self.base_url:
84
+ if self.provider == "groq":
85
+ self.base_url = "https://api.groq.com/openai/v1"
86
+ elif self.provider == "ollama":
87
+ self.base_url = "http://localhost:11434/v1"
88
+ elif self.provider == "lmstudio":
89
+ self.base_url = "http://localhost:1234/v1"
90
+
91
+ # For ollama/lmstudio, use dummy key if not provided
92
+ if self.provider in ("ollama", "lmstudio") and not self.api_key:
93
+ self.api_key = "local"
94
+
95
+ # Validate API key for cloud providers
96
+ if self.provider in ("openai", "groq") and not self.api_key:
97
+ raise ValueError(f"API key is required for {self.provider}")
98
+
99
+ # Groq service tier configuration
100
+ self.groq_service_tier = groq_service_tier or os.getenv("HINDSIGHT_API_LLM_GROQ_SERVICE_TIER", "auto")
101
+
102
+ # Get timeout config
103
+ self.timeout = timeout or float(os.getenv(ENV_LLM_TIMEOUT, str(DEFAULT_LLM_TIMEOUT)))
104
+
105
+ # Create OpenAI client
106
+ client_kwargs: dict[str, Any] = {"api_key": self.api_key, "max_retries": 0}
107
+ if self.base_url:
108
+ client_kwargs["base_url"] = self.base_url
109
+ if self.timeout:
110
+ client_kwargs["timeout"] = self.timeout
111
+
112
+ self._client = AsyncOpenAI(**client_kwargs)
113
+ logger.info(
114
+ f"OpenAI-compatible client initialized: provider={self.provider}, model={self.model}, "
115
+ f"base_url={self.base_url or 'default'}"
116
+ )
117
+
118
+ async def verify_connection(self) -> None:
119
+ """
120
+ Verify that the provider is configured correctly by making a simple test call.
121
+
122
+ Raises:
123
+ RuntimeError: If the connection test fails.
124
+ """
125
+ try:
126
+ logger.info(f"Verifying connection: {self.provider}/{self.model}")
127
+ await self.call(
128
+ messages=[{"role": "user", "content": "Say 'ok'"}],
129
+ max_completion_tokens=100,
130
+ max_retries=2,
131
+ initial_backoff=0.5,
132
+ max_backoff=2.0,
133
+ )
134
+ logger.info(f"Connection verified: {self.provider}/{self.model}")
135
+ except Exception as e:
136
+ raise RuntimeError(f"Connection verification failed for {self.provider}/{self.model}: {e}") from e
137
+
138
+ def _supports_reasoning_model(self) -> bool:
139
+ """Check if the current model is a reasoning model (o1, o3, GPT-5, DeepSeek)."""
140
+ model_lower = self.model.lower()
141
+ return any(x in model_lower for x in ["gpt-5", "o1", "o3", "deepseek"])
142
+
143
+ def _get_max_reasoning_tokens(self) -> int | None:
144
+ """Get max reasoning tokens for reasoning models."""
145
+ model_lower = self.model.lower()
146
+
147
+ # GPT-4 and GPT-4.1 models have different caps
148
+ if any(x in model_lower for x in ["gpt-4.1", "gpt-4-"]):
149
+ return 32000
150
+ elif "gpt-4o" in model_lower:
151
+ return 16384
152
+
153
+ return None
154
+
155
+ async def call(
156
+ self,
157
+ messages: list[dict[str, str]],
158
+ response_format: Any | None = None,
159
+ max_completion_tokens: int | None = None,
160
+ temperature: float | None = None,
161
+ scope: str = "memory",
162
+ max_retries: int = 10,
163
+ initial_backoff: float = 1.0,
164
+ max_backoff: float = 60.0,
165
+ skip_validation: bool = False,
166
+ strict_schema: bool = False,
167
+ return_usage: bool = False,
168
+ ) -> Any:
169
+ """
170
+ Make an LLM API call with retry logic.
171
+
172
+ Args:
173
+ messages: List of message dicts with 'role' and 'content'.
174
+ response_format: Optional Pydantic model for structured output.
175
+ max_completion_tokens: Maximum tokens in response.
176
+ temperature: Sampling temperature (0.0-2.0).
177
+ scope: Scope identifier for tracking.
178
+ max_retries: Maximum retry attempts.
179
+ initial_backoff: Initial backoff time in seconds.
180
+ max_backoff: Maximum backoff time in seconds.
181
+ skip_validation: Return raw JSON without Pydantic validation.
182
+ strict_schema: Use strict JSON schema enforcement (OpenAI only).
183
+ return_usage: If True, return tuple (result, TokenUsage) instead of just result.
184
+
185
+ Returns:
186
+ If return_usage=False: Parsed response if response_format is provided, otherwise text content.
187
+ If return_usage=True: Tuple of (result, TokenUsage) with token counts.
188
+
189
+ Raises:
190
+ OutputTooLongError: If output exceeds token limits.
191
+ Exception: Re-raises API errors after retries exhausted.
192
+ """
193
+ # Handle Ollama with native API for structured output (better schema enforcement)
194
+ if self.provider == "ollama" and response_format is not None:
195
+ return await self._call_ollama_native(
196
+ messages=messages,
197
+ response_format=response_format,
198
+ max_completion_tokens=max_completion_tokens,
199
+ temperature=temperature,
200
+ max_retries=max_retries,
201
+ initial_backoff=initial_backoff,
202
+ max_backoff=max_backoff,
203
+ skip_validation=skip_validation,
204
+ scope=scope,
205
+ return_usage=return_usage,
206
+ )
207
+
208
+ start_time = time.time()
209
+
210
+ # Build call parameters
211
+ call_params: dict[str, Any] = {
212
+ "model": self.model,
213
+ "messages": messages,
214
+ }
215
+
216
+ # Check if model supports reasoning parameter
217
+ is_reasoning_model = self._supports_reasoning_model()
218
+
219
+ # Apply model-specific token limits
220
+ if max_completion_tokens is not None:
221
+ max_tokens_cap = self._get_max_reasoning_tokens()
222
+ if max_tokens_cap and max_completion_tokens > max_tokens_cap:
223
+ max_completion_tokens = max_tokens_cap
224
+ # For reasoning models, enforce minimum to ensure space for reasoning + output
225
+ if is_reasoning_model and max_completion_tokens < 16000:
226
+ max_completion_tokens = 16000
227
+ call_params["max_completion_tokens"] = max_completion_tokens
228
+
229
+ # Temperature - reasoning models don't support custom temperature
230
+ if temperature is not None and not is_reasoning_model:
231
+ call_params["temperature"] = temperature
232
+
233
+ # Set reasoning_effort for reasoning models
234
+ if is_reasoning_model:
235
+ call_params["reasoning_effort"] = self.reasoning_effort
236
+
237
+ # Provider-specific parameters
238
+ if self.provider == "groq":
239
+ call_params["seed"] = DEFAULT_LLM_SEED
240
+ extra_body: dict[str, Any] = {}
241
+ # Add service_tier if configured
242
+ if self.groq_service_tier:
243
+ extra_body["service_tier"] = self.groq_service_tier
244
+ # Add reasoning parameters for reasoning models
245
+ if is_reasoning_model:
246
+ extra_body["include_reasoning"] = False
247
+ if extra_body:
248
+ call_params["extra_body"] = extra_body
249
+
250
+ # Prepare response format ONCE before retry loop
251
+ if response_format is not None:
252
+ schema = None
253
+ if hasattr(response_format, "model_json_schema"):
254
+ schema = response_format.model_json_schema()
255
+
256
+ if strict_schema and schema is not None:
257
+ # Use OpenAI's strict JSON schema enforcement
258
+ call_params["response_format"] = {
259
+ "type": "json_schema",
260
+ "json_schema": {
261
+ "name": "response",
262
+ "strict": True,
263
+ "schema": schema,
264
+ },
265
+ }
266
+ else:
267
+ # Soft enforcement: add schema to prompt and use json_object mode
268
+ if schema is not None:
269
+ schema_msg = (
270
+ f"\n\nYou must respond with valid JSON matching this schema:\n{json.dumps(schema, indent=2)}"
271
+ )
272
+
273
+ if call_params["messages"] and call_params["messages"][0].get("role") == "system":
274
+ first_msg = call_params["messages"][0]
275
+ if isinstance(first_msg, dict) and isinstance(first_msg.get("content"), str):
276
+ first_msg["content"] += schema_msg
277
+ elif call_params["messages"]:
278
+ first_msg = call_params["messages"][0]
279
+ if isinstance(first_msg, dict) and isinstance(first_msg.get("content"), str):
280
+ first_msg["content"] = schema_msg + "\n\n" + first_msg["content"]
281
+ if self.provider not in ("lmstudio", "ollama"):
282
+ # LM Studio and Ollama don't support json_object response format reliably
283
+ call_params["response_format"] = {"type": "json_object"}
284
+
285
+ last_exception = None
286
+
287
+ for attempt in range(max_retries + 1):
288
+ try:
289
+ if response_format is not None:
290
+ response = await self._client.chat.completions.create(**call_params)
291
+
292
+ content = response.choices[0].message.content
293
+
294
+ # Strip reasoning model thinking tags
295
+ # Supports: <think>, <thinking>, <reasoning>, |startthink|/|endthink|
296
+ if content:
297
+ original_len = len(content)
298
+ content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL)
299
+ content = re.sub(r"<thinking>.*?</thinking>", "", content, flags=re.DOTALL)
300
+ content = re.sub(r"<reasoning>.*?</reasoning>", "", content, flags=re.DOTALL)
301
+ content = re.sub(r"\|startthink\|.*?\|endthink\|", "", content, flags=re.DOTALL)
302
+ content = content.strip()
303
+ if len(content) < original_len:
304
+ logger.debug(f"Stripped {original_len - len(content)} chars of reasoning tokens")
305
+
306
+ # For local models, they may wrap JSON in markdown code blocks
307
+ if self.provider in ("lmstudio", "ollama"):
308
+ clean_content = content
309
+ if "```json" in content:
310
+ clean_content = content.split("```json")[1].split("```")[0].strip()
311
+ elif "```" in content:
312
+ clean_content = content.split("```")[1].split("```")[0].strip()
313
+ try:
314
+ json_data = json.loads(clean_content)
315
+ except json.JSONDecodeError:
316
+ # Fallback to parsing raw content
317
+ json_data = json.loads(content)
318
+ else:
319
+ # Log raw LLM response for debugging JSON parse issues
320
+ try:
321
+ json_data = json.loads(content)
322
+ except json.JSONDecodeError as json_err:
323
+ # Truncate content for logging
324
+ content_preview = content[:500] if content else "<empty>"
325
+ if content and len(content) > 700:
326
+ content_preview = f"{content[:500]}...TRUNCATED...{content[-200:]}"
327
+ logger.warning(
328
+ f"JSON parse error from LLM response (attempt {attempt + 1}/{max_retries + 1}): {json_err}\n"
329
+ f" Model: {self.provider}/{self.model}\n"
330
+ f" Content length: {len(content) if content else 0} chars\n"
331
+ f" Content preview: {content_preview!r}\n"
332
+ f" Finish reason: {response.choices[0].finish_reason if response.choices else 'unknown'}"
333
+ )
334
+ # Retry on JSON parse errors
335
+ if attempt < max_retries:
336
+ backoff = min(initial_backoff * (2**attempt), max_backoff)
337
+ await asyncio.sleep(backoff)
338
+ last_exception = json_err
339
+ continue
340
+ else:
341
+ logger.error(f"JSON parse error after {max_retries + 1} attempts, giving up")
342
+ raise
343
+
344
+ if skip_validation:
345
+ result = json_data
346
+ else:
347
+ result = response_format.model_validate(json_data)
348
+ else:
349
+ response = await self._client.chat.completions.create(**call_params)
350
+ result = response.choices[0].message.content
351
+
352
+ # Record token usage metrics
353
+ duration = time.time() - start_time
354
+ usage = response.usage
355
+ input_tokens = usage.prompt_tokens or 0 if usage else 0
356
+ output_tokens = usage.completion_tokens or 0 if usage else 0
357
+ total_tokens = usage.total_tokens or 0 if usage else 0
358
+
359
+ # Record LLM metrics
360
+ metrics = get_metrics_collector()
361
+ metrics.record_llm_call(
362
+ provider=self.provider,
363
+ model=self.model,
364
+ scope=scope,
365
+ duration=duration,
366
+ input_tokens=input_tokens,
367
+ output_tokens=output_tokens,
368
+ success=True,
369
+ )
370
+
371
+ # Log slow calls
372
+ if duration > 10.0 and usage:
373
+ ratio = max(1, output_tokens) / max(1, input_tokens)
374
+ cached_tokens = 0
375
+ if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details:
376
+ cached_tokens = getattr(usage.prompt_tokens_details, "cached_tokens", 0) or 0
377
+ cache_info = f", cached_tokens={cached_tokens}" if cached_tokens > 0 else ""
378
+ logger.info(
379
+ f"slow llm call: scope={scope}, model={self.provider}/{self.model}, "
380
+ f"input_tokens={input_tokens}, output_tokens={output_tokens}, "
381
+ f"total_tokens={total_tokens}{cache_info}, time={duration:.3f}s, ratio out/in={ratio:.2f}"
382
+ )
383
+
384
+ if return_usage:
385
+ token_usage = TokenUsage(
386
+ input_tokens=input_tokens,
387
+ output_tokens=output_tokens,
388
+ total_tokens=total_tokens,
389
+ )
390
+ return result, token_usage
391
+ return result
392
+
393
+ except LengthFinishReasonError as e:
394
+ logger.warning(f"LLM output exceeded token limits: {str(e)}")
395
+ raise OutputTooLongError(
396
+ "LLM output exceeded token limits. Input may need to be split into smaller chunks."
397
+ ) from e
398
+
399
+ except APIConnectionError as e:
400
+ last_exception = e
401
+ status_code = getattr(e, "status_code", None) or getattr(
402
+ getattr(e, "response", None), "status_code", None
403
+ )
404
+ logger.warning(f"APIConnectionError (HTTP {status_code}), attempt {attempt + 1}: {str(e)[:200]}")
405
+ if attempt < max_retries:
406
+ backoff = min(initial_backoff * (2**attempt), max_backoff)
407
+ await asyncio.sleep(backoff)
408
+ continue
409
+ else:
410
+ logger.error(f"Connection error after {max_retries + 1} attempts: {str(e)}")
411
+ raise
412
+
413
+ except APIStatusError as e:
414
+ # Fast fail only on 401 (unauthorized) and 403 (forbidden)
415
+ if e.status_code in (401, 403):
416
+ logger.error(f"Auth error (HTTP {e.status_code}), not retrying: {str(e)}")
417
+ raise
418
+
419
+ # Handle tool_use_failed error - model outputted in tool call format
420
+ if e.status_code == 400 and response_format is not None:
421
+ try:
422
+ error_body = e.body if hasattr(e, "body") else {}
423
+ if isinstance(error_body, dict):
424
+ error_info: dict[str, Any] = error_body.get("error") or {}
425
+ if error_info.get("code") == "tool_use_failed":
426
+ failed_gen = error_info.get("failed_generation", "")
427
+ if failed_gen:
428
+ # Parse tool call format and convert to expected format
429
+ tool_call = json.loads(failed_gen)
430
+ tool_name = tool_call.get("name", "")
431
+ tool_args = tool_call.get("arguments", {})
432
+ converted = {"actions": [{"tool": tool_name, **tool_args}]}
433
+ if skip_validation:
434
+ result = converted
435
+ else:
436
+ result = response_format.model_validate(converted)
437
+
438
+ # Record metrics
439
+ duration = time.time() - start_time
440
+ metrics = get_metrics_collector()
441
+ metrics.record_llm_call(
442
+ provider=self.provider,
443
+ model=self.model,
444
+ scope=scope,
445
+ duration=duration,
446
+ input_tokens=0,
447
+ output_tokens=0,
448
+ success=True,
449
+ )
450
+ if return_usage:
451
+ return result, TokenUsage(input_tokens=0, output_tokens=0, total_tokens=0)
452
+ return result
453
+ except (json.JSONDecodeError, KeyError, TypeError):
454
+ pass # Failed to parse tool_use_failed, continue with normal retry
455
+
456
+ last_exception = e
457
+ if attempt < max_retries:
458
+ backoff = min(initial_backoff * (2**attempt), max_backoff)
459
+ jitter = backoff * 0.2 * (2 * (time.time() % 1) - 1)
460
+ sleep_time = backoff + jitter
461
+ await asyncio.sleep(sleep_time)
462
+ else:
463
+ logger.error(f"API error after {max_retries + 1} attempts: {str(e)}")
464
+ raise
465
+
466
+ except Exception:
467
+ raise
468
+
469
+ if last_exception:
470
+ raise last_exception
471
+ raise RuntimeError("LLM call failed after all retries with no exception captured")
472
+
473
+ async def call_with_tools(
474
+ self,
475
+ messages: list[dict[str, Any]],
476
+ tools: list[dict[str, Any]],
477
+ max_completion_tokens: int | None = None,
478
+ temperature: float | None = None,
479
+ scope: str = "tools",
480
+ max_retries: int = 5,
481
+ initial_backoff: float = 1.0,
482
+ max_backoff: float = 30.0,
483
+ tool_choice: str | dict[str, Any] = "auto",
484
+ ) -> LLMToolCallResult:
485
+ """
486
+ Make an LLM API call with tool/function calling support.
487
+
488
+ Args:
489
+ messages: List of message dicts. Can include tool results with role='tool'.
490
+ tools: List of tool definitions in OpenAI format.
491
+ max_completion_tokens: Maximum tokens in response.
492
+ temperature: Sampling temperature (0.0-2.0).
493
+ scope: Scope identifier for tracking.
494
+ max_retries: Maximum retry attempts.
495
+ initial_backoff: Initial backoff time in seconds.
496
+ max_backoff: Maximum backoff time in seconds.
497
+ tool_choice: How to choose tools - "auto", "none", "required", or specific function.
498
+
499
+ Returns:
500
+ LLMToolCallResult with content and/or tool_calls.
501
+ """
502
+ start_time = time.time()
503
+
504
+ # Build call parameters
505
+ call_params: dict[str, Any] = {
506
+ "model": self.model,
507
+ "messages": messages,
508
+ "tools": tools,
509
+ "tool_choice": tool_choice,
510
+ }
511
+
512
+ if max_completion_tokens is not None:
513
+ call_params["max_completion_tokens"] = max_completion_tokens
514
+ if temperature is not None:
515
+ call_params["temperature"] = temperature
516
+
517
+ # Provider-specific parameters
518
+ if self.provider == "groq":
519
+ call_params["seed"] = DEFAULT_LLM_SEED
520
+
521
+ last_exception = None
522
+
523
+ for attempt in range(max_retries + 1):
524
+ try:
525
+ response = await self._client.chat.completions.create(**call_params)
526
+
527
+ message = response.choices[0].message
528
+ finish_reason = response.choices[0].finish_reason
529
+
530
+ # Extract tool calls if present
531
+ tool_calls: list[LLMToolCall] = []
532
+ if message.tool_calls:
533
+ for tc in message.tool_calls:
534
+ try:
535
+ args = json.loads(tc.function.arguments) if tc.function.arguments else {}
536
+ except json.JSONDecodeError:
537
+ args = {"_raw": tc.function.arguments}
538
+ tool_calls.append(LLMToolCall(id=tc.id, name=tc.function.name, arguments=args))
539
+
540
+ content = message.content
541
+
542
+ # Record metrics
543
+ duration = time.time() - start_time
544
+ usage = response.usage
545
+ input_tokens = usage.prompt_tokens or 0 if usage else 0
546
+ output_tokens = usage.completion_tokens or 0 if usage else 0
547
+
548
+ metrics = get_metrics_collector()
549
+ metrics.record_llm_call(
550
+ provider=self.provider,
551
+ model=self.model,
552
+ scope=scope,
553
+ duration=duration,
554
+ input_tokens=input_tokens,
555
+ output_tokens=output_tokens,
556
+ success=True,
557
+ )
558
+
559
+ return LLMToolCallResult(
560
+ content=content,
561
+ tool_calls=tool_calls,
562
+ finish_reason=finish_reason,
563
+ input_tokens=input_tokens,
564
+ output_tokens=output_tokens,
565
+ )
566
+
567
+ except APIConnectionError as e:
568
+ last_exception = e
569
+ if attempt < max_retries:
570
+ await asyncio.sleep(min(initial_backoff * (2**attempt), max_backoff))
571
+ continue
572
+ raise
573
+
574
+ except APIStatusError as e:
575
+ if e.status_code in (401, 403):
576
+ raise
577
+ last_exception = e
578
+ if attempt < max_retries:
579
+ await asyncio.sleep(min(initial_backoff * (2**attempt), max_backoff))
580
+ continue
581
+ raise
582
+
583
+ except Exception:
584
+ raise
585
+
586
+ if last_exception:
587
+ raise last_exception
588
+ raise RuntimeError("Tool call failed after all retries")
589
+
590
+ async def _call_ollama_native(
591
+ self,
592
+ messages: list[dict[str, str]],
593
+ response_format: Any,
594
+ max_completion_tokens: int | None,
595
+ temperature: float | None,
596
+ max_retries: int,
597
+ initial_backoff: float,
598
+ max_backoff: float,
599
+ skip_validation: bool,
600
+ scope: str = "memory",
601
+ return_usage: bool = False,
602
+ ) -> Any:
603
+ """
604
+ Call Ollama using native API with JSON schema enforcement.
605
+
606
+ Ollama's native API supports passing a full JSON schema in the 'format' parameter,
607
+ which provides better structured output control than the OpenAI-compatible API.
608
+ """
609
+ start_time = time.time()
610
+
611
+ # Get the JSON schema from the Pydantic model
612
+ schema = response_format.model_json_schema() if hasattr(response_format, "model_json_schema") else None
613
+
614
+ # Build the base URL for Ollama's native API
615
+ # Default OpenAI-compatible URL is http://localhost:11434/v1
616
+ # Native API is at http://localhost:11434/api/chat
617
+ base_url = self.base_url or "http://localhost:11434/v1"
618
+ if base_url.endswith("/v1"):
619
+ native_url = base_url[:-3] + "/api/chat"
620
+ else:
621
+ native_url = base_url.rstrip("/") + "/api/chat"
622
+
623
+ # Build request payload
624
+ payload: dict[str, Any] = {
625
+ "model": self.model,
626
+ "messages": messages,
627
+ "stream": False,
628
+ }
629
+
630
+ # Add schema as format parameter for structured output
631
+ if schema:
632
+ payload["format"] = schema
633
+
634
+ # Add optional parameters with optimized defaults for Ollama
635
+ options: dict[str, Any] = {
636
+ "num_ctx": 16384, # 16k context window for larger prompts
637
+ "num_batch": 512, # Optimal batch size for prompt processing
638
+ }
639
+ if max_completion_tokens:
640
+ options["num_predict"] = max_completion_tokens
641
+ if temperature is not None:
642
+ options["temperature"] = temperature
643
+ payload["options"] = options
644
+
645
+ last_exception = None
646
+
647
+ async with httpx.AsyncClient(timeout=300.0) as client:
648
+ for attempt in range(max_retries + 1):
649
+ try:
650
+ response = await client.post(native_url, json=payload)
651
+ response.raise_for_status()
652
+
653
+ result = response.json()
654
+ content = result.get("message", {}).get("content", "")
655
+
656
+ # Parse JSON response
657
+ try:
658
+ json_data = json.loads(content)
659
+ except json.JSONDecodeError as json_err:
660
+ content_preview = content[:500] if content else "<empty>"
661
+ if content and len(content) > 700:
662
+ content_preview = f"{content[:500]}...TRUNCATED...{content[-200:]}"
663
+ logger.warning(
664
+ f"Ollama JSON parse error (attempt {attempt + 1}/{max_retries + 1}): {json_err}\n"
665
+ f" Model: ollama/{self.model}\n"
666
+ f" Content length: {len(content) if content else 0} chars\n"
667
+ f" Content preview: {content_preview!r}"
668
+ )
669
+ if attempt < max_retries:
670
+ backoff = min(initial_backoff * (2**attempt), max_backoff)
671
+ await asyncio.sleep(backoff)
672
+ last_exception = json_err
673
+ continue
674
+ else:
675
+ raise
676
+
677
+ # Extract token usage from Ollama response
678
+ duration = time.time() - start_time
679
+ input_tokens = result.get("prompt_eval_count", 0) or 0
680
+ output_tokens = result.get("eval_count", 0) or 0
681
+ total_tokens = input_tokens + output_tokens
682
+
683
+ # Record LLM metrics
684
+ metrics = get_metrics_collector()
685
+ metrics.record_llm_call(
686
+ provider=self.provider,
687
+ model=self.model,
688
+ scope=scope,
689
+ duration=duration,
690
+ input_tokens=input_tokens,
691
+ output_tokens=output_tokens,
692
+ success=True,
693
+ )
694
+
695
+ # Validate against Pydantic model or return raw JSON
696
+ if skip_validation:
697
+ validated_result = json_data
698
+ else:
699
+ validated_result = response_format.model_validate(json_data)
700
+
701
+ if return_usage:
702
+ token_usage = TokenUsage(
703
+ input_tokens=input_tokens,
704
+ output_tokens=output_tokens,
705
+ total_tokens=total_tokens,
706
+ )
707
+ return validated_result, token_usage
708
+ return validated_result
709
+
710
+ except httpx.HTTPStatusError as e:
711
+ last_exception = e
712
+ if attempt < max_retries:
713
+ logger.warning(
714
+ f"Ollama HTTP error (attempt {attempt + 1}/{max_retries + 1}): {e.response.status_code}"
715
+ )
716
+ backoff = min(initial_backoff * (2**attempt), max_backoff)
717
+ await asyncio.sleep(backoff)
718
+ continue
719
+ else:
720
+ logger.error(f"Ollama HTTP error after {max_retries + 1} attempts: {e}")
721
+ raise
722
+
723
+ except httpx.RequestError as e:
724
+ last_exception = e
725
+ if attempt < max_retries:
726
+ logger.warning(f"Ollama connection error (attempt {attempt + 1}/{max_retries + 1}): {e}")
727
+ backoff = min(initial_backoff * (2**attempt), max_backoff)
728
+ await asyncio.sleep(backoff)
729
+ continue
730
+ else:
731
+ logger.error(f"Ollama connection error after {max_retries + 1} attempts: {e}")
732
+ raise
733
+
734
+ except Exception as e:
735
+ logger.error(f"Unexpected error during Ollama call: {type(e).__name__}: {e}")
736
+ raise
737
+
738
+ if last_exception:
739
+ raise last_exception
740
+ raise RuntimeError("Ollama call failed after all retries")
741
+
742
+ async def cleanup(self) -> None:
743
+ """Clean up resources (close OpenAI client connections)."""
744
+ if hasattr(self, "_client") and self._client:
745
+ await self._client.close()