cost-katana 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cost_katana/__init__.py +69 -0
- cost_katana/cli.py +303 -0
- cost_katana/client.py +235 -0
- cost_katana/config.py +181 -0
- cost_katana/exceptions.py +39 -0
- cost_katana/models.py +343 -0
- cost_katana-1.0.0.dist-info/METADATA +425 -0
- cost_katana-1.0.0.dist-info/RECORD +12 -0
- cost_katana-1.0.0.dist-info/WHEEL +5 -0
- cost_katana-1.0.0.dist-info/entry_points.txt +2 -0
- cost_katana-1.0.0.dist-info/licenses/LICENSE +21 -0
- cost_katana-1.0.0.dist-info/top_level.txt +1 -0
cost_katana/config.py
ADDED
@@ -0,0 +1,181 @@
|
|
1
|
+
"""
|
2
|
+
Configuration management for Cost Katana
|
3
|
+
"""
|
4
|
+
|
5
|
+
import json
|
6
|
+
import os
|
7
|
+
from typing import Dict, Any, Optional
|
8
|
+
from dataclasses import dataclass, asdict
|
9
|
+
from pathlib import Path
|
10
|
+
|
11
|
+
@dataclass
|
12
|
+
class Config:
|
13
|
+
"""Configuration class for Cost Katana client"""
|
14
|
+
|
15
|
+
api_key: Optional[str] = None
|
16
|
+
base_url: str = "https://cost-katana-backend.store"
|
17
|
+
timeout: int = 30
|
18
|
+
max_retries: int = 3
|
19
|
+
retry_delay: float = 1.0
|
20
|
+
default_model: str = "nova-lite"
|
21
|
+
default_temperature: float = 0.7
|
22
|
+
default_max_tokens: int = 2000
|
23
|
+
default_chat_mode: str = "balanced"
|
24
|
+
enable_analytics: bool = True
|
25
|
+
enable_optimization: bool = True
|
26
|
+
enable_failover: bool = True
|
27
|
+
cost_limit_per_request: Optional[float] = None
|
28
|
+
cost_limit_per_day: Optional[float] = None
|
29
|
+
|
30
|
+
def __post_init__(self):
|
31
|
+
"""Load from environment variables if not set"""
|
32
|
+
if not self.api_key:
|
33
|
+
self.api_key = os.getenv('COST_KATANA_API_KEY')
|
34
|
+
|
35
|
+
# Override with environment variables if they exist
|
36
|
+
if os.getenv('COST_KATANA_BASE_URL'):
|
37
|
+
self.base_url = os.getenv('COST_KATANA_BASE_URL')
|
38
|
+
if os.getenv('COST_KATANA_DEFAULT_MODEL'):
|
39
|
+
self.default_model = os.getenv('COST_KATANA_DEFAULT_MODEL')
|
40
|
+
if os.getenv('COST_KATANA_TIMEOUT'):
|
41
|
+
self.timeout = int(os.getenv('COST_KATANA_TIMEOUT'))
|
42
|
+
|
43
|
+
@classmethod
|
44
|
+
def from_file(cls, config_path: str) -> 'Config':
|
45
|
+
"""
|
46
|
+
Load configuration from JSON file.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
config_path: Path to JSON configuration file
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
Config instance
|
53
|
+
|
54
|
+
Example config.json:
|
55
|
+
{
|
56
|
+
"api_key": "dak_your_key_here",
|
57
|
+
"base_url": "https://api.costkatana.com",
|
58
|
+
"default_model": "claude-3-sonnet",
|
59
|
+
"default_temperature": 0.3,
|
60
|
+
"cost_limit_per_day": 100.0,
|
61
|
+
"providers": {
|
62
|
+
"anthropic": {
|
63
|
+
"priority": 1,
|
64
|
+
"models": ["claude-3-sonnet", "claude-3-haiku"]
|
65
|
+
},
|
66
|
+
"openai": {
|
67
|
+
"priority": 2,
|
68
|
+
"models": ["gpt-4", "gpt-3.5-turbo"]
|
69
|
+
}
|
70
|
+
}
|
71
|
+
}
|
72
|
+
"""
|
73
|
+
config_path = Path(config_path).expanduser()
|
74
|
+
|
75
|
+
if not config_path.exists():
|
76
|
+
raise FileNotFoundError(f"Configuration file not found: {config_path}")
|
77
|
+
|
78
|
+
try:
|
79
|
+
with open(config_path, 'r', encoding='utf-8') as f:
|
80
|
+
data = json.load(f)
|
81
|
+
except json.JSONDecodeError as e:
|
82
|
+
raise ValueError(f"Invalid JSON in config file: {e}")
|
83
|
+
|
84
|
+
# Extract known fields
|
85
|
+
config_fields = {
|
86
|
+
field.name for field in cls.__dataclass_fields__.values()
|
87
|
+
}
|
88
|
+
|
89
|
+
config_data = {k: v for k, v in data.items() if k in config_fields}
|
90
|
+
config = cls(**config_data)
|
91
|
+
|
92
|
+
# Store additional data
|
93
|
+
config._extra_data = {k: v for k, v in data.items() if k not in config_fields}
|
94
|
+
|
95
|
+
return config
|
96
|
+
|
97
|
+
def to_dict(self) -> Dict[str, Any]:
|
98
|
+
"""Convert config to dictionary"""
|
99
|
+
result = asdict(self)
|
100
|
+
|
101
|
+
# Add extra data if it exists
|
102
|
+
if hasattr(self, '_extra_data'):
|
103
|
+
result.update(self._extra_data)
|
104
|
+
|
105
|
+
return result
|
106
|
+
|
107
|
+
def save(self, config_path: str):
|
108
|
+
"""Save configuration to JSON file"""
|
109
|
+
config_path = Path(config_path).expanduser()
|
110
|
+
config_path.parent.mkdir(parents=True, exist_ok=True)
|
111
|
+
|
112
|
+
with open(config_path, 'w', encoding='utf-8') as f:
|
113
|
+
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
114
|
+
|
115
|
+
def get_provider_config(self, provider: str) -> Dict[str, Any]:
|
116
|
+
"""Get configuration for a specific provider"""
|
117
|
+
if hasattr(self, '_extra_data') and 'providers' in self._extra_data:
|
118
|
+
return self._extra_data['providers'].get(provider, {})
|
119
|
+
return {}
|
120
|
+
|
121
|
+
def get_model_mapping(self, model_name: str) -> str:
|
122
|
+
"""
|
123
|
+
Map user-friendly model names to internal model IDs.
|
124
|
+
This allows users to use names like 'gemini-2.0-flash' while
|
125
|
+
the backend uses the actual model IDs.
|
126
|
+
"""
|
127
|
+
# Default mapping - can be overridden in config file
|
128
|
+
# Based on actual models available from Cost Katana Backend
|
129
|
+
default_mappings = {
|
130
|
+
# Amazon Nova models (primary recommendation)
|
131
|
+
'nova-micro': 'amazon.nova-micro-v1:0',
|
132
|
+
'nova-lite': 'amazon.nova-lite-v1:0',
|
133
|
+
'nova-pro': 'amazon.nova-pro-v1:0',
|
134
|
+
'fast': 'amazon.nova-micro-v1:0',
|
135
|
+
'balanced': 'amazon.nova-lite-v1:0',
|
136
|
+
'powerful': 'amazon.nova-pro-v1:0',
|
137
|
+
|
138
|
+
# Anthropic Claude models
|
139
|
+
'claude-3-haiku': 'anthropic.claude-3-haiku-20240307-v1:0',
|
140
|
+
'claude-3-sonnet': 'anthropic.claude-3-sonnet-20240229-v1:0',
|
141
|
+
'claude-3-opus': 'anthropic.claude-3-opus-20240229-v1:0',
|
142
|
+
'claude-3.5-haiku': 'anthropic.claude-3-5-haiku-20241022-v1:0',
|
143
|
+
'claude-3.5-sonnet': 'anthropic.claude-3-5-sonnet-20241022-v2:0',
|
144
|
+
'claude': 'anthropic.claude-3-5-sonnet-20241022-v2:0',
|
145
|
+
|
146
|
+
# Meta Llama models
|
147
|
+
'llama-3.1-8b': 'meta.llama3-1-8b-instruct-v1:0',
|
148
|
+
'llama-3.1-70b': 'meta.llama3-1-70b-instruct-v1:0',
|
149
|
+
'llama-3.1-405b': 'meta.llama3-1-405b-instruct-v1:0',
|
150
|
+
'llama-3.2-1b': 'meta.llama3-2-1b-instruct-v1:0',
|
151
|
+
'llama-3.2-3b': 'meta.llama3-2-3b-instruct-v1:0',
|
152
|
+
|
153
|
+
# Mistral models
|
154
|
+
'mistral-7b': 'mistral.mistral-7b-instruct-v0:2',
|
155
|
+
'mixtral-8x7b': 'mistral.mixtral-8x7b-instruct-v0:1',
|
156
|
+
'mistral-large': 'mistral.mistral-large-2402-v1:0',
|
157
|
+
|
158
|
+
# Cohere models
|
159
|
+
'command': 'cohere.command-text-v14',
|
160
|
+
'command-light': 'cohere.command-light-text-v14',
|
161
|
+
'command-r': 'cohere.command-r-v1:0',
|
162
|
+
'command-r-plus': 'cohere.command-r-plus-v1:0',
|
163
|
+
|
164
|
+
# AI21 models
|
165
|
+
'jamba': 'ai21.jamba-instruct-v1:0',
|
166
|
+
'j2-ultra': 'ai21.j2-ultra-v1',
|
167
|
+
'j2-mid': 'ai21.j2-mid-v1',
|
168
|
+
|
169
|
+
# Backwards compatibility aliases
|
170
|
+
'gemini-2.0-flash': 'amazon.nova-lite-v1:0', # Map to similar performance
|
171
|
+
'gemini-pro': 'amazon.nova-pro-v1:0',
|
172
|
+
'gpt-4': 'anthropic.claude-3-5-sonnet-20241022-v2:0',
|
173
|
+
'gpt-3.5-turbo': 'anthropic.claude-3-haiku-20240307-v1:0',
|
174
|
+
}
|
175
|
+
|
176
|
+
# Check for custom mappings in config
|
177
|
+
if hasattr(self, '_extra_data') and 'model_mappings' in self._extra_data:
|
178
|
+
custom_mappings = self._extra_data['model_mappings']
|
179
|
+
default_mappings.update(custom_mappings)
|
180
|
+
|
181
|
+
return default_mappings.get(model_name, model_name)
|
@@ -0,0 +1,39 @@
|
|
1
|
+
"""
|
2
|
+
Custom exceptions for Cost Katana
|
3
|
+
"""
|
4
|
+
|
5
|
+
class CostKatanaError(Exception):
|
6
|
+
"""Base exception for Cost Katana errors"""
|
7
|
+
pass
|
8
|
+
|
9
|
+
class AuthenticationError(CostKatanaError):
|
10
|
+
"""Raised when authentication fails"""
|
11
|
+
pass
|
12
|
+
|
13
|
+
class ModelNotAvailableError(CostKatanaError):
|
14
|
+
"""Raised when requested model is not available"""
|
15
|
+
pass
|
16
|
+
|
17
|
+
class RateLimitError(CostKatanaError):
|
18
|
+
"""Raised when rate limit is exceeded"""
|
19
|
+
pass
|
20
|
+
|
21
|
+
class CostLimitExceededError(CostKatanaError):
|
22
|
+
"""Raised when cost limits are exceeded"""
|
23
|
+
pass
|
24
|
+
|
25
|
+
class ConversationNotFoundError(CostKatanaError):
|
26
|
+
"""Raised when conversation is not found"""
|
27
|
+
pass
|
28
|
+
|
29
|
+
class InvalidConfigurationError(CostKatanaError):
|
30
|
+
"""Raised when configuration is invalid"""
|
31
|
+
pass
|
32
|
+
|
33
|
+
class NetworkError(CostKatanaError):
|
34
|
+
"""Raised when network requests fail"""
|
35
|
+
pass
|
36
|
+
|
37
|
+
class ModelTimeoutError(CostKatanaError):
|
38
|
+
"""Raised when model request times out"""
|
39
|
+
pass
|
cost_katana/models.py
ADDED
@@ -0,0 +1,343 @@
|
|
1
|
+
"""
|
2
|
+
Generative AI Models - Simple interface similar to google-generative-ai
|
3
|
+
"""
|
4
|
+
|
5
|
+
import time
|
6
|
+
from typing import Dict, Any, Optional, List, Iterator, Union
|
7
|
+
from dataclasses import dataclass
|
8
|
+
from .client import CostKatanaClient
|
9
|
+
from .exceptions import CostKatanaError, ModelNotAvailableError
|
10
|
+
|
11
|
+
@dataclass
|
12
|
+
class GenerationConfig:
|
13
|
+
"""Configuration for text generation"""
|
14
|
+
temperature: float = 0.7
|
15
|
+
max_output_tokens: int = 2000
|
16
|
+
top_p: Optional[float] = None
|
17
|
+
top_k: Optional[int] = None
|
18
|
+
candidate_count: int = 1
|
19
|
+
stop_sequences: Optional[List[str]] = None
|
20
|
+
|
21
|
+
@dataclass
|
22
|
+
class UsageMetadata:
|
23
|
+
"""Usage metadata returned with responses"""
|
24
|
+
prompt_tokens: int
|
25
|
+
completion_tokens: int
|
26
|
+
total_tokens: int
|
27
|
+
cost: float
|
28
|
+
latency: float
|
29
|
+
model: str
|
30
|
+
optimizations_applied: Optional[List[str]] = None
|
31
|
+
cache_hit: bool = False
|
32
|
+
agent_path: Optional[List[str]] = None
|
33
|
+
risk_level: Optional[str] = None
|
34
|
+
|
35
|
+
class GenerateContentResponse:
|
36
|
+
"""Response from generate_content method"""
|
37
|
+
|
38
|
+
def __init__(self, response_data: Dict[str, Any]):
|
39
|
+
self._data = response_data
|
40
|
+
self._text = response_data.get('data', {}).get('response', '')
|
41
|
+
|
42
|
+
# Extract usage metadata
|
43
|
+
data = response_data.get('data', {})
|
44
|
+
self.usage_metadata = UsageMetadata(
|
45
|
+
prompt_tokens=data.get('tokenCount', 0), # This might need adjustment based on actual response
|
46
|
+
completion_tokens=data.get('tokenCount', 0),
|
47
|
+
total_tokens=data.get('tokenCount', 0),
|
48
|
+
cost=data.get('cost', 0.0),
|
49
|
+
latency=data.get('latency', 0.0),
|
50
|
+
model=data.get('model', ''),
|
51
|
+
optimizations_applied=data.get('optimizationsApplied'),
|
52
|
+
cache_hit=data.get('cacheHit', False),
|
53
|
+
agent_path=data.get('agentPath'),
|
54
|
+
risk_level=data.get('riskLevel')
|
55
|
+
)
|
56
|
+
|
57
|
+
# Store thinking/reasoning if available
|
58
|
+
self.thinking = data.get('thinking')
|
59
|
+
|
60
|
+
@property
|
61
|
+
def text(self) -> str:
|
62
|
+
"""Get the response text"""
|
63
|
+
return self._text
|
64
|
+
|
65
|
+
@property
|
66
|
+
def parts(self) -> List[Dict[str, Any]]:
|
67
|
+
"""Get response parts (for compatibility)"""
|
68
|
+
return [{'text': self._text}] if self._text else []
|
69
|
+
|
70
|
+
def __str__(self) -> str:
|
71
|
+
return self._text
|
72
|
+
|
73
|
+
def __repr__(self) -> str:
|
74
|
+
return f"GenerateContentResponse(text='{self._text[:50]}...', cost=${self.usage_metadata.cost:.4f})"
|
75
|
+
|
76
|
+
class ChatSession:
|
77
|
+
"""A chat session for maintaining conversation context"""
|
78
|
+
|
79
|
+
def __init__(
|
80
|
+
self,
|
81
|
+
client: CostKatanaClient,
|
82
|
+
model_id: str,
|
83
|
+
generation_config: Optional[GenerationConfig] = None,
|
84
|
+
conversation_id: str = None
|
85
|
+
):
|
86
|
+
self.client = client
|
87
|
+
self.model_id = model_id
|
88
|
+
self.generation_config = generation_config or GenerationConfig()
|
89
|
+
self.conversation_id = conversation_id
|
90
|
+
self.history: List[Dict[str, Any]] = []
|
91
|
+
|
92
|
+
# Create conversation if not provided
|
93
|
+
if not self.conversation_id:
|
94
|
+
try:
|
95
|
+
conv_response = self.client.create_conversation(
|
96
|
+
title=f"Chat with {model_id}",
|
97
|
+
model_id=model_id
|
98
|
+
)
|
99
|
+
self.conversation_id = conv_response['data']['id']
|
100
|
+
except Exception as e:
|
101
|
+
raise CostKatanaError(f"Failed to create conversation: {str(e)}")
|
102
|
+
|
103
|
+
def send_message(
|
104
|
+
self,
|
105
|
+
message: str,
|
106
|
+
**kwargs
|
107
|
+
) -> GenerateContentResponse:
|
108
|
+
"""
|
109
|
+
Send a message in the chat session.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
message: The message to send
|
113
|
+
**kwargs: Additional parameters to override defaults
|
114
|
+
|
115
|
+
Returns:
|
116
|
+
GenerateContentResponse with the model's reply
|
117
|
+
|
118
|
+
Example:
|
119
|
+
response = chat.send_message("What's the weather like?")
|
120
|
+
print(response.text)
|
121
|
+
"""
|
122
|
+
# Merge generation config with kwargs
|
123
|
+
params = {
|
124
|
+
'temperature': kwargs.get('temperature', self.generation_config.temperature),
|
125
|
+
'max_tokens': kwargs.get('max_tokens', self.generation_config.max_output_tokens),
|
126
|
+
'chat_mode': kwargs.get('chat_mode', 'balanced'),
|
127
|
+
'use_multi_agent': kwargs.get('use_multi_agent', False),
|
128
|
+
}
|
129
|
+
|
130
|
+
# Add any additional parameters
|
131
|
+
for key, value in kwargs.items():
|
132
|
+
if key not in params:
|
133
|
+
params[key] = value
|
134
|
+
|
135
|
+
try:
|
136
|
+
response_data = self.client.send_message(
|
137
|
+
message=message,
|
138
|
+
model_id=self.model_id,
|
139
|
+
conversation_id=self.conversation_id,
|
140
|
+
**params
|
141
|
+
)
|
142
|
+
|
143
|
+
# Add to history
|
144
|
+
self.history.append({
|
145
|
+
'role': 'user',
|
146
|
+
'content': message,
|
147
|
+
'timestamp': time.time()
|
148
|
+
})
|
149
|
+
|
150
|
+
response_text = response_data.get('data', {}).get('response', '')
|
151
|
+
self.history.append({
|
152
|
+
'role': 'assistant',
|
153
|
+
'content': response_text,
|
154
|
+
'timestamp': time.time(),
|
155
|
+
'metadata': response_data.get('data', {})
|
156
|
+
})
|
157
|
+
|
158
|
+
return GenerateContentResponse(response_data)
|
159
|
+
|
160
|
+
except Exception as e:
|
161
|
+
if isinstance(e, CostKatanaError):
|
162
|
+
raise
|
163
|
+
raise CostKatanaError(f"Failed to send message: {str(e)}")
|
164
|
+
|
165
|
+
def get_history(self) -> List[Dict[str, Any]]:
|
166
|
+
"""Get the conversation history"""
|
167
|
+
try:
|
168
|
+
history_response = self.client.get_conversation_history(self.conversation_id)
|
169
|
+
return history_response.get('data', [])
|
170
|
+
except Exception as e:
|
171
|
+
# Fall back to local history if API call fails
|
172
|
+
return self.history
|
173
|
+
|
174
|
+
def clear_history(self):
|
175
|
+
"""Clear the local conversation history"""
|
176
|
+
self.history = []
|
177
|
+
|
178
|
+
def delete_conversation(self):
|
179
|
+
"""Delete the conversation from the server"""
|
180
|
+
try:
|
181
|
+
self.client.delete_conversation(self.conversation_id)
|
182
|
+
self.conversation_id = None
|
183
|
+
self.history = []
|
184
|
+
except Exception as e:
|
185
|
+
raise CostKatanaError(f"Failed to delete conversation: {str(e)}")
|
186
|
+
|
187
|
+
class GenerativeModel:
|
188
|
+
"""
|
189
|
+
A generative AI model with a simple interface similar to google-generative-ai.
|
190
|
+
All requests are routed through Cost Katana for optimization and cost management.
|
191
|
+
"""
|
192
|
+
|
193
|
+
def __init__(
|
194
|
+
self,
|
195
|
+
client: CostKatanaClient,
|
196
|
+
model_name: str,
|
197
|
+
generation_config: Optional[GenerationConfig] = None,
|
198
|
+
**kwargs
|
199
|
+
):
|
200
|
+
"""
|
201
|
+
Initialize a generative model.
|
202
|
+
|
203
|
+
Args:
|
204
|
+
client: Cost Katana client instance
|
205
|
+
model_name: Name of the model (e.g., 'gemini-2.0-flash', 'claude-3-sonnet')
|
206
|
+
generation_config: Generation configuration
|
207
|
+
**kwargs: Additional model parameters
|
208
|
+
"""
|
209
|
+
self.client = client
|
210
|
+
self.model_name = model_name
|
211
|
+
self.model_id = client.config.get_model_mapping(model_name)
|
212
|
+
self.generation_config = generation_config or GenerationConfig()
|
213
|
+
self.model_params = kwargs
|
214
|
+
|
215
|
+
# Validate model is available
|
216
|
+
self._validate_model()
|
217
|
+
|
218
|
+
def _validate_model(self):
|
219
|
+
"""Validate that the model is available"""
|
220
|
+
try:
|
221
|
+
available_models = self.client.get_available_models()
|
222
|
+
model_ids = [model.get('id', model.get('modelId', '')) for model in available_models]
|
223
|
+
|
224
|
+
if self.model_id not in model_ids and self.model_name not in model_ids:
|
225
|
+
raise ModelNotAvailableError(
|
226
|
+
f"Model '{self.model_name}' (ID: {self.model_id}) is not available. "
|
227
|
+
f"Available models: {', '.join(model_ids[:5])}..."
|
228
|
+
)
|
229
|
+
except ModelNotAvailableError:
|
230
|
+
raise
|
231
|
+
except Exception as e:
|
232
|
+
# If we can't validate, log but don't fail - the model might still work
|
233
|
+
print(f"Warning: Could not validate model availability: {e}")
|
234
|
+
|
235
|
+
def generate_content(
|
236
|
+
self,
|
237
|
+
prompt: Union[str, List[str]],
|
238
|
+
generation_config: Optional[GenerationConfig] = None,
|
239
|
+
**kwargs
|
240
|
+
) -> GenerateContentResponse:
|
241
|
+
"""
|
242
|
+
Generate content from a prompt.
|
243
|
+
|
244
|
+
Args:
|
245
|
+
prompt: Text prompt or list of prompts
|
246
|
+
generation_config: Generation configuration (overrides instance config)
|
247
|
+
**kwargs: Additional parameters
|
248
|
+
|
249
|
+
Returns:
|
250
|
+
GenerateContentResponse with the generated content
|
251
|
+
|
252
|
+
Example:
|
253
|
+
model = cost_katana.GenerativeModel('gemini-2.0-flash')
|
254
|
+
response = model.generate_content("Tell me about AI")
|
255
|
+
print(response.text)
|
256
|
+
print(f"Cost: ${response.usage_metadata.cost:.4f}")
|
257
|
+
"""
|
258
|
+
# Handle multiple prompts
|
259
|
+
if isinstance(prompt, list):
|
260
|
+
prompt = "\n\n".join(str(p) for p in prompt)
|
261
|
+
|
262
|
+
# Use provided config or instance config
|
263
|
+
config = generation_config or self.generation_config
|
264
|
+
|
265
|
+
# Prepare parameters
|
266
|
+
params = {
|
267
|
+
'temperature': kwargs.get('temperature', config.temperature),
|
268
|
+
'max_tokens': kwargs.get('max_tokens', config.max_output_tokens),
|
269
|
+
'chat_mode': kwargs.get('chat_mode', 'balanced'),
|
270
|
+
'use_multi_agent': kwargs.get('use_multi_agent', False),
|
271
|
+
}
|
272
|
+
|
273
|
+
# Add any additional parameters from model_params or kwargs
|
274
|
+
params.update(self.model_params)
|
275
|
+
for key, value in kwargs.items():
|
276
|
+
if key not in params:
|
277
|
+
params[key] = value
|
278
|
+
|
279
|
+
try:
|
280
|
+
response_data = self.client.send_message(
|
281
|
+
message=prompt,
|
282
|
+
model_id=self.model_id,
|
283
|
+
**params
|
284
|
+
)
|
285
|
+
|
286
|
+
return GenerateContentResponse(response_data)
|
287
|
+
|
288
|
+
except Exception as e:
|
289
|
+
if isinstance(e, CostKatanaError):
|
290
|
+
raise
|
291
|
+
raise CostKatanaError(f"Failed to generate content: {str(e)}")
|
292
|
+
|
293
|
+
def start_chat(
|
294
|
+
self,
|
295
|
+
history: Optional[List[Dict[str, Any]]] = None,
|
296
|
+
**kwargs
|
297
|
+
) -> ChatSession:
|
298
|
+
"""
|
299
|
+
Start a chat session.
|
300
|
+
|
301
|
+
Args:
|
302
|
+
history: Optional conversation history
|
303
|
+
**kwargs: Additional chat configuration
|
304
|
+
|
305
|
+
Returns:
|
306
|
+
ChatSession instance
|
307
|
+
|
308
|
+
Example:
|
309
|
+
model = cost_katana.GenerativeModel('gemini-2.0-flash')
|
310
|
+
chat = model.start_chat()
|
311
|
+
response = chat.send_message("Hello!")
|
312
|
+
print(response.text)
|
313
|
+
"""
|
314
|
+
chat_session = ChatSession(
|
315
|
+
client=self.client,
|
316
|
+
model_id=self.model_id,
|
317
|
+
generation_config=self.generation_config,
|
318
|
+
**kwargs
|
319
|
+
)
|
320
|
+
|
321
|
+
# Add history if provided
|
322
|
+
if history:
|
323
|
+
chat_session.history = history
|
324
|
+
|
325
|
+
return chat_session
|
326
|
+
|
327
|
+
def count_tokens(self, prompt: str) -> Dict[str, int]:
|
328
|
+
"""
|
329
|
+
Count tokens in a prompt (estimated).
|
330
|
+
Note: This is a client-side estimate. Actual tokenization happens on the server.
|
331
|
+
"""
|
332
|
+
# Simple word-based estimation - not accurate but gives an idea
|
333
|
+
words = len(prompt.split())
|
334
|
+
estimated_tokens = int(words * 1.3) # Rough approximation
|
335
|
+
|
336
|
+
return {
|
337
|
+
'total_tokens': estimated_tokens,
|
338
|
+
'prompt_tokens': estimated_tokens,
|
339
|
+
'completion_tokens': 0
|
340
|
+
}
|
341
|
+
|
342
|
+
def __repr__(self) -> str:
|
343
|
+
return f"GenerativeModel(model_name='{self.model_name}', model_id='{self.model_id}')"
|