isa-model 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/core/model_manager.py +69 -4
- isa_model/core/storage/hf_storage.py +419 -0
- isa_model/deployment/__init__.py +52 -0
- isa_model/deployment/core/__init__.py +34 -0
- isa_model/deployment/core/deployment_config.py +356 -0
- isa_model/deployment/core/deployment_manager.py +549 -0
- isa_model/deployment/core/isa_deployment_service.py +401 -0
- isa_model/eval/factory.py +381 -140
- isa_model/inference/ai_factory.py +427 -236
- isa_model/inference/billing_tracker.py +406 -0
- isa_model/inference/providers/base_provider.py +51 -4
- isa_model/inference/providers/ml_provider.py +50 -0
- isa_model/inference/providers/ollama_provider.py +37 -18
- isa_model/inference/providers/openai_provider.py +65 -36
- isa_model/inference/providers/replicate_provider.py +42 -30
- isa_model/inference/services/audio/base_stt_service.py +21 -2
- isa_model/inference/services/audio/openai_realtime_service.py +353 -0
- isa_model/inference/services/audio/openai_stt_service.py +252 -0
- isa_model/inference/services/audio/openai_tts_service.py +149 -9
- isa_model/inference/services/audio/replicate_tts_service.py +239 -0
- isa_model/inference/services/base_service.py +36 -1
- isa_model/inference/services/embedding/base_embed_service.py +112 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
- isa_model/inference/services/embedding/openai_embed_service.py +223 -0
- isa_model/inference/services/llm/__init__.py +2 -0
- isa_model/inference/services/llm/base_llm_service.py +158 -86
- isa_model/inference/services/llm/llm_adapter.py +414 -0
- isa_model/inference/services/llm/ollama_llm_service.py +252 -63
- isa_model/inference/services/llm/openai_llm_service.py +231 -93
- isa_model/inference/services/llm/triton_llm_service.py +481 -0
- isa_model/inference/services/ml/base_ml_service.py +78 -0
- isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
- isa_model/inference/services/vision/__init__.py +3 -3
- isa_model/inference/services/vision/base_image_gen_service.py +161 -0
- isa_model/inference/services/vision/base_vision_service.py +177 -0
- isa_model/inference/services/vision/helpers/image_utils.py +4 -3
- isa_model/inference/services/vision/ollama_vision_service.py +151 -17
- isa_model/inference/services/vision/openai_vision_service.py +275 -41
- isa_model/inference/services/vision/replicate_image_gen_service.py +278 -118
- isa_model/training/__init__.py +62 -32
- isa_model/training/cloud/__init__.py +22 -0
- isa_model/training/cloud/job_orchestrator.py +402 -0
- isa_model/training/cloud/runpod_trainer.py +454 -0
- isa_model/training/cloud/storage_manager.py +482 -0
- isa_model/training/core/__init__.py +23 -0
- isa_model/training/core/config.py +181 -0
- isa_model/training/core/dataset.py +222 -0
- isa_model/training/core/trainer.py +720 -0
- isa_model/training/core/utils.py +213 -0
- isa_model/training/factory.py +229 -198
- isa_model-0.3.1.dist-info/METADATA +465 -0
- isa_model-0.3.1.dist-info/RECORD +91 -0
- isa_model/core/model_router.py +0 -226
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +0 -202
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
- isa_model/training/engine/llama_factory/__init__.py +0 -39
- isa_model/training/engine/llama_factory/config.py +0 -115
- isa_model/training/engine/llama_factory/data_adapter.py +0 -284
- isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
- isa_model/training/engine/llama_factory/factory.py +0 -331
- isa_model/training/engine/llama_factory/rl.py +0 -254
- isa_model/training/engine/llama_factory/trainer.py +0 -171
- isa_model/training/image_model/configs/create_config.py +0 -37
- isa_model/training/image_model/configs/create_flux_config.py +0 -26
- isa_model/training/image_model/configs/create_lora_config.py +0 -21
- isa_model/training/image_model/prepare_massed_compute.py +0 -97
- isa_model/training/image_model/prepare_upload.py +0 -17
- isa_model/training/image_model/raw_data/create_captions.py +0 -16
- isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
- isa_model/training/image_model/raw_data/pre_processing.py +0 -200
- isa_model/training/image_model/train/train.py +0 -42
- isa_model/training/image_model/train/train_flux.py +0 -41
- isa_model/training/image_model/train/train_lora.py +0 -57
- isa_model/training/image_model/train_main.py +0 -25
- isa_model-0.2.0.dist-info/METADATA +0 -327
- isa_model-0.2.0.dist-info/RECORD +0 -92
- isa_model-0.2.0.dist-info/licenses/LICENSE +0 -21
- /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
- {isa_model-0.2.0.dist-info → isa_model-0.3.1.dist-info}/WHEEL +0 -0
- {isa_model-0.2.0.dist-info → isa_model-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,406 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
|
4
|
+
"""
|
5
|
+
Billing Tracker for isA_Model Services
|
6
|
+
Tracks usage and costs across all AI providers
|
7
|
+
"""
|
8
|
+
|
9
|
+
from typing import Dict, List, Optional, Any, Union
|
10
|
+
from datetime import datetime, timezone
|
11
|
+
from dataclasses import dataclass, asdict
|
12
|
+
import json
|
13
|
+
import logging
|
14
|
+
from pathlib import Path
|
15
|
+
from enum import Enum
|
16
|
+
import os
|
17
|
+
|
18
|
+
logger = logging.getLogger(__name__)
|
19
|
+
|
20
|
+
class ServiceType(Enum):
|
21
|
+
"""Types of AI services"""
|
22
|
+
LLM = "llm"
|
23
|
+
EMBEDDING = "embedding"
|
24
|
+
VISION = "vision"
|
25
|
+
IMAGE_GENERATION = "image_generation"
|
26
|
+
AUDIO_STT = "audio_stt"
|
27
|
+
AUDIO_TTS = "audio_tts"
|
28
|
+
|
29
|
+
class Provider(Enum):
|
30
|
+
"""AI service providers"""
|
31
|
+
OPENAI = "openai"
|
32
|
+
REPLICATE = "replicate"
|
33
|
+
OLLAMA = "ollama"
|
34
|
+
ANTHROPIC = "anthropic"
|
35
|
+
GOOGLE = "google"
|
36
|
+
|
37
|
+
@dataclass
|
38
|
+
class UsageRecord:
|
39
|
+
"""Record of a single API usage"""
|
40
|
+
timestamp: str
|
41
|
+
provider: str
|
42
|
+
service_type: str
|
43
|
+
model_name: str
|
44
|
+
operation: str
|
45
|
+
input_tokens: Optional[int] = None
|
46
|
+
output_tokens: Optional[int] = None
|
47
|
+
total_tokens: Optional[int] = None
|
48
|
+
input_units: Optional[float] = None # For non-token based services (images, audio)
|
49
|
+
output_units: Optional[float] = None
|
50
|
+
cost_usd: Optional[float] = None
|
51
|
+
metadata: Optional[Dict[str, Any]] = None
|
52
|
+
|
53
|
+
def to_dict(self) -> Dict[str, Any]:
|
54
|
+
"""Convert to dictionary"""
|
55
|
+
return asdict(self)
|
56
|
+
|
57
|
+
@classmethod
|
58
|
+
def from_dict(cls, data: Dict[str, Any]) -> 'UsageRecord':
|
59
|
+
"""Create from dictionary"""
|
60
|
+
return cls(**data)
|
61
|
+
|
62
|
+
class BillingTracker:
|
63
|
+
"""
|
64
|
+
Tracks billing and usage across all AI services
|
65
|
+
"""
|
66
|
+
|
67
|
+
def __init__(self, storage_path: Optional[str] = None):
|
68
|
+
"""
|
69
|
+
Initialize billing tracker
|
70
|
+
|
71
|
+
Args:
|
72
|
+
storage_path: Path to store billing data (defaults to project root)
|
73
|
+
"""
|
74
|
+
if storage_path is None:
|
75
|
+
project_root = Path(__file__).parent.parent.parent
|
76
|
+
self.storage_path = project_root / "billing_data.json"
|
77
|
+
else:
|
78
|
+
self.storage_path = Path(storage_path)
|
79
|
+
self.usage_records: List[UsageRecord] = []
|
80
|
+
self.session_start = datetime.now(timezone.utc).isoformat()
|
81
|
+
|
82
|
+
# Load existing data
|
83
|
+
self._load_data()
|
84
|
+
|
85
|
+
def _load_data(self):
|
86
|
+
"""Load existing billing data"""
|
87
|
+
try:
|
88
|
+
if self.storage_path.exists():
|
89
|
+
with open(self.storage_path, 'r') as f:
|
90
|
+
data = json.load(f)
|
91
|
+
self.usage_records = [
|
92
|
+
UsageRecord.from_dict(record)
|
93
|
+
for record in data.get('usage_records', [])
|
94
|
+
]
|
95
|
+
logger.info(f"Loaded {len(self.usage_records)} billing records")
|
96
|
+
except Exception as e:
|
97
|
+
logger.warning(f"Could not load billing data: {e}")
|
98
|
+
self.usage_records = []
|
99
|
+
|
100
|
+
def _save_data(self):
|
101
|
+
"""Save billing data to storage"""
|
102
|
+
try:
|
103
|
+
# Ensure directory exists
|
104
|
+
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
|
105
|
+
|
106
|
+
data = {
|
107
|
+
"session_start": self.session_start,
|
108
|
+
"last_updated": datetime.now(timezone.utc).isoformat(),
|
109
|
+
"usage_records": [record.to_dict() for record in self.usage_records]
|
110
|
+
}
|
111
|
+
|
112
|
+
with open(self.storage_path, 'w') as f:
|
113
|
+
json.dump(data, f, indent=2)
|
114
|
+
|
115
|
+
except Exception as e:
|
116
|
+
logger.error(f"Could not save billing data: {e}")
|
117
|
+
|
118
|
+
def track_usage(
|
119
|
+
self,
|
120
|
+
provider: Union[str, Provider],
|
121
|
+
service_type: Union[str, ServiceType],
|
122
|
+
model_name: str,
|
123
|
+
operation: str,
|
124
|
+
input_tokens: Optional[int] = None,
|
125
|
+
output_tokens: Optional[int] = None,
|
126
|
+
input_units: Optional[float] = None,
|
127
|
+
output_units: Optional[float] = None,
|
128
|
+
metadata: Optional[Dict[str, Any]] = None
|
129
|
+
) -> UsageRecord:
|
130
|
+
"""
|
131
|
+
Track a usage event
|
132
|
+
|
133
|
+
Args:
|
134
|
+
provider: AI provider name
|
135
|
+
service_type: Type of service used
|
136
|
+
model_name: Name of the model
|
137
|
+
operation: Operation performed (e.g., 'chat', 'embedding', 'image_generation')
|
138
|
+
input_tokens: Number of input tokens
|
139
|
+
output_tokens: Number of output tokens
|
140
|
+
input_units: Input units for non-token services (e.g., audio seconds, image count)
|
141
|
+
output_units: Output units for non-token services
|
142
|
+
metadata: Additional metadata
|
143
|
+
|
144
|
+
Returns:
|
145
|
+
UsageRecord object
|
146
|
+
"""
|
147
|
+
# Convert enums to strings
|
148
|
+
if isinstance(provider, Provider):
|
149
|
+
provider = provider.value
|
150
|
+
if isinstance(service_type, ServiceType):
|
151
|
+
service_type = service_type.value
|
152
|
+
|
153
|
+
# Calculate total tokens
|
154
|
+
total_tokens = None
|
155
|
+
if input_tokens is not None or output_tokens is not None:
|
156
|
+
total_tokens = (input_tokens or 0) + (output_tokens or 0)
|
157
|
+
|
158
|
+
# Calculate cost
|
159
|
+
cost_usd = self._calculate_cost(
|
160
|
+
provider, model_name, operation,
|
161
|
+
input_tokens, output_tokens, input_units, output_units
|
162
|
+
)
|
163
|
+
|
164
|
+
# Create usage record
|
165
|
+
record = UsageRecord(
|
166
|
+
timestamp=datetime.now(timezone.utc).isoformat(),
|
167
|
+
provider=provider,
|
168
|
+
service_type=service_type,
|
169
|
+
model_name=model_name,
|
170
|
+
operation=operation,
|
171
|
+
input_tokens=input_tokens,
|
172
|
+
output_tokens=output_tokens,
|
173
|
+
total_tokens=total_tokens,
|
174
|
+
input_units=input_units,
|
175
|
+
output_units=output_units,
|
176
|
+
cost_usd=cost_usd,
|
177
|
+
metadata=metadata or {}
|
178
|
+
)
|
179
|
+
|
180
|
+
# Add to records and save
|
181
|
+
self.usage_records.append(record)
|
182
|
+
self._save_data()
|
183
|
+
|
184
|
+
logger.info(f"Tracked usage: {provider}/{model_name} - ${cost_usd:.6f}")
|
185
|
+
return record
|
186
|
+
|
187
|
+
def _get_model_pricing(self, provider: str, model_name: str) -> Optional[Dict[str, float]]:
|
188
|
+
"""Get pricing information from ModelManager"""
|
189
|
+
try:
|
190
|
+
from isa_model.core.model_manager import ModelManager
|
191
|
+
pricing = ModelManager.MODEL_PRICING.get(provider, {}).get(model_name)
|
192
|
+
if pricing:
|
193
|
+
return pricing
|
194
|
+
|
195
|
+
# Fallback to legacy pricing for backward compatibility
|
196
|
+
legacy_pricing = self._get_legacy_pricing(provider, model_name)
|
197
|
+
if legacy_pricing:
|
198
|
+
return legacy_pricing
|
199
|
+
|
200
|
+
return None
|
201
|
+
except ImportError:
|
202
|
+
# Fallback to legacy pricing if ModelManager is not available
|
203
|
+
return self._get_legacy_pricing(provider, model_name)
|
204
|
+
|
205
|
+
def _get_legacy_pricing(self, provider: str, model_name: str) -> Optional[Dict[str, float]]:
|
206
|
+
"""Legacy pricing information for backward compatibility"""
|
207
|
+
LEGACY_PRICING = {
|
208
|
+
"openai": {
|
209
|
+
"gpt-4.1-mini": {"input": 0.4, "output": 1.6},
|
210
|
+
"gpt-4o": {"input": 5.0, "output": 15.0},
|
211
|
+
"gpt-4o-mini": {"input": 0.15, "output": 0.6},
|
212
|
+
"text-embedding-3-small": {"input": 0.02, "output": 0.0},
|
213
|
+
"text-embedding-3-large": {"input": 0.13, "output": 0.0},
|
214
|
+
"whisper-1": {"input": 6.0, "output": 0.0},
|
215
|
+
"tts-1": {"input": 15.0, "output": 0.0},
|
216
|
+
"tts-1-hd": {"input": 30.0, "output": 0.0},
|
217
|
+
},
|
218
|
+
"ollama": {
|
219
|
+
"default": {"input": 0.0, "output": 0.0}
|
220
|
+
},
|
221
|
+
"replicate": {
|
222
|
+
"black-forest-labs/flux-schnell": {"input": 0.003, "output": 0.0},
|
223
|
+
"meta/meta-llama-3-8b-instruct": {"input": 0.05, "output": 0.25},
|
224
|
+
}
|
225
|
+
}
|
226
|
+
|
227
|
+
provider_pricing = LEGACY_PRICING.get(provider, {})
|
228
|
+
return provider_pricing.get(model_name) or provider_pricing.get("default")
|
229
|
+
|
230
|
+
def _calculate_cost(
|
231
|
+
self,
|
232
|
+
provider: str,
|
233
|
+
model_name: str,
|
234
|
+
operation: str,
|
235
|
+
input_tokens: Optional[int] = None,
|
236
|
+
output_tokens: Optional[int] = None,
|
237
|
+
input_units: Optional[float] = None,
|
238
|
+
output_units: Optional[float] = None
|
239
|
+
) -> float:
|
240
|
+
"""Calculate cost for a usage event"""
|
241
|
+
try:
|
242
|
+
# Get pricing using unified model manager
|
243
|
+
model_pricing = self._get_model_pricing(provider, model_name)
|
244
|
+
|
245
|
+
if not model_pricing:
|
246
|
+
logger.warning(f"No pricing found for {provider}/{model_name}")
|
247
|
+
return 0.0
|
248
|
+
|
249
|
+
cost = 0.0
|
250
|
+
|
251
|
+
# Token-based pricing (per 1M tokens)
|
252
|
+
if input_tokens is not None and "input" in model_pricing:
|
253
|
+
cost += (input_tokens / 1000000) * model_pricing["input"]
|
254
|
+
|
255
|
+
if output_tokens is not None and "output" in model_pricing:
|
256
|
+
cost += (output_tokens / 1000000) * model_pricing["output"]
|
257
|
+
|
258
|
+
return cost
|
259
|
+
|
260
|
+
except Exception as e:
|
261
|
+
logger.error(f"Error calculating cost: {e}")
|
262
|
+
return 0.0
|
263
|
+
|
264
|
+
def get_session_summary(self) -> Dict[str, Any]:
|
265
|
+
"""Get billing summary for current session"""
|
266
|
+
session_records = [
|
267
|
+
record for record in self.usage_records
|
268
|
+
if record.timestamp >= self.session_start
|
269
|
+
]
|
270
|
+
|
271
|
+
return self._generate_summary(session_records, "Current Session")
|
272
|
+
|
273
|
+
def get_total_summary(self) -> Dict[str, Any]:
|
274
|
+
"""Get total billing summary"""
|
275
|
+
return self._generate_summary(self.usage_records, "Total Usage")
|
276
|
+
|
277
|
+
def get_provider_summary(self, provider: Union[str, Provider]) -> Dict[str, Any]:
|
278
|
+
"""Get billing summary for a specific provider"""
|
279
|
+
if isinstance(provider, Provider):
|
280
|
+
provider = provider.value
|
281
|
+
|
282
|
+
provider_records = [
|
283
|
+
record for record in self.usage_records
|
284
|
+
if record.provider == provider
|
285
|
+
]
|
286
|
+
|
287
|
+
return self._generate_summary(provider_records, f"{provider.title()} Usage")
|
288
|
+
|
289
|
+
def _generate_summary(self, records: List[UsageRecord], title: str) -> Dict[str, Any]:
|
290
|
+
"""Generate billing summary from records"""
|
291
|
+
if not records:
|
292
|
+
return {
|
293
|
+
"title": title,
|
294
|
+
"total_cost": 0.0,
|
295
|
+
"total_requests": 0,
|
296
|
+
"providers": {},
|
297
|
+
"services": {},
|
298
|
+
"models": {}
|
299
|
+
}
|
300
|
+
|
301
|
+
total_cost = sum(record.cost_usd or 0 for record in records)
|
302
|
+
total_requests = len(records)
|
303
|
+
|
304
|
+
# Group by provider
|
305
|
+
providers = {}
|
306
|
+
for record in records:
|
307
|
+
if record.provider not in providers:
|
308
|
+
providers[record.provider] = {
|
309
|
+
"cost": 0.0,
|
310
|
+
"requests": 0,
|
311
|
+
"models": set()
|
312
|
+
}
|
313
|
+
providers[record.provider]["cost"] += record.cost_usd or 0
|
314
|
+
providers[record.provider]["requests"] += 1
|
315
|
+
providers[record.provider]["models"].add(record.model_name)
|
316
|
+
|
317
|
+
# Convert sets to lists for JSON serialization
|
318
|
+
for provider_data in providers.values():
|
319
|
+
provider_data["models"] = list(provider_data["models"])
|
320
|
+
|
321
|
+
# Group by service type
|
322
|
+
services = {}
|
323
|
+
for record in records:
|
324
|
+
if record.service_type not in services:
|
325
|
+
services[record.service_type] = {
|
326
|
+
"cost": 0.0,
|
327
|
+
"requests": 0
|
328
|
+
}
|
329
|
+
services[record.service_type]["cost"] += record.cost_usd or 0
|
330
|
+
services[record.service_type]["requests"] += 1
|
331
|
+
|
332
|
+
# Group by model
|
333
|
+
models = {}
|
334
|
+
for record in records:
|
335
|
+
model_key = f"{record.provider}/{record.model_name}"
|
336
|
+
if model_key not in models:
|
337
|
+
models[model_key] = {
|
338
|
+
"cost": 0.0,
|
339
|
+
"requests": 0,
|
340
|
+
"total_tokens": 0
|
341
|
+
}
|
342
|
+
models[model_key]["cost"] += record.cost_usd or 0
|
343
|
+
models[model_key]["requests"] += 1
|
344
|
+
if record.total_tokens:
|
345
|
+
models[model_key]["total_tokens"] += record.total_tokens
|
346
|
+
|
347
|
+
return {
|
348
|
+
"title": title,
|
349
|
+
"total_cost": round(total_cost, 6),
|
350
|
+
"total_requests": total_requests,
|
351
|
+
"providers": providers,
|
352
|
+
"services": services,
|
353
|
+
"models": models,
|
354
|
+
"period": {
|
355
|
+
"start": records[0].timestamp if records else None,
|
356
|
+
"end": records[-1].timestamp if records else None
|
357
|
+
}
|
358
|
+
}
|
359
|
+
|
360
|
+
def print_summary(self, summary_type: str = "session"):
|
361
|
+
"""Print billing summary to console"""
|
362
|
+
if summary_type == "session":
|
363
|
+
summary = self.get_session_summary()
|
364
|
+
elif summary_type == "total":
|
365
|
+
summary = self.get_total_summary()
|
366
|
+
else:
|
367
|
+
raise ValueError("summary_type must be 'session' or 'total'")
|
368
|
+
|
369
|
+
print(f"\n💰 {summary['title']} Billing Summary")
|
370
|
+
print("=" * 50)
|
371
|
+
print(f"💵 Total Cost: ${summary['total_cost']:.6f}")
|
372
|
+
print(f"📊 Total Requests: {summary['total_requests']}")
|
373
|
+
|
374
|
+
if summary['providers']:
|
375
|
+
print("\n📈 By Provider:")
|
376
|
+
for provider, data in summary['providers'].items():
|
377
|
+
print(f" {provider}: ${data['cost']:.6f} ({data['requests']} requests)")
|
378
|
+
|
379
|
+
if summary['services']:
|
380
|
+
print("\n🔧 By Service:")
|
381
|
+
for service, data in summary['services'].items():
|
382
|
+
print(f" {service}: ${data['cost']:.6f} ({data['requests']} requests)")
|
383
|
+
|
384
|
+
if summary['models']:
|
385
|
+
print("\n🤖 By Model:")
|
386
|
+
for model, data in summary['models'].items():
|
387
|
+
tokens_info = f" ({data['total_tokens']} tokens)" if data['total_tokens'] > 0 else ""
|
388
|
+
print(f" {model}: ${data['cost']:.6f} ({data['requests']} requests){tokens_info}")
|
389
|
+
|
390
|
+
# Global billing tracker instance
|
391
|
+
_global_tracker: Optional[BillingTracker] = None
|
392
|
+
|
393
|
+
def get_billing_tracker() -> BillingTracker:
|
394
|
+
"""Get the global billing tracker instance"""
|
395
|
+
global _global_tracker
|
396
|
+
if _global_tracker is None:
|
397
|
+
_global_tracker = BillingTracker()
|
398
|
+
return _global_tracker
|
399
|
+
|
400
|
+
def track_usage(**kwargs) -> UsageRecord:
|
401
|
+
"""Convenience function to track usage"""
|
402
|
+
return get_billing_tracker().track_usage(**kwargs)
|
403
|
+
|
404
|
+
def print_billing_summary(summary_type: str = "session"):
|
405
|
+
"""Convenience function to print billing summary"""
|
406
|
+
get_billing_tracker().print_summary(summary_type)
|
@@ -1,13 +1,51 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
2
|
from typing import Dict, List, Any, Optional
|
3
|
+
import os
|
4
|
+
import logging
|
5
|
+
from pathlib import Path
|
6
|
+
import dotenv
|
3
7
|
|
4
8
|
from isa_model.inference.base import ModelType, Capability
|
5
9
|
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
6
12
|
class BaseProvider(ABC):
|
7
|
-
"""Base class for all AI providers"""
|
13
|
+
"""Base class for all AI providers - handles API key management"""
|
8
14
|
|
9
15
|
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
10
16
|
self.config = config or {}
|
17
|
+
self._load_environment_config()
|
18
|
+
self._validate_config()
|
19
|
+
|
20
|
+
def _load_environment_config(self):
|
21
|
+
"""Load configuration from environment variables"""
|
22
|
+
# Load .env file if it exists
|
23
|
+
project_root = Path(__file__).parent.parent.parent.parent
|
24
|
+
env_path = project_root / ".env"
|
25
|
+
|
26
|
+
if env_path.exists():
|
27
|
+
dotenv.load_dotenv(env_path)
|
28
|
+
|
29
|
+
# Subclasses should override this to load provider-specific env vars
|
30
|
+
self._load_provider_env_vars()
|
31
|
+
|
32
|
+
@abstractmethod
|
33
|
+
def _load_provider_env_vars(self):
|
34
|
+
"""Load provider-specific environment variables"""
|
35
|
+
pass
|
36
|
+
|
37
|
+
def _validate_config(self):
|
38
|
+
"""Validate that required configuration is present"""
|
39
|
+
# Subclasses should override this to validate provider-specific config
|
40
|
+
pass
|
41
|
+
|
42
|
+
def get_api_key(self) -> Optional[str]:
|
43
|
+
"""Get the API key for this provider"""
|
44
|
+
return self.config.get("api_key")
|
45
|
+
|
46
|
+
def has_valid_credentials(self) -> bool:
|
47
|
+
"""Check if provider has valid credentials"""
|
48
|
+
return bool(self.get_api_key())
|
11
49
|
|
12
50
|
@abstractmethod
|
13
51
|
def get_capabilities(self) -> Dict[ModelType, List[Capability]]:
|
@@ -19,10 +57,19 @@ class BaseProvider(ABC):
|
|
19
57
|
"""Get available models for given type"""
|
20
58
|
pass
|
21
59
|
|
22
|
-
@abstractmethod
|
23
60
|
def get_config(self) -> Dict[str, Any]:
|
24
|
-
"""Get provider configuration"""
|
25
|
-
|
61
|
+
"""Get provider configuration (without sensitive data)"""
|
62
|
+
# Return a copy without sensitive information
|
63
|
+
config_copy = self.config.copy()
|
64
|
+
if "api_key" in config_copy:
|
65
|
+
config_copy["api_key"] = "***" if config_copy["api_key"] else ""
|
66
|
+
if "api_token" in config_copy:
|
67
|
+
config_copy["api_token"] = "***" if config_copy["api_token"] else ""
|
68
|
+
return config_copy
|
69
|
+
|
70
|
+
def get_full_config(self) -> Dict[str, Any]:
|
71
|
+
"""Get full provider configuration (including sensitive data) - for internal use only"""
|
72
|
+
return self.config.copy()
|
26
73
|
|
27
74
|
@abstractmethod
|
28
75
|
def is_reasoning_model(self, model_name: str) -> bool:
|
@@ -0,0 +1,50 @@
|
|
1
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
2
|
+
from isa_model.inference.base import ModelType, Capability
|
3
|
+
from typing import Dict, List, Any
|
4
|
+
import logging
|
5
|
+
|
6
|
+
logger = logging.getLogger(__name__)
|
7
|
+
|
8
|
+
class MLProvider(BaseProvider):
|
9
|
+
"""Provider for traditional ML models"""
|
10
|
+
|
11
|
+
def __init__(self, config=None):
|
12
|
+
default_config = {
|
13
|
+
"model_directory": "./models/ml",
|
14
|
+
"cache_models": True,
|
15
|
+
"max_cache_size": 5
|
16
|
+
}
|
17
|
+
|
18
|
+
merged_config = {**default_config, **(config or {})}
|
19
|
+
super().__init__(config=merged_config)
|
20
|
+
self.name = "ml"
|
21
|
+
|
22
|
+
logger.info(f"Initialized MLProvider with model directory: {self.config['model_directory']}")
|
23
|
+
|
24
|
+
def get_capabilities(self) -> Dict[ModelType, List[Capability]]:
|
25
|
+
"""Get provider capabilities"""
|
26
|
+
return {
|
27
|
+
ModelType.LLM: [], # ML models are not LLMs
|
28
|
+
ModelType.EMBEDDING: [],
|
29
|
+
ModelType.VISION: [],
|
30
|
+
"ML": [ # Custom model type for traditional ML
|
31
|
+
"CLASSIFICATION",
|
32
|
+
"REGRESSION",
|
33
|
+
"CLUSTERING",
|
34
|
+
"FEATURE_EXTRACTION"
|
35
|
+
]
|
36
|
+
}
|
37
|
+
|
38
|
+
def get_models(self, model_type: str = "ML") -> List[str]:
|
39
|
+
"""Get available ML models"""
|
40
|
+
# In practice, this would scan the model directory
|
41
|
+
return [
|
42
|
+
"fraud_detection_rf",
|
43
|
+
"customer_churn_xgb",
|
44
|
+
"price_prediction_lr",
|
45
|
+
"recommendation_kmeans"
|
46
|
+
]
|
47
|
+
|
48
|
+
def get_config(self) -> Dict[str, Any]:
|
49
|
+
"""Get provider configuration"""
|
50
|
+
return self.config
|
@@ -2,38 +2,57 @@ from isa_model.inference.providers.base_provider import BaseProvider
|
|
2
2
|
from isa_model.inference.base import ModelType, Capability
|
3
3
|
from typing import Dict, List, Any
|
4
4
|
import logging
|
5
|
+
import os
|
5
6
|
|
6
7
|
logger = logging.getLogger(__name__)
|
7
8
|
|
8
9
|
class OllamaProvider(BaseProvider):
|
9
|
-
"""Provider for Ollama API"""
|
10
|
+
"""Provider for Ollama API with proper configuration management"""
|
10
11
|
|
11
12
|
def __init__(self, config=None):
|
12
|
-
"""
|
13
|
-
|
13
|
+
"""Initialize the Ollama Provider with centralized config management"""
|
14
|
+
super().__init__(config)
|
15
|
+
self.name = "ollama"
|
14
16
|
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
17
|
+
logger.info(f"Initialized OllamaProvider with URL: {self.config.get('base_url', 'http://localhost:11434')}")
|
18
|
+
|
19
|
+
def _load_provider_env_vars(self):
|
20
|
+
"""Load Ollama-specific environment variables"""
|
21
|
+
# Set defaults first
|
22
|
+
defaults = {
|
21
23
|
"base_url": "http://localhost:11434",
|
22
24
|
"timeout": 60,
|
23
|
-
"stream": True,
|
24
25
|
"temperature": 0.7,
|
25
26
|
"top_p": 0.9,
|
26
27
|
"max_tokens": 2048,
|
27
28
|
"keep_alive": "5m"
|
28
29
|
}
|
29
30
|
|
30
|
-
#
|
31
|
-
|
31
|
+
# Apply defaults only if not already set
|
32
|
+
for key, value in defaults.items():
|
33
|
+
if key not in self.config:
|
34
|
+
self.config[key] = value
|
32
35
|
|
33
|
-
|
34
|
-
|
36
|
+
# Load from environment variables (override config if present)
|
37
|
+
env_mappings = {
|
38
|
+
"base_url": "OLLAMA_BASE_URL",
|
39
|
+
}
|
35
40
|
|
36
|
-
|
41
|
+
for config_key, env_var in env_mappings.items():
|
42
|
+
env_value = os.getenv(env_var)
|
43
|
+
if env_value:
|
44
|
+
self.config[config_key] = env_value
|
45
|
+
|
46
|
+
def _validate_config(self):
|
47
|
+
"""Validate Ollama configuration"""
|
48
|
+
# Ollama doesn't require API keys, just validate base_url is set
|
49
|
+
if not self.config.get("base_url"):
|
50
|
+
logger.warning("Ollama base_url not set, using default: http://localhost:11434")
|
51
|
+
self.config["base_url"] = "http://localhost:11434"
|
52
|
+
|
53
|
+
def has_valid_credentials(self) -> bool:
|
54
|
+
"""Check if provider has valid credentials (Ollama doesn't need API keys)"""
|
55
|
+
return True # Ollama typically doesn't require authentication
|
37
56
|
|
38
57
|
def get_capabilities(self) -> Dict[ModelType, List[Capability]]:
|
39
58
|
"""Get provider capabilities by model type"""
|
@@ -46,7 +65,7 @@ class OllamaProvider(BaseProvider):
|
|
46
65
|
Capability.EMBEDDING
|
47
66
|
],
|
48
67
|
ModelType.VISION: [
|
49
|
-
Capability.
|
68
|
+
Capability.MULTIMODAL_UNDERSTANDING
|
50
69
|
]
|
51
70
|
}
|
52
71
|
|
@@ -54,11 +73,11 @@ class OllamaProvider(BaseProvider):
|
|
54
73
|
"""Get available models for given type"""
|
55
74
|
# Placeholder: In real implementation, this would query Ollama API
|
56
75
|
if model_type == ModelType.LLM:
|
57
|
-
return ["llama3", "mistral", "phi3", "llama3.1", "codellama", "gemma"]
|
76
|
+
return ["llama3.2:3b", "llama3", "mistral", "phi3", "llama3.1", "codellama", "gemma"]
|
58
77
|
elif model_type == ModelType.EMBEDDING:
|
59
78
|
return ["bge-m3", "nomic-embed-text"]
|
60
79
|
elif model_type == ModelType.VISION:
|
61
|
-
return ["llava", "bakllava", "llama3-vision"]
|
80
|
+
return ["gemma3:4b", "llava", "bakllava", "llama3-vision"]
|
62
81
|
else:
|
63
82
|
return []
|
64
83
|
|