rlmesh 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. rlmesh-0.0.1/.gitignore +129 -0
  2. rlmesh-0.0.1/PKG-INFO +20 -0
  3. rlmesh-0.0.1/pyproject.toml +33 -0
  4. rlmesh-0.0.1/src/rlmesh/__init__.py +4 -0
  5. rlmesh-0.0.1/src/rlmesh/client/__init__.py +3 -0
  6. rlmesh-0.0.1/src/rlmesh/client/print.py +66 -0
  7. rlmesh-0.0.1/src/rlmesh/client/session.py +203 -0
  8. rlmesh-0.0.1/src/rlmesh/env/__init__.py +4 -0
  9. rlmesh-0.0.1/src/rlmesh/env/make_env.py +44 -0
  10. rlmesh-0.0.1/src/rlmesh/env/remote_env.py +192 -0
  11. rlmesh-0.0.1/src/rlmesh/env/remote_vector_env.py +54 -0
  12. rlmesh-0.0.1/src/rlmesh/gym/__init__.py +29 -0
  13. rlmesh-0.0.1/src/rlmesh/gym/env_codec.py +109 -0
  14. rlmesh-0.0.1/src/rlmesh/gym/spaces_codec.py +315 -0
  15. rlmesh-0.0.1/src/rlmesh/gym/tensor_codec.py +19 -0
  16. rlmesh-0.0.1/src/rlmesh/gym/utils.py +64 -0
  17. rlmesh-0.0.1/src/rlmesh/ping/__init__.py +3 -0
  18. rlmesh-0.0.1/src/rlmesh/ping/gcping.py +51 -0
  19. rlmesh-0.0.1/src/rlmesh/proto/__init__.py +14 -0
  20. rlmesh-0.0.1/src/rlmesh/proto/control_pb2.py +69 -0
  21. rlmesh-0.0.1/src/rlmesh/proto/control_pb2.pyi +120 -0
  22. rlmesh-0.0.1/src/rlmesh/proto/control_pb2_grpc.py +312 -0
  23. rlmesh-0.0.1/src/rlmesh/proto/env_agent_pb2_grpc.py +24 -0
  24. rlmesh-0.0.1/src/rlmesh/proto/env_pb2.py +61 -0
  25. rlmesh-0.0.1/src/rlmesh/proto/env_pb2.pyi +97 -0
  26. rlmesh-0.0.1/src/rlmesh/proto/env_pb2_grpc.py +97 -0
  27. rlmesh-0.0.1/src/rlmesh/proto/gym_pb2.py +61 -0
  28. rlmesh-0.0.1/src/rlmesh/proto/gym_pb2.pyi +152 -0
  29. rlmesh-0.0.1/src/rlmesh/proto/gym_pb2_grpc.py +24 -0
  30. rlmesh-0.0.1/src/rlmesh/proto/internal_env_pb2.py +42 -0
  31. rlmesh-0.0.1/src/rlmesh/proto/internal_env_pb2.pyi +25 -0
  32. rlmesh-0.0.1/src/rlmesh/proto/internal_env_pb2_grpc.py +98 -0
  33. rlmesh-0.0.1/src/rlmesh/proto/internal_session_pb2.py +48 -0
  34. rlmesh-0.0.1/src/rlmesh/proto/internal_session_pb2.pyi +49 -0
  35. rlmesh-0.0.1/src/rlmesh/proto/internal_session_pb2_grpc.py +141 -0
  36. rlmesh-0.0.1/src/rlmesh/utils/__init__.py +5 -0
  37. rlmesh-0.0.1/src/rlmesh/utils/background.py +17 -0
  38. rlmesh-0.0.1/src/rlmesh/utils/config.py +45 -0
  39. rlmesh-0.0.1/src/rlmesh/utils/env.py +40 -0
  40. rlmesh-0.0.1/tests/gym/test_env_codec.py +0 -0
  41. rlmesh-0.0.1/tests/gym/test_spaces_codec.py +0 -0
