prompture 0.0.46.dev1__py3-none-any.whl → 0.0.47.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
prompture/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.0.46.dev1'
32
- __version_tuple__ = version_tuple = (0, 0, 46, 'dev1')
31
+ __version__ = version = '0.0.47.dev1'
32
+ __version_tuple__ = version_tuple = (0, 0, 47, 'dev1')
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -2,6 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import json
5
6
  import os
6
7
  from typing import Any
7
8
 
@@ -18,6 +19,7 @@ from .azure_driver import AzureDriver
18
19
  class AsyncAzureDriver(CostMixin, AsyncDriver):
19
20
  supports_json_mode = True
20
21
  supports_json_schema = True
22
+ supports_tool_use = True
21
23
  supports_vision = True
22
24
 
23
25
  MODEL_PRICING = AzureDriver.MODEL_PRICING
@@ -122,3 +124,78 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
122
124
 
123
125
  text = resp.choices[0].message.content
124
126
  return {"text": text, "meta": meta}
127
+
128
+ # ------------------------------------------------------------------
129
+ # Tool use
130
+ # ------------------------------------------------------------------
131
+
132
+ async def generate_messages_with_tools(
133
+ self,
134
+ messages: list[dict[str, Any]],
135
+ tools: list[dict[str, Any]],
136
+ options: dict[str, Any],
137
+ ) -> dict[str, Any]:
138
+ """Generate a response that may include tool calls."""
139
+ if self.client is None:
140
+ raise RuntimeError("openai package (>=1.0.0) with AsyncAzureOpenAI not installed")
141
+
142
+ model = options.get("model", self.model)
143
+ model_config = self._get_model_config("azure", model)
144
+ tokens_param = model_config["tokens_param"]
145
+ supports_temperature = model_config["supports_temperature"]
146
+
147
+ self._validate_model_capabilities("azure", model, using_tool_use=True)
148
+
149
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
150
+
151
+ kwargs: dict[str, Any] = {
152
+ "model": self.deployment_id,
153
+ "messages": messages,
154
+ "tools": tools,
155
+ }
156
+ kwargs[tokens_param] = opts.get("max_tokens", 512)
157
+
158
+ if supports_temperature and "temperature" in opts:
159
+ kwargs["temperature"] = opts["temperature"]
160
+
161
+ resp = await self.client.chat.completions.create(**kwargs)
162
+
163
+ usage = getattr(resp, "usage", None)
164
+ prompt_tokens = getattr(usage, "prompt_tokens", 0)
165
+ completion_tokens = getattr(usage, "completion_tokens", 0)
166
+ total_tokens = getattr(usage, "total_tokens", 0)
167
+ total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
168
+
169
+ meta = {
170
+ "prompt_tokens": prompt_tokens,
171
+ "completion_tokens": completion_tokens,
172
+ "total_tokens": total_tokens,
173
+ "cost": round(total_cost, 6),
174
+ "raw_response": resp.model_dump(),
175
+ "model_name": model,
176
+ "deployment_id": self.deployment_id,
177
+ }
178
+
179
+ choice = resp.choices[0]
180
+ text = choice.message.content or ""
181
+ stop_reason = choice.finish_reason
182
+
183
+ tool_calls_out: list[dict[str, Any]] = []
184
+ if choice.message.tool_calls:
185
+ for tc in choice.message.tool_calls:
186
+ try:
187
+ args = json.loads(tc.function.arguments)
188
+ except (json.JSONDecodeError, TypeError):
189
+ args = {}
190
+ tool_calls_out.append({
191
+ "id": tc.id,
192
+ "name": tc.function.name,
193
+ "arguments": args,
194
+ })
195
+
196
+ return {
197
+ "text": text,
198
+ "meta": meta,
199
+ "tool_calls": tool_calls_out,
200
+ "stop_reason": stop_reason,
201
+ }
@@ -2,6 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import json
5
6
  import os
6
7
  from typing import Any
7
8
 
@@ -14,6 +15,7 @@ from .grok_driver import GrokDriver
14
15
 
15
16
  class AsyncGrokDriver(CostMixin, AsyncDriver):
16
17
  supports_json_mode = True
18
+ supports_tool_use = True
17
19
  supports_vision = True
18
20
 
19
21
  MODEL_PRICING = GrokDriver.MODEL_PRICING
@@ -95,3 +97,91 @@ class AsyncGrokDriver(CostMixin, AsyncDriver):
95
97
 
96
98
  text = resp["choices"][0]["message"]["content"]
97
99
  return {"text": text, "meta": meta}
