isa-model 0.0.3__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. isa_model/__init__.py +1 -1
  2. isa_model/core/model_registry.py +273 -46
  3. isa_model/core/storage/hf_storage.py +419 -0
  4. isa_model/deployment/__init__.py +52 -0
  5. isa_model/deployment/core/__init__.py +34 -0
  6. isa_model/deployment/core/deployment_config.py +356 -0
  7. isa_model/deployment/core/deployment_manager.py +549 -0
  8. isa_model/deployment/core/isa_deployment_service.py +401 -0
  9. isa_model/eval/factory.py +381 -140
  10. isa_model/inference/ai_factory.py +142 -240
  11. isa_model/inference/providers/ml_provider.py +50 -0
  12. isa_model/inference/services/audio/openai_tts_service.py +104 -3
  13. isa_model/inference/services/embedding/base_embed_service.py +112 -0
  14. isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
  15. isa_model/inference/services/llm/__init__.py +2 -0
  16. isa_model/inference/services/llm/base_llm_service.py +111 -1
  17. isa_model/inference/services/llm/ollama_llm_service.py +234 -26
  18. isa_model/inference/services/llm/openai_llm_service.py +180 -26
  19. isa_model/inference/services/llm/triton_llm_service.py +481 -0
  20. isa_model/inference/services/ml/base_ml_service.py +78 -0
  21. isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
  22. isa_model/inference/services/vision/__init__.py +3 -3
  23. isa_model/inference/services/vision/base_image_gen_service.py +161 -0
  24. isa_model/inference/services/vision/base_vision_service.py +177 -0
  25. isa_model/inference/services/vision/ollama_vision_service.py +143 -17
  26. isa_model/inference/services/vision/replicate_image_gen_service.py +139 -7
  27. isa_model/training/__init__.py +62 -32
  28. isa_model/training/cloud/__init__.py +22 -0
  29. isa_model/training/cloud/job_orchestrator.py +402 -0
  30. isa_model/training/cloud/runpod_trainer.py +454 -0
  31. isa_model/training/cloud/storage_manager.py +482 -0
  32. isa_model/training/core/__init__.py +23 -0
  33. isa_model/training/core/config.py +181 -0
  34. isa_model/training/core/dataset.py +222 -0
  35. isa_model/training/core/trainer.py +720 -0
  36. isa_model/training/core/utils.py +213 -0
  37. isa_model/training/factory.py +229 -198
  38. isa_model-0.0.8.dist-info/METADATA +465 -0
  39. isa_model-0.0.8.dist-info/RECORD +86 -0
  40. isa_model/core/model_router.py +0 -226
  41. isa_model/core/model_version.py +0 -0
  42. isa_model/core/resource_manager.py +0 -202
  43. isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
  44. isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
  45. isa_model/training/engine/llama_factory/__init__.py +0 -39
  46. isa_model/training/engine/llama_factory/config.py +0 -115
  47. isa_model/training/engine/llama_factory/data_adapter.py +0 -284
  48. isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
  49. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
  50. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
  51. isa_model/training/engine/llama_factory/factory.py +0 -331
  52. isa_model/training/engine/llama_factory/rl.py +0 -254
  53. isa_model/training/engine/llama_factory/trainer.py +0 -171
  54. isa_model/training/image_model/configs/create_config.py +0 -37
  55. isa_model/training/image_model/configs/create_flux_config.py +0 -26
  56. isa_model/training/image_model/configs/create_lora_config.py +0 -21
  57. isa_model/training/image_model/prepare_massed_compute.py +0 -97
  58. isa_model/training/image_model/prepare_upload.py +0 -17
  59. isa_model/training/image_model/raw_data/create_captions.py +0 -16
  60. isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
  61. isa_model/training/image_model/raw_data/pre_processing.py +0 -200
  62. isa_model/training/image_model/train/train.py +0 -42
  63. isa_model/training/image_model/train/train_flux.py +0 -41
  64. isa_model/training/image_model/train/train_lora.py +0 -57
  65. isa_model/training/image_model/train_main.py +0 -25
  66. isa_model-0.0.3.dist-info/METADATA +0 -327
  67. isa_model-0.0.3.dist-info/RECORD +0 -92
  68. isa_model-0.0.3.dist-info/licenses/LICENSE +0 -21
  69. /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
  70. /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
  71. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
  72. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
  73. /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
  74. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
  75. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
  76. /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
  77. {isa_model-0.0.3.dist-info → isa_model-0.0.8.dist-info}/WHEEL +0 -0
  78. {isa_model-0.0.3.dist-info → isa_model-0.0.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,481 @@
1
+ """
2
+ Triton LLM Service
3
+
4
+ Provides LLM-specific functionality using Triton Inference Server as the backend.
5
+ Integrates with the existing TritonProvider for low-level operations.
6
+ """
7
+
8
+ import logging
9
+ from typing import Dict, Any, List, Optional, Union, AsyncGenerator
10
+ import json
11
+ import asyncio
12
+
13
+ from ..base_service import BaseService
14
+ from ...providers.triton_provider import TritonProvider
15
+ from ...base import ModelType, Capability
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class TritonLLMService(BaseService):
21
+ """
22
+ LLM service using Triton Inference Server.
23
+
24
+ This service provides high-level LLM operations like text generation,
25
+ chat completion, and streaming responses using Triton as the backend.
26
+
27
+ Features:
28
+ - Text generation with customizable parameters
29
+ - Chat completion with conversation context
30
+ - Streaming responses for real-time interaction
31
+ - Multiple model support
32
+ - Automatic model loading and management
33
+ - Integration with model registry
34
+
35
+ Example:
36
+ ```python
37
+ from isa_model.inference.services.llm import TritonLLMService
38
+
39
+ # Initialize service
40
+ service = TritonLLMService({
41
+ "triton_url": "localhost:8001",
42
+ "default_model": "gemma-4b-alpaca"
43
+ })
44
+
45
+ # Generate text
46
+ response = await service.generate_text(
47
+ prompt="What is artificial intelligence?",
48
+ model_name="gemma-4b-alpaca",
49
+ max_tokens=100
50
+ )
51
+
52
+ # Chat completion
53
+ messages = [
54
+ {"role": "user", "content": "Hello, how are you?"}
55
+ ]
56
+ response = await service.chat_completion(
57
+ messages=messages,
58
+ model_name="gemma-4b-alpaca"
59
+ )
60
+
61
+ # Streaming generation
62
+ async for chunk in service.generate_text_stream(
63
+ prompt="Tell me a story",
64
+ model_name="gemma-4b-alpaca"
65
+ ):
66
+ print(chunk["text"], end="")
67
+ ```
68
+ """
69
+
70
+ def __init__(self, config: Dict[str, Any]):
71
+ """
72
+ Initialize Triton LLM service.
73
+
74
+ Args:
75
+ config: Service configuration including Triton connection details
76
+ """
77
+ super().__init__(config)
78
+
79
+ # Initialize Triton provider
80
+ self.triton_provider = TritonProvider(config)
81
+
82
+ # Service configuration
83
+ self.default_model = config.get("default_model", "model")
84
+ self.max_tokens_limit = config.get("max_tokens_limit", 2048)
85
+ self.temperature_default = config.get("temperature_default", 0.7)
86
+ self.top_p_default = config.get("top_p_default", 0.9)
87
+ self.top_k_default = config.get("top_k_default", 50)
88
+
89
+ # Chat templates
90
+ self.chat_templates = {
91
+ "gemma": self._format_gemma_chat,
92
+ "llama": self._format_llama_chat,
93
+ "default": self._format_default_chat
94
+ }
95
+
96
+ logger.info(f"TritonLLMService initialized with default model: {self.default_model}")
97
+
98
+ async def initialize(self) -> bool:
99
+ """Initialize the service and check Triton connectivity"""
100
+ try:
101
+ # Check if Triton server is live
102
+ if not self.triton_provider.is_server_live():
103
+ logger.error("Triton server is not live")
104
+ return False
105
+
106
+ # Check if default model is ready
107
+ if not self.triton_provider.is_model_ready(self.default_model):
108
+ logger.warning(f"Default model {self.default_model} is not ready")
109
+
110
+ logger.info("TritonLLMService initialized successfully")
111
+ return True
112
+
113
+ except Exception as e:
114
+ logger.error(f"Failed to initialize TritonLLMService: {e}")
115
+ return False
116
+
117
+ async def generate_text(self,
118
+ prompt: str,
119
+ model_name: Optional[str] = None,
120
+ max_tokens: int = 100,
121
+ temperature: float = None,
122
+ top_p: float = None,
123
+ top_k: int = None,
124
+ stop_sequences: Optional[List[str]] = None,
125
+ system_prompt: Optional[str] = None,
126
+ **kwargs) -> Dict[str, Any]:
127
+ """
128
+ Generate text using the specified model.
129
+
130
+ Args:
131
+ prompt: Input text prompt
132
+ model_name: Name of the model to use (uses default if not specified)
133
+ max_tokens: Maximum number of tokens to generate
134
+ temperature: Sampling temperature (0.0 to 1.0)
135
+ top_p: Top-p sampling parameter
136
+ top_k: Top-k sampling parameter
137
+ stop_sequences: List of sequences to stop generation
138
+ system_prompt: System prompt for instruction-following models
139
+ **kwargs: Additional generation parameters
140
+
141
+ Returns:
142
+ Dictionary containing generated text and metadata
143
+ """
144
+ try:
145
+ # Use default model if not specified
146
+ model_name = model_name or self.default_model
147
+
148
+ # Validate parameters
149
+ max_tokens = min(max_tokens, self.max_tokens_limit)
150
+ temperature = temperature if temperature is not None else self.temperature_default
151
+ top_p = top_p if top_p is not None else self.top_p_default
152
+ top_k = top_k if top_k is not None else self.top_k_default
153
+
154
+ # Prepare generation parameters
155
+ params = {
156
+ "temperature": temperature,
157
+ "max_tokens": max_tokens,
158
+ "top_p": top_p,
159
+ "top_k": top_k,
160
+ **kwargs
161
+ }
162
+
163
+ if system_prompt:
164
+ params["system_prompt"] = system_prompt
165
+
166
+ if stop_sequences:
167
+ params["stop_sequences"] = stop_sequences
168
+
169
+ logger.debug(f"Generating text with model {model_name}, prompt length: {len(prompt)}")
170
+
171
+ # Call Triton provider
172
+ result = await self.triton_provider.completions(
173
+ prompt=prompt,
174
+ model_name=model_name,
175
+ params=params
176
+ )
177
+
178
+ if "error" in result:
179
+ logger.error(f"Text generation failed: {result['error']}")
180
+ return {
181
+ "success": False,
182
+ "error": result["error"],
183
+ "model_name": model_name
184
+ }
185
+
186
+ # Format response
187
+ response = {
188
+ "success": True,
189
+ "text": result["completion"],
190
+ "model_name": model_name,
191
+ "usage": result.get("metadata", {}).get("token_usage", {}),
192
+ "parameters": {
193
+ "temperature": temperature,
194
+ "max_tokens": max_tokens,
195
+ "top_p": top_p,
196
+ "top_k": top_k
197
+ }
198
+ }
199
+
200
+ logger.debug(f"Text generation completed, output length: {len(response['text'])}")
201
+ return response
202
+
203
+ except Exception as e:
204
+ logger.error(f"Error in generate_text: {e}")
205
+ return {
206
+ "success": False,
207
+ "error": str(e),
208
+ "model_name": model_name or self.default_model
209
+ }
210
+
211
+ async def chat_completion(self,
212
+ messages: List[Dict[str, str]],
213
+ model_name: Optional[str] = None,
214
+ max_tokens: int = 100,
215
+ temperature: float = None,
216
+ top_p: float = None,
217
+ top_k: int = None,
218
+ stop_sequences: Optional[List[str]] = None,
219
+ **kwargs) -> Dict[str, Any]:
220
+ """
221
+ Generate chat completion using conversation messages.
222
+
223
+ Args:
224
+ messages: List of message dictionaries with 'role' and 'content'
225
+ model_name: Name of the model to use
226
+ max_tokens: Maximum number of tokens to generate
227
+ temperature: Sampling temperature
228
+ top_p: Top-p sampling parameter
229
+ top_k: Top-k sampling parameter
230
+ stop_sequences: List of sequences to stop generation
231
+ **kwargs: Additional parameters
232
+
233
+ Returns:
234
+ Dictionary containing the assistant's response and metadata
235
+ """
236
+ try:
237
+ # Use default model if not specified
238
+ model_name = model_name or self.default_model
239
+
240
+ # Format messages into a prompt
241
+ prompt = self._format_chat_messages(messages, model_name)
242
+
243
+ logger.debug(f"Chat completion with {len(messages)} messages, model: {model_name}")
244
+
245
+ # Generate response
246
+ result = await self.generate_text(
247
+ prompt=prompt,
248
+ model_name=model_name,
249
+ max_tokens=max_tokens,
250
+ temperature=temperature,
251
+ top_p=top_p,
252
+ top_k=top_k,
253
+ stop_sequences=stop_sequences,
254
+ **kwargs
255
+ )
256
+
257
+ if not result["success"]:
258
+ return result
259
+
260
+ # Format as chat completion response
261
+ response = {
262
+ "success": True,
263
+ "message": {
264
+ "role": "assistant",
265
+ "content": result["text"]
266
+ },
267
+ "model_name": model_name,
268
+ "usage": result.get("usage", {}),
269
+ "parameters": result.get("parameters", {})
270
+ }
271
+
272
+ logger.debug("Chat completion completed successfully")
273
+ return response
274
+
275
+ except Exception as e:
276
+ logger.error(f"Error in chat_completion: {e}")
277
+ return {
278
+ "success": False,
279
+ "error": str(e),
280
+ "model_name": model_name or self.default_model
281
+ }
282
+
283
+ async def generate_text_stream(self,
284
+ prompt: str,
285
+ model_name: Optional[str] = None,
286
+ max_tokens: int = 100,
287
+ temperature: float = None,
288
+ top_p: float = None,
289
+ top_k: int = None,
290
+ stop_sequences: Optional[List[str]] = None,
291
+ **kwargs) -> AsyncGenerator[Dict[str, Any], None]:
292
+ """
293
+ Generate text with streaming response.
294
+
295
+ Args:
296
+ prompt: Input text prompt
297
+ model_name: Name of the model to use
298
+ max_tokens: Maximum number of tokens to generate
299
+ temperature: Sampling temperature
300
+ top_p: Top-p sampling parameter
301
+ top_k: Top-k sampling parameter
302
+ stop_sequences: List of sequences to stop generation
303
+ **kwargs: Additional parameters
304
+
305
+ Yields:
306
+ Dictionary chunks containing partial text and metadata
307
+ """
308
+ try:
309
+ # For now, simulate streaming by chunking the complete response
310
+ # TODO: Implement true streaming when Triton supports it
311
+
312
+ result = await self.generate_text(
313
+ prompt=prompt,
314
+ model_name=model_name,
315
+ max_tokens=max_tokens,
316
+ temperature=temperature,
317
+ top_p=top_p,
318
+ top_k=top_k,
319
+ stop_sequences=stop_sequences,
320
+ **kwargs
321
+ )
322
+
323
+ if not result["success"]:
324
+ yield {
325
+ "success": False,
326
+ "error": result["error"],
327
+ "model_name": model_name or self.default_model
328
+ }
329
+ return
330
+
331
+ # Simulate streaming by yielding chunks
332
+ text = result["text"]
333
+ chunk_size = 10 # Characters per chunk
334
+
335
+ for i in range(0, len(text), chunk_size):
336
+ chunk = text[i:i + chunk_size]
337
+
338
+ yield {
339
+ "success": True,
340
+ "text": chunk,
341
+ "is_complete": i + chunk_size >= len(text),
342
+ "model_name": model_name or self.default_model
343
+ }
344
+
345
+ # Small delay to simulate streaming
346
+ await asyncio.sleep(0.05)
347
+
348
+ except Exception as e:
349
+ logger.error(f"Error in generate_text_stream: {e}")
350
+ yield {
351
+ "success": False,
352
+ "error": str(e),
353
+ "model_name": model_name or self.default_model
354
+ }
355
+
356
+ async def get_model_info(self, model_name: str) -> Dict[str, Any]:
357
+ """Get information about a specific model"""
358
+ try:
359
+ if not self.triton_provider.is_model_ready(model_name):
360
+ return {
361
+ "success": False,
362
+ "error": f"Model {model_name} is not ready"
363
+ }
364
+
365
+ metadata = self.triton_provider.get_model_metadata(model_name)
366
+ config = self.triton_provider.get_model_config(model_name)
367
+
368
+ return {
369
+ "success": True,
370
+ "model_name": model_name,
371
+ "metadata": metadata,
372
+ "config": config,
373
+ "is_ready": True
374
+ }
375
+
376
+ except Exception as e:
377
+ logger.error(f"Error getting model info for {model_name}: {e}")
378
+ return {
379
+ "success": False,
380
+ "error": str(e),
381
+ "model_name": model_name
382
+ }
383
+
384
+ async def list_available_models(self) -> List[str]:
385
+ """List all available models"""
386
+ try:
387
+ return self.triton_provider.get_models(ModelType.LLM)
388
+ except Exception as e:
389
+ logger.error(f"Error listing models: {e}")
390
+ return []
391
+
392
+ def _format_chat_messages(self, messages: List[Dict[str, str]], model_name: str) -> str:
393
+ """Format chat messages into a prompt based on model type"""
394
+ # Determine chat template based on model name
395
+ template_key = "default"
396
+ if "gemma" in model_name.lower():
397
+ template_key = "gemma"
398
+ elif "llama" in model_name.lower():
399
+ template_key = "llama"
400
+
401
+ formatter = self.chat_templates.get(template_key, self.chat_templates["default"])
402
+ return formatter(messages)
403
+
404
+ def _format_gemma_chat(self, messages: List[Dict[str, str]]) -> str:
405
+ """Format messages for Gemma models"""
406
+ formatted = ""
407
+
408
+ for message in messages:
409
+ role = message["role"]
410
+ content = message["content"]
411
+
412
+ if role == "system":
413
+ formatted += f"<start_of_turn>system\n{content}<end_of_turn>\n"
414
+ elif role == "user":
415
+ formatted += f"<start_of_turn>user\n{content}<end_of_turn>\n"
416
+ elif role == "assistant":
417
+ formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n"
418
+
419
+ # Add the start token for the assistant response
420
+ formatted += "<start_of_turn>model\n"
421
+
422
+ return formatted
423
+
424
+ def _format_llama_chat(self, messages: List[Dict[str, str]]) -> str:
425
+ """Format messages for Llama models"""
426
+ formatted = "<s>"
427
+
428
+ for message in messages:
429
+ role = message["role"]
430
+ content = message["content"]
431
+
432
+ if role == "system":
433
+ formatted += f"[INST] <<SYS>>\n{content}\n<</SYS>>\n\n"
434
+ elif role == "user":
435
+ if formatted.endswith("<s>"):
436
+ formatted += f"[INST] {content} [/INST]"
437
+ else:
438
+ formatted += f"<s>[INST] {content} [/INST]"
439
+ elif role == "assistant":
440
+ formatted += f" {content} </s>"
441
+
442
+ return formatted
443
+
444
+ def _format_default_chat(self, messages: List[Dict[str, str]]) -> str:
445
+ """Default chat formatting"""
446
+ formatted = ""
447
+
448
+ for message in messages:
449
+ role = message["role"]
450
+ content = message["content"]
451
+
452
+ if role == "system":
453
+ formatted += f"System: {content}\n\n"
454
+ elif role == "user":
455
+ formatted += f"User: {content}\n\n"
456
+ elif role == "assistant":
457
+ formatted += f"Assistant: {content}\n\n"
458
+
459
+ # Add prompt for assistant response
460
+ formatted += "Assistant:"
461
+
462
+ return formatted
463
+
464
+ def get_capabilities(self) -> List[Capability]:
465
+ """Get service capabilities"""
466
+ return [
467
+ Capability.CHAT,
468
+ Capability.COMPLETION
469
+ ]
470
+
471
+ def get_supported_models(self) -> List[str]:
472
+ """Get list of supported model types"""
473
+ return [
474
+ "gemma-2-2b-it",
475
+ "gemma-2-4b-it",
476
+ "gemma-2-7b-it",
477
+ "llama-2-7b-chat",
478
+ "llama-2-13b-chat",
479
+ "mistral-7b-instruct",
480
+ "custom-models" # Support for custom deployed models
481
+ ]
@@ -0,0 +1,78 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Any, List, Union, Optional, Tuple
3
+ import asyncio
4
+ import logging
5
+ from pathlib import Path
6
+ import joblib
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ from isa_model.inference.services.base_service import BaseService
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class BaseMLService(BaseService, ABC):
15
+ """Base class for traditional ML model services"""
16
+
17
+ def __init__(self, provider: 'BaseProvider', model_name: str):
18
+ super().__init__(provider, model_name)
19
+ self.model = None
20
+ self.model_info = {}
21
+ self.feature_names = []
22
+ self.target_names = []
23
+ self.preprocessing_pipeline = None
24
+
25
+ @abstractmethod
26
+ async def load_model(self, model_path: str) -> None:
27
+ """Load the ML model from file"""
28
+ pass
29
+
30
+ @abstractmethod
31
+ async def predict(self, features: Union[np.ndarray, pd.DataFrame, List]) -> Dict[str, Any]:
32
+ """Make predictions"""
33
+ pass
34
+
35
+ @abstractmethod
36
+ async def predict_proba(self, features: Union[np.ndarray, pd.DataFrame, List]) -> Dict[str, Any]:
37
+ """Get prediction probabilities (for classification models)"""
38
+ pass
39
+
40
+ async def batch_predict(self, features_batch: List[Union[np.ndarray, pd.DataFrame, List]]) -> List[Dict[str, Any]]:
41
+ """Batch predictions for multiple inputs"""
42
+ results = []
43
+ for features in features_batch:
44
+ result = await self.predict(features)
45
+ results.append(result)
46
+ return results
47
+
48
+ async def explain_prediction(self, features: Union[np.ndarray, pd.DataFrame, List]) -> Dict[str, Any]:
49
+ """Explain model predictions (if supported)"""
50
+ return {"explanation": "Feature importance explanation not implemented for this model"}
51
+
52
+ def get_model_info(self) -> Dict[str, Any]:
53
+ """Get model information"""
54
+ return {
55
+ "name": self.model_name,
56
+ "type": "traditional_ml",
57
+ "provider": self.provider.name if self.provider else "unknown",
58
+ "feature_count": len(self.feature_names),
59
+ "model_info": self.model_info,
60
+ "supports_probability": hasattr(self.model, 'predict_proba') if self.model else False
61
+ }
62
+
63
+ def _preprocess_features(self, features: Union[np.ndarray, pd.DataFrame, List]) -> np.ndarray:
64
+ """Preprocess input features"""
65
+ if isinstance(features, list):
66
+ features = np.array(features)
67
+ elif isinstance(features, pd.DataFrame):
68
+ features = features.values
69
+
70
+ if self.preprocessing_pipeline:
71
+ features = self.preprocessing_pipeline.transform(features)
72
+
73
+ return features
74
+
75
+ async def close(self):
76
+ """Cleanup resources"""
77
+ self.model = None
78
+ logger.info(f"ML service {self.model_name} closed")
@@ -0,0 +1,140 @@
1
+ import asyncio
2
+ import joblib
3
+ import numpy as np
4
+ import pandas as pd
5
+ from typing import Dict, Any, List, Union
6
+ from pathlib import Path
7
+ import logging
8
+
9
+ from .base_ml_service import BaseMLService
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class SklearnService(BaseMLService):
14
+ """Service for scikit-learn models"""
15
+
16
+ async def load_model(self, model_path: str) -> None:
17
+ """Load scikit-learn model from joblib file"""
18
+ try:
19
+ model_path = Path(model_path)
20
+
21
+ # Load model
22
+ self.model = joblib.load(model_path)
23
+
24
+ # Try to load additional metadata if available
25
+ metadata_path = model_path.parent / f"{model_path.stem}_metadata.json"
26
+ if metadata_path.exists():
27
+ import json
28
+ with open(metadata_path, 'r') as f:
29
+ metadata = json.load(f)
30
+ self.feature_names = metadata.get('feature_names', [])
31
+ self.target_names = metadata.get('target_names', [])
32
+ self.model_info = metadata.get('model_info', {})
33
+
34
+ # Try to load preprocessing pipeline
35
+ preprocessing_path = model_path.parent / f"{model_path.stem}_preprocessing.joblib"
36
+ if preprocessing_path.exists():
37
+ self.preprocessing_pipeline = joblib.load(preprocessing_path)
38
+
39
+ logger.info(f"Loaded sklearn model: {self.model_name}")
40
+
41
+ except Exception as e:
42
+ logger.error(f"Failed to load sklearn model {model_path}: {e}")
43
+ raise
44
+
45
+ async def predict(self, features: Union[np.ndarray, pd.DataFrame, List]) -> Dict[str, Any]:
46
+ """Make predictions with the sklearn model"""
47
+ if self.model is None:
48
+ raise ValueError("Model not loaded. Call load_model() first.")
49
+
50
+ try:
51
+ # Preprocess features
52
+ processed_features = self._preprocess_features(features)
53
+
54
+ # Make prediction
55
+ prediction = self.model.predict(processed_features)
56
+
57
+ # Handle single vs batch predictions
58
+ if prediction.ndim == 0:
59
+ prediction = [prediction.item()]
60
+ elif prediction.ndim == 1:
61
+ prediction = prediction.tolist()
62
+
63
+ result = {
64
+ "predictions": prediction,
65
+ "model_name": self.model_name,
66
+ "feature_count": processed_features.shape[1] if processed_features.ndim > 1 else len(processed_features)
67
+ }
68
+
69
+ # Add feature names if available
70
+ if self.feature_names:
71
+ result["feature_names"] = self.feature_names
72
+
73
+ return result
74
+
75
+ except Exception as e:
76
+ logger.error(f"Prediction failed: {e}")
77
+ raise
78
+
79
+ async def predict_proba(self, features: Union[np.ndarray, pd.DataFrame, List]) -> Dict[str, Any]:
80
+ """Get prediction probabilities"""
81
+ if self.model is None:
82
+ raise ValueError("Model not loaded. Call load_model() first.")
83
+
84
+ if not hasattr(self.model, 'predict_proba'):
85
+ raise ValueError("Model does not support probability predictions")
86
+
87
+ try:
88
+ processed_features = self._preprocess_features(features)
89
+ probabilities = self.model.predict_proba(processed_features)
90
+
91
+ if probabilities.ndim == 1:
92
+ probabilities = [probabilities.tolist()]
93
+ else:
94
+ probabilities = probabilities.tolist()
95
+
96
+ result = {
97
+ "probabilities": probabilities,
98
+ "classes": getattr(self.model, 'classes_', []).tolist(),
99
+ "model_name": self.model_name
100
+ }
101
+
102
+ return result
103
+
104
+ except Exception as e:
105
+ logger.error(f"Probability prediction failed: {e}")
106
+ raise
107
+
108
+ async def explain_prediction(self, features: Union[np.ndarray, pd.DataFrame, List]) -> Dict[str, Any]:
109
+ """Explain predictions using feature importance"""
110
+ try:
111
+ processed_features = self._preprocess_features(features)
112
+
113
+ explanation = {
114
+ "model_name": self.model_name,
115
+ "explanation_type": "feature_importance"
116
+ }
117
+
118
+ # Get feature importance if available
119
+ if hasattr(self.model, 'feature_importances_'):
120
+ importance = self.model.feature_importances_.tolist()
121
+ explanation["feature_importance"] = importance
122
+
123
+ if self.feature_names:
124
+ explanation["feature_importance_named"] = dict(zip(self.feature_names, importance))
125
+
126
+ # Get coefficients for linear models
127
+ elif hasattr(self.model, 'coef_'):
128
+ coef = self.model.coef_
129
+ if coef.ndim > 1:
130
+ coef = coef[0] # Take first class for binary classification
131
+ explanation["coefficients"] = coef.tolist()
132
+
133
+ if self.feature_names:
134
+ explanation["coefficients_named"] = dict(zip(self.feature_names, coef))
135
+
136
+ return explanation
137
+
138
+ except Exception as e:
139
+ logger.error(f"Explanation failed: {e}")
140
+ return {"error": str(e)}