isa-model 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/core/model_manager.py +69 -4
- isa_model/core/storage/hf_storage.py +419 -0
- isa_model/deployment/__init__.py +52 -0
- isa_model/deployment/core/__init__.py +34 -0
- isa_model/deployment/core/deployment_config.py +356 -0
- isa_model/deployment/core/deployment_manager.py +549 -0
- isa_model/deployment/core/isa_deployment_service.py +401 -0
- isa_model/eval/factory.py +381 -140
- isa_model/inference/ai_factory.py +427 -236
- isa_model/inference/billing_tracker.py +406 -0
- isa_model/inference/providers/base_provider.py +51 -4
- isa_model/inference/providers/ml_provider.py +50 -0
- isa_model/inference/providers/ollama_provider.py +37 -18
- isa_model/inference/providers/openai_provider.py +65 -36
- isa_model/inference/providers/replicate_provider.py +42 -30
- isa_model/inference/services/audio/base_stt_service.py +21 -2
- isa_model/inference/services/audio/openai_realtime_service.py +353 -0
- isa_model/inference/services/audio/openai_stt_service.py +252 -0
- isa_model/inference/services/audio/openai_tts_service.py +149 -9
- isa_model/inference/services/audio/replicate_tts_service.py +239 -0
- isa_model/inference/services/base_service.py +36 -1
- isa_model/inference/services/embedding/base_embed_service.py +112 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
- isa_model/inference/services/embedding/openai_embed_service.py +223 -0
- isa_model/inference/services/llm/__init__.py +2 -0
- isa_model/inference/services/llm/base_llm_service.py +158 -86
- isa_model/inference/services/llm/llm_adapter.py +414 -0
- isa_model/inference/services/llm/ollama_llm_service.py +252 -63
- isa_model/inference/services/llm/openai_llm_service.py +231 -93
- isa_model/inference/services/llm/triton_llm_service.py +481 -0
- isa_model/inference/services/ml/base_ml_service.py +78 -0
- isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
- isa_model/inference/services/vision/__init__.py +3 -3
- isa_model/inference/services/vision/base_image_gen_service.py +161 -0
- isa_model/inference/services/vision/base_vision_service.py +177 -0
- isa_model/inference/services/vision/helpers/image_utils.py +4 -3
- isa_model/inference/services/vision/ollama_vision_service.py +151 -17
- isa_model/inference/services/vision/openai_vision_service.py +275 -41
- isa_model/inference/services/vision/replicate_image_gen_service.py +278 -118
- isa_model/training/__init__.py +62 -32
- isa_model/training/cloud/__init__.py +22 -0
- isa_model/training/cloud/job_orchestrator.py +402 -0
- isa_model/training/cloud/runpod_trainer.py +454 -0
- isa_model/training/cloud/storage_manager.py +482 -0
- isa_model/training/core/__init__.py +23 -0
- isa_model/training/core/config.py +181 -0
- isa_model/training/core/dataset.py +222 -0
- isa_model/training/core/trainer.py +720 -0
- isa_model/training/core/utils.py +213 -0
- isa_model/training/factory.py +229 -198
- isa_model-0.3.1.dist-info/METADATA +465 -0
- isa_model-0.3.1.dist-info/RECORD +91 -0
- isa_model/core/model_router.py +0 -226
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +0 -202
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
- isa_model/training/engine/llama_factory/__init__.py +0 -39
- isa_model/training/engine/llama_factory/config.py +0 -115
- isa_model/training/engine/llama_factory/data_adapter.py +0 -284
- isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
- isa_model/training/engine/llama_factory/factory.py +0 -331
- isa_model/training/engine/llama_factory/rl.py +0 -254
- isa_model/training/engine/llama_factory/trainer.py +0 -171
- isa_model/training/image_model/configs/create_config.py +0 -37
- isa_model/training/image_model/configs/create_flux_config.py +0 -26
- isa_model/training/image_model/configs/create_lora_config.py +0 -21
- isa_model/training/image_model/prepare_massed_compute.py +0 -97
- isa_model/training/image_model/prepare_upload.py +0 -17
- isa_model/training/image_model/raw_data/create_captions.py +0 -16
- isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
- isa_model/training/image_model/raw_data/pre_processing.py +0 -200
- isa_model/training/image_model/train/train.py +0 -42
- isa_model/training/image_model/train/train_flux.py +0 -41
- isa_model/training/image_model/train/train_lora.py +0 -57
- isa_model/training/image_model/train_main.py +0 -25
- isa_model-0.2.0.dist-info/METADATA +0 -327
- isa_model-0.2.0.dist-info/RECORD +0 -92
- isa_model-0.2.0.dist-info/licenses/LICENSE +0 -21
- /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
- {isa_model-0.2.0.dist-info → isa_model-0.3.1.dist-info}/WHEEL +0 -0
- {isa_model-0.2.0.dist-info → isa_model-0.3.1.dist-info}/top_level.txt +0 -0
@@ -7,51 +7,67 @@ import os
|
|
7
7
|
logger = logging.getLogger(__name__)
|
8
8
|
|
9
9
|
class OpenAIProvider(BaseProvider):
|
10
|
-
"""Provider for OpenAI API"""
|
10
|
+
"""Provider for OpenAI API with proper API key management"""
|
11
11
|
|
12
12
|
def __init__(self, config=None):
|
13
|
-
"""
|
14
|
-
|
13
|
+
"""Initialize the OpenAI Provider with centralized config management"""
|
14
|
+
super().__init__(config)
|
15
|
+
self.name = "openai"
|
16
|
+
|
17
|
+
logger.info(f"Initialized OpenAIProvider with URL: {self.config.get('base_url', 'https://api.openai.com/v1')}")
|
15
18
|
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
"
|
24
|
-
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
19
|
+
if not self.has_valid_credentials():
|
20
|
+
logger.warning("OpenAI API key not found. Set OPENAI_API_KEY environment variable or pass api_key in config.")
|
21
|
+
|
22
|
+
def _load_provider_env_vars(self):
|
23
|
+
"""Load OpenAI-specific environment variables"""
|
24
|
+
# Set defaults first
|
25
|
+
defaults = {
|
26
|
+
"base_url": "https://api.openai.com/v1",
|
25
27
|
"timeout": 60,
|
26
|
-
"stream": True,
|
27
28
|
"temperature": 0.7,
|
28
29
|
"top_p": 0.9,
|
29
30
|
"max_tokens": 1024
|
30
31
|
}
|
31
32
|
|
32
|
-
#
|
33
|
-
|
33
|
+
# Apply defaults only if not already set
|
34
|
+
for key, value in defaults.items():
|
35
|
+
if key not in self.config:
|
36
|
+
self.config[key] = value
|
34
37
|
|
35
|
-
#
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
logger.info(f"Initialized OpenAIProvider with URL: {self.config['api_base']}")
|
38
|
+
# Load from environment variables (override config if present)
|
39
|
+
env_mappings = {
|
40
|
+
"api_key": "OPENAI_API_KEY",
|
41
|
+
"base_url": "OPENAI_API_BASE",
|
42
|
+
"organization": "OPENAI_ORGANIZATION"
|
43
|
+
}
|
43
44
|
|
44
|
-
|
45
|
-
|
46
|
-
|
45
|
+
for config_key, env_var in env_mappings.items():
|
46
|
+
env_value = os.getenv(env_var)
|
47
|
+
if env_value:
|
48
|
+
self.config[config_key] = env_value
|
49
|
+
|
50
|
+
def _validate_config(self):
|
51
|
+
"""Validate OpenAI configuration"""
|
52
|
+
if not self.config.get("api_key"):
|
53
|
+
logger.debug("OpenAI API key not set - some functionality may not work")
|
54
|
+
|
55
|
+
def get_model_pricing(self, model_name: str) -> Dict[str, float]:
|
56
|
+
"""Get pricing information for a model - delegated to ModelManager"""
|
57
|
+
# Import here to avoid circular imports
|
58
|
+
from isa_model.core.model_manager import ModelManager
|
59
|
+
model_manager = ModelManager()
|
60
|
+
return model_manager.get_model_pricing("openai", model_name)
|
61
|
+
|
62
|
+
def calculate_cost(self, model_name: str, input_tokens: int, output_tokens: int) -> float:
|
63
|
+
"""Calculate cost for a request - delegated to ModelManager"""
|
64
|
+
# Import here to avoid circular imports
|
65
|
+
from isa_model.core.model_manager import ModelManager
|
66
|
+
model_manager = ModelManager()
|
67
|
+
return model_manager.calculate_cost("openai", model_name, input_tokens, output_tokens)
|
47
68
|
|
48
69
|
def set_api_key(self, api_key: str):
|
49
|
-
"""
|
50
|
-
Set the API key after initialization
|
51
|
-
|
52
|
-
Args:
|
53
|
-
api_key: OpenAI API key
|
54
|
-
"""
|
70
|
+
"""Set the API key after initialization"""
|
55
71
|
self.config["api_key"] = api_key
|
56
72
|
logger.info("OpenAI API key updated")
|
57
73
|
|
@@ -77,16 +93,29 @@ class OpenAIProvider(BaseProvider):
|
|
77
93
|
def get_models(self, model_type: ModelType) -> List[str]:
|
78
94
|
"""Get available models for given type"""
|
79
95
|
if model_type == ModelType.LLM:
|
80
|
-
return ["gpt-
|
96
|
+
return ["gpt-4.1-nano", "gpt-4.1-mini", "gpt-4o-mini", "gpt-4o", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo"]
|
81
97
|
elif model_type == ModelType.EMBEDDING:
|
82
98
|
return ["text-embedding-3-large", "text-embedding-3-small", "text-embedding-ada-002"]
|
83
99
|
elif model_type == ModelType.VISION:
|
84
|
-
return ["gpt-4o", "gpt-4-vision-preview"]
|
100
|
+
return ["gpt-4.1-nano", "gpt-4.1-mini", "gpt-4o-mini", "gpt-4o", "gpt-4-vision-preview"]
|
85
101
|
elif model_type == ModelType.AUDIO:
|
86
|
-
return ["whisper-1"]
|
102
|
+
return ["whisper-1", "gpt-4o-transcribe", "tts-1", "tts-1-hd"]
|
87
103
|
else:
|
88
104
|
return []
|
89
105
|
|
106
|
+
def get_default_model(self, model_type: ModelType) -> str:
|
107
|
+
"""Get default model for a given type"""
|
108
|
+
if model_type == ModelType.LLM:
|
109
|
+
return "gpt-4.1-nano" # Cheapest and most cost-effective
|
110
|
+
elif model_type == ModelType.EMBEDDING:
|
111
|
+
return "text-embedding-3-small"
|
112
|
+
elif model_type == ModelType.VISION:
|
113
|
+
return "gpt-4.1-nano"
|
114
|
+
elif model_type == ModelType.AUDIO:
|
115
|
+
return "whisper-1"
|
116
|
+
else:
|
117
|
+
return ""
|
118
|
+
|
90
119
|
def get_config(self) -> Dict[str, Any]:
|
91
120
|
"""Get provider configuration"""
|
92
121
|
# Return a copy without sensitive information
|
@@ -97,5 +126,5 @@ class OpenAIProvider(BaseProvider):
|
|
97
126
|
|
98
127
|
def is_reasoning_model(self, model_name: str) -> bool:
|
99
128
|
"""Check if the model is optimized for reasoning tasks"""
|
100
|
-
reasoning_models = ["gpt-4", "gpt-4o", "gpt-4-turbo"]
|
129
|
+
reasoning_models = ["gpt-4", "gpt-4o", "gpt-4-turbo", "gpt-4.1"]
|
101
130
|
return any(rm in model_name.lower() for rm in reasoning_models)
|
@@ -7,47 +7,56 @@ import os
|
|
7
7
|
logger = logging.getLogger(__name__)
|
8
8
|
|
9
9
|
class ReplicateProvider(BaseProvider):
|
10
|
-
"""Provider for Replicate API"""
|
10
|
+
"""Provider for Replicate API with proper API key management"""
|
11
11
|
|
12
12
|
def __init__(self, config=None):
|
13
|
-
"""
|
14
|
-
|
13
|
+
"""Initialize the Replicate Provider with centralized config management"""
|
14
|
+
super().__init__(config)
|
15
|
+
self.name = "replicate"
|
16
|
+
|
17
|
+
logger.info("Initialized ReplicateProvider")
|
15
18
|
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
"""
|
21
|
-
|
22
|
-
|
19
|
+
if not self.has_valid_credentials():
|
20
|
+
logger.warning("Replicate API token not found. Set REPLICATE_API_TOKEN environment variable or pass api_token in config.")
|
21
|
+
|
22
|
+
def _load_provider_env_vars(self):
|
23
|
+
"""Load Replicate-specific environment variables"""
|
24
|
+
# Set defaults first
|
25
|
+
defaults = {
|
23
26
|
"timeout": 60,
|
24
|
-
"stream": True,
|
25
27
|
"max_tokens": 1024
|
26
28
|
}
|
27
29
|
|
28
|
-
#
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
if not merged_config["api_token"]:
|
33
|
-
merged_config["api_token"] = os.environ.get("REPLICATE_API_TOKEN", "")
|
34
|
-
|
35
|
-
super().__init__(config=merged_config)
|
36
|
-
self.name = "replicate"
|
30
|
+
# Apply defaults only if not already set
|
31
|
+
for key, value in defaults.items():
|
32
|
+
if key not in self.config:
|
33
|
+
self.config[key] = value
|
37
34
|
|
38
|
-
|
35
|
+
# Load from environment variables (override config if present)
|
36
|
+
env_mappings = {
|
37
|
+
"api_token": "REPLICATE_API_TOKEN",
|
38
|
+
}
|
39
39
|
|
40
|
-
|
41
|
-
|
42
|
-
|
40
|
+
for config_key, env_var in env_mappings.items():
|
41
|
+
env_value = os.getenv(env_var)
|
42
|
+
if env_value:
|
43
|
+
self.config[config_key] = env_value
|
44
|
+
|
45
|
+
def _validate_config(self):
|
46
|
+
"""Validate Replicate configuration"""
|
47
|
+
if not self.config.get("api_token"):
|
48
|
+
logger.debug("Replicate API token not set - some functionality may not work")
|
49
|
+
|
50
|
+
def get_api_key(self) -> str:
|
51
|
+
"""Get the API token for this provider (override for Replicate naming)"""
|
52
|
+
return self.config.get("api_token", "")
|
53
|
+
|
54
|
+
def has_valid_credentials(self) -> bool:
|
55
|
+
"""Check if provider has valid credentials (override for Replicate naming)"""
|
56
|
+
return bool(self.config.get("api_token"))
|
43
57
|
|
44
58
|
def set_api_token(self, api_token: str):
|
45
|
-
"""
|
46
|
-
Set the API token after initialization
|
47
|
-
|
48
|
-
Args:
|
49
|
-
api_token: Replicate API token
|
50
|
-
"""
|
59
|
+
"""Set the API token after initialization"""
|
51
60
|
self.config["api_token"] = api_token
|
52
61
|
logger.info("Replicate API token updated")
|
53
62
|
|
@@ -79,6 +88,8 @@ class ReplicateProvider(BaseProvider):
|
|
79
88
|
]
|
80
89
|
elif model_type == ModelType.VISION:
|
81
90
|
return [
|
91
|
+
"black-forest-labs/flux-schnell",
|
92
|
+
"black-forest-labs/flux-kontext-pro",
|
82
93
|
"stability-ai/sdxl",
|
83
94
|
"stability-ai/stable-diffusion-3-medium",
|
84
95
|
"meta/llama-3-70b-vision",
|
@@ -87,6 +98,7 @@ class ReplicateProvider(BaseProvider):
|
|
87
98
|
]
|
88
99
|
elif model_type == ModelType.AUDIO:
|
89
100
|
return [
|
101
|
+
"jaaari/kokoro-82m",
|
90
102
|
"openai/whisper",
|
91
103
|
"suno-ai/bark"
|
92
104
|
]
|
@@ -6,7 +6,7 @@ class BaseSTTService(BaseService):
|
|
6
6
|
"""Base class for Speech-to-Text services"""
|
7
7
|
|
8
8
|
@abstractmethod
|
9
|
-
async def
|
9
|
+
async def transcribe(
|
10
10
|
self,
|
11
11
|
audio_file: Union[str, BinaryIO],
|
12
12
|
language: Optional[str] = None,
|
@@ -30,7 +30,26 @@ class BaseSTTService(BaseService):
|
|
30
30
|
pass
|
31
31
|
|
32
32
|
@abstractmethod
|
33
|
-
async def
|
33
|
+
async def translate(
|
34
|
+
self,
|
35
|
+
audio_file: Union[str, BinaryIO]
|
36
|
+
) -> Dict[str, Any]:
|
37
|
+
"""
|
38
|
+
Translate audio file to English text
|
39
|
+
|
40
|
+
Args:
|
41
|
+
audio_file: Path to audio file or file-like object
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
Dict containing translation results with keys:
|
45
|
+
- text: The translated text (in English)
|
46
|
+
- detected_language: Original language detected
|
47
|
+
- confidence: Confidence score (if available)
|
48
|
+
"""
|
49
|
+
pass
|
50
|
+
|
51
|
+
@abstractmethod
|
52
|
+
async def transcribe_batch(
|
34
53
|
self,
|
35
54
|
audio_files: List[Union[str, BinaryIO]],
|
36
55
|
language: Optional[str] = None,
|
@@ -0,0 +1,353 @@
|
|
1
|
+
import logging
|
2
|
+
import json
|
3
|
+
import asyncio
|
4
|
+
from typing import Dict, Any, List, Optional, Callable, AsyncGenerator
|
5
|
+
import aiohttp
|
6
|
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
7
|
+
|
8
|
+
from isa_model.inference.services.base_service import BaseService
|
9
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
10
|
+
from isa_model.inference.billing_tracker import ServiceType
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
class OpenAIRealtimeService(BaseService):
|
15
|
+
"""
|
16
|
+
OpenAI Realtime API service for real-time audio conversations.
|
17
|
+
Uses gpt-4o-mini-realtime-preview model for interactive audio chat.
|
18
|
+
"""
|
19
|
+
|
20
|
+
def __init__(self, provider: 'BaseProvider', model_name: str = "gpt-4o-mini-realtime-preview"):
|
21
|
+
super().__init__(provider, model_name)
|
22
|
+
|
23
|
+
self.api_key = self.config.get('api_key')
|
24
|
+
self.base_url = self.config.get('api_base', 'https://api.openai.com/v1')
|
25
|
+
|
26
|
+
# Default session configuration
|
27
|
+
self.default_config = {
|
28
|
+
"model": self.model_name,
|
29
|
+
"modalities": ["audio", "text"],
|
30
|
+
"voice": "alloy",
|
31
|
+
"input_audio_format": "pcm16",
|
32
|
+
"output_audio_format": "pcm16",
|
33
|
+
"input_audio_transcription": {
|
34
|
+
"model": "whisper-1"
|
35
|
+
},
|
36
|
+
"turn_detection": None,
|
37
|
+
"tools": [],
|
38
|
+
"tool_choice": "none",
|
39
|
+
"temperature": 0.7,
|
40
|
+
"max_response_output_tokens": 200,
|
41
|
+
"speed": 1.1,
|
42
|
+
"tracing": "auto"
|
43
|
+
}
|
44
|
+
|
45
|
+
logger.info(f"Initialized OpenAIRealtimeService with model '{self.model_name}'")
|
46
|
+
|
47
|
+
@retry(
|
48
|
+
stop=stop_after_attempt(3),
|
49
|
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
50
|
+
reraise=True
|
51
|
+
)
|
52
|
+
async def create_session(
|
53
|
+
self,
|
54
|
+
instructions: str = "You are a friendly assistant.",
|
55
|
+
modalities: Optional[List[str]] = None,
|
56
|
+
voice: str = "alloy",
|
57
|
+
**kwargs
|
58
|
+
) -> Dict[str, Any]:
|
59
|
+
"""Create a new realtime session"""
|
60
|
+
try:
|
61
|
+
# Prepare session configuration
|
62
|
+
session_config = self.default_config.copy()
|
63
|
+
session_config.update({
|
64
|
+
"instructions": instructions,
|
65
|
+
"modalities": modalities if modalities is not None else ["audio", "text"],
|
66
|
+
"voice": voice,
|
67
|
+
**kwargs
|
68
|
+
})
|
69
|
+
|
70
|
+
# Create session via REST API
|
71
|
+
url = f"{self.base_url}/realtime/sessions"
|
72
|
+
headers = {
|
73
|
+
"Authorization": f"Bearer {self.api_key}",
|
74
|
+
"Content-Type": "application/json"
|
75
|
+
}
|
76
|
+
|
77
|
+
async with aiohttp.ClientSession() as session:
|
78
|
+
async with session.post(url, headers=headers, json=session_config) as response:
|
79
|
+
if response.status == 200:
|
80
|
+
result = await response.json()
|
81
|
+
|
82
|
+
# Track usage for billing
|
83
|
+
self._track_usage(
|
84
|
+
service_type=ServiceType.AUDIO_STT, # Realtime combines STT/TTS
|
85
|
+
operation="create_session",
|
86
|
+
metadata={
|
87
|
+
"session_id": result.get("id"),
|
88
|
+
"model": self.model_name,
|
89
|
+
"modalities": session_config["modalities"]
|
90
|
+
}
|
91
|
+
)
|
92
|
+
|
93
|
+
return result
|
94
|
+
else:
|
95
|
+
error_text = await response.text()
|
96
|
+
raise Exception(f"Failed to create session: {response.status} - {error_text}")
|
97
|
+
|
98
|
+
except Exception as e:
|
99
|
+
logger.error(f"Error creating realtime session: {e}")
|
100
|
+
raise
|
101
|
+
|
102
|
+
async def connect_websocket(self, session_id: str) -> aiohttp.ClientWebSocketResponse:
|
103
|
+
"""Connect to the realtime WebSocket for a session"""
|
104
|
+
try:
|
105
|
+
ws_url = f"wss://api.openai.com/v1/realtime/sessions/{session_id}/ws"
|
106
|
+
headers = {
|
107
|
+
"Authorization": f"Bearer {self.api_key}",
|
108
|
+
"OpenAI-Beta": "realtime=v1"
|
109
|
+
}
|
110
|
+
|
111
|
+
session = aiohttp.ClientSession()
|
112
|
+
ws = await session.ws_connect(ws_url, headers=headers)
|
113
|
+
|
114
|
+
logger.info(f"Connected to realtime WebSocket for session {session_id}")
|
115
|
+
return ws
|
116
|
+
|
117
|
+
except Exception as e:
|
118
|
+
logger.error(f"Error connecting to WebSocket: {e}")
|
119
|
+
raise
|
120
|
+
|
121
|
+
async def send_audio_message(
|
122
|
+
self,
|
123
|
+
ws: aiohttp.ClientWebSocketResponse,
|
124
|
+
audio_data: bytes,
|
125
|
+
format: str = "pcm16"
|
126
|
+
):
|
127
|
+
"""Send audio data to the realtime session"""
|
128
|
+
try:
|
129
|
+
message = {
|
130
|
+
"type": "input_audio_buffer.append",
|
131
|
+
"audio": audio_data.hex() if format == "pcm16" else audio_data
|
132
|
+
}
|
133
|
+
|
134
|
+
await ws.send_str(json.dumps(message))
|
135
|
+
|
136
|
+
# Commit the audio buffer
|
137
|
+
commit_message = {"type": "input_audio_buffer.commit"}
|
138
|
+
await ws.send_str(json.dumps(commit_message))
|
139
|
+
|
140
|
+
except Exception as e:
|
141
|
+
logger.error(f"Error sending audio message: {e}")
|
142
|
+
raise
|
143
|
+
|
144
|
+
async def send_text_message(
|
145
|
+
self,
|
146
|
+
ws: aiohttp.ClientWebSocketResponse,
|
147
|
+
text: str
|
148
|
+
):
|
149
|
+
"""Send text message to the realtime session"""
|
150
|
+
try:
|
151
|
+
message = {
|
152
|
+
"type": "conversation.item.create",
|
153
|
+
"item": {
|
154
|
+
"type": "message",
|
155
|
+
"role": "user",
|
156
|
+
"content": [
|
157
|
+
{
|
158
|
+
"type": "input_text",
|
159
|
+
"text": text
|
160
|
+
}
|
161
|
+
]
|
162
|
+
}
|
163
|
+
}
|
164
|
+
|
165
|
+
await ws.send_str(json.dumps(message))
|
166
|
+
|
167
|
+
# Trigger response
|
168
|
+
response_message = {"type": "response.create"}
|
169
|
+
await ws.send_str(json.dumps(response_message))
|
170
|
+
|
171
|
+
except Exception as e:
|
172
|
+
logger.error(f"Error sending text message: {e}")
|
173
|
+
raise
|
174
|
+
|
175
|
+
async def listen_for_responses(
|
176
|
+
self,
|
177
|
+
ws: aiohttp.ClientWebSocketResponse,
|
178
|
+
message_handler: Optional[Callable] = None
|
179
|
+
) -> AsyncGenerator[Dict[str, Any], None]:
|
180
|
+
"""Listen for responses from the realtime session"""
|
181
|
+
try:
|
182
|
+
async for msg in ws:
|
183
|
+
if msg.type == aiohttp.WSMsgType.TEXT:
|
184
|
+
try:
|
185
|
+
data = json.loads(msg.data)
|
186
|
+
|
187
|
+
# Handle different message types
|
188
|
+
if data.get("type") == "response.audio.delta":
|
189
|
+
# Audio response chunk
|
190
|
+
yield {
|
191
|
+
"type": "audio",
|
192
|
+
"data": data.get("delta", ""),
|
193
|
+
"format": "pcm16"
|
194
|
+
}
|
195
|
+
elif data.get("type") == "response.text.delta":
|
196
|
+
# Text response chunk
|
197
|
+
yield {
|
198
|
+
"type": "text",
|
199
|
+
"data": data.get("delta", "")
|
200
|
+
}
|
201
|
+
elif data.get("type") == "response.done":
|
202
|
+
# Response completed
|
203
|
+
usage = data.get("response", {}).get("usage", {})
|
204
|
+
|
205
|
+
# Track usage for billing
|
206
|
+
self._track_usage(
|
207
|
+
service_type=ServiceType.AUDIO_STT,
|
208
|
+
operation="realtime_response",
|
209
|
+
input_tokens=usage.get("input_tokens", 0),
|
210
|
+
output_tokens=usage.get("output_tokens", 0),
|
211
|
+
metadata={
|
212
|
+
"response_id": data.get("response", {}).get("id"),
|
213
|
+
"model": self.model_name
|
214
|
+
}
|
215
|
+
)
|
216
|
+
|
217
|
+
yield {
|
218
|
+
"type": "done",
|
219
|
+
"usage": usage
|
220
|
+
}
|
221
|
+
|
222
|
+
# Call custom message handler if provided
|
223
|
+
if message_handler:
|
224
|
+
await message_handler(data)
|
225
|
+
|
226
|
+
except json.JSONDecodeError as e:
|
227
|
+
logger.error(f"Error parsing WebSocket message: {e}")
|
228
|
+
continue
|
229
|
+
|
230
|
+
elif msg.type == aiohttp.WSMsgType.ERROR:
|
231
|
+
logger.error(f"WebSocket error: {ws.exception()}")
|
232
|
+
break
|
233
|
+
|
234
|
+
except Exception as e:
|
235
|
+
logger.error(f"Error listening for responses: {e}")
|
236
|
+
raise
|
237
|
+
|
238
|
+
async def simple_audio_chat(
|
239
|
+
self,
|
240
|
+
audio_data: bytes,
|
241
|
+
instructions: str = "You are a helpful assistant. Respond in audio.",
|
242
|
+
voice: str = "alloy"
|
243
|
+
) -> Dict[str, Any]:
|
244
|
+
"""Simple audio chat - send audio, get audio response"""
|
245
|
+
try:
|
246
|
+
# Create session
|
247
|
+
session = await self.create_session(
|
248
|
+
instructions=instructions,
|
249
|
+
modalities=["audio"],
|
250
|
+
voice=voice
|
251
|
+
)
|
252
|
+
session_id = session["id"]
|
253
|
+
|
254
|
+
# Connect to WebSocket
|
255
|
+
ws = await self.connect_websocket(session_id)
|
256
|
+
|
257
|
+
try:
|
258
|
+
# Send audio
|
259
|
+
await self.send_audio_message(ws, audio_data)
|
260
|
+
|
261
|
+
# Collect response
|
262
|
+
audio_chunks = []
|
263
|
+
usage_info = {}
|
264
|
+
|
265
|
+
async for response in self.listen_for_responses(ws):
|
266
|
+
if response["type"] == "audio":
|
267
|
+
audio_chunks.append(response["data"])
|
268
|
+
elif response["type"] == "done":
|
269
|
+
usage_info = response["usage"]
|
270
|
+
break
|
271
|
+
|
272
|
+
# Combine audio chunks
|
273
|
+
full_audio = "".join(audio_chunks)
|
274
|
+
|
275
|
+
return {
|
276
|
+
"audio_response": full_audio,
|
277
|
+
"session_id": session_id,
|
278
|
+
"usage": usage_info
|
279
|
+
}
|
280
|
+
|
281
|
+
finally:
|
282
|
+
await ws.close()
|
283
|
+
|
284
|
+
except Exception as e:
|
285
|
+
logger.error(f"Error in simple audio chat: {e}")
|
286
|
+
raise
|
287
|
+
|
288
|
+
async def simple_text_chat(
|
289
|
+
self,
|
290
|
+
text: str,
|
291
|
+
instructions: str = "You are a helpful assistant.",
|
292
|
+
voice: str = "alloy"
|
293
|
+
) -> Dict[str, Any]:
|
294
|
+
"""Simple text chat - send text, get audio response"""
|
295
|
+
try:
|
296
|
+
# Create session
|
297
|
+
session = await self.create_session(
|
298
|
+
instructions=instructions,
|
299
|
+
modalities=["text", "audio"],
|
300
|
+
voice=voice
|
301
|
+
)
|
302
|
+
session_id = session["id"]
|
303
|
+
|
304
|
+
# Connect to WebSocket
|
305
|
+
ws = await self.connect_websocket(session_id)
|
306
|
+
|
307
|
+
try:
|
308
|
+
# Send text
|
309
|
+
await self.send_text_message(ws, text)
|
310
|
+
|
311
|
+
# Collect response
|
312
|
+
text_response = ""
|
313
|
+
audio_chunks = []
|
314
|
+
usage_info = {}
|
315
|
+
|
316
|
+
async for response in self.listen_for_responses(ws):
|
317
|
+
if response["type"] == "text":
|
318
|
+
text_response += response["data"]
|
319
|
+
elif response["type"] == "audio":
|
320
|
+
audio_chunks.append(response["data"])
|
321
|
+
elif response["type"] == "done":
|
322
|
+
usage_info = response["usage"]
|
323
|
+
break
|
324
|
+
|
325
|
+
# Combine audio chunks
|
326
|
+
full_audio = "".join(audio_chunks)
|
327
|
+
|
328
|
+
return {
|
329
|
+
"text_response": text_response,
|
330
|
+
"audio_response": full_audio,
|
331
|
+
"session_id": session_id,
|
332
|
+
"usage": usage_info
|
333
|
+
}
|
334
|
+
|
335
|
+
finally:
|
336
|
+
await ws.close()
|
337
|
+
|
338
|
+
except Exception as e:
|
339
|
+
logger.error(f"Error in simple text chat: {e}")
|
340
|
+
raise
|
341
|
+
|
342
|
+
def get_supported_voices(self) -> List[str]:
|
343
|
+
"""Get list of supported voice options"""
|
344
|
+
return ["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
|
345
|
+
|
346
|
+
def get_supported_formats(self) -> List[str]:
|
347
|
+
"""Get list of supported audio formats"""
|
348
|
+
return ["pcm16", "g711_ulaw", "g711_alaw"]
|
349
|
+
|
350
|
+
async def close(self):
|
351
|
+
"""Cleanup resources"""
|
352
|
+
# No persistent connections to close for REST API
|
353
|
+
pass
|