100
+
101
+ # ------------------------------------------------------------------
102
+ # Tool use
103
+ # ------------------------------------------------------------------
104
+
105
+ async def generate_messages_with_tools(
106
+ self,
107
+ messages: list[dict[str, Any]],
108
+ tools: list[dict[str, Any]],
109
+ options: dict[str, Any],
110
+ ) -> dict[str, Any]:
111
+ """Generate a response that may include tool calls."""
112
+ if not self.api_key:
113
+ raise RuntimeError("GROK_API_KEY environment variable is required")
114
+
115
+ model = options.get("model", self.model)
116
+ model_config = self._get_model_config("grok", model)
117
+ tokens_param = model_config["tokens_param"]
118
+ supports_temperature = model_config["supports_temperature"]
119
+
120
+ self._validate_model_capabilities("grok", model, using_tool_use=True)
121
+
122
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
123
+
124
+ payload: dict[str, Any] = {
125
+ "model": model,
126
+ "messages": messages,
127
+ "tools": tools,
128
+ }
129
+ payload[tokens_param] = opts.get("max_tokens", 512)
130
+
131
+ if supports_temperature and "temperature" in opts:
132
+ payload["temperature"] = opts["temperature"]
133
+
134
+ if "tool_choice" in options:
135
+ payload["tool_choice"] = options["tool_choice"]
136
+
137
+ headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
138
+
139
+ async with httpx.AsyncClient() as client:
140
+ try:
141
+ response = await client.post(
142
+ f"{self.api_base}/chat/completions", headers=headers, json=payload, timeout=120
143
+ )
144
+ response.raise_for_status()
145
+ resp = response.json()
146
+ except httpx.HTTPStatusError as e:
147
+ raise RuntimeError(f"Grok API request failed: {e!s}") from e
148
+ except Exception as e:
149
+ raise RuntimeError(f"Grok API request failed: {e!s}") from e
150
+
151
+ usage = resp.get("usage", {})
152
+ prompt_tokens = usage.get("prompt_tokens", 0)
153
+ completion_tokens = usage.get("completion_tokens", 0)
154
+ total_tokens = usage.get("total_tokens", 0)
155
+ total_cost = self._calculate_cost("grok", model, prompt_tokens, completion_tokens)
156
+
157
+ meta = {
158
+ "prompt_tokens": prompt_tokens,
159
+ "completion_tokens": completion_tokens,
160
+ "total_tokens": total_tokens,
161
+ "cost": round(total_cost, 6),
162
+ "raw_response": resp,
163
+ "model_name": model,
164
+ }
165
+
166
+ choice = resp["choices"][0]
167
+ text = choice["message"].get("content") or ""
168
+ stop_reason = choice.get("finish_reason")
169
+
170
+ tool_calls_out: list[dict[str, Any]] = []
171
+ for tc in choice["message"].get("tool_calls", []):
172
+ try:
173
+ args = json.loads(tc["function"]["arguments"])
174
+ except (json.JSONDecodeError, TypeError):
175
+ args = {}
176
+ tool_calls_out.append({
177
+ "id": tc["id"],
178
+ "name": tc["function"]["name"],
179
+ "arguments": args,
180
+ })
181
+
182
+ return {
183
+ "text": text,
184
+ "meta": meta,
185
+ "tool_calls": tool_calls_out,
186
+ "stop_reason": stop_reason,
187
+ }
@@ -2,6 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import json
5
6
  import os
6
7
  from typing import Any
7
8
 
@@ -17,6 +18,7 @@ from .groq_driver import GroqDriver
17
18
 
18
19
  class AsyncGroqDriver(CostMixin, AsyncDriver):
19
20
  supports_json_mode = True
21
+ supports_tool_use = True
20
22
  supports_vision = True
21
23
 
22
24
  MODEL_PRICING = GroqDriver.MODEL_PRICING
@@ -88,3 +90,77 @@ class AsyncGroqDriver(CostMixin, AsyncDriver):
88
90
 
89
91
  text = resp.choices[0].message.content
90
92
  return {"text": text, "meta": meta}
93
+
94
+ # ------------------------------------------------------------------
95
+ # Tool use
96
+ # ------------------------------------------------------------------
97
+
98
+ async def generate_messages_with_tools(
99
+ self,
100
+ messages: list[dict[str, Any]],
101
+ tools: list[dict[str, Any]],
102
+ options: dict[str, Any],
103
+ ) -> dict[str, Any]:
104
+ """Generate a response that may include tool calls."""
105
+ if self.client is None:
106
+ raise RuntimeError("groq package is not installed")
107
+
108
+ model = options.get("model", self.model)
109
+ model_config = self._get_model_config("groq", model)
110
+ tokens_param = model_config["tokens_param"]
111
+ supports_temperature = model_config["supports_temperature"]
112
+
113
+ self._validate_model_capabilities("groq", model, using_tool_use=True)
114
+
115
+ opts = {"temperature": 0.7, "max_tokens": 512, **options}
116
+
117
+ kwargs: dict[str, Any] = {
118
+ "model": model,
119
+ "messages": messages,
120
+ "tools": tools,
121
+ }
122
+ kwargs[tokens_param] = opts.get("max_tokens", 512)
123
+
124
+ if supports_temperature and "temperature" in opts:
125
+ kwargs["temperature"] = opts["temperature"]
126
+
127
+ resp = await self.client.chat.completions.create(**kwargs)
128
+
129
+ usage = getattr(resp, "usage", None)
130
+ prompt_tokens = getattr(usage, "prompt_tokens", 0)
131
+ completion_tokens = getattr(usage, "completion_tokens", 0)
132
+ total_tokens = getattr(usage, "total_tokens", 0)
133
+ total_cost = self._calculate_cost("groq", model, prompt_tokens, completion_tokens)
134
+
135
+ meta = {
136
+ "prompt_tokens": prompt_tokens,
137
+ "completion_tokens": completion_tokens,
138
+ "total_tokens": total_tokens,
139
+ "cost": round(total_cost, 6),
140
+ "raw_response": resp.model_dump(),
141
+ "model_name": model,
142
+ }
143
+
144
+ choice = resp.choices[0]
145
+ text = choice.message.content or ""
146
+ stop_reason = choice.finish_reason
147
+
148
+ tool_calls_out: list[dict[str, Any]] = []
149
+ if choice.message.tool_calls:
150
+ for tc in choice.message.tool_calls:
151
+ try:
152
+ args = json.loads(tc.function.arguments)
153
+ except (json.JSONDecodeError, TypeError):
154
+ args = {}
155
+ tool_calls_out.append({
156
+ "id": tc.id,
157
+ "name": tc.function.name,
158
+ "arguments": args,
159
+ })
160
+
161
+ return {
162
+ "text": text,
163
+ "meta": meta,
164
+ "tool_calls": tool_calls_out,
165
+ "stop_reason": stop_reason,
166
+ }
@@ -2,8 +2,10 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import json
5
6
  import logging
