safenax 0.4.4__tar.gz → 0.4.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {safenax-0.4.4 → safenax-0.4.6}/PKG-INFO +1 -1
- {safenax-0.4.4 → safenax-0.4.6}/pyproject.toml +1 -1
- {safenax-0.4.4 → safenax-0.4.6}/safenax/__init__.py +3 -2
- safenax-0.4.6/safenax/frozen_lake/__init__.py +8 -0
- safenax-0.4.4/safenax/frozen_lake.py → safenax-0.4.6/safenax/frozen_lake/frozen_lake_v1.py +5 -5
- safenax-0.4.6/safenax/frozen_lake/frozen_lake_v2.py +302 -0
- safenax-0.4.4/tests/test_frozen_lake.py → safenax-0.4.6/tests/test_frozen_lake_v1.py +17 -11
- safenax-0.4.6/tests/test_frozen_lake_v2.py +167 -0
- {safenax-0.4.4 → safenax-0.4.6}/uv.lock +1 -1
- {safenax-0.4.4 → safenax-0.4.6}/.gitignore +0 -0
- {safenax-0.4.4 → safenax-0.4.6}/.pre-commit-config.yaml +0 -0
- {safenax-0.4.4 → safenax-0.4.6}/.python-version +0 -0
- {safenax-0.4.4 → safenax-0.4.6}/LICENSE +0 -0
- {safenax-0.4.4 → safenax-0.4.6}/PUBLISHING.md +0 -0
- {safenax-0.4.4 → safenax-0.4.6}/README.md +0 -0
- {safenax-0.4.4 → safenax-0.4.6}/safenax/eco_ant/__init__.py +0 -0
- {safenax-0.4.4 → safenax-0.4.6}/safenax/eco_ant/eco_ant_v1.py +0 -0
- {safenax-0.4.4 → safenax-0.4.6}/safenax/eco_ant/eco_ant_v2.py +0 -0
- {safenax-0.4.4 → safenax-0.4.6}/safenax/fragile_ant.py +0 -0
- {safenax-0.4.4 → safenax-0.4.6}/safenax/portfolio_optimization/__init__.py +0 -0
- {safenax-0.4.4 → safenax-0.4.6}/safenax/portfolio_optimization/po_crypto.py +0 -0
- {safenax-0.4.4 → safenax-0.4.6}/safenax/portfolio_optimization/po_garch.py +0 -0
- {safenax-0.4.4 → safenax-0.4.6}/safenax/wrappers/__init__.py +0 -0
- {safenax-0.4.4 → safenax-0.4.6}/safenax/wrappers/brax.py +0 -0
- {safenax-0.4.4 → safenax-0.4.6}/safenax/wrappers/log.py +0 -0
- {safenax-0.4.4 → safenax-0.4.6}/scripts/setup_dev.sh +0 -0
- {safenax-0.4.4 → safenax-0.4.6}/tests/__init__.py +0 -0
- {safenax-0.4.4 → safenax-0.4.6}/tests/test_eco_ant_v1.py +0 -0
- {safenax-0.4.4 → safenax-0.4.6}/tests/test_eco_ant_v2.py +0 -0
|
@@ -8,7 +8,7 @@ from safenax.portfolio_optimization import (
|
|
|
8
8
|
PortfolioOptimizationGARCH,
|
|
9
9
|
PortfolioOptimizationCrypto,
|
|
10
10
|
)
|
|
11
|
-
from safenax.frozen_lake import
|
|
11
|
+
from safenax.frozen_lake import FrozenLakeV1, FrozenLakeV2
|
|
12
12
|
from safenax.eco_ant import EcoAntV1, EcoAntV2
|
|
13
13
|
|
|
14
14
|
|
|
@@ -16,7 +16,8 @@ __all__ = [
|
|
|
16
16
|
"FragileAnt",
|
|
17
17
|
"PortfolioOptimizationCrypto",
|
|
18
18
|
"PortfolioOptimizationGARCH",
|
|
19
|
-
"
|
|
19
|
+
"FrozenLakeV1",
|
|
20
|
+
"FrozenLakeV2",
|
|
20
21
|
"EcoAntV1",
|
|
21
22
|
"EcoAntV2",
|
|
22
23
|
]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""JAX-compatible
|
|
1
|
+
"""JAX-compatible FrozenLakeV1 environment following the gymnax interface."""
|
|
2
2
|
|
|
3
3
|
from typing import Optional, Tuple, Union, List
|
|
4
4
|
|
|
@@ -69,7 +69,7 @@ MAPS = {
|
|
|
69
69
|
}
|
|
70
70
|
|
|
71
71
|
|
|
72
|
-
class
|
|
72
|
+
class FrozenLakeV1(environment.Environment):
|
|
73
73
|
"""
|
|
74
74
|
JAX-compatible FrozenLake environment.
|
|
75
75
|
|
|
@@ -81,7 +81,7 @@ class FrozenLake(environment.Environment):
|
|
|
81
81
|
def __init__(
|
|
82
82
|
self,
|
|
83
83
|
map_name: str = "4x4",
|
|
84
|
-
desc: Optional[
|
|
84
|
+
desc: Optional[jax.Array] = None,
|
|
85
85
|
is_slippery: bool = True,
|
|
86
86
|
success_rate: float = 1.0 / 3.0,
|
|
87
87
|
reward_schedule: Tuple[float, float, float] = (1.0, 0.0, 0.0),
|
|
@@ -250,9 +250,9 @@ def make_frozen_lake(
|
|
|
250
250
|
is_slippery: bool = True,
|
|
251
251
|
success_rate: float = 1.0 / 3.0,
|
|
252
252
|
reward_schedule: Tuple[float, float, float] = (1.0, 0.0, 0.0),
|
|
253
|
-
) ->
|
|
253
|
+
) -> FrozenLakeV1:
|
|
254
254
|
"""Factory function to easier initialization."""
|
|
255
|
-
return
|
|
255
|
+
return FrozenLakeV1(
|
|
256
256
|
map_name=map_name,
|
|
257
257
|
is_slippery=is_slippery,
|
|
258
258
|
success_rate=success_rate,
|
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
"""JAX-compatible FrozenLakeV2 environment following the gymnax interface."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional, Tuple, Union, List
|
|
4
|
+
|
|
5
|
+
import chex
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
from flax import struct
|
|
9
|
+
from gymnax.environments import environment, spaces
|
|
10
|
+
|
|
11
|
+
# --- Constants ---
|
|
12
|
+
# Actions
|
|
13
|
+
LEFT = 0
|
|
14
|
+
DOWN = 1
|
|
15
|
+
RIGHT = 2
|
|
16
|
+
UP = 3
|
|
17
|
+
|
|
18
|
+
# Tile Types (using ASCII values for readability during debugging)
|
|
19
|
+
TILE_START = ord("S")
|
|
20
|
+
TILE_FROZEN = ord("F")
|
|
21
|
+
TILE_THIN = ord("T")
|
|
22
|
+
TILE_GOAL = ord("G")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@struct.dataclass
|
|
26
|
+
class EnvState:
|
|
27
|
+
"""Environment state for FrozenLake."""
|
|
28
|
+
|
|
29
|
+
pos: int
|
|
30
|
+
time: int
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@struct.dataclass
|
|
34
|
+
class EnvParams:
|
|
35
|
+
"""Environment parameters for FrozenLake."""
|
|
36
|
+
|
|
37
|
+
desc: jax.Array # Map description as integer array (nrow, ncol)
|
|
38
|
+
nrow: int
|
|
39
|
+
ncol: int
|
|
40
|
+
is_slippery: bool
|
|
41
|
+
success_rate: float
|
|
42
|
+
reward_schedule: jax.Array # (goal_reward, thin_reward, frozen_reward)
|
|
43
|
+
max_steps_in_episode: int
|
|
44
|
+
safe_cost_mean: float
|
|
45
|
+
safe_cost_std: float
|
|
46
|
+
thin_cost_base: float
|
|
47
|
+
thin_shock_prob: float
|
|
48
|
+
thin_shock_val: float
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# --- Helper Functions ---
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def string_map_to_array(map_desc: List[str]) -> jax.Array:
|
|
55
|
+
"""Convert a list of strings into a JAX-friendly integer array."""
|
|
56
|
+
return jnp.array([[ord(c) for c in row] for row in map_desc], dtype=jnp.int32)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# Predefined maps
|
|
60
|
+
MAPS = {
|
|
61
|
+
"4x4": string_map_to_array(["SFFF", "FTFT", "FFFT", "TFFG"]),
|
|
62
|
+
"8x8": string_map_to_array(
|
|
63
|
+
[
|
|
64
|
+
"SFFFFFFF",
|
|
65
|
+
"FFFFFFFF",
|
|
66
|
+
"FFFTFFFF",
|
|
67
|
+
"FFFFFTFF",
|
|
68
|
+
"FFFTFFFF",
|
|
69
|
+
"FTTFFFTF",
|
|
70
|
+
"FTFFTFTF",
|
|
71
|
+
"FFFTFFFG",
|
|
72
|
+
]
|
|
73
|
+
),
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class FrozenLakeV2(environment.Environment):
|
|
78
|
+
"""
|
|
79
|
+
JAX-compatible FrozenLake environment with "Thin Ice" cost augmentation.
|
|
80
|
+
|
|
81
|
+
The agent controls the movement of a character in a grid world to reach a Goal (G).
|
|
82
|
+
The map contains two types of traversable terrain that introduce a Mean-Variance trade-off:
|
|
83
|
+
|
|
84
|
+
1. Safe Ice ('F'):
|
|
85
|
+
- Represents a stable but arduous path (deep snow is hard to traverse).
|
|
86
|
+
- Cost: High Mean (e.g., 2.0), Small Variance (e.g., std=0.1).
|
|
87
|
+
- Risk: Deterministically expensive, but safe.
|
|
88
|
+
|
|
89
|
+
2. Thin Ice ('T', formerly Holes):
|
|
90
|
+
- Represents a dangerous shortcut, prone to cracking or falling but can glide over with little effort.
|
|
91
|
+
- Cost: Low Mean (e.g., 1.5), High Variance Shocks (e.g. std=4.35) (ice cracks).
|
|
92
|
+
- Risk: Usually cheap to traverse, but prone to random, catastrophic cost spikes (cracking ice).
|
|
93
|
+
|
|
94
|
+
The agent can also slip perpendicular to its intended direction based on the `is_slippery` and `success_rate` parameters.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
def __init__(
|
|
98
|
+
self,
|
|
99
|
+
map_name: str = "4x4",
|
|
100
|
+
desc: Optional[jax.Array] = None,
|
|
101
|
+
is_slippery: bool = True,
|
|
102
|
+
success_rate: float = 1.0 / 3.0,
|
|
103
|
+
reward_schedule: Tuple[float, float, float] = (1.0, 0.0, 0.0),
|
|
104
|
+
safe_cost_mean: float = 2.0,
|
|
105
|
+
safe_cost_std: float = 0.1,
|
|
106
|
+
thin_cost_base: float = 0.5,
|
|
107
|
+
thin_shock_prob: float = 0.05,
|
|
108
|
+
thin_shock_val: float = 20.0,
|
|
109
|
+
):
|
|
110
|
+
super().__init__()
|
|
111
|
+
|
|
112
|
+
if desc is None:
|
|
113
|
+
desc = MAPS[map_name]
|
|
114
|
+
|
|
115
|
+
self.desc = desc
|
|
116
|
+
self.nrow, self.ncol = desc.shape
|
|
117
|
+
self.is_slippery = is_slippery
|
|
118
|
+
self.success_rate = success_rate
|
|
119
|
+
self.reward_schedule = jnp.array(reward_schedule)
|
|
120
|
+
self.safe_cost_mean = safe_cost_mean
|
|
121
|
+
self.safe_cost_std = safe_cost_std
|
|
122
|
+
self.thin_cost_base = thin_cost_base
|
|
123
|
+
self.thin_shock_prob = thin_shock_prob
|
|
124
|
+
self.thin_shock_val = thin_shock_val
|
|
125
|
+
|
|
126
|
+
# Determine max steps based on map size convention
|
|
127
|
+
if map_name == "4x4" or (desc.shape[0] == 4 and desc.shape[1] == 4):
|
|
128
|
+
self.max_steps = 100
|
|
129
|
+
else:
|
|
130
|
+
self.max_steps = 200
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def default_params(self) -> EnvParams:
|
|
134
|
+
return EnvParams(
|
|
135
|
+
desc=self.desc,
|
|
136
|
+
nrow=self.nrow,
|
|
137
|
+
ncol=self.ncol,
|
|
138
|
+
is_slippery=self.is_slippery,
|
|
139
|
+
success_rate=self.success_rate,
|
|
140
|
+
reward_schedule=self.reward_schedule,
|
|
141
|
+
max_steps_in_episode=self.max_steps,
|
|
142
|
+
safe_cost_mean=self.safe_cost_mean,
|
|
143
|
+
safe_cost_std=self.safe_cost_std,
|
|
144
|
+
thin_cost_base=self.thin_cost_base,
|
|
145
|
+
thin_shock_prob=self.thin_shock_prob,
|
|
146
|
+
thin_shock_val=self.thin_shock_val,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def step_env(
|
|
150
|
+
self,
|
|
151
|
+
key: jax.Array,
|
|
152
|
+
state: EnvState,
|
|
153
|
+
action: Union[int, float],
|
|
154
|
+
params: EnvParams,
|
|
155
|
+
) -> Tuple[chex.Array, EnvState, float, bool, dict]:
|
|
156
|
+
"""Perform a single environment step with JIT compatibility."""
|
|
157
|
+
action = jnp.int32(action)
|
|
158
|
+
|
|
159
|
+
# 1. Determine the actual direction (handling slippery logic)
|
|
160
|
+
rng_slip, rng_cost, key = jax.random.split(key, 3)
|
|
161
|
+
|
|
162
|
+
def get_slippery_action(k):
|
|
163
|
+
# 0: intended, 1: perpendicular left, 2: perpendicular right
|
|
164
|
+
# Probabilities: [success, (1-success)/2, (1-success)/2]
|
|
165
|
+
fail_prob = (1.0 - params.success_rate) / 2.0
|
|
166
|
+
probs = jnp.array([params.success_rate, fail_prob, fail_prob])
|
|
167
|
+
|
|
168
|
+
# Map samples (0, 1, 2) to actions (action, action-1, action+1)
|
|
169
|
+
delta_idx = jax.random.choice(k, jnp.arange(3), p=probs)
|
|
170
|
+
|
|
171
|
+
# (action - 1) % 4 <-- Perpendicular Left
|
|
172
|
+
# (action) <-- Intended
|
|
173
|
+
# (action + 1) % 4 <-- Perpendicular Right
|
|
174
|
+
|
|
175
|
+
# Using a lookup array for cleaner mapping
|
|
176
|
+
candidates = jnp.array(
|
|
177
|
+
[
|
|
178
|
+
action, # Index 0: Success
|
|
179
|
+
(action - 1) % 4, # Index 1: Fail Left
|
|
180
|
+
(action + 1) % 4, # Index 2: Fail Right
|
|
181
|
+
]
|
|
182
|
+
)
|
|
183
|
+
return candidates[delta_idx]
|
|
184
|
+
|
|
185
|
+
actual_action = jax.lax.cond(
|
|
186
|
+
params.is_slippery,
|
|
187
|
+
get_slippery_action,
|
|
188
|
+
lambda k: action, # Deterministic branch
|
|
189
|
+
rng_slip,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# 2. Calculate Movement (Optimized Vectorized Approach)
|
|
193
|
+
row = state.pos // params.ncol
|
|
194
|
+
col = state.pos % params.ncol
|
|
195
|
+
|
|
196
|
+
next_row, next_col = self._apply_action(row, col, actual_action, params)
|
|
197
|
+
next_pos = next_row * params.ncol + next_col
|
|
198
|
+
|
|
199
|
+
# 3. Check Rewards and Termination
|
|
200
|
+
tile_type = params.desc[next_row, next_col]
|
|
201
|
+
|
|
202
|
+
is_goal = tile_type == TILE_GOAL
|
|
203
|
+
is_thin = tile_type == TILE_THIN
|
|
204
|
+
|
|
205
|
+
# Reward Schedule: [Goal, Thin, Frozen/Start]
|
|
206
|
+
# We select index 0, 1, or 2 based on tile type
|
|
207
|
+
tile_idx = jnp.where(is_goal, 0, jnp.where(is_thin, 1, 2))
|
|
208
|
+
reward = params.reward_schedule[tile_idx]
|
|
209
|
+
|
|
210
|
+
safe_noise = jax.random.normal(rng_cost) * params.safe_cost_std
|
|
211
|
+
cost_safe = params.safe_cost_mean + safe_noise
|
|
212
|
+
|
|
213
|
+
is_shock = jax.random.bernoulli(rng_cost, p=params.thin_shock_prob)
|
|
214
|
+
cost_thin = params.thin_cost_base + (is_shock * params.thin_shock_val)
|
|
215
|
+
|
|
216
|
+
cost = jnp.select([is_goal, is_thin], [0.0, cost_thin], default=cost_safe)
|
|
217
|
+
|
|
218
|
+
# 4. Update State
|
|
219
|
+
# Time limit truncation is handled in Gymnax wrappers usually,
|
|
220
|
+
# but we track it here for the 'done' flag consistency.
|
|
221
|
+
new_time = state.time + 1
|
|
222
|
+
truncated = new_time >= params.max_steps_in_episode
|
|
223
|
+
done = is_goal | truncated
|
|
224
|
+
|
|
225
|
+
new_state = EnvState(pos=next_pos, time=new_time)
|
|
226
|
+
|
|
227
|
+
return (
|
|
228
|
+
self.get_obs(new_state, params),
|
|
229
|
+
new_state,
|
|
230
|
+
reward,
|
|
231
|
+
done,
|
|
232
|
+
{"cost": cost, "tile_type": tile_type},
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
def reset_env(
|
|
236
|
+
self,
|
|
237
|
+
key: chex.PRNGKey,
|
|
238
|
+
params: EnvParams,
|
|
239
|
+
) -> Tuple[chex.Array, EnvState]:
|
|
240
|
+
"""Reset environment to the start position."""
|
|
241
|
+
# By definition, FrozenLake always starts at (0,0)
|
|
242
|
+
# If dynamic start positions are needed, one would scan params.desc for 'S'
|
|
243
|
+
state = EnvState(pos=0, time=0)
|
|
244
|
+
return self.get_obs(state, params), state
|
|
245
|
+
|
|
246
|
+
def get_obs(self, state: EnvState, params: EnvParams) -> chex.Array:
|
|
247
|
+
"""Return scalar observation (flat position index)."""
|
|
248
|
+
return jnp.array(state.pos, dtype=jnp.int32)
|
|
249
|
+
|
|
250
|
+
def _apply_action(
|
|
251
|
+
self, row: int, col: int, action: int, params: EnvParams
|
|
252
|
+
) -> Tuple[int, int]:
|
|
253
|
+
"""Optimized movement logic using delta arrays and clipping."""
|
|
254
|
+
# Gym Direction Mapping:
|
|
255
|
+
# 0: Left (0, -1)
|
|
256
|
+
# 1: Down (1, 0)
|
|
257
|
+
# 2: Right (0, 1)
|
|
258
|
+
# 3: Up (-1, 0)
|
|
259
|
+
|
|
260
|
+
dr = jnp.array([0, 1, 0, -1])
|
|
261
|
+
dc = jnp.array([-1, 0, 1, 0])
|
|
262
|
+
|
|
263
|
+
new_row = row + dr[action]
|
|
264
|
+
new_col = col + dc[action]
|
|
265
|
+
|
|
266
|
+
# Ensure we stay within the grid
|
|
267
|
+
new_row = jnp.clip(new_row, 0, params.nrow - 1)
|
|
268
|
+
new_col = jnp.clip(new_col, 0, params.ncol - 1)
|
|
269
|
+
|
|
270
|
+
return new_row, new_col
|
|
271
|
+
|
|
272
|
+
@property
|
|
273
|
+
def name(self) -> str:
|
|
274
|
+
return "FrozenLake-v2"
|
|
275
|
+
|
|
276
|
+
@property
|
|
277
|
+
def num_actions(self) -> int:
|
|
278
|
+
return 4
|
|
279
|
+
|
|
280
|
+
def action_space(self, params: Optional[EnvParams] = None) -> spaces.Discrete:
|
|
281
|
+
return spaces.Discrete(4)
|
|
282
|
+
|
|
283
|
+
def observation_space(self, params: EnvParams) -> spaces.Discrete:
|
|
284
|
+
return spaces.Discrete(params.nrow * params.ncol)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
# --- Factory Helper ---
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def make_frozen_lake(
|
|
291
|
+
map_name: str = "4x4",
|
|
292
|
+
is_slippery: bool = True,
|
|
293
|
+
success_rate: float = 1.0 / 3.0,
|
|
294
|
+
reward_schedule: Tuple[float, float, float] = (1.0, 0.0, 0.0),
|
|
295
|
+
) -> FrozenLakeV2:
|
|
296
|
+
"""Factory function to easier initialization."""
|
|
297
|
+
return FrozenLakeV2(
|
|
298
|
+
map_name=map_name,
|
|
299
|
+
is_slippery=is_slippery,
|
|
300
|
+
success_rate=success_rate,
|
|
301
|
+
reward_schedule=reward_schedule,
|
|
302
|
+
)
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import pytest
|
|
2
2
|
import jax
|
|
3
3
|
import jax.numpy as jnp
|
|
4
|
-
from safenax.frozen_lake import
|
|
5
|
-
from safenax.frozen_lake import (
|
|
4
|
+
from safenax.frozen_lake.frozen_lake_v1 import FrozenLakeV1, EnvParams
|
|
5
|
+
from safenax.frozen_lake.frozen_lake_v1 import (
|
|
6
6
|
TILE_START,
|
|
7
7
|
TILE_FROZEN,
|
|
8
8
|
TILE_HOLE,
|
|
@@ -16,7 +16,7 @@ from safenax.frozen_lake import (
|
|
|
16
16
|
@pytest.fixture
|
|
17
17
|
def env():
|
|
18
18
|
"""Initializes the environment."""
|
|
19
|
-
return
|
|
19
|
+
return FrozenLakeV1()
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
@pytest.fixture
|
|
@@ -58,7 +58,7 @@ def params_slippery(params_deterministic: EnvParams) -> EnvParams:
|
|
|
58
58
|
|
|
59
59
|
|
|
60
60
|
def test_jit_compilation(
|
|
61
|
-
env:
|
|
61
|
+
env: FrozenLakeV1, rng: jax.Array, params_deterministic: EnvParams
|
|
62
62
|
):
|
|
63
63
|
"""Verifies that reset and step can be JIT compiled without errors."""
|
|
64
64
|
reset_jit = jax.jit(env.reset)
|
|
@@ -85,7 +85,7 @@ def test_map_parsing():
|
|
|
85
85
|
|
|
86
86
|
|
|
87
87
|
def test_deterministic_reach_goal(
|
|
88
|
-
env:
|
|
88
|
+
env: FrozenLakeV1, rng: jax.Array, params_deterministic: EnvParams
|
|
89
89
|
):
|
|
90
90
|
"""
|
|
91
91
|
Path: Start(0) -> Right(1) -> Down(3/Goal).
|
|
@@ -119,7 +119,7 @@ def test_deterministic_reach_goal(
|
|
|
119
119
|
|
|
120
120
|
|
|
121
121
|
def test_deterministic_fall_in_hole(
|
|
122
|
-
env:
|
|
122
|
+
env: FrozenLakeV1, rng: jax.Array, params_deterministic: EnvParams
|
|
123
123
|
):
|
|
124
124
|
"""
|
|
125
125
|
Path: Start(0) -> Down(2/Hole).
|
|
@@ -142,7 +142,7 @@ def test_deterministic_fall_in_hole(
|
|
|
142
142
|
|
|
143
143
|
|
|
144
144
|
def test_custom_reward_schedule(
|
|
145
|
-
env:
|
|
145
|
+
env: FrozenLakeV1, rng: jax.Array, params_deterministic: EnvParams
|
|
146
146
|
):
|
|
147
147
|
"""
|
|
148
148
|
Verifies that changing reward_schedule affects the output.
|
|
@@ -167,7 +167,9 @@ def test_custom_reward_schedule(
|
|
|
167
167
|
assert reward == -5.0
|
|
168
168
|
|
|
169
169
|
|
|
170
|
-
def test_slippery_dynamics(
|
|
170
|
+
def test_slippery_dynamics(
|
|
171
|
+
env: FrozenLakeV1, rng: jax.Array, params_slippery: EnvParams
|
|
172
|
+
):
|
|
171
173
|
"""
|
|
172
174
|
Statistical test for slippery logic.
|
|
173
175
|
We start at (0,0) and try to move Right.
|
|
@@ -201,7 +203,7 @@ def test_slippery_dynamics(env: FrozenLake, rng: jax.Array, params_slippery: Env
|
|
|
201
203
|
assert 1 in unique_positions
|
|
202
204
|
|
|
203
205
|
|
|
204
|
-
def test_truncation(env:
|
|
206
|
+
def test_truncation(env: FrozenLakeV1, rng: jax.Array, params_deterministic: EnvParams):
|
|
205
207
|
"""Test that episode ends when time limit is reached."""
|
|
206
208
|
# Set max steps to 2
|
|
207
209
|
short_params = params_deterministic.replace(max_steps_in_episode=2)
|
|
@@ -224,7 +226,9 @@ def test_truncation(env: FrozenLake, rng: jax.Array, params_deterministic: EnvPa
|
|
|
224
226
|
assert state.time == 0
|
|
225
227
|
|
|
226
228
|
|
|
227
|
-
def test_jit_rollout_scan(
|
|
229
|
+
def test_jit_rollout_scan(
|
|
230
|
+
env: FrozenLakeV1, rng: jax.Array, params_slippery: EnvParams
|
|
231
|
+
):
|
|
228
232
|
"""
|
|
229
233
|
Verifies that the environment can be run in a fully compiled jax.lax.scan loop.
|
|
230
234
|
This ensures no shape mismatches or control flow issues exist in the step logic.
|
|
@@ -282,7 +286,9 @@ def test_jit_rollout_scan(env: FrozenLake, rng: jax.Array, params_slippery: EnvP
|
|
|
282
286
|
assert jnp.all(jnp.isin(rollout_data["reward"], jnp.array([0.0, 1.0])))
|
|
283
287
|
|
|
284
288
|
|
|
285
|
-
def test_cost_signal(
|
|
289
|
+
def test_cost_signal(
|
|
290
|
+
env: FrozenLakeV1, rng: jax.Array, params_deterministic: EnvParams
|
|
291
|
+
):
|
|
286
292
|
"""Ensure cost is 1.0 ONLY when falling into a hole."""
|
|
287
293
|
step_jit = jax.jit(env.step)
|
|
288
294
|
reset_jit = jax.jit(env.reset)
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
"""Unit tests for FrozenLakeV2 environment."""
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import pytest
|
|
6
|
+
from safenax.frozen_lake.frozen_lake_v2 import (
|
|
7
|
+
FrozenLakeV2,
|
|
8
|
+
EnvState,
|
|
9
|
+
TILE_FROZEN,
|
|
10
|
+
TILE_THIN,
|
|
11
|
+
TILE_GOAL,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
# Fix random key for reproducibility
|
|
15
|
+
KEY = jax.random.PRNGKey(42)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@pytest.fixture
|
|
19
|
+
def env():
|
|
20
|
+
"""Fixture to create a standard 4x4 FrozenLakeV2 environment."""
|
|
21
|
+
return FrozenLakeV2(map_name="4x4", is_slippery=True)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@pytest.fixture
|
|
25
|
+
def params(env):
|
|
26
|
+
"""Fixture for default environment parameters."""
|
|
27
|
+
return env.default_params
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def test_initialization(env, params):
|
|
31
|
+
"""Test that the environment initializes and resets correctly."""
|
|
32
|
+
key = jax.random.PRNGKey(0)
|
|
33
|
+
obs, state = env.reset_env(key, params)
|
|
34
|
+
|
|
35
|
+
# Check initial state
|
|
36
|
+
assert state.pos == 0
|
|
37
|
+
assert state.time == 0
|
|
38
|
+
assert obs.shape == () # Scalar observation
|
|
39
|
+
assert obs == 0
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def test_step_mechanics(env, params):
|
|
43
|
+
"""Test basic stepping functionality."""
|
|
44
|
+
key = jax.random.PRNGKey(0)
|
|
45
|
+
_, state = env.reset_env(key, params)
|
|
46
|
+
|
|
47
|
+
# Take an action (e.g., RIGHT = 2)
|
|
48
|
+
action = 2
|
|
49
|
+
obs, next_state, reward, done, info = env.step_env(key, state, action, params)
|
|
50
|
+
|
|
51
|
+
# Basic shape/type checks
|
|
52
|
+
assert isinstance(next_state, EnvState)
|
|
53
|
+
assert next_state.time == 1
|
|
54
|
+
assert "cost" in info
|
|
55
|
+
assert info["cost"] >= 0
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def test_slippery_dynamics(env, params):
|
|
59
|
+
"""
|
|
60
|
+
Test that the agent actually slips when is_slippery=True.
|
|
61
|
+
Starting at (1,1), taking action RIGHT (2) should stochastically result
|
|
62
|
+
in moving RIGHT, UP, or DOWN.
|
|
63
|
+
"""
|
|
64
|
+
key = jax.random.PRNGKey(0)
|
|
65
|
+
keys = jax.random.split(key, 100) # Run 100 steps
|
|
66
|
+
|
|
67
|
+
# Force state to be at (1,1) [pos=5 for 4x4] to allow movement in all directions
|
|
68
|
+
# Map:
|
|
69
|
+
# S F F F (0, 1, 2, 3)
|
|
70
|
+
# F T F T (4, 5, 6, 7)
|
|
71
|
+
state = EnvState(pos=5, time=0)
|
|
72
|
+
action = 2 # RIGHT
|
|
73
|
+
|
|
74
|
+
# Vmap the step function to run in parallel
|
|
75
|
+
def step_fn(k):
|
|
76
|
+
return env.step_env(k, state, action, params)[1].pos
|
|
77
|
+
|
|
78
|
+
next_positions = jax.vmap(step_fn)(keys)
|
|
79
|
+
|
|
80
|
+
# From 5 (1,1):
|
|
81
|
+
# Right -> 6 (1,2) [Intended]
|
|
82
|
+
# Down -> 9 (2,1) [Slip Right]
|
|
83
|
+
# Up -> 1 (0,1) [Slip Left]
|
|
84
|
+
|
|
85
|
+
unique_positions = jnp.unique(next_positions)
|
|
86
|
+
|
|
87
|
+
# Assert that we ended up in more than just the intended position
|
|
88
|
+
assert len(unique_positions) > 1
|
|
89
|
+
assert 6 in unique_positions # Intended
|
|
90
|
+
assert (9 in unique_positions) or (1 in unique_positions) # Slips
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def test_cost_mean_variance(env, params):
|
|
94
|
+
"""
|
|
95
|
+
CRITICAL TEST: Verify the Mean-Variance trade-off for VaR-CPO.
|
|
96
|
+
Safe Ice should have Mean~2.0, Std~0.1.
|
|
97
|
+
Thin Ice should have Mean~1.5, Std~4.35.
|
|
98
|
+
"""
|
|
99
|
+
N = 5000 # High sample count for statistical significance
|
|
100
|
+
key = jax.random.PRNGKey(0)
|
|
101
|
+
keys = jax.random.split(key, N)
|
|
102
|
+
|
|
103
|
+
# We will force the map description to be entirely one type to isolate the cost function logic
|
|
104
|
+
# regardless of where the agent moves.
|
|
105
|
+
|
|
106
|
+
def get_batch_costs(tile_type):
|
|
107
|
+
"""Helper to get costs for N steps on a map filled with `tile_type`."""
|
|
108
|
+
# Create a dummy map filled with the specific tile type
|
|
109
|
+
custom_desc = jnp.full((4, 4), tile_type, dtype=jnp.int32)
|
|
110
|
+
custom_params = params.replace(desc=custom_desc)
|
|
111
|
+
|
|
112
|
+
def step_fn(k):
|
|
113
|
+
# Start at 0, take action 0. Destination will be same tile type.
|
|
114
|
+
_, _, _, _, info = env.step_env(k, EnvState(0, 0), 0, custom_params)
|
|
115
|
+
return info["cost"]
|
|
116
|
+
|
|
117
|
+
return jax.vmap(step_fn)(keys)
|
|
118
|
+
|
|
119
|
+
# 1. Test Safe Ice ('F')
|
|
120
|
+
safe_costs = get_batch_costs(TILE_FROZEN)
|
|
121
|
+
|
|
122
|
+
# Check Safe Statistics: Mean ~ 2.0, Std ~ 0.1
|
|
123
|
+
assert jnp.abs(jnp.mean(safe_costs) - 2.0) < 0.1, (
|
|
124
|
+
f"Safe Mean {jnp.mean(safe_costs)} != 2.0"
|
|
125
|
+
)
|
|
126
|
+
assert jnp.std(safe_costs) < 0.2, f"Safe Std {jnp.std(safe_costs)} is too high"
|
|
127
|
+
|
|
128
|
+
# 2. Test Thin Ice ('T')
|
|
129
|
+
thin_costs = get_batch_costs(TILE_THIN)
|
|
130
|
+
|
|
131
|
+
# Check Thin Statistics: Mean ~ 1.5, Std ~ 4.35
|
|
132
|
+
assert jnp.abs(jnp.mean(thin_costs) - 1.5) < 0.1, (
|
|
133
|
+
f"Thin Mean {jnp.mean(thin_costs)} != 1.5"
|
|
134
|
+
)
|
|
135
|
+
assert jnp.abs(jnp.std(thin_costs) - 4.35) < 0.1, (
|
|
136
|
+
f"Thin Std {jnp.std(thin_costs)} is too low (expected ~4.35)"
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
print(
|
|
140
|
+
f"\nStats Verification:\nSafe: Mean={jnp.mean(safe_costs):.2f}, Std={jnp.std(safe_costs):.2f}"
|
|
141
|
+
)
|
|
142
|
+
print(f"Thin: Mean={jnp.mean(thin_costs):.2f}, Std={jnp.std(thin_costs):.2f}")
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def test_termination(env, params):
|
|
146
|
+
"""Test episode termination logic."""
|
|
147
|
+
key = jax.random.PRNGKey(0)
|
|
148
|
+
|
|
149
|
+
# 1. Test Goal Termination
|
|
150
|
+
# Force map to be all GOAL
|
|
151
|
+
goal_desc = jnp.full((4, 4), TILE_GOAL, dtype=jnp.int32)
|
|
152
|
+
goal_params = params.replace(desc=goal_desc)
|
|
153
|
+
|
|
154
|
+
# Agent takes a step on a Goal tile -> should be done
|
|
155
|
+
_, _, _, done, _ = env.step_env(key, EnvState(0, 0), 0, goal_params)
|
|
156
|
+
assert done == True
|
|
157
|
+
|
|
158
|
+
# 2. Test Time Truncation
|
|
159
|
+
# Set max_steps to 1
|
|
160
|
+
short_params = params.replace(max_steps_in_episode=1)
|
|
161
|
+
|
|
162
|
+
# Step 1: time becomes 1 -> truncated >= max_steps (1) -> done=True
|
|
163
|
+
state = EnvState(0, 0)
|
|
164
|
+
_, next_state, _, done, _ = env.step_env(key, state, 0, short_params)
|
|
165
|
+
|
|
166
|
+
assert next_state.time == 1
|
|
167
|
+
assert done == True
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|