safenax 0.4.4__tar.gz → 0.4.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {safenax-0.4.4 → safenax-0.4.6}/PKG-INFO +1 -1
  2. {safenax-0.4.4 → safenax-0.4.6}/pyproject.toml +1 -1
  3. {safenax-0.4.4 → safenax-0.4.6}/safenax/__init__.py +3 -2
  4. safenax-0.4.6/safenax/frozen_lake/__init__.py +8 -0
  5. safenax-0.4.4/safenax/frozen_lake.py → safenax-0.4.6/safenax/frozen_lake/frozen_lake_v1.py +5 -5
  6. safenax-0.4.6/safenax/frozen_lake/frozen_lake_v2.py +302 -0
  7. safenax-0.4.4/tests/test_frozen_lake.py → safenax-0.4.6/tests/test_frozen_lake_v1.py +17 -11
  8. safenax-0.4.6/tests/test_frozen_lake_v2.py +167 -0
  9. {safenax-0.4.4 → safenax-0.4.6}/uv.lock +1 -1
  10. {safenax-0.4.4 → safenax-0.4.6}/.gitignore +0 -0
  11. {safenax-0.4.4 → safenax-0.4.6}/.pre-commit-config.yaml +0 -0
  12. {safenax-0.4.4 → safenax-0.4.6}/.python-version +0 -0
  13. {safenax-0.4.4 → safenax-0.4.6}/LICENSE +0 -0
  14. {safenax-0.4.4 → safenax-0.4.6}/PUBLISHING.md +0 -0
  15. {safenax-0.4.4 → safenax-0.4.6}/README.md +0 -0
  16. {safenax-0.4.4 → safenax-0.4.6}/safenax/eco_ant/__init__.py +0 -0
  17. {safenax-0.4.4 → safenax-0.4.6}/safenax/eco_ant/eco_ant_v1.py +0 -0
  18. {safenax-0.4.4 → safenax-0.4.6}/safenax/eco_ant/eco_ant_v2.py +0 -0
  19. {safenax-0.4.4 → safenax-0.4.6}/safenax/fragile_ant.py +0 -0
  20. {safenax-0.4.4 → safenax-0.4.6}/safenax/portfolio_optimization/__init__.py +0 -0
  21. {safenax-0.4.4 → safenax-0.4.6}/safenax/portfolio_optimization/po_crypto.py +0 -0
  22. {safenax-0.4.4 → safenax-0.4.6}/safenax/portfolio_optimization/po_garch.py +0 -0
  23. {safenax-0.4.4 → safenax-0.4.6}/safenax/wrappers/__init__.py +0 -0
  24. {safenax-0.4.4 → safenax-0.4.6}/safenax/wrappers/brax.py +0 -0
  25. {safenax-0.4.4 → safenax-0.4.6}/safenax/wrappers/log.py +0 -0
  26. {safenax-0.4.4 → safenax-0.4.6}/scripts/setup_dev.sh +0 -0
  27. {safenax-0.4.4 → safenax-0.4.6}/tests/__init__.py +0 -0
  28. {safenax-0.4.4 → safenax-0.4.6}/tests/test_eco_ant_v1.py +0 -0
  29. {safenax-0.4.4 → safenax-0.4.6}/tests/test_eco_ant_v2.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: safenax
3
- Version: 0.4.4
3
+ Version: 0.4.6
4
4
  Summary: Constrained environments with a gymnax interface
5
5
  Project-URL: Homepage, https://github.com/0xprofessooor/safenax
6
6
  Project-URL: Repository, https://github.com/0xprofessooor/safenax
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "safenax"
7
- version = "0.4.4"
7
+ version = "0.4.6"
8
8
  description = "Constrained environments with a gymnax interface"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.11"
@@ -8,7 +8,7 @@ from safenax.portfolio_optimization import (
8
8
  PortfolioOptimizationGARCH,
9
9
  PortfolioOptimizationCrypto,
10
10
  )