@@ -0,0 +1,129 @@
1
+ .DS_Store
2
+ __MACOSX
3
+
4
+ # C++ and CMake stuff:
5
+ *.a
6
+ *.bin
7
+ *.o
8
+ /.ccls-cache/
9
+ /arrow/
10
+ **/CMakeFiles/
11
+ **/CMakeCache.txt
12
+ **/Makefile
13
+ **/cmake_install.cmake
14
+ _deps
15
+ **/.cache/
16
+ **/rerun_cpp/docs/html
17
+ **/rerun_cpp/docs/xml
18
+ *.tgz
19
+
20
+ # Rust compile target directory:
21
+ **/target
22
+ **/target_pixi
23
+ **/target_pixi_wasm
24
+ **/target_ra
25
+ **/target_wasm
26
+
27
+ # Python virtual environment:
28
+ **/venv*
29
+ **/.venv*
30
+ /env/
31
+ !.python-version
32
+
33
+ # Python build artifacts:
34
+ __pycache__
35
+ *.pyc
36
+ *.pyd
37
+ *.so
38
+ **/.pytest_cache
39
+ **/.ipynb_checkpoints/
40
+
41
+ # Pixi environment
42
+ .pixi
43
+
44
+ .gdb_history
45
+ perf.data*
46
+
47
+ **/dataset/
48
+
49
+ # Screenshots from samples etc.
50
+ screenshot*.png
51
+
52
+ # Saved example `.rrd` files
53
+ example_data
54
+
55
+ # Various builds
56
+ dist
57
+ wheels
58
+
59
+ # Screenshot comparison build
60
+ /compare_screenshot
61
+ **/tests/snapshots/**/*.diff.png
62
+ **/tests/snapshots/**/*.new.png
63
+ **/tests/snapshots/**/*.old.png
64
+
65
+ # Mesa install
66
+ mesa
67
+ mesa.7z
68
+ mesa.tar.xz
69
+ icd.json
70
+
71
+ *.rrd
72
+
73
+ /meilisearch
74
+
75
+ .pixi/*
76
+
77
+ # heaptrack files
78
+ *.zst
79
+
80
+ # IDE stuff
81
+ /.idea
82
+
83
+ # Local env files
84
+ .env
85
+ .env.*
86
+ .env.local
87
+ .env.development.local
88
+ .env.test.local
89
+ .env.production.local
90
+ .envrc
91
+
92
+ # Testing
93
+ coverage
94
+
95
+ # Misc
96
+ .DS_Store
97
+ *.pem
98
+
99
+ # Logs
100
+ logs
101
+ *.log
102
+ npm-debug.log*
103
+ yarn-debug.log*
104
+ yarn-error.log*
105
+ pnpm-debug.log*
106
+ lerna-debug.log*
107
+
108
+ node_modules
109
+ dist
110
+ dist-ssr
111
+ *.local
112
+
113
+ # Editor directories and files
114
+ .vscode/*
115
+ !.vscode/settings.json
116
+ !.vscode/extensions.json
117
+
118
+ .idea
119
+ .DS_Store
120
+ *.suo
121
+ *.ntvs*
122
+ *.njsproj
123
+ *.sln
124
+ *.sw?
125
+
126
+ *storybook.log
127
+ bin
128
+ docs/.obsidian/workspace.json
129
+ !certs/*.pem
rlmesh-0.0.1/PKG-INFO ADDED
@@ -0,0 +1,20 @@
1
+ Metadata-Version: 2.4
2
+ Name: rlmesh
3
+ Version: 0.0.1
4
+ Summary: SDK for interfacing with SAI RL Infra.
5
+ Project-URL: Homepage, https://competesai.com
6
+ Project-URL: Documentation, https://docs.competesai.com
7
+ Author-email: ArenaX Labs <research@competesai.com>
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.10
10
+ Classifier: Programming Language :: Python :: 3.11
11
+ Classifier: Programming Language :: Python :: 3.12
12
+ Requires-Python: >=3.10
13
+ Requires-Dist: asyncio<5.0.0,>=4.0.0
14
+ Requires-Dist: grpcio>=1.74
15
+ Requires-Dist: gymnasium<1.2.0,>=1.0.0
16
+ Requires-Dist: httpx[http2]>=0.28.0
17
+ Requires-Dist: numpy<3.0,>=2.0.0
18
+ Requires-Dist: protobuf>=6
19
+ Requires-Dist: requests<3.0.0,>=2.20.0
20
+ Requires-Dist: tomli>=1.1.0; python_version < '3.11'
@@ -0,0 +1,33 @@
1
+ [project]
2
+ name = "rlmesh"
3
+ version = "0.0.1"
4
+ description = "SDK for interfacing with SAI RL Infra."
5
+ authors = [{ name = "ArenaX Labs", email = "research@competesai.com" }]
6
+ requires-python = ">=3.10"
7
+ classifiers = [
8
+ "Programming Language :: Python :: 3",
9
+ "Programming Language :: Python :: 3.10",
10
+ "Programming Language :: Python :: 3.11",
11
+ "Programming Language :: Python :: 3.12",
12
+ ]
13
+ dependencies = [
14
+ "gymnasium>=1.0.0,<1.2.0",
15
+ "numpy>=2.0.0,<3.0",
16
+ "requests>=2.20.0,<3.0.0",
17
+ "asyncio>=4.0.0,<5.0.0",
18
+ "httpx[http2]>=0.28.0",
19
+ "protobuf>=6",
20
+ "grpcio>=1.74",
21
+ "tomli >= 1.1.0 ; python_version < '3.11'",
22
+ ]
23
+
24
+ [project.urls]
25
+ Homepage = "https://competesai.com"
26
+ Documentation = "https://docs.competesai.com"
27
+
28
+ [build-system]
29
+ requires = ["hatchling"]
30
+ build-backend = "hatchling.build"
31
+
32
+ [tool.hatch.build.targets.wheel]
33
+ packages = ["src/rlmesh"]
@@ -0,0 +1,4 @@
1
+ from rlmesh.client import RLMeshClient
2
+
3
+ __version__ = "0.0.1"
4
+ __all__ = ["RLMeshClient"]
@@ -0,0 +1,3 @@
1
+ from .session import RLMeshClient
2
+
3
+ __all__ = ["RLMeshClient"]
@@ -0,0 +1,66 @@
1
+ from math import ceil
2
+
3
+ from rich import print
4
+ from rich.table import Table
5
+ from rich.panel import Panel
6
+ from rich.console import Group, Console
7
+
8
+ from rlmesh.proto import control_pb2
9
+
10
+
11
+ def _into_columns(items, cols):
12
+ if cols < 1:
13
+ cols = 1
14
+ rows = ceil(len(items) / cols)
15
+ out = []
16
+ for c in range(cols):
17
+ start = c * rows
18
+ out.append(items[start : start + rows])
19
+ return out, rows
20
+
21
+
22
+ def print_envs(res: control_pb2.ListEnvironmentsRes):
23
+ console = Console()
24
+ panels = []
25
+ for pkg, envs in res.packages.items():
26
+ env_list = list(envs.env_ids)
27
+ if not env_list:
28
+ continue
29
+
30
+ max_len = max(len(s) for s in env_list)
31
+ term_width = console.size.width or 100
32
+ min_col_width = max(18, min(40, max_len + 2))
33
+ cols = max(1, min(4, term_width // min_col_width))
34
+
35
+ cols_data, rows = _into_columns(env_list, cols)
36
+
37
+ grid = Table.grid(expand=True, pad_edge=False)
38
+ for _ in range(cols):
39
+ grid.add_column(ratio=1, justify="left", no_wrap=True)
40
+
41
+ for r in range(rows):
42
+ grid.add_row(
43
+ *[
44
+ (cols_data[c][r] if r < len(cols_data[c]) else "")
45
+ for c in range(cols)
46
+ ]
47
+ )
48
+
49
+ panels.append(
50
+ Panel(
51
+ grid,
52
+ title=f"[green]{pkg}[/green]",
53
+ border_style="bright_black",
54
+ padding=(0, 1),
55
+ )
56
+ )
57
+
58
+ group = Group(*panels)
59
+ print(
60
+ Panel(
61
+ group,
62
+ title="[bold yellow]Available Environments[/bold yellow]",
63
+ border_style="bold yellow",
64
+ expand=True,
65
+ )
66
+ )
@@ -0,0 +1,203 @@
1
+ from typing import Iterator, Optional, Literal, Callable, Tuple
2
+
3
+ import contextlib
4
+ import grpc
5
+ import asyncio
6
+ import gymnasium as gym
7
+
8
+ from gymnasium.wrappers import HumanRendering
9
+
10
+ from rlmesh.client.print import print_envs
11
+ from rlmesh.env import RemoteEnv
12
+ from rlmesh.ping import compare_gcp_regions_async
13
+ from rlmesh.proto import control_pb2_grpc, control_pb2
14
+ from rlmesh.utils import config
15
+
16
+ DEFAULT_CONTROLPLANE_ENDPOINT = "34.130.104.52:4510"
17
+
18
+
19
+ class RLMeshClient:
20
+ def __init__(
21
+ self,
22
+ api_key: str | None = None,
23
+ *,
24
+ endpoint: str | None = DEFAULT_CONTROLPLANE_ENDPOINT,
25
+ region: str | None = None,
26
+ insecure: bool = True,
27
+ ):
28
+ self._addr = endpoint or config.api_url
29
+ self._token = api_key
30
+ self._region = region
31
+
32
+ self._chan = grpc.insecure_channel(self._addr) if insecure else None
33
+ if self._chan is None:
34
+ raise NotImplementedError("TLS not wired yet")
35
+
36
+ self._stub = control_pb2_grpc.SessionAPIStub(self._chan)
37
+ self._metadata = []
38
+
39
+ if self._token:
40
+ self._metadata.append(("x-rlmesh-token", self._token))
41
+
42
+ self._env = None
43
+
44
+ @contextlib.contextmanager
45
+ def channel(self):
46
+ try:
47
+ yield self._chan
48
+ finally:
49
+ if self._chan is not None:
50
+ self._chan.close()
51
+
52
+ def find_closest_region(self):
53
+ regions = list(self.list_regions().regions)
54
+ if regions:
55
+ return asyncio.run(compare_gcp_regions_async(regions))
56
+
57
+ return None
58
+
59
+ def list_regions(self, deadline_s: float = 5.0) -> control_pb2.ListRegionsRes:
60
+ req = control_pb2.ListRegionsReq()
61
+ return self._stub.ListRegions(req, timeout=deadline_s, metadata=self._metadata)
62
+
63
+ def list_environments(
64
+ self, deadline_s: float = 5.0
65
+ ) -> control_pb2.ListEnvironmentsRes:
66
+ req = control_pb2.ListEnvironmentsReq()
67
+ return self._stub.ListEnvironments(
68
+ req, timeout=deadline_s, metadata=self._metadata
69
+ )
70
+
71
+ def create_session(
72
+ self, gym_id: str, *, num_envs: int = 1, deadline_s: float = 5.0
73
+ ) -> control_pb2.Session:
74
+ if self._region is None:
75
+ self._region = self.find_closest_region()
76
+
77
+ req = control_pb2.InteractiveSessionSpec(
78
+ gym_id=gym_id, num_envs=num_envs, location=self._region
79
+ )
80
+ try:
81
+ res = self._stub.CreateSession(
82
+ req, timeout=deadline_s, metadata=self._metadata
83
+ )
84
+ return res
85
+ except grpc.RpcError as e:
86
+ if e.code() == grpc.StatusCode.NOT_FOUND:
87
+ try:
88
+ details = e.details() if hasattr(e, "details") else str(e)
89
+ except Exception:
90
+ details = str(e)
91
+ print(f"[rlmesh] CreateSession failed: {details}")
92
+ print("[rlmesh] Available environments:")
93
+ self.print_envs()
94
+ raise
95
+
96
+ def watch_session(
97
+ self, session_id: str, *, deadline_s: float = 120.0
98
+ ) -> Iterator[control_pb2.Session]:
99
+ req = control_pb2.WatchReq(session_id=session_id)
100
+ stream = self._stub.WatchSession(
101
+ req, timeout=deadline_s, metadata=self._metadata
102
+ )
103
+ for evt in stream:
104
+ yield evt
105
+
106
+ def wait_until_ready(
107
+ self, session_id: str, *, deadline_s: float = 600.0
108
+ ) -> control_pb2.Session:
109
+ ready_evt: Optional[control_pb2.Session] = None
110
+ for evt in self.watch_session(session_id, deadline_s=deadline_s):
111
+ if evt.state == control_pb2.SessionState.READY:
112
+ ready_evt = evt
113
+ break
114
+ if not ready_evt:
115
+ raise RuntimeError("session stream ended without READY")
116
+ return ready_evt
117
+
118
+ def describe(
119
+ self, session_id: str, *, deadline_s: float = 5.0
120
+ ) -> control_pb2.DescribeRes:
121
+ req = control_pb2.DescribeReq(session_id=session_id)
122
+ return self._stub.DescribeSession(
123
+ req, timeout=deadline_s, metadata=self._metadata
124
+ )
125
+
126
+ def get_results(
127
+ self, session_id: str, *, deadline_s: float = 5.0
128
+ ) -> control_pb2.GetResultsRes:
129
+ req = control_pb2.GetResultsReq(session_id=session_id)
130
+ return self._stub.GetResults(req, timeout=deadline_s, metadata=self._metadata)
131
+
132
+ def _connect_env(
133
+ self, *, session: control_pb2.Session, desc: control_pb2.DescribeRes
134
+ ) -> RemoteEnv:
135
+ node_addr = getattr(session, "address")
136
+ if not node_addr:
137
+ raise ValueError("session is missing node address")
138
+
139
+ metadata = self._metadata.copy()
140
+ metadata.append(("x-rlmesh-session", session.id))
141
+
142
+ return RemoteEnv(
143
+ desc.env,
144
+ transport_addr=node_addr,
145
+ transport_metadata=metadata,
146
+ )
147
+
148
+ def _create_and_connect(
149
+ self, gym_id: str, num_envs: int = 1
150
+ ) -> Tuple[RemoteEnv, control_pb2.Session]:
151
+ if self._env:
152
+ try:
153
+ self._env.close()
154
+ finally:
155
+ self._env = None
156
+
157
+ sess = self.create_session(gym_id, num_envs=num_envs)
158
+ sess = self.wait_until_ready(sess.id)
159
+ desc = self.describe(sess.id)
160
+
161
+ env = self._connect_env(session=sess, desc=desc)
162
+ env.render_mode = "rgb_array"
163
+ self._env = env
164
+ return env, sess
165
+
166
+ def make(
167
+ self, gym_id: str, render_mode: Literal["human", "rgb_array"] = "rgb_array"
168
+ ) -> gym.Env:
169
+ env, _ = self._create_and_connect(gym_id)
170
+
171
+ if render_mode == "human":
172
+ env = HumanRendering(env)
173
+
174
+ self._env = env
175
+ return env
176
+
177
+ def evaluate(self, gym_id: str, get_action: Callable | None = None, render=False):
178
+ env, sess = self._create_and_connect(gym_id)
179
+
180
+ if render:
181
+ env = HumanRendering(env)
182
+
183
+ obs, _ = env.reset(seed=42)
184
+ done = False
185
+
186
+ while not done:
187
+ action = (
188
+ env.action_space.sample() if get_action is None else get_action(obs)
189
+ )
190
+ obs, r, term, trunc, _ = env.step(action)
191
+ done = term or trunc
192
+
193
+ env.close()
194
+ results = self.get_results(sess.id)
195
+ return {
196
+ "score": results.score,
197
+ "duration": results.duration,
198
+ "timestamp": results.timestamp.ToJsonString(),
199
+ }
200
+
201
+ def print_envs(self):
202
+ res = self.list_environments()
203
+ return print_envs(res)
@@ -0,0 +1,4 @@
1
+ from .make_env import make_env_factory, make_env
2
+ from .remote_env import RemoteEnv
3
+
4
+ __all__ = ["make_env_factory", "make_env", "RemoteEnv"]
@@ -0,0 +1,44 @@
1
+ from typing import Literal, Any, Optional
2
+
3
+ import gymnasium as gym
4
+
5
+
6
+ def make_env_factory(
7
+ gym_id: str,
8
+ gym_type: Literal["gymnasium", "gym-v26", "gym-v21"] = "gymnasium",
9
+ gym_vars: dict[str, Any] = {},
10
+ ):
11
+ def env_factory(index=0, **kwargs):
12
+ env_vars = {
13
+ **(gym_vars or {}),
14
+ "render_mode": "rgb_array",
15
+ # "index": index,
16
+ **kwargs,
17
+ }
18
+
19
+ env = None
20
+ if gym_type == "gymnasium":
21
+ env = gym.make(gym_id, **env_vars)
22
+ elif gym_type == "gym-v26":
23
+ env = gym.make("GymV26Environment-v0", env_id=gym_id, **env_vars)
24
+ elif gym_type == "gym-v21":
25
+ env = gym.make("GymV21Environment-v0", env_id=gym_id, **env_vars)
26
+ else:
27
+ raise EnvironmentError(
28
+ f"Unsupported environment type: {gym_type}. "
29
+ "Please use a supported environment."
30
+ )
31
+
32
+ return env
33
+
34
+ return env_factory
35
+
36
+
37
+ def make_env(
38
+ gym_id: str,
39
+ gym_type: Literal["gymnasium", "gym-v26", "gym-v21"] = "gymnasium",
40
+ *,
41
+ render_mode: Optional[str] = None,
42
+ **kwargs,
43
+ ):
44
+ return make_env_factory(gym_id, gym_type)(render_mode=render_mode, **kwargs)
@@ -0,0 +1,192 @@
1
+ from typing import Any, Optional
2
+
3
+ import asyncio
4
+ import grpc
5
+ import gymnasium as gym
6
+ import gymnasium.envs.registration as reg
7
+ import numpy as np
8
+
9
+ from google.protobuf.json_format import MessageToDict
10
+
11
+
12
+ from rlmesh.gym import deserialize_space, tensor_from_ndarray, ndarray_from_tensor
13
+ from rlmesh.proto import gym_pb2, env_pb2_grpc, env_pb2
14
+ from rlmesh.utils import LoopRunner
15
+
16
+
17
+ class AsyncEnvTransport:
18
+ def __init__(self, addr: str, metadata: list):
19
+ self._addr = addr
20
+ self._metadata = metadata
21
+
22
+ self._chan: grpc.aio.Channel | None = None
23
+ self._stub: env_pb2_grpc.EnvServiceStub | None = None
24
+ self._stream: grpc.aio.StreamStreamCall | None = None
25
+ self._ready = asyncio.Event()
26
+ self._lock = asyncio.Lock()
27
+
28
+ async def start(self):
29
+ self._chan = grpc.aio.insecure_channel(self._addr)
30
+ self._stub = env_pb2_grpc.EnvServiceStub(self._chan)
31
+ self._stream = self._stub.Connect(metadata=self._metadata)
32
+ self._ready.set()
33
+
34
+ async def _read_and_check_msg(
35
+ self, expected: Optional[str] = None
36
+ ) -> env_pb2.ServerMsg:
37
+ assert self._stream is not None
38
+ msg = await self._stream.read()
39
+
40
+ if msg is None:
41
+ raise RuntimeError("stream closed on reset")
42
+
43
+ assert isinstance(msg, env_pb2.ServerMsg)
44
+ if msg.WhichOneof("kind") == "error":
45
+ raise RuntimeError(
46
+ f"recieved an error from the server: {msg.error.message}"
47
+ )
48
+
49
+ if expected is not None and msg.WhichOneof("kind") != expected:
50
+ raise RuntimeError(f"expected {expected}, got {msg.WhichOneof('kind')}")
51
+
52
+ return msg
53
+
54
+ async def close(self):
55
+ try:
56
+ if self._stream is not None:
57
+ await self._stream.write(env_pb2.ClientMsg(close=env_pb2.CloseReq()))
58
+ await self._stream.done_writing()
59
+ finally:
60
+ if self._chan is not None:
61
+ await self._chan.close()
62
+
63
+ async def reset(
64
+ self, *, seed: int | None, options: dict | None = None
65
+ ) -> tuple[np.ndarray, dict]:
66
+ await self._ready.wait()
67
+ async with self._lock:
68
+ assert self._stream is not None
69
+ seeds = [] if seed is None else [int(seed)]
70
+ reset = env_pb2.ResetReq(seeds=seeds)
71
+
72
+ await self._stream.write(env_pb2.ClientMsg(reset=reset))
73
+ msg = await self._read_and_check_msg("reset_ok")
74
+
75
+ obs = ndarray_from_tensor(msg.reset_ok.obs)
76
+
77
+ info = {}
78
+ if msg.reset_ok.infos:
79
+ info = MessageToDict(msg.reset_ok.infos[0])
80
+
81
+ return obs, info
82
+
83
+ async def step(
84
+ self, action_msg: env_pb2.ClientMsg
85
+ ) -> tuple[np.ndarray, float, bool, bool, dict]:
86
+ await self._ready.wait()
87
+ async with self._lock:
88
+ assert self._stream is not None
89
+ await self._stream.write(action_msg)
90
+ msg = await self._read_and_check_msg()
91
+
92
+ k = msg.WhichOneof("kind")
93
+ if k == "step_ok":
94
+ s = msg.step_ok
95
+ obs = ndarray_from_tensor(s.obs)
96
+ rew = float(ndarray_from_tensor(s.rewards).reshape(-1)[0])
97
+ term = bool(s.terminated_mask and (s.terminated_mask[0] & 0x01))
98
+ trunc = bool(s.truncated_mask and (s.truncated_mask[0] & 0x01))
99
+ info = {}
100
+ if s.infos:
101
+ from google.protobuf.json_format import MessageToDict
102
+
103
+ info = MessageToDict(s.infos[0])
104
+ return obs, rew, term, trunc, info
105
+
106
+ if k == "reset_ok":
107
+ obs = ndarray_from_tensor(msg.reset_ok.obs)
108
+ return obs, 0.0, True, False, {}
109
+
110
+ if k == "error":
111
+ raise RuntimeError(f"server error: {msg.error.message}")
112
+
113
+ raise RuntimeError(f"unexpected server msg kind: {k}")
114
+
115
+ async def render(self) -> np.ndarray | None:
116
+ await self._ready.wait()
117
+ async with self._lock:
118
+ assert self._stream is not None
119
+ await self._stream.write(env_pb2.ClientMsg(render=env_pb2.RenderReq()))
120
+ msg = await self._read_and_check_msg("render_ok")
121
+ frame = ndarray_from_tensor(msg.render_ok.frame)
122
+ return frame
123
+
124
+
125
+ class RemoteEnv(gym.Env):
126
+ metadata: dict[str, Any]
127
+
128
+ def __init__(
129
+ self,
130
+ env_spec: gym_pb2.EnvSpec,
131
+ *,
132
+ transport_addr: str,
133
+ transport_metadata: list,
134
+ ):
135
+ super().__init__()
136
+
137
+ self.observation_space = deserialize_space(env_spec.spaces.obs)
138
+ self.action_space = deserialize_space(env_spec.spaces.action)
139
+ self.metadata = {
140
+ "render_modes": list(env_spec.metadata.render_modes),
141
+ "render_fps": int(env_spec.metadata.render_fps or 30),
142
+ }
143
+
144
+ self.spec = reg.EnvSpec(id=env_spec.gym_id)
145
+
146
+ self._runner = LoopRunner()
147
+ self._tx = AsyncEnvTransport(transport_addr, transport_metadata)
148
+ self._runner.run(self._tx.start())
149
+ self._closed = False
150
+
151
+ self._obs = None
152
+ self._is_ale = env_spec.package == "ale_py"
153
+
154
+ def reset(self, *, seed: int | None = None, options: dict | None = None):
155
+ if self._closed:
156
+ raise RuntimeError("env closed")
157
+ obs, info = self._runner.run(self._tx.reset(seed=seed, options=options))
158
+ self._obs = obs
159
+
160
+ return obs, info
161
+
162
+ def step(self, action):
163
+ if self._closed:
164
+ raise RuntimeError("env closed")
165
+ if not self.action_space.contains(action):
166
+ raise ValueError("action outside action_space")
167
+
168
+ arr = np.asarray(action)
169
+ msg = env_pb2.ClientMsg(step=env_pb2.StepReq(actions=tensor_from_ndarray(arr)))
170
+ obs, r, term, trunc, info = self._runner.run(self._tx.step(msg))
171
+ self._obs = obs
172
+ return obs, r, term, trunc, info
173
+
174
+ def render(self):
175
+ if self._closed:
176
+ raise RuntimeError("env closed")
177
+
178
+ # Special Case for ALE Environment
179
+ # - since we know the obs is the rgb_array, we can skip sending it over twice
180
+ if self._is_ale and self._obs is not None:
181
+ return self._obs
182
+
183
+ frame = self._runner.run(self._tx.render())
184
+ return frame
185
+
186
+ def close(self):
187
+ if not self._closed:
188
+ try:
189
+ self._runner.run(self._tx.close())
190
+ finally:
191
+ self._runner.stop()
192
+ self._closed = True