prompture 0.0.29.dev8__py3-none-any.whl → 0.0.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. prompture/__init__.py +146 -23
  2. prompture/_version.py +34 -0
  3. prompture/aio/__init__.py +74 -0
  4. prompture/async_conversation.py +607 -0
  5. prompture/async_core.py +803 -0
  6. prompture/async_driver.py +169 -0
  7. prompture/cache.py +469 -0
  8. prompture/callbacks.py +55 -0
  9. prompture/cli.py +63 -4
  10. prompture/conversation.py +631 -0
  11. prompture/core.py +876 -263
  12. prompture/cost_mixin.py +51 -0
  13. prompture/discovery.py +164 -0
  14. prompture/driver.py +168 -5
  15. prompture/drivers/__init__.py +173 -69
  16. prompture/drivers/airllm_driver.py +109 -0
  17. prompture/drivers/async_airllm_driver.py +26 -0
  18. prompture/drivers/async_azure_driver.py +117 -0
  19. prompture/drivers/async_claude_driver.py +107 -0
  20. prompture/drivers/async_google_driver.py +132 -0
  21. prompture/drivers/async_grok_driver.py +91 -0
  22. prompture/drivers/async_groq_driver.py +84 -0
  23. prompture/drivers/async_hugging_driver.py +61 -0
  24. prompture/drivers/async_lmstudio_driver.py +79 -0
  25. prompture/drivers/async_local_http_driver.py +44 -0
  26. prompture/drivers/async_ollama_driver.py +125 -0
  27. prompture/drivers/async_openai_driver.py +96 -0
  28. prompture/drivers/async_openrouter_driver.py +96 -0
  29. prompture/drivers/async_registry.py +129 -0
  30. prompture/drivers/azure_driver.py +36 -9
  31. prompture/drivers/claude_driver.py +251 -34
  32. prompture/drivers/google_driver.py +107 -38
  33. prompture/drivers/grok_driver.py +29 -32
  34. prompture/drivers/groq_driver.py +27 -26
  35. prompture/drivers/hugging_driver.py +6 -6
  36. prompture/drivers/lmstudio_driver.py +26 -13
  37. prompture/drivers/local_http_driver.py +6 -6
  38. prompture/drivers/ollama_driver.py +157 -23
  39. prompture/drivers/openai_driver.py +178 -9
  40. prompture/drivers/openrouter_driver.py +31 -25
  41. prompture/drivers/registry.py +306 -0
  42. prompture/field_definitions.py +106 -96
  43. prompture/logging.py +80 -0
  44. prompture/model_rates.py +217 -0
  45. prompture/runner.py +49 -47
  46. prompture/scaffold/__init__.py +1 -0
  47. prompture/scaffold/generator.py +84 -0
  48. prompture/scaffold/templates/Dockerfile.j2 +12 -0
  49. prompture/scaffold/templates/README.md.j2 +41 -0
  50. prompture/scaffold/templates/config.py.j2 +21 -0
  51. prompture/scaffold/templates/env.example.j2 +8 -0
  52. prompture/scaffold/templates/main.py.j2 +86 -0
  53. prompture/scaffold/templates/models.py.j2 +40 -0
  54. prompture/scaffold/templates/requirements.txt.j2 +5 -0
  55. prompture/server.py +183 -0
  56. prompture/session.py +117 -0
  57. prompture/settings.py +18 -1
  58. prompture/tools.py +219 -267
  59. prompture/tools_schema.py +254 -0
  60. prompture/validator.py +3 -3
  61. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/METADATA +117 -21
  62. prompture-0.0.35.dist-info/RECORD +66 -0
  63. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/WHEEL +1 -1
  64. prompture-0.0.29.dev8.dist-info/RECORD +0 -27
  65. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/entry_points.txt +0 -0
  66. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/licenses/LICENSE +0 -0
  67. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,21 @@
1
1
  """xAI Grok driver.
2
2
  Requires the `requests` package. Uses GROK_API_KEY env var.
3
3
  """
4
+
4
5
  import os
5
- from typing import Any, Dict
6
+ from typing import Any
7
+
6
8
  import requests
