kader 0.1.6__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,690 @@
1
+ """
2
+ Google LLM Provider implementation.
3
+
4
+ Provides synchronous and asynchronous access to Google Gemini models
5
+ via the Google Gen AI SDK.
6
+ """
7
+
8
+ import os
9
+ from typing import AsyncIterator, Iterator
10
+
11
+ from google import genai
12
+ from google.genai import types
13
+
14
+ # Import config to ensure ~/.kader/.env is loaded
15
+ import kader.config # noqa: F401
16
+
17
+ from .base import (
18
+ BaseLLMProvider,
19
+ CostInfo,
20
+ LLMResponse,
21
+ Message,
22
+ ModelConfig,
23
+ ModelInfo,
24
+ ModelPricing,
25
+ StreamChunk,
26
+ Usage,
27
+ )
28
+
29
+ # Pricing data for Gemini models (per 1M tokens, in USD)
30
+ GEMINI_PRICING: dict[str, ModelPricing] = {
31
+ "gemini-2.5-flash": ModelPricing(
32
+ input_cost_per_million=0.15,
33
+ output_cost_per_million=0.60,
34
+ cached_input_cost_per_million=0.0375,
35
+ ),
36
+ "gemini-2.5-flash-preview-05-20": ModelPricing(
37
+ input_cost_per_million=0.15,
38
+ output_cost_per_million=0.60,
39
+ cached_input_cost_per_million=0.0375,
40
+ ),
41
+ "gemini-2.5-pro": ModelPricing(
42
+ input_cost_per_million=1.25,
43
+ output_cost_per_million=10.00,
44
+ cached_input_cost_per_million=0.3125,
45
+ ),
46
+ "gemini-2.5-pro-preview-05-06": ModelPricing(
47
+ input_cost_per_million=1.25,
48
+ output_cost_per_million=10.00,
49
+ cached_input_cost_per_million=0.3125,
50
+ ),
51
+ "gemini-2.0-flash": ModelPricing(
52
+ input_cost_per_million=0.10,
53
+ output_cost_per_million=0.40,
54
+ cached_input_cost_per_million=0.025,
55
+ ),
56
+ "gemini-2.0-flash-lite": ModelPricing(
57
+ input_cost_per_million=0.075,
58
+ output_cost_per_million=0.30,
59
+ cached_input_cost_per_million=0.01875,
60
+ ),
61
+ "gemini-1.5-flash": ModelPricing(
62
+ input_cost_per_million=0.075,
63
+ output_cost_per_million=0.30,
64
+ cached_input_cost_per_million=0.01875,
65
+ ),
66
+ "gemini-1.5-pro": ModelPricing(
67
+ input_cost_per_million=1.25,
68
+ output_cost_per_million=5.00,
69
+ cached_input_cost_per_million=0.3125,
70
+ ),
71
+ }
72
+
73
+
74
+ class GoogleProvider(BaseLLMProvider):
75
+ """
76
+ Google LLM Provider.
77
+
78
+ Provides access to Google Gemini models with full support
79
+ for synchronous and asynchronous operations, including streaming.
80
+
81
+ The API key is loaded from (in order of priority):
82
+ 1. The `api_key` parameter passed to the constructor
83
+ 2. The GEMINI_API_KEY environment variable (loaded from ~/.kader/.env)
84
+ 3. The GOOGLE_API_KEY environment variable
85
+
86
+ Example:
87
+ provider = GoogleProvider(model="gemini-2.5-flash")
88
+ response = provider.invoke([Message.user("Hello!")])
89
+ print(response.content)
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ model: str,
95
+ api_key: str | None = None,
96
+ default_config: ModelConfig | None = None,
97
+ ) -> None:
98
+ """
99
+ Initialize the Google provider.
100
+
101
+ Args:
102
+ model: The Gemini model identifier (e.g., "gemini-2.5-flash")
103
+ api_key: Optional API key. If not provided, uses GEMINI_API_KEY
104
+ from ~/.kader/.env or GOOGLE_API_KEY environment variable.
105
+ default_config: Default configuration for all requests
106
+ """
107
+ super().__init__(model=model, default_config=default_config)
108
+
109
+ # Resolve API key: parameter > GEMINI_API_KEY > GOOGLE_API_KEY
110
+ if api_key is None:
111
+ api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get(
112
+ "GOOGLE_API_KEY"
113
+ )
114
+ # Filter out empty strings from the .env default
115
+ if api_key == "":
116
+ api_key = None
117
+
118
+ self._api_key = api_key
119
+ self._client = genai.Client(api_key=api_key) if api_key else genai.Client()
120
+
121
+ def _convert_messages(
122
+ self, messages: list[Message]
123
+ ) -> tuple[list[types.Content], str | None]:
124
+ """
125
+ Convert Message objects to Google GenAI Content format.
126
+
127
+ Returns:
128
+ Tuple of (contents list, system_instruction if present)
129
+ """
130
+ contents: list[types.Content] = []
131
+ system_instruction: str | None = None
132
+
133
+ for msg in messages:
134
+ if msg.role == "system":
135
+ # System messages are handled separately in Google's API
136
+ system_instruction = msg.content
137
+ elif msg.role == "user":
138
+ contents.append(
139
+ types.Content(
140
+ role="user",
141
+ parts=[types.Part.from_text(text=msg.content)],
142
+ )
143
+ )
144
+ elif msg.role == "assistant":
145
+ parts: list[types.Part] = []
146
+ if msg.content:
147
+ parts.append(types.Part.from_text(text=msg.content))
148
+ if msg.tool_calls:
149
+ for tc in msg.tool_calls:
150
+ parts.append(
151
+ types.Part.from_function_call(
152
+ name=tc["function"]["name"],
153
+ args=tc["function"]["arguments"]
154
+ if isinstance(tc["function"]["arguments"], dict)
155
+ else {},
156
+ )
157
+ )
158
+ contents.append(types.Content(role="model", parts=parts))
159
+ elif msg.role == "tool":
160
+ contents.append(
161
+ types.Content(
162
+ role="tool",
163
+ parts=[
164
+ types.Part.from_function_response(
165
+ name=msg.name or "tool",
166
+ response={"result": msg.content},
167
+ )
168
+ ],
169
+ )
170
+ )
171
+
172
+ return contents, system_instruction
173
+
174
+ def _convert_config_to_generate_config(
175
+ self, config: ModelConfig, system_instruction: str | None = None
176
+ ) -> types.GenerateContentConfig:
177
+ """Convert ModelConfig to Google GenerateContentConfig."""
178
+ generate_config = types.GenerateContentConfig(
179
+ temperature=config.temperature if config.temperature != 1.0 else None,
180
+ max_output_tokens=config.max_tokens,
181
+ top_p=config.top_p if config.top_p != 1.0 else None,
182
+ top_k=config.top_k,
183
+ stop_sequences=config.stop_sequences,
184
+ system_instruction=system_instruction,
185
+ )
186
+
187
+ # Handle tools - convert from dict format to Google's FunctionDeclaration format
188
+ if config.tools:
189
+ google_tools = self._convert_tools_to_google_format(config.tools)
190
+ if google_tools:
191
+ generate_config.tools = google_tools
192
+
193
+ # Handle response format
194
+ if config.response_format:
195
+ resp_format_type = config.response_format.get("type")
196
+ if resp_format_type == "json_object":
197
+ generate_config.response_mime_type = "application/json"
198
+
199
+ return generate_config
200
+
201
+ def _convert_tools_to_google_format(
202
+ self, tools: list[dict]
203
+ ) -> list[types.Tool] | None:
204
+ """
205
+ Convert tool definitions from dict format to Google's Tool format.
206
+
207
+ Args:
208
+ tools: List of tool definitions (from to_google_format or to_openai_format)
209
+
210
+ Returns:
211
+ List of Google Tool objects, or None if no valid tools
212
+ """
213
+ if not tools:
214
+ return None
215
+
216
+ function_declarations: list[types.FunctionDeclaration] = []
217
+
218
+ for tool in tools:
219
+ # Handle OpenAI format (type: "function", function: {...})
220
+ if tool.get("type") == "function" and "function" in tool:
221
+ func_def = tool["function"]
222
+ name = func_def.get("name", "")
223
+ description = func_def.get("description", "")
224
+ parameters = func_def.get("parameters", {})
225
+ # Handle Google format (directly has name, description, parameters)
226
+ elif "name" in tool:
227
+ name = tool.get("name", "")
228
+ description = tool.get("description", "")
229
+ parameters = tool.get("parameters", {})
230
+ else:
231
+ continue
232
+
233
+ if not name:
234
+ continue
235
+
236
+ # Create FunctionDeclaration
237
+ try:
238
+ func_decl = types.FunctionDeclaration(
239
+ name=name,
240
+ description=description,
241
+ parameters=parameters if parameters else None,
242
+ )
243
+ function_declarations.append(func_decl)
244
+ except Exception:
245
+ # Skip invalid function declarations
246
+ continue
247
+
248
+ if not function_declarations:
249
+ return None
250
+
251
+ # Wrap all function declarations in a single Tool
252
+ return [types.Tool(function_declarations=function_declarations)]
253
+
254
+ def _parse_response(self, response, model: str) -> LLMResponse:
255
+ """Parse Google GenAI response to LLMResponse."""
256
+ # Extract content
257
+ content = ""
258
+ tool_calls = None
259
+
260
+ if response.candidates and len(response.candidates) > 0:
261
+ candidate = response.candidates[0]
262
+ if candidate.content and candidate.content.parts:
263
+ text_parts = []
264
+ function_calls = []
265
+
266
+ for part in candidate.content.parts:
267
+ if hasattr(part, "text") and part.text:
268
+ text_parts.append(part.text)
269
+ if hasattr(part, "function_call") and part.function_call:
270
+ fc = part.function_call
271
+ function_calls.append(
272
+ {
273
+ "id": f"call_{len(function_calls)}",
274
+ "type": "function",
275
+ "function": {
276
+ "name": fc.name,
277
+ "arguments": dict(fc.args) if fc.args else {},
278
+ },
279
+ }
280
+ )
281
+
282
+ content = "".join(text_parts)
283
+ if function_calls:
284
+ tool_calls = function_calls
285
+
286
+ # Extract usage
287
+ usage = Usage()
288
+ if hasattr(response, "usage_metadata") and response.usage_metadata:
289
+ usage = Usage(
290
+ prompt_tokens=getattr(response.usage_metadata, "prompt_token_count", 0)
291
+ or 0,
292
+ completion_tokens=getattr(
293
+ response.usage_metadata, "candidates_token_count", 0
294
+ )
295
+ or 0,
296
+ cached_tokens=getattr(
297
+ response.usage_metadata, "cached_content_token_count", 0
298
+ )
299
+ or 0,
300
+ )
301
+
302
+ # Determine finish reason
303
+ finish_reason = "stop"
304
+ if response.candidates and len(response.candidates) > 0:
305
+ candidate = response.candidates[0]
306
+ if hasattr(candidate, "finish_reason") and candidate.finish_reason:
307
+ reason = str(candidate.finish_reason).lower()
308
+ if "stop" in reason:
309
+ finish_reason = "stop"
310
+ elif "length" in reason or "max_tokens" in reason:
311
+ finish_reason = "length"
312
+ elif "tool" in reason or "function" in reason:
313
+ finish_reason = "tool_calls"
314
+ elif "safety" in reason or "filter" in reason:
315
+ finish_reason = "content_filter"
316
+
317
+ # Calculate cost
318
+ cost = self.estimate_cost(usage)
319
+
320
+ return LLMResponse(
321
+ content=content,
322
+ model=model,
323
+ usage=usage,
324
+ finish_reason=finish_reason,
325
+ cost=cost,
326
+ tool_calls=tool_calls,
327
+ raw_response=response,
328
+ )
329
+
330
+ def _parse_stream_chunk(
331
+ self, chunk, accumulated_content: str, model: str
332
+ ) -> StreamChunk:
333
+ """Parse streaming chunk to StreamChunk."""
334
+ delta = ""
335
+ tool_calls = None
336
+
337
+ if chunk.candidates and len(chunk.candidates) > 0:
338
+ candidate = chunk.candidates[0]
339
+ if candidate.content and candidate.content.parts:
340
+ for part in candidate.content.parts:
341
+ if hasattr(part, "text") and part.text:
342
+ delta = part.text
343
+ if hasattr(part, "function_call") and part.function_call:
344
+ fc = part.function_call
345
+ tool_calls = [
346
+ {
347
+ "id": "call_0",
348
+ "type": "function",
349
+ "function": {
350
+ "name": fc.name,
351
+ "arguments": dict(fc.args) if fc.args else {},
352
+ },
353
+ }
354
+ ]
355
+
356
+ # Extract usage from final chunk
357
+ usage = None
358
+ if hasattr(chunk, "usage_metadata") and chunk.usage_metadata:
359
+ usage = Usage(
360
+ prompt_tokens=getattr(chunk.usage_metadata, "prompt_token_count", 0)
361
+ or 0,
362
+ completion_tokens=getattr(
363
+ chunk.usage_metadata, "candidates_token_count", 0
364
+ )
365
+ or 0,
366
+ )
367
+
368
+ # Determine finish reason
369
+ finish_reason = None
370
+ if chunk.candidates and len(chunk.candidates) > 0:
371
+ candidate = chunk.candidates[0]
372
+ if hasattr(candidate, "finish_reason") and candidate.finish_reason:
373
+ reason = str(candidate.finish_reason).lower()
374
+ if "stop" in reason:
375
+ finish_reason = "stop"
376
+ elif "length" in reason:
377
+ finish_reason = "length"
378
+
379
+ return StreamChunk(
380
+ content=accumulated_content + delta,
381
+ delta=delta,
382
+ finish_reason=finish_reason,
383
+ usage=usage,
384
+ tool_calls=tool_calls,
385
+ )
386
+
387
+ # -------------------------------------------------------------------------
388
+ # Synchronous Methods
389
+ # -------------------------------------------------------------------------
390
+
391
+ def invoke(
392
+ self,
393
+ messages: list[Message],
394
+ config: ModelConfig | None = None,
395
+ ) -> LLMResponse:
396
+ """
397
+ Synchronously invoke the Google Gemini model.
398
+
399
+ Args:
400
+ messages: List of messages in the conversation
401
+ config: Optional configuration overrides
402
+
403
+ Returns:
404
+ LLMResponse with the model's response
405
+ """
406
+ merged_config = self._merge_config(config)
407
+ contents, system_instruction = self._convert_messages(messages)
408
+ generate_config = self._convert_config_to_generate_config(
409
+ merged_config, system_instruction
410
+ )
411
+
412
+ response = self._client.models.generate_content(
413
+ model=self._model,
414
+ contents=contents,
415
+ config=generate_config,
416
+ )
417
+
418
+ llm_response = self._parse_response(response, self._model)
419
+ self._update_tracking(llm_response)
420
+ return llm_response
421
+
422
+ def stream(
423
+ self,
424
+ messages: list[Message],
425
+ config: ModelConfig | None = None,
426
+ ) -> Iterator[StreamChunk]:
427
+ """
428
+ Synchronously stream the Google Gemini model response.
429
+
430
+ Args:
431
+ messages: List of messages in the conversation
432
+ config: Optional configuration overrides
433
+
434
+ Yields:
435
+ StreamChunk objects as they arrive
436
+ """
437
+ merged_config = self._merge_config(config)
438
+ contents, system_instruction = self._convert_messages(messages)
439
+ generate_config = self._convert_config_to_generate_config(
440
+ merged_config, system_instruction
441
+ )
442
+
443
+ response_stream = self._client.models.generate_content_stream(
444
+ model=self._model,
445
+ contents=contents,
446
+ config=generate_config,
447
+ )
448
+
449
+ accumulated_content = ""
450
+ for chunk in response_stream:
451
+ stream_chunk = self._parse_stream_chunk(
452
+ chunk, accumulated_content, self._model
453
+ )
454
+ accumulated_content = stream_chunk.content
455
+ yield stream_chunk
456
+
457
+ # Update tracking on final chunk
458
+ if stream_chunk.is_final and stream_chunk.usage:
459
+ final_response = LLMResponse(
460
+ content=accumulated_content,
461
+ model=self._model,
462
+ usage=stream_chunk.usage,
463
+ finish_reason=stream_chunk.finish_reason,
464
+ cost=self.estimate_cost(stream_chunk.usage),
465
+ )
466
+ self._update_tracking(final_response)
467
+
468
+ # -------------------------------------------------------------------------
469
+ # Asynchronous Methods
470
+ # -------------------------------------------------------------------------
471
+
472
+ async def ainvoke(
473
+ self,
474
+ messages: list[Message],
475
+ config: ModelConfig | None = None,
476
+ ) -> LLMResponse:
477
+ """
478
+ Asynchronously invoke the Google Gemini model.
479
+
480
+ Args:
481
+ messages: List of messages in the conversation
482
+ config: Optional configuration overrides
483
+
484
+ Returns:
485
+ LLMResponse with the model's response
486
+ """
487
+ merged_config = self._merge_config(config)
488
+ contents, system_instruction = self._convert_messages(messages)
489
+ generate_config = self._convert_config_to_generate_config(
490
+ merged_config, system_instruction
491
+ )
492
+
493
+ response = await self._client.aio.models.generate_content(
494
+ model=self._model,
495
+ contents=contents,
496
+ config=generate_config,
497
+ )
498
+
499
+ llm_response = self._parse_response(response, self._model)
500
+ self._update_tracking(llm_response)
501
+ return llm_response
502
+
503
+ async def astream(
504
+ self,
505
+ messages: list[Message],
506
+ config: ModelConfig | None = None,
507
+ ) -> AsyncIterator[StreamChunk]:
508
+ """
509
+ Asynchronously stream the Google Gemini model response.
510
+
511
+ Args:
512
+ messages: List of messages in the conversation
513
+ config: Optional configuration overrides
514
+
515
+ Yields:
516
+ StreamChunk objects as they arrive
517
+ """
518
+ merged_config = self._merge_config(config)
519
+ contents, system_instruction = self._convert_messages(messages)
520
+ generate_config = self._convert_config_to_generate_config(
521
+ merged_config, system_instruction
522
+ )
523
+
524
+ response_stream = await self._client.aio.models.generate_content_stream(
525
+ model=self._model,
526
+ contents=contents,
527
+ config=generate_config,
528
+ )
529
+
530
+ accumulated_content = ""
531
+ async for chunk in response_stream:
532
+ stream_chunk = self._parse_stream_chunk(
533
+ chunk, accumulated_content, self._model
534
+ )
535
+ accumulated_content = stream_chunk.content
536
+ yield stream_chunk
537
+
538
+ # Update tracking on final chunk
539
+ if stream_chunk.is_final and stream_chunk.usage:
540
+ final_response = LLMResponse(
541
+ content=accumulated_content,
542
+ model=self._model,
543
+ usage=stream_chunk.usage,
544
+ finish_reason=stream_chunk.finish_reason,
545
+ cost=self.estimate_cost(stream_chunk.usage),
546
+ )
547
+ self._update_tracking(final_response)
548
+
549
+ # -------------------------------------------------------------------------
550
+ # Token & Cost Methods
551
+ # -------------------------------------------------------------------------
552
+
553
+ def count_tokens(
554
+ self,
555
+ text: str | list[Message],
556
+ ) -> int:
557
+ """
558
+ Count the number of tokens in the given text or messages.
559
+
560
+ Args:
561
+ text: A string or list of messages to count tokens for
562
+
563
+ Returns:
564
+ Number of tokens
565
+ """
566
+ try:
567
+ if isinstance(text, str):
568
+ response = self._client.models.count_tokens(
569
+ model=self._model,
570
+ contents=text,
571
+ )
572
+ else:
573
+ contents, _ = self._convert_messages(text)
574
+ response = self._client.models.count_tokens(
575
+ model=self._model,
576
+ contents=contents,
577
+ )
578
+ return getattr(response, "total_tokens", 0) or 0
579
+ except Exception:
580
+ # Fallback to character-based estimation
581
+ if isinstance(text, str):
582
+ return len(text) // 4
583
+ else:
584
+ total_chars = sum(len(msg.content) for msg in text)
585
+ return total_chars // 4
586
+
587
+ def estimate_cost(
588
+ self,
589
+ usage: Usage,
590
+ ) -> CostInfo:
591
+ """
592
+ Estimate the cost for the given token usage.
593
+
594
+ Args:
595
+ usage: Token usage information
596
+
597
+ Returns:
598
+ CostInfo with cost breakdown
599
+ """
600
+ # Try to find exact pricing, then fall back to base model name
601
+ pricing = GEMINI_PRICING.get(self._model)
602
+
603
+ if not pricing:
604
+ # Try to match by prefix (e.g., "gemini-2.5-flash-preview" -> "gemini-2.5-flash")
605
+ for model_prefix, model_pricing in GEMINI_PRICING.items():
606
+ if self._model.startswith(model_prefix):
607
+ pricing = model_pricing
608
+ break
609
+
610
+ if not pricing:
611
+ # Default to gemini-2.5-flash pricing if unknown model
612
+ pricing = GEMINI_PRICING.get(
613
+ "gemini-2.5-flash",
614
+ ModelPricing(
615
+ input_cost_per_million=0.15,
616
+ output_cost_per_million=0.60,
617
+ ),
618
+ )
619
+
620
+ return pricing.calculate_cost(usage)
621
+
622
+ # -------------------------------------------------------------------------
623
+ # Utility Methods
624
+ # -------------------------------------------------------------------------
625
+
626
+ def get_model_info(self) -> ModelInfo | None:
627
+ """Get information about the current model."""
628
+ try:
629
+ model_info = self._client.models.get(model=self._model)
630
+
631
+ return ModelInfo(
632
+ name=self._model,
633
+ provider="google",
634
+ context_window=getattr(model_info, "input_token_limit", 0) or 128000,
635
+ max_output_tokens=getattr(model_info, "output_token_limit", None),
636
+ pricing=GEMINI_PRICING.get(self._model),
637
+ supports_tools=True,
638
+ supports_streaming=True,
639
+ supports_json_mode=True,
640
+ supports_vision=True,
641
+ capabilities={
642
+ "display_name": getattr(model_info, "display_name", None),
643
+ "description": getattr(model_info, "description", None),
644
+ },
645
+ )
646
+ except Exception:
647
+ return None
648
+
649
+ @classmethod
650
+ def get_supported_models(cls, api_key: str | None = None) -> list[str]:
651
+ """
652
+ Get list of models available from Google.
653
+
654
+ Args:
655
+ api_key: Optional API key
656
+
657
+ Returns:
658
+ List of available model names that support generation
659
+ """
660
+ try:
661
+ client = genai.Client(api_key=api_key) if api_key else genai.Client()
662
+ models = []
663
+
664
+ for model in client.models.list():
665
+ model_name = getattr(model, "name", "")
666
+ # Filter to only include gemini models that support generateContent
667
+ if model_name and "gemini" in model_name.lower():
668
+ supported_methods = getattr(
669
+ model, "supported_generation_methods", []
670
+ )
671
+ if supported_methods is None:
672
+ supported_methods = []
673
+ # Include models that support content generation
674
+ if (
675
+ any("generateContent" in method for method in supported_methods)
676
+ or not supported_methods
677
+ ):
678
+ # Extract just the model ID from full path
679
+ # e.g., "models/gemini-2.5-flash" -> "gemini-2.5-flash"
680
+ if "/" in model_name:
681
+ model_name = model_name.split("/")[-1]
682
+ models.append(model_name)
683
+
684
+ return models
685
+ except Exception:
686
+ return []
687
+
688
+ def list_models(self) -> list[str]:
689
+ """List all available Gemini models."""
690
+ return self.get_supported_models(self._api_key)
kader/providers/ollama.py CHANGED
@@ -433,11 +433,11 @@ class OllamaProvider(BaseLLMProvider):
433
433
  models_config = {}
434
434
  for model in models:
435
435
  models_config[model] = client.show(model)
436
+ accepted_capabilities = ["completion", "tools"]
436
437
  return [
437
438
  model
438
439
  for model, config in models_config.items()
439
- if config.capabilities
440
- in [["completion", "tools", "thinking"], ["completion", "tools"]]
440
+ if set(accepted_capabilities).issubset(set(config.capabilities))
441
441
  ]
442
442
  except Exception:
443
443
  return []