11
- from safenax.frozen_lake import FrozenLake
11
+ from safenax.frozen_lake import FrozenLakeV1, FrozenLakeV2
12
12
  from safenax.eco_ant import EcoAntV1, EcoAntV2
13
13
 
14
14
 
@@ -16,7 +16,8 @@ __all__ = [
16
16
  "FragileAnt",
17
17
  "PortfolioOptimizationCrypto",
18
18
  "PortfolioOptimizationGARCH",
19
- "FrozenLake",
19
+ "FrozenLakeV1",
20
+ "FrozenLakeV2",
20
21
  "EcoAntV1",
21
22
  "EcoAntV2",
22
23
  ]
@@ -0,0 +1,8 @@
1
+ from safenax.frozen_lake.frozen_lake_v1 import FrozenLakeV1
2
+ from safenax.frozen_lake.frozen_lake_v2 import FrozenLakeV2
3
+
4
+
5
+ __all__ = [
6
+ "FrozenLakeV1",
7
+ "FrozenLakeV2",
8
+ ]
@@ -1,4 +1,4 @@
1
- """JAX-compatible FrozenLake environment following the gymnax interface."""
1
+ """JAX-compatible FrozenLakeV1 environment following the gymnax interface."""
2
2
 
3
3
  from typing import Optional, Tuple, Union, List
4
4
 
@@ -69,7 +69,7 @@ MAPS = {
69
69
  }
70
70
 
71
71
 
