safenax 0.4.4__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
safenax/__init__.py CHANGED
@@ -8,7 +8,7 @@ from safenax.portfolio_optimization import (
8
8
  PortfolioOptimizationGARCH,
9
9
  PortfolioOptimizationCrypto,
10
10
  )
11
- from safenax.frozen_lake import FrozenLake
11
+ from safenax.frozen_lake import FrozenLakeV1, FrozenLakeV2
12
12
  from safenax.eco_ant import EcoAntV1, EcoAntV2
13
13
 
14
14
 
@@ -16,7 +16,8 @@ __all__ = [
16
16
  "FragileAnt",
17
17
  "PortfolioOptimizationCrypto",
18
18
  "PortfolioOptimizationGARCH",
19
- "FrozenLake",
19
+ "FrozenLakeV1",
20
+ "FrozenLakeV2",
20
21
  "EcoAntV1",
21
22
  "EcoAntV2",
22
23
  ]
@@ -0,0 +1,8 @@
1
+ from safenax.frozen_lake.frozen_lake_v1 import FrozenLakeV1
2
+ from safenax.frozen_lake.frozen_lake_v2 import FrozenLakeV2
3
+
4
+
5
+ __all__ = [
6
+ "FrozenLakeV1",
7
+ "FrozenLakeV2",
8
+ ]
@@ -1,4 +1,4 @@
1
- """JAX-compatible FrozenLake environment following the gymnax interface."""
1
+ """JAX-compatible FrozenLakeV1 environment following the gymnax interface."""
2
2
 
3
3
  from typing import Optional, Tuple, Union, List
4
4
 
@@ -69,7 +69,7 @@ MAPS = {
69
69
  }
70
70
 
71
71
 
