isa-model 0.0.3__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. isa_model/__init__.py +1 -1
  2. isa_model/core/model_registry.py +273 -46
  3. isa_model/core/storage/hf_storage.py +419 -0
  4. isa_model/deployment/__init__.py +52 -0
  5. isa_model/deployment/core/__init__.py +34 -0
  6. isa_model/deployment/core/deployment_config.py +356 -0
  7. isa_model/deployment/core/deployment_manager.py +549 -0
  8. isa_model/deployment/core/isa_deployment_service.py +401 -0
  9. isa_model/eval/factory.py +381 -140
  10. isa_model/inference/ai_factory.py +142 -240
  11. isa_model/inference/providers/ml_provider.py +50 -0
  12. isa_model/inference/services/audio/openai_tts_service.py +104 -3
  13. isa_model/inference/services/embedding/base_embed_service.py +112 -0
  14. isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
  15. isa_model/inference/services/llm/__init__.py +2 -0
  16. isa_model/inference/services/llm/base_llm_service.py +111 -1
  17. isa_model/inference/services/llm/ollama_llm_service.py +234 -26
  18. isa_model/inference/services/llm/openai_llm_service.py +180 -26
  19. isa_model/inference/services/llm/triton_llm_service.py +481 -0
  20. isa_model/inference/services/ml/base_ml_service.py +78 -0
  21. isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
  22. isa_model/inference/services/vision/__init__.py +3 -3
  23. isa_model/inference/services/vision/base_image_gen_service.py +161 -0
  24. isa_model/inference/services/vision/base_vision_service.py +177 -0
  25. isa_model/inference/services/vision/ollama_vision_service.py +143 -17
  26. isa_model/inference/services/vision/replicate_image_gen_service.py +139 -7
  27. isa_model/training/__init__.py +62 -32
  28. isa_model/training/cloud/__init__.py +22 -0
  29. isa_model/training/cloud/job_orchestrator.py +402 -0
  30. isa_model/training/cloud/runpod_trainer.py +454 -0
  31. isa_model/training/cloud/storage_manager.py +482 -0
  32. isa_model/training/core/__init__.py +23 -0
  33. isa_model/training/core/config.py +181 -0
  34. isa_model/training/core/dataset.py +222 -0
  35. isa_model/training/core/trainer.py +720 -0
  36. isa_model/training/core/utils.py +213 -0
  37. isa_model/training/factory.py +229 -198
  38. isa_model-0.0.8.dist-info/METADATA +465 -0
  39. isa_model-0.0.8.dist-info/RECORD +86 -0
  40. isa_model/core/model_router.py +0 -226
  41. isa_model/core/model_version.py +0 -0
  42. isa_model/core/resource_manager.py +0 -202
  43. isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
  44. isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
  45. isa_model/training/engine/llama_factory/__init__.py +0 -39
  46. isa_model/training/engine/llama_factory/config.py +0 -115
  47. isa_model/training/engine/llama_factory/data_adapter.py +0 -284
  48. isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
  49. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
  50. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
  51. isa_model/training/engine/llama_factory/factory.py +0 -331
  52. isa_model/training/engine/llama_factory/rl.py +0 -254
  53. isa_model/training/engine/llama_factory/trainer.py +0 -171
  54. isa_model/training/image_model/configs/create_config.py +0 -37
  55. isa_model/training/image_model/configs/create_flux_config.py +0 -26
  56. isa_model/training/image_model/configs/create_lora_config.py +0 -21
  57. isa_model/training/image_model/prepare_massed_compute.py +0 -97
  58. isa_model/training/image_model/prepare_upload.py +0 -17
  59. isa_model/training/image_model/raw_data/create_captions.py +0 -16
  60. isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
  61. isa_model/training/image_model/raw_data/pre_processing.py +0 -200
  62. isa_model/training/image_model/train/train.py +0 -42
  63. isa_model/training/image_model/train/train_flux.py +0 -41
  64. isa_model/training/image_model/train/train_lora.py +0 -57
  65. isa_model/training/image_model/train_main.py +0 -25
  66. isa_model-0.0.3.dist-info/METADATA +0 -327
  67. isa_model-0.0.3.dist-info/RECORD +0 -92
  68. isa_model-0.0.3.dist-info/licenses/LICENSE +0 -21
  69. /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
  70. /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
  71. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
  72. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
  73. /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
  74. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
  75. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
  76. /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
  77. {isa_model-0.0.3.dist-info → isa_model-0.0.8.dist-info}/WHEEL +0 -0
  78. {isa_model-0.0.3.dist-info → isa_model-0.0.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,112 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Any, List, Union, Optional
3
+ from isa_model.inference.services.base_service import BaseService
4
+
5
+ class BaseEmbedService(BaseService):
6
+ """Base class for embedding services"""
7
+
8
+ @abstractmethod
9
+ async def create_text_embedding(self, text: str) -> List[float]:
10
+ """
11
+ Create embedding for single text
12
+
13
+ Args:
14
+ text: Input text to embed
15
+
16
+ Returns:
17
+ List of float values representing the embedding vector
18
+ """
19
+ pass
20
+
21
+ @abstractmethod
22
+ async def create_text_embeddings(self, texts: List[str]) -> List[List[float]]:
23
+ """
24
+ Create embeddings for multiple texts
25
+
26
+ Args:
27
+ texts: List of input texts to embed
28
+
29
+ Returns:
30
+ List of embedding vectors, one for each input text
31
+ """
32
+ pass
33
+
34
+ @abstractmethod
35
+ async def create_chunks(self, text: str, metadata: Optional[Dict] = None) -> List[Dict]:
36
+ """
37
+ Create text chunks with embeddings
38
+
39
+ Args:
40
+ text: Input text to chunk and embed
41
+ metadata: Optional metadata to attach to chunks
42
+
43
+ Returns:
44
+ List of dictionaries containing:
45
+ - text: The chunk text
46
+ - embedding: The embedding vector
47
+ - metadata: Associated metadata
48
+ - start_index: Start position in original text
49
+ - end_index: End position in original text
50
+ """
51
+ pass
52
+
53
+ @abstractmethod
54
+ async def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
55
+ """
56
+ Compute similarity between two embeddings
57
+
58
+ Args:
59
+ embedding1: First embedding vector
60
+ embedding2: Second embedding vector
61
+
62
+ Returns:
63
+ Similarity score (typically cosine similarity, range -1 to 1)
64
+ """
65
+ pass
66
+
67
+ @abstractmethod
68
+ async def find_similar_texts(
69
+ self,
70
+ query_embedding: List[float],
71
+ candidate_embeddings: List[List[float]],
72
+ top_k: int = 5
73
+ ) -> List[Dict[str, Any]]:
74
+ """
75
+ Find most similar texts based on embeddings
76
+
77
+ Args:
78
+ query_embedding: Query embedding vector
79
+ candidate_embeddings: List of candidate embedding vectors
80
+ top_k: Number of top similar results to return
81
+
82
+ Returns:
83
+ List of dictionaries containing:
84
+ - index: Index in candidate_embeddings
85
+ - similarity: Similarity score
86
+ """
87
+ pass
88
+
89
+ @abstractmethod
90
+ def get_embedding_dimension(self) -> int:
91
+ """
92
+ Get the dimension of embeddings produced by this service
93
+
94
+ Returns:
95
+ Integer dimension of embedding vectors
96
+ """
97
+ pass
98
+
99
+ @abstractmethod
100
+ def get_max_input_length(self) -> int:
101
+ """
102
+ Get maximum input text length supported
103
+
104
+ Returns:
105
+ Maximum number of characters/tokens supported
106
+ """
107
+ pass
108
+
109
+ @abstractmethod
110
+ async def close(self):
111
+ """Cleanup resources"""
112
+ pass
@@ -4,12 +4,12 @@ import asyncio
4
4
  from typing import List, Dict, Any, Optional
5
5
 
6
6
  # 保留您指定的导入和框架结构
7
- from isa_model.inference.services.base_service import BaseEmbeddingService
7
+ from isa_model.inference.services.embedding.base_embed_service import BaseEmbedService
8
8
  from isa_model.inference.providers.base_provider import BaseProvider
9
9
 
10
10
  logger = logging.getLogger(__name__)
11
11
 
12
- class OllamaEmbedService(BaseEmbeddingService):
12
+ class OllamaEmbedService(BaseEmbedService):
13
13
  """
14
14
  Ollama embedding service.
15
15
  此类遵循基础服务架构,但使用其自己的 HTTP 客户端与 Ollama API 通信,
@@ -91,6 +91,32 @@ class OllamaEmbedService(BaseEmbeddingService):
91
91
 
92
92
  return dot_product / (norm1 * norm2)
93
93
 
94
+ async def find_similar_texts(
95
+ self,
96
+ query_embedding: List[float],
97
+ candidate_embeddings: List[List[float]],
98
+ top_k: int = 5
99
+ ) -> List[Dict[str, Any]]:
100
+ """Find most similar texts based on embeddings"""
101
+ similarities = []
102
+ for i, candidate in enumerate(candidate_embeddings):
103
+ similarity = await self.compute_similarity(query_embedding, candidate)
104
+ similarities.append({"index": i, "similarity": similarity})
105
+
106
+ # Sort by similarity in descending order and return top_k
107
+ similarities.sort(key=lambda x: x["similarity"], reverse=True)
108
+ return similarities[:top_k]
109
+
110
+ def get_embedding_dimension(self) -> int:
111
+ """Get the dimension of embeddings produced by this service"""
112
+ # BGE-M3 produces 1024-dimensional embeddings
113
+ return 1024
114
+
115
+ def get_max_input_length(self) -> int:
116
+ """Get maximum input text length supported"""
117
+ # BGE-M3 supports up to 8192 tokens
118
+ return 8192
119
+
94
120
  async def close(self):
95
121
  """关闭内置的 HTTP 客户端"""
96
122
  await self.client.aclose()
@@ -5,8 +5,10 @@ LLM Services - Business logic services for Language Models
5
5
  # Import LLM services here when created
6
6
  from .ollama_llm_service import OllamaLLMService
7
7
  from .openai_llm_service import OpenAILLMService
8
+ from .triton_llm_service import TritonLLMService
8
9
 
9
10
  __all__ = [
10
11
  "OllamaLLMService",
11
12
  "OpenAILLMService",
13
+ "TritonLLMService"
12
14
  ]
@@ -1,5 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Dict, Any, List, Union, Optional, AsyncGenerator, TypeVar
2
+ from typing import Dict, Any, List, Union, Optional, AsyncGenerator, TypeVar, Callable
3
3
  from isa_model.inference.services.base_service import BaseService
4
4
 
5
5
  T = TypeVar('T') # Generic type for responses
@@ -87,6 +87,116 @@ class BaseLLMService(BaseService):
87
87
  """
88
88
  pass
89
89
 
90
+ def bind_tools(self, tools: List[Union[Dict[str, Any], Callable]], **kwargs) -> 'BaseLLMService':
91
+ """
92
+ Bind tools to this LLM service for function calling (LangChain interface)
93
+
94
+ Args:
95
+ tools: List of tools to bind. Can be:
96
+ - Dictionary with tool schema
97
+ - Callable functions (will be converted to schema)
98
+ **kwargs: Additional tool binding parameters
99
+
100
+ Returns:
101
+ A new instance of the service with tools bound
102
+
103
+ Example:
104
+ def get_weather(location: str) -> str:
105
+ '''Get weather for a location'''
106
+ return f"Weather in {location}: Sunny, 25°C"
107
+
108
+ llm_with_tools = llm.bind_tools([get_weather])
109
+ response = await llm_with_tools.ainvoke("What's the weather in Paris?")
110
+ """
111
+ # Create a copy of the current service
112
+ bound_service = self._create_bound_copy()
113
+ bound_service._bound_tools = self._convert_tools_to_schema(tools)
114
+ bound_service._tool_binding_kwargs = kwargs
115
+ return bound_service
116
+
117
+ def _create_bound_copy(self) -> 'BaseLLMService':
118
+ """Create a copy of this service for tool binding"""
119
+ # Default implementation - subclasses should override if needed
120
+ bound_service = self.__class__(self.provider, self.model_name)
121
+ bound_service.config = self.config.copy()
122
+ return bound_service
123
+
124
+ def _convert_tools_to_schema(self, tools: List[Union[Dict[str, Any], Callable]]) -> List[Dict[str, Any]]:
125
+ """Convert tools to OpenAI function calling schema"""
126
+ schemas = []
127
+ for tool in tools:
128
+ if callable(tool):
129
+ schema = self._function_to_schema(tool)
130
+ elif isinstance(tool, dict):
131
+ schema = tool
132
+ else:
133
+ raise ValueError(f"Tool must be callable or dict, got {type(tool)}")
134
+ schemas.append(schema)
135
+ return schemas
136
+
137
+ def _function_to_schema(self, func: Callable) -> Dict[str, Any]:
138
+ """Convert a Python function to OpenAI function schema"""
139
+ import inspect
140
+ import json
141
+ from typing import get_type_hints
142
+
143
+ sig = inspect.signature(func)
144
+ type_hints = get_type_hints(func)
145
+
146
+ properties = {}
147
+ required = []
148
+
149
+ for param_name, param in sig.parameters.items():
150
+ param_type = type_hints.get(param_name, str)
151
+
152
+ # Convert Python types to JSON schema types
153
+ if param_type == str:
154
+ prop_type = "string"
155
+ elif param_type == int:
156
+ prop_type = "integer"
157
+ elif param_type == float:
158
+ prop_type = "number"
159
+ elif param_type == bool:
160
+ prop_type = "boolean"
161
+ elif param_type == list:
162
+ prop_type = "array"
163
+ elif param_type == dict:
164
+ prop_type = "object"
165
+ else:
166
+ prop_type = "string" # Default fallback
167
+
168
+ properties[param_name] = {"type": prop_type}
169
+
170
+ # Add parameter to required if it has no default value
171
+ if param.default == inspect.Parameter.empty:
172
+ required.append(param_name)
173
+
174
+ return {
175
+ "type": "function",
176
+ "function": {
177
+ "name": func.__name__,
178
+ "description": func.__doc__ or f"Function {func.__name__}",
179
+ "parameters": {
180
+ "type": "object",
181
+ "properties": properties,
182
+ "required": required
183
+ }
184
+ }
185
+ }
186
+
187
+ def _has_bound_tools(self) -> bool:
188
+ """Check if this service has bound tools"""
189
+ return hasattr(self, '_bound_tools') and self._bound_tools
190
+
191
+ def _get_bound_tools(self) -> List[Dict[str, Any]]:
192
+ """Get the bound tools schema"""
193
+ return getattr(self, '_bound_tools', [])
194
+
195
+ def _execute_tool_call(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
196
+ """Execute a tool call by name with arguments"""
197
+ # This is a placeholder - subclasses should implement actual tool execution
198
+ raise NotImplementedError("Tool execution not implemented in base class")
199
+
90
200
  @abstractmethod
91
201
  def get_token_usage(self) -> Dict[str, Any]:
92
202
  """
@@ -1,18 +1,57 @@
1
1
  import logging
2
- from typing import Dict, Any, List, Union, AsyncGenerator, Optional
3
- from isa_model.inference.services.base_service import BaseLLMService
2
+ import httpx
3
+ import json
4
+ from typing import Dict, Any, List, Union, AsyncGenerator, Optional, Callable
5
+ from isa_model.inference.services.llm.base_llm_service import BaseLLMService
4
6
  from isa_model.inference.providers.base_provider import BaseProvider
5
7
 
6
8
  logger = logging.getLogger(__name__)
7
9
 
8
10
  class OllamaLLMService(BaseLLMService):
9
- """Ollama LLM service using backend client"""
11
+ """Ollama LLM service using HTTP client"""
10
12
 
11
13
  def __init__(self, provider: 'BaseProvider', model_name: str = "llama3.1"):
12
14
  super().__init__(provider, model_name)
15
+
16
+ # Create HTTP client for Ollama API
17
+ base_url = self.config.get("base_url", "http://localhost:11434")
18
+ timeout = self.config.get("timeout", 60)
19
+
20
+ self.client = httpx.AsyncClient(
21
+ base_url=base_url,
22
+ timeout=timeout
23
+ )
13
24
 
14
25
  self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
15
- logger.info(f"Initialized OllamaLLMService with model {model_name}")
26
+ self.total_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "requests_count": 0}
27
+
28
+ # Tool binding attributes
29
+ self._bound_tools: List[Dict[str, Any]] = []
30
+ self._tool_binding_kwargs: Dict[str, Any] = {}
31
+ self._tool_functions: Dict[str, Callable] = {}
32
+
33
+ logger.info(f"Initialized OllamaLLMService with model {model_name} at {base_url}")
34
+
35
+ def _create_bound_copy(self) -> 'OllamaLLMService':
36
+ """Create a copy of this service for tool binding"""
37
+ bound_service = OllamaLLMService(self.provider, self.model_name)
38
+ bound_service._bound_tools = self._bound_tools.copy()
39
+ bound_service._tool_binding_kwargs = self._tool_binding_kwargs.copy()
40
+ bound_service._tool_functions = self._tool_functions.copy()
41
+ return bound_service
42
+
43
+ def bind_tools(self, tools: List[Union[Dict[str, Any], Callable]], **kwargs) -> 'OllamaLLMService':
44
+ """Bind tools to this LLM service for function calling"""
45
+ bound_service = self._create_bound_copy()
46
+ bound_service._bound_tools = self._convert_tools_to_schema(tools)
47
+ bound_service._tool_binding_kwargs = kwargs
48
+
49
+ # Store the actual functions for execution
50
+ for tool in tools:
51
+ if callable(tool):
52
+ bound_service._tool_functions[tool.__name__] = tool
53
+
54
+ return bound_service
16
55
 
17
56
  async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]):
18
57
  """Universal invocation method"""
@@ -29,44 +68,130 @@ class OllamaLLMService(BaseLLMService):
29
68
  payload = {
30
69
  "model": self.model_name,
31
70
  "messages": messages,
32
- "stream": False
71
+ "stream": False,
72
+ "options": {
73
+ "temperature": self.config.get("temperature", 0.7),
74
+ "top_p": self.config.get("top_p", 0.9),
75
+ "num_predict": self.config.get("max_tokens", 2048)
76
+ }
33
77
  }
34
- response = await self.backend.post("/api/chat", payload)
78
+
79
+ # Add tools if bound
80
+ if self._has_bound_tools():
81
+ payload["tools"] = self._get_bound_tools()
82
+
83
+ response = await self.client.post("/api/chat", json=payload)
84
+ response.raise_for_status()
85
+ result = response.json()
35
86
 
36
87
  # Update token usage if available
37
- if "eval_count" in response:
88
+ if "eval_count" in result:
38
89
  self.last_token_usage = {
39
- "prompt_tokens": response.get("prompt_eval_count", 0),
40
- "completion_tokens": response.get("eval_count", 0),
41
- "total_tokens": response.get("prompt_eval_count", 0) + response.get("eval_count", 0)
90
+ "prompt_tokens": result.get("prompt_eval_count", 0),
91
+ "completion_tokens": result.get("eval_count", 0),
92
+ "total_tokens": result.get("prompt_eval_count", 0) + result.get("eval_count", 0)
42
93
  }
94
+
95
+ # Update total usage
96
+ self.total_token_usage["prompt_tokens"] += self.last_token_usage["prompt_tokens"]
97
+ self.total_token_usage["completion_tokens"] += self.last_token_usage["completion_tokens"]
98
+ self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
99
+ self.total_token_usage["requests_count"] += 1
100
+
101
+ # Handle tool calls if present
102
+ message = result["message"]
103
+ if "tool_calls" in message and message["tool_calls"]:
104
+ return await self._handle_tool_calls(message, messages)
43
105
 
44
- return response["message"]["content"]
106
+ return message["content"]
45
107
 
108
+ except httpx.RequestError as e:
109
+ logger.error(f"HTTP request error in chat completion: {e}")
110
+ raise
46
111
  except Exception as e:
47
112
  logger.error(f"Error in chat completion: {e}")
48
113
  raise
49
114
 
115
+ async def _handle_tool_calls(self, assistant_message: Dict[str, Any], original_messages: List[Dict[str, str]]) -> str:
116
+ """Handle tool calls from the assistant"""
117
+ tool_calls = assistant_message.get("tool_calls", [])
118
+
119
+ # Add assistant message with tool calls to conversation
120
+ messages = original_messages + [assistant_message]
121
+
122
+ # Execute each tool call
123
+ for tool_call in tool_calls:
124
+ function_name = tool_call["function"]["name"]
125
+ arguments = tool_call["function"]["arguments"]
126
+
127
+ try:
128
+ # Parse arguments if they're a string
129
+ if isinstance(arguments, str):
130
+ arguments = json.loads(arguments)
131
+
132
+ # Execute the tool
133
+ if function_name in self._tool_functions:
134
+ result = self._tool_functions[function_name](**arguments)
135
+ if hasattr(result, '__await__'): # Handle async functions
136
+ result = await result
137
+ else:
138
+ result = f"Error: Function {function_name} not found"
139
+
140
+ # Add tool result to messages
141
+ messages.append({
142
+ "role": "tool",
143
+ "content": str(result),
144
+ "tool_call_id": tool_call.get("id", function_name)
145
+ })
146
+
147
+ except Exception as e:
148
+ logger.error(f"Error executing tool {function_name}: {e}")
149
+ messages.append({
150
+ "role": "tool",
151
+ "content": f"Error executing {function_name}: {str(e)}",
152
+ "tool_call_id": tool_call.get("id", function_name)
153
+ })
154
+
155
+ # Get final response from the model
156
+ return await self.achat(messages)
157
+
50
158
  async def acompletion(self, prompt: str):
51
159
  """Text completion method"""
52
160
  try:
53
161
  payload = {
54
162
  "model": self.model_name,
55
163
  "prompt": prompt,
56
- "stream": False
164
+ "stream": False,
165
+ "options": {
166
+ "temperature": self.config.get("temperature", 0.7),
167
+ "top_p": self.config.get("top_p", 0.9),
168
+ "num_predict": self.config.get("max_tokens", 2048)
169
+ }
57
170
  }
58
- response = await self.backend.post("/api/generate", payload)
171
+
172
+ response = await self.client.post("/api/generate", json=payload)
173
+ response.raise_for_status()
174
+ result = response.json()
59
175
 
60
176
  # Update token usage if available
61
- if "eval_count" in response:
177
+ if "eval_count" in result:
62
178
  self.last_token_usage = {
63
- "prompt_tokens": response.get("prompt_eval_count", 0),
64
- "completion_tokens": response.get("eval_count", 0),
65
- "total_tokens": response.get("prompt_eval_count", 0) + response.get("eval_count", 0)
179
+ "prompt_tokens": result.get("prompt_eval_count", 0),
180
+ "completion_tokens": result.get("eval_count", 0),
181
+ "total_tokens": result.get("prompt_eval_count", 0) + result.get("eval_count", 0)
66
182
  }
183
+
184
+ # Update total usage
185
+ self.total_token_usage["prompt_tokens"] += self.last_token_usage["prompt_tokens"]
186
+ self.total_token_usage["completion_tokens"] += self.last_token_usage["completion_tokens"]
187
+ self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
188
+ self.total_token_usage["requests_count"] += 1
67
189
 
68
- return response["response"]
190
+ return result["response"]
69
191
 
192
+ except httpx.RequestError as e:
193
+ logger.error(f"HTTP request error in text completion: {e}")
194
+ raise
70
195
  except Exception as e:
71
196
  logger.error(f"Error in text completion: {e}")
72
197
  raise
@@ -81,19 +206,102 @@ class OllamaLLMService(BaseLLMService):
81
206
 
82
207
  async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
83
208
  """Stream chat responses"""
84
- # Note: This would require modifying the backend to support streaming
85
- # For now, return the full response
86
- response = await self.achat(messages)
87
- yield response
209
+ try:
210
+ payload = {
211
+ "model": self.model_name,
212
+ "messages": messages,
213
+ "stream": True,
214
+ "options": {
215
+ "temperature": self.config.get("temperature", 0.7),
216
+ "top_p": self.config.get("top_p", 0.9),
217
+ "num_predict": self.config.get("max_tokens", 2048)
218
+ }
219
+ }
220
+
221
+ # Add tools if bound
222
+ if self._has_bound_tools():
223
+ payload["tools"] = self._get_bound_tools()
224
+
225
+ async with self.client.stream("POST", "/api/chat", json=payload) as response:
226
+ response.raise_for_status()
227
+ async for line in response.aiter_lines():
228
+ if line.strip():
229
+ try:
230
+ chunk = json.loads(line)
231
+ if "message" in chunk and "content" in chunk["message"]:
232
+ content = chunk["message"]["content"]
233
+ if content:
234
+ yield content
235
+ except json.JSONDecodeError:
236
+ continue
237
+
238
+ except httpx.RequestError as e:
239
+ logger.error(f"HTTP request error in stream chat: {e}")
240
+ raise
241
+ except Exception as e:
242
+ logger.error(f"Error in stream chat: {e}")
243
+ raise
88
244
 
89
- def get_token_usage(self):
245
+ async def astream_completion(self, prompt: str) -> AsyncGenerator[str, None]:
246
+ """Stream completion responses"""
247
+ try:
248
+ payload = {
249
+ "model": self.model_name,
250
+ "prompt": prompt,
251
+ "stream": True,
252
+ "options": {
253
+ "temperature": self.config.get("temperature", 0.7),
254
+ "top_p": self.config.get("top_p", 0.9),
255
+ "num_predict": self.config.get("max_tokens", 2048)
256
+ }
257
+ }
258
+
259
+ async with self.client.stream("POST", "/api/generate", json=payload) as response:
260
+ response.raise_for_status()
261
+ async for line in response.aiter_lines():
262
+ if line.strip():
263
+ try:
264
+ chunk = json.loads(line)
265
+ if "response" in chunk:
266
+ content = chunk["response"]
267
+ if content:
268
+ yield content
269
+ except json.JSONDecodeError:
270
+ continue
271
+
272
+ except httpx.RequestError as e:
273
+ logger.error(f"HTTP request error in stream completion: {e}")
274
+ raise
275
+ except Exception as e:
276
+ logger.error(f"Error in stream completion: {e}")
277
+ raise
278
+
279
+ def get_token_usage(self) -> Dict[str, Any]:
90
280
  """Get total token usage statistics"""
91
- return self.last_token_usage
281
+ return self.total_token_usage
92
282
 
93
283
  def get_last_token_usage(self) -> Dict[str, int]:
94
284
  """Get token usage from last request"""
95
285
  return self.last_token_usage
286
+
287
+ def get_model_info(self) -> Dict[str, Any]:
288
+ """Get information about the current model"""
289
+ return {
290
+ "name": self.model_name,
291
+ "max_tokens": self.config.get("max_tokens", 2048),
292
+ "supports_streaming": True,
293
+ "supports_functions": True,
294
+ "provider": "ollama"
295
+ }
296
+
297
+ def _has_bound_tools(self) -> bool:
298
+ """Check if this service has bound tools"""
299
+ return bool(self._bound_tools)
300
+
301
+ def _get_bound_tools(self) -> List[Dict[str, Any]]:
302
+ """Get the bound tools schema"""
303
+ return self._bound_tools
96
304
 
97
305
  async def close(self):
98
- """Close the backend client"""
99
- await self.backend.close()
306
+ """Close the HTTP client"""
307
+ await self.client.aclose()