zyworkflow 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,72 @@
1
+ import os
2
+ import traceback
3
+ import pandas as pd
4
+ from zyworkflow.utils.logger_config import setup_data_collection_logger
5
+ logger = setup_data_collection_logger()
6
+
7
+
8
+ def process_dataset(dataset_root, record_count):
9
+ try:
10
+ traj = f"traj_{record_count:03d}"
11
+ logger.info(f"开始对轨迹{traj}进行后处理")
12
+
13
+ traj_path = os.path.join(dataset_root, traj)
14
+ old_img_dir = os.path.join(traj_path, traj)
15
+ new_img_dir = os.path.join(traj_path, "images")
16
+
17
+ old_txt = os.path.join(traj_path, f"{traj}.txt")
18
+ new_csv = os.path.join(traj_path, "actions.csv")
19
+
20
+ if os.path.isdir(old_img_dir) and not os.path.exists(new_img_dir):
21
+ os.rename(old_img_dir, new_img_dir)
22
+
23
+ if os.path.isfile(old_txt) and not os.path.exists(new_csv):
24
+ df = pd.read_csv(old_txt)
25
+ if 'Image_Filename' in df.columns:
26
+ filtered_df = df[df['Image_Filename'].notna() & (df['Image_Filename'].astype(str).str.strip() != '')]
27
+
28
+ if not filtered_df.empty:
29
+ filtered_df.to_csv(new_csv, index=False)
30
+ logger.info(f"CSV文件已创建,包含{len(filtered_df)}行数据")
31
+ else:
32
+ logger.info("没有符合条件的行,不创建CSV文件")
33
+ else:
34
+ logger.info("文件中没有Image_Filename列,不创建CSV文件")
35
+
36
+ logger.info("第一阶段完成:重命名 images & txt → csv")
37
+
38
+ rename_map = {
39
+ "j1(rad)": "j1",
40
+ "j2(rad)": "j2",
41
+ "j3(rad)": "j3",
42
+ "j4(rad)": "j4",
43
+ "j5(rad)": "j5",
44
+ "j6(rad)": "j6",
45
+ "Gripper_Set": "Gripper_Set_Position(‰)"
46
+ }
47
+
48
+ traj_path = os.path.join(dataset_root, traj)
49
+ csv_path = os.path.join(traj_path, "actions.csv")
50
+
51
+ if not os.path.isfile(csv_path):
52
+ logger.error(f"轨迹{traj}的actions.csv文件不存在")
53
+ return
54
+
55
+ df = pd.read_csv(csv_path)
56
+ df.rename(columns=rename_map, inplace=True)
57
+
58
+ df["success_flag"] = 0
59
+ if len(df) > 0:
60
+ df.loc[df.index[-1], "success_flag"] = 1
61
+
62
+ df.to_csv(csv_path, index=False)
63
+ logger.info("第二阶段完成:列名修改 & success_flag 添加")
64
+ logger.info(f"轨迹{traj}后处理完成")
65
+ except Exception as e:
66
+ logger.error(f"轨迹{traj}后处理失败: {e}\n{traceback.format_exc()}")
67
+ return
68
+
69
+
70
+ if __name__ == "__main__":
71
+ dataset_root = "/home/user/8T/caizewu/code/lnn/dataset/trajectory_data_pick_update"
72
+ process_dataset(dataset_root)
zyworkflow/doc/api.md ADDED
@@ -0,0 +1,461 @@
1
+ # ai-workflow API 接口文档
2
+
3
+ > 基于 `api_server.py`(FastAPI 实现),所有接口默认返回 `application/json`。
4
+
5
+ ---
6
+
7
+ ## 基本信息
8
+
9
+ - 服务地址:`http://<host>:8003`
10
+ - 在线 Swagger:`http://<host>:8003/docs`
11
+
12
+ 统一错误返回格式(FastAPI `HTTPException`):
13
+
14
+ ```json
15
+ { "detail": "错误原因描述" }
16
+ ```
17
+
18
+ ---
19
+
20
+ ## 数据目录规范
21
+
22
+ 本项目的数据、训练产物都默认落在 `/workspace` 下
23
+
24
+ ### 1) 数据采集目
25
+
26
+ - **采集根目录**:
27
+
28
+ ```text
29
+ /workspace/dataset/<dataset_id>/<ability_id>/
30
+ ```
31
+
32
+ - **每次采集一条轨迹**会生成一个 `traj_XXX/` 目录(XXX 为 3 位序号,从 001 递增):
33
+
34
+ ```text
35
+ /workspace/dataset/<dataset_id>/<ability_id>/
36
+ traj_001/
37
+ actions.csv
38
+ images/
39
+ 0.000000.png
40
+ 0.050000.png
41
+ ...
42
+ traj_002/
43
+ actions.csv
44
+ images/
45
+ ...
46
+ ```
47
+
48
+ - `actions.csv` 中关键字段:
49
+
50
+ ```text
51
+ Time(s), X(m),Y(m),Z(m), Rx(rad),Ry(rad),Rz(rad), j1..j6, Gripper_Set_Position(‰), Gripper_Real, SKU, Image_Filename, success_flag
52
+ ```
53
+
54
+ - `Image_Filename` 指向同目录下图片文件(位于 `traj_XXX/` 子目录中)。图片文件名为 `Time(s)` 的浮点字符串(保留 6 位小数)+ `.png`
55
+
56
+ ### 2) 训练数据目录
57
+
58
+ 训练请求会在服务端拼出训练数据根目录:
59
+
60
+ ```text
61
+ root_dir = /workspace/dataset/<dataset_id>/<ability_id>/
62
+ ```
63
+
64
+ 并把该 `root_dir` 传给 `policy/train_pick_policy.py` 的 `--root_dir`。
65
+
66
+ 训练脚本会在 `root_dir` 下扫描所有 `traj_*` 目录(见 `SingleViewRobotTrajectoryDataset.__init__`),因此训练目录期望结构为:
67
+
68
+ ```text
69
+ /workspace/dataset/<dataset_id>/<ability_id>/
70
+ traj_001/
71
+ actions.csv
72
+ images/
73
+ *.png
74
+ traj_002/
75
+ ...
76
+ ```
77
+
78
+ ### 3) 训练产物目录(日志与 checkpoint)
79
+
80
+ - **训练日志默认路径**(可通过 `TrainRequest.log_path` 覆盖):
81
+
82
+ ```text
83
+ /workspace/logs/<task_id>/<ability_id>/<model_id>/training_log.txt
84
+ ```
85
+
86
+ - **checkpoint 默认目录**(可通过 `TrainRequest.ckpt_dir` 覆盖):
87
+
88
+ ```text
89
+ /workspace/checkpoints/<task_id>/<ability_id>/<model_id>/
90
+ ```
91
+
92
+ - checkpoint 文件名:
93
+
94
+ ```text
95
+ epoch_<N>.pth
96
+ ```
97
+
98
+ 其中 `<N>` 为 1-based epoch
99
+
100
+ ### 4) 测试/推理模型路径
101
+
102
+ 测试任务会在服务端拼出模型文件完整路径:
103
+
104
+ ```text
105
+ model_path = /workspace/checkpoints/<task_id>/<ability_id>/<model_id>/<model_name>
106
+ ```
107
+
108
+ 因此 `model_name` 一般取值示例:
109
+
110
+ ```text
111
+ epoch_1.pth
112
+ epoch_10.pth
113
+ ...
114
+ ```
115
+
116
+ ---
117
+
118
+ ## 目录
119
+
120
+ - [服务信息 `GET /`](#服务信息-get-)
121
+ - [启动训练 `POST /train`](#启动训练-post-train)
122
+ - [查询训练状态 `GET /train/status/{task_id}/{ability_id}/{model_id}`](#查询训练状态-get-trainstatustask_idability_idmodel_id)
123
+ - [停止训练 `POST /train/stop/{task_id}/{ability_id}/{model_id}`](#停止训练-post-trainstoptask_idability_idmodel_id)
124
+ - [提交数据采集 `POST /data/collection`](#提交数据采集-post-datacollection)
125
+ - [启动测试 `POST /test`](#启动测试-post-test)
126
+ - [查询测试状态 `GET /test/status/{task_id}/{ability_id}/{model_id}/{model_name}`](#查询测试状态-get-teststatustask_idability_idmodel_idmodel_name)
127
+ - [停止测试/急停 `POST /test/stop/{task_id}/{ability_id}/{model_id}/{model_name}`](#停止测试急停-post-teststoptask_idability_idmodel_idmodel_name)
128
+
129
+ ---
130
+
131
+ ## 服务信息 `GET /`
132
+
133
+ 返回服务基本信息。
134
+
135
+ ### 请求参数
136
+
137
+
138
+
139
+ ### 响应示例
140
+
141
+ ```json
142
+ {
143
+ "message": "BNN 训练和测试服务",
144
+ "version": "1.0"
145
+ }
146
+ ```
147
+
148
+ ### CURL 示例
149
+
150
+ ```bash
151
+ curl -X GET "http://127.0.0.1:8003/"
152
+ ```
153
+
154
+ ---
155
+
156
+ ## 启动训练 `POST /train`
157
+
158
+ 后台以独立进程运行训练脚本(`policy/train_pick_policy.py`)。
159
+
160
+ ### 请求体(`TrainRequest`)
161
+
162
+ | 字段 | 类型 | 必填 | 默认值 | 说明 |
163
+ |------|------|------|--------|------|
164
+ | task_id | string | ✔ | - | 任务ID |
165
+ | dataset_id | string | ✔ | - | 数据集ID(与采集侧一致的概念) |
166
+ | model_id | string | ✔ | - | 模型版本ID |
167
+ | ability_id | string | ✔ | - | 原子动作ID |
168
+ | algo_type | string | ✔ | - | 算法类型(目前仅支持 `bnn`) |
169
+ | action_type | string | ✔ | - | 动作类型(`pick`/`place`) |
170
+ | batch_size | int | ✖ | 48 | 批大小 |
171
+ | seq_len | int | ✖ | 4 | 图像序列长度 |
172
+ | action_chunk | int | ✖ | 8 | 预测步长 |
173
+ | lr | float | ✖ | 1e-4 | 学习率 |
174
+ | num_epochs | int | ✖ | 500 | 训练轮数 |
175
+ | start_epoch | int | ✖ | 0 | 起始 epoch(断点续训) |
176
+ | lambda_joints | float | ✖ | 10.0 | joints 损失权重 |
177
+ | lambda_grip | float | ✖ | 5.0 | gripper 损失权重 |
178
+ | lambda_success | float | ✖ | 2.0 | success 损失权重 |
179
+ | log_path | string | ✖ | null | 日志保存路径(不传用默认规则生成) |
180
+ | ckpt_dir | string | ✖ | null | checkpoint 目录(不传用默认规则生成) |
181
+ | success_mode | string | ✖ | within_horizon | 成功率评估模式:`within_horizon`/`terminal_only` |
182
+ | report_url | string | ✖ | null | 训练过程上报回调地址 |
183
+
184
+ ### 训练进度回调(`report_url`)
185
+
186
+ 训练脚本每完成一个 epoch 会 POST JSON 到 `report_url`(见 `policy/train_pick_policy.py`),格式:
187
+
188
+ ```json
189
+ {
190
+ "task_name": "<task_id>-<ability_id>-<model_id>",
191
+ "epoch": 12, // 训练轮次
192
+ "duration_sec": 37.42, // 本轮所用时间
193
+ "avg_loss": 0.0123, // 本轮平均loss
194
+ "j_err": 0.045, // 本轮关节角平均误差
195
+ "msg": "Ep 12 Saved. Time: ...", // 本轮完整信息
196
+ "is_finished": false, // 是否训练完成
197
+ "model_path": "/workspace/checkpoints/<task_id>/<ability_id>/<model_id>/epoch_12.pth" // 模型地址
198
+ }
199
+ ```
200
+
201
+ ### 成功响应(`TrainResponse`)
202
+
203
+ ```json
204
+ {
205
+ "status": "started",
206
+ "message": "已下发训练任务",
207
+ "task_name": "<task_id>-<ability_id>-<model_id>"
208
+ }
209
+ ```
210
+
211
+ ### CURL 示例
212
+
213
+ ```bash
214
+ curl -X POST "http://127.0.0.1:8003/train" \
215
+ -H "Content-Type: application/json" \
216
+ -d '{
217
+ "task_id": "t1",
218
+ "dataset_id": "d1",
219
+ "model_id": "m1",
220
+ "ability_id": "a1",
221
+ "algo_type": "bnn",
222
+ "action_type": "pick",
223
+ "num_epochs": 20,
224
+ "batch_size": 64,
225
+ "lr": 0.0005,
226
+ "report_url": "http://127.0.0.1:9000/report"
227
+ }'
228
+ ```
229
+
230
+ ---
231
+
232
+ ## 查询训练状态 `GET /train/status/{task_id}/{ability_id}/{model_id}`
233
+
234
+ 返回指定训练任务的实时状态。
235
+
236
+ ### `status` 可能取值
237
+
238
+ 训练状态由服务端内存 `training_status[task_name]` 维护,并在训练进程退出时自动更新
239
+
240
+ - **`running`**:训练进程已启动并在运行中。
241
+ - **`stopping`**:已收到停止请求,服务端正在终止训练进程组。
242
+ - **`stopped`**:训练已被用户停止。
243
+ - **`completed`**:训练进程正常退出(return code = 0)。
244
+ - **`failed`**:训练进程异常退出(return code != 0)。
245
+
246
+ 说明:
247
+
248
+ - 训练任务不存在时会返回 **404**。
249
+
250
+ ### 响应示例
251
+
252
+ ```json
253
+ {
254
+ "status": "running",
255
+ "message": "Process started (PID=12345)"
256
+ }
257
+ ```
258
+
259
+ ### CURL 示例
260
+
261
+ ```bash
262
+ curl -X GET "http://127.0.0.1:8003/train/status/t1/a1/m1"
263
+ ```
264
+
265
+ ---
266
+
267
+ ## 停止训练 `POST /train/stop/{task_id}/{ability_id}/{model_id}`
268
+
269
+ 向进程组发送 SIGTERM
270
+
271
+ ### 响应示例
272
+
273
+ ```json
274
+ {
275
+ "status": "stopped",
276
+ "message": "训练已被用户停止。"
277
+ }
278
+ ```
279
+
280
+ ### CURL 示例
281
+
282
+ ```bash
283
+ curl -X POST "http://127.0.0.1:8003/train/stop/t1/a1/m1"
284
+ ```
285
+
286
+ ---
287
+
288
+ ## 提交数据采集 `POST /data/collection`
289
+
290
+ 提交采集任务,服务端异步执行,任务状态存储在内存中。
291
+
292
+ ### 请求体(`TaskRequest`)
293
+
294
+ | 字段 | 类型 | 必填 | 默认值 | 说明 |
295
+ |------|------|------|--------|------|
296
+ | sku | string | ✔ | - | 物体编号 |
297
+ | ability_id | string | ✔ | - | 原子动作ID |
298
+ | dataset_id | string | ✔ | - | 数据集ID |
299
+ | algo_type | string | ✔ | - | 算法类型(目前仅支持 `bnn`) |
300
+ | action_type | string | ✔ | - | 动作类型(与 `create("<algo>-<action>")` 对应注册名) |
301
+ | init_pose | list[float] | ✔ | - | 初始关节姿态 |
302
+ | speed | int | ✖ | 40 | 运动速度 |
303
+ | sampling_rate | int | ✖ | 20 | 采样率 Hz |
304
+ | callback_url | string | ✖ | null | 任务完成回调地址 |
305
+
306
+ ### 成功响应(`TaskResponse`)
307
+
308
+ ```json
309
+ {
310
+ "task_name": "<dataset_id>-<ability_id>",
311
+ "status": "pending",
312
+ "message": "数据采集任务已提交",
313
+ "result": null
314
+ }
315
+ ```
316
+
317
+ ### 采集任务回调 `callback_url`
318
+
319
+ 当采集任务结束(成功/失败)且 `callback_url` 非空时,服务端会 POST JSON:
320
+
321
+ ```json
322
+ {
323
+ "code": 0,
324
+ "status": "completed",
325
+ "message": "Execution completed",
326
+ "sku": "SKU123",
327
+ "task_name": "d1-a1",
328
+ "dataset_id": "d1",
329
+ "ability_id": "a1",
330
+ "traj_path": "/workspace/dataset/d1/a1/traj_001"
331
+ }
332
+ ```
333
+
334
+ - `code == 0` 表示成功
335
+ - `traj_path` 为该次轨迹目录路径
336
+
337
+ ### CURL 示例
338
+
339
+ ```bash
340
+ curl -X POST "http://127.0.0.1:8003/data/collection" \
341
+ -H "Content-Type: application/json" \
342
+ -d '{
343
+ "sku": "SKU123",
344
+ "ability_id": "a1",
345
+ "dataset_id": "d1",
346
+ "algo_type": "bnn",
347
+ "action_type": "pick",
348
+ "init_pose": [0,0,0,0,0,0],
349
+ "sampling_rate": 20,
350
+ "callback_url": "http://127.0.0.1:9000/collect_cb"
351
+ }'
352
+ ```
353
+
354
+
355
+ ## 启动测试 `POST /test`
356
+
357
+ 服务端会从相机拉取图片(`utils/utils.py` 中 `get_image(rgb_image_url)`)并进行多步推理,结果写入服务内存状态;如传 `callback_url` 会在结束时回调。
358
+
359
+ ### 请求体(`TestRequest`)
360
+
361
+ | 字段 | 类型 | 必填 | 默认值 | 说明 |
362
+ |------|------|------|--------|------|
363
+ | task_id | string | ✔ | - | 任务ID |
364
+ | ability_id | string | ✔ | - | 原子动作ID |
365
+ | model_id | string | ✔ | - | 模型版本ID |
366
+ | model_name | string | ✔ | - | 模型文件名(例如 `epoch_10.pth`) |
367
+ | algo_type | string | ✔ | - | 算法类型(目前仅支持 `bnn`) |
368
+ | action_type | string | ✔ | - | 动作类型(`pick`/`place`) |
369
+ | seq_len | int | ✖ | 4 | 序列长度 |
370
+ | action_chunk | int | ✖ | 8 | 动作块 |
371
+ | step | int | ✖ | 200 | 推理步数 |
372
+ | callback_url | string | ✖ | null | 测试结束回调 |
373
+
374
+ ### 成功响应
375
+
376
+ ```json
377
+ {
378
+ "code": 0,
379
+ "status": "started",
380
+ "message": "测试任务已提交",
381
+ "task_name": "<task_id>-<ability_id>-<model_id>-<model_name>"
382
+ }
383
+ ```
384
+
385
+ ### CURL 示例
386
+
387
+ ```bash
388
+ curl -X POST "http://127.0.0.1:8003/test" \
389
+ -H "Content-Type: application/json" \
390
+ -d '{
391
+ "task_id": "t1",
392
+ "ability_id": "a1",
393
+ "model_id": "m1",
394
+ "model_name": "epoch_10.pth",
395
+ "algo_type": "bnn",
396
+ "action_type": "pick",
397
+ "step": 200
398
+ }'
399
+ ```
400
+
401
+ ---
402
+
403
+ ## 查询测试状态 `GET /test/status/{task_id}/{ability_id}/{model_id}/{model_name}`
404
+
405
+ 返回服务端内存中的测试任务状态与过程数据。
406
+
407
+ ### `status` 可能取值
408
+
409
+ 测试状态由服务端内存 `test_tasks[task_name]` 维护
410
+
411
+ - **`starting`**:测试任务已提交,后台任务尚未进入主循环。
412
+ - **`running`**:测试任务执行中(会持续更新 `current_step/joints/gripper/success`)。
413
+ - **`completed`**:测试完成。
414
+ - **`failed`**:测试失败(通常 `message` 会包含异常原因)。
415
+ - **`stopped`**:测试任务被急停/停止(`POST /test/stop/...` 后后台检测到停止标志)。
416
+
417
+ 说明:
418
+
419
+ - 测试任务不存在时会返回 **404**。
420
+ - `stop_requested` 为布尔值,表示是否已收到停止请求。
421
+ - `current_step` 从 0 开始递增。
422
+
423
+ ### 响应示例
424
+
425
+ ```json
426
+ {
427
+ "status": "running",
428
+ "message": "测试任务执行中...",
429
+ "stop_requested": false,
430
+ "current_step": 12,
431
+ "joints": [[...], ...],
432
+ "gripper": [...],
433
+ "success": [...]
434
+ }
435
+ ```
436
+
437
+ ### CURL 示例
438
+
439
+ ```bash
440
+ curl -X GET "http://127.0.0.1:8003/test/status/t1/a1/m1/epoch_10.pth"
441
+ ```
442
+
443
+ ---
444
+
445
+ ## 停止测试/急停 `POST /test/stop/{task_id}/{ability_id}/{model_id}/{model_name}`
446
+
447
+ 会先调用机械臂急停 `post_arm_stop()`,然后设置 `stop_requested=true`,后台循环检测到后结束。
448
+
449
+ ### 响应示例
450
+
451
+ ```json
452
+ { "code": 0, "message": "急停指令已发送" }
453
+ ```
454
+
455
+ ### CURL 示例
456
+
457
+ ```bash
458
+ curl -X POST "http://127.0.0.1:8003/test/stop/t1/a1/m1/epoch_10.pth"
459
+ ```
460
+
461
+ ---
File without changes