isa-model 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. isa_model/__init__.py +5 -0
  2. isa_model/core/model_manager.py +143 -0
  3. isa_model/core/model_registry.py +115 -0
  4. isa_model/core/model_router.py +226 -0
  5. isa_model/core/model_storage.py +133 -0
  6. isa_model/core/model_version.py +0 -0
  7. isa_model/core/resource_manager.py +202 -0
  8. isa_model/core/storage/hf_storage.py +0 -0
  9. isa_model/core/storage/local_storage.py +0 -0
  10. isa_model/core/storage/minio_storage.py +0 -0
  11. isa_model/deployment/mlflow_gateway/__init__.py +8 -0
  12. isa_model/deployment/mlflow_gateway/start_gateway.py +65 -0
  13. isa_model/deployment/unified_multimodal_client.py +341 -0
  14. isa_model/inference/__init__.py +11 -0
  15. isa_model/inference/adapter/triton_adapter.py +453 -0
  16. isa_model/inference/adapter/unified_api.py +248 -0
  17. isa_model/inference/ai_factory.py +354 -0
  18. isa_model/inference/backends/Pytorch/bge_embed_backend.py +188 -0
  19. isa_model/inference/backends/Pytorch/gemma_backend.py +167 -0
  20. isa_model/inference/backends/Pytorch/llama_backend.py +166 -0
  21. isa_model/inference/backends/Pytorch/whisper_backend.py +194 -0
  22. isa_model/inference/backends/__init__.py +53 -0
  23. isa_model/inference/backends/base_backend_client.py +26 -0
  24. isa_model/inference/backends/container_services.py +104 -0
  25. isa_model/inference/backends/local_services.py +72 -0
  26. isa_model/inference/backends/openai_client.py +130 -0
  27. isa_model/inference/backends/replicate_client.py +197 -0
  28. isa_model/inference/backends/third_party_services.py +239 -0
  29. isa_model/inference/backends/triton_client.py +97 -0
  30. isa_model/inference/base.py +46 -0
  31. isa_model/inference/client_sdk/__init__.py +0 -0
  32. isa_model/inference/client_sdk/client.py +134 -0
  33. isa_model/inference/client_sdk/client_data_std.py +34 -0
  34. isa_model/inference/client_sdk/client_sdk_schema.py +16 -0
  35. isa_model/inference/client_sdk/exceptions.py +0 -0
  36. isa_model/inference/engine/triton/model_repository/bge/1/model.py +174 -0
  37. isa_model/inference/engine/triton/model_repository/gemma/1/model.py +250 -0
  38. isa_model/inference/engine/triton/model_repository/llama/1/model.py +76 -0
  39. isa_model/inference/engine/triton/model_repository/whisper/1/model.py +195 -0
  40. isa_model/inference/providers/__init__.py +19 -0
  41. isa_model/inference/providers/base_provider.py +30 -0
  42. isa_model/inference/providers/model_cache_manager.py +341 -0
  43. isa_model/inference/providers/ollama_provider.py +73 -0
  44. isa_model/inference/providers/openai_provider.py +87 -0
  45. isa_model/inference/providers/replicate_provider.py +94 -0
  46. isa_model/inference/providers/triton_provider.py +439 -0
  47. isa_model/inference/providers/vllm_provider.py +0 -0
  48. isa_model/inference/providers/yyds_provider.py +83 -0
  49. isa_model/inference/services/__init__.py +14 -0
  50. isa_model/inference/services/audio/fish_speech/handler.py +215 -0
  51. isa_model/inference/services/audio/runpod_tts_fish_service.py +212 -0
  52. isa_model/inference/services/audio/triton_speech_service.py +138 -0
  53. isa_model/inference/services/audio/whisper_service.py +186 -0
  54. isa_model/inference/services/audio/yyds_audio_service.py +71 -0
  55. isa_model/inference/services/base_service.py +106 -0
  56. isa_model/inference/services/base_tts_service.py +66 -0
  57. isa_model/inference/services/embedding/bge_service.py +183 -0
  58. isa_model/inference/services/embedding/ollama_embed_service.py +85 -0
  59. isa_model/inference/services/embedding/ollama_rerank_service.py +118 -0
  60. isa_model/inference/services/embedding/onnx_rerank_service.py +73 -0
  61. isa_model/inference/services/llm/__init__.py +16 -0
  62. isa_model/inference/services/llm/gemma_service.py +143 -0
  63. isa_model/inference/services/llm/llama_service.py +143 -0
  64. isa_model/inference/services/llm/ollama_llm_service.py +108 -0
  65. isa_model/inference/services/llm/openai_llm_service.py +129 -0
  66. isa_model/inference/services/llm/replicate_llm_service.py +179 -0
  67. isa_model/inference/services/llm/triton_llm_service.py +230 -0
  68. isa_model/inference/services/others/table_transformer_service.py +61 -0
  69. isa_model/inference/services/vision/__init__.py +12 -0
  70. isa_model/inference/services/vision/helpers/image_utils.py +58 -0
  71. isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
  72. isa_model/inference/services/vision/ollama_vision_service.py +60 -0
  73. isa_model/inference/services/vision/replicate_vision_service.py +241 -0
  74. isa_model/inference/services/vision/triton_vision_service.py +199 -0
  75. isa_model/inference/services/vision/yyds_vision_service.py +80 -0
  76. isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
  77. isa_model/inference/utils/conversion/onnx_converter.py +0 -0
  78. isa_model/inference/utils/conversion/torch_converter.py +0 -0
  79. isa_model/scripts/inference_tracker.py +283 -0
  80. isa_model/scripts/mlflow_manager.py +379 -0
  81. isa_model/scripts/model_registry.py +465 -0
  82. isa_model/scripts/start_mlflow.py +95 -0
  83. isa_model/scripts/training_tracker.py +257 -0
  84. isa_model/training/engine/llama_factory/__init__.py +39 -0
  85. isa_model/training/engine/llama_factory/config.py +115 -0
  86. isa_model/training/engine/llama_factory/data_adapter.py +284 -0
  87. isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
  88. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
  89. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
  90. isa_model/training/engine/llama_factory/factory.py +331 -0
  91. isa_model/training/engine/llama_factory/rl.py +254 -0
  92. isa_model/training/engine/llama_factory/trainer.py +171 -0
  93. isa_model/training/image_model/configs/create_config.py +37 -0
  94. isa_model/training/image_model/configs/create_flux_config.py +26 -0
  95. isa_model/training/image_model/configs/create_lora_config.py +21 -0
  96. isa_model/training/image_model/prepare_massed_compute.py +97 -0
  97. isa_model/training/image_model/prepare_upload.py +17 -0
  98. isa_model/training/image_model/raw_data/create_captions.py +16 -0
  99. isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
  100. isa_model/training/image_model/raw_data/pre_processing.py +200 -0
  101. isa_model/training/image_model/train/train.py +42 -0
  102. isa_model/training/image_model/train/train_flux.py +41 -0
  103. isa_model/training/image_model/train/train_lora.py +57 -0
  104. isa_model/training/image_model/train_main.py +25 -0
  105. isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
  106. isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
  107. isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
  108. isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
  109. isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
  110. isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
  111. isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
  112. isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
  113. isa_model-0.1.0.dist-info/METADATA +116 -0
  114. isa_model-0.1.0.dist-info/RECORD +117 -0
  115. isa_model-0.1.0.dist-info/WHEEL +5 -0
  116. isa_model-0.1.0.dist-info/licenses/LICENSE +21 -0
  117. isa_model-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,118 @@