6
7
  import os
8
+ import uuid
7
9
  from typing import Any
8
10
 
9
11
  import httpx
@@ -16,6 +18,7 @@ logger = logging.getLogger(__name__)
16
18
  class AsyncOllamaDriver(AsyncDriver):
17
19
  supports_json_mode = True
18
20
  supports_json_schema = True
21
+ supports_tool_use = True
19
22
  supports_vision = True
20
23
 
21
24
  MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
@@ -80,6 +83,88 @@ class AsyncOllamaDriver(AsyncDriver):
80
83
 
81
84
  return {"text": response_data.get("response", ""), "meta": meta}
82
85
 
86
+ # ------------------------------------------------------------------
87
+ # Tool use
88
+ # ------------------------------------------------------------------
89
+
90
+ async def generate_messages_with_tools(
91
+ self,
92
+ messages: list[dict[str, Any]],
93
+ tools: list[dict[str, Any]],
94
+ options: dict[str, Any],
95
+ ) -> dict[str, Any]:
96
+ """Generate a response that may include tool calls via Ollama's /api/chat endpoint."""
97
+ merged_options = self.options.copy()
98
+ if options:
99
+ merged_options.update(options)
100
+
101
+ chat_endpoint = self.endpoint.replace("/api/generate", "/api/chat")
102
+
103
+ payload: dict[str, Any] = {
104
+ "model": merged_options.get("model", self.model),
105
+ "messages": messages,
106
+ "tools": tools,
107
+ "stream": False,
108
+ }
109
+
110
+ if "temperature" in merged_options:
111
+ payload["temperature"] = merged_options["temperature"]
112
+ if "top_p" in merged_options:
113
+ payload["top_p"] = merged_options["top_p"]
114
+ if "top_k" in merged_options:
115
+ payload["top_k"] = merged_options["top_k"]
116
+
117
+ async with httpx.AsyncClient() as client:
118
+ try:
119
+ r = await client.post(chat_endpoint, json=payload, timeout=120)
120
+ r.raise_for_status()
121
+ response_data = r.json()
122
+ except httpx.HTTPStatusError as e:
123
+ raise RuntimeError(f"Ollama tool use request failed: {e}") from e
124
+ except Exception as e:
125
+ raise RuntimeError(f"Ollama tool use request failed: {e}") from e
126
+
127
+ prompt_tokens = response_data.get("prompt_eval_count", 0)
128
+ completion_tokens = response_data.get("eval_count", 0)
129
+ total_tokens = prompt_tokens + completion_tokens
130
+
131
+ meta = {
132
+ "prompt_tokens": prompt_tokens,
133
+ "completion_tokens": completion_tokens,
134
+ "total_tokens": total_tokens,
135
+ "cost": 0.0,
136
+ "raw_response": response_data,
137
+ "model_name": merged_options.get("model", self.model),
138
+ }
139
+
140
+ message = response_data.get("message", {})
141
+ text = message.get("content") or ""
142
+ stop_reason = response_data.get("done_reason", "stop")
143
+
144
+ tool_calls_out: list[dict[str, Any]] = []
145
+ for tc in message.get("tool_calls", []):
146
+ func = tc.get("function", {})
147
+ # Ollama returns arguments as a dict already (no JSON string parsing needed)
148
+ args = func.get("arguments", {})
149
+ if isinstance(args, str):
150
+ try:
151
+ args = json.loads(args)
152
+ except (json.JSONDecodeError, TypeError):
153
+ args = {}
154
+ tool_calls_out.append({
155
+ # Ollama does not return tool_call IDs — generate one locally
156
+ "id": f"call_{uuid.uuid4().hex[:24]}",
157
+ "name": func.get("name", ""),
158
+ "arguments": args,
159
+ })
160
+
161
+ return {
162
+ "text": text,
163
+ "meta": meta,
164
+ "tool_calls": tool_calls_out,
165
+ "stop_reason": stop_reason,
166
+ }
167
+
83
168
  async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
84
169
  """Use Ollama's /api/chat endpoint for multi-turn conversations."""
85
170
  messages = self._prepare_messages(messages)
@@ -2,6 +2,7 @@
2
2
  Requires the `openai` package.
