isa-model 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. isa_model/__init__.py +1 -1
  2. isa_model/core/model_manager.py +69 -4
  3. isa_model/core/storage/hf_storage.py +419 -0
  4. isa_model/deployment/__init__.py +52 -0
  5. isa_model/deployment/core/__init__.py +34 -0
  6. isa_model/deployment/core/deployment_config.py +356 -0
  7. isa_model/deployment/core/deployment_manager.py +549 -0
  8. isa_model/deployment/core/isa_deployment_service.py +401 -0
  9. isa_model/eval/factory.py +381 -140
  10. isa_model/inference/ai_factory.py +427 -236
  11. isa_model/inference/billing_tracker.py +406 -0
  12. isa_model/inference/providers/base_provider.py +51 -4
  13. isa_model/inference/providers/ml_provider.py +50 -0
  14. isa_model/inference/providers/ollama_provider.py +37 -18
  15. isa_model/inference/providers/openai_provider.py +65 -36
  16. isa_model/inference/providers/replicate_provider.py +42 -30
  17. isa_model/inference/services/audio/base_stt_service.py +21 -2
  18. isa_model/inference/services/audio/openai_realtime_service.py +353 -0
  19. isa_model/inference/services/audio/openai_stt_service.py +252 -0
  20. isa_model/inference/services/audio/openai_tts_service.py +149 -9
  21. isa_model/inference/services/audio/replicate_tts_service.py +239 -0
  22. isa_model/inference/services/base_service.py +36 -1
  23. isa_model/inference/services/embedding/base_embed_service.py +112 -0
  24. isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
  25. isa_model/inference/services/embedding/openai_embed_service.py +223 -0
  26. isa_model/inference/services/llm/__init__.py +2 -0
  27. isa_model/inference/services/llm/base_llm_service.py +158 -86
  28. isa_model/inference/services/llm/llm_adapter.py +414 -0
  29. isa_model/inference/services/llm/ollama_llm_service.py +252 -63
  30. isa_model/inference/services/llm/openai_llm_service.py +231 -93
  31. isa_model/inference/services/llm/triton_llm_service.py +481 -0
  32. isa_model/inference/services/ml/base_ml_service.py +78 -0
  33. isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
  34. isa_model/inference/services/vision/__init__.py +3 -3
  35. isa_model/inference/services/vision/base_image_gen_service.py +161 -0
  36. isa_model/inference/services/vision/base_vision_service.py +177 -0
  37. isa_model/inference/services/vision/helpers/image_utils.py +4 -3
  38. isa_model/inference/services/vision/ollama_vision_service.py +151 -17
  39. isa_model/inference/services/vision/openai_vision_service.py +275 -41
  40. isa_model/inference/services/vision/replicate_image_gen_service.py +278 -118
  41. isa_model/training/__init__.py +62 -32
  42. isa_model/training/cloud/__init__.py +22 -0
  43. isa_model/training/cloud/job_orchestrator.py +402 -0
  44. isa_model/training/cloud/runpod_trainer.py +454 -0
  45. isa_model/training/cloud/storage_manager.py +482 -0
  46. isa_model/training/core/__init__.py +23 -0
  47. isa_model/training/core/config.py +181 -0
  48. isa_model/training/core/dataset.py +222 -0
  49. isa_model/training/core/trainer.py +720 -0
  50. isa_model/training/core/utils.py +213 -0
  51. isa_model/training/factory.py +229 -198
  52. isa_model-0.3.1.dist-info/METADATA +465 -0
  53. isa_model-0.3.1.dist-info/RECORD +91 -0
  54. isa_model/core/model_router.py +0 -226
  55. isa_model/core/model_version.py +0 -0
  56. isa_model/core/resource_manager.py +0 -202
  57. isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
  58. isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
  59. isa_model/training/engine/llama_factory/__init__.py +0 -39
  60. isa_model/training/engine/llama_factory/config.py +0 -115
  61. isa_model/training/engine/llama_factory/data_adapter.py +0 -284
  62. isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
  63. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
  64. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
  65. isa_model/training/engine/llama_factory/factory.py +0 -331
  66. isa_model/training/engine/llama_factory/rl.py +0 -254
  67. isa_model/training/engine/llama_factory/trainer.py +0 -171
  68. isa_model/training/image_model/configs/create_config.py +0 -37
  69. isa_model/training/image_model/configs/create_flux_config.py +0 -26
  70. isa_model/training/image_model/configs/create_lora_config.py +0 -21
  71. isa_model/training/image_model/prepare_massed_compute.py +0 -97
  72. isa_model/training/image_model/prepare_upload.py +0 -17
  73. isa_model/training/image_model/raw_data/create_captions.py +0 -16
  74. isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
  75. isa_model/training/image_model/raw_data/pre_processing.py +0 -200
  76. isa_model/training/image_model/train/train.py +0 -42
  77. isa_model/training/image_model/train/train_flux.py +0 -41
  78. isa_model/training/image_model/train/train_lora.py +0 -57
  79. isa_model/training/image_model/train_main.py +0 -25
  80. isa_model-0.2.0.dist-info/METADATA +0 -327
  81. isa_model-0.2.0.dist-info/RECORD +0 -92
  82. isa_model-0.2.0.dist-info/licenses/LICENSE +0 -21
  83. /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
  84. /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
  85. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
  86. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
  87. /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
  88. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
  89. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
  90. /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
  91. {isa_model-0.2.0.dist-info → isa_model-0.3.1.dist-info}/WHEEL +0 -0
  92. {isa_model-0.2.0.dist-info → isa_model-0.3.1.dist-info}/top_level.txt +0 -0
