rlmesh 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rlmesh/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ from rlmesh.client import RLMeshClient
2
+
3
+ __version__ = "0.0.1"
4
+ __all__ = ["RLMeshClient"]
@@ -0,0 +1,3 @@
1
+ from .session import RLMeshClient
2
+
3
+ __all__ = ["RLMeshClient"]
rlmesh/client/print.py ADDED
@@ -0,0 +1,66 @@
1
+ from math import ceil
2
+
3
+ from rich import print
4
+ from rich.table import Table
5
+ from rich.panel import Panel
6
+ from rich.console import Group, Console
7
+
8
+ from rlmesh.proto import control_pb2
9
+
10
+
11
+ def _into_columns(items, cols):
12
+ if cols < 1:
13
+ cols = 1
14
+ rows = ceil(len(items) / cols)
15
+ out = []
16
+ for c in range(cols):
17
+ start = c * rows
18
+ out.append(items[start : start + rows])
19
+ return out, rows
20
+
21
+
22
+ def print_envs(res: control_pb2.ListEnvironmentsRes):
23
+ console = Console()
24
+ panels = []
25
+ for pkg, envs in res.packages.items():
26
+ env_list = list(envs.env_ids)
27
+ if not env_list:
28
+ continue
29
+
30
+ max_len = max(len(s) for s in env_list)
31
+ term_width = console.size.width or 100
32
+ min_col_width = max(18, min(40, max_len + 2))
33
+ cols = max(1, min(4, term_width // min_col_width))
34
+
35
+ cols_data, rows = _into_columns(env_list, cols)
36
+
37
+ grid = Table.grid(expand=True, pad_edge=False)
38
+ for _ in range(cols):
39
+ grid.add_column(ratio=1, justify="left", no_wrap=True)
40
+
41
+ for r in range(rows):
42
+ grid.add_row(
43
+ *[
44
+ (cols_data[c][r] if r < len(cols_data[c]) else "")
45
+ for c in range(cols)
46
+ ]
47
+ )
48
+
49
+ panels.append(
50
+ Panel(
51
+ grid,
52
+ title=f"[green]{pkg}[/green]",
53
+ border_style="bright_black",
54
+ padding=(0, 1),
55
+ )
56
+ )
57
+
58
+ group = Group(*panels)
59
+ print(
60
+ Panel(
61
+ group,
62
+ title="[bold yellow]Available Environments[/bold yellow]",
63
+ border_style="bold yellow",
64
+ expand=True,
65
+ )
66
+ )
@@ -0,0 +1,203 @@
1
+ from typing import Iterator, Optional, Literal, Callable, Tuple
2
+
3
+ import contextlib
4
+ import grpc
5
+ import asyncio
6
+ import gymnasium as gym
7
+
8
+ from gymnasium.wrappers import HumanRendering
9
+
10
+ from rlmesh.client.print import print_envs
11
+ from rlmesh.env import RemoteEnv
12
+ from rlmesh.ping import compare_gcp_regions_async
13
+ from rlmesh.proto import control_pb2_grpc, control_pb2
14
+ from rlmesh.utils import config
15
+
16
+ DEFAULT_CONTROLPLANE_ENDPOINT = "34.130.104.52:4510"
17
+
18
+
19
+ class RLMeshClient:
20
+ def __init__(
21
+ self,
22
+ api_key: str | None = None,
23
+ *,
24
+ endpoint: str | None = DEFAULT_CONTROLPLANE_ENDPOINT,
25
+ region: str | None = None,
26
+ insecure: bool = True,
27
+ ):
28
+ self._addr = endpoint or config.api_url
29
+ self._token = api_key
30
+ self._region = region
31
+
32
+ self._chan = grpc.insecure_channel(self._addr) if insecure else None
33
+ if self._chan is None:
34
+ raise NotImplementedError("TLS not wired yet")
35
+
36
+ self._stub = control_pb2_grpc.SessionAPIStub(self._chan)
37
+ self._metadata = []
38
+
39
+ if self._token:
40
+ self._metadata.append(("x-rlmesh-token", self._token))
41
+
42
+ self._env = None
43
+
44
+ @contextlib.contextmanager
45
+ def channel(self):
46
+ try:
47
+ yield self._chan
48
+ finally:
49
+ if self._chan is not None:
50
+ self._chan.close()
51
+
52
+ def find_closest_region(self):
53
+ regions = list(self.list_regions().regions)
54
+ if regions:
55
+ return asyncio.run(compare_gcp_regions_async(regions))
56
+
57
+ return None
58
+
59
+ def list_regions(self, deadline_s: float = 5.0) -> control_pb2.ListRegionsRes:
60
+ req = control_pb2.ListRegionsReq()
61
+ return self._stub.ListRegions(req, timeout=deadline_s, metadata=self._metadata)
62
+
63
+ def list_environments(
64
+ self, deadline_s: float = 5.0
65
+ ) -> control_pb2.ListEnvironmentsRes:
66
+ req = control_pb2.ListEnvironmentsReq()
67
+ return self._stub.ListEnvironments(
68
+ req, timeout=deadline_s, metadata=self._metadata
69
+ )
70
+
71
+ def create_session(
72
+ self, gym_id: str, *, num_envs: int = 1, deadline_s: float = 5.0
73
+ ) -> control_pb2.Session:
74
+ if self._region is None:
75
+ self._region = self.find_closest_region()
76
+
77
+ req = control_pb2.InteractiveSessionSpec(
78
+ gym_id=gym_id, num_envs=num_envs, location=self._region
79
+ )
80
+ try:
81
+ res = self._stub.CreateSession(
82
+ req, timeout=deadline_s, metadata=self._metadata
83
+ )
84
+ return res
85
+ except grpc.RpcError as e:
86
+ if e.code() == grpc.StatusCode.NOT_FOUND:
87
+ try:
88
+ details = e.details() if hasattr(e, "details") else str(e)
89
+ except Exception:
90
+ details = str(e)
91
+ print(f"[rlmesh] CreateSession failed: {details}")
92
+ print("[rlmesh] Available environments:")
93
+ self.print_envs()
94
+ raise
95
+
96
+ def watch_session(
97
+ self, session_id: str, *, deadline_s: float = 120.0
98
+ ) -> Iterator[control_pb2.Session]:
99
+ req = control_pb2.WatchReq(session_id=session_id)
100
+ stream = self._stub.WatchSession(
101
+ req, timeout=deadline_s, metadata=self._metadata
102
+ )
103
+ for evt in stream:
104
+ yield evt
105
+
106
+ def wait_until_ready(
107
+ self, session_id: str, *, deadline_s: float = 600.0
108
+ ) -> control_pb2.Session:
109
+ ready_evt: Optional[control_pb2.Session] = None
110
+ for evt in self.watch_session(session_id, deadline_s=deadline_s):
111
+ if evt.state == control_pb2.SessionState.READY:
112
+ ready_evt = evt
113
+ break
114
+ if not ready_evt:
115
+ raise RuntimeError("session stream ended without READY")
116
+ return ready_evt
117
+
118
+ def describe(
119
+ self, session_id: str, *, deadline_s: float = 5.0
120
+ ) -> control_pb2.DescribeRes:
121
+ req = control_pb2.DescribeReq(session_id=session_id)
122
+ return self._stub.DescribeSession(
123
+ req, timeout=deadline_s, metadata=self._metadata
124
+ )
125
+
126
+ def get_results(
127
+ self, session_id: str, *, deadline_s: float = 5.0
128
+ ) -> control_pb2.GetResultsRes:
129
+ req = control_pb2.GetResultsReq(session_id=session_id)
130
+ return self._stub.GetResults(req, timeout=deadline_s, metadata=self._metadata)
131
+
132
+ def _connect_env(
133
+ self, *, session: control_pb2.Session, desc: control_pb2.DescribeRes
134
+ ) -> RemoteEnv:
135
+ node_addr = getattr(session, "address")
136
+ if not node_addr:
137
+ raise ValueError("session is missing node address")
138
+
139
+ metadata = self._metadata.copy()
140
+ metadata.append(("x-rlmesh-session", session.id))
141
+
142
+ return RemoteEnv(
143
+ desc.env,
144
+ transport_addr=node_addr,
145
+ transport_metadata=metadata,
146
+ )
147
+
148
+ def _create_and_connect(
149
+ self, gym_id: str, num_envs: int = 1
150
+ ) -> Tuple[RemoteEnv, control_pb2.Session]:
151
+ if self._env:
152
+ try:
153
+ self._env.close()
154
+ finally:
155
+ self._env = None
156
+
157
+ sess = self.create_session(gym_id, num_envs=num_envs)
158
+ sess = self.wait_until_ready(sess.id)
159
+ desc = self.describe(sess.id)
160
+
161
+ env = self._connect_env(session=sess, desc=desc)
162
+ env.render_mode = "rgb_array"
163
+ self._env = env
164
+ return env, sess
165
+
166
+ def make(
167
+ self, gym_id: str, render_mode: Literal["human", "rgb_array"] = "rgb_array"
168
+ ) -> gym.Env:
169
+ env, _ = self._create_and_connect(gym_id)
170
+
171
+ if render_mode == "human":
172
+ env = HumanRendering(env)
173
+
174
+ self._env = env
175
+ return env
176
+
177
+ def evaluate(self, gym_id: str, get_action: Callable | None = None, render=False):
178
+ env, sess = self._create_and_connect(gym_id)
179
+
180
+ if render:
181
+ env = HumanRendering(env)
182
+
183
+ obs, _ = env.reset(seed=42)
184
+ done = False
185
+
186
+ while not done:
187
+ action = (
188
+ env.action_space.sample() if get_action is None else get_action(obs)
189
+ )
190
+ obs, r, term, trunc, _ = env.step(action)
191
+ done = term or trunc
192
+
193
+ env.close()
194
+ results = self.get_results(sess.id)
195
+ return {
196
+ "score": results.score,
197
+ "duration": results.duration,
198
+ "timestamp": results.timestamp.ToJsonString(),
199
+ }
200
+
201
+ def print_envs(self):
202
+ res = self.list_environments()
203
+ return print_envs(res)
rlmesh/env/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ from .make_env import make_env_factory, make_env
2
+ from .remote_env import RemoteEnv
3
+
4
+ __all__ = ["make_env_factory", "make_env", "RemoteEnv"]
rlmesh/env/make_env.py ADDED
@@ -0,0 +1,44 @@
1
+ from typing import Literal, Any, Optional
2
+
3
+ import gymnasium as gym
4
+
5
+
6
+ def make_env_factory(
7
+ gym_id: str,
8
+ gym_type: Literal["gymnasium", "gym-v26", "gym-v21"] = "gymnasium",
9
+ gym_vars: dict[str, Any] = {},
10
+ ):
11
+ def env_factory(index=0, **kwargs):
12
+ env_vars = {
13
+ **(gym_vars or {}),
14
+ "render_mode": "rgb_array",
15
+ # "index": index,
16
+ **kwargs,
17
+ }
18
+
19
+ env = None
20
+ if gym_type == "gymnasium":
21
+ env = gym.make(gym_id, **env_vars)
22
+ elif gym_type == "gym-v26":
23
+ env = gym.make("GymV26Environment-v0", env_id=gym_id, **env_vars)
24
+ elif gym_type == "gym-v21":
25
+ env = gym.make("GymV21Environment-v0", env_id=gym_id, **env_vars)
26
+ else:
27
+ raise EnvironmentError(
28
+ f"Unsupported environment type: {gym_type}. "
29
+ "Please use a supported environment."
30
+ )
31
+
32
+ return env
33
+
34
+ return env_factory
35
+
36
+
37
+ def make_env(
38
+ gym_id: str,
39
+ gym_type: Literal["gymnasium", "gym-v26", "gym-v21"] = "gymnasium",
40
+ *,
41
+ render_mode: Optional[str] = None,
42
+ **kwargs,
43
+ ):
44
+ return make_env_factory(gym_id, gym_type)(render_mode=render_mode, **kwargs)
@@ -0,0 +1,192 @@
1
+ from typing import Any, Optional
2
+
3
+ import asyncio
4
+ import grpc
5
+ import gymnasium as gym
6
+ import gymnasium.envs.registration as reg
7
+ import numpy as np
8
+
9
+ from google.protobuf.json_format import MessageToDict
10
+
11
+
12
+ from rlmesh.gym import deserialize_space, tensor_from_ndarray, ndarray_from_tensor
13
+ from rlmesh.proto import gym_pb2, env_pb2_grpc, env_pb2
14
+ from rlmesh.utils import LoopRunner
15
+
16
+
17
+ class AsyncEnvTransport:
18
+ def __init__(self, addr: str, metadata: list):
19
+ self._addr = addr
20
+ self._metadata = metadata
21
+
22
+ self._chan: grpc.aio.Channel | None = None
23
+ self._stub: env_pb2_grpc.EnvServiceStub | None = None
24
+ self._stream: grpc.aio.StreamStreamCall | None = None
25
+ self._ready = asyncio.Event()
26
+ self._lock = asyncio.Lock()
27
+
28
+ async def start(self):
29
+ self._chan = grpc.aio.insecure_channel(self._addr)
30
+ self._stub = env_pb2_grpc.EnvServiceStub(self._chan)
31
+ self._stream = self._stub.Connect(metadata=self._metadata)
32
+ self._ready.set()
33
+
34
+ async def _read_and_check_msg(
35
+ self, expected: Optional[str] = None
36
+ ) -> env_pb2.ServerMsg:
37
+ assert self._stream is not None
38
+ msg = await self._stream.read()
39
+
40
+ if msg is None:
41
+ raise RuntimeError("stream closed on reset")
42
+
43
+ assert isinstance(msg, env_pb2.ServerMsg)
44
+ if msg.WhichOneof("kind") == "error":
45
+ raise RuntimeError(
46
+ f"recieved an error from the server: {msg.error.message}"
47
+ )
48
+
49
+ if expected is not None and msg.WhichOneof("kind") != expected:
50
+ raise RuntimeError(f"expected {expected}, got {msg.WhichOneof('kind')}")
51
+
52
+ return msg
53
+
54
+ async def close(self):
55
+ try:
56
+ if self._stream is not None:
57
+ await self._stream.write(env_pb2.ClientMsg(close=env_pb2.CloseReq()))
58
+ await self._stream.done_writing()
59
+ finally:
60
+ if self._chan is not None:
61
+ await self._chan.close()
62
+
63
+ async def reset(
64
+ self, *, seed: int | None, options: dict | None = None
65
+ ) -> tuple[np.ndarray, dict]:
66
+ await self._ready.wait()
67
+ async with self._lock:
68
+ assert self._stream is not None
69
+ seeds = [] if seed is None else [int(seed)]
70
+ reset = env_pb2.ResetReq(seeds=seeds)
71
+
72
+ await self._stream.write(env_pb2.ClientMsg(reset=reset))
73
+ msg = await self._read_and_check_msg("reset_ok")
74
+
75
+ obs = ndarray_from_tensor(msg.reset_ok.obs)
76
+
77
+ info = {}
78
+ if msg.reset_ok.infos:
79
+ info = MessageToDict(msg.reset_ok.infos[0])
80
+
81
+ return obs, info
82
+
83
+ async def step(
84
+ self, action_msg: env_pb2.ClientMsg
85
+ ) -> tuple[np.ndarray, float, bool, bool, dict]:
86
+ await self._ready.wait()
87
+ async with self._lock:
88
+ assert self._stream is not None
89
+ await self._stream.write(action_msg)
90
+ msg = await self._read_and_check_msg()
91
+
92
+ k = msg.WhichOneof("kind")
93
+ if k == "step_ok":
94
+ s = msg.step_ok
95
+ obs = ndarray_from_tensor(s.obs)
96
+ rew = float(ndarray_from_tensor(s.rewards).reshape(-1)[0])
97
+ term = bool(s.terminated_mask and (s.terminated_mask[0] & 0x01))
98
+ trunc = bool(s.truncated_mask and (s.truncated_mask[0] & 0x01))
99
+ info = {}
100
+ if s.infos:
101
+ from google.protobuf.json_format import MessageToDict
102
+
103
+ info = MessageToDict(s.infos[0])
104
+ return obs, rew, term, trunc, info
105
+
106
+ if k == "reset_ok":
107
+ obs = ndarray_from_tensor(msg.reset_ok.obs)
108
+ return obs, 0.0, True, False, {}
109
+
110
+ if k == "error":
111
+ raise RuntimeError(f"server error: {msg.error.message}")
112
+
113
+ raise RuntimeError(f"unexpected server msg kind: {k}")
114
+
115
+ async def render(self) -> np.ndarray | None:
116
+ await self._ready.wait()
117
+ async with self._lock:
118
+ assert self._stream is not None
119
+ await self._stream.write(env_pb2.ClientMsg(render=env_pb2.RenderReq()))
120
+ msg = await self._read_and_check_msg("render_ok")
121
+ frame = ndarray_from_tensor(msg.render_ok.frame)
122
+ return frame
123
+
124
+
125
+ class RemoteEnv(gym.Env):
126
+ metadata: dict[str, Any]
127
+
128
+ def __init__(
129
+ self,
130
+ env_spec: gym_pb2.EnvSpec,
131
+ *,
132
+ transport_addr: str,
133
+ transport_metadata: list,
134
+ ):
135
+ super().__init__()
136
+
137
+ self.observation_space = deserialize_space(env_spec.spaces.obs)
138
+ self.action_space = deserialize_space(env_spec.spaces.action)
139
+ self.metadata = {
140
+ "render_modes": list(env_spec.metadata.render_modes),
141
+ "render_fps": int(env_spec.metadata.render_fps or 30),
142
+ }
143
+
144
+ self.spec = reg.EnvSpec(id=env_spec.gym_id)
145
+
146
+ self._runner = LoopRunner()
147
+ self._tx = AsyncEnvTransport(transport_addr, transport_metadata)
148
+ self._runner.run(self._tx.start())
149
+ self._closed = False
150
+
151
+ self._obs = None
152
+ self._is_ale = env_spec.package == "ale_py"
153
+
154
+ def reset(self, *, seed: int | None = None, options: dict | None = None):
155
+ if self._closed:
156
+ raise RuntimeError("env closed")
157
+ obs, info = self._runner.run(self._tx.reset(seed=seed, options=options))
158
+ self._obs = obs
159
+
160
+ return obs, info
161
+
162
+ def step(self, action):
163
+ if self._closed:
164
+ raise RuntimeError("env closed")
165
+ if not self.action_space.contains(action):
166
+ raise ValueError("action outside action_space")
167
+
168
+ arr = np.asarray(action)
169
+ msg = env_pb2.ClientMsg(step=env_pb2.StepReq(actions=tensor_from_ndarray(arr)))
170
+ obs, r, term, trunc, info = self._runner.run(self._tx.step(msg))
171
+ self._obs = obs
172
+ return obs, r, term, trunc, info
173
+
174
+ def render(self):
175
+ if self._closed:
176
+ raise RuntimeError("env closed")
177
+
178
+ # Special Case for ALE Environment
179
+ # - since we know the obs is the rgb_array, we can skip sending it over twice
180
+ if self._is_ale and self._obs is not None:
181
+ return self._obs
182
+
183
+ frame = self._runner.run(self._tx.render())
184
+ return frame
185
+
186
+ def close(self):
187
+ if not self._closed:
188
+ try:
189
+ self._runner.run(self._tx.close())
190
+ finally:
191
+ self._runner.stop()
192
+ self._closed = True
@@ -0,0 +1,54 @@
1
+ # import gymnasium as gym
2
+
3
+
4
+ # class RemoteVectorEnv(gym.vector.VectorEnv):
5
+ # metadata: dict[str, Any]
6
+
7
+ # def __init__(
8
+ # self,
9
+ # env_spec: gym_pb2.EnvSpec,
10
+ # *,
11
+ # gateway_addr: str,
12
+ # session_id: str,
13
+ # bearer: str | None,
14
+ # ):
15
+ # super().__init__()
16
+ # # build spaces from spec
17
+ # self.observation_space = deserialize_space(env_spec.spaces.obs)
18
+ # self.action_space = deserialize_space(env_spec.spaces.action)
19
+ # self.metadata = {
20
+ # "render_modes": list(env_spec.metadata.render_modes),
21
+ # "render_fps": int(env_spec.metadata.render_fps or 30),
22
+ # }
23
+
24
+ # self._runner = _LoopRunner()
25
+ # self._tx = _AsyncEnvTransport(
26
+ # gateway_addr, session_id=session_id, bearer=bearer
27
+ # )
28
+ # self._runner.run(self._tx.start())
29
+ # self._closed = False
30
+
31
+ # def reset(self, *, seed: int | None = None, options: dict | None = None):
32
+ # if self._closed:
33
+ # raise RuntimeError("env closed")
34
+ # obs, info = self._runner.run(self._tx.reset(seed=seed, options=options))
35
+ # return obs, info
36
+
37
+ # def step(self, action):
38
+ # if self._closed:
39
+ # raise RuntimeError("env closed")
40
+ # if not self.action_space.contains(action):
41
+ # raise ValueError("action outside action_space")
42
+
43
+ # arr = np.asarray(action)
44
+ # msg = env_pb2.ClientMsg(step=env_pb2.StepReq(actions=tensor_from_ndarray(arr)))
45
+ # obs, r, term, trunc, info = self._runner.run(self._tx.step(msg))
46
+ # return obs, r, term, trunc, info
47
+
48
+ # def close(self):
49
+ # if not self._closed:
50
+ # try:
51
+ # self._runner.run(self._tx.close())
52
+ # finally:
53
+ # self._runner.stop()
54
+ # self._closed = True
rlmesh/gym/__init__.py ADDED
@@ -0,0 +1,29 @@
1
+ from .env_codec import (
2
+ serialize_env,
3
+ deserialize_env,
4
+ env_spec_to_json,
5
+ env_spec_from_json,
6
+ )
7
+ from .spaces_codec import (
8
+ serialize_space,
9
+ deserialize_space,
10
+ space_spec_to_json,
11
+ space_spec_from_json,
12
+ )
13
+ from .tensor_codec import (
14
+ tensor_from_ndarray,
15
+ ndarray_from_tensor,
16
+ )
17
+
18
+ __all__ = [
19
+ "serialize_env",
20
+ "deserialize_env",
21
+ "env_spec_to_json",
22
+ "env_spec_from_json",
23
+ "serialize_space",
24
+ "deserialize_space",
25
+ "space_spec_to_json",
26
+ "space_spec_from_json",
27
+ "tensor_from_ndarray",
28
+ "ndarray_from_tensor",
29
+ ]