3
3
  """
4
4
 
5
+ import json
5
6
  import os
6
7
  from typing import Any
7
8
 
@@ -17,6 +18,7 @@ from ..driver import Driver
17
18
  class AzureDriver(CostMixin, Driver):
18
19
  supports_json_mode = True
19
20
  supports_json_schema = True
21
+ supports_tool_use = True
20
22
  supports_vision = True
21
23
 
22
24
  # Pricing per 1K tokens (adjust if your Azure pricing differs from OpenAI defaults)
@@ -164,3 +166,78 @@ class AzureDriver(CostMixin, Driver):
164
166
 
165
167
  text = resp.choices[0].message.content
166
168
  return {"text": text, "meta": meta}
169
+
170
+ # ------------------------------------------------------------------
171
+ # Tool use
172
+ # ------------------------------------------------------------------
173
+
174
+ def generate_messages_with_tools(
175
+ self,
176
+ messages: list[dict[str, Any]],
177
+ tools: list[dict[str, Any]],
178
+ options: dict[str, Any],
179
+ ) -> dict[str, Any]:
180
+ """Generate a response that may include tool calls."""
181
+ if self.client is None:
182
+ raise RuntimeError("openai package (>=1.0.0) with AzureOpenAI not installed")
183
+
184
+ model = options.get("model", self.model)
185
+ model_config = self._get_model_config("azure", model)
186
+ tokens_param = model_config["tokens_param"]
187
+ supports_temperature = model_config["supports_temperature"]
188
+
189
+ self._validate_model_capabilities("azure", model, using_tool_use=True)
190
+
191
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
192
+
193
+ kwargs: dict[str, Any] = {
194
+ "model": self.deployment_id,
195
+ "messages": messages,
196
+ "tools": tools,
197
+ }
198
+ kwargs[tokens_param] = opts.get("max_tokens", 512)
199
+
200
+ if supports_temperature and "temperature" in opts:
201
+ kwargs["temperature"] = opts["temperature"]
202
+
203
+ resp = self.client.chat.completions.create(**kwargs)
204
+
205
+ usage = getattr(resp, "usage", None)
206
+ prompt_tokens = getattr(usage, "prompt_tokens", 0)
207
+ completion_tokens = getattr(usage, "completion_tokens", 0)
208
+ total_tokens = getattr(usage, "total_tokens", 0)
209
+ total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
210
+
211
+ meta = {
212
+ "prompt_tokens": prompt_tokens,
213
+ "completion_tokens": completion_tokens,
214
+ "total_tokens": total_tokens,
215
+ "cost": round(total_cost, 6),
216
+ "raw_response": resp.model_dump(),
217
+ "model_name": model,
218
+ "deployment_id": self.deployment_id,
219
+ }
220
+
221
+ choice = resp.choices[0]
222
+ text = choice.message.content or ""
223
+ stop_reason = choice.finish_reason
224
+
225
+ tool_calls_out: list[dict[str, Any]] = []
226
+ if choice.message.tool_calls:
227
+ for tc in choice.message.tool_calls:
228
+ try:
229
+ args = json.loads(tc.function.arguments)
230
+ except (json.JSONDecodeError, TypeError):
231
+ args = {}
232
+ tool_calls_out.append({
233
+ "id": tc.id,
234
+ "name": tc.function.name,
235
+ "arguments": args,
236
+ })
237
+
238
+ return {
239
+ "text": text,
240
+ "meta": meta,
241
+ "tool_calls": tool_calls_out,
242
+ "stop_reason": stop_reason,
243
+ }
@@ -2,6 +2,7 @@
2
2
  Requires the `requests` package. Uses GROK_API_KEY env var.
3
3
  """
4
4
 
5
+ import json
5
6
  import os
6
7
  from typing import Any
7
8
 
@@ -13,6 +14,7 @@ from ..driver import Driver
13
14
 
14
15
  class GrokDriver(CostMixin, Driver):
15
16
  supports_json_mode = True
17
+ supports_tool_use = True
16
18
  supports_vision = True
17
19
 
18
20
  # Pricing per 1M tokens based on xAI's documentation
@@ -154,3 +156,86 @@ class GrokDriver(CostMixin, Driver):
154
156
 
155
157
  text = resp["choices"][0]["message"]["content"]
156
158
  return {"text": text, "meta": meta}
