isa-model 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. isa_model/__init__.py +5 -0
  2. isa_model/core/model_manager.py +143 -0
  3. isa_model/core/model_registry.py +115 -0
  4. isa_model/core/model_router.py +226 -0
  5. isa_model/core/model_storage.py +133 -0
  6. isa_model/core/model_version.py +0 -0
  7. isa_model/core/resource_manager.py +202 -0
  8. isa_model/core/storage/hf_storage.py +0 -0
  9. isa_model/core/storage/local_storage.py +0 -0
  10. isa_model/core/storage/minio_storage.py +0 -0
  11. isa_model/deployment/mlflow_gateway/__init__.py +8 -0
  12. isa_model/deployment/mlflow_gateway/start_gateway.py +65 -0
  13. isa_model/deployment/unified_multimodal_client.py +341 -0
  14. isa_model/inference/__init__.py +11 -0
  15. isa_model/inference/adapter/triton_adapter.py +453 -0
  16. isa_model/inference/adapter/unified_api.py +248 -0
  17. isa_model/inference/ai_factory.py +354 -0
  18. isa_model/inference/backends/Pytorch/bge_embed_backend.py +188 -0
  19. isa_model/inference/backends/Pytorch/gemma_backend.py +167 -0
  20. isa_model/inference/backends/Pytorch/llama_backend.py +166 -0
  21. isa_model/inference/backends/Pytorch/whisper_backend.py +194 -0
  22. isa_model/inference/backends/__init__.py +53 -0
  23. isa_model/inference/backends/base_backend_client.py +26 -0
  24. isa_model/inference/backends/container_services.py +104 -0
  25. isa_model/inference/backends/local_services.py +72 -0
  26. isa_model/inference/backends/openai_client.py +130 -0
  27. isa_model/inference/backends/replicate_client.py +197 -0
  28. isa_model/inference/backends/third_party_services.py +239 -0
  29. isa_model/inference/backends/triton_client.py +97 -0
  30. isa_model/inference/base.py +46 -0
  31. isa_model/inference/client_sdk/__init__.py +0 -0
  32. isa_model/inference/client_sdk/client.py +134 -0
  33. isa_model/inference/client_sdk/client_data_std.py +34 -0
  34. isa_model/inference/client_sdk/client_sdk_schema.py +16 -0
  35. isa_model/inference/client_sdk/exceptions.py +0 -0
  36. isa_model/inference/engine/triton/model_repository/bge/1/model.py +174 -0
  37. isa_model/inference/engine/triton/model_repository/gemma/1/model.py +250 -0
  38. isa_model/inference/engine/triton/model_repository/llama/1/model.py +76 -0
  39. isa_model/inference/engine/triton/model_repository/whisper/1/model.py +195 -0
  40. isa_model/inference/providers/__init__.py +19 -0
  41. isa_model/inference/providers/base_provider.py +30 -0
  42. isa_model/inference/providers/model_cache_manager.py +341 -0
  43. isa_model/inference/providers/ollama_provider.py +73 -0
  44. isa_model/inference/providers/openai_provider.py +87 -0
  45. isa_model/inference/providers/replicate_provider.py +94 -0
  46. isa_model/inference/providers/triton_provider.py +439 -0
  47. isa_model/inference/providers/vllm_provider.py +0 -0
  48. isa_model/inference/providers/yyds_provider.py +83 -0
  49. isa_model/inference/services/__init__.py +14 -0
  50. isa_model/inference/services/audio/fish_speech/handler.py +215 -0
  51. isa_model/inference/services/audio/runpod_tts_fish_service.py +212 -0
  52. isa_model/inference/services/audio/triton_speech_service.py +138 -0
  53. isa_model/inference/services/audio/whisper_service.py +186 -0
  54. isa_model/inference/services/audio/yyds_audio_service.py +71 -0
  55. isa_model/inference/services/base_service.py +106 -0
  56. isa_model/inference/services/base_tts_service.py +66 -0
  57. isa_model/inference/services/embedding/bge_service.py +183 -0
  58. isa_model/inference/services/embedding/ollama_embed_service.py +85 -0
  59. isa_model/inference/services/embedding/ollama_rerank_service.py +118 -0
  60. isa_model/inference/services/embedding/onnx_rerank_service.py +73 -0
  61. isa_model/inference/services/llm/__init__.py +16 -0
  62. isa_model/inference/services/llm/gemma_service.py +143 -0
  63. isa_model/inference/services/llm/llama_service.py +143 -0
  64. isa_model/inference/services/llm/ollama_llm_service.py +108 -0
  65. isa_model/inference/services/llm/openai_llm_service.py +129 -0
  66. isa_model/inference/services/llm/replicate_llm_service.py +179 -0
  67. isa_model/inference/services/llm/triton_llm_service.py +230 -0
  68. isa_model/inference/services/others/table_transformer_service.py +61 -0
  69. isa_model/inference/services/vision/__init__.py +12 -0
  70. isa_model/inference/services/vision/helpers/image_utils.py +58 -0
  71. isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
  72. isa_model/inference/services/vision/ollama_vision_service.py +60 -0
  73. isa_model/inference/services/vision/replicate_vision_service.py +241 -0
  74. isa_model/inference/services/vision/triton_vision_service.py +199 -0
  75. isa_model/inference/services/vision/yyds_vision_service.py +80 -0
  76. isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
  77. isa_model/inference/utils/conversion/onnx_converter.py +0 -0
  78. isa_model/inference/utils/conversion/torch_converter.py +0 -0
  79. isa_model/scripts/inference_tracker.py +283 -0
  80. isa_model/scripts/mlflow_manager.py +379 -0
  81. isa_model/scripts/model_registry.py +465 -0
  82. isa_model/scripts/start_mlflow.py +95 -0
  83. isa_model/scripts/training_tracker.py +257 -0
  84. isa_model/training/engine/llama_factory/__init__.py +39 -0
  85. isa_model/training/engine/llama_factory/config.py +115 -0
  86. isa_model/training/engine/llama_factory/data_adapter.py +284 -0
  87. isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
  88. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
  89. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
  90. isa_model/training/engine/llama_factory/factory.py +331 -0
  91. isa_model/training/engine/llama_factory/rl.py +254 -0
  92. isa_model/training/engine/llama_factory/trainer.py +171 -0
  93. isa_model/training/image_model/configs/create_config.py +37 -0
  94. isa_model/training/image_model/configs/create_flux_config.py +26 -0
  95. isa_model/training/image_model/configs/create_lora_config.py +21 -0
  96. isa_model/training/image_model/prepare_massed_compute.py +97 -0
  97. isa_model/training/image_model/prepare_upload.py +17 -0
  98. isa_model/training/image_model/raw_data/create_captions.py +16 -0
  99. isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
  100. isa_model/training/image_model/raw_data/pre_processing.py +200 -0
  101. isa_model/training/image_model/train/train.py +42 -0
  102. isa_model/training/image_model/train/train_flux.py +41 -0
  103. isa_model/training/image_model/train/train_lora.py +57 -0
  104. isa_model/training/image_model/train_main.py +25 -0
  105. isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
  106. isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
  107. isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
  108. isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
  109. isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
  110. isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
  111. isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
  112. isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
  113. isa_model-0.1.0.dist-info/METADATA +116 -0
  114. isa_model-0.1.0.dist-info/RECORD +117 -0
  115. isa_model-0.1.0.dist-info/WHEEL +5 -0
  116. isa_model-0.1.0.dist-info/licenses/LICENSE +21 -0
  117. isa_model-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,174 @@
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ import triton_python_backend_utils as pb_utils
5
+ import sys
6
+ import logging
7
+ import torch
8
+
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger("bge_triton_model")
12
+
13
+ class TritonPythonModel:
14
+ """
15
+ Python model for BGE embedding.
16
+ """
17
+
18
+ def initialize(self, args):
19
+ """
20
+ Initialize the model.
21
+ """
22
+ self.model_config = json.loads(args['model_config'])
23
+ self.model_path = os.environ.get("BGE_MODEL_PATH", "/models/Bge-m3")
24
+ # Always use CPU for testing
25
+ self.device = "cpu"
26
+ self.model = None
27
+ self.tokenizer = None
28
+ self._loaded = False
29
+
30
+ # Default configuration
31
+ self.config = {
32
+ "normalize": True,
33
+ "max_length": 512,
34
+ "pooling_method": "cls" # Use CLS token for sentence embedding
35
+ }
36
+
37
+ self._load_model()
38
+
39
+ logger.info(f"Initialized BGE embedding model on {self.device}")
40
+
41
+ def _load_model(self):
42
+ """Load the BGE model and tokenizer"""
43
+ if self._loaded:
44
+ return
45
+
46
+ try:
47
+ from transformers import AutoModel, AutoTokenizer
48
+
49
+ # Load tokenizer
50
+ logger.info(f"Loading BGE tokenizer from {self.model_path}")
51
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
52
+
53
+ # Load model on CPU for testing
54
+ logger.info(f"Loading BGE model on {self.device}")
55
+ self.model = AutoModel.from_pretrained(
56
+ self.model_path,
57
+ torch_dtype=torch.float32,
58
+ device_map="cpu" # Force CPU
59
+ )
60
+
61
+ self.model.eval()
62
+ self._loaded = True
63
+ logger.info("BGE model loaded successfully")
64
+
65
+ except Exception as e:
66
+ logger.error(f"Failed to load BGE model: {str(e)}")
67
+ # Fall back to dummy model for testing if model loading fails
68
+ self._loaded = False
69
+
70
+ def _generate_embeddings(self, texts, normalize=True):
71
+ """Generate embeddings for the given texts"""
72
+ if not self._loaded:
73
+ # Return random embeddings for testing
74
+ logger.warning("Model not loaded, returning random embeddings")
75
+ # Generate random embeddings with dimension 1024 (typical for BGE)
76
+ num_texts = len(texts) if isinstance(texts, list) else 1
77
+ return np.random.randn(num_texts, 1024).astype(np.float32)
78
+
79
+ try:
80
+ # Ensure texts is a list
81
+ if isinstance(texts, str):
82
+ texts = [texts]
83
+
84
+ # Tokenize the texts
85
+ inputs = self.tokenizer(
86
+ texts,
87
+ padding=True,
88
+ truncation=True,
89
+ max_length=self.config["max_length"],
90
+ return_tensors="pt"
91
+ ).to(self.device)
92
+
93
+ # Generate embeddings
94
+ with torch.no_grad():
95
+ outputs = self.model(**inputs)
96
+
97
+ # Use [CLS] token embedding as the sentence embedding
98
+ embeddings = outputs.last_hidden_state[:, 0, :]
99
+
100
+ # Normalize embeddings if required
101
+ if normalize:
102
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
103
+
104
+ # Convert to numpy array
105
+ embeddings_np = embeddings.cpu().numpy()
106
+
107
+ return embeddings_np
108
+
109
+ except Exception as e:
110
+ logger.error(f"Error during BGE embedding generation: {str(e)}")
111
+ # Return random embeddings in case of error
112
+ num_texts = len(texts) if isinstance(texts, list) else 1
113
+ return np.random.randn(num_texts, 1024).astype(np.float32)
114
+
115
+ def execute(self, requests):
116
+ """
117
+ Process inference requests.
118
+ """
119
+ responses = []
120
+
121
+ for request in requests:
122
+ # Get input tensors
123
+ input_tensor = pb_utils.get_input_tensor_by_name(request, "text_input")
124
+
125
+ # Get texts from input tensor
126
+ if input_tensor is not None:
127
+ input_data = input_tensor.as_numpy()
128
+ texts = []
129
+
130
+ # Convert bytes to strings
131
+ for i in range(len(input_data)):
132
+ if input_data[i].dtype.type is np.bytes_:
133
+ texts.append(input_data[i].decode('utf-8'))
134
+ else:
135
+ texts.append(str(input_data[i]))
136
+
137
+ # Generate embeddings
138
+ embeddings = self._generate_embeddings(texts)
139
+
140
+ # Create output tensor
141
+ output_tensor = pb_utils.Tensor(
142
+ "embedding_output",
143
+ embeddings.astype(np.float32)
144
+ )
145
+ else:
146
+ # If no input is provided, return empty tensor
147
+ output_tensor = pb_utils.Tensor(
148
+ "embedding_output",
149
+ np.array([], dtype=np.float32)
150
+ )
151
+
152
+ # Create inference response
153
+ inference_response = pb_utils.InferenceResponse(output_tensors=[output_tensor])
154
+ responses.append(inference_response)
155
+
156
+ return responses
157
+
158
+ def finalize(self):
159
+ """
160
+ Clean up resources when the model is unloaded.
161
+ """
162
+ if self.model is not None:
163
+ self.model = None
164
+ self.tokenizer = None
165
+ self._loaded = False
166
+
167
+ # Force garbage collection
168
+ import gc
169
+ gc.collect()
170
+
171
+ if self.device == "cuda":
172
+ torch.cuda.empty_cache()
173
+
174
+ logger.info("BGE embedding model unloaded")
@@ -0,0 +1,250 @@
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ import triton_python_backend_utils as pb_utils
5
+ import sys
6
+ import base64
7
+ import torch
8
+ from PIL import Image
9
+ import io
10
+ import logging
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger("gemma_triton_model")
15
+
16
+ class TritonPythonModel:
17
+ """
18
+ Python model for Gemma Vision LLM (VLM).
19
+ """
20
+
21
+ def initialize(self, args):
22
+ """
23
+ Initialize the model.
24
+ """
25
+ self.model_config = json.loads(args['model_config'])
26
+
27
+ # Try different possible model paths
28
+ possible_paths = [
29
+ "/models/Gemma3-4B", # Original path
30
+ "/models/gemma", # Alternative path
31
+ ]
32
+
33
+ # Find the first path that exists
34
+ self.model_path = None
35
+ for path in possible_paths:
36
+ if os.path.exists(path):
37
+ self.model_path = path
38
+ logger.info(f"Found model at path: {path}")
39
+ break
40
+
41
+ if self.model_path is None:
42
+ logger.error("Could not find model path!")
43
+ self.model_path = "/models/Gemma3-4B" # Default, will fail later
44
+
45
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
46
+ self.model = None
47
+ self.tokenizer = None
48
+ self._loaded = False
49
+
50
+ # Default generation config
51
+ self.default_config = {
52
+ "max_new_tokens": 512,
53
+ "temperature": 0.7,
54
+ "top_p": 0.9,
55
+ "top_k": 50,
56
+ "repetition_penalty": 1.1,
57
+ "do_sample": True
58
+ }
59
+
60
+ logger.info(f"Initializing Gemma Vision model at {self.model_path} on {self.device}")
61
+ self._load_model()
62
+
63
+ if self._loaded:
64
+ logger.info("Gemma Vision model initialized successfully")
65
+ else:
66
+ logger.error("Failed to initialize Gemma Vision model")
67
+
68
+ def _load_model(self):
69
+ """Load the Gemma model and tokenizer"""
70
+ if self._loaded:
71
+ return
72
+
73
+ try:
74
+ from transformers import AutoModelForCausalLM, AutoTokenizer
75
+
76
+ # Log environment information
77
+ logger.info(f"Possible model paths:")
78
+ logger.info(f"Current dir: {os.getcwd()}")
79
+ logger.info(f"Model path exists: {os.path.exists(self.model_path)}")
80
+ logger.info(f"Directory listing of /models:")
81
+ if os.path.exists("/models"):
82
+ logger.info(", ".join(os.listdir("/models")))
83
+
84
+ # Load tokenizer
85
+ logger.info(f"Loading Gemma tokenizer from {self.model_path}")
86
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
87
+
88
+ # Load model
89
+ logger.info(f"Loading Gemma model on {self.device}")
90
+ if self.device == "cpu":
91
+ self.model = AutoModelForCausalLM.from_pretrained(
92
+ self.model_path,
93
+ torch_dtype=torch.float32,
94
+ low_cpu_mem_usage=True,
95
+ device_map="auto"
96
+ )
97
+ else: # cuda
98
+ self.model = AutoModelForCausalLM.from_pretrained(
99
+ self.model_path,
100
+ torch_dtype=torch.float16, # Use half precision on GPU
101
+ device_map="auto"
102
+ )
103
+
104
+ self.model.eval()
105
+ self._loaded = True
106
+ logger.info("Gemma model loaded successfully")
107
+
108
+ except Exception as e:
109
+ logger.error(f"Failed to load Gemma model: {str(e)}")
110
+ # Fall back to dummy model for testing if model loading fails
111
+ self._loaded = False
112
+
113
+ def _process_image(self, image_data):
114
+ """Process base64 image data"""
115
+ try:
116
+ # Extract the base64 part if it's a data URL
117
+ if isinstance(image_data, str) and image_data.startswith("data:image"):
118
+ # Extract the base64 part
119
+ image_data = image_data.split(",")[1]
120
+
121
+ # Decode base64
122
+ image_bytes = base64.b64decode(image_data)
123
+
124
+ # Open as PIL Image
125
+ image = Image.open(io.BytesIO(image_bytes))
126
+
127
+ # Process for model input if needed
128
+ # For now, we're just returning the image for text description
129
+ return image
130
+
131
+ except Exception as e:
132
+ logger.error(f"Error processing image: {str(e)}")
133
+ return None
134
+
135
+ def _generate_text(self, prompt, system_prompt=None, generation_config=None):
136
+ """Generate text using the Gemma model"""
137
+ if not self._loaded:
138
+ return "Model not loaded. Using fallback response: I can see an image but cannot analyze it properly as the vision model is not loaded."
139
+
140
+ try:
141
+ # Get generation config
142
+ config = self.default_config.copy()
143
+ if generation_config:
144
+ config.update(generation_config)
145
+
146
+ # Format the prompt with system prompt if provided
147
+ if system_prompt:
148
+ # Gemma uses a specific format for system prompts
149
+ formatted_prompt = f"<bos><start_of_turn>system\n{system_prompt}<end_of_turn>\n<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model"
150
+ else:
151
+ formatted_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model"
152
+
153
+ # Tokenize the prompt
154
+ inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.device)
155
+
156
+ # Generate text
157
+ with torch.no_grad():
158
+ outputs = self.model.generate(
159
+ inputs.input_ids,
160
+ attention_mask=inputs.attention_mask,
161
+ pad_token_id=self.tokenizer.eos_token_id,
162
+ **config
163
+ )
164
+
165
+ # Decode the generated text
166
+ generated_text = self.tokenizer.decode(
167
+ outputs[0][inputs.input_ids.shape[1]:],
168
+ skip_special_tokens=True
169
+ )
170
+
171
+ return generated_text.strip()
172
+
173
+ except Exception as e:
174
+ logger.error(f"Error during Gemma text generation: {str(e)}")
175
+ return f"Error generating response: {str(e)}"
176
+
177
+ def execute(self, requests):
178
+ """
179
+ Process inference requests.
180
+ """
181
+ responses = []
182
+
183
+ for request in requests:
184
+ # Get input tensors
185
+ input_tensor = pb_utils.get_input_tensor_by_name(request, "prompt")
186
+
187
+ # Convert to string
188
+ if input_tensor is not None:
189
+ input_data = input_tensor.as_numpy()
190
+ if input_data.dtype.type is np.bytes_:
191
+ input_text = input_data[0][0].decode('utf-8')
192
+ else:
193
+ input_text = str(input_data[0][0])
194
+
195
+ # Check if the input contains an image (base64)
196
+ if "data:image" in input_text:
197
+ # Extract image description query
198
+ query = "Describe this image in detail."
199
+ if "?" in input_text:
200
+ parts = input_text.split("?")
201
+ query = parts[0] + "?"
202
+
203
+ # For image inputs
204
+ if self._loaded:
205
+ response_text = self._generate_text(input_text)
206
+ else:
207
+ # Fallback if model not loaded
208
+ response_text = "Model not loaded. Using fallback response: I can see an image but cannot analyze it properly as the vision model is not loaded."
209
+ else:
210
+ # For text-only prompts
211
+ system_prompt_tensor = pb_utils.get_input_tensor_by_name(request, "system_prompt")
212
+ system_prompt = None
213
+ if system_prompt_tensor is not None:
214
+ system_prompt_data = system_prompt_tensor.as_numpy()
215
+ if system_prompt_data.dtype.type is np.bytes_:
216
+ system_prompt = system_prompt_data[0].decode('utf-8')
217
+
218
+ response_text = self._generate_text(input_text, system_prompt)
219
+ else:
220
+ response_text = "No input provided"
221
+
222
+ # Create output tensor
223
+ output_tensor = pb_utils.Tensor(
224
+ "text_output",
225
+ np.array([[response_text]], dtype=np.object_)
226
+ )
227
+
228
+ # Create inference response
229
+ inference_response = pb_utils.InferenceResponse(output_tensors=[output_tensor])
230
+ responses.append(inference_response)
231
+
232
+ return responses
233
+
234
+ def finalize(self):
235
+ """
236
+ Clean up resources when the model is unloaded.
237
+ """
238
+ if self.model is not None:
239
+ self.model = None
240
+ self.tokenizer = None
241
+ self._loaded = False
242
+
243
+ # Force garbage collection
244
+ import gc
245
+ gc.collect()
246
+
247
+ if self.device == "cuda":
248
+ torch.cuda.empty_cache()
249
+
250
+ logger.info("Gemma Vision model unloaded")
@@ -0,0 +1,76 @@
1
+ import json
2
+ import numpy as np
3
+ import triton_python_backend_utils as pb_utils
4
+ import random
5
+
6
+ class TritonPythonModel:
7
+ """
8
+ Simulated Llama model for testing.
9
+ """
10
+
11
+ def initialize(self, args):
12
+ """
13
+ Initialize the model.
14
+ """
15
+ self.model_config = json.loads(args['model_config'])
16
+ self.responses = {
17
+ "artificial intelligence": "Artificial Intelligence (AI) refers to the simulation of human intelligence in machines that are programmed to think and learn like humans. AI encompasses various subfields including machine learning, natural language processing, computer vision, and robotics. Modern AI systems can perform tasks such as understanding natural language, recognizing images, making decisions, and solving complex problems.",
18
+ "language model": "A language model is a type of artificial intelligence model that's trained to understand and generate human language. Large Language Models (LLMs) like myself are trained on vast amounts of text data to predict the next word in a sequence, enabling them to generate coherent and contextually relevant text, answer questions, translate languages, and perform various text-based tasks.",
19
+ "machine learning": "Machine learning is a subset of artificial intelligence that enables systems to learn and improve from experience without being explicitly programmed. It focuses on developing algorithms that can access data, learn from it, and make predictions or decisions. Common types include supervised learning, unsupervised learning, and reinforcement learning.",
20
+ "default": "As an AI language model, I'm designed to process and generate human-like text based on the input I receive. I can assist with a wide range of tasks including answering questions, providing explanations, generating creative content, and engaging in conversations on various topics. How else can I help you today?"
21
+ }
22
+ print(f"Initialized Simulated Llama model")
23
+
24
+ def execute(self, requests):
25
+ """
26
+ Process inference requests.
27
+ """
28
+ responses = []
29
+
30
+ for request in requests:
31
+ # Get input tensors
32
+ prompt = pb_utils.get_input_tensor_by_name(request, "prompt")
33
+ prompt_data = prompt.as_numpy()
34
+
35
+ # Decode prompt from bytes
36
+ if prompt_data.dtype == np.object_:
37
+ prompt_str = prompt_data[0][0].decode('utf-8')
38
+ else:
39
+ prompt_str = prompt_data[0][0]
40
+
41
+ # Generate a relevant response based on the prompt
42
+ generated_text = self._generate_response(prompt_str)
43
+
44
+ # Create output tensor
45
+ output_tensor = pb_utils.Tensor(
46
+ "text_output",
47
+ np.array([[generated_text]], dtype=np.object_)
48
+ )
49
+
50
+ # Create and append response
51
+ inference_response = pb_utils.InferenceResponse(
52
+ output_tensors=[output_tensor]
53
+ )
54
+ responses.append(inference_response)
55
+
56
+ return responses
57
+
58
+ def _generate_response(self, prompt):
59
+ """
60
+ Generate a response based on the prompt keywords.
61
+ """
62
+ prompt_lower = prompt.lower()
63
+
64
+ # Check for keywords in the prompt
65
+ for key in self.responses:
66
+ if key in prompt_lower:
67
+ return self.responses[key]
68
+
69
+ # If no keywords match, return the default response
70
+ return self.responses["default"]
71
+
72
+ def finalize(self):
73
+ """
74
+ Clean up resources when the model is unloaded.
75
+ """
76
+ print("Simulated Llama model unloaded")
@@ -0,0 +1,195 @@
1
+ import json
2
+ import numpy as np
3
+ import triton_python_backend_utils as pb_utils
4
+ import os
5
+ import sys
6
+ import logging
7
+
8
+ # Configure logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger("whisper_triton_model")
11
+
12
+ class TritonPythonModel:
13
+ """
14
+ Python model for Whisper speech-to-text using a simplified approach.
15
+ """
16
+
17
+ def initialize(self, args):
18
+ """
19
+ Initialize the model.
20
+ """
21
+ self.model_config = json.loads(args['model_config'])
22
+
23
+ # Get model name from config
24
+ self.model_name = "/models/Whisper-tiny"
25
+ if 'parameters' in self.model_config:
26
+ parameters = self.model_config['parameters']
27
+ if 'model_name' in parameters:
28
+ self.model_name = parameters['model_name']['string_value']
29
+
30
+ logger.info(f"Initializing simplified Whisper model: {self.model_name}")
31
+
32
+ # This is a simple mock model for testing
33
+ # In production, you would use an actual Whisper model
34
+ self.languages = {
35
+ "en": "English",
36
+ "fr": "French",
37
+ "es": "Spanish",
38
+ "de": "German",
39
+ "zh": "Chinese",
40
+ "ja": "Japanese"
41
+ }
42
+
43
+ # Try loading the model
44
+ try:
45
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
46
+ import torch
47
+
48
+ logger.info(f"Attempting to load Whisper model from {self.model_name}")
49
+ self.processor = AutoProcessor.from_pretrained(self.model_name)
50
+ self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
51
+ self.model_name,
52
+ torch_dtype=torch.float32,
53
+ device_map="cpu"
54
+ )
55
+ self.model.eval()
56
+ self.model_loaded = True
57
+ logger.info("Whisper model loaded successfully")
58
+ except Exception as e:
59
+ logger.warning(f"Failed to load Whisper model: {e}")
60
+ self.model_loaded = False
61
+ logger.info("Using fallback mock transcription")
62
+
63
+ logger.info("Simplified Whisper model initialized successfully")
64
+
65
+ def execute(self, requests):
66
+ """
67
+ Process inference requests.
68
+ """
69
+ responses = []
70
+
71
+ for request in requests:
72
+ try:
73
+ # Get input tensors
74
+ audio_input = pb_utils.get_input_tensor_by_name(request, "audio_input")
75
+ language_input = pb_utils.get_input_tensor_by_name(request, "language")
76
+
77
+ # Get language or use default
78
+ language = "en"
79
+ if language_input is not None:
80
+ # Fix for decoding language input
81
+ lang_np = language_input.as_numpy()
82
+ if lang_np.dtype.type is np.bytes_:
83
+ language = lang_np[0][0].decode('utf-8')
84
+ elif lang_np.dtype.type is np.object_:
85
+ language = str(lang_np[0][0])
86
+ else:
87
+ language = str(lang_np[0][0])
88
+
89
+ # Process audio input
90
+ if audio_input is not None:
91
+ audio_data = audio_input.as_numpy()
92
+
93
+ # Handle input shape [1, -1]
94
+ logger.info(f"Original audio data shape: {audio_data.shape}")
95
+
96
+ # If the model is loaded, use it for transcription
97
+ if hasattr(self, 'model_loaded') and self.model_loaded:
98
+ try:
99
+ import torch
100
+
101
+ # Reshape if needed
102
+ if len(audio_data.shape) > 2: # [batch, 1, length]
103
+ audio_data = audio_data.reshape(audio_data.shape[0], -1)
104
+
105
+ # Process audio with transformers
106
+ inputs = self.processor(
107
+ audio_data,
108
+ sampling_rate=16000,
109
+ return_tensors="pt"
110
+ )
111
+
112
+ # Generate transcription
113
+ with torch.no_grad():
114
+ generated_ids = self.model.generate(
115
+ inputs.input_features,
116
+ language=language,
117
+ task="transcribe"
118
+ )
119
+
120
+ # Decode transcription
121
+ transcription = self.processor.batch_decode(
122
+ generated_ids,
123
+ skip_special_tokens=True
124
+ )[0]
125
+
126
+ logger.info(f"Generated transcription using model: {transcription}")
127
+ except Exception as e:
128
+ logger.error(f"Error using loaded model: {e}")
129
+ # Fall back to mock transcription
130
+ audio_length = audio_data.size
131
+ transcription = self._get_mock_transcription(audio_length, language)
132
+ else:
133
+ # Use mock transcription
134
+ audio_length = audio_data.size
135
+ transcription = self._get_mock_transcription(audio_length, language)
136
+
137
+ else:
138
+ transcription = "No audio input provided."
139
+
140
+ # Create output tensor
141
+ transcription_tensor = pb_utils.Tensor(
142
+ "text_output",
143
+ np.array([transcription], dtype=np.object_)
144
+ )
145
+
146
+ # Create and append response
147
+ inference_response = pb_utils.InferenceResponse(
148
+ output_tensors=[transcription_tensor]
149
+ )
150
+ responses.append(inference_response)
151
+
152
+ except Exception as e:
153
+ logger.error(f"Error processing request: {e}")
154
+ # Return error response
155
+ error_response = pb_utils.InferenceResponse(
156
+ output_tensors=[
157
+ pb_utils.Tensor(
158
+ "text_output",
159
+ np.array([f"Error: {str(e)}"], dtype=np.object_)
160
+ )
161
+ ]
162
+ )
163
+ responses.append(error_response)
164
+
165
+ return responses
166
+
167
+ def _get_mock_transcription(self, audio_length, language):
168
+ """Generate a mock transcription based on audio length"""
169
+ if audio_length > 100000:
170
+ return f"This is a test transcription in {self.languages.get(language, 'English')}. The audio is quite long with {audio_length} samples."
171
+ elif audio_length > 50000:
172
+ return f"This is a medium length test transcription in {self.languages.get(language, 'English')}."
173
+ else:
174
+ return f"Short test transcription in {self.languages.get(language, 'English')}."
175
+
176
+ def finalize(self):
177
+ """
178
+ Clean up resources when the model is unloaded.
179
+ """
180
+ if hasattr(self, 'model') and self.model is not None:
181
+ self.model = None
182
+ self.processor = None
183
+
184
+ # Force garbage collection
185
+ import gc
186
+ gc.collect()
187
+
188
+ try:
189
+ import torch
190
+ if torch.cuda.is_available():
191
+ torch.cuda.empty_cache()
192
+ except ImportError:
193
+ pass
194
+
195
+ logger.info("Whisper model unloaded")