device-protocol-sdk 1.2.8__tar.gz → 1.2.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {device_protocol_sdk-1.2.8 → device_protocol_sdk-1.2.9}/PKG-INFO +1 -1
  2. device_protocol_sdk-1.2.9/device_protocol_sdk/model_workflow/__init__.py +0 -0
  3. device_protocol_sdk-1.2.9/device_protocol_sdk/model_workflow/executer_service.py +673 -0
  4. {device_protocol_sdk-1.2.8 → device_protocol_sdk-1.2.9}/device_protocol_sdk.egg-info/PKG-INFO +1 -1
  5. {device_protocol_sdk-1.2.8 → device_protocol_sdk-1.2.9}/device_protocol_sdk.egg-info/SOURCES.txt +3 -1
  6. {device_protocol_sdk-1.2.8 → device_protocol_sdk-1.2.9}/setup.py +1 -1
  7. {device_protocol_sdk-1.2.8 → device_protocol_sdk-1.2.9}/README.md +0 -0
  8. {device_protocol_sdk-1.2.8 → device_protocol_sdk-1.2.9}/device_protocol_sdk/__init__.py +0 -0
  9. {device_protocol_sdk-1.2.8 → device_protocol_sdk-1.2.9}/device_protocol_sdk/abstract_device.py +0 -0
  10. {device_protocol_sdk-1.2.8 → device_protocol_sdk-1.2.9}/device_protocol_sdk/device_pb2.py +0 -0
  11. {device_protocol_sdk-1.2.8 → device_protocol_sdk-1.2.9}/device_protocol_sdk/device_pb2_grpc.py +0 -0
  12. {device_protocol_sdk-1.2.8 → device_protocol_sdk-1.2.9}/device_protocol_sdk/http_client.py +0 -0
  13. {device_protocol_sdk-1.2.8 → device_protocol_sdk-1.2.9}/device_protocol_sdk/model/__init__.py +0 -0
  14. {device_protocol_sdk-1.2.8 → device_protocol_sdk-1.2.9}/device_protocol_sdk/model/action_item.py +0 -0
  15. {device_protocol_sdk-1.2.8 → device_protocol_sdk-1.2.9}/device_protocol_sdk/model/device_key.py +0 -0
  16. {device_protocol_sdk-1.2.8 → device_protocol_sdk-1.2.9}/device_protocol_sdk/model/device_status.py +0 -0
  17. {device_protocol_sdk-1.2.8 → device_protocol_sdk-1.2.9}/device_protocol_sdk/pusher.py +0 -0
  18. {device_protocol_sdk-1.2.8 → device_protocol_sdk-1.2.9}/device_protocol_sdk.egg-info/dependency_links.txt +0 -0
  19. {device_protocol_sdk-1.2.8 → device_protocol_sdk-1.2.9}/device_protocol_sdk.egg-info/requires.txt +0 -0
  20. {device_protocol_sdk-1.2.8 → device_protocol_sdk-1.2.9}/device_protocol_sdk.egg-info/top_level.txt +0 -0
  21. {device_protocol_sdk-1.2.8 → device_protocol_sdk-1.2.9}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: device_protocol_sdk
3
- Version: 1.2.8
3
+ Version: 1.2.9
4
4
  Summary: 无人设备协议开发SDK
5
5
  Author: fuhl
6
6
  Requires-Python: >=3.8
