kader 0.1.6__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/app.py +108 -36
- cli/app.tcss +20 -0
- cli/llm_factory.py +165 -0
- cli/utils.py +19 -11
- cli/widgets/conversation.py +50 -4
- kader/__init__.py +2 -0
- kader/agent/agents.py +8 -0
- kader/agent/base.py +84 -7
- kader/config.py +10 -2
- kader/memory/types.py +60 -0
- kader/prompts/__init__.py +9 -1
- kader/prompts/agent_prompts.py +28 -0
- kader/prompts/templates/executor_agent.j2 +70 -0
- kader/prompts/templates/kader_planner.j2 +71 -0
- kader/providers/__init__.py +2 -0
- kader/providers/google.py +690 -0
- kader/providers/ollama.py +2 -2
- kader/tools/__init__.py +26 -0
- kader/tools/agent.py +452 -0
- kader/tools/filesys.py +1 -1
- kader/tools/todo.py +43 -2
- kader/utils/__init__.py +10 -0
- kader/utils/checkpointer.py +371 -0
- kader/utils/context_aggregator.py +347 -0
- kader/workflows/__init__.py +13 -0
- kader/workflows/base.py +71 -0
- kader/workflows/planner_executor.py +251 -0
- {kader-0.1.6.dist-info → kader-1.1.0.dist-info}/METADATA +39 -1
- kader-1.1.0.dist-info/RECORD +56 -0
- kader-0.1.6.dist-info/RECORD +0 -45
- {kader-0.1.6.dist-info → kader-1.1.0.dist-info}/WHEEL +0 -0
- {kader-0.1.6.dist-info → kader-1.1.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,690 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Google LLM Provider implementation.
|
|
3
|
+
|
|
4
|
+
Provides synchronous and asynchronous access to Google Gemini models
|
|
5
|
+
via the Google Gen AI SDK.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
from typing import AsyncIterator, Iterator
|
|
10
|
+
|
|
11
|
+
from google import genai
|
|
12
|
+
from google.genai import types
|
|
13
|
+
|
|
14
|
+
# Import config to ensure ~/.kader/.env is loaded
|
|
15
|
+
import kader.config # noqa: F401
|
|
16
|
+
|
|
17
|
+
from .base import (
|
|
18
|
+
BaseLLMProvider,
|
|
19
|
+
CostInfo,
|
|
20
|
+
LLMResponse,
|
|
21
|
+
Message,
|
|
22
|
+
ModelConfig,
|
|
23
|
+
ModelInfo,
|
|
24
|
+
ModelPricing,
|
|
25
|
+
StreamChunk,
|
|
26
|
+
Usage,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
# Pricing data for Gemini models (per 1M tokens, in USD)
|
|
30
|
+
GEMINI_PRICING: dict[str, ModelPricing] = {
|
|
31
|
+
"gemini-2.5-flash": ModelPricing(
|
|
32
|
+
input_cost_per_million=0.15,
|
|
33
|
+
output_cost_per_million=0.60,
|
|
34
|
+
cached_input_cost_per_million=0.0375,
|
|
35
|
+
),
|
|
36
|
+
"gemini-2.5-flash-preview-05-20": ModelPricing(
|
|
37
|
+
input_cost_per_million=0.15,
|
|
38
|
+
output_cost_per_million=0.60,
|
|
39
|
+
cached_input_cost_per_million=0.0375,
|
|
40
|
+
),
|
|
41
|
+
"gemini-2.5-pro": ModelPricing(
|
|
42
|
+
input_cost_per_million=1.25,
|
|
43
|
+
output_cost_per_million=10.00,
|
|
44
|
+
cached_input_cost_per_million=0.3125,
|
|
45
|
+
),
|
|
46
|
+
"gemini-2.5-pro-preview-05-06": ModelPricing(
|
|
47
|
+
input_cost_per_million=1.25,
|
|
48
|
+
output_cost_per_million=10.00,
|
|
49
|
+
cached_input_cost_per_million=0.3125,
|
|
50
|
+
),
|
|
51
|
+
"gemini-2.0-flash": ModelPricing(
|
|
52
|
+
input_cost_per_million=0.10,
|
|
53
|
+
output_cost_per_million=0.40,
|
|
54
|
+
cached_input_cost_per_million=0.025,
|
|
55
|
+
),
|
|
56
|
+
"gemini-2.0-flash-lite": ModelPricing(
|
|
57
|
+
input_cost_per_million=0.075,
|
|
58
|
+
output_cost_per_million=0.30,
|
|
59
|
+
cached_input_cost_per_million=0.01875,
|
|
60
|
+
),
|
|
61
|
+
"gemini-1.5-flash": ModelPricing(
|
|
62
|
+
input_cost_per_million=0.075,
|
|
63
|
+
output_cost_per_million=0.30,
|
|
64
|
+
cached_input_cost_per_million=0.01875,
|
|
65
|
+
),
|
|
66
|
+
"gemini-1.5-pro": ModelPricing(
|
|
67
|
+
input_cost_per_million=1.25,
|
|
68
|
+
output_cost_per_million=5.00,
|
|
69
|
+
cached_input_cost_per_million=0.3125,
|
|
70
|
+
),
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class GoogleProvider(BaseLLMProvider):
|
|
75
|
+
"""
|
|
76
|
+
Google LLM Provider.
|
|
77
|
+
|
|
78
|
+
Provides access to Google Gemini models with full support
|
|
79
|
+
for synchronous and asynchronous operations, including streaming.
|
|
80
|
+
|
|
81
|
+
The API key is loaded from (in order of priority):
|
|
82
|
+
1. The `api_key` parameter passed to the constructor
|
|
83
|
+
2. The GEMINI_API_KEY environment variable (loaded from ~/.kader/.env)
|
|
84
|
+
3. The GOOGLE_API_KEY environment variable
|
|
85
|
+
|
|
86
|
+
Example:
|
|
87
|
+
provider = GoogleProvider(model="gemini-2.5-flash")
|
|
88
|
+
response = provider.invoke([Message.user("Hello!")])
|
|
89
|
+
print(response.content)
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(
|
|
93
|
+
self,
|
|
94
|
+
model: str,
|
|
95
|
+
api_key: str | None = None,
|
|
96
|
+
default_config: ModelConfig | None = None,
|
|
97
|
+
) -> None:
|
|
98
|
+
"""
|
|
99
|
+
Initialize the Google provider.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
model: The Gemini model identifier (e.g., "gemini-2.5-flash")
|
|
103
|
+
api_key: Optional API key. If not provided, uses GEMINI_API_KEY
|
|
104
|
+
from ~/.kader/.env or GOOGLE_API_KEY environment variable.
|
|
105
|
+
default_config: Default configuration for all requests
|
|
106
|
+
"""
|
|
107
|
+
super().__init__(model=model, default_config=default_config)
|
|
108
|
+
|
|
109
|
+
# Resolve API key: parameter > GEMINI_API_KEY > GOOGLE_API_KEY
|
|
110
|
+
if api_key is None:
|
|
111
|
+
api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get(
|
|
112
|
+
"GOOGLE_API_KEY"
|
|
113
|
+
)
|
|
114
|
+
# Filter out empty strings from the .env default
|
|
115
|
+
if api_key == "":
|
|
116
|
+
api_key = None
|
|
117
|
+
|
|
118
|
+
self._api_key = api_key
|
|
119
|
+
self._client = genai.Client(api_key=api_key) if api_key else genai.Client()
|
|
120
|
+
|
|
121
|
+
def _convert_messages(
|
|
122
|
+
self, messages: list[Message]
|
|
123
|
+
) -> tuple[list[types.Content], str | None]:
|
|
124
|
+
"""
|
|
125
|
+
Convert Message objects to Google GenAI Content format.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Tuple of (contents list, system_instruction if present)
|
|
129
|
+
"""
|
|
130
|
+
contents: list[types.Content] = []
|
|
131
|
+
system_instruction: str | None = None
|
|
132
|
+
|
|
133
|
+
for msg in messages:
|
|
134
|
+
if msg.role == "system":
|
|
135
|
+
# System messages are handled separately in Google's API
|
|
136
|
+
system_instruction = msg.content
|
|
137
|
+
elif msg.role == "user":
|
|
138
|
+
contents.append(
|
|
139
|
+
types.Content(
|
|
140
|
+
role="user",
|
|
141
|
+
parts=[types.Part.from_text(text=msg.content)],
|
|
142
|
+
)
|
|
143
|
+
)
|
|
144
|
+
elif msg.role == "assistant":
|
|
145
|
+
parts: list[types.Part] = []
|
|
146
|
+
if msg.content:
|
|
147
|
+
parts.append(types.Part.from_text(text=msg.content))
|
|
148
|
+
if msg.tool_calls:
|
|
149
|
+
for tc in msg.tool_calls:
|
|
150
|
+
parts.append(
|
|
151
|
+
types.Part.from_function_call(
|
|
152
|
+
name=tc["function"]["name"],
|
|
153
|
+
args=tc["function"]["arguments"]
|
|
154
|
+
if isinstance(tc["function"]["arguments"], dict)
|
|
155
|
+
else {},
|
|
156
|
+
)
|
|
157
|
+
)
|
|
158
|
+
contents.append(types.Content(role="model", parts=parts))
|
|
159
|
+
elif msg.role == "tool":
|
|
160
|
+
contents.append(
|
|
161
|
+
types.Content(
|
|
162
|
+
role="tool",
|
|
163
|
+
parts=[
|
|
164
|
+
types.Part.from_function_response(
|
|
165
|
+
name=msg.name or "tool",
|
|
166
|
+
response={"result": msg.content},
|
|
167
|
+
)
|
|
168
|
+
],
|
|
169
|
+
)
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
return contents, system_instruction
|
|
173
|
+
|
|
174
|
+
def _convert_config_to_generate_config(
|
|
175
|
+
self, config: ModelConfig, system_instruction: str | None = None
|
|
176
|
+
) -> types.GenerateContentConfig:
|
|
177
|
+
"""Convert ModelConfig to Google GenerateContentConfig."""
|
|
178
|
+
generate_config = types.GenerateContentConfig(
|
|
179
|
+
temperature=config.temperature if config.temperature != 1.0 else None,
|
|
180
|
+
max_output_tokens=config.max_tokens,
|
|
181
|
+
top_p=config.top_p if config.top_p != 1.0 else None,
|
|
182
|
+
top_k=config.top_k,
|
|
183
|
+
stop_sequences=config.stop_sequences,
|
|
184
|
+
system_instruction=system_instruction,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Handle tools - convert from dict format to Google's FunctionDeclaration format
|
|
188
|
+
if config.tools:
|
|
189
|
+
google_tools = self._convert_tools_to_google_format(config.tools)
|
|
190
|
+
if google_tools:
|
|
191
|
+
generate_config.tools = google_tools
|
|
192
|
+
|
|
193
|
+
# Handle response format
|
|
194
|
+
if config.response_format:
|
|
195
|
+
resp_format_type = config.response_format.get("type")
|
|
196
|
+
if resp_format_type == "json_object":
|
|
197
|
+
generate_config.response_mime_type = "application/json"
|
|
198
|
+
|
|
199
|
+
return generate_config
|
|
200
|
+
|
|
201
|
+
def _convert_tools_to_google_format(
|
|
202
|
+
self, tools: list[dict]
|
|
203
|
+
) -> list[types.Tool] | None:
|
|
204
|
+
"""
|
|
205
|
+
Convert tool definitions from dict format to Google's Tool format.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
tools: List of tool definitions (from to_google_format or to_openai_format)
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
List of Google Tool objects, or None if no valid tools
|
|
212
|
+
"""
|
|
213
|
+
if not tools:
|
|
214
|
+
return None
|
|
215
|
+
|
|
216
|
+
function_declarations: list[types.FunctionDeclaration] = []
|
|
217
|
+
|
|
218
|
+
for tool in tools:
|
|
219
|
+
# Handle OpenAI format (type: "function", function: {...})
|
|
220
|
+
if tool.get("type") == "function" and "function" in tool:
|
|
221
|
+
func_def = tool["function"]
|
|
222
|
+
name = func_def.get("name", "")
|
|
223
|
+
description = func_def.get("description", "")
|
|
224
|
+
parameters = func_def.get("parameters", {})
|
|
225
|
+
# Handle Google format (directly has name, description, parameters)
|
|
226
|
+
elif "name" in tool:
|
|
227
|
+
name = tool.get("name", "")
|
|
228
|
+
description = tool.get("description", "")
|
|
229
|
+
parameters = tool.get("parameters", {})
|
|
230
|
+
else:
|
|
231
|
+
continue
|
|
232
|
+
|
|
233
|
+
if not name:
|
|
234
|
+
continue
|
|
235
|
+
|
|
236
|
+
# Create FunctionDeclaration
|
|
237
|
+
try:
|
|
238
|
+
func_decl = types.FunctionDeclaration(
|
|
239
|
+
name=name,
|
|
240
|
+
description=description,
|
|
241
|
+
parameters=parameters if parameters else None,
|
|
242
|
+
)
|
|
243
|
+
function_declarations.append(func_decl)
|
|
244
|
+
except Exception:
|
|
245
|
+
# Skip invalid function declarations
|
|
246
|
+
continue
|
|
247
|
+
|
|
248
|
+
if not function_declarations:
|
|
249
|
+
return None
|
|
250
|
+
|
|
251
|
+
# Wrap all function declarations in a single Tool
|
|
252
|
+
return [types.Tool(function_declarations=function_declarations)]
|
|
253
|
+
|
|
254
|
+
def _parse_response(self, response, model: str) -> LLMResponse:
|
|
255
|
+
"""Parse Google GenAI response to LLMResponse."""
|
|
256
|
+
# Extract content
|
|
257
|
+
content = ""
|
|
258
|
+
tool_calls = None
|
|
259
|
+
|
|
260
|
+
if response.candidates and len(response.candidates) > 0:
|
|
261
|
+
candidate = response.candidates[0]
|
|
262
|
+
if candidate.content and candidate.content.parts:
|
|
263
|
+
text_parts = []
|
|
264
|
+
function_calls = []
|
|
265
|
+
|
|
266
|
+
for part in candidate.content.parts:
|
|
267
|
+
if hasattr(part, "text") and part.text:
|
|
268
|
+
text_parts.append(part.text)
|
|
269
|
+
if hasattr(part, "function_call") and part.function_call:
|
|
270
|
+
fc = part.function_call
|
|
271
|
+
function_calls.append(
|
|
272
|
+
{
|
|
273
|
+
"id": f"call_{len(function_calls)}",
|
|
274
|
+
"type": "function",
|
|
275
|
+
"function": {
|
|
276
|
+
"name": fc.name,
|
|
277
|
+
"arguments": dict(fc.args) if fc.args else {},
|
|
278
|
+
},
|
|
279
|
+
}
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
content = "".join(text_parts)
|
|
283
|
+
if function_calls:
|
|
284
|
+
tool_calls = function_calls
|
|
285
|
+
|
|
286
|
+
# Extract usage
|
|
287
|
+
usage = Usage()
|
|
288
|
+
if hasattr(response, "usage_metadata") and response.usage_metadata:
|
|
289
|
+
usage = Usage(
|
|
290
|
+
prompt_tokens=getattr(response.usage_metadata, "prompt_token_count", 0)
|
|
291
|
+
or 0,
|
|
292
|
+
completion_tokens=getattr(
|
|
293
|
+
response.usage_metadata, "candidates_token_count", 0
|
|
294
|
+
)
|
|
295
|
+
or 0,
|
|
296
|
+
cached_tokens=getattr(
|
|
297
|
+
response.usage_metadata, "cached_content_token_count", 0
|
|
298
|
+
)
|
|
299
|
+
or 0,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
# Determine finish reason
|
|
303
|
+
finish_reason = "stop"
|
|
304
|
+
if response.candidates and len(response.candidates) > 0:
|
|
305
|
+
candidate = response.candidates[0]
|
|
306
|
+
if hasattr(candidate, "finish_reason") and candidate.finish_reason:
|
|
307
|
+
reason = str(candidate.finish_reason).lower()
|
|
308
|
+
if "stop" in reason:
|
|
309
|
+
finish_reason = "stop"
|
|
310
|
+
elif "length" in reason or "max_tokens" in reason:
|
|
311
|
+
finish_reason = "length"
|
|
312
|
+
elif "tool" in reason or "function" in reason:
|
|
313
|
+
finish_reason = "tool_calls"
|
|
314
|
+
elif "safety" in reason or "filter" in reason:
|
|
315
|
+
finish_reason = "content_filter"
|
|
316
|
+
|
|
317
|
+
# Calculate cost
|
|
318
|
+
cost = self.estimate_cost(usage)
|
|
319
|
+
|
|
320
|
+
return LLMResponse(
|
|
321
|
+
content=content,
|
|
322
|
+
model=model,
|
|
323
|
+
usage=usage,
|
|
324
|
+
finish_reason=finish_reason,
|
|
325
|
+
cost=cost,
|
|
326
|
+
tool_calls=tool_calls,
|
|
327
|
+
raw_response=response,
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
def _parse_stream_chunk(
|
|
331
|
+
self, chunk, accumulated_content: str, model: str
|
|
332
|
+
) -> StreamChunk:
|
|
333
|
+
"""Parse streaming chunk to StreamChunk."""
|
|
334
|
+
delta = ""
|
|
335
|
+
tool_calls = None
|
|
336
|
+
|
|
337
|
+
if chunk.candidates and len(chunk.candidates) > 0:
|
|
338
|
+
candidate = chunk.candidates[0]
|
|
339
|
+
if candidate.content and candidate.content.parts:
|
|
340
|
+
for part in candidate.content.parts:
|
|
341
|
+
if hasattr(part, "text") and part.text:
|
|
342
|
+
delta = part.text
|
|
343
|
+
if hasattr(part, "function_call") and part.function_call:
|
|
344
|
+
fc = part.function_call
|
|
345
|
+
tool_calls = [
|
|
346
|
+
{
|
|
347
|
+
"id": "call_0",
|
|
348
|
+
"type": "function",
|
|
349
|
+
"function": {
|
|
350
|
+
"name": fc.name,
|
|
351
|
+
"arguments": dict(fc.args) if fc.args else {},
|
|
352
|
+
},
|
|
353
|
+
}
|
|
354
|
+
]
|
|
355
|
+
|
|
356
|
+
# Extract usage from final chunk
|
|
357
|
+
usage = None
|
|
358
|
+
if hasattr(chunk, "usage_metadata") and chunk.usage_metadata:
|
|
359
|
+
usage = Usage(
|
|
360
|
+
prompt_tokens=getattr(chunk.usage_metadata, "prompt_token_count", 0)
|
|
361
|
+
or 0,
|
|
362
|
+
completion_tokens=getattr(
|
|
363
|
+
chunk.usage_metadata, "candidates_token_count", 0
|
|
364
|
+
)
|
|
365
|
+
or 0,
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
# Determine finish reason
|
|
369
|
+
finish_reason = None
|
|
370
|
+
if chunk.candidates and len(chunk.candidates) > 0:
|
|
371
|
+
candidate = chunk.candidates[0]
|
|
372
|
+
if hasattr(candidate, "finish_reason") and candidate.finish_reason:
|
|
373
|
+
reason = str(candidate.finish_reason).lower()
|
|
374
|
+
if "stop" in reason:
|
|
375
|
+
finish_reason = "stop"
|
|
376
|
+
elif "length" in reason:
|
|
377
|
+
finish_reason = "length"
|
|
378
|
+
|
|
379
|
+
return StreamChunk(
|
|
380
|
+
content=accumulated_content + delta,
|
|
381
|
+
delta=delta,
|
|
382
|
+
finish_reason=finish_reason,
|
|
383
|
+
usage=usage,
|
|
384
|
+
tool_calls=tool_calls,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
# -------------------------------------------------------------------------
|
|
388
|
+
# Synchronous Methods
|
|
389
|
+
# -------------------------------------------------------------------------
|
|
390
|
+
|
|
391
|
+
def invoke(
|
|
392
|
+
self,
|
|
393
|
+
messages: list[Message],
|
|
394
|
+
config: ModelConfig | None = None,
|
|
395
|
+
) -> LLMResponse:
|
|
396
|
+
"""
|
|
397
|
+
Synchronously invoke the Google Gemini model.
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
messages: List of messages in the conversation
|
|
401
|
+
config: Optional configuration overrides
|
|
402
|
+
|
|
403
|
+
Returns:
|
|
404
|
+
LLMResponse with the model's response
|
|
405
|
+
"""
|
|
406
|
+
merged_config = self._merge_config(config)
|
|
407
|
+
contents, system_instruction = self._convert_messages(messages)
|
|
408
|
+
generate_config = self._convert_config_to_generate_config(
|
|
409
|
+
merged_config, system_instruction
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
response = self._client.models.generate_content(
|
|
413
|
+
model=self._model,
|
|
414
|
+
contents=contents,
|
|
415
|
+
config=generate_config,
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
llm_response = self._parse_response(response, self._model)
|
|
419
|
+
self._update_tracking(llm_response)
|
|
420
|
+
return llm_response
|
|
421
|
+
|
|
422
|
+
def stream(
|
|
423
|
+
self,
|
|
424
|
+
messages: list[Message],
|
|
425
|
+
config: ModelConfig | None = None,
|
|
426
|
+
) -> Iterator[StreamChunk]:
|
|
427
|
+
"""
|
|
428
|
+
Synchronously stream the Google Gemini model response.
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
messages: List of messages in the conversation
|
|
432
|
+
config: Optional configuration overrides
|
|
433
|
+
|
|
434
|
+
Yields:
|
|
435
|
+
StreamChunk objects as they arrive
|
|
436
|
+
"""
|
|
437
|
+
merged_config = self._merge_config(config)
|
|
438
|
+
contents, system_instruction = self._convert_messages(messages)
|
|
439
|
+
generate_config = self._convert_config_to_generate_config(
|
|
440
|
+
merged_config, system_instruction
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
response_stream = self._client.models.generate_content_stream(
|
|
444
|
+
model=self._model,
|
|
445
|
+
contents=contents,
|
|
446
|
+
config=generate_config,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
accumulated_content = ""
|
|
450
|
+
for chunk in response_stream:
|
|
451
|
+
stream_chunk = self._parse_stream_chunk(
|
|
452
|
+
chunk, accumulated_content, self._model
|
|
453
|
+
)
|
|
454
|
+
accumulated_content = stream_chunk.content
|
|
455
|
+
yield stream_chunk
|
|
456
|
+
|
|
457
|
+
# Update tracking on final chunk
|
|
458
|
+
if stream_chunk.is_final and stream_chunk.usage:
|
|
459
|
+
final_response = LLMResponse(
|
|
460
|
+
content=accumulated_content,
|
|
461
|
+
model=self._model,
|
|
462
|
+
usage=stream_chunk.usage,
|
|
463
|
+
finish_reason=stream_chunk.finish_reason,
|
|
464
|
+
cost=self.estimate_cost(stream_chunk.usage),
|
|
465
|
+
)
|
|
466
|
+
self._update_tracking(final_response)
|
|
467
|
+
|
|
468
|
+
# -------------------------------------------------------------------------
|
|
469
|
+
# Asynchronous Methods
|
|
470
|
+
# -------------------------------------------------------------------------
|
|
471
|
+
|
|
472
|
+
async def ainvoke(
|
|
473
|
+
self,
|
|
474
|
+
messages: list[Message],
|
|
475
|
+
config: ModelConfig | None = None,
|
|
476
|
+
) -> LLMResponse:
|
|
477
|
+
"""
|
|
478
|
+
Asynchronously invoke the Google Gemini model.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
messages: List of messages in the conversation
|
|
482
|
+
config: Optional configuration overrides
|
|
483
|
+
|
|
484
|
+
Returns:
|
|
485
|
+
LLMResponse with the model's response
|
|
486
|
+
"""
|
|
487
|
+
merged_config = self._merge_config(config)
|
|
488
|
+
contents, system_instruction = self._convert_messages(messages)
|
|
489
|
+
generate_config = self._convert_config_to_generate_config(
|
|
490
|
+
merged_config, system_instruction
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
response = await self._client.aio.models.generate_content(
|
|
494
|
+
model=self._model,
|
|
495
|
+
contents=contents,
|
|
496
|
+
config=generate_config,
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
llm_response = self._parse_response(response, self._model)
|
|
500
|
+
self._update_tracking(llm_response)
|
|
501
|
+
return llm_response
|
|
502
|
+
|
|
503
|
+
async def astream(
|
|
504
|
+
self,
|
|
505
|
+
messages: list[Message],
|
|
506
|
+
config: ModelConfig | None = None,
|
|
507
|
+
) -> AsyncIterator[StreamChunk]:
|
|
508
|
+
"""
|
|
509
|
+
Asynchronously stream the Google Gemini model response.
|
|
510
|
+
|
|
511
|
+
Args:
|
|
512
|
+
messages: List of messages in the conversation
|
|
513
|
+
config: Optional configuration overrides
|
|
514
|
+
|
|
515
|
+
Yields:
|
|
516
|
+
StreamChunk objects as they arrive
|
|
517
|
+
"""
|
|
518
|
+
merged_config = self._merge_config(config)
|
|
519
|
+
contents, system_instruction = self._convert_messages(messages)
|
|
520
|
+
generate_config = self._convert_config_to_generate_config(
|
|
521
|
+
merged_config, system_instruction
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
response_stream = await self._client.aio.models.generate_content_stream(
|
|
525
|
+
model=self._model,
|
|
526
|
+
contents=contents,
|
|
527
|
+
config=generate_config,
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
accumulated_content = ""
|
|
531
|
+
async for chunk in response_stream:
|
|
532
|
+
stream_chunk = self._parse_stream_chunk(
|
|
533
|
+
chunk, accumulated_content, self._model
|
|
534
|
+
)
|
|
535
|
+
accumulated_content = stream_chunk.content
|
|
536
|
+
yield stream_chunk
|
|
537
|
+
|
|
538
|
+
# Update tracking on final chunk
|
|
539
|
+
if stream_chunk.is_final and stream_chunk.usage:
|
|
540
|
+
final_response = LLMResponse(
|
|
541
|
+
content=accumulated_content,
|
|
542
|
+
model=self._model,
|
|
543
|
+
usage=stream_chunk.usage,
|
|
544
|
+
finish_reason=stream_chunk.finish_reason,
|
|
545
|
+
cost=self.estimate_cost(stream_chunk.usage),
|
|
546
|
+
)
|
|
547
|
+
self._update_tracking(final_response)
|
|
548
|
+
|
|
549
|
+
# -------------------------------------------------------------------------
|
|
550
|
+
# Token & Cost Methods
|
|
551
|
+
# -------------------------------------------------------------------------
|
|
552
|
+
|
|
553
|
+
def count_tokens(
|
|
554
|
+
self,
|
|
555
|
+
text: str | list[Message],
|
|
556
|
+
) -> int:
|
|
557
|
+
"""
|
|
558
|
+
Count the number of tokens in the given text or messages.
|
|
559
|
+
|
|
560
|
+
Args:
|
|
561
|
+
text: A string or list of messages to count tokens for
|
|
562
|
+
|
|
563
|
+
Returns:
|
|
564
|
+
Number of tokens
|
|
565
|
+
"""
|
|
566
|
+
try:
|
|
567
|
+
if isinstance(text, str):
|
|
568
|
+
response = self._client.models.count_tokens(
|
|
569
|
+
model=self._model,
|
|
570
|
+
contents=text,
|
|
571
|
+
)
|
|
572
|
+
else:
|
|
573
|
+
contents, _ = self._convert_messages(text)
|
|
574
|
+
response = self._client.models.count_tokens(
|
|
575
|
+
model=self._model,
|
|
576
|
+
contents=contents,
|
|
577
|
+
)
|
|
578
|
+
return getattr(response, "total_tokens", 0) or 0
|
|
579
|
+
except Exception:
|
|
580
|
+
# Fallback to character-based estimation
|
|
581
|
+
if isinstance(text, str):
|
|
582
|
+
return len(text) // 4
|
|
583
|
+
else:
|
|
584
|
+
total_chars = sum(len(msg.content) for msg in text)
|
|
585
|
+
return total_chars // 4
|
|
586
|
+
|
|
587
|
+
def estimate_cost(
|
|
588
|
+
self,
|
|
589
|
+
usage: Usage,
|
|
590
|
+
) -> CostInfo:
|
|
591
|
+
"""
|
|
592
|
+
Estimate the cost for the given token usage.
|
|
593
|
+
|
|
594
|
+
Args:
|
|
595
|
+
usage: Token usage information
|
|
596
|
+
|
|
597
|
+
Returns:
|
|
598
|
+
CostInfo with cost breakdown
|
|
599
|
+
"""
|
|
600
|
+
# Try to find exact pricing, then fall back to base model name
|
|
601
|
+
pricing = GEMINI_PRICING.get(self._model)
|
|
602
|
+
|
|
603
|
+
if not pricing:
|
|
604
|
+
# Try to match by prefix (e.g., "gemini-2.5-flash-preview" -> "gemini-2.5-flash")
|
|
605
|
+
for model_prefix, model_pricing in GEMINI_PRICING.items():
|
|
606
|
+
if self._model.startswith(model_prefix):
|
|
607
|
+
pricing = model_pricing
|
|
608
|
+
break
|
|
609
|
+
|
|
610
|
+
if not pricing:
|
|
611
|
+
# Default to gemini-2.5-flash pricing if unknown model
|
|
612
|
+
pricing = GEMINI_PRICING.get(
|
|
613
|
+
"gemini-2.5-flash",
|
|
614
|
+
ModelPricing(
|
|
615
|
+
input_cost_per_million=0.15,
|
|
616
|
+
output_cost_per_million=0.60,
|
|
617
|
+
),
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
return pricing.calculate_cost(usage)
|
|
621
|
+
|
|
622
|
+
# -------------------------------------------------------------------------
|
|
623
|
+
# Utility Methods
|
|
624
|
+
# -------------------------------------------------------------------------
|
|
625
|
+
|
|
626
|
+
def get_model_info(self) -> ModelInfo | None:
|
|
627
|
+
"""Get information about the current model."""
|
|
628
|
+
try:
|
|
629
|
+
model_info = self._client.models.get(model=self._model)
|
|
630
|
+
|
|
631
|
+
return ModelInfo(
|
|
632
|
+
name=self._model,
|
|
633
|
+
provider="google",
|
|
634
|
+
context_window=getattr(model_info, "input_token_limit", 0) or 128000,
|
|
635
|
+
max_output_tokens=getattr(model_info, "output_token_limit", None),
|
|
636
|
+
pricing=GEMINI_PRICING.get(self._model),
|
|
637
|
+
supports_tools=True,
|
|
638
|
+
supports_streaming=True,
|
|
639
|
+
supports_json_mode=True,
|
|
640
|
+
supports_vision=True,
|
|
641
|
+
capabilities={
|
|
642
|
+
"display_name": getattr(model_info, "display_name", None),
|
|
643
|
+
"description": getattr(model_info, "description", None),
|
|
644
|
+
},
|
|
645
|
+
)
|
|
646
|
+
except Exception:
|
|
647
|
+
return None
|
|
648
|
+
|
|
649
|
+
@classmethod
|
|
650
|
+
def get_supported_models(cls, api_key: str | None = None) -> list[str]:
|
|
651
|
+
"""
|
|
652
|
+
Get list of models available from Google.
|
|
653
|
+
|
|
654
|
+
Args:
|
|
655
|
+
api_key: Optional API key
|
|
656
|
+
|
|
657
|
+
Returns:
|
|
658
|
+
List of available model names that support generation
|
|
659
|
+
"""
|
|
660
|
+
try:
|
|
661
|
+
client = genai.Client(api_key=api_key) if api_key else genai.Client()
|
|
662
|
+
models = []
|
|
663
|
+
|
|
664
|
+
for model in client.models.list():
|
|
665
|
+
model_name = getattr(model, "name", "")
|
|
666
|
+
# Filter to only include gemini models that support generateContent
|
|
667
|
+
if model_name and "gemini" in model_name.lower():
|
|
668
|
+
supported_methods = getattr(
|
|
669
|
+
model, "supported_generation_methods", []
|
|
670
|
+
)
|
|
671
|
+
if supported_methods is None:
|
|
672
|
+
supported_methods = []
|
|
673
|
+
# Include models that support content generation
|
|
674
|
+
if (
|
|
675
|
+
any("generateContent" in method for method in supported_methods)
|
|
676
|
+
or not supported_methods
|
|
677
|
+
):
|
|
678
|
+
# Extract just the model ID from full path
|
|
679
|
+
# e.g., "models/gemini-2.5-flash" -> "gemini-2.5-flash"
|
|
680
|
+
if "/" in model_name:
|
|
681
|
+
model_name = model_name.split("/")[-1]
|
|
682
|
+
models.append(model_name)
|
|
683
|
+
|
|
684
|
+
return models
|
|
685
|
+
except Exception:
|
|
686
|
+
return []
|
|
687
|
+
|
|
688
|
+
def list_models(self) -> list[str]:
|
|
689
|
+
"""List all available Gemini models."""
|
|
690
|
+
return self.get_supported_models(self._api_key)
|
kader/providers/ollama.py
CHANGED
|
@@ -433,11 +433,11 @@ class OllamaProvider(BaseLLMProvider):
|
|
|
433
433
|
models_config = {}
|
|
434
434
|
for model in models:
|
|
435
435
|
models_config[model] = client.show(model)
|
|
436
|
+
accepted_capabilities = ["completion", "tools"]
|
|
436
437
|
return [
|
|
437
438
|
model
|
|
438
439
|
for model, config in models_config.items()
|
|
439
|
-
if config.capabilities
|
|
440
|
-
in [["completion", "tools", "thinking"], ["completion", "tools"]]
|
|
440
|
+
if set(accepted_capabilities).issubset(set(config.capabilities))
|
|
441
441
|
]
|
|
442
442
|
except Exception:
|
|
443
443
|
return []
|