159
+
160
+ # ------------------------------------------------------------------
161
+ # Tool use
162
+ # ------------------------------------------------------------------
163
+
164
+ def generate_messages_with_tools(
165
+ self,
166
+ messages: list[dict[str, Any]],
167
+ tools: list[dict[str, Any]],
168
+ options: dict[str, Any],
169
+ ) -> dict[str, Any]:
170
+ """Generate a response that may include tool calls."""
171
+ if not self.api_key:
172
+ raise RuntimeError("GROK_API_KEY environment variable is required")
173
+
174
+ model = options.get("model", self.model)
175
+ model_config = self._get_model_config("grok", model)
176
+ tokens_param = model_config["tokens_param"]
177
+ supports_temperature = model_config["supports_temperature"]
178
+
179
+ self._validate_model_capabilities("grok", model, using_tool_use=True)
180
+
181
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
182
+
183
+ payload: dict[str, Any] = {
184
+ "model": model,
185
+ "messages": messages,
186
+ "tools": tools,
187
+ }
188
+ payload[tokens_param] = opts.get("max_tokens", 512)
189
+
190
+ if supports_temperature and "temperature" in opts:
191
+ payload["temperature"] = opts["temperature"]
192
+
193
+ if "tool_choice" in options:
194
+ payload["tool_choice"] = options["tool_choice"]
195
+
196
+ headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
197
+
198
+ try:
199
+ response = requests.post(f"{self.api_base}/chat/completions", headers=headers, json=payload)
200
+ response.raise_for_status()
201
+ resp = response.json()
202
+ except requests.exceptions.RequestException as e:
203
+ raise RuntimeError(f"Grok API request failed: {e!s}") from e
204
+
205
+ usage = resp.get("usage", {})
206
+ prompt_tokens = usage.get("prompt_tokens", 0)
207
+ completion_tokens = usage.get("completion_tokens", 0)
208
+ total_tokens = usage.get("total_tokens", 0)
209
+ total_cost = self._calculate_cost("grok", model, prompt_tokens, completion_tokens)
210
+
211
+ meta = {
212
+ "prompt_tokens": prompt_tokens,
213
+ "completion_tokens": completion_tokens,
214
+ "total_tokens": total_tokens,
215
+ "cost": round(total_cost, 6),
216
+ "raw_response": resp,
217
+ "model_name": model,
218
+ }
219
+
220
+ choice = resp["choices"][0]
221
+ text = choice["message"].get("content") or ""
222
+ stop_reason = choice.get("finish_reason")
223
+
224
+ tool_calls_out: list[dict[str, Any]] = []
225
+ for tc in choice["message"].get("tool_calls", []):
226
+ try:
227
+ args = json.loads(tc["function"]["arguments"])
228
+ except (json.JSONDecodeError, TypeError):
229
+ args = {}
230
+ tool_calls_out.append({
231
+ "id": tc["id"],
232
+ "name": tc["function"]["name"],
233
+ "arguments": args,
234
+ })
235
+
236
+ return {
237
+ "text": text,
238
+ "meta": meta,
239
+ "tool_calls": tool_calls_out,
240
+ "stop_reason": stop_reason,
241
+ }
@@ -2,6 +2,7 @@
2
2
  Requires the `groq` package. Uses GROQ_API_KEY env var.