72
- class FrozenLake(environment.Environment):
72
+ class FrozenLakeV1(environment.Environment):
73
73
  """
74
74
  JAX-compatible FrozenLake environment.
75
75
 
@@ -81,7 +81,7 @@ class FrozenLake(environment.Environment):
81
81
  def __init__(
82
82
  self,
83
83
  map_name: str = "4x4",
84
- desc: Optional[chex.Array] = None,
84
+ desc: Optional[jax.Array] = None,
85
85
  is_slippery: bool = True,
86
86
  success_rate: float = 1.0 / 3.0,
87
87
  reward_schedule: Tuple[float, float, float] = (1.0, 0.0, 0.0),
@@ -250,9 +250,9 @@ def make_frozen_lake(
250
250
  is_slippery: bool = True,
251
251
  success_rate: float = 1.0 / 3.0,
252
252
  reward_schedule: Tuple[float, float, float] = (1.0, 0.0, 0.0),
253
- ) -> FrozenLake:
253
+ ) -> FrozenLakeV1:
254
254
  """Factory function to easier initialization."""
255
- return FrozenLake(
255
+ return FrozenLakeV1(
256
256
  map_name=map_name,
257
257
  is_slippery=is_slippery,
258
258
  success_rate=success_rate,
@@ -0,0 +1,302 @@
1
+ """JAX-compatible FrozenLakeV2 environment following the gymnax interface."""
2
+
3
+ from typing import Optional, Tuple, Union, List
4
+
5
+ import chex
6
+ import jax
7
+ import jax.numpy as jnp
8
+ from flax import struct
9
+ from gymnax.environments import environment, spaces
10
+
11
+ # --- Constants ---
12
+ # Actions
13
+ LEFT = 0
14
+ DOWN = 1
15
+ RIGHT = 2
16
+ UP = 3
17
+
18
+ # Tile Types (using ASCII values for readability during debugging)
19
+ TILE_START = ord("S")
20
+ TILE_FROZEN = ord("F")
21
+ TILE_THIN = ord("T")
22
+ TILE_GOAL = ord("G")
23
+
24
+
25
+ @struct.dataclass
26
+ class EnvState:
27
+ """Environment state for FrozenLake."""
28
+
29
+ pos: int
30
+ time: int
31
+
32
+
33
+ @struct.dataclass
34
+ class EnvParams:
35
+ """Environment parameters for FrozenLake."""
36
+
37
+ desc: jax.Array # Map description as integer array (nrow, ncol)
38
+ nrow: int
39
+ ncol: int
40
+ is_slippery: bool
41
+ success_rate: float
42
+ reward_schedule: jax.Array # (goal_reward, thin_reward, frozen_reward)
43
+ max_steps_in_episode: int
44
+ safe_cost_mean: float
45
+ safe_cost_std: float
46
+ thin_cost_base: float
47
+ thin_shock_prob: float
48
+ thin_shock_val: float
49
+
50
+
51
+ # --- Helper Functions ---
52
+
53
+
54
+ def string_map_to_array(map_desc: List[str]) -> jax.Array:
55
+ """Convert a list of strings into a JAX-friendly integer array."""
56
+ return jnp.array([[ord(c) for c in row] for row in map_desc], dtype=jnp.int32)
57
+
58
+
59
+ # Predefined maps
60
+ MAPS = {
61
+ "4x4": string_map_to_array(["SFFF", "FTFT", "FFFT", "TFFG"]),
62
+ "8x8": string_map_to_array(
63
+ [
64
+ "SFFFFFFF",
65
+ "FFFFFFFF",
66
+ "FFFTFFFF",
67
+ "FFFFFTFF",
68
+ "FFFTFFFF",
69
+ "FTTFFFTF",
70
+ "FTFFTFTF",
71
+ "FFFTFFFG",
72
+ ]
73
+ ),
74
+ }
75
+
76
+
77
+ class FrozenLakeV2(environment.Environment):
78
+ """
79
+ JAX-compatible FrozenLake environment with "Thin Ice" cost augmentation.
80
+
81
+ The agent controls the movement of a character in a grid world to reach a Goal (G).
82
+ The map contains two types of traversable terrain that introduce a Mean-Variance trade-off:
83
+
84
+ 1. Safe Ice ('F'):
85
+ - Represents a stable but arduous path (deep snow is hard to traverse).
86
+ - Cost: High Mean (e.g., 2.0), Small Variance (e.g., std=0.1).
87
+ - Risk: Deterministically expensive, but safe.
88
+
89
+ 2. Thin Ice ('T', formerly Holes):
90
+ - Represents a dangerous shortcut, prone to cracking or falling but can glide over with little effort.
91
+ - Cost: Low Mean (e.g., 1.5), High Variance Shocks (e.g. std=4.35) (ice cracks).
92
+ - Risk: Usually cheap to traverse, but prone to random, catastrophic cost spikes (cracking ice).
93
+
94
+ The agent can also slip perpendicular to its intended direction based on the `is_slippery` and `success_rate` parameters.
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ map_name: str = "4x4",
100
+ desc: Optional[jax.Array] = None,
101
+ is_slippery: bool = True,
102
+ success_rate: float = 1.0 / 3.0,
103
+ reward_schedule: Tuple[float, float, float] = (1.0, 0.0, 0.0),
104
+ safe_cost_mean: float = 2.0,
105
+ safe_cost_std: float = 0.1,
106
+ thin_cost_base: float = 0.5,
107
+ thin_shock_prob: float = 0.05,
108
+ thin_shock_val: float = 20.0,
109
+ ):
110
+ super().__init__()
111
+
112
+ if desc is None:
113
+ desc = MAPS[map_name]
114
+
115
+ self.desc = desc
116
+ self.nrow, self.ncol = desc.shape
117
+ self.is_slippery = is_slippery
118
+ self.success_rate = success_rate
119
+ self.reward_schedule = jnp.array(reward_schedule)
120
+ self.safe_cost_mean = safe_cost_mean
121
+ self.safe_cost_std = safe_cost_std
122
+ self.thin_cost_base = thin_cost_base
123
+ self.thin_shock_prob = thin_shock_prob
124
+ self.thin_shock_val = thin_shock_val
125
+
126
+ # Determine max steps based on map size convention
127
+ if map_name == "4x4" or (desc.shape[0] == 4 and desc.shape[1] == 4):
128
+ self.max_steps = 100
129
+ else:
130
+ self.max_steps = 200
131
+
132
+ @property
133
+ def default_params(self) -> EnvParams:
134
+ return EnvParams(
135
+ desc=self.desc,
136
+ nrow=self.nrow,
137
+ ncol=self.ncol,
138
+ is_slippery=self.is_slippery,
139
+ success_rate=self.success_rate,
140
+ reward_schedule=self.reward_schedule,
141
+ max_steps_in_episode=self.max_steps,
142
+ safe_cost_mean=self.safe_cost_mean,
143
+ safe_cost_std=self.safe_cost_std,
144
+ thin_cost_base=self.thin_cost_base,
145
+ thin_shock_prob=self.thin_shock_prob,
146
+ thin_shock_val=self.thin_shock_val,
147
+ )
148
+
149
+ def step_env(
150
+ self,
151
+ key: jax.Array,
152
+ state: EnvState,
153
+ action: Union[int, float],
154
+ params: EnvParams,
155
+ ) -> Tuple[chex.Array, EnvState, float, bool, dict]:
156
+ """Perform a single environment step with JIT compatibility."""
157
+ action = jnp.int32(action)
158
+
159
+ # 1. Determine the actual direction (handling slippery logic)
160
+ rng_slip, rng_cost, key = jax.random.split(key, 3)
161
+
162
+ def get_slippery_action(k):
163
+ # 0: intended, 1: perpendicular left, 2: perpendicular right
164
+ # Probabilities: [success, (1-success)/2, (1-success)/2]
165
+ fail_prob = (1.0 - params.success_rate) / 2.0
166
+ probs = jnp.array([params.success_rate, fail_prob, fail_prob])
167
+
168
+ # Map samples (0, 1, 2) to actions (action, action-1, action+1)
169
+ delta_idx = jax.random.choice(k, jnp.arange(3), p=probs)
170
+
171
+ # (action - 1) % 4 <-- Perpendicular Left
172
+ # (action) <-- Intended
173
+ # (action + 1) % 4 <-- Perpendicular Right
174
+
175
+ # Using a lookup array for cleaner mapping
176
+ candidates = jnp.array(
177
+ [
178
+ action, # Index 0: Success
179
+ (action - 1) % 4, # Index 1: Fail Left
180
+ (action + 1) % 4, # Index 2: Fail Right
181
+ ]
182
+ )
183
+ return candidates[delta_idx]
184
+
185
+ actual_action = jax.lax.cond(
186
+ params.is_slippery,
187
+ get_slippery_action,
188
+ lambda k: action, # Deterministic branch
189
+ rng_slip,
190
+ )
191
+
192
+ # 2. Calculate Movement (Optimized Vectorized Approach)
193
+ row = state.pos // params.ncol
194
+ col = state.pos % params.ncol
195
+
196
+ next_row, next_col = self._apply_action(row, col, actual_action, params)
197
+ next_pos = next_row * params.ncol + next_col
198
+
199
+ # 3. Check Rewards and Termination
200
+ tile_type = params.desc[next_row, next_col]
201
+
202
+ is_goal = tile_type == TILE_GOAL
203
+ is_thin = tile_type == TILE_THIN
204
+
205
+ # Reward Schedule: [Goal, Thin, Frozen/Start]
206
+ # We select index 0, 1, or 2 based on tile type
207
+ tile_idx = jnp.where(is_goal, 0, jnp.where(is_thin, 1, 2))
208
+ reward = params.reward_schedule[tile_idx]
209
+
210
+ safe_noise = jax.random.normal(rng_cost) * params.safe_cost_std
211
+ cost_safe = params.safe_cost_mean + safe_noise
212
+
213
+ is_shock = jax.random.bernoulli(rng_cost, p=params.thin_shock_prob)
214
+ cost_thin = params.thin_cost_base + (is_shock * params.thin_shock_val)
215
+
216
+ cost = jnp.select([is_goal, is_thin], [0.0, cost_thin], default=cost_safe)
217
+
218
+ # 4. Update State
219
+ # Time limit truncation is handled in Gymnax wrappers usually,
220
+ # but we track it here for the 'done' flag consistency.
221
+ new_time = state.time + 1
222
+ truncated = new_time >= params.max_steps_in_episode
223
+ done = is_goal | truncated
224
+
225
+ new_state = EnvState(pos=next_pos, time=new_time)
226
+
227
+ return (
228
+ self.get_obs(new_state, params),
229
+ new_state,
230
+ reward,
231
+ done,
232
+ {"cost": cost, "tile_type": tile_type},
233
+ )
234
+
235
+ def reset_env(
236
+ self,
237
+ key: chex.PRNGKey,
238
+ params: EnvParams,
239
+ ) -> Tuple[chex.Array, EnvState]:
240
+ """Reset environment to the start position."""
241
+ # By definition, FrozenLake always starts at (0,0)
242
+ # If dynamic start positions are needed, one would scan params.desc for 'S'
243
+ state = EnvState(pos=0, time=0)
244
+ return self.get_obs(state, params), state
245
+
246
+ def get_obs(self, state: EnvState, params: EnvParams) -> chex.Array:
247
+ """Return scalar observation (flat position index)."""
248
+ return jnp.array(state.pos, dtype=jnp.int32)
249
+
250
+ def _apply_action(
251
+ self, row: int, col: int, action: int, params: EnvParams
252
+ ) -> Tuple[int, int]:
253
+ """Optimized movement logic using delta arrays and clipping."""
254
+ # Gym Direction Mapping:
255
+ # 0: Left (0, -1)
256
+ # 1: Down (1, 0)
257
+ # 2: Right (0, 1)
258
+ # 3: Up (-1, 0)
259
+
260
+ dr = jnp.array([0, 1, 0, -1])
261
+ dc = jnp.array([-1, 0, 1, 0])
262
+
263
+ new_row = row + dr[action]
264
+ new_col = col + dc[action]
265
+
266
+ # Ensure we stay within the grid
267
+ new_row = jnp.clip(new_row, 0, params.nrow - 1)
268
+ new_col = jnp.clip(new_col, 0, params.ncol - 1)
269
+
270
+ return new_row, new_col
271
+
272
+ @property
273
+ def name(self) -> str:
274
+ return "FrozenLake-v2"
275
+
276
+ @property
277
+ def num_actions(self) -> int:
278
+ return 4
279
+
280
+ def action_space(self, params: Optional[EnvParams] = None) -> spaces.Discrete:
281
+ return spaces.Discrete(4)
282
+
283
+ def observation_space(self, params: EnvParams) -> spaces.Discrete:
284
+ return spaces.Discrete(params.nrow * params.ncol)
285
+
286
+
287
+ # --- Factory Helper ---
288
+
289
+
290
+ def make_frozen_lake(
291
+ map_name: str = "4x4",
292
+ is_slippery: bool = True,
293
+ success_rate: float = 1.0 / 3.0,
294
+ reward_schedule: Tuple[float, float, float] = (1.0, 0.0, 0.0),
295
+ ) -> FrozenLakeV2:
296
+ """Factory function to easier initialization."""
297
+ return FrozenLakeV2(
298
+ map_name=map_name,
299
+ is_slippery=is_slippery,
300
+ success_rate=success_rate,
301
+ reward_schedule=reward_schedule,
302
+ )
@@ -1,8 +1,8 @@
1
1
  import pytest
