isa-model 0.2.0__py3-none-any.whl → 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/core/storage/hf_storage.py +419 -0
- isa_model/deployment/__init__.py +52 -0
- isa_model/deployment/core/__init__.py +34 -0
- isa_model/deployment/core/deployment_config.py +356 -0
- isa_model/deployment/core/deployment_manager.py +549 -0
- isa_model/deployment/core/isa_deployment_service.py +401 -0
- isa_model/eval/factory.py +381 -140
- isa_model/inference/ai_factory.py +142 -240
- isa_model/inference/providers/ml_provider.py +50 -0
- isa_model/inference/services/audio/openai_tts_service.py +104 -3
- isa_model/inference/services/embedding/base_embed_service.py +112 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
- isa_model/inference/services/llm/__init__.py +2 -0
- isa_model/inference/services/llm/base_llm_service.py +111 -1
- isa_model/inference/services/llm/ollama_llm_service.py +234 -26
- isa_model/inference/services/llm/openai_llm_service.py +243 -28
- isa_model/inference/services/llm/triton_llm_service.py +481 -0
- isa_model/inference/services/ml/base_ml_service.py +78 -0
- isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
- isa_model/inference/services/vision/__init__.py +3 -3
- isa_model/inference/services/vision/base_image_gen_service.py +161 -0
- isa_model/inference/services/vision/base_vision_service.py +177 -0
- isa_model/inference/services/vision/ollama_vision_service.py +143 -17
- isa_model/inference/services/vision/replicate_image_gen_service.py +139 -7
- isa_model/training/__init__.py +62 -32
- isa_model/training/cloud/__init__.py +22 -0
- isa_model/training/cloud/job_orchestrator.py +402 -0
- isa_model/training/cloud/runpod_trainer.py +454 -0
- isa_model/training/cloud/storage_manager.py +482 -0
- isa_model/training/core/__init__.py +23 -0
- isa_model/training/core/config.py +181 -0
- isa_model/training/core/dataset.py +222 -0
- isa_model/training/core/trainer.py +720 -0
- isa_model/training/core/utils.py +213 -0
- isa_model/training/factory.py +229 -198
- isa_model-0.2.9.dist-info/METADATA +465 -0
- isa_model-0.2.9.dist-info/RECORD +86 -0
- isa_model/core/model_router.py +0 -226
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +0 -202
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
- isa_model/training/engine/llama_factory/__init__.py +0 -39
- isa_model/training/engine/llama_factory/config.py +0 -115
- isa_model/training/engine/llama_factory/data_adapter.py +0 -284
- isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
- isa_model/training/engine/llama_factory/factory.py +0 -331
- isa_model/training/engine/llama_factory/rl.py +0 -254
- isa_model/training/engine/llama_factory/trainer.py +0 -171
- isa_model/training/image_model/configs/create_config.py +0 -37
- isa_model/training/image_model/configs/create_flux_config.py +0 -26
- isa_model/training/image_model/configs/create_lora_config.py +0 -21
- isa_model/training/image_model/prepare_massed_compute.py +0 -97
- isa_model/training/image_model/prepare_upload.py +0 -17
- isa_model/training/image_model/raw_data/create_captions.py +0 -16
- isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
- isa_model/training/image_model/raw_data/pre_processing.py +0 -200
- isa_model/training/image_model/train/train.py +0 -42
- isa_model/training/image_model/train/train_flux.py +0 -41
- isa_model/training/image_model/train/train_lora.py +0 -57
- isa_model/training/image_model/train_main.py +0 -25
- isa_model-0.2.0.dist-info/METADATA +0 -327
- isa_model-0.2.0.dist-info/RECORD +0 -92
- isa_model-0.2.0.dist-info/licenses/LICENSE +0 -21
- /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
- {isa_model-0.2.0.dist-info → isa_model-0.2.9.dist-info}/WHEEL +0 -0
- {isa_model-0.2.0.dist-info → isa_model-0.2.9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,481 @@
|
|
1
|
+
"""
|
2
|
+
Triton LLM Service
|
3
|
+
|
4
|
+
Provides LLM-specific functionality using Triton Inference Server as the backend.
|
5
|
+
Integrates with the existing TritonProvider for low-level operations.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import logging
|
9
|
+
from typing import Dict, Any, List, Optional, Union, AsyncGenerator
|
10
|
+
import json
|
11
|
+
import asyncio
|
12
|
+
|
13
|
+
from ..base_service import BaseService
|
14
|
+
from ...providers.triton_provider import TritonProvider
|
15
|
+
from ...base import ModelType, Capability
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
class TritonLLMService(BaseService):
|
21
|
+
"""
|
22
|
+
LLM service using Triton Inference Server.
|
23
|
+
|
24
|
+
This service provides high-level LLM operations like text generation,
|
25
|
+
chat completion, and streaming responses using Triton as the backend.
|
26
|
+
|
27
|
+
Features:
|
28
|
+
- Text generation with customizable parameters
|
29
|
+
- Chat completion with conversation context
|
30
|
+
- Streaming responses for real-time interaction
|
31
|
+
- Multiple model support
|
32
|
+
- Automatic model loading and management
|
33
|
+
- Integration with model registry
|
34
|
+
|
35
|
+
Example:
|
36
|
+
```python
|
37
|
+
from isa_model.inference.services.llm import TritonLLMService
|
38
|
+
|
39
|
+
# Initialize service
|
40
|
+
service = TritonLLMService({
|
41
|
+
"triton_url": "localhost:8001",
|
42
|
+
"default_model": "gemma-4b-alpaca"
|
43
|
+
})
|
44
|
+
|
45
|
+
# Generate text
|
46
|
+
response = await service.generate_text(
|
47
|
+
prompt="What is artificial intelligence?",
|
48
|
+
model_name="gemma-4b-alpaca",
|
49
|
+
max_tokens=100
|
50
|
+
)
|
51
|
+
|
52
|
+
# Chat completion
|
53
|
+
messages = [
|
54
|
+
{"role": "user", "content": "Hello, how are you?"}
|
55
|
+
]
|
56
|
+
response = await service.chat_completion(
|
57
|
+
messages=messages,
|
58
|
+
model_name="gemma-4b-alpaca"
|
59
|
+
)
|
60
|
+
|
61
|
+
# Streaming generation
|
62
|
+
async for chunk in service.generate_text_stream(
|
63
|
+
prompt="Tell me a story",
|
64
|
+
model_name="gemma-4b-alpaca"
|
65
|
+
):
|
66
|
+
print(chunk["text"], end="")
|
67
|
+
```
|
68
|
+
"""
|
69
|
+
|
70
|
+
def __init__(self, config: Dict[str, Any]):
|
71
|
+
"""
|
72
|
+
Initialize Triton LLM service.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
config: Service configuration including Triton connection details
|
76
|
+
"""
|
77
|
+
super().__init__(config)
|
78
|
+
|
79
|
+
# Initialize Triton provider
|
80
|
+
self.triton_provider = TritonProvider(config)
|
81
|
+
|
82
|
+
# Service configuration
|
83
|
+
self.default_model = config.get("default_model", "model")
|
84
|
+
self.max_tokens_limit = config.get("max_tokens_limit", 2048)
|
85
|
+
self.temperature_default = config.get("temperature_default", 0.7)
|
86
|
+
self.top_p_default = config.get("top_p_default", 0.9)
|
87
|
+
self.top_k_default = config.get("top_k_default", 50)
|
88
|
+
|
89
|
+
# Chat templates
|
90
|
+
self.chat_templates = {
|
91
|
+
"gemma": self._format_gemma_chat,
|
92
|
+
"llama": self._format_llama_chat,
|
93
|
+
"default": self._format_default_chat
|
94
|
+
}
|
95
|
+
|
96
|
+
logger.info(f"TritonLLMService initialized with default model: {self.default_model}")
|
97
|
+
|
98
|
+
async def initialize(self) -> bool:
|
99
|
+
"""Initialize the service and check Triton connectivity"""
|
100
|
+
try:
|
101
|
+
# Check if Triton server is live
|
102
|
+
if not self.triton_provider.is_server_live():
|
103
|
+
logger.error("Triton server is not live")
|
104
|
+
return False
|
105
|
+
|
106
|
+
# Check if default model is ready
|
107
|
+
if not self.triton_provider.is_model_ready(self.default_model):
|
108
|
+
logger.warning(f"Default model {self.default_model} is not ready")
|
109
|
+
|
110
|
+
logger.info("TritonLLMService initialized successfully")
|
111
|
+
return True
|
112
|
+
|
113
|
+
except Exception as e:
|
114
|
+
logger.error(f"Failed to initialize TritonLLMService: {e}")
|
115
|
+
return False
|
116
|
+
|
117
|
+
async def generate_text(self,
|
118
|
+
prompt: str,
|
119
|
+
model_name: Optional[str] = None,
|
120
|
+
max_tokens: int = 100,
|
121
|
+
temperature: float = None,
|
122
|
+
top_p: float = None,
|
123
|
+
top_k: int = None,
|
124
|
+
stop_sequences: Optional[List[str]] = None,
|
125
|
+
system_prompt: Optional[str] = None,
|
126
|
+
**kwargs) -> Dict[str, Any]:
|
127
|
+
"""
|
128
|
+
Generate text using the specified model.
|
129
|
+
|
130
|
+
Args:
|
131
|
+
prompt: Input text prompt
|
132
|
+
model_name: Name of the model to use (uses default if not specified)
|
133
|
+
max_tokens: Maximum number of tokens to generate
|
134
|
+
temperature: Sampling temperature (0.0 to 1.0)
|
135
|
+
top_p: Top-p sampling parameter
|
136
|
+
top_k: Top-k sampling parameter
|
137
|
+
stop_sequences: List of sequences to stop generation
|
138
|
+
system_prompt: System prompt for instruction-following models
|
139
|
+
**kwargs: Additional generation parameters
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
Dictionary containing generated text and metadata
|
143
|
+
"""
|
144
|
+
try:
|
145
|
+
# Use default model if not specified
|
146
|
+
model_name = model_name or self.default_model
|
147
|
+
|
148
|
+
# Validate parameters
|
149
|
+
max_tokens = min(max_tokens, self.max_tokens_limit)
|
150
|
+
temperature = temperature if temperature is not None else self.temperature_default
|
151
|
+
top_p = top_p if top_p is not None else self.top_p_default
|
152
|
+
top_k = top_k if top_k is not None else self.top_k_default
|
153
|
+
|
154
|
+
# Prepare generation parameters
|
155
|
+
params = {
|
156
|
+
"temperature": temperature,
|
157
|
+
"max_tokens": max_tokens,
|
158
|
+
"top_p": top_p,
|
159
|
+
"top_k": top_k,
|
160
|
+
**kwargs
|
161
|
+
}
|
162
|
+
|
163
|
+
if system_prompt:
|
164
|
+
params["system_prompt"] = system_prompt
|
165
|
+
|
166
|
+
if stop_sequences:
|
167
|
+
params["stop_sequences"] = stop_sequences
|
168
|
+
|
169
|
+
logger.debug(f"Generating text with model {model_name}, prompt length: {len(prompt)}")
|
170
|
+
|
171
|
+
# Call Triton provider
|
172
|
+
result = await self.triton_provider.completions(
|
173
|
+
prompt=prompt,
|
174
|
+
model_name=model_name,
|
175
|
+
params=params
|
176
|
+
)
|
177
|
+
|
178
|
+
if "error" in result:
|
179
|
+
logger.error(f"Text generation failed: {result['error']}")
|
180
|
+
return {
|
181
|
+
"success": False,
|
182
|
+
"error": result["error"],
|
183
|
+
"model_name": model_name
|
184
|
+
}
|
185
|
+
|
186
|
+
# Format response
|
187
|
+
response = {
|
188
|
+
"success": True,
|
189
|
+
"text": result["completion"],
|
190
|
+
"model_name": model_name,
|
191
|
+
"usage": result.get("metadata", {}).get("token_usage", {}),
|
192
|
+
"parameters": {
|
193
|
+
"temperature": temperature,
|
194
|
+
"max_tokens": max_tokens,
|
195
|
+
"top_p": top_p,
|
196
|
+
"top_k": top_k
|
197
|
+
}
|
198
|
+
}
|
199
|
+
|
200
|
+
logger.debug(f"Text generation completed, output length: {len(response['text'])}")
|
201
|
+
return response
|
202
|
+
|
203
|
+
except Exception as e:
|
204
|
+
logger.error(f"Error in generate_text: {e}")
|
205
|
+
return {
|
206
|
+
"success": False,
|
207
|
+
"error": str(e),
|
208
|
+
"model_name": model_name or self.default_model
|
209
|
+
}
|
210
|
+
|
211
|
+
async def chat_completion(self,
|
212
|
+
messages: List[Dict[str, str]],
|
213
|
+
model_name: Optional[str] = None,
|
214
|
+
max_tokens: int = 100,
|
215
|
+
temperature: float = None,
|
216
|
+
top_p: float = None,
|
217
|
+
top_k: int = None,
|
218
|
+
stop_sequences: Optional[List[str]] = None,
|
219
|
+
**kwargs) -> Dict[str, Any]:
|
220
|
+
"""
|
221
|
+
Generate chat completion using conversation messages.
|
222
|
+
|
223
|
+
Args:
|
224
|
+
messages: List of message dictionaries with 'role' and 'content'
|
225
|
+
model_name: Name of the model to use
|
226
|
+
max_tokens: Maximum number of tokens to generate
|
227
|
+
temperature: Sampling temperature
|
228
|
+
top_p: Top-p sampling parameter
|
229
|
+
top_k: Top-k sampling parameter
|
230
|
+
stop_sequences: List of sequences to stop generation
|
231
|
+
**kwargs: Additional parameters
|
232
|
+
|
233
|
+
Returns:
|
234
|
+
Dictionary containing the assistant's response and metadata
|
235
|
+
"""
|
236
|
+
try:
|
237
|
+
# Use default model if not specified
|
238
|
+
model_name = model_name or self.default_model
|
239
|
+
|
240
|
+
# Format messages into a prompt
|
241
|
+
prompt = self._format_chat_messages(messages, model_name)
|
242
|
+
|
243
|
+
logger.debug(f"Chat completion with {len(messages)} messages, model: {model_name}")
|
244
|
+
|
245
|
+
# Generate response
|
246
|
+
result = await self.generate_text(
|
247
|
+
prompt=prompt,
|
248
|
+
model_name=model_name,
|
249
|
+
max_tokens=max_tokens,
|
250
|
+
temperature=temperature,
|
251
|
+
top_p=top_p,
|
252
|
+
top_k=top_k,
|
253
|
+
stop_sequences=stop_sequences,
|
254
|
+
**kwargs
|
255
|
+
)
|
256
|
+
|
257
|
+
if not result["success"]:
|
258
|
+
return result
|
259
|
+
|
260
|
+
# Format as chat completion response
|
261
|
+
response = {
|
262
|
+
"success": True,
|
263
|
+
"message": {
|
264
|
+
"role": "assistant",
|
265
|
+
"content": result["text"]
|
266
|
+
},
|
267
|
+
"model_name": model_name,
|
268
|
+
"usage": result.get("usage", {}),
|
269
|
+
"parameters": result.get("parameters", {})
|
270
|
+
}
|
271
|
+
|
272
|
+
logger.debug("Chat completion completed successfully")
|
273
|
+
return response
|
274
|
+
|
275
|
+
except Exception as e:
|
276
|
+
logger.error(f"Error in chat_completion: {e}")
|
277
|
+
return {
|
278
|
+
"success": False,
|
279
|
+
"error": str(e),
|
280
|
+
"model_name": model_name or self.default_model
|
281
|
+
}
|
282
|
+
|
283
|
+
async def generate_text_stream(self,
|
284
|
+
prompt: str,
|
285
|
+
model_name: Optional[str] = None,
|
286
|
+
max_tokens: int = 100,
|
287
|
+
temperature: float = None,
|
288
|
+
top_p: float = None,
|
289
|
+
top_k: int = None,
|
290
|
+
stop_sequences: Optional[List[str]] = None,
|
291
|
+
**kwargs) -> AsyncGenerator[Dict[str, Any], None]:
|
292
|
+
"""
|
293
|
+
Generate text with streaming response.
|
294
|
+
|
295
|
+
Args:
|
296
|
+
prompt: Input text prompt
|
297
|
+
model_name: Name of the model to use
|
298
|
+
max_tokens: Maximum number of tokens to generate
|
299
|
+
temperature: Sampling temperature
|
300
|
+
top_p: Top-p sampling parameter
|
301
|
+
top_k: Top-k sampling parameter
|
302
|
+
stop_sequences: List of sequences to stop generation
|
303
|
+
**kwargs: Additional parameters
|
304
|
+
|
305
|
+
Yields:
|
306
|
+
Dictionary chunks containing partial text and metadata
|
307
|
+
"""
|
308
|
+
try:
|
309
|
+
# For now, simulate streaming by chunking the complete response
|
310
|
+
# TODO: Implement true streaming when Triton supports it
|
311
|
+
|
312
|
+
result = await self.generate_text(
|
313
|
+
prompt=prompt,
|
314
|
+
model_name=model_name,
|
315
|
+
max_tokens=max_tokens,
|
316
|
+
temperature=temperature,
|
317
|
+
top_p=top_p,
|
318
|
+
top_k=top_k,
|
319
|
+
stop_sequences=stop_sequences,
|
320
|
+
**kwargs
|
321
|
+
)
|
322
|
+
|
323
|
+
if not result["success"]:
|
324
|
+
yield {
|
325
|
+
"success": False,
|
326
|
+
"error": result["error"],
|
327
|
+
"model_name": model_name or self.default_model
|
328
|
+
}
|
329
|
+
return
|
330
|
+
|
331
|
+
# Simulate streaming by yielding chunks
|
332
|
+
text = result["text"]
|
333
|
+
chunk_size = 10 # Characters per chunk
|
334
|
+
|
335
|
+
for i in range(0, len(text), chunk_size):
|
336
|
+
chunk = text[i:i + chunk_size]
|
337
|
+
|
338
|
+
yield {
|
339
|
+
"success": True,
|
340
|
+
"text": chunk,
|
341
|
+
"is_complete": i + chunk_size >= len(text),
|
342
|
+
"model_name": model_name or self.default_model
|
343
|
+
}
|
344
|
+
|
345
|
+
# Small delay to simulate streaming
|
346
|
+
await asyncio.sleep(0.05)
|
347
|
+
|
348
|
+
except Exception as e:
|
349
|
+
logger.error(f"Error in generate_text_stream: {e}")
|
350
|
+
yield {
|
351
|
+
"success": False,
|
352
|
+
"error": str(e),
|
353
|
+
"model_name": model_name or self.default_model
|
354
|
+
}
|
355
|
+
|
356
|
+
async def get_model_info(self, model_name: str) -> Dict[str, Any]:
|
357
|
+
"""Get information about a specific model"""
|
358
|
+
try:
|
359
|
+
if not self.triton_provider.is_model_ready(model_name):
|
360
|
+
return {
|
361
|
+
"success": False,
|
362
|
+
"error": f"Model {model_name} is not ready"
|
363
|
+
}
|
364
|
+
|
365
|
+
metadata = self.triton_provider.get_model_metadata(model_name)
|
366
|
+
config = self.triton_provider.get_model_config(model_name)
|
367
|
+
|
368
|
+
return {
|
369
|
+
"success": True,
|
370
|
+
"model_name": model_name,
|
371
|
+
"metadata": metadata,
|
372
|
+
"config": config,
|
373
|
+
"is_ready": True
|
374
|
+
}
|
375
|
+
|
376
|
+
except Exception as e:
|
377
|
+
logger.error(f"Error getting model info for {model_name}: {e}")
|
378
|
+
return {
|
379
|
+
"success": False,
|
380
|
+
"error": str(e),
|
381
|
+
"model_name": model_name
|
382
|
+
}
|
383
|
+
|
384
|
+
async def list_available_models(self) -> List[str]:
|
385
|
+
"""List all available models"""
|
386
|
+
try:
|
387
|
+
return self.triton_provider.get_models(ModelType.LLM)
|
388
|
+
except Exception as e:
|
389
|
+
logger.error(f"Error listing models: {e}")
|
390
|
+
return []
|
391
|
+
|
392
|
+
def _format_chat_messages(self, messages: List[Dict[str, str]], model_name: str) -> str:
|
393
|
+
"""Format chat messages into a prompt based on model type"""
|
394
|
+
# Determine chat template based on model name
|
395
|
+
template_key = "default"
|
396
|
+
if "gemma" in model_name.lower():
|
397
|
+
template_key = "gemma"
|
398
|
+
elif "llama" in model_name.lower():
|
399
|
+
template_key = "llama"
|
400
|
+
|
401
|
+
formatter = self.chat_templates.get(template_key, self.chat_templates["default"])
|
402
|
+
return formatter(messages)
|
403
|
+
|
404
|
+
def _format_gemma_chat(self, messages: List[Dict[str, str]]) -> str:
|
405
|
+
"""Format messages for Gemma models"""
|
406
|
+
formatted = ""
|
407
|
+
|
408
|
+
for message in messages:
|
409
|
+
role = message["role"]
|
410
|
+
content = message["content"]
|
411
|
+
|
412
|
+
if role == "system":
|
413
|
+
formatted += f"<start_of_turn>system\n{content}<end_of_turn>\n"
|
414
|
+
elif role == "user":
|
415
|
+
formatted += f"<start_of_turn>user\n{content}<end_of_turn>\n"
|
416
|
+
elif role == "assistant":
|
417
|
+
formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n"
|
418
|
+
|
419
|
+
# Add the start token for the assistant response
|
420
|
+
formatted += "<start_of_turn>model\n"
|
421
|
+
|
422
|
+
return formatted
|
423
|
+
|
424
|
+
def _format_llama_chat(self, messages: List[Dict[str, str]]) -> str:
|
425
|
+
"""Format messages for Llama models"""
|
426
|
+
formatted = "<s>"
|
427
|
+
|
428
|
+
for message in messages:
|
429
|
+
role = message["role"]
|
430
|
+
content = message["content"]
|
431
|
+
|
432
|
+
if role == "system":
|
433
|
+
formatted += f"[INST] <<SYS>>\n{content}\n<</SYS>>\n\n"
|
434
|
+
elif role == "user":
|
435
|
+
if formatted.endswith("<s>"):
|
436
|
+
formatted += f"[INST] {content} [/INST]"
|
437
|
+
else:
|
438
|
+
formatted += f"<s>[INST] {content} [/INST]"
|
439
|
+
elif role == "assistant":
|
440
|
+
formatted += f" {content} </s>"
|
441
|
+
|
442
|
+
return formatted
|
443
|
+
|
444
|
+
def _format_default_chat(self, messages: List[Dict[str, str]]) -> str:
|
445
|
+
"""Default chat formatting"""
|
446
|
+
formatted = ""
|
447
|
+
|
448
|
+
for message in messages:
|
449
|
+
role = message["role"]
|
450
|
+
content = message["content"]
|
451
|
+
|
452
|
+
if role == "system":
|
453
|
+
formatted += f"System: {content}\n\n"
|
454
|
+
elif role == "user":
|
455
|
+
formatted += f"User: {content}\n\n"
|
456
|
+
elif role == "assistant":
|
457
|
+
formatted += f"Assistant: {content}\n\n"
|
458
|
+
|
459
|
+
# Add prompt for assistant response
|
460
|
+
formatted += "Assistant:"
|
461
|
+
|
462
|
+
return formatted
|
463
|
+
|
464
|
+
def get_capabilities(self) -> List[Capability]:
|
465
|
+
"""Get service capabilities"""
|
466
|
+
return [
|
467
|
+
Capability.CHAT,
|
468
|
+
Capability.COMPLETION
|
469
|
+
]
|
470
|
+
|
471
|
+
def get_supported_models(self) -> List[str]:
|
472
|
+
"""Get list of supported model types"""
|
473
|
+
return [
|
474
|
+
"gemma-2-2b-it",
|
475
|
+
"gemma-2-4b-it",
|
476
|
+
"gemma-2-7b-it",
|
477
|
+
"llama-2-7b-chat",
|
478
|
+
"llama-2-13b-chat",
|
479
|
+
"mistral-7b-instruct",
|
480
|
+
"custom-models" # Support for custom deployed models
|
481
|
+
]
|
@@ -0,0 +1,78 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Dict, Any, List, Union, Optional, Tuple
|
3
|
+
import asyncio
|
4
|
+
import logging
|
5
|
+
from pathlib import Path
|
6
|
+
import joblib
|
7
|
+
import numpy as np
|
8
|
+
import pandas as pd
|
9
|
+
|
10
|
+
from isa_model.inference.services.base_service import BaseService
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
class BaseMLService(BaseService, ABC):
|
15
|
+
"""Base class for traditional ML model services"""
|
16
|
+
|
17
|
+
def __init__(self, provider: 'BaseProvider', model_name: str):
|
18
|
+
super().__init__(provider, model_name)
|
19
|
+
self.model = None
|
20
|
+
self.model_info = {}
|
21
|
+
self.feature_names = []
|
22
|
+
self.target_names = []
|
23
|
+
self.preprocessing_pipeline = None
|
24
|
+
|
25
|
+
@abstractmethod
|
26
|
+
async def load_model(self, model_path: str) -> None:
|
27
|
+
"""Load the ML model from file"""
|
28
|
+
pass
|
29
|
+
|
30
|
+
@abstractmethod
|
31
|
+
async def predict(self, features: Union[np.ndarray, pd.DataFrame, List]) -> Dict[str, Any]:
|
32
|
+
"""Make predictions"""
|
33
|
+
pass
|
34
|
+
|
35
|
+
@abstractmethod
|
36
|
+
async def predict_proba(self, features: Union[np.ndarray, pd.DataFrame, List]) -> Dict[str, Any]:
|
37
|
+
"""Get prediction probabilities (for classification models)"""
|
38
|
+
pass
|
39
|
+
|
40
|
+
async def batch_predict(self, features_batch: List[Union[np.ndarray, pd.DataFrame, List]]) -> List[Dict[str, Any]]:
|
41
|
+
"""Batch predictions for multiple inputs"""
|
42
|
+
results = []
|
43
|
+
for features in features_batch:
|
44
|
+
result = await self.predict(features)
|
45
|
+
results.append(result)
|
46
|
+
return results
|
47
|
+
|
48
|
+
async def explain_prediction(self, features: Union[np.ndarray, pd.DataFrame, List]) -> Dict[str, Any]:
|
49
|
+
"""Explain model predictions (if supported)"""
|
50
|
+
return {"explanation": "Feature importance explanation not implemented for this model"}
|
51
|
+
|
52
|
+
def get_model_info(self) -> Dict[str, Any]:
|
53
|
+
"""Get model information"""
|
54
|
+
return {
|
55
|
+
"name": self.model_name,
|
56
|
+
"type": "traditional_ml",
|
57
|
+
"provider": self.provider.name if self.provider else "unknown",
|
58
|
+
"feature_count": len(self.feature_names),
|
59
|
+
"model_info": self.model_info,
|
60
|
+
"supports_probability": hasattr(self.model, 'predict_proba') if self.model else False
|
61
|
+
}
|
62
|
+
|
63
|
+
def _preprocess_features(self, features: Union[np.ndarray, pd.DataFrame, List]) -> np.ndarray:
|
64
|
+
"""Preprocess input features"""
|
65
|
+
if isinstance(features, list):
|
66
|
+
features = np.array(features)
|
67
|
+
elif isinstance(features, pd.DataFrame):
|
68
|
+
features = features.values
|
69
|
+
|
70
|
+
if self.preprocessing_pipeline:
|
71
|
+
features = self.preprocessing_pipeline.transform(features)
|
72
|
+
|
73
|
+
return features
|
74
|
+
|
75
|
+
async def close(self):
|
76
|
+
"""Cleanup resources"""
|
77
|
+
self.model = None
|
78
|
+
logger.info(f"ML service {self.model_name} closed")
|
@@ -0,0 +1,140 @@
|
|
1
|
+
import asyncio
|
2
|
+
import joblib
|
3
|
+
import numpy as np
|
4
|
+
import pandas as pd
|
5
|
+
from typing import Dict, Any, List, Union
|
6
|
+
from pathlib import Path
|
7
|
+
import logging
|
8
|
+
|
9
|
+
from .base_ml_service import BaseMLService
|
10
|
+
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
class SklearnService(BaseMLService):
|
14
|
+
"""Service for scikit-learn models"""
|
15
|
+
|
16
|
+
async def load_model(self, model_path: str) -> None:
|
17
|
+
"""Load scikit-learn model from joblib file"""
|
18
|
+
try:
|
19
|
+
model_path = Path(model_path)
|
20
|
+
|
21
|
+
# Load model
|
22
|
+
self.model = joblib.load(model_path)
|
23
|
+
|
24
|
+
# Try to load additional metadata if available
|
25
|
+
metadata_path = model_path.parent / f"{model_path.stem}_metadata.json"
|
26
|
+
if metadata_path.exists():
|
27
|
+
import json
|
28
|
+
with open(metadata_path, 'r') as f:
|
29
|
+
metadata = json.load(f)
|
30
|
+
self.feature_names = metadata.get('feature_names', [])
|
31
|
+
self.target_names = metadata.get('target_names', [])
|
32
|
+
self.model_info = metadata.get('model_info', {})
|
33
|
+
|
34
|
+
# Try to load preprocessing pipeline
|
35
|
+
preprocessing_path = model_path.parent / f"{model_path.stem}_preprocessing.joblib"
|
36
|
+
if preprocessing_path.exists():
|
37
|
+
self.preprocessing_pipeline = joblib.load(preprocessing_path)
|
38
|
+
|
39
|
+
logger.info(f"Loaded sklearn model: {self.model_name}")
|
40
|
+
|
41
|
+
except Exception as e:
|
42
|
+
logger.error(f"Failed to load sklearn model {model_path}: {e}")
|
43
|
+
raise
|
44
|
+
|
45
|
+
async def predict(self, features: Union[np.ndarray, pd.DataFrame, List]) -> Dict[str, Any]:
|
46
|
+
"""Make predictions with the sklearn model"""
|
47
|
+
if self.model is None:
|
48
|
+
raise ValueError("Model not loaded. Call load_model() first.")
|
49
|
+
|
50
|
+
try:
|
51
|
+
# Preprocess features
|
52
|
+
processed_features = self._preprocess_features(features)
|
53
|
+
|
54
|
+
# Make prediction
|
55
|
+
prediction = self.model.predict(processed_features)
|
56
|
+
|
57
|
+
# Handle single vs batch predictions
|
58
|
+
if prediction.ndim == 0:
|
59
|
+
prediction = [prediction.item()]
|
60
|
+
elif prediction.ndim == 1:
|
61
|
+
prediction = prediction.tolist()
|
62
|
+
|
63
|
+
result = {
|
64
|
+
"predictions": prediction,
|
65
|
+
"model_name": self.model_name,
|
66
|
+
"feature_count": processed_features.shape[1] if processed_features.ndim > 1 else len(processed_features)
|
67
|
+
}
|
68
|
+
|
69
|
+
# Add feature names if available
|
70
|
+
if self.feature_names:
|
71
|
+
result["feature_names"] = self.feature_names
|
72
|
+
|
73
|
+
return result
|
74
|
+
|
75
|
+
except Exception as e:
|
76
|
+
logger.error(f"Prediction failed: {e}")
|
77
|
+
raise
|
78
|
+
|
79
|
+
async def predict_proba(self, features: Union[np.ndarray, pd.DataFrame, List]) -> Dict[str, Any]:
|
80
|
+
"""Get prediction probabilities"""
|
81
|
+
if self.model is None:
|
82
|
+
raise ValueError("Model not loaded. Call load_model() first.")
|
83
|
+
|
84
|
+
if not hasattr(self.model, 'predict_proba'):
|
85
|
+
raise ValueError("Model does not support probability predictions")
|
86
|
+
|
87
|
+
try:
|
88
|
+
processed_features = self._preprocess_features(features)
|
89
|
+
probabilities = self.model.predict_proba(processed_features)
|
90
|
+
|
91
|
+
if probabilities.ndim == 1:
|
92
|
+
probabilities = [probabilities.tolist()]
|
93
|
+
else:
|
94
|
+
probabilities = probabilities.tolist()
|
95
|
+
|
96
|
+
result = {
|
97
|
+
"probabilities": probabilities,
|
98
|
+
"classes": getattr(self.model, 'classes_', []).tolist(),
|
99
|
+
"model_name": self.model_name
|
100
|
+
}
|
101
|
+
|
102
|
+
return result
|
103
|
+
|
104
|
+
except Exception as e:
|
105
|
+
logger.error(f"Probability prediction failed: {e}")
|
106
|
+
raise
|
107
|
+
|
108
|
+
async def explain_prediction(self, features: Union[np.ndarray, pd.DataFrame, List]) -> Dict[str, Any]:
|
109
|
+
"""Explain predictions using feature importance"""
|
110
|
+
try:
|
111
|
+
processed_features = self._preprocess_features(features)
|
112
|
+
|
113
|
+
explanation = {
|
114
|
+
"model_name": self.model_name,
|
115
|
+
"explanation_type": "feature_importance"
|
116
|
+
}
|
117
|
+
|
118
|
+
# Get feature importance if available
|
119
|
+
if hasattr(self.model, 'feature_importances_'):
|
120
|
+
importance = self.model.feature_importances_.tolist()
|
121
|
+
explanation["feature_importance"] = importance
|
122
|
+
|
123
|
+
if self.feature_names:
|
124
|
+
explanation["feature_importance_named"] = dict(zip(self.feature_names, importance))
|
125
|
+
|
126
|
+
# Get coefficients for linear models
|
127
|
+
elif hasattr(self.model, 'coef_'):
|
128
|
+
coef = self.model.coef_
|
129
|
+
if coef.ndim > 1:
|
130
|
+
coef = coef[0] # Take first class for binary classification
|
131
|
+
explanation["coefficients"] = coef.tolist()
|
132
|
+
|
133
|
+
if self.feature_names:
|
134
|
+
explanation["coefficients_named"] = dict(zip(self.feature_names, coef))
|
135
|
+
|
136
|
+
return explanation
|
137
|
+
|
138
|
+
except Exception as e:
|
139
|
+
logger.error(f"Explanation failed: {e}")
|
140
|
+
return {"error": str(e)}
|