isa-model 0.2.0__py3-none-any.whl → 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/core/storage/hf_storage.py +419 -0
- isa_model/deployment/__init__.py +52 -0
- isa_model/deployment/core/__init__.py +34 -0
- isa_model/deployment/core/deployment_config.py +356 -0
- isa_model/deployment/core/deployment_manager.py +549 -0
- isa_model/deployment/core/isa_deployment_service.py +401 -0
- isa_model/eval/factory.py +381 -140
- isa_model/inference/ai_factory.py +142 -240
- isa_model/inference/providers/ml_provider.py +50 -0
- isa_model/inference/services/audio/openai_tts_service.py +104 -3
- isa_model/inference/services/embedding/base_embed_service.py +112 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
- isa_model/inference/services/llm/__init__.py +2 -0
- isa_model/inference/services/llm/base_llm_service.py +111 -1
- isa_model/inference/services/llm/ollama_llm_service.py +234 -26
- isa_model/inference/services/llm/openai_llm_service.py +243 -28
- isa_model/inference/services/llm/triton_llm_service.py +481 -0
- isa_model/inference/services/ml/base_ml_service.py +78 -0
- isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
- isa_model/inference/services/vision/__init__.py +3 -3
- isa_model/inference/services/vision/base_image_gen_service.py +161 -0
- isa_model/inference/services/vision/base_vision_service.py +177 -0
- isa_model/inference/services/vision/ollama_vision_service.py +143 -17
- isa_model/inference/services/vision/replicate_image_gen_service.py +139 -7
- isa_model/training/__init__.py +62 -32
- isa_model/training/cloud/__init__.py +22 -0
- isa_model/training/cloud/job_orchestrator.py +402 -0
- isa_model/training/cloud/runpod_trainer.py +454 -0
- isa_model/training/cloud/storage_manager.py +482 -0
- isa_model/training/core/__init__.py +23 -0
- isa_model/training/core/config.py +181 -0
- isa_model/training/core/dataset.py +222 -0
- isa_model/training/core/trainer.py +720 -0
- isa_model/training/core/utils.py +213 -0
- isa_model/training/factory.py +229 -198
- isa_model-0.2.9.dist-info/METADATA +465 -0
- isa_model-0.2.9.dist-info/RECORD +86 -0
- isa_model/core/model_router.py +0 -226
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +0 -202
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
- isa_model/training/engine/llama_factory/__init__.py +0 -39
- isa_model/training/engine/llama_factory/config.py +0 -115
- isa_model/training/engine/llama_factory/data_adapter.py +0 -284
- isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
- isa_model/training/engine/llama_factory/factory.py +0 -331
- isa_model/training/engine/llama_factory/rl.py +0 -254
- isa_model/training/engine/llama_factory/trainer.py +0 -171
- isa_model/training/image_model/configs/create_config.py +0 -37
- isa_model/training/image_model/configs/create_flux_config.py +0 -26
- isa_model/training/image_model/configs/create_lora_config.py +0 -21
- isa_model/training/image_model/prepare_massed_compute.py +0 -97
- isa_model/training/image_model/prepare_upload.py +0 -17
- isa_model/training/image_model/raw_data/create_captions.py +0 -16
- isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
- isa_model/training/image_model/raw_data/pre_processing.py +0 -200
- isa_model/training/image_model/train/train.py +0 -42
- isa_model/training/image_model/train/train_flux.py +0 -41
- isa_model/training/image_model/train/train_lora.py +0 -57
- isa_model/training/image_model/train_main.py +0 -25
- isa_model-0.2.0.dist-info/METADATA +0 -327
- isa_model-0.2.0.dist-info/RECORD +0 -92
- isa_model-0.2.0.dist-info/licenses/LICENSE +0 -21
- /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
- {isa_model-0.2.0.dist-info → isa_model-0.2.9.dist-info}/WHEEL +0 -0
- {isa_model-0.2.0.dist-info → isa_model-0.2.9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,112 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Dict, Any, List, Union, Optional
|
3
|
+
from isa_model.inference.services.base_service import BaseService
|
4
|
+
|
5
|
+
class BaseEmbedService(BaseService):
|
6
|
+
"""Base class for embedding services"""
|
7
|
+
|
8
|
+
@abstractmethod
|
9
|
+
async def create_text_embedding(self, text: str) -> List[float]:
|
10
|
+
"""
|
11
|
+
Create embedding for single text
|
12
|
+
|
13
|
+
Args:
|
14
|
+
text: Input text to embed
|
15
|
+
|
16
|
+
Returns:
|
17
|
+
List of float values representing the embedding vector
|
18
|
+
"""
|
19
|
+
pass
|
20
|
+
|
21
|
+
@abstractmethod
|
22
|
+
async def create_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
23
|
+
"""
|
24
|
+
Create embeddings for multiple texts
|
25
|
+
|
26
|
+
Args:
|
27
|
+
texts: List of input texts to embed
|
28
|
+
|
29
|
+
Returns:
|
30
|
+
List of embedding vectors, one for each input text
|
31
|
+
"""
|
32
|
+
pass
|
33
|
+
|
34
|
+
@abstractmethod
|
35
|
+
async def create_chunks(self, text: str, metadata: Optional[Dict] = None) -> List[Dict]:
|
36
|
+
"""
|
37
|
+
Create text chunks with embeddings
|
38
|
+
|
39
|
+
Args:
|
40
|
+
text: Input text to chunk and embed
|
41
|
+
metadata: Optional metadata to attach to chunks
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
List of dictionaries containing:
|
45
|
+
- text: The chunk text
|
46
|
+
- embedding: The embedding vector
|
47
|
+
- metadata: Associated metadata
|
48
|
+
- start_index: Start position in original text
|
49
|
+
- end_index: End position in original text
|
50
|
+
"""
|
51
|
+
pass
|
52
|
+
|
53
|
+
@abstractmethod
|
54
|
+
async def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
|
55
|
+
"""
|
56
|
+
Compute similarity between two embeddings
|
57
|
+
|
58
|
+
Args:
|
59
|
+
embedding1: First embedding vector
|
60
|
+
embedding2: Second embedding vector
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
Similarity score (typically cosine similarity, range -1 to 1)
|
64
|
+
"""
|
65
|
+
pass
|
66
|
+
|
67
|
+
@abstractmethod
|
68
|
+
async def find_similar_texts(
|
69
|
+
self,
|
70
|
+
query_embedding: List[float],
|
71
|
+
candidate_embeddings: List[List[float]],
|
72
|
+
top_k: int = 5
|
73
|
+
) -> List[Dict[str, Any]]:
|
74
|
+
"""
|
75
|
+
Find most similar texts based on embeddings
|
76
|
+
|
77
|
+
Args:
|
78
|
+
query_embedding: Query embedding vector
|
79
|
+
candidate_embeddings: List of candidate embedding vectors
|
80
|
+
top_k: Number of top similar results to return
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
List of dictionaries containing:
|
84
|
+
- index: Index in candidate_embeddings
|
85
|
+
- similarity: Similarity score
|
86
|
+
"""
|
87
|
+
pass
|
88
|
+
|
89
|
+
@abstractmethod
|
90
|
+
def get_embedding_dimension(self) -> int:
|
91
|
+
"""
|
92
|
+
Get the dimension of embeddings produced by this service
|
93
|
+
|
94
|
+
Returns:
|
95
|
+
Integer dimension of embedding vectors
|
96
|
+
"""
|
97
|
+
pass
|
98
|
+
|
99
|
+
@abstractmethod
|
100
|
+
def get_max_input_length(self) -> int:
|
101
|
+
"""
|
102
|
+
Get maximum input text length supported
|
103
|
+
|
104
|
+
Returns:
|
105
|
+
Maximum number of characters/tokens supported
|
106
|
+
"""
|
107
|
+
pass
|
108
|
+
|
109
|
+
@abstractmethod
|
110
|
+
async def close(self):
|
111
|
+
"""Cleanup resources"""
|
112
|
+
pass
|
@@ -4,12 +4,12 @@ import asyncio
|
|
4
4
|
from typing import List, Dict, Any, Optional
|
5
5
|
|
6
6
|
# 保留您指定的导入和框架结构
|
7
|
-
from isa_model.inference.services.
|
7
|
+
from isa_model.inference.services.embedding.base_embed_service import BaseEmbedService
|
8
8
|
from isa_model.inference.providers.base_provider import BaseProvider
|
9
9
|
|
10
10
|
logger = logging.getLogger(__name__)
|
11
11
|
|
12
|
-
class OllamaEmbedService(
|
12
|
+
class OllamaEmbedService(BaseEmbedService):
|
13
13
|
"""
|
14
14
|
Ollama embedding service.
|
15
15
|
此类遵循基础服务架构,但使用其自己的 HTTP 客户端与 Ollama API 通信,
|
@@ -91,6 +91,32 @@ class OllamaEmbedService(BaseEmbeddingService):
|
|
91
91
|
|
92
92
|
return dot_product / (norm1 * norm2)
|
93
93
|
|
94
|
+
async def find_similar_texts(
|
95
|
+
self,
|
96
|
+
query_embedding: List[float],
|
97
|
+
candidate_embeddings: List[List[float]],
|
98
|
+
top_k: int = 5
|
99
|
+
) -> List[Dict[str, Any]]:
|
100
|
+
"""Find most similar texts based on embeddings"""
|
101
|
+
similarities = []
|
102
|
+
for i, candidate in enumerate(candidate_embeddings):
|
103
|
+
similarity = await self.compute_similarity(query_embedding, candidate)
|
104
|
+
similarities.append({"index": i, "similarity": similarity})
|
105
|
+
|
106
|
+
# Sort by similarity in descending order and return top_k
|
107
|
+
similarities.sort(key=lambda x: x["similarity"], reverse=True)
|
108
|
+
return similarities[:top_k]
|
109
|
+
|
110
|
+
def get_embedding_dimension(self) -> int:
|
111
|
+
"""Get the dimension of embeddings produced by this service"""
|
112
|
+
# BGE-M3 produces 1024-dimensional embeddings
|
113
|
+
return 1024
|
114
|
+
|
115
|
+
def get_max_input_length(self) -> int:
|
116
|
+
"""Get maximum input text length supported"""
|
117
|
+
# BGE-M3 supports up to 8192 tokens
|
118
|
+
return 8192
|
119
|
+
|
94
120
|
async def close(self):
|
95
121
|
"""关闭内置的 HTTP 客户端"""
|
96
122
|
await self.client.aclose()
|
@@ -5,8 +5,10 @@ LLM Services - Business logic services for Language Models
|
|
5
5
|
# Import LLM services here when created
|
6
6
|
from .ollama_llm_service import OllamaLLMService
|
7
7
|
from .openai_llm_service import OpenAILLMService
|
8
|
+
from .triton_llm_service import TritonLLMService
|
8
9
|
|
9
10
|
__all__ = [
|
10
11
|
"OllamaLLMService",
|
11
12
|
"OpenAILLMService",
|
13
|
+
"TritonLLMService"
|
12
14
|
]
|
@@ -1,5 +1,5 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
-
from typing import Dict, Any, List, Union, Optional, AsyncGenerator, TypeVar
|
2
|
+
from typing import Dict, Any, List, Union, Optional, AsyncGenerator, TypeVar, Callable
|
3
3
|
from isa_model.inference.services.base_service import BaseService
|
4
4
|
|
5
5
|
T = TypeVar('T') # Generic type for responses
|
@@ -87,6 +87,116 @@ class BaseLLMService(BaseService):
|
|
87
87
|
"""
|
88
88
|
pass
|
89
89
|
|
90
|
+
def bind_tools(self, tools: List[Union[Dict[str, Any], Callable]], **kwargs) -> 'BaseLLMService':
|
91
|
+
"""
|
92
|
+
Bind tools to this LLM service for function calling (LangChain interface)
|
93
|
+
|
94
|
+
Args:
|
95
|
+
tools: List of tools to bind. Can be:
|
96
|
+
- Dictionary with tool schema
|
97
|
+
- Callable functions (will be converted to schema)
|
98
|
+
**kwargs: Additional tool binding parameters
|
99
|
+
|
100
|
+
Returns:
|
101
|
+
A new instance of the service with tools bound
|
102
|
+
|
103
|
+
Example:
|
104
|
+
def get_weather(location: str) -> str:
|
105
|
+
'''Get weather for a location'''
|
106
|
+
return f"Weather in {location}: Sunny, 25°C"
|
107
|
+
|
108
|
+
llm_with_tools = llm.bind_tools([get_weather])
|
109
|
+
response = await llm_with_tools.ainvoke("What's the weather in Paris?")
|
110
|
+
"""
|
111
|
+
# Create a copy of the current service
|
112
|
+
bound_service = self._create_bound_copy()
|
113
|
+
bound_service._bound_tools = self._convert_tools_to_schema(tools)
|
114
|
+
bound_service._tool_binding_kwargs = kwargs
|
115
|
+
return bound_service
|
116
|
+
|
117
|
+
def _create_bound_copy(self) -> 'BaseLLMService':
|
118
|
+
"""Create a copy of this service for tool binding"""
|
119
|
+
# Default implementation - subclasses should override if needed
|
120
|
+
bound_service = self.__class__(self.provider, self.model_name)
|
121
|
+
bound_service.config = self.config.copy()
|
122
|
+
return bound_service
|
123
|
+
|
124
|
+
def _convert_tools_to_schema(self, tools: List[Union[Dict[str, Any], Callable]]) -> List[Dict[str, Any]]:
|
125
|
+
"""Convert tools to OpenAI function calling schema"""
|
126
|
+
schemas = []
|
127
|
+
for tool in tools:
|
128
|
+
if callable(tool):
|
129
|
+
schema = self._function_to_schema(tool)
|
130
|
+
elif isinstance(tool, dict):
|
131
|
+
schema = tool
|
132
|
+
else:
|
133
|
+
raise ValueError(f"Tool must be callable or dict, got {type(tool)}")
|
134
|
+
schemas.append(schema)
|
135
|
+
return schemas
|
136
|
+
|
137
|
+
def _function_to_schema(self, func: Callable) -> Dict[str, Any]:
|
138
|
+
"""Convert a Python function to OpenAI function schema"""
|
139
|
+
import inspect
|
140
|
+
import json
|
141
|
+
from typing import get_type_hints
|
142
|
+
|
143
|
+
sig = inspect.signature(func)
|
144
|
+
type_hints = get_type_hints(func)
|
145
|
+
|
146
|
+
properties = {}
|
147
|
+
required = []
|
148
|
+
|
149
|
+
for param_name, param in sig.parameters.items():
|
150
|
+
param_type = type_hints.get(param_name, str)
|
151
|
+
|
152
|
+
# Convert Python types to JSON schema types
|
153
|
+
if param_type == str:
|
154
|
+
prop_type = "string"
|
155
|
+
elif param_type == int:
|
156
|
+
prop_type = "integer"
|
157
|
+
elif param_type == float:
|
158
|
+
prop_type = "number"
|
159
|
+
elif param_type == bool:
|
160
|
+
prop_type = "boolean"
|
161
|
+
elif param_type == list:
|
162
|
+
prop_type = "array"
|
163
|
+
elif param_type == dict:
|
164
|
+
prop_type = "object"
|
165
|
+
else:
|
166
|
+
prop_type = "string" # Default fallback
|
167
|
+
|
168
|
+
properties[param_name] = {"type": prop_type}
|
169
|
+
|
170
|
+
# Add parameter to required if it has no default value
|
171
|
+
if param.default == inspect.Parameter.empty:
|
172
|
+
required.append(param_name)
|
173
|
+
|
174
|
+
return {
|
175
|
+
"type": "function",
|
176
|
+
"function": {
|
177
|
+
"name": func.__name__,
|
178
|
+
"description": func.__doc__ or f"Function {func.__name__}",
|
179
|
+
"parameters": {
|
180
|
+
"type": "object",
|
181
|
+
"properties": properties,
|
182
|
+
"required": required
|
183
|
+
}
|
184
|
+
}
|
185
|
+
}
|
186
|
+
|
187
|
+
def _has_bound_tools(self) -> bool:
|
188
|
+
"""Check if this service has bound tools"""
|
189
|
+
return hasattr(self, '_bound_tools') and self._bound_tools
|
190
|
+
|
191
|
+
def _get_bound_tools(self) -> List[Dict[str, Any]]:
|
192
|
+
"""Get the bound tools schema"""
|
193
|
+
return getattr(self, '_bound_tools', [])
|
194
|
+
|
195
|
+
def _execute_tool_call(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
196
|
+
"""Execute a tool call by name with arguments"""
|
197
|
+
# This is a placeholder - subclasses should implement actual tool execution
|
198
|
+
raise NotImplementedError("Tool execution not implemented in base class")
|
199
|
+
|
90
200
|
@abstractmethod
|
91
201
|
def get_token_usage(self) -> Dict[str, Any]:
|
92
202
|
"""
|
@@ -1,18 +1,57 @@
|
|
1
1
|
import logging
|
2
|
-
|
3
|
-
|
2
|
+
import httpx
|
3
|
+
import json
|
4
|
+
from typing import Dict, Any, List, Union, AsyncGenerator, Optional, Callable
|
5
|
+
from isa_model.inference.services.llm.base_llm_service import BaseLLMService
|
4
6
|
from isa_model.inference.providers.base_provider import BaseProvider
|
5
7
|
|
6
8
|
logger = logging.getLogger(__name__)
|
7
9
|
|
8
10
|
class OllamaLLMService(BaseLLMService):
|
9
|
-
"""Ollama LLM service using
|
11
|
+
"""Ollama LLM service using HTTP client"""
|
10
12
|
|
11
13
|
def __init__(self, provider: 'BaseProvider', model_name: str = "llama3.1"):
|
12
14
|
super().__init__(provider, model_name)
|
15
|
+
|
16
|
+
# Create HTTP client for Ollama API
|
17
|
+
base_url = self.config.get("base_url", "http://localhost:11434")
|
18
|
+
timeout = self.config.get("timeout", 60)
|
19
|
+
|
20
|
+
self.client = httpx.AsyncClient(
|
21
|
+
base_url=base_url,
|
22
|
+
timeout=timeout
|
23
|
+
)
|
13
24
|
|
14
25
|
self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
15
|
-
|
26
|
+
self.total_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "requests_count": 0}
|
27
|
+
|
28
|
+
# Tool binding attributes
|
29
|
+
self._bound_tools: List[Dict[str, Any]] = []
|
30
|
+
self._tool_binding_kwargs: Dict[str, Any] = {}
|
31
|
+
self._tool_functions: Dict[str, Callable] = {}
|
32
|
+
|
33
|
+
logger.info(f"Initialized OllamaLLMService with model {model_name} at {base_url}")
|
34
|
+
|
35
|
+
def _create_bound_copy(self) -> 'OllamaLLMService':
|
36
|
+
"""Create a copy of this service for tool binding"""
|
37
|
+
bound_service = OllamaLLMService(self.provider, self.model_name)
|
38
|
+
bound_service._bound_tools = self._bound_tools.copy()
|
39
|
+
bound_service._tool_binding_kwargs = self._tool_binding_kwargs.copy()
|
40
|
+
bound_service._tool_functions = self._tool_functions.copy()
|
41
|
+
return bound_service
|
42
|
+
|
43
|
+
def bind_tools(self, tools: List[Union[Dict[str, Any], Callable]], **kwargs) -> 'OllamaLLMService':
|
44
|
+
"""Bind tools to this LLM service for function calling"""
|
45
|
+
bound_service = self._create_bound_copy()
|
46
|
+
bound_service._bound_tools = self._convert_tools_to_schema(tools)
|
47
|
+
bound_service._tool_binding_kwargs = kwargs
|
48
|
+
|
49
|
+
# Store the actual functions for execution
|
50
|
+
for tool in tools:
|
51
|
+
if callable(tool):
|
52
|
+
bound_service._tool_functions[tool.__name__] = tool
|
53
|
+
|
54
|
+
return bound_service
|
16
55
|
|
17
56
|
async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]):
|
18
57
|
"""Universal invocation method"""
|
@@ -29,44 +68,130 @@ class OllamaLLMService(BaseLLMService):
|
|
29
68
|
payload = {
|
30
69
|
"model": self.model_name,
|
31
70
|
"messages": messages,
|
32
|
-
"stream": False
|
71
|
+
"stream": False,
|
72
|
+
"options": {
|
73
|
+
"temperature": self.config.get("temperature", 0.7),
|
74
|
+
"top_p": self.config.get("top_p", 0.9),
|
75
|
+
"num_predict": self.config.get("max_tokens", 2048)
|
76
|
+
}
|
33
77
|
}
|
34
|
-
|
78
|
+
|
79
|
+
# Add tools if bound
|
80
|
+
if self._has_bound_tools():
|
81
|
+
payload["tools"] = self._get_bound_tools()
|
82
|
+
|
83
|
+
response = await self.client.post("/api/chat", json=payload)
|
84
|
+
response.raise_for_status()
|
85
|
+
result = response.json()
|
35
86
|
|
36
87
|
# Update token usage if available
|
37
|
-
if "eval_count" in
|
88
|
+
if "eval_count" in result:
|
38
89
|
self.last_token_usage = {
|
39
|
-
"prompt_tokens":
|
40
|
-
"completion_tokens":
|
41
|
-
"total_tokens":
|
90
|
+
"prompt_tokens": result.get("prompt_eval_count", 0),
|
91
|
+
"completion_tokens": result.get("eval_count", 0),
|
92
|
+
"total_tokens": result.get("prompt_eval_count", 0) + result.get("eval_count", 0)
|
42
93
|
}
|
94
|
+
|
95
|
+
# Update total usage
|
96
|
+
self.total_token_usage["prompt_tokens"] += self.last_token_usage["prompt_tokens"]
|
97
|
+
self.total_token_usage["completion_tokens"] += self.last_token_usage["completion_tokens"]
|
98
|
+
self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
|
99
|
+
self.total_token_usage["requests_count"] += 1
|
100
|
+
|
101
|
+
# Handle tool calls if present
|
102
|
+
message = result["message"]
|
103
|
+
if "tool_calls" in message and message["tool_calls"]:
|
104
|
+
return await self._handle_tool_calls(message, messages)
|
43
105
|
|
44
|
-
return
|
106
|
+
return message["content"]
|
45
107
|
|
108
|
+
except httpx.RequestError as e:
|
109
|
+
logger.error(f"HTTP request error in chat completion: {e}")
|
110
|
+
raise
|
46
111
|
except Exception as e:
|
47
112
|
logger.error(f"Error in chat completion: {e}")
|
48
113
|
raise
|
49
114
|
|
115
|
+
async def _handle_tool_calls(self, assistant_message: Dict[str, Any], original_messages: List[Dict[str, str]]) -> str:
|
116
|
+
"""Handle tool calls from the assistant"""
|
117
|
+
tool_calls = assistant_message.get("tool_calls", [])
|
118
|
+
|
119
|
+
# Add assistant message with tool calls to conversation
|
120
|
+
messages = original_messages + [assistant_message]
|
121
|
+
|
122
|
+
# Execute each tool call
|
123
|
+
for tool_call in tool_calls:
|
124
|
+
function_name = tool_call["function"]["name"]
|
125
|
+
arguments = tool_call["function"]["arguments"]
|
126
|
+
|
127
|
+
try:
|
128
|
+
# Parse arguments if they're a string
|
129
|
+
if isinstance(arguments, str):
|
130
|
+
arguments = json.loads(arguments)
|
131
|
+
|
132
|
+
# Execute the tool
|
133
|
+
if function_name in self._tool_functions:
|
134
|
+
result = self._tool_functions[function_name](**arguments)
|
135
|
+
if hasattr(result, '__await__'): # Handle async functions
|
136
|
+
result = await result
|
137
|
+
else:
|
138
|
+
result = f"Error: Function {function_name} not found"
|
139
|
+
|
140
|
+
# Add tool result to messages
|
141
|
+
messages.append({
|
142
|
+
"role": "tool",
|
143
|
+
"content": str(result),
|
144
|
+
"tool_call_id": tool_call.get("id", function_name)
|
145
|
+
})
|
146
|
+
|
147
|
+
except Exception as e:
|
148
|
+
logger.error(f"Error executing tool {function_name}: {e}")
|
149
|
+
messages.append({
|
150
|
+
"role": "tool",
|
151
|
+
"content": f"Error executing {function_name}: {str(e)}",
|
152
|
+
"tool_call_id": tool_call.get("id", function_name)
|
153
|
+
})
|
154
|
+
|
155
|
+
# Get final response from the model
|
156
|
+
return await self.achat(messages)
|
157
|
+
|
50
158
|
async def acompletion(self, prompt: str):
|
51
159
|
"""Text completion method"""
|
52
160
|
try:
|
53
161
|
payload = {
|
54
162
|
"model": self.model_name,
|
55
163
|
"prompt": prompt,
|
56
|
-
"stream": False
|
164
|
+
"stream": False,
|
165
|
+
"options": {
|
166
|
+
"temperature": self.config.get("temperature", 0.7),
|
167
|
+
"top_p": self.config.get("top_p", 0.9),
|
168
|
+
"num_predict": self.config.get("max_tokens", 2048)
|
169
|
+
}
|
57
170
|
}
|
58
|
-
|
171
|
+
|
172
|
+
response = await self.client.post("/api/generate", json=payload)
|
173
|
+
response.raise_for_status()
|
174
|
+
result = response.json()
|
59
175
|
|
60
176
|
# Update token usage if available
|
61
|
-
if "eval_count" in
|
177
|
+
if "eval_count" in result:
|
62
178
|
self.last_token_usage = {
|
63
|
-
"prompt_tokens":
|
64
|
-
"completion_tokens":
|
65
|
-
"total_tokens":
|
179
|
+
"prompt_tokens": result.get("prompt_eval_count", 0),
|
180
|
+
"completion_tokens": result.get("eval_count", 0),
|
181
|
+
"total_tokens": result.get("prompt_eval_count", 0) + result.get("eval_count", 0)
|
66
182
|
}
|
183
|
+
|
184
|
+
# Update total usage
|
185
|
+
self.total_token_usage["prompt_tokens"] += self.last_token_usage["prompt_tokens"]
|
186
|
+
self.total_token_usage["completion_tokens"] += self.last_token_usage["completion_tokens"]
|
187
|
+
self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
|
188
|
+
self.total_token_usage["requests_count"] += 1
|
67
189
|
|
68
|
-
return
|
190
|
+
return result["response"]
|
69
191
|
|
192
|
+
except httpx.RequestError as e:
|
193
|
+
logger.error(f"HTTP request error in text completion: {e}")
|
194
|
+
raise
|
70
195
|
except Exception as e:
|
71
196
|
logger.error(f"Error in text completion: {e}")
|
72
197
|
raise
|
@@ -81,19 +206,102 @@ class OllamaLLMService(BaseLLMService):
|
|
81
206
|
|
82
207
|
async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
|
83
208
|
"""Stream chat responses"""
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
209
|
+
try:
|
210
|
+
payload = {
|
211
|
+
"model": self.model_name,
|
212
|
+
"messages": messages,
|
213
|
+
"stream": True,
|
214
|
+
"options": {
|
215
|
+
"temperature": self.config.get("temperature", 0.7),
|
216
|
+
"top_p": self.config.get("top_p", 0.9),
|
217
|
+
"num_predict": self.config.get("max_tokens", 2048)
|
218
|
+
}
|
219
|
+
}
|
220
|
+
|
221
|
+
# Add tools if bound
|
222
|
+
if self._has_bound_tools():
|
223
|
+
payload["tools"] = self._get_bound_tools()
|
224
|
+
|
225
|
+
async with self.client.stream("POST", "/api/chat", json=payload) as response:
|
226
|
+
response.raise_for_status()
|
227
|
+
async for line in response.aiter_lines():
|
228
|
+
if line.strip():
|
229
|
+
try:
|
230
|
+
chunk = json.loads(line)
|
231
|
+
if "message" in chunk and "content" in chunk["message"]:
|
232
|
+
content = chunk["message"]["content"]
|
233
|
+
if content:
|
234
|
+
yield content
|
235
|
+
except json.JSONDecodeError:
|
236
|
+
continue
|
237
|
+
|
238
|
+
except httpx.RequestError as e:
|
239
|
+
logger.error(f"HTTP request error in stream chat: {e}")
|
240
|
+
raise
|
241
|
+
except Exception as e:
|
242
|
+
logger.error(f"Error in stream chat: {e}")
|
243
|
+
raise
|
88
244
|
|
89
|
-
def
|
245
|
+
async def astream_completion(self, prompt: str) -> AsyncGenerator[str, None]:
|
246
|
+
"""Stream completion responses"""
|
247
|
+
try:
|
248
|
+
payload = {
|
249
|
+
"model": self.model_name,
|
250
|
+
"prompt": prompt,
|
251
|
+
"stream": True,
|
252
|
+
"options": {
|
253
|
+
"temperature": self.config.get("temperature", 0.7),
|
254
|
+
"top_p": self.config.get("top_p", 0.9),
|
255
|
+
"num_predict": self.config.get("max_tokens", 2048)
|
256
|
+
}
|
257
|
+
}
|
258
|
+
|
259
|
+
async with self.client.stream("POST", "/api/generate", json=payload) as response:
|
260
|
+
response.raise_for_status()
|
261
|
+
async for line in response.aiter_lines():
|
262
|
+
if line.strip():
|
263
|
+
try:
|
264
|
+
chunk = json.loads(line)
|
265
|
+
if "response" in chunk:
|
266
|
+
content = chunk["response"]
|
267
|
+
if content:
|
268
|
+
yield content
|
269
|
+
except json.JSONDecodeError:
|
270
|
+
continue
|
271
|
+
|
272
|
+
except httpx.RequestError as e:
|
273
|
+
logger.error(f"HTTP request error in stream completion: {e}")
|
274
|
+
raise
|
275
|
+
except Exception as e:
|
276
|
+
logger.error(f"Error in stream completion: {e}")
|
277
|
+
raise
|
278
|
+
|
279
|
+
def get_token_usage(self) -> Dict[str, Any]:
|
90
280
|
"""Get total token usage statistics"""
|
91
|
-
return self.
|
281
|
+
return self.total_token_usage
|
92
282
|
|
93
283
|
def get_last_token_usage(self) -> Dict[str, int]:
|
94
284
|
"""Get token usage from last request"""
|
95
285
|
return self.last_token_usage
|
286
|
+
|
287
|
+
def get_model_info(self) -> Dict[str, Any]:
|
288
|
+
"""Get information about the current model"""
|
289
|
+
return {
|
290
|
+
"name": self.model_name,
|
291
|
+
"max_tokens": self.config.get("max_tokens", 2048),
|
292
|
+
"supports_streaming": True,
|
293
|
+
"supports_functions": True,
|
294
|
+
"provider": "ollama"
|
295
|
+
}
|
296
|
+
|
297
|
+
def _has_bound_tools(self) -> bool:
|
298
|
+
"""Check if this service has bound tools"""
|
299
|
+
return bool(self._bound_tools)
|
300
|
+
|
301
|
+
def _get_bound_tools(self) -> List[Dict[str, Any]]:
|
302
|
+
"""Get the bound tools schema"""
|
303
|
+
return self._bound_tools
|
96
304
|
|
97
305
|
async def close(self):
|
98
|
-
"""Close the
|
99
|
-
await self.
|
306
|
+
"""Close the HTTP client"""
|
307
|
+
await self.client.aclose()
|