jarvis-ai-assistant 0.1.46__py3-none-any.whl → 0.1.48__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jarvis/__init__.py +1 -1
- jarvis/agent.py +32 -47
- jarvis/main.py +35 -51
- jarvis/models/__init__.py +1 -1
- jarvis/models/ai8.py +58 -88
- jarvis/models/base.py +6 -6
- jarvis/models/kimi.py +80 -171
- jarvis/models/openai.py +23 -43
- jarvis/models/oyi.py +91 -113
- jarvis/models/registry.py +44 -63
- jarvis/tools/__init__.py +1 -1
- jarvis/tools/base.py +2 -2
- jarvis/tools/file_ops.py +15 -19
- jarvis/tools/generator.py +12 -15
- jarvis/tools/methodology.py +20 -20
- jarvis/tools/registry.py +30 -44
- jarvis/tools/shell.py +11 -12
- jarvis/tools/sub_agent.py +2 -1
- jarvis/utils.py +27 -47
- {jarvis_ai_assistant-0.1.46.dist-info → jarvis_ai_assistant-0.1.48.dist-info}/METADATA +1 -1
- jarvis_ai_assistant-0.1.48.dist-info/RECORD +25 -0
- jarvis_ai_assistant-0.1.46.dist-info/RECORD +0 -25
- {jarvis_ai_assistant-0.1.46.dist-info → jarvis_ai_assistant-0.1.48.dist-info}/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.46.dist-info → jarvis_ai_assistant-0.1.48.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.46.dist-info → jarvis_ai_assistant-0.1.48.dist-info}/entry_points.txt +0 -0
- {jarvis_ai_assistant-0.1.46.dist-info → jarvis_ai_assistant-0.1.48.dist-info}/top_level.txt +0 -0
jarvis/models/oyi.py
CHANGED
@@ -6,17 +6,16 @@ from jarvis.utils import PrettyOutput, OutputType
|
|
6
6
|
import requests
|
7
7
|
import json
|
8
8
|
|
9
|
-
|
10
9
|
class OyiModel(BasePlatform):
|
11
10
|
"""Oyi model implementation"""
|
12
|
-
|
11
|
+
|
13
12
|
platform_name = "oyi"
|
14
13
|
BASE_URL = "https://api-10086.rcouyi.com"
|
15
|
-
|
14
|
+
|
16
15
|
def __init__(self):
|
17
16
|
"""Initialize model"""
|
18
17
|
PrettyOutput.section("支持的模型", OutputType.SUCCESS)
|
19
|
-
|
18
|
+
|
20
19
|
# 获取可用模型列表
|
21
20
|
available_models = self.get_available_models()
|
22
21
|
if available_models:
|
@@ -24,32 +23,30 @@ class OyiModel(BasePlatform):
|
|
24
23
|
PrettyOutput.print(model, OutputType.INFO)
|
25
24
|
else:
|
26
25
|
PrettyOutput.print("获取模型列表失败", OutputType.WARNING)
|
27
|
-
|
26
|
+
|
28
27
|
PrettyOutput.print("使用OYI_MODEL环境变量配置模型", OutputType.SUCCESS)
|
29
|
-
|
28
|
+
|
30
29
|
self.messages = []
|
31
30
|
self.system_message = ""
|
32
31
|
self.conversation = None
|
33
|
-
self.upload_files = []
|
32
|
+
self.upload_files = [] # 重命名 files 为 upload_files
|
34
33
|
self.first_chat = True
|
35
|
-
|
34
|
+
|
36
35
|
self.token = os.getenv("OYI_API_KEY")
|
37
36
|
if not self.token:
|
38
37
|
raise Exception("OYI_API_KEY is not set")
|
39
|
-
|
38
|
+
|
40
39
|
self.model_name = os.getenv("OYI_MODEL") or "deepseek-chat"
|
41
40
|
if self.model_name not in [m.split()[0] for m in available_models]:
|
42
|
-
PrettyOutput.print(
|
43
|
-
|
44
|
-
self.model_name} 不在可用列表中",
|
45
|
-
OutputType.WARNING)
|
46
|
-
|
41
|
+
PrettyOutput.print(f"警告: 当前选择的模型 {self.model_name} 不在可用列表中", OutputType.WARNING)
|
42
|
+
|
47
43
|
PrettyOutput.print(f"当前使用模型: {self.model_name}", OutputType.SYSTEM)
|
48
44
|
|
49
45
|
def set_model_name(self, model_name: str):
|
50
46
|
"""设置模型名称"""
|
51
47
|
self.model_name = model_name
|
52
48
|
|
49
|
+
|
53
50
|
def create_conversation(self) -> bool:
|
54
51
|
"""Create a new conversation"""
|
55
52
|
try:
|
@@ -59,7 +56,7 @@ class OyiModel(BasePlatform):
|
|
59
56
|
'Accept': 'application/json',
|
60
57
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36'
|
61
58
|
}
|
62
|
-
|
59
|
+
|
63
60
|
payload = {
|
64
61
|
"id": 0,
|
65
62
|
"roleId": 0,
|
@@ -78,49 +75,40 @@ class OyiModel(BasePlatform):
|
|
78
75
|
"chatPluginIds": []
|
79
76
|
})
|
80
77
|
}
|
81
|
-
|
78
|
+
|
82
79
|
response = requests.post(
|
83
80
|
f"{self.BASE_URL}/chatapi/chat/save",
|
84
81
|
headers=headers,
|
85
82
|
json=payload
|
86
83
|
)
|
87
|
-
|
84
|
+
|
88
85
|
if response.status_code == 200:
|
89
86
|
data = response.json()
|
90
87
|
if data['code'] == 200 and data['type'] == 'success':
|
91
88
|
self.conversation = data
|
92
|
-
PrettyOutput.print(
|
93
|
-
f"创建会话成功: {
|
94
|
-
data['result']['id']}",
|
95
|
-
OutputType.SUCCESS)
|
89
|
+
PrettyOutput.print(f"创建会话成功: {data['result']['id']}", OutputType.SUCCESS)
|
96
90
|
return True
|
97
91
|
else:
|
98
|
-
PrettyOutput.print(
|
99
|
-
f"创建会话失败: {
|
100
|
-
data['message']}",
|
101
|
-
OutputType.ERROR)
|
92
|
+
PrettyOutput.print(f"创建会话失败: {data['message']}", OutputType.ERROR)
|
102
93
|
return False
|
103
94
|
else:
|
104
|
-
PrettyOutput.print(
|
105
|
-
f"创建会话失败: {
|
106
|
-
response.status_code}",
|
107
|
-
OutputType.ERROR)
|
95
|
+
PrettyOutput.print(f"创建会话失败: {response.status_code}", OutputType.ERROR)
|
108
96
|
return False
|
109
|
-
|
97
|
+
|
110
98
|
except Exception as e:
|
111
99
|
PrettyOutput.print(f"创建会话异常: {str(e)}", OutputType.ERROR)
|
112
100
|
return False
|
113
|
-
|
101
|
+
|
114
102
|
def set_system_message(self, message: str):
|
115
103
|
"""Set system message"""
|
116
104
|
self.system_message = message
|
117
|
-
|
105
|
+
|
118
106
|
def chat(self, message: str) -> str:
|
119
107
|
"""Execute chat with the model
|
120
|
-
|
108
|
+
|
121
109
|
Args:
|
122
110
|
message: User input message
|
123
|
-
|
111
|
+
|
124
112
|
Returns:
|
125
113
|
str: Model response
|
126
114
|
"""
|
@@ -129,7 +117,7 @@ class OyiModel(BasePlatform):
|
|
129
117
|
if not self.conversation:
|
130
118
|
if not self.create_conversation():
|
131
119
|
raise Exception("Failed to create conversation")
|
132
|
-
|
120
|
+
|
133
121
|
# 1. 发送消息
|
134
122
|
headers = {
|
135
123
|
'Authorization': f'Bearer {self.token}',
|
@@ -139,14 +127,14 @@ class OyiModel(BasePlatform):
|
|
139
127
|
'Origin': 'https://ai.rcouyi.com',
|
140
128
|
'Referer': 'https://ai.rcouyi.com/'
|
141
129
|
}
|
142
|
-
|
130
|
+
|
143
131
|
payload = {
|
144
132
|
"topicId": self.conversation['result']['id'],
|
145
133
|
"messages": self.messages,
|
146
134
|
"content": message,
|
147
135
|
"contentFiles": []
|
148
136
|
}
|
149
|
-
|
137
|
+
|
150
138
|
# 如果有上传的文件,添加到请求中
|
151
139
|
if self.first_chat:
|
152
140
|
if self.upload_files:
|
@@ -165,64 +153,63 @@ class OyiModel(BasePlatform):
|
|
165
153
|
self.first_chat = False
|
166
154
|
|
167
155
|
self.messages.append({"role": "user", "content": message})
|
168
|
-
|
156
|
+
|
169
157
|
# 发送消息
|
170
158
|
response = requests.post(
|
171
159
|
f"{self.BASE_URL}/chatapi/chat/message",
|
172
160
|
headers=headers,
|
173
161
|
json=payload
|
174
162
|
)
|
175
|
-
|
163
|
+
|
176
164
|
if response.status_code != 200:
|
177
165
|
error_msg = f"聊天请求失败: {response.status_code}"
|
178
166
|
PrettyOutput.print(error_msg, OutputType.ERROR)
|
179
167
|
raise Exception(error_msg)
|
180
|
-
|
168
|
+
|
181
169
|
data = response.json()
|
182
170
|
if data['code'] != 200 or data['type'] != 'success':
|
183
171
|
error_msg = f"聊天失败: {data.get('message', '未知错误')}"
|
184
172
|
PrettyOutput.print(error_msg, OutputType.ERROR)
|
185
173
|
raise Exception(error_msg)
|
186
|
-
|
174
|
+
|
187
175
|
message_id = data['result'][-1]
|
188
|
-
|
176
|
+
|
189
177
|
# 获取响应内容
|
190
178
|
response = requests.post(
|
191
179
|
f"{self.BASE_URL}/chatapi/chat/message/{message_id}",
|
192
180
|
headers=headers
|
193
181
|
)
|
194
|
-
|
182
|
+
|
195
183
|
if response.status_code == 200:
|
196
184
|
PrettyOutput.print(response.text, OutputType.SYSTEM)
|
197
|
-
self.messages.append(
|
198
|
-
{"role": "assistant", "content": response.text})
|
185
|
+
self.messages.append({"role": "assistant", "content": response.text})
|
199
186
|
return response.text
|
200
187
|
else:
|
201
188
|
error_msg = f"获取响应失败: {response.status_code}"
|
202
189
|
PrettyOutput.print(error_msg, OutputType.ERROR)
|
203
190
|
raise Exception(error_msg)
|
204
|
-
|
191
|
+
|
205
192
|
except Exception as e:
|
206
193
|
PrettyOutput.print(f"聊天异常: {str(e)}", OutputType.ERROR)
|
207
194
|
raise e
|
208
|
-
|
195
|
+
|
209
196
|
def name(self) -> str:
|
210
197
|
"""Return model name"""
|
211
198
|
return self.model_name
|
212
|
-
|
199
|
+
|
213
200
|
def reset(self):
|
214
201
|
"""Reset model state"""
|
215
202
|
self.messages = []
|
216
203
|
self.conversation = None
|
217
204
|
self.upload_files = []
|
218
205
|
self.first_chat = True
|
219
|
-
|
206
|
+
|
220
207
|
def delete_chat(self) -> bool:
|
221
208
|
"""Delete current chat session"""
|
222
209
|
try:
|
223
210
|
if not self.conversation:
|
224
211
|
return True
|
225
|
-
|
212
|
+
|
226
213
|
headers = {
|
227
214
|
'Authorization': f'Bearer {self.token}',
|
228
215
|
'Content-Type': 'application/json',
|
@@ -231,15 +218,13 @@ class OyiModel(BasePlatform):
|
|
231
218
|
'Origin': 'https://ai.rcouyi.com',
|
232
219
|
'Referer': 'https://ai.rcouyi.com/'
|
233
220
|
}
|
234
|
-
|
221
|
+
|
235
222
|
response = requests.post(
|
236
|
-
f"{
|
237
|
-
self.BASE_URL}/chatapi/chat/{
|
238
|
-
self.conversation['result']['id']}",
|
223
|
+
f"{self.BASE_URL}/chatapi/chat/{self.conversation['result']['id']}",
|
239
224
|
headers=headers,
|
240
225
|
json={}
|
241
226
|
)
|
242
|
-
|
227
|
+
|
243
228
|
if response.status_code == 200:
|
244
229
|
data = response.json()
|
245
230
|
if data['code'] == 200 and data['type'] == 'success':
|
@@ -254,17 +239,17 @@ class OyiModel(BasePlatform):
|
|
254
239
|
error_msg = f"删除会话请求失败: {response.status_code}"
|
255
240
|
PrettyOutput.print(error_msg, OutputType.ERROR)
|
256
241
|
return False
|
257
|
-
|
242
|
+
|
258
243
|
except Exception as e:
|
259
244
|
PrettyOutput.print(f"删除会话异常: {str(e)}", OutputType.ERROR)
|
260
245
|
return False
|
261
|
-
|
262
|
-
def
|
246
|
+
|
247
|
+
def upload_files(self, file_list: List[str]) -> List[Dict]:
|
263
248
|
"""Upload a file to OYI API
|
264
|
-
|
249
|
+
|
265
250
|
Args:
|
266
251
|
file_path: Path to the file to upload
|
267
|
-
|
252
|
+
|
268
253
|
Returns:
|
269
254
|
Dict: Upload response data
|
270
255
|
"""
|
@@ -272,12 +257,9 @@ class OyiModel(BasePlatform):
|
|
272
257
|
# 检查当前模型是否支持文件上传
|
273
258
|
model_info = self.models.get(self.model_name)
|
274
259
|
if not model_info or not model_info.get('uploadFile', False):
|
275
|
-
PrettyOutput.print(
|
276
|
-
f"当前模型 {
|
277
|
-
self.model_name} 不支持文件上传",
|
278
|
-
OutputType.WARNING)
|
260
|
+
PrettyOutput.print(f"当前模型 {self.model_name} 不支持文件上传", OutputType.WARNING)
|
279
261
|
return None
|
280
|
-
|
262
|
+
|
281
263
|
headers = {
|
282
264
|
'Authorization': f'Bearer {self.token}',
|
283
265
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36',
|
@@ -286,45 +268,45 @@ class OyiModel(BasePlatform):
|
|
286
268
|
'Origin': 'https://ai.rcouyi.com',
|
287
269
|
'Referer': 'https://ai.rcouyi.com/'
|
288
270
|
}
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
files=
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
271
|
+
|
272
|
+
for file_path in file_list:
|
273
|
+
# 检查文件类型
|
274
|
+
file_type = mimetypes.guess_type(file_path)[0]
|
275
|
+
if not file_type or not file_type.startswith(('image/', 'text/', 'application/')):
|
276
|
+
PrettyOutput.print(f"文件类型不支持: {file_type}", OutputType.ERROR)
|
277
|
+
continue
|
278
|
+
|
279
|
+
with open(file_path, 'rb') as f:
|
280
|
+
files = {
|
281
|
+
'file': (os.path.basename(file_path), f, file_type)
|
282
|
+
}
|
283
|
+
|
284
|
+
response = requests.post(
|
285
|
+
f"{self.BASE_URL}/chatapi/m_file/uploadfile",
|
286
|
+
headers=headers,
|
287
|
+
files=files
|
288
|
+
)
|
289
|
+
|
290
|
+
if response.status_code == 200:
|
291
|
+
data = response.json()
|
292
|
+
if data.get('code') == 200:
|
293
|
+
PrettyOutput.print("文件上传成功", OutputType.SUCCESS)
|
294
|
+
self.upload_files.append(data)
|
295
|
+
return data
|
296
|
+
else:
|
297
|
+
PrettyOutput.print(f"文件上传失败: {data.get('message')}", OutputType.ERROR)
|
298
|
+
return None
|
308
299
|
else:
|
309
|
-
PrettyOutput.print(
|
310
|
-
f"文件上传失败: {
|
311
|
-
data.get('message')}",
|
312
|
-
OutputType.ERROR)
|
300
|
+
PrettyOutput.print(f"文件上传失败: {response.status_code}", OutputType.ERROR)
|
313
301
|
return None
|
314
|
-
|
315
|
-
PrettyOutput.print(
|
316
|
-
f"文件上传失败: {
|
317
|
-
response.status_code}",
|
318
|
-
OutputType.ERROR)
|
319
|
-
return None
|
320
|
-
|
302
|
+
|
321
303
|
except Exception as e:
|
322
304
|
PrettyOutput.print(f"文件上传异常: {str(e)}", OutputType.ERROR)
|
323
305
|
return None
|
324
306
|
|
325
307
|
def get_available_models(self) -> List[str]:
|
326
308
|
"""获取可用的模型列表
|
327
|
-
|
309
|
+
|
328
310
|
Returns:
|
329
311
|
List[str]: 可用模型名称列表
|
330
312
|
"""
|
@@ -336,59 +318,55 @@ class OyiModel(BasePlatform):
|
|
336
318
|
'Origin': 'https://ai.rcouyi.com',
|
337
319
|
'Referer': 'https://ai.rcouyi.com/'
|
338
320
|
}
|
339
|
-
|
321
|
+
|
340
322
|
response = requests.get(
|
341
323
|
"https://ai.rcouyi.com/config/system.json",
|
342
324
|
headers=headers
|
343
325
|
)
|
344
|
-
|
326
|
+
|
345
327
|
if response.status_code != 200:
|
346
|
-
PrettyOutput.print(
|
347
|
-
f"获取模型列表失败: {
|
348
|
-
response.status_code}",
|
349
|
-
OutputType.ERROR)
|
328
|
+
PrettyOutput.print(f"获取模型列表失败: {response.status_code}", OutputType.ERROR)
|
350
329
|
return []
|
351
|
-
|
330
|
+
|
352
331
|
data = response.json()
|
353
|
-
|
332
|
+
|
354
333
|
# 保存模型信息
|
355
334
|
self.models = {
|
356
335
|
model['value']: model
|
357
336
|
for model in data.get('model', [])
|
358
337
|
if model.get('enable', False) # 只保存启用的模型
|
359
338
|
}
|
360
|
-
|
339
|
+
|
361
340
|
# 格式化显示
|
362
341
|
models = []
|
363
342
|
for model in self.models.values():
|
364
343
|
# 基本信息
|
365
344
|
model_str = f"{model['value']:<30} {model['label']}"
|
366
|
-
|
345
|
+
|
367
346
|
# 添加后缀标签
|
368
347
|
suffix = model.get('suffix', [])
|
369
348
|
if suffix:
|
370
349
|
# 处理新格式的suffix (字典列表)
|
371
350
|
if suffix and isinstance(suffix[0], dict):
|
372
|
-
suffix_str = ', '.join(s.get('tag', '')
|
373
|
-
for s in suffix)
|
351
|
+
suffix_str = ', '.join(s.get('tag', '') for s in suffix)
|
374
352
|
# 处理旧格式的suffix (字符串列表)
|
375
353
|
else:
|
376
354
|
suffix_str = ', '.join(str(s) for s in suffix)
|
377
355
|
model_str += f" ({suffix_str})"
|
378
|
-
|
356
|
+
|
379
357
|
# 添加描述或提示
|
380
358
|
info = model.get('tooltip') or model.get('description', '')
|
381
359
|
if info:
|
382
360
|
model_str += f" - {info}"
|
383
|
-
|
361
|
+
|
384
362
|
# 添加文件上传支持标记
|
385
363
|
if model.get('uploadFile'):
|
386
364
|
model_str += " [支持文件上传]"
|
387
|
-
|
365
|
+
|
388
366
|
models.append(model_str)
|
389
|
-
|
367
|
+
|
390
368
|
return sorted(models)
|
391
|
-
|
369
|
+
|
392
370
|
except Exception as e:
|
393
371
|
PrettyOutput.print(f"获取模型列表异常: {str(e)}", OutputType.ERROR)
|
394
372
|
return []
|
jarvis/models/registry.py
CHANGED
@@ -14,7 +14,6 @@ REQUIRED_METHODS = [
|
|
14
14
|
('set_system_message', ['message'])
|
15
15
|
]
|
16
16
|
|
17
|
-
|
18
17
|
class PlatformRegistry:
|
19
18
|
"""平台注册器"""
|
20
19
|
|
@@ -30,82 +29,76 @@ class PlatformRegistry:
|
|
30
29
|
# 创建 __init__.py 使其成为 Python 包
|
31
30
|
with open(os.path.join(user_platform_dir, "__init__.py"), "w") as f:
|
32
31
|
pass
|
33
|
-
PrettyOutput.print(
|
34
|
-
f"已创建平台目录: {user_platform_dir}",
|
35
|
-
OutputType.INFO)
|
32
|
+
PrettyOutput.print(f"已创建平台目录: {user_platform_dir}", OutputType.INFO)
|
36
33
|
except Exception as e:
|
37
34
|
PrettyOutput.print(f"创建平台目录失败: {str(e)}", OutputType.ERROR)
|
38
35
|
return ""
|
39
36
|
return user_platform_dir
|
40
37
|
|
41
38
|
@staticmethod
|
42
|
-
def check_platform_implementation(
|
43
|
-
platform_class: Type[BasePlatform]) -> bool:
|
39
|
+
def check_platform_implementation(platform_class: Type[BasePlatform]) -> bool:
|
44
40
|
"""检查平台类是否实现了所有必要的方法
|
45
|
-
|
41
|
+
|
46
42
|
Args:
|
47
43
|
platform_class: 要检查的平台类
|
48
|
-
|
44
|
+
|
49
45
|
Returns:
|
50
46
|
bool: 是否实现了所有必要的方法
|
51
47
|
"""
|
52
48
|
missing_methods = []
|
53
|
-
|
49
|
+
|
54
50
|
for method_name, params in REQUIRED_METHODS:
|
55
51
|
if not hasattr(platform_class, method_name):
|
56
52
|
missing_methods.append(method_name)
|
57
53
|
continue
|
58
|
-
|
54
|
+
|
59
55
|
method = getattr(platform_class, method_name)
|
60
56
|
if not callable(method):
|
61
57
|
missing_methods.append(method_name)
|
62
58
|
continue
|
63
|
-
|
59
|
+
|
64
60
|
# 检查方法参数
|
65
61
|
import inspect
|
66
62
|
sig = inspect.signature(method)
|
67
63
|
method_params = [p for p in sig.parameters if p != 'self']
|
68
64
|
if len(method_params) != len(params):
|
69
65
|
missing_methods.append(f"{method_name}(参数不匹配)")
|
70
|
-
|
66
|
+
|
71
67
|
if missing_methods:
|
72
68
|
PrettyOutput.print(
|
73
|
-
f"平台 {
|
74
|
-
platform_class.__name__} 缺少必要的方法: {
|
75
|
-
', '.join(missing_methods)}",
|
69
|
+
f"平台 {platform_class.__name__} 缺少必要的方法: {', '.join(missing_methods)}",
|
76
70
|
OutputType.ERROR
|
77
71
|
)
|
78
72
|
return False
|
79
|
-
|
73
|
+
|
80
74
|
return True
|
81
75
|
|
82
76
|
@staticmethod
|
83
|
-
def load_platform_from_dir(
|
84
|
-
directory: str) -> Dict[str, Type[BasePlatform]]:
|
77
|
+
def load_platform_from_dir(directory: str) -> Dict[str, Type[BasePlatform]]:
|
85
78
|
"""从指定目录加载平台
|
86
|
-
|
79
|
+
|
87
80
|
Args:
|
88
81
|
directory: 平台目录路径
|
89
|
-
|
82
|
+
|
90
83
|
Returns:
|
91
84
|
Dict[str, Type[BaseModel]]: 平台名称到平台类的映射
|
92
85
|
"""
|
93
86
|
platforms = {}
|
94
|
-
|
87
|
+
|
95
88
|
# 确保目录存在
|
96
89
|
if not os.path.exists(directory):
|
97
90
|
PrettyOutput.print(f"平台目录不存在: {directory}", OutputType.ERROR)
|
98
91
|
return platforms
|
99
|
-
|
92
|
+
|
100
93
|
# 获取目录的包名
|
101
94
|
package_name = None
|
102
95
|
if directory == os.path.dirname(__file__):
|
103
96
|
package_name = "jarvis.models"
|
104
|
-
|
97
|
+
|
105
98
|
# 添加目录到Python路径
|
106
99
|
if directory not in sys.path:
|
107
100
|
sys.path.append(directory)
|
108
|
-
|
101
|
+
|
109
102
|
# 遍历目录下的所有.py文件
|
110
103
|
for filename in os.listdir(directory):
|
111
104
|
if filename.endswith('.py') and not filename.startswith('__'):
|
@@ -113,55 +106,46 @@ class PlatformRegistry:
|
|
113
106
|
try:
|
114
107
|
# 导入模块
|
115
108
|
if package_name:
|
116
|
-
module = importlib.import_module(
|
117
|
-
f"{package_name}.{module_name}")
|
109
|
+
module = importlib.import_module(f"{package_name}.{module_name}")
|
118
110
|
else:
|
119
111
|
module = importlib.import_module(module_name)
|
120
|
-
|
112
|
+
|
121
113
|
# 遍历模块中的所有类
|
122
114
|
for name, obj in inspect.getmembers(module):
|
123
115
|
# 检查是否是BaseModel的子类,但不是BaseModel本身
|
124
|
-
if (inspect.isclass(obj) and
|
125
|
-
issubclass(obj, BasePlatform) and
|
116
|
+
if (inspect.isclass(obj) and
|
117
|
+
issubclass(obj, BasePlatform) and
|
126
118
|
obj != BasePlatform and
|
127
|
-
|
119
|
+
hasattr(obj, 'platform_name')):
|
128
120
|
# 检查平台实现
|
129
|
-
if not PlatformRegistry.check_platform_implementation(
|
130
|
-
obj):
|
121
|
+
if not PlatformRegistry.check_platform_implementation(obj):
|
131
122
|
continue
|
132
123
|
platforms[obj.platform_name] = obj
|
133
|
-
PrettyOutput.print(
|
134
|
-
f"从 {directory} 加载平台: {
|
135
|
-
obj.platform_name}", OutputType.INFO)
|
124
|
+
PrettyOutput.print(f"从 {directory} 加载平台: {obj.platform_name}", OutputType.INFO)
|
136
125
|
break
|
137
126
|
except Exception as e:
|
138
|
-
PrettyOutput.print(
|
139
|
-
|
140
|
-
str(e)}", OutputType.ERROR)
|
141
|
-
|
127
|
+
PrettyOutput.print(f"加载平台 {module_name} 失败: {str(e)}", OutputType.ERROR)
|
128
|
+
|
142
129
|
return platforms
|
143
130
|
|
131
|
+
|
144
132
|
@staticmethod
|
145
133
|
def get_global_platform_registry():
|
146
134
|
"""获取全局平台注册器"""
|
147
135
|
if PlatformRegistry.global_platform_registry is None:
|
148
136
|
PlatformRegistry.global_platform_registry = PlatformRegistry()
|
149
|
-
|
137
|
+
|
150
138
|
# 从用户平台目录加载额外平台
|
151
139
|
platform_dir = PlatformRegistry.get_platform_dir()
|
152
140
|
if platform_dir and os.path.exists(platform_dir):
|
153
|
-
for platform_name, platform_class in PlatformRegistry.load_platform_from_dir(
|
154
|
-
|
155
|
-
PlatformRegistry.global_platform_registry.register_platform(
|
156
|
-
platform_name, platform_class)
|
141
|
+
for platform_name, platform_class in PlatformRegistry.load_platform_from_dir(platform_dir).items():
|
142
|
+
PlatformRegistry.global_platform_registry.register_platform(platform_name, platform_class)
|
157
143
|
platform_dir = os.path.dirname(__file__)
|
158
144
|
if platform_dir and os.path.exists(platform_dir):
|
159
|
-
for platform_name, platform_class in PlatformRegistry.load_platform_from_dir(
|
160
|
-
|
161
|
-
PlatformRegistry.global_platform_registry.register_platform(
|
162
|
-
platform_name, platform_class)
|
145
|
+
for platform_name, platform_class in PlatformRegistry.load_platform_from_dir(platform_dir).items():
|
146
|
+
PlatformRegistry.global_platform_registry.register_platform(platform_name, platform_class)
|
163
147
|
return PlatformRegistry.global_platform_registry
|
164
|
-
|
148
|
+
|
165
149
|
def __init__(self):
|
166
150
|
"""初始化平台注册器
|
167
151
|
"""
|
@@ -170,37 +154,34 @@ class PlatformRegistry:
|
|
170
154
|
@staticmethod
|
171
155
|
def get_global_platform() -> BasePlatform:
|
172
156
|
"""获取全局平台实例"""
|
173
|
-
platform = PlatformRegistry.get_global_platform_registry(
|
174
|
-
).create_platform(PlatformRegistry.global_platform_name)
|
157
|
+
platform = PlatformRegistry.get_global_platform_registry().create_platform(PlatformRegistry.global_platform_name)
|
175
158
|
if not platform:
|
176
|
-
raise Exception(
|
177
|
-
f"Failed to create platform: {
|
178
|
-
PlatformRegistry.global_platform_name}")
|
159
|
+
raise Exception(f"Failed to create platform: {PlatformRegistry.global_platform_name}")
|
179
160
|
return platform
|
180
|
-
|
161
|
+
|
181
162
|
def register_platform(self, name: str, platform_class: Type[BasePlatform]):
|
182
163
|
"""注册平台类
|
183
|
-
|
164
|
+
|
184
165
|
Args:
|
185
166
|
name: 平台名称
|
186
167
|
model_class: 平台类
|
187
168
|
"""
|
188
169
|
self.platforms[name] = platform_class
|
189
170
|
PrettyOutput.print(f"已注册平台: {name}", OutputType.INFO)
|
190
|
-
|
171
|
+
|
191
172
|
def create_platform(self, name: str) -> Optional[BasePlatform]:
|
192
173
|
"""创建平台实例
|
193
|
-
|
174
|
+
|
194
175
|
Args:
|
195
176
|
name: 平台名称
|
196
|
-
|
177
|
+
|
197
178
|
Returns:
|
198
179
|
BaseModel: 平台实例
|
199
180
|
"""
|
200
181
|
if name not in self.platforms:
|
201
182
|
PrettyOutput.print(f"未找到平台: {name}", OutputType.ERROR)
|
202
183
|
return None
|
203
|
-
|
184
|
+
|
204
185
|
try:
|
205
186
|
platform = self.platforms[name]()
|
206
187
|
PrettyOutput.print(f"已创建平台实例: {name}", OutputType.INFO)
|
@@ -208,11 +189,11 @@ class PlatformRegistry:
|
|
208
189
|
except Exception as e:
|
209
190
|
PrettyOutput.print(f"创建平台失败: {str(e)}", OutputType.ERROR)
|
210
191
|
return None
|
211
|
-
|
192
|
+
|
212
193
|
def get_available_platforms(self) -> List[str]:
|
213
194
|
"""获取可用平台列表"""
|
214
|
-
return list(self.platforms.keys())
|
215
|
-
|
195
|
+
return list(self.platforms.keys())
|
196
|
+
|
216
197
|
def set_global_platform_name(self, platform_name: str):
|
217
198
|
"""设置全局平台"""
|
218
199
|
PlatformRegistry.global_platform_name = platform_name
|