isa-model 0.3.91__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/client.py +1166 -584
- isa_model/core/cache/redis_cache.py +410 -0
- isa_model/core/config/config_manager.py +282 -12
- isa_model/core/config.py +91 -1
- isa_model/core/database/__init__.py +1 -0
- isa_model/core/database/direct_db_client.py +114 -0
- isa_model/core/database/migration_manager.py +563 -0
- isa_model/core/database/migrations.py +297 -0
- isa_model/core/database/supabase_client.py +258 -0
- isa_model/core/dependencies.py +316 -0
- isa_model/core/discovery/__init__.py +19 -0
- isa_model/core/discovery/consul_discovery.py +190 -0
- isa_model/core/logging/__init__.py +54 -0
- isa_model/core/logging/influx_logger.py +523 -0
- isa_model/core/logging/loki_logger.py +160 -0
- isa_model/core/models/__init__.py +46 -0
- isa_model/core/models/config_models.py +625 -0
- isa_model/core/models/deployment_billing_tracker.py +430 -0
- isa_model/core/models/model_billing_tracker.py +60 -88
- isa_model/core/models/model_manager.py +66 -25
- isa_model/core/models/model_metadata.py +690 -0
- isa_model/core/models/model_repo.py +217 -55
- isa_model/core/models/model_statistics_tracker.py +234 -0
- isa_model/core/models/model_storage.py +0 -1
- isa_model/core/models/model_version_manager.py +959 -0
- isa_model/core/models/system_models.py +857 -0
- isa_model/core/pricing_manager.py +2 -249
- isa_model/core/repositories/__init__.py +9 -0
- isa_model/core/repositories/config_repository.py +912 -0
- isa_model/core/resilience/circuit_breaker.py +366 -0
- isa_model/core/security/secrets.py +358 -0
- isa_model/core/services/__init__.py +2 -4
- isa_model/core/services/intelligent_model_selector.py +479 -370
- isa_model/core/storage/hf_storage.py +2 -2
- isa_model/core/types.py +8 -0
- isa_model/deployment/__init__.py +5 -48
- isa_model/deployment/core/__init__.py +2 -31
- isa_model/deployment/core/deployment_manager.py +1278 -368
- isa_model/deployment/local/__init__.py +31 -0
- isa_model/deployment/local/config.py +248 -0
- isa_model/deployment/local/gpu_gateway.py +607 -0
- isa_model/deployment/local/health_checker.py +428 -0
- isa_model/deployment/local/provider.py +586 -0
- isa_model/deployment/local/tensorrt_service.py +621 -0
- isa_model/deployment/local/transformers_service.py +644 -0
- isa_model/deployment/local/vllm_service.py +527 -0
- isa_model/deployment/modal/__init__.py +8 -0
- isa_model/deployment/modal/config.py +136 -0
- isa_model/deployment/modal/deployer.py +894 -0
- isa_model/deployment/modal/services/__init__.py +3 -0
- isa_model/deployment/modal/services/audio/__init__.py +1 -0
- isa_model/deployment/modal/services/audio/isa_audio_chatTTS_service.py +520 -0
- isa_model/deployment/modal/services/audio/isa_audio_openvoice_service.py +758 -0
- isa_model/deployment/modal/services/audio/isa_audio_service_v2.py +1044 -0
- isa_model/deployment/modal/services/embedding/__init__.py +1 -0
- isa_model/deployment/modal/services/embedding/isa_embed_rerank_service.py +296 -0
- isa_model/deployment/modal/services/llm/__init__.py +1 -0
- isa_model/deployment/modal/services/llm/isa_llm_service.py +424 -0
- isa_model/deployment/modal/services/video/__init__.py +1 -0
- isa_model/deployment/modal/services/video/isa_video_hunyuan_service.py +423 -0
- isa_model/deployment/modal/services/vision/__init__.py +1 -0
- isa_model/deployment/modal/services/vision/isa_vision_ocr_service.py +519 -0
- isa_model/deployment/modal/services/vision/isa_vision_qwen25_service.py +709 -0
- isa_model/deployment/modal/services/vision/isa_vision_table_service.py +676 -0
- isa_model/deployment/modal/services/vision/isa_vision_ui_service.py +833 -0
- isa_model/deployment/modal/services/vision/isa_vision_ui_service_optimized.py +660 -0
- isa_model/deployment/models/org-org-acme-corp-tenant-a-service-llm-20250825-225822/tenant-a-service_modal_service.py +48 -0
- isa_model/deployment/models/org-test-org-123-prefix-test-service-llm-20250825-225822/prefix-test-service_modal_service.py +48 -0
- isa_model/deployment/models/test-llm-service-llm-20250825-204442/test-llm-service_modal_service.py +48 -0
- isa_model/deployment/models/test-monitoring-gpt2-llm-20250825-212906/test-monitoring-gpt2_modal_service.py +48 -0
- isa_model/deployment/models/test-monitoring-gpt2-llm-20250825-213009/test-monitoring-gpt2_modal_service.py +48 -0
- isa_model/deployment/storage/__init__.py +5 -0
- isa_model/deployment/storage/deployment_repository.py +824 -0
- isa_model/deployment/triton/__init__.py +10 -0
- isa_model/deployment/triton/config.py +196 -0
- isa_model/deployment/triton/configs/__init__.py +1 -0
- isa_model/deployment/triton/provider.py +512 -0
- isa_model/deployment/triton/scripts/__init__.py +1 -0
- isa_model/deployment/triton/templates/__init__.py +1 -0
- isa_model/inference/__init__.py +47 -1
- isa_model/inference/ai_factory.py +179 -16
- isa_model/inference/legacy_services/__init__.py +21 -0
- isa_model/inference/legacy_services/model_evaluation.py +637 -0
- isa_model/inference/legacy_services/model_service.py +573 -0
- isa_model/inference/legacy_services/model_serving.py +717 -0
- isa_model/inference/legacy_services/model_training.py +561 -0
- isa_model/inference/models/__init__.py +21 -0
- isa_model/inference/models/inference_config.py +551 -0
- isa_model/inference/models/inference_record.py +675 -0
- isa_model/inference/models/performance_models.py +714 -0
- isa_model/inference/repositories/__init__.py +9 -0
- isa_model/inference/repositories/inference_repository.py +828 -0
- isa_model/inference/services/audio/__init__.py +21 -0
- isa_model/inference/services/audio/base_realtime_service.py +225 -0
- isa_model/inference/services/audio/base_stt_service.py +184 -11
- isa_model/inference/services/audio/isa_tts_service.py +0 -0
- isa_model/inference/services/audio/openai_realtime_service.py +320 -124
- isa_model/inference/services/audio/openai_stt_service.py +53 -11
- isa_model/inference/services/base_service.py +17 -1
- isa_model/inference/services/custom_model_manager.py +277 -0
- isa_model/inference/services/embedding/__init__.py +13 -0
- isa_model/inference/services/embedding/base_embed_service.py +111 -8
- isa_model/inference/services/embedding/isa_embed_service.py +305 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +15 -3
- isa_model/inference/services/embedding/openai_embed_service.py +2 -4
- isa_model/inference/services/embedding/resilient_embed_service.py +285 -0
- isa_model/inference/services/embedding/tests/test_embedding.py +222 -0
- isa_model/inference/services/img/__init__.py +2 -2
- isa_model/inference/services/img/base_image_gen_service.py +24 -7
- isa_model/inference/services/img/replicate_image_gen_service.py +84 -422
- isa_model/inference/services/img/services/replicate_face_swap.py +193 -0
- isa_model/inference/services/img/services/replicate_flux.py +226 -0
- isa_model/inference/services/img/services/replicate_flux_kontext.py +219 -0
- isa_model/inference/services/img/services/replicate_sticker_maker.py +249 -0
- isa_model/inference/services/img/tests/test_img_client.py +297 -0
- isa_model/inference/services/llm/__init__.py +10 -2
- isa_model/inference/services/llm/base_llm_service.py +361 -26
- isa_model/inference/services/llm/cerebras_llm_service.py +628 -0
- isa_model/inference/services/llm/helpers/llm_adapter.py +71 -12
- isa_model/inference/services/llm/helpers/llm_prompts.py +342 -0
- isa_model/inference/services/llm/helpers/llm_utils.py +321 -23
- isa_model/inference/services/llm/huggingface_llm_service.py +581 -0
- isa_model/inference/services/llm/local_llm_service.py +747 -0
- isa_model/inference/services/llm/ollama_llm_service.py +11 -3
- isa_model/inference/services/llm/openai_llm_service.py +670 -56
- isa_model/inference/services/llm/yyds_llm_service.py +10 -3
- isa_model/inference/services/vision/__init__.py +27 -6
- isa_model/inference/services/vision/base_vision_service.py +118 -185
- isa_model/inference/services/vision/blip_vision_service.py +359 -0
- isa_model/inference/services/vision/helpers/image_utils.py +19 -10
- isa_model/inference/services/vision/isa_vision_service.py +634 -0
- isa_model/inference/services/vision/openai_vision_service.py +19 -10
- isa_model/inference/services/vision/tests/test_ocr_client.py +284 -0
- isa_model/inference/services/vision/vgg16_vision_service.py +257 -0
- isa_model/serving/api/cache_manager.py +245 -0
- isa_model/serving/api/dependencies/__init__.py +1 -0
- isa_model/serving/api/dependencies/auth.py +194 -0
- isa_model/serving/api/dependencies/database.py +139 -0
- isa_model/serving/api/error_handlers.py +284 -0
- isa_model/serving/api/fastapi_server.py +240 -18
- isa_model/serving/api/middleware/auth.py +317 -0
- isa_model/serving/api/middleware/security.py +268 -0
- isa_model/serving/api/middleware/tenant_context.py +414 -0
- isa_model/serving/api/routes/analytics.py +489 -0
- isa_model/serving/api/routes/config.py +645 -0
- isa_model/serving/api/routes/deployment_billing.py +315 -0
- isa_model/serving/api/routes/deployments.py +475 -0
- isa_model/serving/api/routes/gpu_gateway.py +440 -0
- isa_model/serving/api/routes/health.py +32 -12
- isa_model/serving/api/routes/inference_monitoring.py +486 -0
- isa_model/serving/api/routes/local_deployments.py +448 -0
- isa_model/serving/api/routes/logs.py +430 -0
- isa_model/serving/api/routes/settings.py +582 -0
- isa_model/serving/api/routes/tenants.py +575 -0
- isa_model/serving/api/routes/unified.py +992 -171
- isa_model/serving/api/routes/webhooks.py +479 -0
- isa_model/serving/api/startup.py +318 -0
- isa_model/serving/modal_proxy_server.py +249 -0
- isa_model/utils/gpu_utils.py +311 -0
- {isa_model-0.3.91.dist-info → isa_model-0.4.3.dist-info}/METADATA +76 -22
- isa_model-0.4.3.dist-info/RECORD +193 -0
- isa_model/deployment/cloud/__init__.py +0 -9
- isa_model/deployment/cloud/modal/__init__.py +0 -10
- isa_model/deployment/cloud/modal/isa_vision_doc_service.py +0 -766
- isa_model/deployment/cloud/modal/isa_vision_table_service.py +0 -532
- isa_model/deployment/cloud/modal/isa_vision_ui_service.py +0 -406
- isa_model/deployment/cloud/modal/register_models.py +0 -321
- isa_model/deployment/core/deployment_config.py +0 -356
- isa_model/deployment/core/isa_deployment_service.py +0 -401
- isa_model/deployment/gpu_int8_ds8/app/server.py +0 -66
- isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +0 -43
- isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +0 -35
- isa_model/deployment/runtime/deployed_service.py +0 -338
- isa_model/deployment/services/__init__.py +0 -9
- isa_model/deployment/services/auto_deploy_vision_service.py +0 -538
- isa_model/deployment/services/model_service.py +0 -332
- isa_model/deployment/services/service_monitor.py +0 -356
- isa_model/deployment/services/service_registry.py +0 -527
- isa_model/eval/__init__.py +0 -92
- isa_model/eval/benchmarks.py +0 -469
- isa_model/eval/config/__init__.py +0 -10
- isa_model/eval/config/evaluation_config.py +0 -108
- isa_model/eval/evaluators/__init__.py +0 -18
- isa_model/eval/evaluators/base_evaluator.py +0 -503
- isa_model/eval/evaluators/llm_evaluator.py +0 -472
- isa_model/eval/factory.py +0 -531
- isa_model/eval/infrastructure/__init__.py +0 -24
- isa_model/eval/infrastructure/experiment_tracker.py +0 -466
- isa_model/eval/metrics.py +0 -798
- isa_model/inference/adapter/unified_api.py +0 -248
- isa_model/inference/services/helpers/stacked_config.py +0 -148
- isa_model/inference/services/img/flux_professional_service.py +0 -603
- isa_model/inference/services/img/helpers/base_stacked_service.py +0 -274
- isa_model/inference/services/others/table_transformer_service.py +0 -61
- isa_model/inference/services/vision/doc_analysis_service.py +0 -640
- isa_model/inference/services/vision/helpers/base_stacked_service.py +0 -274
- isa_model/inference/services/vision/ui_analysis_service.py +0 -823
- isa_model/scripts/inference_tracker.py +0 -283
- isa_model/scripts/mlflow_manager.py +0 -379
- isa_model/scripts/model_registry.py +0 -465
- isa_model/scripts/register_models.py +0 -370
- isa_model/scripts/register_models_with_embeddings.py +0 -510
- isa_model/scripts/start_mlflow.py +0 -95
- isa_model/scripts/training_tracker.py +0 -257
- isa_model/training/__init__.py +0 -74
- isa_model/training/annotation/annotation_schema.py +0 -47
- isa_model/training/annotation/processors/annotation_processor.py +0 -126
- isa_model/training/annotation/storage/dataset_manager.py +0 -131
- isa_model/training/annotation/storage/dataset_schema.py +0 -44
- isa_model/training/annotation/tests/test_annotation_flow.py +0 -109
- isa_model/training/annotation/tests/test_minio copy.py +0 -113
- isa_model/training/annotation/tests/test_minio_upload.py +0 -43
- isa_model/training/annotation/views/annotation_controller.py +0 -158
- isa_model/training/cloud/__init__.py +0 -22
- isa_model/training/cloud/job_orchestrator.py +0 -402
- isa_model/training/cloud/runpod_trainer.py +0 -454
- isa_model/training/cloud/storage_manager.py +0 -482
- isa_model/training/core/__init__.py +0 -23
- isa_model/training/core/config.py +0 -181
- isa_model/training/core/dataset.py +0 -222
- isa_model/training/core/trainer.py +0 -720
- isa_model/training/core/utils.py +0 -213
- isa_model/training/factory.py +0 -424
- isa_model-0.3.91.dist-info/RECORD +0 -138
- /isa_model/{core/storage/minio_storage.py → deployment/modal/services/audio/isa_audio_fish_service.py} +0 -0
- /isa_model/deployment/{services → modal/services/vision}/simple_auto_deploy_vision_service.py +0 -0
- {isa_model-0.3.91.dist-info → isa_model-0.4.3.dist-info}/WHEEL +0 -0
- {isa_model-0.3.91.dist-info → isa_model-0.4.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,628 @@
|
|
1
|
+
import logging
|
2
|
+
import json
|
3
|
+
import asyncio
|
4
|
+
from typing import Dict, Any, List, Union, AsyncGenerator, Optional
|
5
|
+
|
6
|
+
# Conditional import for Cerebras SDK
|
7
|
+
try:
|
8
|
+
from cerebras.cloud.sdk import Cerebras
|
9
|
+
CEREBRAS_AVAILABLE = True
|
10
|
+
except ImportError:
|
11
|
+
CEREBRAS_AVAILABLE = False
|
12
|
+
Cerebras = None
|
13
|
+
|
14
|
+
from isa_model.inference.services.llm.base_llm_service import BaseLLMService
|
15
|
+
from isa_model.core.types import ServiceType
|
16
|
+
from isa_model.core.dependencies import DependencyChecker
|
17
|
+
|
18
|
+
logger = logging.getLogger(__name__)
|
19
|
+
|
20
|
+
class CerebrasLLMService(BaseLLMService):
|
21
|
+
"""
|
22
|
+
Cerebras LLM service implementation with tool calling emulation.
|
23
|
+
|
24
|
+
Cerebras provides ultra-fast inference but doesn't natively support function calling.
|
25
|
+
This implementation uses prompt engineering to emulate tool calling capabilities.
|
26
|
+
|
27
|
+
Supported models:
|
28
|
+
- llama-4-scout-17b-16e-instruct (109B params, ~2600 tokens/sec)
|
29
|
+
- llama3.1-8b (8B params, ~2200 tokens/sec)
|
30
|
+
- llama-3.3-70b (70B params, ~2100 tokens/sec)
|
31
|
+
- gpt-oss-120b (120B params, ~3000 tokens/sec)
|
32
|
+
- qwen-3-32b (32B params, ~2600 tokens/sec)
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(self, model_name: str = "llama-3.3-70b", provider_name: str = "cerebras", **kwargs):
|
36
|
+
# Check if Cerebras SDK is available
|
37
|
+
if not CEREBRAS_AVAILABLE:
|
38
|
+
install_cmd = DependencyChecker.get_install_command(packages=["cerebras-cloud-sdk"])
|
39
|
+
raise ImportError(
|
40
|
+
f"Cerebras SDK is not installed. This is required for using Cerebras models.\n"
|
41
|
+
f"Install with: {install_cmd}"
|
42
|
+
)
|
43
|
+
|
44
|
+
super().__init__(provider_name, model_name, **kwargs)
|
45
|
+
|
46
|
+
# Check if this is a reasoning model (gpt-oss-120b supports CoT reasoning)
|
47
|
+
self.is_reasoning_model = "gpt-oss" in model_name.lower()
|
48
|
+
|
49
|
+
# Get configuration from centralized config manager
|
50
|
+
provider_config = self.get_provider_config()
|
51
|
+
|
52
|
+
# Initialize Cerebras client
|
53
|
+
try:
|
54
|
+
if not provider_config.get("api_key"):
|
55
|
+
raise ValueError("Cerebras API key not found in provider configuration")
|
56
|
+
|
57
|
+
self.client = Cerebras(
|
58
|
+
api_key=provider_config["api_key"],
|
59
|
+
)
|
60
|
+
|
61
|
+
logger.info(f"Initialized CerebrasLLMService with model {self.model_name}")
|
62
|
+
if self.is_reasoning_model:
|
63
|
+
logger.info(f"Model {self.model_name} is a reasoning model with CoT support")
|
64
|
+
|
65
|
+
except Exception as e:
|
66
|
+
logger.error(f"Failed to initialize Cerebras client: {e}")
|
67
|
+
raise ValueError(f"Failed to initialize Cerebras client. Check your API key configuration: {e}") from e
|
68
|
+
|
69
|
+
self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
70
|
+
self.total_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "requests_count": 0}
|
71
|
+
|
72
|
+
# Tool calling emulation flag
|
73
|
+
self._emulate_tool_calling = True
|
74
|
+
|
75
|
+
def _create_bound_copy(self) -> 'CerebrasLLMService':
|
76
|
+
"""Create a copy of this service for tool binding"""
|
77
|
+
bound_service = object.__new__(CerebrasLLMService)
|
78
|
+
|
79
|
+
# Copy all essential attributes
|
80
|
+
bound_service.model_name = self.model_name
|
81
|
+
bound_service.provider_name = self.provider_name
|
82
|
+
bound_service.client = self.client
|
83
|
+
bound_service.last_token_usage = self.last_token_usage.copy()
|
84
|
+
bound_service.total_token_usage = self.total_token_usage.copy()
|
85
|
+
bound_service._bound_tools = self._bound_tools.copy() if self._bound_tools else []
|
86
|
+
bound_service.adapter_manager = self.adapter_manager
|
87
|
+
|
88
|
+
# Copy base class attributes
|
89
|
+
bound_service.streaming = self.streaming
|
90
|
+
bound_service.max_tokens = self.max_tokens
|
91
|
+
bound_service.temperature = self.temperature
|
92
|
+
bound_service._tool_mappings = {}
|
93
|
+
bound_service._emulate_tool_calling = self._emulate_tool_calling
|
94
|
+
|
95
|
+
# Copy BaseService attributes
|
96
|
+
bound_service.config_manager = self.config_manager
|
97
|
+
bound_service.model_manager = self.model_manager
|
98
|
+
|
99
|
+
return bound_service
|
100
|
+
|
101
|
+
def bind_tools(self, tools: List[Any], **kwargs) -> 'CerebrasLLMService':
|
102
|
+
"""
|
103
|
+
Bind tools to this LLM service for emulated function calling
|
104
|
+
|
105
|
+
Args:
|
106
|
+
tools: List of tools (functions, dicts, or LangChain tools)
|
107
|
+
**kwargs: Additional arguments for tool binding
|
108
|
+
|
109
|
+
Returns:
|
110
|
+
New LLM service instance with tools bound
|
111
|
+
"""
|
112
|
+
bound_service = self._create_bound_copy()
|
113
|
+
bound_service._bound_tools = tools
|
114
|
+
|
115
|
+
return bound_service
|
116
|
+
|
117
|
+
def _create_tool_calling_prompt(self, messages: List[Dict[str, str]], tool_schemas: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
118
|
+
"""
|
119
|
+
Create a prompt that instructs the model to use tools via structured output.
|
120
|
+
|
121
|
+
This emulates OpenAI-style function calling using prompt engineering.
|
122
|
+
"""
|
123
|
+
# Build tool descriptions
|
124
|
+
tool_descriptions = []
|
125
|
+
for tool in tool_schemas:
|
126
|
+
if tool.get("type") == "function":
|
127
|
+
func = tool.get("function", {})
|
128
|
+
name = func.get("name", "unknown")
|
129
|
+
description = func.get("description", "")
|
130
|
+
parameters = func.get("parameters", {})
|
131
|
+
|
132
|
+
tool_desc = f"- {name}: {description}"
|
133
|
+
if parameters.get("properties"):
|
134
|
+
props = []
|
135
|
+
for prop_name, prop_info in parameters["properties"].items():
|
136
|
+
prop_type = prop_info.get("type", "any")
|
137
|
+
prop_desc = prop_info.get("description", "")
|
138
|
+
props.append(f" - {prop_name} ({prop_type}): {prop_desc}")
|
139
|
+
tool_desc += "\n" + "\n".join(props)
|
140
|
+
|
141
|
+
tool_descriptions.append(tool_desc)
|
142
|
+
|
143
|
+
# Create system message with tool instructions
|
144
|
+
tool_system_msg = f"""You have access to the following tools:
|
145
|
+
|
146
|
+
{chr(10).join(tool_descriptions)}
|
147
|
+
|
148
|
+
IMPORTANT INSTRUCTIONS:
|
149
|
+
1. Analyze the user's request carefully
|
150
|
+
2. Select ALL APPROPRIATE tools needed to fulfill the request
|
151
|
+
3. You can call MULTIPLE tools in a single response if needed
|
152
|
+
4. When tools are needed, respond ONLY with a JSON object (no other text)
|
153
|
+
5. Choose tools based on their description and purpose
|
154
|
+
|
155
|
+
When you need to use tool(s), respond with ONLY this JSON format:
|
156
|
+
{{
|
157
|
+
"tool_calls": [
|
158
|
+
{{
|
159
|
+
"id": "call_1",
|
160
|
+
"type": "function",
|
161
|
+
"function": {{
|
162
|
+
"name": "<exact_tool_name_from_list>",
|
163
|
+
"arguments": "{{\\"param1\\": \\"value1\\", \\"param2\\": \\"value2\\"}}"
|
164
|
+
}}
|
165
|
+
}},
|
166
|
+
{{
|
167
|
+
"id": "call_2",
|
168
|
+
"type": "function",
|
169
|
+
"function": {{
|
170
|
+
"name": "<another_tool_if_needed>",
|
171
|
+
"arguments": "{{\\"param1\\": \\"value1\\"}}"
|
172
|
+
}}
|
173
|
+
}}
|
174
|
+
]
|
175
|
+
}}
|
176
|
+
|
177
|
+
Examples:
|
178
|
+
- Single tool: "Calculate 5+5" → use calculate tool once
|
179
|
+
- Multiple tools: "Weather in Paris and Tokyo" → use get_weather twice
|
180
|
+
- Multiple different tools: "Book flight and check weather" → use book_flight AND get_weather
|
181
|
+
|
182
|
+
Only respond normally WITHOUT JSON if the request does NOT require any of the available tools."""
|
183
|
+
|
184
|
+
# Prepend or merge with existing system message
|
185
|
+
modified_messages = messages.copy()
|
186
|
+
if modified_messages and modified_messages[0].get("role") == "system":
|
187
|
+
# Merge with existing system message
|
188
|
+
modified_messages[0]["content"] = tool_system_msg + "\n\n" + modified_messages[0]["content"]
|
189
|
+
else:
|
190
|
+
# Add new system message
|
191
|
+
modified_messages.insert(0, {"role": "system", "content": tool_system_msg})
|
192
|
+
|
193
|
+
return modified_messages
|
194
|
+
|
195
|
+
def _add_reasoning_instruction(self, messages: List[Dict[str, str]], level: str = "high") -> List[Dict[str, str]]:
|
196
|
+
"""
|
197
|
+
Add reasoning level instruction to system message for gpt-oss-120b.
|
198
|
+
|
199
|
+
Args:
|
200
|
+
messages: List of message dicts
|
201
|
+
level: Reasoning level - "low", "medium", or "high"
|
202
|
+
|
203
|
+
Returns:
|
204
|
+
Modified messages with reasoning instruction
|
205
|
+
"""
|
206
|
+
# Create reasoning instruction
|
207
|
+
reasoning_instruction = f"Reasoning: {level}"
|
208
|
+
|
209
|
+
# Find or create system message
|
210
|
+
modified_messages = messages.copy()
|
211
|
+
|
212
|
+
if modified_messages and modified_messages[0].get("role") == "system":
|
213
|
+
# Append to existing system message
|
214
|
+
modified_messages[0]["content"] = f"{reasoning_instruction}\n\n{modified_messages[0]['content']}"
|
215
|
+
else:
|
216
|
+
# Insert new system message at the beginning
|
217
|
+
modified_messages.insert(0, {
|
218
|
+
"role": "system",
|
219
|
+
"content": reasoning_instruction
|
220
|
+
})
|
221
|
+
|
222
|
+
logger.info(f"Added reasoning level '{level}' for gpt-oss-120b")
|
223
|
+
return modified_messages
|
224
|
+
|
225
|
+
def _parse_tool_calling_response(self, content: str) -> tuple[Optional[str], Optional[List[Dict[str, Any]]]]:
|
226
|
+
"""
|
227
|
+
Parse the model's response to extract tool calls.
|
228
|
+
|
229
|
+
Returns:
|
230
|
+
(text_content, tool_calls) where tool_calls is None if no tools were called
|
231
|
+
"""
|
232
|
+
content = content.strip()
|
233
|
+
|
234
|
+
# Try to parse as JSON
|
235
|
+
try:
|
236
|
+
# Check if response contains JSON
|
237
|
+
if content.startswith("{") and "tool_calls" in content:
|
238
|
+
data = json.loads(content)
|
239
|
+
tool_calls = data.get("tool_calls", [])
|
240
|
+
|
241
|
+
if tool_calls:
|
242
|
+
# Convert to OpenAI format
|
243
|
+
formatted_calls = []
|
244
|
+
for call in tool_calls:
|
245
|
+
formatted_calls.append({
|
246
|
+
"id": call.get("id", f"call_{len(formatted_calls)}"),
|
247
|
+
"type": "function",
|
248
|
+
"function": {
|
249
|
+
"name": call.get("function", {}).get("name", ""),
|
250
|
+
"arguments": call.get("function", {}).get("arguments", "{}")
|
251
|
+
}
|
252
|
+
})
|
253
|
+
|
254
|
+
return None, formatted_calls
|
255
|
+
except json.JSONDecodeError:
|
256
|
+
pass
|
257
|
+
|
258
|
+
# No tool calls found, return content as-is
|
259
|
+
return content, None
|
260
|
+
|
261
|
+
async def astream(self, input_data: Union[str, List[Dict[str, str]], Any], show_reasoning: bool = False, **kwargs) -> AsyncGenerator[Union[str, Dict[str, Any]], None]:
|
262
|
+
"""
|
263
|
+
True streaming method - yields tokens one by one as they arrive
|
264
|
+
|
265
|
+
Args:
|
266
|
+
input_data: Same as ainvoke
|
267
|
+
show_reasoning: If True and model supports it, enable chain-of-thought reasoning
|
268
|
+
**kwargs: Additional parameters
|
269
|
+
|
270
|
+
Yields:
|
271
|
+
Individual tokens as they arrive from the API, plus final result with tool_calls
|
272
|
+
"""
|
273
|
+
try:
|
274
|
+
# Use adapter manager to prepare messages
|
275
|
+
messages = self._prepare_messages(input_data)
|
276
|
+
|
277
|
+
# Add reasoning configuration for gpt-oss-120b
|
278
|
+
if show_reasoning and self.is_reasoning_model:
|
279
|
+
messages = self._add_reasoning_instruction(messages, level="high")
|
280
|
+
|
281
|
+
# Check if we have bound tools
|
282
|
+
tool_schemas = await self._prepare_tools_for_request()
|
283
|
+
has_tools = bool(tool_schemas)
|
284
|
+
|
285
|
+
# Modify messages for tool calling emulation
|
286
|
+
if has_tools:
|
287
|
+
messages = self._create_tool_calling_prompt(messages, tool_schemas)
|
288
|
+
|
289
|
+
# Prepare request kwargs
|
290
|
+
provider_config = self.get_provider_config()
|
291
|
+
|
292
|
+
# Stream tokens
|
293
|
+
content_chunks = []
|
294
|
+
|
295
|
+
try:
|
296
|
+
stream = self.client.chat.completions.create(
|
297
|
+
model=self.model_name,
|
298
|
+
messages=messages,
|
299
|
+
stream=True,
|
300
|
+
temperature=provider_config.get("temperature", 0.7),
|
301
|
+
max_tokens=provider_config.get("max_tokens", 1024),
|
302
|
+
)
|
303
|
+
|
304
|
+
for chunk in stream:
|
305
|
+
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
|
306
|
+
content = chunk.choices[0].delta.content
|
307
|
+
content_chunks.append(content)
|
308
|
+
|
309
|
+
# Only stream content if no tools (tool responses should be complete JSON)
|
310
|
+
if not has_tools:
|
311
|
+
yield content
|
312
|
+
|
313
|
+
# Process complete response
|
314
|
+
full_content = "".join(content_chunks)
|
315
|
+
|
316
|
+
# Parse for tool calls if tools are bound
|
317
|
+
text_content, tool_calls = None, None
|
318
|
+
if has_tools:
|
319
|
+
text_content, tool_calls = self._parse_tool_calling_response(full_content)
|
320
|
+
else:
|
321
|
+
text_content = full_content
|
322
|
+
|
323
|
+
# Track usage
|
324
|
+
self._track_streaming_usage(messages, full_content)
|
325
|
+
await asyncio.sleep(0.01)
|
326
|
+
|
327
|
+
# Create response object
|
328
|
+
if tool_calls:
|
329
|
+
# Create mock message with tool calls
|
330
|
+
class MockMessage:
|
331
|
+
def __init__(self):
|
332
|
+
self.content = text_content or ""
|
333
|
+
self.tool_calls = []
|
334
|
+
for tc in tool_calls:
|
335
|
+
mock_tc = type('MockToolCall', (), {
|
336
|
+
'id': tc['id'],
|
337
|
+
'function': type('MockFunction', (), {
|
338
|
+
'name': tc['function']['name'],
|
339
|
+
'arguments': tc['function']['arguments']
|
340
|
+
})()
|
341
|
+
})()
|
342
|
+
self.tool_calls.append(mock_tc)
|
343
|
+
|
344
|
+
final_result = self._format_response(MockMessage(), input_data)
|
345
|
+
else:
|
346
|
+
# Stream the content if we haven't already
|
347
|
+
if has_tools and text_content:
|
348
|
+
yield text_content
|
349
|
+
final_result = self._format_response(text_content or "", input_data)
|
350
|
+
|
351
|
+
# Yield final result
|
352
|
+
yield {
|
353
|
+
"result": final_result,
|
354
|
+
"billing": self._get_streaming_billing_info(),
|
355
|
+
"api_used": "cerebras"
|
356
|
+
}
|
357
|
+
|
358
|
+
except Exception as e:
|
359
|
+
logger.error(f"Error in Cerebras streaming: {e}")
|
360
|
+
raise
|
361
|
+
|
362
|
+
except Exception as e:
|
363
|
+
logger.error(f"Error in astream: {e}")
|
364
|
+
raise
|
365
|
+
|
366
|
+
async def ainvoke(self, input_data: Union[str, List[Dict[str, str]], Any], show_reasoning: bool = False, **kwargs) -> Union[str, Any]:
|
367
|
+
"""
|
368
|
+
Unified invoke method for all input types with tool calling emulation
|
369
|
+
|
370
|
+
Args:
|
371
|
+
input_data: Input messages or text
|
372
|
+
show_reasoning: If True and model supports it, enable chain-of-thought reasoning
|
373
|
+
**kwargs: Additional parameters
|
374
|
+
"""
|
375
|
+
try:
|
376
|
+
# Use adapter manager to prepare messages
|
377
|
+
messages = self._prepare_messages(input_data)
|
378
|
+
|
379
|
+
# Add reasoning configuration for gpt-oss-120b
|
380
|
+
if show_reasoning and self.is_reasoning_model:
|
381
|
+
messages = self._add_reasoning_instruction(messages, level="high")
|
382
|
+
|
383
|
+
# Check if we have bound tools
|
384
|
+
tool_schemas = await self._prepare_tools_for_request()
|
385
|
+
has_tools = bool(tool_schemas)
|
386
|
+
|
387
|
+
# Modify messages for tool calling emulation
|
388
|
+
if has_tools:
|
389
|
+
messages = self._create_tool_calling_prompt(messages, tool_schemas)
|
390
|
+
|
391
|
+
# Prepare request kwargs
|
392
|
+
provider_config = self.get_provider_config()
|
393
|
+
|
394
|
+
# Handle streaming vs non-streaming
|
395
|
+
if self.streaming:
|
396
|
+
# Streaming mode - collect all chunks
|
397
|
+
content_chunks = []
|
398
|
+
async for token in self.astream(input_data, show_reasoning=show_reasoning, **kwargs):
|
399
|
+
if isinstance(token, str):
|
400
|
+
content_chunks.append(token)
|
401
|
+
elif isinstance(token, dict) and "result" in token:
|
402
|
+
return token["result"]
|
403
|
+
|
404
|
+
# Fallback
|
405
|
+
content = "".join(content_chunks)
|
406
|
+
return self._format_response(content, input_data)
|
407
|
+
else:
|
408
|
+
# Non-streaming mode
|
409
|
+
response = self.client.chat.completions.create(
|
410
|
+
model=self.model_name,
|
411
|
+
messages=messages,
|
412
|
+
temperature=provider_config.get("temperature", 0.7),
|
413
|
+
max_tokens=provider_config.get("max_tokens", 1024),
|
414
|
+
)
|
415
|
+
|
416
|
+
content = response.choices[0].message.content or ""
|
417
|
+
|
418
|
+
# Update usage tracking
|
419
|
+
if response.usage:
|
420
|
+
self._update_token_usage(response.usage)
|
421
|
+
await self._track_billing(response.usage)
|
422
|
+
|
423
|
+
# Parse for tool calls if tools are bound
|
424
|
+
if has_tools:
|
425
|
+
text_content, tool_calls = self._parse_tool_calling_response(content)
|
426
|
+
|
427
|
+
if tool_calls:
|
428
|
+
# Create mock message with tool calls
|
429
|
+
class MockMessage:
|
430
|
+
def __init__(self):
|
431
|
+
self.content = text_content or ""
|
432
|
+
self.tool_calls = []
|
433
|
+
for tc in tool_calls:
|
434
|
+
mock_tc = type('MockToolCall', (), {
|
435
|
+
'id': tc['id'],
|
436
|
+
'function': type('MockFunction', (), {
|
437
|
+
'name': tc['function']['name'],
|
438
|
+
'arguments': tc['function']['arguments']
|
439
|
+
})()
|
440
|
+
})()
|
441
|
+
self.tool_calls.append(mock_tc)
|
442
|
+
|
443
|
+
return self._format_response(MockMessage(), input_data)
|
444
|
+
|
445
|
+
# No tool calls, return content
|
446
|
+
return self._format_response(content, input_data)
|
447
|
+
|
448
|
+
except Exception as e:
|
449
|
+
logger.error(f"Error in ainvoke: {e}")
|
450
|
+
raise
|
451
|
+
|
452
|
+
def _track_streaming_usage(self, messages: List[Dict[str, str]], content: str):
|
453
|
+
"""Track usage for streaming requests (estimated)"""
|
454
|
+
class MockUsage:
|
455
|
+
def __init__(self):
|
456
|
+
self.prompt_tokens = len(str(messages)) // 4
|
457
|
+
self.completion_tokens = len(content) // 4
|
458
|
+
self.total_tokens = self.prompt_tokens + self.completion_tokens
|
459
|
+
|
460
|
+
usage = MockUsage()
|
461
|
+
self._update_token_usage(usage)
|
462
|
+
|
463
|
+
# Fire and forget async tracking
|
464
|
+
try:
|
465
|
+
loop = asyncio.get_event_loop()
|
466
|
+
loop.create_task(self._track_billing(usage))
|
467
|
+
except:
|
468
|
+
pass
|
469
|
+
|
470
|
+
def _update_token_usage(self, usage):
|
471
|
+
"""Update token usage statistics"""
|
472
|
+
self.last_token_usage = {
|
473
|
+
"prompt_tokens": usage.prompt_tokens,
|
474
|
+
"completion_tokens": usage.completion_tokens,
|
475
|
+
"total_tokens": usage.total_tokens
|
476
|
+
}
|
477
|
+
|
478
|
+
# Update total usage
|
479
|
+
self.total_token_usage["prompt_tokens"] += self.last_token_usage["prompt_tokens"]
|
480
|
+
self.total_token_usage["completion_tokens"] += self.last_token_usage["completion_tokens"]
|
481
|
+
self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
|
482
|
+
self.total_token_usage["requests_count"] += 1
|
483
|
+
|
484
|
+
async def _track_billing(self, usage):
|
485
|
+
"""Track billing information"""
|
486
|
+
provider_config = self.get_provider_config()
|
487
|
+
|
488
|
+
await self._track_usage(
|
489
|
+
service_type=ServiceType.LLM,
|
490
|
+
operation="chat",
|
491
|
+
input_tokens=usage.prompt_tokens,
|
492
|
+
output_tokens=usage.completion_tokens,
|
493
|
+
metadata={
|
494
|
+
"temperature": provider_config.get("temperature", 0.7),
|
495
|
+
"max_tokens": provider_config.get("max_tokens", 1024),
|
496
|
+
"inference_speed": "ultra-fast"
|
497
|
+
}
|
498
|
+
)
|
499
|
+
|
500
|
+
def get_token_usage(self) -> Dict[str, Any]:
|
501
|
+
"""Get total token usage statistics"""
|
502
|
+
return self.total_token_usage
|
503
|
+
|
504
|
+
def get_last_token_usage(self) -> Dict[str, int]:
|
505
|
+
"""Get token usage from last request"""
|
506
|
+
return self.last_token_usage
|
507
|
+
|
508
|
+
def get_model_info(self) -> Dict[str, Any]:
|
509
|
+
"""Get information about the current model"""
|
510
|
+
provider_config = self.get_provider_config()
|
511
|
+
|
512
|
+
# Model specifications
|
513
|
+
model_specs = {
|
514
|
+
"llama-4-scout-17b-16e-instruct": {
|
515
|
+
"params": "109B",
|
516
|
+
"speed_tokens_per_sec": 2600,
|
517
|
+
"description": "Llama 4 Scout - High performance instruction following"
|
518
|
+
},
|
519
|
+
"llama3.1-8b": {
|
520
|
+
"params": "8B",
|
521
|
+
"speed_tokens_per_sec": 2200,
|
522
|
+
"description": "Llama 3.1 8B - Fast and efficient"
|
523
|
+
},
|
524
|
+
"llama-3.3-70b": {
|
525
|
+
"params": "70B",
|
526
|
+
"speed_tokens_per_sec": 2100,
|
527
|
+
"description": "Llama 3.3 70B - Powerful reasoning"
|
528
|
+
},
|
529
|
+
"gpt-oss-120b": {
|
530
|
+
"params": "120B",
|
531
|
+
"speed_tokens_per_sec": 3000,
|
532
|
+
"description": "OpenAI GPT OSS - Ultra-fast inference"
|
533
|
+
},
|
534
|
+
"qwen-3-32b": {
|
535
|
+
"params": "32B",
|
536
|
+
"speed_tokens_per_sec": 2600,
|
537
|
+
"description": "Qwen 3 32B - Balanced performance"
|
538
|
+
}
|
539
|
+
}
|
540
|
+
|
541
|
+
specs = model_specs.get(self.model_name, {
|
542
|
+
"params": "Unknown",
|
543
|
+
"speed_tokens_per_sec": 2000,
|
544
|
+
"description": "Cerebras model"
|
545
|
+
})
|
546
|
+
|
547
|
+
return {
|
548
|
+
"name": self.model_name,
|
549
|
+
"max_tokens": provider_config.get("max_tokens", 1024),
|
550
|
+
"supports_streaming": True,
|
551
|
+
"supports_functions": True, # Emulated via prompt engineering
|
552
|
+
"supports_reasoning": self.is_reasoning_model, # Native CoT for gpt-oss-120b
|
553
|
+
"provider": "cerebras",
|
554
|
+
"inference_speed_tokens_per_sec": specs["speed_tokens_per_sec"],
|
555
|
+
"parameters": specs["params"],
|
556
|
+
"description": specs["description"],
|
557
|
+
"tool_calling_method": "emulated_via_prompt",
|
558
|
+
"reasoning_method": "native_system_message" if self.is_reasoning_model else "none"
|
559
|
+
}
|
560
|
+
|
561
|
+
async def chat(
|
562
|
+
self,
|
563
|
+
input_data: Union[str, List[Dict[str, str]], Any],
|
564
|
+
max_tokens: Optional[int] = None,
|
565
|
+
show_reasoning: bool = False
|
566
|
+
) -> Dict[str, Any]:
|
567
|
+
"""
|
568
|
+
Chat method that wraps ainvoke for compatibility with base class
|
569
|
+
"""
|
570
|
+
try:
|
571
|
+
response = await self.ainvoke(input_data)
|
572
|
+
|
573
|
+
return {
|
574
|
+
"message": response,
|
575
|
+
"success": True,
|
576
|
+
"metadata": {
|
577
|
+
"model": self.model_name,
|
578
|
+
"provider": self.provider_name,
|
579
|
+
"max_tokens": max_tokens or self.max_tokens,
|
580
|
+
"ultra_fast_inference": True
|
581
|
+
}
|
582
|
+
}
|
583
|
+
except Exception as e:
|
584
|
+
logger.error(f"Chat method failed: {e}")
|
585
|
+
return {
|
586
|
+
"message": None,
|
587
|
+
"success": False,
|
588
|
+
"error": str(e),
|
589
|
+
"metadata": {
|
590
|
+
"model": self.model_name,
|
591
|
+
"provider": self.provider_name
|
592
|
+
}
|
593
|
+
}
|
594
|
+
|
595
|
+
async def close(self):
|
596
|
+
"""Close the backend client"""
|
597
|
+
# Cerebras SDK client doesn't have a close method
|
598
|
+
pass
|
599
|
+
|
600
|
+
def _get_streaming_billing_info(self) -> Dict[str, Any]:
|
601
|
+
"""Get billing information for streaming requests"""
|
602
|
+
try:
|
603
|
+
last_usage = self.get_last_token_usage()
|
604
|
+
estimated_cost = 0.0
|
605
|
+
|
606
|
+
if hasattr(self, 'model_manager'):
|
607
|
+
estimated_cost = self.model_manager.calculate_cost(
|
608
|
+
provider=self.provider_name,
|
609
|
+
model_name=self.model_name,
|
610
|
+
input_tokens=last_usage.get("prompt_tokens", 0),
|
611
|
+
output_tokens=last_usage.get("completion_tokens", 0)
|
612
|
+
)
|
613
|
+
|
614
|
+
return {
|
615
|
+
"cost_usd": estimated_cost,
|
616
|
+
"input_tokens": last_usage.get("prompt_tokens", 0),
|
617
|
+
"output_tokens": last_usage.get("completion_tokens", 0),
|
618
|
+
"total_tokens": last_usage.get("total_tokens", 0),
|
619
|
+
"operation": "chat",
|
620
|
+
"currency": "USD"
|
621
|
+
}
|
622
|
+
except Exception as e:
|
623
|
+
logger.warning(f"Failed to get streaming billing info: {e}")
|
624
|
+
return {
|
625
|
+
"cost_usd": 0.0,
|
626
|
+
"error": str(e),
|
627
|
+
"currency": "USD"
|
628
|
+
}
|