isa-model 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/core/model_registry.py +273 -46
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
- isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
- isa_model/eval/__init__.py +56 -0
- isa_model/eval/benchmarks.py +469 -0
- isa_model/eval/factory.py +582 -0
- isa_model/eval/metrics.py +628 -0
- isa_model/inference/ai_factory.py +98 -93
- isa_model/inference/providers/openai_provider.py +21 -7
- isa_model/inference/providers/replicate_provider.py +18 -5
- isa_model/inference/providers/triton_provider.py +1 -1
- isa_model/inference/services/audio/base_stt_service.py +91 -0
- isa_model/inference/services/audio/base_tts_service.py +136 -0
- isa_model/inference/services/audio/{yyds_audio_service.py → openai_tts_service.py} +4 -4
- isa_model/inference/services/embedding/ollama_embed_service.py +48 -36
- isa_model/inference/services/llm/__init__.py +0 -4
- isa_model/inference/services/llm/base_llm_service.py +134 -0
- isa_model/inference/services/llm/ollama_llm_service.py +1 -10
- isa_model/inference/services/llm/openai_llm_service.py +70 -61
- isa_model/inference/services/vision/__init__.py +1 -1
- isa_model/inference/services/vision/ollama_vision_service.py +4 -4
- isa_model/inference/services/vision/{yyds_vision_service.py → openai_vision_service.py} +5 -5
- isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
- isa_model/training/__init__.py +44 -0
- isa_model/training/factory.py +393 -0
- isa_model-0.1.1.dist-info/METADATA +327 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/RECORD +35 -60
- isa_model/deployment/mlflow_gateway/__init__.py +0 -8
- isa_model/deployment/mlflow_gateway/start_gateway.py +0 -65
- isa_model/deployment/unified_multimodal_client.py +0 -341
- isa_model/inference/adapter/triton_adapter.py +0 -453
- isa_model/inference/backends/Pytorch/bge_embed_backend.py +0 -188
- isa_model/inference/backends/Pytorch/gemma_backend.py +0 -167
- isa_model/inference/backends/Pytorch/llama_backend.py +0 -166
- isa_model/inference/backends/Pytorch/whisper_backend.py +0 -194
- isa_model/inference/backends/__init__.py +0 -53
- isa_model/inference/backends/base_backend_client.py +0 -26
- isa_model/inference/backends/container_services.py +0 -104
- isa_model/inference/backends/local_services.py +0 -72
- isa_model/inference/backends/openai_client.py +0 -130
- isa_model/inference/backends/replicate_client.py +0 -197
- isa_model/inference/backends/third_party_services.py +0 -239
- isa_model/inference/backends/triton_client.py +0 -97
- isa_model/inference/client_sdk/client.py +0 -134
- isa_model/inference/client_sdk/client_data_std.py +0 -34
- isa_model/inference/client_sdk/client_sdk_schema.py +0 -16
- isa_model/inference/client_sdk/exceptions.py +0 -0
- isa_model/inference/engine/triton/model_repository/bge/1/model.py +0 -174
- isa_model/inference/engine/triton/model_repository/gemma/1/model.py +0 -250
- isa_model/inference/engine/triton/model_repository/llama/1/model.py +0 -76
- isa_model/inference/engine/triton/model_repository/whisper/1/model.py +0 -195
- isa_model/inference/providers/vllm_provider.py +0 -0
- isa_model/inference/providers/yyds_provider.py +0 -83
- isa_model/inference/services/audio/fish_speech/handler.py +0 -215
- isa_model/inference/services/audio/runpod_tts_fish_service.py +0 -212
- isa_model/inference/services/audio/triton_speech_service.py +0 -138
- isa_model/inference/services/audio/whisper_service.py +0 -186
- isa_model/inference/services/base_tts_service.py +0 -66
- isa_model/inference/services/embedding/bge_service.py +0 -183
- isa_model/inference/services/embedding/ollama_rerank_service.py +0 -118
- isa_model/inference/services/embedding/onnx_rerank_service.py +0 -73
- isa_model/inference/services/llm/gemma_service.py +0 -143
- isa_model/inference/services/llm/llama_service.py +0 -143
- isa_model/inference/services/llm/replicate_llm_service.py +0 -179
- isa_model/inference/services/llm/triton_llm_service.py +0 -230
- isa_model/inference/services/vision/replicate_vision_service.py +0 -241
- isa_model/inference/services/vision/triton_vision_service.py +0 -199
- isa_model-0.1.0.dist-info/METADATA +0 -116
- /isa_model/inference/{client_sdk/__init__.py → services/embedding/openai_embed_service.py} +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/WHEEL +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/top_level.txt +0 -0
@@ -1,174 +0,0 @@
|
|
1
|
-
import os
|
2
|
-
import json
|
3
|
-
import numpy as np
|
4
|
-
import triton_python_backend_utils as pb_utils
|
5
|
-
import sys
|
6
|
-
import logging
|
7
|
-
import torch
|
8
|
-
|
9
|
-
# Configure logging
|
10
|
-
logging.basicConfig(level=logging.INFO)
|
11
|
-
logger = logging.getLogger("bge_triton_model")
|
12
|
-
|
13
|
-
class TritonPythonModel:
|
14
|
-
"""
|
15
|
-
Python model for BGE embedding.
|
16
|
-
"""
|
17
|
-
|
18
|
-
def initialize(self, args):
|
19
|
-
"""
|
20
|
-
Initialize the model.
|
21
|
-
"""
|
22
|
-
self.model_config = json.loads(args['model_config'])
|
23
|
-
self.model_path = os.environ.get("BGE_MODEL_PATH", "/models/Bge-m3")
|
24
|
-
# Always use CPU for testing
|
25
|
-
self.device = "cpu"
|
26
|
-
self.model = None
|
27
|
-
self.tokenizer = None
|
28
|
-
self._loaded = False
|
29
|
-
|
30
|
-
# Default configuration
|
31
|
-
self.config = {
|
32
|
-
"normalize": True,
|
33
|
-
"max_length": 512,
|
34
|
-
"pooling_method": "cls" # Use CLS token for sentence embedding
|
35
|
-
}
|
36
|
-
|
37
|
-
self._load_model()
|
38
|
-
|
39
|
-
logger.info(f"Initialized BGE embedding model on {self.device}")
|
40
|
-
|
41
|
-
def _load_model(self):
|
42
|
-
"""Load the BGE model and tokenizer"""
|
43
|
-
if self._loaded:
|
44
|
-
return
|
45
|
-
|
46
|
-
try:
|
47
|
-
from transformers import AutoModel, AutoTokenizer
|
48
|
-
|
49
|
-
# Load tokenizer
|
50
|
-
logger.info(f"Loading BGE tokenizer from {self.model_path}")
|
51
|
-
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
|
52
|
-
|
53
|
-
# Load model on CPU for testing
|
54
|
-
logger.info(f"Loading BGE model on {self.device}")
|
55
|
-
self.model = AutoModel.from_pretrained(
|
56
|
-
self.model_path,
|
57
|
-
torch_dtype=torch.float32,
|
58
|
-
device_map="cpu" # Force CPU
|
59
|
-
)
|
60
|
-
|
61
|
-
self.model.eval()
|
62
|
-
self._loaded = True
|
63
|
-
logger.info("BGE model loaded successfully")
|
64
|
-
|
65
|
-
except Exception as e:
|
66
|
-
logger.error(f"Failed to load BGE model: {str(e)}")
|
67
|
-
# Fall back to dummy model for testing if model loading fails
|
68
|
-
self._loaded = False
|
69
|
-
|
70
|
-
def _generate_embeddings(self, texts, normalize=True):
|
71
|
-
"""Generate embeddings for the given texts"""
|
72
|
-
if not self._loaded:
|
73
|
-
# Return random embeddings for testing
|
74
|
-
logger.warning("Model not loaded, returning random embeddings")
|
75
|
-
# Generate random embeddings with dimension 1024 (typical for BGE)
|
76
|
-
num_texts = len(texts) if isinstance(texts, list) else 1
|
77
|
-
return np.random.randn(num_texts, 1024).astype(np.float32)
|
78
|
-
|
79
|
-
try:
|
80
|
-
# Ensure texts is a list
|
81
|
-
if isinstance(texts, str):
|
82
|
-
texts = [texts]
|
83
|
-
|
84
|
-
# Tokenize the texts
|
85
|
-
inputs = self.tokenizer(
|
86
|
-
texts,
|
87
|
-
padding=True,
|
88
|
-
truncation=True,
|
89
|
-
max_length=self.config["max_length"],
|
90
|
-
return_tensors="pt"
|
91
|
-
).to(self.device)
|
92
|
-
|
93
|
-
# Generate embeddings
|
94
|
-
with torch.no_grad():
|
95
|
-
outputs = self.model(**inputs)
|
96
|
-
|
97
|
-
# Use [CLS] token embedding as the sentence embedding
|
98
|
-
embeddings = outputs.last_hidden_state[:, 0, :]
|
99
|
-
|
100
|
-
# Normalize embeddings if required
|
101
|
-
if normalize:
|
102
|
-
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
103
|
-
|
104
|
-
# Convert to numpy array
|
105
|
-
embeddings_np = embeddings.cpu().numpy()
|
106
|
-
|
107
|
-
return embeddings_np
|
108
|
-
|
109
|
-
except Exception as e:
|
110
|
-
logger.error(f"Error during BGE embedding generation: {str(e)}")
|
111
|
-
# Return random embeddings in case of error
|
112
|
-
num_texts = len(texts) if isinstance(texts, list) else 1
|
113
|
-
return np.random.randn(num_texts, 1024).astype(np.float32)
|
114
|
-
|
115
|
-
def execute(self, requests):
|
116
|
-
"""
|
117
|
-
Process inference requests.
|
118
|
-
"""
|
119
|
-
responses = []
|
120
|
-
|
121
|
-
for request in requests:
|
122
|
-
# Get input tensors
|
123
|
-
input_tensor = pb_utils.get_input_tensor_by_name(request, "text_input")
|
124
|
-
|
125
|
-
# Get texts from input tensor
|
126
|
-
if input_tensor is not None:
|
127
|
-
input_data = input_tensor.as_numpy()
|
128
|
-
texts = []
|
129
|
-
|
130
|
-
# Convert bytes to strings
|
131
|
-
for i in range(len(input_data)):
|
132
|
-
if input_data[i].dtype.type is np.bytes_:
|
133
|
-
texts.append(input_data[i].decode('utf-8'))
|
134
|
-
else:
|
135
|
-
texts.append(str(input_data[i]))
|
136
|
-
|
137
|
-
# Generate embeddings
|
138
|
-
embeddings = self._generate_embeddings(texts)
|
139
|
-
|
140
|
-
# Create output tensor
|
141
|
-
output_tensor = pb_utils.Tensor(
|
142
|
-
"embedding_output",
|
143
|
-
embeddings.astype(np.float32)
|
144
|
-
)
|
145
|
-
else:
|
146
|
-
# If no input is provided, return empty tensor
|
147
|
-
output_tensor = pb_utils.Tensor(
|
148
|
-
"embedding_output",
|
149
|
-
np.array([], dtype=np.float32)
|
150
|
-
)
|
151
|
-
|
152
|
-
# Create inference response
|
153
|
-
inference_response = pb_utils.InferenceResponse(output_tensors=[output_tensor])
|
154
|
-
responses.append(inference_response)
|
155
|
-
|
156
|
-
return responses
|
157
|
-
|
158
|
-
def finalize(self):
|
159
|
-
"""
|
160
|
-
Clean up resources when the model is unloaded.
|
161
|
-
"""
|
162
|
-
if self.model is not None:
|
163
|
-
self.model = None
|
164
|
-
self.tokenizer = None
|
165
|
-
self._loaded = False
|
166
|
-
|
167
|
-
# Force garbage collection
|
168
|
-
import gc
|
169
|
-
gc.collect()
|
170
|
-
|
171
|
-
if self.device == "cuda":
|
172
|
-
torch.cuda.empty_cache()
|
173
|
-
|
174
|
-
logger.info("BGE embedding model unloaded")
|
@@ -1,250 +0,0 @@
|
|
1
|
-
import os
|
2
|
-
import json
|
3
|
-
import numpy as np
|
4
|
-
import triton_python_backend_utils as pb_utils
|
5
|
-
import sys
|
6
|
-
import base64
|
7
|
-
import torch
|
8
|
-
from PIL import Image
|
9
|
-
import io
|
10
|
-
import logging
|
11
|
-
|
12
|
-
# Configure logging
|
13
|
-
logging.basicConfig(level=logging.INFO)
|
14
|
-
logger = logging.getLogger("gemma_triton_model")
|
15
|
-
|
16
|
-
class TritonPythonModel:
|
17
|
-
"""
|
18
|
-
Python model for Gemma Vision LLM (VLM).
|
19
|
-
"""
|
20
|
-
|
21
|
-
def initialize(self, args):
|
22
|
-
"""
|
23
|
-
Initialize the model.
|
24
|
-
"""
|
25
|
-
self.model_config = json.loads(args['model_config'])
|
26
|
-
|
27
|
-
# Try different possible model paths
|
28
|
-
possible_paths = [
|
29
|
-
"/models/Gemma3-4B", # Original path
|
30
|
-
"/models/gemma", # Alternative path
|
31
|
-
]
|
32
|
-
|
33
|
-
# Find the first path that exists
|
34
|
-
self.model_path = None
|
35
|
-
for path in possible_paths:
|
36
|
-
if os.path.exists(path):
|
37
|
-
self.model_path = path
|
38
|
-
logger.info(f"Found model at path: {path}")
|
39
|
-
break
|
40
|
-
|
41
|
-
if self.model_path is None:
|
42
|
-
logger.error("Could not find model path!")
|
43
|
-
self.model_path = "/models/Gemma3-4B" # Default, will fail later
|
44
|
-
|
45
|
-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
46
|
-
self.model = None
|
47
|
-
self.tokenizer = None
|
48
|
-
self._loaded = False
|
49
|
-
|
50
|
-
# Default generation config
|
51
|
-
self.default_config = {
|
52
|
-
"max_new_tokens": 512,
|
53
|
-
"temperature": 0.7,
|
54
|
-
"top_p": 0.9,
|
55
|
-
"top_k": 50,
|
56
|
-
"repetition_penalty": 1.1,
|
57
|
-
"do_sample": True
|
58
|
-
}
|
59
|
-
|
60
|
-
logger.info(f"Initializing Gemma Vision model at {self.model_path} on {self.device}")
|
61
|
-
self._load_model()
|
62
|
-
|
63
|
-
if self._loaded:
|
64
|
-
logger.info("Gemma Vision model initialized successfully")
|
65
|
-
else:
|
66
|
-
logger.error("Failed to initialize Gemma Vision model")
|
67
|
-
|
68
|
-
def _load_model(self):
|
69
|
-
"""Load the Gemma model and tokenizer"""
|
70
|
-
if self._loaded:
|
71
|
-
return
|
72
|
-
|
73
|
-
try:
|
74
|
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
75
|
-
|
76
|
-
# Log environment information
|
77
|
-
logger.info(f"Possible model paths:")
|
78
|
-
logger.info(f"Current dir: {os.getcwd()}")
|
79
|
-
logger.info(f"Model path exists: {os.path.exists(self.model_path)}")
|
80
|
-
logger.info(f"Directory listing of /models:")
|
81
|
-
if os.path.exists("/models"):
|
82
|
-
logger.info(", ".join(os.listdir("/models")))
|
83
|
-
|
84
|
-
# Load tokenizer
|
85
|
-
logger.info(f"Loading Gemma tokenizer from {self.model_path}")
|
86
|
-
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
|
87
|
-
|
88
|
-
# Load model
|
89
|
-
logger.info(f"Loading Gemma model on {self.device}")
|
90
|
-
if self.device == "cpu":
|
91
|
-
self.model = AutoModelForCausalLM.from_pretrained(
|
92
|
-
self.model_path,
|
93
|
-
torch_dtype=torch.float32,
|
94
|
-
low_cpu_mem_usage=True,
|
95
|
-
device_map="auto"
|
96
|
-
)
|
97
|
-
else: # cuda
|
98
|
-
self.model = AutoModelForCausalLM.from_pretrained(
|
99
|
-
self.model_path,
|
100
|
-
torch_dtype=torch.float16, # Use half precision on GPU
|
101
|
-
device_map="auto"
|
102
|
-
)
|
103
|
-
|
104
|
-
self.model.eval()
|
105
|
-
self._loaded = True
|
106
|
-
logger.info("Gemma model loaded successfully")
|
107
|
-
|
108
|
-
except Exception as e:
|
109
|
-
logger.error(f"Failed to load Gemma model: {str(e)}")
|
110
|
-
# Fall back to dummy model for testing if model loading fails
|
111
|
-
self._loaded = False
|
112
|
-
|
113
|
-
def _process_image(self, image_data):
|
114
|
-
"""Process base64 image data"""
|
115
|
-
try:
|
116
|
-
# Extract the base64 part if it's a data URL
|
117
|
-
if isinstance(image_data, str) and image_data.startswith("data:image"):
|
118
|
-
# Extract the base64 part
|
119
|
-
image_data = image_data.split(",")[1]
|
120
|
-
|
121
|
-
# Decode base64
|
122
|
-
image_bytes = base64.b64decode(image_data)
|
123
|
-
|
124
|
-
# Open as PIL Image
|
125
|
-
image = Image.open(io.BytesIO(image_bytes))
|
126
|
-
|
127
|
-
# Process for model input if needed
|
128
|
-
# For now, we're just returning the image for text description
|
129
|
-
return image
|
130
|
-
|
131
|
-
except Exception as e:
|
132
|
-
logger.error(f"Error processing image: {str(e)}")
|
133
|
-
return None
|
134
|
-
|
135
|
-
def _generate_text(self, prompt, system_prompt=None, generation_config=None):
|
136
|
-
"""Generate text using the Gemma model"""
|
137
|
-
if not self._loaded:
|
138
|
-
return "Model not loaded. Using fallback response: I can see an image but cannot analyze it properly as the vision model is not loaded."
|
139
|
-
|
140
|
-
try:
|
141
|
-
# Get generation config
|
142
|
-
config = self.default_config.copy()
|
143
|
-
if generation_config:
|
144
|
-
config.update(generation_config)
|
145
|
-
|
146
|
-
# Format the prompt with system prompt if provided
|
147
|
-
if system_prompt:
|
148
|
-
# Gemma uses a specific format for system prompts
|
149
|
-
formatted_prompt = f"<bos><start_of_turn>system\n{system_prompt}<end_of_turn>\n<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model"
|
150
|
-
else:
|
151
|
-
formatted_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model"
|
152
|
-
|
153
|
-
# Tokenize the prompt
|
154
|
-
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.device)
|
155
|
-
|
156
|
-
# Generate text
|
157
|
-
with torch.no_grad():
|
158
|
-
outputs = self.model.generate(
|
159
|
-
inputs.input_ids,
|
160
|
-
attention_mask=inputs.attention_mask,
|
161
|
-
pad_token_id=self.tokenizer.eos_token_id,
|
162
|
-
**config
|
163
|
-
)
|
164
|
-
|
165
|
-
# Decode the generated text
|
166
|
-
generated_text = self.tokenizer.decode(
|
167
|
-
outputs[0][inputs.input_ids.shape[1]:],
|
168
|
-
skip_special_tokens=True
|
169
|
-
)
|
170
|
-
|
171
|
-
return generated_text.strip()
|
172
|
-
|
173
|
-
except Exception as e:
|
174
|
-
logger.error(f"Error during Gemma text generation: {str(e)}")
|
175
|
-
return f"Error generating response: {str(e)}"
|
176
|
-
|
177
|
-
def execute(self, requests):
|
178
|
-
"""
|
179
|
-
Process inference requests.
|
180
|
-
"""
|
181
|
-
responses = []
|
182
|
-
|
183
|
-
for request in requests:
|
184
|
-
# Get input tensors
|
185
|
-
input_tensor = pb_utils.get_input_tensor_by_name(request, "prompt")
|
186
|
-
|
187
|
-
# Convert to string
|
188
|
-
if input_tensor is not None:
|
189
|
-
input_data = input_tensor.as_numpy()
|
190
|
-
if input_data.dtype.type is np.bytes_:
|
191
|
-
input_text = input_data[0][0].decode('utf-8')
|
192
|
-
else:
|
193
|
-
input_text = str(input_data[0][0])
|
194
|
-
|
195
|
-
# Check if the input contains an image (base64)
|
196
|
-
if "data:image" in input_text:
|
197
|
-
# Extract image description query
|
198
|
-
query = "Describe this image in detail."
|
199
|
-
if "?" in input_text:
|
200
|
-
parts = input_text.split("?")
|
201
|
-
query = parts[0] + "?"
|
202
|
-
|
203
|
-
# For image inputs
|
204
|
-
if self._loaded:
|
205
|
-
response_text = self._generate_text(input_text)
|
206
|
-
else:
|
207
|
-
# Fallback if model not loaded
|
208
|
-
response_text = "Model not loaded. Using fallback response: I can see an image but cannot analyze it properly as the vision model is not loaded."
|
209
|
-
else:
|
210
|
-
# For text-only prompts
|
211
|
-
system_prompt_tensor = pb_utils.get_input_tensor_by_name(request, "system_prompt")
|
212
|
-
system_prompt = None
|
213
|
-
if system_prompt_tensor is not None:
|
214
|
-
system_prompt_data = system_prompt_tensor.as_numpy()
|
215
|
-
if system_prompt_data.dtype.type is np.bytes_:
|
216
|
-
system_prompt = system_prompt_data[0].decode('utf-8')
|
217
|
-
|
218
|
-
response_text = self._generate_text(input_text, system_prompt)
|
219
|
-
else:
|
220
|
-
response_text = "No input provided"
|
221
|
-
|
222
|
-
# Create output tensor
|
223
|
-
output_tensor = pb_utils.Tensor(
|
224
|
-
"text_output",
|
225
|
-
np.array([[response_text]], dtype=np.object_)
|
226
|
-
)
|
227
|
-
|
228
|
-
# Create inference response
|
229
|
-
inference_response = pb_utils.InferenceResponse(output_tensors=[output_tensor])
|
230
|
-
responses.append(inference_response)
|
231
|
-
|
232
|
-
return responses
|
233
|
-
|
234
|
-
def finalize(self):
|
235
|
-
"""
|
236
|
-
Clean up resources when the model is unloaded.
|
237
|
-
"""
|
238
|
-
if self.model is not None:
|
239
|
-
self.model = None
|
240
|
-
self.tokenizer = None
|
241
|
-
self._loaded = False
|
242
|
-
|
243
|
-
# Force garbage collection
|
244
|
-
import gc
|
245
|
-
gc.collect()
|
246
|
-
|
247
|
-
if self.device == "cuda":
|
248
|
-
torch.cuda.empty_cache()
|
249
|
-
|
250
|
-
logger.info("Gemma Vision model unloaded")
|
@@ -1,76 +0,0 @@
|
|
1
|
-
import json
|
2
|
-
import numpy as np
|
3
|
-
import triton_python_backend_utils as pb_utils
|
4
|
-
import random
|
5
|
-
|
6
|
-
class TritonPythonModel:
|
7
|
-
"""
|
8
|
-
Simulated Llama model for testing.
|
9
|
-
"""
|
10
|
-
|
11
|
-
def initialize(self, args):
|
12
|
-
"""
|
13
|
-
Initialize the model.
|
14
|
-
"""
|
15
|
-
self.model_config = json.loads(args['model_config'])
|
16
|
-
self.responses = {
|
17
|
-
"artificial intelligence": "Artificial Intelligence (AI) refers to the simulation of human intelligence in machines that are programmed to think and learn like humans. AI encompasses various subfields including machine learning, natural language processing, computer vision, and robotics. Modern AI systems can perform tasks such as understanding natural language, recognizing images, making decisions, and solving complex problems.",
|
18
|
-
"language model": "A language model is a type of artificial intelligence model that's trained to understand and generate human language. Large Language Models (LLMs) like myself are trained on vast amounts of text data to predict the next word in a sequence, enabling them to generate coherent and contextually relevant text, answer questions, translate languages, and perform various text-based tasks.",
|
19
|
-
"machine learning": "Machine learning is a subset of artificial intelligence that enables systems to learn and improve from experience without being explicitly programmed. It focuses on developing algorithms that can access data, learn from it, and make predictions or decisions. Common types include supervised learning, unsupervised learning, and reinforcement learning.",
|
20
|
-
"default": "As an AI language model, I'm designed to process and generate human-like text based on the input I receive. I can assist with a wide range of tasks including answering questions, providing explanations, generating creative content, and engaging in conversations on various topics. How else can I help you today?"
|
21
|
-
}
|
22
|
-
print(f"Initialized Simulated Llama model")
|
23
|
-
|
24
|
-
def execute(self, requests):
|
25
|
-
"""
|
26
|
-
Process inference requests.
|
27
|
-
"""
|
28
|
-
responses = []
|
29
|
-
|
30
|
-
for request in requests:
|
31
|
-
# Get input tensors
|
32
|
-
prompt = pb_utils.get_input_tensor_by_name(request, "prompt")
|
33
|
-
prompt_data = prompt.as_numpy()
|
34
|
-
|
35
|
-
# Decode prompt from bytes
|
36
|
-
if prompt_data.dtype == np.object_:
|
37
|
-
prompt_str = prompt_data[0][0].decode('utf-8')
|
38
|
-
else:
|
39
|
-
prompt_str = prompt_data[0][0]
|
40
|
-
|
41
|
-
# Generate a relevant response based on the prompt
|
42
|
-
generated_text = self._generate_response(prompt_str)
|
43
|
-
|
44
|
-
# Create output tensor
|
45
|
-
output_tensor = pb_utils.Tensor(
|
46
|
-
"text_output",
|
47
|
-
np.array([[generated_text]], dtype=np.object_)
|
48
|
-
)
|
49
|
-
|
50
|
-
# Create and append response
|
51
|
-
inference_response = pb_utils.InferenceResponse(
|
52
|
-
output_tensors=[output_tensor]
|
53
|
-
)
|
54
|
-
responses.append(inference_response)
|
55
|
-
|
56
|
-
return responses
|
57
|
-
|
58
|
-
def _generate_response(self, prompt):
|
59
|
-
"""
|
60
|
-
Generate a response based on the prompt keywords.
|
61
|
-
"""
|
62
|
-
prompt_lower = prompt.lower()
|
63
|
-
|
64
|
-
# Check for keywords in the prompt
|
65
|
-
for key in self.responses:
|
66
|
-
if key in prompt_lower:
|
67
|
-
return self.responses[key]
|
68
|
-
|
69
|
-
# If no keywords match, return the default response
|
70
|
-
return self.responses["default"]
|
71
|
-
|
72
|
-
def finalize(self):
|
73
|
-
"""
|
74
|
-
Clean up resources when the model is unloaded.
|
75
|
-
"""
|
76
|
-
print("Simulated Llama model unloaded")
|
@@ -1,195 +0,0 @@
|
|
1
|
-
import json
|
2
|
-
import numpy as np
|
3
|
-
import triton_python_backend_utils as pb_utils
|
4
|
-
import os
|
5
|
-
import sys
|
6
|
-
import logging
|
7
|
-
|
8
|
-
# Configure logging
|
9
|
-
logging.basicConfig(level=logging.INFO)
|
10
|
-
logger = logging.getLogger("whisper_triton_model")
|
11
|
-
|
12
|
-
class TritonPythonModel:
|
13
|
-
"""
|
14
|
-
Python model for Whisper speech-to-text using a simplified approach.
|
15
|
-
"""
|
16
|
-
|
17
|
-
def initialize(self, args):
|
18
|
-
"""
|
19
|
-
Initialize the model.
|
20
|
-
"""
|
21
|
-
self.model_config = json.loads(args['model_config'])
|
22
|
-
|
23
|
-
# Get model name from config
|
24
|
-
self.model_name = "/models/Whisper-tiny"
|
25
|
-
if 'parameters' in self.model_config:
|
26
|
-
parameters = self.model_config['parameters']
|
27
|
-
if 'model_name' in parameters:
|
28
|
-
self.model_name = parameters['model_name']['string_value']
|
29
|
-
|
30
|
-
logger.info(f"Initializing simplified Whisper model: {self.model_name}")
|
31
|
-
|
32
|
-
# This is a simple mock model for testing
|
33
|
-
# In production, you would use an actual Whisper model
|
34
|
-
self.languages = {
|
35
|
-
"en": "English",
|
36
|
-
"fr": "French",
|
37
|
-
"es": "Spanish",
|
38
|
-
"de": "German",
|
39
|
-
"zh": "Chinese",
|
40
|
-
"ja": "Japanese"
|
41
|
-
}
|
42
|
-
|
43
|
-
# Try loading the model
|
44
|
-
try:
|
45
|
-
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
46
|
-
import torch
|
47
|
-
|
48
|
-
logger.info(f"Attempting to load Whisper model from {self.model_name}")
|
49
|
-
self.processor = AutoProcessor.from_pretrained(self.model_name)
|
50
|
-
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
51
|
-
self.model_name,
|
52
|
-
torch_dtype=torch.float32,
|
53
|
-
device_map="cpu"
|
54
|
-
)
|
55
|
-
self.model.eval()
|
56
|
-
self.model_loaded = True
|
57
|
-
logger.info("Whisper model loaded successfully")
|
58
|
-
except Exception as e:
|
59
|
-
logger.warning(f"Failed to load Whisper model: {e}")
|
60
|
-
self.model_loaded = False
|
61
|
-
logger.info("Using fallback mock transcription")
|
62
|
-
|
63
|
-
logger.info("Simplified Whisper model initialized successfully")
|
64
|
-
|
65
|
-
def execute(self, requests):
|
66
|
-
"""
|
67
|
-
Process inference requests.
|
68
|
-
"""
|
69
|
-
responses = []
|
70
|
-
|
71
|
-
for request in requests:
|
72
|
-
try:
|
73
|
-
# Get input tensors
|
74
|
-
audio_input = pb_utils.get_input_tensor_by_name(request, "audio_input")
|
75
|
-
language_input = pb_utils.get_input_tensor_by_name(request, "language")
|
76
|
-
|
77
|
-
# Get language or use default
|
78
|
-
language = "en"
|
79
|
-
if language_input is not None:
|
80
|
-
# Fix for decoding language input
|
81
|
-
lang_np = language_input.as_numpy()
|
82
|
-
if lang_np.dtype.type is np.bytes_:
|
83
|
-
language = lang_np[0][0].decode('utf-8')
|
84
|
-
elif lang_np.dtype.type is np.object_:
|
85
|
-
language = str(lang_np[0][0])
|
86
|
-
else:
|
87
|
-
language = str(lang_np[0][0])
|
88
|
-
|
89
|
-
# Process audio input
|
90
|
-
if audio_input is not None:
|
91
|
-
audio_data = audio_input.as_numpy()
|
92
|
-
|
93
|
-
# Handle input shape [1, -1]
|
94
|
-
logger.info(f"Original audio data shape: {audio_data.shape}")
|
95
|
-
|
96
|
-
# If the model is loaded, use it for transcription
|
97
|
-
if hasattr(self, 'model_loaded') and self.model_loaded:
|
98
|
-
try:
|
99
|
-
import torch
|
100
|
-
|
101
|
-
# Reshape if needed
|
102
|
-
if len(audio_data.shape) > 2: # [batch, 1, length]
|
103
|
-
audio_data = audio_data.reshape(audio_data.shape[0], -1)
|
104
|
-
|
105
|
-
# Process audio with transformers
|
106
|
-
inputs = self.processor(
|
107
|
-
audio_data,
|
108
|
-
sampling_rate=16000,
|
109
|
-
return_tensors="pt"
|
110
|
-
)
|
111
|
-
|
112
|
-
# Generate transcription
|
113
|
-
with torch.no_grad():
|
114
|
-
generated_ids = self.model.generate(
|
115
|
-
inputs.input_features,
|
116
|
-
language=language,
|
117
|
-
task="transcribe"
|
118
|
-
)
|
119
|
-
|
120
|
-
# Decode transcription
|
121
|
-
transcription = self.processor.batch_decode(
|
122
|
-
generated_ids,
|
123
|
-
skip_special_tokens=True
|
124
|
-
)[0]
|
125
|
-
|
126
|
-
logger.info(f"Generated transcription using model: {transcription}")
|
127
|
-
except Exception as e:
|
128
|
-
logger.error(f"Error using loaded model: {e}")
|
129
|
-
# Fall back to mock transcription
|
130
|
-
audio_length = audio_data.size
|
131
|
-
transcription = self._get_mock_transcription(audio_length, language)
|
132
|
-
else:
|
133
|
-
# Use mock transcription
|
134
|
-
audio_length = audio_data.size
|
135
|
-
transcription = self._get_mock_transcription(audio_length, language)
|
136
|
-
|
137
|
-
else:
|
138
|
-
transcription = "No audio input provided."
|
139
|
-
|
140
|
-
# Create output tensor
|
141
|
-
transcription_tensor = pb_utils.Tensor(
|
142
|
-
"text_output",
|
143
|
-
np.array([transcription], dtype=np.object_)
|
144
|
-
)
|
145
|
-
|
146
|
-
# Create and append response
|
147
|
-
inference_response = pb_utils.InferenceResponse(
|
148
|
-
output_tensors=[transcription_tensor]
|
149
|
-
)
|
150
|
-
responses.append(inference_response)
|
151
|
-
|
152
|
-
except Exception as e:
|
153
|
-
logger.error(f"Error processing request: {e}")
|
154
|
-
# Return error response
|
155
|
-
error_response = pb_utils.InferenceResponse(
|
156
|
-
output_tensors=[
|
157
|
-
pb_utils.Tensor(
|
158
|
-
"text_output",
|
159
|
-
np.array([f"Error: {str(e)}"], dtype=np.object_)
|
160
|
-
)
|
161
|
-
]
|
162
|
-
)
|
163
|
-
responses.append(error_response)
|
164
|
-
|
165
|
-
return responses
|
166
|
-
|
167
|
-
def _get_mock_transcription(self, audio_length, language):
|
168
|
-
"""Generate a mock transcription based on audio length"""
|
169
|
-
if audio_length > 100000:
|
170
|
-
return f"This is a test transcription in {self.languages.get(language, 'English')}. The audio is quite long with {audio_length} samples."
|
171
|
-
elif audio_length > 50000:
|
172
|
-
return f"This is a medium length test transcription in {self.languages.get(language, 'English')}."
|
173
|
-
else:
|
174
|
-
return f"Short test transcription in {self.languages.get(language, 'English')}."
|
175
|
-
|
176
|
-
def finalize(self):
|
177
|
-
"""
|
178
|
-
Clean up resources when the model is unloaded.
|
179
|
-
"""
|
180
|
-
if hasattr(self, 'model') and self.model is not None:
|
181
|
-
self.model = None
|
182
|
-
self.processor = None
|
183
|
-
|
184
|
-
# Force garbage collection
|
185
|
-
import gc
|
186
|
-
gc.collect()
|
187
|
-
|
188
|
-
try:
|
189
|
-
import torch
|
190
|
-
if torch.cuda.is_available():
|
191
|
-
torch.cuda.empty_cache()
|
192
|
-
except ImportError:
|
193
|
-
pass
|
194
|
-
|
195
|
-
logger.info("Whisper model unloaded")
|