kader 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/app.py +10 -6
- cli/llm_factory.py +165 -0
- cli/utils.py +19 -11
- kader/agent/base.py +16 -2
- kader/config.py +10 -2
- kader/providers/__init__.py +2 -0
- kader/providers/google.py +690 -0
- {kader-1.0.0.dist-info → kader-1.1.0.dist-info}/METADATA +2 -1
- {kader-1.0.0.dist-info → kader-1.1.0.dist-info}/RECORD +11 -9
- {kader-1.0.0.dist-info → kader-1.1.0.dist-info}/WHEEL +0 -0
- {kader-1.0.0.dist-info → kader-1.1.0.dist-info}/entry_points.txt +0 -0
cli/app.py
CHANGED
|
@@ -26,6 +26,7 @@ from kader.memory import (
|
|
|
26
26
|
)
|
|
27
27
|
from kader.workflows import PlannerExecutorWorkflow
|
|
28
28
|
|
|
29
|
+
from .llm_factory import LLMProviderFactory
|
|
29
30
|
from .utils import (
|
|
30
31
|
DEFAULT_MODEL,
|
|
31
32
|
HELP_TEXT,
|
|
@@ -114,9 +115,13 @@ class KaderApp(App):
|
|
|
114
115
|
|
|
115
116
|
def _create_workflow(self, model_name: str) -> PlannerExecutorWorkflow:
|
|
116
117
|
"""Create a new PlannerExecutorWorkflow with the specified model."""
|
|
118
|
+
# Create provider using factory (supports provider:model format)
|
|
119
|
+
provider = LLMProviderFactory.create_provider(model_name)
|
|
120
|
+
|
|
117
121
|
return PlannerExecutorWorkflow(
|
|
118
122
|
name="kader_cli",
|
|
119
|
-
|
|
123
|
+
provider=provider,
|
|
124
|
+
model_name=model_name, # Keep for reference
|
|
120
125
|
interrupt_before_tool=True,
|
|
121
126
|
tool_confirmation_callback=self._tool_confirmation_callback,
|
|
122
127
|
direct_execution_callback=self._direct_execution_callback,
|
|
@@ -268,13 +273,12 @@ class KaderApp(App):
|
|
|
268
273
|
|
|
269
274
|
async def _show_model_selector(self, conversation: ConversationView) -> None:
|
|
270
275
|
"""Show the model selector widget."""
|
|
271
|
-
from kader.providers import OllamaProvider
|
|
272
|
-
|
|
273
276
|
try:
|
|
274
|
-
models
|
|
277
|
+
# Get models from all available providers
|
|
278
|
+
models = LLMProviderFactory.get_flat_model_list()
|
|
275
279
|
if not models:
|
|
276
280
|
conversation.add_message(
|
|
277
|
-
"## Models (^^)\n\n*No models found.
|
|
281
|
+
"## Models (^^)\n\n*No models found. Check provider configurations.*",
|
|
278
282
|
"assistant",
|
|
279
283
|
)
|
|
280
284
|
return
|
|
@@ -569,7 +573,7 @@ Please resize your terminal."""
|
|
|
569
573
|
|
|
570
574
|
except Exception as e:
|
|
571
575
|
spinner.stop()
|
|
572
|
-
error_msg = f"(-) **Error:** {str(e)}\n\nMake sure
|
|
576
|
+
error_msg = f"(-) **Error:** {str(e)}\n\nMake sure the provider for `{self._current_model}` is configured and available."
|
|
573
577
|
conversation.add_message(error_msg, "assistant")
|
|
574
578
|
self.notify(f"Error: {e}", severity="error")
|
|
575
579
|
|
cli/llm_factory.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
"""LLM Provider Factory for Kader CLI.
|
|
2
|
+
|
|
3
|
+
Factory pattern implementation for creating LLM provider instances
|
|
4
|
+
with automatic provider detection based on model name format.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
from kader.providers import GoogleProvider, OllamaProvider
|
|
10
|
+
from kader.providers.base import BaseLLMProvider, ModelConfig
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LLMProviderFactory:
|
|
14
|
+
"""
|
|
15
|
+
Factory for creating LLM provider instances.
|
|
16
|
+
|
|
17
|
+
Supports multiple providers with automatic detection based on model name format.
|
|
18
|
+
Model names can be specified as:
|
|
19
|
+
- "provider:model" (e.g., "google:gemini-2.5-flash", "ollama:kimi-k2.5:cloud")
|
|
20
|
+
- "model" (defaults to Ollama for backward compatibility)
|
|
21
|
+
|
|
22
|
+
Example:
|
|
23
|
+
factory = LLMProviderFactory()
|
|
24
|
+
provider = factory.create_provider("google:gemini-2.5-flash")
|
|
25
|
+
|
|
26
|
+
# Or with default provider (Ollama)
|
|
27
|
+
provider = factory.create_provider("kimi-k2.5:cloud")
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
# Registered provider classes
|
|
31
|
+
PROVIDERS: dict[str, type[BaseLLMProvider]] = {
|
|
32
|
+
"ollama": OllamaProvider,
|
|
33
|
+
"google": GoogleProvider,
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
# Default provider when no prefix is specified
|
|
37
|
+
DEFAULT_PROVIDER = "ollama"
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def parse_model_name(cls, model_string: str) -> tuple[str, str]:
|
|
41
|
+
"""
|
|
42
|
+
Parse model string to extract provider and model name.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
model_string: Model string in format "provider:model" or just "model"
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Tuple of (provider_name, model_name)
|
|
49
|
+
"""
|
|
50
|
+
# Check if the string starts with a known provider prefix
|
|
51
|
+
for provider_name in cls.PROVIDERS.keys():
|
|
52
|
+
prefix = f"{provider_name}:"
|
|
53
|
+
if model_string.lower().startswith(prefix):
|
|
54
|
+
return provider_name, model_string[len(prefix) :]
|
|
55
|
+
|
|
56
|
+
# No known provider prefix found, use default
|
|
57
|
+
return cls.DEFAULT_PROVIDER, model_string
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def create_provider(
|
|
61
|
+
cls,
|
|
62
|
+
model_string: str,
|
|
63
|
+
config: Optional[ModelConfig] = None,
|
|
64
|
+
) -> BaseLLMProvider:
|
|
65
|
+
"""
|
|
66
|
+
Create an LLM provider instance.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
model_string: Model identifier (e.g., "google:gemini-2.5-flash" or "kimi-k2.5:cloud")
|
|
70
|
+
config: Optional model configuration
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Configured provider instance
|
|
74
|
+
|
|
75
|
+
Raises:
|
|
76
|
+
ValueError: If provider is not supported
|
|
77
|
+
"""
|
|
78
|
+
provider_name, model_name = cls.parse_model_name(model_string)
|
|
79
|
+
|
|
80
|
+
provider_class = cls.PROVIDERS.get(provider_name)
|
|
81
|
+
if not provider_class:
|
|
82
|
+
supported = ", ".join(cls.PROVIDERS.keys())
|
|
83
|
+
raise ValueError(
|
|
84
|
+
f"Unknown provider: {provider_name}. Supported: {supported}"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
return provider_class(model=model_name, default_config=config)
|
|
88
|
+
|
|
89
|
+
@classmethod
|
|
90
|
+
def get_all_models(cls) -> dict[str, list[str]]:
|
|
91
|
+
"""
|
|
92
|
+
Get all available models from all registered providers.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Dictionary mapping provider names to their available models
|
|
96
|
+
(with provider prefix included in model names)
|
|
97
|
+
"""
|
|
98
|
+
models: dict[str, list[str]] = {}
|
|
99
|
+
|
|
100
|
+
# Get Ollama models
|
|
101
|
+
try:
|
|
102
|
+
ollama_models = OllamaProvider.get_supported_models()
|
|
103
|
+
models["ollama"] = [f"ollama:{m}" for m in ollama_models]
|
|
104
|
+
except Exception:
|
|
105
|
+
models["ollama"] = []
|
|
106
|
+
|
|
107
|
+
# Get Google models
|
|
108
|
+
try:
|
|
109
|
+
google_models = GoogleProvider.get_supported_models()
|
|
110
|
+
models["google"] = [f"google:{m}" for m in google_models]
|
|
111
|
+
except Exception:
|
|
112
|
+
models["google"] = []
|
|
113
|
+
|
|
114
|
+
return models
|
|
115
|
+
|
|
116
|
+
@classmethod
|
|
117
|
+
def get_flat_model_list(cls) -> list[str]:
|
|
118
|
+
"""
|
|
119
|
+
Get a flattened list of all available models with provider prefixes.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
List of model strings in "provider:model" format
|
|
123
|
+
"""
|
|
124
|
+
all_models = cls.get_all_models()
|
|
125
|
+
flat_list: list[str] = []
|
|
126
|
+
for models in all_models.values():
|
|
127
|
+
flat_list.extend(models)
|
|
128
|
+
return flat_list
|
|
129
|
+
|
|
130
|
+
@classmethod
|
|
131
|
+
def is_provider_available(cls, provider_name: str) -> bool:
|
|
132
|
+
"""
|
|
133
|
+
Check if a provider is available and configured.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
provider_name: Name of the provider to check
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
True if provider is available and has models, False otherwise
|
|
140
|
+
"""
|
|
141
|
+
provider_name = provider_name.lower()
|
|
142
|
+
if provider_name not in cls.PROVIDERS:
|
|
143
|
+
return False
|
|
144
|
+
|
|
145
|
+
# Try to get models to verify provider is working
|
|
146
|
+
try:
|
|
147
|
+
provider_class = cls.PROVIDERS[provider_name]
|
|
148
|
+
models = provider_class.get_supported_models()
|
|
149
|
+
return len(models) > 0
|
|
150
|
+
except Exception:
|
|
151
|
+
return False
|
|
152
|
+
|
|
153
|
+
@classmethod
|
|
154
|
+
def get_provider_name(cls, model_string: str) -> str:
|
|
155
|
+
"""
|
|
156
|
+
Get the provider name for a given model string.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
model_string: Model string in format "provider:model" or just "model"
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
Provider name (e.g., "ollama", "google")
|
|
163
|
+
"""
|
|
164
|
+
provider_name, _ = cls.parse_model_name(model_string)
|
|
165
|
+
return provider_name
|
cli/utils.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
"""Utility constants and helpers for Kader CLI."""
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from .llm_factory import LLMProviderFactory
|
|
4
4
|
|
|
5
|
-
# Default model
|
|
6
|
-
DEFAULT_MODEL = "kimi-k2.5:cloud"
|
|
5
|
+
# Default model (with provider prefix for clarity)
|
|
6
|
+
DEFAULT_MODEL = "ollama:kimi-k2.5:cloud"
|
|
7
7
|
|
|
8
8
|
HELP_TEXT = """## Kader CLI Commands
|
|
9
9
|
|
|
@@ -40,24 +40,32 @@ HELP_TEXT = """## Kader CLI Commands
|
|
|
40
40
|
### Tips:
|
|
41
41
|
- Type any question to chat with the AI
|
|
42
42
|
- Use **Tab** to navigate between panels
|
|
43
|
+
- Model format: `provider:model` (e.g., `google:gemini-2.5-flash`)
|
|
43
44
|
"""
|
|
44
45
|
|
|
45
46
|
|
|
46
47
|
def get_models_text() -> str:
|
|
47
|
-
"""Get formatted text of available
|
|
48
|
+
"""Get formatted text of available models from all providers."""
|
|
48
49
|
try:
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
50
|
+
all_models = LLMProviderFactory.get_all_models()
|
|
51
|
+
flat_list = LLMProviderFactory.get_flat_model_list()
|
|
52
|
+
|
|
53
|
+
if not flat_list:
|
|
54
|
+
return "## Available Models (^^)\n\n*No models found. Check provider configurations.*"
|
|
52
55
|
|
|
53
56
|
lines = [
|
|
54
57
|
"## Available Models (^^)\n",
|
|
55
|
-
"| Model | Status |",
|
|
56
|
-
"
|
|
58
|
+
"| Provider | Model | Status |",
|
|
59
|
+
"|----------|-------|--------|",
|
|
57
60
|
]
|
|
58
|
-
for
|
|
59
|
-
|
|
61
|
+
for provider_name, provider_models in all_models.items():
|
|
62
|
+
for model in provider_models:
|
|
63
|
+
lines.append(f"| {provider_name.title()} | `{model}` | (+) Available |")
|
|
64
|
+
|
|
60
65
|
lines.append(f"\n*Currently using: **{DEFAULT_MODEL}***")
|
|
66
|
+
lines.append(
|
|
67
|
+
"\n> (!) Tip: Use `provider:model` format (e.g., `google:gemini-2.5-flash`)"
|
|
68
|
+
)
|
|
61
69
|
return "\n".join(lines)
|
|
62
70
|
except Exception as e:
|
|
63
71
|
return f"## Available Models (^^)\n\n*Error fetching models: {e}*"
|
kader/agent/base.py
CHANGED
|
@@ -23,7 +23,9 @@ from kader.providers.base import (
|
|
|
23
23
|
Message,
|
|
24
24
|
ModelConfig,
|
|
25
25
|
StreamChunk,
|
|
26
|
+
Usage,
|
|
26
27
|
)
|
|
28
|
+
from kader.providers.google import GoogleProvider
|
|
27
29
|
from kader.providers.ollama import OllamaProvider
|
|
28
30
|
from kader.tools import BaseTool, ToolRegistry
|
|
29
31
|
|
|
@@ -222,6 +224,8 @@ class BaseAgent:
|
|
|
222
224
|
provider_type = "openai"
|
|
223
225
|
if isinstance(self.provider, OllamaProvider):
|
|
224
226
|
provider_type = "ollama"
|
|
227
|
+
elif isinstance(self.provider, GoogleProvider):
|
|
228
|
+
provider_type = "google"
|
|
225
229
|
|
|
226
230
|
base_config = ModelConfig(
|
|
227
231
|
temperature=base_config.temperature,
|
|
@@ -624,7 +628,12 @@ class BaseAgent:
|
|
|
624
628
|
)
|
|
625
629
|
|
|
626
630
|
# estimate the cost...
|
|
627
|
-
|
|
631
|
+
usage_obj = Usage(
|
|
632
|
+
prompt_tokens=token_usage["prompt_tokens"],
|
|
633
|
+
completion_tokens=token_usage["completion_tokens"],
|
|
634
|
+
total_tokens=token_usage["total_tokens"],
|
|
635
|
+
)
|
|
636
|
+
estimated_cost = self.provider.estimate_cost(usage_obj)
|
|
628
637
|
|
|
629
638
|
# Calculate and log cost
|
|
630
639
|
agent_logger.calculate_cost(
|
|
@@ -796,7 +805,12 @@ class BaseAgent:
|
|
|
796
805
|
)
|
|
797
806
|
|
|
798
807
|
# estimate the cost...
|
|
799
|
-
|
|
808
|
+
usage_obj = Usage(
|
|
809
|
+
prompt_tokens=token_usage["prompt_tokens"],
|
|
810
|
+
completion_tokens=token_usage["completion_tokens"],
|
|
811
|
+
total_tokens=token_usage["total_tokens"],
|
|
812
|
+
)
|
|
813
|
+
estimated_cost = self.provider.estimate_cost(usage_obj)
|
|
800
814
|
|
|
801
815
|
# Calculate and log cost
|
|
802
816
|
agent_logger.calculate_cost(
|
kader/config.py
CHANGED
|
@@ -69,13 +69,21 @@ def ensure_kader_directory():
|
|
|
69
69
|
def ensure_env_file(kader_dir):
|
|
70
70
|
"""
|
|
71
71
|
Ensure that the .env file exists in the .kader directory with the
|
|
72
|
-
required
|
|
72
|
+
required API key configurations.
|
|
73
73
|
"""
|
|
74
74
|
env_file = kader_dir / ".env"
|
|
75
75
|
|
|
76
76
|
# Create the .env file if it doesn't exist
|
|
77
77
|
if not env_file.exists():
|
|
78
|
-
|
|
78
|
+
default_env_content = """# Kader Configuration
|
|
79
|
+
# Ollama API Key (for local Ollama models)
|
|
80
|
+
OLLAMA_API_KEY=''
|
|
81
|
+
|
|
82
|
+
# Google Gemini API Key (for Google Gemini models)
|
|
83
|
+
# Get your API key from: https://aistudio.google.com/apikey
|
|
84
|
+
GEMINI_API_KEY=''
|
|
85
|
+
"""
|
|
86
|
+
env_file.write_text(default_env_content, encoding="utf-8")
|
|
79
87
|
|
|
80
88
|
# Set appropriate permissions for the .env file on Unix-like systems
|
|
81
89
|
if not sys.platform.startswith("win"):
|
kader/providers/__init__.py
CHANGED
|
@@ -0,0 +1,690 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Google LLM Provider implementation.
|
|
3
|
+
|
|
4
|
+
Provides synchronous and asynchronous access to Google Gemini models
|
|
5
|
+
via the Google Gen AI SDK.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
from typing import AsyncIterator, Iterator
|
|
10
|
+
|
|
11
|
+
from google import genai
|
|
12
|
+
from google.genai import types
|
|
13
|
+
|
|
14
|
+
# Import config to ensure ~/.kader/.env is loaded
|
|
15
|
+
import kader.config # noqa: F401
|
|
16
|
+
|
|
17
|
+
from .base import (
|
|
18
|
+
BaseLLMProvider,
|
|
19
|
+
CostInfo,
|
|
20
|
+
LLMResponse,
|
|
21
|
+
Message,
|
|
22
|
+
ModelConfig,
|
|
23
|
+
ModelInfo,
|
|
24
|
+
ModelPricing,
|
|
25
|
+
StreamChunk,
|
|
26
|
+
Usage,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
# Pricing data for Gemini models (per 1M tokens, in USD)
|
|
30
|
+
GEMINI_PRICING: dict[str, ModelPricing] = {
|
|
31
|
+
"gemini-2.5-flash": ModelPricing(
|
|
32
|
+
input_cost_per_million=0.15,
|
|
33
|
+
output_cost_per_million=0.60,
|
|
34
|
+
cached_input_cost_per_million=0.0375,
|
|
35
|
+
),
|
|
36
|
+
"gemini-2.5-flash-preview-05-20": ModelPricing(
|
|
37
|
+
input_cost_per_million=0.15,
|
|
38
|
+
output_cost_per_million=0.60,
|
|
39
|
+
cached_input_cost_per_million=0.0375,
|
|
40
|
+
),
|
|
41
|
+
"gemini-2.5-pro": ModelPricing(
|
|
42
|
+
input_cost_per_million=1.25,
|
|
43
|
+
output_cost_per_million=10.00,
|
|
44
|
+
cached_input_cost_per_million=0.3125,
|
|
45
|
+
),
|
|
46
|
+
"gemini-2.5-pro-preview-05-06": ModelPricing(
|
|
47
|
+
input_cost_per_million=1.25,
|
|
48
|
+
output_cost_per_million=10.00,
|
|
49
|
+
cached_input_cost_per_million=0.3125,
|
|
50
|
+
),
|
|
51
|
+
"gemini-2.0-flash": ModelPricing(
|
|
52
|
+
input_cost_per_million=0.10,
|
|
53
|
+
output_cost_per_million=0.40,
|
|
54
|
+
cached_input_cost_per_million=0.025,
|
|
55
|
+
),
|
|
56
|
+
"gemini-2.0-flash-lite": ModelPricing(
|
|
57
|
+
input_cost_per_million=0.075,
|
|
58
|
+
output_cost_per_million=0.30,
|
|
59
|
+
cached_input_cost_per_million=0.01875,
|
|
60
|
+
),
|
|
61
|
+
"gemini-1.5-flash": ModelPricing(
|
|
62
|
+
input_cost_per_million=0.075,
|
|
63
|
+
output_cost_per_million=0.30,
|
|
64
|
+
cached_input_cost_per_million=0.01875,
|
|
65
|
+
),
|
|
66
|
+
"gemini-1.5-pro": ModelPricing(
|
|
67
|
+
input_cost_per_million=1.25,
|
|
68
|
+
output_cost_per_million=5.00,
|
|
69
|
+
cached_input_cost_per_million=0.3125,
|
|
70
|
+
),
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class GoogleProvider(BaseLLMProvider):
|
|
75
|
+
"""
|
|
76
|
+
Google LLM Provider.
|
|
77
|
+
|
|
78
|
+
Provides access to Google Gemini models with full support
|
|
79
|
+
for synchronous and asynchronous operations, including streaming.
|
|
80
|
+
|
|
81
|
+
The API key is loaded from (in order of priority):
|
|
82
|
+
1. The `api_key` parameter passed to the constructor
|
|
83
|
+
2. The GEMINI_API_KEY environment variable (loaded from ~/.kader/.env)
|
|
84
|
+
3. The GOOGLE_API_KEY environment variable
|
|
85
|
+
|
|
86
|
+
Example:
|
|
87
|
+
provider = GoogleProvider(model="gemini-2.5-flash")
|
|
88
|
+
response = provider.invoke([Message.user("Hello!")])
|
|
89
|
+
print(response.content)
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(
|
|
93
|
+
self,
|
|
94
|
+
model: str,
|
|
95
|
+
api_key: str | None = None,
|
|
96
|
+
default_config: ModelConfig | None = None,
|
|
97
|
+
) -> None:
|
|
98
|
+
"""
|
|
99
|
+
Initialize the Google provider.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
model: The Gemini model identifier (e.g., "gemini-2.5-flash")
|
|
103
|
+
api_key: Optional API key. If not provided, uses GEMINI_API_KEY
|
|
104
|
+
from ~/.kader/.env or GOOGLE_API_KEY environment variable.
|
|
105
|
+
default_config: Default configuration for all requests
|
|
106
|
+
"""
|
|
107
|
+
super().__init__(model=model, default_config=default_config)
|
|
108
|
+
|
|
109
|
+
# Resolve API key: parameter > GEMINI_API_KEY > GOOGLE_API_KEY
|
|
110
|
+
if api_key is None:
|
|
111
|
+
api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get(
|
|
112
|
+
"GOOGLE_API_KEY"
|
|
113
|
+
)
|
|
114
|
+
# Filter out empty strings from the .env default
|
|
115
|
+
if api_key == "":
|
|
116
|
+
api_key = None
|
|
117
|
+
|
|
118
|
+
self._api_key = api_key
|
|
119
|
+
self._client = genai.Client(api_key=api_key) if api_key else genai.Client()
|
|
120
|
+
|
|
121
|
+
def _convert_messages(
|
|
122
|
+
self, messages: list[Message]
|
|
123
|
+
) -> tuple[list[types.Content], str | None]:
|
|
124
|
+
"""
|
|
125
|
+
Convert Message objects to Google GenAI Content format.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Tuple of (contents list, system_instruction if present)
|
|
129
|
+
"""
|
|
130
|
+
contents: list[types.Content] = []
|
|
131
|
+
system_instruction: str | None = None
|
|
132
|
+
|
|
133
|
+
for msg in messages:
|
|
134
|
+
if msg.role == "system":
|
|
135
|
+
# System messages are handled separately in Google's API
|
|
136
|
+
system_instruction = msg.content
|
|
137
|
+
elif msg.role == "user":
|
|
138
|
+
contents.append(
|
|
139
|
+
types.Content(
|
|
140
|
+
role="user",
|
|
141
|
+
parts=[types.Part.from_text(text=msg.content)],
|
|
142
|
+
)
|
|
143
|
+
)
|
|
144
|
+
elif msg.role == "assistant":
|
|
145
|
+
parts: list[types.Part] = []
|
|
146
|
+
if msg.content:
|
|
147
|
+
parts.append(types.Part.from_text(text=msg.content))
|
|
148
|
+
if msg.tool_calls:
|
|
149
|
+
for tc in msg.tool_calls:
|
|
150
|
+
parts.append(
|
|
151
|
+
types.Part.from_function_call(
|
|
152
|
+
name=tc["function"]["name"],
|
|
153
|
+
args=tc["function"]["arguments"]
|
|
154
|
+
if isinstance(tc["function"]["arguments"], dict)
|
|
155
|
+
else {},
|
|
156
|
+
)
|
|
157
|
+
)
|
|
158
|
+
contents.append(types.Content(role="model", parts=parts))
|
|
159
|
+
elif msg.role == "tool":
|
|
160
|
+
contents.append(
|
|
161
|
+
types.Content(
|
|
162
|
+
role="tool",
|
|
163
|
+
parts=[
|
|
164
|
+
types.Part.from_function_response(
|
|
165
|
+
name=msg.name or "tool",
|
|
166
|
+
response={"result": msg.content},
|
|
167
|
+
)
|
|
168
|
+
],
|
|
169
|
+
)
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
return contents, system_instruction
|
|
173
|
+
|
|
174
|
+
def _convert_config_to_generate_config(
|
|
175
|
+
self, config: ModelConfig, system_instruction: str | None = None
|
|
176
|
+
) -> types.GenerateContentConfig:
|
|
177
|
+
"""Convert ModelConfig to Google GenerateContentConfig."""
|
|
178
|
+
generate_config = types.GenerateContentConfig(
|
|
179
|
+
temperature=config.temperature if config.temperature != 1.0 else None,
|
|
180
|
+
max_output_tokens=config.max_tokens,
|
|
181
|
+
top_p=config.top_p if config.top_p != 1.0 else None,
|
|
182
|
+
top_k=config.top_k,
|
|
183
|
+
stop_sequences=config.stop_sequences,
|
|
184
|
+
system_instruction=system_instruction,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Handle tools - convert from dict format to Google's FunctionDeclaration format
|
|
188
|
+
if config.tools:
|
|
189
|
+
google_tools = self._convert_tools_to_google_format(config.tools)
|
|
190
|
+
if google_tools:
|
|
191
|
+
generate_config.tools = google_tools
|
|
192
|
+
|
|
193
|
+
# Handle response format
|
|
194
|
+
if config.response_format:
|
|
195
|
+
resp_format_type = config.response_format.get("type")
|
|
196
|
+
if resp_format_type == "json_object":
|
|
197
|
+
generate_config.response_mime_type = "application/json"
|
|
198
|
+
|
|
199
|
+
return generate_config
|
|
200
|
+
|
|
201
|
+
def _convert_tools_to_google_format(
|
|
202
|
+
self, tools: list[dict]
|
|
203
|
+
) -> list[types.Tool] | None:
|
|
204
|
+
"""
|
|
205
|
+
Convert tool definitions from dict format to Google's Tool format.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
tools: List of tool definitions (from to_google_format or to_openai_format)
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
List of Google Tool objects, or None if no valid tools
|
|
212
|
+
"""
|
|
213
|
+
if not tools:
|
|
214
|
+
return None
|
|
215
|
+
|
|
216
|
+
function_declarations: list[types.FunctionDeclaration] = []
|
|
217
|
+
|
|
218
|
+
for tool in tools:
|
|
219
|
+
# Handle OpenAI format (type: "function", function: {...})
|
|
220
|
+
if tool.get("type") == "function" and "function" in tool:
|
|
221
|
+
func_def = tool["function"]
|
|
222
|
+
name = func_def.get("name", "")
|
|
223
|
+
description = func_def.get("description", "")
|
|
224
|
+
parameters = func_def.get("parameters", {})
|
|
225
|
+
# Handle Google format (directly has name, description, parameters)
|
|
226
|
+
elif "name" in tool:
|
|
227
|
+
name = tool.get("name", "")
|
|
228
|
+
description = tool.get("description", "")
|
|
229
|
+
parameters = tool.get("parameters", {})
|
|
230
|
+
else:
|
|
231
|
+
continue
|
|
232
|
+
|
|
233
|
+
if not name:
|
|
234
|
+
continue
|
|
235
|
+
|
|
236
|
+
# Create FunctionDeclaration
|
|
237
|
+
try:
|
|
238
|
+
func_decl = types.FunctionDeclaration(
|
|
239
|
+
name=name,
|
|
240
|
+
description=description,
|
|
241
|
+
parameters=parameters if parameters else None,
|
|
242
|
+
)
|
|
243
|
+
function_declarations.append(func_decl)
|
|
244
|
+
except Exception:
|
|
245
|
+
# Skip invalid function declarations
|
|
246
|
+
continue
|
|
247
|
+
|
|
248
|
+
if not function_declarations:
|
|
249
|
+
return None
|
|
250
|
+
|
|
251
|
+
# Wrap all function declarations in a single Tool
|
|
252
|
+
return [types.Tool(function_declarations=function_declarations)]
|
|
253
|
+
|
|
254
|
+
def _parse_response(self, response, model: str) -> LLMResponse:
|
|
255
|
+
"""Parse Google GenAI response to LLMResponse."""
|
|
256
|
+
# Extract content
|
|
257
|
+
content = ""
|
|
258
|
+
tool_calls = None
|
|
259
|
+
|
|
260
|
+
if response.candidates and len(response.candidates) > 0:
|
|
261
|
+
candidate = response.candidates[0]
|
|
262
|
+
if candidate.content and candidate.content.parts:
|
|
263
|
+
text_parts = []
|
|
264
|
+
function_calls = []
|
|
265
|
+
|
|
266
|
+
for part in candidate.content.parts:
|
|
267
|
+
if hasattr(part, "text") and part.text:
|
|
268
|
+
text_parts.append(part.text)
|
|
269
|
+
if hasattr(part, "function_call") and part.function_call:
|
|
270
|
+
fc = part.function_call
|
|
271
|
+
function_calls.append(
|
|
272
|
+
{
|
|
273
|
+
"id": f"call_{len(function_calls)}",
|
|
274
|
+
"type": "function",
|
|
275
|
+
"function": {
|
|
276
|
+
"name": fc.name,
|
|
277
|
+
"arguments": dict(fc.args) if fc.args else {},
|
|
278
|
+
},
|
|
279
|
+
}
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
content = "".join(text_parts)
|
|
283
|
+
if function_calls:
|
|
284
|
+
tool_calls = function_calls
|
|
285
|
+
|
|
286
|
+
# Extract usage
|
|
287
|
+
usage = Usage()
|
|
288
|
+
if hasattr(response, "usage_metadata") and response.usage_metadata:
|
|
289
|
+
usage = Usage(
|
|
290
|
+
prompt_tokens=getattr(response.usage_metadata, "prompt_token_count", 0)
|
|
291
|
+
or 0,
|
|
292
|
+
completion_tokens=getattr(
|
|
293
|
+
response.usage_metadata, "candidates_token_count", 0
|
|
294
|
+
)
|
|
295
|
+
or 0,
|
|
296
|
+
cached_tokens=getattr(
|
|
297
|
+
response.usage_metadata, "cached_content_token_count", 0
|
|
298
|
+
)
|
|
299
|
+
or 0,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
# Determine finish reason
|
|
303
|
+
finish_reason = "stop"
|
|
304
|
+
if response.candidates and len(response.candidates) > 0:
|
|
305
|
+
candidate = response.candidates[0]
|
|
306
|
+
if hasattr(candidate, "finish_reason") and candidate.finish_reason:
|
|
307
|
+
reason = str(candidate.finish_reason).lower()
|
|
308
|
+
if "stop" in reason:
|
|
309
|
+
finish_reason = "stop"
|
|
310
|
+
elif "length" in reason or "max_tokens" in reason:
|
|
311
|
+
finish_reason = "length"
|
|
312
|
+
elif "tool" in reason or "function" in reason:
|
|
313
|
+
finish_reason = "tool_calls"
|
|
314
|
+
elif "safety" in reason or "filter" in reason:
|
|
315
|
+
finish_reason = "content_filter"
|
|
316
|
+
|
|
317
|
+
# Calculate cost
|
|
318
|
+
cost = self.estimate_cost(usage)
|
|
319
|
+
|
|
320
|
+
return LLMResponse(
|
|
321
|
+
content=content,
|
|
322
|
+
model=model,
|
|
323
|
+
usage=usage,
|
|
324
|
+
finish_reason=finish_reason,
|
|
325
|
+
cost=cost,
|
|
326
|
+
tool_calls=tool_calls,
|
|
327
|
+
raw_response=response,
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
def _parse_stream_chunk(
|
|
331
|
+
self, chunk, accumulated_content: str, model: str
|
|
332
|
+
) -> StreamChunk:
|
|
333
|
+
"""Parse streaming chunk to StreamChunk."""
|
|
334
|
+
delta = ""
|
|
335
|
+
tool_calls = None
|
|
336
|
+
|
|
337
|
+
if chunk.candidates and len(chunk.candidates) > 0:
|
|
338
|
+
candidate = chunk.candidates[0]
|
|
339
|
+
if candidate.content and candidate.content.parts:
|
|
340
|
+
for part in candidate.content.parts:
|
|
341
|
+
if hasattr(part, "text") and part.text:
|
|
342
|
+
delta = part.text
|
|
343
|
+
if hasattr(part, "function_call") and part.function_call:
|
|
344
|
+
fc = part.function_call
|
|
345
|
+
tool_calls = [
|
|
346
|
+
{
|
|
347
|
+
"id": "call_0",
|
|
348
|
+
"type": "function",
|
|
349
|
+
"function": {
|
|
350
|
+
"name": fc.name,
|
|
351
|
+
"arguments": dict(fc.args) if fc.args else {},
|
|
352
|
+
},
|
|
353
|
+
}
|
|
354
|
+
]
|
|
355
|
+
|
|
356
|
+
# Extract usage from final chunk
|
|
357
|
+
usage = None
|
|
358
|
+
if hasattr(chunk, "usage_metadata") and chunk.usage_metadata:
|
|
359
|
+
usage = Usage(
|
|
360
|
+
prompt_tokens=getattr(chunk.usage_metadata, "prompt_token_count", 0)
|
|
361
|
+
or 0,
|
|
362
|
+
completion_tokens=getattr(
|
|
363
|
+
chunk.usage_metadata, "candidates_token_count", 0
|
|
364
|
+
)
|
|
365
|
+
or 0,
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
# Determine finish reason
|
|
369
|
+
finish_reason = None
|
|
370
|
+
if chunk.candidates and len(chunk.candidates) > 0:
|
|
371
|
+
candidate = chunk.candidates[0]
|
|
372
|
+
if hasattr(candidate, "finish_reason") and candidate.finish_reason:
|
|
373
|
+
reason = str(candidate.finish_reason).lower()
|
|
374
|
+
if "stop" in reason:
|
|
375
|
+
finish_reason = "stop"
|
|
376
|
+
elif "length" in reason:
|
|
377
|
+
finish_reason = "length"
|
|
378
|
+
|
|
379
|
+
return StreamChunk(
|
|
380
|
+
content=accumulated_content + delta,
|
|
381
|
+
delta=delta,
|
|
382
|
+
finish_reason=finish_reason,
|
|
383
|
+
usage=usage,
|
|
384
|
+
tool_calls=tool_calls,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
# -------------------------------------------------------------------------
|
|
388
|
+
# Synchronous Methods
|
|
389
|
+
# -------------------------------------------------------------------------
|
|
390
|
+
|
|
391
|
+
def invoke(
|
|
392
|
+
self,
|
|
393
|
+
messages: list[Message],
|
|
394
|
+
config: ModelConfig | None = None,
|
|
395
|
+
) -> LLMResponse:
|
|
396
|
+
"""
|
|
397
|
+
Synchronously invoke the Google Gemini model.
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
messages: List of messages in the conversation
|
|
401
|
+
config: Optional configuration overrides
|
|
402
|
+
|
|
403
|
+
Returns:
|
|
404
|
+
LLMResponse with the model's response
|
|
405
|
+
"""
|
|
406
|
+
merged_config = self._merge_config(config)
|
|
407
|
+
contents, system_instruction = self._convert_messages(messages)
|
|
408
|
+
generate_config = self._convert_config_to_generate_config(
|
|
409
|
+
merged_config, system_instruction
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
response = self._client.models.generate_content(
|
|
413
|
+
model=self._model,
|
|
414
|
+
contents=contents,
|
|
415
|
+
config=generate_config,
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
llm_response = self._parse_response(response, self._model)
|
|
419
|
+
self._update_tracking(llm_response)
|
|
420
|
+
return llm_response
|
|
421
|
+
|
|
422
|
+
def stream(
|
|
423
|
+
self,
|
|
424
|
+
messages: list[Message],
|
|
425
|
+
config: ModelConfig | None = None,
|
|
426
|
+
) -> Iterator[StreamChunk]:
|
|
427
|
+
"""
|
|
428
|
+
Synchronously stream the Google Gemini model response.
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
messages: List of messages in the conversation
|
|
432
|
+
config: Optional configuration overrides
|
|
433
|
+
|
|
434
|
+
Yields:
|
|
435
|
+
StreamChunk objects as they arrive
|
|
436
|
+
"""
|
|
437
|
+
merged_config = self._merge_config(config)
|
|
438
|
+
contents, system_instruction = self._convert_messages(messages)
|
|
439
|
+
generate_config = self._convert_config_to_generate_config(
|
|
440
|
+
merged_config, system_instruction
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
response_stream = self._client.models.generate_content_stream(
|
|
444
|
+
model=self._model,
|
|
445
|
+
contents=contents,
|
|
446
|
+
config=generate_config,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
accumulated_content = ""
|
|
450
|
+
for chunk in response_stream:
|
|
451
|
+
stream_chunk = self._parse_stream_chunk(
|
|
452
|
+
chunk, accumulated_content, self._model
|
|
453
|
+
)
|
|
454
|
+
accumulated_content = stream_chunk.content
|
|
455
|
+
yield stream_chunk
|
|
456
|
+
|
|
457
|
+
# Update tracking on final chunk
|
|
458
|
+
if stream_chunk.is_final and stream_chunk.usage:
|
|
459
|
+
final_response = LLMResponse(
|
|
460
|
+
content=accumulated_content,
|
|
461
|
+
model=self._model,
|
|
462
|
+
usage=stream_chunk.usage,
|
|
463
|
+
finish_reason=stream_chunk.finish_reason,
|
|
464
|
+
cost=self.estimate_cost(stream_chunk.usage),
|
|
465
|
+
)
|
|
466
|
+
self._update_tracking(final_response)
|
|
467
|
+
|
|
468
|
+
# -------------------------------------------------------------------------
|
|
469
|
+
# Asynchronous Methods
|
|
470
|
+
# -------------------------------------------------------------------------
|
|
471
|
+
|
|
472
|
+
async def ainvoke(
|
|
473
|
+
self,
|
|
474
|
+
messages: list[Message],
|
|
475
|
+
config: ModelConfig | None = None,
|
|
476
|
+
) -> LLMResponse:
|
|
477
|
+
"""
|
|
478
|
+
Asynchronously invoke the Google Gemini model.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
messages: List of messages in the conversation
|
|
482
|
+
config: Optional configuration overrides
|
|
483
|
+
|
|
484
|
+
Returns:
|
|
485
|
+
LLMResponse with the model's response
|
|
486
|
+
"""
|
|
487
|
+
merged_config = self._merge_config(config)
|
|
488
|
+
contents, system_instruction = self._convert_messages(messages)
|
|
489
|
+
generate_config = self._convert_config_to_generate_config(
|
|
490
|
+
merged_config, system_instruction
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
response = await self._client.aio.models.generate_content(
|
|
494
|
+
model=self._model,
|
|
495
|
+
contents=contents,
|
|
496
|
+
config=generate_config,
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
llm_response = self._parse_response(response, self._model)
|
|
500
|
+
self._update_tracking(llm_response)
|
|
501
|
+
return llm_response
|
|
502
|
+
|
|
503
|
+
async def astream(
|
|
504
|
+
self,
|
|
505
|
+
messages: list[Message],
|
|
506
|
+
config: ModelConfig | None = None,
|
|
507
|
+
) -> AsyncIterator[StreamChunk]:
|
|
508
|
+
"""
|
|
509
|
+
Asynchronously stream the Google Gemini model response.
|
|
510
|
+
|
|
511
|
+
Args:
|
|
512
|
+
messages: List of messages in the conversation
|
|
513
|
+
config: Optional configuration overrides
|
|
514
|
+
|
|
515
|
+
Yields:
|
|
516
|
+
StreamChunk objects as they arrive
|
|
517
|
+
"""
|
|
518
|
+
merged_config = self._merge_config(config)
|
|
519
|
+
contents, system_instruction = self._convert_messages(messages)
|
|
520
|
+
generate_config = self._convert_config_to_generate_config(
|
|
521
|
+
merged_config, system_instruction
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
response_stream = await self._client.aio.models.generate_content_stream(
|
|
525
|
+
model=self._model,
|
|
526
|
+
contents=contents,
|
|
527
|
+
config=generate_config,
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
accumulated_content = ""
|
|
531
|
+
async for chunk in response_stream:
|
|
532
|
+
stream_chunk = self._parse_stream_chunk(
|
|
533
|
+
chunk, accumulated_content, self._model
|
|
534
|
+
)
|
|
535
|
+
accumulated_content = stream_chunk.content
|
|
536
|
+
yield stream_chunk
|
|
537
|
+
|
|
538
|
+
# Update tracking on final chunk
|
|
539
|
+
if stream_chunk.is_final and stream_chunk.usage:
|
|
540
|
+
final_response = LLMResponse(
|
|
541
|
+
content=accumulated_content,
|
|
542
|
+
model=self._model,
|
|
543
|
+
usage=stream_chunk.usage,
|
|
544
|
+
finish_reason=stream_chunk.finish_reason,
|
|
545
|
+
cost=self.estimate_cost(stream_chunk.usage),
|
|
546
|
+
)
|
|
547
|
+
self._update_tracking(final_response)
|
|
548
|
+
|
|
549
|
+
# -------------------------------------------------------------------------
|
|
550
|
+
# Token & Cost Methods
|
|
551
|
+
# -------------------------------------------------------------------------
|
|
552
|
+
|
|
553
|
+
def count_tokens(
|
|
554
|
+
self,
|
|
555
|
+
text: str | list[Message],
|
|
556
|
+
) -> int:
|
|
557
|
+
"""
|
|
558
|
+
Count the number of tokens in the given text or messages.
|
|
559
|
+
|
|
560
|
+
Args:
|
|
561
|
+
text: A string or list of messages to count tokens for
|
|
562
|
+
|
|
563
|
+
Returns:
|
|
564
|
+
Number of tokens
|
|
565
|
+
"""
|
|
566
|
+
try:
|
|
567
|
+
if isinstance(text, str):
|
|
568
|
+
response = self._client.models.count_tokens(
|
|
569
|
+
model=self._model,
|
|
570
|
+
contents=text,
|
|
571
|
+
)
|
|
572
|
+
else:
|
|
573
|
+
contents, _ = self._convert_messages(text)
|
|
574
|
+
response = self._client.models.count_tokens(
|
|
575
|
+
model=self._model,
|
|
576
|
+
contents=contents,
|
|
577
|
+
)
|
|
578
|
+
return getattr(response, "total_tokens", 0) or 0
|
|
579
|
+
except Exception:
|
|
580
|
+
# Fallback to character-based estimation
|
|
581
|
+
if isinstance(text, str):
|
|
582
|
+
return len(text) // 4
|
|
583
|
+
else:
|
|
584
|
+
total_chars = sum(len(msg.content) for msg in text)
|
|
585
|
+
return total_chars // 4
|
|
586
|
+
|
|
587
|
+
def estimate_cost(
|
|
588
|
+
self,
|
|
589
|
+
usage: Usage,
|
|
590
|
+
) -> CostInfo:
|
|
591
|
+
"""
|
|
592
|
+
Estimate the cost for the given token usage.
|
|
593
|
+
|
|
594
|
+
Args:
|
|
595
|
+
usage: Token usage information
|
|
596
|
+
|
|
597
|
+
Returns:
|
|
598
|
+
CostInfo with cost breakdown
|
|
599
|
+
"""
|
|
600
|
+
# Try to find exact pricing, then fall back to base model name
|
|
601
|
+
pricing = GEMINI_PRICING.get(self._model)
|
|
602
|
+
|
|
603
|
+
if not pricing:
|
|
604
|
+
# Try to match by prefix (e.g., "gemini-2.5-flash-preview" -> "gemini-2.5-flash")
|
|
605
|
+
for model_prefix, model_pricing in GEMINI_PRICING.items():
|
|
606
|
+
if self._model.startswith(model_prefix):
|
|
607
|
+
pricing = model_pricing
|
|
608
|
+
break
|
|
609
|
+
|
|
610
|
+
if not pricing:
|
|
611
|
+
# Default to gemini-2.5-flash pricing if unknown model
|
|
612
|
+
pricing = GEMINI_PRICING.get(
|
|
613
|
+
"gemini-2.5-flash",
|
|
614
|
+
ModelPricing(
|
|
615
|
+
input_cost_per_million=0.15,
|
|
616
|
+
output_cost_per_million=0.60,
|
|
617
|
+
),
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
return pricing.calculate_cost(usage)
|
|
621
|
+
|
|
622
|
+
# -------------------------------------------------------------------------
|
|
623
|
+
# Utility Methods
|
|
624
|
+
# -------------------------------------------------------------------------
|
|
625
|
+
|
|
626
|
+
def get_model_info(self) -> ModelInfo | None:
|
|
627
|
+
"""Get information about the current model."""
|
|
628
|
+
try:
|
|
629
|
+
model_info = self._client.models.get(model=self._model)
|
|
630
|
+
|
|
631
|
+
return ModelInfo(
|
|
632
|
+
name=self._model,
|
|
633
|
+
provider="google",
|
|
634
|
+
context_window=getattr(model_info, "input_token_limit", 0) or 128000,
|
|
635
|
+
max_output_tokens=getattr(model_info, "output_token_limit", None),
|
|
636
|
+
pricing=GEMINI_PRICING.get(self._model),
|
|
637
|
+
supports_tools=True,
|
|
638
|
+
supports_streaming=True,
|
|
639
|
+
supports_json_mode=True,
|
|
640
|
+
supports_vision=True,
|
|
641
|
+
capabilities={
|
|
642
|
+
"display_name": getattr(model_info, "display_name", None),
|
|
643
|
+
"description": getattr(model_info, "description", None),
|
|
644
|
+
},
|
|
645
|
+
)
|
|
646
|
+
except Exception:
|
|
647
|
+
return None
|
|
648
|
+
|
|
649
|
+
@classmethod
|
|
650
|
+
def get_supported_models(cls, api_key: str | None = None) -> list[str]:
|
|
651
|
+
"""
|
|
652
|
+
Get list of models available from Google.
|
|
653
|
+
|
|
654
|
+
Args:
|
|
655
|
+
api_key: Optional API key
|
|
656
|
+
|
|
657
|
+
Returns:
|
|
658
|
+
List of available model names that support generation
|
|
659
|
+
"""
|
|
660
|
+
try:
|
|
661
|
+
client = genai.Client(api_key=api_key) if api_key else genai.Client()
|
|
662
|
+
models = []
|
|
663
|
+
|
|
664
|
+
for model in client.models.list():
|
|
665
|
+
model_name = getattr(model, "name", "")
|
|
666
|
+
# Filter to only include gemini models that support generateContent
|
|
667
|
+
if model_name and "gemini" in model_name.lower():
|
|
668
|
+
supported_methods = getattr(
|
|
669
|
+
model, "supported_generation_methods", []
|
|
670
|
+
)
|
|
671
|
+
if supported_methods is None:
|
|
672
|
+
supported_methods = []
|
|
673
|
+
# Include models that support content generation
|
|
674
|
+
if (
|
|
675
|
+
any("generateContent" in method for method in supported_methods)
|
|
676
|
+
or not supported_methods
|
|
677
|
+
):
|
|
678
|
+
# Extract just the model ID from full path
|
|
679
|
+
# e.g., "models/gemini-2.5-flash" -> "gemini-2.5-flash"
|
|
680
|
+
if "/" in model_name:
|
|
681
|
+
model_name = model_name.split("/")[-1]
|
|
682
|
+
models.append(model_name)
|
|
683
|
+
|
|
684
|
+
return models
|
|
685
|
+
except Exception:
|
|
686
|
+
return []
|
|
687
|
+
|
|
688
|
+
def list_models(self) -> list[str]:
|
|
689
|
+
"""List all available Gemini models."""
|
|
690
|
+
return self.get_supported_models(self._api_key)
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: kader
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.1.0
|
|
4
4
|
Summary: kader coding agent
|
|
5
5
|
Requires-Python: >=3.11
|
|
6
6
|
Requires-Dist: aiofiles>=25.1.0
|
|
7
7
|
Requires-Dist: faiss-cpu>=1.9.0
|
|
8
|
+
Requires-Dist: google-genai>=1.61.0
|
|
8
9
|
Requires-Dist: jinja2>=3.1.6
|
|
9
10
|
Requires-Dist: loguru>=0.7.3
|
|
10
11
|
Requires-Dist: ollama>=0.6.1
|
|
@@ -1,18 +1,19 @@
|
|
|
1
1
|
cli/README.md,sha256=DY3X7w6LPka1GzhtTrGwhpkFmx0YyRpcTCHjFmti3Yg,4654
|
|
2
2
|
cli/__init__.py,sha256=OAi_KSwcuYXR0sRxKuw1DYQrz1jbu8p7vn41_99f36I,107
|
|
3
3
|
cli/__main__.py,sha256=xO2JVjCsh691b-cjSBAEKocJeUeI3P0gfUqM-f1Mp1A,95
|
|
4
|
-
cli/app.py,sha256=
|
|
4
|
+
cli/app.py,sha256=3JwSr4224k2g7DUrWN3IGkCbzNzYDJdKBoIOfUCDeJM,28313
|
|
5
5
|
cli/app.tcss,sha256=szNaXxCEo0QfFyM1klPB21GzNaXV3wCxTbDm9e71ioA,4488
|
|
6
|
-
cli/
|
|
6
|
+
cli/llm_factory.py,sha256=wXZtCgf2yBeYYKBwbR2WqSbiYxB_26bTMOqcKlbOACE,5214
|
|
7
|
+
cli/utils.py,sha256=y4unmXrANLk-tTEumotq0wp-sreBlLVYKRxz4BMyWVM,2178
|
|
7
8
|
cli/widgets/__init__.py,sha256=1vj31CrJyxZROLthkKr79i_GbNyj8g3q60ZQPbJHK5k,300
|
|
8
9
|
cli/widgets/confirmation.py,sha256=7hXqGyhW5V9fmtjgiWR4z2fJWmKxWhUH9RigqDrTKp4,9396
|
|
9
10
|
cli/widgets/conversation.py,sha256=n99b9wjgrw4WTbWX4entK2Jx4xcP-n-F0KPJXC4w2oM,2720
|
|
10
11
|
cli/widgets/loading.py,sha256=wlhQ47ppSj8vCEqjrbG2mk1yKnfo8dWC5429Z2q1-0g,1689
|
|
11
12
|
kader/__init__.py,sha256=lv08nSC3h5YLdBU4WqXMz2YHHojy7mcBPMbfP251Rjo,654
|
|
12
|
-
kader/config.py,sha256=
|
|
13
|
+
kader/config.py,sha256=ra0VAUnbuo4rvIuMIqq3G9Kg5YYpzrc1Wp9bBMED9vo,4503
|
|
13
14
|
kader/agent/__init__.py,sha256=UJzUw9NIzggCrhIBHC6nJnfzkhCjCZnIzmD6uUn2SNA,159
|
|
14
15
|
kader/agent/agents.py,sha256=qG594bZ71tbTshwhSrbKOIpkSz1y3FjcLztxPKJRrfE,4837
|
|
15
|
-
kader/agent/base.py,sha256=
|
|
16
|
+
kader/agent/base.py,sha256=uqLv4WEzEw9Nga9e_B4i2oriu9S3XquI4GfOPIKEaVI,40163
|
|
16
17
|
kader/agent/logger.py,sha256=3vFwz_yycSBU-5mcdalfZ3KBVT9P_20Q-WT5Al8yIXo,5796
|
|
17
18
|
kader/memory/__init__.py,sha256=VUzzhGOWvO_2aYB6uuavmtNI8l94K7H3uPn4_1MVUUs,1473
|
|
18
19
|
kader/memory/conversation.py,sha256=h6Bamd8_rYnk0Bwt4MJWZRfv2wxCcg6eUxPvzP-tIyA,11810
|
|
@@ -26,8 +27,9 @@ kader/prompts/templates/executor_agent.j2,sha256=YtKH2LBbY4FrGxam-3Q0YPnnNnLxcV_
|
|
|
26
27
|
kader/prompts/templates/kader_planner.j2,sha256=ONpeuu6OvNuxv8d6zrjYSF1QFPoIDFBCqe7P0RmcD-I,3429
|
|
27
28
|
kader/prompts/templates/planning_agent.j2,sha256=Uc4SnMPv4vKWchhO0RLRNjbEio5CVlRgqDJG_dgM2Pk,1315
|
|
28
29
|
kader/prompts/templates/react_agent.j2,sha256=yME6Qgj2WTW8jRZ_yuQcY6xlXKcV7YUv5sz5ZfCl8P4,606
|
|
29
|
-
kader/providers/__init__.py,sha256=
|
|
30
|
+
kader/providers/__init__.py,sha256=6vZvD0nPIMblZWddn7hZjO6a0VlL5ZKBDuLn7OE_0_w,211
|
|
30
31
|
kader/providers/base.py,sha256=gxpomjRAX9q3Qf4GHYxdiGI_GsRW9BG7PM38SKUAeCk,17105
|
|
32
|
+
kader/providers/google.py,sha256=UaTl2jodVNcsJ7SwARjIf6Q4XGtqA_h_g7cfx903xZE,24724
|
|
31
33
|
kader/providers/mock.py,sha256=VBuOFFPvDWn4QVFS9HXlwu3jswP0NNNxrMyL4Qgvm50,2723
|
|
32
34
|
kader/providers/ollama.py,sha256=R5F0zlmbGGwSxNVURU0dWa-gMG_V-CmVZdRvy_GMmuw,15577
|
|
33
35
|
kader/tools/README.md,sha256=lmw-Ghm8ie2pNcSTL4sJ7OKerkGvbXmlD9Zi87hiC-8,14347
|
|
@@ -48,7 +50,7 @@ kader/utils/context_aggregator.py,sha256=5_2suuWSsJZhJ60zWTIkiEx5R5EIIdXak7MU98z
|
|
|
48
50
|
kader/workflows/__init__.py,sha256=qaarPRT7xcY86dHmAUM6IQpLedKKBayFiASZr8-dSSA,297
|
|
49
51
|
kader/workflows/base.py,sha256=BCTMMWE-LW_qIU7TWZgTzu82EMem6Uy2hJv0sa7buc0,1892
|
|
50
52
|
kader/workflows/planner_executor.py,sha256=VK4bCGvoUJ0eezNmkVb-iPjis1HsZFjlAUTtmluF9zw,9392
|
|
51
|
-
kader-1.
|
|
52
|
-
kader-1.
|
|
53
|
-
kader-1.
|
|
54
|
-
kader-1.
|
|
53
|
+
kader-1.1.0.dist-info/METADATA,sha256=BI_L65fwCoPTFwx7Un9P2B804lYc3TVcjOGx_jwIys4,10968
|
|
54
|
+
kader-1.1.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
55
|
+
kader-1.1.0.dist-info/entry_points.txt,sha256=TK0VOtrfDFqZ8JQfxpuAHHvDLHyoiafUjS-VOixl02c,39
|
|
56
|
+
kader-1.1.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|