3
3
  """
4
4
 
5
+ import json
5
6
  import os
6
7
  from typing import Any
7
8
 
@@ -16,6 +17,7 @@ from ..driver import Driver
16
17
 
17
18
  class GroqDriver(CostMixin, Driver):
18
19
  supports_json_mode = True
20
+ supports_tool_use = True
19
21
  supports_vision = True
20
22
 
21
23
  # Approximate pricing per 1K tokens (to be updated with official pricing)
@@ -122,3 +124,77 @@ class GroqDriver(CostMixin, Driver):
122
124
  # Extract generated text
123
125
  text = resp.choices[0].message.content
124
126
  return {"text": text, "meta": meta}
127
+
128
+ # ------------------------------------------------------------------
129
+ # Tool use
130
+ # ------------------------------------------------------------------
131
+
132
+ def generate_messages_with_tools(
133
+ self,
134
+ messages: list[dict[str, Any]],
135
+ tools: list[dict[str, Any]],
136
+ options: dict[str, Any],
137
+ ) -> dict[str, Any]:
138
+ """Generate a response that may include tool calls."""
139
+ if self.client is None:
140
+ raise RuntimeError("groq package is not installed")
141
+
142
+ model = options.get("model", self.model)
143
+ model_config = self._get_model_config("groq", model)
144
+ tokens_param = model_config["tokens_param"]
145
+ supports_temperature = model_config["supports_temperature"]
146
+
147
+ self._validate_model_capabilities("groq", model, using_tool_use=True)
148
+
149
+ opts = {"temperature": 0.7, "max_tokens": 512, **options}
150
+
151
+ kwargs: dict[str, Any] = {
152
+ "model": model,
153
+ "messages": messages,
154
+ "tools": tools,
155
+ }
156
+ kwargs[tokens_param] = opts.get("max_tokens", 512)
157
+
158
+ if supports_temperature and "temperature" in opts:
159
+ kwargs["temperature"] = opts["temperature"]
160
+
161
+ resp = self.client.chat.completions.create(**kwargs)
162
+
163
+ usage = getattr(resp, "usage", None)
164
+ prompt_tokens = getattr(usage, "prompt_tokens", 0)
165
+ completion_tokens = getattr(usage, "completion_tokens", 0)
166
+ total_tokens = getattr(usage, "total_tokens", 0)
167
+ total_cost = self._calculate_cost("groq", model, prompt_tokens, completion_tokens)
168
+
169
+ meta = {
170
+ "prompt_tokens": prompt_tokens,
171
+ "completion_tokens": completion_tokens,
172
+ "total_tokens": total_tokens,
173
+ "cost": round(total_cost, 6),
174
+ "raw_response": resp.model_dump(),
175
+ "model_name": model,
176
+ }
177
+
178
+ choice = resp.choices[0]
179
+ text = choice.message.content or ""
180
+ stop_reason = choice.finish_reason
181
+
182
+ tool_calls_out: list[dict[str, Any]] = []
183
+ if choice.message.tool_calls:
184
+ for tc in choice.message.tool_calls:
185
+ try:
186
+ args = json.loads(tc.function.arguments)
187
+ except (json.JSONDecodeError, TypeError):
188
+ args = {}
189
+ tool_calls_out.append({
190
+ "id": tc.id,
191
+ "name": tc.function.name,
192
+ "arguments": args,
193
+ })
194
+
195
+ return {
196
+ "text": text,
197
+ "meta": meta,
198
+ "tool_calls": tool_calls_out,
199
+ "stop_reason": stop_reason,
200
+ }
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  import logging
3
3
  import os
4
+ import uuid
4
5
  from collections.abc import Iterator
5
6
  from typing import Any, Optional
6
7
 
@@ -15,6 +16,7 @@ class OllamaDriver(Driver):
15
16
  supports_json_mode = True
16
17
  supports_json_schema = True
17
18
  supports_streaming = True
19
+ supports_tool_use = True
18
20
  supports_vision = True
19
21
 
20
22
  # Ollama is free – costs are always zero.
@@ -131,6 +133,95 @@ class OllamaDriver(Driver):
131
133
  # Ollama returns text in "response"
132
134
  return {"text": response_data.get("response", ""), "meta": meta}
133
135
 
136
+ # ------------------------------------------------------------------
137
+ # Tool use
138
+ # ------------------------------------------------------------------
139
+
140
+ def generate_messages_with_tools(
141
+ self,
142
+ messages: list[dict[str, Any]],
143
+ tools: list[dict[str, Any]],
144
+ options: dict[str, Any],
145
+ ) -> dict[str, Any]:
146
+ """Generate a response that may include tool calls via Ollama's /api/chat endpoint."""
147
+ merged_options = self.options.copy()
148
+ if options:
149
+ merged_options.update(options)
150
+
151
+ chat_endpoint = self.endpoint.replace("/api/generate", "/api/chat")
152
+
153
+ payload: dict[str, Any] = {
154
+ "model": merged_options.get("model", self.model),
155
+ "messages": messages,
156
+ "tools": tools,
157
+ "stream": False,
158
+ }
159
+
160
+ if "temperature" in merged_options:
161
+ payload["temperature"] = merged_options["temperature"]
162
+ if "top_p" in merged_options:
163
+ payload["top_p"] = merged_options["top_p"]
164
+ if "top_k" in merged_options:
165
+ payload["top_k"] = merged_options["top_k"]
166
+
167
+ try:
168
+ logger.debug(f"Sending tool use request to Ollama endpoint: {chat_endpoint}")
169
+ r = requests.post(chat_endpoint, json=payload, timeout=120)
170
+ r.raise_for_status()
171
+ response_data = r.json()
172
+
173
+ if not isinstance(response_data, dict):
174
+ raise ValueError(f"Expected dict response, got {type(response_data)}")
175
+ except requests.exceptions.ConnectionError:
176
+ raise
177
+ except requests.exceptions.HTTPError:
178
+ raise
179
+ except json.JSONDecodeError as e:
180
+ raise json.JSONDecodeError(f"Invalid JSON response from Ollama: {e.msg}", e.doc, e.pos) from e
181
+ except Exception as e:
182
+ raise RuntimeError(f"Ollama tool use request failed: {e}") from e
183
+
184
+ prompt_tokens = response_data.get("prompt_eval_count", 0)
185
+ completion_tokens = response_data.get("eval_count", 0)
186
+ total_tokens = prompt_tokens + completion_tokens
187
+
188
+ meta = {
189
+ "prompt_tokens": prompt_tokens,
190
+ "completion_tokens": completion_tokens,
191
+ "total_tokens": total_tokens,
192
+ "cost": 0.0,
193
+ "raw_response": response_data,
194
+ "model_name": merged_options.get("model", self.model),
195
+ }
196
+
197
+ message = response_data.get("message", {})
198
+ text = message.get("content") or ""
199
+ stop_reason = response_data.get("done_reason", "stop")
200
+
201
+ tool_calls_out: list[dict[str, Any]] = []
202
+ for tc in message.get("tool_calls", []):
203
+ func = tc.get("function", {})
204
+ # Ollama returns arguments as a dict already (no JSON string parsing needed)
205
+ args = func.get("arguments", {})
206
+ if isinstance(args, str):
207
+ try:
208
+ args = json.loads(args)
209
+ except (json.JSONDecodeError, TypeError):
210
+ args = {}
211
+ tool_calls_out.append({
212
+ # Ollama does not return tool_call IDs — generate one locally
213
+ "id": f"call_{uuid.uuid4().hex[:24]}",
214
+ "name": func.get("name", ""),
215
+ "arguments": args,
216
+ })
217
+
218
+ return {
219
+ "text": text,
220
+ "meta": meta,
221
+ "tool_calls": tool_calls_out,
222
+ "stop_reason": stop_reason,
223
+ }
224
+
134
225
  # ------------------------------------------------------------------
135
226
  # Streaming
136
227
  # ------------------------------------------------------------------
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: prompture
3
- Version: 0.0.46.dev1
3
+ Version: 0.0.47.dev1
4
4
  Summary: Ask LLMs to return structured JSON and run cross-model tests. API-first.
5
5
  Author-email: Juan Denis <juan@vene.co>
6
6
  License-Expression: MIT
@@ -1,5 +1,5 @@
1
1
  prompture/__init__.py,sha256=cJnkefDpiyFbU77juw4tXPdKJQWoJ-c6XBFt2v-e5Q4,7455
