paygent-sdk 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
paygent_sdk/client.py ADDED
@@ -0,0 +1,464 @@
1
+ """
2
+ Main client implementation for the Paygent SDK.
3
+ """
4
+
5
+ import json
6
+ import logging
7
+ import time
8
+ from typing import Optional
9
+ from urllib.parse import urljoin
10
+
11
+ import requests
12
+ from requests.adapters import HTTPAdapter
13
+ from urllib3.util.retry import Retry
14
+
15
+ try:
16
+ import tiktoken
17
+ except ImportError:
18
+ tiktoken = None
19
+
20
+ from .models import UsageData, UsageDataWithStrings, APIRequest, ModelPricing, MODEL_PRICING
21
+
22
+
23
+ class Client:
24
+ """Paygent SDK client for tracking usage and costs for AI models."""
25
+
26
+ def __init__(self, api_key: str, base_url: str = "http://13.201.118.45:8080"):
27
+ """
28
+ Initialize the Paygent SDK client.
29
+
30
+ Args:
31
+ api_key: Your Paygent API key
32
+ base_url: Base URL for the Paygent API (default: http://13.201.118.45:8080)
33
+ """
34
+ self.api_key = api_key
35
+ self.base_url = base_url.rstrip('/')
36
+
37
+ # Setup logging
38
+ self.logger = logging.getLogger(f"paygent_sdk.{id(self)}")
39
+ self.logger.setLevel(logging.INFO)
40
+
41
+ # Add console handler if no handlers exist
42
+ if not self.logger.handlers:
43
+ handler = logging.StreamHandler()
44
+ formatter = logging.Formatter(
45
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
46
+ )
47
+ handler.setFormatter(formatter)
48
+ self.logger.addHandler(handler)
49
+
50
+ # Setup HTTP client with retry strategy
51
+ self.session = requests.Session()
52
+
53
+ # Configure retry strategy
54
+ retry_strategy = Retry(
55
+ total=3,
56
+ backoff_factor=1,
57
+ status_forcelist=[429, 500, 502, 503, 504],
58
+ )
59
+
60
+ adapter = HTTPAdapter(max_retries=retry_strategy)
61
+ self.session.mount("http://", adapter)
62
+ self.session.mount("https://", adapter)
63
+
64
+ # Set default timeout
65
+ self.session.timeout = 30
66
+
67
+ @classmethod
68
+ def new_client(cls, api_key: str) -> 'Client':
69
+ """
70
+ Create a new Paygent SDK client with the default API URL.
71
+
72
+ Args:
73
+ api_key: Your Paygent API key
74
+
75
+ Returns:
76
+ Client instance
77
+ """
78
+ return cls(api_key)
79
+
80
+ @classmethod
81
+ def new_client_with_url(cls, api_key: str, base_url: str) -> 'Client':
82
+ """
83
+ Create a new Paygent SDK client with a custom base URL.
84
+
85
+ Args:
86
+ api_key: Your Paygent API key
87
+ base_url: Custom base URL for the Paygent API
88
+
89
+ Returns:
90
+ Client instance
91
+ """
92
+ return cls(api_key, base_url)
93
+
94
+ def _calculate_cost(self, model: str, usage_data: UsageData) -> float:
95
+ """
96
+ Calculate the cost based on model and usage data.
97
+
98
+ Args:
99
+ model: The AI model name
100
+ usage_data: Usage data containing token counts
101
+
102
+ Returns:
103
+ Calculated cost in USD
104
+ """
105
+ pricing = MODEL_PRICING.get(model)
106
+ if not pricing:
107
+ self.logger.warning(f"Unknown model '{model}', using default pricing")
108
+ # Use default pricing for unknown models (per 1000 tokens)
109
+ pricing = ModelPricing(
110
+ prompt_tokens_cost=0.0001, # $0.10 per 1000 tokens
111
+ completion_tokens_cost=0.0001 # $0.10 per 1000 tokens
112
+ )
113
+
114
+ # Calculate cost per 1000 tokens
115
+ prompt_cost = (usage_data.prompt_tokens / 1000.0) * pricing.prompt_tokens_cost
116
+ completion_cost = (usage_data.completion_tokens / 1000.0) * pricing.completion_tokens_cost
117
+ total_cost = prompt_cost + completion_cost
118
+
119
+ self.logger.debug(
120
+ f"Cost calculation for model '{model}': "
121
+ f"prompt_tokens={usage_data.prompt_tokens} ({prompt_cost:.6f}), "
122
+ f"completion_tokens={usage_data.completion_tokens} ({completion_cost:.6f}), "
123
+ f"total={total_cost:.6f}"
124
+ )
125
+
126
+ return total_cost
127
+
128
+ def send_usage(
129
+ self,
130
+ agent_id: str,
131
+ customer_id: str,
132
+ indicator: str,
133
+ usage_data: UsageData
134
+ ) -> None:
135
+ """
136
+ Send usage data to the Paygent API.
137
+
138
+ Args:
139
+ agent_id: Unique identifier for the agent
140
+ customer_id: Unique identifier for the customer
141
+ indicator: Indicator for the usage event
142
+ usage_data: Usage data containing model and token information
143
+
144
+ Raises:
145
+ requests.RequestException: If the HTTP request fails
146
+ ValueError: If the usage data is invalid
147
+ """
148
+ self.logger.info(
149
+ f"Starting sendUsage for agentID={agent_id}, customerID={customer_id}, "
150
+ f"indicator={indicator}, model={usage_data.model}"
151
+ )
152
+
153
+ # Calculate cost
154
+ try:
155
+ cost = self._calculate_cost(usage_data.model, usage_data)
156
+ except Exception as e:
157
+ self.logger.error(f"Failed to calculate cost: {e}")
158
+ raise ValueError(f"Failed to calculate cost: {e}") from e
159
+
160
+ self.logger.info(f"Calculated cost: {cost:.6f} for model {usage_data.model}")
161
+
162
+ # Prepare API request
163
+ api_request = APIRequest(
164
+ agent_id=agent_id,
165
+ customer_id=customer_id,
166
+ indicator=indicator,
167
+ amount=cost
168
+ )
169
+
170
+ # Prepare request data
171
+ request_data = {
172
+ "agentId": api_request.agent_id,
173
+ "customerId": api_request.customer_id,
174
+ "indicator": api_request.indicator,
175
+ "amount": api_request.amount,
176
+ "inputToken": usage_data.prompt_tokens,
177
+ "outputToken": usage_data.completion_tokens,
178
+ "model": usage_data.model,
179
+ "serviceProvider": usage_data.service_provider
180
+ }
181
+
182
+ self.logger.debug(f"API request body: {json.dumps(request_data)}")
183
+
184
+ # Create HTTP request
185
+ url = urljoin(self.base_url, "/api/v1/usage")
186
+
187
+ headers = {
188
+ "Content-Type": "application/json",
189
+ "paygent-api-key": self.api_key
190
+ }
191
+
192
+ self.logger.debug(f"Making HTTP POST request to: {url}")
193
+
194
+ try:
195
+ # Make HTTP request
196
+ response = self.session.post(
197
+ url,
198
+ json=request_data,
199
+ headers=headers,
200
+ timeout=30
201
+ )
202
+
203
+ self.logger.debug(
204
+ f"API response status: {response.status_code}, "
205
+ f"body: {response.text}"
206
+ )
207
+
208
+ # Check response status
209
+ if 200 <= response.status_code < 300:
210
+ self.logger.info(
211
+ f"Successfully sent usage data for agentID={agent_id}, "
212
+ f"customerID={customer_id}, cost={cost:.6f}"
213
+ )
214
+ return
215
+
216
+ # Handle error response
217
+ self.logger.error(
218
+ f"API request failed with status {response.status_code}: {response.text}"
219
+ )
220
+ response.raise_for_status()
221
+
222
+ except requests.RequestException as e:
223
+ self.logger.error(f"HTTP request failed: {e}")
224
+ raise
225
+
226
+ def set_log_level(self, level: int) -> None:
227
+ """
228
+ Set the logging level for the client.
229
+
230
+ Args:
231
+ level: Logging level (e.g., logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR)
232
+ """
233
+ self.logger.setLevel(level)
234
+
235
+ def get_logger(self) -> logging.Logger:
236
+ """
237
+ Get the logger instance for custom logging.
238
+
239
+ Returns:
240
+ Logger instance
241
+ """
242
+ return self.logger
243
+
244
+ def _get_token_count(self, model: str, text: str) -> int:
245
+ """
246
+ Get token count for a given model and text.
247
+ Supports OpenAI, Anthropic, Google, Meta, AWS, Mistral, Cohere, DeepSeek
248
+
249
+ Args:
250
+ model: The AI model name
251
+ text: Text to count tokens for
252
+
253
+ Returns:
254
+ Number of tokens
255
+ """
256
+ if not text:
257
+ return 0
258
+
259
+ if not tiktoken:
260
+ self.logger.warning("tiktoken not available, using fallback token counting")
261
+ return self._fallback_token_count(text)
262
+
263
+ model_lower = model.lower()
264
+
265
+ try:
266
+ # OpenAI GPT models
267
+ if model_lower.startswith("gpt-"):
268
+ encoding = tiktoken.encoding_for_model(model)
269
+ return len(encoding.encode(text))
270
+
271
+ # Anthropic Claude models
272
+ elif model_lower.startswith("claude-"):
273
+ encoding = tiktoken.get_encoding("cl100k_base")
274
+ return len(encoding.encode(text))
275
+
276
+ # Google DeepMind Gemini models
277
+ elif model_lower.startswith("gemini-"):
278
+ encoding = tiktoken.get_encoding("cl100k_base")
279
+ return len(encoding.encode(text))
280
+
281
+ # Meta Llama models
282
+ elif model_lower.startswith("llama-"):
283
+ encoding = tiktoken.get_encoding("cl100k_base")
284
+ return len(encoding.encode(text))
285
+
286
+ # AWS Titan models
287
+ elif model_lower.startswith("titan-"):
288
+ encoding = tiktoken.get_encoding("cl100k_base")
289
+ return len(encoding.encode(text))
290
+
291
+ # Mistral models
292
+ elif model_lower.startswith("mistral-"):
293
+ encoding = tiktoken.get_encoding("cl100k_base")
294
+ return len(encoding.encode(text))
295
+
296
+ # Cohere models
297
+ elif model_lower.startswith("command"):
298
+ encoding = tiktoken.get_encoding("cl100k_base")
299
+ return len(encoding.encode(text))
300
+
301
+ # DeepSeek models
302
+ elif model_lower.startswith("deepseek-"):
303
+ encoding = tiktoken.get_encoding("cl100k_base")
304
+ return len(encoding.encode(text))
305
+
306
+ # Default fallback
307
+ else:
308
+ self.logger.warning(f"Unknown model '{model}', using cl100k_base encoding")
309
+ encoding = tiktoken.get_encoding("cl100k_base")
310
+ return len(encoding.encode(text))
311
+
312
+ except Exception as e:
313
+ self.logger.warning(f"Failed to get token count for model {model}: {e}, using fallback")
314
+ return self._fallback_token_count(text)
315
+
316
+ def _fallback_token_count(self, text: str) -> int:
317
+ """
318
+ Fallback token counting method using word-based estimation.
319
+
320
+ Args:
321
+ text: Text to count tokens for
322
+
323
+ Returns:
324
+ Estimated number of tokens
325
+ """
326
+ # Simple word-based estimation: 1.3 tokens per word
327
+ word_count = len(text.split())
328
+ return int(word_count * 1.3)
329
+
330
+ def _calculate_cost_from_strings(self, model: str, usage_data: UsageDataWithStrings) -> float:
331
+ """
332
+ Calculate the cost based on model and usage data with strings.
333
+
334
+ Args:
335
+ model: The AI model name
336
+ usage_data: Usage data containing prompt and output strings
337
+
338
+ Returns:
339
+ Calculated cost in USD
340
+ """
341
+ # Count tokens
342
+ prompt_tokens = self._get_token_count(model, usage_data.prompt_string)
343
+ completion_tokens = self._get_token_count(model, usage_data.output_string)
344
+ total_tokens = prompt_tokens + completion_tokens
345
+
346
+ self.logger.debug(
347
+ f"Token counting for model '{model}': "
348
+ f"prompt_tokens={prompt_tokens}, completion_tokens={completion_tokens}, "
349
+ f"total_tokens={total_tokens}"
350
+ )
351
+
352
+ # Create UsageData for cost calculation
353
+ usage_data_obj = UsageData(
354
+ service_provider=usage_data.service_provider,
355
+ model=model,
356
+ prompt_tokens=prompt_tokens,
357
+ completion_tokens=completion_tokens,
358
+ total_tokens=total_tokens
359
+ )
360
+
361
+ return self._calculate_cost(model, usage_data_obj)
362
+
363
+ def send_usage_with_token_string(
364
+ self,
365
+ agent_id: str,
366
+ customer_id: str,
367
+ indicator: str,
368
+ usage_data: UsageDataWithStrings
369
+ ) -> None:
370
+ """
371
+ Send usage data to the Paygent API using prompt and output strings.
372
+ The function automatically counts tokens using proper tokenizers for each model provider and calculates costs.
373
+
374
+ Args:
375
+ agent_id: Unique identifier for the agent
376
+ customer_id: Unique identifier for the customer
377
+ indicator: Indicator for the usage event
378
+ usage_data: Usage data containing prompt and output strings
379
+
380
+ Raises:
381
+ requests.RequestException: If the HTTP request fails
382
+ ValueError: If the usage data is invalid
383
+ """
384
+ self.logger.info(
385
+ f"Starting sendUsageWithTokenString for agentID={agent_id}, customerID={customer_id}, "
386
+ f"indicator={indicator}, serviceProvider={usage_data.service_provider}, model={usage_data.model}"
387
+ )
388
+
389
+ # Calculate cost from strings
390
+ try:
391
+ cost = self._calculate_cost_from_strings(usage_data.model, usage_data)
392
+ except Exception as e:
393
+ self.logger.error(f"Failed to calculate cost from strings: {e}")
394
+ raise ValueError(f"Failed to calculate cost from strings: {e}") from e
395
+
396
+ self.logger.info(f"Calculated cost: {cost:.6f} for model {usage_data.model} from strings")
397
+
398
+ # Calculate token counts for API request
399
+ prompt_tokens = self._get_token_count(usage_data.model, usage_data.prompt_string)
400
+ completion_tokens = self._get_token_count(usage_data.model, usage_data.output_string)
401
+
402
+ # Prepare API request
403
+ api_request = APIRequest(
404
+ agent_id=agent_id,
405
+ customer_id=customer_id,
406
+ indicator=indicator,
407
+ amount=cost
408
+ )
409
+
410
+ # Prepare request data
411
+ request_data = {
412
+ "agentId": api_request.agent_id,
413
+ "customerId": api_request.customer_id,
414
+ "indicator": api_request.indicator,
415
+ "amount": api_request.amount,
416
+ "inputToken": prompt_tokens,
417
+ "outputToken": completion_tokens,
418
+ "model": usage_data.model,
419
+ "serviceProvider": usage_data.service_provider
420
+ }
421
+
422
+ self.logger.debug(f"API request body: {json.dumps(request_data)}")
423
+
424
+ # Create HTTP request
425
+ url = urljoin(self.base_url, "/api/v1/usage")
426
+
427
+ headers = {
428
+ "Content-Type": "application/json",
429
+ "paygent-api-key": self.api_key
430
+ }
431
+
432
+ self.logger.debug(f"Making HTTP POST request to: {url}")
433
+
434
+ try:
435
+ # Make HTTP request
436
+ response = self.session.post(
437
+ url,
438
+ json=request_data,
439
+ headers=headers,
440
+ timeout=30
441
+ )
442
+
443
+ self.logger.debug(
444
+ f"API response status: {response.status_code}, "
445
+ f"body: {response.text}"
446
+ )
447
+
448
+ # Check response status
449
+ if 200 <= response.status_code < 300:
450
+ self.logger.info(
451
+ f"Successfully sent usage data from strings for agentID={agent_id}, "
452
+ f"customerID={customer_id}, cost={cost:.6f}"
453
+ )
454
+ return
455
+
456
+ # Handle error response
457
+ self.logger.error(
458
+ f"API request failed with status {response.status_code}: {response.text}"
459
+ )
460
+ response.raise_for_status()
461
+
462
+ except requests.RequestException as e:
463
+ self.logger.error(f"HTTP request failed: {e}")
464
+ raise
@@ -0,0 +1,217 @@
1
+ """
2
+ Model constants for the Paygent SDK.
3
+
4
+ This module provides predefined constants for AI models and service providers
5
+ to make it easier for users to reference models without hardcoding strings.
6
+
7
+ For the Go SDK equivalent, see: https://github.com/paygent/paygent-sdk-go
8
+ """
9
+
10
+ # Service Provider Constants
11
+ class ServiceProvider:
12
+ """Service provider constants for external access."""
13
+ OPENAI = "OpenAI"
14
+ ANTHROPIC = "Anthropic"
15
+ GOOGLE_DEEPMIND = "Google DeepMind"
16
+ META = "Meta"
17
+ AWS = "AWS"
18
+ MISTRAL_AI = "Mistral AI"
19
+ COHERE = "Cohere"
20
+ DEEPSEEK = "DeepSeek"
21
+ CUSTOM = "Custom"
22
+
23
+
24
+ # OpenAI Models
25
+ class OpenAIModels:
26
+ """OpenAI model constants."""
27
+ # GPT-5 Series
28
+ GPT_5 = "gpt-5"
29
+ GPT_5_MINI = "gpt-5-mini"
30
+ GPT_5_NANO = "gpt-5-nano"
31
+ GPT_5_CHAT_LATEST = "gpt-5-chat-latest"
32
+ GPT_5_CODEX = "gpt-5-codex"
33
+ GPT_5_PRO = "gpt-5-pro"
34
+ GPT_5_SEARCH_API = "gpt-5-search-api"
35
+
36
+ # GPT-4.1 Series
37
+ GPT_4_1 = "gpt-4.1"
38
+ GPT_4_1_MINI = "gpt-4.1-mini"
39
+ GPT_4_1_NANO = "gpt-4.1-nano"
40
+
41
+ # GPT-4o Series
42
+ GPT_4O = "gpt-4o"
43
+ GPT_4O_2024_05_13 = "gpt-4o-2024-05-13"
44
+ GPT_4O_MINI = "gpt-4o-mini"
45
+
46
+ # Realtime Models
47
+ GPT_REALTIME = "gpt-realtime"
48
+ GPT_REALTIME_MINI = "gpt-realtime-mini"
49
+ GPT_4O_REALTIME_PREVIEW = "gpt-4o-realtime-preview"
50
+ GPT_4O_MINI_REALTIME_PREVIEW = "gpt-4o-mini-realtime-preview"
51
+
52
+ # Audio Models
53
+ GPT_AUDIO = "gpt-audio"
54
+ GPT_AUDIO_MINI = "gpt-audio-mini"
55
+ GPT_4O_AUDIO_PREVIEW = "gpt-4o-audio-preview"
56
+ GPT_4O_MINI_AUDIO_PREVIEW = "gpt-4o-mini-audio-preview"
57
+
58
+ # O-Series Models
59
+ O1 = "o1"
60
+ O1_PRO = "o1-pro"
61
+ O3_PRO = "o3-pro"
62
+ O3 = "o3"
63
+ O3_DEEP_RESEARCH = "o3-deep-research"
64
+ O4_MINI = "o4-mini"
65
+ O4_MINI_DEEP_RESEARCH = "o4-mini-deep-research"
66
+ O3_MINI = "o3-mini"
67
+ O1_MINI = "o1-mini"
68
+
69
+ # Other Models
70
+ CODEX_MINI_LATEST = "codex-mini-latest"
71
+ GPT_4O_MINI_SEARCH_PREVIEW = "gpt-4o-mini-search-preview"
72
+ GPT_4O_SEARCH_PREVIEW = "gpt-4o-search-preview"
73
+ COMPUTER_USE_PREVIEW = "computer-use-preview"
74
+ CHATGPT_4O_LATEST = "chatgpt-4o-latest"
75
+
76
+ # GPT-4 Turbo
77
+ GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09"
78
+ GPT_4_0125_PREVIEW = "gpt-4-0125-preview"
79
+ GPT_4_1106_PREVIEW = "gpt-4-1106-preview"
80
+ GPT_4_1106_VISION_PREVIEW = "gpt-4-1106-vision-preview"
81
+ GPT_4_0613 = "gpt-4-0613"
82
+ GPT_4_0314 = "gpt-4-0314"
83
+ GPT_4_32K = "gpt-4-32k"
84
+
85
+ # GPT-3.5 Series
86
+ GPT_3_5_TURBO = "gpt-3.5-turbo"
87
+ GPT_3_5_TURBO_0125 = "gpt-3.5-turbo-0125"
88
+ GPT_3_5_TURBO_1106 = "gpt-3.5-turbo-1106"
89
+ GPT_3_5_TURBO_0613 = "gpt-3.5-turbo-0613"
90
+ GPT_3_5_0301 = "gpt-3.5-0301"
91
+ GPT_3_5_TURBO_INSTRUCT = "gpt-3.5-turbo-instruct"
92
+ GPT_3_5_TURBO_16K_0613 = "gpt-3.5-turbo-16k-0613"
93
+
94
+ # Legacy Models
95
+ DAVINCI_002 = "davinci-002"
96
+ BABBAGE_002 = "babbage-002"
97
+
98
+
99
+ # Anthropic Models
100
+ class AnthropicModels:
101
+ """Anthropic model constants."""
102
+ SONNET_4_5 = "Sonnet 4.5"
103
+ HAIKU_4_5 = "Haiku 4.5"
104
+ OPUS_4_1 = "Opus 4.1"
105
+ SONNET_4 = "Sonnet 4"
106
+ OPUS_4 = "Opus 4"
107
+ SONNET_3_7 = "Sonnet 3.7"
108
+ HAIKU_3_5 = "Haiku 3.5"
109
+ OPUS_3 = "Opus 3"
110
+ HAIKU_3 = "Haiku 3"
111
+
112
+
113
+ # Google DeepMind Models
114
+ class GoogleDeepMindModels:
115
+ """Google DeepMind model constants."""
116
+ GEMINI_2_5_PRO = "Gemini 2.5 Pro"
117
+ GEMINI_2_5_FLASH = "Gemini 2.5 Flash"
118
+ GEMINI_2_5_FLASH_PREVIEW = "Gemini 2.5 Flash Preview"
119
+ GEMINI_2_5_FLASH_LITE = "Gemini 2.5 Flash-Lite"
120
+ GEMINI_2_5_FLASH_LITE_PREVIEW = "Gemini 2.5 Flash-Lite Preview"
121
+ GEMINI_2_5_FLASH_NATIVE_AUDIO = "Gemini 2.5 Flash Native Audio"
122
+ GEMINI_2_5_FLASH_IMAGE = "Gemini 2.5 Flash Image"
123
+ GEMINI_2_5_FLASH_PREVIEW_TTS = "Gemini 2.5 Flash Preview TTS"
124
+ GEMINI_2_5_PRO_PREVIEW_TTS = "Gemini 2.5 Pro Preview TTS"
125
+ GEMINI_2_5_COMPUTER_USE_PREVIEW = "Gemini 2.5 Computer Use Preview"
126
+
127
+
128
+ # Meta Models
129
+ class MetaModels:
130
+ """Meta model constants."""
131
+ # Llama 4 Series
132
+ LLAMA_4_MAVERICK = "Llama 4 Maverick"
133
+ LLAMA_4_SCOUT = "Llama 4 Scout"
134
+
135
+ # Llama 3.3 Series
136
+ LLAMA_3_3_70B_INSTRUCT_TURBO = "Llama 3.3 70B Instruct-Turbo"
137
+
138
+ # Llama 3.2 Series
139
+ LLAMA_3_2_3B_INSTRUCT_TURBO = "Llama 3.2 3B Instruct Turbo"
140
+
141
+ # Llama 3.1 Series
142
+ LLAMA_3_1_405B_INSTRUCT_TURBO = "Llama 3.1 405B Instruct Turbo"
143
+ LLAMA_3_1_70B_INSTRUCT_TURBO = "Llama 3.1 70B Instruct Turbo"
144
+ LLAMA_3_1_8B_INSTRUCT_TURBO = "Llama 3.1 8B Instruct Turbo"
145
+
146
+ # Llama 3 Series
147
+ LLAMA_3_70B_INSTRUCT_TURBO = "Llama 3 70B Instruct Turbo"
148
+ LLAMA_3_70B_INSTRUCT_REFERENCE = "Llama 3 70B Instruct Reference"
149
+ LLAMA_3_8B_INSTRUCT_LITE = "Llama 3 8B Instruct Lite"
150
+
151
+ # Llama 2
152
+ LLAMA_2 = "LLaMA-2"
153
+
154
+ # Llama Guard Series
155
+ LLAMA_GUARD_4_12B = "Llama Guard 4 12B"
156
+ LLAMA_GUARD_3_11B_VISION_TURBO = "Llama Guard 3 11B Vision Turbo"
157
+ LLAMA_GUARD_3_8B = "Llama Guard 3 8B"
158
+ LLAMA_GUARD_2_8B = "Llama Guard 2 8B"
159
+
160
+ # Salesforce
161
+ SALESFORCE_LLAMA_RANK_V1_8B = "Salesforce Llama Rank V1 (8B)"
162
+
163
+
164
+ # AWS Models
165
+ class AWSModels:
166
+ """AWS model constants."""
167
+ AMAZON_NOVA_MICRO = "Amazon Nova Micro"
168
+ AMAZON_NOVA_LITE = "Amazon Nova Lite"
169
+ AMAZON_NOVA_PRO = "Amazon Nova Pro"
170
+
171
+
172
+ # Mistral AI Models
173
+ class MistralAIModels:
174
+ """Mistral AI model constants."""
175
+ MISTRAL_7B_INSTRUCT = "Mistral 7B Instruct"
176
+ MISTRAL_LARGE = "Mistral Large"
177
+ MISTRAL_SMALL = "Mistral Small"
178
+ MISTRAL_MEDIUM = "Mistral Medium"
179
+
180
+
181
+ # Cohere Models
182
+ class CohereModels:
183
+ """Cohere model constants."""
184
+ COMMAND_R7B = "Command R7B"
185
+ COMMAND_R = "Command R"
186
+ COMMAND_R_PLUS = "Command R+"
187
+ COMMAND_A = "Command A"
188
+ AYA_EXPANSE_8B_32B = "Aya Expanse (8B/32B)"
189
+
190
+
191
+ # DeepSeek Models
192
+ class DeepSeekModels:
193
+ """DeepSeek model constants."""
194
+ DEEPSEEK_CHAT = "DeepSeek Chat"
195
+ DEEPSEEK_REASONER = "DeepSeek Reasoner"
196
+ DEEPSEEK_R1_GLOBAL = "DeepSeek R1 Global"
197
+ DEEPSEEK_R1_DATAZONE = "DeepSeek R1 DataZone"
198
+ DEEPSEEK_V3_2_EXP = "DeepSeek V3.2-Exp"
199
+
200
+
201
+
202
+
203
+
204
+
205
+ def is_model_supported(model: str) -> bool:
206
+ """
207
+ Check if a model is supported by the SDK.
208
+
209
+ Args:
210
+ model: The model name
211
+
212
+ Returns:
213
+ True if the model is supported, False otherwise
214
+ """
215
+ # Import here to avoid circular dependency
216
+ from .models import MODEL_PRICING
217
+ return model in MODEL_PRICING