vlagents 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tests/__init__.py ADDED
File without changes
@@ -0,0 +1,80 @@
1
+ import subprocess
2
+ from time import sleep
3
+
4
+ import numpy as np
5
+
6
+ from vlagents.client import RemoteAgent
7
+ from vlagents.evaluator_envs import start_server
8
+ from vlagents.policies import Act, Obs
9
+
10
+
11
+ def _test_connection(agent: RemoteAgent):
12
+ data = np.zeros((256, 256, 3), dtype=np.uint8)
13
+ data[2, 0, 0] = 16
14
+ obs = Obs(cameras=dict(rgb_side=data))
15
+ instruction = "do something"
16
+ reset_info = agent.reset(obs, instruction)
17
+ assert reset_info["instruction"] == instruction
18
+ assert reset_info["shapes"] == {"rgb_side": [256, 256, 3]}
19
+ assert reset_info["dtype"] == {"rgb_side": "uint8"}
20
+ assert (reset_info["data"]["rgb_side"] == data).all()
21
+
22
+ data[0, 0, 2] = 1
23
+ a1 = agent.act(Obs(cameras=dict(rgb_side=data)))
24
+ assert a1.info["shapes"] == {"rgb_side": [256, 256, 3]}
25
+ assert a1.info["dtype"] == {"rgb_side": "uint8"}
26
+ assert (a1.info["data"]["rgb_side"] == data).all()
27
+ assert np.all(a1.action == np.array([0, 0, 0, 0, 0, 0, 0], dtype=np.float32))
28
+ assert not a1.done
29
+
30
+ data[0, 2, 0] = 1
31
+ a1 = agent.act(Obs(cameras=dict(rgb_side=data)))
32
+ assert a1.info["shapes"] == {"rgb_side": [256, 256, 3]}
33
+ assert a1.info["dtype"] == {"rgb_side": "uint8"}
34
+ assert (a1.info["data"]["rgb_side"] == data).all()
35
+ assert np.all(a1.action == np.array([0, 0, 0, 0, 0, 0, 1], dtype=np.float32))
36
+ assert not a1.done
37
+
38
+
39
+ def _test_connection_jpeg(agent: RemoteAgent):
40
+ data = np.zeros((256, 256, 3), dtype=np.uint8)
41
+ obs = Obs(cameras=dict(rgb_side=data))
42
+ instruction = "do something"
43
+ reset_info = agent.reset(obs, instruction)
44
+ assert reset_info["instruction"] == instruction
45
+ assert reset_info["shapes"] == {"rgb_side": [256, 256, 3]}
46
+ assert reset_info["dtype"] == {"rgb_side": "uint8"}
47
+ assert (reset_info["data"]["rgb_side"] == data).all()
48
+
49
+
50
+ def test_connection_numpy_serialization():
51
+ with start_server("test", {}, 8080, "localhost") as p:
52
+ sleep(2)
53
+ agent = RemoteAgent("localhost", 8080, "test")
54
+ with agent:
55
+ while not agent.is_initialized():
56
+ sleep(0.1)
57
+ _test_connection(agent)
58
+ p.send_signal(subprocess.signal.SIGINT)
59
+
60
+
61
+ def test_connection_numpy_shm():
62
+ with start_server("test", {}, 8080, "localhost") as p:
63
+ sleep(2)
64
+ agent = RemoteAgent("localhost", 8080, "test", on_same_machine=True)
65
+ with agent:
66
+ while not agent.is_initialized():
67
+ sleep(0.1)
68
+ _test_connection(agent)
69
+ p.send_signal(subprocess.signal.SIGINT)
70
+
71
+
72
+ def test_connection_numpy_jpeg():
73
+ with start_server("test", {}, 8080, "localhost") as p:
74
+ sleep(2)
75
+ agent = RemoteAgent("localhost", 8080, "test", jpeg_encoding=True)
76
+ with agent:
77
+ while not agent.is_initialized():
78
+ sleep(0.1)
79
+ _test_connection_jpeg(agent)
80
+ p.send_signal(subprocess.signal.SIGINT)
vlagents/__init__.py ADDED
@@ -0,0 +1 @@
1
+ __all__ = ["__doc__", "__version__", "policies", "client", "server"]
vlagents/__main__.py ADDED
@@ -0,0 +1,400 @@
1
+ import copy
2
+ import json
3
+ import logging
4
+ import os
5
+ from multiprocessing import Pool
6
+ from pathlib import Path
7
+ from typing import Annotated
8
+
9
+ import numpy as np
10
+ import rpyc
11
+ import typer
12
+ import wandb
13
+
14
+ # to use non-inline backend, necessary for the case
15
+ # when started from jupyter notebook
16
+ os.environ["MPLBACKEND"] = "Agg"
17
+
18
+ from vlagents.evaluator_envs import AgentConfig, EvalConfig, evaluation, write_results
19
+ from vlagents.policies import AGENTS
20
+ from vlagents.server import AgentService
21
+
22
+ main_app = typer.Typer(help="CLI tool for the vlagents library.")
23
+
24
+
25
+ def wandb_log_git_diff(path: str):
26
+ path = Path(path)
27
+ git_path = path / "git"
28
+ git_path.mkdir(parents=True, exist_ok=True)
29
+ log_git_diff(git_path)
30
+ wandb.log_artifact(git_path, type="directory", name="git")
31
+
32
+
33
+ def log_git_diff(path: str):
34
+ # git id
35
+ git_id = os.path.join(path, "git_id.txt")
36
+ os.system(f'git log --format="%H" -n 1 > {git_id}')
37
+
38
+ # submodule git ids
39
+ git_submodules = os.path.join(path, "git_submodules.txt")
40
+ os.system(f"git submodule status > {git_submodules}")
41
+
42
+ # get git diff
43
+ git_diff = os.path.join(path, "git_diff.txt")
44
+ os.system(f"git diff --submodule=diff > {git_diff}")
45
+
46
+
47
+ @main_app.command()
48
+ def start_server(
49
+ agent_name: Annotated[str, typer.Argument(help="Agent name to run.")],
50
+ kwargs: Annotated[str, typer.Option(help="args to start the agent.")] = "{}",
51
+ port: Annotated[int, typer.Option(help="Port to run the server on.")] = 8080,
52
+ host: Annotated[str, typer.Option(help="Host to run the server on.")] = "localhost",
53
+ ):
54
+ """Runs eval server."""
55
+ agent = AGENTS[agent_name](**json.loads(kwargs))
56
+ service = AgentService(agent, agent_name)
57
+ with service:
58
+ t = rpyc.ThreadedServer(
59
+ service, port=port, hostname=host, protocol_config={"allow_pickle": True, "allow_public_attrs": True}
60
+ )
61
+ t.start()
62
+
63
+
64
+ def _per_process(
65
+ args: tuple[int, AgentConfig, list[EvalConfig], int, int | None, int],
66
+ ) -> tuple[np.ndarray, list[list[list[float]]], list[float], int]:
67
+ step, _agent_cfg, eval_cfgs, episodes, n_processes, nth_gpu = args
68
+ logging.info(f"Starting evaluation for step {step}")
69
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(nth_gpu)
70
+ os.environ["CAM_PATH"] = f"{os.environ['RUN_PATH']}/videos/{step}"
71
+ agent_cfg = copy.deepcopy(_agent_cfg)
72
+ agent_cfg.agent_kwargs["checkpoint_step"] = step
73
+
74
+ per_env_results_last_reward, per_env_results_rewards = evaluation(
75
+ agent_cfg=agent_cfg, eval_cfgs=eval_cfgs, episodes=episodes, n_processes=n_processes
76
+ )
77
+ logging.info(f"Finished evaluation for step {step}")
78
+ flatten_rewards = [[item for sublist in env_rewards for item in sublist] for env_rewards in per_env_results_rewards]
79
+ mean_rewards = [np.mean(env_rewards) if env_rewards else 0.0 for env_rewards in flatten_rewards]
80
+ logging.info("Returning results for step %s", step)
81
+ return per_env_results_last_reward, per_env_results_rewards, mean_rewards, step
82
+
83
+
84
+ @main_app.command()
85
+ def run_eval_post_training(
86
+ wandb_project: Annotated[str, typer.Option(help="weights and biases logging project.")],
87
+ wandb_entity: Annotated[str, typer.Option(help="weights and biases logging entity.")],
88
+ wandb_note: Annotated[str, typer.Option(help="weights and biases logging note.")],
89
+ wandb_name: Annotated[str, typer.Option(help="weights and biases logging name.")],
90
+ output_path: Annotated[str, typer.Option(help="Path to store the run results.")],
91
+ wandb_group: Annotated[str | None, typer.Option(help="weights and biases logging name.")] = None,
92
+ steps: Annotated[str | None, typer.Option(help="steps to evaluate.")] = None,
93
+ episodes: Annotated[int, typer.Option(help="Number of episodes to run.")] = 100,
94
+ n_processes: Annotated[int | None, typer.Option(help="Number of processes to run.")] = None,
95
+ n_gpus: Annotated[int, typer.Option(help="Number of gpus to run.")] = 1,
96
+ eval_cfgs: Annotated[
97
+ str, typer.Option(help="Evaluation configurations.")
98
+ ] = '[{"env": "rcs/SimplePickUpSim-v0", "kwargs": {}}]',
99
+ agent_cfg: Annotated[
100
+ str, typer.Option(help="Agent configuration.")
101
+ ] = '{"host": "localhost", "port": 8080, "agent_name": "Test", "agent_kwargs": {}, "python_path": "python"}',
102
+ ):
103
+ """
104
+ post training eval which goes over all checkpoints
105
+ - each checkpoint with many envs
106
+ """
107
+ if steps is None:
108
+ steps = [None]
109
+ else:
110
+ steps = json.loads(steps)
111
+
112
+ if wandb_group == "":
113
+ wandb_group = None
114
+
115
+ wandb.init(
116
+ entity=wandb_entity,
117
+ resume="allow",
118
+ project=wandb_project,
119
+ # config=dict(agent_name=agent_name, agent_kwargs=json.loads(kwargs), eval_cfgs=json.loads(eval_cfgs)),
120
+ notes=wandb_note,
121
+ job_type="eval",
122
+ name=wandb_name,
123
+ group=wandb_group,
124
+ )
125
+ wandb_log_git_diff(output_path)
126
+ wandb.run.log_code(".")
127
+
128
+ wandb.define_metric(
129
+ "total/success",
130
+ step_metric="train_step",
131
+ overwrite=False,
132
+ step_sync=False,
133
+ hidden=False,
134
+ summary="max",
135
+ )
136
+ wandb.define_metric(
137
+ "total/last_step_reward",
138
+ step_metric="train_step",
139
+ overwrite=False,
140
+ step_sync=False,
141
+ hidden=False,
142
+ summary="max",
143
+ )
144
+ wandb.define_metric(
145
+ "total/total_steps",
146
+ step_metric="train_step",
147
+ overwrite=False,
148
+ step_sync=False,
149
+ hidden=False,
150
+ summary="min",
151
+ )
152
+ wandb.define_metric(
153
+ "total/mean_reward",
154
+ step_metric="train_step",
155
+ overwrite=False,
156
+ step_sync=False,
157
+ hidden=False,
158
+ summary="max",
159
+ )
160
+ eval_cfgs = [EvalConfig(**cfg) for cfg in json.loads(eval_cfgs)]
161
+ for idx, env in enumerate(eval_cfgs):
162
+ wandb.define_metric(
163
+ f"{env.env_id}/success",
164
+ step_metric="train_step",
165
+ overwrite=False,
166
+ step_sync=False,
167
+ hidden=False,
168
+ summary="max",
169
+ )
170
+ wandb.define_metric(
171
+ f"{env.env_id}/last_step_reward",
172
+ step_metric="train_step",
173
+ overwrite=False,
174
+ step_sync=False,
175
+ hidden=False,
176
+ summary="max",
177
+ )
178
+ wandb.define_metric(
179
+ f"{env.env_id}/total_steps",
180
+ step_metric="train_step",
181
+ overwrite=False,
182
+ step_sync=False,
183
+ hidden=False,
184
+ summary="min",
185
+ )
186
+ wandb.define_metric(
187
+ f"{env.env_id}/mean_reward",
188
+ step_metric="train_step",
189
+ overwrite=False,
190
+ step_sync=False,
191
+ hidden=False,
192
+ summary="max",
193
+ )
194
+
195
+ # distribute gpus equally
196
+ gpus_ids = [i % n_gpus for i in range(len(steps))]
197
+
198
+ # spawn n processes and run in parallel
199
+
200
+ agent_cfgs = [AgentConfig(**json.loads(agent_cfg)) for _ in steps]
201
+ for idx in range(len(steps)):
202
+ agent_cfgs[idx].port += idx
203
+ with Pool(n_processes) as p:
204
+ args = [(step, agent_cfgs[idx], eval_cfgs, episodes, 1, gpus_ids[idx]) for idx, step in enumerate(steps)]
205
+ results = p.map(_per_process, args)
206
+ logging.info("Finished evaluation")
207
+
208
+ for result in results:
209
+ per_env_results_last_reward, per_env_results_rewards, mean_rewards, step = result
210
+ step = step if step is not None else 0
211
+ wandb_log_dict = {
212
+ "total/success": per_env_results_last_reward.mean(axis=(0, 1))[0],
213
+ "total/last_step_reward": per_env_results_last_reward.mean(axis=(0, 1))[1],
214
+ "total/total_steps": per_env_results_last_reward.mean(axis=(0, 1))[2],
215
+ "total/mean_reward": np.mean(mean_rewards),
216
+ "train_step": step,
217
+ }
218
+ # log for each env
219
+ for idx, env in enumerate(eval_cfgs):
220
+ wandb_log_dict.update(
221
+ {
222
+ f"{env.env_id}/success": per_env_results_last_reward[idx].mean(axis=0)[0],
223
+ f"{env.env_id}/last_step_reward": per_env_results_last_reward[idx].mean(axis=0)[1],
224
+ f"{env.env_id}/total_steps": per_env_results_last_reward[idx].mean(axis=0)[2],
225
+ f"{env.env_id}/mean_reward": mean_rewards[idx],
226
+ }
227
+ )
228
+ wandb.log(wandb_log_dict, step=step, commit=True)
229
+
230
+ path = write_results(
231
+ per_env_results_last_reward,
232
+ per_env_results_rewards,
233
+ eval_cfgs,
234
+ agent_cfg=agent_cfgs[0],
235
+ out=output_path,
236
+ )
237
+ wandb.log_artifact(path, type="file", name="results", aliases=[f"step_{step}"])
238
+
239
+
240
+ @main_app.command()
241
+ def run_eval_during_training(
242
+ wandb_id: Annotated[str, typer.Option(help="weights and biases logging id.")],
243
+ wandb_group: Annotated[str, typer.Option(help="weights and biases logging group.")],
244
+ wandb_project: Annotated[str, typer.Option(help="weights and biases logging project.")],
245
+ wandb_entity: Annotated[str, typer.Option(help="weights and biases logging entity.")],
246
+ wandb_note: Annotated[str, typer.Option(help="weights and biases logging note.")],
247
+ wandb_name: Annotated[str, typer.Option(help="weights and biases logging name.")],
248
+ output_path: Annotated[str, typer.Option(help="Path to store the run results.")],
249
+ wandb_first: Annotated[bool, typer.Option(help="whether its the first eval.")] = False,
250
+ episodes: Annotated[int, typer.Option(help="Number of episodes to run.")] = 100,
251
+ n_processes: Annotated[int | None, typer.Option(help="Number of processes to run.")] = None,
252
+ eval_cfgs: Annotated[
253
+ str, typer.Option(help="Evaluation configurations.")
254
+ ] = '[{"env": "rcs/SimplePickUpSim-v0", "kwargs": {}}]',
255
+ agent_cfg: Annotated[
256
+ str, typer.Option(help="Agent configuration.")
257
+ ] = '{"host": "localhost", "port": 8080, "agent_name": "Test", "agent_kwargs": {}, "python_path": "python"}',
258
+ ):
259
+ """
260
+ during training eval, all need to use the same id
261
+ - just for one model, but many envs
262
+ - can be new run but at least in the same project and same group as the training
263
+ """
264
+ assert (
265
+ agent_cfg["agent_name"] != "Test"
266
+ ), "agent_cfg needs to be passed as a json argument. See the default for an example."
267
+
268
+ if wandb_first:
269
+ wandb.init(
270
+ id=wandb_id,
271
+ entity=wandb_entity,
272
+ resume="allow",
273
+ group=wandb_group,
274
+ project=wandb_project,
275
+ # config=dict(agent_name=agent_name, agent_kwargs=json.loads(kwargs), eval_cfgs=json.loads(eval_cfgs)),
276
+ notes=wandb_note,
277
+ job_type="eval",
278
+ name=wandb_name,
279
+ )
280
+
281
+ wandb_log_git_diff(output_path)
282
+ wandb.run.log_code(".")
283
+ else:
284
+ wandb.init(id=wandb_id, entity=wandb_entity, resume="must", project=wandb_project)
285
+
286
+ eval_cfgs = [EvalConfig(**cfg) for cfg in json.loads(eval_cfgs)]
287
+
288
+ agent_cfg = AgentConfig(**json.loads(agent_cfg))
289
+ step = agent_cfg.agent_kwargs.get("checkpoint_step", 0)
290
+ step = step if step is not None else 0
291
+
292
+ per_env_results_last_reward, per_env_results_rewards = evaluation(
293
+ agent_cfg=agent_cfg, eval_cfgs=eval_cfgs, episodes=episodes, n_processes=n_processes
294
+ )
295
+
296
+ # return is [envs, episodes, 3(success, reward, steps)], [envs, episodes, rewards for all steps in the episode]
297
+
298
+ flatten_rewards = [[item for sublist in env_rewards for item in sublist] for env_rewards in per_env_results_rewards]
299
+ mean_rewards = [np.mean(env_rewards) if env_rewards else 0.0 for env_rewards in flatten_rewards]
300
+
301
+ # these new define metric to also not work with several jobs
302
+ # wandb says that logging can only be done in run groups
303
+ if wandb_first:
304
+ wandb.define_metric(
305
+ "total/success",
306
+ step_metric="train_step",
307
+ overwrite=False,
308
+ step_sync=False,
309
+ hidden=False,
310
+ summary="max",
311
+ )
312
+ wandb.define_metric(
313
+ "total/last_step_reward",
314
+ step_metric="train_step",
315
+ overwrite=False,
316
+ step_sync=False,
317
+ hidden=False,
318
+ summary="max",
319
+ )
320
+ wandb.define_metric(
321
+ "total/total_steps",
322
+ step_metric="train_step",
323
+ overwrite=False,
324
+ step_sync=False,
325
+ hidden=False,
326
+ summary="min",
327
+ )
328
+ wandb.define_metric(
329
+ "total/mean_reward",
330
+ step_metric="train_step",
331
+ overwrite=False,
332
+ step_sync=False,
333
+ hidden=False,
334
+ summary="max",
335
+ )
336
+ for idx, env in enumerate(eval_cfgs):
337
+ wandb.define_metric(
338
+ f"{env.env_id}/success",
339
+ step_metric="train_step",
340
+ overwrite=False,
341
+ step_sync=False,
342
+ hidden=False,
343
+ summary="max",
344
+ )
345
+ wandb.define_metric(
346
+ f"{env.env_id}/last_step_reward",
347
+ step_metric="train_step",
348
+ overwrite=False,
349
+ step_sync=False,
350
+ hidden=False,
351
+ summary="max",
352
+ )
353
+ wandb.define_metric(
354
+ f"{env.env_id}/total_steps",
355
+ step_metric="train_step",
356
+ overwrite=False,
357
+ step_sync=False,
358
+ hidden=False,
359
+ summary="min",
360
+ )
361
+ wandb.define_metric(
362
+ f"{env.env_id}/mean_reward",
363
+ step_metric="train_step",
364
+ overwrite=False,
365
+ step_sync=False,
366
+ hidden=False,
367
+ summary="max",
368
+ )
369
+
370
+ wandb_log_dict = {
371
+ "total/success": per_env_results_last_reward.mean(axis=(0, 1))[0],
372
+ "total/last_step_reward": per_env_results_last_reward.mean(axis=(0, 1))[1],
373
+ "total/total_steps": per_env_results_last_reward.mean(axis=(0, 1))[2],
374
+ "total/mean_reward": np.mean(mean_rewards),
375
+ "train_step": step,
376
+ }
377
+
378
+ # log for each env
379
+ for idx, env in enumerate(eval_cfgs):
380
+ wandb_log_dict.update(
381
+ {
382
+ f"{env.env_id}/success": per_env_results_last_reward[idx].mean(axis=0)[0],
383
+ f"{env.env_id}/last_step_reward": per_env_results_last_reward[idx].mean(axis=0)[1],
384
+ f"{env.env_id}/total_steps": per_env_results_last_reward[idx].mean(axis=0)[2],
385
+ f"{env.env_id}/mean_reward": mean_rewards[idx],
386
+ }
387
+ )
388
+ wandb.log(wandb_log_dict, step=step, commit=True)
389
+ path = write_results(
390
+ per_env_results_last_reward,
391
+ per_env_results_rewards,
392
+ eval_cfgs,
393
+ agent_cfg=agent_cfg,
394
+ out=output_path,
395
+ )
396
+ wandb.log_artifact(path, type="file", name="results", aliases=[f"step_{step}"])
397
+
398
+
399
+ if __name__ == "__main__":
400
+ main_app()
vlagents/client.py ADDED
@@ -0,0 +1,107 @@
1
+ import base64
2
+ import dataclasses
3
+ from dataclasses import asdict
4
+ from multiprocessing import shared_memory
5
+ from typing import Any
6
+
7
+ import json_numpy
8
+ import numpy as np
9
+ import rpyc
10
+ import simplejpeg
11
+
12
+ from vlagents.policies import Act, Agent, CameraDataType, Obs, SharedMemoryPayload
13
+
14
+
15
+ def dataclass_from_dict(klass, d):
16
+ # https://stackoverflow.com/questions/53376099/python-dataclass-from-a-nested-dict
17
+ try:
18
+ fieldtypes = {f.name: f.type for f in dataclasses.fields(klass)}
19
+ return klass(**{f: dataclass_from_dict(fieldtypes[f], d[f]) for f in d})
20
+ except:
21
+ return d # Not a dataclass field
22
+
23
+
24
+ class RemoteAgent(Agent):
25
+ def __init__(self, host: str, port: int, model: str, on_same_machine: bool = False, jpeg_encoding: bool = False):
26
+ """Connect to a remote agent service.
27
+
28
+ Args:
29
+ host (str): Hostname or IP address of the remote agent service.
30
+ port (int): Port number of the remote agent service.
31
+ model (str): Name of the model to connect to.
32
+ on_same_machine (bool, optional): If True, assumes the agent is running on the same machine and uses
33
+ shared memory for more efficient communication. Defaults to False.
34
+ jpeg_encoding (bool, optional): If True the image data is jpeg encoded for smaller transfer size.
35
+ Defaults to False.
36
+ """
37
+ self.on_same_machine = on_same_machine
38
+ self.jpeg_encoding = jpeg_encoding
39
+ self._shm: dict[str, shared_memory.SharedMemory] = {}
40
+ self.c = rpyc.connect(
41
+ host, port, config={"allow_pickle": True, "allow_public_attrs": True, "sync_request_timeout": 300}
42
+ )
43
+ assert model == self.c.root.name()
44
+
45
+ def _process(self, obs: Obs) -> Obs:
46
+ if self.on_same_machine:
47
+ camera_dict = {}
48
+ for camera_name, camera_data in obs.cameras.items():
49
+ assert isinstance(camera_data, np.ndarray)
50
+ if camera_name not in self._shm:
51
+ self._shm[camera_name] = shared_memory.SharedMemory(create=True, size=camera_data.nbytes)
52
+ camera_shared = np.ndarray(
53
+ camera_data.shape, buffer=self._shm[camera_name].buf, dtype=camera_data.dtype
54
+ )
55
+ camera_shared[:] = camera_data[:]
56
+ camera_dict[camera_name] = SharedMemoryPayload(
57
+ shm_name=self._shm[camera_name].name,
58
+ shape=camera_data.shape,
59
+ dtype=camera_data.dtype.name,
60
+ )
61
+ obs.cameras = camera_dict
62
+ obs.camera_data_type = CameraDataType.SHARED_MEMORY
63
+ elif self.jpeg_encoding:
64
+ camera_dict = {}
65
+ for camera_name, camera_data in obs.cameras.items():
66
+ assert isinstance(camera_data, np.ndarray)
67
+ camera_dict[camera_name] = base64.urlsafe_b64encode(
68
+ simplejpeg.encode_jpeg(np.ascontiguousarray(camera_data))
69
+ ).decode("utf-8")
70
+ obs.cameras = camera_dict
71
+ obs.camera_data_type = CameraDataType.JPEG_ENCODED
72
+ return obs
73
+
74
+ def act(self, obs: Obs) -> Act:
75
+ obs = self._process(obs)
76
+ obs = json_numpy.dumps(asdict(obs))
77
+ # action, done, info
78
+ return dataclass_from_dict(Act, json_numpy.loads(self.c.root.act(obs)))
79
+
80
+ def reset(self, obs: Obs, instruction: Any, **kwargs) -> dict[str, Any]:
81
+ obs = self._process(obs)
82
+ obs_dict = asdict(obs)
83
+ # info
84
+ return json_numpy.loads(self.c.root.reset(json_numpy.dumps((obs_dict, instruction, kwargs))))
85
+
86
+ def git_status(self) -> str:
87
+ return json_numpy.loads(self.c.root.git_status())
88
+
89
+ def is_initialized(self) -> bool:
90
+ return self.c.root.is_initialized()
91
+
92
+ def close(self):
93
+ for shm in self._shm.values():
94
+ shm.close()
95
+ shm.unlink()
96
+ self._shm = {}
97
+ self.c.close()
98
+
99
+
100
+ if __name__ == "__main__":
101
+ # to test the connection
102
+ agent = RemoteAgent("localhost", 8080, "test")
103
+ obs = Obs(cameras={"rgb_side": np.zeros((256, 256, 3), dtype=np.uint8)})
104
+ instruction = "do something"
105
+ agent.reset(obs, instruction)
106
+ print(agent.act(obs))
107
+ print(agent.act(obs))