indoxrouter 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- indoxRouter/__init__.py +83 -0
- indoxRouter/client.py +564 -218
- indoxRouter/client_resourses/__init__.py +20 -0
- indoxRouter/client_resourses/base.py +67 -0
- indoxRouter/client_resourses/chat.py +144 -0
- indoxRouter/client_resourses/completion.py +138 -0
- indoxRouter/client_resourses/embedding.py +83 -0
- indoxRouter/client_resourses/image.py +116 -0
- indoxRouter/client_resourses/models.py +114 -0
- indoxRouter/config.py +151 -0
- indoxRouter/constants/__init__.py +81 -0
- indoxRouter/exceptions/__init__.py +70 -0
- indoxRouter/models/__init__.py +111 -0
- indoxRouter/providers/__init__.py +50 -50
- indoxRouter/providers/ai21labs.json +128 -0
- indoxRouter/providers/base_provider.py +62 -30
- indoxRouter/providers/claude.json +164 -0
- indoxRouter/providers/cohere.json +116 -0
- indoxRouter/providers/databricks.json +110 -0
- indoxRouter/providers/deepseek.json +110 -0
- indoxRouter/providers/google.json +128 -0
- indoxRouter/providers/meta.json +128 -0
- indoxRouter/providers/mistral.json +146 -0
- indoxRouter/providers/nvidia.json +110 -0
- indoxRouter/providers/openai.json +308 -0
- indoxRouter/providers/openai.py +471 -72
- indoxRouter/providers/qwen.json +110 -0
- indoxRouter/utils/__init__.py +240 -0
- indoxrouter-0.1.2.dist-info/LICENSE +21 -0
- indoxrouter-0.1.2.dist-info/METADATA +259 -0
- indoxrouter-0.1.2.dist-info/RECORD +33 -0
- indoxRouter/api_endpoints.py +0 -336
- indoxRouter/client_package.py +0 -138
- indoxRouter/init_db.py +0 -71
- indoxRouter/main.py +0 -711
- indoxRouter/migrations/__init__.py +0 -1
- indoxRouter/migrations/env.py +0 -98
- indoxRouter/migrations/versions/__init__.py +0 -1
- indoxRouter/migrations/versions/initial_schema.py +0 -84
- indoxRouter/providers/ai21.py +0 -268
- indoxRouter/providers/claude.py +0 -177
- indoxRouter/providers/cohere.py +0 -171
- indoxRouter/providers/databricks.py +0 -166
- indoxRouter/providers/deepseek.py +0 -166
- indoxRouter/providers/google.py +0 -216
- indoxRouter/providers/llama.py +0 -164
- indoxRouter/providers/meta.py +0 -227
- indoxRouter/providers/mistral.py +0 -182
- indoxRouter/providers/nvidia.py +0 -164
- indoxrouter-0.1.0.dist-info/METADATA +0 -179
- indoxrouter-0.1.0.dist-info/RECORD +0 -27
- {indoxrouter-0.1.0.dist-info → indoxrouter-0.1.2.dist-info}/WHEEL +0 -0
- {indoxrouter-0.1.0.dist-info → indoxrouter-0.1.2.dist-info}/top_level.txt +0 -0
@@ -1 +0,0 @@
|
|
1
|
-
# IndoxRouter migrations package
|
indoxRouter/migrations/env.py
DELETED
@@ -1,98 +0,0 @@
|
|
1
|
-
from logging.config import fileConfig
|
2
|
-
|
3
|
-
from sqlalchemy import engine_from_config
|
4
|
-
from sqlalchemy import pool
|
5
|
-
|
6
|
-
from alembic import context
|
7
|
-
|
8
|
-
import os
|
9
|
-
import sys
|
10
|
-
|
11
|
-
# Add the parent directory to sys.path
|
12
|
-
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
13
|
-
|
14
|
-
# Import the SQLAlchemy models
|
15
|
-
from indoxRouter.utils.database import Base
|
16
|
-
from indoxRouter.utils.config import get_config
|
17
|
-
|
18
|
-
# this is the Alembic Config object, which provides
|
19
|
-
# access to the values within the .ini file in use.
|
20
|
-
config = context.config
|
21
|
-
|
22
|
-
# Interpret the config file for Python logging.
|
23
|
-
# This line sets up loggers basically.
|
24
|
-
fileConfig(config.config_file_name)
|
25
|
-
|
26
|
-
# Get the database URL from the IndoxRouter config
|
27
|
-
indox_config = get_config()
|
28
|
-
db_config = indox_config.get_section("database")
|
29
|
-
|
30
|
-
db_type = db_config.get("type", "sqlite")
|
31
|
-
if db_type == "sqlite":
|
32
|
-
db_url = f"sqlite:///{db_config.get('path', 'indoxrouter.db')}"
|
33
|
-
elif db_type == "postgres":
|
34
|
-
db_url = f"postgresql://{db_config.get('user')}:{db_config.get('password')}@{db_config.get('host')}:{db_config.get('port')}/{db_config.get('database')}"
|
35
|
-
else:
|
36
|
-
raise ValueError(f"Unsupported database type: {db_type}")
|
37
|
-
|
38
|
-
# Override the SQLAlchemy URL in the Alembic config
|
39
|
-
config.set_main_option("sqlalchemy.url", db_url)
|
40
|
-
|
41
|
-
# add your model's MetaData object here
|
42
|
-
# for 'autogenerate' support
|
43
|
-
target_metadata = Base.metadata
|
44
|
-
|
45
|
-
# other values from the config, defined by the needs of env.py,
|
46
|
-
# can be acquired:
|
47
|
-
# my_important_option = config.get_main_option("my_important_option")
|
48
|
-
# ... etc.
|
49
|
-
|
50
|
-
|
51
|
-
def run_migrations_offline():
|
52
|
-
"""Run migrations in 'offline' mode.
|
53
|
-
|
54
|
-
This configures the context with just a URL
|
55
|
-
and not an Engine, though an Engine is acceptable
|
56
|
-
here as well. By skipping the Engine creation
|
57
|
-
we don't even need a DBAPI to be available.
|
58
|
-
|
59
|
-
Calls to context.execute() here emit the given string to the
|
60
|
-
script output.
|
61
|
-
|
62
|
-
"""
|
63
|
-
url = config.get_main_option("sqlalchemy.url")
|
64
|
-
context.configure(
|
65
|
-
url=url,
|
66
|
-
target_metadata=target_metadata,
|
67
|
-
literal_binds=True,
|
68
|
-
dialect_opts={"paramstyle": "named"},
|
69
|
-
)
|
70
|
-
|
71
|
-
with context.begin_transaction():
|
72
|
-
context.run_migrations()
|
73
|
-
|
74
|
-
|
75
|
-
def run_migrations_online():
|
76
|
-
"""Run migrations in 'online' mode.
|
77
|
-
|
78
|
-
In this scenario we need to create an Engine
|
79
|
-
and associate a connection with the context.
|
80
|
-
|
81
|
-
"""
|
82
|
-
connectable = engine_from_config(
|
83
|
-
config.get_section(config.config_ini_section),
|
84
|
-
prefix="sqlalchemy.",
|
85
|
-
poolclass=pool.NullPool,
|
86
|
-
)
|
87
|
-
|
88
|
-
with connectable.connect() as connection:
|
89
|
-
context.configure(connection=connection, target_metadata=target_metadata)
|
90
|
-
|
91
|
-
with context.begin_transaction():
|
92
|
-
context.run_migrations()
|
93
|
-
|
94
|
-
|
95
|
-
if context.is_offline_mode():
|
96
|
-
run_migrations_offline()
|
97
|
-
else:
|
98
|
-
run_migrations_online()
|
@@ -1 +0,0 @@
|
|
1
|
-
# IndoxRouter migrations versions package
|
@@ -1,84 +0,0 @@
|
|
1
|
-
"""Initial database schema
|
2
|
-
|
3
|
-
Revision ID: 001
|
4
|
-
Revises:
|
5
|
-
Create Date: 2023-03-01
|
6
|
-
|
7
|
-
"""
|
8
|
-
|
9
|
-
from alembic import op
|
10
|
-
import sqlalchemy as sa
|
11
|
-
|
12
|
-
|
13
|
-
# revision identifiers, used by Alembic.
|
14
|
-
revision = "001"
|
15
|
-
down_revision = None
|
16
|
-
branch_labels = None
|
17
|
-
depends_on = None
|
18
|
-
|
19
|
-
|
20
|
-
def upgrade():
|
21
|
-
# Create users table
|
22
|
-
op.create_table(
|
23
|
-
"users",
|
24
|
-
sa.Column("id", sa.String(36), primary_key=True),
|
25
|
-
sa.Column("email", sa.String(255), unique=True, nullable=False),
|
26
|
-
sa.Column("name", sa.String(255), nullable=False),
|
27
|
-
sa.Column("balance", sa.Float, nullable=False, default=0.0),
|
28
|
-
sa.Column("is_active", sa.Boolean, nullable=False, default=True),
|
29
|
-
sa.Column(
|
30
|
-
"created_at", sa.DateTime, nullable=False, server_default=sa.func.now()
|
31
|
-
),
|
32
|
-
sa.Column(
|
33
|
-
"updated_at",
|
34
|
-
sa.DateTime,
|
35
|
-
nullable=False,
|
36
|
-
server_default=sa.func.now(),
|
37
|
-
onupdate=sa.func.now(),
|
38
|
-
),
|
39
|
-
)
|
40
|
-
|
41
|
-
# Create api_keys table
|
42
|
-
op.create_table(
|
43
|
-
"api_keys",
|
44
|
-
sa.Column("id", sa.String(36), primary_key=True),
|
45
|
-
sa.Column("user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False),
|
46
|
-
sa.Column("key", sa.String(64), unique=True, nullable=False),
|
47
|
-
sa.Column("is_active", sa.Boolean, nullable=False, default=True),
|
48
|
-
sa.Column(
|
49
|
-
"created_at", sa.DateTime, nullable=False, server_default=sa.func.now()
|
50
|
-
),
|
51
|
-
sa.Column("last_used_at", sa.DateTime, nullable=True),
|
52
|
-
)
|
53
|
-
|
54
|
-
# Create usage_logs table
|
55
|
-
op.create_table(
|
56
|
-
"usage_logs",
|
57
|
-
sa.Column("id", sa.String(36), primary_key=True),
|
58
|
-
sa.Column("user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False),
|
59
|
-
sa.Column("provider", sa.String(50), nullable=False),
|
60
|
-
sa.Column("model", sa.String(100), nullable=False),
|
61
|
-
sa.Column("prompt_tokens", sa.Integer, nullable=False),
|
62
|
-
sa.Column("completion_tokens", sa.Integer, nullable=False),
|
63
|
-
sa.Column("total_tokens", sa.Integer, nullable=False),
|
64
|
-
sa.Column("cost", sa.Float, nullable=False),
|
65
|
-
sa.Column(
|
66
|
-
"created_at", sa.DateTime, nullable=False, server_default=sa.func.now()
|
67
|
-
),
|
68
|
-
)
|
69
|
-
|
70
|
-
# Create indexes
|
71
|
-
op.create_index("idx_users_email", "users", ["email"])
|
72
|
-
op.create_index("idx_api_keys_user_id", "api_keys", ["user_id"])
|
73
|
-
op.create_index("idx_api_keys_key", "api_keys", ["key"])
|
74
|
-
op.create_index("idx_usage_logs_user_id", "usage_logs", ["user_id"])
|
75
|
-
op.create_index(
|
76
|
-
"idx_usage_logs_provider_model", "usage_logs", ["provider", "model"]
|
77
|
-
)
|
78
|
-
op.create_index("idx_usage_logs_created_at", "usage_logs", ["created_at"])
|
79
|
-
|
80
|
-
|
81
|
-
def downgrade():
|
82
|
-
op.drop_table("usage_logs")
|
83
|
-
op.drop_table("api_keys")
|
84
|
-
op.drop_table("users")
|
indoxRouter/providers/ai21.py
DELETED
@@ -1,268 +0,0 @@
|
|
1
|
-
import json
|
2
|
-
import os
|
3
|
-
from pathlib import Path
|
4
|
-
from typing import Dict, Any, List, Optional
|
5
|
-
import requests
|
6
|
-
|
7
|
-
from .base_provider import BaseProvider
|
8
|
-
from ..utils.exceptions import ModelNotFoundError, ProviderAPIError, RateLimitError
|
9
|
-
|
10
|
-
|
11
|
-
class Provider(BaseProvider):
|
12
|
-
"""
|
13
|
-
AI21 Labs provider implementation
|
14
|
-
"""
|
15
|
-
|
16
|
-
def __init__(self, api_key: str, model_name: str):
|
17
|
-
"""
|
18
|
-
Initialize the AI21 Labs provider
|
19
|
-
|
20
|
-
Args:
|
21
|
-
api_key: AI21 Labs API key
|
22
|
-
model_name: Model name (e.g., j2-ultra, j2-mid, jamba-instruct)
|
23
|
-
"""
|
24
|
-
super().__init__(api_key, model_name)
|
25
|
-
|
26
|
-
# Load model configuration
|
27
|
-
self.model_config = self._load_model_config(model_name)
|
28
|
-
|
29
|
-
# AI21 API base URL
|
30
|
-
self.api_base = os.environ.get(
|
31
|
-
"AI21_API_BASE", "https://api.ai21.com/studio/v1"
|
32
|
-
)
|
33
|
-
|
34
|
-
# Default generation parameters
|
35
|
-
self.default_params = {
|
36
|
-
"temperature": 0.7,
|
37
|
-
"top_p": 0.9,
|
38
|
-
"max_tokens": 1024,
|
39
|
-
}
|
40
|
-
|
41
|
-
def _load_model_config(self, model_name: str) -> Dict[str, Any]:
|
42
|
-
"""
|
43
|
-
Load model configuration from the JSON file
|
44
|
-
|
45
|
-
Args:
|
46
|
-
model_name: Model name to load configuration for
|
47
|
-
|
48
|
-
Returns:
|
49
|
-
Model configuration dictionary
|
50
|
-
|
51
|
-
Raises:
|
52
|
-
ModelNotFoundError: If the model is not found in the configuration
|
53
|
-
"""
|
54
|
-
config_path = Path(__file__).parent / "ai21.json"
|
55
|
-
|
56
|
-
try:
|
57
|
-
with open(config_path, "r") as f:
|
58
|
-
models = json.load(f)
|
59
|
-
|
60
|
-
for model in models:
|
61
|
-
if model.get("modelName") == model_name:
|
62
|
-
return model
|
63
|
-
|
64
|
-
raise ModelNotFoundError(
|
65
|
-
f"Model {model_name} not found in AI21 Labs provider"
|
66
|
-
)
|
67
|
-
|
68
|
-
except FileNotFoundError:
|
69
|
-
raise ModelNotFoundError(f"AI21 Labs provider configuration file not found")
|
70
|
-
except json.JSONDecodeError:
|
71
|
-
raise ModelNotFoundError(
|
72
|
-
f"Invalid JSON in AI21 Labs provider configuration file"
|
73
|
-
)
|
74
|
-
|
75
|
-
def estimate_cost(self, prompt: str, max_tokens: int) -> float:
|
76
|
-
"""
|
77
|
-
Estimate the cost of generating a completion
|
78
|
-
|
79
|
-
Args:
|
80
|
-
prompt: Prompt text
|
81
|
-
max_tokens: Maximum number of tokens to generate
|
82
|
-
|
83
|
-
Returns:
|
84
|
-
Estimated cost in USD
|
85
|
-
"""
|
86
|
-
# Count tokens in the prompt
|
87
|
-
prompt_tokens = self.count_tokens(prompt)
|
88
|
-
|
89
|
-
# Calculate cost based on input and output pricing
|
90
|
-
input_cost = (prompt_tokens / 1000) * self.model_config.get(
|
91
|
-
"inputPricePer1KTokens", 0
|
92
|
-
)
|
93
|
-
output_cost = (max_tokens / 1000) * self.model_config.get(
|
94
|
-
"outputPricePer1KTokens", 0
|
95
|
-
)
|
96
|
-
|
97
|
-
return input_cost + output_cost
|
98
|
-
|
99
|
-
def count_tokens(self, text: str) -> int:
|
100
|
-
"""
|
101
|
-
Count the number of tokens in a text
|
102
|
-
|
103
|
-
Args:
|
104
|
-
text: Text to count tokens for
|
105
|
-
|
106
|
-
Returns:
|
107
|
-
Number of tokens
|
108
|
-
"""
|
109
|
-
# AI21 doesn't provide a direct token counting API in their Python client
|
110
|
-
# This is a rough approximation - in production, consider using a tokenizer library
|
111
|
-
return len(text.split()) * 1.3 # Rough approximation
|
112
|
-
|
113
|
-
def generate(self, prompt: str, **kwargs) -> Dict[str, Any]:
|
114
|
-
"""
|
115
|
-
Generate a completion for the given prompt
|
116
|
-
|
117
|
-
Args:
|
118
|
-
prompt: Prompt text
|
119
|
-
**kwargs: Additional parameters for the generation
|
120
|
-
|
121
|
-
Returns:
|
122
|
-
Dictionary containing the generated text, cost, and usage statistics
|
123
|
-
|
124
|
-
Raises:
|
125
|
-
ProviderAPIError: If there's an error with the provider API
|
126
|
-
RateLimitError: If the provider's rate limit is exceeded
|
127
|
-
"""
|
128
|
-
try:
|
129
|
-
# Get generation parameters
|
130
|
-
max_tokens = kwargs.get("max_tokens", self.default_params["max_tokens"])
|
131
|
-
temperature = kwargs.get("temperature", self.default_params["temperature"])
|
132
|
-
top_p = kwargs.get("top_p", self.default_params["top_p"])
|
133
|
-
|
134
|
-
# Prepare system prompt if provided
|
135
|
-
system_prompt = kwargs.get(
|
136
|
-
"system_prompt", self.model_config.get("systemPrompt", "")
|
137
|
-
)
|
138
|
-
|
139
|
-
# Format the prompt using the template if available
|
140
|
-
prompt_template = self.model_config.get("promptTemplate", "")
|
141
|
-
if prompt_template and "%1" in prompt_template:
|
142
|
-
formatted_prompt = prompt_template.replace("%1", prompt)
|
143
|
-
else:
|
144
|
-
formatted_prompt = prompt
|
145
|
-
|
146
|
-
# Combine system prompt and user prompt if needed
|
147
|
-
if system_prompt:
|
148
|
-
full_prompt = f"{system_prompt}\n\n{formatted_prompt}"
|
149
|
-
else:
|
150
|
-
full_prompt = formatted_prompt
|
151
|
-
|
152
|
-
# Check if this is a Jamba model (chat model)
|
153
|
-
is_jamba = "jamba" in self.model_name.lower()
|
154
|
-
|
155
|
-
if is_jamba:
|
156
|
-
# Chat completion endpoint for Jamba models
|
157
|
-
endpoint = f"{self.api_base}/chat/completions"
|
158
|
-
|
159
|
-
# Prepare the request payload for chat
|
160
|
-
payload = {
|
161
|
-
"model": self.model_name,
|
162
|
-
"messages": [],
|
163
|
-
"temperature": temperature,
|
164
|
-
"top_p": top_p,
|
165
|
-
"max_tokens": max_tokens,
|
166
|
-
}
|
167
|
-
|
168
|
-
# Add system message if provided
|
169
|
-
if system_prompt:
|
170
|
-
payload["messages"].append(
|
171
|
-
{"role": "system", "content": system_prompt}
|
172
|
-
)
|
173
|
-
|
174
|
-
# Add user message
|
175
|
-
payload["messages"].append(
|
176
|
-
{"role": "user", "content": formatted_prompt}
|
177
|
-
)
|
178
|
-
else:
|
179
|
-
# Completion endpoint for Jurassic models
|
180
|
-
endpoint = f"{self.api_base}/{self.model_name}/complete"
|
181
|
-
|
182
|
-
# Prepare the request payload for completion
|
183
|
-
payload = {
|
184
|
-
"prompt": full_prompt,
|
185
|
-
"temperature": temperature,
|
186
|
-
"topP": top_p,
|
187
|
-
"maxTokens": max_tokens,
|
188
|
-
}
|
189
|
-
|
190
|
-
# Make the API request
|
191
|
-
headers = {
|
192
|
-
"Authorization": f"Bearer {self.api_key}",
|
193
|
-
"Content-Type": "application/json",
|
194
|
-
}
|
195
|
-
|
196
|
-
response = requests.post(
|
197
|
-
endpoint, headers=headers, json=payload, timeout=60
|
198
|
-
)
|
199
|
-
|
200
|
-
# Check for errors
|
201
|
-
if response.status_code != 200:
|
202
|
-
error_message = response.json().get("detail", "Unknown error")
|
203
|
-
|
204
|
-
if response.status_code == 429:
|
205
|
-
raise RateLimitError(
|
206
|
-
f"AI21 Labs API rate limit exceeded: {error_message}"
|
207
|
-
)
|
208
|
-
else:
|
209
|
-
raise ProviderAPIError(f"AI21 Labs API error: {error_message}")
|
210
|
-
|
211
|
-
# Parse the response
|
212
|
-
response_data = response.json()
|
213
|
-
|
214
|
-
# Extract the generated text based on model type
|
215
|
-
if is_jamba:
|
216
|
-
generated_text = (
|
217
|
-
response_data.get("choices", [{}])[0]
|
218
|
-
.get("message", {})
|
219
|
-
.get("content", "")
|
220
|
-
)
|
221
|
-
|
222
|
-
# Get token usage
|
223
|
-
usage = response_data.get("usage", {})
|
224
|
-
prompt_tokens = usage.get(
|
225
|
-
"prompt_tokens", self.count_tokens(formatted_prompt)
|
226
|
-
)
|
227
|
-
completion_tokens = usage.get(
|
228
|
-
"completion_tokens", self.count_tokens(generated_text)
|
229
|
-
)
|
230
|
-
total_tokens = usage.get(
|
231
|
-
"total_tokens", prompt_tokens + completion_tokens
|
232
|
-
)
|
233
|
-
else:
|
234
|
-
generated_text = (
|
235
|
-
response_data.get("completions", [{}])[0]
|
236
|
-
.get("data", {})
|
237
|
-
.get("text", "")
|
238
|
-
)
|
239
|
-
|
240
|
-
# For Jurassic models, token counts are not directly provided
|
241
|
-
prompt_tokens = self.count_tokens(full_prompt)
|
242
|
-
completion_tokens = self.count_tokens(generated_text)
|
243
|
-
total_tokens = prompt_tokens + completion_tokens
|
244
|
-
|
245
|
-
# Calculate cost
|
246
|
-
cost = self.estimate_cost(full_prompt, completion_tokens)
|
247
|
-
|
248
|
-
# Prepare the response
|
249
|
-
result = {
|
250
|
-
"text": generated_text,
|
251
|
-
"cost": cost,
|
252
|
-
"usage": {
|
253
|
-
"prompt_tokens": prompt_tokens,
|
254
|
-
"completion_tokens": completion_tokens,
|
255
|
-
"total_tokens": total_tokens,
|
256
|
-
},
|
257
|
-
}
|
258
|
-
|
259
|
-
return self.validate_response(result)
|
260
|
-
|
261
|
-
except RateLimitError:
|
262
|
-
# Re-raise rate limit errors
|
263
|
-
raise
|
264
|
-
except Exception as e:
|
265
|
-
# Handle other errors
|
266
|
-
raise ProviderAPIError(
|
267
|
-
f"Error generating completion with AI21 Labs API: {str(e)}", e
|
268
|
-
)
|
indoxRouter/providers/claude.py
DELETED
@@ -1,177 +0,0 @@
|
|
1
|
-
from typing import Dict, Any, Optional, List
|
2
|
-
import json
|
3
|
-
import os
|
4
|
-
from pathlib import Path
|
5
|
-
|
6
|
-
try:
|
7
|
-
import anthropic
|
8
|
-
from anthropic import Anthropic
|
9
|
-
except ImportError:
|
10
|
-
raise ImportError(
|
11
|
-
"Anthropic package not installed. Install it with 'pip install anthropic'"
|
12
|
-
)
|
13
|
-
|
14
|
-
from ..utils.exceptions import RateLimitError, ModelNotFoundError
|
15
|
-
from .base_provider import BaseProvider
|
16
|
-
|
17
|
-
|
18
|
-
class Provider(BaseProvider):
|
19
|
-
"""
|
20
|
-
Anthropic (Claude) provider implementation
|
21
|
-
"""
|
22
|
-
|
23
|
-
def __init__(self, api_key: str, model_name: str):
|
24
|
-
"""
|
25
|
-
Initialize the Anthropic provider
|
26
|
-
|
27
|
-
Args:
|
28
|
-
api_key: Anthropic API key
|
29
|
-
model_name: Model name to use (e.g., 'claude-3-opus-20240229')
|
30
|
-
"""
|
31
|
-
super().__init__(api_key, model_name)
|
32
|
-
|
33
|
-
# Initialize Anthropic client
|
34
|
-
self.client = Anthropic(api_key=api_key)
|
35
|
-
|
36
|
-
# Load model configuration
|
37
|
-
self.model_config = self._load_model_config(model_name)
|
38
|
-
|
39
|
-
def _load_model_config(self, model_name: str) -> Dict[str, Any]:
|
40
|
-
"""
|
41
|
-
Load model configuration from JSON file
|
42
|
-
|
43
|
-
Args:
|
44
|
-
model_name: Model name
|
45
|
-
|
46
|
-
Returns:
|
47
|
-
Model configuration dictionary
|
48
|
-
"""
|
49
|
-
# Get the path to the model configuration file
|
50
|
-
config_path = Path(__file__).parent / "claude.json"
|
51
|
-
|
52
|
-
# Load the configuration
|
53
|
-
with open(config_path, "r") as f:
|
54
|
-
models = json.load(f)
|
55
|
-
|
56
|
-
# Find the model configuration
|
57
|
-
for model in models:
|
58
|
-
if model.get("modelName") == model_name:
|
59
|
-
return model
|
60
|
-
|
61
|
-
# If model not found, raise an error
|
62
|
-
raise ModelNotFoundError(
|
63
|
-
f"Model {model_name} not found in Anthropic configuration"
|
64
|
-
)
|
65
|
-
|
66
|
-
def estimate_cost(self, prompt: str, max_tokens: int) -> float:
|
67
|
-
"""
|
68
|
-
Estimate the cost of generating a completion
|
69
|
-
|
70
|
-
Args:
|
71
|
-
prompt: The prompt to generate a completion for
|
72
|
-
max_tokens: Maximum number of tokens to generate
|
73
|
-
|
74
|
-
Returns:
|
75
|
-
Estimated cost in credits
|
76
|
-
"""
|
77
|
-
# Estimate token count (rough approximation)
|
78
|
-
prompt_tokens = self.count_tokens(prompt)
|
79
|
-
|
80
|
-
# Get pricing for the model
|
81
|
-
input_price = self.model_config.get("inputPricePer1KTokens", 0)
|
82
|
-
output_price = self.model_config.get("outputPricePer1KTokens", 0)
|
83
|
-
|
84
|
-
# Calculate cost
|
85
|
-
prompt_cost = (prompt_tokens / 1000) * input_price
|
86
|
-
completion_cost = (max_tokens / 1000) * output_price
|
87
|
-
|
88
|
-
return prompt_cost + completion_cost
|
89
|
-
|
90
|
-
def count_tokens(self, text: str) -> int:
|
91
|
-
"""
|
92
|
-
Count the number of tokens in a text
|
93
|
-
|
94
|
-
Args:
|
95
|
-
text: Text to count tokens for
|
96
|
-
|
97
|
-
Returns:
|
98
|
-
Number of tokens
|
99
|
-
"""
|
100
|
-
try:
|
101
|
-
# Use Anthropic's token counter if available
|
102
|
-
return anthropic.count_tokens(text)
|
103
|
-
except:
|
104
|
-
# Fallback to simple approximation
|
105
|
-
return len(text.split()) * 1.3
|
106
|
-
|
107
|
-
def generate(self, prompt: str, **kwargs) -> Dict[str, Any]:
|
108
|
-
"""
|
109
|
-
Generate a completion using Anthropic
|
110
|
-
|
111
|
-
Args:
|
112
|
-
prompt: The prompt to generate a completion for
|
113
|
-
**kwargs: Additional parameters for the generation
|
114
|
-
- max_tokens: Maximum number of tokens to generate
|
115
|
-
- temperature: Sampling temperature (0.0 to 1.0)
|
116
|
-
- top_p: Nucleus sampling parameter (0.0 to 1.0)
|
117
|
-
|
118
|
-
Returns:
|
119
|
-
Dictionary containing the response text, cost, and other metadata
|
120
|
-
"""
|
121
|
-
try:
|
122
|
-
# Extract parameters
|
123
|
-
max_tokens = kwargs.get("max_tokens", 1024)
|
124
|
-
temperature = kwargs.get("temperature", 0.7)
|
125
|
-
top_p = kwargs.get("top_p", 1.0)
|
126
|
-
|
127
|
-
# Format prompt using the template from model config
|
128
|
-
prompt_template = self.model_config.get(
|
129
|
-
"promptTemplate", "Human: %1\n\nAssistant: %2"
|
130
|
-
)
|
131
|
-
formatted_prompt = prompt_template.replace("%1", prompt).replace("%2", "")
|
132
|
-
|
133
|
-
# Get system prompt if available
|
134
|
-
system_prompt = self.model_config.get("systemPrompt", "")
|
135
|
-
|
136
|
-
# Make API call
|
137
|
-
response = self.client.messages.create(
|
138
|
-
model=self.model_name,
|
139
|
-
messages=[{"role": "user", "content": prompt}],
|
140
|
-
system=system_prompt if system_prompt else None,
|
141
|
-
max_tokens=max_tokens,
|
142
|
-
temperature=temperature,
|
143
|
-
top_p=top_p,
|
144
|
-
)
|
145
|
-
|
146
|
-
# Extract response text
|
147
|
-
text = response.content[0].text
|
148
|
-
|
149
|
-
# Calculate actual cost
|
150
|
-
input_price = self.model_config.get("inputPricePer1KTokens", 0)
|
151
|
-
output_price = self.model_config.get("outputPricePer1KTokens", 0)
|
152
|
-
|
153
|
-
prompt_tokens = response.usage.input_tokens
|
154
|
-
completion_tokens = response.usage.output_tokens
|
155
|
-
|
156
|
-
prompt_cost = (prompt_tokens / 1000) * input_price
|
157
|
-
completion_cost = (completion_tokens / 1000) * output_price
|
158
|
-
total_cost = prompt_cost + completion_cost
|
159
|
-
|
160
|
-
# Return standardized response
|
161
|
-
return {
|
162
|
-
"text": text,
|
163
|
-
"cost": total_cost,
|
164
|
-
"usage": {
|
165
|
-
"prompt_tokens": prompt_tokens,
|
166
|
-
"completion_tokens": completion_tokens,
|
167
|
-
"total_tokens": prompt_tokens + completion_tokens,
|
168
|
-
},
|
169
|
-
"model": self.model_name,
|
170
|
-
}
|
171
|
-
|
172
|
-
except anthropic.RateLimitError as e:
|
173
|
-
raise RateLimitError(f"Anthropic rate limit exceeded: {str(e)}")
|
174
|
-
except anthropic.APIError as e:
|
175
|
-
raise Exception(f"Anthropic API error: {str(e)}")
|
176
|
-
except Exception as e:
|
177
|
-
raise Exception(f"Error generating completion: {str(e)}")
|