72
- class FrozenLake(environment.Environment):
72
+ class FrozenLakeV1(environment.Environment):
73
73
  """
74
74
  JAX-compatible FrozenLake environment.
75
75
 
@@ -81,7 +81,7 @@ class FrozenLake(environment.Environment):
81
81
  def __init__(
82
82
  self,
83
83
  map_name: str = "4x4",
84
- desc: Optional[chex.Array] = None,
84
+ desc: Optional[jax.Array] = None,
85
85
  is_slippery: bool = True,
86
86
  success_rate: float = 1.0 / 3.0,
87
87
  reward_schedule: Tuple[float, float, float] = (1.0, 0.0, 0.0),
@@ -250,9 +250,9 @@ def make_frozen_lake(
250
250
  is_slippery: bool = True,
251
251
  success_rate: float = 1.0 / 3.0,
252
252
  reward_schedule: Tuple[float, float, float] = (1.0, 0.0, 0.0),
253
- ) -> FrozenLake:
253
+ ) -> FrozenLakeV1:
254
254
  """Factory function to easier initialization."""
255
- return FrozenLake(
255
+ return FrozenLakeV1(
256
256
  map_name=map_name,
257
257
  is_slippery=is_slippery,
258
258
  success_rate=success_rate,
@@ -0,0 +1,302 @@
1
+ """JAX-compatible FrozenLakeV2 environment following the gymnax interface."""
2
+
3
+ from typing import Optional, Tuple, Union, List
4
+
5
+ import chex
6
+ import jax
7
+ import jax.numpy as jnp
8
+ from flax import struct
9
+ from gymnax.environments import environment, spaces
10
+
11
+ # --- Constants ---
12
+ # Actions
13
+ LEFT = 0
14
+ DOWN = 1
15
+ RIGHT = 2
16
+ UP = 3
17
+
18
+ # Tile Types (using ASCII values for readability during debugging)
19
+ TILE_START = ord("S")
20
+ TILE_FROZEN = ord("F")
21
+ TILE_THIN = ord("T")
22
+ TILE_GOAL = ord("G")
23
+
24
+
25
+ @struct.dataclass
26
+ class EnvState:
27
+ """Environment state for FrozenLake."""
28
+
29
+ pos: int
30
+ time: int
31
+
32
+
33
+ @struct.dataclass
34
+ class EnvParams:
35
+ """Environment parameters for FrozenLake."""
36
+
37
+ desc: jax.Array # Map description as integer array (nrow, ncol)
38
+ nrow: int
39
+ ncol: int
40
+ is_slippery: bool
41
+ success_rate: float
42
+ reward_schedule: jax.Array # (goal_reward, thin_reward, frozen_reward)
43
+ max_steps_in_episode: int
44
+ safe_cost_mean: float
45
+ safe_cost_std: float
46
+ thin_cost_base: float
47
+ thin_shock_prob: float
48
+ thin_shock_val: float
49
+
50
+
51
+ # --- Helper Functions ---
52
+
53
+
54
+ def string_map_to_array(map_desc: List[str]) -> jax.Array:
55
+ """Convert a list of strings into a JAX-friendly integer array."""
56
+ return jnp.array([[ord(c) for c in row] for row in map_desc], dtype=jnp.int32)
57
+
58
+
59
+ # Predefined maps
60
+ MAPS = {
61
+ "4x4": string_map_to_array(["SFFF", "FTFT", "FFFT", "TFFG"]),
62
+ "8x8": string_map_to_array(
63
+ [
64
+ "SFFFFFFF",
65
+ "FFFFFFFF",
66
+ "FFFTFFFF",
67
+ "FFFFFTFF",
68
+ "FFFTFFFF",
69
+ "FTTFFFTF",
70
+ "FTFFTFTF",
71
+ "FFFTFFFG",
72
+ ]
73
+ ),
74
+ }
75
+
76
+
77
+ class FrozenLakeV2(environment.Environment):
78
+ """
79
+ JAX-compatible FrozenLake environment with "Thin Ice" cost augmentation.
80
+
81
+ The agent controls the movement of a character in a grid world to reach a Goal (G).
82
+ The map contains two types of traversable terrain that introduce a Mean-Variance trade-off:
83
+
84
+ 1. Safe Ice ('F'):
85
+ - Represents a stable but arduous path (deep snow is hard to traverse).
86
+ - Cost: High Mean (e.g., 2.0), Small Variance (e.g., std=0.1).
87
+ - Risk: Deterministically expensive, but safe.
88
+
89
+ 2. Thin Ice ('T', formerly Holes):
90
+ - Represents a dangerous shortcut, prone to cracking or falling but can glide over with little effort.
91
+ - Cost: Low Mean (e.g., 1.5), High Variance Shocks (e.g. std=4.35) (ice cracks).
92
+ - Risk: Usually cheap to traverse, but prone to random, catastrophic cost spikes (cracking ice).
93
+
94
+ The agent can also slip perpendicular to its intended direction based on the `is_slippery` and `success_rate` parameters.
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ map_name: str = "4x4",
100
+ desc: Optional[jax.Array] = None,
101
+ is_slippery: bool = True,
102
+ success_rate: float = 1.0 / 3.0,
103
+ reward_schedule: Tuple[float, float, float] = (1.0, 0.0, 0.0),
104
+ safe_cost_mean: float = 2.0,
105
+ safe_cost_std: float = 0.1,
106
+ thin_cost_base: float = 0.5,
107
+ thin_shock_prob: float = 0.05,
108
+ thin_shock_val: float = 20.0,
109
+ ):
110
+ super().__init__()
111
+
112
+ if desc is None:
113
+ desc = MAPS[map_name]
114
+
115
+ self.desc = desc
116
+ self.nrow, self.ncol = desc.shape
117
+ self.is_slippery = is_slippery
118
+ self.success_rate = success_rate
119
+ self.reward_schedule = jnp.array(reward_schedule)
120
+ self.safe_cost_mean = safe_cost_mean
121
+ self.safe_cost_std = safe_cost_std
122
+ self.thin_cost_base = thin_cost_base
123
+ self.thin_shock_prob = thin_shock_prob
124
+ self.thin_shock_val = thin_shock_val
125
+
126
+ # Determine max steps based on map size convention
127
+ if map_name == "4x4" or (desc.shape[0] == 4 and desc.shape[1] == 4):
128
+ self.max_steps = 100
129
+ else:
130
+ self.max_steps = 200
131
+
132
+ @property
133
+ def default_params(self) -> EnvParams:
134
+ return EnvParams(
135
+ desc=self.desc,
136
+ nrow=self.nrow,
137
+ ncol=self.ncol,
138
+ is_slippery=self.is_slippery,
139
+ success_rate=self.success_rate,
140
+ reward_schedule=self.reward_schedule,
141
+ max_steps_in_episode=self.max_steps,
142
+ safe_cost_mean=self.safe_cost_mean,
143
+ safe_cost_std=self.safe_cost_std,
144
+ thin_cost_base=self.thin_cost_base,
145
+ thin_shock_prob=self.thin_shock_prob,
146
+ thin_shock_val=self.thin_shock_val,
147
+ )
148
+
149
+ def step_env(
150
+ self,
151
+ key: jax.Array,
152
+ state: EnvState,
153
+ action: Union[int, float],
154
+ params: EnvParams,
155
+ ) -> Tuple[chex.Array, EnvState, float, bool, dict]:
156
+ """Perform a single environment step with JIT compatibility."""
157
+ action = jnp.int32(action)
158
+
159
+ # 1. Determine the actual direction (handling slippery logic)
160
+ rng_slip, rng_cost, key = jax.random.split(key, 3)
161
+
162
+ def get_slippery_action(k):
163
+ # 0: intended, 1: perpendicular left, 2: perpendicular right
164
+ # Probabilities: [success, (1-success)/2, (1-success)/2]
165
+ fail_prob = (1.0 - params.success_rate) / 2.0
166
+ probs = jnp.array([params.success_rate, fail_prob, fail_prob])
167
+
168
+ # Map samples (0, 1, 2) to actions (action, action-1, action+1)
169
+ delta_idx = jax.random.choice(k, jnp.arange(3), p=probs)
170
+
171
+ # (action - 1) % 4 <-- Perpendicular Left
172
+ # (action) <-- Intended
173
+ # (action + 1) % 4 <-- Perpendicular Right
174
+
175
+ # Using a lookup array for cleaner mapping
176
+ candidates = jnp.array(
177
+ [
178
+ action, # Index 0: Success
179
+ (action - 1) % 4, # Index 1: Fail Left
180
+ (action + 1) % 4, # Index 2: Fail Right
181
+ ]
182
+ )
183
+ return candidates[delta_idx]
184
+
185
+ actual_action = jax.lax.cond(
186
+ params.is_slippery,
187
+ get_slippery_action,
188
+ lambda k: action, # Deterministic branch
189
+ rng_slip,
190
+ )
191
+
192
+ # 2. Calculate Movement (Optimized Vectorized Approach)
193
+ row = state.pos // params.ncol
194
+ col = state.pos % params.ncol
195
+
196
+ next_row, next_col = self._apply_action(row, col, actual_action, params)
197
+ next_pos = next_row * params.ncol + next_col
198
+
199
+ # 3. Check Rewards and Termination
200
+ tile_type = params.desc[next_row, next_col]
201
+
202
+ is_goal = tile_type == TILE_GOAL
203
+ is_thin = tile_type == TILE_THIN
204
+
205
+ # Reward Schedule: [Goal, Thin, Frozen/Start]
206
+ # We select index 0, 1, or 2 based on tile type
207
+ tile_idx = jnp.where(is_goal, 0, jnp.where(is_thin, 1, 2))
208
+ reward = params.reward_schedule[tile_idx]
209
+
210
+ safe_noise = jax.random.normal(rng_cost) * params.safe_cost_std
211
+ cost_safe = params.safe_cost_mean + safe_noise
212
+
213
+ is_shock = jax.random.bernoulli(rng_cost, p=params.thin_shock_prob)
214
+ cost_thin = params.thin_cost_base + (is_shock * params.thin_shock_val)
215
+
216
+ cost = jnp.select([is_goal, is_thin], [0.0, cost_thin], default=cost_safe)
217
+
218
+ # 4. Update State
219
+ # Time limit truncation is handled in Gymnax wrappers usually,
220
+ # but we track it here for the 'done' flag consistency.
221
+ new_time = state.time + 1
222
+ truncated = new_time >= params.max_steps_in_episode
223
+ done = is_goal | truncated
224
+
225
+ new_state = EnvState(pos=next_pos, time=new_time)
226
+
227
+ return (
228
+ self.get_obs(new_state, params),
229
+ new_state,
230
+ reward,
231
+ done,
232
+ {"cost": cost, "tile_type": tile_type},
233
+ )
234
+
235
+ def reset_env(
236
+ self,
237
+ key: chex.PRNGKey,
238
+ params: EnvParams,
239
+ ) -> Tuple[chex.Array, EnvState]:
240
+ """Reset environment to the start position."""
241
+ # By definition, FrozenLake always starts at (0,0)
242
+ # If dynamic start positions are needed, one would scan params.desc for 'S'
243
+ state = EnvState(pos=0, time=0)
244
+ return self.get_obs(state, params), state
245
+
246
+ def get_obs(self, state: EnvState, params: EnvParams) -> chex.Array:
247
+ """Return scalar observation (flat position index)."""
248
+ return jnp.array(state.pos, dtype=jnp.int32)
249
+
250
+ def _apply_action(
251
+ self, row: int, col: int, action: int, params: EnvParams
252
+ ) -> Tuple[int, int]:
253
+ """Optimized movement logic using delta arrays and clipping."""
254
+ # Gym Direction Mapping:
255
+ # 0: Left (0, -1)
256
+ # 1: Down (1, 0)
257
+ # 2: Right (0, 1)
258
+ # 3: Up (-1, 0)
259
+
260
+ dr = jnp.array([0, 1, 0, -1])
261
+ dc = jnp.array([-1, 0, 1, 0])
262
+
263
+ new_row = row + dr[action]
264
+ new_col = col + dc[action]
265
+
266
+ # Ensure we stay within the grid
267
+ new_row = jnp.clip(new_row, 0, params.nrow - 1)
268
+ new_col = jnp.clip(new_col, 0, params.ncol - 1)
269
+
270
+ return new_row, new_col
271
+
272
+ @property
273
+ def name(self) -> str:
274
+ return "FrozenLake-v2"
275
+
276
+ @property
277
+ def num_actions(self) -> int:
278
+ return 4
279
+
280
+ def action_space(self, params: Optional[EnvParams] = None) -> spaces.Discrete:
281
+ return spaces.Discrete(4)
282
+
283
+ def observation_space(self, params: EnvParams) -> spaces.Discrete:
284
+ return spaces.Discrete(params.nrow * params.ncol)
285
+
286
+
287
+ # --- Factory Helper ---
288
+
289
+
290
+ def make_frozen_lake(
291
+ map_name: str = "4x4",
292
+ is_slippery: bool = True,
293
+ success_rate: float = 1.0 / 3.0,
294
+ reward_schedule: Tuple[float, float, float] = (1.0, 0.0, 0.0),
295
+ ) -> FrozenLakeV2:
296
+ """Factory function to easier initialization."""
297
+ return FrozenLakeV2(
298
+ map_name=map_name,
299
+ is_slippery=is_slippery,
300
+ success_rate=success_rate,
301
+ reward_schedule=reward_schedule,
302
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: safenax
3
- Version: 0.4.4
3
+ Version: 0.4.6
4
4
  Summary: Constrained environments with a gymnax interface
5
5
  Project-URL: Homepage, https://github.com/0xprofessooor/safenax
6
6
  Project-URL: Repository, https://github.com/0xprofessooor/safenax
@@ -1,16 +1,18 @@
1
- safenax/__init__.py,sha256=alSl90zOUHuYSvdq9zsPO_nx91cWI-tojPgz_CDFjfI,557
1
+ safenax/__init__.py,sha256=yD78xGJagBtjTM9fVskf9OULJdXSLxPfLSWTvxEINOY,595
2
2
  safenax/fragile_ant.py,sha256=10XAOYFrEmi9mjeVurk6W4fIQj5zaUXlQYthQQxSm14,6751
3
- safenax/frozen_lake.py,sha256=YZT5KWbiFyWY6gUo6vPn9owW_flinqnZIkmFtPclGs0,7880
4
3
  safenax/eco_ant/__init__.py,sha256=USsps574Jfc7A5N9PWiY2zbwiphy3URJinzvfsSYjEg,144
5
4
  safenax/eco_ant/eco_ant_v1.py,sha256=G6YekTSSK2orcYjNR9QNVZkKpeIrqM56m7gmsNu4cOI,2743
6
5
  safenax/eco_ant/eco_ant_v2.py,sha256=Aid3ySUJuzGHLiC4L93wLNRy9IrTTsEdP7Ii8aDxQqQ,2601
6
+ safenax/frozen_lake/__init__.py,sha256=81aH7mpQiEWJeem4usZTbilSdlXDJybA7ePowxRyQhc,176
7
+ safenax/frozen_lake/frozen_lake_v1.py,sha256=6Yy9tm4MbrNiYXDNi091Jsh9iwoEcKz3TrOn2sVVGTw,7887
8
+ safenax/frozen_lake/frozen_lake_v2.py,sha256=rXSJ8lJKPsF0ZVQcQpBm6ysBYq5I2G0GIfsfLD1-TmA,9731
7
9
  safenax/portfolio_optimization/__init__.py,sha256=tbtCF4fVfan2nfFJc2wNl24hCALSb0yON1OYboN5OGk,245
8
10
  safenax/portfolio_optimization/po_crypto.py,sha256=Bi4QCd4MoeQAnhag22MFWdqy1uQ5hVQdiwYymP9v7N4,7342
9
11
  safenax/portfolio_optimization/po_garch.py,sha256=f2kneV5NpH_ebG_IFcfUvc3qthzZHEZt5YwcKgaI9sI,33320
10
12
  safenax/wrappers/__init__.py,sha256=v9wyHyR482ZEfmfTtcGabpf_lUHze4fy-NjrEaGv3zA,158
11
13
  safenax/wrappers/brax.py,sha256=svijcYVoWy5ej7RRLuN8VixDL_cMXKBK-veFsC57LRE,2985
12
14
  safenax/wrappers/log.py,sha256=jsjT0FJBo21rCM6D2Hx9fOwXLdwP1MW6PAx1BJBP2lA,2842
13
- safenax-0.4.4.dist-info/METADATA,sha256=n7cC3e2QGtQaoBP9MKXhQQhu2v1lU9rualxWyhqpUZA,1202
14
- safenax-0.4.4.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
15
- safenax-0.4.4.dist-info/licenses/LICENSE,sha256=BI7P9lDrJUcIUIX_4sCSE9pKHgCYIKWzHCOFyn85eKk,1077
16
- safenax-0.4.4.dist-info/RECORD,,
15
+ safenax-0.4.6.dist-info/METADATA,sha256=efFG_7pLH-Z5Fd8a1YzmR681Pl5mLaQIXgQqNmUx9V8,1202
16
+ safenax-0.4.6.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
17
+ safenax-0.4.6.dist-info/licenses/LICENSE,sha256=BI7P9lDrJUcIUIX_4sCSE9pKHgCYIKWzHCOFyn85eKk,1077
18
+ safenax-0.4.6.dist-info/RECORD,,