2
- prompture/_version.py,sha256=Q4g4A2kqcigcQ4G9LQEozxYyQXClh8yXQX-QLy-EQaw,719
2
+ prompture/_version.py,sha256=m4L2kLiZktyjsO5dlv6VYgYlU0JGlYNdugMyoHzVbXk,719
3
3
  prompture/agent.py,sha256=-8qdo_Lz20GGssCe5B_QPxb5Kct71YtKHh5vZgrSYik,34748
4
4
  prompture/agent_types.py,sha256=Icl16PQI-ThGLMFCU43adtQA6cqETbsPn4KssKBI4xc,4664
5
5
  prompture/async_agent.py,sha256=_6_IRb-LGzZxGxfPVy43SIWByUoQfN-5XnUWahVP6r8,33110
@@ -36,32 +36,32 @@ prompture/aio/__init__.py,sha256=bKqTu4Jxld16aP_7SP9wU5au45UBIb041ORo4E4HzVo,181
36
36
  prompture/drivers/__init__.py,sha256=r8wBYGKD7C7v4CqcyRNoaITzGVyxasoiAU6jBYsPZio,8178
37
37
  prompture/drivers/airllm_driver.py,sha256=SaTh7e7Plvuct_TfRqQvsJsKHvvM_3iVqhBtlciM-Kw,3858
38
38
  prompture/drivers/async_airllm_driver.py,sha256=1hIWLXfyyIg9tXaOE22tLJvFyNwHnOi1M5BIKnV8ysk,908
39
- prompture/drivers/async_azure_driver.py,sha256=uXPMStCn5jMnLFpiLYBvTheZm2dNlwKmSLWL3J2s8es,4544
39
+ prompture/drivers/async_azure_driver.py,sha256=s__y_EGQkK7UZjxiyF08uql8F09cnbJ0q7aFuxzreIw,7328
40
40
  prompture/drivers/async_claude_driver.py,sha256=oawbFVVMtRlikQOmu3jRjbdpoeu95JqTF1YHLKO3ybE,10576
41
41
  prompture/drivers/async_google_driver.py,sha256=LTUgCXJjzuTDGzsCsmY2-xH2KdTLJD7htwO49ZNFOdE,13711
42
- prompture/drivers/async_grok_driver.py,sha256=s3bXEGhVrMyw10CowkBhs5522mhipWJyWWu-xVixzyg,3538
43
- prompture/drivers/async_groq_driver.py,sha256=pjAh_bgZWSWaNSm5XrU-u3gRV6YSGwNG5NfAbkYeJ84,3067
42
+ prompture/drivers/async_grok_driver.py,sha256=4oOGT4SzsheulU_QK0ZSqj4-THrFAOCeZwIqIslnW14,6858
43
+ prompture/drivers/async_groq_driver.py,sha256=iORpf0wcqPfS4zKCg4BTWpQCoHV2klkQVTQ1W-jhjUE,5755
44
44
  prompture/drivers/async_hugging_driver.py,sha256=IblxqU6TpNUiigZ0BCgNkAgzpUr2FtPHJOZnOZMnHF0,2152
45
45
  prompture/drivers/async_lmstudio_driver.py,sha256=rPn2qVPm6UE2APzAn7ZHYTELUwr0dQMi8XHv6gAhyH8,5782
46
46
  prompture/drivers/async_local_http_driver.py,sha256=qoigIf-w3_c2dbVdM6m1e2RMAWP4Gk4VzVs5hM3lPvQ,1609
47
47
  prompture/drivers/async_modelscope_driver.py,sha256=wzHYGLf9qE9KXRFZYtN1hZS10Bw1m1Wy6HcmyUD67HM,10170
48
48
  prompture/drivers/async_moonshot_driver.py,sha256=Jl6rGlW3SsneFfmBiDo0RBZQN5c3-08kwax369me01E,14798
49
- prompture/drivers/async_ollama_driver.py,sha256=FaSXtFXrgeVHIe0b90Vg6rGeSTWLpPnjaThh9Ai7qQo,5042
49
+ prompture/drivers/async_ollama_driver.py,sha256=pFtCvh5bHe_qwGy-jIJbyG_zmnPbNbagJCGxCTJMdPU,8244
50
50
  prompture/drivers/async_openai_driver.py,sha256=COa_JE-AgKowKJpmRnfDJp4RSQKZel_7WswxOzvLksM,9044
51
51
  prompture/drivers/async_openrouter_driver.py,sha256=GnOMY67CCV3HV83lCC-CxcngwrUnuc7G-AX7fb1DYpg,10698
52
52
  prompture/drivers/async_registry.py,sha256=JFEnXNPm-8AAUCiNLoKuYBSCYEK-4BmAen5t55QrMvg,5223
53
53
  prompture/drivers/async_zai_driver.py,sha256=zXHxske1CtK8dDTGY-D_kiyZZ_NfceNTJlyTpKn0R4c,10727
54
- prompture/drivers/azure_driver.py,sha256=zwCRNJRm18XEfYeqpFCDLMEEyY0vIGdqrwKk9ng6s4s,5798
54
+ prompture/drivers/azure_driver.py,sha256=gQFffA29gOr-GZ25fNXTokV8-mEmffeV9CT_UBZ3yXc,8565
55
55
  prompture/drivers/claude_driver.py,sha256=C8Av3DXP2x3f35jEv8BRwEM_4vh0cfmLsy3t5dsR6aM,11837
