rlmesh 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rlmesh/__init__.py +4 -0
- rlmesh/client/__init__.py +3 -0
- rlmesh/client/print.py +66 -0
- rlmesh/client/session.py +203 -0
- rlmesh/env/__init__.py +4 -0
- rlmesh/env/make_env.py +44 -0
- rlmesh/env/remote_env.py +192 -0
- rlmesh/env/remote_vector_env.py +54 -0
- rlmesh/gym/__init__.py +29 -0
- rlmesh/gym/env_codec.py +109 -0
- rlmesh/gym/spaces_codec.py +315 -0
- rlmesh/gym/tensor_codec.py +19 -0
- rlmesh/gym/utils.py +64 -0
- rlmesh/ping/__init__.py +3 -0
- rlmesh/ping/gcping.py +51 -0
- rlmesh/proto/__init__.py +14 -0
- rlmesh/proto/control_pb2.py +69 -0
- rlmesh/proto/control_pb2.pyi +120 -0
- rlmesh/proto/control_pb2_grpc.py +312 -0
- rlmesh/proto/env_agent_pb2_grpc.py +24 -0
- rlmesh/proto/env_pb2.py +61 -0
- rlmesh/proto/env_pb2.pyi +97 -0
- rlmesh/proto/env_pb2_grpc.py +97 -0
- rlmesh/proto/gym_pb2.py +61 -0
- rlmesh/proto/gym_pb2.pyi +152 -0
- rlmesh/proto/gym_pb2_grpc.py +24 -0
- rlmesh/proto/internal_env_pb2.py +42 -0
- rlmesh/proto/internal_env_pb2.pyi +25 -0
- rlmesh/proto/internal_env_pb2_grpc.py +98 -0
- rlmesh/proto/internal_session_pb2.py +48 -0
- rlmesh/proto/internal_session_pb2.pyi +49 -0
- rlmesh/proto/internal_session_pb2_grpc.py +141 -0
- rlmesh/utils/__init__.py +5 -0
- rlmesh/utils/background.py +17 -0
- rlmesh/utils/config.py +45 -0
- rlmesh/utils/env.py +40 -0
- rlmesh-0.0.1.dist-info/METADATA +20 -0
- rlmesh-0.0.1.dist-info/RECORD +39 -0
- rlmesh-0.0.1.dist-info/WHEEL +4 -0
rlmesh/__init__.py
ADDED
rlmesh/client/print.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from math import ceil
|
|
2
|
+
|
|
3
|
+
from rich import print
|
|
4
|
+
from rich.table import Table
|
|
5
|
+
from rich.panel import Panel
|
|
6
|
+
from rich.console import Group, Console
|
|
7
|
+
|
|
8
|
+
from rlmesh.proto import control_pb2
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _into_columns(items, cols):
|
|
12
|
+
if cols < 1:
|
|
13
|
+
cols = 1
|
|
14
|
+
rows = ceil(len(items) / cols)
|
|
15
|
+
out = []
|
|
16
|
+
for c in range(cols):
|
|
17
|
+
start = c * rows
|
|
18
|
+
out.append(items[start : start + rows])
|
|
19
|
+
return out, rows
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def print_envs(res: control_pb2.ListEnvironmentsRes):
|
|
23
|
+
console = Console()
|
|
24
|
+
panels = []
|
|
25
|
+
for pkg, envs in res.packages.items():
|
|
26
|
+
env_list = list(envs.env_ids)
|
|
27
|
+
if not env_list:
|
|
28
|
+
continue
|
|
29
|
+
|
|
30
|
+
max_len = max(len(s) for s in env_list)
|
|
31
|
+
term_width = console.size.width or 100
|
|
32
|
+
min_col_width = max(18, min(40, max_len + 2))
|
|
33
|
+
cols = max(1, min(4, term_width // min_col_width))
|
|
34
|
+
|
|
35
|
+
cols_data, rows = _into_columns(env_list, cols)
|
|
36
|
+
|
|
37
|
+
grid = Table.grid(expand=True, pad_edge=False)
|
|
38
|
+
for _ in range(cols):
|
|
39
|
+
grid.add_column(ratio=1, justify="left", no_wrap=True)
|
|
40
|
+
|
|
41
|
+
for r in range(rows):
|
|
42
|
+
grid.add_row(
|
|
43
|
+
*[
|
|
44
|
+
(cols_data[c][r] if r < len(cols_data[c]) else "")
|
|
45
|
+
for c in range(cols)
|
|
46
|
+
]
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
panels.append(
|
|
50
|
+
Panel(
|
|
51
|
+
grid,
|
|
52
|
+
title=f"[green]{pkg}[/green]",
|
|
53
|
+
border_style="bright_black",
|
|
54
|
+
padding=(0, 1),
|
|
55
|
+
)
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
group = Group(*panels)
|
|
59
|
+
print(
|
|
60
|
+
Panel(
|
|
61
|
+
group,
|
|
62
|
+
title="[bold yellow]Available Environments[/bold yellow]",
|
|
63
|
+
border_style="bold yellow",
|
|
64
|
+
expand=True,
|
|
65
|
+
)
|
|
66
|
+
)
|
rlmesh/client/session.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
from typing import Iterator, Optional, Literal, Callable, Tuple
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import grpc
|
|
5
|
+
import asyncio
|
|
6
|
+
import gymnasium as gym
|
|
7
|
+
|
|
8
|
+
from gymnasium.wrappers import HumanRendering
|
|
9
|
+
|
|
10
|
+
from rlmesh.client.print import print_envs
|
|
11
|
+
from rlmesh.env import RemoteEnv
|
|
12
|
+
from rlmesh.ping import compare_gcp_regions_async
|
|
13
|
+
from rlmesh.proto import control_pb2_grpc, control_pb2
|
|
14
|
+
from rlmesh.utils import config
|
|
15
|
+
|
|
16
|
+
DEFAULT_CONTROLPLANE_ENDPOINT = "34.130.104.52:4510"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RLMeshClient:
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
api_key: str | None = None,
|
|
23
|
+
*,
|
|
24
|
+
endpoint: str | None = DEFAULT_CONTROLPLANE_ENDPOINT,
|
|
25
|
+
region: str | None = None,
|
|
26
|
+
insecure: bool = True,
|
|
27
|
+
):
|
|
28
|
+
self._addr = endpoint or config.api_url
|
|
29
|
+
self._token = api_key
|
|
30
|
+
self._region = region
|
|
31
|
+
|
|
32
|
+
self._chan = grpc.insecure_channel(self._addr) if insecure else None
|
|
33
|
+
if self._chan is None:
|
|
34
|
+
raise NotImplementedError("TLS not wired yet")
|
|
35
|
+
|
|
36
|
+
self._stub = control_pb2_grpc.SessionAPIStub(self._chan)
|
|
37
|
+
self._metadata = []
|
|
38
|
+
|
|
39
|
+
if self._token:
|
|
40
|
+
self._metadata.append(("x-rlmesh-token", self._token))
|
|
41
|
+
|
|
42
|
+
self._env = None
|
|
43
|
+
|
|
44
|
+
@contextlib.contextmanager
|
|
45
|
+
def channel(self):
|
|
46
|
+
try:
|
|
47
|
+
yield self._chan
|
|
48
|
+
finally:
|
|
49
|
+
if self._chan is not None:
|
|
50
|
+
self._chan.close()
|
|
51
|
+
|
|
52
|
+
def find_closest_region(self):
|
|
53
|
+
regions = list(self.list_regions().regions)
|
|
54
|
+
if regions:
|
|
55
|
+
return asyncio.run(compare_gcp_regions_async(regions))
|
|
56
|
+
|
|
57
|
+
return None
|
|
58
|
+
|
|
59
|
+
def list_regions(self, deadline_s: float = 5.0) -> control_pb2.ListRegionsRes:
|
|
60
|
+
req = control_pb2.ListRegionsReq()
|
|
61
|
+
return self._stub.ListRegions(req, timeout=deadline_s, metadata=self._metadata)
|
|
62
|
+
|
|
63
|
+
def list_environments(
|
|
64
|
+
self, deadline_s: float = 5.0
|
|
65
|
+
) -> control_pb2.ListEnvironmentsRes:
|
|
66
|
+
req = control_pb2.ListEnvironmentsReq()
|
|
67
|
+
return self._stub.ListEnvironments(
|
|
68
|
+
req, timeout=deadline_s, metadata=self._metadata
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def create_session(
|
|
72
|
+
self, gym_id: str, *, num_envs: int = 1, deadline_s: float = 5.0
|
|
73
|
+
) -> control_pb2.Session:
|
|
74
|
+
if self._region is None:
|
|
75
|
+
self._region = self.find_closest_region()
|
|
76
|
+
|
|
77
|
+
req = control_pb2.InteractiveSessionSpec(
|
|
78
|
+
gym_id=gym_id, num_envs=num_envs, location=self._region
|
|
79
|
+
)
|
|
80
|
+
try:
|
|
81
|
+
res = self._stub.CreateSession(
|
|
82
|
+
req, timeout=deadline_s, metadata=self._metadata
|
|
83
|
+
)
|
|
84
|
+
return res
|
|
85
|
+
except grpc.RpcError as e:
|
|
86
|
+
if e.code() == grpc.StatusCode.NOT_FOUND:
|
|
87
|
+
try:
|
|
88
|
+
details = e.details() if hasattr(e, "details") else str(e)
|
|
89
|
+
except Exception:
|
|
90
|
+
details = str(e)
|
|
91
|
+
print(f"[rlmesh] CreateSession failed: {details}")
|
|
92
|
+
print("[rlmesh] Available environments:")
|
|
93
|
+
self.print_envs()
|
|
94
|
+
raise
|
|
95
|
+
|
|
96
|
+
def watch_session(
|
|
97
|
+
self, session_id: str, *, deadline_s: float = 120.0
|
|
98
|
+
) -> Iterator[control_pb2.Session]:
|
|
99
|
+
req = control_pb2.WatchReq(session_id=session_id)
|
|
100
|
+
stream = self._stub.WatchSession(
|
|
101
|
+
req, timeout=deadline_s, metadata=self._metadata
|
|
102
|
+
)
|
|
103
|
+
for evt in stream:
|
|
104
|
+
yield evt
|
|
105
|
+
|
|
106
|
+
def wait_until_ready(
|
|
107
|
+
self, session_id: str, *, deadline_s: float = 600.0
|
|
108
|
+
) -> control_pb2.Session:
|
|
109
|
+
ready_evt: Optional[control_pb2.Session] = None
|
|
110
|
+
for evt in self.watch_session(session_id, deadline_s=deadline_s):
|
|
111
|
+
if evt.state == control_pb2.SessionState.READY:
|
|
112
|
+
ready_evt = evt
|
|
113
|
+
break
|
|
114
|
+
if not ready_evt:
|
|
115
|
+
raise RuntimeError("session stream ended without READY")
|
|
116
|
+
return ready_evt
|
|
117
|
+
|
|
118
|
+
def describe(
|
|
119
|
+
self, session_id: str, *, deadline_s: float = 5.0
|
|
120
|
+
) -> control_pb2.DescribeRes:
|
|
121
|
+
req = control_pb2.DescribeReq(session_id=session_id)
|
|
122
|
+
return self._stub.DescribeSession(
|
|
123
|
+
req, timeout=deadline_s, metadata=self._metadata
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
def get_results(
|
|
127
|
+
self, session_id: str, *, deadline_s: float = 5.0
|
|
128
|
+
) -> control_pb2.GetResultsRes:
|
|
129
|
+
req = control_pb2.GetResultsReq(session_id=session_id)
|
|
130
|
+
return self._stub.GetResults(req, timeout=deadline_s, metadata=self._metadata)
|
|
131
|
+
|
|
132
|
+
def _connect_env(
|
|
133
|
+
self, *, session: control_pb2.Session, desc: control_pb2.DescribeRes
|
|
134
|
+
) -> RemoteEnv:
|
|
135
|
+
node_addr = getattr(session, "address")
|
|
136
|
+
if not node_addr:
|
|
137
|
+
raise ValueError("session is missing node address")
|
|
138
|
+
|
|
139
|
+
metadata = self._metadata.copy()
|
|
140
|
+
metadata.append(("x-rlmesh-session", session.id))
|
|
141
|
+
|
|
142
|
+
return RemoteEnv(
|
|
143
|
+
desc.env,
|
|
144
|
+
transport_addr=node_addr,
|
|
145
|
+
transport_metadata=metadata,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
def _create_and_connect(
|
|
149
|
+
self, gym_id: str, num_envs: int = 1
|
|
150
|
+
) -> Tuple[RemoteEnv, control_pb2.Session]:
|
|
151
|
+
if self._env:
|
|
152
|
+
try:
|
|
153
|
+
self._env.close()
|
|
154
|
+
finally:
|
|
155
|
+
self._env = None
|
|
156
|
+
|
|
157
|
+
sess = self.create_session(gym_id, num_envs=num_envs)
|
|
158
|
+
sess = self.wait_until_ready(sess.id)
|
|
159
|
+
desc = self.describe(sess.id)
|
|
160
|
+
|
|
161
|
+
env = self._connect_env(session=sess, desc=desc)
|
|
162
|
+
env.render_mode = "rgb_array"
|
|
163
|
+
self._env = env
|
|
164
|
+
return env, sess
|
|
165
|
+
|
|
166
|
+
def make(
|
|
167
|
+
self, gym_id: str, render_mode: Literal["human", "rgb_array"] = "rgb_array"
|
|
168
|
+
) -> gym.Env:
|
|
169
|
+
env, _ = self._create_and_connect(gym_id)
|
|
170
|
+
|
|
171
|
+
if render_mode == "human":
|
|
172
|
+
env = HumanRendering(env)
|
|
173
|
+
|
|
174
|
+
self._env = env
|
|
175
|
+
return env
|
|
176
|
+
|
|
177
|
+
def evaluate(self, gym_id: str, get_action: Callable | None = None, render=False):
|
|
178
|
+
env, sess = self._create_and_connect(gym_id)
|
|
179
|
+
|
|
180
|
+
if render:
|
|
181
|
+
env = HumanRendering(env)
|
|
182
|
+
|
|
183
|
+
obs, _ = env.reset(seed=42)
|
|
184
|
+
done = False
|
|
185
|
+
|
|
186
|
+
while not done:
|
|
187
|
+
action = (
|
|
188
|
+
env.action_space.sample() if get_action is None else get_action(obs)
|
|
189
|
+
)
|
|
190
|
+
obs, r, term, trunc, _ = env.step(action)
|
|
191
|
+
done = term or trunc
|
|
192
|
+
|
|
193
|
+
env.close()
|
|
194
|
+
results = self.get_results(sess.id)
|
|
195
|
+
return {
|
|
196
|
+
"score": results.score,
|
|
197
|
+
"duration": results.duration,
|
|
198
|
+
"timestamp": results.timestamp.ToJsonString(),
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
def print_envs(self):
|
|
202
|
+
res = self.list_environments()
|
|
203
|
+
return print_envs(res)
|
rlmesh/env/__init__.py
ADDED
rlmesh/env/make_env.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from typing import Literal, Any, Optional
|
|
2
|
+
|
|
3
|
+
import gymnasium as gym
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def make_env_factory(
|
|
7
|
+
gym_id: str,
|
|
8
|
+
gym_type: Literal["gymnasium", "gym-v26", "gym-v21"] = "gymnasium",
|
|
9
|
+
gym_vars: dict[str, Any] = {},
|
|
10
|
+
):
|
|
11
|
+
def env_factory(index=0, **kwargs):
|
|
12
|
+
env_vars = {
|
|
13
|
+
**(gym_vars or {}),
|
|
14
|
+
"render_mode": "rgb_array",
|
|
15
|
+
# "index": index,
|
|
16
|
+
**kwargs,
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
env = None
|
|
20
|
+
if gym_type == "gymnasium":
|
|
21
|
+
env = gym.make(gym_id, **env_vars)
|
|
22
|
+
elif gym_type == "gym-v26":
|
|
23
|
+
env = gym.make("GymV26Environment-v0", env_id=gym_id, **env_vars)
|
|
24
|
+
elif gym_type == "gym-v21":
|
|
25
|
+
env = gym.make("GymV21Environment-v0", env_id=gym_id, **env_vars)
|
|
26
|
+
else:
|
|
27
|
+
raise EnvironmentError(
|
|
28
|
+
f"Unsupported environment type: {gym_type}. "
|
|
29
|
+
"Please use a supported environment."
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
return env
|
|
33
|
+
|
|
34
|
+
return env_factory
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def make_env(
|
|
38
|
+
gym_id: str,
|
|
39
|
+
gym_type: Literal["gymnasium", "gym-v26", "gym-v21"] = "gymnasium",
|
|
40
|
+
*,
|
|
41
|
+
render_mode: Optional[str] = None,
|
|
42
|
+
**kwargs,
|
|
43
|
+
):
|
|
44
|
+
return make_env_factory(gym_id, gym_type)(render_mode=render_mode, **kwargs)
|
rlmesh/env/remote_env.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
from typing import Any, Optional
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import grpc
|
|
5
|
+
import gymnasium as gym
|
|
6
|
+
import gymnasium.envs.registration as reg
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from google.protobuf.json_format import MessageToDict
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
from rlmesh.gym import deserialize_space, tensor_from_ndarray, ndarray_from_tensor
|
|
13
|
+
from rlmesh.proto import gym_pb2, env_pb2_grpc, env_pb2
|
|
14
|
+
from rlmesh.utils import LoopRunner
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AsyncEnvTransport:
|
|
18
|
+
def __init__(self, addr: str, metadata: list):
|
|
19
|
+
self._addr = addr
|
|
20
|
+
self._metadata = metadata
|
|
21
|
+
|
|
22
|
+
self._chan: grpc.aio.Channel | None = None
|
|
23
|
+
self._stub: env_pb2_grpc.EnvServiceStub | None = None
|
|
24
|
+
self._stream: grpc.aio.StreamStreamCall | None = None
|
|
25
|
+
self._ready = asyncio.Event()
|
|
26
|
+
self._lock = asyncio.Lock()
|
|
27
|
+
|
|
28
|
+
async def start(self):
|
|
29
|
+
self._chan = grpc.aio.insecure_channel(self._addr)
|
|
30
|
+
self._stub = env_pb2_grpc.EnvServiceStub(self._chan)
|
|
31
|
+
self._stream = self._stub.Connect(metadata=self._metadata)
|
|
32
|
+
self._ready.set()
|
|
33
|
+
|
|
34
|
+
async def _read_and_check_msg(
|
|
35
|
+
self, expected: Optional[str] = None
|
|
36
|
+
) -> env_pb2.ServerMsg:
|
|
37
|
+
assert self._stream is not None
|
|
38
|
+
msg = await self._stream.read()
|
|
39
|
+
|
|
40
|
+
if msg is None:
|
|
41
|
+
raise RuntimeError("stream closed on reset")
|
|
42
|
+
|
|
43
|
+
assert isinstance(msg, env_pb2.ServerMsg)
|
|
44
|
+
if msg.WhichOneof("kind") == "error":
|
|
45
|
+
raise RuntimeError(
|
|
46
|
+
f"recieved an error from the server: {msg.error.message}"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
if expected is not None and msg.WhichOneof("kind") != expected:
|
|
50
|
+
raise RuntimeError(f"expected {expected}, got {msg.WhichOneof('kind')}")
|
|
51
|
+
|
|
52
|
+
return msg
|
|
53
|
+
|
|
54
|
+
async def close(self):
|
|
55
|
+
try:
|
|
56
|
+
if self._stream is not None:
|
|
57
|
+
await self._stream.write(env_pb2.ClientMsg(close=env_pb2.CloseReq()))
|
|
58
|
+
await self._stream.done_writing()
|
|
59
|
+
finally:
|
|
60
|
+
if self._chan is not None:
|
|
61
|
+
await self._chan.close()
|
|
62
|
+
|
|
63
|
+
async def reset(
|
|
64
|
+
self, *, seed: int | None, options: dict | None = None
|
|
65
|
+
) -> tuple[np.ndarray, dict]:
|
|
66
|
+
await self._ready.wait()
|
|
67
|
+
async with self._lock:
|
|
68
|
+
assert self._stream is not None
|
|
69
|
+
seeds = [] if seed is None else [int(seed)]
|
|
70
|
+
reset = env_pb2.ResetReq(seeds=seeds)
|
|
71
|
+
|
|
72
|
+
await self._stream.write(env_pb2.ClientMsg(reset=reset))
|
|
73
|
+
msg = await self._read_and_check_msg("reset_ok")
|
|
74
|
+
|
|
75
|
+
obs = ndarray_from_tensor(msg.reset_ok.obs)
|
|
76
|
+
|
|
77
|
+
info = {}
|
|
78
|
+
if msg.reset_ok.infos:
|
|
79
|
+
info = MessageToDict(msg.reset_ok.infos[0])
|
|
80
|
+
|
|
81
|
+
return obs, info
|
|
82
|
+
|
|
83
|
+
async def step(
|
|
84
|
+
self, action_msg: env_pb2.ClientMsg
|
|
85
|
+
) -> tuple[np.ndarray, float, bool, bool, dict]:
|
|
86
|
+
await self._ready.wait()
|
|
87
|
+
async with self._lock:
|
|
88
|
+
assert self._stream is not None
|
|
89
|
+
await self._stream.write(action_msg)
|
|
90
|
+
msg = await self._read_and_check_msg()
|
|
91
|
+
|
|
92
|
+
k = msg.WhichOneof("kind")
|
|
93
|
+
if k == "step_ok":
|
|
94
|
+
s = msg.step_ok
|
|
95
|
+
obs = ndarray_from_tensor(s.obs)
|
|
96
|
+
rew = float(ndarray_from_tensor(s.rewards).reshape(-1)[0])
|
|
97
|
+
term = bool(s.terminated_mask and (s.terminated_mask[0] & 0x01))
|
|
98
|
+
trunc = bool(s.truncated_mask and (s.truncated_mask[0] & 0x01))
|
|
99
|
+
info = {}
|
|
100
|
+
if s.infos:
|
|
101
|
+
from google.protobuf.json_format import MessageToDict
|
|
102
|
+
|
|
103
|
+
info = MessageToDict(s.infos[0])
|
|
104
|
+
return obs, rew, term, trunc, info
|
|
105
|
+
|
|
106
|
+
if k == "reset_ok":
|
|
107
|
+
obs = ndarray_from_tensor(msg.reset_ok.obs)
|
|
108
|
+
return obs, 0.0, True, False, {}
|
|
109
|
+
|
|
110
|
+
if k == "error":
|
|
111
|
+
raise RuntimeError(f"server error: {msg.error.message}")
|
|
112
|
+
|
|
113
|
+
raise RuntimeError(f"unexpected server msg kind: {k}")
|
|
114
|
+
|
|
115
|
+
async def render(self) -> np.ndarray | None:
|
|
116
|
+
await self._ready.wait()
|
|
117
|
+
async with self._lock:
|
|
118
|
+
assert self._stream is not None
|
|
119
|
+
await self._stream.write(env_pb2.ClientMsg(render=env_pb2.RenderReq()))
|
|
120
|
+
msg = await self._read_and_check_msg("render_ok")
|
|
121
|
+
frame = ndarray_from_tensor(msg.render_ok.frame)
|
|
122
|
+
return frame
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class RemoteEnv(gym.Env):
|
|
126
|
+
metadata: dict[str, Any]
|
|
127
|
+
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
env_spec: gym_pb2.EnvSpec,
|
|
131
|
+
*,
|
|
132
|
+
transport_addr: str,
|
|
133
|
+
transport_metadata: list,
|
|
134
|
+
):
|
|
135
|
+
super().__init__()
|
|
136
|
+
|
|
137
|
+
self.observation_space = deserialize_space(env_spec.spaces.obs)
|
|
138
|
+
self.action_space = deserialize_space(env_spec.spaces.action)
|
|
139
|
+
self.metadata = {
|
|
140
|
+
"render_modes": list(env_spec.metadata.render_modes),
|
|
141
|
+
"render_fps": int(env_spec.metadata.render_fps or 30),
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
self.spec = reg.EnvSpec(id=env_spec.gym_id)
|
|
145
|
+
|
|
146
|
+
self._runner = LoopRunner()
|
|
147
|
+
self._tx = AsyncEnvTransport(transport_addr, transport_metadata)
|
|
148
|
+
self._runner.run(self._tx.start())
|
|
149
|
+
self._closed = False
|
|
150
|
+
|
|
151
|
+
self._obs = None
|
|
152
|
+
self._is_ale = env_spec.package == "ale_py"
|
|
153
|
+
|
|
154
|
+
def reset(self, *, seed: int | None = None, options: dict | None = None):
|
|
155
|
+
if self._closed:
|
|
156
|
+
raise RuntimeError("env closed")
|
|
157
|
+
obs, info = self._runner.run(self._tx.reset(seed=seed, options=options))
|
|
158
|
+
self._obs = obs
|
|
159
|
+
|
|
160
|
+
return obs, info
|
|
161
|
+
|
|
162
|
+
def step(self, action):
|
|
163
|
+
if self._closed:
|
|
164
|
+
raise RuntimeError("env closed")
|
|
165
|
+
if not self.action_space.contains(action):
|
|
166
|
+
raise ValueError("action outside action_space")
|
|
167
|
+
|
|
168
|
+
arr = np.asarray(action)
|
|
169
|
+
msg = env_pb2.ClientMsg(step=env_pb2.StepReq(actions=tensor_from_ndarray(arr)))
|
|
170
|
+
obs, r, term, trunc, info = self._runner.run(self._tx.step(msg))
|
|
171
|
+
self._obs = obs
|
|
172
|
+
return obs, r, term, trunc, info
|
|
173
|
+
|
|
174
|
+
def render(self):
|
|
175
|
+
if self._closed:
|
|
176
|
+
raise RuntimeError("env closed")
|
|
177
|
+
|
|
178
|
+
# Special Case for ALE Environment
|
|
179
|
+
# - since we know the obs is the rgb_array, we can skip sending it over twice
|
|
180
|
+
if self._is_ale and self._obs is not None:
|
|
181
|
+
return self._obs
|
|
182
|
+
|
|
183
|
+
frame = self._runner.run(self._tx.render())
|
|
184
|
+
return frame
|
|
185
|
+
|
|
186
|
+
def close(self):
|
|
187
|
+
if not self._closed:
|
|
188
|
+
try:
|
|
189
|
+
self._runner.run(self._tx.close())
|
|
190
|
+
finally:
|
|
191
|
+
self._runner.stop()
|
|
192
|
+
self._closed = True
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# import gymnasium as gym
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
# class RemoteVectorEnv(gym.vector.VectorEnv):
|
|
5
|
+
# metadata: dict[str, Any]
|
|
6
|
+
|
|
7
|
+
# def __init__(
|
|
8
|
+
# self,
|
|
9
|
+
# env_spec: gym_pb2.EnvSpec,
|
|
10
|
+
# *,
|
|
11
|
+
# gateway_addr: str,
|
|
12
|
+
# session_id: str,
|
|
13
|
+
# bearer: str | None,
|
|
14
|
+
# ):
|
|
15
|
+
# super().__init__()
|
|
16
|
+
# # build spaces from spec
|
|
17
|
+
# self.observation_space = deserialize_space(env_spec.spaces.obs)
|
|
18
|
+
# self.action_space = deserialize_space(env_spec.spaces.action)
|
|
19
|
+
# self.metadata = {
|
|
20
|
+
# "render_modes": list(env_spec.metadata.render_modes),
|
|
21
|
+
# "render_fps": int(env_spec.metadata.render_fps or 30),
|
|
22
|
+
# }
|
|
23
|
+
|
|
24
|
+
# self._runner = _LoopRunner()
|
|
25
|
+
# self._tx = _AsyncEnvTransport(
|
|
26
|
+
# gateway_addr, session_id=session_id, bearer=bearer
|
|
27
|
+
# )
|
|
28
|
+
# self._runner.run(self._tx.start())
|
|
29
|
+
# self._closed = False
|
|
30
|
+
|
|
31
|
+
# def reset(self, *, seed: int | None = None, options: dict | None = None):
|
|
32
|
+
# if self._closed:
|
|
33
|
+
# raise RuntimeError("env closed")
|
|
34
|
+
# obs, info = self._runner.run(self._tx.reset(seed=seed, options=options))
|
|
35
|
+
# return obs, info
|
|
36
|
+
|
|
37
|
+
# def step(self, action):
|
|
38
|
+
# if self._closed:
|
|
39
|
+
# raise RuntimeError("env closed")
|
|
40
|
+
# if not self.action_space.contains(action):
|
|
41
|
+
# raise ValueError("action outside action_space")
|
|
42
|
+
|
|
43
|
+
# arr = np.asarray(action)
|
|
44
|
+
# msg = env_pb2.ClientMsg(step=env_pb2.StepReq(actions=tensor_from_ndarray(arr)))
|
|
45
|
+
# obs, r, term, trunc, info = self._runner.run(self._tx.step(msg))
|
|
46
|
+
# return obs, r, term, trunc, info
|
|
47
|
+
|
|
48
|
+
# def close(self):
|
|
49
|
+
# if not self._closed:
|
|
50
|
+
# try:
|
|
51
|
+
# self._runner.run(self._tx.close())
|
|
52
|
+
# finally:
|
|
53
|
+
# self._runner.stop()
|
|
54
|
+
# self._closed = True
|
rlmesh/gym/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from .env_codec import (
|
|
2
|
+
serialize_env,
|
|
3
|
+
deserialize_env,
|
|
4
|
+
env_spec_to_json,
|
|
5
|
+
env_spec_from_json,
|
|
6
|
+
)
|
|
7
|
+
from .spaces_codec import (
|
|
8
|
+
serialize_space,
|
|
9
|
+
deserialize_space,
|
|
10
|
+
space_spec_to_json,
|
|
11
|
+
space_spec_from_json,
|
|
12
|
+
)
|
|
13
|
+
from .tensor_codec import (
|
|
14
|
+
tensor_from_ndarray,
|
|
15
|
+
ndarray_from_tensor,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"serialize_env",
|
|
20
|
+
"deserialize_env",
|
|
21
|
+
"env_spec_to_json",
|
|
22
|
+
"env_spec_from_json",
|
|
23
|
+
"serialize_space",
|
|
24
|
+
"deserialize_space",
|
|
25
|
+
"space_spec_to_json",
|
|
26
|
+
"space_spec_from_json",
|
|
27
|
+
"tensor_from_ndarray",
|
|
28
|
+
"ndarray_from_tensor",
|
|
29
|
+
]
|