stratifyai 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
api/main.py ADDED
@@ -0,0 +1,763 @@
1
+ """FastAPI application for StratifyAI."""
2
+
3
+ import json
4
+ import logging
5
+ from typing import Optional, List, Dict
6
+ from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi.responses import HTMLResponse, FileResponse
9
+ from fastapi.staticfiles import StaticFiles
10
+ from pydantic import BaseModel
11
+ import os
12
+ from dotenv import load_dotenv
13
+ import asyncio
14
+ from concurrent.futures import ThreadPoolExecutor
15
+
16
+ # Load environment variables from .env file
17
+ load_dotenv()
18
+
19
+ from stratifyai import LLMClient, ChatRequest, Message, ProviderType
20
+ from stratifyai.cost_tracker import CostTracker
21
+ from stratifyai.config import MODEL_CATALOG
22
+
23
+ # Configure logging
24
+ logging.basicConfig(level=logging.INFO)
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # Initialize FastAPI app
28
+ app = FastAPI(
29
+ title="StratifyAI API",
30
+ description="Unified API for multiple LLM providers",
31
+ version="0.1.0",
32
+ )
33
+
34
+ # Configure CORS
35
+ app.add_middleware(
36
+ CORSMiddleware,
37
+ allow_origins=["*"],
38
+ allow_credentials=True,
39
+ allow_methods=["*"],
40
+ allow_headers=["*"],
41
+ )
42
+
43
+ # Global cost tracker
44
+ cost_tracker = CostTracker()
45
+
46
+ # Mount static files
47
+ static_dir = os.path.join(os.path.dirname(__file__), "static")
48
+ if os.path.exists(static_dir):
49
+ app.mount("/static", StaticFiles(directory=static_dir), name="static")
50
+
51
+
52
+ # Request/Response models
53
+ class ChatCompletionRequest(BaseModel):
54
+ """Chat completion request model."""
55
+ provider: str
56
+ model: str
57
+ messages: List[dict]
58
+ temperature: Optional[float] = None
59
+ max_tokens: Optional[int] = None
60
+ stream: bool = False
61
+ file_content: Optional[str] = None # Base64 encoded file content or plain text
62
+ file_name: Optional[str] = None # Original filename for type detection
63
+ chunked: bool = False # Enable smart chunking and summarization
64
+ chunk_size: int = 50000 # Chunk size in characters
65
+
66
+
67
+ class ChatCompletionResponse(BaseModel):
68
+ """Chat completion response model."""
69
+ id: str
70
+ provider: str
71
+ model: str
72
+ content: str
73
+ finish_reason: str
74
+ usage: dict
75
+ cost_usd: float
76
+
77
+
78
+ class ProviderInfo(BaseModel):
79
+ """Provider information model."""
80
+ name: str
81
+ models: List[str]
82
+
83
+
84
+ class ErrorResponse(BaseModel):
85
+ """Error response model."""
86
+ error: str
87
+ detail: str
88
+ error_type: str
89
+
90
+
91
+ @app.get("/")
92
+ async def root():
93
+ """Serve the frontend interface."""
94
+ static_dir = os.path.join(os.path.dirname(__file__), "static")
95
+ index_path = os.path.join(static_dir, "index.html")
96
+ if os.path.exists(index_path):
97
+ return FileResponse(index_path)
98
+ return {
99
+ "name": "StratifyAI API",
100
+ "version": "0.1.0",
101
+ "message": "Frontend not found. API endpoints available at /docs"
102
+ }
103
+
104
+
105
+ @app.get("/models")
106
+ async def models_page():
107
+ """Serve the models catalog page."""
108
+ static_dir = os.path.join(os.path.dirname(__file__), "static")
109
+ models_path = os.path.join(static_dir, "models.html")
110
+ if os.path.exists(models_path):
111
+ return FileResponse(models_path)
112
+ return {"error": "Models page not found"}
113
+
114
+
115
+ @app.get("/api/providers", response_model=List[str])
116
+ async def list_providers():
117
+ """List all available providers."""
118
+ return [
119
+ "openai",
120
+ "anthropic",
121
+ "google",
122
+ "deepseek",
123
+ "groq",
124
+ "grok",
125
+ "ollama",
126
+ "openrouter",
127
+ "bedrock",
128
+ ]
129
+
130
+
131
+ class ModelInfo(BaseModel):
132
+ """Model information."""
133
+ id: str # Model ID (e.g., 'gpt-4o')
134
+ display_name: str # Display name (e.g., 'GPT-4o')
135
+ description: str = "" # Description with labels
136
+ category: str = "" # Category for grouping
137
+ reasoning_model: bool = False
138
+ supports_vision: bool = False
139
+
140
+
141
+ class ModelListResponse(BaseModel):
142
+ """Model list response with validation metadata."""
143
+ models: List[ModelInfo]
144
+ validation: dict
145
+
146
+
147
+ @app.get("/api/models/{provider}", response_model=ModelListResponse)
148
+ async def list_models(provider: str):
149
+ """List validated models for a specific provider."""
150
+ from stratifyai.utils.provider_validator import get_validated_interactive_models
151
+
152
+ if provider not in MODEL_CATALOG:
153
+ raise HTTPException(status_code=404, detail=f"Provider '{provider}' not found")
154
+
155
+ # Run validation in background thread to avoid blocking
156
+ loop = asyncio.get_event_loop()
157
+ with ThreadPoolExecutor() as pool:
158
+ validation_data = await loop.run_in_executor(
159
+ pool,
160
+ get_validated_interactive_models,
161
+ provider
162
+ )
163
+
164
+ validated_models = validation_data["models"]
165
+ validation_result = validation_data["validation_result"]
166
+
167
+ # Log validation result
168
+ if validation_result["error"]:
169
+ logger.warning(f"Model validation for {provider}: {validation_result['error']}")
170
+ else:
171
+ logger.info(f"Model validation for {provider}: {len(validated_models)} models in {validation_result['validation_time_ms']}ms")
172
+
173
+ # If validation succeeded: return only validated models with metadata
174
+ # If validation failed with error: fall back to catalog
175
+ if validation_result["error"]:
176
+ # Fallback to catalog when validation fails
177
+ model_ids = list(MODEL_CATALOG[provider].keys())
178
+ model_metadata = MODEL_CATALOG[provider]
179
+ else:
180
+ # Show only validated models on success
181
+ model_ids = list(validated_models.keys())
182
+ model_metadata = validated_models
183
+
184
+ # Build model info list with rich metadata
185
+ models_info = []
186
+ for model_id in model_ids:
187
+ meta = model_metadata.get(model_id, {})
188
+ models_info.append(ModelInfo(
189
+ id=model_id,
190
+ display_name=meta.get("display_name", model_id),
191
+ description=meta.get("description", ""),
192
+ category=meta.get("category", ""),
193
+ reasoning_model=meta.get("reasoning_model", False),
194
+ supports_vision=meta.get("supports_vision", False),
195
+ ))
196
+
197
+ return ModelListResponse(
198
+ models=models_info,
199
+ validation=validation_result
200
+ )
201
+
202
+
203
+ @app.get("/api/model-info/{provider}/{model}")
204
+ async def get_model_info(provider: str, model: str):
205
+ """Get detailed information about a specific model."""
206
+ if provider not in MODEL_CATALOG:
207
+ raise HTTPException(status_code=404, detail=f"Provider '{provider}' not found")
208
+
209
+ if model not in MODEL_CATALOG[provider]:
210
+ raise HTTPException(status_code=404, detail=f"Model '{model}' not found for provider '{provider}'")
211
+
212
+ model_info = MODEL_CATALOG[provider][model]
213
+
214
+ return {
215
+ "provider": provider,
216
+ "model": model,
217
+ "fixed_temperature": model_info.get("fixed_temperature"),
218
+ "reasoning_model": model_info.get("reasoning_model", False),
219
+ "supports_vision": model_info.get("supports_vision", False),
220
+ "supports_tools": model_info.get("supports_tools", False),
221
+ "supports_caching": model_info.get("supports_caching", False),
222
+ "context": model_info.get("context", 0),
223
+ }
224
+
225
+
226
+ @app.get("/api/provider-info", response_model=List[ProviderInfo])
227
+ async def get_provider_info():
228
+ """Get information about all providers and their models."""
229
+ providers = []
230
+ for provider_name, models in MODEL_CATALOG.items():
231
+ providers.append(ProviderInfo(
232
+ name=provider_name,
233
+ models=list(models.keys())
234
+ ))
235
+ return providers
236
+
237
+
238
+ @app.post("/api/chat", response_model=ChatCompletionResponse)
239
+ async def chat_completion(request: ChatCompletionRequest):
240
+ """
241
+ Execute a chat completion request.
242
+
243
+ Args:
244
+ request: Chat completion request
245
+
246
+ Returns:
247
+ Chat completion response with cost tracking
248
+ """
249
+ try:
250
+ # Convert messages to Message objects
251
+ messages = [
252
+ Message(role=msg["role"], content=msg["content"])
253
+ for msg in request.messages
254
+ ]
255
+
256
+ # Process file if provided
257
+ if request.file_content and request.file_name:
258
+ from stratifyai.summarization import summarize_file_async
259
+ from stratifyai.utils.file_analyzer import analyze_file
260
+ from pathlib import Path
261
+ import tempfile
262
+ import base64
263
+
264
+ # Detect if content is base64 encoded or plain text
265
+ try:
266
+ # Try to decode as base64
267
+ file_bytes = base64.b64decode(request.file_content)
268
+ file_text = file_bytes.decode('utf-8')
269
+ except Exception:
270
+ # If decoding fails, assume it's plain text
271
+ file_text = request.file_content
272
+
273
+ # Apply chunking if enabled
274
+ if request.chunked:
275
+ logger.info(f"Chunking file {request.file_name} (size: {len(file_text)} chars, chunk_size: {request.chunk_size})")
276
+
277
+ # Create temporary file for analysis
278
+ with tempfile.NamedTemporaryFile(mode='w', suffix=Path(request.file_name).suffix, delete=False) as tmp_file:
279
+ tmp_file.write(file_text)
280
+ tmp_path = Path(tmp_file.name) # Convert to Path object
281
+
282
+ try:
283
+ # Analyze file to determine if chunking is beneficial
284
+ analysis = analyze_file(tmp_path, request.provider, request.model)
285
+ logger.info(f"File analysis: type={analysis.file_type.value}, tokens={analysis.estimated_tokens}")
286
+
287
+ # Perform chunking and summarization
288
+ # Use a cheap model for summarization (gpt-4o-mini or similar)
289
+ # Auto-select based on provider
290
+ summarization_models = {
291
+ "openai": "gpt-4o-mini",
292
+ "anthropic": "claude-3-5-sonnet-20241022",
293
+ "google": "gemini-2.5-flash",
294
+ "deepseek": "deepseek-chat",
295
+ "groq": "llama-3.1-8b-instant",
296
+ "grok": "grok-beta",
297
+ "openrouter": "google/gemini-2.5-flash",
298
+ "ollama": "llama3.2",
299
+ "bedrock": "anthropic.claude-3-5-haiku-20241022-v1:0",
300
+ }
301
+ summarization_model = summarization_models.get(request.provider, "gpt-4o-mini")
302
+
303
+ client = LLMClient(provider=request.provider)
304
+
305
+ # Get context from last user message if available
306
+ context = None
307
+ if messages and messages[-1].role == "user":
308
+ context = messages[-1].content
309
+
310
+ # Run async summarization with cheap model
311
+ result = await summarize_file_async(
312
+ file_text,
313
+ client,
314
+ request.chunk_size,
315
+ summarization_model,
316
+ context,
317
+ False # show_progress=False for API
318
+ )
319
+
320
+ # Use summarized content
321
+ file_content_to_use = result['summary']
322
+ logger.info(f"Chunking complete: {result['reduction_percentage']}% reduction ({result['original_length']} -> {result['summary_length']} chars)")
323
+ finally:
324
+ # Clean up temp file
325
+ import os
326
+ os.unlink(tmp_path)
327
+ else:
328
+ # Use file content as-is
329
+ file_content_to_use = file_text
330
+
331
+ # Append file content to last user message or create new message
332
+ if messages and messages[-1].role == "user":
333
+ # Combine with existing user message
334
+ messages[-1].content = f"{messages[-1].content}\n\n[File: {request.file_name}]\n\n{file_content_to_use}"
335
+ else:
336
+ # Create new user message with file content
337
+ messages.append(Message(
338
+ role="user",
339
+ content=f"[File: {request.file_name}]\n\n{file_content_to_use}"
340
+ ))
341
+
342
+ # Validate token count before making request
343
+ from stratifyai.utils.token_counter import count_tokens_for_messages, get_context_window
344
+ estimated_tokens = count_tokens_for_messages(messages, request.provider, request.model)
345
+
346
+ # Get context window and API limits
347
+ context_window = get_context_window(request.provider, request.model)
348
+ model_info = MODEL_CATALOG.get(request.provider, {}).get(request.model, {})
349
+ api_max_input = model_info.get("api_max_input")
350
+ effective_limit = api_max_input if api_max_input and api_max_input < context_window else context_window
351
+
352
+ # Check if exceeds absolute maximum (1M tokens)
353
+ MAX_SYSTEM_LIMIT = 1_000_000
354
+ if estimated_tokens > MAX_SYSTEM_LIMIT:
355
+ raise HTTPException(
356
+ status_code=413,
357
+ detail={
358
+ "error": "content_too_large",
359
+ "message": f"File is too large to process. The content has approximately {estimated_tokens:,} tokens, which exceeds the system's maximum limit of {MAX_SYSTEM_LIMIT:,} tokens.",
360
+ "estimated_tokens": estimated_tokens,
361
+ "system_limit": MAX_SYSTEM_LIMIT,
362
+ "provider": request.provider,
363
+ "model": request.model,
364
+ "suggestion": "Please split your file into smaller chunks or use a different processing approach."
365
+ }
366
+ )
367
+
368
+ # Check if exceeds model's effective limit
369
+ if estimated_tokens > effective_limit:
370
+ # Determine if chunking could help
371
+ if api_max_input and context_window > api_max_input:
372
+ # Model has larger context but API restricts input
373
+ # Suggest chunking to reduce tokens OR switching to unrestricted model
374
+ raise HTTPException(
375
+ status_code=413,
376
+ detail={
377
+ "error": "input_too_long",
378
+ "message": f"Input is too long for {request.model}. The content has approximately {estimated_tokens:,} tokens, but the API restricts input to {api_max_input:,} tokens (despite the model's {context_window:,} token context window).",
379
+ "estimated_tokens": estimated_tokens,
380
+ "api_limit": api_max_input,
381
+ "context_window": context_window,
382
+ "provider": request.provider,
383
+ "model": request.model,
384
+ "suggestion": "✓ Enable 'Smart Chunking' checkbox to reduce tokens by 40-90%\n✓ Switch to Google Gemini models (no API input limits): gemini-2.5-pro, gemini-2.5-flash\n✓ Switch to OpenRouter with google/gemini-2.5-pro or google/gemini-2.5-flash",
385
+ "chunking_enabled": request.chunked
386
+ }
387
+ )
388
+ else:
389
+ # Model simply can't handle this much input
390
+ # Suggest switching to larger context model
391
+ raise HTTPException(
392
+ status_code=413,
393
+ detail={
394
+ "error": "input_too_long",
395
+ "message": f"Input is too long for {request.model}. The content has approximately {estimated_tokens:,} tokens, which exceeds the model's maximum of {effective_limit:,} tokens.",
396
+ "estimated_tokens": estimated_tokens,
397
+ "model_limit": effective_limit,
398
+ "provider": request.provider,
399
+ "model": request.model,
400
+ "suggestion": "✓ Switch to a model with larger context window:\n - Google Gemini 2.5 Pro (1M tokens, no API limits)\n - Google Gemini 2.5 Flash (1M tokens, cheaper)\n - Claude Opus 4.5 (1M context, 200k API limit)\n✓ Enable 'Smart Chunking' to reduce token usage",
401
+ "chunking_enabled": request.chunked
402
+ }
403
+ )
404
+
405
+ # Determine temperature for reasoning models
406
+ model_info = MODEL_CATALOG.get(request.provider, {}).get(request.model, {})
407
+ is_reasoning_model = model_info.get("reasoning_model", False)
408
+
409
+ # Also check model name patterns for OpenAI and DeepSeek
410
+ if not is_reasoning_model and request.provider in ["openai", "deepseek"]:
411
+ model_lower = request.model.lower()
412
+ is_reasoning_model = (
413
+ model_lower.startswith("o1") or
414
+ model_lower.startswith("o3") or
415
+ model_lower.startswith("gpt-5") or
416
+ "reasoner" in model_lower or
417
+ "reasoning" in model_lower or
418
+ (model_lower.startswith("o") and len(model_lower) > 1 and model_lower[1].isdigit())
419
+ )
420
+
421
+ # Set temperature based on model type and user input
422
+ if is_reasoning_model:
423
+ temperature = 1.0
424
+ if request.temperature is not None and request.temperature != 1.0:
425
+ logger.warning(f"Overriding temperature={request.temperature} to 1.0 for reasoning model {request.provider}/{request.model}")
426
+ else:
427
+ logger.info(f"Using temperature=1.0 for reasoning model {request.provider}/{request.model}")
428
+ else:
429
+ # Use provided temperature or default to 0.7
430
+ temperature = request.temperature if request.temperature is not None else 0.7
431
+ logger.info(f"Using temperature={temperature} for model {request.provider}/{request.model}")
432
+
433
+ # Create chat request
434
+ chat_request = ChatRequest(
435
+ model=request.model,
436
+ messages=messages,
437
+ temperature=temperature,
438
+ max_tokens=request.max_tokens,
439
+ )
440
+
441
+ # Initialize client and make request (now using native async)
442
+ client = LLMClient(provider=request.provider)
443
+ response = await client.chat_completion(chat_request)
444
+
445
+ # Track cost
446
+ cost_tracker.add_entry(
447
+ provider=response.provider,
448
+ model=response.model,
449
+ prompt_tokens=response.usage.prompt_tokens,
450
+ completion_tokens=response.usage.completion_tokens,
451
+ total_tokens=response.usage.total_tokens,
452
+ cost_usd=response.usage.cost_usd,
453
+ request_id=response.id,
454
+ cached_tokens=response.usage.cached_tokens,
455
+ cache_creation_tokens=response.usage.cache_creation_tokens,
456
+ cache_read_tokens=response.usage.cache_read_tokens,
457
+ )
458
+
459
+ return ChatCompletionResponse(
460
+ id=response.id,
461
+ provider=response.provider,
462
+ model=response.model,
463
+ content=response.content,
464
+ finish_reason=response.finish_reason,
465
+ usage={
466
+ "prompt_tokens": response.usage.prompt_tokens,
467
+ "completion_tokens": response.usage.completion_tokens,
468
+ "total_tokens": response.usage.total_tokens,
469
+ },
470
+ cost_usd=response.usage.cost_usd,
471
+ )
472
+ except HTTPException:
473
+ # Re-raise our custom HTTP exceptions (token limits, etc.)
474
+ raise
475
+ except Exception as e:
476
+ error_msg = str(e)
477
+ logger.error(f"Chat completion error: {error_msg}")
478
+
479
+ # Determine error type and status code
480
+ status_code = 500
481
+ error_type = "internal_error"
482
+ suggestion = None
483
+
484
+ if "insufficient balance" in error_msg.lower():
485
+ status_code = 402
486
+ error_type = "insufficient_balance_error"
487
+ elif "authentication" in error_msg.lower() or "api key" in error_msg.lower():
488
+ status_code = 401
489
+ error_type = "authentication_error"
490
+ elif "rate limit" in error_msg.lower():
491
+ status_code = 429
492
+ error_type = "rate_limit_error"
493
+ elif "not found" in error_msg.lower():
494
+ status_code = 404
495
+ error_type = "not_found_error"
496
+ elif "invalid model" in error_msg.lower():
497
+ status_code = 400
498
+ error_type = "invalid_model_error"
499
+ elif "temperature" in error_msg.lower() and "not support" in error_msg.lower():
500
+ status_code = 400
501
+ error_type = "invalid_parameter_error"
502
+ # Catch provider API token limit errors that slip through
503
+ elif "too long" in error_msg.lower() or "maximum" in error_msg.lower():
504
+ status_code = 413
505
+ error_type = "input_too_long"
506
+
507
+ # Extract token count from error if available
508
+ import re
509
+ token_match = re.search(r'(\d+)\s+tokens?\s+>\s+(\d+)', error_msg)
510
+ if token_match:
511
+ actual_tokens = int(token_match.group(1))
512
+ limit_tokens = int(token_match.group(2))
513
+
514
+ # Get model info to provide smart suggestions
515
+ model_info = MODEL_CATALOG.get(request.provider, {}).get(request.model, {})
516
+ context_window = model_info.get("context", 0)
517
+ api_max_input = model_info.get("api_max_input")
518
+
519
+ if api_max_input and context_window > api_max_input:
520
+ suggestion = f"✓ Enable 'Smart Chunking' checkbox to reduce tokens by 40-90%\n✓ Switch to Google Gemini models (no API input limits): gemini-2.5-pro, gemini-2.5-flash\n✓ Your input: {actual_tokens:,} tokens | API limit: {limit_tokens:,} tokens | Model context: {context_window:,} tokens"
521
+ else:
522
+ suggestion = f"✓ Switch to a model with larger context window (Google Gemini 2.5: 1M tokens)\n✓ Enable 'Smart Chunking' to reduce token usage\n✓ Your input: {actual_tokens:,} tokens | Model limit: {limit_tokens:,} tokens"
523
+
524
+ detail = {
525
+ "error": error_type,
526
+ "detail": error_msg,
527
+ "provider": request.provider,
528
+ "model": request.model
529
+ }
530
+
531
+ if suggestion:
532
+ detail["suggestion"] = suggestion
533
+
534
+ raise HTTPException(
535
+ status_code=status_code,
536
+ detail=detail
537
+ )
538
+
539
+
540
+ @app.websocket("/api/chat/stream")
541
+ async def chat_stream(websocket: WebSocket):
542
+ """
543
+ WebSocket endpoint for streaming chat completions.
544
+
545
+ Protocol:
546
+ Client sends JSON: {"provider": "openai", "model": "gpt-4", "messages": [...]}
547
+ Server streams JSON chunks: {"content": "...", "done": false}
548
+ Final message: {"content": "", "done": true, "usage": {...}}
549
+ """
550
+ await websocket.accept()
551
+
552
+ try:
553
+ # Receive request
554
+ data = await websocket.receive_text()
555
+ request_data = json.loads(data)
556
+
557
+ provider = request_data.get("provider")
558
+ model = request_data.get("model")
559
+ messages_data = request_data.get("messages", [])
560
+ requested_temperature = request_data.get("temperature")
561
+ max_tokens = request_data.get("max_tokens")
562
+
563
+ # Convert messages
564
+ messages = [
565
+ Message(role=msg["role"], content=msg["content"])
566
+ for msg in messages_data
567
+ ]
568
+
569
+ # Determine temperature for reasoning models
570
+ model_info = MODEL_CATALOG.get(provider, {}).get(model, {})
571
+ is_reasoning_model = model_info.get("reasoning_model", False)
572
+
573
+ # Also check model name patterns for OpenAI and DeepSeek
574
+ if not is_reasoning_model and provider in ["openai", "deepseek"]:
575
+ model_lower = model.lower()
576
+ is_reasoning_model = (
577
+ model_lower.startswith("o1") or
578
+ model_lower.startswith("o3") or
579
+ "reasoner" in model_lower or
580
+ "reasoning" in model_lower or
581
+ (model_lower.startswith("o") and len(model_lower) > 1 and model_lower[1].isdigit())
582
+ )
583
+
584
+ # Set temperature based on model type and user input
585
+ if is_reasoning_model:
586
+ temperature = 1.0
587
+ if requested_temperature is not None and requested_temperature != 1.0:
588
+ logger.warning(f"Overriding temperature={requested_temperature} to 1.0 for reasoning model {provider}/{model}")
589
+ else:
590
+ logger.info(f"Using temperature=1.0 for reasoning model {provider}/{model}")
591
+ else:
592
+ # Use provided temperature or default to 0.7
593
+ temperature = requested_temperature if requested_temperature is not None else 0.7
594
+ logger.info(f"Using temperature={temperature} for model {provider}/{model}")
595
+
596
+ # Create request
597
+ chat_request = ChatRequest(
598
+ model=model,
599
+ messages=messages,
600
+ temperature=temperature,
601
+ max_tokens=max_tokens,
602
+ )
603
+
604
+ # Initialize client and stream (now using native async)
605
+ client = LLMClient(provider=provider)
606
+
607
+ full_content = ""
608
+ stream = client.chat_completion_stream(chat_request)
609
+ async for chunk in stream:
610
+ full_content += chunk.content
611
+ await websocket.send_json({
612
+ "content": chunk.content,
613
+ "done": False,
614
+ })
615
+
616
+ # Send final message
617
+ await websocket.send_json({
618
+ "content": "",
619
+ "done": True,
620
+ "full_content": full_content,
621
+ })
622
+
623
+ except WebSocketDisconnect:
624
+ logger.info("WebSocket disconnected")
625
+ except Exception as e:
626
+ logger.error(f"WebSocket error: {str(e)}")
627
+ await websocket.send_json({
628
+ "error": str(e),
629
+ "done": True,
630
+ })
631
+ finally:
632
+ await websocket.close()
633
+
634
+
635
+ @app.get("/api/cost")
636
+ async def get_cost_summary():
637
+ """Get cost tracking summary."""
638
+ return cost_tracker.get_summary()
639
+
640
+
641
+ @app.post("/api/cost/reset")
642
+ async def reset_cost_tracker():
643
+ """Reset cost tracker."""
644
+ cost_tracker.reset()
645
+ return {"message": "Cost tracker reset successfully"}
646
+
647
+
648
+ class ProviderModelsInfo(BaseModel):
649
+ """Models info for a single provider."""
650
+ models: List[dict]
651
+ active: bool
652
+ validation_error: Optional[str] = None
653
+ validation_time_ms: int = 0
654
+
655
+
656
+ class AllModelsResponse(BaseModel):
657
+ """Response model for all validated models."""
658
+ providers: Dict[str, ProviderModelsInfo]
659
+ summary: dict
660
+
661
+
662
+ @app.get("/api/all-models")
663
+ async def get_all_validated_models():
664
+ """
665
+ Get all validated models across all providers with detailed metadata.
666
+
667
+ Returns models with: provider, cost (input/output), context window,
668
+ capabilities (vision, reasoning, tools, caching), and active status.
669
+ """
670
+ from stratifyai.utils.provider_validator import validate_provider_models
671
+ from stratifyai.api_key_helper import APIKeyHelper
672
+
673
+ providers_list = [
674
+ "openai", "anthropic", "google", "deepseek",
675
+ "groq", "grok", "ollama", "openrouter", "bedrock"
676
+ ]
677
+
678
+ # Get API key availability
679
+ api_key_status = APIKeyHelper.check_available_providers()
680
+
681
+ result = {}
682
+ total_models = 0
683
+ active_providers = 0
684
+
685
+ # Run validation for each provider in parallel
686
+ loop = asyncio.get_event_loop()
687
+ with ThreadPoolExecutor() as pool:
688
+ validation_tasks = []
689
+ for provider in providers_list:
690
+ model_ids = list(MODEL_CATALOG.get(provider, {}).keys())
691
+ task = loop.run_in_executor(
692
+ pool,
693
+ validate_provider_models,
694
+ provider,
695
+ model_ids
696
+ )
697
+ validation_tasks.append((provider, task))
698
+
699
+ # Gather results
700
+ for provider, task in validation_tasks:
701
+ validation_result = await task
702
+
703
+ # Check if provider is active (has API key configured)
704
+ is_active = api_key_status.get(provider, False)
705
+ if is_active:
706
+ active_providers += 1
707
+
708
+ models_list = []
709
+ catalog = MODEL_CATALOG.get(provider, {})
710
+
711
+ # Use valid models if available, otherwise use catalog
712
+ model_ids = validation_result["valid_models"] if not validation_result["error"] else list(catalog.keys())
713
+
714
+ for model_id in model_ids:
715
+ model_info = catalog.get(model_id, {})
716
+
717
+ models_list.append({
718
+ "id": model_id,
719
+ "provider": provider,
720
+ "context_window": model_info.get("context", 0),
721
+ "cost_input": model_info.get("cost_input", 0),
722
+ "cost_output": model_info.get("cost_output", 0),
723
+ "supports_vision": model_info.get("supports_vision", False),
724
+ "supports_tools": model_info.get("supports_tools", False),
725
+ "supports_caching": model_info.get("supports_caching", False),
726
+ "reasoning_model": model_info.get("reasoning_model", False),
727
+ "validated": model_id in validation_result["valid_models"],
728
+ })
729
+
730
+ result[provider] = ProviderModelsInfo(
731
+ models=models_list,
732
+ active=is_active,
733
+ validation_error=validation_result.get("error"),
734
+ validation_time_ms=validation_result.get("validation_time_ms", 0),
735
+ )
736
+ total_models += len(models_list)
737
+
738
+ return AllModelsResponse(
739
+ providers=result,
740
+ summary={
741
+ "total_models": total_models,
742
+ "total_providers": len(providers_list),
743
+ "active_providers": active_providers,
744
+ }
745
+ )
746
+
747
+
748
+ @app.get("/api/health")
749
+ async def health_check():
750
+ """Health check endpoint."""
751
+ return {"status": "healthy", "version": "0.1.0"}
752
+
753
+
754
+ if __name__ == "__main__":
755
+ import uvicorn
756
+ # Increase body size limit to 100MB for large file uploads
757
+ uvicorn.run(
758
+ app,
759
+ host="0.0.0.0",
760
+ port=8000,
761
+ limit_concurrency=1000,
762
+ timeout_keep_alive=5
763
+ )