56
56
  prompture/drivers/google_driver.py,sha256=Zck5VUsW37kDgohXz3cUWRmZ88OfhmTpVD-qzAVMp-8,16318
57
- prompture/drivers/grok_driver.py,sha256=CzAXKAbbWmbE8qLFZxxoEhf4Qzbtc9YqDX7kkCsE4dk,5320
58
- prompture/drivers/groq_driver.py,sha256=61LKHhYyRiFkHKbLKFYX10fqjpL_INtPY_Zeb55AV0o,4221
57
+ prompture/drivers/grok_driver.py,sha256=mNfPgOsJR53_5Ep6aYnfKGy7lnZMqN8bxrqKep4CiF0,8408
58
+ prompture/drivers/groq_driver.py,sha256=olr1t7V71ET8Z-7VyRwb75_iYEiZg8-n5qs1edZ2erw,6897
59
59
  prompture/drivers/hugging_driver.py,sha256=gZir3XnM77VfYIdnu3S1pRftlZJM6G3L8bgGn5esg-Q,2346
60
60
  prompture/drivers/lmstudio_driver.py,sha256=9ZnJ1l5LuWAjkH2WKfFjZprNMVIXoSC7qXDNDTxm-tA,6748
61
61
  prompture/drivers/local_http_driver.py,sha256=QJgEf9kAmy8YZ5fb8FHnWuhoDoZYNd8at4jegzNVJH0,1658
62
62
  prompture/drivers/modelscope_driver.py,sha256=yTxTG7j5f7zz4CjbrV8J0VKeoBmxv69F40bfp8nq6AE,10651
63
63
  prompture/drivers/moonshot_driver.py,sha256=MtlvtUUwE4WtzCKo_pJJ5wATB-h2GU4zY9jbGo3a_-g,18264
64
- prompture/drivers/ollama_driver.py,sha256=k9xeUwFp91OrDbjkbYI-F8CDFy5ew-zQ0btXqwbXXWM,10220
64
+ prompture/drivers/ollama_driver.py,sha256=SJtMRtAr8geUB4y5GIZxPr-RJ0C3q7yqigYei2b4luM,13710
65
65
  prompture/drivers/openai_driver.py,sha256=DqdMhxF8M2HdOY5vfsFrz0h23lqBoQlbxV3xUdHvZho,10548
66
66
  prompture/drivers/openrouter_driver.py,sha256=DaG1H99s8GaOgJXZK4TP28HM7U4wiLu9wHXzWZleW_U,12589
67
67
  prompture/drivers/registry.py,sha256=Dg_5w9alnIPKhOnsR9Xspuf5T7roBGu0r_L2Cf-UhXs,9926
@@ -76,9 +76,9 @@ prompture/scaffold/templates/env.example.j2,sha256=eESKr1KWgyrczO6d-nwAhQwSpf_G-
76
76
  prompture/scaffold/templates/main.py.j2,sha256=TEgc5OvsZOEX0JthkSW1NI_yLwgoeVN_x97Ibg-vyWY,2632
77
77
  prompture/scaffold/templates/models.py.j2,sha256=JrZ99GCVK6TKWapskVRSwCssGrTu5cGZ_r46fOhY2GE,858
78
78
  prompture/scaffold/templates/requirements.txt.j2,sha256=m3S5fi1hq9KG9l_9j317rjwWww0a43WMKd8VnUWv2A4,102
79
- prompture-0.0.46.dev1.dist-info/licenses/LICENSE,sha256=0HgDepH7aaHNFhHF-iXuW6_GqDfYPnVkjtiCAZ4yS8I,1060
80
- prompture-0.0.46.dev1.dist-info/METADATA,sha256=-PVUbm089WB89t_CyNBqDWyM4N0Feq2-R9E_-OXSqqE,10842
81
- prompture-0.0.46.dev1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
82
- prompture-0.0.46.dev1.dist-info/entry_points.txt,sha256=AFPG3lJR86g4IJMoWQUW5Ph7G6MLNWG3A2u2Tp9zkp8,48
83
- prompture-0.0.46.dev1.dist-info/top_level.txt,sha256=to86zq_kjfdoLeAxQNr420UWqT0WzkKoZ509J7Qr2t4,10
84
- prompture-0.0.46.dev1.dist-info/RECORD,,
79
+ prompture-0.0.47.dev1.dist-info/licenses/LICENSE,sha256=0HgDepH7aaHNFhHF-iXuW6_GqDfYPnVkjtiCAZ4yS8I,1060
80
+ prompture-0.0.47.dev1.dist-info/METADATA,sha256=gxnbPKPzC1F715GdpLjy6LchTZ3mlQTQHrjnoGUibDQ,10842
81
+ prompture-0.0.47.dev1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
82
+ prompture-0.0.47.dev1.dist-info/entry_points.txt,sha256=AFPG3lJR86g4IJMoWQUW5Ph7G6MLNWG3A2u2Tp9zkp8,48
83
+ prompture-0.0.47.dev1.dist-info/top_level.txt,sha256=to86zq_kjfdoLeAxQNr420UWqT0WzkKoZ509J7Qr2t4,10
84
+ prompture-0.0.47.dev1.dist-info/RECORD,,