jarvis-ai-assistant 0.1.46__py3-none-any.whl → 0.1.48__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- jarvis/__init__.py +1 -1
- jarvis/agent.py +32 -47
- jarvis/main.py +35 -51
- jarvis/models/__init__.py +1 -1
- jarvis/models/ai8.py +58 -88
- jarvis/models/base.py +6 -6
- jarvis/models/kimi.py +80 -171
- jarvis/models/openai.py +23 -43
- jarvis/models/oyi.py +91 -113
- jarvis/models/registry.py +44 -63
- jarvis/tools/__init__.py +1 -1
- jarvis/tools/base.py +2 -2
- jarvis/tools/file_ops.py +15 -19
- jarvis/tools/generator.py +12 -15
- jarvis/tools/methodology.py +20 -20
- jarvis/tools/registry.py +30 -44
- jarvis/tools/shell.py +11 -12
- jarvis/tools/sub_agent.py +2 -1
- jarvis/utils.py +27 -47
- {jarvis_ai_assistant-0.1.46.dist-info → jarvis_ai_assistant-0.1.48.dist-info}/METADATA +1 -1
- jarvis_ai_assistant-0.1.48.dist-info/RECORD +25 -0
- jarvis_ai_assistant-0.1.46.dist-info/RECORD +0 -25
- {jarvis_ai_assistant-0.1.46.dist-info → jarvis_ai_assistant-0.1.48.dist-info}/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.46.dist-info → jarvis_ai_assistant-0.1.48.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.46.dist-info → jarvis_ai_assistant-0.1.48.dist-info}/entry_points.txt +0 -0
- {jarvis_ai_assistant-0.1.46.dist-info → jarvis_ai_assistant-0.1.48.dist-info}/top_level.txt +0 -0
jarvis/models/oyi.py
CHANGED
@@ -6,17 +6,16 @@ from jarvis.utils import PrettyOutput, OutputType
|
|
6
6
|
import requests
|
7
7
|
import json
|
8
8
|
|
9
|
-
|
10
9
|
class OyiModel(BasePlatform):
|
11
10
|
"""Oyi model implementation"""
|
12
|
-
|
11
|
+
|
13
12
|
platform_name = "oyi"
|
14
13
|
BASE_URL = "https://api-10086.rcouyi.com"
|
15
|
-
|
14
|
+
|
16
15
|
def __init__(self):
|
17
16
|
"""Initialize model"""
|
18
17
|
PrettyOutput.section("支持的模型", OutputType.SUCCESS)
|
19
|
-
|
18
|
+
|
20
19
|
# 获取可用模型列表
|
21
20
|
available_models = self.get_available_models()
|
22
21
|
if available_models:
|
@@ -24,32 +23,30 @@ class OyiModel(BasePlatform):
|
|
24
23
|
PrettyOutput.print(model, OutputType.INFO)
|
25
24
|
else:
|
26
25
|
PrettyOutput.print("获取模型列表失败", OutputType.WARNING)
|
27
|
-
|
26
|
+
|
28
27
|
PrettyOutput.print("使用OYI_MODEL环境变量配置模型", OutputType.SUCCESS)
|
29
|
-
|
28
|
+
|
30
29
|
self.messages = []
|
31
30
|
self.system_message = ""
|
32
31
|
self.conversation = None
|
33
|
-
self.upload_files = []
|
32
|
+
self.upload_files = [] # 重命名 files 为 upload_files
|
34
33
|
self.first_chat = True
|
35
|
-
|
34
|
+
|
36
35
|
self.token = os.getenv("OYI_API_KEY")
|
37
36
|
if not self.token:
|
38
37
|
raise Exception("OYI_API_KEY is not set")
|
39
|
-
|
38
|
+
|
40
39
|
self.model_name = os.getenv("OYI_MODEL") or "deepseek-chat"
|
41
40
|
if self.model_name not in [m.split()[0] for m in available_models]:
|
42
|
-
PrettyOutput.print(
|
43
|
-
|
44
|
-
self.model_name} 不在可用列表中",
|
45
|
-
OutputType.WARNING)
|
46
|
-
|
41
|
+
PrettyOutput.print(f"警告: 当前选择的模型 {self.model_name} 不在可用列表中", OutputType.WARNING)
|
42
|
+
|
47
43
|
PrettyOutput.print(f"当前使用模型: {self.model_name}", OutputType.SYSTEM)
|
48
44
|
|
49
45
|
def set_model_name(self, model_name: str):
|
50
46
|
"""设置模型名称"""
|
51
47
|
self.model_name = model_name
|
52
48
|
|
49
|
+
|
53
50
|
def create_conversation(self) -> bool:
|
54
51
|
"""Create a new conversation"""
|
55
52
|
try:
|
@@ -59,7 +56,7 @@ class OyiModel(BasePlatform):
|
|
59
56
|
'Accept': 'application/json',
|
60
57
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36'
|
61
58
|
}
|
62
|
-
|
59
|
+
|
63
60
|
payload = {
|
64
61
|
"id": 0,
|
65
62
|
"roleId": 0,
|
@@ -78,49 +75,40 @@ class OyiModel(BasePlatform):
|
|
78
75
|
"chatPluginIds": []
|
79
76
|
})
|
80
77
|
}
|
81
|
-
|
78
|
+
|
82
79
|
response = requests.post(
|
83
80
|
f"{self.BASE_URL}/chatapi/chat/save",
|
84
81
|
headers=headers,
|
85
82
|
json=payload
|
86
83
|
)
|
87
|
-
|
84
|
+
|
88
85
|
if response.status_code == 200:
|
89
86
|
data = response.json()
|
90
87
|
if data['code'] == 200 and data['type'] == 'success':
|
91
88
|
self.conversation = data
|
92
|
-
PrettyOutput.print(
|
93
|
-
f"创建会话成功: {
|
94
|
-
data['result']['id']}",
|
95
|
-
OutputType.SUCCESS)
|
89
|
+
PrettyOutput.print(f"创建会话成功: {data['result']['id']}", OutputType.SUCCESS)
|
96
90
|
return True
|
97
91
|
else:
|
98
|
-
PrettyOutput.print(
|
99
|
-
f"创建会话失败: {
|
100
|
-
data['message']}",
|
101
|
-
OutputType.ERROR)
|
92
|
+
PrettyOutput.print(f"创建会话失败: {data['message']}", OutputType.ERROR)
|
102
93
|
return False
|
103
94
|
else:
|
104
|
-
PrettyOutput.print(
|
105
|
-
f"创建会话失败: {
|
106
|
-
response.status_code}",
|
107
|
-
OutputType.ERROR)
|
95
|
+
PrettyOutput.print(f"创建会话失败: {response.status_code}", OutputType.ERROR)
|
108
96
|
return False
|
109
|
-
|
97
|
+
|
110
98
|
except Exception as e:
|
111
99
|
PrettyOutput.print(f"创建会话异常: {str(e)}", OutputType.ERROR)
|
112
100
|
return False
|
113
|
-
|
101
|
+
|
114
102
|
def set_system_message(self, message: str):
|
115
103
|
"""Set system message"""
|
116
104
|
self.system_message = message
|
117
|
-
|
105
|
+
|
118
106
|
def chat(self, message: str) -> str:
|
119
107
|
"""Execute chat with the model
|
120
|
-
|
108
|
+
|
121
109
|
Args:
|
122
110
|
message: User input message
|
123
|
-
|
111
|
+
|
124
112
|
Returns:
|
125
113
|
str: Model response
|
126
114
|
"""
|
@@ -129,7 +117,7 @@ class OyiModel(BasePlatform):
|
|
129
117
|
if not self.conversation:
|
130
118
|
if not self.create_conversation():
|
131
119
|
raise Exception("Failed to create conversation")
|
132
|
-
|
120
|
+
|
133
121
|
# 1. 发送消息
|
134
122
|
headers = {
|
135
123
|
'Authorization': f'Bearer {self.token}',
|
@@ -139,14 +127,14 @@ class OyiModel(BasePlatform):
|
|
139
127
|
'Origin': 'https://ai.rcouyi.com',
|
140
128
|
'Referer': 'https://ai.rcouyi.com/'
|
141
129
|
}
|
142
|
-
|
130
|
+
|
143
131
|
payload = {
|
144
132
|
"topicId": self.conversation['result']['id'],
|
145
133
|
"messages": self.messages,
|
146
134
|
"content": message,
|
147
135
|
"contentFiles": []
|
148
136
|
}
|
149
|
-
|
137
|
+
|
150
138
|
# 如果有上传的文件,添加到请求中
|
151
139
|
if self.first_chat:
|
152
140
|
if self.upload_files:
|
@@ -165,64 +153,63 @@ class OyiModel(BasePlatform):
|
|
165
153
|
self.first_chat = False
|
166
154
|
|
167
155
|
self.messages.append({"role": "user", "content": message})
|
168
|
-
|
156
|
+
|
169
157
|
# 发送消息
|
170
158
|
response = requests.post(
|
171
159
|
f"{self.BASE_URL}/chatapi/chat/message",
|
172
160
|
headers=headers,
|
173
161
|
json=payload
|
174
162
|
)
|
175
|
-
|
163
|
+
|
176
164
|
if response.status_code != 200:
|
177
165
|
error_msg = f"聊天请求失败: {response.status_code}"
|
178
166
|
PrettyOutput.print(error_msg, OutputType.ERROR)
|
179
167
|
raise Exception(error_msg)
|
180
|
-
|
168
|
+
|
181
169
|
data = response.json()
|
182
170
|
if data['code'] != 200 or data['type'] != 'success':
|
183
171
|
error_msg = f"聊天失败: {data.get('message', '未知错误')}"
|
184
172
|
PrettyOutput.print(error_msg, OutputType.ERROR)
|
185
173
|
raise Exception(error_msg)
|
186
|
-
|
174
|
+
|
187
175
|
message_id = data['result'][-1]
|
188
|
-
|
176
|
+
|
189
177
|
# 获取响应内容
|
190
178
|
response = requests.post(
|
191
179
|
f"{self.BASE_URL}/chatapi/chat/message/{message_id}",
|
192
180
|
headers=headers
|
193
181
|
)
|
194
|
-
|
182
|
+
|
195
183
|
if response.status_code == 200:
|
196
184
|
PrettyOutput.print(response.text, OutputType.SYSTEM)
|
197
|
-
self.messages.append(
|
198
|
-
{"role": "assistant", "content": response.text})
|
185
|
+
self.messages.append({"role": "assistant", "content": response.text})
|
199
186
|
return response.text
|
200
187
|
else:
|
201
188
|
error_msg = f"获取响应失败: {response.status_code}"
|
202
189
|
PrettyOutput.print(error_msg, OutputType.ERROR)
|
203
190
|
raise Exception(error_msg)
|
204
|
-
|
191
|
+
|
205
192
|
except Exception as e:
|
206
193
|
PrettyOutput.print(f"聊天异常: {str(e)}", OutputType.ERROR)
|
207
194
|
raise e
|
208
|
-
|
195
|
+
|
209
196
|
def name(self) -> str:
|
210
197
|
"""Return model name"""
|
211
198
|
return self.model_name
|
212
|
-
|
199
|
+
|
213
200
|
def reset(self):
|
214
201
|
"""Reset model state"""
|
215
202
|
self.messages = []
|
216
203
|
self.conversation = None
|
217
204
|
self.upload_files = []
|
218
205
|
self.first_chat = True
|
219
|
-
|
206
|
+
|
220
207
|
def delete_chat(self) -> bool:
|
221
208
|
"""Delete current chat session"""
|
222
209
|
try:
|
223
210
|
if not self.conversation:
|
224
211
|
return True
|
225
|
-
|
212
|
+
|
226
213
|
headers = {
|
227
214
|
'Authorization': f'Bearer {self.token}',
|
228
215
|
'Content-Type': 'application/json',
|
@@ -231,15 +218,13 @@ class OyiModel(BasePlatform):
|
|
231
218
|
'Origin': 'https://ai.rcouyi.com',
|
232
219
|
'Referer': 'https://ai.rcouyi.com/'
|
233
220
|
}
|
234
|
-
|
221
|
+
|
235
222
|
response = requests.post(
|
236
|
-
f"{
|
237
|
-
self.BASE_URL}/chatapi/chat/{
|
238
|
-
self.conversation['result']['id']}",
|
223
|
+
f"{self.BASE_URL}/chatapi/chat/{self.conversation['result']['id']}",
|
239
224
|
headers=headers,
|
240
225
|
json={}
|
241
226
|
)
|
242
|
-
|
227
|
+
|
243
228
|
if response.status_code == 200:
|
244
229
|
data = response.json()
|
245
230
|
if data['code'] == 200 and data['type'] == 'success':
|
@@ -254,17 +239,17 @@ class OyiModel(BasePlatform):
|
|
254
239
|
error_msg = f"删除会话请求失败: {response.status_code}"
|
255
240
|
PrettyOutput.print(error_msg, OutputType.ERROR)
|
256
241
|
return False
|
257
|
-
|
242
|
+
|
258
243
|
except Exception as e:
|
259
244
|
PrettyOutput.print(f"删除会话异常: {str(e)}", OutputType.ERROR)
|
260
245
|
return False
|
261
|
-
|
262
|
-
def
|
246
|
+
|
247
|
+
def upload_files(self, file_list: List[str]) -> List[Dict]:
|
263
248
|
"""Upload a file to OYI API
|
264
|
-
|
249
|
+
|
265
250
|
Args:
|
266
251
|
file_path: Path to the file to upload
|
267
|
-
|
252
|
+
|
268
253
|
Returns:
|
269
254
|
Dict: Upload response data
|
270
255
|
"""
|
@@ -272,12 +257,9 @@ class OyiModel(BasePlatform):
|
|
272
257
|
# 检查当前模型是否支持文件上传
|
273
258
|
model_info = self.models.get(self.model_name)
|
274
259
|
if not model_info or not model_info.get('uploadFile', False):
|
275
|
-
PrettyOutput.print(
|
276
|
-
f"当前模型 {
|
277
|
-
self.model_name} 不支持文件上传",
|
278
|
-
OutputType.WARNING)
|
260
|
+
PrettyOutput.print(f"当前模型 {self.model_name} 不支持文件上传", OutputType.WARNING)
|
279
261
|
return None
|
280
|
-
|
262
|
+
|
281
263
|
headers = {
|
282
264
|
'Authorization': f'Bearer {self.token}',
|
283
265
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36',
|
@@ -286,45 +268,45 @@ class OyiModel(BasePlatform):
|
|
286
268
|
'Origin': 'https://ai.rcouyi.com',
|
287
269
|
'Referer': 'https://ai.rcouyi.com/'
|
288
270
|
}
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
files=
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
271
|
+
|
272
|
+
for file_path in file_list:
|
273
|
+
# 检查文件类型
|
274
|
+
file_type = mimetypes.guess_type(file_path)[0]
|
275
|
+
if not file_type or not file_type.startswith(('image/', 'text/', 'application/')):
|
276
|
+
PrettyOutput.print(f"文件类型不支持: {file_type}", OutputType.ERROR)
|
277
|
+
continue
|
278
|
+
|
279
|
+
with open(file_path, 'rb') as f:
|
280
|
+
files = {
|
281
|
+
'file': (os.path.basename(file_path), f, file_type)
|
282
|
+
}
|
283
|
+
|
284
|
+
response = requests.post(
|
285
|
+
f"{self.BASE_URL}/chatapi/m_file/uploadfile",
|
286
|
+
headers=headers,
|
287
|
+
files=files
|
288
|
+
)
|
289
|
+
|
290
|
+
if response.status_code == 200:
|
291
|
+
data = response.json()
|
292
|
+
if data.get('code') == 200:
|
293
|
+
PrettyOutput.print("文件上传成功", OutputType.SUCCESS)
|
294
|
+
self.upload_files.append(data)
|
295
|
+
return data
|
296
|
+
else:
|
297
|
+
PrettyOutput.print(f"文件上传失败: {data.get('message')}", OutputType.ERROR)
|
298
|
+
return None
|
308
299
|
else:
|
309
|
-
PrettyOutput.print(
|
310
|
-
f"文件上传失败: {
|
311
|
-
data.get('message')}",
|
312
|
-
OutputType.ERROR)
|
300
|
+
PrettyOutput.print(f"文件上传失败: {response.status_code}", OutputType.ERROR)
|
313
301
|
return None
|
314
|
-
|
315
|
-
PrettyOutput.print(
|
316
|
-
f"文件上传失败: {
|
317
|
-
response.status_code}",
|
318
|
-
OutputType.ERROR)
|
319
|
-
return None
|
320
|
-
|
302
|
+
|
321
303
|
except Exception as e:
|
322
304
|
PrettyOutput.print(f"文件上传异常: {str(e)}", OutputType.ERROR)
|
323
305
|
return None
|
324
306
|
|
325
307
|
def get_available_models(self) -> List[str]:
|
326
308
|
"""获取可用的模型列表
|
327
|
-
|
309
|
+
|
328
310
|
Returns:
|
329
311
|
List[str]: 可用模型名称列表
|
330
312
|
"""
|
@@ -336,59 +318,55 @@ class OyiModel(BasePlatform):
|
|
336
318
|
'Origin': 'https://ai.rcouyi.com',
|
337
319
|
'Referer': 'https://ai.rcouyi.com/'
|
338
320
|
}
|
339
|
-
|
321
|
+
|
340
322
|
response = requests.get(
|
341
323
|
"https://ai.rcouyi.com/config/system.json",
|
342
324
|
headers=headers
|
343
325
|
)
|
344
|
-
|
326
|
+
|
345
327
|
if response.status_code != 200:
|
346
|
-
PrettyOutput.print(
|
347
|
-
f"获取模型列表失败: {
|
348
|
-
response.status_code}",
|
349
|
-
OutputType.ERROR)
|
328
|
+
PrettyOutput.print(f"获取模型列表失败: {response.status_code}", OutputType.ERROR)
|
350
329
|
return []
|
351
|
-
|
330
|
+
|
352
331
|
data = response.json()
|
353
|
-
|
332
|
+
|
354
333
|
# 保存模型信息
|
355
334
|
self.models = {
|
356
335
|
model['value']: model
|
357
336
|
for model in data.get('model', [])
|
358
337
|
if model.get('enable', False) # 只保存启用的模型
|
359
338
|
}
|
360
|
-
|
339
|
+
|
361
340
|
# 格式化显示
|
362
341
|
models = []
|
363
342
|
for model in self.models.values():
|
364
343
|
# 基本信息
|
365
344
|
model_str = f"{model['value']:<30} {model['label']}"
|
366
|
-
|
345
|
+
|
367
346
|
# 添加后缀标签
|
368
347
|
suffix = model.get('suffix', [])
|
369
348
|
if suffix:
|
370
349
|
# 处理新格式的suffix (字典列表)
|
371
350
|
if suffix and isinstance(suffix[0], dict):
|
372
|
-
suffix_str = ', '.join(s.get('tag', '')
|
373
|
-
for s in suffix)
|
351
|
+
suffix_str = ', '.join(s.get('tag', '') for s in suffix)
|
374
352
|
# 处理旧格式的suffix (字符串列表)
|
375
353
|
else:
|
376
354
|
suffix_str = ', '.join(str(s) for s in suffix)
|
377
355
|
model_str += f" ({suffix_str})"
|
378
|
-
|
356
|
+
|
379
357
|
# 添加描述或提示
|
380
358
|
info = model.get('tooltip') or model.get('description', '')
|
381
359
|
if info:
|
382
360
|
model_str += f" - {info}"
|
383
|
-
|
361
|
+
|
384
362
|
# 添加文件上传支持标记
|
385
363
|
if model.get('uploadFile'):
|
386
364
|
model_str += " [支持文件上传]"
|
387
|
-
|
365
|
+
|
388
366
|
models.append(model_str)
|
389
|
-
|
367
|
+
|
390
368
|
return sorted(models)
|
391
|
-
|
369
|
+
|
392
370
|
except Exception as e:
|
393
371
|
PrettyOutput.print(f"获取模型列表异常: {str(e)}", OutputType.ERROR)
|
394
372
|
return []
|
jarvis/models/registry.py
CHANGED
@@ -14,7 +14,6 @@ REQUIRED_METHODS = [
|
|
14
14
|
('set_system_message', ['message'])
|
15
15
|
]
|
16
16
|
|
17
|
-
|
18
17
|
class PlatformRegistry:
|
19
18
|
"""平台注册器"""
|
20
19
|
|
@@ -30,82 +29,76 @@ class PlatformRegistry:
|
|
30
29
|
# 创建 __init__.py 使其成为 Python 包
|
31
30
|
with open(os.path.join(user_platform_dir, "__init__.py"), "w") as f:
|
32
31
|
pass
|
33
|
-
PrettyOutput.print(
|
34
|
-
f"已创建平台目录: {user_platform_dir}",
|
35
|
-
OutputType.INFO)
|
32
|
+
PrettyOutput.print(f"已创建平台目录: {user_platform_dir}", OutputType.INFO)
|
36
33
|
except Exception as e:
|
37
34
|
PrettyOutput.print(f"创建平台目录失败: {str(e)}", OutputType.ERROR)
|
38
35
|
return ""
|
39
36
|
return user_platform_dir
|
40
37
|
|
41
38
|
@staticmethod
|
42
|
-
def check_platform_implementation(
|
43
|
-
platform_class: Type[BasePlatform]) -> bool:
|
39
|
+
def check_platform_implementation(platform_class: Type[BasePlatform]) -> bool:
|
44
40
|
"""检查平台类是否实现了所有必要的方法
|
45
|
-
|
41
|
+
|
46
42
|
Args:
|
47
43
|
platform_class: 要检查的平台类
|
48
|
-
|
44
|
+
|
49
45
|
Returns:
|
50
46
|
bool: 是否实现了所有必要的方法
|
51
47
|
"""
|
52
48
|
missing_methods = []
|
53
|
-
|
49
|
+
|
54
50
|
for method_name, params in REQUIRED_METHODS:
|
55
51
|
if not hasattr(platform_class, method_name):
|
56
52
|
missing_methods.append(method_name)
|
57
53
|
continue
|
58
|
-
|
54
|
+
|
59
55
|
method = getattr(platform_class, method_name)
|
60
56
|
if not callable(method):
|
61
57
|
missing_methods.append(method_name)
|
62
58
|
continue
|
63
|
-
|
59
|
+
|
64
60
|
# 检查方法参数
|
65
61
|
import inspect
|
66
62
|
sig = inspect.signature(method)
|
67
63
|
method_params = [p for p in sig.parameters if p != 'self']
|
68
64
|
if len(method_params) != len(params):
|
69
65
|
missing_methods.append(f"{method_name}(参数不匹配)")
|
70
|
-
|
66
|
+
|
71
67
|
if missing_methods:
|
72
68
|
PrettyOutput.print(
|
73
|
-
f"平台 {
|
74
|
-
platform_class.__name__} 缺少必要的方法: {
|
75
|
-
', '.join(missing_methods)}",
|
69
|
+
f"平台 {platform_class.__name__} 缺少必要的方法: {', '.join(missing_methods)}",
|
76
70
|
OutputType.ERROR
|
77
71
|
)
|
78
72
|
return False
|
79
|
-
|
73
|
+
|
80
74
|
return True
|
81
75
|
|
82
76
|
@staticmethod
|
83
|
-
def load_platform_from_dir(
|
84
|
-
directory: str) -> Dict[str, Type[BasePlatform]]:
|
77
|
+
def load_platform_from_dir(directory: str) -> Dict[str, Type[BasePlatform]]:
|
85
78
|
"""从指定目录加载平台
|
86
|
-
|
79
|
+
|
87
80
|
Args:
|
88
81
|
directory: 平台目录路径
|
89
|
-
|
82
|
+
|
90
83
|
Returns:
|
91
84
|
Dict[str, Type[BaseModel]]: 平台名称到平台类的映射
|
92
85
|
"""
|
93
86
|
platforms = {}
|
94
|
-
|
87
|
+
|
95
88
|
# 确保目录存在
|
96
89
|
if not os.path.exists(directory):
|
97
90
|
PrettyOutput.print(f"平台目录不存在: {directory}", OutputType.ERROR)
|
98
91
|
return platforms
|
99
|
-
|
92
|
+
|
100
93
|
# 获取目录的包名
|
101
94
|
package_name = None
|
102
95
|
if directory == os.path.dirname(__file__):
|
103
96
|
package_name = "jarvis.models"
|
104
|
-
|
97
|
+
|
105
98
|
# 添加目录到Python路径
|
106
99
|
if directory not in sys.path:
|
107
100
|
sys.path.append(directory)
|
108
|
-
|
101
|
+
|
109
102
|
# 遍历目录下的所有.py文件
|
110
103
|
for filename in os.listdir(directory):
|
111
104
|
if filename.endswith('.py') and not filename.startswith('__'):
|
@@ -113,55 +106,46 @@ class PlatformRegistry:
|
|
113
106
|
try:
|
114
107
|
# 导入模块
|
115
108
|
if package_name:
|
116
|
-
module = importlib.import_module(
|
117
|
-
f"{package_name}.{module_name}")
|
109
|
+
module = importlib.import_module(f"{package_name}.{module_name}")
|
118
110
|
else:
|
119
111
|
module = importlib.import_module(module_name)
|
120
|
-
|
112
|
+
|
121
113
|
# 遍历模块中的所有类
|
122
114
|
for name, obj in inspect.getmembers(module):
|
123
115
|
# 检查是否是BaseModel的子类,但不是BaseModel本身
|
124
|
-
if (inspect.isclass(obj) and
|
125
|
-
issubclass(obj, BasePlatform) and
|
116
|
+
if (inspect.isclass(obj) and
|
117
|
+
issubclass(obj, BasePlatform) and
|
126
118
|
obj != BasePlatform and
|
127
|
-
|
119
|
+
hasattr(obj, 'platform_name')):
|
128
120
|
# 检查平台实现
|
129
|
-
if not PlatformRegistry.check_platform_implementation(
|
130
|
-
obj):
|
121
|
+
if not PlatformRegistry.check_platform_implementation(obj):
|
131
122
|
continue
|
132
123
|
platforms[obj.platform_name] = obj
|
133
|
-
PrettyOutput.print(
|
134
|
-
f"从 {directory} 加载平台: {
|
135
|
-
obj.platform_name}", OutputType.INFO)
|
124
|
+
PrettyOutput.print(f"从 {directory} 加载平台: {obj.platform_name}", OutputType.INFO)
|
136
125
|
break
|
137
126
|
except Exception as e:
|
138
|
-
PrettyOutput.print(
|
139
|
-
|
140
|
-
str(e)}", OutputType.ERROR)
|
141
|
-
|
127
|
+
PrettyOutput.print(f"加载平台 {module_name} 失败: {str(e)}", OutputType.ERROR)
|
128
|
+
|
142
129
|
return platforms
|
143
130
|
|
131
|
+
|
144
132
|
@staticmethod
|
145
133
|
def get_global_platform_registry():
|
146
134
|
"""获取全局平台注册器"""
|
147
135
|
if PlatformRegistry.global_platform_registry is None:
|
148
136
|
PlatformRegistry.global_platform_registry = PlatformRegistry()
|
149
|
-
|
137
|
+
|
150
138
|
# 从用户平台目录加载额外平台
|
151
139
|
platform_dir = PlatformRegistry.get_platform_dir()
|
152
140
|
if platform_dir and os.path.exists(platform_dir):
|
153
|
-
for platform_name, platform_class in PlatformRegistry.load_platform_from_dir(
|
154
|
-
|
155
|
-
PlatformRegistry.global_platform_registry.register_platform(
|
156
|
-
platform_name, platform_class)
|
141
|
+
for platform_name, platform_class in PlatformRegistry.load_platform_from_dir(platform_dir).items():
|
142
|
+
PlatformRegistry.global_platform_registry.register_platform(platform_name, platform_class)
|
157
143
|
platform_dir = os.path.dirname(__file__)
|
158
144
|
if platform_dir and os.path.exists(platform_dir):
|
159
|
-
for platform_name, platform_class in PlatformRegistry.load_platform_from_dir(
|
160
|
-
|
161
|
-
PlatformRegistry.global_platform_registry.register_platform(
|
162
|
-
platform_name, platform_class)
|
145
|
+
for platform_name, platform_class in PlatformRegistry.load_platform_from_dir(platform_dir).items():
|
146
|
+
PlatformRegistry.global_platform_registry.register_platform(platform_name, platform_class)
|
163
147
|
return PlatformRegistry.global_platform_registry
|
164
|
-
|
148
|
+
|
165
149
|
def __init__(self):
|
166
150
|
"""初始化平台注册器
|
167
151
|
"""
|
@@ -170,37 +154,34 @@ class PlatformRegistry:
|
|
170
154
|
@staticmethod
|
171
155
|
def get_global_platform() -> BasePlatform:
|
172
156
|
"""获取全局平台实例"""
|
173
|
-
platform = PlatformRegistry.get_global_platform_registry(
|
174
|
-
).create_platform(PlatformRegistry.global_platform_name)
|
157
|
+
platform = PlatformRegistry.get_global_platform_registry().create_platform(PlatformRegistry.global_platform_name)
|
175
158
|
if not platform:
|
176
|
-
raise Exception(
|
177
|
-
f"Failed to create platform: {
|
178
|
-
PlatformRegistry.global_platform_name}")
|
159
|
+
raise Exception(f"Failed to create platform: {PlatformRegistry.global_platform_name}")
|
179
160
|
return platform
|
180
|
-
|
161
|
+
|
181
162
|
def register_platform(self, name: str, platform_class: Type[BasePlatform]):
|
182
163
|
"""注册平台类
|
183
|
-
|
164
|
+
|
184
165
|
Args:
|
185
166
|
name: 平台名称
|
186
167
|
model_class: 平台类
|
187
168
|
"""
|
188
169
|
self.platforms[name] = platform_class
|
189
170
|
PrettyOutput.print(f"已注册平台: {name}", OutputType.INFO)
|
190
|
-
|
171
|
+
|
191
172
|
def create_platform(self, name: str) -> Optional[BasePlatform]:
|
192
173
|
"""创建平台实例
|
193
|
-
|
174
|
+
|
194
175
|
Args:
|
195
176
|
name: 平台名称
|
196
|
-
|
177
|
+
|
197
178
|
Returns:
|
198
179
|
BaseModel: 平台实例
|
199
180
|
"""
|
200
181
|
if name not in self.platforms:
|
201
182
|
PrettyOutput.print(f"未找到平台: {name}", OutputType.ERROR)
|
202
183
|
return None
|
203
|
-
|
184
|
+
|
204
185
|
try:
|
205
186
|
platform = self.platforms[name]()
|
206
187
|
PrettyOutput.print(f"已创建平台实例: {name}", OutputType.INFO)
|
@@ -208,11 +189,11 @@ class PlatformRegistry:
|
|
208
189
|
except Exception as e:
|
209
190
|
PrettyOutput.print(f"创建平台失败: {str(e)}", OutputType.ERROR)
|
210
191
|
return None
|
211
|
-
|
192
|
+
|
212
193
|
def get_available_platforms(self) -> List[str]:
|
213
194
|
"""获取可用平台列表"""
|
214
|
-
return list(self.platforms.keys())
|
215
|
-
|
195
|
+
return list(self.platforms.keys())
|
196
|
+
|
216
197
|
def set_global_platform_name(self, platform_name: str):
|
217
198
|
"""设置全局平台"""
|
218
199
|
PlatformRegistry.global_platform_name = platform_name
|