indoxrouter 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- indoxRouter/__init__.py +83 -0
- indoxRouter/client.py +564 -218
- indoxRouter/client_resourses/__init__.py +20 -0
- indoxRouter/client_resourses/base.py +67 -0
- indoxRouter/client_resourses/chat.py +144 -0
- indoxRouter/client_resourses/completion.py +138 -0
- indoxRouter/client_resourses/embedding.py +83 -0
- indoxRouter/client_resourses/image.py +116 -0
- indoxRouter/client_resourses/models.py +114 -0
- indoxRouter/config.py +151 -0
- indoxRouter/constants/__init__.py +81 -0
- indoxRouter/exceptions/__init__.py +70 -0
- indoxRouter/models/__init__.py +111 -0
- indoxRouter/providers/__init__.py +50 -50
- indoxRouter/providers/ai21labs.json +128 -0
- indoxRouter/providers/base_provider.py +62 -30
- indoxRouter/providers/claude.json +164 -0
- indoxRouter/providers/cohere.json +116 -0
- indoxRouter/providers/databricks.json +110 -0
- indoxRouter/providers/deepseek.json +110 -0
- indoxRouter/providers/google.json +128 -0
- indoxRouter/providers/meta.json +128 -0
- indoxRouter/providers/mistral.json +146 -0
- indoxRouter/providers/nvidia.json +110 -0
- indoxRouter/providers/openai.json +308 -0
- indoxRouter/providers/openai.py +471 -72
- indoxRouter/providers/qwen.json +110 -0
- indoxRouter/utils/__init__.py +240 -0
- indoxrouter-0.1.2.dist-info/LICENSE +21 -0
- indoxrouter-0.1.2.dist-info/METADATA +259 -0
- indoxrouter-0.1.2.dist-info/RECORD +33 -0
- indoxRouter/api_endpoints.py +0 -336
- indoxRouter/client_package.py +0 -138
- indoxRouter/init_db.py +0 -71
- indoxRouter/main.py +0 -711
- indoxRouter/migrations/__init__.py +0 -1
- indoxRouter/migrations/env.py +0 -98
- indoxRouter/migrations/versions/__init__.py +0 -1
- indoxRouter/migrations/versions/initial_schema.py +0 -84
- indoxRouter/providers/ai21.py +0 -268
- indoxRouter/providers/claude.py +0 -177
- indoxRouter/providers/cohere.py +0 -171
- indoxRouter/providers/databricks.py +0 -166
- indoxRouter/providers/deepseek.py +0 -166
- indoxRouter/providers/google.py +0 -216
- indoxRouter/providers/llama.py +0 -164
- indoxRouter/providers/meta.py +0 -227
- indoxRouter/providers/mistral.py +0 -182
- indoxRouter/providers/nvidia.py +0 -164
- indoxrouter-0.1.0.dist-info/METADATA +0 -179
- indoxrouter-0.1.0.dist-info/RECORD +0 -27
- {indoxrouter-0.1.0.dist-info → indoxrouter-0.1.2.dist-info}/WHEEL +0 -0
- {indoxrouter-0.1.0.dist-info → indoxrouter-0.1.2.dist-info}/top_level.txt +0 -0
indoxRouter/providers/google.py
DELETED
@@ -1,216 +0,0 @@
|
|
1
|
-
import json
|
2
|
-
import os
|
3
|
-
from pathlib import Path
|
4
|
-
from typing import Dict, Any, List, Optional
|
5
|
-
|
6
|
-
import google.generativeai as genai
|
7
|
-
from google.generativeai.types import HarmCategory, HarmBlockThreshold
|
8
|
-
|
9
|
-
from .base_provider import BaseProvider
|
10
|
-
from ..utils.exceptions import ModelNotFoundError, ProviderAPIError, RateLimitError
|
11
|
-
|
12
|
-
|
13
|
-
class Provider(BaseProvider):
|
14
|
-
"""
|
15
|
-
Google Gemini provider implementation
|
16
|
-
"""
|
17
|
-
|
18
|
-
def __init__(self, api_key: str, model_name: str):
|
19
|
-
"""
|
20
|
-
Initialize the Google provider
|
21
|
-
|
22
|
-
Args:
|
23
|
-
api_key: Google API key
|
24
|
-
model_name: Model name (e.g., gemini-1.5-pro)
|
25
|
-
"""
|
26
|
-
super().__init__(api_key, model_name)
|
27
|
-
|
28
|
-
# Configure the Google API client
|
29
|
-
genai.configure(api_key=api_key)
|
30
|
-
|
31
|
-
# Load model configuration
|
32
|
-
self.model_config = self._load_model_config(model_name)
|
33
|
-
|
34
|
-
# Set default generation config
|
35
|
-
self.generation_config = {
|
36
|
-
"temperature": 0.7,
|
37
|
-
"top_p": 0.95,
|
38
|
-
"top_k": 40,
|
39
|
-
"max_output_tokens": 8192,
|
40
|
-
}
|
41
|
-
|
42
|
-
# Set default safety settings (moderate filtering)
|
43
|
-
self.safety_settings = {
|
44
|
-
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
45
|
-
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
46
|
-
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
47
|
-
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
48
|
-
}
|
49
|
-
|
50
|
-
def _load_model_config(self, model_name: str) -> Dict[str, Any]:
|
51
|
-
"""
|
52
|
-
Load model configuration from the JSON file
|
53
|
-
|
54
|
-
Args:
|
55
|
-
model_name: Model name to load configuration for
|
56
|
-
|
57
|
-
Returns:
|
58
|
-
Model configuration dictionary
|
59
|
-
|
60
|
-
Raises:
|
61
|
-
ModelNotFoundError: If the model is not found in the configuration
|
62
|
-
"""
|
63
|
-
config_path = Path(__file__).parent / "google.json"
|
64
|
-
|
65
|
-
try:
|
66
|
-
with open(config_path, "r") as f:
|
67
|
-
models = json.load(f)
|
68
|
-
|
69
|
-
for model in models:
|
70
|
-
if model.get("modelName") == model_name:
|
71
|
-
return model
|
72
|
-
|
73
|
-
raise ModelNotFoundError(f"Model {model_name} not found in Google provider")
|
74
|
-
|
75
|
-
except FileNotFoundError:
|
76
|
-
raise ModelNotFoundError(f"Google provider configuration file not found")
|
77
|
-
except json.JSONDecodeError:
|
78
|
-
raise ModelNotFoundError(
|
79
|
-
f"Invalid JSON in Google provider configuration file"
|
80
|
-
)
|
81
|
-
|
82
|
-
def estimate_cost(self, prompt: str, max_tokens: int) -> float:
|
83
|
-
"""
|
84
|
-
Estimate the cost of generating a completion
|
85
|
-
|
86
|
-
Args:
|
87
|
-
prompt: Prompt text
|
88
|
-
max_tokens: Maximum number of tokens to generate
|
89
|
-
|
90
|
-
Returns:
|
91
|
-
Estimated cost in USD
|
92
|
-
"""
|
93
|
-
# Count tokens in the prompt
|
94
|
-
prompt_tokens = self.count_tokens(prompt)
|
95
|
-
|
96
|
-
# Calculate cost based on input and output pricing
|
97
|
-
input_cost = (prompt_tokens / 1000) * self.model_config.get(
|
98
|
-
"inputPricePer1KTokens", 0
|
99
|
-
)
|
100
|
-
output_cost = (max_tokens / 1000) * self.model_config.get(
|
101
|
-
"outputPricePer1KTokens", 0
|
102
|
-
)
|
103
|
-
|
104
|
-
return input_cost + output_cost
|
105
|
-
|
106
|
-
def count_tokens(self, text: str) -> int:
|
107
|
-
"""
|
108
|
-
Count the number of tokens in a text
|
109
|
-
|
110
|
-
Args:
|
111
|
-
text: Text to count tokens for
|
112
|
-
|
113
|
-
Returns:
|
114
|
-
Number of tokens
|
115
|
-
"""
|
116
|
-
try:
|
117
|
-
# Use Google's tokenizer if available
|
118
|
-
model = genai.GenerativeModel(self.model_name)
|
119
|
-
return model.count_tokens(text).total_tokens
|
120
|
-
except Exception:
|
121
|
-
# Fallback to approximate token counting
|
122
|
-
return len(text.split())
|
123
|
-
|
124
|
-
def generate(self, prompt: str, **kwargs) -> Dict[str, Any]:
|
125
|
-
"""
|
126
|
-
Generate a completion for the given prompt
|
127
|
-
|
128
|
-
Args:
|
129
|
-
prompt: Prompt text
|
130
|
-
**kwargs: Additional parameters for the generation
|
131
|
-
|
132
|
-
Returns:
|
133
|
-
Dictionary containing the generated text, cost, and usage statistics
|
134
|
-
|
135
|
-
Raises:
|
136
|
-
ProviderAPIError: If there's an error with the provider API
|
137
|
-
RateLimitError: If the provider's rate limit is exceeded
|
138
|
-
"""
|
139
|
-
try:
|
140
|
-
# Get generation parameters
|
141
|
-
max_tokens = kwargs.get("max_tokens", 1024)
|
142
|
-
temperature = kwargs.get("temperature", 0.7)
|
143
|
-
top_p = kwargs.get("top_p", 0.95)
|
144
|
-
top_k = kwargs.get("top_k", 40)
|
145
|
-
|
146
|
-
# Update generation config
|
147
|
-
generation_config = {
|
148
|
-
"temperature": temperature,
|
149
|
-
"top_p": top_p,
|
150
|
-
"top_k": top_k,
|
151
|
-
"max_output_tokens": max_tokens,
|
152
|
-
}
|
153
|
-
|
154
|
-
# Prepare system prompt if provided
|
155
|
-
system_prompt = kwargs.get(
|
156
|
-
"system_prompt", self.model_config.get("systemPrompt", "")
|
157
|
-
)
|
158
|
-
|
159
|
-
# Create the model
|
160
|
-
model = genai.GenerativeModel(
|
161
|
-
model_name=self.model_name,
|
162
|
-
generation_config=generation_config,
|
163
|
-
safety_settings=self.safety_settings,
|
164
|
-
)
|
165
|
-
|
166
|
-
# Format the prompt using the template if available
|
167
|
-
prompt_template = self.model_config.get("promptTemplate", "")
|
168
|
-
if prompt_template and "%1" in prompt_template:
|
169
|
-
formatted_prompt = prompt_template.replace("%1", prompt)
|
170
|
-
else:
|
171
|
-
formatted_prompt = prompt
|
172
|
-
|
173
|
-
# Generate the completion
|
174
|
-
if system_prompt:
|
175
|
-
response = model.generate_content([system_prompt, formatted_prompt])
|
176
|
-
else:
|
177
|
-
response = model.generate_content(formatted_prompt)
|
178
|
-
|
179
|
-
# Extract the generated text
|
180
|
-
generated_text = response.text
|
181
|
-
|
182
|
-
# Get token counts
|
183
|
-
prompt_tokens = self.count_tokens(prompt)
|
184
|
-
completion_tokens = self.count_tokens(generated_text)
|
185
|
-
total_tokens = prompt_tokens + completion_tokens
|
186
|
-
|
187
|
-
# Calculate cost
|
188
|
-
cost = self.estimate_cost(prompt, completion_tokens)
|
189
|
-
|
190
|
-
# Prepare the response
|
191
|
-
result = {
|
192
|
-
"text": generated_text,
|
193
|
-
"cost": cost,
|
194
|
-
"usage": {
|
195
|
-
"prompt_tokens": prompt_tokens,
|
196
|
-
"completion_tokens": completion_tokens,
|
197
|
-
"total_tokens": total_tokens,
|
198
|
-
},
|
199
|
-
}
|
200
|
-
|
201
|
-
return self.validate_response(result)
|
202
|
-
|
203
|
-
except Exception as e:
|
204
|
-
error_message = str(e)
|
205
|
-
|
206
|
-
# Handle rate limit errors
|
207
|
-
if (
|
208
|
-
"rate limit" in error_message.lower()
|
209
|
-
or "quota" in error_message.lower()
|
210
|
-
):
|
211
|
-
raise RateLimitError(f"Google API rate limit exceeded: {error_message}")
|
212
|
-
|
213
|
-
# Handle other API errors
|
214
|
-
raise ProviderAPIError(
|
215
|
-
f"Error generating completion with Google API: {error_message}", e
|
216
|
-
)
|
indoxRouter/providers/llama.py
DELETED
@@ -1,164 +0,0 @@
|
|
1
|
-
import os
|
2
|
-
import json
|
3
|
-
import requests
|
4
|
-
from typing import Dict, Any, Optional, List
|
5
|
-
from .base_provider import BaseProvider
|
6
|
-
|
7
|
-
|
8
|
-
class Provider(BaseProvider):
|
9
|
-
def __init__(self, api_key: str, model_name: str):
|
10
|
-
"""
|
11
|
-
Initialize the Llama provider with API key and model name.
|
12
|
-
|
13
|
-
Args:
|
14
|
-
api_key (str): The API key for authentication
|
15
|
-
model_name (str): The name of the model to use
|
16
|
-
"""
|
17
|
-
super().__init__(api_key, model_name)
|
18
|
-
self.base_url = os.environ.get("LLAMA_API_BASE", "https://llama-api.meta.ai/v1")
|
19
|
-
self.headers = {
|
20
|
-
"Authorization": f"Bearer {api_key}",
|
21
|
-
"Content-Type": "application/json",
|
22
|
-
}
|
23
|
-
self.model_config = self._load_model_config()
|
24
|
-
|
25
|
-
def _load_model_config(self) -> Dict[str, Any]:
|
26
|
-
"""
|
27
|
-
Load the model configuration from the JSON file.
|
28
|
-
|
29
|
-
Returns:
|
30
|
-
Dict[str, Any]: The model configuration
|
31
|
-
"""
|
32
|
-
current_dir = os.path.dirname(os.path.abspath(__file__))
|
33
|
-
config_path = os.path.join(current_dir, "llama.json")
|
34
|
-
|
35
|
-
with open(config_path, "r") as f:
|
36
|
-
models = json.load(f)
|
37
|
-
|
38
|
-
for model in models:
|
39
|
-
if model["modelName"] == self.model_name:
|
40
|
-
return model
|
41
|
-
|
42
|
-
raise ValueError(f"Model {self.model_name} not found in configuration")
|
43
|
-
|
44
|
-
def estimate_cost(self, prompt: str, max_tokens: int = 100) -> float:
|
45
|
-
"""
|
46
|
-
Estimate the cost of generating a completion.
|
47
|
-
|
48
|
-
Args:
|
49
|
-
prompt (str): The prompt to generate a completion for
|
50
|
-
max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 100.
|
51
|
-
|
52
|
-
Returns:
|
53
|
-
float: The estimated cost in USD
|
54
|
-
"""
|
55
|
-
input_tokens = self.count_tokens(prompt)
|
56
|
-
input_cost = (input_tokens / 1000) * self.model_config["inputPricePer1KTokens"]
|
57
|
-
output_cost = (max_tokens / 1000) * self.model_config["outputPricePer1KTokens"]
|
58
|
-
return input_cost + output_cost
|
59
|
-
|
60
|
-
def count_tokens(self, text: str) -> int:
|
61
|
-
"""
|
62
|
-
Count the number of tokens in a text.
|
63
|
-
This is a simple approximation. For more accurate counts, consider using a tokenizer.
|
64
|
-
|
65
|
-
Args:
|
66
|
-
text (str): The text to count tokens for
|
67
|
-
|
68
|
-
Returns:
|
69
|
-
int: The number of tokens
|
70
|
-
"""
|
71
|
-
# Simple approximation: 1 token ≈ 4 characters
|
72
|
-
return len(text) // 4
|
73
|
-
|
74
|
-
def generate(
|
75
|
-
self,
|
76
|
-
prompt: str,
|
77
|
-
max_tokens: int = 100,
|
78
|
-
temperature: float = 0.7,
|
79
|
-
top_p: float = 1.0,
|
80
|
-
frequency_penalty: float = 0.0,
|
81
|
-
presence_penalty: float = 0.0,
|
82
|
-
stop: Optional[List[str]] = None,
|
83
|
-
) -> Dict[str, Any]:
|
84
|
-
"""
|
85
|
-
Generate a completion for the given prompt.
|
86
|
-
|
87
|
-
Args:
|
88
|
-
prompt (str): The prompt to generate a completion for
|
89
|
-
max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 100.
|
90
|
-
temperature (float, optional): The temperature for sampling. Defaults to 0.7.
|
91
|
-
top_p (float, optional): The top-p value for nucleus sampling. Defaults to 1.0.
|
92
|
-
frequency_penalty (float, optional): The frequency penalty. Defaults to 0.0.
|
93
|
-
presence_penalty (float, optional): The presence penalty. Defaults to 0.0.
|
94
|
-
stop (Optional[List[str]], optional): A list of stop sequences. Defaults to None.
|
95
|
-
|
96
|
-
Returns:
|
97
|
-
Dict[str, Any]: The generated completion
|
98
|
-
"""
|
99
|
-
# Format the prompt according to the model's template
|
100
|
-
prompt_template = self.model_config.get("promptTemplate", "%1%2")
|
101
|
-
formatted_prompt = prompt_template.replace("%1", prompt).replace("%2", "")
|
102
|
-
|
103
|
-
# Prepare the request payload
|
104
|
-
payload = {
|
105
|
-
"model": self.model_config.get("companyModelName", self.model_name),
|
106
|
-
"prompt": formatted_prompt,
|
107
|
-
"max_tokens": max_tokens,
|
108
|
-
"temperature": temperature,
|
109
|
-
"top_p": top_p,
|
110
|
-
"frequency_penalty": frequency_penalty,
|
111
|
-
"presence_penalty": presence_penalty,
|
112
|
-
}
|
113
|
-
|
114
|
-
if stop:
|
115
|
-
payload["stop"] = stop
|
116
|
-
|
117
|
-
# Make the API request
|
118
|
-
try:
|
119
|
-
response = requests.post(
|
120
|
-
f"{self.base_url}/completions", headers=self.headers, json=payload
|
121
|
-
)
|
122
|
-
response.raise_for_status()
|
123
|
-
result = response.json()
|
124
|
-
|
125
|
-
# Calculate the cost
|
126
|
-
input_tokens = result.get("usage", {}).get(
|
127
|
-
"prompt_tokens", self.count_tokens(prompt)
|
128
|
-
)
|
129
|
-
output_tokens = result.get("usage", {}).get("completion_tokens", 0)
|
130
|
-
input_cost = (input_tokens / 1000) * self.model_config[
|
131
|
-
"inputPricePer1KTokens"
|
132
|
-
]
|
133
|
-
output_cost = (output_tokens / 1000) * self.model_config[
|
134
|
-
"outputPricePer1KTokens"
|
135
|
-
]
|
136
|
-
total_cost = input_cost + output_cost
|
137
|
-
|
138
|
-
# Format the response
|
139
|
-
return self.validate_response(
|
140
|
-
{
|
141
|
-
"text": result.get("choices", [{}])[0].get("text", ""),
|
142
|
-
"cost": total_cost,
|
143
|
-
"usage": {
|
144
|
-
"input_tokens": input_tokens,
|
145
|
-
"output_tokens": output_tokens,
|
146
|
-
"input_cost": input_cost,
|
147
|
-
"output_cost": output_cost,
|
148
|
-
},
|
149
|
-
"raw_response": result,
|
150
|
-
}
|
151
|
-
)
|
152
|
-
|
153
|
-
except requests.exceptions.RequestException as e:
|
154
|
-
return {
|
155
|
-
"text": f"Error: {str(e)}",
|
156
|
-
"cost": 0,
|
157
|
-
"usage": {
|
158
|
-
"input_tokens": 0,
|
159
|
-
"output_tokens": 0,
|
160
|
-
"input_cost": 0,
|
161
|
-
"output_cost": 0,
|
162
|
-
},
|
163
|
-
"error": str(e),
|
164
|
-
}
|
indoxRouter/providers/meta.py
DELETED
@@ -1,227 +0,0 @@
|
|
1
|
-
import json
|
2
|
-
import os
|
3
|
-
from pathlib import Path
|
4
|
-
from typing import Dict, Any, List, Optional
|
5
|
-
import requests
|
6
|
-
|
7
|
-
from .base_provider import BaseProvider
|
8
|
-
from ..utils.exceptions import ModelNotFoundError, ProviderAPIError, RateLimitError
|
9
|
-
|
10
|
-
|
11
|
-
class Provider(BaseProvider):
|
12
|
-
"""
|
13
|
-
Meta Llama provider implementation
|
14
|
-
"""
|
15
|
-
|
16
|
-
def __init__(self, api_key: str, model_name: str):
|
17
|
-
"""
|
18
|
-
Initialize the Meta provider
|
19
|
-
|
20
|
-
Args:
|
21
|
-
api_key: Meta API key
|
22
|
-
model_name: Model name (e.g., llama-3-70b-instruct)
|
23
|
-
"""
|
24
|
-
super().__init__(api_key, model_name)
|
25
|
-
|
26
|
-
# Load model configuration
|
27
|
-
self.model_config = self._load_model_config(model_name)
|
28
|
-
|
29
|
-
# Meta API base URL
|
30
|
-
self.api_base = os.environ.get("META_API_BASE", "https://api.meta.ai/v1")
|
31
|
-
|
32
|
-
# Default generation parameters
|
33
|
-
self.default_params = {
|
34
|
-
"temperature": 0.7,
|
35
|
-
"top_p": 0.9,
|
36
|
-
"max_tokens": 1024,
|
37
|
-
}
|
38
|
-
|
39
|
-
def _load_model_config(self, model_name: str) -> Dict[str, Any]:
|
40
|
-
"""
|
41
|
-
Load model configuration from the JSON file
|
42
|
-
|
43
|
-
Args:
|
44
|
-
model_name: Model name to load configuration for
|
45
|
-
|
46
|
-
Returns:
|
47
|
-
Model configuration dictionary
|
48
|
-
|
49
|
-
Raises:
|
50
|
-
ModelNotFoundError: If the model is not found in the configuration
|
51
|
-
"""
|
52
|
-
config_path = Path(__file__).parent / "meta.json"
|
53
|
-
|
54
|
-
try:
|
55
|
-
with open(config_path, "r") as f:
|
56
|
-
models = json.load(f)
|
57
|
-
|
58
|
-
for model in models:
|
59
|
-
if model.get("modelName") == model_name:
|
60
|
-
return model
|
61
|
-
|
62
|
-
raise ModelNotFoundError(f"Model {model_name} not found in Meta provider")
|
63
|
-
|
64
|
-
except FileNotFoundError:
|
65
|
-
raise ModelNotFoundError(f"Meta provider configuration file not found")
|
66
|
-
except json.JSONDecodeError:
|
67
|
-
raise ModelNotFoundError(
|
68
|
-
f"Invalid JSON in Meta provider configuration file"
|
69
|
-
)
|
70
|
-
|
71
|
-
def estimate_cost(self, prompt: str, max_tokens: int) -> float:
|
72
|
-
"""
|
73
|
-
Estimate the cost of generating a completion
|
74
|
-
|
75
|
-
Args:
|
76
|
-
prompt: Prompt text
|
77
|
-
max_tokens: Maximum number of tokens to generate
|
78
|
-
|
79
|
-
Returns:
|
80
|
-
Estimated cost in USD
|
81
|
-
"""
|
82
|
-
# Count tokens in the prompt
|
83
|
-
prompt_tokens = self.count_tokens(prompt)
|
84
|
-
|
85
|
-
# Calculate cost based on input and output pricing
|
86
|
-
input_cost = (prompt_tokens / 1000) * self.model_config.get(
|
87
|
-
"inputPricePer1KTokens", 0
|
88
|
-
)
|
89
|
-
output_cost = (max_tokens / 1000) * self.model_config.get(
|
90
|
-
"outputPricePer1KTokens", 0
|
91
|
-
)
|
92
|
-
|
93
|
-
return input_cost + output_cost
|
94
|
-
|
95
|
-
def count_tokens(self, text: str) -> int:
|
96
|
-
"""
|
97
|
-
Count the number of tokens in a text
|
98
|
-
|
99
|
-
Args:
|
100
|
-
text: Text to count tokens for
|
101
|
-
|
102
|
-
Returns:
|
103
|
-
Number of tokens
|
104
|
-
"""
|
105
|
-
# Meta doesn't provide a direct token counting API
|
106
|
-
# This is a rough approximation - in production, consider using a tokenizer library
|
107
|
-
return len(text.split()) * 1.3 # Rough approximation
|
108
|
-
|
109
|
-
def generate(self, prompt: str, **kwargs) -> Dict[str, Any]:
|
110
|
-
"""
|
111
|
-
Generate a completion for the given prompt
|
112
|
-
|
113
|
-
Args:
|
114
|
-
prompt: Prompt text
|
115
|
-
**kwargs: Additional parameters for the generation
|
116
|
-
|
117
|
-
Returns:
|
118
|
-
Dictionary containing the generated text, cost, and usage statistics
|
119
|
-
|
120
|
-
Raises:
|
121
|
-
ProviderAPIError: If there's an error with the provider API
|
122
|
-
RateLimitError: If the provider's rate limit is exceeded
|
123
|
-
"""
|
124
|
-
try:
|
125
|
-
# Get generation parameters
|
126
|
-
max_tokens = kwargs.get("max_tokens", self.default_params["max_tokens"])
|
127
|
-
temperature = kwargs.get("temperature", self.default_params["temperature"])
|
128
|
-
top_p = kwargs.get("top_p", self.default_params["top_p"])
|
129
|
-
|
130
|
-
# Prepare system prompt if provided
|
131
|
-
system_prompt = kwargs.get(
|
132
|
-
"system_prompt", self.model_config.get("systemPrompt", "")
|
133
|
-
)
|
134
|
-
|
135
|
-
# Format the prompt using the template if available
|
136
|
-
prompt_template = self.model_config.get("promptTemplate", "")
|
137
|
-
if prompt_template and "%1" in prompt_template:
|
138
|
-
formatted_prompt = prompt_template.replace("%1", prompt)
|
139
|
-
else:
|
140
|
-
formatted_prompt = prompt
|
141
|
-
|
142
|
-
# Prepare the request payload
|
143
|
-
payload = {
|
144
|
-
"model": self.model_name,
|
145
|
-
"messages": [],
|
146
|
-
"temperature": temperature,
|
147
|
-
"top_p": top_p,
|
148
|
-
"max_tokens": max_tokens,
|
149
|
-
}
|
150
|
-
|
151
|
-
# Add system message if provided
|
152
|
-
if system_prompt:
|
153
|
-
payload["messages"].append({"role": "system", "content": system_prompt})
|
154
|
-
|
155
|
-
# Add user message
|
156
|
-
payload["messages"].append({"role": "user", "content": formatted_prompt})
|
157
|
-
|
158
|
-
# Make the API request
|
159
|
-
headers = {
|
160
|
-
"Authorization": f"Bearer {self.api_key}",
|
161
|
-
"Content-Type": "application/json",
|
162
|
-
}
|
163
|
-
|
164
|
-
response = requests.post(
|
165
|
-
f"{self.api_base}/chat/completions",
|
166
|
-
headers=headers,
|
167
|
-
json=payload,
|
168
|
-
timeout=60,
|
169
|
-
)
|
170
|
-
|
171
|
-
# Check for errors
|
172
|
-
if response.status_code != 200:
|
173
|
-
error_message = (
|
174
|
-
response.json().get("error", {}).get("message", "Unknown error")
|
175
|
-
)
|
176
|
-
|
177
|
-
if response.status_code == 429:
|
178
|
-
raise RateLimitError(
|
179
|
-
f"Meta API rate limit exceeded: {error_message}"
|
180
|
-
)
|
181
|
-
else:
|
182
|
-
raise ProviderAPIError(f"Meta API error: {error_message}")
|
183
|
-
|
184
|
-
# Parse the response
|
185
|
-
response_data = response.json()
|
186
|
-
|
187
|
-
# Extract the generated text
|
188
|
-
generated_text = (
|
189
|
-
response_data.get("choices", [{}])[0]
|
190
|
-
.get("message", {})
|
191
|
-
.get("content", "")
|
192
|
-
)
|
193
|
-
|
194
|
-
# Get token usage
|
195
|
-
usage = response_data.get("usage", {})
|
196
|
-
prompt_tokens = usage.get(
|
197
|
-
"prompt_tokens", self.count_tokens(formatted_prompt)
|
198
|
-
)
|
199
|
-
completion_tokens = usage.get(
|
200
|
-
"completion_tokens", self.count_tokens(generated_text)
|
201
|
-
)
|
202
|
-
total_tokens = usage.get("total_tokens", prompt_tokens + completion_tokens)
|
203
|
-
|
204
|
-
# Calculate cost
|
205
|
-
cost = self.estimate_cost(formatted_prompt, completion_tokens)
|
206
|
-
|
207
|
-
# Prepare the response
|
208
|
-
result = {
|
209
|
-
"text": generated_text,
|
210
|
-
"cost": cost,
|
211
|
-
"usage": {
|
212
|
-
"prompt_tokens": prompt_tokens,
|
213
|
-
"completion_tokens": completion_tokens,
|
214
|
-
"total_tokens": total_tokens,
|
215
|
-
},
|
216
|
-
}
|
217
|
-
|
218
|
-
return self.validate_response(result)
|
219
|
-
|
220
|
-
except RateLimitError:
|
221
|
-
# Re-raise rate limit errors
|
222
|
-
raise
|
223
|
-
except Exception as e:
|
224
|
-
# Handle other errors
|
225
|
-
raise ProviderAPIError(
|
226
|
-
f"Error generating completion with Meta API: {str(e)}", e
|
227
|
-
)
|