stratifyai 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- api/__init__.py +3 -0
- api/main.py +763 -0
- api/static/index.html +1126 -0
- api/static/models.html +567 -0
- api/static/stratifyai_trans_logo.png +0 -0
- api/static/stratifyai_wide_logo.png +0 -0
- api/static/stratum_logo.png +0 -0
- cli/stratifyai_cli.py +574 -73
- stratifyai/api_key_helper.py +1 -1
- stratifyai/config.py +158 -24
- stratifyai/models.py +36 -1
- stratifyai/providers/anthropic.py +65 -5
- stratifyai/providers/bedrock.py +96 -9
- stratifyai/providers/grok.py +3 -2
- stratifyai/providers/openai.py +63 -8
- stratifyai/providers/openai_compatible.py +79 -7
- stratifyai/router.py +2 -2
- stratifyai/summarization.py +147 -3
- stratifyai/utils/model_selector.py +3 -3
- stratifyai/utils/provider_validator.py +4 -2
- {stratifyai-0.1.1.dist-info → stratifyai-0.1.3.dist-info}/METADATA +9 -5
- {stratifyai-0.1.1.dist-info → stratifyai-0.1.3.dist-info}/RECORD +26 -19
- {stratifyai-0.1.1.dist-info → stratifyai-0.1.3.dist-info}/top_level.txt +1 -0
- {stratifyai-0.1.1.dist-info → stratifyai-0.1.3.dist-info}/WHEEL +0 -0
- {stratifyai-0.1.1.dist-info → stratifyai-0.1.3.dist-info}/entry_points.txt +0 -0
- {stratifyai-0.1.1.dist-info → stratifyai-0.1.3.dist-info}/licenses/LICENSE +0 -0
api/main.py
ADDED
|
@@ -0,0 +1,763 @@
|
|
|
1
|
+
"""FastAPI application for StratifyAI."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Optional, List, Dict
|
|
6
|
+
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
|
7
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
8
|
+
from fastapi.responses import HTMLResponse, FileResponse
|
|
9
|
+
from fastapi.staticfiles import StaticFiles
|
|
10
|
+
from pydantic import BaseModel
|
|
11
|
+
import os
|
|
12
|
+
from dotenv import load_dotenv
|
|
13
|
+
import asyncio
|
|
14
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
15
|
+
|
|
16
|
+
# Load environment variables from .env file
|
|
17
|
+
load_dotenv()
|
|
18
|
+
|
|
19
|
+
from stratifyai import LLMClient, ChatRequest, Message, ProviderType
|
|
20
|
+
from stratifyai.cost_tracker import CostTracker
|
|
21
|
+
from stratifyai.config import MODEL_CATALOG
|
|
22
|
+
|
|
23
|
+
# Configure logging
|
|
24
|
+
logging.basicConfig(level=logging.INFO)
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
# Initialize FastAPI app
|
|
28
|
+
app = FastAPI(
|
|
29
|
+
title="StratifyAI API",
|
|
30
|
+
description="Unified API for multiple LLM providers",
|
|
31
|
+
version="0.1.0",
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
# Configure CORS
|
|
35
|
+
app.add_middleware(
|
|
36
|
+
CORSMiddleware,
|
|
37
|
+
allow_origins=["*"],
|
|
38
|
+
allow_credentials=True,
|
|
39
|
+
allow_methods=["*"],
|
|
40
|
+
allow_headers=["*"],
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Global cost tracker
|
|
44
|
+
cost_tracker = CostTracker()
|
|
45
|
+
|
|
46
|
+
# Mount static files
|
|
47
|
+
static_dir = os.path.join(os.path.dirname(__file__), "static")
|
|
48
|
+
if os.path.exists(static_dir):
|
|
49
|
+
app.mount("/static", StaticFiles(directory=static_dir), name="static")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# Request/Response models
|
|
53
|
+
class ChatCompletionRequest(BaseModel):
|
|
54
|
+
"""Chat completion request model."""
|
|
55
|
+
provider: str
|
|
56
|
+
model: str
|
|
57
|
+
messages: List[dict]
|
|
58
|
+
temperature: Optional[float] = None
|
|
59
|
+
max_tokens: Optional[int] = None
|
|
60
|
+
stream: bool = False
|
|
61
|
+
file_content: Optional[str] = None # Base64 encoded file content or plain text
|
|
62
|
+
file_name: Optional[str] = None # Original filename for type detection
|
|
63
|
+
chunked: bool = False # Enable smart chunking and summarization
|
|
64
|
+
chunk_size: int = 50000 # Chunk size in characters
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class ChatCompletionResponse(BaseModel):
|
|
68
|
+
"""Chat completion response model."""
|
|
69
|
+
id: str
|
|
70
|
+
provider: str
|
|
71
|
+
model: str
|
|
72
|
+
content: str
|
|
73
|
+
finish_reason: str
|
|
74
|
+
usage: dict
|
|
75
|
+
cost_usd: float
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class ProviderInfo(BaseModel):
|
|
79
|
+
"""Provider information model."""
|
|
80
|
+
name: str
|
|
81
|
+
models: List[str]
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class ErrorResponse(BaseModel):
|
|
85
|
+
"""Error response model."""
|
|
86
|
+
error: str
|
|
87
|
+
detail: str
|
|
88
|
+
error_type: str
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@app.get("/")
|
|
92
|
+
async def root():
|
|
93
|
+
"""Serve the frontend interface."""
|
|
94
|
+
static_dir = os.path.join(os.path.dirname(__file__), "static")
|
|
95
|
+
index_path = os.path.join(static_dir, "index.html")
|
|
96
|
+
if os.path.exists(index_path):
|
|
97
|
+
return FileResponse(index_path)
|
|
98
|
+
return {
|
|
99
|
+
"name": "StratifyAI API",
|
|
100
|
+
"version": "0.1.0",
|
|
101
|
+
"message": "Frontend not found. API endpoints available at /docs"
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@app.get("/models")
|
|
106
|
+
async def models_page():
|
|
107
|
+
"""Serve the models catalog page."""
|
|
108
|
+
static_dir = os.path.join(os.path.dirname(__file__), "static")
|
|
109
|
+
models_path = os.path.join(static_dir, "models.html")
|
|
110
|
+
if os.path.exists(models_path):
|
|
111
|
+
return FileResponse(models_path)
|
|
112
|
+
return {"error": "Models page not found"}
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@app.get("/api/providers", response_model=List[str])
|
|
116
|
+
async def list_providers():
|
|
117
|
+
"""List all available providers."""
|
|
118
|
+
return [
|
|
119
|
+
"openai",
|
|
120
|
+
"anthropic",
|
|
121
|
+
"google",
|
|
122
|
+
"deepseek",
|
|
123
|
+
"groq",
|
|
124
|
+
"grok",
|
|
125
|
+
"ollama",
|
|
126
|
+
"openrouter",
|
|
127
|
+
"bedrock",
|
|
128
|
+
]
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class ModelInfo(BaseModel):
|
|
132
|
+
"""Model information."""
|
|
133
|
+
id: str # Model ID (e.g., 'gpt-4o')
|
|
134
|
+
display_name: str # Display name (e.g., 'GPT-4o')
|
|
135
|
+
description: str = "" # Description with labels
|
|
136
|
+
category: str = "" # Category for grouping
|
|
137
|
+
reasoning_model: bool = False
|
|
138
|
+
supports_vision: bool = False
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class ModelListResponse(BaseModel):
|
|
142
|
+
"""Model list response with validation metadata."""
|
|
143
|
+
models: List[ModelInfo]
|
|
144
|
+
validation: dict
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@app.get("/api/models/{provider}", response_model=ModelListResponse)
|
|
148
|
+
async def list_models(provider: str):
|
|
149
|
+
"""List validated models for a specific provider."""
|
|
150
|
+
from stratifyai.utils.provider_validator import get_validated_interactive_models
|
|
151
|
+
|
|
152
|
+
if provider not in MODEL_CATALOG:
|
|
153
|
+
raise HTTPException(status_code=404, detail=f"Provider '{provider}' not found")
|
|
154
|
+
|
|
155
|
+
# Run validation in background thread to avoid blocking
|
|
156
|
+
loop = asyncio.get_event_loop()
|
|
157
|
+
with ThreadPoolExecutor() as pool:
|
|
158
|
+
validation_data = await loop.run_in_executor(
|
|
159
|
+
pool,
|
|
160
|
+
get_validated_interactive_models,
|
|
161
|
+
provider
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
validated_models = validation_data["models"]
|
|
165
|
+
validation_result = validation_data["validation_result"]
|
|
166
|
+
|
|
167
|
+
# Log validation result
|
|
168
|
+
if validation_result["error"]:
|
|
169
|
+
logger.warning(f"Model validation for {provider}: {validation_result['error']}")
|
|
170
|
+
else:
|
|
171
|
+
logger.info(f"Model validation for {provider}: {len(validated_models)} models in {validation_result['validation_time_ms']}ms")
|
|
172
|
+
|
|
173
|
+
# If validation succeeded: return only validated models with metadata
|
|
174
|
+
# If validation failed with error: fall back to catalog
|
|
175
|
+
if validation_result["error"]:
|
|
176
|
+
# Fallback to catalog when validation fails
|
|
177
|
+
model_ids = list(MODEL_CATALOG[provider].keys())
|
|
178
|
+
model_metadata = MODEL_CATALOG[provider]
|
|
179
|
+
else:
|
|
180
|
+
# Show only validated models on success
|
|
181
|
+
model_ids = list(validated_models.keys())
|
|
182
|
+
model_metadata = validated_models
|
|
183
|
+
|
|
184
|
+
# Build model info list with rich metadata
|
|
185
|
+
models_info = []
|
|
186
|
+
for model_id in model_ids:
|
|
187
|
+
meta = model_metadata.get(model_id, {})
|
|
188
|
+
models_info.append(ModelInfo(
|
|
189
|
+
id=model_id,
|
|
190
|
+
display_name=meta.get("display_name", model_id),
|
|
191
|
+
description=meta.get("description", ""),
|
|
192
|
+
category=meta.get("category", ""),
|
|
193
|
+
reasoning_model=meta.get("reasoning_model", False),
|
|
194
|
+
supports_vision=meta.get("supports_vision", False),
|
|
195
|
+
))
|
|
196
|
+
|
|
197
|
+
return ModelListResponse(
|
|
198
|
+
models=models_info,
|
|
199
|
+
validation=validation_result
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
@app.get("/api/model-info/{provider}/{model}")
|
|
204
|
+
async def get_model_info(provider: str, model: str):
|
|
205
|
+
"""Get detailed information about a specific model."""
|
|
206
|
+
if provider not in MODEL_CATALOG:
|
|
207
|
+
raise HTTPException(status_code=404, detail=f"Provider '{provider}' not found")
|
|
208
|
+
|
|
209
|
+
if model not in MODEL_CATALOG[provider]:
|
|
210
|
+
raise HTTPException(status_code=404, detail=f"Model '{model}' not found for provider '{provider}'")
|
|
211
|
+
|
|
212
|
+
model_info = MODEL_CATALOG[provider][model]
|
|
213
|
+
|
|
214
|
+
return {
|
|
215
|
+
"provider": provider,
|
|
216
|
+
"model": model,
|
|
217
|
+
"fixed_temperature": model_info.get("fixed_temperature"),
|
|
218
|
+
"reasoning_model": model_info.get("reasoning_model", False),
|
|
219
|
+
"supports_vision": model_info.get("supports_vision", False),
|
|
220
|
+
"supports_tools": model_info.get("supports_tools", False),
|
|
221
|
+
"supports_caching": model_info.get("supports_caching", False),
|
|
222
|
+
"context": model_info.get("context", 0),
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
@app.get("/api/provider-info", response_model=List[ProviderInfo])
|
|
227
|
+
async def get_provider_info():
|
|
228
|
+
"""Get information about all providers and their models."""
|
|
229
|
+
providers = []
|
|
230
|
+
for provider_name, models in MODEL_CATALOG.items():
|
|
231
|
+
providers.append(ProviderInfo(
|
|
232
|
+
name=provider_name,
|
|
233
|
+
models=list(models.keys())
|
|
234
|
+
))
|
|
235
|
+
return providers
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@app.post("/api/chat", response_model=ChatCompletionResponse)
|
|
239
|
+
async def chat_completion(request: ChatCompletionRequest):
|
|
240
|
+
"""
|
|
241
|
+
Execute a chat completion request.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
request: Chat completion request
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
Chat completion response with cost tracking
|
|
248
|
+
"""
|
|
249
|
+
try:
|
|
250
|
+
# Convert messages to Message objects
|
|
251
|
+
messages = [
|
|
252
|
+
Message(role=msg["role"], content=msg["content"])
|
|
253
|
+
for msg in request.messages
|
|
254
|
+
]
|
|
255
|
+
|
|
256
|
+
# Process file if provided
|
|
257
|
+
if request.file_content and request.file_name:
|
|
258
|
+
from stratifyai.summarization import summarize_file_async
|
|
259
|
+
from stratifyai.utils.file_analyzer import analyze_file
|
|
260
|
+
from pathlib import Path
|
|
261
|
+
import tempfile
|
|
262
|
+
import base64
|
|
263
|
+
|
|
264
|
+
# Detect if content is base64 encoded or plain text
|
|
265
|
+
try:
|
|
266
|
+
# Try to decode as base64
|
|
267
|
+
file_bytes = base64.b64decode(request.file_content)
|
|
268
|
+
file_text = file_bytes.decode('utf-8')
|
|
269
|
+
except Exception:
|
|
270
|
+
# If decoding fails, assume it's plain text
|
|
271
|
+
file_text = request.file_content
|
|
272
|
+
|
|
273
|
+
# Apply chunking if enabled
|
|
274
|
+
if request.chunked:
|
|
275
|
+
logger.info(f"Chunking file {request.file_name} (size: {len(file_text)} chars, chunk_size: {request.chunk_size})")
|
|
276
|
+
|
|
277
|
+
# Create temporary file for analysis
|
|
278
|
+
with tempfile.NamedTemporaryFile(mode='w', suffix=Path(request.file_name).suffix, delete=False) as tmp_file:
|
|
279
|
+
tmp_file.write(file_text)
|
|
280
|
+
tmp_path = Path(tmp_file.name) # Convert to Path object
|
|
281
|
+
|
|
282
|
+
try:
|
|
283
|
+
# Analyze file to determine if chunking is beneficial
|
|
284
|
+
analysis = analyze_file(tmp_path, request.provider, request.model)
|
|
285
|
+
logger.info(f"File analysis: type={analysis.file_type.value}, tokens={analysis.estimated_tokens}")
|
|
286
|
+
|
|
287
|
+
# Perform chunking and summarization
|
|
288
|
+
# Use a cheap model for summarization (gpt-4o-mini or similar)
|
|
289
|
+
# Auto-select based on provider
|
|
290
|
+
summarization_models = {
|
|
291
|
+
"openai": "gpt-4o-mini",
|
|
292
|
+
"anthropic": "claude-3-5-sonnet-20241022",
|
|
293
|
+
"google": "gemini-2.5-flash",
|
|
294
|
+
"deepseek": "deepseek-chat",
|
|
295
|
+
"groq": "llama-3.1-8b-instant",
|
|
296
|
+
"grok": "grok-beta",
|
|
297
|
+
"openrouter": "google/gemini-2.5-flash",
|
|
298
|
+
"ollama": "llama3.2",
|
|
299
|
+
"bedrock": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
|
300
|
+
}
|
|
301
|
+
summarization_model = summarization_models.get(request.provider, "gpt-4o-mini")
|
|
302
|
+
|
|
303
|
+
client = LLMClient(provider=request.provider)
|
|
304
|
+
|
|
305
|
+
# Get context from last user message if available
|
|
306
|
+
context = None
|
|
307
|
+
if messages and messages[-1].role == "user":
|
|
308
|
+
context = messages[-1].content
|
|
309
|
+
|
|
310
|
+
# Run async summarization with cheap model
|
|
311
|
+
result = await summarize_file_async(
|
|
312
|
+
file_text,
|
|
313
|
+
client,
|
|
314
|
+
request.chunk_size,
|
|
315
|
+
summarization_model,
|
|
316
|
+
context,
|
|
317
|
+
False # show_progress=False for API
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
# Use summarized content
|
|
321
|
+
file_content_to_use = result['summary']
|
|
322
|
+
logger.info(f"Chunking complete: {result['reduction_percentage']}% reduction ({result['original_length']} -> {result['summary_length']} chars)")
|
|
323
|
+
finally:
|
|
324
|
+
# Clean up temp file
|
|
325
|
+
import os
|
|
326
|
+
os.unlink(tmp_path)
|
|
327
|
+
else:
|
|
328
|
+
# Use file content as-is
|
|
329
|
+
file_content_to_use = file_text
|
|
330
|
+
|
|
331
|
+
# Append file content to last user message or create new message
|
|
332
|
+
if messages and messages[-1].role == "user":
|
|
333
|
+
# Combine with existing user message
|
|
334
|
+
messages[-1].content = f"{messages[-1].content}\n\n[File: {request.file_name}]\n\n{file_content_to_use}"
|
|
335
|
+
else:
|
|
336
|
+
# Create new user message with file content
|
|
337
|
+
messages.append(Message(
|
|
338
|
+
role="user",
|
|
339
|
+
content=f"[File: {request.file_name}]\n\n{file_content_to_use}"
|
|
340
|
+
))
|
|
341
|
+
|
|
342
|
+
# Validate token count before making request
|
|
343
|
+
from stratifyai.utils.token_counter import count_tokens_for_messages, get_context_window
|
|
344
|
+
estimated_tokens = count_tokens_for_messages(messages, request.provider, request.model)
|
|
345
|
+
|
|
346
|
+
# Get context window and API limits
|
|
347
|
+
context_window = get_context_window(request.provider, request.model)
|
|
348
|
+
model_info = MODEL_CATALOG.get(request.provider, {}).get(request.model, {})
|
|
349
|
+
api_max_input = model_info.get("api_max_input")
|
|
350
|
+
effective_limit = api_max_input if api_max_input and api_max_input < context_window else context_window
|
|
351
|
+
|
|
352
|
+
# Check if exceeds absolute maximum (1M tokens)
|
|
353
|
+
MAX_SYSTEM_LIMIT = 1_000_000
|
|
354
|
+
if estimated_tokens > MAX_SYSTEM_LIMIT:
|
|
355
|
+
raise HTTPException(
|
|
356
|
+
status_code=413,
|
|
357
|
+
detail={
|
|
358
|
+
"error": "content_too_large",
|
|
359
|
+
"message": f"File is too large to process. The content has approximately {estimated_tokens:,} tokens, which exceeds the system's maximum limit of {MAX_SYSTEM_LIMIT:,} tokens.",
|
|
360
|
+
"estimated_tokens": estimated_tokens,
|
|
361
|
+
"system_limit": MAX_SYSTEM_LIMIT,
|
|
362
|
+
"provider": request.provider,
|
|
363
|
+
"model": request.model,
|
|
364
|
+
"suggestion": "Please split your file into smaller chunks or use a different processing approach."
|
|
365
|
+
}
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
# Check if exceeds model's effective limit
|
|
369
|
+
if estimated_tokens > effective_limit:
|
|
370
|
+
# Determine if chunking could help
|
|
371
|
+
if api_max_input and context_window > api_max_input:
|
|
372
|
+
# Model has larger context but API restricts input
|
|
373
|
+
# Suggest chunking to reduce tokens OR switching to unrestricted model
|
|
374
|
+
raise HTTPException(
|
|
375
|
+
status_code=413,
|
|
376
|
+
detail={
|
|
377
|
+
"error": "input_too_long",
|
|
378
|
+
"message": f"Input is too long for {request.model}. The content has approximately {estimated_tokens:,} tokens, but the API restricts input to {api_max_input:,} tokens (despite the model's {context_window:,} token context window).",
|
|
379
|
+
"estimated_tokens": estimated_tokens,
|
|
380
|
+
"api_limit": api_max_input,
|
|
381
|
+
"context_window": context_window,
|
|
382
|
+
"provider": request.provider,
|
|
383
|
+
"model": request.model,
|
|
384
|
+
"suggestion": "✓ Enable 'Smart Chunking' checkbox to reduce tokens by 40-90%\n✓ Switch to Google Gemini models (no API input limits): gemini-2.5-pro, gemini-2.5-flash\n✓ Switch to OpenRouter with google/gemini-2.5-pro or google/gemini-2.5-flash",
|
|
385
|
+
"chunking_enabled": request.chunked
|
|
386
|
+
}
|
|
387
|
+
)
|
|
388
|
+
else:
|
|
389
|
+
# Model simply can't handle this much input
|
|
390
|
+
# Suggest switching to larger context model
|
|
391
|
+
raise HTTPException(
|
|
392
|
+
status_code=413,
|
|
393
|
+
detail={
|
|
394
|
+
"error": "input_too_long",
|
|
395
|
+
"message": f"Input is too long for {request.model}. The content has approximately {estimated_tokens:,} tokens, which exceeds the model's maximum of {effective_limit:,} tokens.",
|
|
396
|
+
"estimated_tokens": estimated_tokens,
|
|
397
|
+
"model_limit": effective_limit,
|
|
398
|
+
"provider": request.provider,
|
|
399
|
+
"model": request.model,
|
|
400
|
+
"suggestion": "✓ Switch to a model with larger context window:\n - Google Gemini 2.5 Pro (1M tokens, no API limits)\n - Google Gemini 2.5 Flash (1M tokens, cheaper)\n - Claude Opus 4.5 (1M context, 200k API limit)\n✓ Enable 'Smart Chunking' to reduce token usage",
|
|
401
|
+
"chunking_enabled": request.chunked
|
|
402
|
+
}
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
# Determine temperature for reasoning models
|
|
406
|
+
model_info = MODEL_CATALOG.get(request.provider, {}).get(request.model, {})
|
|
407
|
+
is_reasoning_model = model_info.get("reasoning_model", False)
|
|
408
|
+
|
|
409
|
+
# Also check model name patterns for OpenAI and DeepSeek
|
|
410
|
+
if not is_reasoning_model and request.provider in ["openai", "deepseek"]:
|
|
411
|
+
model_lower = request.model.lower()
|
|
412
|
+
is_reasoning_model = (
|
|
413
|
+
model_lower.startswith("o1") or
|
|
414
|
+
model_lower.startswith("o3") or
|
|
415
|
+
model_lower.startswith("gpt-5") or
|
|
416
|
+
"reasoner" in model_lower or
|
|
417
|
+
"reasoning" in model_lower or
|
|
418
|
+
(model_lower.startswith("o") and len(model_lower) > 1 and model_lower[1].isdigit())
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
# Set temperature based on model type and user input
|
|
422
|
+
if is_reasoning_model:
|
|
423
|
+
temperature = 1.0
|
|
424
|
+
if request.temperature is not None and request.temperature != 1.0:
|
|
425
|
+
logger.warning(f"Overriding temperature={request.temperature} to 1.0 for reasoning model {request.provider}/{request.model}")
|
|
426
|
+
else:
|
|
427
|
+
logger.info(f"Using temperature=1.0 for reasoning model {request.provider}/{request.model}")
|
|
428
|
+
else:
|
|
429
|
+
# Use provided temperature or default to 0.7
|
|
430
|
+
temperature = request.temperature if request.temperature is not None else 0.7
|
|
431
|
+
logger.info(f"Using temperature={temperature} for model {request.provider}/{request.model}")
|
|
432
|
+
|
|
433
|
+
# Create chat request
|
|
434
|
+
chat_request = ChatRequest(
|
|
435
|
+
model=request.model,
|
|
436
|
+
messages=messages,
|
|
437
|
+
temperature=temperature,
|
|
438
|
+
max_tokens=request.max_tokens,
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
# Initialize client and make request (now using native async)
|
|
442
|
+
client = LLMClient(provider=request.provider)
|
|
443
|
+
response = await client.chat_completion(chat_request)
|
|
444
|
+
|
|
445
|
+
# Track cost
|
|
446
|
+
cost_tracker.add_entry(
|
|
447
|
+
provider=response.provider,
|
|
448
|
+
model=response.model,
|
|
449
|
+
prompt_tokens=response.usage.prompt_tokens,
|
|
450
|
+
completion_tokens=response.usage.completion_tokens,
|
|
451
|
+
total_tokens=response.usage.total_tokens,
|
|
452
|
+
cost_usd=response.usage.cost_usd,
|
|
453
|
+
request_id=response.id,
|
|
454
|
+
cached_tokens=response.usage.cached_tokens,
|
|
455
|
+
cache_creation_tokens=response.usage.cache_creation_tokens,
|
|
456
|
+
cache_read_tokens=response.usage.cache_read_tokens,
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
return ChatCompletionResponse(
|
|
460
|
+
id=response.id,
|
|
461
|
+
provider=response.provider,
|
|
462
|
+
model=response.model,
|
|
463
|
+
content=response.content,
|
|
464
|
+
finish_reason=response.finish_reason,
|
|
465
|
+
usage={
|
|
466
|
+
"prompt_tokens": response.usage.prompt_tokens,
|
|
467
|
+
"completion_tokens": response.usage.completion_tokens,
|
|
468
|
+
"total_tokens": response.usage.total_tokens,
|
|
469
|
+
},
|
|
470
|
+
cost_usd=response.usage.cost_usd,
|
|
471
|
+
)
|
|
472
|
+
except HTTPException:
|
|
473
|
+
# Re-raise our custom HTTP exceptions (token limits, etc.)
|
|
474
|
+
raise
|
|
475
|
+
except Exception as e:
|
|
476
|
+
error_msg = str(e)
|
|
477
|
+
logger.error(f"Chat completion error: {error_msg}")
|
|
478
|
+
|
|
479
|
+
# Determine error type and status code
|
|
480
|
+
status_code = 500
|
|
481
|
+
error_type = "internal_error"
|
|
482
|
+
suggestion = None
|
|
483
|
+
|
|
484
|
+
if "insufficient balance" in error_msg.lower():
|
|
485
|
+
status_code = 402
|
|
486
|
+
error_type = "insufficient_balance_error"
|
|
487
|
+
elif "authentication" in error_msg.lower() or "api key" in error_msg.lower():
|
|
488
|
+
status_code = 401
|
|
489
|
+
error_type = "authentication_error"
|
|
490
|
+
elif "rate limit" in error_msg.lower():
|
|
491
|
+
status_code = 429
|
|
492
|
+
error_type = "rate_limit_error"
|
|
493
|
+
elif "not found" in error_msg.lower():
|
|
494
|
+
status_code = 404
|
|
495
|
+
error_type = "not_found_error"
|
|
496
|
+
elif "invalid model" in error_msg.lower():
|
|
497
|
+
status_code = 400
|
|
498
|
+
error_type = "invalid_model_error"
|
|
499
|
+
elif "temperature" in error_msg.lower() and "not support" in error_msg.lower():
|
|
500
|
+
status_code = 400
|
|
501
|
+
error_type = "invalid_parameter_error"
|
|
502
|
+
# Catch provider API token limit errors that slip through
|
|
503
|
+
elif "too long" in error_msg.lower() or "maximum" in error_msg.lower():
|
|
504
|
+
status_code = 413
|
|
505
|
+
error_type = "input_too_long"
|
|
506
|
+
|
|
507
|
+
# Extract token count from error if available
|
|
508
|
+
import re
|
|
509
|
+
token_match = re.search(r'(\d+)\s+tokens?\s+>\s+(\d+)', error_msg)
|
|
510
|
+
if token_match:
|
|
511
|
+
actual_tokens = int(token_match.group(1))
|
|
512
|
+
limit_tokens = int(token_match.group(2))
|
|
513
|
+
|
|
514
|
+
# Get model info to provide smart suggestions
|
|
515
|
+
model_info = MODEL_CATALOG.get(request.provider, {}).get(request.model, {})
|
|
516
|
+
context_window = model_info.get("context", 0)
|
|
517
|
+
api_max_input = model_info.get("api_max_input")
|
|
518
|
+
|
|
519
|
+
if api_max_input and context_window > api_max_input:
|
|
520
|
+
suggestion = f"✓ Enable 'Smart Chunking' checkbox to reduce tokens by 40-90%\n✓ Switch to Google Gemini models (no API input limits): gemini-2.5-pro, gemini-2.5-flash\n✓ Your input: {actual_tokens:,} tokens | API limit: {limit_tokens:,} tokens | Model context: {context_window:,} tokens"
|
|
521
|
+
else:
|
|
522
|
+
suggestion = f"✓ Switch to a model with larger context window (Google Gemini 2.5: 1M tokens)\n✓ Enable 'Smart Chunking' to reduce token usage\n✓ Your input: {actual_tokens:,} tokens | Model limit: {limit_tokens:,} tokens"
|
|
523
|
+
|
|
524
|
+
detail = {
|
|
525
|
+
"error": error_type,
|
|
526
|
+
"detail": error_msg,
|
|
527
|
+
"provider": request.provider,
|
|
528
|
+
"model": request.model
|
|
529
|
+
}
|
|
530
|
+
|
|
531
|
+
if suggestion:
|
|
532
|
+
detail["suggestion"] = suggestion
|
|
533
|
+
|
|
534
|
+
raise HTTPException(
|
|
535
|
+
status_code=status_code,
|
|
536
|
+
detail=detail
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
@app.websocket("/api/chat/stream")
|
|
541
|
+
async def chat_stream(websocket: WebSocket):
|
|
542
|
+
"""
|
|
543
|
+
WebSocket endpoint for streaming chat completions.
|
|
544
|
+
|
|
545
|
+
Protocol:
|
|
546
|
+
Client sends JSON: {"provider": "openai", "model": "gpt-4", "messages": [...]}
|
|
547
|
+
Server streams JSON chunks: {"content": "...", "done": false}
|
|
548
|
+
Final message: {"content": "", "done": true, "usage": {...}}
|
|
549
|
+
"""
|
|
550
|
+
await websocket.accept()
|
|
551
|
+
|
|
552
|
+
try:
|
|
553
|
+
# Receive request
|
|
554
|
+
data = await websocket.receive_text()
|
|
555
|
+
request_data = json.loads(data)
|
|
556
|
+
|
|
557
|
+
provider = request_data.get("provider")
|
|
558
|
+
model = request_data.get("model")
|
|
559
|
+
messages_data = request_data.get("messages", [])
|
|
560
|
+
requested_temperature = request_data.get("temperature")
|
|
561
|
+
max_tokens = request_data.get("max_tokens")
|
|
562
|
+
|
|
563
|
+
# Convert messages
|
|
564
|
+
messages = [
|
|
565
|
+
Message(role=msg["role"], content=msg["content"])
|
|
566
|
+
for msg in messages_data
|
|
567
|
+
]
|
|
568
|
+
|
|
569
|
+
# Determine temperature for reasoning models
|
|
570
|
+
model_info = MODEL_CATALOG.get(provider, {}).get(model, {})
|
|
571
|
+
is_reasoning_model = model_info.get("reasoning_model", False)
|
|
572
|
+
|
|
573
|
+
# Also check model name patterns for OpenAI and DeepSeek
|
|
574
|
+
if not is_reasoning_model and provider in ["openai", "deepseek"]:
|
|
575
|
+
model_lower = model.lower()
|
|
576
|
+
is_reasoning_model = (
|
|
577
|
+
model_lower.startswith("o1") or
|
|
578
|
+
model_lower.startswith("o3") or
|
|
579
|
+
"reasoner" in model_lower or
|
|
580
|
+
"reasoning" in model_lower or
|
|
581
|
+
(model_lower.startswith("o") and len(model_lower) > 1 and model_lower[1].isdigit())
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
# Set temperature based on model type and user input
|
|
585
|
+
if is_reasoning_model:
|
|
586
|
+
temperature = 1.0
|
|
587
|
+
if requested_temperature is not None and requested_temperature != 1.0:
|
|
588
|
+
logger.warning(f"Overriding temperature={requested_temperature} to 1.0 for reasoning model {provider}/{model}")
|
|
589
|
+
else:
|
|
590
|
+
logger.info(f"Using temperature=1.0 for reasoning model {provider}/{model}")
|
|
591
|
+
else:
|
|
592
|
+
# Use provided temperature or default to 0.7
|
|
593
|
+
temperature = requested_temperature if requested_temperature is not None else 0.7
|
|
594
|
+
logger.info(f"Using temperature={temperature} for model {provider}/{model}")
|
|
595
|
+
|
|
596
|
+
# Create request
|
|
597
|
+
chat_request = ChatRequest(
|
|
598
|
+
model=model,
|
|
599
|
+
messages=messages,
|
|
600
|
+
temperature=temperature,
|
|
601
|
+
max_tokens=max_tokens,
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
# Initialize client and stream (now using native async)
|
|
605
|
+
client = LLMClient(provider=provider)
|
|
606
|
+
|
|
607
|
+
full_content = ""
|
|
608
|
+
stream = client.chat_completion_stream(chat_request)
|
|
609
|
+
async for chunk in stream:
|
|
610
|
+
full_content += chunk.content
|
|
611
|
+
await websocket.send_json({
|
|
612
|
+
"content": chunk.content,
|
|
613
|
+
"done": False,
|
|
614
|
+
})
|
|
615
|
+
|
|
616
|
+
# Send final message
|
|
617
|
+
await websocket.send_json({
|
|
618
|
+
"content": "",
|
|
619
|
+
"done": True,
|
|
620
|
+
"full_content": full_content,
|
|
621
|
+
})
|
|
622
|
+
|
|
623
|
+
except WebSocketDisconnect:
|
|
624
|
+
logger.info("WebSocket disconnected")
|
|
625
|
+
except Exception as e:
|
|
626
|
+
logger.error(f"WebSocket error: {str(e)}")
|
|
627
|
+
await websocket.send_json({
|
|
628
|
+
"error": str(e),
|
|
629
|
+
"done": True,
|
|
630
|
+
})
|
|
631
|
+
finally:
|
|
632
|
+
await websocket.close()
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
@app.get("/api/cost")
|
|
636
|
+
async def get_cost_summary():
|
|
637
|
+
"""Get cost tracking summary."""
|
|
638
|
+
return cost_tracker.get_summary()
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
@app.post("/api/cost/reset")
|
|
642
|
+
async def reset_cost_tracker():
|
|
643
|
+
"""Reset cost tracker."""
|
|
644
|
+
cost_tracker.reset()
|
|
645
|
+
return {"message": "Cost tracker reset successfully"}
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
class ProviderModelsInfo(BaseModel):
|
|
649
|
+
"""Models info for a single provider."""
|
|
650
|
+
models: List[dict]
|
|
651
|
+
active: bool
|
|
652
|
+
validation_error: Optional[str] = None
|
|
653
|
+
validation_time_ms: int = 0
|
|
654
|
+
|
|
655
|
+
|
|
656
|
+
class AllModelsResponse(BaseModel):
|
|
657
|
+
"""Response model for all validated models."""
|
|
658
|
+
providers: Dict[str, ProviderModelsInfo]
|
|
659
|
+
summary: dict
|
|
660
|
+
|
|
661
|
+
|
|
662
|
+
@app.get("/api/all-models")
|
|
663
|
+
async def get_all_validated_models():
|
|
664
|
+
"""
|
|
665
|
+
Get all validated models across all providers with detailed metadata.
|
|
666
|
+
|
|
667
|
+
Returns models with: provider, cost (input/output), context window,
|
|
668
|
+
capabilities (vision, reasoning, tools, caching), and active status.
|
|
669
|
+
"""
|
|
670
|
+
from stratifyai.utils.provider_validator import validate_provider_models
|
|
671
|
+
from stratifyai.api_key_helper import APIKeyHelper
|
|
672
|
+
|
|
673
|
+
providers_list = [
|
|
674
|
+
"openai", "anthropic", "google", "deepseek",
|
|
675
|
+
"groq", "grok", "ollama", "openrouter", "bedrock"
|
|
676
|
+
]
|
|
677
|
+
|
|
678
|
+
# Get API key availability
|
|
679
|
+
api_key_status = APIKeyHelper.check_available_providers()
|
|
680
|
+
|
|
681
|
+
result = {}
|
|
682
|
+
total_models = 0
|
|
683
|
+
active_providers = 0
|
|
684
|
+
|
|
685
|
+
# Run validation for each provider in parallel
|
|
686
|
+
loop = asyncio.get_event_loop()
|
|
687
|
+
with ThreadPoolExecutor() as pool:
|
|
688
|
+
validation_tasks = []
|
|
689
|
+
for provider in providers_list:
|
|
690
|
+
model_ids = list(MODEL_CATALOG.get(provider, {}).keys())
|
|
691
|
+
task = loop.run_in_executor(
|
|
692
|
+
pool,
|
|
693
|
+
validate_provider_models,
|
|
694
|
+
provider,
|
|
695
|
+
model_ids
|
|
696
|
+
)
|
|
697
|
+
validation_tasks.append((provider, task))
|
|
698
|
+
|
|
699
|
+
# Gather results
|
|
700
|
+
for provider, task in validation_tasks:
|
|
701
|
+
validation_result = await task
|
|
702
|
+
|
|
703
|
+
# Check if provider is active (has API key configured)
|
|
704
|
+
is_active = api_key_status.get(provider, False)
|
|
705
|
+
if is_active:
|
|
706
|
+
active_providers += 1
|
|
707
|
+
|
|
708
|
+
models_list = []
|
|
709
|
+
catalog = MODEL_CATALOG.get(provider, {})
|
|
710
|
+
|
|
711
|
+
# Use valid models if available, otherwise use catalog
|
|
712
|
+
model_ids = validation_result["valid_models"] if not validation_result["error"] else list(catalog.keys())
|
|
713
|
+
|
|
714
|
+
for model_id in model_ids:
|
|
715
|
+
model_info = catalog.get(model_id, {})
|
|
716
|
+
|
|
717
|
+
models_list.append({
|
|
718
|
+
"id": model_id,
|
|
719
|
+
"provider": provider,
|
|
720
|
+
"context_window": model_info.get("context", 0),
|
|
721
|
+
"cost_input": model_info.get("cost_input", 0),
|
|
722
|
+
"cost_output": model_info.get("cost_output", 0),
|
|
723
|
+
"supports_vision": model_info.get("supports_vision", False),
|
|
724
|
+
"supports_tools": model_info.get("supports_tools", False),
|
|
725
|
+
"supports_caching": model_info.get("supports_caching", False),
|
|
726
|
+
"reasoning_model": model_info.get("reasoning_model", False),
|
|
727
|
+
"validated": model_id in validation_result["valid_models"],
|
|
728
|
+
})
|
|
729
|
+
|
|
730
|
+
result[provider] = ProviderModelsInfo(
|
|
731
|
+
models=models_list,
|
|
732
|
+
active=is_active,
|
|
733
|
+
validation_error=validation_result.get("error"),
|
|
734
|
+
validation_time_ms=validation_result.get("validation_time_ms", 0),
|
|
735
|
+
)
|
|
736
|
+
total_models += len(models_list)
|
|
737
|
+
|
|
738
|
+
return AllModelsResponse(
|
|
739
|
+
providers=result,
|
|
740
|
+
summary={
|
|
741
|
+
"total_models": total_models,
|
|
742
|
+
"total_providers": len(providers_list),
|
|
743
|
+
"active_providers": active_providers,
|
|
744
|
+
}
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
|
|
748
|
+
@app.get("/api/health")
|
|
749
|
+
async def health_check():
|
|
750
|
+
"""Health check endpoint."""
|
|
751
|
+
return {"status": "healthy", "version": "0.1.0"}
|
|
752
|
+
|
|
753
|
+
|
|
754
|
+
if __name__ == "__main__":
|
|
755
|
+
import uvicorn
|
|
756
|
+
# Increase body size limit to 100MB for large file uploads
|
|
757
|
+
uvicorn.run(
|
|
758
|
+
app,
|
|
759
|
+
host="0.0.0.0",
|
|
760
|
+
port=8000,
|
|
761
|
+
limit_concurrency=1000,
|
|
762
|
+
timeout_keep_alive=5
|
|
763
|
+
)
|