isa-model 0.1.1__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. isa_model/__init__.py +1 -1
  2. isa_model/core/storage/hf_storage.py +419 -0
  3. isa_model/deployment/__init__.py +52 -0
  4. isa_model/deployment/core/__init__.py +34 -0
  5. isa_model/deployment/core/deployment_config.py +356 -0
  6. isa_model/deployment/core/deployment_manager.py +549 -0
  7. isa_model/deployment/core/isa_deployment_service.py +401 -0
  8. isa_model/eval/factory.py +381 -140
  9. isa_model/inference/ai_factory.py +142 -240
  10. isa_model/inference/providers/ml_provider.py +50 -0
  11. isa_model/inference/services/audio/openai_tts_service.py +104 -3
  12. isa_model/inference/services/embedding/base_embed_service.py +112 -0
  13. isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
  14. isa_model/inference/services/llm/__init__.py +2 -0
  15. isa_model/inference/services/llm/base_llm_service.py +111 -1
  16. isa_model/inference/services/llm/ollama_llm_service.py +234 -26
  17. isa_model/inference/services/llm/openai_llm_service.py +225 -28
  18. isa_model/inference/services/llm/triton_llm_service.py +481 -0
  19. isa_model/inference/services/ml/base_ml_service.py +78 -0
  20. isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
  21. isa_model/inference/services/vision/__init__.py +3 -3
  22. isa_model/inference/services/vision/base_image_gen_service.py +161 -0
  23. isa_model/inference/services/vision/base_vision_service.py +177 -0
  24. isa_model/inference/services/vision/ollama_vision_service.py +143 -17
  25. isa_model/inference/services/vision/replicate_image_gen_service.py +139 -7
  26. isa_model/training/__init__.py +62 -32
  27. isa_model/training/cloud/__init__.py +22 -0
  28. isa_model/training/cloud/job_orchestrator.py +402 -0
  29. isa_model/training/cloud/runpod_trainer.py +454 -0
  30. isa_model/training/cloud/storage_manager.py +482 -0
  31. isa_model/training/core/__init__.py +23 -0
  32. isa_model/training/core/config.py +181 -0
  33. isa_model/training/core/dataset.py +222 -0
  34. isa_model/training/core/trainer.py +720 -0
  35. isa_model/training/core/utils.py +213 -0
  36. isa_model/training/factory.py +229 -198
  37. isa_model-0.2.8.dist-info/METADATA +465 -0
  38. isa_model-0.2.8.dist-info/RECORD +86 -0
  39. isa_model/core/model_router.py +0 -226
  40. isa_model/core/model_version.py +0 -0
  41. isa_model/core/resource_manager.py +0 -202
  42. isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
  43. isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
  44. isa_model/training/engine/llama_factory/__init__.py +0 -39
  45. isa_model/training/engine/llama_factory/config.py +0 -115
  46. isa_model/training/engine/llama_factory/data_adapter.py +0 -284
  47. isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
  48. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
  49. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
  50. isa_model/training/engine/llama_factory/factory.py +0 -331
  51. isa_model/training/engine/llama_factory/rl.py +0 -254
  52. isa_model/training/engine/llama_factory/trainer.py +0 -171
  53. isa_model/training/image_model/configs/create_config.py +0 -37
  54. isa_model/training/image_model/configs/create_flux_config.py +0 -26
  55. isa_model/training/image_model/configs/create_lora_config.py +0 -21
  56. isa_model/training/image_model/prepare_massed_compute.py +0 -97
  57. isa_model/training/image_model/prepare_upload.py +0 -17
  58. isa_model/training/image_model/raw_data/create_captions.py +0 -16
  59. isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
  60. isa_model/training/image_model/raw_data/pre_processing.py +0 -200
  61. isa_model/training/image_model/train/train.py +0 -42
  62. isa_model/training/image_model/train/train_flux.py +0 -41
  63. isa_model/training/image_model/train/train_lora.py +0 -57
  64. isa_model/training/image_model/train_main.py +0 -25
  65. isa_model-0.1.1.dist-info/METADATA +0 -327
  66. isa_model-0.1.1.dist-info/RECORD +0 -92
  67. isa_model-0.1.1.dist-info/licenses/LICENSE +0 -21
  68. /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
  69. /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
  70. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
  71. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
  72. /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
  73. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
  74. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
  75. /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
  76. {isa_model-0.1.1.dist-info → isa_model-0.2.8.dist-info}/WHEEL +0 -0
  77. {isa_model-0.1.1.dist-info → isa_model-0.2.8.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,13 @@
1
1
  import logging
2
2
  import os
3
- from typing import Dict, Any, List, Union, AsyncGenerator, Optional
3
+ import json
4
+ from typing import Dict, Any, List, Union, AsyncGenerator, Optional, Callable
4
5
 
5
6
  # 使用官方 OpenAI 库和 dotenv
6
7
  from openai import AsyncOpenAI
7
8
  from dotenv import load_dotenv
8
9
 
9
- from isa_model.inference.services.base_service import BaseLLMService
10
+ from isa_model.inference.services.llm.base_llm_service import BaseLLMService
10
11
  from isa_model.inference.providers.base_provider import BaseProvider
11
12
 
12
13
  # 加载 .env.local 文件中的环境变量
@@ -34,14 +35,85 @@ class OpenAILLMService(BaseLLMService):
34
35
  raise ValueError("OPENAI_API_KEY 未设置。") from e
35
36
 
36
37
  self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
38
+ self.total_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "requests_count": 0}
39
+
40
+ # Tool binding attributes
41
+ self._bound_tools: List[Dict[str, Any]] = []
42
+ self._tool_binding_kwargs: Dict[str, Any] = {}
43
+ self._tool_functions: Dict[str, Callable] = {}
44
+
37
45
  logger.info(f"Initialized OpenAILLMService with model {self.model_name} and endpoint {self.client.base_url}")
