indoxrouter 0.1.0__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- indoxrouter-0.1.3.dist-info/METADATA +188 -0
- indoxrouter-0.1.3.dist-info/RECORD +4 -0
- indoxrouter-0.1.3.dist-info/top_level.txt +1 -0
- indoxRouter/__init__.py +0 -0
- indoxRouter/api_endpoints.py +0 -336
- indoxRouter/client.py +0 -286
- indoxRouter/client_package.py +0 -138
- indoxRouter/init_db.py +0 -71
- indoxRouter/main.py +0 -711
- indoxRouter/migrations/__init__.py +0 -1
- indoxRouter/migrations/env.py +0 -98
- indoxRouter/migrations/versions/__init__.py +0 -1
- indoxRouter/migrations/versions/initial_schema.py +0 -84
- indoxRouter/providers/__init__.py +0 -108
- indoxRouter/providers/ai21.py +0 -268
- indoxRouter/providers/base_provider.py +0 -69
- indoxRouter/providers/claude.py +0 -177
- indoxRouter/providers/cohere.py +0 -171
- indoxRouter/providers/databricks.py +0 -166
- indoxRouter/providers/deepseek.py +0 -166
- indoxRouter/providers/google.py +0 -216
- indoxRouter/providers/llama.py +0 -164
- indoxRouter/providers/meta.py +0 -227
- indoxRouter/providers/mistral.py +0 -182
- indoxRouter/providers/nvidia.py +0 -164
- indoxRouter/providers/openai.py +0 -122
- indoxrouter-0.1.0.dist-info/METADATA +0 -179
- indoxrouter-0.1.0.dist-info/RECORD +0 -27
- indoxrouter-0.1.0.dist-info/top_level.txt +0 -1
- {indoxrouter-0.1.0.dist-info → indoxrouter-0.1.3.dist-info}/WHEEL +0 -0
@@ -1 +0,0 @@
|
|
1
|
-
# IndoxRouter migrations package
|
indoxRouter/migrations/env.py
DELETED
@@ -1,98 +0,0 @@
|
|
1
|
-
from logging.config import fileConfig
|
2
|
-
|
3
|
-
from sqlalchemy import engine_from_config
|
4
|
-
from sqlalchemy import pool
|
5
|
-
|
6
|
-
from alembic import context
|
7
|
-
|
8
|
-
import os
|
9
|
-
import sys
|
10
|
-
|
11
|
-
# Add the parent directory to sys.path
|
12
|
-
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
13
|
-
|
14
|
-
# Import the SQLAlchemy models
|
15
|
-
from indoxRouter.utils.database import Base
|
16
|
-
from indoxRouter.utils.config import get_config
|
17
|
-
|
18
|
-
# this is the Alembic Config object, which provides
|
19
|
-
# access to the values within the .ini file in use.
|
20
|
-
config = context.config
|
21
|
-
|
22
|
-
# Interpret the config file for Python logging.
|
23
|
-
# This line sets up loggers basically.
|
24
|
-
fileConfig(config.config_file_name)
|
25
|
-
|
26
|
-
# Get the database URL from the IndoxRouter config
|
27
|
-
indox_config = get_config()
|
28
|
-
db_config = indox_config.get_section("database")
|
29
|
-
|
30
|
-
db_type = db_config.get("type", "sqlite")
|
31
|
-
if db_type == "sqlite":
|
32
|
-
db_url = f"sqlite:///{db_config.get('path', 'indoxrouter.db')}"
|
33
|
-
elif db_type == "postgres":
|
34
|
-
db_url = f"postgresql://{db_config.get('user')}:{db_config.get('password')}@{db_config.get('host')}:{db_config.get('port')}/{db_config.get('database')}"
|
35
|
-
else:
|
36
|
-
raise ValueError(f"Unsupported database type: {db_type}")
|
37
|
-
|
38
|
-
# Override the SQLAlchemy URL in the Alembic config
|
39
|
-
config.set_main_option("sqlalchemy.url", db_url)
|
40
|
-
|
41
|
-
# add your model's MetaData object here
|
42
|
-
# for 'autogenerate' support
|
43
|
-
target_metadata = Base.metadata
|
44
|
-
|
45
|
-
# other values from the config, defined by the needs of env.py,
|
46
|
-
# can be acquired:
|
47
|
-
# my_important_option = config.get_main_option("my_important_option")
|
48
|
-
# ... etc.
|
49
|
-
|
50
|
-
|
51
|
-
def run_migrations_offline():
|
52
|
-
"""Run migrations in 'offline' mode.
|
53
|
-
|
54
|
-
This configures the context with just a URL
|
55
|
-
and not an Engine, though an Engine is acceptable
|
56
|
-
here as well. By skipping the Engine creation
|
57
|
-
we don't even need a DBAPI to be available.
|
58
|
-
|
59
|
-
Calls to context.execute() here emit the given string to the
|
60
|
-
script output.
|
61
|
-
|
62
|
-
"""
|
63
|
-
url = config.get_main_option("sqlalchemy.url")
|
64
|
-
context.configure(
|
65
|
-
url=url,
|
66
|
-
target_metadata=target_metadata,
|
67
|
-
literal_binds=True,
|
68
|
-
dialect_opts={"paramstyle": "named"},
|
69
|
-
)
|
70
|
-
|
71
|
-
with context.begin_transaction():
|
72
|
-
context.run_migrations()
|
73
|
-
|
74
|
-
|
75
|
-
def run_migrations_online():
|
76
|
-
"""Run migrations in 'online' mode.
|
77
|
-
|
78
|
-
In this scenario we need to create an Engine
|
79
|
-
and associate a connection with the context.
|
80
|
-
|
81
|
-
"""
|
82
|
-
connectable = engine_from_config(
|
83
|
-
config.get_section(config.config_ini_section),
|
84
|
-
prefix="sqlalchemy.",
|
85
|
-
poolclass=pool.NullPool,
|
86
|
-
)
|
87
|
-
|
88
|
-
with connectable.connect() as connection:
|
89
|
-
context.configure(connection=connection, target_metadata=target_metadata)
|
90
|
-
|
91
|
-
with context.begin_transaction():
|
92
|
-
context.run_migrations()
|
93
|
-
|
94
|
-
|
95
|
-
if context.is_offline_mode():
|
96
|
-
run_migrations_offline()
|
97
|
-
else:
|
98
|
-
run_migrations_online()
|
@@ -1 +0,0 @@
|
|
1
|
-
# IndoxRouter migrations versions package
|
@@ -1,84 +0,0 @@
|
|
1
|
-
"""Initial database schema
|
2
|
-
|
3
|
-
Revision ID: 001
|
4
|
-
Revises:
|
5
|
-
Create Date: 2023-03-01
|
6
|
-
|
7
|
-
"""
|
8
|
-
|
9
|
-
from alembic import op
|
10
|
-
import sqlalchemy as sa
|
11
|
-
|
12
|
-
|
13
|
-
# revision identifiers, used by Alembic.
|
14
|
-
revision = "001"
|
15
|
-
down_revision = None
|
16
|
-
branch_labels = None
|
17
|
-
depends_on = None
|
18
|
-
|
19
|
-
|
20
|
-
def upgrade():
|
21
|
-
# Create users table
|
22
|
-
op.create_table(
|
23
|
-
"users",
|
24
|
-
sa.Column("id", sa.String(36), primary_key=True),
|
25
|
-
sa.Column("email", sa.String(255), unique=True, nullable=False),
|
26
|
-
sa.Column("name", sa.String(255), nullable=False),
|
27
|
-
sa.Column("balance", sa.Float, nullable=False, default=0.0),
|
28
|
-
sa.Column("is_active", sa.Boolean, nullable=False, default=True),
|
29
|
-
sa.Column(
|
30
|
-
"created_at", sa.DateTime, nullable=False, server_default=sa.func.now()
|
31
|
-
),
|
32
|
-
sa.Column(
|
33
|
-
"updated_at",
|
34
|
-
sa.DateTime,
|
35
|
-
nullable=False,
|
36
|
-
server_default=sa.func.now(),
|
37
|
-
onupdate=sa.func.now(),
|
38
|
-
),
|
39
|
-
)
|
40
|
-
|
41
|
-
# Create api_keys table
|
42
|
-
op.create_table(
|
43
|
-
"api_keys",
|
44
|
-
sa.Column("id", sa.String(36), primary_key=True),
|
45
|
-
sa.Column("user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False),
|
46
|
-
sa.Column("key", sa.String(64), unique=True, nullable=False),
|
47
|
-
sa.Column("is_active", sa.Boolean, nullable=False, default=True),
|
48
|
-
sa.Column(
|
49
|
-
"created_at", sa.DateTime, nullable=False, server_default=sa.func.now()
|
50
|
-
),
|
51
|
-
sa.Column("last_used_at", sa.DateTime, nullable=True),
|
52
|
-
)
|
53
|
-
|
54
|
-
# Create usage_logs table
|
55
|
-
op.create_table(
|
56
|
-
"usage_logs",
|
57
|
-
sa.Column("id", sa.String(36), primary_key=True),
|
58
|
-
sa.Column("user_id", sa.String(36), sa.ForeignKey("users.id"), nullable=False),
|
59
|
-
sa.Column("provider", sa.String(50), nullable=False),
|
60
|
-
sa.Column("model", sa.String(100), nullable=False),
|
61
|
-
sa.Column("prompt_tokens", sa.Integer, nullable=False),
|
62
|
-
sa.Column("completion_tokens", sa.Integer, nullable=False),
|
63
|
-
sa.Column("total_tokens", sa.Integer, nullable=False),
|
64
|
-
sa.Column("cost", sa.Float, nullable=False),
|
65
|
-
sa.Column(
|
66
|
-
"created_at", sa.DateTime, nullable=False, server_default=sa.func.now()
|
67
|
-
),
|
68
|
-
)
|
69
|
-
|
70
|
-
# Create indexes
|
71
|
-
op.create_index("idx_users_email", "users", ["email"])
|
72
|
-
op.create_index("idx_api_keys_user_id", "api_keys", ["user_id"])
|
73
|
-
op.create_index("idx_api_keys_key", "api_keys", ["key"])
|
74
|
-
op.create_index("idx_usage_logs_user_id", "usage_logs", ["user_id"])
|
75
|
-
op.create_index(
|
76
|
-
"idx_usage_logs_provider_model", "usage_logs", ["provider", "model"]
|
77
|
-
)
|
78
|
-
op.create_index("idx_usage_logs_created_at", "usage_logs", ["created_at"])
|
79
|
-
|
80
|
-
|
81
|
-
def downgrade():
|
82
|
-
op.drop_table("usage_logs")
|
83
|
-
op.drop_table("api_keys")
|
84
|
-
op.drop_table("users")
|
@@ -1,108 +0,0 @@
|
|
1
|
-
# Import all provider modules
|
2
|
-
import logging
|
3
|
-
|
4
|
-
# Configure logging
|
5
|
-
logger = logging.getLogger(__name__)
|
6
|
-
|
7
|
-
# Dictionary to store provider modules
|
8
|
-
PROVIDERS = {}
|
9
|
-
|
10
|
-
# Import providers with graceful error handling
|
11
|
-
try:
|
12
|
-
from . import openai
|
13
|
-
|
14
|
-
PROVIDERS["openai"] = openai
|
15
|
-
except ImportError as e:
|
16
|
-
logger.warning(f"OpenAI provider not available: {e}")
|
17
|
-
|
18
|
-
try:
|
19
|
-
from . import claude
|
20
|
-
|
21
|
-
PROVIDERS["claude"] = claude
|
22
|
-
except ImportError as e:
|
23
|
-
logger.warning(f"Claude provider not available: {e}")
|
24
|
-
|
25
|
-
try:
|
26
|
-
from . import mistral
|
27
|
-
|
28
|
-
PROVIDERS["mistral"] = mistral
|
29
|
-
except ImportError as e:
|
30
|
-
logger.warning(f"Mistral provider not available: {e}")
|
31
|
-
|
32
|
-
try:
|
33
|
-
from . import cohere
|
34
|
-
|
35
|
-
PROVIDERS["cohere"] = cohere
|
36
|
-
except ImportError as e:
|
37
|
-
logger.warning(f"Cohere provider not available: {e}")
|
38
|
-
|
39
|
-
try:
|
40
|
-
from . import google
|
41
|
-
|
42
|
-
PROVIDERS["google"] = google
|
43
|
-
except ImportError as e:
|
44
|
-
logger.warning(f"Google provider not available: {e}")
|
45
|
-
|
46
|
-
try:
|
47
|
-
from . import meta
|
48
|
-
|
49
|
-
PROVIDERS["meta"] = meta
|
50
|
-
except ImportError as e:
|
51
|
-
logger.warning(f"Meta provider not available: {e}")
|
52
|
-
|
53
|
-
try:
|
54
|
-
from . import ai21
|
55
|
-
|
56
|
-
PROVIDERS["ai21"] = ai21
|
57
|
-
except ImportError as e:
|
58
|
-
logger.warning(f"AI21 provider not available: {e}")
|
59
|
-
|
60
|
-
try:
|
61
|
-
from . import llama
|
62
|
-
|
63
|
-
PROVIDERS["llama"] = llama
|
64
|
-
except ImportError as e:
|
65
|
-
logger.warning(f"Llama provider not available: {e}")
|
66
|
-
|
67
|
-
try:
|
68
|
-
from . import nvidia
|
69
|
-
|
70
|
-
PROVIDERS["nvidia"] = nvidia
|
71
|
-
except ImportError as e:
|
72
|
-
logger.warning(f"NVIDIA provider not available: {e}")
|
73
|
-
|
74
|
-
try:
|
75
|
-
from . import deepseek
|
76
|
-
|
77
|
-
PROVIDERS["deepseek"] = deepseek
|
78
|
-
except ImportError as e:
|
79
|
-
logger.warning(f"Deepseek provider not available: {e}")
|
80
|
-
|
81
|
-
try:
|
82
|
-
from . import databricks
|
83
|
-
|
84
|
-
PROVIDERS["databricks"] = databricks
|
85
|
-
except ImportError as e:
|
86
|
-
logger.warning(f"Databricks provider not available: {e}")
|
87
|
-
|
88
|
-
|
89
|
-
def get_provider(provider_name, api_key, model_name):
|
90
|
-
"""
|
91
|
-
Get a provider instance by name.
|
92
|
-
|
93
|
-
Args:
|
94
|
-
provider_name (str): The name of the provider
|
95
|
-
api_key (str): The API key for the provider
|
96
|
-
model_name (str): The name of the model to use
|
97
|
-
|
98
|
-
Returns:
|
99
|
-
BaseProvider: An instance of the provider
|
100
|
-
|
101
|
-
Raises:
|
102
|
-
ValueError: If the provider is not found
|
103
|
-
"""
|
104
|
-
if provider_name not in PROVIDERS:
|
105
|
-
raise ValueError(f"Provider {provider_name} not found or not available")
|
106
|
-
|
107
|
-
provider_module = PROVIDERS[provider_name]
|
108
|
-
return provider_module.Provider(api_key, model_name)
|
indoxRouter/providers/ai21.py
DELETED
@@ -1,268 +0,0 @@
|
|
1
|
-
import json
|
2
|
-
import os
|
3
|
-
from pathlib import Path
|
4
|
-
from typing import Dict, Any, List, Optional
|
5
|
-
import requests
|
6
|
-
|
7
|
-
from .base_provider import BaseProvider
|
8
|
-
from ..utils.exceptions import ModelNotFoundError, ProviderAPIError, RateLimitError
|
9
|
-
|
10
|
-
|
11
|
-
class Provider(BaseProvider):
|
12
|
-
"""
|
13
|
-
AI21 Labs provider implementation
|
14
|
-
"""
|
15
|
-
|
16
|
-
def __init__(self, api_key: str, model_name: str):
|
17
|
-
"""
|
18
|
-
Initialize the AI21 Labs provider
|
19
|
-
|
20
|
-
Args:
|
21
|
-
api_key: AI21 Labs API key
|
22
|
-
model_name: Model name (e.g., j2-ultra, j2-mid, jamba-instruct)
|
23
|
-
"""
|
24
|
-
super().__init__(api_key, model_name)
|
25
|
-
|
26
|
-
# Load model configuration
|
27
|
-
self.model_config = self._load_model_config(model_name)
|
28
|
-
|
29
|
-
# AI21 API base URL
|
30
|
-
self.api_base = os.environ.get(
|
31
|
-
"AI21_API_BASE", "https://api.ai21.com/studio/v1"
|
32
|
-
)
|
33
|
-
|
34
|
-
# Default generation parameters
|
35
|
-
self.default_params = {
|
36
|
-
"temperature": 0.7,
|
37
|
-
"top_p": 0.9,
|
38
|
-
"max_tokens": 1024,
|
39
|
-
}
|
40
|
-
|
41
|
-
def _load_model_config(self, model_name: str) -> Dict[str, Any]:
|
42
|
-
"""
|
43
|
-
Load model configuration from the JSON file
|
44
|
-
|
45
|
-
Args:
|
46
|
-
model_name: Model name to load configuration for
|
47
|
-
|
48
|
-
Returns:
|
49
|
-
Model configuration dictionary
|
50
|
-
|
51
|
-
Raises:
|
52
|
-
ModelNotFoundError: If the model is not found in the configuration
|
53
|
-
"""
|
54
|
-
config_path = Path(__file__).parent / "ai21.json"
|
55
|
-
|
56
|
-
try:
|
57
|
-
with open(config_path, "r") as f:
|
58
|
-
models = json.load(f)
|
59
|
-
|
60
|
-
for model in models:
|
61
|
-
if model.get("modelName") == model_name:
|
62
|
-
return model
|
63
|
-
|
64
|
-
raise ModelNotFoundError(
|
65
|
-
f"Model {model_name} not found in AI21 Labs provider"
|
66
|
-
)
|
67
|
-
|
68
|
-
except FileNotFoundError:
|
69
|
-
raise ModelNotFoundError(f"AI21 Labs provider configuration file not found")
|
70
|
-
except json.JSONDecodeError:
|
71
|
-
raise ModelNotFoundError(
|
72
|
-
f"Invalid JSON in AI21 Labs provider configuration file"
|
73
|
-
)
|
74
|
-
|
75
|
-
def estimate_cost(self, prompt: str, max_tokens: int) -> float:
|
76
|
-
"""
|
77
|
-
Estimate the cost of generating a completion
|
78
|
-
|
79
|
-
Args:
|
80
|
-
prompt: Prompt text
|
81
|
-
max_tokens: Maximum number of tokens to generate
|
82
|
-
|
83
|
-
Returns:
|
84
|
-
Estimated cost in USD
|
85
|
-
"""
|
86
|
-
# Count tokens in the prompt
|
87
|
-
prompt_tokens = self.count_tokens(prompt)
|
88
|
-
|
89
|
-
# Calculate cost based on input and output pricing
|
90
|
-
input_cost = (prompt_tokens / 1000) * self.model_config.get(
|
91
|
-
"inputPricePer1KTokens", 0
|
92
|
-
)
|
93
|
-
output_cost = (max_tokens / 1000) * self.model_config.get(
|
94
|
-
"outputPricePer1KTokens", 0
|
95
|
-
)
|
96
|
-
|
97
|
-
return input_cost + output_cost
|
98
|
-
|
99
|
-
def count_tokens(self, text: str) -> int:
|
100
|
-
"""
|
101
|
-
Count the number of tokens in a text
|
102
|
-
|
103
|
-
Args:
|
104
|
-
text: Text to count tokens for
|
105
|
-
|
106
|
-
Returns:
|
107
|
-
Number of tokens
|
108
|
-
"""
|
109
|
-
# AI21 doesn't provide a direct token counting API in their Python client
|
110
|
-
# This is a rough approximation - in production, consider using a tokenizer library
|
111
|
-
return len(text.split()) * 1.3 # Rough approximation
|
112
|
-
|
113
|
-
def generate(self, prompt: str, **kwargs) -> Dict[str, Any]:
|
114
|
-
"""
|
115
|
-
Generate a completion for the given prompt
|
116
|
-
|
117
|
-
Args:
|
118
|
-
prompt: Prompt text
|
119
|
-
**kwargs: Additional parameters for the generation
|
120
|
-
|
121
|
-
Returns:
|
122
|
-
Dictionary containing the generated text, cost, and usage statistics
|
123
|
-
|
124
|
-
Raises:
|
125
|
-
ProviderAPIError: If there's an error with the provider API
|
126
|
-
RateLimitError: If the provider's rate limit is exceeded
|
127
|
-
"""
|
128
|
-
try:
|
129
|
-
# Get generation parameters
|
130
|
-
max_tokens = kwargs.get("max_tokens", self.default_params["max_tokens"])
|
131
|
-
temperature = kwargs.get("temperature", self.default_params["temperature"])
|
132
|
-
top_p = kwargs.get("top_p", self.default_params["top_p"])
|
133
|
-
|
134
|
-
# Prepare system prompt if provided
|
135
|
-
system_prompt = kwargs.get(
|
136
|
-
"system_prompt", self.model_config.get("systemPrompt", "")
|
137
|
-
)
|
138
|
-
|
139
|
-
# Format the prompt using the template if available
|
140
|
-
prompt_template = self.model_config.get("promptTemplate", "")
|
141
|
-
if prompt_template and "%1" in prompt_template:
|
142
|
-
formatted_prompt = prompt_template.replace("%1", prompt)
|
143
|
-
else:
|
144
|
-
formatted_prompt = prompt
|
145
|
-
|
146
|
-
# Combine system prompt and user prompt if needed
|
147
|
-
if system_prompt:
|
148
|
-
full_prompt = f"{system_prompt}\n\n{formatted_prompt}"
|
149
|
-
else:
|
150
|
-
full_prompt = formatted_prompt
|
151
|
-
|
152
|
-
# Check if this is a Jamba model (chat model)
|
153
|
-
is_jamba = "jamba" in self.model_name.lower()
|
154
|
-
|
155
|
-
if is_jamba:
|
156
|
-
# Chat completion endpoint for Jamba models
|
157
|
-
endpoint = f"{self.api_base}/chat/completions"
|
158
|
-
|
159
|
-
# Prepare the request payload for chat
|
160
|
-
payload = {
|
161
|
-
"model": self.model_name,
|
162
|
-
"messages": [],
|
163
|
-
"temperature": temperature,
|
164
|
-
"top_p": top_p,
|
165
|
-
"max_tokens": max_tokens,
|
166
|
-
}
|
167
|
-
|
168
|
-
# Add system message if provided
|
169
|
-
if system_prompt:
|
170
|
-
payload["messages"].append(
|
171
|
-
{"role": "system", "content": system_prompt}
|
172
|
-
)
|
173
|
-
|
174
|
-
# Add user message
|
175
|
-
payload["messages"].append(
|
176
|
-
{"role": "user", "content": formatted_prompt}
|
177
|
-
)
|
178
|
-
else:
|
179
|
-
# Completion endpoint for Jurassic models
|
180
|
-
endpoint = f"{self.api_base}/{self.model_name}/complete"
|
181
|
-
|
182
|
-
# Prepare the request payload for completion
|
183
|
-
payload = {
|
184
|
-
"prompt": full_prompt,
|
185
|
-
"temperature": temperature,
|
186
|
-
"topP": top_p,
|
187
|
-
"maxTokens": max_tokens,
|
188
|
-
}
|
189
|
-
|
190
|
-
# Make the API request
|
191
|
-
headers = {
|
192
|
-
"Authorization": f"Bearer {self.api_key}",
|
193
|
-
"Content-Type": "application/json",
|
194
|
-
}
|
195
|
-
|
196
|
-
response = requests.post(
|
197
|
-
endpoint, headers=headers, json=payload, timeout=60
|
198
|
-
)
|
199
|
-
|
200
|
-
# Check for errors
|
201
|
-
if response.status_code != 200:
|
202
|
-
error_message = response.json().get("detail", "Unknown error")
|
203
|
-
|
204
|
-
if response.status_code == 429:
|
205
|
-
raise RateLimitError(
|
206
|
-
f"AI21 Labs API rate limit exceeded: {error_message}"
|
207
|
-
)
|
208
|
-
else:
|
209
|
-
raise ProviderAPIError(f"AI21 Labs API error: {error_message}")
|
210
|
-
|
211
|
-
# Parse the response
|
212
|
-
response_data = response.json()
|
213
|
-
|
214
|
-
# Extract the generated text based on model type
|
215
|
-
if is_jamba:
|
216
|
-
generated_text = (
|
217
|
-
response_data.get("choices", [{}])[0]
|
218
|
-
.get("message", {})
|
219
|
-
.get("content", "")
|
220
|
-
)
|
221
|
-
|
222
|
-
# Get token usage
|
223
|
-
usage = response_data.get("usage", {})
|
224
|
-
prompt_tokens = usage.get(
|
225
|
-
"prompt_tokens", self.count_tokens(formatted_prompt)
|
226
|
-
)
|
227
|
-
completion_tokens = usage.get(
|
228
|
-
"completion_tokens", self.count_tokens(generated_text)
|
229
|
-
)
|
230
|
-
total_tokens = usage.get(
|
231
|
-
"total_tokens", prompt_tokens + completion_tokens
|
232
|
-
)
|
233
|
-
else:
|
234
|
-
generated_text = (
|
235
|
-
response_data.get("completions", [{}])[0]
|
236
|
-
.get("data", {})
|
237
|
-
.get("text", "")
|
238
|
-
)
|
239
|
-
|
240
|
-
# For Jurassic models, token counts are not directly provided
|
241
|
-
prompt_tokens = self.count_tokens(full_prompt)
|
242
|
-
completion_tokens = self.count_tokens(generated_text)
|
243
|
-
total_tokens = prompt_tokens + completion_tokens
|
244
|
-
|
245
|
-
# Calculate cost
|
246
|
-
cost = self.estimate_cost(full_prompt, completion_tokens)
|
247
|
-
|
248
|
-
# Prepare the response
|
249
|
-
result = {
|
250
|
-
"text": generated_text,
|
251
|
-
"cost": cost,
|
252
|
-
"usage": {
|
253
|
-
"prompt_tokens": prompt_tokens,
|
254
|
-
"completion_tokens": completion_tokens,
|
255
|
-
"total_tokens": total_tokens,
|
256
|
-
},
|
257
|
-
}
|
258
|
-
|
259
|
-
return self.validate_response(result)
|
260
|
-
|
261
|
-
except RateLimitError:
|
262
|
-
# Re-raise rate limit errors
|
263
|
-
raise
|
264
|
-
except Exception as e:
|
265
|
-
# Handle other errors
|
266
|
-
raise ProviderAPIError(
|
267
|
-
f"Error generating completion with AI21 Labs API: {str(e)}", e
|
268
|
-
)
|
@@ -1,69 +0,0 @@
|
|
1
|
-
from abc import ABC, abstractmethod
|
2
|
-
from typing import Dict, Any, Optional
|
3
|
-
|
4
|
-
|
5
|
-
class BaseProvider(ABC):
|
6
|
-
"""
|
7
|
-
Base class for all LLM providers
|
8
|
-
|
9
|
-
All provider implementations should inherit from this class
|
10
|
-
and implement the required methods.
|
11
|
-
"""
|
12
|
-
|
13
|
-
def __init__(self, api_key: str, model_name: str):
|
14
|
-
"""
|
15
|
-
Initialize the provider
|
16
|
-
|
17
|
-
Args:
|
18
|
-
api_key: Provider API key
|
19
|
-
model_name: Model name to use
|
20
|
-
"""
|
21
|
-
self.api_key = api_key
|
22
|
-
self.model_name = model_name
|
23
|
-
|
24
|
-
@abstractmethod
|
25
|
-
def generate(self, prompt: str, **kwargs) -> Dict[str, Any]:
|
26
|
-
"""
|
27
|
-
Generate a completion for the given prompt
|
28
|
-
|
29
|
-
Args:
|
30
|
-
prompt: The prompt to generate a completion for
|
31
|
-
**kwargs: Additional parameters for the generation
|
32
|
-
|
33
|
-
Returns:
|
34
|
-
Dictionary containing the response text, cost, and other metadata
|
35
|
-
"""
|
36
|
-
pass
|
37
|
-
|
38
|
-
@abstractmethod
|
39
|
-
def estimate_cost(self, prompt: str, max_tokens: int) -> float:
|
40
|
-
"""
|
41
|
-
Estimate the cost of generating a completion
|
42
|
-
|
43
|
-
Args:
|
44
|
-
prompt: The prompt to generate a completion for
|
45
|
-
max_tokens: Maximum number of tokens to generate
|
46
|
-
|
47
|
-
Returns:
|
48
|
-
Estimated cost in credits
|
49
|
-
"""
|
50
|
-
pass
|
51
|
-
|
52
|
-
def validate_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
53
|
-
"""
|
54
|
-
Validate and standardize the response from the provider
|
55
|
-
|
56
|
-
Args:
|
57
|
-
response: Raw response from the provider
|
58
|
-
|
59
|
-
Returns:
|
60
|
-
Standardized response dictionary
|
61
|
-
"""
|
62
|
-
# Ensure the response has the required fields
|
63
|
-
if "text" not in response:
|
64
|
-
raise ValueError("Provider response missing 'text' field")
|
65
|
-
|
66
|
-
if "cost" not in response:
|
67
|
-
raise ValueError("Provider response missing 'cost' field")
|
68
|
-
|
69
|
-
return response
|