2
2
  import jax
3
3
  import jax.numpy as jnp
4
- from safenax.frozen_lake import FrozenLake, EnvParams
5
- from safenax.frozen_lake import (
4
+ from safenax.frozen_lake.frozen_lake_v1 import FrozenLakeV1, EnvParams
5
+ from safenax.frozen_lake.frozen_lake_v1 import (
6
6
  TILE_START,
7
7
  TILE_FROZEN,
8
8
  TILE_HOLE,
@@ -16,7 +16,7 @@ from safenax.frozen_lake import (
16
16
  @pytest.fixture
17
17
  def env():
18
18
  """Initializes the environment."""
19
- return FrozenLake()
19
+ return FrozenLakeV1()
20
20
 
21
21
 
22
22
  @pytest.fixture
@@ -58,7 +58,7 @@ def params_slippery(params_deterministic: EnvParams) -> EnvParams:
58
58
 
59
59
 
60
60
  def test_jit_compilation(
61
- env: FrozenLake, rng: jax.Array, params_deterministic: EnvParams
61
+ env: FrozenLakeV1, rng: jax.Array, params_deterministic: EnvParams
62
62
  ):
63
63
  """Verifies that reset and step can be JIT compiled without errors."""
64
64
  reset_jit = jax.jit(env.reset)
@@ -85,7 +85,7 @@ def test_map_parsing():
85
85
 
86
86
 
87
87
  def test_deterministic_reach_goal(
88
- env: FrozenLake, rng: jax.Array, params_deterministic: EnvParams
88
+ env: FrozenLakeV1, rng: jax.Array, params_deterministic: EnvParams
89
89
  ):
90
90
  """
91
91
  Path: Start(0) -> Right(1) -> Down(3/Goal).
@@ -119,7 +119,7 @@ def test_deterministic_reach_goal(
119
119
 
120
120
 
121
121
  def test_deterministic_fall_in_hole(
122
- env: FrozenLake, rng: jax.Array, params_deterministic: EnvParams
122
+ env: FrozenLakeV1, rng: jax.Array, params_deterministic: EnvParams
123
123
  ):
124
124
  """
125
125
  Path: Start(0) -> Down(2/Hole).
@@ -142,7 +142,7 @@ def test_deterministic_fall_in_hole(
142
142
 
143
143
 
144
144
  def test_custom_reward_schedule(
145
- env: FrozenLake, rng: jax.Array, params_deterministic: EnvParams
145
+ env: FrozenLakeV1, rng: jax.Array, params_deterministic: EnvParams
146
146
  ):
147
147
  """
148
148
  Verifies that changing reward_schedule affects the output.
@@ -167,7 +167,9 @@ def test_custom_reward_schedule(
167
167
  assert reward == -5.0
168
168
 
169
169
 
170
- def test_slippery_dynamics(env: FrozenLake, rng: jax.Array, params_slippery: EnvParams):
170
+ def test_slippery_dynamics(
171
+ env: FrozenLakeV1, rng: jax.Array, params_slippery: EnvParams
172
+ ):
171
173
  """
172
174
  Statistical test for slippery logic.
173
175
  We start at (0,0) and try to move Right.
@@ -201,7 +203,7 @@ def test_slippery_dynamics(env: FrozenLake, rng: jax.Array, params_slippery: Env
201
203
  assert 1 in unique_positions
202
204
 
203
205
 
204
- def test_truncation(env: FrozenLake, rng: jax.Array, params_deterministic: EnvParams):
206
+ def test_truncation(env: FrozenLakeV1, rng: jax.Array, params_deterministic: EnvParams):
205
207
  """Test that episode ends when time limit is reached."""
206
208
  # Set max steps to 2
207
209
  short_params = params_deterministic.replace(max_steps_in_episode=2)
@@ -224,7 +226,9 @@ def test_truncation(env: FrozenLake, rng: jax.Array, params_deterministic: EnvPa
224
226
  assert state.time == 0
225
227
 
226
228
 
227
- def test_jit_rollout_scan(env: FrozenLake, rng: jax.Array, params_slippery: EnvParams):
229
+ def test_jit_rollout_scan(
230
+ env: FrozenLakeV1, rng: jax.Array, params_slippery: EnvParams
231
+ ):
228
232
  """
229
233
  Verifies that the environment can be run in a fully compiled jax.lax.scan loop.
230
234
  This ensures no shape mismatches or control flow issues exist in the step logic.
@@ -282,7 +286,9 @@ def test_jit_rollout_scan(env: FrozenLake, rng: jax.Array, params_slippery: EnvP
282
286
  assert jnp.all(jnp.isin(rollout_data["reward"], jnp.array([0.0, 1.0])))
283
287
 
284
288
 
285
- def test_cost_signal(env: FrozenLake, rng: jax.Array, params_deterministic: EnvParams):
289
+ def test_cost_signal(
290
+ env: FrozenLakeV1, rng: jax.Array, params_deterministic: EnvParams
291
+ ):
286
292
  """Ensure cost is 1.0 ONLY when falling into a hole."""
287
293
  step_jit = jax.jit(env.step)
288
294
  reset_jit = jax.jit(env.reset)
@@ -0,0 +1,167 @@
1
+ """Unit tests for FrozenLakeV2 environment."""
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import pytest
6
+ from safenax.frozen_lake.frozen_lake_v2 import (
7
+ FrozenLakeV2,
8
+ EnvState,
9
+ TILE_FROZEN,
10
+ TILE_THIN,
11
+ TILE_GOAL,
12
+ )
13
+
14
+ # Fix random key for reproducibility
15
+ KEY = jax.random.PRNGKey(42)
16
+
17
+
18
+ @pytest.fixture
19
+ def env():
20
+ """Fixture to create a standard 4x4 FrozenLakeV2 environment."""
21
+ return FrozenLakeV2(map_name="4x4", is_slippery=True)
22
+
23
+
24
+ @pytest.fixture
25
+ def params(env):
26
+ """Fixture for default environment parameters."""
27
+ return env.default_params
28
+
29
+
30
+ def test_initialization(env, params):
31
+ """Test that the environment initializes and resets correctly."""
32
+ key = jax.random.PRNGKey(0)
33
+ obs, state = env.reset_env(key, params)
34
+
35
+ # Check initial state
36
+ assert state.pos == 0
37
+ assert state.time == 0
38
+ assert obs.shape == () # Scalar observation
39
+ assert obs == 0
40
+
41
+
42
+ def test_step_mechanics(env, params):
43
+ """Test basic stepping functionality."""
44
+ key = jax.random.PRNGKey(0)
45
+ _, state = env.reset_env(key, params)
46
+
47
+ # Take an action (e.g., RIGHT = 2)
48
+ action = 2
49
+ obs, next_state, reward, done, info = env.step_env(key, state, action, params)
50
+
51
+ # Basic shape/type checks
52
+ assert isinstance(next_state, EnvState)
53
+ assert next_state.time == 1
54
+ assert "cost" in info
55
+ assert info["cost"] >= 0
56
+
57
+
58
+ def test_slippery_dynamics(env, params):
59
+ """
60
+ Test that the agent actually slips when is_slippery=True.
61
+ Starting at (1,1), taking action RIGHT (2) should stochastically result
62
+ in moving RIGHT, UP, or DOWN.
63
+ """
64
+ key = jax.random.PRNGKey(0)
65
+ keys = jax.random.split(key, 100) # Run 100 steps
66
+
67
+ # Force state to be at (1,1) [pos=5 for 4x4] to allow movement in all directions
68
+ # Map:
69
+ # S F F F (0, 1, 2, 3)
70
+ # F T F T (4, 5, 6, 7)
71
+ state = EnvState(pos=5, time=0)
72
+ action = 2 # RIGHT
73
+
74
+ # Vmap the step function to run in parallel
75
+ def step_fn(k):
76
+ return env.step_env(k, state, action, params)[1].pos
77
+
78
+ next_positions = jax.vmap(step_fn)(keys)
79
+
80
+ # From 5 (1,1):
81
+ # Right -> 6 (1,2) [Intended]
82
+ # Down -> 9 (2,1) [Slip Right]
83
+ # Up -> 1 (0,1) [Slip Left]
84
+
85
+ unique_positions = jnp.unique(next_positions)
86
+
87
+ # Assert that we ended up in more than just the intended position
88
+ assert len(unique_positions) > 1
89
+ assert 6 in unique_positions # Intended
90
+ assert (9 in unique_positions) or (1 in unique_positions) # Slips
91
+
92
+
93
+ def test_cost_mean_variance(env, params):
94
+ """
95
+ CRITICAL TEST: Verify the Mean-Variance trade-off for VaR-CPO.
96
+ Safe Ice should have Mean~2.0, Std~0.1.
97
+ Thin Ice should have Mean~1.5, Std~4.35.
98
+ """
99
+ N = 5000 # High sample count for statistical significance
100
+ key = jax.random.PRNGKey(0)
101
+ keys = jax.random.split(key, N)
102
+
103
+ # We will force the map description to be entirely one type to isolate the cost function logic
104
+ # regardless of where the agent moves.
105
+
106
+ def get_batch_costs(tile_type):
107
+ """Helper to get costs for N steps on a map filled with `tile_type`."""
108
+ # Create a dummy map filled with the specific tile type
109
+ custom_desc = jnp.full((4, 4), tile_type, dtype=jnp.int32)
110
+ custom_params = params.replace(desc=custom_desc)
111
+
112
+ def step_fn(k):
113
+ # Start at 0, take action 0. Destination will be same tile type.
114
+ _, _, _, _, info = env.step_env(k, EnvState(0, 0), 0, custom_params)
115
+ return info["cost"]
116
+
117
+ return jax.vmap(step_fn)(keys)
118
+
119
+ # 1. Test Safe Ice ('F')
120
+ safe_costs = get_batch_costs(TILE_FROZEN)
121
+
122
+ # Check Safe Statistics: Mean ~ 2.0, Std ~ 0.1
123
+ assert jnp.abs(jnp.mean(safe_costs) - 2.0) < 0.1, (
124
+ f"Safe Mean {jnp.mean(safe_costs)} != 2.0"
125
+ )
126
+ assert jnp.std(safe_costs) < 0.2, f"Safe Std {jnp.std(safe_costs)} is too high"
127
+
128
+ # 2. Test Thin Ice ('T')
129
+ thin_costs = get_batch_costs(TILE_THIN)
130
+
131
+ # Check Thin Statistics: Mean ~ 1.5, Std ~ 4.35
132
+ assert jnp.abs(jnp.mean(thin_costs) - 1.5) < 0.1, (
133
+ f"Thin Mean {jnp.mean(thin_costs)} != 1.5"
134
+ )
135
+ assert jnp.abs(jnp.std(thin_costs) - 4.35) < 0.1, (
136
+ f"Thin Std {jnp.std(thin_costs)} is too low (expected ~4.35)"
137
+ )
138
+
139
+ print(
140
+ f"\nStats Verification:\nSafe: Mean={jnp.mean(safe_costs):.2f}, Std={jnp.std(safe_costs):.2f}"
141
+ )
142
+ print(f"Thin: Mean={jnp.mean(thin_costs):.2f}, Std={jnp.std(thin_costs):.2f}")
143
+
144
+
145
+ def test_termination(env, params):
146
+ """Test episode termination logic."""
147
+ key = jax.random.PRNGKey(0)
148
+
149
+ # 1. Test Goal Termination
150
+ # Force map to be all GOAL
151
+ goal_desc = jnp.full((4, 4), TILE_GOAL, dtype=jnp.int32)
152
+ goal_params = params.replace(desc=goal_desc)
153
+
154
+ # Agent takes a step on a Goal tile -> should be done
155
+ _, _, _, done, _ = env.step_env(key, EnvState(0, 0), 0, goal_params)
156
+ assert done == True
157
+
158
+ # 2. Test Time Truncation
159
+ # Set max_steps to 1
160
+ short_params = params.replace(max_steps_in_episode=1)
161
+
162
+ # Step 1: time becomes 1 -> truncated >= max_steps (1) -> done=True
163
+ state = EnvState(0, 0)
164
+ _, next_state, _, done, _ = env.step_env(key, state, 0, short_params)
165
+
166
+ assert next_state.time == 1
167
+ assert done == True
@@ -1501,7 +1501,7 @@ wheels = [
1501
1501
 
1502
1502
  [[package]]
1503
1503
  name = "safenax"
1504
- version = "0.4.3"
1504
+ version = "0.4.6"
1505
1505
  source = { editable = "." }
1506
1506
  dependencies = [
1507
1507
  { name = "brax" },
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes