prompture 0.0.46__py3-none-any.whl → 0.0.47__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -98,7 +98,12 @@ class AsyncLMStudioDriver(AsyncDriver):
98
98
  if "choices" not in response_data or not response_data["choices"]:
99
99
  raise ValueError(f"Unexpected response format: {response_data}")
100
100
 
101
- text = response_data["choices"][0]["message"]["content"]
101
+ message = response_data["choices"][0]["message"]
102
+ text = message.get("content") or ""
103
+ reasoning_content = message.get("reasoning_content")
104
+
105
+ if not text and reasoning_content:
106
+ text = reasoning_content
102
107
 
103
108
  usage = response_data.get("usage", {})
104
109
  prompt_tokens = usage.get("prompt_tokens", 0)
@@ -114,7 +119,10 @@ class AsyncLMStudioDriver(AsyncDriver):
114
119
  "model_name": merged_options.get("model", self.model),
115
120
  }
116
121
 
117
- return {"text": text, "meta": meta}
122
+ result: dict[str, Any] = {"text": text, "meta": meta}
123
+ if reasoning_content is not None:
124
+ result["reasoning_content"] = reasoning_content
125
+ return result
118
126
 
119
127
  # -- Model management (LM Studio 0.4.0+) ----------------------------------
120
128
 
@@ -138,10 +138,11 @@ class AsyncMoonshotDriver(CostMixin, AsyncDriver):
138
138
 
139
139
  message = resp["choices"][0]["message"]
140
140
  text = message.get("content") or ""
141
+ reasoning_content = message.get("reasoning_content")
141
142
 
142
143
  # Reasoning models may return content in reasoning_content when content is empty
143
- if not text and message.get("reasoning_content"):
144
- text = message["reasoning_content"]
144
+ if not text and reasoning_content:
145
+ text = reasoning_content
145
146
 
146
147
  # Structured output fallback: if we used json_schema mode and got an
147
148
  # empty response, retry with json_object mode and schema in the prompt.
@@ -184,8 +185,9 @@ class AsyncMoonshotDriver(CostMixin, AsyncDriver):
184
185
  resp = fb_resp
185
186
  fb_message = fb_resp["choices"][0]["message"]
186
187
  text = fb_message.get("content") or ""
187
- if not text and fb_message.get("reasoning_content"):
188
- text = fb_message["reasoning_content"]
188
+ reasoning_content = fb_message.get("reasoning_content")
189
+ if not text and reasoning_content:
190
+ text = reasoning_content
189
191
 
190
192
  total_cost = self._calculate_cost("moonshot", model, prompt_tokens, completion_tokens)
191
193
 
@@ -198,7 +200,10 @@ class AsyncMoonshotDriver(CostMixin, AsyncDriver):
198
200
  "model_name": model,
199
201
  }
200
202
 
201
- return {"text": text, "meta": meta}
203
+ result: dict[str, Any] = {"text": text, "meta": meta}
204
+ if reasoning_content is not None:
205
+ result["reasoning_content"] = reasoning_content
206
+ return result
202
207
 
203
208
  # ------------------------------------------------------------------
204
209
  # Tool use
@@ -271,11 +276,12 @@ class AsyncMoonshotDriver(CostMixin, AsyncDriver):
271
276
  }
272
277
 
273
278
  choice = resp["choices"][0]
274
- text = choice["message"].get("content") or ""
279
+ message = choice["message"]
280
+ text = message.get("content") or ""
275
281
  stop_reason = choice.get("finish_reason")
276
282
 
277
283
  tool_calls_out: list[dict[str, Any]] = []
278
- for tc in choice["message"].get("tool_calls", []):
284
+ for tc in message.get("tool_calls", []):
279
285
  try:
280
286
  args = json.loads(tc["function"]["arguments"])
281
287
  except (json.JSONDecodeError, TypeError):
@@ -288,13 +294,21 @@ class AsyncMoonshotDriver(CostMixin, AsyncDriver):
288
294
  }
289
295
  )
290
296
 
291
- return {
297
+ result: dict[str, Any] = {
292
298
  "text": text,
293
299
  "meta": meta,
294
300
  "tool_calls": tool_calls_out,
295
301
  "stop_reason": stop_reason,
296
302
  }
297
303
 
304
+ # Preserve reasoning_content for reasoning models so the
305
+ # conversation loop can include it when sending the assistant
306
+ # message back (Moonshot requires it on subsequent requests).
307
+ if message.get("reasoning_content") is not None:
308
+ result["reasoning_content"] = message["reasoning_content"]
309
+
310
+ return result
311
+
298
312
  # ------------------------------------------------------------------
299
313
  # Streaming
300
314
  # ------------------------------------------------------------------
@@ -325,6 +339,7 @@ class AsyncMoonshotDriver(CostMixin, AsyncDriver):
325
339
  data["temperature"] = opts["temperature"]
326
340
 
327
341
  full_text = ""
342
+ full_reasoning = ""
328
343
  prompt_tokens = 0
329
344
  completion_tokens = 0
330
345
 
@@ -359,9 +374,11 @@ class AsyncMoonshotDriver(CostMixin, AsyncDriver):
359
374
  if choices:
360
375
  delta = choices[0].get("delta", {})
361
376
  content = delta.get("content") or ""
362
- # Reasoning models stream thinking via reasoning_content
363
- if not content:
364
- content = delta.get("reasoning_content") or ""
377
+ reasoning_chunk = delta.get("reasoning_content") or ""
378
+ if reasoning_chunk:
379
+ full_reasoning += reasoning_chunk
380
+ if not content and reasoning_chunk:
381
+ content = reasoning_chunk
365
382
  if content:
366
383
  full_text += content
367
384
  yield {"type": "delta", "text": content}
@@ -369,7 +386,7 @@ class AsyncMoonshotDriver(CostMixin, AsyncDriver):
369
386
  total_tokens = prompt_tokens + completion_tokens
370
387
  total_cost = self._calculate_cost("moonshot", model, prompt_tokens, completion_tokens)
371
388
 
372
- yield {
389
+ done_chunk: dict[str, Any] = {
373
390
  "type": "done",
374
391
  "text": full_text,
375
392
  "meta": {
@@ -381,3 +398,6 @@ class AsyncMoonshotDriver(CostMixin, AsyncDriver):
381
398
  "model_name": model,
382
399
  },
383
400
  }
401
+ if full_reasoning:
402
+ done_chunk["reasoning_content"] = full_reasoning
403
+ yield done_chunk
@@ -2,8 +2,10 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import json
5
6
  import logging
6
7
  import os
8
+ import uuid
7
9
  from typing import Any
8
10
 
9
11
  import httpx
@@ -16,6 +18,7 @@ logger = logging.getLogger(__name__)
16
18
  class AsyncOllamaDriver(AsyncDriver):
17
19
  supports_json_mode = True
18
20
  supports_json_schema = True
21
+ supports_tool_use = True
19
22
  supports_vision = True
20
23
 
21
24
  MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
@@ -80,6 +83,88 @@ class AsyncOllamaDriver(AsyncDriver):
80
83
 
81
84
  return {"text": response_data.get("response", ""), "meta": meta}
82
85
 
86
+ # ------------------------------------------------------------------
87
+ # Tool use
88
+ # ------------------------------------------------------------------
89
+
90
+ async def generate_messages_with_tools(
91
+ self,
92
+ messages: list[dict[str, Any]],
93
+ tools: list[dict[str, Any]],
94
+ options: dict[str, Any],
95
+ ) -> dict[str, Any]:
96
+ """Generate a response that may include tool calls via Ollama's /api/chat endpoint."""
97
+ merged_options = self.options.copy()
98
+ if options:
99
+ merged_options.update(options)
100
+
101
+ chat_endpoint = self.endpoint.replace("/api/generate", "/api/chat")
102
+
103
+ payload: dict[str, Any] = {
104
+ "model": merged_options.get("model", self.model),
105
+ "messages": messages,
106
+ "tools": tools,
107
+ "stream": False,
108
+ }
109
+
110
+ if "temperature" in merged_options:
111
+ payload["temperature"] = merged_options["temperature"]
112
+ if "top_p" in merged_options:
113
+ payload["top_p"] = merged_options["top_p"]
114
+ if "top_k" in merged_options:
115
+ payload["top_k"] = merged_options["top_k"]
116
+
117
+ async with httpx.AsyncClient() as client:
118
+ try:
119
+ r = await client.post(chat_endpoint, json=payload, timeout=120)
120
+ r.raise_for_status()
121
+ response_data = r.json()
122
+ except httpx.HTTPStatusError as e:
123
+ raise RuntimeError(f"Ollama tool use request failed: {e}") from e
124
+ except Exception as e:
125
+ raise RuntimeError(f"Ollama tool use request failed: {e}") from e
126
+
127
+ prompt_tokens = response_data.get("prompt_eval_count", 0)
128
+ completion_tokens = response_data.get("eval_count", 0)
129
+ total_tokens = prompt_tokens + completion_tokens
130
+
131
+ meta = {
132
+ "prompt_tokens": prompt_tokens,
133
+ "completion_tokens": completion_tokens,
134
+ "total_tokens": total_tokens,
135
+ "cost": 0.0,
136
+ "raw_response": response_data,
137
+ "model_name": merged_options.get("model", self.model),
138
+ }
139
+
140
+ message = response_data.get("message", {})
141
+ text = message.get("content") or ""
142
+ stop_reason = response_data.get("done_reason", "stop")
143
+
144
+ tool_calls_out: list[dict[str, Any]] = []
145
+ for tc in message.get("tool_calls", []):
146
+ func = tc.get("function", {})
147
+ # Ollama returns arguments as a dict already (no JSON string parsing needed)
148
+ args = func.get("arguments", {})
149
+ if isinstance(args, str):
150
+ try:
151
+ args = json.loads(args)
152
+ except (json.JSONDecodeError, TypeError):
153
+ args = {}
154
+ tool_calls_out.append({
155
+ # Ollama does not return tool_call IDs — generate one locally
156
+ "id": f"call_{uuid.uuid4().hex[:24]}",
157
+ "name": func.get("name", ""),
158
+ "arguments": args,
159
+ })
160
+
161
+ return {
162
+ "text": text,
163
+ "meta": meta,
164
+ "tool_calls": tool_calls_out,
165
+ "stop_reason": stop_reason,
166
+ }
167
+
83
168
  async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
84
169
  """Use Ollama's /api/chat endpoint for multi-turn conversations."""
85
170
  messages = self._prepare_messages(messages)
@@ -122,8 +122,17 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
122
122
  "model_name": model,
123
123
  }
124
124
 
125
- text = resp["choices"][0]["message"]["content"]
126
- return {"text": text, "meta": meta}
125
+ message = resp["choices"][0]["message"]
126
+ text = message.get("content") or ""
127
+ reasoning_content = message.get("reasoning_content")
128
+
129
+ if not text and reasoning_content:
130
+ text = reasoning_content
131
+
132
+ result: dict[str, Any] = {"text": text, "meta": meta}
133
+ if reasoning_content is not None:
134
+ result["reasoning_content"] = reasoning_content
135
+ return result
127
136
 
128
137
  # ------------------------------------------------------------------
129
138
  # Tool use
@@ -196,18 +205,23 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
196
205
  args = json.loads(tc["function"]["arguments"])
197
206
  except (json.JSONDecodeError, TypeError):
198
207
  args = {}
199
- tool_calls_out.append({
200
- "id": tc["id"],
201
- "name": tc["function"]["name"],
202
- "arguments": args,
203
- })
208
+ tool_calls_out.append(
209
+ {
210
+ "id": tc["id"],
211
+ "name": tc["function"]["name"],
212
+ "arguments": args,
213
+ }
214
+ )
204
215
 
205
- return {
216
+ result: dict[str, Any] = {
206
217
  "text": text,
207
218
  "meta": meta,
208
219
  "tool_calls": tool_calls_out,
209
220
  "stop_reason": stop_reason,
210
221
  }
222
+ if choice["message"].get("reasoning_content") is not None:
223
+ result["reasoning_content"] = choice["message"]["reasoning_content"]
224
+ return result
211
225
 
212
226
  # ------------------------------------------------------------------
213
227
  # Streaming
@@ -238,21 +252,25 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
238
252
  data["temperature"] = opts["temperature"]
239
253
 
240
254
  full_text = ""
255
+ full_reasoning = ""
241
256
  prompt_tokens = 0
242
257
  completion_tokens = 0
243
258
 
244
- async with httpx.AsyncClient() as client, client.stream(
245
- "POST",
246
- f"{self.base_url}/chat/completions",
247
- headers=self.headers,
248
- json=data,
249
- timeout=120,
250
- ) as response:
259
+ async with (
260
+ httpx.AsyncClient() as client,
261
+ client.stream(
262
+ "POST",
263
+ f"{self.base_url}/chat/completions",
264
+ headers=self.headers,
265
+ json=data,
266
+ timeout=120,
267
+ ) as response,
268
+ ):
251
269
  response.raise_for_status()
252
270
  async for line in response.aiter_lines():
253
271
  if not line or not line.startswith("data: "):
254
272
  continue
255
- payload = line[len("data: "):]
273
+ payload = line[len("data: ") :]
256
274
  if payload.strip() == "[DONE]":
257
275
  break
258
276
  try:
@@ -270,6 +288,11 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
270
288
  if choices:
271
289
  delta = choices[0].get("delta", {})
272
290
  content = delta.get("content", "")
291
+ reasoning_chunk = delta.get("reasoning_content") or ""
292
+ if reasoning_chunk:
293
+ full_reasoning += reasoning_chunk
294
+ if not content and reasoning_chunk:
295
+ content = reasoning_chunk
273
296
  if content:
274
297
  full_text += content
275
298
  yield {"type": "delta", "text": content}
@@ -277,7 +300,7 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
277
300
  total_tokens = prompt_tokens + completion_tokens
278
301
  total_cost = self._calculate_cost("openrouter", model, prompt_tokens, completion_tokens)
279
302
 
280
- yield {
303
+ done_chunk: dict[str, Any] = {
281
304
  "type": "done",
282
305
  "text": full_text,
283
306
  "meta": {
@@ -289,3 +312,6 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
289
312
  "model_name": model,
290
313
  },
291
314
  }
315
+ if full_reasoning:
316
+ done_chunk["reasoning_content"] = full_reasoning
317
+ yield done_chunk
@@ -2,6 +2,7 @@
2
2
  Requires the `openai` package.
3
3
  """
4
4
 
5
+ import json
5
6
  import os
6
7
  from typing import Any
7
8
 
@@ -17,6 +18,7 @@ from ..driver import Driver
17
18
  class AzureDriver(CostMixin, Driver):
18
19
  supports_json_mode = True
19
20
  supports_json_schema = True
21
+ supports_tool_use = True
20
22
  supports_vision = True
21
23
 
22
24
  # Pricing per 1K tokens (adjust if your Azure pricing differs from OpenAI defaults)
@@ -164,3 +166,78 @@ class AzureDriver(CostMixin, Driver):
164
166
 
165
167
  text = resp.choices[0].message.content
166
168
  return {"text": text, "meta": meta}
169
+
170
+ # ------------------------------------------------------------------
171
+ # Tool use
172
+ # ------------------------------------------------------------------
173
+
174
+ def generate_messages_with_tools(
175
+ self,
176
+ messages: list[dict[str, Any]],
177
+ tools: list[dict[str, Any]],
178
+ options: dict[str, Any],
179
+ ) -> dict[str, Any]:
180
+ """Generate a response that may include tool calls."""
181
+ if self.client is None:
182
+ raise RuntimeError("openai package (>=1.0.0) with AzureOpenAI not installed")
183
+
184
+ model = options.get("model", self.model)
185
+ model_config = self._get_model_config("azure", model)
186
+ tokens_param = model_config["tokens_param"]
187
+ supports_temperature = model_config["supports_temperature"]
188
+
189
+ self._validate_model_capabilities("azure", model, using_tool_use=True)
190
+
191
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
192
+
193
+ kwargs: dict[str, Any] = {
194
+ "model": self.deployment_id,
195
+ "messages": messages,
196
+ "tools": tools,
197
+ }
198
+ kwargs[tokens_param] = opts.get("max_tokens", 512)
199
+
200
+ if supports_temperature and "temperature" in opts:
201
+ kwargs["temperature"] = opts["temperature"]
202
+
203
+ resp = self.client.chat.completions.create(**kwargs)
204
+
205
+ usage = getattr(resp, "usage", None)
206
+ prompt_tokens = getattr(usage, "prompt_tokens", 0)
207
+ completion_tokens = getattr(usage, "completion_tokens", 0)
208
+ total_tokens = getattr(usage, "total_tokens", 0)
209
+ total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
210
+
211
+ meta = {
212
+ "prompt_tokens": prompt_tokens,
213
+ "completion_tokens": completion_tokens,
214
+ "total_tokens": total_tokens,
215
+ "cost": round(total_cost, 6),
216
+ "raw_response": resp.model_dump(),
217
+ "model_name": model,
218
+ "deployment_id": self.deployment_id,
219
+ }
220
+
221
+ choice = resp.choices[0]
222
+ text = choice.message.content or ""
223
+ stop_reason = choice.finish_reason
224
+
225
+ tool_calls_out: list[dict[str, Any]] = []
226
+ if choice.message.tool_calls:
227
+ for tc in choice.message.tool_calls:
228
+ try:
229
+ args = json.loads(tc.function.arguments)
230
+ except (json.JSONDecodeError, TypeError):
231
+ args = {}
232
+ tool_calls_out.append({
233
+ "id": tc.id,
234
+ "name": tc.function.name,
235
+ "arguments": args,
236
+ })
237
+
238
+ return {
239
+ "text": text,
240
+ "meta": meta,
241
+ "tool_calls": tool_calls_out,
242
+ "stop_reason": stop_reason,
243
+ }
@@ -2,6 +2,7 @@
2
2
  Requires the `requests` package. Uses GROK_API_KEY env var.
3
3
  """
4
4
 
5
+ import json
5
6
  import os
6
7
  from typing import Any
7
8
 
@@ -13,6 +14,7 @@ from ..driver import Driver
13
14
 
14
15
  class GrokDriver(CostMixin, Driver):
15
16
  supports_json_mode = True
17
+ supports_tool_use = True
16
18
  supports_vision = True
17
19
 
18
20
  # Pricing per 1M tokens based on xAI's documentation
@@ -152,5 +154,102 @@ class GrokDriver(CostMixin, Driver):
152
154
  "model_name": model,
153
155
  }
154
156
 
155
- text = resp["choices"][0]["message"]["content"]
156
- return {"text": text, "meta": meta}
157
+ message = resp["choices"][0]["message"]
158
+ text = message.get("content") or ""
159
+ reasoning_content = message.get("reasoning_content")
160
+
161
+ if not text and reasoning_content:
162
+ text = reasoning_content
163
+
164
+ result: dict[str, Any] = {"text": text, "meta": meta}
165
+ if reasoning_content is not None:
166
+ result["reasoning_content"] = reasoning_content
167
+ return result
168
+
169
+ # ------------------------------------------------------------------
170
+ # Tool use
171
+ # ------------------------------------------------------------------
172
+
173
+ def generate_messages_with_tools(
174
+ self,
175
+ messages: list[dict[str, Any]],
176
+ tools: list[dict[str, Any]],
177
+ options: dict[str, Any],
178
+ ) -> dict[str, Any]:
179
+ """Generate a response that may include tool calls."""
180
+ if not self.api_key:
181
+ raise RuntimeError("GROK_API_KEY environment variable is required")
182
+
183
+ model = options.get("model", self.model)
184
+ model_config = self._get_model_config("grok", model)
185
+ tokens_param = model_config["tokens_param"]
186
+ supports_temperature = model_config["supports_temperature"]
187
+
188
+ self._validate_model_capabilities("grok", model, using_tool_use=True)
189
+
190
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
191
+
192
+ payload: dict[str, Any] = {
193
+ "model": model,
194
+ "messages": messages,
195
+ "tools": tools,
196
+ }
197
+ payload[tokens_param] = opts.get("max_tokens", 512)
198
+
199
+ if supports_temperature and "temperature" in opts:
200
+ payload["temperature"] = opts["temperature"]
201
+
202
+ if "tool_choice" in options:
203
+ payload["tool_choice"] = options["tool_choice"]
204
+
205
+ headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
206
+
207
+ try:
208
+ response = requests.post(f"{self.api_base}/chat/completions", headers=headers, json=payload)
209
+ response.raise_for_status()
210
+ resp = response.json()
211
+ except requests.exceptions.RequestException as e:
212
+ raise RuntimeError(f"Grok API request failed: {e!s}") from e
213
+
214
+ usage = resp.get("usage", {})
215
+ prompt_tokens = usage.get("prompt_tokens", 0)
216
+ completion_tokens = usage.get("completion_tokens", 0)
217
+ total_tokens = usage.get("total_tokens", 0)
218
+ total_cost = self._calculate_cost("grok", model, prompt_tokens, completion_tokens)
219
+
220
+ meta = {
221
+ "prompt_tokens": prompt_tokens,
222
+ "completion_tokens": completion_tokens,
223
+ "total_tokens": total_tokens,
224
+ "cost": round(total_cost, 6),
225
+ "raw_response": resp,
226
+ "model_name": model,
227
+ }
228
+
229
+ choice = resp["choices"][0]
230
+ text = choice["message"].get("content") or ""
231
+ stop_reason = choice.get("finish_reason")
232
+
233
+ tool_calls_out: list[dict[str, Any]] = []
234
+ for tc in choice["message"].get("tool_calls", []):
235
+ try:
236
+ args = json.loads(tc["function"]["arguments"])
237
+ except (json.JSONDecodeError, TypeError):
238
+ args = {}
239
+ tool_calls_out.append(
240
+ {
241
+ "id": tc["id"],
242
+ "name": tc["function"]["name"],
243
+ "arguments": args,
244
+ }
245
+ )
246
+
247
+ result: dict[str, Any] = {
248
+ "text": text,
249
+ "meta": meta,
250
+ "tool_calls": tool_calls_out,
251
+ "stop_reason": stop_reason,
252
+ }
253
+ if choice["message"].get("reasoning_content") is not None:
254
+ result["reasoning_content"] = choice["message"]["reasoning_content"]
255
+ return result
@@ -2,6 +2,7 @@
2
2
  Requires the `groq` package. Uses GROQ_API_KEY env var.
3
3
  """
4
4
 
5
+ import json
5
6
  import os
6
7
  from typing import Any
7
8
 
@@ -16,6 +17,7 @@ from ..driver import Driver
16
17
 
17
18
  class GroqDriver(CostMixin, Driver):
18
19
  supports_json_mode = True
20
+ supports_tool_use = True
19
21
  supports_vision = True
20
22
 
21
23
  # Approximate pricing per 1K tokens (to be updated with official pricing)
@@ -120,5 +122,93 @@ class GroqDriver(CostMixin, Driver):
120
122
  }
121
123
 
122
124
  # Extract generated text
123
- text = resp.choices[0].message.content
124
- return {"text": text, "meta": meta}
125
+ text = resp.choices[0].message.content or ""
126
+ reasoning_content = getattr(resp.choices[0].message, "reasoning_content", None)
127
+
128
+ if not text and reasoning_content:
129
+ text = reasoning_content
130
+
131
+ result: dict[str, Any] = {"text": text, "meta": meta}
132
+ if reasoning_content is not None:
133
+ result["reasoning_content"] = reasoning_content
134
+ return result
135
+
136
+ # ------------------------------------------------------------------
137
+ # Tool use
138
+ # ------------------------------------------------------------------
139
+
140
+ def generate_messages_with_tools(
141
+ self,
142
+ messages: list[dict[str, Any]],
143
+ tools: list[dict[str, Any]],
144
+ options: dict[str, Any],
145
+ ) -> dict[str, Any]:
146
+ """Generate a response that may include tool calls."""
147
+ if self.client is None:
148
+ raise RuntimeError("groq package is not installed")
149
+
150
+ model = options.get("model", self.model)
151
+ model_config = self._get_model_config("groq", model)
152
+ tokens_param = model_config["tokens_param"]
153
+ supports_temperature = model_config["supports_temperature"]
154
+
155
+ self._validate_model_capabilities("groq", model, using_tool_use=True)
156
+
157
+ opts = {"temperature": 0.7, "max_tokens": 512, **options}
158
+
159
+ kwargs: dict[str, Any] = {
160
+ "model": model,
161
+ "messages": messages,
162
+ "tools": tools,
163
+ }
164
+ kwargs[tokens_param] = opts.get("max_tokens", 512)
165
+
166
+ if supports_temperature and "temperature" in opts:
167
+ kwargs["temperature"] = opts["temperature"]
168
+
169
+ resp = self.client.chat.completions.create(**kwargs)
170
+
171
+ usage = getattr(resp, "usage", None)
172
+ prompt_tokens = getattr(usage, "prompt_tokens", 0)
173
+ completion_tokens = getattr(usage, "completion_tokens", 0)
174
+ total_tokens = getattr(usage, "total_tokens", 0)
175
+ total_cost = self._calculate_cost("groq", model, prompt_tokens, completion_tokens)
176
+
177
+ meta = {
178
+ "prompt_tokens": prompt_tokens,
179
+ "completion_tokens": completion_tokens,
180
+ "total_tokens": total_tokens,
181
+ "cost": round(total_cost, 6),
182
+ "raw_response": resp.model_dump(),
183
+ "model_name": model,
184
+ }
185
+
186
+ choice = resp.choices[0]
187
+ text = choice.message.content or ""
188
+ stop_reason = choice.finish_reason
189
+
190
+ tool_calls_out: list[dict[str, Any]] = []
191
+ if choice.message.tool_calls:
192
+ for tc in choice.message.tool_calls:
193
+ try:
194
+ args = json.loads(tc.function.arguments)
195
+ except (json.JSONDecodeError, TypeError):
196
+ args = {}
197
+ tool_calls_out.append(
198
+ {
199
+ "id": tc.id,
200
+ "name": tc.function.name,
201
+ "arguments": args,
202
+ }
203
+ )
204
+
205
+ result: dict[str, Any] = {
206
+ "text": text,
207
+ "meta": meta,
208
+ "tool_calls": tool_calls_out,
209
+ "stop_reason": stop_reason,
210
+ }
211
+ reasoning_content = getattr(choice.message, "reasoning_content", None)
212
+ if reasoning_content is not None:
213
+ result["reasoning_content"] = reasoning_content
214
+ return result