isa-model 0.0.2__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. isa_model/__init__.py +1 -1
  2. isa_model/core/model_manager.py +69 -4
  3. isa_model/core/model_registry.py +273 -46
  4. isa_model/core/storage/hf_storage.py +419 -0
  5. isa_model/deployment/__init__.py +52 -0
  6. isa_model/deployment/core/__init__.py +34 -0
  7. isa_model/deployment/core/deployment_config.py +356 -0
  8. isa_model/deployment/core/deployment_manager.py +549 -0
  9. isa_model/deployment/core/isa_deployment_service.py +401 -0
  10. isa_model/eval/factory.py +381 -140
  11. isa_model/inference/ai_factory.py +427 -236
  12. isa_model/inference/billing_tracker.py +406 -0
  13. isa_model/inference/providers/base_provider.py +51 -4
  14. isa_model/inference/providers/ml_provider.py +50 -0
  15. isa_model/inference/providers/ollama_provider.py +37 -18
  16. isa_model/inference/providers/openai_provider.py +65 -36
  17. isa_model/inference/providers/replicate_provider.py +42 -30
  18. isa_model/inference/services/audio/base_stt_service.py +21 -2
  19. isa_model/inference/services/audio/openai_realtime_service.py +353 -0
  20. isa_model/inference/services/audio/openai_stt_service.py +252 -0
  21. isa_model/inference/services/audio/openai_tts_service.py +149 -9
  22. isa_model/inference/services/audio/replicate_tts_service.py +239 -0
  23. isa_model/inference/services/base_service.py +36 -1
  24. isa_model/inference/services/embedding/base_embed_service.py +112 -0
  25. isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
  26. isa_model/inference/services/embedding/openai_embed_service.py +223 -0
  27. isa_model/inference/services/llm/__init__.py +2 -0
  28. isa_model/inference/services/llm/base_llm_service.py +158 -86
  29. isa_model/inference/services/llm/llm_adapter.py +414 -0
  30. isa_model/inference/services/llm/ollama_llm_service.py +252 -63
  31. isa_model/inference/services/llm/openai_llm_service.py +231 -93
  32. isa_model/inference/services/llm/triton_llm_service.py +481 -0
  33. isa_model/inference/services/ml/base_ml_service.py +78 -0
  34. isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
  35. isa_model/inference/services/vision/__init__.py +3 -3
  36. isa_model/inference/services/vision/base_image_gen_service.py +161 -0
  37. isa_model/inference/services/vision/base_vision_service.py +177 -0
  38. isa_model/inference/services/vision/helpers/image_utils.py +4 -3
  39. isa_model/inference/services/vision/ollama_vision_service.py +151 -17
  40. isa_model/inference/services/vision/openai_vision_service.py +275 -41
  41. isa_model/inference/services/vision/replicate_image_gen_service.py +278 -118
  42. isa_model/training/__init__.py +62 -32
  43. isa_model/training/cloud/__init__.py +22 -0
  44. isa_model/training/cloud/job_orchestrator.py +402 -0
  45. isa_model/training/cloud/runpod_trainer.py +454 -0
  46. isa_model/training/cloud/storage_manager.py +482 -0
  47. isa_model/training/core/__init__.py +23 -0
  48. isa_model/training/core/config.py +181 -0
  49. isa_model/training/core/dataset.py +222 -0
  50. isa_model/training/core/trainer.py +720 -0
  51. isa_model/training/core/utils.py +213 -0
  52. isa_model/training/factory.py +229 -198
  53. isa_model-0.3.1.dist-info/METADATA +465 -0
  54. isa_model-0.3.1.dist-info/RECORD +91 -0
  55. isa_model/core/model_router.py +0 -226
  56. isa_model/core/model_version.py +0 -0
  57. isa_model/core/resource_manager.py +0 -202
  58. isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
  59. isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
  60. isa_model/training/engine/llama_factory/__init__.py +0 -39
  61. isa_model/training/engine/llama_factory/config.py +0 -115
  62. isa_model/training/engine/llama_factory/data_adapter.py +0 -284
  63. isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
  64. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
  65. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
  66. isa_model/training/engine/llama_factory/factory.py +0 -331
  67. isa_model/training/engine/llama_factory/rl.py +0 -254
  68. isa_model/training/engine/llama_factory/trainer.py +0 -171
  69. isa_model/training/image_model/configs/create_config.py +0 -37
  70. isa_model/training/image_model/configs/create_flux_config.py +0 -26
  71. isa_model/training/image_model/configs/create_lora_config.py +0 -21
  72. isa_model/training/image_model/prepare_massed_compute.py +0 -97
  73. isa_model/training/image_model/prepare_upload.py +0 -17
  74. isa_model/training/image_model/raw_data/create_captions.py +0 -16
  75. isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
  76. isa_model/training/image_model/raw_data/pre_processing.py +0 -200
  77. isa_model/training/image_model/train/train.py +0 -42
  78. isa_model/training/image_model/train/train_flux.py +0 -41
  79. isa_model/training/image_model/train/train_lora.py +0 -57
  80. isa_model/training/image_model/train_main.py +0 -25
  81. isa_model-0.0.2.dist-info/METADATA +0 -327
  82. isa_model-0.0.2.dist-info/RECORD +0 -92
  83. isa_model-0.0.2.dist-info/licenses/LICENSE +0 -21
  84. /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
  85. /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
  86. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
  87. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
  88. /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
  89. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
  90. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
  91. /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
  92. {isa_model-0.0.2.dist-info → isa_model-0.3.1.dist-info}/WHEEL +0 -0
  93. {isa_model-0.0.2.dist-info → isa_model-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,406 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Billing Tracker for isA_Model Services
6
+ Tracks usage and costs across all AI providers
7
+ """
8
+
9
+ from typing import Dict, List, Optional, Any, Union
10
+ from datetime import datetime, timezone
11
+ from dataclasses import dataclass, asdict
12
+ import json
13
+ import logging
14
+ from pathlib import Path
15
+ from enum import Enum
16
+ import os
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ class ServiceType(Enum):
21
+ """Types of AI services"""
22
+ LLM = "llm"
23
+ EMBEDDING = "embedding"
24
+ VISION = "vision"
25
+ IMAGE_GENERATION = "image_generation"
26
+ AUDIO_STT = "audio_stt"
27
+ AUDIO_TTS = "audio_tts"
28
+
29
+ class Provider(Enum):
30
+ """AI service providers"""
31
+ OPENAI = "openai"
32
+ REPLICATE = "replicate"
33
+ OLLAMA = "ollama"
34
+ ANTHROPIC = "anthropic"
35
+ GOOGLE = "google"
36
+
37
+ @dataclass
38
+ class UsageRecord:
39
+ """Record of a single API usage"""
40
+ timestamp: str
41
+ provider: str
42
+ service_type: str
43
+ model_name: str
44
+ operation: str
45
+ input_tokens: Optional[int] = None
46
+ output_tokens: Optional[int] = None
47
+ total_tokens: Optional[int] = None
48
+ input_units: Optional[float] = None # For non-token based services (images, audio)
49
+ output_units: Optional[float] = None
50
+ cost_usd: Optional[float] = None
51
+ metadata: Optional[Dict[str, Any]] = None
52
+
53
+ def to_dict(self) -> Dict[str, Any]:
54
+ """Convert to dictionary"""
55
+ return asdict(self)
56
+
57
+ @classmethod
58
+ def from_dict(cls, data: Dict[str, Any]) -> 'UsageRecord':
59
+ """Create from dictionary"""
60
+ return cls(**data)
61
+
62
+ class BillingTracker:
63
+ """
64
+ Tracks billing and usage across all AI services
65
+ """
66
+
67
+ def __init__(self, storage_path: Optional[str] = None):
68
+ """
69
+ Initialize billing tracker
70
+
71
+ Args:
72
+ storage_path: Path to store billing data (defaults to project root)
73
+ """
74
+ if storage_path is None:
75
+ project_root = Path(__file__).parent.parent.parent
76
+ self.storage_path = project_root / "billing_data.json"
77
+ else:
78
+ self.storage_path = Path(storage_path)
79
+ self.usage_records: List[UsageRecord] = []
80
+ self.session_start = datetime.now(timezone.utc).isoformat()
81
+
82
+ # Load existing data
83
+ self._load_data()
84
+
85
+ def _load_data(self):
86
+ """Load existing billing data"""
87
+ try:
88
+ if self.storage_path.exists():
89
+ with open(self.storage_path, 'r') as f:
90
+ data = json.load(f)
91
+ self.usage_records = [
92
+ UsageRecord.from_dict(record)
93
+ for record in data.get('usage_records', [])
94
+ ]
95
+ logger.info(f"Loaded {len(self.usage_records)} billing records")
96
+ except Exception as e:
97
+ logger.warning(f"Could not load billing data: {e}")
98
+ self.usage_records = []
99
+
100
+ def _save_data(self):
101
+ """Save billing data to storage"""
102
+ try:
103
+ # Ensure directory exists
104
+ self.storage_path.parent.mkdir(parents=True, exist_ok=True)
105
+
106
+ data = {
107
+ "session_start": self.session_start,
108
+ "last_updated": datetime.now(timezone.utc).isoformat(),
109
+ "usage_records": [record.to_dict() for record in self.usage_records]
110
+ }
111
+
112
+ with open(self.storage_path, 'w') as f:
113
+ json.dump(data, f, indent=2)
114
+
115
+ except Exception as e:
116
+ logger.error(f"Could not save billing data: {e}")
117
+
118
+ def track_usage(
119
+ self,
120
+ provider: Union[str, Provider],
121
+ service_type: Union[str, ServiceType],
122
+ model_name: str,
123
+ operation: str,
124
+ input_tokens: Optional[int] = None,
125
+ output_tokens: Optional[int] = None,
126
+ input_units: Optional[float] = None,
127
+ output_units: Optional[float] = None,
128
+ metadata: Optional[Dict[str, Any]] = None
129
+ ) -> UsageRecord:
130
+ """
131
+ Track a usage event
132
+
133
+ Args:
134
+ provider: AI provider name
135
+ service_type: Type of service used
136
+ model_name: Name of the model
137
+ operation: Operation performed (e.g., 'chat', 'embedding', 'image_generation')
138
+ input_tokens: Number of input tokens
139
+ output_tokens: Number of output tokens
140
+ input_units: Input units for non-token services (e.g., audio seconds, image count)
141
+ output_units: Output units for non-token services
142
+ metadata: Additional metadata
143
+
144
+ Returns:
145
+ UsageRecord object
146
+ """
147
+ # Convert enums to strings
148
+ if isinstance(provider, Provider):
149
+ provider = provider.value
150
+ if isinstance(service_type, ServiceType):
151
+ service_type = service_type.value
152
+
153
+ # Calculate total tokens
154
+ total_tokens = None
155
+ if input_tokens is not None or output_tokens is not None:
156
+ total_tokens = (input_tokens or 0) + (output_tokens or 0)
157
+
158
+ # Calculate cost
159
+ cost_usd = self._calculate_cost(
160
+ provider, model_name, operation,
161
+ input_tokens, output_tokens, input_units, output_units
162
+ )
163
+
164
+ # Create usage record
165
+ record = UsageRecord(
166
+ timestamp=datetime.now(timezone.utc).isoformat(),
167
+ provider=provider,
168
+ service_type=service_type,
169
+ model_name=model_name,
170
+ operation=operation,
171
+ input_tokens=input_tokens,
172
+ output_tokens=output_tokens,
173
+ total_tokens=total_tokens,
174
+ input_units=input_units,
175
+ output_units=output_units,
176
+ cost_usd=cost_usd,
177
+ metadata=metadata or {}
178
+ )
179
+
180
+ # Add to records and save
181
+ self.usage_records.append(record)
182
+ self._save_data()
183
+
184
+ logger.info(f"Tracked usage: {provider}/{model_name} - ${cost_usd:.6f}")
185
+ return record
186
+
187
+ def _get_model_pricing(self, provider: str, model_name: str) -> Optional[Dict[str, float]]:
188
+ """Get pricing information from ModelManager"""
189
+ try:
190
+ from isa_model.core.model_manager import ModelManager
191
+ pricing = ModelManager.MODEL_PRICING.get(provider, {}).get(model_name)
192
+ if pricing:
193
+ return pricing
194
+
195
+ # Fallback to legacy pricing for backward compatibility
196
+ legacy_pricing = self._get_legacy_pricing(provider, model_name)
197
+ if legacy_pricing:
198
+ return legacy_pricing
199
+
200
+ return None
201
+ except ImportError:
202
+ # Fallback to legacy pricing if ModelManager is not available
203
+ return self._get_legacy_pricing(provider, model_name)
204
+
205
+ def _get_legacy_pricing(self, provider: str, model_name: str) -> Optional[Dict[str, float]]:
206
+ """Legacy pricing information for backward compatibility"""
207
+ LEGACY_PRICING = {
208
+ "openai": {
209
+ "gpt-4.1-mini": {"input": 0.4, "output": 1.6},
210
+ "gpt-4o": {"input": 5.0, "output": 15.0},
211
+ "gpt-4o-mini": {"input": 0.15, "output": 0.6},
212
+ "text-embedding-3-small": {"input": 0.02, "output": 0.0},
213
+ "text-embedding-3-large": {"input": 0.13, "output": 0.0},
214
+ "whisper-1": {"input": 6.0, "output": 0.0},
215
+ "tts-1": {"input": 15.0, "output": 0.0},
216
+ "tts-1-hd": {"input": 30.0, "output": 0.0},
217
+ },
218
+ "ollama": {
219
+ "default": {"input": 0.0, "output": 0.0}
220
+ },
221
+ "replicate": {
222
+ "black-forest-labs/flux-schnell": {"input": 0.003, "output": 0.0},
223
+ "meta/meta-llama-3-8b-instruct": {"input": 0.05, "output": 0.25},
224
+ }
225
+ }
226
+
227
+ provider_pricing = LEGACY_PRICING.get(provider, {})
228
+ return provider_pricing.get(model_name) or provider_pricing.get("default")
229
+
230
+ def _calculate_cost(
231
+ self,
232
+ provider: str,
233
+ model_name: str,
234
+ operation: str,
235
+ input_tokens: Optional[int] = None,
236
+ output_tokens: Optional[int] = None,
237
+ input_units: Optional[float] = None,
238
+ output_units: Optional[float] = None
239
+ ) -> float:
240
+ """Calculate cost for a usage event"""
241
+ try:
242
+ # Get pricing using unified model manager
243
+ model_pricing = self._get_model_pricing(provider, model_name)
244
+
245
+ if not model_pricing:
246
+ logger.warning(f"No pricing found for {provider}/{model_name}")
247
+ return 0.0
248
+
249
+ cost = 0.0
250
+
251
+ # Token-based pricing (per 1M tokens)
252
+ if input_tokens is not None and "input" in model_pricing:
253
+ cost += (input_tokens / 1000000) * model_pricing["input"]
254
+
255
+ if output_tokens is not None and "output" in model_pricing:
256
+ cost += (output_tokens / 1000000) * model_pricing["output"]
257
+
258
+ return cost
259
+
260
+ except Exception as e:
261
+ logger.error(f"Error calculating cost: {e}")
262
+ return 0.0
263
+
264
+ def get_session_summary(self) -> Dict[str, Any]:
265
+ """Get billing summary for current session"""
266
+ session_records = [
267
+ record for record in self.usage_records
268
+ if record.timestamp >= self.session_start
269
+ ]
270
+
271
+ return self._generate_summary(session_records, "Current Session")
272
+
273
+ def get_total_summary(self) -> Dict[str, Any]:
274
+ """Get total billing summary"""
275
+ return self._generate_summary(self.usage_records, "Total Usage")
276
+
277
+ def get_provider_summary(self, provider: Union[str, Provider]) -> Dict[str, Any]:
278
+ """Get billing summary for a specific provider"""
279
+ if isinstance(provider, Provider):
280
+ provider = provider.value
281
+
282
+ provider_records = [
283
+ record for record in self.usage_records
284
+ if record.provider == provider
285
+ ]
286
+
287
+ return self._generate_summary(provider_records, f"{provider.title()} Usage")
288
+
289
+ def _generate_summary(self, records: List[UsageRecord], title: str) -> Dict[str, Any]:
290
+ """Generate billing summary from records"""
291
+ if not records:
292
+ return {
293
+ "title": title,
294
+ "total_cost": 0.0,
295
+ "total_requests": 0,
296
+ "providers": {},
297
+ "services": {},
298
+ "models": {}
299
+ }
300
+
301
+ total_cost = sum(record.cost_usd or 0 for record in records)
302
+ total_requests = len(records)
303
+
304
+ # Group by provider
305
+ providers = {}
306
+ for record in records:
307
+ if record.provider not in providers:
308
+ providers[record.provider] = {
309
+ "cost": 0.0,
310
+ "requests": 0,
311
+ "models": set()
312
+ }
313
+ providers[record.provider]["cost"] += record.cost_usd or 0
314
+ providers[record.provider]["requests"] += 1
315
+ providers[record.provider]["models"].add(record.model_name)
316
+
317
+ # Convert sets to lists for JSON serialization
318
+ for provider_data in providers.values():
319
+ provider_data["models"] = list(provider_data["models"])
320
+
321
+ # Group by service type
322
+ services = {}
323
+ for record in records:
324
+ if record.service_type not in services:
325
+ services[record.service_type] = {
326
+ "cost": 0.0,
327
+ "requests": 0
328
+ }
329
+ services[record.service_type]["cost"] += record.cost_usd or 0
330
+ services[record.service_type]["requests"] += 1
331
+
332
+ # Group by model
333
+ models = {}
334
+ for record in records:
335
+ model_key = f"{record.provider}/{record.model_name}"
336
+ if model_key not in models:
337
+ models[model_key] = {
338
+ "cost": 0.0,
339
+ "requests": 0,
340
+ "total_tokens": 0
341
+ }
342
+ models[model_key]["cost"] += record.cost_usd or 0
343
+ models[model_key]["requests"] += 1
344
+ if record.total_tokens:
345
+ models[model_key]["total_tokens"] += record.total_tokens
346
+
347
+ return {
348
+ "title": title,
349
+ "total_cost": round(total_cost, 6),
350
+ "total_requests": total_requests,
351
+ "providers": providers,
352
+ "services": services,
353
+ "models": models,
354
+ "period": {
355
+ "start": records[0].timestamp if records else None,
356
+ "end": records[-1].timestamp if records else None
357
+ }
358
+ }
359
+
360
+ def print_summary(self, summary_type: str = "session"):
361
+ """Print billing summary to console"""
362
+ if summary_type == "session":
363
+ summary = self.get_session_summary()
364
+ elif summary_type == "total":
365
+ summary = self.get_total_summary()
366
+ else:
367
+ raise ValueError("summary_type must be 'session' or 'total'")
368
+
369
+ print(f"\n💰 {summary['title']} Billing Summary")
370
+ print("=" * 50)
371
+ print(f"💵 Total Cost: ${summary['total_cost']:.6f}")
372
+ print(f"📊 Total Requests: {summary['total_requests']}")
373
+
374
+ if summary['providers']:
375
+ print("\n📈 By Provider:")
376
+ for provider, data in summary['providers'].items():
377
+ print(f" {provider}: ${data['cost']:.6f} ({data['requests']} requests)")
378
+
379
+ if summary['services']:
380
+ print("\n🔧 By Service:")
381
+ for service, data in summary['services'].items():
382
+ print(f" {service}: ${data['cost']:.6f} ({data['requests']} requests)")
383
+
384
+ if summary['models']:
385
+ print("\n🤖 By Model:")
386
+ for model, data in summary['models'].items():
387
+ tokens_info = f" ({data['total_tokens']} tokens)" if data['total_tokens'] > 0 else ""
388
+ print(f" {model}: ${data['cost']:.6f} ({data['requests']} requests){tokens_info}")
389
+
390
+ # Global billing tracker instance
391
+ _global_tracker: Optional[BillingTracker] = None
392
+
393
+ def get_billing_tracker() -> BillingTracker:
394
+ """Get the global billing tracker instance"""
395
+ global _global_tracker
396
+ if _global_tracker is None:
397
+ _global_tracker = BillingTracker()
398
+ return _global_tracker
399
+
400
+ def track_usage(**kwargs) -> UsageRecord:
401
+ """Convenience function to track usage"""
402
+ return get_billing_tracker().track_usage(**kwargs)
403
+
404
+ def print_billing_summary(summary_type: str = "session"):
405
+ """Convenience function to print billing summary"""
406
+ get_billing_tracker().print_summary(summary_type)
@@ -1,13 +1,51 @@
1
1
  from abc import ABC, abstractmethod
2
2
  from typing import Dict, List, Any, Optional
3
+ import os
4
+ import logging
5
+ from pathlib import Path
6
+ import dotenv
3
7
 
4
8
  from isa_model.inference.base import ModelType, Capability
5
9
 
10
+ logger = logging.getLogger(__name__)
11
+
6
12
  class BaseProvider(ABC):
7
- """Base class for all AI providers"""
13
+ """Base class for all AI providers - handles API key management"""
8
14
 
9
15
  def __init__(self, config: Optional[Dict[str, Any]] = None):
10
16
  self.config = config or {}
17
+ self._load_environment_config()
18
+ self._validate_config()
19
+
20
+ def _load_environment_config(self):
21
+ """Load configuration from environment variables"""
22
+ # Load .env file if it exists
23
+ project_root = Path(__file__).parent.parent.parent.parent
24
+ env_path = project_root / ".env"
25
+
26
+ if env_path.exists():
27
+ dotenv.load_dotenv(env_path)
28
+
29
+ # Subclasses should override this to load provider-specific env vars
30
+ self._load_provider_env_vars()
31
+
32
+ @abstractmethod
33
+ def _load_provider_env_vars(self):
34
+ """Load provider-specific environment variables"""
35
+ pass
36
+
37
+ def _validate_config(self):
38
+ """Validate that required configuration is present"""
39
+ # Subclasses should override this to validate provider-specific config
40
+ pass
41
+
42
+ def get_api_key(self) -> Optional[str]:
43
+ """Get the API key for this provider"""
44
+ return self.config.get("api_key")
45
+
46
+ def has_valid_credentials(self) -> bool:
47
+ """Check if provider has valid credentials"""
48
+ return bool(self.get_api_key())
11
49
 
12
50
  @abstractmethod
13
51
  def get_capabilities(self) -> Dict[ModelType, List[Capability]]:
@@ -19,10 +57,19 @@ class BaseProvider(ABC):
19
57
  """Get available models for given type"""
20
58
  pass
21
59
 
22
- @abstractmethod
23
60
  def get_config(self) -> Dict[str, Any]:
24
- """Get provider configuration"""
25
- return self.config
61
+ """Get provider configuration (without sensitive data)"""
62
+ # Return a copy without sensitive information
63
+ config_copy = self.config.copy()
64
+ if "api_key" in config_copy:
65
+ config_copy["api_key"] = "***" if config_copy["api_key"] else ""
66
+ if "api_token" in config_copy:
67
+ config_copy["api_token"] = "***" if config_copy["api_token"] else ""
68
+ return config_copy
69
+
70
+ def get_full_config(self) -> Dict[str, Any]:
71
+ """Get full provider configuration (including sensitive data) - for internal use only"""
72
+ return self.config.copy()
26
73
 
27
74
  @abstractmethod
28
75
  def is_reasoning_model(self, model_name: str) -> bool:
@@ -0,0 +1,50 @@
1
+ from isa_model.inference.providers.base_provider import BaseProvider
2
+ from isa_model.inference.base import ModelType, Capability
3
+ from typing import Dict, List, Any
4
+ import logging
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ class MLProvider(BaseProvider):
9
+ """Provider for traditional ML models"""
10
+
11
+ def __init__(self, config=None):
12
+ default_config = {
13
+ "model_directory": "./models/ml",
14
+ "cache_models": True,
15
+ "max_cache_size": 5
16
+ }
17
+
18
+ merged_config = {**default_config, **(config or {})}
19
+ super().__init__(config=merged_config)
20
+ self.name = "ml"
21
+
22
+ logger.info(f"Initialized MLProvider with model directory: {self.config['model_directory']}")
23
+
24
+ def get_capabilities(self) -> Dict[ModelType, List[Capability]]:
25
+ """Get provider capabilities"""
26
+ return {
27
+ ModelType.LLM: [], # ML models are not LLMs
28
+ ModelType.EMBEDDING: [],
29
+ ModelType.VISION: [],
30
+ "ML": [ # Custom model type for traditional ML
31
+ "CLASSIFICATION",
32
+ "REGRESSION",
33
+ "CLUSTERING",
34
+ "FEATURE_EXTRACTION"
35
+ ]
36
+ }
37
+
38
+ def get_models(self, model_type: str = "ML") -> List[str]:
39
+ """Get available ML models"""
40
+ # In practice, this would scan the model directory
41
+ return [
42
+ "fraud_detection_rf",
43
+ "customer_churn_xgb",
44
+ "price_prediction_lr",
45
+ "recommendation_kmeans"
46
+ ]
47
+
48
+ def get_config(self) -> Dict[str, Any]:
49
+ """Get provider configuration"""
50
+ return self.config
@@ -2,38 +2,57 @@ from isa_model.inference.providers.base_provider import BaseProvider
2
2
  from isa_model.inference.base import ModelType, Capability
3
3
  from typing import Dict, List, Any
4
4
  import logging
5
+ import os
5
6
 
6
7
  logger = logging.getLogger(__name__)
7
8
 
8
9
  class OllamaProvider(BaseProvider):
9
- """Provider for Ollama API"""
10
+ """Provider for Ollama API with proper configuration management"""
10
11
 
11
12
  def __init__(self, config=None):
12
- """
13
- Initialize the Ollama Provider
13
+ """Initialize the Ollama Provider with centralized config management"""
14
+ super().__init__(config)
15
+ self.name = "ollama"
14
16
 
15
- Args:
16
- config (dict, optional): Configuration for the provider
17
- - base_url: Base URL for Ollama API (default: http://localhost:11434)
18
- - timeout: Timeout for API calls in seconds
19
- """
20
- default_config = {
17
+ logger.info(f"Initialized OllamaProvider with URL: {self.config.get('base_url', 'http://localhost:11434')}")
18
+
19
+ def _load_provider_env_vars(self):
20
+ """Load Ollama-specific environment variables"""
21
+ # Set defaults first
22
+ defaults = {
21
23
  "base_url": "http://localhost:11434",
22
24
  "timeout": 60,
23
- "stream": True,
24
25
  "temperature": 0.7,
25
26
  "top_p": 0.9,
26
27
  "max_tokens": 2048,
27
28
  "keep_alive": "5m"
28
29
  }
29
30
 
30
- # Merge default config with provided config
31
- merged_config = {**default_config, **(config or {})}
31
+ # Apply defaults only if not already set
32
+ for key, value in defaults.items():
33
+ if key not in self.config:
34
+ self.config[key] = value
32
35
 
33
- super().__init__(config=merged_config)
34
- self.name = "ollama"
36
+ # Load from environment variables (override config if present)
37
+ env_mappings = {
38
+ "base_url": "OLLAMA_BASE_URL",
39
+ }
35
40
 
36
- logger.info(f"Initialized OllamaProvider with URL: {self.config['base_url']}")
41
+ for config_key, env_var in env_mappings.items():
42
+ env_value = os.getenv(env_var)
43
+ if env_value:
44
+ self.config[config_key] = env_value
45
+
46
+ def _validate_config(self):
47
+ """Validate Ollama configuration"""
48
+ # Ollama doesn't require API keys, just validate base_url is set
49
+ if not self.config.get("base_url"):
50
+ logger.warning("Ollama base_url not set, using default: http://localhost:11434")
51
+ self.config["base_url"] = "http://localhost:11434"
52
+
53
+ def has_valid_credentials(self) -> bool:
54
+ """Check if provider has valid credentials (Ollama doesn't need API keys)"""
55
+ return True # Ollama typically doesn't require authentication
37
56
 
38
57
  def get_capabilities(self) -> Dict[ModelType, List[Capability]]:
39
58
  """Get provider capabilities by model type"""
@@ -46,7 +65,7 @@ class OllamaProvider(BaseProvider):
46
65
  Capability.EMBEDDING
47
66
  ],
48
67
  ModelType.VISION: [
49
- Capability.IMAGE_UNDERSTANDING
68
+ Capability.MULTIMODAL_UNDERSTANDING
50
69
  ]
51
70
  }
52
71
 
@@ -54,11 +73,11 @@ class OllamaProvider(BaseProvider):
54
73
  """Get available models for given type"""
55
74
  # Placeholder: In real implementation, this would query Ollama API
56
75
  if model_type == ModelType.LLM:
57
- return ["llama3", "mistral", "phi3", "llama3.1", "codellama", "gemma"]
76
+ return ["llama3.2:3b", "llama3", "mistral", "phi3", "llama3.1", "codellama", "gemma"]
58
77
  elif model_type == ModelType.EMBEDDING:
59
78
  return ["bge-m3", "nomic-embed-text"]
60
79
  elif model_type == ModelType.VISION:
61
- return ["llava", "bakllava", "llama3-vision"]
80
+ return ["gemma3:4b", "llava", "bakllava", "llama3-vision"]
62
81
  else:
63
82
  return []
64
83