vlagents 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +0 -0
- tests/test_connection.py +80 -0
- vlagents/__init__.py +1 -0
- vlagents/__main__.py +400 -0
- vlagents/client.py +107 -0
- vlagents/evaluator_envs.py +541 -0
- vlagents/policies.py +694 -0
- vlagents/server.py +114 -0
- vlagents/wrappers.py +32 -0
- vlagents-0.0.1.dist-info/METADATA +231 -0
- vlagents-0.0.1.dist-info/RECORD +14 -0
- vlagents-0.0.1.dist-info/WHEEL +5 -0
- vlagents-0.0.1.dist-info/licenses/LICENSE +201 -0
- vlagents-0.0.1.dist-info/top_level.txt +2 -0
tests/__init__.py
ADDED
|
File without changes
|
tests/test_connection.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
import subprocess
|
|
2
|
+
from time import sleep
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from vlagents.client import RemoteAgent
|
|
7
|
+
from vlagents.evaluator_envs import start_server
|
|
8
|
+
from vlagents.policies import Act, Obs
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _test_connection(agent: RemoteAgent):
|
|
12
|
+
data = np.zeros((256, 256, 3), dtype=np.uint8)
|
|
13
|
+
data[2, 0, 0] = 16
|
|
14
|
+
obs = Obs(cameras=dict(rgb_side=data))
|
|
15
|
+
instruction = "do something"
|
|
16
|
+
reset_info = agent.reset(obs, instruction)
|
|
17
|
+
assert reset_info["instruction"] == instruction
|
|
18
|
+
assert reset_info["shapes"] == {"rgb_side": [256, 256, 3]}
|
|
19
|
+
assert reset_info["dtype"] == {"rgb_side": "uint8"}
|
|
20
|
+
assert (reset_info["data"]["rgb_side"] == data).all()
|
|
21
|
+
|
|
22
|
+
data[0, 0, 2] = 1
|
|
23
|
+
a1 = agent.act(Obs(cameras=dict(rgb_side=data)))
|
|
24
|
+
assert a1.info["shapes"] == {"rgb_side": [256, 256, 3]}
|
|
25
|
+
assert a1.info["dtype"] == {"rgb_side": "uint8"}
|
|
26
|
+
assert (a1.info["data"]["rgb_side"] == data).all()
|
|
27
|
+
assert np.all(a1.action == np.array([0, 0, 0, 0, 0, 0, 0], dtype=np.float32))
|
|
28
|
+
assert not a1.done
|
|
29
|
+
|
|
30
|
+
data[0, 2, 0] = 1
|
|
31
|
+
a1 = agent.act(Obs(cameras=dict(rgb_side=data)))
|
|
32
|
+
assert a1.info["shapes"] == {"rgb_side": [256, 256, 3]}
|
|
33
|
+
assert a1.info["dtype"] == {"rgb_side": "uint8"}
|
|
34
|
+
assert (a1.info["data"]["rgb_side"] == data).all()
|
|
35
|
+
assert np.all(a1.action == np.array([0, 0, 0, 0, 0, 0, 1], dtype=np.float32))
|
|
36
|
+
assert not a1.done
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _test_connection_jpeg(agent: RemoteAgent):
|
|
40
|
+
data = np.zeros((256, 256, 3), dtype=np.uint8)
|
|
41
|
+
obs = Obs(cameras=dict(rgb_side=data))
|
|
42
|
+
instruction = "do something"
|
|
43
|
+
reset_info = agent.reset(obs, instruction)
|
|
44
|
+
assert reset_info["instruction"] == instruction
|
|
45
|
+
assert reset_info["shapes"] == {"rgb_side": [256, 256, 3]}
|
|
46
|
+
assert reset_info["dtype"] == {"rgb_side": "uint8"}
|
|
47
|
+
assert (reset_info["data"]["rgb_side"] == data).all()
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def test_connection_numpy_serialization():
|
|
51
|
+
with start_server("test", {}, 8080, "localhost") as p:
|
|
52
|
+
sleep(2)
|
|
53
|
+
agent = RemoteAgent("localhost", 8080, "test")
|
|
54
|
+
with agent:
|
|
55
|
+
while not agent.is_initialized():
|
|
56
|
+
sleep(0.1)
|
|
57
|
+
_test_connection(agent)
|
|
58
|
+
p.send_signal(subprocess.signal.SIGINT)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def test_connection_numpy_shm():
|
|
62
|
+
with start_server("test", {}, 8080, "localhost") as p:
|
|
63
|
+
sleep(2)
|
|
64
|
+
agent = RemoteAgent("localhost", 8080, "test", on_same_machine=True)
|
|
65
|
+
with agent:
|
|
66
|
+
while not agent.is_initialized():
|
|
67
|
+
sleep(0.1)
|
|
68
|
+
_test_connection(agent)
|
|
69
|
+
p.send_signal(subprocess.signal.SIGINT)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def test_connection_numpy_jpeg():
|
|
73
|
+
with start_server("test", {}, 8080, "localhost") as p:
|
|
74
|
+
sleep(2)
|
|
75
|
+
agent = RemoteAgent("localhost", 8080, "test", jpeg_encoding=True)
|
|
76
|
+
with agent:
|
|
77
|
+
while not agent.is_initialized():
|
|
78
|
+
sleep(0.1)
|
|
79
|
+
_test_connection_jpeg(agent)
|
|
80
|
+
p.send_signal(subprocess.signal.SIGINT)
|
vlagents/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__all__ = ["__doc__", "__version__", "policies", "client", "server"]
|
vlagents/__main__.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from multiprocessing import Pool
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Annotated
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import rpyc
|
|
11
|
+
import typer
|
|
12
|
+
import wandb
|
|
13
|
+
|
|
14
|
+
# to use non-inline backend, necessary for the case
|
|
15
|
+
# when started from jupyter notebook
|
|
16
|
+
os.environ["MPLBACKEND"] = "Agg"
|
|
17
|
+
|
|
18
|
+
from vlagents.evaluator_envs import AgentConfig, EvalConfig, evaluation, write_results
|
|
19
|
+
from vlagents.policies import AGENTS
|
|
20
|
+
from vlagents.server import AgentService
|
|
21
|
+
|
|
22
|
+
main_app = typer.Typer(help="CLI tool for the vlagents library.")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def wandb_log_git_diff(path: str):
|
|
26
|
+
path = Path(path)
|
|
27
|
+
git_path = path / "git"
|
|
28
|
+
git_path.mkdir(parents=True, exist_ok=True)
|
|
29
|
+
log_git_diff(git_path)
|
|
30
|
+
wandb.log_artifact(git_path, type="directory", name="git")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def log_git_diff(path: str):
|
|
34
|
+
# git id
|
|
35
|
+
git_id = os.path.join(path, "git_id.txt")
|
|
36
|
+
os.system(f'git log --format="%H" -n 1 > {git_id}')
|
|
37
|
+
|
|
38
|
+
# submodule git ids
|
|
39
|
+
git_submodules = os.path.join(path, "git_submodules.txt")
|
|
40
|
+
os.system(f"git submodule status > {git_submodules}")
|
|
41
|
+
|
|
42
|
+
# get git diff
|
|
43
|
+
git_diff = os.path.join(path, "git_diff.txt")
|
|
44
|
+
os.system(f"git diff --submodule=diff > {git_diff}")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@main_app.command()
|
|
48
|
+
def start_server(
|
|
49
|
+
agent_name: Annotated[str, typer.Argument(help="Agent name to run.")],
|
|
50
|
+
kwargs: Annotated[str, typer.Option(help="args to start the agent.")] = "{}",
|
|
51
|
+
port: Annotated[int, typer.Option(help="Port to run the server on.")] = 8080,
|
|
52
|
+
host: Annotated[str, typer.Option(help="Host to run the server on.")] = "localhost",
|
|
53
|
+
):
|
|
54
|
+
"""Runs eval server."""
|
|
55
|
+
agent = AGENTS[agent_name](**json.loads(kwargs))
|
|
56
|
+
service = AgentService(agent, agent_name)
|
|
57
|
+
with service:
|
|
58
|
+
t = rpyc.ThreadedServer(
|
|
59
|
+
service, port=port, hostname=host, protocol_config={"allow_pickle": True, "allow_public_attrs": True}
|
|
60
|
+
)
|
|
61
|
+
t.start()
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _per_process(
|
|
65
|
+
args: tuple[int, AgentConfig, list[EvalConfig], int, int | None, int],
|
|
66
|
+
) -> tuple[np.ndarray, list[list[list[float]]], list[float], int]:
|
|
67
|
+
step, _agent_cfg, eval_cfgs, episodes, n_processes, nth_gpu = args
|
|
68
|
+
logging.info(f"Starting evaluation for step {step}")
|
|
69
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(nth_gpu)
|
|
70
|
+
os.environ["CAM_PATH"] = f"{os.environ['RUN_PATH']}/videos/{step}"
|
|
71
|
+
agent_cfg = copy.deepcopy(_agent_cfg)
|
|
72
|
+
agent_cfg.agent_kwargs["checkpoint_step"] = step
|
|
73
|
+
|
|
74
|
+
per_env_results_last_reward, per_env_results_rewards = evaluation(
|
|
75
|
+
agent_cfg=agent_cfg, eval_cfgs=eval_cfgs, episodes=episodes, n_processes=n_processes
|
|
76
|
+
)
|
|
77
|
+
logging.info(f"Finished evaluation for step {step}")
|
|
78
|
+
flatten_rewards = [[item for sublist in env_rewards for item in sublist] for env_rewards in per_env_results_rewards]
|
|
79
|
+
mean_rewards = [np.mean(env_rewards) if env_rewards else 0.0 for env_rewards in flatten_rewards]
|
|
80
|
+
logging.info("Returning results for step %s", step)
|
|
81
|
+
return per_env_results_last_reward, per_env_results_rewards, mean_rewards, step
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@main_app.command()
|
|
85
|
+
def run_eval_post_training(
|
|
86
|
+
wandb_project: Annotated[str, typer.Option(help="weights and biases logging project.")],
|
|
87
|
+
wandb_entity: Annotated[str, typer.Option(help="weights and biases logging entity.")],
|
|
88
|
+
wandb_note: Annotated[str, typer.Option(help="weights and biases logging note.")],
|
|
89
|
+
wandb_name: Annotated[str, typer.Option(help="weights and biases logging name.")],
|
|
90
|
+
output_path: Annotated[str, typer.Option(help="Path to store the run results.")],
|
|
91
|
+
wandb_group: Annotated[str | None, typer.Option(help="weights and biases logging name.")] = None,
|
|
92
|
+
steps: Annotated[str | None, typer.Option(help="steps to evaluate.")] = None,
|
|
93
|
+
episodes: Annotated[int, typer.Option(help="Number of episodes to run.")] = 100,
|
|
94
|
+
n_processes: Annotated[int | None, typer.Option(help="Number of processes to run.")] = None,
|
|
95
|
+
n_gpus: Annotated[int, typer.Option(help="Number of gpus to run.")] = 1,
|
|
96
|
+
eval_cfgs: Annotated[
|
|
97
|
+
str, typer.Option(help="Evaluation configurations.")
|
|
98
|
+
] = '[{"env": "rcs/SimplePickUpSim-v0", "kwargs": {}}]',
|
|
99
|
+
agent_cfg: Annotated[
|
|
100
|
+
str, typer.Option(help="Agent configuration.")
|
|
101
|
+
] = '{"host": "localhost", "port": 8080, "agent_name": "Test", "agent_kwargs": {}, "python_path": "python"}',
|
|
102
|
+
):
|
|
103
|
+
"""
|
|
104
|
+
post training eval which goes over all checkpoints
|
|
105
|
+
- each checkpoint with many envs
|
|
106
|
+
"""
|
|
107
|
+
if steps is None:
|
|
108
|
+
steps = [None]
|
|
109
|
+
else:
|
|
110
|
+
steps = json.loads(steps)
|
|
111
|
+
|
|
112
|
+
if wandb_group == "":
|
|
113
|
+
wandb_group = None
|
|
114
|
+
|
|
115
|
+
wandb.init(
|
|
116
|
+
entity=wandb_entity,
|
|
117
|
+
resume="allow",
|
|
118
|
+
project=wandb_project,
|
|
119
|
+
# config=dict(agent_name=agent_name, agent_kwargs=json.loads(kwargs), eval_cfgs=json.loads(eval_cfgs)),
|
|
120
|
+
notes=wandb_note,
|
|
121
|
+
job_type="eval",
|
|
122
|
+
name=wandb_name,
|
|
123
|
+
group=wandb_group,
|
|
124
|
+
)
|
|
125
|
+
wandb_log_git_diff(output_path)
|
|
126
|
+
wandb.run.log_code(".")
|
|
127
|
+
|
|
128
|
+
wandb.define_metric(
|
|
129
|
+
"total/success",
|
|
130
|
+
step_metric="train_step",
|
|
131
|
+
overwrite=False,
|
|
132
|
+
step_sync=False,
|
|
133
|
+
hidden=False,
|
|
134
|
+
summary="max",
|
|
135
|
+
)
|
|
136
|
+
wandb.define_metric(
|
|
137
|
+
"total/last_step_reward",
|
|
138
|
+
step_metric="train_step",
|
|
139
|
+
overwrite=False,
|
|
140
|
+
step_sync=False,
|
|
141
|
+
hidden=False,
|
|
142
|
+
summary="max",
|
|
143
|
+
)
|
|
144
|
+
wandb.define_metric(
|
|
145
|
+
"total/total_steps",
|
|
146
|
+
step_metric="train_step",
|
|
147
|
+
overwrite=False,
|
|
148
|
+
step_sync=False,
|
|
149
|
+
hidden=False,
|
|
150
|
+
summary="min",
|
|
151
|
+
)
|
|
152
|
+
wandb.define_metric(
|
|
153
|
+
"total/mean_reward",
|
|
154
|
+
step_metric="train_step",
|
|
155
|
+
overwrite=False,
|
|
156
|
+
step_sync=False,
|
|
157
|
+
hidden=False,
|
|
158
|
+
summary="max",
|
|
159
|
+
)
|
|
160
|
+
eval_cfgs = [EvalConfig(**cfg) for cfg in json.loads(eval_cfgs)]
|
|
161
|
+
for idx, env in enumerate(eval_cfgs):
|
|
162
|
+
wandb.define_metric(
|
|
163
|
+
f"{env.env_id}/success",
|
|
164
|
+
step_metric="train_step",
|
|
165
|
+
overwrite=False,
|
|
166
|
+
step_sync=False,
|
|
167
|
+
hidden=False,
|
|
168
|
+
summary="max",
|
|
169
|
+
)
|
|
170
|
+
wandb.define_metric(
|
|
171
|
+
f"{env.env_id}/last_step_reward",
|
|
172
|
+
step_metric="train_step",
|
|
173
|
+
overwrite=False,
|
|
174
|
+
step_sync=False,
|
|
175
|
+
hidden=False,
|
|
176
|
+
summary="max",
|
|
177
|
+
)
|
|
178
|
+
wandb.define_metric(
|
|
179
|
+
f"{env.env_id}/total_steps",
|
|
180
|
+
step_metric="train_step",
|
|
181
|
+
overwrite=False,
|
|
182
|
+
step_sync=False,
|
|
183
|
+
hidden=False,
|
|
184
|
+
summary="min",
|
|
185
|
+
)
|
|
186
|
+
wandb.define_metric(
|
|
187
|
+
f"{env.env_id}/mean_reward",
|
|
188
|
+
step_metric="train_step",
|
|
189
|
+
overwrite=False,
|
|
190
|
+
step_sync=False,
|
|
191
|
+
hidden=False,
|
|
192
|
+
summary="max",
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# distribute gpus equally
|
|
196
|
+
gpus_ids = [i % n_gpus for i in range(len(steps))]
|
|
197
|
+
|
|
198
|
+
# spawn n processes and run in parallel
|
|
199
|
+
|
|
200
|
+
agent_cfgs = [AgentConfig(**json.loads(agent_cfg)) for _ in steps]
|
|
201
|
+
for idx in range(len(steps)):
|
|
202
|
+
agent_cfgs[idx].port += idx
|
|
203
|
+
with Pool(n_processes) as p:
|
|
204
|
+
args = [(step, agent_cfgs[idx], eval_cfgs, episodes, 1, gpus_ids[idx]) for idx, step in enumerate(steps)]
|
|
205
|
+
results = p.map(_per_process, args)
|
|
206
|
+
logging.info("Finished evaluation")
|
|
207
|
+
|
|
208
|
+
for result in results:
|
|
209
|
+
per_env_results_last_reward, per_env_results_rewards, mean_rewards, step = result
|
|
210
|
+
step = step if step is not None else 0
|
|
211
|
+
wandb_log_dict = {
|
|
212
|
+
"total/success": per_env_results_last_reward.mean(axis=(0, 1))[0],
|
|
213
|
+
"total/last_step_reward": per_env_results_last_reward.mean(axis=(0, 1))[1],
|
|
214
|
+
"total/total_steps": per_env_results_last_reward.mean(axis=(0, 1))[2],
|
|
215
|
+
"total/mean_reward": np.mean(mean_rewards),
|
|
216
|
+
"train_step": step,
|
|
217
|
+
}
|
|
218
|
+
# log for each env
|
|
219
|
+
for idx, env in enumerate(eval_cfgs):
|
|
220
|
+
wandb_log_dict.update(
|
|
221
|
+
{
|
|
222
|
+
f"{env.env_id}/success": per_env_results_last_reward[idx].mean(axis=0)[0],
|
|
223
|
+
f"{env.env_id}/last_step_reward": per_env_results_last_reward[idx].mean(axis=0)[1],
|
|
224
|
+
f"{env.env_id}/total_steps": per_env_results_last_reward[idx].mean(axis=0)[2],
|
|
225
|
+
f"{env.env_id}/mean_reward": mean_rewards[idx],
|
|
226
|
+
}
|
|
227
|
+
)
|
|
228
|
+
wandb.log(wandb_log_dict, step=step, commit=True)
|
|
229
|
+
|
|
230
|
+
path = write_results(
|
|
231
|
+
per_env_results_last_reward,
|
|
232
|
+
per_env_results_rewards,
|
|
233
|
+
eval_cfgs,
|
|
234
|
+
agent_cfg=agent_cfgs[0],
|
|
235
|
+
out=output_path,
|
|
236
|
+
)
|
|
237
|
+
wandb.log_artifact(path, type="file", name="results", aliases=[f"step_{step}"])
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
@main_app.command()
|
|
241
|
+
def run_eval_during_training(
|
|
242
|
+
wandb_id: Annotated[str, typer.Option(help="weights and biases logging id.")],
|
|
243
|
+
wandb_group: Annotated[str, typer.Option(help="weights and biases logging group.")],
|
|
244
|
+
wandb_project: Annotated[str, typer.Option(help="weights and biases logging project.")],
|
|
245
|
+
wandb_entity: Annotated[str, typer.Option(help="weights and biases logging entity.")],
|
|
246
|
+
wandb_note: Annotated[str, typer.Option(help="weights and biases logging note.")],
|
|
247
|
+
wandb_name: Annotated[str, typer.Option(help="weights and biases logging name.")],
|
|
248
|
+
output_path: Annotated[str, typer.Option(help="Path to store the run results.")],
|
|
249
|
+
wandb_first: Annotated[bool, typer.Option(help="whether its the first eval.")] = False,
|
|
250
|
+
episodes: Annotated[int, typer.Option(help="Number of episodes to run.")] = 100,
|
|
251
|
+
n_processes: Annotated[int | None, typer.Option(help="Number of processes to run.")] = None,
|
|
252
|
+
eval_cfgs: Annotated[
|
|
253
|
+
str, typer.Option(help="Evaluation configurations.")
|
|
254
|
+
] = '[{"env": "rcs/SimplePickUpSim-v0", "kwargs": {}}]',
|
|
255
|
+
agent_cfg: Annotated[
|
|
256
|
+
str, typer.Option(help="Agent configuration.")
|
|
257
|
+
] = '{"host": "localhost", "port": 8080, "agent_name": "Test", "agent_kwargs": {}, "python_path": "python"}',
|
|
258
|
+
):
|
|
259
|
+
"""
|
|
260
|
+
during training eval, all need to use the same id
|
|
261
|
+
- just for one model, but many envs
|
|
262
|
+
- can be new run but at least in the same project and same group as the training
|
|
263
|
+
"""
|
|
264
|
+
assert (
|
|
265
|
+
agent_cfg["agent_name"] != "Test"
|
|
266
|
+
), "agent_cfg needs to be passed as a json argument. See the default for an example."
|
|
267
|
+
|
|
268
|
+
if wandb_first:
|
|
269
|
+
wandb.init(
|
|
270
|
+
id=wandb_id,
|
|
271
|
+
entity=wandb_entity,
|
|
272
|
+
resume="allow",
|
|
273
|
+
group=wandb_group,
|
|
274
|
+
project=wandb_project,
|
|
275
|
+
# config=dict(agent_name=agent_name, agent_kwargs=json.loads(kwargs), eval_cfgs=json.loads(eval_cfgs)),
|
|
276
|
+
notes=wandb_note,
|
|
277
|
+
job_type="eval",
|
|
278
|
+
name=wandb_name,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
wandb_log_git_diff(output_path)
|
|
282
|
+
wandb.run.log_code(".")
|
|
283
|
+
else:
|
|
284
|
+
wandb.init(id=wandb_id, entity=wandb_entity, resume="must", project=wandb_project)
|
|
285
|
+
|
|
286
|
+
eval_cfgs = [EvalConfig(**cfg) for cfg in json.loads(eval_cfgs)]
|
|
287
|
+
|
|
288
|
+
agent_cfg = AgentConfig(**json.loads(agent_cfg))
|
|
289
|
+
step = agent_cfg.agent_kwargs.get("checkpoint_step", 0)
|
|
290
|
+
step = step if step is not None else 0
|
|
291
|
+
|
|
292
|
+
per_env_results_last_reward, per_env_results_rewards = evaluation(
|
|
293
|
+
agent_cfg=agent_cfg, eval_cfgs=eval_cfgs, episodes=episodes, n_processes=n_processes
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
# return is [envs, episodes, 3(success, reward, steps)], [envs, episodes, rewards for all steps in the episode]
|
|
297
|
+
|
|
298
|
+
flatten_rewards = [[item for sublist in env_rewards for item in sublist] for env_rewards in per_env_results_rewards]
|
|
299
|
+
mean_rewards = [np.mean(env_rewards) if env_rewards else 0.0 for env_rewards in flatten_rewards]
|
|
300
|
+
|
|
301
|
+
# these new define metric to also not work with several jobs
|
|
302
|
+
# wandb says that logging can only be done in run groups
|
|
303
|
+
if wandb_first:
|
|
304
|
+
wandb.define_metric(
|
|
305
|
+
"total/success",
|
|
306
|
+
step_metric="train_step",
|
|
307
|
+
overwrite=False,
|
|
308
|
+
step_sync=False,
|
|
309
|
+
hidden=False,
|
|
310
|
+
summary="max",
|
|
311
|
+
)
|
|
312
|
+
wandb.define_metric(
|
|
313
|
+
"total/last_step_reward",
|
|
314
|
+
step_metric="train_step",
|
|
315
|
+
overwrite=False,
|
|
316
|
+
step_sync=False,
|
|
317
|
+
hidden=False,
|
|
318
|
+
summary="max",
|
|
319
|
+
)
|
|
320
|
+
wandb.define_metric(
|
|
321
|
+
"total/total_steps",
|
|
322
|
+
step_metric="train_step",
|
|
323
|
+
overwrite=False,
|
|
324
|
+
step_sync=False,
|
|
325
|
+
hidden=False,
|
|
326
|
+
summary="min",
|
|
327
|
+
)
|
|
328
|
+
wandb.define_metric(
|
|
329
|
+
"total/mean_reward",
|
|
330
|
+
step_metric="train_step",
|
|
331
|
+
overwrite=False,
|
|
332
|
+
step_sync=False,
|
|
333
|
+
hidden=False,
|
|
334
|
+
summary="max",
|
|
335
|
+
)
|
|
336
|
+
for idx, env in enumerate(eval_cfgs):
|
|
337
|
+
wandb.define_metric(
|
|
338
|
+
f"{env.env_id}/success",
|
|
339
|
+
step_metric="train_step",
|
|
340
|
+
overwrite=False,
|
|
341
|
+
step_sync=False,
|
|
342
|
+
hidden=False,
|
|
343
|
+
summary="max",
|
|
344
|
+
)
|
|
345
|
+
wandb.define_metric(
|
|
346
|
+
f"{env.env_id}/last_step_reward",
|
|
347
|
+
step_metric="train_step",
|
|
348
|
+
overwrite=False,
|
|
349
|
+
step_sync=False,
|
|
350
|
+
hidden=False,
|
|
351
|
+
summary="max",
|
|
352
|
+
)
|
|
353
|
+
wandb.define_metric(
|
|
354
|
+
f"{env.env_id}/total_steps",
|
|
355
|
+
step_metric="train_step",
|
|
356
|
+
overwrite=False,
|
|
357
|
+
step_sync=False,
|
|
358
|
+
hidden=False,
|
|
359
|
+
summary="min",
|
|
360
|
+
)
|
|
361
|
+
wandb.define_metric(
|
|
362
|
+
f"{env.env_id}/mean_reward",
|
|
363
|
+
step_metric="train_step",
|
|
364
|
+
overwrite=False,
|
|
365
|
+
step_sync=False,
|
|
366
|
+
hidden=False,
|
|
367
|
+
summary="max",
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
wandb_log_dict = {
|
|
371
|
+
"total/success": per_env_results_last_reward.mean(axis=(0, 1))[0],
|
|
372
|
+
"total/last_step_reward": per_env_results_last_reward.mean(axis=(0, 1))[1],
|
|
373
|
+
"total/total_steps": per_env_results_last_reward.mean(axis=(0, 1))[2],
|
|
374
|
+
"total/mean_reward": np.mean(mean_rewards),
|
|
375
|
+
"train_step": step,
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
# log for each env
|
|
379
|
+
for idx, env in enumerate(eval_cfgs):
|
|
380
|
+
wandb_log_dict.update(
|
|
381
|
+
{
|
|
382
|
+
f"{env.env_id}/success": per_env_results_last_reward[idx].mean(axis=0)[0],
|
|
383
|
+
f"{env.env_id}/last_step_reward": per_env_results_last_reward[idx].mean(axis=0)[1],
|
|
384
|
+
f"{env.env_id}/total_steps": per_env_results_last_reward[idx].mean(axis=0)[2],
|
|
385
|
+
f"{env.env_id}/mean_reward": mean_rewards[idx],
|
|
386
|
+
}
|
|
387
|
+
)
|
|
388
|
+
wandb.log(wandb_log_dict, step=step, commit=True)
|
|
389
|
+
path = write_results(
|
|
390
|
+
per_env_results_last_reward,
|
|
391
|
+
per_env_results_rewards,
|
|
392
|
+
eval_cfgs,
|
|
393
|
+
agent_cfg=agent_cfg,
|
|
394
|
+
out=output_path,
|
|
395
|
+
)
|
|
396
|
+
wandb.log_artifact(path, type="file", name="results", aliases=[f"step_{step}"])
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
if __name__ == "__main__":
|
|
400
|
+
main_app()
|
vlagents/client.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import dataclasses
|
|
3
|
+
from dataclasses import asdict
|
|
4
|
+
from multiprocessing import shared_memory
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import json_numpy
|
|
8
|
+
import numpy as np
|
|
9
|
+
import rpyc
|
|
10
|
+
import simplejpeg
|
|
11
|
+
|
|
12
|
+
from vlagents.policies import Act, Agent, CameraDataType, Obs, SharedMemoryPayload
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def dataclass_from_dict(klass, d):
|
|
16
|
+
# https://stackoverflow.com/questions/53376099/python-dataclass-from-a-nested-dict
|
|
17
|
+
try:
|
|
18
|
+
fieldtypes = {f.name: f.type for f in dataclasses.fields(klass)}
|
|
19
|
+
return klass(**{f: dataclass_from_dict(fieldtypes[f], d[f]) for f in d})
|
|
20
|
+
except:
|
|
21
|
+
return d # Not a dataclass field
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class RemoteAgent(Agent):
|
|
25
|
+
def __init__(self, host: str, port: int, model: str, on_same_machine: bool = False, jpeg_encoding: bool = False):
|
|
26
|
+
"""Connect to a remote agent service.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
host (str): Hostname or IP address of the remote agent service.
|
|
30
|
+
port (int): Port number of the remote agent service.
|
|
31
|
+
model (str): Name of the model to connect to.
|
|
32
|
+
on_same_machine (bool, optional): If True, assumes the agent is running on the same machine and uses
|
|
33
|
+
shared memory for more efficient communication. Defaults to False.
|
|
34
|
+
jpeg_encoding (bool, optional): If True the image data is jpeg encoded for smaller transfer size.
|
|
35
|
+
Defaults to False.
|
|
36
|
+
"""
|
|
37
|
+
self.on_same_machine = on_same_machine
|
|
38
|
+
self.jpeg_encoding = jpeg_encoding
|
|
39
|
+
self._shm: dict[str, shared_memory.SharedMemory] = {}
|
|
40
|
+
self.c = rpyc.connect(
|
|
41
|
+
host, port, config={"allow_pickle": True, "allow_public_attrs": True, "sync_request_timeout": 300}
|
|
42
|
+
)
|
|
43
|
+
assert model == self.c.root.name()
|
|
44
|
+
|
|
45
|
+
def _process(self, obs: Obs) -> Obs:
|
|
46
|
+
if self.on_same_machine:
|
|
47
|
+
camera_dict = {}
|
|
48
|
+
for camera_name, camera_data in obs.cameras.items():
|
|
49
|
+
assert isinstance(camera_data, np.ndarray)
|
|
50
|
+
if camera_name not in self._shm:
|
|
51
|
+
self._shm[camera_name] = shared_memory.SharedMemory(create=True, size=camera_data.nbytes)
|
|
52
|
+
camera_shared = np.ndarray(
|
|
53
|
+
camera_data.shape, buffer=self._shm[camera_name].buf, dtype=camera_data.dtype
|
|
54
|
+
)
|
|
55
|
+
camera_shared[:] = camera_data[:]
|
|
56
|
+
camera_dict[camera_name] = SharedMemoryPayload(
|
|
57
|
+
shm_name=self._shm[camera_name].name,
|
|
58
|
+
shape=camera_data.shape,
|
|
59
|
+
dtype=camera_data.dtype.name,
|
|
60
|
+
)
|
|
61
|
+
obs.cameras = camera_dict
|
|
62
|
+
obs.camera_data_type = CameraDataType.SHARED_MEMORY
|
|
63
|
+
elif self.jpeg_encoding:
|
|
64
|
+
camera_dict = {}
|
|
65
|
+
for camera_name, camera_data in obs.cameras.items():
|
|
66
|
+
assert isinstance(camera_data, np.ndarray)
|
|
67
|
+
camera_dict[camera_name] = base64.urlsafe_b64encode(
|
|
68
|
+
simplejpeg.encode_jpeg(np.ascontiguousarray(camera_data))
|
|
69
|
+
).decode("utf-8")
|
|
70
|
+
obs.cameras = camera_dict
|
|
71
|
+
obs.camera_data_type = CameraDataType.JPEG_ENCODED
|
|
72
|
+
return obs
|
|
73
|
+
|
|
74
|
+
def act(self, obs: Obs) -> Act:
|
|
75
|
+
obs = self._process(obs)
|
|
76
|
+
obs = json_numpy.dumps(asdict(obs))
|
|
77
|
+
# action, done, info
|
|
78
|
+
return dataclass_from_dict(Act, json_numpy.loads(self.c.root.act(obs)))
|
|
79
|
+
|
|
80
|
+
def reset(self, obs: Obs, instruction: Any, **kwargs) -> dict[str, Any]:
|
|
81
|
+
obs = self._process(obs)
|
|
82
|
+
obs_dict = asdict(obs)
|
|
83
|
+
# info
|
|
84
|
+
return json_numpy.loads(self.c.root.reset(json_numpy.dumps((obs_dict, instruction, kwargs))))
|
|
85
|
+
|
|
86
|
+
def git_status(self) -> str:
|
|
87
|
+
return json_numpy.loads(self.c.root.git_status())
|
|
88
|
+
|
|
89
|
+
def is_initialized(self) -> bool:
|
|
90
|
+
return self.c.root.is_initialized()
|
|
91
|
+
|
|
92
|
+
def close(self):
|
|
93
|
+
for shm in self._shm.values():
|
|
94
|
+
shm.close()
|
|
95
|
+
shm.unlink()
|
|
96
|
+
self._shm = {}
|
|
97
|
+
self.c.close()
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
if __name__ == "__main__":
|
|
101
|
+
# to test the connection
|
|
102
|
+
agent = RemoteAgent("localhost", 8080, "test")
|
|
103
|
+
obs = Obs(cameras={"rgb_side": np.zeros((256, 256, 3), dtype=np.uint8)})
|
|
104
|
+
instruction = "do something"
|
|
105
|
+
agent.reset(obs, instruction)
|
|
106
|
+
print(agent.act(obs))
|
|
107
|
+
print(agent.act(obs))
|