safenax 0.4.4__py3-none-any.whl → 0.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- safenax/__init__.py +3 -2
- safenax/frozen_lake/__init__.py +8 -0
- safenax/{frozen_lake.py → frozen_lake/frozen_lake_v1.py} +5 -5
- safenax/frozen_lake/frozen_lake_v2.py +296 -0
- {safenax-0.4.4.dist-info → safenax-0.4.5.dist-info}/METADATA +1 -1
- {safenax-0.4.4.dist-info → safenax-0.4.5.dist-info}/RECORD +8 -6
- {safenax-0.4.4.dist-info → safenax-0.4.5.dist-info}/WHEEL +0 -0
- {safenax-0.4.4.dist-info → safenax-0.4.5.dist-info}/licenses/LICENSE +0 -0
safenax/__init__.py
CHANGED
|
@@ -8,7 +8,7 @@ from safenax.portfolio_optimization import (
|
|
|
8
8
|
PortfolioOptimizationGARCH,
|
|
9
9
|
PortfolioOptimizationCrypto,
|
|
10
10
|
)
|
|
11
|
-
from safenax.frozen_lake import
|
|
11
|
+
from safenax.frozen_lake import FrozenLakeV1, FrozenLakeV2
|
|
12
12
|
from safenax.eco_ant import EcoAntV1, EcoAntV2
|
|
13
13
|
|
|
14
14
|
|
|
@@ -16,7 +16,8 @@ __all__ = [
|
|
|
16
16
|
"FragileAnt",
|
|
17
17
|
"PortfolioOptimizationCrypto",
|
|
18
18
|
"PortfolioOptimizationGARCH",
|
|
19
|
-
"
|
|
19
|
+
"FrozenLakeV1",
|
|
20
|
+
"FrozenLakeV2",
|
|
20
21
|
"EcoAntV1",
|
|
21
22
|
"EcoAntV2",
|
|
22
23
|
]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""JAX-compatible
|
|
1
|
+
"""JAX-compatible FrozenLakeV1 environment following the gymnax interface."""
|
|
2
2
|
|
|
3
3
|
from typing import Optional, Tuple, Union, List
|
|
4
4
|
|
|
@@ -69,7 +69,7 @@ MAPS = {
|
|
|
69
69
|
}
|
|
70
70
|
|
|
71
71
|
|
|
72
|
-
class
|
|
72
|
+
class FrozenLakeV1(environment.Environment):
|
|
73
73
|
"""
|
|
74
74
|
JAX-compatible FrozenLake environment.
|
|
75
75
|
|
|
@@ -81,7 +81,7 @@ class FrozenLake(environment.Environment):
|
|
|
81
81
|
def __init__(
|
|
82
82
|
self,
|
|
83
83
|
map_name: str = "4x4",
|
|
84
|
-
desc: Optional[
|
|
84
|
+
desc: Optional[jax.Array] = None,
|
|
85
85
|
is_slippery: bool = True,
|
|
86
86
|
success_rate: float = 1.0 / 3.0,
|
|
87
87
|
reward_schedule: Tuple[float, float, float] = (1.0, 0.0, 0.0),
|
|
@@ -250,9 +250,9 @@ def make_frozen_lake(
|
|
|
250
250
|
is_slippery: bool = True,
|
|
251
251
|
success_rate: float = 1.0 / 3.0,
|
|
252
252
|
reward_schedule: Tuple[float, float, float] = (1.0, 0.0, 0.0),
|
|
253
|
-
) ->
|
|
253
|
+
) -> FrozenLakeV1:
|
|
254
254
|
"""Factory function to easier initialization."""
|
|
255
|
-
return
|
|
255
|
+
return FrozenLakeV1(
|
|
256
256
|
map_name=map_name,
|
|
257
257
|
is_slippery=is_slippery,
|
|
258
258
|
success_rate=success_rate,
|
|
@@ -0,0 +1,296 @@
|
|
|
1
|
+
"""JAX-compatible FrozenLakeV2 environment following the gymnax interface."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional, Tuple, Union, List
|
|
4
|
+
|
|
5
|
+
import chex
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
from flax import struct
|
|
9
|
+
from gymnax.environments import environment, spaces
|
|
10
|
+
|
|
11
|
+
# --- Constants ---
|
|
12
|
+
# Actions
|
|
13
|
+
LEFT = 0
|
|
14
|
+
DOWN = 1
|
|
15
|
+
RIGHT = 2
|
|
16
|
+
UP = 3
|
|
17
|
+
|
|
18
|
+
# Tile Types (using ASCII values for readability during debugging)
|
|
19
|
+
TILE_START = ord("S")
|
|
20
|
+
TILE_FROZEN = ord("F")
|
|
21
|
+
TILE_THIN = ord("T")
|
|
22
|
+
TILE_GOAL = ord("G")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@struct.dataclass
|
|
26
|
+
class EnvState:
|
|
27
|
+
"""Environment state for FrozenLake."""
|
|
28
|
+
|
|
29
|
+
pos: int
|
|
30
|
+
time: int
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@struct.dataclass
|
|
34
|
+
class EnvParams:
|
|
35
|
+
"""Environment parameters for FrozenLake."""
|
|
36
|
+
|
|
37
|
+
desc: jax.Array # Map description as integer array (nrow, ncol)
|
|
38
|
+
nrow: int
|
|
39
|
+
ncol: int
|
|
40
|
+
is_slippery: bool
|
|
41
|
+
success_rate: float
|
|
42
|
+
reward_schedule: jax.Array # (goal_reward, thin_reward, frozen_reward)
|
|
43
|
+
max_steps_in_episode: int
|
|
44
|
+
safe_cost_mean: float
|
|
45
|
+
safe_cost_std: float
|
|
46
|
+
thin_cost_base: float
|
|
47
|
+
thin_shock_prob: float
|
|
48
|
+
thin_shock_val: float
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# --- Helper Functions ---
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def string_map_to_array(map_desc: List[str]) -> jax.Array:
|
|
55
|
+
"""Convert a list of strings into a JAX-friendly integer array."""
|
|
56
|
+
return jnp.array([[ord(c) for c in row] for row in map_desc], dtype=jnp.int32)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# Predefined maps
|
|
60
|
+
MAPS = {
|
|
61
|
+
"4x4": string_map_to_array(["SFFF", "FTFT", "FFFT", "TFFG"]),
|
|
62
|
+
"8x8": string_map_to_array(
|
|
63
|
+
[
|
|
64
|
+
"SFFFFFFF",
|
|
65
|
+
"FFFFFFFF",
|
|
66
|
+
"FFFTFFFF",
|
|
67
|
+
"FFFFFTFF",
|
|
68
|
+
"FFFTFFFF",
|
|
69
|
+
"FTTFFFTF",
|
|
70
|
+
"FTFFTFTF",
|
|
71
|
+
"FFFTFFFG",
|
|
72
|
+
]
|
|
73
|
+
),
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class FrozenLakeV2(environment.Environment):
|
|
78
|
+
"""
|
|
79
|
+
JAX-compatible FrozenLake environment with "Thin Ice" cost augmentation.
|
|
80
|
+
|
|
81
|
+
The agent controls the movement of a character in a grid world to reach a Goal (G).
|
|
82
|
+
The map contains two types of traversable terrain that introduce a Mean-Variance trade-off:
|
|
83
|
+
|
|
84
|
+
1. Safe Ice ('F'):
|
|
85
|
+
- Represents a stable but arduous path (deep snow is hard to traverse).
|
|
86
|
+
- Cost: High Mean (e.g., 2.0), Small Variance (e.g., std=0.1).
|
|
87
|
+
- Risk: Deterministically expensive, but safe.
|
|
88
|
+
|
|
89
|
+
2. Thin Ice ('T', formerly Holes):
|
|
90
|
+
- Represents a dangerous shortcut, prone to cracking or falling but can glide over with little effort.
|
|
91
|
+
- Cost: Low Mean (e.g., 1.5), High Variance Shocks (e.g. std=4.35) (ice cracks).
|
|
92
|
+
- Risk: Usually cheap to traverse, but prone to random, catastrophic cost spikes (cracking ice).
|
|
93
|
+
|
|
94
|
+
The agent can also slip perpendicular to its intended direction based on the `is_slippery` and `success_rate` parameters.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
def __init__(
|
|
98
|
+
self,
|
|
99
|
+
map_name: str = "4x4",
|
|
100
|
+
desc: Optional[jax.Array] = None,
|
|
101
|
+
is_slippery: bool = True,
|
|
102
|
+
success_rate: float = 1.0 / 3.0,
|
|
103
|
+
reward_schedule: Tuple[float, float, float] = (1.0, 0.0, 0.0),
|
|
104
|
+
safe_cost_mean: float = 2.0,
|
|
105
|
+
safe_cost_std: float = 0.1,
|
|
106
|
+
thin_cost_base: float = 0.5,
|
|
107
|
+
thin_shock_prob: float = 0.05,
|
|
108
|
+
thin_shock_val: float = 20.0,
|
|
109
|
+
):
|
|
110
|
+
super().__init__()
|
|
111
|
+
|
|
112
|
+
if desc is None:
|
|
113
|
+
desc = MAPS[map_name]
|
|
114
|
+
|
|
115
|
+
self.desc = desc
|
|
116
|
+
self.nrow, self.ncol = desc.shape
|
|
117
|
+
self.is_slippery = is_slippery
|
|
118
|
+
self.success_rate = success_rate
|
|
119
|
+
self.reward_schedule = jnp.array(reward_schedule)
|
|
120
|
+
self.safe_cost_mean = safe_cost_mean
|
|
121
|
+
self.safe_cost_std = safe_cost_std
|
|
122
|
+
self.thin_cost_base = thin_cost_base
|
|
123
|
+
self.thin_shock_prob = thin_shock_prob
|
|
124
|
+
self.thin_shock_val = thin_shock_val
|
|
125
|
+
|
|
126
|
+
# Determine max steps based on map size convention
|
|
127
|
+
if map_name == "4x4" or (desc.shape[0] == 4 and desc.shape[1] == 4):
|
|
128
|
+
self.max_steps = 100
|
|
129
|
+
else:
|
|
130
|
+
self.max_steps = 200
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def default_params(self) -> EnvParams:
|
|
134
|
+
return EnvParams(
|
|
135
|
+
desc=self.desc,
|
|
136
|
+
nrow=self.nrow,
|
|
137
|
+
ncol=self.ncol,
|
|
138
|
+
is_slippery=self.is_slippery,
|
|
139
|
+
success_rate=self.success_rate,
|
|
140
|
+
reward_schedule=self.reward_schedule,
|
|
141
|
+
max_steps_in_episode=self.max_steps,
|
|
142
|
+
safe_cost_mean=self.safe_cost_mean,
|
|
143
|
+
safe_cost_std=self.safe_cost_std,
|
|
144
|
+
thin_cost_base=self.thin_cost_base,
|
|
145
|
+
thin_shock_prob=self.thin_shock_prob,
|
|
146
|
+
thin_shock_val=self.thin_shock_val,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def step_env(
|
|
150
|
+
self,
|
|
151
|
+
key: jax.Array,
|
|
152
|
+
state: EnvState,
|
|
153
|
+
action: Union[int, float],
|
|
154
|
+
params: EnvParams,
|
|
155
|
+
) -> Tuple[chex.Array, EnvState, float, bool, dict]:
|
|
156
|
+
"""Perform a single environment step with JIT compatibility."""
|
|
157
|
+
action = jnp.int32(action)
|
|
158
|
+
|
|
159
|
+
# 1. Determine the actual direction (handling slippery logic)
|
|
160
|
+
rng_slip, rng_cost, key = jax.random.split(key, 3)
|
|
161
|
+
|
|
162
|
+
def get_slippery_action(k):
|
|
163
|
+
# 0: intended, 1: perpendicular left, 2: perpendicular right
|
|
164
|
+
# Probabilities: [success, (1-success)/2, (1-success)/2]
|
|
165
|
+
fail_prob = (1.0 - params.success_rate) / 2.0
|
|
166
|
+
probs = jnp.array([params.success_rate, fail_prob, fail_prob])
|
|
167
|
+
|
|
168
|
+
# Map samples (0, 1, 2) to actions (action, action-1, action+1)
|
|
169
|
+
delta_idx = jax.random.choice(k, jnp.arange(3), p=probs)
|
|
170
|
+
|
|
171
|
+
# (action - 1) % 4 <-- Perpendicular Left
|
|
172
|
+
# (action) <-- Intended
|
|
173
|
+
# (action + 1) % 4 <-- Perpendicular Right
|
|
174
|
+
|
|
175
|
+
# Using a lookup array for cleaner mapping
|
|
176
|
+
candidates = jnp.array(
|
|
177
|
+
[
|
|
178
|
+
action, # Index 0: Success
|
|
179
|
+
(action - 1) % 4, # Index 1: Fail Left
|
|
180
|
+
(action + 1) % 4, # Index 2: Fail Right
|
|
181
|
+
]
|
|
182
|
+
)
|
|
183
|
+
return candidates[delta_idx]
|
|
184
|
+
|
|
185
|
+
actual_action = jax.lax.cond(
|
|
186
|
+
params.is_slippery,
|
|
187
|
+
get_slippery_action,
|
|
188
|
+
lambda k: action, # Deterministic branch
|
|
189
|
+
rng_slip,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# 2. Calculate Movement (Optimized Vectorized Approach)
|
|
193
|
+
row = state.pos // params.ncol
|
|
194
|
+
col = state.pos % params.ncol
|
|
195
|
+
|
|
196
|
+
next_row, next_col = self._apply_action(row, col, actual_action, params)
|
|
197
|
+
next_pos = next_row * params.ncol + next_col
|
|
198
|
+
|
|
199
|
+
# 3. Check Rewards and Termination
|
|
200
|
+
tile_type = params.desc[next_row, next_col]
|
|
201
|
+
|
|
202
|
+
is_goal = tile_type == TILE_GOAL
|
|
203
|
+
is_thin = tile_type == TILE_THIN
|
|
204
|
+
|
|
205
|
+
# Reward Schedule: [Goal, Thin, Frozen/Start]
|
|
206
|
+
# We select index 0, 1, or 2 based on tile type
|
|
207
|
+
tile_idx = jnp.where(is_goal, 0, jnp.where(is_thin, 1, 2))
|
|
208
|
+
reward = params.reward_schedule[tile_idx]
|
|
209
|
+
|
|
210
|
+
safe_noise = jax.random.normal(rng_cost) * params.safe_cost_std
|
|
211
|
+
cost_safe = params.safe_cost_mean + safe_noise
|
|
212
|
+
|
|
213
|
+
is_shock = jax.random.bernoulli(rng_cost, p=params.thin_shock_prob)
|
|
214
|
+
cost_thin = params.thin_cost_base + (is_shock * params.thin_shock_val)
|
|
215
|
+
|
|
216
|
+
cost = jnp.select([is_goal, is_thin], [0.0, cost_thin], default=cost_safe)
|
|
217
|
+
|
|
218
|
+
# 4. Update State
|
|
219
|
+
# Time limit truncation is handled in Gymnax wrappers usually,
|
|
220
|
+
# but we track it here for the 'done' flag consistency.
|
|
221
|
+
new_time = state.time + 1
|
|
222
|
+
truncated = new_time >= params.max_steps_in_episode
|
|
223
|
+
done = is_goal | truncated
|
|
224
|
+
|
|
225
|
+
new_state = EnvState(pos=next_pos, time=new_time)
|
|
226
|
+
|
|
227
|
+
return self.get_obs(new_state, params), new_state, reward, done, {"cost": cost}
|
|
228
|
+
|
|
229
|
+
def reset_env(
|
|
230
|
+
self,
|
|
231
|
+
key: chex.PRNGKey,
|
|
232
|
+
params: EnvParams,
|
|
233
|
+
) -> Tuple[chex.Array, EnvState]:
|
|
234
|
+
"""Reset environment to the start position."""
|
|
235
|
+
# By definition, FrozenLake always starts at (0,0)
|
|
236
|
+
# If dynamic start positions are needed, one would scan params.desc for 'S'
|
|
237
|
+
state = EnvState(pos=0, time=0)
|
|
238
|
+
return self.get_obs(state, params), state
|
|
239
|
+
|
|
240
|
+
def get_obs(self, state: EnvState, params: EnvParams) -> chex.Array:
|
|
241
|
+
"""Return scalar observation (flat position index)."""
|
|
242
|
+
return jnp.array(state.pos, dtype=jnp.int32)
|
|
243
|
+
|
|
244
|
+
def _apply_action(
|
|
245
|
+
self, row: int, col: int, action: int, params: EnvParams
|
|
246
|
+
) -> Tuple[int, int]:
|
|
247
|
+
"""Optimized movement logic using delta arrays and clipping."""
|
|
248
|
+
# Gym Direction Mapping:
|
|
249
|
+
# 0: Left (0, -1)
|
|
250
|
+
# 1: Down (1, 0)
|
|
251
|
+
# 2: Right (0, 1)
|
|
252
|
+
# 3: Up (-1, 0)
|
|
253
|
+
|
|
254
|
+
dr = jnp.array([0, 1, 0, -1])
|
|
255
|
+
dc = jnp.array([-1, 0, 1, 0])
|
|
256
|
+
|
|
257
|
+
new_row = row + dr[action]
|
|
258
|
+
new_col = col + dc[action]
|
|
259
|
+
|
|
260
|
+
# Ensure we stay within the grid
|
|
261
|
+
new_row = jnp.clip(new_row, 0, params.nrow - 1)
|
|
262
|
+
new_col = jnp.clip(new_col, 0, params.ncol - 1)
|
|
263
|
+
|
|
264
|
+
return new_row, new_col
|
|
265
|
+
|
|
266
|
+
@property
|
|
267
|
+
def name(self) -> str:
|
|
268
|
+
return "FrozenLake-v2"
|
|
269
|
+
|
|
270
|
+
@property
|
|
271
|
+
def num_actions(self) -> int:
|
|
272
|
+
return 4
|
|
273
|
+
|
|
274
|
+
def action_space(self, params: Optional[EnvParams] = None) -> spaces.Discrete:
|
|
275
|
+
return spaces.Discrete(4)
|
|
276
|
+
|
|
277
|
+
def observation_space(self, params: EnvParams) -> spaces.Discrete:
|
|
278
|
+
return spaces.Discrete(params.nrow * params.ncol)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
# --- Factory Helper ---
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def make_frozen_lake(
|
|
285
|
+
map_name: str = "4x4",
|
|
286
|
+
is_slippery: bool = True,
|
|
287
|
+
success_rate: float = 1.0 / 3.0,
|
|
288
|
+
reward_schedule: Tuple[float, float, float] = (1.0, 0.0, 0.0),
|
|
289
|
+
) -> FrozenLakeV2:
|
|
290
|
+
"""Factory function to easier initialization."""
|
|
291
|
+
return FrozenLakeV2(
|
|
292
|
+
map_name=map_name,
|
|
293
|
+
is_slippery=is_slippery,
|
|
294
|
+
success_rate=success_rate,
|
|
295
|
+
reward_schedule=reward_schedule,
|
|
296
|
+
)
|
|
@@ -1,16 +1,18 @@
|
|
|
1
|
-
safenax/__init__.py,sha256=
|
|
1
|
+
safenax/__init__.py,sha256=yD78xGJagBtjTM9fVskf9OULJdXSLxPfLSWTvxEINOY,595
|
|
2
2
|
safenax/fragile_ant.py,sha256=10XAOYFrEmi9mjeVurk6W4fIQj5zaUXlQYthQQxSm14,6751
|
|
3
|
-
safenax/frozen_lake.py,sha256=YZT5KWbiFyWY6gUo6vPn9owW_flinqnZIkmFtPclGs0,7880
|
|
4
3
|
safenax/eco_ant/__init__.py,sha256=USsps574Jfc7A5N9PWiY2zbwiphy3URJinzvfsSYjEg,144
|
|
5
4
|
safenax/eco_ant/eco_ant_v1.py,sha256=G6YekTSSK2orcYjNR9QNVZkKpeIrqM56m7gmsNu4cOI,2743
|
|
6
5
|
safenax/eco_ant/eco_ant_v2.py,sha256=Aid3ySUJuzGHLiC4L93wLNRy9IrTTsEdP7Ii8aDxQqQ,2601
|
|
6
|
+
safenax/frozen_lake/__init__.py,sha256=81aH7mpQiEWJeem4usZTbilSdlXDJybA7ePowxRyQhc,176
|
|
7
|
+
safenax/frozen_lake/frozen_lake_v1.py,sha256=6Yy9tm4MbrNiYXDNi091Jsh9iwoEcKz3TrOn2sVVGTw,7887
|
|
8
|
+
safenax/frozen_lake/frozen_lake_v2.py,sha256=fVe1ObHXU_PLnG1MwiAlZEG-VaUoIq5YYEPfiC0-tok,9634
|
|
7
9
|
safenax/portfolio_optimization/__init__.py,sha256=tbtCF4fVfan2nfFJc2wNl24hCALSb0yON1OYboN5OGk,245
|
|
8
10
|
safenax/portfolio_optimization/po_crypto.py,sha256=Bi4QCd4MoeQAnhag22MFWdqy1uQ5hVQdiwYymP9v7N4,7342
|
|
9
11
|
safenax/portfolio_optimization/po_garch.py,sha256=f2kneV5NpH_ebG_IFcfUvc3qthzZHEZt5YwcKgaI9sI,33320
|
|
10
12
|
safenax/wrappers/__init__.py,sha256=v9wyHyR482ZEfmfTtcGabpf_lUHze4fy-NjrEaGv3zA,158
|
|
11
13
|
safenax/wrappers/brax.py,sha256=svijcYVoWy5ej7RRLuN8VixDL_cMXKBK-veFsC57LRE,2985
|
|
12
14
|
safenax/wrappers/log.py,sha256=jsjT0FJBo21rCM6D2Hx9fOwXLdwP1MW6PAx1BJBP2lA,2842
|
|
13
|
-
safenax-0.4.
|
|
14
|
-
safenax-0.4.
|
|
15
|
-
safenax-0.4.
|
|
16
|
-
safenax-0.4.
|
|
15
|
+
safenax-0.4.5.dist-info/METADATA,sha256=J4fo_TeyPwOLsK_6NDqUbDECDjOhWYhAv77TUFJI_ZE,1202
|
|
16
|
+
safenax-0.4.5.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
17
|
+
safenax-0.4.5.dist-info/licenses/LICENSE,sha256=BI7P9lDrJUcIUIX_4sCSE9pKHgCYIKWzHCOFyn85eKk,1077
|
|
18
|
+
safenax-0.4.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|