zyworkflow 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zyworkflow/__init__.py +0 -0
- zyworkflow/api_server.py +630 -0
- zyworkflow/data/__init__.py +0 -0
- zyworkflow/data/collection.py +1241 -0
- zyworkflow/data/process.py +72 -0
- zyworkflow/doc/api.md +461 -0
- zyworkflow/example/__init__.py +0 -0
- zyworkflow/example/train_client.py +301 -0
- zyworkflow/example/train_client_example.py +43 -0
- zyworkflow/policy/__init__.py +0 -0
- zyworkflow/policy/train_pick_policy.py +834 -0
- zyworkflow/utils/__init__.py +0 -0
- zyworkflow/utils/logger_config.py +50 -0
- zyworkflow/utils/pose.py +131 -0
- zyworkflow/utils/utils.py +264 -0
- zyworkflow-0.0.1.dist-info/METADATA +11 -0
- zyworkflow-0.0.1.dist-info/RECORD +19 -0
- zyworkflow-0.0.1.dist-info/WHEEL +5 -0
- zyworkflow-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
import json
|
|
3
|
+
import time
|
|
4
|
+
import requests
|
|
5
|
+
import argparse
|
|
6
|
+
from typing import Optional, Literal
|
|
7
|
+
from zyworkflow.utils.logger_config import setup_train_client_logger
|
|
8
|
+
|
|
9
|
+
logger = setup_train_client_logger()
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TrainClient:
|
|
13
|
+
def __init__(self, base_url: str = "http://localhost:8001"):
|
|
14
|
+
self.base_url = base_url.rstrip('/')
|
|
15
|
+
self.train_endpoint = f"{self.base_url}/train"
|
|
16
|
+
|
|
17
|
+
def start_training(
|
|
18
|
+
self,
|
|
19
|
+
algo_type: Literal["bnn"],
|
|
20
|
+
action_type: Literal["pick", "place"],
|
|
21
|
+
task_id: str,
|
|
22
|
+
ability_id: str,
|
|
23
|
+
dataset_id: str,
|
|
24
|
+
model_id: str,
|
|
25
|
+
batch_size: int = 48,
|
|
26
|
+
seq_len: int = 4,
|
|
27
|
+
action_chunk: int = 8,
|
|
28
|
+
lr: float = 1e-4,
|
|
29
|
+
num_epochs: int = 500,
|
|
30
|
+
start_epoch: int = 0,
|
|
31
|
+
lambda_joints: float = 10.0,
|
|
32
|
+
lambda_grip: float = 5.0,
|
|
33
|
+
lambda_success: float = 2.0,
|
|
34
|
+
log_path: Optional[str] = None,
|
|
35
|
+
ckpt_dir: Optional[str] = None,
|
|
36
|
+
success_mode: str = "within_horizon",
|
|
37
|
+
report_url: Optional[str] = None,
|
|
38
|
+
) -> dict:
|
|
39
|
+
payload = {
|
|
40
|
+
"algo_type": algo_type,
|
|
41
|
+
"action_type": action_type,
|
|
42
|
+
"task_id": task_id,
|
|
43
|
+
"ability_id": ability_id,
|
|
44
|
+
"dataset_id": dataset_id,
|
|
45
|
+
"model_id": model_id,
|
|
46
|
+
"batch_size": batch_size,
|
|
47
|
+
"seq_len": seq_len,
|
|
48
|
+
"action_chunk": action_chunk,
|
|
49
|
+
"lr": lr,
|
|
50
|
+
"num_epochs": num_epochs,
|
|
51
|
+
"start_epoch": start_epoch,
|
|
52
|
+
"lambda_joints": lambda_joints,
|
|
53
|
+
"lambda_grip": lambda_grip,
|
|
54
|
+
"lambda_success": lambda_success,
|
|
55
|
+
"success_mode": success_mode,
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
if log_path is not None:
|
|
59
|
+
payload["log_path"] = log_path
|
|
60
|
+
if ckpt_dir is not None:
|
|
61
|
+
payload["ckpt_dir"] = ckpt_dir
|
|
62
|
+
if report_url is not None:
|
|
63
|
+
payload["report_url"] = report_url
|
|
64
|
+
|
|
65
|
+
try:
|
|
66
|
+
logger.info(f"正在向 {self.train_endpoint} 发送训练请求...")
|
|
67
|
+
logger.debug(f"请求参数: {json.dumps(payload, indent=2, ensure_ascii=False)}")
|
|
68
|
+
|
|
69
|
+
response = requests.post(self.train_endpoint, json=payload, timeout=30)
|
|
70
|
+
response.raise_for_status()
|
|
71
|
+
|
|
72
|
+
result = response.json()
|
|
73
|
+
logger.info(f"训练任务已启动: {result.get('task_name')}")
|
|
74
|
+
print("训练任务已启动!")
|
|
75
|
+
print(f"任务名称: {result.get('task_name')}")
|
|
76
|
+
print(f"状态: {result.get('status')}")
|
|
77
|
+
print(f"消息: {result.get('message')}")
|
|
78
|
+
|
|
79
|
+
return result
|
|
80
|
+
|
|
81
|
+
except requests.exceptions.RequestException as e:
|
|
82
|
+
logger.error(f"请求失败: {e}")
|
|
83
|
+
print(f"请求失败: {e}")
|
|
84
|
+
if hasattr(e, 'response') and e.response is not None:
|
|
85
|
+
try:
|
|
86
|
+
error_detail = e.response.json()
|
|
87
|
+
logger.error(f"错误详情: {json.dumps(error_detail, indent=2, ensure_ascii=False)}")
|
|
88
|
+
print(f"错误详情: {json.dumps(error_detail, indent=2, ensure_ascii=False)}")
|
|
89
|
+
except Exception:
|
|
90
|
+
logger.error(f"响应内容: {e.response.text}")
|
|
91
|
+
print(f"响应内容: {e.response.text}")
|
|
92
|
+
raise
|
|
93
|
+
|
|
94
|
+
def get_training_status(self, task_id: str, ability_id: str, model_id: str) -> dict:
|
|
95
|
+
try:
|
|
96
|
+
url = f"{self.base_url}/train/status/{task_id}/{ability_id}/{model_id}"
|
|
97
|
+
logger.debug(f"查询训练状态: {task_id}-{ability_id}-{model_id}")
|
|
98
|
+
response = requests.get(url, timeout=10)
|
|
99
|
+
response.raise_for_status()
|
|
100
|
+
status = response.json()
|
|
101
|
+
logger.debug(f"训练状态: {status}")
|
|
102
|
+
return status
|
|
103
|
+
|
|
104
|
+
except requests.exceptions.RequestException as e:
|
|
105
|
+
logger.error(f"查询状态失败: {e}")
|
|
106
|
+
print(f"查询状态失败: {e}")
|
|
107
|
+
if hasattr(e, 'response') and e.response is not None:
|
|
108
|
+
try:
|
|
109
|
+
error_detail = e.response.json()
|
|
110
|
+
logger.error(f"错误详情: {json.dumps(error_detail, indent=2, ensure_ascii=False)}")
|
|
111
|
+
print(f"错误详情: {json.dumps(error_detail, indent=2, ensure_ascii=False)}")
|
|
112
|
+
except Exception:
|
|
113
|
+
logger.error(f"响应内容: {e.response.text}")
|
|
114
|
+
print(f"响应内容: {e.response.text}")
|
|
115
|
+
raise
|
|
116
|
+
|
|
117
|
+
def stop_training(self, task_id: str, ability_id: str, model_id: str) -> dict:
|
|
118
|
+
try:
|
|
119
|
+
url = f"{self.base_url}/train/stop/{task_id}/{ability_id}/{model_id}"
|
|
120
|
+
logger.info(f"发送停止训练请求: {task_id}-{ability_id}-{model_id}")
|
|
121
|
+
response = requests.post(url, timeout=10)
|
|
122
|
+
response.raise_for_status()
|
|
123
|
+
result = response.json()
|
|
124
|
+
logger.info(f"停止训练结果: {result}")
|
|
125
|
+
return result
|
|
126
|
+
|
|
127
|
+
except requests.exceptions.RequestException as e:
|
|
128
|
+
logger.error(f"停止训练失败: {e}")
|
|
129
|
+
print(f"停止训练失败: {e}")
|
|
130
|
+
if hasattr(e, 'response') and e.response is not None:
|
|
131
|
+
try:
|
|
132
|
+
error_detail = e.response.json()
|
|
133
|
+
logger.error(f"错误详情: {json.dumps(error_detail, indent=2, ensure_ascii=False)}")
|
|
134
|
+
print(f"错误详情: {json.dumps(error_detail, indent=2, ensure_ascii=False)}")
|
|
135
|
+
except Exception:
|
|
136
|
+
logger.error(f"响应内容: {e.response.text}")
|
|
137
|
+
print(f"响应内容: {e.response.text}")
|
|
138
|
+
raise
|
|
139
|
+
|
|
140
|
+
def monitor_training(self, task_id: str, ability_id: str, model_id: str, interval: int = 10):
|
|
141
|
+
task_name = f"{task_id}-{ability_id}-{model_id}"
|
|
142
|
+
logger.info(f"开始监控训练任务: {task_name}, 查询间隔: {interval} 秒")
|
|
143
|
+
print(f"\n开始监控训练任务: {task_name}")
|
|
144
|
+
print(f"查询间隔: {interval} 秒")
|
|
145
|
+
print("-" * 60)
|
|
146
|
+
|
|
147
|
+
while True:
|
|
148
|
+
try:
|
|
149
|
+
status = self.get_training_status(task_id, ability_id, model_id)
|
|
150
|
+
status_msg = f"状态: {status.get('status')} | 消息: {status.get('message', 'N/A')}"
|
|
151
|
+
logger.info(f"[{task_name}] {status_msg}")
|
|
152
|
+
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {status_msg}")
|
|
153
|
+
|
|
154
|
+
if status.get('status') in ['completed', 'failed', 'stopped']:
|
|
155
|
+
logger.info(f"训练任务结束: {task_name}, 状态: {status.get('status')}")
|
|
156
|
+
print("\n" + "=" * 60)
|
|
157
|
+
print(f"训练任务结束: {status.get('status')}")
|
|
158
|
+
print(f"最终消息: {status.get('message', 'N/A')}")
|
|
159
|
+
break
|
|
160
|
+
|
|
161
|
+
time.sleep(interval)
|
|
162
|
+
|
|
163
|
+
except KeyboardInterrupt:
|
|
164
|
+
logger.warning("监控已中断")
|
|
165
|
+
print("\n\n监控已中断")
|
|
166
|
+
break
|
|
167
|
+
except Exception as e:
|
|
168
|
+
logger.error(f"监控出错: {e}")
|
|
169
|
+
print(f"监控出错: {e}")
|
|
170
|
+
time.sleep(interval)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def main():
|
|
174
|
+
parser = argparse.ArgumentParser(description="训练客户端 - 向 FastAPI 服务发送训练请求")
|
|
175
|
+
|
|
176
|
+
parser.add_argument("--url", type=str, default="http://localhost:8001",
|
|
177
|
+
help="API 服务地址 (默认: http://localhost:8001)")
|
|
178
|
+
parser.add_argument("--algo-type", type=str, default="bnn", choices=["bnn"],
|
|
179
|
+
help="算法类型 (默认: bnn)")
|
|
180
|
+
parser.add_argument("--action-type", type=str, default="pick", choices=["pick", "place"],
|
|
181
|
+
help="动作类型 (默认: pick)")
|
|
182
|
+
parser.add_argument("--task-id", type=str, default=None,
|
|
183
|
+
help="任务ID(启动训练/查询状态/停止训练时必填)")
|
|
184
|
+
parser.add_argument("--ability-id", type=str, default=None,
|
|
185
|
+
help="原子动作ID(启动训练/查询状态/停止训练时必填)")
|
|
186
|
+
parser.add_argument("--dataset-id", type=str, default=None,
|
|
187
|
+
help="数据集ID(启动训练时必填)")
|
|
188
|
+
parser.add_argument("--model-id", type=str, default=None,
|
|
189
|
+
help="模型ID(启动训练/查询状态/停止训练时必填)")
|
|
190
|
+
|
|
191
|
+
parser.add_argument("--batch_size", type=int, default=48,
|
|
192
|
+
help="批次大小 (默认: 48)")
|
|
193
|
+
parser.add_argument("--seq_len", type=int, default=4,
|
|
194
|
+
help="序列长度 (默认: 4)")
|
|
195
|
+
parser.add_argument("--action_chunk", type=int, default=8,
|
|
196
|
+
help="动作块大小 (默认: 8)")
|
|
197
|
+
parser.add_argument("--lr", type=float, default=1e-4,
|
|
198
|
+
help="学习率 (默认: 1e-4)")
|
|
199
|
+
parser.add_argument("--num_epochs", type=int, default=500,
|
|
200
|
+
help="训练轮数 (默认: 500)")
|
|
201
|
+
parser.add_argument("--start_epoch", type=int, default=0,
|
|
202
|
+
help="起始轮数,用于断点续训 (默认: 0)")
|
|
203
|
+
parser.add_argument("--lambda_joints", type=float, default=10.0,
|
|
204
|
+
help="关节损失权重 (默认: 10.0)")
|
|
205
|
+
parser.add_argument("--lambda_grip", type=float, default=5.0,
|
|
206
|
+
help="夹爪损失权重 (默认: 5.0)")
|
|
207
|
+
parser.add_argument("--lambda_success", type=float, default=2.0,
|
|
208
|
+
help="成功标志损失权重 (默认: 2.0)")
|
|
209
|
+
|
|
210
|
+
parser.add_argument("--log_path", type=str, default=None,
|
|
211
|
+
help="日志文件路径 (可选)")
|
|
212
|
+
parser.add_argument("--ckpt_dir", type=str, default=None,
|
|
213
|
+
help="检查点保存目录 (可选)")
|
|
214
|
+
parser.add_argument("--success_mode", type=str, default="within_horizon",
|
|
215
|
+
choices=["within_horizon", "terminal_only"],
|
|
216
|
+
help="成功标志模式 (默认: within_horizon)")
|
|
217
|
+
parser.add_argument("--report_url", type=str, default=None,
|
|
218
|
+
help="训练过程上报回调地址 (可选)")
|
|
219
|
+
|
|
220
|
+
parser.add_argument("--monitor", action="store_true",
|
|
221
|
+
help="启动训练后监控任务状态")
|
|
222
|
+
parser.add_argument("--monitor_interval", type=int, default=10,
|
|
223
|
+
help="监控查询间隔(秒)(默认: 10)")
|
|
224
|
+
|
|
225
|
+
parser.add_argument("--status", action="store_true",
|
|
226
|
+
help="查询任务状态(不启动新训练)")
|
|
227
|
+
parser.add_argument("--stop", action="store_true",
|
|
228
|
+
help="停止指定任务的训练")
|
|
229
|
+
|
|
230
|
+
args = parser.parse_args()
|
|
231
|
+
client = TrainClient(base_url=args.url)
|
|
232
|
+
|
|
233
|
+
if args.status:
|
|
234
|
+
if not args.task_id or not args.ability_id or not args.model_id:
|
|
235
|
+
print("查询状态需要提供 --task-id、--ability-id 和 --model-id")
|
|
236
|
+
return 1
|
|
237
|
+
try:
|
|
238
|
+
status = client.get_training_status(args.task_id, args.ability_id, args.model_id)
|
|
239
|
+
print(json.dumps(status, indent=2, ensure_ascii=False))
|
|
240
|
+
except Exception as e:
|
|
241
|
+
logger.error(f"查询失败: {e}")
|
|
242
|
+
print(f"查询失败: {e}")
|
|
243
|
+
return 1
|
|
244
|
+
return 0
|
|
245
|
+
|
|
246
|
+
if args.stop:
|
|
247
|
+
if not args.task_id or not args.ability_id or not args.model_id:
|
|
248
|
+
print("停止训练需要提供 --task-id、--ability-id 和 --model-id")
|
|
249
|
+
return 1
|
|
250
|
+
try:
|
|
251
|
+
result = client.stop_training(args.task_id, args.ability_id, args.model_id)
|
|
252
|
+
print(json.dumps(result, indent=2, ensure_ascii=False))
|
|
253
|
+
except Exception as e:
|
|
254
|
+
logger.error(f"停止失败: {e}")
|
|
255
|
+
print(f"停止失败: {e}")
|
|
256
|
+
return 1
|
|
257
|
+
return 0
|
|
258
|
+
|
|
259
|
+
if not args.task_id or not args.ability_id or not args.dataset_id or not args.model_id:
|
|
260
|
+
print("启动训练需要提供 --task-id、--ability-id、--dataset-id 和 --model-id")
|
|
261
|
+
return 1
|
|
262
|
+
|
|
263
|
+
try:
|
|
264
|
+
result = client.start_training(
|
|
265
|
+
algo_type=args.algo_type,
|
|
266
|
+
action_type=args.action_type,
|
|
267
|
+
task_id=args.task_id,
|
|
268
|
+
ability_id=args.ability_id,
|
|
269
|
+
dataset_id=args.dataset_id,
|
|
270
|
+
model_id=args.model_id,
|
|
271
|
+
batch_size=args.batch_size,
|
|
272
|
+
seq_len=args.seq_len,
|
|
273
|
+
action_chunk=args.action_chunk,
|
|
274
|
+
lr=args.lr,
|
|
275
|
+
num_epochs=args.num_epochs,
|
|
276
|
+
start_epoch=args.start_epoch,
|
|
277
|
+
lambda_joints=args.lambda_joints,
|
|
278
|
+
lambda_grip=args.lambda_grip,
|
|
279
|
+
lambda_success=args.lambda_success,
|
|
280
|
+
log_path=args.log_path,
|
|
281
|
+
ckpt_dir=args.ckpt_dir,
|
|
282
|
+
success_mode=args.success_mode,
|
|
283
|
+
report_url=args.report_url,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
if args.monitor:
|
|
287
|
+
client.monitor_training(args.task_id, args.ability_id, args.model_id, interval=args.monitor_interval)
|
|
288
|
+
else:
|
|
289
|
+
print("\n提示: 使用以下命令查询训练状态:")
|
|
290
|
+
print(f" python train_client.py --status --task-id {args.task_id} --ability-id {args.ability_id} --model-id {args.model_id} --url {args.url}")
|
|
291
|
+
|
|
292
|
+
except Exception as e:
|
|
293
|
+
logger.error(f"启动训练失败: {e}")
|
|
294
|
+
print(f"启动训练失败: {e}")
|
|
295
|
+
return 1
|
|
296
|
+
|
|
297
|
+
return 0
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
if __name__ == "__main__":
|
|
301
|
+
raise SystemExit(main())
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
from zyworkflow.example.train_client import TrainClient
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
ALGO_TYPE = "bnn"
|
|
6
|
+
ACTION_TYPE = "pick"
|
|
7
|
+
TASK_ID = "t1"
|
|
8
|
+
ABILITY_ID = "a1"
|
|
9
|
+
DATASET_ID = "d1"
|
|
10
|
+
MODEL_ID = "m1"
|
|
11
|
+
API_URL = "http://localhost:8001"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def main():
|
|
15
|
+
client = TrainClient(base_url=API_URL)
|
|
16
|
+
result = client.start_training(
|
|
17
|
+
algo_type=ALGO_TYPE,
|
|
18
|
+
action_type=ACTION_TYPE,
|
|
19
|
+
task_id=TASK_ID,
|
|
20
|
+
ability_id=ABILITY_ID,
|
|
21
|
+
dataset_id=DATASET_ID,
|
|
22
|
+
model_id=MODEL_ID,
|
|
23
|
+
batch_size=48,
|
|
24
|
+
seq_len=4,
|
|
25
|
+
action_chunk=8,
|
|
26
|
+
lr=1e-4,
|
|
27
|
+
num_epochs=500,
|
|
28
|
+
start_epoch=0,
|
|
29
|
+
lambda_joints=10.0,
|
|
30
|
+
lambda_grip=5.0,
|
|
31
|
+
lambda_success=2.0,
|
|
32
|
+
success_mode="within_horizon",
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
task_name = result.get("task_name")
|
|
36
|
+
if task_name:
|
|
37
|
+
print(f"训练任务已启动,任务名称: {task_name}")
|
|
38
|
+
else:
|
|
39
|
+
print("未能获取任务名称")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
if __name__ == "__main__":
|
|
43
|
+
main()
|
|
File without changes
|