mouse-env 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mouse/envs/__init__.py +11 -0
- mouse/envs/build.py +230 -0
- mouse/envs/config.py +188 -0
- mouse/envs/env_ids.py +4 -0
- mouse/envs/experts/__init__.py +13 -0
- mouse/envs/experts/action_star.py +678 -0
- mouse/envs/experts/value_iteration.py +180 -0
- mouse/envs/format.py +287 -0
- mouse/envs/integrations/__init__.py +0 -0
- mouse/envs/integrations/atari.py +33 -0
- mouse/envs/integrations/ns_gym.py +320 -0
- mouse/envs/utils.py +25 -0
- mouse/envs/worlds/__init__.py +0 -0
- mouse/envs/worlds/procedural_frozenlake.py +518 -0
- mouse/envs/worlds/synthetic.py +449 -0
- mouse/envs/wrappers.py +384 -0
- mouse_env-0.3.0.dist-info/METADATA +181 -0
- mouse_env-0.3.0.dist-info/RECORD +21 -0
- mouse_env-0.3.0.dist-info/WHEEL +5 -0
- mouse_env-0.3.0.dist-info/licenses/LICENSE +674 -0
- mouse_env-0.3.0.dist-info/top_level.txt +1 -0
mouse/envs/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""MOUSE environments — vector envs and rollout formatting for mouse-core."""
|
|
2
|
+
|
|
3
|
+
from mouse.envs.build import make_vector_env
|
|
4
|
+
from mouse.envs.config import EnvConfig
|
|
5
|
+
from mouse.envs.format import MouseVectorEnv
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"EnvConfig",
|
|
9
|
+
"make_vector_env",
|
|
10
|
+
"MouseVectorEnv",
|
|
11
|
+
]
|
mouse/envs/build.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
"""Build vector environments from :class:`EnvConfig`."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import gymnasium as gym
|
|
9
|
+
|
|
10
|
+
from mouse.envs.config import (
|
|
11
|
+
EnvConfig,
|
|
12
|
+
is_ns_gym_env,
|
|
13
|
+
normalize_group_id,
|
|
14
|
+
resolve_q_star_source_for_env,
|
|
15
|
+
)
|
|
16
|
+
from mouse.envs.experts.action_star import apply_q_star_source_env_kwargs
|
|
17
|
+
from mouse.envs.integrations.atari import (
|
|
18
|
+
ensure_ale_registered,
|
|
19
|
+
is_ale_env,
|
|
20
|
+
wrap_atari_preprocessing,
|
|
21
|
+
)
|
|
22
|
+
from mouse.envs.integrations.ns_gym import make_ns_env
|
|
23
|
+
from mouse.envs.env_ids import PROCEDURAL_FROZENLAKE_ENV_ID, SYNTHETIC_ENV_ID
|
|
24
|
+
from mouse.envs.format import MouseVectorEnv
|
|
25
|
+
from mouse.envs.wrappers import (
|
|
26
|
+
ConstructionSeedWrapper,
|
|
27
|
+
ObservationSliceWrapper,
|
|
28
|
+
build_vector_env_stack,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _resolve_group_ids(group_id: str, num_envs: int, group_ids: list[str] | None) -> list[str]:
|
|
33
|
+
if group_ids is not None:
|
|
34
|
+
if len(group_ids) != num_envs:
|
|
35
|
+
raise ValueError(
|
|
36
|
+
f"group_ids has {len(group_ids)} entries but num_envs={num_envs}. "
|
|
37
|
+
f"Provide exactly one id per vector index."
|
|
38
|
+
)
|
|
39
|
+
seen: dict[str, list[int]] = {}
|
|
40
|
+
for i, gid in enumerate(group_ids):
|
|
41
|
+
seen.setdefault(gid, []).append(i)
|
|
42
|
+
duplicates = {k: v for k, v in seen.items() if len(v) > 1}
|
|
43
|
+
if duplicates:
|
|
44
|
+
raise ValueError(
|
|
45
|
+
f"group_id must be unique per vector index — found duplicates: {duplicates}. "
|
|
46
|
+
f"Use a unique suffix per index (e.g. '{group_id}#0', '{group_id}#1')."
|
|
47
|
+
)
|
|
48
|
+
return list(group_ids)
|
|
49
|
+
|
|
50
|
+
if not group_id:
|
|
51
|
+
raise ValueError(
|
|
52
|
+
"group_id is required on EnvConfig but was not set. "
|
|
53
|
+
"Provide a non-empty group_id (e.g. 'CartPole-v1')."
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
if num_envs == 1:
|
|
57
|
+
return [group_id]
|
|
58
|
+
|
|
59
|
+
return [f"{group_id}#{i}" for i in range(num_envs)]
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def make_vector_env(config: EnvConfig) -> MouseVectorEnv:
|
|
63
|
+
"""Create a vector env from :class:`EnvConfig`."""
|
|
64
|
+
if config.max_episode_steps is None:
|
|
65
|
+
raise ValueError(
|
|
66
|
+
"max_episode_steps is required (used to normalise xformed_reward). "
|
|
67
|
+
"Set max_episode_steps in the env config."
|
|
68
|
+
)
|
|
69
|
+
if config.num_envs < 1:
|
|
70
|
+
raise ValueError(f"num_envs must be >= 1, got {config.num_envs}.")
|
|
71
|
+
|
|
72
|
+
resolved_group_id = normalize_group_id(config.group_id)
|
|
73
|
+
resolved_q_star_source = resolve_q_star_source_for_env(config.group_id, config.q_star_source)
|
|
74
|
+
resolved_group_ids = _resolve_group_ids(resolved_group_id, config.num_envs, config.group_ids)
|
|
75
|
+
|
|
76
|
+
if is_ns_gym_env(config.group_id, config.non_stationary_params):
|
|
77
|
+
gym_env = _build_ns_vector_env(
|
|
78
|
+
config=config,
|
|
79
|
+
resolved_group_id=resolved_group_id,
|
|
80
|
+
resolved_q_star_source=resolved_q_star_source,
|
|
81
|
+
resolved_group_ids=resolved_group_ids,
|
|
82
|
+
)
|
|
83
|
+
else:
|
|
84
|
+
gym_env = _build_plain_vector_env(
|
|
85
|
+
config=config,
|
|
86
|
+
resolved_group_id=resolved_group_id,
|
|
87
|
+
resolved_q_star_source=resolved_q_star_source,
|
|
88
|
+
resolved_group_ids=resolved_group_ids,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
return MouseVectorEnv(gym_env, resolved_group_ids)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _prepare_plain_env_kwargs(config: EnvConfig, *, atari_preprocessing: bool) -> dict[str, Any]:
|
|
95
|
+
env_kwargs = dict(config.kwargs or {})
|
|
96
|
+
env_kwargs = apply_q_star_source_env_kwargs(
|
|
97
|
+
env_id=config.group_id,
|
|
98
|
+
env_kwargs=env_kwargs,
|
|
99
|
+
q_star_source=config.q_star_source,
|
|
100
|
+
)
|
|
101
|
+
if config.group_id == PROCEDURAL_FROZENLAKE_ENV_ID:
|
|
102
|
+
from mouse.envs.worlds.procedural_frozenlake import ensure_procedural_frozenlake_registered
|
|
103
|
+
|
|
104
|
+
ensure_procedural_frozenlake_registered()
|
|
105
|
+
random_map_wrapper_raw = env_kwargs.pop("random_map_wrapper", None)
|
|
106
|
+
if isinstance(random_map_wrapper_raw, dict):
|
|
107
|
+
env_kwargs.update(dict(random_map_wrapper_raw))
|
|
108
|
+
elif random_map_wrapper_raw is not None:
|
|
109
|
+
raise ValueError("env_kwargs.random_map_wrapper must be a dict when provided.")
|
|
110
|
+
elif config.group_id == SYNTHETIC_ENV_ID:
|
|
111
|
+
from mouse.envs.worlds.synthetic import ensure_synthetic_env_registered
|
|
112
|
+
|
|
113
|
+
ensure_synthetic_env_registered()
|
|
114
|
+
if config.render and "render_mode" not in env_kwargs:
|
|
115
|
+
env_kwargs["render_mode"] = "human"
|
|
116
|
+
if is_ale_env(config.group_id):
|
|
117
|
+
ensure_ale_registered()
|
|
118
|
+
if is_ale_env(config.group_id) and atari_preprocessing:
|
|
119
|
+
env_kwargs["frameskip"] = 1
|
|
120
|
+
if config.observation_indices is not None and is_ale_env(config.group_id):
|
|
121
|
+
raise ValueError("observation_indices is not supported for ALE (Atari) envs.")
|
|
122
|
+
return env_kwargs
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _make_plain_single_env(
|
|
126
|
+
config: EnvConfig,
|
|
127
|
+
index: int,
|
|
128
|
+
*,
|
|
129
|
+
env_kwargs: dict[str, Any],
|
|
130
|
+
atari_preprocessing: bool,
|
|
131
|
+
) -> gym.Env:
|
|
132
|
+
mdp_seed = config.seed + index
|
|
133
|
+
seeded_at_construction = config.group_id in (SYNTHETIC_ENV_ID, PROCEDURAL_FROZENLAKE_ENV_ID)
|
|
134
|
+
use_preprocessing = is_ale_env(config.group_id) and atari_preprocessing
|
|
135
|
+
|
|
136
|
+
def env_fn(s: int) -> gym.Env:
|
|
137
|
+
kw = dict(env_kwargs)
|
|
138
|
+
if seeded_at_construction:
|
|
139
|
+
kw["seed"] = s
|
|
140
|
+
env = gym.make(
|
|
141
|
+
config.group_id,
|
|
142
|
+
max_episode_steps=config.max_episode_steps,
|
|
143
|
+
**kw,
|
|
144
|
+
)
|
|
145
|
+
return wrap_atari_preprocessing(
|
|
146
|
+
env,
|
|
147
|
+
enabled=use_preprocessing,
|
|
148
|
+
preprocessing_kwargs=config.atari_preprocessing_kwargs,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
env = ConstructionSeedWrapper(env_fn, seed=mdp_seed)
|
|
152
|
+
if config.observation_indices is not None:
|
|
153
|
+
env = ObservationSliceWrapper(env=env, indices=config.observation_indices)
|
|
154
|
+
return env
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _make_ns_single_env(config: EnvConfig) -> gym.Env:
|
|
158
|
+
env = make_ns_env(
|
|
159
|
+
env_id=config.group_id,
|
|
160
|
+
non_stationary_params=config.non_stationary_params or {},
|
|
161
|
+
max_steps_per_episode=config.max_episode_steps,
|
|
162
|
+
env_kwargs=config.kwargs,
|
|
163
|
+
render=config.render,
|
|
164
|
+
)
|
|
165
|
+
if config.observation_indices is not None:
|
|
166
|
+
env = ObservationSliceWrapper(env=env, indices=config.observation_indices)
|
|
167
|
+
return env
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _build_plain_vector_env(
|
|
171
|
+
*,
|
|
172
|
+
config: EnvConfig,
|
|
173
|
+
resolved_group_id: str,
|
|
174
|
+
resolved_q_star_source: dict[str, Any] | None,
|
|
175
|
+
resolved_group_ids: list[str],
|
|
176
|
+
) -> gym.vector.VectorEnv:
|
|
177
|
+
atari_preprocessing = bool(config.atari_preprocessing)
|
|
178
|
+
env_kwargs = _prepare_plain_env_kwargs(config, atari_preprocessing=atari_preprocessing)
|
|
179
|
+
|
|
180
|
+
if config.group_id in (SYNTHETIC_ENV_ID, PROCEDURAL_FROZENLAKE_ENV_ID):
|
|
181
|
+
clean_kwargs = dict(env_kwargs)
|
|
182
|
+
clean_kwargs.pop("seed", None)
|
|
183
|
+
else:
|
|
184
|
+
clean_kwargs = env_kwargs
|
|
185
|
+
|
|
186
|
+
env_fns: list[Callable[[], gym.Env]] = [
|
|
187
|
+
lambda i=i: _make_plain_single_env(
|
|
188
|
+
config,
|
|
189
|
+
i,
|
|
190
|
+
env_kwargs=clean_kwargs,
|
|
191
|
+
atari_preprocessing=atari_preprocessing,
|
|
192
|
+
)
|
|
193
|
+
for i in range(config.num_envs)
|
|
194
|
+
]
|
|
195
|
+
|
|
196
|
+
obs_key = "observation_image" if is_ale_env(config.group_id) else "observation"
|
|
197
|
+
return build_vector_env_stack(
|
|
198
|
+
env_fns=env_fns,
|
|
199
|
+
group_id=resolved_group_id,
|
|
200
|
+
seed=config.seed,
|
|
201
|
+
max_steps_per_episode=config.max_episode_steps,
|
|
202
|
+
obs_key=obs_key,
|
|
203
|
+
reward_scale=config.reward_scale,
|
|
204
|
+
reward_shift=config.reward_shift,
|
|
205
|
+
q_star_source=resolved_q_star_source,
|
|
206
|
+
group_ids=resolved_group_ids,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _build_ns_vector_env(
|
|
211
|
+
*,
|
|
212
|
+
config: EnvConfig,
|
|
213
|
+
resolved_group_id: str,
|
|
214
|
+
resolved_q_star_source: dict[str, Any] | None,
|
|
215
|
+
resolved_group_ids: list[str],
|
|
216
|
+
) -> gym.vector.VectorEnv:
|
|
217
|
+
env_fns: list[Callable[[], gym.Env]] = [
|
|
218
|
+
lambda: _make_ns_single_env(config) for _ in range(config.num_envs)
|
|
219
|
+
]
|
|
220
|
+
return build_vector_env_stack(
|
|
221
|
+
env_fns=env_fns,
|
|
222
|
+
group_id=resolved_group_id,
|
|
223
|
+
seed=config.seed,
|
|
224
|
+
max_steps_per_episode=config.max_episode_steps,
|
|
225
|
+
obs_key="observation",
|
|
226
|
+
reward_scale=config.reward_scale,
|
|
227
|
+
reward_shift=config.reward_shift,
|
|
228
|
+
q_star_source=resolved_q_star_source,
|
|
229
|
+
group_ids=resolved_group_ids,
|
|
230
|
+
)
|
mouse/envs/config.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""Environment configuration."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
DEFAULT_SB3_Q_STAR_CARTPOLE: dict[str, Any] = {
|
|
10
|
+
"provider": "sb3_rl_zoo",
|
|
11
|
+
"algo": "ppo",
|
|
12
|
+
"repo_id": "sb3/ppo-CartPole-v1",
|
|
13
|
+
"filename": "ppo-CartPole-v1.zip",
|
|
14
|
+
"deterministic": True,
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def normalize_group_id(group_id: str) -> str:
|
|
19
|
+
"""Strip the legacy ``NS-`` prefix before passing the id to ``gym.make``."""
|
|
20
|
+
if group_id.startswith("NS-"):
|
|
21
|
+
return group_id[3:]
|
|
22
|
+
return group_id
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def is_ns_gym_env(
|
|
26
|
+
_group_id: str,
|
|
27
|
+
non_stationary_params: dict[str, Any] | None,
|
|
28
|
+
) -> bool:
|
|
29
|
+
"""Return ``True`` when the env should use the NS-Gym backend."""
|
|
30
|
+
return bool(non_stationary_params)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class EnvConfig:
|
|
35
|
+
"""Configuration for building a vector environment via :func:`mouse.envs.make_vector_env`."""
|
|
36
|
+
|
|
37
|
+
group_id: str
|
|
38
|
+
seed: int
|
|
39
|
+
num_envs: int
|
|
40
|
+
max_episode_steps: int | None
|
|
41
|
+
kwargs: dict | None = None
|
|
42
|
+
render: bool = False
|
|
43
|
+
non_stationary_params: dict | None = None
|
|
44
|
+
q_star_source: dict[str, Any] | None = None
|
|
45
|
+
atari_preprocessing: bool | None = None
|
|
46
|
+
atari_preprocessing_kwargs: dict | None = None
|
|
47
|
+
observation_indices: list[int] | None = None
|
|
48
|
+
reward_scale: float = 1.0
|
|
49
|
+
reward_shift: float = 0.0
|
|
50
|
+
group_ids: list[str] | None = None
|
|
51
|
+
|
|
52
|
+
def build(self):
|
|
53
|
+
"""Build a vector env from this config (alias for ``make_vector_env(self)``)."""
|
|
54
|
+
from mouse.envs.build import make_vector_env
|
|
55
|
+
|
|
56
|
+
return make_vector_env(self)
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def cartpole(
|
|
60
|
+
cls,
|
|
61
|
+
*,
|
|
62
|
+
seed: int = 0,
|
|
63
|
+
num_envs: int = 1,
|
|
64
|
+
max_episode_steps: int = 500,
|
|
65
|
+
q_star_source: dict[str, Any] | None | object = ...,
|
|
66
|
+
observation_indices: list[int] | None = None,
|
|
67
|
+
**kwargs: Any,
|
|
68
|
+
) -> EnvConfig:
|
|
69
|
+
"""Preset for ``CartPole-v1`` with SB3 PPO Q* by default.
|
|
70
|
+
|
|
71
|
+
Pass ``q_star_source=None`` explicitly to disable expert metadata.
|
|
72
|
+
"""
|
|
73
|
+
if q_star_source is ...:
|
|
74
|
+
resolved_q_star: dict[str, Any] | None = dict(DEFAULT_SB3_Q_STAR_CARTPOLE)
|
|
75
|
+
else:
|
|
76
|
+
resolved_q_star = q_star_source
|
|
77
|
+
return cls(
|
|
78
|
+
group_id="CartPole-v1",
|
|
79
|
+
seed=seed,
|
|
80
|
+
num_envs=num_envs,
|
|
81
|
+
max_episode_steps=max_episode_steps,
|
|
82
|
+
q_star_source=resolved_q_star,
|
|
83
|
+
observation_indices=observation_indices,
|
|
84
|
+
**kwargs,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
@classmethod
|
|
88
|
+
def ns_cartpole(
|
|
89
|
+
cls,
|
|
90
|
+
*,
|
|
91
|
+
seed: int = 0,
|
|
92
|
+
num_envs: int = 1,
|
|
93
|
+
max_episode_steps: int = 500,
|
|
94
|
+
non_stationary_params: dict[str, Any] | None = None,
|
|
95
|
+
**kwargs: Any,
|
|
96
|
+
) -> EnvConfig:
|
|
97
|
+
"""Preset for non-stationary ``CartPole-v1`` with optional physics schedules."""
|
|
98
|
+
params = non_stationary_params if non_stationary_params is not None else {}
|
|
99
|
+
return cls(
|
|
100
|
+
group_id="CartPole-v1",
|
|
101
|
+
seed=seed,
|
|
102
|
+
num_envs=num_envs,
|
|
103
|
+
max_episode_steps=max_episode_steps,
|
|
104
|
+
non_stationary_params=params,
|
|
105
|
+
**kwargs,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
@classmethod
|
|
109
|
+
def procedural_frozenlake(
|
|
110
|
+
cls,
|
|
111
|
+
*,
|
|
112
|
+
seed: int = 0,
|
|
113
|
+
num_envs: int = 1,
|
|
114
|
+
max_episode_steps: int = 200,
|
|
115
|
+
env_kwargs: dict[str, Any] | None = None,
|
|
116
|
+
q_star_source: dict[str, Any] | None = None,
|
|
117
|
+
**kwargs: Any,
|
|
118
|
+
) -> EnvConfig:
|
|
119
|
+
"""Preset for ``Procedural-FrozenLake-v1``."""
|
|
120
|
+
return cls(
|
|
121
|
+
group_id="Procedural-FrozenLake-v1",
|
|
122
|
+
seed=seed,
|
|
123
|
+
num_envs=num_envs,
|
|
124
|
+
max_episode_steps=max_episode_steps,
|
|
125
|
+
kwargs=env_kwargs,
|
|
126
|
+
q_star_source=q_star_source or {"provider": "metadata_q_star"},
|
|
127
|
+
**kwargs,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
@classmethod
|
|
131
|
+
def synthetic(
|
|
132
|
+
cls,
|
|
133
|
+
*,
|
|
134
|
+
seed: int = 0,
|
|
135
|
+
num_envs: int = 1,
|
|
136
|
+
max_episode_steps: int = 200,
|
|
137
|
+
env_kwargs: dict[str, Any] | None = None,
|
|
138
|
+
q_star_source: dict[str, Any] | None = None,
|
|
139
|
+
**kwargs: Any,
|
|
140
|
+
) -> EnvConfig:
|
|
141
|
+
"""Preset for ``SyntheticEnv-v1``."""
|
|
142
|
+
return cls(
|
|
143
|
+
group_id="SyntheticEnv-v1",
|
|
144
|
+
seed=seed,
|
|
145
|
+
num_envs=num_envs,
|
|
146
|
+
max_episode_steps=max_episode_steps,
|
|
147
|
+
kwargs=env_kwargs,
|
|
148
|
+
q_star_source=q_star_source or {"provider": "metadata_q_star"},
|
|
149
|
+
**kwargs,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
@classmethod
|
|
153
|
+
def atari(
|
|
154
|
+
cls,
|
|
155
|
+
env_id: str = "ALE/Pong-v5",
|
|
156
|
+
*,
|
|
157
|
+
seed: int = 0,
|
|
158
|
+
num_envs: int = 4,
|
|
159
|
+
max_episode_steps: int = 10000,
|
|
160
|
+
frame_skip: int = 4,
|
|
161
|
+
screen_size: int = 84,
|
|
162
|
+
noop_max: int = 30,
|
|
163
|
+
**kwargs: Any,
|
|
164
|
+
) -> EnvConfig:
|
|
165
|
+
"""Preset for ALE Atari envs — common ``AtariPreprocessing`` defaults only."""
|
|
166
|
+
return cls(
|
|
167
|
+
group_id=env_id,
|
|
168
|
+
seed=seed,
|
|
169
|
+
num_envs=num_envs,
|
|
170
|
+
max_episode_steps=max_episode_steps,
|
|
171
|
+
atari_preprocessing=True,
|
|
172
|
+
atari_preprocessing_kwargs={
|
|
173
|
+
"frame_skip": frame_skip,
|
|
174
|
+
"screen_size": screen_size,
|
|
175
|
+
"noop_max": noop_max,
|
|
176
|
+
},
|
|
177
|
+
**kwargs,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def resolve_q_star_source_for_env(
|
|
182
|
+
_group_id: str,
|
|
183
|
+
q_star_source: dict[str, Any] | None,
|
|
184
|
+
) -> dict[str, Any] | None:
|
|
185
|
+
"""Return the effective ``q_star_source`` config for an env."""
|
|
186
|
+
if q_star_source:
|
|
187
|
+
return dict(q_star_source)
|
|
188
|
+
return None
|
mouse/envs/env_ids.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Expert policies and MDP solvers for Q* metadata."""
|
|
2
|
+
|
|
3
|
+
from mouse.envs.experts.value_iteration import (
|
|
4
|
+
solve_tabular_mdp,
|
|
5
|
+
value_iteration_gymnasium_p,
|
|
6
|
+
value_iteration_tabular,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"solve_tabular_mdp",
|
|
11
|
+
"value_iteration_tabular",
|
|
12
|
+
"value_iteration_gymnasium_p",
|
|
13
|
+
]
|