cost-katana 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cost_katana/config.py ADDED
@@ -0,0 +1,181 @@
1
+ """
2
+ Configuration management for Cost Katana
3
+ """
4
+
5
+ import json
6
+ import os
7
+ from typing import Dict, Any, Optional
8
+ from dataclasses import dataclass, asdict
9
+ from pathlib import Path
10
+
11
+ @dataclass
12
+ class Config:
13
+ """Configuration class for Cost Katana client"""
14
+
15
+ api_key: Optional[str] = None
16
+ base_url: str = "https://cost-katana-backend.store"
17
+ timeout: int = 30
18
+ max_retries: int = 3
19
+ retry_delay: float = 1.0
20
+ default_model: str = "nova-lite"
21
+ default_temperature: float = 0.7
22
+ default_max_tokens: int = 2000
23
+ default_chat_mode: str = "balanced"
24
+ enable_analytics: bool = True
25
+ enable_optimization: bool = True
26
+ enable_failover: bool = True
27
+ cost_limit_per_request: Optional[float] = None
28
+ cost_limit_per_day: Optional[float] = None
29
+
30
+ def __post_init__(self):
31
+ """Load from environment variables if not set"""
32
+ if not self.api_key:
33
+ self.api_key = os.getenv('COST_KATANA_API_KEY')
34
+
35
+ # Override with environment variables if they exist
36
+ if os.getenv('COST_KATANA_BASE_URL'):
37
+ self.base_url = os.getenv('COST_KATANA_BASE_URL')
38
+ if os.getenv('COST_KATANA_DEFAULT_MODEL'):
39
+ self.default_model = os.getenv('COST_KATANA_DEFAULT_MODEL')
40
+ if os.getenv('COST_KATANA_TIMEOUT'):
41
+ self.timeout = int(os.getenv('COST_KATANA_TIMEOUT'))
42
+
43
+ @classmethod
44
+ def from_file(cls, config_path: str) -> 'Config':
45
+ """
46
+ Load configuration from JSON file.
47
+
48
+ Args:
49
+ config_path: Path to JSON configuration file
50
+
51
+ Returns:
52
+ Config instance
53
+
54
+ Example config.json:
55
+ {
56
+ "api_key": "dak_your_key_here",
57
+ "base_url": "https://api.costkatana.com",
58
+ "default_model": "claude-3-sonnet",
59
+ "default_temperature": 0.3,
60
+ "cost_limit_per_day": 100.0,
61
+ "providers": {
62
+ "anthropic": {
63
+ "priority": 1,
64
+ "models": ["claude-3-sonnet", "claude-3-haiku"]
65
+ },
66
+ "openai": {
67
+ "priority": 2,
68
+ "models": ["gpt-4", "gpt-3.5-turbo"]
69
+ }
70
+ }
71
+ }
72
+ """
73
+ config_path = Path(config_path).expanduser()
74
+
75
+ if not config_path.exists():
76
+ raise FileNotFoundError(f"Configuration file not found: {config_path}")
77
+
78
+ try:
79
+ with open(config_path, 'r', encoding='utf-8') as f:
80
+ data = json.load(f)
81
+ except json.JSONDecodeError as e:
82
+ raise ValueError(f"Invalid JSON in config file: {e}")
83
+
84
+ # Extract known fields
85
+ config_fields = {
86
+ field.name for field in cls.__dataclass_fields__.values()
87
+ }
88
+
89
+ config_data = {k: v for k, v in data.items() if k in config_fields}
90
+ config = cls(**config_data)
91
+
92
+ # Store additional data
93
+ config._extra_data = {k: v for k, v in data.items() if k not in config_fields}
94
+
95
+ return config
96
+
97
+ def to_dict(self) -> Dict[str, Any]:
98
+ """Convert config to dictionary"""
99
+ result = asdict(self)
100
+
101
+ # Add extra data if it exists
102
+ if hasattr(self, '_extra_data'):
103
+ result.update(self._extra_data)
104
+
105
+ return result
106
+
107
+ def save(self, config_path: str):
108
+ """Save configuration to JSON file"""
109
+ config_path = Path(config_path).expanduser()
110
+ config_path.parent.mkdir(parents=True, exist_ok=True)
111
+
112
+ with open(config_path, 'w', encoding='utf-8') as f:
113
+ json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
114
+
115
+ def get_provider_config(self, provider: str) -> Dict[str, Any]:
116
+ """Get configuration for a specific provider"""
117
+ if hasattr(self, '_extra_data') and 'providers' in self._extra_data:
118
+ return self._extra_data['providers'].get(provider, {})
119
+ return {}
120
+
121
+ def get_model_mapping(self, model_name: str) -> str:
122
+ """
123
+ Map user-friendly model names to internal model IDs.
124
+ This allows users to use names like 'gemini-2.0-flash' while
125
+ the backend uses the actual model IDs.
126
+ """
127
+ # Default mapping - can be overridden in config file
128
+ # Based on actual models available from Cost Katana Backend
129
+ default_mappings = {
130
+ # Amazon Nova models (primary recommendation)
131
+ 'nova-micro': 'amazon.nova-micro-v1:0',
132
+ 'nova-lite': 'amazon.nova-lite-v1:0',
133
+ 'nova-pro': 'amazon.nova-pro-v1:0',
134
+ 'fast': 'amazon.nova-micro-v1:0',
135
+ 'balanced': 'amazon.nova-lite-v1:0',
136
+ 'powerful': 'amazon.nova-pro-v1:0',
137
+
138
+ # Anthropic Claude models
139
+ 'claude-3-haiku': 'anthropic.claude-3-haiku-20240307-v1:0',
140
+ 'claude-3-sonnet': 'anthropic.claude-3-sonnet-20240229-v1:0',
141
+ 'claude-3-opus': 'anthropic.claude-3-opus-20240229-v1:0',
142
+ 'claude-3.5-haiku': 'anthropic.claude-3-5-haiku-20241022-v1:0',
143
+ 'claude-3.5-sonnet': 'anthropic.claude-3-5-sonnet-20241022-v2:0',
144
+ 'claude': 'anthropic.claude-3-5-sonnet-20241022-v2:0',
145
+
146
+ # Meta Llama models
147
+ 'llama-3.1-8b': 'meta.llama3-1-8b-instruct-v1:0',
148
+ 'llama-3.1-70b': 'meta.llama3-1-70b-instruct-v1:0',
149
+ 'llama-3.1-405b': 'meta.llama3-1-405b-instruct-v1:0',
150
+ 'llama-3.2-1b': 'meta.llama3-2-1b-instruct-v1:0',
151
+ 'llama-3.2-3b': 'meta.llama3-2-3b-instruct-v1:0',
152
+
153
+ # Mistral models
154
+ 'mistral-7b': 'mistral.mistral-7b-instruct-v0:2',
155
+ 'mixtral-8x7b': 'mistral.mixtral-8x7b-instruct-v0:1',
156
+ 'mistral-large': 'mistral.mistral-large-2402-v1:0',
157
+
158
+ # Cohere models
159
+ 'command': 'cohere.command-text-v14',
160
+ 'command-light': 'cohere.command-light-text-v14',
161
+ 'command-r': 'cohere.command-r-v1:0',
162
+ 'command-r-plus': 'cohere.command-r-plus-v1:0',
163
+
164
+ # AI21 models
165
+ 'jamba': 'ai21.jamba-instruct-v1:0',
166
+ 'j2-ultra': 'ai21.j2-ultra-v1',
167
+ 'j2-mid': 'ai21.j2-mid-v1',
168
+
169
+ # Backwards compatibility aliases
170
+ 'gemini-2.0-flash': 'amazon.nova-lite-v1:0', # Map to similar performance
171
+ 'gemini-pro': 'amazon.nova-pro-v1:0',
172
+ 'gpt-4': 'anthropic.claude-3-5-sonnet-20241022-v2:0',
173
+ 'gpt-3.5-turbo': 'anthropic.claude-3-haiku-20240307-v1:0',
174
+ }
175
+
176
+ # Check for custom mappings in config
177
+ if hasattr(self, '_extra_data') and 'model_mappings' in self._extra_data:
178
+ custom_mappings = self._extra_data['model_mappings']
179
+ default_mappings.update(custom_mappings)
180
+
181
+ return default_mappings.get(model_name, model_name)
@@ -0,0 +1,39 @@
1
+ """
2
+ Custom exceptions for Cost Katana
3
+ """
4
+
5
+ class CostKatanaError(Exception):
6
+ """Base exception for Cost Katana errors"""
7
+ pass
8
+
9
+ class AuthenticationError(CostKatanaError):
10
+ """Raised when authentication fails"""
11
+ pass
12
+
13
+ class ModelNotAvailableError(CostKatanaError):
14
+ """Raised when requested model is not available"""
15
+ pass
16
+
17
+ class RateLimitError(CostKatanaError):
18
+ """Raised when rate limit is exceeded"""
19
+ pass
20
+
21
+ class CostLimitExceededError(CostKatanaError):
22
+ """Raised when cost limits are exceeded"""
23
+ pass
24
+
25
+ class ConversationNotFoundError(CostKatanaError):
26
+ """Raised when conversation is not found"""
27
+ pass
28
+
29
+ class InvalidConfigurationError(CostKatanaError):
30
+ """Raised when configuration is invalid"""
31
+ pass
32
+
33
+ class NetworkError(CostKatanaError):
34
+ """Raised when network requests fail"""
35
+ pass
36
+
37
+ class ModelTimeoutError(CostKatanaError):
38
+ """Raised when model request times out"""
39
+ pass
cost_katana/models.py ADDED
@@ -0,0 +1,343 @@
1
+ """
2
+ Generative AI Models - Simple interface similar to google-generative-ai
3
+ """
4
+
5
+ import time
6
+ from typing import Dict, Any, Optional, List, Iterator, Union
7
+ from dataclasses import dataclass
8
+ from .client import CostKatanaClient
9
+ from .exceptions import CostKatanaError, ModelNotAvailableError
10
+
11
+ @dataclass
12
+ class GenerationConfig:
13
+ """Configuration for text generation"""
14
+ temperature: float = 0.7
15
+ max_output_tokens: int = 2000
16
+ top_p: Optional[float] = None
17
+ top_k: Optional[int] = None
18
+ candidate_count: int = 1
19
+ stop_sequences: Optional[List[str]] = None
20
+
21
+ @dataclass
22
+ class UsageMetadata:
23
+ """Usage metadata returned with responses"""
24
+ prompt_tokens: int
25
+ completion_tokens: int
26
+ total_tokens: int
27
+ cost: float
28
+ latency: float
29
+ model: str
30
+ optimizations_applied: Optional[List[str]] = None
31
+ cache_hit: bool = False
32
+ agent_path: Optional[List[str]] = None
33
+ risk_level: Optional[str] = None
34
+
35
+ class GenerateContentResponse:
36
+ """Response from generate_content method"""
37
+
38
+ def __init__(self, response_data: Dict[str, Any]):
39
+ self._data = response_data
40
+ self._text = response_data.get('data', {}).get('response', '')
41
+
42
+ # Extract usage metadata
43
+ data = response_data.get('data', {})
44
+ self.usage_metadata = UsageMetadata(
45
+ prompt_tokens=data.get('tokenCount', 0), # This might need adjustment based on actual response
46
+ completion_tokens=data.get('tokenCount', 0),
47
+ total_tokens=data.get('tokenCount', 0),
48
+ cost=data.get('cost', 0.0),
49
+ latency=data.get('latency', 0.0),
50
+ model=data.get('model', ''),
51
+ optimizations_applied=data.get('optimizationsApplied'),
52
+ cache_hit=data.get('cacheHit', False),
53
+ agent_path=data.get('agentPath'),
54
+ risk_level=data.get('riskLevel')
55
+ )
56
+
57
+ # Store thinking/reasoning if available
58
+ self.thinking = data.get('thinking')
59
+
60
+ @property
61
+ def text(self) -> str:
62
+ """Get the response text"""
63
+ return self._text
64
+
65
+ @property
66
+ def parts(self) -> List[Dict[str, Any]]:
67
+ """Get response parts (for compatibility)"""
68
+ return [{'text': self._text}] if self._text else []
69
+
70
+ def __str__(self) -> str:
71
+ return self._text
72
+
73
+ def __repr__(self) -> str:
74
+ return f"GenerateContentResponse(text='{self._text[:50]}...', cost=${self.usage_metadata.cost:.4f})"
75
+
76
+ class ChatSession:
77
+ """A chat session for maintaining conversation context"""
78
+
79
+ def __init__(
80
+ self,
81
+ client: CostKatanaClient,
82
+ model_id: str,
83
+ generation_config: Optional[GenerationConfig] = None,
84
+ conversation_id: str = None
85
+ ):
86
+ self.client = client
87
+ self.model_id = model_id
88
+ self.generation_config = generation_config or GenerationConfig()
89
+ self.conversation_id = conversation_id
90
+ self.history: List[Dict[str, Any]] = []
91
+
92
+ # Create conversation if not provided
93
+ if not self.conversation_id:
94
+ try:
95
+ conv_response = self.client.create_conversation(
96
+ title=f"Chat with {model_id}",
97
+ model_id=model_id
98
+ )
99
+ self.conversation_id = conv_response['data']['id']
100
+ except Exception as e:
101
+ raise CostKatanaError(f"Failed to create conversation: {str(e)}")
102
+
103
+ def send_message(
104
+ self,
105
+ message: str,
106
+ **kwargs
107
+ ) -> GenerateContentResponse:
108
+ """
109
+ Send a message in the chat session.
110
+
111
+ Args:
112
+ message: The message to send
113
+ **kwargs: Additional parameters to override defaults
114
+
115
+ Returns:
116
+ GenerateContentResponse with the model's reply
117
+
118
+ Example:
119
+ response = chat.send_message("What's the weather like?")
120
+ print(response.text)
121
+ """
122
+ # Merge generation config with kwargs
123
+ params = {
124
+ 'temperature': kwargs.get('temperature', self.generation_config.temperature),
125
+ 'max_tokens': kwargs.get('max_tokens', self.generation_config.max_output_tokens),
126
+ 'chat_mode': kwargs.get('chat_mode', 'balanced'),
127
+ 'use_multi_agent': kwargs.get('use_multi_agent', False),
128
+ }
129
+
130
+ # Add any additional parameters
131
+ for key, value in kwargs.items():
132
+ if key not in params:
133
+ params[key] = value
134
+
135
+ try:
136
+ response_data = self.client.send_message(
137
+ message=message,
138
+ model_id=self.model_id,
139
+ conversation_id=self.conversation_id,
140
+ **params
141
+ )
142
+
143
+ # Add to history
144
+ self.history.append({
145
+ 'role': 'user',
146
+ 'content': message,
147
+ 'timestamp': time.time()
148
+ })
149
+
150
+ response_text = response_data.get('data', {}).get('response', '')
151
+ self.history.append({
152
+ 'role': 'assistant',
153
+ 'content': response_text,
154
+ 'timestamp': time.time(),
155
+ 'metadata': response_data.get('data', {})
156
+ })
157
+
158
+ return GenerateContentResponse(response_data)
159
+
160
+ except Exception as e:
161
+ if isinstance(e, CostKatanaError):
162
+ raise
163
+ raise CostKatanaError(f"Failed to send message: {str(e)}")
164
+
165
+ def get_history(self) -> List[Dict[str, Any]]:
166
+ """Get the conversation history"""
167
+ try:
168
+ history_response = self.client.get_conversation_history(self.conversation_id)
169
+ return history_response.get('data', [])
170
+ except Exception as e:
171
+ # Fall back to local history if API call fails
172
+ return self.history
173
+
174
+ def clear_history(self):
175
+ """Clear the local conversation history"""
176
+ self.history = []
177
+
178
+ def delete_conversation(self):
179
+ """Delete the conversation from the server"""
180
+ try:
181
+ self.client.delete_conversation(self.conversation_id)
182
+ self.conversation_id = None
183
+ self.history = []
184
+ except Exception as e:
185
+ raise CostKatanaError(f"Failed to delete conversation: {str(e)}")
186
+
187
+ class GenerativeModel:
188
+ """
189
+ A generative AI model with a simple interface similar to google-generative-ai.
190
+ All requests are routed through Cost Katana for optimization and cost management.
191
+ """
192
+
193
+ def __init__(
194
+ self,
195
+ client: CostKatanaClient,
196
+ model_name: str,
197
+ generation_config: Optional[GenerationConfig] = None,
198
+ **kwargs
199
+ ):
200
+ """
201
+ Initialize a generative model.
202
+
203
+ Args:
204
+ client: Cost Katana client instance
205
+ model_name: Name of the model (e.g., 'gemini-2.0-flash', 'claude-3-sonnet')
206
+ generation_config: Generation configuration
207
+ **kwargs: Additional model parameters
208
+ """
209
+ self.client = client
210
+ self.model_name = model_name
211
+ self.model_id = client.config.get_model_mapping(model_name)
212
+ self.generation_config = generation_config or GenerationConfig()
213
+ self.model_params = kwargs
214
+
215
+ # Validate model is available
216
+ self._validate_model()
217
+
218
+ def _validate_model(self):
219
+ """Validate that the model is available"""
220
+ try:
221
+ available_models = self.client.get_available_models()
222
+ model_ids = [model.get('id', model.get('modelId', '')) for model in available_models]
223
+
224
+ if self.model_id not in model_ids and self.model_name not in model_ids:
225
+ raise ModelNotAvailableError(
226
+ f"Model '{self.model_name}' (ID: {self.model_id}) is not available. "
227
+ f"Available models: {', '.join(model_ids[:5])}..."
228
+ )
229
+ except ModelNotAvailableError:
230
+ raise
231
+ except Exception as e:
232
+ # If we can't validate, log but don't fail - the model might still work
233
+ print(f"Warning: Could not validate model availability: {e}")
234
+
235
+ def generate_content(
236
+ self,
237
+ prompt: Union[str, List[str]],
238
+ generation_config: Optional[GenerationConfig] = None,
239
+ **kwargs
240
+ ) -> GenerateContentResponse:
241
+ """
242
+ Generate content from a prompt.
243
+
244
+ Args:
245
+ prompt: Text prompt or list of prompts
246
+ generation_config: Generation configuration (overrides instance config)
247
+ **kwargs: Additional parameters
248
+
249
+ Returns:
250
+ GenerateContentResponse with the generated content
251
+
252
+ Example:
253
+ model = cost_katana.GenerativeModel('gemini-2.0-flash')
254
+ response = model.generate_content("Tell me about AI")
255
+ print(response.text)
256
+ print(f"Cost: ${response.usage_metadata.cost:.4f}")
257
+ """
258
+ # Handle multiple prompts
259
+ if isinstance(prompt, list):
260
+ prompt = "\n\n".join(str(p) for p in prompt)
261
+
262
+ # Use provided config or instance config
263
+ config = generation_config or self.generation_config
264
+
265
+ # Prepare parameters
266
+ params = {
267
+ 'temperature': kwargs.get('temperature', config.temperature),
268
+ 'max_tokens': kwargs.get('max_tokens', config.max_output_tokens),
269
+ 'chat_mode': kwargs.get('chat_mode', 'balanced'),
270
+ 'use_multi_agent': kwargs.get('use_multi_agent', False),
271
+ }
272
+
273
+ # Add any additional parameters from model_params or kwargs
274
+ params.update(self.model_params)
275
+ for key, value in kwargs.items():
276
+ if key not in params:
277
+ params[key] = value
278
+
279
+ try:
280
+ response_data = self.client.send_message(
281
+ message=prompt,
282
+ model_id=self.model_id,
283
+ **params
284
+ )
285
+
286
+ return GenerateContentResponse(response_data)
287
+
288
+ except Exception as e:
289
+ if isinstance(e, CostKatanaError):
290
+ raise
291
+ raise CostKatanaError(f"Failed to generate content: {str(e)}")
292
+
293
+ def start_chat(
294
+ self,
295
+ history: Optional[List[Dict[str, Any]]] = None,
296
+ **kwargs
297
+ ) -> ChatSession:
298
+ """
299
+ Start a chat session.
300
+
301
+ Args:
302
+ history: Optional conversation history
303
+ **kwargs: Additional chat configuration
304
+
305
+ Returns:
306
+ ChatSession instance
307
+
308
+ Example:
309
+ model = cost_katana.GenerativeModel('gemini-2.0-flash')
310
+ chat = model.start_chat()
311
+ response = chat.send_message("Hello!")
312
+ print(response.text)
313
+ """
314
+ chat_session = ChatSession(
315
+ client=self.client,
316
+ model_id=self.model_id,
317
+ generation_config=self.generation_config,
318
+ **kwargs
319
+ )
320
+
321
+ # Add history if provided
322
+ if history:
323
+ chat_session.history = history
324
+
325
+ return chat_session
326
+
327
+ def count_tokens(self, prompt: str) -> Dict[str, int]:
328
+ """
329
+ Count tokens in a prompt (estimated).
330
+ Note: This is a client-side estimate. Actual tokenization happens on the server.
331
+ """
332
+ # Simple word-based estimation - not accurate but gives an idea
333
+ words = len(prompt.split())
334
+ estimated_tokens = int(words * 1.3) # Rough approximation
335
+
336
+ return {
337
+ 'total_tokens': estimated_tokens,
338
+ 'prompt_tokens': estimated_tokens,
339
+ 'completion_tokens': 0
340
+ }
341
+
342
+ def __repr__(self) -> str:
343
+ return f"GenerativeModel(model_name='{self.model_name}', model_id='{self.model_id}')"