@@ -0,0 +1,673 @@
1
+ import logging
2
+ from typing import Dict, List, Any, Optional,Callable,Awaitable
3
+ import uuid
4
+ import time
5
+ from datetime import datetime
6
+ import inspect
7
+ import aiohttp
8
+ from ..model.device_status import NodeExecutionResult,WorkflowExecutionResponse,WorkflowExecutionRequest,ExecutionConfig
9
+
10
+ # 导入上面两个文件的内容
11
+ # 假设 model_workflow.py 定义了 Base, WorkflowTemplate, WorkflowNode 等
12
+ # 假设 workflow_dataclasses.py 定义了 ExecutionConfig, WorkflowExecutionRequest 等
13
+
14
+ class WorkflowExecutor:
15
+ """工作流执行引擎"""
16
+ async def execute_workflow(
17
+ self,
18
+ request: WorkflowExecutionRequest,
19
+ current_mqtt_data,
20
+ device_id: Optional[str] = None,
21
+ get_realtime_image: Optional[Callable[[str], Awaitable[Dict]]] = None,
22
+ call_edge_model: Optional[Callable[[str], Awaitable[Dict]]] = None
23
+ ) -> WorkflowExecutionResponse:
24
+ """执行工作流"""
25
+ success = True
26
+ execution_id = request.execution_id or str(uuid.uuid4())
27
+ start_time = time.time()
28
+ execution_record = {
29
+ 'template_id':request.template_id,
30
+ 'execution_id':execution_id,
31
+ 'model_mappings':{
32
+ node_name: {
33
+ 'model_id': mapping.model_id,
34
+ 'endpoint': mapping.endpoint
35
+ }
36
+ for node_name, mapping in request.execution_config.model_mappings.items()
37
+ },
38
+ 'parameter_overrides':request.execution_config.parameter_overrides,
39
+ 'status':'running',
40
+ 'input_data':request.input_data,
41
+ 'start_time':datetime.utcnow()
42
+ }
43
+
44
+ node_results = {}
45
+
46
+ try:
47
+ # 工作流模板
48
+ template = request.template
49
+ if not template:
50
+ raise ValueError(f"工作流模板 {request.template_id} 不存在")
51
+
52
+ # 按执行顺序获取节点
53
+ nodes = sorted(template['nodes'], key=lambda n: n['execution_order'])
54
+
55
+ # 构建执行上下文
56
+ context = {
57
+ 'workflow_input': request.input_data,
58
+ 'device_id': device_id, # 直接传递device_id
59
+ '_get_realtime_image': get_realtime_image # 将函数放入上下文
60
+ }
61
+ # 预处理连接关系,构建每个节点的输入源映射
62
+ connections_map = self._build_connections_map(template.get('connections', []))
63
+ # 执行每个节点
64
+ for node in nodes:
65
+ # 将连接关系添加到节点上下文中
66
+ node_connections = connections_map.get(node['name'], [])
67
+ node['target_connections'] = node_connections # 临时添加到节点,供_build_node_input使用
68
+
69
+ node_result = await self._execute_node(
70
+ device_id,
71
+ node,
72
+ context,
73
+ execution_record,
74
+ request.execution_config,
75
+ current_mqtt_data,call_edge_model
76
+ )
77
+
78
+ node_results[node['name']] = node_result
79
+
80
+ # 更新上下文
81
+ if node_result.success and node_result.data:
82
+ # 将节点输出按 outputs 定义的格式存入上下文
83
+ context[node['name']] = node_result.data
84
+
85
+ # 同时也将各个输出字段单独存入上下文,方便直接引用
86
+ for field_name, field_value in node_result.data.items():
87
+ context[f"{node['name']}.{field_name}"] = field_value
88
+
89
+ # 检查是否需要停止
90
+ if not node_result.success and not node['skip_on_failure']:
91
+ execution_record['status'] = 'failed'
92
+ execution_record['error_message'] = f"节点 {node.name} 执行失败: {node_result.error}"
93
+ break
94
+
95
+ # 构建最终输出
96
+ final_output = self._build_final_output(template, context)
97
+
98
+ # 更新执行记录
99
+ # execution_record['status'] = 'completed'
100
+ # execution_record['output_data'] = final_output
101
+ # execution_record['end_time'] = datetime.utcnow()
102
+ # execution_record['total_duration'] = (
103
+ # execution_record['end_time'] - execution_record['start_time']
104
+ # ).total_seconds()
105
+ #
106
+ # success = execution_record['status'] == 'completed'
107
+
108
+ except Exception as e:
109
+ execution_record['status'] = 'failed'
110
+ execution_record['error_message'] = str(e)
111
+ execution_record['end_time'] = datetime.utcnow()
112
+ success = False
113
+ final_output = None
114
+
115
+
116
+ # 构建响应
117
+ return WorkflowExecutionResponse(
118
+ success=success,
119
+ execution_id=execution_id,
120
+ data=final_output,
121
+ error=execution_record.get('error_message',None),
122
+ execution_time=time.time() - start_time,
123
+ node_results=node_results
124
+ )
125
+
126
+ def _build_connections_map(self, connections: List[Dict]) -> Dict[str, List[Dict]]:
127
+ """
128
+ 构建节点连接映射
129
+ 返回格式: {target_node_name: [connection_configs]}
130
+ """
131
+ connections_map = {}
132
+ for conn in connections:
133
+ target_node = conn['target_node_name']
134
+ if target_node not in connections_map:
135
+ connections_map[target_node] = []
136
+
137
+ # 构建连接配置
138
+ conn_config = {
139
+ 'source_node_name': conn['source_node_name'],
140
+ 'source_output': conn['source_output'],
141
+ 'target_input': conn['target_input'],
142
+ 'transform_script': conn.get('transform_script'),
143
+ 'condition': conn.get('condition')
144
+ }
145
+ connections_map[target_node].append(conn_config)
146
+
147
+ return connections_map
148
+
149
+ async def _execute_node(self,device_id,
150
+ node: Dict[str, Any],
151
+ context: Dict[str, Any],
152
+ execution_record: Dict[str, Any],
153
+ execution_config: ExecutionConfig,
154
+ current_mqtt_data: Dict[str, Any],call_edge_model) -> NodeExecutionResult:
155
+ """执行单个节点"""
156
+
157
+ start_time = time.time()
158
+
159
+ try:
160
+ # 检查执行条件
161
+ if node['condition']:
162
+ if not self._evaluate_condition(node['condition'], context):
163
+ # 创建节点执行记录(跳过)
164
+ self._create_node_execution_record(
165
+ node, execution_record, execution_config,
166
+ status='skipped', start_time=datetime.utcnow()
167
+ )
168
+
169
+ return NodeExecutionResult(
170
+ node_name=node['name'],
171
+ success=True,
172
+ data={},
173
+ execution_time=0.0
174
+ )
175
+
176
+ # 获取该节点的模型配置
177
+ model_mapping = execution_config.model_mappings.get(node['name'])
178
+ if not model_mapping:
179
+ raise ValueError(f"节点 '{node['name']}' 未配置模型映射")
180
+
181
+ # 构建节点输入
182
+ node_input = self._build_node_input(node, context)
183
+
184
+ # 合并参数
185
+ actual_parameters = node.get('default_parameters', {}).copy()
186
+ if node['name'] in execution_config.parameter_overrides:
187
+ actual_parameters.update(execution_config.parameter_overrides[node['name']])
188
+ if model_mapping.parameters_override:
189
+ actual_parameters.update(model_mapping.parameters_override)
190
+
191
+ # 创建节点执行记录
192
+ node_execution = self._create_node_execution_record(
193
+ node, execution_record, execution_config,
194
+ status='running',
195
+ start_time=datetime.utcnow(),
196
+ input_data=node_input,
197
+ actual_parameters=actual_parameters
198
+ )
199
+
200
+ # 执行调用模型(这里调用您的模型调用逻辑)
201
+ model_output = await self._call_model(
202
+ device_id,
203
+ model_mapping.model_type,
204
+ model_mapping.model_id,
205
+ model_mapping.endpoint,
206
+ node_input,
207
+ actual_parameters,
208
+ node['timeout'] or 30,
209
+ call_edge_model
210
+ )
211
+ # 根据节点配置的 outputs 字段格式化输出
212
+ formatted_output = self._format_node_output(node, model_output)
213
+
214
+ # 应用输出映射(如果配置了 output_mapping)
215
+ if node.get('output_mapping'):
216
+ mapped_output = {}
217
+ for source_key, target_key in node['output_mapping'].items():
218
+ # 支持嵌套字段访问
219
+ value = self._get_nested_value(formatted_output, source_key)
220
+ if value is not None:
221
+ mapped_output[target_key] = value
222
+ else:
223
+ logging.warning(f"节点 {node['name']} 输出映射: 源字段 {source_key} 不存在")
224
+ formatted_output = mapped_output
225
+
226
+ # 更新节点执行记录
227
+ # node_execution['status'] = 'completed'
228
+ # node_execution['output_data'] = model_output
229
+ # node_execution['end_time'] = datetime.utcnow()
230
+ # node_execution['duration'] = (node_execution.end_time - node_execution.start_time).total_seconds()
231
+
232
+
233
+ return NodeExecutionResult(
234
+ node_name=node['name'],
235
+ success=True,
236
+ model_id=model_mapping.model_id,
237
+ endpoint_url=model_mapping.endpoint,
238
+ data=model_output,
239
+ execution_time=time.time() - start_time
240
+ )
241
+
242
+ except Exception as e:
243
+ logging.error(f"执行节点报错:{e}")
244
+ # 更新节点执行记录为失败
245
+ # if 'node_execution' in locals():
246
+ # node_execution.status = 'failed'
247
+ # node_execution.error_message = str(e)
248
+ # node_execution.end_time = datetime.utcnow()
249
+
250
+ return NodeExecutionResult(
251
+ node_name=node['name'],
252
+ success=False,
253
+ model_id=model_mapping.model_id if 'model_mapping' in locals() else None,
254
+ endpoint_url=model_mapping.endpoint if 'model_mapping' in locals() else None,
255
+ error=str(e),
256
+ execution_time=time.time() - start_time
257
+ )
258
+
259
+ def _format_node_output(self, node: Dict[str, Any], model_output: Dict[str, Any]) -> Dict[str, Any]:
260
+ """
261
+ 根据节点配置的 outputs 字段格式化输出
262
+
263
+ Args:
264
+ node: 节点配置
265
+ model_output: 模型原始输出
266
+
267
+ Returns:
268
+ 格式化后的输出,只包含节点 outputs 中定义的字段
269
+
270
+ Raises:
271
+ ValueError: 当模型输出缺少节点定义的必需字段时
272
+ """
273
+ formatted_output = {}
274
+
275
+ # 获取节点定义的输出字段
276
+ node_outputs = node.get('outputs', [])
277
+
278
+ if not node_outputs:
279
+ # 如果没有定义 outputs,返回原始输出
280
+ logging.warning(f"节点 {node['name']} 未定义 outputs,直接返回模型输出")
281
+ return model_output
282
+
283
+ # 根据 outputs 定义提取字段
284
+ missing_fields = []
285
+ for output_field in node_outputs:
286
+ field_name = output_field.get('name')
287
+ if not field_name:
288
+ continue
289
+
290
+ required = output_field.get('required', True)
291
+
292
+ # 从模型输出中获取对应字段
293
+ if field_name in model_output:
294
+ formatted_output[field_name] = model_output[field_name]
295
+ else:
296
+ if required:
297
+ missing_fields.append(field_name)
298
+ else:
299
+ # 可选字段,使用默认值
300
+ data_type = output_field.get('data_type', 'string')
301
+ formatted_output[field_name] = self._get_default_value_for_type(data_type)
302
+ logging.warning(f"节点 {node['name']} 可选输出字段 '{field_name}' 未在模型输出中找到,使用默认值")
303
+
304
+ # 如果有必需的字段缺失,抛出错误
305
+ if missing_fields:
306
+ error_msg = f"节点 {node['name']} 模型输出缺少必需字段: {missing_fields}. 模型实际输出字段: {list(model_output.keys())}"
307
+ logging.error(error_msg)
308
+ raise ValueError(error_msg)
309
+
310
+ # 检查是否有多余的字段(可选,用于调试)
311
+ extra_fields = set(model_output.keys()) - set(formatted_output.keys())
312
+ if extra_fields:
313
+ logging.debug(f"节点 {node['name']} 模型输出包含未定义的字段: {extra_fields}")
314
+
315
+ return formatted_output
316
+
317
+ def _get_default_value_for_type(self, data_type: str) -> Any:
318
+ """
319
+ 根据数据类型返回默认值
320
+ """
321
+ default_values = {
322
+ 'string': '',
323
+ 'number': 0,
324
+ 'integer': 0,
325
+ 'float': 0.0,
326
+ 'boolean': False,
327
+ 'array': [],
328
+ 'object': {},
329
+ 'any': None
330
+ }
331
+ return default_values.get(data_type, None)
332
+ def _get_nested_value(self, data: Dict[str, Any], key_path: str) -> Any:
333
+ """
334
+ 获取嵌套字典中的值
335
+ 例如: "results.data" 会获取 data['results']['data']
336
+ """
337
+ try:
338
+ if '.' not in key_path:
339
+ return data.get(key_path)
340
+
341
+ parts = key_path.split('.')
342
+ current = data
343
+ for part in parts:
344
+ if isinstance(current, dict) and part in current:
345
+ current = current[part]
346
+ else:
347
+ return None
348
+ return current
349
+ except Exception as e:
350
+ logging.error(f"获取嵌套值失败: {key_path}, 错误: {e}")
351
+ return None
352
+
353
+ def _create_node_execution_record(self,
354
+ node: Dict[str, Any],
355
+ execution_record: Dict[str, Any],
356
+ execution_config: Dict[str, Any],
357
+ status: str,
358
+ start_time: datetime,
359
+ input_data: Optional[Dict] = None,
360
+ actual_parameters: Optional[Dict] = None) :
361
+ """创建节点执行记录"""
362
+
363
+ model_mapping = execution_config.model_mappings.get(node['name'])
364
+
365
+ return {}
366
+
367
+ def _build_node_input(self, node: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
368
+ """
369
+ 构建节点输入,严格按照 input_mapping 和 connections 配置
370
+ """
371
+ input_data = {}
372
+ missing_sources = []
373
+
374
+ # 处理 input_mapping
375
+ if node.get('input_mapping'):
376
+ for target_key, source_key in node['input_mapping'].items():
377
+ value = self._parse_source_key(source_key, context)
378
+ if value is not None:
379
+ input_data[target_key] = value
380
+ else:
381
+ # 检查是否是必需字段
382
+ # 这里可以根据节点的 inputs 定义来判断是否必需
383
+ missing_sources.append(source_key)
384
+
385
+ # 处理 connections
386
+ if 'connections' in node:
387
+ for conn in node['connections']:
388
+ source_node_name = conn.get('source_node_name')
389
+ source_output = conn.get('source_output')
390
+ target_input = conn.get('target_input')
391
+
392
+ if source_node_name in context:
393
+ source_data = context[source_node_name]
394
+
395
+ if source_output:
396
+ value = self._get_nested_value(source_data, source_output)
397
+ if value is None:
398
+ error_msg = f"节点 {node['name']} 从源节点 {source_node_name} 获取字段 {source_output} 失败,源节点输出: {source_data}"
399
+ logging.error(error_msg)
400
+ raise ValueError(error_msg)
401
+ else:
402
+ value = source_data
403
+
404
+ if target_input:
405
+ input_data[target_input] = value
406
+ else:
407
+ if isinstance(value, dict):
408
+ input_data.update(value)
409
+ else:
410
+ input_data['data'] = value
411
+ else:
412
+ missing_sources.append(source_node_name)
413
+
414
+ # 如果有缺失的数据源,抛出错误
415
+ if missing_sources:
416
+ error_msg = f"节点 {node['name']} 缺少必需的数据源: {missing_sources}"
417
+ logging.error(error_msg)
418
+ raise ValueError(error_msg)
419
+
420
+ logging.debug(f"节点 {node['name']} 构建的输入: {input_data}")
421
+ return input_data
422
+
423
+ def _parse_source_key(self, source_key: str, context: Dict[str, Any]) -> Any:
424
+ """
425
+ 解析数据源键,支持多种格式:
426
+ 1. device.image.url # 设备图像URL
427
+ 2. device.image.data # 设备图像base64数据
428
+ 3. node_name.field_name # 其他节点输出字段
429
+ 4. workflow_input.field # 工作流输入参数
430
+ """
431
+ try:
432
+ if '.' not in source_key:
433
+ # 简写形式
434
+ if source_key in context.get("workflow_input", {}):
435
+ return context["workflow_input"][source_key]
436
+ # 尝试从节点输出中查找
437
+ if source_key in context:
438
+ return context[source_key]
439
+ return None
440
+
441
+ parts = source_key.split('.')
442
+ root = parts[0]
443
+
444
+ if root == "device":
445
+ # 设备数据处理
446
+ if len(parts) >= 3:
447
+ data_type = parts[1] # image 或 audio
448
+ field = parts[2] # url, data, timestamp 等
449
+ device_id = context.get('device_id')
450
+ get_realtime_image = context.get('_get_realtime_image')
451
+
452
+ if data_type == 'image' and device_id and get_realtime_image:
453
+ # 检查 get_realtime_image 是否是协程函数
454
+ if inspect.iscoroutinefunction(get_realtime_image):
455
+ # 这是一个问题:_parse_source_key 是同步方法,不能直接 await
456
+ # 需要从上层传入已经获取好的图像数据
457
+ logging.error("在同步方法中无法调用异步函数 get_realtime_image")
458
+ return None
459
+ else:
460
+ # 同步调用
461
+ image_data = get_realtime_image(device_id)
462
+ if field == 'url':
463
+ return image_data.get('url')
464
+ elif field == 'data':
465
+ return image_data.get('data')
466
+ elif field == 'timestamp':
467
+ return image_data.get('timestamp')
468
+ else:
469
+ return image_data.get(field)
470
+
471
+ elif root == "workflow_input":
472
+ # 工作流输入参数
473
+ workflow_input = context.get("workflow_input", {})
474
+ if len(parts) == 2:
475
+ return workflow_input.get(parts[1])
476
+
477
+ else:
478
+ # 其他节点输出
479
+ node_name = parts[0]
480
+ field_name = '.'.join(parts[1:]) # 支持嵌套字段
481
+
482
+ if node_name in context:
483
+ node_output = context[node_name]
484
+ # 获取嵌套字段的值
485
+ return self._get_nested_value(node_output, field_name)
486
+
487
+ return None
488
+
489
+ except Exception as e:
490
+ logging.error(f"解析source_key失败: {source_key}, 错误: {e}")
491
+ return None
492
+
493
+ def _build_final_output(self, template: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
494
+ """构建最终输出"""
495
+ if template['output_schema']:
496
+ # 根据输出schema构建输出
497
+ output = {}
498
+ for field, config in template['output_schema'].items():
499
+ source = config.get('source', '')
500
+ if '.' in source:
501
+ node_name, output_field = source.split('.', 1)
502
+ if node_name in context and output_field in context[node_name]:
503
+ output[field] = context[node_name][output_field]
504
+ else:
505
+ # 默认输出最后一个节点的输出
506
+ last_node = template['nodes'][-1]
507
+ output = context.get(last_node.name, {})
508
+
509
+ # 添加执行元数据
510
+ output['_metadata'] = {
511
+ 'workflow_id': template['id'],
512
+ 'workflow_name': template['name'],
513
+ 'version': template['version'],
514
+ 'execution_time': datetime.utcnow().isoformat()
515
+ }
516
+
517
+ return output
518
+
519
+ def _evaluate_condition(self, condition: str, context: Dict[str, Any]) -> bool:
520
+ """评估执行条件"""
521
+ try:
522
+ # 安全地评估条件表达式
523
+ return eval(condition, {"__builtins__": {}}, context)
524
+ except:
525
+ return False
526
+
527
+ async def _call_model(self,
528
+ device_id,
529
+ model_type: str,
530
+ model_id: str,
531
+ endpoint: str,
532
+ input_data: Dict[str, Any],
533
+ parameters: Dict[str, Any],
534
+ timeout: int,call_edge_model) -> Dict[str, Any]:
535
+ """
536
+ 动态调用模型,适配任意工作流配置
537
+
538
+ 参数:
539
+ - input_data: 已经通过input_mapping处理好的数据,字段名由工作流定义
540
+ - parameters: 模型参数(来自default_parameters和parameter_overrides)
541
+
542
+ 返回:
543
+ - 模型输出数据,字段名和结构由模型决定
544
+ """
545
+ try:
546
+
547
+ # 构建模型调用所需的输入格式
548
+ model_input = input_data.copy()
549
+ # 如果有参数覆盖,合并到输入中
550
+ if parameters:
551
+ model_input['parameters'] = parameters
552
+
553
+ # 构建模型映射配置
554
+ model_mapping = {
555
+ "edge_url": endpoint,
556
+ "model_id": model_id,
557
+ "timeout": timeout
558
+ }
559
+ logging.info(f"调用模型 - ID: {model_id}, 端点: {endpoint}")
560
+ logging.debug(f"模型输入数据: {input_data}")
561
+
562
+ # 调用边缘模型
563
+ result = None
564
+ if model_type == 'chat-completion':
565
+ # 构建 OpenAI 兼容的 chat completion 请求格式
566
+ # 处理 messages 字段
567
+ messages = [
568
+ {
569
+ "role": "user",
570
+ "content": "nihao"
571
+ }
572
+ ]
573
+
574
+ # 准备 chat completion 的请求体
575
+ chat_payload = {
576
+ "model": model_id,
577
+ "messages": messages,
578
+ "stream": False # 默认不使用流式输出
579
+ }
580
+
581
+ # 添加可选参数
582
+ # if 'temperature' in model_input:
583
+ # chat_payload['temperature'] = model_input['temperature']
584
+ # elif parameters and 'temperature' in parameters:
585
+ # chat_payload['temperature'] = parameters['temperature']
586
+ #
587
+ # if 'max_tokens' in model_input:
588
+ # chat_payload['max_tokens'] = model_input['max_tokens']
589
+ # elif parameters and 'max_tokens' in parameters:
590
+ # chat_payload['max_tokens'] = parameters['max_tokens']
591
+ #
592
+ # if 'top_p' in model_input:
593
+ # chat_payload['top_p'] = model_input['top_p']
594
+ # elif parameters and 'top_p' in parameters:
595
+ # chat_payload['top_p'] = parameters['top_p']
596
+ #
597
+ # if 'frequency_penalty' in model_input:
598
+ # chat_payload['frequency_penalty'] = model_input['frequency_penalty']
599
+ #
600
+ # if 'presence_penalty' in model_input:
601
+ # chat_payload['presence_penalty'] = model_input['presence_penalty']
602
+
603
+ logging.debug(f"Chat completion 请求体: {chat_payload}")
604
+
605
+ # 构建完整的 API URL
606
+ api_url = f"http://{endpoint}/v1/chat/completions"
607
+
608
+ # 使用 aiohttp 发送 POST 请求
609
+ async with aiohttp.ClientSession() as session:
610
+ async with session.post(
611
+ api_url,
612
+ json=chat_payload,
613
+ headers={"Content-Type": "application/json"},
614
+ timeout=aiohttp.ClientTimeout(total=timeout)
615
+ ) as response:
616
+ if response.status != 200:
617
+ error_text = await response.text()
618
+ raise Exception(f"模型调用失败,HTTP状态码: {response.status}, 错误信息: {error_text}")
619
+
620
+ # 解析响应
621
+ response_data = await response.json()
622
+
623
+ # 提取模型输出内容
624
+ # OpenAI 兼容格式的响应结构: response.choices[0].message.content
625
+ if 'choices' in response_data and len(response_data['choices']) > 0:
626
+ # result = {
627
+ # "content": response_data['choices'][0]['message']['content'],
628
+ # "raw_response": response_data # 保存完整响应,以便后续处理
629
+ # }
630
+ result = response_data['choices'][0]['message']['content']
631
+ else:
632
+ # 如果响应格式不符合预期,返回原始响应
633
+ # result = {"raw_response": response_data}
634
+ result = response_data
635
+ else:
636
+ result = call_edge_model(device_id, model_id, model_input, model_mapping)
637
+
638
+ # 直接返回模型结果,不做硬编码的字段处理
639
+ # 因为工作流的output_mapping会处理字段映射
640
+ model_output = self._extract_model_output(result)
641
+
642
+ logging.info(f"{model_id}模型调用成功")
643
+ logging.debug(f"模型输出: {model_output}")
644
+
645
+ return model_output
646
+ except Exception as e:
647
+ raise Exception(f"模型调用失败 ({model_id}): {str(e)}")
648
+
649
+ def _extract_model_output(self, result: Dict[str, Any]) -> Dict[str, Any]:
650
+ """
651
+ 从模型调用结果中提取输出数据
652
+ 保持原始结构,让output_mapping来处理字段映射
653
+ """
654
+ # 如果result包含标准的数据结构
655
+ if isinstance(result, dict):
656
+ # 如果result有data字段,优先使用data
657
+ if 'data' in result:
658
+ output_data = result['data']
659
+ # 如果data是字典,直接返回
660
+ if isinstance(output_data, dict):
661
+ return output_data
662
+ # 如果data是其他类型,包装成字典
663
+ else:
664
+ return {'result': output_data}
665
+ # 如果result有output字段
666
+ elif 'output' in result:
667
+ return {'output': result['output']}
668
+ # 直接返回result
669
+ else:
670
+ return result
671
+ # 如果result不是字典,包装成标准格式
672
+ else:
673
+ return {'result': result}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: device_protocol_sdk
3
- Version: 1.2.8
3
+ Version: 1.2.9
4
4
  Summary: 无人设备协议开发SDK
5
5
  Author: fuhl
6
6
  Requires-Python: >=3.8
@@ -14,4 +14,6 @@ device_protocol_sdk.egg-info/top_level.txt
14
14
  device_protocol_sdk/model/__init__.py
15
15
  device_protocol_sdk/model/action_item.py
16
16
  device_protocol_sdk/model/device_key.py
17
- device_protocol_sdk/model/device_status.py
17
+ device_protocol_sdk/model/device_status.py
18
+ device_protocol_sdk/model_workflow/__init__.py
19
+ device_protocol_sdk/model_workflow/executer_service.py
@@ -7,7 +7,7 @@ long_description = readme_path.read_text(encoding="utf-8")
7
7
 
8
8
  setup(
9
9
  name="device_protocol_sdk",
10
- version="1.2.8",
10
+ version="1.2.9",
11
11
  packages=find_packages(include=["sdk*", "device_protocol_sdk*"]),
12
12
  install_requires=[
13
13
  "grpcio>=1.48.2", # gRPC 运行时依赖