38
46
 
39
- async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]) -> str:
47
+ def _create_bound_copy(self) -> 'OpenAILLMService':
48
+ """Create a copy of this service for tool binding"""
49
+ bound_service = OpenAILLMService(self.provider, self.model_name)
50
+ bound_service._bound_tools = self._bound_tools.copy()
51
+ bound_service._tool_binding_kwargs = self._tool_binding_kwargs.copy()
52
+ bound_service._tool_functions = self._tool_functions.copy()
53
+ return bound_service
54
+
55
+ def bind_tools(self, tools: List[Union[Dict[str, Any], Callable]], **kwargs) -> 'OpenAILLMService':
56
+ """Bind tools to this LLM service for function calling"""
57
+ bound_service = self._create_bound_copy()
58
+ bound_service._bound_tools = self._convert_tools_to_schema(tools)
59
+ bound_service._tool_binding_kwargs = kwargs
60
+
61
+ # Store the actual functions for execution
62
+ for tool in tools:
63
+ if callable(tool):
64
+ bound_service._tool_functions[tool.__name__] = tool
65
+
66
+ return bound_service
67
+
68
+ async def ainvoke(self, prompt: Union[str, List[Any], Any]) -> str:
40
69
  """Universal invocation method"""
41
70
  if isinstance(prompt, str):
42
71
  return await self.acompletion(prompt)
43
72
  elif isinstance(prompt, list):
44
- return await self.achat(prompt)
73
+ # 检查是否是 LangGraph 消息对象
74
+ if prompt and hasattr(prompt[0], 'content'):
75
+ # 转换 LangGraph 消息对象为标准格式
76
+ converted_messages = []
77
+ for msg in prompt:
78
+ if hasattr(msg, 'type'):
79
+ # LangGraph 消息对象
80
+ msg_dict = {"content": msg.content}
81
+
82
+ # 根据消息类型设置 role
83
+ if msg.type == "system":
84
+ msg_dict["role"] = "system"
85
+ elif msg.type == "human":
86
+ msg_dict["role"] = "user"
87
+ elif msg.type == "ai":
88
+ msg_dict["role"] = "assistant"
89
+ # 处理工具调用
90
+ if hasattr(msg, 'tool_calls') and msg.tool_calls:
91
+ msg_dict["tool_calls"] = [
92
+ {
93
+ "id": tc.get("id", f"call_{i}"),
94
+ "type": "function",
95
+ "function": {
96
+ "name": tc["name"],
97
+ "arguments": json.dumps(tc["args"])
98
+ }
99
+ } for i, tc in enumerate(msg.tool_calls)
100
+ ]
101
+ elif msg.type == "tool":
102
+ msg_dict["role"] = "tool"
103
+ if hasattr(msg, 'tool_call_id'):
104
+ msg_dict["tool_call_id"] = msg.tool_call_id
105
+ else:
106
+ msg_dict["role"] = "user" # 默认为用户消息
107
+
108
+ converted_messages.append(msg_dict)
109
+ else:
110
+ # 已经是字典格式
111
+ converted_messages.append(msg)
112
+
113
+ return await self.achat(converted_messages)
114
+ else:
115
+ # 标准字典格式的消息
116
+ return await self.achat(prompt)
45
117
  else:
46
118
  raise ValueError("Prompt must be a string or a list of messages")
47
119
 
@@ -51,12 +123,19 @@ class OpenAILLMService(BaseLLMService):
51
123
  temperature = self.config.get("temperature", 0.7)
52
124
  max_tokens = self.config.get("max_tokens", 1024)
53
125
 
54
- response = await self.client.chat.completions.create(
55
- model=self.model_name,
56
- messages=messages,
57
- temperature=temperature,
58
- max_tokens=max_tokens
59
- )
126
+ kwargs = {
127
+ "model": self.model_name,
128
+ "messages": messages,
129
+ "temperature": temperature,
130
+ "max_tokens": max_tokens
131
+ }
132
+
133
+ # Add tools if bound
134
+ if self._has_bound_tools():
135
+ kwargs["tools"] = self._get_bound_tools()
136
+ kwargs["tool_choice"] = "auto"
137
+
138
+ response = await self.client.chat.completions.create(**kwargs)
60
139
 
61
140
  if response.usage:
62
141
  self.last_token_usage = {
@@ -64,13 +143,87 @@ class OpenAILLMService(BaseLLMService):
64
143
  "completion_tokens": response.usage.completion_tokens,
65
144
  "total_tokens": response.usage.total_tokens
66
145
  }
146
+
147
+ # Update total usage
148
+ self.total_token_usage["prompt_tokens"] += self.last_token_usage["prompt_tokens"]
149
+ self.total_token_usage["completion_tokens"] += self.last_token_usage["completion_tokens"]
150
+ self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
151
+ self.total_token_usage["requests_count"] += 1
67
152
 
68
- return response.choices[0].message.content or ""
153
+ # Handle tool calls if present
154
+ message = response.choices[0].message
155
+ if message.tool_calls:
156
+ return await self._handle_tool_calls(message, messages)
157
+
158
+ return message.content or ""
69
159
 
70
160
  except Exception as e:
71
161
  logger.error(f"Error in chat completion: {e}")
72
162
  raise
73
163
 
164
+ async def _handle_tool_calls(self, assistant_message, original_messages: List[Dict[str, str]]) -> str:
165
+ """Handle tool calls from the assistant"""
166
+ # Add assistant message with tool calls to conversation
167
+ messages = original_messages + [{
168
+ "role": "assistant",
169
+ "content": assistant_message.content or "",
170
+ "tool_calls": [
171
+ {
172
+ "id": tc.id,
173
+ "type": tc.type,
174
+ "function": {
175
+ "name": tc.function.name,
176
+ "arguments": tc.function.arguments
177
+ }
178
+ } for tc in assistant_message.tool_calls
179
+ ]
180
+ }]
181
+
182
+ # Execute each tool call
183
+ for tool_call in assistant_message.tool_calls:
184
+ function_name = tool_call.function.name
185
+ arguments = json.loads(tool_call.function.arguments)
186
+
187
+ try:
188
+ # Execute the tool
189
+ if function_name in self._tool_functions:
190
+ result = self._tool_functions[function_name](**arguments)
191
+ if hasattr(result, '__await__'): # Handle async functions
192
+ result = await result
193
+ else:
194
+ result = f"Error: Function {function_name} not found"
195
+
196
+ # Add tool result to messages
197
+ messages.append({
198
+ "role": "tool",
199
+ "content": str(result),
200
+ "tool_call_id": tool_call.id
201
+ })
202
+
203
+ except Exception as e:
204
+ logger.error(f"Error executing tool {function_name}: {e}")
205
+ messages.append({
206
+ "role": "tool",
207
+ "content": f"Error executing {function_name}: {str(e)}",
208
+ "tool_call_id": tool_call.id
209
+ })
210
+
211
+ # Get final response from the model with all context
212
+ try:
213
+ kwargs = {
214
+ "model": self.model_name,
215
+ "messages": messages,
216
+ "temperature": self.config.get("temperature", 0.7),
217
+ "max_tokens": self.config.get("max_tokens", 1024)
218
+ }
219
+
220
+ response = await self.client.chat.completions.create(**kwargs)
221
+ return response.choices[0].message.content or ""
222
+
223
+ except Exception as e:
224
+ logger.error(f"Error getting final response after tool calls: {e}")
225
+ raise
226
+
74
227
  async def acompletion(self, prompt: str) -> str:
75
228
  """Text completion method (using chat API)"""