1
+ from typing import Dict, Any, List, Optional
2
+ from ollama import AsyncClient
3
+ from ...base_service import BaseRerankService
4
+ from ...base_provider import BaseProvider
5
+ from app.config.config_manager import config_manager
6
+ import httpx
7
+ import asyncio
8
+ from functools import wraps
9
+
10
+ logger = config_manager.get_logger(__name__)
11
+
12
+ def retry_on_connection_error(max_retries=3, delay=1):
13
+ """Decorator to retry on connection errors"""
14
+ def decorator(func):
15
+ @wraps(func)
16
+ async def wrapper(*args, **kwargs):
17
+ last_error = None
18
+ for attempt in range(max_retries):
19
+ try:
20
+ return await func(*args, **kwargs)
21
+ except (httpx.RemoteProtocolError, httpx.ConnectError) as e:
22
+ last_error = e
23
+ if attempt < max_retries - 1:
24
+ logger.warning(f"Connection error on attempt {attempt + 1}, retrying in {delay}s: {str(e)}")
25
+ await asyncio.sleep(delay)
26
+ continue
27
+ raise last_error
28
+ return wrapper
29
+ return decorator
30
+
31
+ class OllamaRerankService(BaseRerankService):
32
+ """Reranking service wrapper around Ollama"""
33
+
34
+ def __init__(self, provider: 'BaseProvider', model_name: str):
35
+ super().__init__(provider, model_name)
36
+
37
+ # Initialize the Ollama client for reranking
38
+ self.client = AsyncClient(
39
+ host=self.config.get('base_url', 'http://localhost:11434')
40
+ )
41
+ self.model_name = model_name
42
+
43
+ @retry_on_connection_error()
44
+ async def rerank(
45
+ self,
46
+ query: str,
47
+ documents: List[Dict],
48
+ top_k: int = 5
49
+ ) -> List[Dict]:
50
+ """Rerank documents based on query relevance"""
51
+ try:
52
+ if not query:
53
+ raise ValueError("Query cannot be empty")
54
+ if not documents:
55
+ return []
56
+
57
+ results = []
58
+ for doc in documents:
59
+ if "content" not in doc:
60
+ raise ValueError("Each document must have a 'content' field")
61
+
62
+ # Format prompt for relevance scoring
63
+ prompt = f"""Rate the relevance of the following text to the query on a scale of 0-100.
64
+ Query: {query}
65
+ Text: {doc['content']}
66
+ Only respond with a number between 0 and 100."""
67
+
68
+ # Get relevance score using direct Ollama API
69
+ response = await self.client.generate(
70
+ model=self.model_name,
71
+ prompt=prompt,
72
+ stream=False
73
+ )
74
+ try:
75
+ score = float(response.response.strip())
76
+ score = max(0.0, min(100.0, score)) / 100.0 # Normalize to 0-1
77
+ except ValueError:
78
+ logger.warning(f"Could not parse score from response: {response.response}")
79
+ score = 0.0
80
+
81
+ # Update document with rerank score
82
+ doc_copy = doc.copy()
83
+ doc_copy["rerank_score"] = score
84
+ doc_copy["final_score"] = doc.get("score", 1.0) * score
85
+ results.append(doc_copy)
86
+
87
+ # Sort by final score in descending order
88
+ results.sort(key=lambda x: x["final_score"], reverse=True)
89
+ return results[:top_k]
90
+
91
+ except Exception as e:
92
+ logger.error(f"Error in rerank: {e}")
93
+ raise
94
+
95
+ @retry_on_connection_error()
96
+ async def rerank_texts(
97
+ self,
98
+ query: str,
99
+ texts: List[str]
100
+ ) -> List[Dict]:
101
+ """Rerank raw texts based on query relevance"""
102
+ try:
103
+ if not query:
104
+ raise ValueError("Query cannot be empty")
105
+ if not texts:
106
+ return []
107
+
108
+ # Convert texts to document format
109
+ documents = [{"content": text, "score": 1.0} for text in texts]
110
+ return await self.rerank(query, documents)
111
+
112
+ except Exception as e:
113
+ logger.error(f"Error in rerank_texts: {str(e)}")
114
+ raise
115
+
116
+ async def close(self):
117
+ """Cleanup resources"""
118
+ await self.client.aclose()
@@ -0,0 +1,73 @@
1
+ from typing import Dict, Any, List, Union, Optional
2
+ from ...base_service import BaseService
3
+ from ...base_provider import BaseProvider
4
+ from transformers import AutoTokenizer
5
+ import onnxruntime as ort
6
+ import numpy as np
7
+ import torch
8
+ import os
9
+ from pathlib import Path
10
+
11
+ class ONNXRerankService(BaseService):
12
+ """ONNX Reranker service for BGE models"""
13
+
14
+ def __init__(self, provider: 'BaseProvider', model_name: str):
15
+ super().__init__(provider, model_name)
16
+ self.model_path = self._get_model_path(model_name)
17
+ self.session = provider.get_session(self.model_path)
18
+
19
+ # Initialize tokenizer
20
+ self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-v2-m3')
21
+ self.max_length = 512
22
+
23
+ def _get_model_path(self, model_name: str) -> str:
24
+ """Get path to ONNX model file"""
25
+ base_dir = Path(__file__).parent
26
+ model_path = base_dir / "model_converted" / model_name / "model.onnx"
27
+ if not model_path.exists():
28
+ raise FileNotFoundError(f"ONNX model not found at {model_path}. Please run the conversion script first.")
29
+ return str(model_path)
30
+
31
+ async def compute_score(self,
32
+ pairs: Union[List[str], List[List[str]]],
33
+ normalize: bool = False) -> Union[float, List[float]]:
34
+ """Compute reranking scores for query-passage pairs"""
35
+ try:
36
+ # Handle single pair case
37
+ if isinstance(pairs[0], str):
38
+ pairs = [pairs]
39
+
40
+ # Tokenize inputs
41
+ inputs = self.tokenizer(
42
+ pairs,
43
+ padding=True,
44
+ truncation=True,
45
+ return_tensors='np',
46
+ max_length=self.max_length
47
+ )
48
+
49
+ # Run inference
50
+ ort_inputs = {
51
+ 'input_ids': inputs['input_ids'],
52
+ 'attention_mask': inputs['attention_mask']
53
+ }
54
+
55
+ scores = self.session.run(
56
+ None, # output names, None means all
57
+ ort_inputs
58
+ )[0]
59
+
60
+ # Convert to float and optionally normalize
61
+ scores = scores.flatten().tolist()
62
+ if normalize:
63
+ scores = [self._sigmoid(score) for score in scores]
64
+
65
+ # Return single score for single pair
66
+ return scores[0] if len(scores) == 1 else scores
67
+
68
+ except Exception as e:
69
+ raise RuntimeError(f"ONNX reranking failed: {e}")
70
+
71
+ def _sigmoid(self, x: float) -> float:
72
+ """Apply sigmoid function to score"""
73
+ return 1 / (1 + np.exp(-x))
@@ -0,0 +1,16 @@
1
+ """
2
+ LLM Services - Business logic services for Language Models
3
+ """
4
+
5
+ # Import LLM services here when created
6
+ from .ollama_llm_service import OllamaLLMService
7
+ from .triton_llm_service import TritonLLMService
8
+ from .openai_llm_service import OpenAILLMService
9
+ from .replicate_llm_service import ReplicateLLMService
10
+
11
+ __all__ = [
12
+ "OllamaLLMService",
13
+ "TritonLLMService",
14
+ "OpenAILLMService",
15
+ "ReplicateLLMService",
16
+ ]
@@ -0,0 +1,143 @@
1
+ import json
2
+ import logging
3
+ import asyncio
4
+ from typing import Dict, List, Any, Optional, Union
5
+
6
+ from isa_model.inference.services.base_service import BaseService
7
+ from isa_model.inference.backends.triton_client import TritonClient
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class GemmaService(BaseService):
13
+ """
14
+ Service for Gemma LLM using Triton Inference Server.
15
+ """
16
+
17
+ def __init__(self, triton_url: str = "localhost:8001", model_name: str = "gemma"):
18
+ """
19
+ Initialize the Gemma service.
20
+
21
+ Args:
22
+ triton_url: URL of the Triton Inference Server
23
+ model_name: Name of the model in Triton
24
+ """
25
+ super().__init__()
26
+ self.triton_url = triton_url
27
+ self.model_name = model_name
28
+ self.client = None
29
+
30
+ # Default generation config
31
+ self.default_config = {
32
+ "max_new_tokens": 512,
33
+ "temperature": 0.7,
34
+ "top_p": 0.9,
35
+ "top_k": 50,
36
+ "repetition_penalty": 1.1,
37
+ "do_sample": True
38
+ }
39
+
40
+ self.logger = logger
41
+
42
+ async def load(self) -> None:
43
+ """
44
+ Load the client connection to Triton.
45
+ """
46
+ if self.is_loaded():
47
+ return
48
+
49
+ try:
50
+ # Create Triton client
51
+ self.logger.info(f"Connecting to Triton server at {self.triton_url}")
52
+ self.client = TritonClient(self.triton_url)
53
+
54
+ # Check if model is ready
55
+ if not await self.client.is_model_ready(self.model_name):
56
+ self.logger.error(f"Model {self.model_name} is not ready on Triton server")
57
+ raise RuntimeError(f"Model {self.model_name} is not ready on Triton server")
58
+
59
+ self._loaded = True
60
+ self.logger.info(f"Connected to Triton for model {self.model_name}")
61
+
62
+ except Exception as e:
63
+ self.logger.error(f"Failed to connect to Triton: {str(e)}")
64
+ raise
65
+
66
+ async def unload(self) -> None:
67
+ """
68
+ Unload the client connection.
69
+ """
70
+ if not self.is_loaded():
71
+ return
72
+
73
+ self.client = None
74
+ self._loaded = False
75
+ self.logger.info("Triton client connection closed")
76
+
77
+ async def generate(self,
78
+ prompt: str,
79
+ system_prompt: Optional[str] = None,
80
+ generation_config: Optional[Dict[str, Any]] = None) -> str:
81
+ """
82
+ Generate text from a prompt using Triton.
83
+
84
+ Args:
85
+ prompt: User prompt
86
+ system_prompt: System prompt to control model behavior
87
+ generation_config: Configuration for text generation
88
+
89
+ Returns:
90
+ Generated text
91
+ """
92
+ if not self.is_loaded():
93
+ await self.load()
94
+
95
+ # Get configuration
96
+ merged_config = self.default_config.copy()
97
+ if generation_config:
98
+ merged_config.update(generation_config)
99
+
100
+ try:
101
+ # Prepare inputs
102
+ inputs = {
103
+ "prompt": [prompt],
104
+ }
105
+
106
+ # Add optional inputs
107
+ if system_prompt:
108
+ inputs["system_prompt"] = [system_prompt]
109
+
110
+ if merged_config:
111
+ inputs["generation_config"] = [json.dumps(merged_config)]
112
+
113
+ # Run inference
114
+ result = await self.client.infer(
115
+ model_name=self.model_name,
116
+ inputs=inputs,
117
+ outputs=["text_output"]
118
+ )
119
+
120
+ # Extract generated text
121
+ generated_text = result["text_output"][0].decode('utf-8')
122
+
123
+ return generated_text
124
+
125
+ except Exception as e:
126
+ self.logger.error(f"Error during text generation: {str(e)}")
127
+ raise
128
+
129
+ def get_model_info(self) -> Dict[str, Any]:
130
+ """
131
+ Get information about the model.
132
+
133
+ Returns:
134
+ Dictionary containing model information
135
+ """
136
+ return {
137
+ "name": self.model_name,
138
+ "type": "llm",
139
+ "backend": "triton",
140
+ "url": self.triton_url,
141
+ "loaded": self.is_loaded(),
142
+ "config": self.default_config
143
+ }
@@ -0,0 +1,143 @@
1
+ import json
2
+ import logging
3
+ import asyncio
4
+ from typing import Dict, List, Any, Optional, Union
5
+
6
+ from isa_model.inference.services.base_service import BaseService
7
+ from isa_model.inference.backends.triton_client import TritonClient
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class LlamaService(BaseService):
13
+ """
14
+ Service for Llama LLM using Triton Inference Server.
15
+ """
16
+
17
+ def __init__(self, triton_url: str = "localhost:8001", model_name: str = "llama"):
18
+ """
19
+ Initialize the Llama service.
20
+
21
+ Args:
22
+ triton_url: URL of the Triton Inference Server
23
+ model_name: Name of the model in Triton
24
+ """
25
+ super().__init__()
26
+ self.triton_url = triton_url
27
+ self.model_name = model_name
28
+ self.client = None
29
+
30
+ # Default generation config
31
+ self.default_config = {
32
+ "max_new_tokens": 512,
33
+ "temperature": 0.7,
34
+ "top_p": 0.9,
35
+ "top_k": 50,
36
+ "repetition_penalty": 1.1,
37
+ "do_sample": True
38
+ }
39
+
40
+ self.logger = logger
41
+
42
+ async def load(self) -> None:
43
+ """
44
+ Load the client connection to Triton.
45
+ """
46
+ if self.is_loaded():
47
+ return
48
+
49
+ try:
50
+ # Create Triton client
51
+ self.logger.info(f"Connecting to Triton server at {self.triton_url}")
52
+ self.client = TritonClient(self.triton_url)
53
+
54
+ # Check if model is ready
55
+ if not await self.client.is_model_ready(self.model_name):
56
+ self.logger.error(f"Model {self.model_name} is not ready on Triton server")
57
+ raise RuntimeError(f"Model {self.model_name} is not ready on Triton server")
58
+
59
+ self._loaded = True
60
+ self.logger.info(f"Connected to Triton for model {self.model_name}")
61
+
62
+ except Exception as e:
63
+ self.logger.error(f"Failed to connect to Triton: {str(e)}")
64
+ raise
65
+
66
+ async def unload(self) -> None:
67
+ """
68
+ Unload the client connection.
69
+ """
70
+ if not self.is_loaded():
71
+ return
72
+
73
+ self.client = None
74
+ self._loaded = False
75
+ self.logger.info("Triton client connection closed")
76
+
77
+ async def generate(self,
78
+ prompt: str,
79
+ system_prompt: Optional[str] = None,
80
+ generation_config: Optional[Dict[str, Any]] = None) -> str:
81
+ """
82
+ Generate text from a prompt using Triton.
83
+
84
+ Args:
85
+ prompt: User prompt
86
+ system_prompt: System prompt to control model behavior
87
+ generation_config: Configuration for text generation
88
+
89
+ Returns:
90
+ Generated text
91
+ """
92
+ if not self.is_loaded():
93
+ await self.load()
94
+
95
+ # Get configuration
96
+ merged_config = self.default_config.copy()
97
+ if generation_config:
98
+ merged_config.update(generation_config)
99
+
100
+ try:
101
+ # Prepare inputs
102
+ inputs = {
103
+ "prompt": [prompt],
104
+ }
105
+
106
+ # Add optional inputs
107
+ if system_prompt:
108
+ inputs["system_prompt"] = [system_prompt]
109
+
110
+ if merged_config:
111
+ inputs["generation_config"] = [json.dumps(merged_config)]
112
+
113
+ # Run inference
114
+ result = await self.client.infer(
115
+ model_name=self.model_name,
116
+ inputs=inputs,
117
+ outputs=["text_output"]
118
+ )
119
+
120
+ # Extract generated text
121
+ generated_text = result["text_output"][0].decode('utf-8')
122
+
123
+ return generated_text
124
+
125
+ except Exception as e:
126
+ self.logger.error(f"Error during text generation: {str(e)}")
127
+ raise
128
+
129
+ def get_model_info(self) -> Dict[str, Any]:
130
+ """
131
+ Get information about the model.
132
+
133
+ Returns:
134
+ Dictionary containing model information
135
+ """
136
+ return {
137
+ "name": self.model_name,
138
+ "type": "llm",
139
+ "backend": "triton",
140
+ "url": self.triton_url,
141
+ "loaded": self.is_loaded(),
142
+ "config": self.default_config
143
+ }
@@ -0,0 +1,108 @@
1
+ import logging
2
+ from typing import Dict, Any, List, Union, AsyncGenerator, Optional
3
+ from isa_model.inference.services.base_service import BaseLLMService
4
+ from isa_model.inference.providers.base_provider import BaseProvider
5
+ from isa_model.inference.backends.local_services import OllamaBackendClient
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ class OllamaLLMService(BaseLLMService):
10
+ """Ollama LLM service using backend client"""
11
+
12
+ def __init__(self, provider: 'BaseProvider', model_name: str = "llama3.1", backend: Optional[OllamaBackendClient] = None):
13
+ super().__init__(provider, model_name)
14
+
15
+ # Use provided backend or create new one
16
+ if backend:
17
+ self.backend = backend
18
+ else:
19
+ host = self.config.get("host", "localhost")
20
+ port = self.config.get("port", 11434)
21
+ self.backend = OllamaBackendClient(host, port)
22
+
23
+ self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
24
+ logger.info(f"Initialized OllamaLLMService with model {model_name}")
25
+
26
+ async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]):
27
+ """Universal invocation method"""
28
+ if isinstance(prompt, str):
29
+ return await self.acompletion(prompt)
30
+ elif isinstance(prompt, list):
31
+ return await self.achat(prompt)
32
+ else:
33
+ raise ValueError("Prompt must be string or list of messages")
34
+
35
+ async def achat(self, messages: List[Dict[str, str]]):
36
+ """Chat completion method"""
37
+ try:
38
+ payload = {
39
+ "model": self.model_name,
40
+ "messages": messages,
41
+ "stream": False
42
+ }
43
+ response = await self.backend.post("/api/chat", payload)
44
+
45
+ # Update token usage if available
46
+ if "eval_count" in response:
47
+ self.last_token_usage = {
48
+ "prompt_tokens": response.get("prompt_eval_count", 0),
49
+ "completion_tokens": response.get("eval_count", 0),
50
+ "total_tokens": response.get("prompt_eval_count", 0) + response.get("eval_count", 0)
51
+ }
52
+
53
+ return response["message"]["content"]
54
+
55
+ except Exception as e:
56
+ logger.error(f"Error in chat completion: {e}")
57
+ raise
58
+
59
+ async def acompletion(self, prompt: str):
60
+ """Text completion method"""
61
+ try:
62
+ payload = {
63
+ "model": self.model_name,
64
+ "prompt": prompt,
65
+ "stream": False
66
+ }
67
+ response = await self.backend.post("/api/generate", payload)
68
+
69
+ # Update token usage if available
70
+ if "eval_count" in response:
71
+ self.last_token_usage = {
72
+ "prompt_tokens": response.get("prompt_eval_count", 0),
73
+ "completion_tokens": response.get("eval_count", 0),
74
+ "total_tokens": response.get("prompt_eval_count", 0) + response.get("eval_count", 0)
75
+ }
76
+
77
+ return response["response"]
78
+
79
+ except Exception as e:
80
+ logger.error(f"Error in text completion: {e}")
81
+ raise
82
+
83
+ async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[str]:
84
+ """Generate multiple completions"""
85
+ results = []
86
+ for _ in range(n):
87
+ result = await self.achat(messages)
88
+ results.append(result)
89
+ return results
90
+
91
+ async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
92
+ """Stream chat responses"""
93
+ # Note: This would require modifying the backend to support streaming
94
+ # For now, return the full response
95
+ response = await self.achat(messages)
96
+ yield response
97
+
98
+ def get_token_usage(self):
99
+ """Get total token usage statistics"""
100
+ return self.last_token_usage
101
+
102
+ def get_last_token_usage(self) -> Dict[str, int]:
103
+ """Get token usage from last request"""
104
+ return self.last_token_usage
105
+
106
+ async def close(self):
107
+ """Close the backend client"""
108
+ await self.backend.close()