isa-model 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/core/model_registry.py +273 -46
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
- isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
- isa_model/eval/__init__.py +56 -0
- isa_model/eval/benchmarks.py +469 -0
- isa_model/eval/factory.py +582 -0
- isa_model/eval/metrics.py +628 -0
- isa_model/inference/ai_factory.py +98 -93
- isa_model/inference/providers/openai_provider.py +21 -7
- isa_model/inference/providers/replicate_provider.py +18 -5
- isa_model/inference/providers/triton_provider.py +1 -1
- isa_model/inference/services/audio/base_stt_service.py +91 -0
- isa_model/inference/services/audio/base_tts_service.py +136 -0
- isa_model/inference/services/audio/{yyds_audio_service.py → openai_tts_service.py} +4 -4
- isa_model/inference/services/embedding/ollama_embed_service.py +48 -36
- isa_model/inference/services/llm/__init__.py +0 -4
- isa_model/inference/services/llm/base_llm_service.py +134 -0
- isa_model/inference/services/llm/ollama_llm_service.py +1 -10
- isa_model/inference/services/llm/openai_llm_service.py +70 -61
- isa_model/inference/services/vision/__init__.py +1 -1
- isa_model/inference/services/vision/ollama_vision_service.py +4 -4
- isa_model/inference/services/vision/{yyds_vision_service.py → openai_vision_service.py} +5 -5
- isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
- isa_model/training/__init__.py +44 -0
- isa_model/training/factory.py +393 -0
- isa_model-0.2.0.dist-info/METADATA +327 -0
- {isa_model-0.1.0.dist-info → isa_model-0.2.0.dist-info}/RECORD +35 -60
- isa_model/deployment/mlflow_gateway/__init__.py +0 -8
- isa_model/deployment/mlflow_gateway/start_gateway.py +0 -65
- isa_model/deployment/unified_multimodal_client.py +0 -341
- isa_model/inference/adapter/triton_adapter.py +0 -453
- isa_model/inference/backends/Pytorch/bge_embed_backend.py +0 -188
- isa_model/inference/backends/Pytorch/gemma_backend.py +0 -167
- isa_model/inference/backends/Pytorch/llama_backend.py +0 -166
- isa_model/inference/backends/Pytorch/whisper_backend.py +0 -194
- isa_model/inference/backends/__init__.py +0 -53
- isa_model/inference/backends/base_backend_client.py +0 -26
- isa_model/inference/backends/container_services.py +0 -104
- isa_model/inference/backends/local_services.py +0 -72
- isa_model/inference/backends/openai_client.py +0 -130
- isa_model/inference/backends/replicate_client.py +0 -197
- isa_model/inference/backends/third_party_services.py +0 -239
- isa_model/inference/backends/triton_client.py +0 -97
- isa_model/inference/client_sdk/client.py +0 -134
- isa_model/inference/client_sdk/client_data_std.py +0 -34
- isa_model/inference/client_sdk/client_sdk_schema.py +0 -16
- isa_model/inference/client_sdk/exceptions.py +0 -0
- isa_model/inference/engine/triton/model_repository/bge/1/model.py +0 -174
- isa_model/inference/engine/triton/model_repository/gemma/1/model.py +0 -250
- isa_model/inference/engine/triton/model_repository/llama/1/model.py +0 -76
- isa_model/inference/engine/triton/model_repository/whisper/1/model.py +0 -195
- isa_model/inference/providers/vllm_provider.py +0 -0
- isa_model/inference/providers/yyds_provider.py +0 -83
- isa_model/inference/services/audio/fish_speech/handler.py +0 -215
- isa_model/inference/services/audio/runpod_tts_fish_service.py +0 -212
- isa_model/inference/services/audio/triton_speech_service.py +0 -138
- isa_model/inference/services/audio/whisper_service.py +0 -186
- isa_model/inference/services/base_tts_service.py +0 -66
- isa_model/inference/services/embedding/bge_service.py +0 -183
- isa_model/inference/services/embedding/ollama_rerank_service.py +0 -118
- isa_model/inference/services/embedding/onnx_rerank_service.py +0 -73
- isa_model/inference/services/llm/gemma_service.py +0 -143
- isa_model/inference/services/llm/llama_service.py +0 -143
- isa_model/inference/services/llm/replicate_llm_service.py +0 -179
- isa_model/inference/services/llm/triton_llm_service.py +0 -230
- isa_model/inference/services/vision/replicate_vision_service.py +0 -241
- isa_model/inference/services/vision/triton_vision_service.py +0 -199
- isa_model-0.1.0.dist-info/METADATA +0 -116
- /isa_model/inference/{client_sdk/__init__.py → services/embedding/openai_embed_service.py} +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.2.0.dist-info}/WHEEL +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,104 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Container Services Backend - Docker/K8s deployed services.
|
3
|
-
Examples: Triton Inference Server, vLLM, TensorFlow Serving
|
4
|
-
"""
|
5
|
-
|
6
|
-
import aiohttp
|
7
|
-
import json
|
8
|
-
from typing import Dict, Any, List, Optional
|
9
|
-
from .base_backend_client import BaseBackendClient
|
10
|
-
from .triton_client import TritonBackendClient # Re-export existing Triton client
|
11
|
-
|
12
|
-
|
13
|
-
class VLLMBackendClient(BaseBackendClient):
|
14
|
-
"""Pure connection client for vLLM service deployed in containers"""
|
15
|
-
|
16
|
-
def __init__(self, base_url: str, api_key: Optional[str] = None):
|
17
|
-
self.base_url = base_url.rstrip('/')
|
18
|
-
self.api_key = api_key
|
19
|
-
self.headers = {"Content-Type": "application/json"}
|
20
|
-
if api_key:
|
21
|
-
self.headers["Authorization"] = f"Bearer {api_key}"
|
22
|
-
self._session = None
|
23
|
-
|
24
|
-
async def _get_session(self):
|
25
|
-
"""Get or create HTTP session"""
|
26
|
-
if self._session is None:
|
27
|
-
self._session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=60))
|
28
|
-
return self._session
|
29
|
-
|
30
|
-
async def post(self, endpoint: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
31
|
-
"""Make POST request to vLLM API"""
|
32
|
-
session = await self._get_session()
|
33
|
-
async with session.post(f"{self.base_url}{endpoint}", json=payload, headers=self.headers) as response:
|
34
|
-
response.raise_for_status()
|
35
|
-
return await response.json()
|
36
|
-
|
37
|
-
async def get(self, endpoint: str) -> Dict[str, Any]:
|
38
|
-
"""Make GET request to vLLM API"""
|
39
|
-
session = await self._get_session()
|
40
|
-
async with session.get(f"{self.base_url}{endpoint}", headers=self.headers) as response:
|
41
|
-
response.raise_for_status()
|
42
|
-
return await response.json()
|
43
|
-
|
44
|
-
async def health_check(self) -> bool:
|
45
|
-
"""Check if vLLM service is healthy"""
|
46
|
-
try:
|
47
|
-
await self.get("/health")
|
48
|
-
return True
|
49
|
-
except Exception:
|
50
|
-
return False
|
51
|
-
|
52
|
-
async def close(self):
|
53
|
-
"""Close the HTTP session"""
|
54
|
-
if self._session:
|
55
|
-
await self._session.close()
|
56
|
-
self._session = None
|
57
|
-
|
58
|
-
|
59
|
-
class TensorFlowServingClient(BaseBackendClient):
|
60
|
-
"""Backend client for TensorFlow Serving in containers"""
|
61
|
-
|
62
|
-
def __init__(self, base_url: str, model_name: str, version: Optional[str] = None):
|
63
|
-
self.base_url = base_url.rstrip('/')
|
64
|
-
self.model_name = model_name
|
65
|
-
self.version = version or "latest"
|
66
|
-
|
67
|
-
async def predict(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
68
|
-
"""Make prediction using TensorFlow Serving"""
|
69
|
-
async with aiohttp.ClientSession() as session:
|
70
|
-
url = f"{self.base_url}/v1/models/{self.model_name}"
|
71
|
-
if self.version != "latest":
|
72
|
-
url += f"/versions/{self.version}"
|
73
|
-
url += ":predict"
|
74
|
-
|
75
|
-
payload = {"instances": [inputs]}
|
76
|
-
async with session.post(url, json=payload) as response:
|
77
|
-
return await response.json()
|
78
|
-
|
79
|
-
async def health_check(self) -> bool:
|
80
|
-
"""Check if TensorFlow Serving is healthy"""
|
81
|
-
try:
|
82
|
-
async with aiohttp.ClientSession() as session:
|
83
|
-
url = f"{self.base_url}/v1/models/{self.model_name}"
|
84
|
-
async with session.get(url) as response:
|
85
|
-
return response.status == 200
|
86
|
-
except Exception:
|
87
|
-
return False
|
88
|
-
|
89
|
-
|
90
|
-
class KubernetesServiceClient(BaseBackendClient):
|
91
|
-
"""Generic client for services deployed in Kubernetes"""
|
92
|
-
|
93
|
-
def __init__(self, service_url: str, namespace: str = "default"):
|
94
|
-
self.service_url = service_url.rstrip('/')
|
95
|
-
self.namespace = namespace
|
96
|
-
|
97
|
-
async def health_check(self) -> bool:
|
98
|
-
"""Check if K8s service is healthy"""
|
99
|
-
try:
|
100
|
-
async with aiohttp.ClientSession() as session:
|
101
|
-
async with session.get(f"{self.service_url}/health") as response:
|
102
|
-
return response.status == 200
|
103
|
-
except Exception:
|
104
|
-
return False
|
@@ -1,72 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Local Services Backend - Services running locally on the same machine.
|
3
|
-
Examples: Ollama, Local model servers
|
4
|
-
"""
|
5
|
-
|
6
|
-
import aiohttp
|
7
|
-
import json
|
8
|
-
from typing import Dict, Any, List, Optional
|
9
|
-
from .base_backend_client import BaseBackendClient
|
10
|
-
|
11
|
-
|
12
|
-
class OllamaBackendClient(BaseBackendClient):
|
13
|
-
"""Pure connection client for local Ollama service"""
|
14
|
-
|
15
|
-
def __init__(self, host: str = "localhost", port: int = 11434):
|
16
|
-
self.base_url = f"http://{host}:{port}"
|
17
|
-
self._session = None
|
18
|
-
|
19
|
-
async def _get_session(self):
|
20
|
-
"""Get or create HTTP session"""
|
21
|
-
if self._session is None:
|
22
|
-
self._session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30))
|
23
|
-
return self._session
|
24
|
-
|
25
|
-
async def post(self, endpoint: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
26
|
-
"""Make POST request to Ollama API"""
|
27
|
-
session = await self._get_session()
|
28
|
-
async with session.post(f"{self.base_url}{endpoint}", json=payload) as response:
|
29
|
-
response.raise_for_status()
|
30
|
-
return await response.json()
|
31
|
-
|
32
|
-
async def get(self, endpoint: str) -> Dict[str, Any]:
|
33
|
-
"""Make GET request to Ollama API"""
|
34
|
-
session = await self._get_session()
|
35
|
-
async with session.get(f"{self.base_url}{endpoint}") as response:
|
36
|
-
response.raise_for_status()
|
37
|
-
return await response.json()
|
38
|
-
|
39
|
-
async def health_check(self) -> bool:
|
40
|
-
"""Check if Ollama service is healthy"""
|
41
|
-
try:
|
42
|
-
await self.get("/api/tags")
|
43
|
-
return True
|
44
|
-
except Exception:
|
45
|
-
return False
|
46
|
-
|
47
|
-
async def close(self):
|
48
|
-
"""Close the HTTP session"""
|
49
|
-
if self._session:
|
50
|
-
await self._session.close()
|
51
|
-
self._session = None
|
52
|
-
|
53
|
-
|
54
|
-
class LocalModelServerClient(BaseBackendClient):
|
55
|
-
"""Generic client for local model servers"""
|
56
|
-
|
57
|
-
def __init__(self, base_url: str):
|
58
|
-
self.base_url = base_url.rstrip('/')
|
59
|
-
|
60
|
-
async def generate_completion(self, model: str, prompt: str, **kwargs) -> Dict[str, Any]:
|
61
|
-
"""Generate completion using generic local server"""
|
62
|
-
# Implementation depends on local server API
|
63
|
-
raise NotImplementedError("Implement based on your local server API")
|
64
|
-
|
65
|
-
async def health_check(self) -> bool:
|
66
|
-
"""Check if local server is healthy"""
|
67
|
-
try:
|
68
|
-
async with aiohttp.ClientSession() as session:
|
69
|
-
async with session.get(f"{self.base_url}/health") as response:
|
70
|
-
return response.status == 200
|
71
|
-
except Exception:
|
72
|
-
return False
|
@@ -1,130 +0,0 @@
|
|
1
|
-
import httpx
|
2
|
-
import logging
|
3
|
-
from typing import Dict, Any, Optional, AsyncGenerator
|
4
|
-
import json
|
5
|
-
import asyncio
|
6
|
-
|
7
|
-
logger = logging.getLogger(__name__)
|
8
|
-
|
9
|
-
class OpenAIBackendClient:
|
10
|
-
"""Client for interacting with OpenAI API"""
|
11
|
-
|
12
|
-
def __init__(self, api_key: str, api_base: str = "https://api.openai.com/v1", timeout: int = 60):
|
13
|
-
"""
|
14
|
-
Initialize the OpenAI client
|
15
|
-
|
16
|
-
Args:
|
17
|
-
api_key: OpenAI API key
|
18
|
-
api_base: Base URL for OpenAI API
|
19
|
-
timeout: Timeout for API calls in seconds
|
20
|
-
"""
|
21
|
-
self.api_key = api_key
|
22
|
-
self.api_base = api_base
|
23
|
-
self.timeout = timeout
|
24
|
-
self.client = httpx.AsyncClient(timeout=timeout)
|
25
|
-
self._headers = {
|
26
|
-
"Content-Type": "application/json",
|
27
|
-
"Authorization": f"Bearer {api_key}"
|
28
|
-
}
|
29
|
-
|
30
|
-
logger.info(f"Initialized OpenAI client with API base: {api_base}")
|
31
|
-
|
32
|
-
async def post(self, endpoint: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
33
|
-
"""
|
34
|
-
Send a POST request to the OpenAI API
|
35
|
-
|
36
|
-
Args:
|
37
|
-
endpoint: API endpoint (e.g., /chat/completions)
|
38
|
-
payload: Request payload
|
39
|
-
|
40
|
-
Returns:
|
41
|
-
Response from the API
|
42
|
-
"""
|
43
|
-
url = f"{self.api_base}{endpoint}"
|
44
|
-
try:
|
45
|
-
response = await self.client.post(url, json=payload, headers=self._headers)
|
46
|
-
response.raise_for_status()
|
47
|
-
return response.json()
|
48
|
-
except httpx.HTTPStatusError as e:
|
49
|
-
error_detail = {}
|
50
|
-
try:
|
51
|
-
error_detail = e.response.json()
|
52
|
-
except Exception:
|
53
|
-
error_detail = {"status": e.response.status_code, "text": e.response.text}
|
54
|
-
|
55
|
-
logger.error(f"OpenAI API error: {error_detail}")
|
56
|
-
raise ValueError(f"OpenAI API error: {error_detail}")
|
57
|
-
except Exception as e:
|
58
|
-
logger.error(f"Error communicating with OpenAI API: {e}")
|
59
|
-
raise
|
60
|
-
|
61
|
-
async def stream_chat(self, payload: Dict[str, Any]) -> AsyncGenerator[str, None]:
|
62
|
-
"""
|
63
|
-
Stream responses from the chat completion API
|
64
|
-
|
65
|
-
Args:
|
66
|
-
payload: Request payload (must include 'stream': True)
|
67
|
-
|
68
|
-
Yields:
|
69
|
-
Response chunks from the API
|
70
|
-
"""
|
71
|
-
url = f"{self.api_base}/chat/completions"
|
72
|
-
payload["stream"] = True
|
73
|
-
|
74
|
-
try:
|
75
|
-
async with self.client.stream("POST", url, json=payload, headers=self._headers) as response:
|
76
|
-
response.raise_for_status()
|
77
|
-
async for chunk in response.aiter_lines():
|
78
|
-
if not chunk.strip():
|
79
|
-
continue
|
80
|
-
if chunk.startswith("data: "):
|
81
|
-
chunk = chunk[6:]
|
82
|
-
if chunk == "[DONE]":
|
83
|
-
break
|
84
|
-
try:
|
85
|
-
content = json.loads(chunk)
|
86
|
-
if content.get("choices") and len(content["choices"]) > 0:
|
87
|
-
delta = content["choices"][0].get("delta", {})
|
88
|
-
if "content" in delta:
|
89
|
-
yield delta["content"]
|
90
|
-
except json.JSONDecodeError:
|
91
|
-
logger.warning(f"Failed to parse chunk: {chunk}")
|
92
|
-
continue
|
93
|
-
except Exception as e:
|
94
|
-
logger.error(f"Error processing stream chunk: {e}")
|
95
|
-
continue
|
96
|
-
except httpx.HTTPStatusError as e:
|
97
|
-
error_detail = {}
|
98
|
-
try:
|
99
|
-
error_detail = e.response.json()
|
100
|
-
except Exception:
|
101
|
-
error_detail = {"status": e.response.status_code, "text": e.response.text}
|
102
|
-
|
103
|
-
logger.error(f"OpenAI API streaming error: {error_detail}")
|
104
|
-
raise ValueError(f"OpenAI API streaming error: {error_detail}")
|
105
|
-
except Exception as e:
|
106
|
-
logger.error(f"Error communicating with OpenAI API: {e}")
|
107
|
-
raise
|
108
|
-
|
109
|
-
async def get_embedding(self, text: str, model: str = "text-embedding-3-small") -> list:
|
110
|
-
"""
|
111
|
-
Get embedding for a text
|
112
|
-
|
113
|
-
Args:
|
114
|
-
text: Text to embed
|
115
|
-
model: Embedding model to use
|
116
|
-
|
117
|
-
Returns:
|
118
|
-
List of embedding values
|
119
|
-
"""
|
120
|
-
payload = {
|
121
|
-
"input": text,
|
122
|
-
"model": model
|
123
|
-
}
|
124
|
-
|
125
|
-
result = await self.post("/embeddings", payload)
|
126
|
-
return result["data"][0]["embedding"]
|
127
|
-
|
128
|
-
async def close(self):
|
129
|
-
"""Close the HTTP client"""
|
130
|
-
await self.client.aclose()
|
@@ -1,197 +0,0 @@
|
|
1
|
-
import httpx
|
2
|
-
import logging
|
3
|
-
from typing import Dict, Any, Optional, AsyncGenerator, List
|
4
|
-
import json
|
5
|
-
import asyncio
|
6
|
-
import time
|
7
|
-
|
8
|
-
logger = logging.getLogger(__name__)
|
9
|
-
|
10
|
-
class ReplicateBackendClient:
|
11
|
-
"""Client for interacting with Replicate API"""
|
12
|
-
|
13
|
-
def __init__(self, api_token: str, timeout: int = 120):
|
14
|
-
"""
|
15
|
-
Initialize the Replicate client
|
16
|
-
|
17
|
-
Args:
|
18
|
-
api_token: Replicate API token
|
19
|
-
timeout: Timeout for API calls in seconds
|
20
|
-
"""
|
21
|
-
self.api_token = api_token
|
22
|
-
self.api_base = "https://api.replicate.com/v1"
|
23
|
-
self.timeout = timeout
|
24
|
-
self.client = httpx.AsyncClient(timeout=timeout)
|
25
|
-
self._headers = {
|
26
|
-
"Content-Type": "application/json",
|
27
|
-
"Authorization": f"Token {api_token}"
|
28
|
-
}
|
29
|
-
|
30
|
-
logger.info(f"Initialized Replicate client")
|
31
|
-
|
32
|
-
async def post(self, endpoint: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
33
|
-
"""
|
34
|
-
Send a POST request to the Replicate API
|
35
|
-
|
36
|
-
Args:
|
37
|
-
endpoint: API endpoint (e.g., /predictions)
|
38
|
-
payload: Request payload
|
39
|
-
|
40
|
-
Returns:
|
41
|
-
Response from the API
|
42
|
-
"""
|
43
|
-
url = f"{self.api_base}{endpoint}"
|
44
|
-
try:
|
45
|
-
response = await self.client.post(url, json=payload, headers=self._headers)
|
46
|
-
response.raise_for_status()
|
47
|
-
return response.json()
|
48
|
-
except httpx.HTTPStatusError as e:
|
49
|
-
error_detail = {}
|
50
|
-
try:
|
51
|
-
error_detail = e.response.json()
|
52
|
-
except Exception:
|
53
|
-
error_detail = {"status": e.response.status_code, "text": e.response.text}
|
54
|
-
|
55
|
-
logger.error(f"Replicate API error: {error_detail}")
|
56
|
-
raise ValueError(f"Replicate API error: {error_detail}")
|
57
|
-
except Exception as e:
|
58
|
-
logger.error(f"Error communicating with Replicate API: {e}")
|
59
|
-
raise
|
60
|
-
|
61
|
-
async def get(self, endpoint: str) -> Dict[str, Any]:
|
62
|
-
"""
|
63
|
-
Send a GET request to the Replicate API
|
64
|
-
|
65
|
-
Args:
|
66
|
-
endpoint: API endpoint (e.g., /predictions/{id})
|
67
|
-
|
68
|
-
Returns:
|
69
|
-
Response from the API
|
70
|
-
"""
|
71
|
-
url = f"{self.api_base}{endpoint}"
|
72
|
-
try:
|
73
|
-
response = await self.client.get(url, headers=self._headers)
|
74
|
-
response.raise_for_status()
|
75
|
-
return response.json()
|
76
|
-
except httpx.HTTPStatusError as e:
|
77
|
-
error_detail = {}
|
78
|
-
try:
|
79
|
-
error_detail = e.response.json()
|
80
|
-
except Exception:
|
81
|
-
error_detail = {"status": e.response.status_code, "text": e.response.text}
|
82
|
-
|
83
|
-
logger.error(f"Replicate API error: {error_detail}")
|
84
|
-
raise ValueError(f"Replicate API error: {error_detail}")
|
85
|
-
except Exception as e:
|
86
|
-
logger.error(f"Error communicating with Replicate API: {e}")
|
87
|
-
raise
|
88
|
-
|
89
|
-
async def create_prediction(self, model: str, version: str, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
90
|
-
"""
|
91
|
-
Create a prediction with Replicate
|
92
|
-
|
93
|
-
Args:
|
94
|
-
model: Model identifier (e.g., "meta/llama-3-8b-instruct")
|
95
|
-
version: Model version or None to use latest
|
96
|
-
input_data: Input data for the model
|
97
|
-
|
98
|
-
Returns:
|
99
|
-
Prediction result
|
100
|
-
"""
|
101
|
-
model_identifier = f"{model}:{version}" if version else model
|
102
|
-
|
103
|
-
payload = {
|
104
|
-
"version": model_identifier,
|
105
|
-
"input": input_data
|
106
|
-
}
|
107
|
-
|
108
|
-
# Create prediction
|
109
|
-
prediction = await self.post("/predictions", payload)
|
110
|
-
prediction_id = prediction["id"]
|
111
|
-
|
112
|
-
# Poll for completion
|
113
|
-
status = prediction["status"]
|
114
|
-
max_attempts = 120 # 10 minutes with 5 second intervals
|
115
|
-
attempts = 0
|
116
|
-
|
117
|
-
while status in ["starting", "processing"] and attempts < max_attempts:
|
118
|
-
await asyncio.sleep(5)
|
119
|
-
prediction = await self.get(f"/predictions/{prediction_id}")
|
120
|
-
status = prediction["status"]
|
121
|
-
attempts += 1
|
122
|
-
|
123
|
-
if status != "succeeded":
|
124
|
-
error = prediction.get("error", "Unknown error")
|
125
|
-
raise ValueError(f"Prediction failed: {error}")
|
126
|
-
|
127
|
-
return prediction
|
128
|
-
|
129
|
-
async def stream_prediction(self, model: str, version: str, input_data: Dict[str, Any]) -> AsyncGenerator[str, None]:
|
130
|
-
"""
|
131
|
-
Stream a prediction from Replicate
|
132
|
-
|
133
|
-
Args:
|
134
|
-
model: Model identifier (e.g., "meta/llama-3-8b-instruct")
|
135
|
-
version: Model version or None to use latest
|
136
|
-
input_data: Input data for the model
|
137
|
-
|
138
|
-
Yields:
|
139
|
-
Output tokens from the model
|
140
|
-
"""
|
141
|
-
# Set streaming in input data
|
142
|
-
input_data["stream"] = True
|
143
|
-
|
144
|
-
model_identifier = f"{model}:{version}" if version else model
|
145
|
-
|
146
|
-
payload = {
|
147
|
-
"version": model_identifier,
|
148
|
-
"input": input_data
|
149
|
-
}
|
150
|
-
|
151
|
-
# Create prediction
|
152
|
-
prediction = await self.post("/predictions", payload)
|
153
|
-
prediction_id = prediction["id"]
|
154
|
-
|
155
|
-
# Poll for the start of processing
|
156
|
-
status = prediction["status"]
|
157
|
-
while status in ["starting"]:
|
158
|
-
await asyncio.sleep(1)
|
159
|
-
prediction = await self.get(f"/predictions/{prediction_id}")
|
160
|
-
status = prediction["status"]
|
161
|
-
|
162
|
-
# Stream the output
|
163
|
-
if status in ["processing", "succeeded"]:
|
164
|
-
current_outputs = []
|
165
|
-
last_output_len = 0
|
166
|
-
max_attempts = 180 # 15 minutes with 5 second intervals
|
167
|
-
attempts = 0
|
168
|
-
|
169
|
-
while status in ["processing"] and attempts < max_attempts:
|
170
|
-
prediction = await self.get(f"/predictions/{prediction_id}")
|
171
|
-
status = prediction["status"]
|
172
|
-
|
173
|
-
# Get outputs
|
174
|
-
outputs = prediction.get("output", [])
|
175
|
-
|
176
|
-
# If we have new tokens, yield them
|
177
|
-
if isinstance(outputs, list) and len(outputs) > last_output_len:
|
178
|
-
for i in range(last_output_len, len(outputs)):
|
179
|
-
yield outputs[i]
|
180
|
-
last_output_len = len(outputs)
|
181
|
-
|
182
|
-
await asyncio.sleep(0.5)
|
183
|
-
attempts += 1
|
184
|
-
|
185
|
-
# Final check for any remaining output
|
186
|
-
if status == "succeeded":
|
187
|
-
outputs = prediction.get("output", [])
|
188
|
-
if isinstance(outputs, list) and len(outputs) > last_output_len:
|
189
|
-
for i in range(last_output_len, len(outputs)):
|
190
|
-
yield outputs[i]
|
191
|
-
else:
|
192
|
-
error = prediction.get("error", "Unknown error")
|
193
|
-
raise ValueError(f"Prediction failed: {error}")
|
194
|
-
|
195
|
-
async def close(self):
|
196
|
-
"""Close the HTTP client"""
|
197
|
-
await self.client.aclose()
|