bizyengine 1.2.45__py3-none-any.whl → 1.2.71__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bizyengine/bizy_server/errno.py +21 -0
- bizyengine/bizy_server/server.py +130 -160
- bizyengine/bizy_server/utils.py +3 -0
- bizyengine/bizyair_extras/__init__.py +38 -31
- bizyengine/bizyair_extras/third_party_api/__init__.py +15 -0
- bizyengine/bizyair_extras/third_party_api/nodes_doubao.py +535 -0
- bizyengine/bizyair_extras/third_party_api/nodes_flux.py +173 -0
- bizyengine/bizyair_extras/third_party_api/nodes_gemini.py +403 -0
- bizyengine/bizyair_extras/third_party_api/nodes_gpt.py +101 -0
- bizyengine/bizyair_extras/third_party_api/nodes_hailuo.py +115 -0
- bizyengine/bizyair_extras/third_party_api/nodes_kling.py +404 -0
- bizyengine/bizyair_extras/third_party_api/nodes_sora.py +218 -0
- bizyengine/bizyair_extras/third_party_api/nodes_veo3.py +193 -0
- bizyengine/bizyair_extras/third_party_api/nodes_wan_api.py +198 -0
- bizyengine/bizyair_extras/third_party_api/trd_nodes_base.py +183 -0
- bizyengine/bizyair_extras/utils/aliyun_oss.py +92 -0
- bizyengine/bizyair_extras/utils/audio.py +88 -0
- bizyengine/bizybot/__init__.py +12 -0
- bizyengine/bizybot/client.py +774 -0
- bizyengine/bizybot/config.py +129 -0
- bizyengine/bizybot/coordinator.py +556 -0
- bizyengine/bizybot/exceptions.py +186 -0
- bizyengine/bizybot/mcp/__init__.py +3 -0
- bizyengine/bizybot/mcp/manager.py +520 -0
- bizyengine/bizybot/mcp/models.py +46 -0
- bizyengine/bizybot/mcp/registry.py +129 -0
- bizyengine/bizybot/mcp/routing.py +378 -0
- bizyengine/bizybot/models.py +344 -0
- bizyengine/core/__init__.py +1 -0
- bizyengine/core/commands/servers/prompt_server.py +10 -1
- bizyengine/core/common/client.py +8 -7
- bizyengine/core/common/utils.py +30 -1
- bizyengine/core/image_utils.py +12 -283
- bizyengine/misc/llm.py +32 -15
- bizyengine/misc/utils.py +179 -2
- bizyengine/version.txt +1 -1
- {bizyengine-1.2.45.dist-info → bizyengine-1.2.71.dist-info}/METADATA +3 -1
- {bizyengine-1.2.45.dist-info → bizyengine-1.2.71.dist-info}/RECORD +40 -16
- {bizyengine-1.2.45.dist-info → bizyengine-1.2.71.dist-info}/WHEEL +0 -0
- {bizyengine-1.2.45.dist-info → bizyengine-1.2.71.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data models for conversation management
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import uuid
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from typing import Any, Dict, List, Optional
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ConversationValidationError(Exception):
|
|
13
|
+
"""对话验证错误"""
|
|
14
|
+
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class ToolCall:
|
|
20
|
+
"""Represents a tool call from LLM"""
|
|
21
|
+
|
|
22
|
+
id: str
|
|
23
|
+
type: str # Currently only "function"
|
|
24
|
+
function: "ToolFunction"
|
|
25
|
+
result: Optional[Dict[str, Any]] = None
|
|
26
|
+
error: Optional[str] = None
|
|
27
|
+
|
|
28
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
29
|
+
"""Convert to dictionary format"""
|
|
30
|
+
return {
|
|
31
|
+
"id": self.id,
|
|
32
|
+
"type": self.type,
|
|
33
|
+
"function": {
|
|
34
|
+
"name": self.function.name,
|
|
35
|
+
"arguments": self.function.arguments,
|
|
36
|
+
},
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class ToolFunction:
|
|
42
|
+
"""Represents a function call within a tool call"""
|
|
43
|
+
|
|
44
|
+
name: str
|
|
45
|
+
arguments: str # JSON string format
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class Message:
|
|
50
|
+
"""Represents a conversation message"""
|
|
51
|
+
|
|
52
|
+
role: str # "user", "assistant", "system", "tool"
|
|
53
|
+
content: Optional[str] = None
|
|
54
|
+
tool_calls: Optional[List[ToolCall]] = None
|
|
55
|
+
tool_call_id: Optional[str] = None
|
|
56
|
+
timestamp: datetime = field(default_factory=datetime.now)
|
|
57
|
+
reasoning_content: Optional[str] = None # For reasoning models like deepseek-R1
|
|
58
|
+
|
|
59
|
+
def __post_init__(self):
|
|
60
|
+
"""Validate message after initialization"""
|
|
61
|
+
if self.role not in ["user", "assistant", "system", "tool"]:
|
|
62
|
+
raise ValueError(f"Invalid role: {self.role}")
|
|
63
|
+
|
|
64
|
+
if self.role == "tool" and not self.tool_call_id:
|
|
65
|
+
raise ValueError("Tool messages must have a tool_call_id")
|
|
66
|
+
|
|
67
|
+
if self.role == "tool" and not self.content:
|
|
68
|
+
raise ValueError("Tool messages must have content")
|
|
69
|
+
|
|
70
|
+
def to_openai_format(self) -> Dict[str, Any]:
|
|
71
|
+
"""Convert to OpenAI API format"""
|
|
72
|
+
msg = {"role": self.role}
|
|
73
|
+
|
|
74
|
+
# Only include content if it's not None
|
|
75
|
+
if self.content is not None:
|
|
76
|
+
msg["content"] = self.content
|
|
77
|
+
|
|
78
|
+
if self.tool_calls:
|
|
79
|
+
msg["tool_calls"] = [tc.to_dict() for tc in self.tool_calls]
|
|
80
|
+
|
|
81
|
+
if self.tool_call_id:
|
|
82
|
+
msg["tool_call_id"] = self.tool_call_id
|
|
83
|
+
|
|
84
|
+
# Note: reasoning_content is not included in OpenAI format
|
|
85
|
+
# as it's handled separately in the conversation flow
|
|
86
|
+
|
|
87
|
+
return msg
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass
|
|
91
|
+
class Conversation:
|
|
92
|
+
"""Represents a conversation with message history"""
|
|
93
|
+
|
|
94
|
+
id: str
|
|
95
|
+
messages: List[Message]
|
|
96
|
+
created_at: datetime
|
|
97
|
+
updated_at: datetime
|
|
98
|
+
|
|
99
|
+
def add_user_message(self, content: str) -> None:
|
|
100
|
+
"""Add a user message to the conversation"""
|
|
101
|
+
message = Message(role="user", content=content)
|
|
102
|
+
self.messages.append(message)
|
|
103
|
+
self.updated_at = datetime.now()
|
|
104
|
+
|
|
105
|
+
def add_assistant_message(
|
|
106
|
+
self, content: str, tool_calls: Optional[List[ToolCall]] = None
|
|
107
|
+
) -> None:
|
|
108
|
+
"""Add an assistant message to the conversation"""
|
|
109
|
+
message = Message(role="assistant", content=content, tool_calls=tool_calls)
|
|
110
|
+
self.messages.append(message)
|
|
111
|
+
self.updated_at = datetime.now()
|
|
112
|
+
|
|
113
|
+
def add_tool_result(self, tool_call_id: str, result: Any) -> None:
|
|
114
|
+
"""Add a tool result message to the conversation"""
|
|
115
|
+
# Handle different result types
|
|
116
|
+
if isinstance(result, str):
|
|
117
|
+
content = result
|
|
118
|
+
elif isinstance(result, dict):
|
|
119
|
+
# 如果是字典,尝试提取有用信息
|
|
120
|
+
content = str(result)
|
|
121
|
+
else:
|
|
122
|
+
content = str(result)
|
|
123
|
+
|
|
124
|
+
message = Message(role="tool", content=content, tool_call_id=tool_call_id)
|
|
125
|
+
self.messages.append(message)
|
|
126
|
+
self.updated_at = datetime.now()
|
|
127
|
+
|
|
128
|
+
def add_system_message(self, content: str) -> None:
|
|
129
|
+
"""Add a system message to the conversation"""
|
|
130
|
+
message = Message(role="system", content=content)
|
|
131
|
+
self.messages.append(message)
|
|
132
|
+
self.updated_at = datetime.now()
|
|
133
|
+
|
|
134
|
+
def add_message(self, message: Message) -> None:
|
|
135
|
+
"""Add a message to the conversation"""
|
|
136
|
+
self.messages.append(message)
|
|
137
|
+
self.updated_at = datetime.now()
|
|
138
|
+
|
|
139
|
+
def to_openai_format(self) -> List[Dict[str, Any]]:
|
|
140
|
+
"""Convert conversation to OpenAI API format"""
|
|
141
|
+
return [msg.to_openai_format() for msg in self.messages]
|
|
142
|
+
|
|
143
|
+
def get_message_count(self) -> int:
|
|
144
|
+
"""Get total number of messages"""
|
|
145
|
+
return len(self.messages)
|
|
146
|
+
|
|
147
|
+
def get_last_message(self) -> Optional[Message]:
|
|
148
|
+
"""Get the last message in the conversation"""
|
|
149
|
+
return self.messages[-1] if self.messages else None
|
|
150
|
+
|
|
151
|
+
def get_messages_by_role(self, role: str) -> List[Message]:
|
|
152
|
+
"""Get all messages with a specific role"""
|
|
153
|
+
return [msg for msg in self.messages if msg.role == role]
|
|
154
|
+
|
|
155
|
+
def get_recent_messages(self, count: int) -> List[Message]:
|
|
156
|
+
"""Get the most recent N messages"""
|
|
157
|
+
return self.messages[-count:] if count > 0 else []
|
|
158
|
+
|
|
159
|
+
def clear_messages(self) -> None:
|
|
160
|
+
"""Clear all messages from the conversation"""
|
|
161
|
+
self.messages.clear()
|
|
162
|
+
self.updated_at = datetime.now()
|
|
163
|
+
|
|
164
|
+
def has_tool_calls(self) -> bool:
|
|
165
|
+
"""Check if the conversation has any pending tool calls"""
|
|
166
|
+
last_message = self.get_last_message()
|
|
167
|
+
return (
|
|
168
|
+
last_message
|
|
169
|
+
and last_message.role == "assistant"
|
|
170
|
+
and last_message.tool_calls is not None
|
|
171
|
+
and len(last_message.tool_calls) > 0
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
@classmethod
|
|
175
|
+
def from_openai_format(
|
|
176
|
+
cls,
|
|
177
|
+
conversation_history: List[Dict[str, Any]],
|
|
178
|
+
conversation_id: Optional[str] = None,
|
|
179
|
+
) -> "Conversation":
|
|
180
|
+
"""
|
|
181
|
+
从OpenAI格式的对话历史创建Conversation实例
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
conversation_history: OpenAI格式的消息列表
|
|
185
|
+
conversation_id: 可选的对话ID,如果不提供则生成新的
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
Conversation实例
|
|
189
|
+
|
|
190
|
+
Raises:
|
|
191
|
+
ConversationValidationError: 如果消息格式无效
|
|
192
|
+
"""
|
|
193
|
+
if conversation_id is None:
|
|
194
|
+
conversation_id = str(uuid.uuid4())
|
|
195
|
+
|
|
196
|
+
messages = []
|
|
197
|
+
current_time = datetime.now()
|
|
198
|
+
|
|
199
|
+
for i, msg_data in enumerate(conversation_history):
|
|
200
|
+
try:
|
|
201
|
+
message = cls._parse_message_from_dict(msg_data)
|
|
202
|
+
messages.append(message)
|
|
203
|
+
except Exception as e:
|
|
204
|
+
raise ConversationValidationError(
|
|
205
|
+
f"Invalid message format at index {i}: {e}"
|
|
206
|
+
) from e
|
|
207
|
+
|
|
208
|
+
return cls(
|
|
209
|
+
id=conversation_id,
|
|
210
|
+
messages=messages,
|
|
211
|
+
created_at=current_time,
|
|
212
|
+
updated_at=current_time,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
@staticmethod
|
|
216
|
+
def _parse_message_from_dict(msg_data: Dict[str, Any]) -> Message:
|
|
217
|
+
"""解析单个消息字典为Message对象"""
|
|
218
|
+
|
|
219
|
+
# 验证必需字段
|
|
220
|
+
if "role" not in msg_data:
|
|
221
|
+
raise ConversationValidationError("Message must have 'role' field")
|
|
222
|
+
|
|
223
|
+
role = msg_data["role"]
|
|
224
|
+
if role not in ["user", "assistant", "system", "tool"]:
|
|
225
|
+
raise ConversationValidationError(f"Invalid role: {role}")
|
|
226
|
+
|
|
227
|
+
content = msg_data.get("content")
|
|
228
|
+
|
|
229
|
+
# 解析工具调用(如果存在)
|
|
230
|
+
tool_calls = None
|
|
231
|
+
if "tool_calls" in msg_data and msg_data["tool_calls"]:
|
|
232
|
+
tool_calls = []
|
|
233
|
+
for tc_data in msg_data["tool_calls"]:
|
|
234
|
+
if not isinstance(tc_data, dict):
|
|
235
|
+
raise ConversationValidationError(
|
|
236
|
+
"Each tool_call must be a dictionary"
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
if "id" not in tc_data or "function" not in tc_data:
|
|
240
|
+
raise ConversationValidationError(
|
|
241
|
+
"tool_call must have 'id' and 'function' fields"
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
function = tc_data["function"]
|
|
245
|
+
if "name" not in function or "arguments" not in function:
|
|
246
|
+
raise ConversationValidationError(
|
|
247
|
+
"function must have 'name' and 'arguments' fields"
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
# 验证arguments是否为有效JSON
|
|
251
|
+
try:
|
|
252
|
+
json.loads(function["arguments"])
|
|
253
|
+
except json.JSONDecodeError as e:
|
|
254
|
+
raise ConversationValidationError(
|
|
255
|
+
f"Invalid JSON in function arguments: {e}"
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
tool_call = ToolCall(
|
|
259
|
+
id=str(tc_data["id"]),
|
|
260
|
+
type=tc_data.get("type", "function"),
|
|
261
|
+
function=ToolFunction(
|
|
262
|
+
name=str(function["name"]), arguments=str(function["arguments"])
|
|
263
|
+
),
|
|
264
|
+
)
|
|
265
|
+
tool_calls.append(tool_call)
|
|
266
|
+
|
|
267
|
+
# 获取工具调用ID(用于tool角色消息)
|
|
268
|
+
tool_call_id = msg_data.get("tool_call_id")
|
|
269
|
+
if role == "tool" and not tool_call_id:
|
|
270
|
+
raise ConversationValidationError("Tool messages must have a tool_call_id")
|
|
271
|
+
|
|
272
|
+
# 处理推理内容(如果存在)
|
|
273
|
+
reasoning_content = msg_data.get("reasoning_content")
|
|
274
|
+
|
|
275
|
+
return Message(
|
|
276
|
+
role=role,
|
|
277
|
+
content=content,
|
|
278
|
+
tool_calls=tool_calls,
|
|
279
|
+
tool_call_id=str(tool_call_id) if tool_call_id else None,
|
|
280
|
+
reasoning_content=reasoning_content,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
@staticmethod
|
|
284
|
+
def validate_conversation_history(
|
|
285
|
+
conversation_history: List[Dict[str, Any]],
|
|
286
|
+
) -> bool:
|
|
287
|
+
"""验证对话历史格式是否正确"""
|
|
288
|
+
try:
|
|
289
|
+
if not isinstance(conversation_history, list):
|
|
290
|
+
return False
|
|
291
|
+
|
|
292
|
+
for msg_data in conversation_history:
|
|
293
|
+
# 基本字段验证
|
|
294
|
+
if not isinstance(msg_data, dict):
|
|
295
|
+
return False
|
|
296
|
+
|
|
297
|
+
if "role" not in msg_data:
|
|
298
|
+
return False
|
|
299
|
+
|
|
300
|
+
role = msg_data["role"]
|
|
301
|
+
if role not in ["user", "assistant", "system", "tool"]:
|
|
302
|
+
return False
|
|
303
|
+
|
|
304
|
+
# 角色特定验证
|
|
305
|
+
if role == "tool":
|
|
306
|
+
if "tool_call_id" not in msg_data or not msg_data["tool_call_id"]:
|
|
307
|
+
return False
|
|
308
|
+
if "content" not in msg_data:
|
|
309
|
+
return False
|
|
310
|
+
|
|
311
|
+
# 工具调用格式验证
|
|
312
|
+
if "tool_calls" in msg_data and msg_data["tool_calls"]:
|
|
313
|
+
tool_calls = msg_data["tool_calls"]
|
|
314
|
+
if not isinstance(tool_calls, list):
|
|
315
|
+
return False
|
|
316
|
+
|
|
317
|
+
for tc in tool_calls:
|
|
318
|
+
if not isinstance(tc, dict):
|
|
319
|
+
return False
|
|
320
|
+
if "id" not in tc or "function" not in tc:
|
|
321
|
+
return False
|
|
322
|
+
if (
|
|
323
|
+
"name" not in tc["function"]
|
|
324
|
+
or "arguments" not in tc["function"]
|
|
325
|
+
):
|
|
326
|
+
return False
|
|
327
|
+
|
|
328
|
+
# 验证arguments是否为有效JSON
|
|
329
|
+
try:
|
|
330
|
+
json.loads(tc["function"]["arguments"])
|
|
331
|
+
except json.JSONDecodeError:
|
|
332
|
+
return False
|
|
333
|
+
|
|
334
|
+
return True
|
|
335
|
+
except Exception:
|
|
336
|
+
return False
|
|
337
|
+
|
|
338
|
+
def add_user_message_from_request(self, message: str) -> None:
|
|
339
|
+
"""从请求中添加用户消息"""
|
|
340
|
+
self.add_user_message(message)
|
|
341
|
+
|
|
342
|
+
def get_openai_messages_for_llm(self) -> List[Dict[str, Any]]:
|
|
343
|
+
"""获取用于LLM调用的OpenAI格式消息"""
|
|
344
|
+
return self.to_openai_format()
|
bizyengine/core/__init__.py
CHANGED
|
@@ -294,7 +294,16 @@ class PromptServer(Command):
|
|
|
294
294
|
|
|
295
295
|
try:
|
|
296
296
|
real_out = decode_data(out)
|
|
297
|
-
|
|
297
|
+
out_lst = []
|
|
298
|
+
for x in real_out:
|
|
299
|
+
if (
|
|
300
|
+
x is None
|
|
301
|
+
): # ref: https://github.com/siliconflow/comfybridge/blob/ecf2e835d4db9816514078f9eed98ab8ba12e23e/custom_plugins/comfy_pipeline/executor.py#L75-L78
|
|
302
|
+
out_lst.append(None)
|
|
303
|
+
else:
|
|
304
|
+
# ref: https://github.com/comfyanonymous/ComfyUI/blob/c170fd2db598a0bdce56f80e22e83e10ad731421/execution.py#L312
|
|
305
|
+
out_lst.append(x[0])
|
|
306
|
+
return out_lst
|
|
298
307
|
except Exception as e:
|
|
299
308
|
print("Exception occurred while decoding data")
|
|
300
309
|
self.cache_manager.delete(sh256)
|
bizyengine/core/common/client.py
CHANGED
|
@@ -107,7 +107,6 @@ def get_api_key() -> str:
|
|
|
107
107
|
if validate_api_key(BIZYAIR_API_KEY):
|
|
108
108
|
api_key_state.is_valid = True
|
|
109
109
|
api_key_state.current_api_key = BIZYAIR_API_KEY
|
|
110
|
-
logging.info("API key set successfully")
|
|
111
110
|
except Exception as e:
|
|
112
111
|
logging.error(str(e))
|
|
113
112
|
raise ValueError(str(e))
|
|
@@ -161,7 +160,7 @@ def send_request(
|
|
|
161
160
|
req = urllib.request.Request(
|
|
162
161
|
url, data=data, headers=headers, method=method, **kwargs
|
|
163
162
|
)
|
|
164
|
-
with urllib.request.urlopen(req) as response:
|
|
163
|
+
with urllib.request.urlopen(req, timeout=3600) as response:
|
|
165
164
|
response_data = response.read().decode("utf-8")
|
|
166
165
|
|
|
167
166
|
except urllib.error.URLError as e:
|
|
@@ -169,7 +168,7 @@ def send_request(
|
|
|
169
168
|
response_body = e.read().decode("utf-8") if hasattr(e, "read") else "N/A"
|
|
170
169
|
if verbose:
|
|
171
170
|
logging.error(f"URLError encountered: {error_message}")
|
|
172
|
-
logging.info(f"Response Body: {
|
|
171
|
+
logging.info(f"Response Body: {response_body}")
|
|
173
172
|
code, message = "N/A", "N/A"
|
|
174
173
|
try:
|
|
175
174
|
response_dict = json.loads(response_body)
|
|
@@ -178,8 +177,8 @@ def send_request(
|
|
|
178
177
|
message = response_dict.get("message", "N/A")
|
|
179
178
|
|
|
180
179
|
except json.JSONDecodeError:
|
|
181
|
-
|
|
182
|
-
|
|
180
|
+
error_info = f"Failed to decode response body as JSON: {str(e)}. Response: {response_body[:200]}"
|
|
181
|
+
raise ConnectionError(f"Invalid server response: {error_info}") from e
|
|
183
182
|
|
|
184
183
|
if "Unauthorized" in error_message:
|
|
185
184
|
raise PermissionError(
|
|
@@ -225,7 +224,7 @@ def send_request(
|
|
|
225
224
|
except urllib.error.HTTPError as e:
|
|
226
225
|
success = 200 <= e.code < 400
|
|
227
226
|
except (urllib.error.URLError, TimeoutError):
|
|
228
|
-
|
|
227
|
+
logging.error(f"Failed ({type(e).__name__}: {str(e)})")
|
|
229
228
|
results[site] = "Success" if success else "Failed"
|
|
230
229
|
raise ConnectionError(
|
|
231
230
|
f"Failed to connect to the server: {url}.\n"
|
|
@@ -253,7 +252,9 @@ async def async_send_request(
|
|
|
253
252
|
) -> dict:
|
|
254
253
|
headers = kwargs.pop("headers") if "headers" in kwargs else _headers()
|
|
255
254
|
try:
|
|
256
|
-
async with aiohttp.ClientSession(
|
|
255
|
+
async with aiohttp.ClientSession(
|
|
256
|
+
timeout=aiohttp.ClientTimeout(total=600)
|
|
257
|
+
) as session:
|
|
257
258
|
async with session.request(
|
|
258
259
|
method, url, data=data, headers=headers, **kwargs
|
|
259
260
|
) as response:
|
bizyengine/core/common/utils.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
|
1
1
|
import copy
|
|
2
|
+
import importlib
|
|
2
3
|
import json
|
|
3
4
|
import os
|
|
4
|
-
|
|
5
|
+
import sys
|
|
6
|
+
import traceback
|
|
7
|
+
from typing import Any, List, Optional
|
|
5
8
|
|
|
6
9
|
import torch
|
|
7
10
|
import yaml
|
|
@@ -91,3 +94,29 @@ def load_config_file(file_path: str) -> dict:
|
|
|
91
94
|
return _load_yaml_config(file_path)
|
|
92
95
|
else:
|
|
93
96
|
raise ValueError(f"Unsupported file extension: {file_path}")
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def safe_star_import(module: str, package: Optional[str] = None) -> None:
|
|
100
|
+
"""
|
|
101
|
+
在模块级执行 from module import * (支持相对导入)
|
|
102
|
+
失败时打印异常,不终止程序
|
|
103
|
+
:param module: 模块名,可带前导点,如 '.nodes_advanced_refluxcontrol'
|
|
104
|
+
:param package: 当前包名,如 'ui.nodes';若 module 为相对导入则必须传
|
|
105
|
+
"""
|
|
106
|
+
try:
|
|
107
|
+
# 1. 导入模块(相对导入必须传 package)
|
|
108
|
+
mod = importlib.import_module(module, package=package)
|
|
109
|
+
# 2. 取 __all__ 或全部非下划线名字
|
|
110
|
+
names = getattr(mod, "__all__", None)
|
|
111
|
+
if names is None:
|
|
112
|
+
names = [k for k in mod.__dict__ if not k.startswith("_")]
|
|
113
|
+
# 3. 注入到调用者 globals
|
|
114
|
+
caller_globals = sys._getframe(1).f_globals
|
|
115
|
+
for name in names:
|
|
116
|
+
caller_globals[name] = getattr(mod, name)
|
|
117
|
+
except Exception as e:
|
|
118
|
+
print(
|
|
119
|
+
f"\033[92m[BizyAir]\033[0m safe_star_import {module} failed: {e}",
|
|
120
|
+
file=sys.stderr,
|
|
121
|
+
)
|
|
122
|
+
traceback.print_exc()
|