mouse-env 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mouse/envs/__init__.py ADDED
@@ -0,0 +1,11 @@
1
+ """MOUSE environments — vector envs and rollout formatting for mouse-core."""
2
+
3
+ from mouse.envs.build import make_vector_env
4
+ from mouse.envs.config import EnvConfig
5
+ from mouse.envs.format import MouseVectorEnv
6
+
7
+ __all__ = [
8
+ "EnvConfig",
9
+ "make_vector_env",
10
+ "MouseVectorEnv",
11
+ ]
mouse/envs/build.py ADDED
@@ -0,0 +1,230 @@
1
+ """Build vector environments from :class:`EnvConfig`."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Callable
6
+ from typing import Any
7
+
8
+ import gymnasium as gym
9
+
10
+ from mouse.envs.config import (
11
+ EnvConfig,
12
+ is_ns_gym_env,
13
+ normalize_group_id,
14
+ resolve_q_star_source_for_env,
15
+ )
16
+ from mouse.envs.experts.action_star import apply_q_star_source_env_kwargs
17
+ from mouse.envs.integrations.atari import (
18
+ ensure_ale_registered,
19
+ is_ale_env,
20
+ wrap_atari_preprocessing,
21
+ )
22
+ from mouse.envs.integrations.ns_gym import make_ns_env
23
+ from mouse.envs.env_ids import PROCEDURAL_FROZENLAKE_ENV_ID, SYNTHETIC_ENV_ID
24
+ from mouse.envs.format import MouseVectorEnv
25
+ from mouse.envs.wrappers import (
26
+ ConstructionSeedWrapper,
27
+ ObservationSliceWrapper,
28
+ build_vector_env_stack,
29
+ )
30
+
31
+
32
+ def _resolve_group_ids(group_id: str, num_envs: int, group_ids: list[str] | None) -> list[str]:
33
+ if group_ids is not None:
34
+ if len(group_ids) != num_envs:
35
+ raise ValueError(
36
+ f"group_ids has {len(group_ids)} entries but num_envs={num_envs}. "
37
+ f"Provide exactly one id per vector index."
38
+ )
39
+ seen: dict[str, list[int]] = {}
40
+ for i, gid in enumerate(group_ids):
41
+ seen.setdefault(gid, []).append(i)
42
+ duplicates = {k: v for k, v in seen.items() if len(v) > 1}
43
+ if duplicates:
44
+ raise ValueError(
45
+ f"group_id must be unique per vector index — found duplicates: {duplicates}. "
46
+ f"Use a unique suffix per index (e.g. '{group_id}#0', '{group_id}#1')."
47
+ )
48
+ return list(group_ids)
49
+
50
+ if not group_id:
51
+ raise ValueError(
52
+ "group_id is required on EnvConfig but was not set. "
53
+ "Provide a non-empty group_id (e.g. 'CartPole-v1')."
54
+ )
55
+
56
+ if num_envs == 1:
57
+ return [group_id]
58
+
59
+ return [f"{group_id}#{i}" for i in range(num_envs)]
60
+
61
+
62
+ def make_vector_env(config: EnvConfig) -> MouseVectorEnv:
63
+ """Create a vector env from :class:`EnvConfig`."""
64
+ if config.max_episode_steps is None:
65
+ raise ValueError(
66
+ "max_episode_steps is required (used to normalise xformed_reward). "
67
+ "Set max_episode_steps in the env config."
68
+ )
69
+ if config.num_envs < 1:
70
+ raise ValueError(f"num_envs must be >= 1, got {config.num_envs}.")
71
+
72
+ resolved_group_id = normalize_group_id(config.group_id)
73
+ resolved_q_star_source = resolve_q_star_source_for_env(config.group_id, config.q_star_source)
74
+ resolved_group_ids = _resolve_group_ids(resolved_group_id, config.num_envs, config.group_ids)
75
+
76
+ if is_ns_gym_env(config.group_id, config.non_stationary_params):
77
+ gym_env = _build_ns_vector_env(
78
+ config=config,
79
+ resolved_group_id=resolved_group_id,
80
+ resolved_q_star_source=resolved_q_star_source,
81
+ resolved_group_ids=resolved_group_ids,
82
+ )
83
+ else:
84
+ gym_env = _build_plain_vector_env(
85
+ config=config,
86
+ resolved_group_id=resolved_group_id,
87
+ resolved_q_star_source=resolved_q_star_source,
88
+ resolved_group_ids=resolved_group_ids,
89
+ )
90
+
91
+ return MouseVectorEnv(gym_env, resolved_group_ids)
92
+
93
+
94
+ def _prepare_plain_env_kwargs(config: EnvConfig, *, atari_preprocessing: bool) -> dict[str, Any]:
95
+ env_kwargs = dict(config.kwargs or {})
96
+ env_kwargs = apply_q_star_source_env_kwargs(
97
+ env_id=config.group_id,
98
+ env_kwargs=env_kwargs,
99
+ q_star_source=config.q_star_source,
100
+ )
101
+ if config.group_id == PROCEDURAL_FROZENLAKE_ENV_ID:
102
+ from mouse.envs.worlds.procedural_frozenlake import ensure_procedural_frozenlake_registered
103
+
104
+ ensure_procedural_frozenlake_registered()
105
+ random_map_wrapper_raw = env_kwargs.pop("random_map_wrapper", None)
106
+ if isinstance(random_map_wrapper_raw, dict):
107
+ env_kwargs.update(dict(random_map_wrapper_raw))
108
+ elif random_map_wrapper_raw is not None:
109
+ raise ValueError("env_kwargs.random_map_wrapper must be a dict when provided.")
110
+ elif config.group_id == SYNTHETIC_ENV_ID:
111
+ from mouse.envs.worlds.synthetic import ensure_synthetic_env_registered
112
+
113
+ ensure_synthetic_env_registered()
114
+ if config.render and "render_mode" not in env_kwargs:
115
+ env_kwargs["render_mode"] = "human"
116
+ if is_ale_env(config.group_id):
117
+ ensure_ale_registered()
118
+ if is_ale_env(config.group_id) and atari_preprocessing:
119
+ env_kwargs["frameskip"] = 1
120
+ if config.observation_indices is not None and is_ale_env(config.group_id):
121
+ raise ValueError("observation_indices is not supported for ALE (Atari) envs.")
122
+ return env_kwargs
123
+
124
+
125
+ def _make_plain_single_env(
126
+ config: EnvConfig,
127
+ index: int,
128
+ *,
129
+ env_kwargs: dict[str, Any],
130
+ atari_preprocessing: bool,
131
+ ) -> gym.Env:
132
+ mdp_seed = config.seed + index
133
+ seeded_at_construction = config.group_id in (SYNTHETIC_ENV_ID, PROCEDURAL_FROZENLAKE_ENV_ID)
134
+ use_preprocessing = is_ale_env(config.group_id) and atari_preprocessing
135
+
136
+ def env_fn(s: int) -> gym.Env:
137
+ kw = dict(env_kwargs)
138
+ if seeded_at_construction:
139
+ kw["seed"] = s
140
+ env = gym.make(
141
+ config.group_id,
142
+ max_episode_steps=config.max_episode_steps,
143
+ **kw,
144
+ )
145
+ return wrap_atari_preprocessing(
146
+ env,
147
+ enabled=use_preprocessing,
148
+ preprocessing_kwargs=config.atari_preprocessing_kwargs,
149
+ )
150
+
151
+ env = ConstructionSeedWrapper(env_fn, seed=mdp_seed)
152
+ if config.observation_indices is not None:
153
+ env = ObservationSliceWrapper(env=env, indices=config.observation_indices)
154
+ return env
155
+
156
+
157
+ def _make_ns_single_env(config: EnvConfig) -> gym.Env:
158
+ env = make_ns_env(
159
+ env_id=config.group_id,
160
+ non_stationary_params=config.non_stationary_params or {},
161
+ max_steps_per_episode=config.max_episode_steps,
162
+ env_kwargs=config.kwargs,
163
+ render=config.render,
164
+ )
165
+ if config.observation_indices is not None:
166
+ env = ObservationSliceWrapper(env=env, indices=config.observation_indices)
167
+ return env
168
+
169
+
170
+ def _build_plain_vector_env(
171
+ *,
172
+ config: EnvConfig,
173
+ resolved_group_id: str,
174
+ resolved_q_star_source: dict[str, Any] | None,
175
+ resolved_group_ids: list[str],
176
+ ) -> gym.vector.VectorEnv:
177
+ atari_preprocessing = bool(config.atari_preprocessing)
178
+ env_kwargs = _prepare_plain_env_kwargs(config, atari_preprocessing=atari_preprocessing)
179
+
180
+ if config.group_id in (SYNTHETIC_ENV_ID, PROCEDURAL_FROZENLAKE_ENV_ID):
181
+ clean_kwargs = dict(env_kwargs)
182
+ clean_kwargs.pop("seed", None)
183
+ else:
184
+ clean_kwargs = env_kwargs
185
+
186
+ env_fns: list[Callable[[], gym.Env]] = [
187
+ lambda i=i: _make_plain_single_env(
188
+ config,
189
+ i,
190
+ env_kwargs=clean_kwargs,
191
+ atari_preprocessing=atari_preprocessing,
192
+ )
193
+ for i in range(config.num_envs)
194
+ ]
195
+
196
+ obs_key = "observation_image" if is_ale_env(config.group_id) else "observation"
197
+ return build_vector_env_stack(
198
+ env_fns=env_fns,
199
+ group_id=resolved_group_id,
200
+ seed=config.seed,
201
+ max_steps_per_episode=config.max_episode_steps,
202
+ obs_key=obs_key,
203
+ reward_scale=config.reward_scale,
204
+ reward_shift=config.reward_shift,
205
+ q_star_source=resolved_q_star_source,
206
+ group_ids=resolved_group_ids,
207
+ )
208
+
209
+
210
+ def _build_ns_vector_env(
211
+ *,
212
+ config: EnvConfig,
213
+ resolved_group_id: str,
214
+ resolved_q_star_source: dict[str, Any] | None,
215
+ resolved_group_ids: list[str],
216
+ ) -> gym.vector.VectorEnv:
217
+ env_fns: list[Callable[[], gym.Env]] = [
218
+ lambda: _make_ns_single_env(config) for _ in range(config.num_envs)
219
+ ]
220
+ return build_vector_env_stack(
221
+ env_fns=env_fns,
222
+ group_id=resolved_group_id,
223
+ seed=config.seed,
224
+ max_steps_per_episode=config.max_episode_steps,
225
+ obs_key="observation",
226
+ reward_scale=config.reward_scale,
227
+ reward_shift=config.reward_shift,
228
+ q_star_source=resolved_q_star_source,
229
+ group_ids=resolved_group_ids,
230
+ )
mouse/envs/config.py ADDED
@@ -0,0 +1,188 @@
1
+ """Environment configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Any
7
+
8
+
9
+ DEFAULT_SB3_Q_STAR_CARTPOLE: dict[str, Any] = {
10
+ "provider": "sb3_rl_zoo",
11
+ "algo": "ppo",
12
+ "repo_id": "sb3/ppo-CartPole-v1",
13
+ "filename": "ppo-CartPole-v1.zip",
14
+ "deterministic": True,
15
+ }
16
+
17
+
18
+ def normalize_group_id(group_id: str) -> str:
19
+ """Strip the legacy ``NS-`` prefix before passing the id to ``gym.make``."""
20
+ if group_id.startswith("NS-"):
21
+ return group_id[3:]
22
+ return group_id
23
+
24
+
25
+ def is_ns_gym_env(
26
+ _group_id: str,
27
+ non_stationary_params: dict[str, Any] | None,
28
+ ) -> bool:
29
+ """Return ``True`` when the env should use the NS-Gym backend."""
30
+ return bool(non_stationary_params)
31
+
32
+
33
+ @dataclass
34
+ class EnvConfig:
35
+ """Configuration for building a vector environment via :func:`mouse.envs.make_vector_env`."""
36
+
37
+ group_id: str
38
+ seed: int
39
+ num_envs: int
40
+ max_episode_steps: int | None
41
+ kwargs: dict | None = None
42
+ render: bool = False
43
+ non_stationary_params: dict | None = None
44
+ q_star_source: dict[str, Any] | None = None
45
+ atari_preprocessing: bool | None = None
46
+ atari_preprocessing_kwargs: dict | None = None
47
+ observation_indices: list[int] | None = None
48
+ reward_scale: float = 1.0
49
+ reward_shift: float = 0.0
50
+ group_ids: list[str] | None = None
51
+
52
+ def build(self):
53
+ """Build a vector env from this config (alias for ``make_vector_env(self)``)."""
54
+ from mouse.envs.build import make_vector_env
55
+
56
+ return make_vector_env(self)
57
+
58
+ @classmethod
59
+ def cartpole(
60
+ cls,
61
+ *,
62
+ seed: int = 0,
63
+ num_envs: int = 1,
64
+ max_episode_steps: int = 500,
65
+ q_star_source: dict[str, Any] | None | object = ...,
66
+ observation_indices: list[int] | None = None,
67
+ **kwargs: Any,
68
+ ) -> EnvConfig:
69
+ """Preset for ``CartPole-v1`` with SB3 PPO Q* by default.
70
+
71
+ Pass ``q_star_source=None`` explicitly to disable expert metadata.
72
+ """
73
+ if q_star_source is ...:
74
+ resolved_q_star: dict[str, Any] | None = dict(DEFAULT_SB3_Q_STAR_CARTPOLE)
75
+ else:
76
+ resolved_q_star = q_star_source
77
+ return cls(
78
+ group_id="CartPole-v1",
79
+ seed=seed,
80
+ num_envs=num_envs,
81
+ max_episode_steps=max_episode_steps,
82
+ q_star_source=resolved_q_star,
83
+ observation_indices=observation_indices,
84
+ **kwargs,
85
+ )
86
+
87
+ @classmethod
88
+ def ns_cartpole(
89
+ cls,
90
+ *,
91
+ seed: int = 0,
92
+ num_envs: int = 1,
93
+ max_episode_steps: int = 500,
94
+ non_stationary_params: dict[str, Any] | None = None,
95
+ **kwargs: Any,
96
+ ) -> EnvConfig:
97
+ """Preset for non-stationary ``CartPole-v1`` with optional physics schedules."""
98
+ params = non_stationary_params if non_stationary_params is not None else {}
99
+ return cls(
100
+ group_id="CartPole-v1",
101
+ seed=seed,
102
+ num_envs=num_envs,
103
+ max_episode_steps=max_episode_steps,
104
+ non_stationary_params=params,
105
+ **kwargs,
106
+ )
107
+
108
+ @classmethod
109
+ def procedural_frozenlake(
110
+ cls,
111
+ *,
112
+ seed: int = 0,
113
+ num_envs: int = 1,
114
+ max_episode_steps: int = 200,
115
+ env_kwargs: dict[str, Any] | None = None,
116
+ q_star_source: dict[str, Any] | None = None,
117
+ **kwargs: Any,
118
+ ) -> EnvConfig:
119
+ """Preset for ``Procedural-FrozenLake-v1``."""
120
+ return cls(
121
+ group_id="Procedural-FrozenLake-v1",
122
+ seed=seed,
123
+ num_envs=num_envs,
124
+ max_episode_steps=max_episode_steps,
125
+ kwargs=env_kwargs,
126
+ q_star_source=q_star_source or {"provider": "metadata_q_star"},
127
+ **kwargs,
128
+ )
129
+
130
+ @classmethod
131
+ def synthetic(
132
+ cls,
133
+ *,
134
+ seed: int = 0,
135
+ num_envs: int = 1,
136
+ max_episode_steps: int = 200,
137
+ env_kwargs: dict[str, Any] | None = None,
138
+ q_star_source: dict[str, Any] | None = None,
139
+ **kwargs: Any,
140
+ ) -> EnvConfig:
141
+ """Preset for ``SyntheticEnv-v1``."""
142
+ return cls(
143
+ group_id="SyntheticEnv-v1",
144
+ seed=seed,
145
+ num_envs=num_envs,
146
+ max_episode_steps=max_episode_steps,
147
+ kwargs=env_kwargs,
148
+ q_star_source=q_star_source or {"provider": "metadata_q_star"},
149
+ **kwargs,
150
+ )
151
+
152
+ @classmethod
153
+ def atari(
154
+ cls,
155
+ env_id: str = "ALE/Pong-v5",
156
+ *,
157
+ seed: int = 0,
158
+ num_envs: int = 4,
159
+ max_episode_steps: int = 10000,
160
+ frame_skip: int = 4,
161
+ screen_size: int = 84,
162
+ noop_max: int = 30,
163
+ **kwargs: Any,
164
+ ) -> EnvConfig:
165
+ """Preset for ALE Atari envs — common ``AtariPreprocessing`` defaults only."""
166
+ return cls(
167
+ group_id=env_id,
168
+ seed=seed,
169
+ num_envs=num_envs,
170
+ max_episode_steps=max_episode_steps,
171
+ atari_preprocessing=True,
172
+ atari_preprocessing_kwargs={
173
+ "frame_skip": frame_skip,
174
+ "screen_size": screen_size,
175
+ "noop_max": noop_max,
176
+ },
177
+ **kwargs,
178
+ )
179
+
180
+
181
+ def resolve_q_star_source_for_env(
182
+ _group_id: str,
183
+ q_star_source: dict[str, Any] | None,
184
+ ) -> dict[str, Any] | None:
185
+ """Return the effective ``q_star_source`` config for an env."""
186
+ if q_star_source:
187
+ return dict(q_star_source)
188
+ return None
mouse/envs/env_ids.py ADDED
@@ -0,0 +1,4 @@
1
+ """Registered custom environment id constants."""
2
+
3
+ PROCEDURAL_FROZENLAKE_ENV_ID = "Procedural-FrozenLake-v1"
4
+ SYNTHETIC_ENV_ID = "SyntheticEnv-v1"
@@ -0,0 +1,13 @@
1
+ """Expert policies and MDP solvers for Q* metadata."""
2
+
3
+ from mouse.envs.experts.value_iteration import (
4
+ solve_tabular_mdp,
5
+ value_iteration_gymnasium_p,
6
+ value_iteration_tabular,
7
+ )
8
+
9
+ __all__ = [
10
+ "solve_tabular_mdp",
11
+ "value_iteration_tabular",
12
+ "value_iteration_gymnasium_p",
13
+ ]