isa-model 0.1.1__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/core/storage/hf_storage.py +419 -0
- isa_model/deployment/__init__.py +52 -0
- isa_model/deployment/core/__init__.py +34 -0
- isa_model/deployment/core/deployment_config.py +356 -0
- isa_model/deployment/core/deployment_manager.py +549 -0
- isa_model/deployment/core/isa_deployment_service.py +401 -0
- isa_model/eval/factory.py +381 -140
- isa_model/inference/ai_factory.py +142 -240
- isa_model/inference/providers/ml_provider.py +50 -0
- isa_model/inference/services/audio/openai_tts_service.py +104 -3
- isa_model/inference/services/embedding/base_embed_service.py +112 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
- isa_model/inference/services/llm/__init__.py +2 -0
- isa_model/inference/services/llm/base_llm_service.py +111 -1
- isa_model/inference/services/llm/ollama_llm_service.py +234 -26
- isa_model/inference/services/llm/openai_llm_service.py +225 -28
- isa_model/inference/services/llm/triton_llm_service.py +481 -0
- isa_model/inference/services/ml/base_ml_service.py +78 -0
- isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
- isa_model/inference/services/vision/__init__.py +3 -3
- isa_model/inference/services/vision/base_image_gen_service.py +161 -0
- isa_model/inference/services/vision/base_vision_service.py +177 -0
- isa_model/inference/services/vision/ollama_vision_service.py +143 -17
- isa_model/inference/services/vision/replicate_image_gen_service.py +139 -7
- isa_model/training/__init__.py +62 -32
- isa_model/training/cloud/__init__.py +22 -0
- isa_model/training/cloud/job_orchestrator.py +402 -0
- isa_model/training/cloud/runpod_trainer.py +454 -0
- isa_model/training/cloud/storage_manager.py +482 -0
- isa_model/training/core/__init__.py +23 -0
- isa_model/training/core/config.py +181 -0
- isa_model/training/core/dataset.py +222 -0
- isa_model/training/core/trainer.py +720 -0
- isa_model/training/core/utils.py +213 -0
- isa_model/training/factory.py +229 -198
- isa_model-0.2.8.dist-info/METADATA +465 -0
- isa_model-0.2.8.dist-info/RECORD +86 -0
- isa_model/core/model_router.py +0 -226
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +0 -202
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
- isa_model/training/engine/llama_factory/__init__.py +0 -39
- isa_model/training/engine/llama_factory/config.py +0 -115
- isa_model/training/engine/llama_factory/data_adapter.py +0 -284
- isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
- isa_model/training/engine/llama_factory/factory.py +0 -331
- isa_model/training/engine/llama_factory/rl.py +0 -254
- isa_model/training/engine/llama_factory/trainer.py +0 -171
- isa_model/training/image_model/configs/create_config.py +0 -37
- isa_model/training/image_model/configs/create_flux_config.py +0 -26
- isa_model/training/image_model/configs/create_lora_config.py +0 -21
- isa_model/training/image_model/prepare_massed_compute.py +0 -97
- isa_model/training/image_model/prepare_upload.py +0 -17
- isa_model/training/image_model/raw_data/create_captions.py +0 -16
- isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
- isa_model/training/image_model/raw_data/pre_processing.py +0 -200
- isa_model/training/image_model/train/train.py +0 -42
- isa_model/training/image_model/train/train_flux.py +0 -41
- isa_model/training/image_model/train/train_lora.py +0 -57
- isa_model/training/image_model/train_main.py +0 -25
- isa_model-0.1.1.dist-info/METADATA +0 -327
- isa_model-0.1.1.dist-info/RECORD +0 -92
- isa_model-0.1.1.dist-info/licenses/LICENSE +0 -21
- /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
- {isa_model-0.1.1.dist-info → isa_model-0.2.8.dist-info}/WHEEL +0 -0
- {isa_model-0.1.1.dist-info → isa_model-0.2.8.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,13 @@
|
|
1
1
|
import logging
|
2
2
|
import os
|
3
|
-
|
3
|
+
import json
|
4
|
+
from typing import Dict, Any, List, Union, AsyncGenerator, Optional, Callable
|
4
5
|
|
5
6
|
# 使用官方 OpenAI 库和 dotenv
|
6
7
|
from openai import AsyncOpenAI
|
7
8
|
from dotenv import load_dotenv
|
8
9
|
|
9
|
-
from isa_model.inference.services.
|
10
|
+
from isa_model.inference.services.llm.base_llm_service import BaseLLMService
|
10
11
|
from isa_model.inference.providers.base_provider import BaseProvider
|
11
12
|
|
12
13
|
# 加载 .env.local 文件中的环境变量
|
@@ -34,14 +35,85 @@ class OpenAILLMService(BaseLLMService):
|
|
34
35
|
raise ValueError("OPENAI_API_KEY 未设置。") from e
|
35
36
|
|
36
37
|
self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
38
|
+
self.total_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "requests_count": 0}
|
39
|
+
|
40
|
+
# Tool binding attributes
|
41
|
+
self._bound_tools: List[Dict[str, Any]] = []
|
42
|
+
self._tool_binding_kwargs: Dict[str, Any] = {}
|
43
|
+
self._tool_functions: Dict[str, Callable] = {}
|
44
|
+
|
37
45
|
logger.info(f"Initialized OpenAILLMService with model {self.model_name} and endpoint {self.client.base_url}")
|
38
46
|
|
39
|
-
|
47
|
+
def _create_bound_copy(self) -> 'OpenAILLMService':
|
48
|
+
"""Create a copy of this service for tool binding"""
|
49
|
+
bound_service = OpenAILLMService(self.provider, self.model_name)
|
50
|
+
bound_service._bound_tools = self._bound_tools.copy()
|
51
|
+
bound_service._tool_binding_kwargs = self._tool_binding_kwargs.copy()
|
52
|
+
bound_service._tool_functions = self._tool_functions.copy()
|
53
|
+
return bound_service
|
54
|
+
|
55
|
+
def bind_tools(self, tools: List[Union[Dict[str, Any], Callable]], **kwargs) -> 'OpenAILLMService':
|
56
|
+
"""Bind tools to this LLM service for function calling"""
|
57
|
+
bound_service = self._create_bound_copy()
|
58
|
+
bound_service._bound_tools = self._convert_tools_to_schema(tools)
|
59
|
+
bound_service._tool_binding_kwargs = kwargs
|
60
|
+
|
61
|
+
# Store the actual functions for execution
|
62
|
+
for tool in tools:
|
63
|
+
if callable(tool):
|
64
|
+
bound_service._tool_functions[tool.__name__] = tool
|
65
|
+
|
66
|
+
return bound_service
|
67
|
+
|
68
|
+
async def ainvoke(self, prompt: Union[str, List[Any], Any]) -> str:
|
40
69
|
"""Universal invocation method"""
|
41
70
|
if isinstance(prompt, str):
|
42
71
|
return await self.acompletion(prompt)
|
43
72
|
elif isinstance(prompt, list):
|
44
|
-
|
73
|
+
# 检查是否是 LangGraph 消息对象
|
74
|
+
if prompt and hasattr(prompt[0], 'content'):
|
75
|
+
# 转换 LangGraph 消息对象为标准格式
|
76
|
+
converted_messages = []
|
77
|
+
for msg in prompt:
|
78
|
+
if hasattr(msg, 'type'):
|
79
|
+
# LangGraph 消息对象
|
80
|
+
msg_dict = {"content": msg.content}
|
81
|
+
|
82
|
+
# 根据消息类型设置 role
|
83
|
+
if msg.type == "system":
|
84
|
+
msg_dict["role"] = "system"
|
85
|
+
elif msg.type == "human":
|
86
|
+
msg_dict["role"] = "user"
|
87
|
+
elif msg.type == "ai":
|
88
|
+
msg_dict["role"] = "assistant"
|
89
|
+
# 处理工具调用
|
90
|
+
if hasattr(msg, 'tool_calls') and msg.tool_calls:
|
91
|
+
msg_dict["tool_calls"] = [
|
92
|
+
{
|
93
|
+
"id": tc.get("id", f"call_{i}"),
|
94
|
+
"type": "function",
|
95
|
+
"function": {
|
96
|
+
"name": tc["name"],
|
97
|
+
"arguments": json.dumps(tc["args"])
|
98
|
+
}
|
99
|
+
} for i, tc in enumerate(msg.tool_calls)
|
100
|
+
]
|
101
|
+
elif msg.type == "tool":
|
102
|
+
msg_dict["role"] = "tool"
|
103
|
+
if hasattr(msg, 'tool_call_id'):
|
104
|
+
msg_dict["tool_call_id"] = msg.tool_call_id
|
105
|
+
else:
|
106
|
+
msg_dict["role"] = "user" # 默认为用户消息
|
107
|
+
|
108
|
+
converted_messages.append(msg_dict)
|
109
|
+
else:
|
110
|
+
# 已经是字典格式
|
111
|
+
converted_messages.append(msg)
|
112
|
+
|
113
|
+
return await self.achat(converted_messages)
|
114
|
+
else:
|
115
|
+
# 标准字典格式的消息
|
116
|
+
return await self.achat(prompt)
|
45
117
|
else:
|
46
118
|
raise ValueError("Prompt must be a string or a list of messages")
|
47
119
|
|
@@ -51,12 +123,19 @@ class OpenAILLMService(BaseLLMService):
|
|
51
123
|
temperature = self.config.get("temperature", 0.7)
|
52
124
|
max_tokens = self.config.get("max_tokens", 1024)
|
53
125
|
|
54
|
-
|
55
|
-
model
|
56
|
-
messages
|
57
|
-
temperature
|
58
|
-
max_tokens
|
59
|
-
|
126
|
+
kwargs = {
|
127
|
+
"model": self.model_name,
|
128
|
+
"messages": messages,
|
129
|
+
"temperature": temperature,
|
130
|
+
"max_tokens": max_tokens
|
131
|
+
}
|
132
|
+
|
133
|
+
# Add tools if bound
|
134
|
+
if self._has_bound_tools():
|
135
|
+
kwargs["tools"] = self._get_bound_tools()
|
136
|
+
kwargs["tool_choice"] = "auto"
|
137
|
+
|
138
|
+
response = await self.client.chat.completions.create(**kwargs)
|
60
139
|
|
61
140
|
if response.usage:
|
62
141
|
self.last_token_usage = {
|
@@ -64,13 +143,87 @@ class OpenAILLMService(BaseLLMService):
|
|
64
143
|
"completion_tokens": response.usage.completion_tokens,
|
65
144
|
"total_tokens": response.usage.total_tokens
|
66
145
|
}
|
146
|
+
|
147
|
+
# Update total usage
|
148
|
+
self.total_token_usage["prompt_tokens"] += self.last_token_usage["prompt_tokens"]
|
149
|
+
self.total_token_usage["completion_tokens"] += self.last_token_usage["completion_tokens"]
|
150
|
+
self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
|
151
|
+
self.total_token_usage["requests_count"] += 1
|
67
152
|
|
68
|
-
|
153
|
+
# Handle tool calls if present
|
154
|
+
message = response.choices[0].message
|
155
|
+
if message.tool_calls:
|
156
|
+
return await self._handle_tool_calls(message, messages)
|
157
|
+
|
158
|
+
return message.content or ""
|
69
159
|
|
70
160
|
except Exception as e:
|
71
161
|
logger.error(f"Error in chat completion: {e}")
|
72
162
|
raise
|
73
163
|
|
164
|
+
async def _handle_tool_calls(self, assistant_message, original_messages: List[Dict[str, str]]) -> str:
|
165
|
+
"""Handle tool calls from the assistant"""
|
166
|
+
# Add assistant message with tool calls to conversation
|
167
|
+
messages = original_messages + [{
|
168
|
+
"role": "assistant",
|
169
|
+
"content": assistant_message.content or "",
|
170
|
+
"tool_calls": [
|
171
|
+
{
|
172
|
+
"id": tc.id,
|
173
|
+
"type": tc.type,
|
174
|
+
"function": {
|
175
|
+
"name": tc.function.name,
|
176
|
+
"arguments": tc.function.arguments
|
177
|
+
}
|
178
|
+
} for tc in assistant_message.tool_calls
|
179
|
+
]
|
180
|
+
}]
|
181
|
+
|
182
|
+
# Execute each tool call
|
183
|
+
for tool_call in assistant_message.tool_calls:
|
184
|
+
function_name = tool_call.function.name
|
185
|
+
arguments = json.loads(tool_call.function.arguments)
|
186
|
+
|
187
|
+
try:
|
188
|
+
# Execute the tool
|
189
|
+
if function_name in self._tool_functions:
|
190
|
+
result = self._tool_functions[function_name](**arguments)
|
191
|
+
if hasattr(result, '__await__'): # Handle async functions
|
192
|
+
result = await result
|
193
|
+
else:
|
194
|
+
result = f"Error: Function {function_name} not found"
|
195
|
+
|
196
|
+
# Add tool result to messages
|
197
|
+
messages.append({
|
198
|
+
"role": "tool",
|
199
|
+
"content": str(result),
|
200
|
+
"tool_call_id": tool_call.id
|
201
|
+
})
|
202
|
+
|
203
|
+
except Exception as e:
|
204
|
+
logger.error(f"Error executing tool {function_name}: {e}")
|
205
|
+
messages.append({
|
206
|
+
"role": "tool",
|
207
|
+
"content": f"Error executing {function_name}: {str(e)}",
|
208
|
+
"tool_call_id": tool_call.id
|
209
|
+
})
|
210
|
+
|
211
|
+
# Get final response from the model with all context
|
212
|
+
try:
|
213
|
+
kwargs = {
|
214
|
+
"model": self.model_name,
|
215
|
+
"messages": messages,
|
216
|
+
"temperature": self.config.get("temperature", 0.7),
|
217
|
+
"max_tokens": self.config.get("max_tokens", 1024)
|
218
|
+
}
|
219
|
+
|
220
|
+
response = await self.client.chat.completions.create(**kwargs)
|
221
|
+
return response.choices[0].message.content or ""
|
222
|
+
|
223
|
+
except Exception as e:
|
224
|
+
logger.error(f"Error getting final response after tool calls: {e}")
|
225
|
+
raise
|
226
|
+
|
74
227
|
async def acompletion(self, prompt: str) -> str:
|
75
228
|
"""Text completion method (using chat API)"""
|
76
229
|
messages = [{"role": "user", "content": prompt}]
|
@@ -82,13 +235,20 @@ class OpenAILLMService(BaseLLMService):
|
|
82
235
|
temperature = self.config.get("temperature", 0.7)
|
83
236
|
max_tokens = self.config.get("max_tokens", 1024)
|
84
237
|
|
85
|
-
|
86
|
-
model
|
87
|
-
messages
|
88
|
-
temperature
|
89
|
-
max_tokens
|
90
|
-
n
|
91
|
-
|
238
|
+
kwargs = {
|
239
|
+
"model": self.model_name,
|
240
|
+
"messages": messages,
|
241
|
+
"temperature": temperature,
|
242
|
+
"max_tokens": max_tokens,
|
243
|
+
"n": n
|
244
|
+
}
|
245
|
+
|
246
|
+
# Add tools if bound
|
247
|
+
if self._has_bound_tools():
|
248
|
+
kwargs["tools"] = self._get_bound_tools()
|
249
|
+
kwargs["tool_choice"] = "auto"
|
250
|
+
|
251
|
+
response = await self.client.chat.completions.create(**kwargs)
|
92
252
|
|
93
253
|
if response.usage:
|
94
254
|
self.last_token_usage = {
|
@@ -96,6 +256,12 @@ class OpenAILLMService(BaseLLMService):
|
|
96
256
|
"completion_tokens": response.usage.completion_tokens,
|
97
257
|
"total_tokens": response.usage.total_tokens
|
98
258
|
}
|
259
|
+
|
260
|
+
# Update total usage
|
261
|
+
self.total_token_usage["prompt_tokens"] += self.last_token_usage["prompt_tokens"]
|
262
|
+
self.total_token_usage["completion_tokens"] += self.last_token_usage["completion_tokens"]
|
263
|
+
self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
|
264
|
+
self.total_token_usage["requests_count"] += 1
|
99
265
|
|
100
266
|
return [choice.message.content or "" for choice in response.choices]
|
101
267
|
except Exception as e:
|
@@ -108,13 +274,20 @@ class OpenAILLMService(BaseLLMService):
|
|
108
274
|
temperature = self.config.get("temperature", 0.7)
|
109
275
|
max_tokens = self.config.get("max_tokens", 1024)
|
110
276
|
|
111
|
-
|
112
|
-
model
|
113
|
-
messages
|
114
|
-
temperature
|
115
|
-
max_tokens
|
116
|
-
stream
|
117
|
-
|
277
|
+
kwargs = {
|
278
|
+
"model": self.model_name,
|
279
|
+
"messages": messages,
|
280
|
+
"temperature": temperature,
|
281
|
+
"max_tokens": max_tokens,
|
282
|
+
"stream": True
|
283
|
+
}
|
284
|
+
|
285
|
+
# Add tools if bound
|
286
|
+
if self._has_bound_tools():
|
287
|
+
kwargs["tools"] = self._get_bound_tools()
|
288
|
+
kwargs["tool_choice"] = "auto"
|
289
|
+
|
290
|
+
stream = await self.client.chat.completions.create(**kwargs)
|
118
291
|
|
119
292
|
async for chunk in stream:
|
120
293
|
content = chunk.choices[0].delta.content
|
@@ -125,14 +298,38 @@ class OpenAILLMService(BaseLLMService):
|
|
125
298
|
logger.error(f"Error in stream chat: {e}")
|
126
299
|
raise
|
127
300
|
|
128
|
-
def
|
301
|
+
async def astream_completion(self, prompt: str) -> AsyncGenerator[str, None]:
|
302
|
+
"""Stream completion responses"""
|
303
|
+
messages = [{"role": "user", "content": prompt}]
|
304
|
+
async for chunk in self.astream_chat(messages):
|
305
|
+
yield chunk
|
306
|
+
|
307
|
+
def get_token_usage(self) -> Dict[str, Any]:
|
129
308
|
"""Get total token usage statistics"""
|
130
|
-
return self.
|
309
|
+
return self.total_token_usage
|
131
310
|
|
132
311
|
def get_last_token_usage(self) -> Dict[str, int]:
|
133
312
|
"""Get token usage from last request"""
|
134
313
|
return self.last_token_usage
|
314
|
+
|
315
|
+
def get_model_info(self) -> Dict[str, Any]:
|
316
|
+
"""Get information about the current model"""
|
317
|
+
return {
|
318
|
+
"name": self.model_name,
|
319
|
+
"max_tokens": self.config.get("max_tokens", 1024),
|
320
|
+
"supports_streaming": True,
|
321
|
+
"supports_functions": True,
|
322
|
+
"provider": "openai"
|
323
|
+
}
|
324
|
+
|
325
|
+
def _has_bound_tools(self) -> bool:
|
326
|
+
"""Check if this service has bound tools"""
|
327
|
+
return bool(self._bound_tools)
|
328
|
+
|
329
|
+
def _get_bound_tools(self) -> List[Dict[str, Any]]:
|
330
|
+
"""Get the bound tools schema"""
|
331
|
+
return self._bound_tools
|
135
332
|
|
136
333
|
async def close(self):
|
137
334
|
"""Close the backend client"""
|
138
|
-
await self.client.
|
335
|
+
await self.client.close()
|