@@ -7,51 +7,67 @@ import os
7
7
  logger = logging.getLogger(__name__)
8
8
 
9
9
  class OpenAIProvider(BaseProvider):
10
- """Provider for OpenAI API"""
10
+ """Provider for OpenAI API with proper API key management"""
11
11
 
12
12
  def __init__(self, config=None):
13
- """
14
- Initialize the OpenAI Provider
13
+ """Initialize the OpenAI Provider with centralized config management"""
14
+ super().__init__(config)
15
+ self.name = "openai"
16
+
17
+ logger.info(f"Initialized OpenAIProvider with URL: {self.config.get('base_url', 'https://api.openai.com/v1')}")
15
18
 
16
- Args:
17
- config (dict, optional): Configuration for the provider
18
- - api_key: OpenAI API key (can be passed here or via environment variable)
19
- - api_base: Base URL for OpenAI API (default: https://api.openai.com/v1)
20
- - timeout: Timeout for API calls in seconds
21
- """
22
- default_config = {
23
- "api_key": "", # Will be set from config or environment
24
- "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
19
+ if not self.has_valid_credentials():
20
+ logger.warning("OpenAI API key not found. Set OPENAI_API_KEY environment variable or pass api_key in config.")
21
+
22
+ def _load_provider_env_vars(self):
23
+ """Load OpenAI-specific environment variables"""
24
+ # Set defaults first
25
+ defaults = {
26
+ "base_url": "https://api.openai.com/v1",
25
27
  "timeout": 60,
26
- "stream": True,
27
28
  "temperature": 0.7,
28
29
  "top_p": 0.9,
29
30
  "max_tokens": 1024
30
31
  }
31
32
 
32
- # Merge default config with provided config
33
- merged_config = {**default_config, **(config or {})}
33
+ # Apply defaults only if not already set
34
+ for key, value in defaults.items():
35
+ if key not in self.config:
36
+ self.config[key] = value
34
37
 
35
- # Set API key from config first, then fallback to environment variable
36
- if not merged_config["api_key"]:
37
- merged_config["api_key"] = os.environ.get("OPENAI_API_KEY", "")
38
-
39
- super().__init__(config=merged_config)
40
- self.name = "openai"
41
-
42
- logger.info(f"Initialized OpenAIProvider with URL: {self.config['api_base']}")
38
+ # Load from environment variables (override config if present)
39
+ env_mappings = {
40
+ "api_key": "OPENAI_API_KEY",
41
+ "base_url": "OPENAI_API_BASE",
42
+ "organization": "OPENAI_ORGANIZATION"
43
+ }
43
44
 
44
- # Only warn if no API key is provided at all
45
- if not self.config["api_key"]:
46
- logger.info("OpenAI API key not provided. You can set it via OPENAI_API_KEY environment variable or pass it in the config when creating services.")
45
+ for config_key, env_var in env_mappings.items():
46
+ env_value = os.getenv(env_var)
47
+ if env_value:
48
+ self.config[config_key] = env_value
49
+
50
+ def _validate_config(self):
51
+ """Validate OpenAI configuration"""
52
+ if not self.config.get("api_key"):
53
+ logger.debug("OpenAI API key not set - some functionality may not work")
54
+
55
+ def get_model_pricing(self, model_name: str) -> Dict[str, float]:
56
+ """Get pricing information for a model - delegated to ModelManager"""
57
+ # Import here to avoid circular imports
58
+ from isa_model.core.model_manager import ModelManager
59
+ model_manager = ModelManager()
60
+ return model_manager.get_model_pricing("openai", model_name)
61
+
62
+ def calculate_cost(self, model_name: str, input_tokens: int, output_tokens: int) -> float:
63
+ """Calculate cost for a request - delegated to ModelManager"""
64
+ # Import here to avoid circular imports
65
+ from isa_model.core.model_manager import ModelManager
66
+ model_manager = ModelManager()
67
+ return model_manager.calculate_cost("openai", model_name, input_tokens, output_tokens)
47
68
 
48
69
  def set_api_key(self, api_key: str):
49
- """
50
- Set the API key after initialization
51
-
52
- Args:
53
- api_key: OpenAI API key
54
- """
70
+ """Set the API key after initialization"""
55
71
  self.config["api_key"] = api_key
56
72
  logger.info("OpenAI API key updated")
57
73
 
@@ -77,16 +93,29 @@ class OpenAIProvider(BaseProvider):
77
93
  def get_models(self, model_type: ModelType) -> List[str]:
78
94
  """Get available models for given type"""
79
95
  if model_type == ModelType.LLM:
80
- return ["gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo"]
96
+ return ["gpt-4.1-nano", "gpt-4.1-mini", "gpt-4o-mini", "gpt-4o", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo"]
81
97
  elif model_type == ModelType.EMBEDDING:
82
98
  return ["text-embedding-3-large", "text-embedding-3-small", "text-embedding-ada-002"]
83
99
  elif model_type == ModelType.VISION:
84
- return ["gpt-4o", "gpt-4-vision-preview"]
100
+ return ["gpt-4.1-nano", "gpt-4.1-mini", "gpt-4o-mini", "gpt-4o", "gpt-4-vision-preview"]
85
101
  elif model_type == ModelType.AUDIO:
86
- return ["whisper-1"]
102
+ return ["whisper-1", "gpt-4o-transcribe", "tts-1", "tts-1-hd"]
87
103
  else:
88
104
  return []
89
105
 
106
+ def get_default_model(self, model_type: ModelType) -> str:
107
+ """Get default model for a given type"""
108
+ if model_type == ModelType.LLM:
109
+ return "gpt-4.1-nano" # Cheapest and most cost-effective
110
+ elif model_type == ModelType.EMBEDDING:
111
+ return "text-embedding-3-small"
112
+ elif model_type == ModelType.VISION:
113
+ return "gpt-4.1-nano"
114
+ elif model_type == ModelType.AUDIO:
115
+ return "whisper-1"
116
+ else:
117
+ return ""
118
+
90
119
  def get_config(self) -> Dict[str, Any]:
91
120
  """Get provider configuration"""
92
121
  # Return a copy without sensitive information
@@ -97,5 +126,5 @@ class OpenAIProvider(BaseProvider):
97
126
 
98
127
  def is_reasoning_model(self, model_name: str) -> bool:
99
128
  """Check if the model is optimized for reasoning tasks"""
100
- reasoning_models = ["gpt-4", "gpt-4o", "gpt-4-turbo"]
129
+ reasoning_models = ["gpt-4", "gpt-4o", "gpt-4-turbo", "gpt-4.1"]
101
130
  return any(rm in model_name.lower() for rm in reasoning_models)
@@ -7,47 +7,56 @@ import os
7
7
  logger = logging.getLogger(__name__)
8
8
 
9
9
  class ReplicateProvider(BaseProvider):
10
- """Provider for Replicate API"""
10
+ """Provider for Replicate API with proper API key management"""
11
11
 
12
12
  def __init__(self, config=None):
13
- """
14
- Initialize the Replicate Provider
13
+ """Initialize the Replicate Provider with centralized config management"""
14
+ super().__init__(config)
15
+ self.name = "replicate"
16
+
17
+ logger.info("Initialized ReplicateProvider")
15
18
 
16
- Args:
17
- config (dict, optional): Configuration for the provider
18
- - api_token: Replicate API token (can be passed here or via environment variable)
19
- - timeout: Timeout for API calls in seconds
20
- """
21
- default_config = {
22
- "api_token": "", # Will be set from config or environment
19
+ if not self.has_valid_credentials():
20
+ logger.warning("Replicate API token not found. Set REPLICATE_API_TOKEN environment variable or pass api_token in config.")
21
+
22
+ def _load_provider_env_vars(self):
23
+ """Load Replicate-specific environment variables"""
24
+ # Set defaults first
25
+ defaults = {
23
26
  "timeout": 60,
24
- "stream": True,
25
27
  "max_tokens": 1024
26
28
  }
27
29
 
28
- # Merge default config with provided config
29
- merged_config = {**default_config, **(config or {})}
30
-
31
- # Set API token from config first, then fallback to environment variable
32
- if not merged_config["api_token"]:
33
- merged_config["api_token"] = os.environ.get("REPLICATE_API_TOKEN", "")
34
-
35
- super().__init__(config=merged_config)
36
- self.name = "replicate"
30
+ # Apply defaults only if not already set
31
+ for key, value in defaults.items():
32
+ if key not in self.config:
33
+ self.config[key] = value
37
34
 
38
- logger.info(f"Initialized ReplicateProvider")
35
+ # Load from environment variables (override config if present)
36
+ env_mappings = {
37
+ "api_token": "REPLICATE_API_TOKEN",
38
+ }
39
39
 
40
- # Only warn if no API token is provided at all
41
- if not self.config["api_token"]:
42
- logger.info("Replicate API token not provided. You can set it via REPLICATE_API_TOKEN environment variable or pass it in the config when creating services.")
40
+ for config_key, env_var in env_mappings.items():
41
+ env_value = os.getenv(env_var)
42
+ if env_value:
43
+ self.config[config_key] = env_value
44
+
45
+ def _validate_config(self):
46
+ """Validate Replicate configuration"""
47
+ if not self.config.get("api_token"):
48
+ logger.debug("Replicate API token not set - some functionality may not work")
49
+
50
+ def get_api_key(self) -> str:
51
+ """Get the API token for this provider (override for Replicate naming)"""
52
+ return self.config.get("api_token", "")
53
+
54
+ def has_valid_credentials(self) -> bool:
55
+ """Check if provider has valid credentials (override for Replicate naming)"""
56
+ return bool(self.config.get("api_token"))
43
57
 
44
58
  def set_api_token(self, api_token: str):
45
- """
46
- Set the API token after initialization
47
-
48
- Args:
49
- api_token: Replicate API token
50
- """
59
+ """Set the API token after initialization"""
51
60
  self.config["api_token"] = api_token
52
61
  logger.info("Replicate API token updated")
53
62
 
@@ -79,6 +88,8 @@ class ReplicateProvider(BaseProvider):
79
88
  ]
80
89
  elif model_type == ModelType.VISION:
81
90
  return [
91
+ "black-forest-labs/flux-schnell",
92
+ "black-forest-labs/flux-kontext-pro",
82
93
  "stability-ai/sdxl",
83
94
  "stability-ai/stable-diffusion-3-medium",
84
95
  "meta/llama-3-70b-vision",
@@ -87,6 +98,7 @@ class ReplicateProvider(BaseProvider):
87
98
  ]
88
99
  elif model_type == ModelType.AUDIO:
89
100
  return [
101
+ "jaaari/kokoro-82m",
90
102
  "openai/whisper",
91
103
  "suno-ai/bark"
92
104
  ]
@@ -6,7 +6,7 @@ class BaseSTTService(BaseService):
6
6
  """Base class for Speech-to-Text services"""
7
7
 
8
8
  @abstractmethod
9
- async def transcribe_audio(
9
+ async def transcribe(
10
10
  self,
11
11
  audio_file: Union[str, BinaryIO],
12
12
  language: Optional[str] = None,
@@ -30,7 +30,26 @@ class BaseSTTService(BaseService):
30
30
  pass
31
31
 
32
32
  @abstractmethod
33
- async def transcribe_audio_batch(
33
+ async def translate(
34
+ self,
35
+ audio_file: Union[str, BinaryIO]
36
+ ) -> Dict[str, Any]:
37
+ """
38
+ Translate audio file to English text
39
+
40
+ Args:
41
+ audio_file: Path to audio file or file-like object
42
+
43
+ Returns:
44
+ Dict containing translation results with keys:
45
+ - text: The translated text (in English)
46
+ - detected_language: Original language detected
47
+ - confidence: Confidence score (if available)
48
+ """
49
+ pass
50
+
51
+ @abstractmethod
52
+ async def transcribe_batch(
34
53
  self,
35
54
  audio_files: List[Union[str, BinaryIO]],
36
55
  language: Optional[str] = None,
@@ -0,0 +1,353 @@
1
+ import logging
2
+ import json
3
+ import asyncio
4
+ from typing import Dict, Any, List, Optional, Callable, AsyncGenerator
5
+ import aiohttp
6
+ from tenacity import retry, stop_after_attempt, wait_exponential
7
+
8
+ from isa_model.inference.services.base_service import BaseService
9
+ from isa_model.inference.providers.base_provider import BaseProvider
10
+ from isa_model.inference.billing_tracker import ServiceType
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class OpenAIRealtimeService(BaseService):
15
+ """
16
+ OpenAI Realtime API service for real-time audio conversations.
17
+ Uses gpt-4o-mini-realtime-preview model for interactive audio chat.
18
+ """
19
+
20
+ def __init__(self, provider: 'BaseProvider', model_name: str = "gpt-4o-mini-realtime-preview"):
21
+ super().__init__(provider, model_name)
22
+
23
+ self.api_key = self.config.get('api_key')
24
+ self.base_url = self.config.get('api_base', 'https://api.openai.com/v1')
25
+
26
+ # Default session configuration
27
+ self.default_config = {
28
+ "model": self.model_name,
29
+ "modalities": ["audio", "text"],
30
+ "voice": "alloy",
31
+ "input_audio_format": "pcm16",
32
+ "output_audio_format": "pcm16",
33
+ "input_audio_transcription": {
34
+ "model": "whisper-1"
35
+ },
36
+ "turn_detection": None,
37
+ "tools": [],
38
+ "tool_choice": "none",
39
+ "temperature": 0.7,
40
+ "max_response_output_tokens": 200,
41
+ "speed": 1.1,
42
+ "tracing": "auto"
43
+ }
44
+
45
+ logger.info(f"Initialized OpenAIRealtimeService with model '{self.model_name}'")
46
+
47
+ @retry(
48
+ stop=stop_after_attempt(3),
49
+ wait=wait_exponential(multiplier=1, min=4, max=10),
50
+ reraise=True
51
+ )
52
+ async def create_session(
53
+ self,
54
+ instructions: str = "You are a friendly assistant.",
55
+ modalities: Optional[List[str]] = None,
56
+ voice: str = "alloy",
57
+ **kwargs
58
+ ) -> Dict[str, Any]:
59
+ """Create a new realtime session"""
60
+ try:
61
+ # Prepare session configuration
62
+ session_config = self.default_config.copy()
63
+ session_config.update({
64
+ "instructions": instructions,
65
+ "modalities": modalities if modalities is not None else ["audio", "text"],
66
+ "voice": voice,
67
+ **kwargs
68
+ })
69
+
70
+ # Create session via REST API
71
+ url = f"{self.base_url}/realtime/sessions"
72
+ headers = {
73
+ "Authorization": f"Bearer {self.api_key}",
74
+ "Content-Type": "application/json"
75
+ }
76
+
77
+ async with aiohttp.ClientSession() as session:
78
+ async with session.post(url, headers=headers, json=session_config) as response:
79
+ if response.status == 200:
80
+ result = await response.json()
81
+
82
+ # Track usage for billing
83
+ self._track_usage(
84
+ service_type=ServiceType.AUDIO_STT, # Realtime combines STT/TTS
85
+ operation="create_session",
86
+ metadata={
87
+ "session_id": result.get("id"),
88
+ "model": self.model_name,
89
+ "modalities": session_config["modalities"]
90
+ }
91
+ )
92
+
93
+ return result
94
+ else:
95
+ error_text = await response.text()
96
+ raise Exception(f"Failed to create session: {response.status} - {error_text}")
97
+
98
+ except Exception as e:
99
+ logger.error(f"Error creating realtime session: {e}")
100
+ raise
101
+
102
+ async def connect_websocket(self, session_id: str) -> aiohttp.ClientWebSocketResponse:
103
+ """Connect to the realtime WebSocket for a session"""
104
+ try:
105
+ ws_url = f"wss://api.openai.com/v1/realtime/sessions/{session_id}/ws"
106
+ headers = {
107
+ "Authorization": f"Bearer {self.api_key}",
108
+ "OpenAI-Beta": "realtime=v1"
109
+ }
110
+
111
+ session = aiohttp.ClientSession()
112
+ ws = await session.ws_connect(ws_url, headers=headers)
113
+
114
+ logger.info(f"Connected to realtime WebSocket for session {session_id}")
115
+ return ws
116
+
117
+ except Exception as e:
118
+ logger.error(f"Error connecting to WebSocket: {e}")
119
+ raise
120
+
121
+ async def send_audio_message(
122
+ self,
123
+ ws: aiohttp.ClientWebSocketResponse,
124
+ audio_data: bytes,
125
+ format: str = "pcm16"
126
+ ):
127
+ """Send audio data to the realtime session"""
128
+ try:
129
+ message = {
130
+ "type": "input_audio_buffer.append",
131
+ "audio": audio_data.hex() if format == "pcm16" else audio_data
132
+ }
133
+
134
+ await ws.send_str(json.dumps(message))
135
+
136
+ # Commit the audio buffer
137
+ commit_message = {"type": "input_audio_buffer.commit"}
138
+ await ws.send_str(json.dumps(commit_message))
139
+
140
+ except Exception as e:
141
+ logger.error(f"Error sending audio message: {e}")
142
+ raise
143
+
144
+ async def send_text_message(
145
+ self,
146
+ ws: aiohttp.ClientWebSocketResponse,
147
+ text: str
148
+ ):
149
+ """Send text message to the realtime session"""
150
+ try:
151
+ message = {
152
+ "type": "conversation.item.create",
153
+ "item": {
154
+ "type": "message",
155
+ "role": "user",
156
+ "content": [
157
+ {
158
+ "type": "input_text",
159
+ "text": text
160
+ }
161
+ ]
162
+ }
163
+ }
164
+
165
+ await ws.send_str(json.dumps(message))
166
+
167
+ # Trigger response
168
+ response_message = {"type": "response.create"}
169
+ await ws.send_str(json.dumps(response_message))
170
+
171
+ except Exception as e:
172
+ logger.error(f"Error sending text message: {e}")
173
+ raise
174
+
175
+ async def listen_for_responses(
176
+ self,
177
+ ws: aiohttp.ClientWebSocketResponse,
178
+ message_handler: Optional[Callable] = None
179
+ ) -> AsyncGenerator[Dict[str, Any], None]:
180
+ """Listen for responses from the realtime session"""
181
+ try:
182
+ async for msg in ws:
183
+ if msg.type == aiohttp.WSMsgType.TEXT:
184
+ try:
185
+ data = json.loads(msg.data)
186
+
187
+ # Handle different message types
188
+ if data.get("type") == "response.audio.delta":
189
+ # Audio response chunk
190
+ yield {
191
+ "type": "audio",
192
+ "data": data.get("delta", ""),
193
+ "format": "pcm16"
194
+ }
195
+ elif data.get("type") == "response.text.delta":
196
+ # Text response chunk
197
+ yield {
198
+ "type": "text",
199
+ "data": data.get("delta", "")
200
+ }
201
+ elif data.get("type") == "response.done":
202
+ # Response completed
203
+ usage = data.get("response", {}).get("usage", {})
204
+
205
+ # Track usage for billing
206
+ self._track_usage(
207
+ service_type=ServiceType.AUDIO_STT,
208
+ operation="realtime_response",
209
+ input_tokens=usage.get("input_tokens", 0),
210
+ output_tokens=usage.get("output_tokens", 0),
211
+ metadata={
212
+ "response_id": data.get("response", {}).get("id"),
213
+ "model": self.model_name
214
+ }
215
+ )
216
+
217
+ yield {
218
+ "type": "done",
219
+ "usage": usage
220
+ }
221
+
222
+ # Call custom message handler if provided
223
+ if message_handler:
224
+ await message_handler(data)
225
+
226
+ except json.JSONDecodeError as e:
227
+ logger.error(f"Error parsing WebSocket message: {e}")
228
+ continue
229
+
230
+ elif msg.type == aiohttp.WSMsgType.ERROR:
231
+ logger.error(f"WebSocket error: {ws.exception()}")
232
+ break
233
+
234
+ except Exception as e:
235
+ logger.error(f"Error listening for responses: {e}")
236
+ raise
237
+
238
+ async def simple_audio_chat(
239
+ self,
240
+ audio_data: bytes,
241
+ instructions: str = "You are a helpful assistant. Respond in audio.",
242
+ voice: str = "alloy"
243
+ ) -> Dict[str, Any]:
244
+ """Simple audio chat - send audio, get audio response"""
245
+ try:
246
+ # Create session
247
+ session = await self.create_session(
248
+ instructions=instructions,
249
+ modalities=["audio"],
250
+ voice=voice
251
+ )
252
+ session_id = session["id"]
253
+
254
+ # Connect to WebSocket
255
+ ws = await self.connect_websocket(session_id)
256
+
257
+ try:
258
+ # Send audio
259
+ await self.send_audio_message(ws, audio_data)
260
+
261
+ # Collect response
262
+ audio_chunks = []
263
+ usage_info = {}
264
+
265
+ async for response in self.listen_for_responses(ws):
266
+ if response["type"] == "audio":
267
+ audio_chunks.append(response["data"])
268
+ elif response["type"] == "done":
269
+ usage_info = response["usage"]
270
+ break
271
+
272
+ # Combine audio chunks
273
+ full_audio = "".join(audio_chunks)
274
+
275
+ return {
276
+ "audio_response": full_audio,
277
+ "session_id": session_id,
278
+ "usage": usage_info
279
+ }
280
+
281
+ finally:
282
+ await ws.close()
283
+
284
+ except Exception as e:
285
+ logger.error(f"Error in simple audio chat: {e}")
286
+ raise
287
+
288
+ async def simple_text_chat(
289
+ self,
290
+ text: str,
291
+ instructions: str = "You are a helpful assistant.",
292
+ voice: str = "alloy"
293
+ ) -> Dict[str, Any]:
294
+ """Simple text chat - send text, get audio response"""
295
+ try:
296
+ # Create session
297
+ session = await self.create_session(
298
+ instructions=instructions,
299
+ modalities=["text", "audio"],
300
+ voice=voice
301
+ )
302
+ session_id = session["id"]
303
+
304
+ # Connect to WebSocket
305
+ ws = await self.connect_websocket(session_id)
306
+
307
+ try:
308
+ # Send text
309
+ await self.send_text_message(ws, text)
310
+
311
+ # Collect response
312
+ text_response = ""
313
+ audio_chunks = []
314
+ usage_info = {}
315
+
316
+ async for response in self.listen_for_responses(ws):
317
+ if response["type"] == "text":
318
+ text_response += response["data"]
319
+ elif response["type"] == "audio":
320
+ audio_chunks.append(response["data"])
321
+ elif response["type"] == "done":
322
+ usage_info = response["usage"]
323
+ break
324
+
325
+ # Combine audio chunks
326
+ full_audio = "".join(audio_chunks)
327
+
328
+ return {
329
+ "text_response": text_response,
330
+ "audio_response": full_audio,
331
+ "session_id": session_id,
332
+ "usage": usage_info
333
+ }
334
+
335
+ finally:
336
+ await ws.close()
337
+
338
+ except Exception as e:
339
+ logger.error(f"Error in simple text chat: {e}")
340
+ raise
341
+
342
+ def get_supported_voices(self) -> List[str]:
343
+ """Get list of supported voice options"""
344
+ return ["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
345
+
346
+ def get_supported_formats(self) -> List[str]:
347
+ """Get list of supported audio formats"""
348
+ return ["pcm16", "g711_ulaw", "g711_alaw"]
349
+
350
+ async def close(self):
351
+ """Cleanup resources"""
352
+ # No persistent connections to close for REST API
353
+ pass