7
9
 
10
+ from ..cost_mixin import CostMixin
8
11
  from ..driver import Driver
9
12
 
10
13
 
11
- class GrokDriver(Driver):
14
+ class GrokDriver(CostMixin, Driver):
15
+ supports_json_mode = True
16
+
12
17
  # Pricing per 1M tokens based on xAI's documentation
18
+ _PRICING_UNIT = 1_000_000
13
19
  MODEL_PRICING = {
14
20
  "grok-code-fast-1": {
15
21
  "prompt": 0.20,
@@ -72,19 +78,16 @@ class GrokDriver(Driver):
72
78
  self.model = model
73
79
  self.api_base = "https://api.x.ai/v1"
74
80
 
75
- def generate(self, prompt: str, options: Dict[str, Any]) -> Dict[str, Any]:
76
- """Generate completion using Grok API.
81
+ supports_messages = True
77
82
 
78
- Args:
79
- prompt: Input prompt
80
- options: Generation options
81
-
82
- Returns:
83
- Dict containing generated text and metadata
84
-
85
- Raises:
86
- RuntimeError: If API key is missing or request fails
87
- """
83
+ def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
84
+ messages = [{"role": "user", "content": prompt}]
85
+ return self._do_generate(messages, options)
86
+
87
+ def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
88
+ return self._do_generate(messages, options)
89
+
90
+ def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
88
91
  if not self.api_key:
89
92
  raise RuntimeError("GROK_API_KEY environment variable is required")
90
93
 
@@ -101,7 +104,7 @@ class GrokDriver(Driver):
101
104
  # Base request payload
102
105
  payload = {
103
106
  "model": model,
104
- "messages": [{"role": "user", "content": prompt}],
107
+ "messages": messages,
105
108
  }
106
109
 
107
110
  # Add token limit with correct parameter name
@@ -111,33 +114,27 @@ class GrokDriver(Driver):
111
114
  if supports_temperature and "temperature" in opts:
112
115
  payload["temperature"] = opts["temperature"]
113
116
 
114
- headers = {
115
- "Authorization": f"Bearer {self.api_key}",
116
- "Content-Type": "application/json"
117
- }
117
+ # Native JSON mode support
118
+ if options.get("json_mode"):
119
+ payload["response_format"] = {"type": "json_object"}
120
+
121
+ headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
118
122
 
119
123
  try:
120
- response = requests.post(
121
- f"{self.api_base}/chat/completions",
122
- headers=headers,
123
- json=payload
124
- )
124
+ response = requests.post(f"{self.api_base}/chat/completions", headers=headers, json=payload)
125
125
  response.raise_for_status()
126
126
  resp = response.json()
127
127
  except requests.exceptions.RequestException as e:
128
- raise RuntimeError(f"Grok API request failed: {str(e)}")
128
+ raise RuntimeError(f"Grok API request failed: {e!s}") from e
129
129
 
130
130
  # Extract usage info
131
131
  usage = resp.get("usage", {})
132
132
  prompt_tokens = usage.get("prompt_tokens", 0)
133
- completion_tokens = usage.get("completion_tokens", 0)
133
+ completion_tokens = usage.get("completion_tokens", 0)
134
134
  total_tokens = usage.get("total_tokens", 0)
135
135
 
136
- # Calculate cost
137
- model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
138
- prompt_cost = (prompt_tokens / 1000000) * model_pricing["prompt"]
139
- completion_cost = (completion_tokens / 1000000) * model_pricing["completion"]
140
- total_cost = prompt_cost + completion_cost
136
+ # Calculate cost via shared mixin
137
+ total_cost = self._calculate_cost("grok", model, prompt_tokens, completion_tokens)
141
138
 
142
139
  # Standardized meta object
143
140
  meta = {
@@ -150,4 +147,4 @@ class GrokDriver(Driver):
150
147
  }
151
148
 
152
149
  text = resp["choices"][0]["message"]["content"]
153
- return {"text": text, "meta": meta}
150
+ return {"text": text, "meta": meta}
@@ -1,18 +1,22 @@
1
1
  """Groq driver for prompture.
2
2
  Requires the `groq` package. Uses GROQ_API_KEY env var.
3
3
  """
4
+
4
5
  import os
5
- from typing import Any, Dict
6
+ from typing import Any
6
7
 
7
8
  try:
8
9
  import groq
9
10
  except Exception:
10
11
  groq = None
11
12
 
13
+ from ..cost_mixin import CostMixin
12
14
  from ..driver import Driver
13
15
 
14
16
 
15
- class GroqDriver(Driver):
17
+ class GroqDriver(CostMixin, Driver):
18
+ supports_json_mode = True
19
+
16
20
  # Approximate pricing per 1K tokens (to be updated with official pricing)
17
21
  # Each model entry defines token parameters and temperature support
18
22
  MODEL_PRICING = {
@@ -32,7 +36,7 @@ class GroqDriver(Driver):
32
36
 
33
37
  def __init__(self, api_key: str | None = None, model: str = "llama2-70b-4096"):
34
38
  """Initialize Groq driver.
35
-
39
+
36
40
  Args:
37
41
  api_key: Groq API key (defaults to GROQ_API_KEY env var)
38
42
  model: Model to use (defaults to llama2-70b-4096)
@@ -44,20 +48,16 @@ class GroqDriver(Driver):
44
48
  else:
45
49
  self.client = None
46
50
 
47
- def generate(self, prompt: str, options: Dict[str, Any]) -> Dict[str, Any]:
48
- """Generate completion using Groq API.
49
-
50
- Args:
51
- prompt: Input prompt
52
- options: Generation options
53
-
54
- Returns:
55
- Dict containing generated text and metadata
56
-
57
- Raises:
58
- RuntimeError: If groq package is not installed
59
- groq.error.*: Various Groq API errors
60
- """
51
+ supports_messages = True
52
+
53
+ def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
54
+ messages = [{"role": "user", "content": prompt}]
55
+ return self._do_generate(messages, options)
56
+
57
+ def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
58
+ return self._do_generate(messages, options)
59
+
60
+ def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
61
61
  if self.client is None:
62
62
  raise RuntimeError("groq package is not installed")
63
63
 
@@ -74,7 +74,7 @@ class GroqDriver(Driver):
74
74
  # Base kwargs for API call
75
75
  kwargs = {
76
76
  "model": model,
77
- "messages": [{"role": "user", "content": prompt}],
77
+ "messages": messages,
78
78
  }
79
79
 
80
80
  # Set token limit with correct parameter name
@@ -84,23 +84,24 @@ class GroqDriver(Driver):
84
84
  if supports_temperature and "temperature" in opts:
85
85
  kwargs["temperature"] = opts["temperature"]
86
86
 
87
+ # Native JSON mode support
88
+ if options.get("json_mode"):
89
+ kwargs["response_format"] = {"type": "json_object"}
90
+
87
91
  try:
88
92
  resp = self.client.chat.completions.create(**kwargs)
89
- except Exception as e:
93
+ except Exception:
90
94
  # Re-raise any Groq API errors
91
95
  raise
92
96
 
93
97
  # Extract usage statistics
94
98
  usage = getattr(resp, "usage", None)
95
99
  prompt_tokens = getattr(usage, "prompt_tokens", 0)
96
- completion_tokens = getattr(usage, "completion_tokens", 0)
100
+ completion_tokens = getattr(usage, "completion_tokens", 0)
97
101
  total_tokens = getattr(usage, "total_tokens", 0)
98
102
 
99
- # Calculate costs
100
- model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
101
- prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
102
- completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
103
- total_cost = prompt_cost + completion_cost
103
+ # Calculate cost via shared mixin
104
+ total_cost = self._calculate_cost("groq", model, prompt_tokens, completion_tokens)
104
105
 
105
106
  # Standard metadata object
106
107
  meta = {
@@ -114,4 +115,4 @@ class GroqDriver(Driver):
114
115
 
115
116
  # Extract generated text
116
117
  text = resp.choices[0].message.content
117
- return {"text": text, "meta": meta}
118
+ return {"text": text, "meta": meta}
@@ -1,14 +1,14 @@
1
1
  import os
2
+ from typing import Any
3
+
2
4
  import requests
5
+
3
6
  from ..driver import Driver
4
- from typing import Any, Dict
5
7
 
6
8
 
7
9
  class HuggingFaceDriver(Driver):
8
10
  # Hugging Face is usage-based (credits/subscription), but we set costs to 0 for now.
9
- MODEL_PRICING = {
10
- "default": {"prompt": 0.0, "completion": 0.0}
11
- }
11
+ MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
12
12
 
13
13
  def __init__(self, endpoint: str | None = None, token: str | None = None, model: str = "bert-base-uncased"):
14
14
  self.endpoint = endpoint or os.getenv("HF_ENDPOINT")
@@ -22,7 +22,7 @@ class HuggingFaceDriver(Driver):
22
22
 
23
23
  self.headers = {"Authorization": f"Bearer {self.token}"}
24
24
 
25
- def generate(self, prompt: str, options: Dict[str, Any]) -> Dict[str, Any]:
25
+ def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
26
26
  payload = {
27
27
  "inputs": prompt,
28
28
  "parameters": options, # HF allows temperature, max_new_tokens, etc. here
@@ -32,7 +32,7 @@ class HuggingFaceDriver(Driver):
32
32
  r.raise_for_status()
33
33
  response_data = r.json()
34
34
  except Exception as e:
35
- raise RuntimeError(f"HuggingFaceDriver request failed: {e}")
35
+ raise RuntimeError(f"HuggingFaceDriver request failed: {e}") from e
36
36
 
37
37
  # Different HF models return slightly different response formats
38
38
  # Text-generation models usually return [{"generated_text": "..."}]
@@ -1,26 +1,26 @@
1
- import os
2
1
  import json
3
- import requests
4
2
  import logging
3
+ import os
4
+ from typing import Any, Optional
5
+
6
+ import requests
7
+
5
8
  from ..driver import Driver
6
- from typing import Any, Dict
7
9
 
8
10
  logger = logging.getLogger(__name__)
9
11
 
10
12
 
11
13
  class LMStudioDriver(Driver):
14
+ supports_json_mode = True
15
+
12
16
  # LM Studio is local – costs are always zero.
13
- MODEL_PRICING = {
14
- "default": {"prompt": 0.0, "completion": 0.0}
15
- }
17
+ MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
16
18
 
17
19
  def __init__(self, endpoint: str | None = None, model: str = "deepseek/deepseek-r1-0528-qwen3-8b"):
18
20
  # Allow override via env var
19
- self.endpoint = endpoint or os.getenv(
20
- "LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions"
21
- )
21
+ self.endpoint = endpoint or os.getenv("LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions")
22
22
  self.model = model
23
- self.options: Dict[str, Any] = {}
23
+ self.options: dict[str, Any] = {}
24
24
 
25
25
  # Validate connection to LM Studio server
26
26
  self._validate_connection()
@@ -38,17 +38,30 @@ class LMStudioDriver(Driver):
38
38
  except requests.exceptions.RequestException as e:
39
39
  logger.warning(f"Could not validate connection to LM Studio server: {e}")
40
40
 
41
- def generate(self, prompt: str, options: Dict[str, Any] = None) -> Dict[str, Any]:
41
+ supports_messages = True
42
+
43
+ def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
44
+ messages = [{"role": "user", "content": prompt}]
45
+ return self._do_generate(messages, options)
46
+
47
+ def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
48
+ return self._do_generate(messages, options)
49
+
50
+ def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
42
51
  merged_options = self.options.copy()
43
52
  if options:
44
53
  merged_options.update(options)
45
54
 
46
55
  payload = {
47
56
  "model": merged_options.get("model", self.model),
48
- "messages": [{"role": "user", "content": prompt}],
57
+ "messages": messages,
49
58
  "temperature": merged_options.get("temperature", 0.7),
50
59
  }
51
60
 
61
+ # Native JSON mode support
62
+ if merged_options.get("json_mode"):
63
+ payload["response_format"] = {"type": "json_object"}
64
+
52
65
  try:
53
66
  logger.debug(f"Sending request to LM Studio endpoint: {self.endpoint}")
54
67
  logger.debug(f"Request payload: {payload}")
@@ -70,7 +83,7 @@ class LMStudioDriver(Driver):
70
83
  raise
71
84
  except Exception as e:
72
85
  logger.error(f"Unexpected error in LM Studio request: {e}")
73
- raise RuntimeError(f"LM Studio request failed: {e}")
86
+ raise RuntimeError(f"LM Studio request failed: {e}") from e
74
87
 
75
88
  # Extract text
76
89
  text = response_data["choices"][0]["message"]["content"]
@@ -1,27 +1,27 @@
1
1
  import os
2
+ from typing import Any
3
+
2
4
  import requests
5
+
3
6
  from ..driver import Driver
4
- from typing import Any, Dict
5
7
 
6
8
 
7
9
  class LocalHTTPDriver(Driver):
8
10
  # Default: no cost; extend if your local service has pricing logic
9
- MODEL_PRICING = {
10
- "default": {"prompt": 0.0, "completion": 0.0}
11
- }
11
+ MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
12
12
 
13
13
  def __init__(self, endpoint: str | None = None, model: str = "local-model"):
14
14
  self.endpoint = endpoint or os.getenv("LOCAL_HTTP_ENDPOINT", "http://localhost:8000/generate")
15
15
  self.model = model
16
16
 
17
- def generate(self, prompt: str, options: Dict[str, Any]) -> Dict[str, Any]:
17
+ def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
18
18
  payload = {"prompt": prompt, "options": options}
19
19
  try:
20
20
  r = requests.post(self.endpoint, json=payload, timeout=options.get("timeout", 30))
21
21
  r.raise_for_status()
22
22
  response_data = r.json()
23
23
  except Exception as e:
24
- raise RuntimeError(f"LocalHTTPDriver request failed: {e}")
24
+ raise RuntimeError(f"LocalHTTPDriver request failed: {e}") from e
25
25
 
26
26
  # If the local API already provides {"text": "...", "meta": {...}}, just return it
27
27
  if "text" in response_data and "meta" in response_data:
@@ -1,38 +1,40 @@
1
- import os
2
1
  import json
3
- import requests
4
2
  import logging
3
+ import os
4
+ from collections.abc import Iterator
5
+ from typing import Any, Optional
6
+
7
+ import requests
8
+
5
9
  from ..driver import Driver
6
- from typing import Any, Dict
7
10
 
8
11
  logger = logging.getLogger(__name__)
9
12
 
10
13
 
11
14
  class OllamaDriver(Driver):
15
+ supports_json_mode = True
16
+ supports_streaming = True
17
+
12
18
  # Ollama is free – costs are always zero.
13
- MODEL_PRICING = {
14
- "default": {"prompt": 0.0, "completion": 0.0}
15
- }
19
+ MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
16
20
 
17
21
  def __init__(self, endpoint: str | None = None, model: str = "llama3"):
18
22
  # Allow override via env var
19
- self.endpoint = endpoint or os.getenv(
20
- "OLLAMA_ENDPOINT", "http://localhost:11434/api/generate"
21
- )
23
+ self.endpoint = endpoint or os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434/api/generate")
22
24
  self.model = model
23
25
  self.options = {} # Initialize empty options dict
24
-
26
+
25
27
  # Validate connection to Ollama server
26
28
  self._validate_connection()
27
-
29
+
28
30
  def _validate_connection(self):
29
31
  """Validate connection to the Ollama server."""
30
32
  try:
31
33
  # Send a simple HEAD request to check if server is accessible
32
34
  # Use the base API endpoint without the specific path
33
- base_url = self.endpoint.split('/api/')[0]
35
+ base_url = self.endpoint.split("/api/")[0]
34
36
  health_url = f"{base_url}/api/version"
35
-
37
+
36
38
  logger.debug(f"Validating connection to Ollama server at: {health_url}")
37
39
  response = requests.head(health_url, timeout=5)
38
40
  response.raise_for_status()
@@ -42,7 +44,9 @@ class OllamaDriver(Driver):
42
44
  # We don't raise an error here to allow for delayed server startup
43
45
  # The actual error will be raised when generate() is called
44
46
 
45
- def generate(self, prompt: str, options: Dict[str, Any] = None) -> Dict[str, Any]:
47
+ supports_messages = True
48
+
49
+ def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
46
50
  # Merge instance options with call-specific options
47
51
  merged_options = self.options.copy()
48
52
  if options:
@@ -54,6 +58,10 @@ class OllamaDriver(Driver):
54
58
  "stream": False,
55
59
  }
56
60
 
61
+ # Native JSON mode support
62
+ if merged_options.get("json_mode"):
63
+ payload["format"] = "json"
64
+
57
65
  # Add any Ollama-specific options from merged_options
58
66
  if "temperature" in merged_options:
59
67
  payload["temperature"] = merged_options["temperature"]
@@ -65,21 +73,21 @@ class OllamaDriver(Driver):
65
73
  try:
66
74
  logger.debug(f"Sending request to Ollama endpoint: {self.endpoint}")
67
75
  logger.debug(f"Request payload: {payload}")
68
-
76
+
69
77
  r = requests.post(self.endpoint, json=payload, timeout=120)
70
78
  logger.debug(f"Response status code: {r.status_code}")
71
-
79
+
72
80
  r.raise_for_status()
73
-
81
+
74
82
  response_text = r.text
75
83
  logger.debug(f"Raw response text: {response_text}")
76
-
84
+
77
85
  response_data = r.json()
78
86
  logger.debug(f"Parsed response data: {response_data}")
79
-
87
+
80
88
  if not isinstance(response_data, dict):
81
89
  raise ValueError(f"Expected dict response, got {type(response_data)}")
82
-
90
+
83
91
  except requests.exceptions.ConnectionError as e:
84
92
  logger.error(f"Connection error to Ollama endpoint: {e}")
85
93
  # Preserve original exception
@@ -91,11 +99,11 @@ class OllamaDriver(Driver):
91
99
  except json.JSONDecodeError as e:
92
100
  logger.error(f"Failed to decode JSON response: {e}")
93
101
  # Re-raise JSONDecodeError with more context
94
- raise json.JSONDecodeError(f"Invalid JSON response from Ollama: {e.msg}", e.doc, e.pos)
102
+ raise json.JSONDecodeError(f"Invalid JSON response from Ollama: {e.msg}", e.doc, e.pos) from e
95
103
  except Exception as e:
96
104
  logger.error(f"Unexpected error in Ollama request: {e}")
97
105
  # Only wrap unknown exceptions in RuntimeError
98
- raise RuntimeError(f"Ollama request failed: {e}")
106
+ raise RuntimeError(f"Ollama request failed: {e}") from e
99
107
 
100
108
  # Extract token counts
101
109
  prompt_tokens = response_data.get("prompt_eval_count", 0)
@@ -113,4 +121,130 @@ class OllamaDriver(Driver):
113
121
  }
114
122
 
115
123
  # Ollama returns text in "response"
116
- return {"text": response_data.get("response", ""), "meta": meta}
124
+ return {"text": response_data.get("response", ""), "meta": meta}
125
+
126
+ # ------------------------------------------------------------------
127
+ # Streaming
128
+ # ------------------------------------------------------------------
129
+
130
+ def generate_messages_stream(
131
+ self,
132
+ messages: list[dict[str, Any]],
133
+ options: dict[str, Any],
134
+ ) -> Iterator[dict[str, Any]]:
135
+ """Yield response chunks via Ollama streaming API."""
136
+ merged_options = self.options.copy()
137
+ if options:
138
+ merged_options.update(options)
139
+
140
+ chat_endpoint = self.endpoint.replace("/api/generate", "/api/chat")
141
+
142
+ payload: dict[str, Any] = {
143
+ "model": merged_options.get("model", self.model),
144
+ "messages": messages,
145
+ "stream": True,
146
+ }
147
+
148
+ if merged_options.get("json_mode"):
149
+ payload["format"] = "json"
150
+ if "temperature" in merged_options:
151
+ payload["temperature"] = merged_options["temperature"]
152
+ if "top_p" in merged_options:
153
+ payload["top_p"] = merged_options["top_p"]
154
+ if "top_k" in merged_options:
155
+ payload["top_k"] = merged_options["top_k"]
156
+
157
+ full_text = ""
158
+ prompt_tokens = 0
159
+ completion_tokens = 0
160
+
161
+ r = requests.post(chat_endpoint, json=payload, timeout=120, stream=True)
162
+ r.raise_for_status()
163
+
164
+ for line in r.iter_lines():
165
+ if not line:
166
+ continue
167
+ chunk = json.loads(line)
168
+ if chunk.get("done"):
169
+ prompt_tokens = chunk.get("prompt_eval_count", 0)
170
+ completion_tokens = chunk.get("eval_count", 0)
171
+ else:
172
+ content = chunk.get("message", {}).get("content", "")
173
+ if content:
174
+ full_text += content
175
+ yield {"type": "delta", "text": content}
176
+
177
+ total_tokens = prompt_tokens + completion_tokens
178
+ yield {
179
+ "type": "done",
180
+ "text": full_text,
181
+ "meta": {
182
+ "prompt_tokens": prompt_tokens,
183
+ "completion_tokens": completion_tokens,
184
+ "total_tokens": total_tokens,
185
+ "cost": 0.0,
186
+ "raw_response": {},
187
+ "model_name": merged_options.get("model", self.model),
188
+ },
189
+ }
190
+
191
+ def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
192
+ """Use Ollama's /api/chat endpoint for multi-turn conversations."""
193
+ merged_options = self.options.copy()
194
+ if options:
195
+ merged_options.update(options)
196
+
197
+ # Derive the chat endpoint from the generate endpoint
198
+ chat_endpoint = self.endpoint.replace("/api/generate", "/api/chat")
199
+
200
+ payload: dict[str, Any] = {
201
+ "model": merged_options.get("model", self.model),
202
+ "messages": messages,
203
+ "stream": False,
204
+ }
205
+
206
+ # Native JSON mode support
207
+ if merged_options.get("json_mode"):
208
+ payload["format"] = "json"
209
+
210
+ if "temperature" in merged_options:
211
+ payload["temperature"] = merged_options["temperature"]
212
+ if "top_p" in merged_options:
213
+ payload["top_p"] = merged_options["top_p"]
214
+ if "top_k" in merged_options:
215
+ payload["top_k"] = merged_options["top_k"]
216
+
217
+ try:
218
+ logger.debug(f"Sending chat request to Ollama endpoint: {chat_endpoint}")
219
+ r = requests.post(chat_endpoint, json=payload, timeout=120)
220
+ r.raise_for_status()
221
+ response_data = r.json()
222
+
223
+ if not isinstance(response_data, dict):
224
+ raise ValueError(f"Expected dict response, got {type(response_data)}")
225
+ except requests.exceptions.ConnectionError:
226
+ raise
227
+ except requests.exceptions.HTTPError:
228
+ raise
229
+ except json.JSONDecodeError as e:
230
+ raise json.JSONDecodeError(f"Invalid JSON response from Ollama: {e.msg}", e.doc, e.pos) from e
231
+ except Exception as e:
232
+ raise RuntimeError(f"Ollama chat request failed: {e}") from e
233
+
234
+ prompt_tokens = response_data.get("prompt_eval_count", 0)
235
+ completion_tokens = response_data.get("eval_count", 0)
236
+ total_tokens = prompt_tokens + completion_tokens
237
+
238
+ meta = {
239
+ "prompt_tokens": prompt_tokens,
240
+ "completion_tokens": completion_tokens,
241
+ "total_tokens": total_tokens,
242
+ "cost": 0.0,
243
+ "raw_response": response_data,
244
+ "model_name": merged_options.get("model", self.model),
245
+ }
246
+
247
+ # Chat endpoint returns response in message.content
248
+ message = response_data.get("message", {})
249
+ text = message.get("content", "")
250
+ return {"text": text, "meta": meta}