76
229
  messages = [{"role": "user", "content": prompt}]
@@ -82,13 +235,20 @@ class OpenAILLMService(BaseLLMService):
82
235
  temperature = self.config.get("temperature", 0.7)
83
236
  max_tokens = self.config.get("max_tokens", 1024)
84
237
 
85
- response = await self.client.chat.completions.create(
86
- model=self.model_name,
87
- messages=messages,
88
- temperature=temperature,
89
- max_tokens=max_tokens,
90
- n=n
91
- )
238
+ kwargs = {
239
+ "model": self.model_name,
240
+ "messages": messages,
241
+ "temperature": temperature,
242
+ "max_tokens": max_tokens,
243
+ "n": n
244
+ }
245
+
246
+ # Add tools if bound
247
+ if self._has_bound_tools():
248
+ kwargs["tools"] = self._get_bound_tools()
249
+ kwargs["tool_choice"] = "auto"
250
+
251
+ response = await self.client.chat.completions.create(**kwargs)
92
252
 
93
253
  if response.usage:
94
254
  self.last_token_usage = {
@@ -96,6 +256,12 @@ class OpenAILLMService(BaseLLMService):
96
256
  "completion_tokens": response.usage.completion_tokens,
97
257
  "total_tokens": response.usage.total_tokens
98
258
  }
259
+
260
+ # Update total usage
261
+ self.total_token_usage["prompt_tokens"] += self.last_token_usage["prompt_tokens"]
262
+ self.total_token_usage["completion_tokens"] += self.last_token_usage["completion_tokens"]
263
+ self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
264
+ self.total_token_usage["requests_count"] += 1
99
265
 
100
266
  return [choice.message.content or "" for choice in response.choices]
101
267
  except Exception as e:
@@ -108,13 +274,20 @@ class OpenAILLMService(BaseLLMService):
108
274
  temperature = self.config.get("temperature", 0.7)
109
275
  max_tokens = self.config.get("max_tokens", 1024)
110
276
 
111
- stream = await self.client.chat.completions.create(
112
- model=self.model_name,
113
- messages=messages,
114
- temperature=temperature,
115
- max_tokens=max_tokens,
116
- stream=True
117
- )
277
+ kwargs = {
278
+ "model": self.model_name,
279
+ "messages": messages,
280
+ "temperature": temperature,
281
+ "max_tokens": max_tokens,
282
+ "stream": True
283
+ }
284
+
285
+ # Add tools if bound
286
+ if self._has_bound_tools():
287
+ kwargs["tools"] = self._get_bound_tools()
288
+ kwargs["tool_choice"] = "auto"
289
+
290
+ stream = await self.client.chat.completions.create(**kwargs)
118
291
 
119
292
  async for chunk in stream:
120
293
  content = chunk.choices[0].delta.content
@@ -125,14 +298,38 @@ class OpenAILLMService(BaseLLMService):
125
298
  logger.error(f"Error in stream chat: {e}")
126
299
  raise
127
300
 
128
- def get_token_usage(self) -> Dict[str, int]:
301
+ async def astream_completion(self, prompt: str) -> AsyncGenerator[str, None]:
302
+ """Stream completion responses"""
303
+ messages = [{"role": "user", "content": prompt}]
304
+ async for chunk in self.astream_chat(messages):
305
+ yield chunk
306
+
307
+ def get_token_usage(self) -> Dict[str, Any]:
129
308
  """Get total token usage statistics"""
130
- return self.last_token_usage
309
+ return self.total_token_usage
131
310
 
132
311
  def get_last_token_usage(self) -> Dict[str, int]:
133
312
  """Get token usage from last request"""
134
313
  return self.last_token_usage
314
+
315
+ def get_model_info(self) -> Dict[str, Any]:
316
+ """Get information about the current model"""
317
+ return {
318
+ "name": self.model_name,
319
+ "max_tokens": self.config.get("max_tokens", 1024),
320
+ "supports_streaming": True,
321
+ "supports_functions": True,
322
+ "provider": "openai"
323
+ }
324
+
325
+ def _has_bound_tools(self) -> bool:
326
+ """Check if this service has bound tools"""
327
+ return bool(self._bound_tools)
328
+
329
+ def _get_bound_tools(self) -> List[Dict[str, Any]]:
330
+ """Get the bound tools schema"""
331
+ return self._bound_tools
135
332
 
136
333
  async def close(self):
137
334
  """Close the backend client"""
138
- await self.client.aclose()
335
+ await self.client.close()