zyworkflow 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zyworkflow/__init__.py ADDED
File without changes
@@ -0,0 +1,630 @@
1
+ import os
2
+ import sys
3
+ import cv2
4
+ import time
5
+ import torch
6
+ import signal
7
+ import threading
8
+ import traceback
9
+ import subprocess
10
+ import numpy as np
11
+ from pydantic import BaseModel, Field
12
+ from typing import List, Optional, Dict, Any, Literal
13
+
14
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
15
+ from zyworkflow.utils.utils import *
16
+ from zyworkflow.utils.logger_config import setup_api_server_logger
17
+ from zyworkflow.policy.train_pick_policy import (
18
+ PersistentBNNPool,
19
+ SingleViewBNNActionPolicy,
20
+ )
21
+
22
+
23
+ logger = setup_api_server_logger()
24
+ app = FastAPI(title="BNN 训练和测试服务", description="用于机器人抓取任务的训练和真机测试")
25
+
26
+ bnn_pool_cache = None
27
+ test_tasks = {}
28
+ test_tasks_lock = threading.Lock()
29
+ training_status = {}
30
+ training_processes = {}
31
+ training_lock = threading.RLock()
32
+ collection_task_store = {}
33
+ collection_executor = None
34
+ collection_executor_lock = threading.Lock()
35
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
36
+ TRAIN_SCRIPT = os.path.join(SCRIPT_DIR, "policy", "train_pick_policy.py")
37
+
38
+
39
+ class TrainRequest(BaseModel):
40
+ task_id: str = Field(..., description="任务ID,必填")
41
+ dataset_id: str = Field(..., description="数据集ID,必填")
42
+ model_id: str = Field(..., description="模型版本ID,必填")
43
+ ability_id: str = Field(..., description="原子动作ID,必填")
44
+ algo_type: str = Field(..., description="算法类型,目前仅支持'bnn',必填")
45
+ action_type: str = Field(..., description="动作类型,目前仅支持'pick'和'place',必填")
46
+ batch_size: int = 48
47
+ seq_len: int = 4
48
+ action_chunk: int = 8
49
+ lr: float = 1e-4
50
+ num_epochs: int = 500
51
+ start_epoch: int = 0
52
+ lambda_joints: float = 10.0
53
+ lambda_grip: float = 5.0
54
+ lambda_success: float = 2.0
55
+ log_path: Optional[str] = None
56
+ ckpt_dir: Optional[str] = None
57
+ success_mode: str = "within_horizon"
58
+ report_url: Optional[str] = None
59
+
60
+ @property
61
+ def task_name(self) -> str:
62
+ return f"{self.task_id}-{self.ability_id}-{self.model_id}"
63
+
64
+ def get_log_path(self) -> str:
65
+ if self.log_path:
66
+ return self.log_path
67
+ return f"/workspace/logs/{self.task_id}/{self.ability_id}/{self.model_id}/training_log.txt"
68
+
69
+ def get_ckpt_dir(self) -> str:
70
+ if self.ckpt_dir:
71
+ return self.ckpt_dir
72
+ return f"/workspace/checkpoints/{self.task_id}/{self.ability_id}/{self.model_id}"
73
+
74
+
75
+ class TestRequest(BaseModel):
76
+ task_id: str = Field(..., description="任务ID,必填")
77
+ ability_id: str = Field(..., description="原子动作ID,必填")
78
+ model_id: str = Field(..., description="模型版本ID,必填")
79
+ model_name: str = Field(..., description="模型文件名/别名,必填")
80
+ algo_type: str = Field(..., description="算法类型,目前仅支持'bnn',必填")
81
+ action_type: str = Field(..., description="动作类型,目前仅支持'pick'和'place',必填")
82
+ seq_len: int = 4
83
+ action_chunk: int = 8
84
+ step: int = 200
85
+ callback_url: Optional[str] = None
86
+
87
+
88
+ class TaskRequest(BaseModel):
89
+ sku: str
90
+ ability_id: str = Field(..., description="原子动作ID,必填")
91
+ dataset_id: str = Field(..., description="数据集ID,必填")
92
+ algo_type: str = Field(..., description="算法类型,目前仅支持'bnn',必填")
93
+ action_type: str = Field(..., description="动作类型,目前仅支持'pick'和'place',必填")
94
+ init_pose: Optional[List[float]] = Field(..., description="机械臂初始姿态,必填")
95
+ speed: Optional[int] = 40
96
+ sampling_rate: Optional[int] = 20
97
+ callback_url: Optional[str] = None
98
+
99
+
100
+ class TrainResponse(BaseModel):
101
+ status: str
102
+ message: str
103
+ task_name: str
104
+
105
+
106
+ class TaskResponse(BaseModel):
107
+ task_name: str
108
+ status: str
109
+ message: Optional[str] = None
110
+ result: Optional[Dict] = None
111
+
112
+
113
+ async def execute_collection_task(algo_type: str, action_type: str, ability_id: str, dataset_id: str, sku: str, speed: int, init_pose: Optional[List[float]], sampling_rate: int = 20) -> Dict[str, Any]:
114
+ global collection_executor
115
+
116
+ try:
117
+ handle = create(f"{algo_type}-{action_type}", args=(ability_id, dataset_id, init_pose), kws={"speed": speed, "sampling_rate": sampling_rate})
118
+ except Exception as e:
119
+ logger.error(f"[data collection] 创建采集任务失败: {e}")
120
+ return {"code": -1, "msg": f"创建采集任务失败: {e}"}
121
+
122
+ import asyncio
123
+ loop = asyncio.get_event_loop()
124
+
125
+ with collection_executor_lock:
126
+ if collection_executor is None:
127
+ import concurrent.futures
128
+ collection_executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
129
+
130
+ try:
131
+ result = await loop.run_in_executor(collection_executor, lambda: handle.run_from_http_camera(sku))
132
+ if isinstance(result, dict):
133
+ return result
134
+ logger.error(f"[data collection] 任务返回结果类型不支持: {type(result)}")
135
+ return {"code": -1, "msg": f"任务返回结果类型不支持: {type(result)}"}
136
+ except Exception as e:
137
+ logger.error(f"[data collection] 采集任务执行失败: {e}\n{traceback.format_exc()}")
138
+ return {"code": -1, "msg": f"任务执行失败: {e}"}
139
+
140
+
141
+ async def send_collection_callback(callback_url: str, task_name: str, sku: str, result: Dict[str, Any]) -> None:
142
+ import aiohttp
143
+
144
+ try:
145
+ msg = result.get("msg")
146
+ async with aiohttp.ClientSession() as session:
147
+ async with session.post(
148
+ callback_url,
149
+ json={
150
+ "code": result.get("code"),
151
+ "status": collection_task_store[task_name].status,
152
+ "message": f"{msg}",
153
+ "sku": sku,
154
+ "task_name": task_name,
155
+ "dataset_id": result.get("dataset_id", ""),
156
+ "ability_id": result.get("ability_id", ""),
157
+ "traj_path": result.get("traj_path", ""),
158
+ },
159
+ headers={"Content-Type": "application/json"},
160
+ timeout=aiohttp.ClientTimeout(total=30),
161
+ ) as response:
162
+ if response.status != 200:
163
+ text = await response.text()
164
+ logger.warning(f"[data collection] 回调发送失败: {response.status} - {text}")
165
+ except Exception as e:
166
+ logger.error(f"[data collection] 回调发送异常: {e}")
167
+
168
+
169
+ async def process_collection_task_background(algo_type: str, action_type: str, task_name: str, ability_id: str, dataset_id: str, sku: str, init_pose: Optional[List[float]], callback_url: Optional[str], speed: int = 40, sampling_rate: int = 20):
170
+ try:
171
+ if task_name in collection_task_store:
172
+ collection_task_store[task_name].status = "running"
173
+ collection_task_store[task_name].message = "任务执行中..."
174
+
175
+ result = await execute_collection_task(algo_type, action_type, ability_id, dataset_id, sku, speed, init_pose, sampling_rate)
176
+
177
+ if task_name not in collection_task_store:
178
+ collection_task_store[task_name] = TaskResponse(task_name=task_name, status="running")
179
+
180
+ if result.get("code") == 0:
181
+ collection_task_store[task_name].status = "completed"
182
+ collection_task_store[task_name].message = result.get("msg", "任务执行成功")
183
+ else:
184
+ collection_task_store[task_name].status = "failed"
185
+ collection_task_store[task_name].message = result.get("msg", "任务执行失败")
186
+
187
+ collection_task_store[task_name].result = result
188
+
189
+ if callback_url:
190
+ await send_collection_callback(callback_url, task_name, sku, result)
191
+
192
+ except Exception as e:
193
+ if task_name in collection_task_store:
194
+ collection_task_store[task_name].status = "failed"
195
+ collection_task_store[task_name].message = f"任务执行异常: {str(e)}"
196
+ collection_task_store[task_name].result = {"code": -1, "msg": str(e)}
197
+ logger.error(f"[data collection] 任务 {task_name} 执行异常: {e}\n{traceback.format_exc()}")
198
+
199
+
200
+ def _update_proc_status_no_throw(task_name: str) -> None:
201
+ with training_lock:
202
+ proc_info = training_processes.get(task_name)
203
+ if not proc_info:
204
+ return
205
+
206
+ proc: subprocess.Popen = proc_info["popen"]
207
+ rc = proc.poll()
208
+ if rc is None:
209
+ return
210
+
211
+ st = training_status.get(task_name, {})
212
+ if st.get("status") == "stopping":
213
+ st["status"] = "stopped"
214
+ st["message"] = "训练已被用户停止。"
215
+ elif rc == 0:
216
+ st["status"] = "completed"
217
+ st["message"] = "训练成功完成。"
218
+ else:
219
+ st["status"] = "failed"
220
+ st["message"] = f"训练失败,进程返回码: {rc}。请检查训练日志。"
221
+
222
+ training_processes.pop(task_name, None)
223
+
224
+
225
+ def _kill_process_group(pgid: int, term_timeout_sec: float = 5.0) -> None:
226
+ try:
227
+ os.killpg(pgid, signal.SIGTERM)
228
+ except ProcessLookupError:
229
+ return
230
+
231
+ deadline = time.time() + term_timeout_sec
232
+ while time.time() < deadline:
233
+ try:
234
+ os.killpg(pgid, 0)
235
+ except ProcessLookupError:
236
+ return
237
+ time.sleep(0.2)
238
+
239
+ try:
240
+ os.killpg(pgid, signal.SIGKILL)
241
+ except ProcessLookupError:
242
+ return
243
+
244
+
245
+ def load_model_for_inference(model_path: str, seq_len: int = 4, action_chunk: int = 8, device=None):
246
+ if device is None:
247
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
248
+ model = SingleViewBNNActionPolicy(seq_len=seq_len, action_chunk=action_chunk).to(device)
249
+ checkpoint = torch.load(model_path, map_location=device)
250
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
251
+ model.load_state_dict(checkpoint['model_state_dict'])
252
+ joint_mean = checkpoint.get('joint_mean', torch.zeros(6))
253
+ joint_std = checkpoint.get('joint_std', torch.ones(6))
254
+ else:
255
+ model.load_state_dict(checkpoint)
256
+ joint_mean = torch.zeros(6)
257
+ joint_std = torch.ones(6)
258
+ model.eval()
259
+ return model, joint_mean.to(device), joint_std.to(device)
260
+
261
+
262
+ def inference_single_step(model, bnn_pool, image_seq: np.ndarray, seq_len: int, device: torch.device):
263
+ img_tensor = torch.from_numpy(image_seq).unsqueeze(0).to(device)
264
+ with torch.no_grad():
265
+ bnn_in = model.encode_visual(img_tensor)
266
+ bnn_outputs = []
267
+ for i in range(seq_len):
268
+ curr_feat = bnn_in[:, i, :].detach().cpu().numpy() / 10.0
269
+ bnn_out = bnn_pool.step_batch([curr_feat.reshape(1, -1)])[0]
270
+ if bnn_out is not None and hasattr(bnn_out, "shape") and bnn_out.shape == (1, 80):
271
+ bnn_out = bnn_out.T
272
+ if bnn_out is None:
273
+ bnn_out = np.zeros(80)
274
+ bnn_out = np.array(bnn_out).squeeze()
275
+ bnn_outputs.append(bnn_out)
276
+ bnn_seq_tensor = torch.tensor(np.stack(bnn_outputs), device=device, dtype=torch.float32).unsqueeze(0)
277
+ p_j, p_g, p_s = model.decode_action(bnn_in, bnn_seq_tensor)
278
+ return p_j, p_g, p_s
279
+
280
+
281
+ @app.get("/")
282
+ async def root():
283
+ return {"message": "BNN 训练和测试服务", "version": "1.0"}
284
+
285
+
286
+ @app.post("/train", response_model=TrainResponse)
287
+ async def train_model(request: TrainRequest):
288
+ task_name = request.task_name
289
+ root_dir = os.path.join("/workspace/dataset", request.dataset_id, request.ability_id)
290
+ logger.info(f"[train] 接收到新的训练任务: algo_type={request.algo_type}, action_type={request.action_type}, task_id={request.task_id}, ability_id={request.ability_id}, model_id={request.model_id}, dataset_id={request.dataset_id}")
291
+ try:
292
+ if not os.path.exists(root_dir):
293
+ logger.error(f"[train] 没有找到训练数据: {root_dir}")
294
+ raise HTTPException(status_code=500, detail=f"Dataset not found: {root_dir}")
295
+ if not os.path.exists(TRAIN_SCRIPT):
296
+ logger.error(f"[train] 没有找到训练脚本: {TRAIN_SCRIPT}")
297
+ raise HTTPException(status_code=500, detail=f"Train script not found: {TRAIN_SCRIPT}")
298
+
299
+ with training_lock:
300
+ if task_name in training_processes:
301
+ _update_proc_status_no_throw(task_name)
302
+ if task_name in training_processes:
303
+ logger.error(f"[train] 任务{task_name}已经在训练")
304
+ raise HTTPException(status_code=409, detail=f"Task '{task_name}' is already running.")
305
+
306
+ log_path = request.get_log_path()
307
+ ckpt_dir = request.get_ckpt_dir()
308
+
309
+ os.makedirs(os.path.dirname(log_path), exist_ok=True)
310
+ os.makedirs(ckpt_dir, exist_ok=True)
311
+
312
+ cmd = [
313
+ sys.executable, TRAIN_SCRIPT,
314
+ "--task_name", task_name,
315
+ "--root_dir", root_dir,
316
+ "--ckpt_dir", ckpt_dir,
317
+ "--batch_size", str(request.batch_size),
318
+ "--seq_len", str(request.seq_len),
319
+ "--action_chunk", str(request.action_chunk),
320
+ "--lr", str(request.lr),
321
+ "--num_epochs", str(request.num_epochs),
322
+ "--start_epoch", str(request.start_epoch),
323
+ "--lambda_joints", str(request.lambda_joints),
324
+ "--lambda_grip", str(request.lambda_grip),
325
+ "--lambda_success", str(request.lambda_success),
326
+ "--success_mode", request.success_mode,
327
+ ]
328
+ if log_path:
329
+ cmd.extend(["--log_path", log_path])
330
+ if request.report_url:
331
+ cmd.extend(["--report_url", request.report_url])
332
+
333
+ stdout_f = open(log_path, "a", encoding="utf-8")
334
+ proc = subprocess.Popen(cmd, cwd=SCRIPT_DIR, stdout=stdout_f, stderr=stdout_f, preexec_fn=os.setsid)
335
+
336
+ with training_lock:
337
+ st = training_status.get(task_name, {})
338
+ st.update({"status": "running", "message": f"Process started (PID={proc.pid})"})
339
+ training_status[task_name] = st
340
+ training_processes[task_name] = {"popen": proc, "pid": proc.pid, "pgid": os.getpgid(proc.pid)}
341
+ logger.info(f"[train] [{task_name}]: 已下发训练任务")
342
+ return TrainResponse(status="started", message="已下发训练任务", task_name=task_name)
343
+ except Exception as e:
344
+ logger.error(f"[train] [{task_name}] Failed to start training: {e}\n{traceback.format_exc()}")
345
+ raise HTTPException(status_code=500, detail=str(e))
346
+
347
+
348
+ @app.get("/train/status/{task_id}/{ability_id}/{model_id}")
349
+ async def get_training_status(task_id: str, ability_id: str, model_id: str):
350
+ task_name = f"{task_id}-{ability_id}-{model_id}"
351
+ logger.info(f"[train status] 接收获取训练状态请求: task_id={task_id}, ability_id={ability_id}, model_id={model_id}")
352
+ with training_lock:
353
+ if task_name not in training_status:
354
+ logger.error(f"[train status] 没有找到{task_name}训练任务")
355
+ raise HTTPException(status_code=404, detail=f"Task '{task_name}' not found.")
356
+ _update_proc_status_no_throw(task_name)
357
+ with training_lock:
358
+ logger.info(f"[train status] [{task_name}]: {training_status[task_name]}")
359
+ return training_status[task_name]
360
+
361
+
362
+ @app.post("/train/stop/{task_id}/{ability_id}/{model_id}")
363
+ async def stop_training(task_id: str, ability_id: str, model_id: str):
364
+ task_name = f"{task_id}-{ability_id}-{model_id}"
365
+ logger.info(f"[train stop] 接收到停止训练请求: task_id={task_id}, ability_id={ability_id}, model_id={model_id}")
366
+ with training_lock:
367
+ proc_info = training_processes.get(task_name)
368
+ if not proc_info: # 如果为None, 则更新状态
369
+ _update_proc_status_no_throw(task_name)
370
+ st = training_status.get(task_name, {})
371
+ if st.get("status") in ["completed", "failed", "stopped"]:
372
+ return {"status": "already_finished", "message": st.get("message")}
373
+ logger.error(f"[train stop] 没有找到待停止任务: {task_name}")
374
+ raise HTTPException(status_code=404, detail="Running task not found.")
375
+ pgid = proc_info["pgid"]
376
+ training_status[task_name]["status"] = "stopping"
377
+
378
+ _kill_process_group(pgid)
379
+ _update_proc_status_no_throw(task_name)
380
+
381
+ with training_lock:
382
+ logger.info(f"[train stop] [{task_name}]: Task stopped.")
383
+ return training_status.get(task_name, {"status": "stopped", "message": "Task stopped."})
384
+
385
+
386
+ @app.post("/data/collection", response_model=TaskResponse)
387
+ async def submit_collection_task(request: TaskRequest, background_tasks: BackgroundTasks):
388
+ task_name = f"{request.dataset_id}-{request.ability_id}"
389
+ logger.info(f"[data collection] 接收到数据采集任务提交请求: {task_name}, sku={request.sku}, algo_type={request.algo_type}, action_type={request.action_type}")
390
+
391
+ valid_tasks = global_config.keys()
392
+ func = f"{request.algo_type}-{request.action_type}"
393
+ if func not in valid_tasks:
394
+ logger.error(f"[data collection] 任务类型错误: {func}")
395
+ return TaskResponse(
396
+ task_name=task_name,
397
+ status="failed",
398
+ message=f"任务类型错误,可选值: {list(valid_tasks)}",
399
+ )
400
+
401
+ collection_task_store[task_name] = TaskResponse(
402
+ task_name=task_name,
403
+ status="pending",
404
+ message="数据采集任务已提交",
405
+ )
406
+
407
+ background_tasks.add_task(
408
+ process_collection_task_background,
409
+ task_name,
410
+ request.algo_type,
411
+ request.action_type,
412
+ request.ability_id,
413
+ request.dataset_id,
414
+ request.sku,
415
+ request.init_pose,
416
+ request.callback_url,
417
+ request.speed if request.speed is not None else 40,
418
+ request.sampling_rate if request.sampling_rate is not None else 20,
419
+ )
420
+
421
+ logger.info(f"[data collection] 已提交新数据采集任务: {task_name}, sku={request.sku}, algo_type={request.algo_type}, action_type={request.action_type}")
422
+ return collection_task_store[task_name]
423
+
424
+
425
+ async def send_test_callback(callback_url: str, payload: Dict[str, Any]) -> None:
426
+ import aiohttp
427
+ try:
428
+ async with aiohttp.ClientSession() as session:
429
+ async with session.post(
430
+ callback_url,
431
+ json=payload,
432
+ headers={"Content-Type": "application/json"},
433
+ timeout=aiohttp.ClientTimeout(total=10),
434
+ ) as response:
435
+ if response.status != 200:
436
+ text = await response.text()
437
+ logger.warning(f"[test] 测试回调发送失败: {response.status} - {text}")
438
+ except Exception as e:
439
+ logger.error(f"[test] 测试回调发送异常: {e}")
440
+
441
+
442
+ async def process_test_task_background(task_name: str, request: TestRequest):
443
+ def _should_stop() -> bool:
444
+ with test_tasks_lock:
445
+ return bool(test_tasks.get(task_name, {}).get("stop_requested", False))
446
+
447
+ callback_url = request.callback_url
448
+
449
+ try:
450
+ with test_tasks_lock:
451
+ st = test_tasks.get(task_name, {})
452
+ st.update({"status": "running", "message": "测试任务执行中...", "start_time": time.time()})
453
+ test_tasks[task_name] = st
454
+
455
+ model_path = os.path.join("/workspace/checkpoints", request.task_id, request.ability_id, request.model_id, request.model_name)
456
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
457
+ if not os.path.exists(model_path):
458
+ logger.error(f"[test] 测试模型{model_path}没有找到")
459
+ raise RuntimeError("Model file not found.")
460
+
461
+ model, joint_mean, joint_std = load_model_for_inference(model_path, request.seq_len, request.action_chunk, device)
462
+
463
+ global bnn_pool_cache
464
+ if bnn_pool_cache is None:
465
+ bnn_pool_cache = PersistentBNNPool(num_workers=1)
466
+ bnn_pool_cache.reset_all(1)
467
+
468
+ images = []
469
+ for _ in range(request.seq_len):
470
+ img, err = get_image(rgb_image_url)
471
+ if img is None:
472
+ logger.error(f"[test] 请求相机图像失败: {err}")
473
+ raise RuntimeError(f"请求相机图像失败: {err}")
474
+ images.append(preprocess_image(img))
475
+
476
+ joints_out = []
477
+ gripper_out = []
478
+ success_out = []
479
+
480
+ for i in range(int(request.step)):
481
+ if _should_stop():
482
+ with test_tasks_lock:
483
+ test_tasks[task_name].update({"status": "stopped", "message": "测试任务被急停", "end_time": time.time()})
484
+ if callback_url:
485
+ await send_test_callback(
486
+ callback_url,
487
+ {
488
+ "code": -2,
489
+ "message": "测试任务被急停",
490
+ "task_id": request.task_id,
491
+ "ability_id": request.ability_id,
492
+ "model_id": request.model_id,
493
+ "model_name": request.model_name,
494
+ "step": i,
495
+ },
496
+ )
497
+ return
498
+
499
+ p_j, p_g, p_s = inference_single_step(model, bnn_pool_cache, np.stack(images), request.seq_len, device)
500
+
501
+ joints_step = (p_j[0] * joint_std + joint_mean).cpu().tolist()
502
+ gripper_step = torch.sigmoid(p_g[0]).squeeze(-1).cpu().tolist()
503
+ success_step = torch.sigmoid(p_s[0]).squeeze(-1).cpu().tolist()
504
+
505
+ joints_out.append(joints_step)
506
+ gripper_out.extend(gripper_step)
507
+ success_out.extend(success_step)
508
+
509
+ with test_tasks_lock:
510
+ test_tasks[task_name].update({
511
+ "current_step": i + 1,
512
+ "joints": joints_out,
513
+ "gripper": gripper_out,
514
+ "success": success_out,
515
+ })
516
+
517
+ img, err = get_image(rgb_image_url)
518
+ if img is None:
519
+ logger.error(f"[test] 请求相机图像失败: {err}")
520
+ raise RuntimeError(f"请求相机图像失败: {err}")
521
+ images = images[1:] + [preprocess_image(img)]
522
+
523
+ with test_tasks_lock:
524
+ test_tasks[task_name].update({"status": "completed", "message": "测试完成", "end_time": time.time()})
525
+
526
+ if callback_url:
527
+ await send_test_callback(
528
+ callback_url,
529
+ {
530
+ "code": 0,
531
+ "message": "测试完成",
532
+ "task_id": request.task_id,
533
+ "ability_id": request.ability_id,
534
+ "model_id": request.model_id,
535
+ "model_name": request.model_name,
536
+ "step": request.step,
537
+ },
538
+ )
539
+
540
+ except Exception as e:
541
+ logger.error(f"[test] 测试任务执行失败: {e}\n{traceback.format_exc()}")
542
+ with test_tasks_lock:
543
+ test_tasks[task_name].update({"status": "failed", "message": str(e), "end_time": time.time()})
544
+ if callback_url:
545
+ await send_test_callback(
546
+ callback_url,
547
+ {
548
+ "code": -1,
549
+ "message": f"测试时发生未知错误: {e}",
550
+ "task_id": request.task_id,
551
+ "ability_id": request.ability_id,
552
+ "model_id": request.model_id,
553
+ "model_name": request.model_name,
554
+ },
555
+ )
556
+
557
+
558
+ @app.post("/test/stop/{task_id}/{ability_id}/{model_id}/{model_name}")
559
+ async def stop_test(task_id: str, ability_id: str, model_id: str, model_name: str):
560
+ task_name = f"{task_id}-{ability_id}-{model_id}-{model_name}"
561
+ logger.info(f"[test stop] 接收到停止测试请求: task_id={task_id}, ability_id={ability_id}, model_id={model_id}, model_name={model_name}")
562
+
563
+ try:
564
+ code, msg = post_arm_stop()
565
+ if code is None:
566
+ raise RuntimeError(msg or "post_arm_stop failed")
567
+ except Exception as e:
568
+ logger.error(f"[test stop] 急停失败: {e}")
569
+ raise HTTPException(status_code=500, detail=f"急停失败: {e}")
570
+
571
+ with test_tasks_lock:
572
+ st = test_tasks.get(task_name)
573
+ if st is None:
574
+ logger.error(f"[test stop] 没有找到需要急停的{task_name}任务")
575
+ raise HTTPException(status_code=404, detail="Running test task not found.")
576
+ st["stop_requested"] = True
577
+ test_tasks[task_name] = st
578
+
579
+ return {"code": 0, "message": "急停指令已发送"}
580
+
581
+
582
+ @app.post("/test")
583
+ async def test_model(request: TestRequest, background_tasks: BackgroundTasks):
584
+ task_name = f"{request.task_id}-{request.ability_id}-{request.model_id}-{request.model_name}"
585
+ logger.info(f"[test] 接收到测试任务: task_name={task_name}, model_name={request.model_name}, algo_type={request.algo_type}, action_type={request.action_type}, step={request.step}")
586
+
587
+ with test_tasks_lock:
588
+ st = test_tasks.get(task_name)
589
+ if st and st.get("status") in ["starting", "running"]:
590
+ logger.error(f"[test] 任务{task_name}已经在测试中")
591
+ raise HTTPException(status_code=409, detail=f"Test task '{task_name}' is already running.")
592
+ test_tasks[task_name] = {
593
+ "status": "starting",
594
+ "message": "测试任务已提交",
595
+ "stop_requested": False,
596
+ "current_step": 0,
597
+ "joints": [],
598
+ "gripper": [],
599
+ "success": [],
600
+ }
601
+
602
+ background_tasks.add_task(process_test_task_background, task_name, request)
603
+
604
+ return {"code": 0, "status": "started", "message": "测试任务已提交", "task_name": task_name}
605
+
606
+
607
+ @app.get("/test/status/{task_id}/{ability_id}/{model_id}/{model_name}")
608
+ async def get_test_status(task_id: str, ability_id: str, model_id: str, model_name: str):
609
+ task_name = f"{task_id}-{ability_id}-{model_id}-{model_name}"
610
+ logger.info(f"[test status] 接收获取测试状态请求: task_id={task_id}, ability_id={ability_id}, model_id={model_id}, model_name={model_name}")
611
+ with test_tasks_lock:
612
+ st = test_tasks.get(task_name)
613
+ if st is None:
614
+ logger.error(f"[test status] 没有找到{task_name}测试任务")
615
+ raise HTTPException(status_code=404, detail=f"Test task '{task_name}' not found.")
616
+ return st
617
+
618
+
619
+ @app.on_event("shutdown")
620
+ async def shutdown_event():
621
+ with training_lock:
622
+ running = list(training_processes.values())
623
+ for info in running:
624
+ if info.get("pgid"):
625
+ _kill_process_group(info["pgid"], 2.0)
626
+
627
+
628
+ if __name__ == "__main__":
629
+ import uvicorn
630
+ uvicorn.run(app, host="0.0.0.0", port=8003, access_log=False)
File without changes