safenax 0.4.5__py3-none-any.whl → 0.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- safenax/frozen_lake/frozen_lake_v2.py +7 -1
- safenax/portfolio_optimization/po_garch.py +51 -17
- {safenax-0.4.5.dist-info → safenax-0.4.7.dist-info}/METADATA +1 -1
- {safenax-0.4.5.dist-info → safenax-0.4.7.dist-info}/RECORD +6 -6
- {safenax-0.4.5.dist-info → safenax-0.4.7.dist-info}/WHEEL +0 -0
- {safenax-0.4.5.dist-info → safenax-0.4.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -224,7 +224,13 @@ class FrozenLakeV2(environment.Environment):
|
|
|
224
224
|
|
|
225
225
|
new_state = EnvState(pos=next_pos, time=new_time)
|
|
226
226
|
|
|
227
|
-
return
|
|
227
|
+
return (
|
|
228
|
+
self.get_obs(new_state, params),
|
|
229
|
+
new_state,
|
|
230
|
+
reward,
|
|
231
|
+
done,
|
|
232
|
+
{"cost": cost, "tile_type": tile_type},
|
|
233
|
+
)
|
|
228
234
|
|
|
229
235
|
def reset_env(
|
|
230
236
|
self,
|
|
@@ -64,6 +64,11 @@ class EnvParams:
|
|
|
64
64
|
garch_params: Dict[str, GARCHParams] # GARCH params for each asset
|
|
65
65
|
|
|
66
66
|
|
|
67
|
+
class ObsType(Enum):
|
|
68
|
+
EASY = "easy"
|
|
69
|
+
MARKET = "market"
|
|
70
|
+
|
|
71
|
+
|
|
67
72
|
@jax.jit
|
|
68
73
|
def _sample_garch(carry, x):
|
|
69
74
|
"""
|
|
@@ -125,6 +130,7 @@ class PortfolioOptimizationGARCH(Environment):
|
|
|
125
130
|
step_size: int = 1,
|
|
126
131
|
num_samples: int = 1_000_000,
|
|
127
132
|
num_trajectories: int = 1,
|
|
133
|
+
obs_type: ObsType = ObsType.MARKET,
|
|
128
134
|
):
|
|
129
135
|
"""
|
|
130
136
|
Initialize GARCH portfolio environment.
|
|
@@ -142,6 +148,7 @@ class PortfolioOptimizationGARCH(Environment):
|
|
|
142
148
|
self.step_size = step_size
|
|
143
149
|
self.num_samples = num_samples
|
|
144
150
|
self.num_trajectories = num_trajectories
|
|
151
|
+
self.obs_type = obs_type
|
|
145
152
|
|
|
146
153
|
# Store individual GARCH params for default_params property
|
|
147
154
|
self._garch_params = {name: garch_params[name] for name in self.asset_names}
|
|
@@ -241,25 +248,28 @@ class PortfolioOptimizationGARCH(Environment):
|
|
|
241
248
|
dtype=jnp.float32,
|
|
242
249
|
)
|
|
243
250
|
|
|
244
|
-
def
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
# obs_shape = (
|
|
252
|
-
# self.num_assets
|
|
253
|
-
# + 1
|
|
254
|
-
# + self.step_size * self.num_assets * 2
|
|
255
|
-
# + num_garch_params,
|
|
256
|
-
# )
|
|
251
|
+
def _obs_space_market(self, params: EnvParams) -> spaces.Box:
|
|
252
|
+
obs_shape = (self.num_assets,)
|
|
253
|
+
return spaces.Box(
|
|
254
|
+
low=-jnp.inf, high=jnp.inf, shape=obs_shape, dtype=jnp.float32
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
def _obs_space_easy(self, params: EnvParams) -> spaces.Box:
|
|
257
258
|
obs_shape = (self.num_assets * 2,)
|
|
258
259
|
return spaces.Box(
|
|
259
260
|
low=-jnp.inf, high=jnp.inf, shape=obs_shape, dtype=jnp.float32
|
|
260
261
|
)
|
|
261
262
|
|
|
262
|
-
def
|
|
263
|
+
def observation_space(self, params: EnvParams) -> spaces.Box:
|
|
264
|
+
"""Return observation space based on configured observation type."""
|
|
265
|
+
if self.obs_type == ObsType.EASY:
|
|
266
|
+
return self._obs_space_easy(params)
|
|
267
|
+
elif self.obs_type == ObsType.MARKET:
|
|
268
|
+
return self._obs_space_market(params)
|
|
269
|
+
else:
|
|
270
|
+
raise ValueError(f"Unknown observation type: {self.obs_type}")
|
|
271
|
+
|
|
272
|
+
def _get_obs_easy(self, state: EnvState, params: EnvParams) -> jax.Array:
|
|
263
273
|
next_time = state.time + self.step_size
|
|
264
274
|
# Index into the correct trajectory: (num_trajectories, num_samples, num_assets)
|
|
265
275
|
# Use dynamic_slice for JIT compatibility
|
|
@@ -274,7 +284,22 @@ class PortfolioOptimizationGARCH(Environment):
|
|
|
274
284
|
obs = jnp.concatenate([next_vol.flatten(), mu.flatten()])
|
|
275
285
|
return obs
|
|
276
286
|
|
|
277
|
-
def
|
|
287
|
+
def _get_obs_market(self, state: EnvState, params: EnvParams) -> jax.Array:
|
|
288
|
+
# Extract recent returns and volatilities from pre-generated path
|
|
289
|
+
start_time_idx = jnp.maximum(0, state.time - self.step_size + 1)
|
|
290
|
+
|
|
291
|
+
# Index into correct trajectory using dynamic_slice
|
|
292
|
+
log_returns_window = jax.lax.dynamic_slice(
|
|
293
|
+
self.log_returns,
|
|
294
|
+
(state.trajectory_id, start_time_idx, 0),
|
|
295
|
+
(1, self.step_size, self.num_assets),
|
|
296
|
+
).squeeze(0) # Remove trajectory dimension
|
|
297
|
+
|
|
298
|
+
step_log_return = log_returns_window.sum(axis=0)
|
|
299
|
+
|
|
300
|
+
return step_log_return
|
|
301
|
+
|
|
302
|
+
def _get_obs_full(self, state: EnvState, params: EnvParams) -> jax.Array:
|
|
278
303
|
"""Get observation from current state."""
|
|
279
304
|
# Extract recent returns and volatilities from pre-generated path
|
|
280
305
|
start_time_idx = jnp.maximum(0, state.time - self.step_size + 1)
|
|
@@ -310,6 +335,15 @@ class PortfolioOptimizationGARCH(Environment):
|
|
|
310
335
|
)
|
|
311
336
|
return obs
|
|
312
337
|
|
|
338
|
+
def get_obs(self, state: EnvState, params: EnvParams) -> jax.Array:
|
|
339
|
+
"""Get observation based on configured observation type."""
|
|
340
|
+
if self.obs_type == ObsType.EASY:
|
|
341
|
+
return self._get_obs_easy(state, params)
|
|
342
|
+
elif self.obs_type == ObsType.MARKET:
|
|
343
|
+
return self._get_obs_market(state, params)
|
|
344
|
+
else:
|
|
345
|
+
raise ValueError(f"Unknown observation type: {self.obs_type}")
|
|
346
|
+
|
|
313
347
|
def is_terminal(self, state: EnvState, params: EnvParams) -> jax.Array:
|
|
314
348
|
"""Check if episode is done."""
|
|
315
349
|
max_steps_reached = state.step >= params.max_steps
|
|
@@ -393,7 +427,7 @@ class PortfolioOptimizationGARCH(Environment):
|
|
|
393
427
|
total_value=new_total_value,
|
|
394
428
|
)
|
|
395
429
|
|
|
396
|
-
obs = self.
|
|
430
|
+
obs = self.get_obs(next_state, params)
|
|
397
431
|
done = self.is_terminal(next_state, params)
|
|
398
432
|
info = {"cost": -reward}
|
|
399
433
|
return obs, next_state, reward, done, info
|
|
@@ -448,7 +482,7 @@ class PortfolioOptimizationGARCH(Environment):
|
|
|
448
482
|
values=values,
|
|
449
483
|
total_value=jnp.sum(values),
|
|
450
484
|
)
|
|
451
|
-
obs = self.
|
|
485
|
+
obs = self.get_obs(state, params)
|
|
452
486
|
return obs, state
|
|
453
487
|
|
|
454
488
|
def plot_garch(self, trajectory_id: int = 0):
|
|
@@ -5,14 +5,14 @@ safenax/eco_ant/eco_ant_v1.py,sha256=G6YekTSSK2orcYjNR9QNVZkKpeIrqM56m7gmsNu4cOI
|
|
|
5
5
|
safenax/eco_ant/eco_ant_v2.py,sha256=Aid3ySUJuzGHLiC4L93wLNRy9IrTTsEdP7Ii8aDxQqQ,2601
|
|
6
6
|
safenax/frozen_lake/__init__.py,sha256=81aH7mpQiEWJeem4usZTbilSdlXDJybA7ePowxRyQhc,176
|
|
7
7
|
safenax/frozen_lake/frozen_lake_v1.py,sha256=6Yy9tm4MbrNiYXDNi091Jsh9iwoEcKz3TrOn2sVVGTw,7887
|
|
8
|
-
safenax/frozen_lake/frozen_lake_v2.py,sha256=
|
|
8
|
+
safenax/frozen_lake/frozen_lake_v2.py,sha256=rXSJ8lJKPsF0ZVQcQpBm6ysBYq5I2G0GIfsfLD1-TmA,9731
|
|
9
9
|
safenax/portfolio_optimization/__init__.py,sha256=tbtCF4fVfan2nfFJc2wNl24hCALSb0yON1OYboN5OGk,245
|
|
10
10
|
safenax/portfolio_optimization/po_crypto.py,sha256=Bi4QCd4MoeQAnhag22MFWdqy1uQ5hVQdiwYymP9v7N4,7342
|
|
11
|
-
safenax/portfolio_optimization/po_garch.py,sha256=
|
|
11
|
+
safenax/portfolio_optimization/po_garch.py,sha256=qIVWDJg1qw95pSqVViwQfZGqZ66NNAGyIa337qyMmbU,34690
|
|
12
12
|
safenax/wrappers/__init__.py,sha256=v9wyHyR482ZEfmfTtcGabpf_lUHze4fy-NjrEaGv3zA,158
|
|
13
13
|
safenax/wrappers/brax.py,sha256=svijcYVoWy5ej7RRLuN8VixDL_cMXKBK-veFsC57LRE,2985
|
|
14
14
|
safenax/wrappers/log.py,sha256=jsjT0FJBo21rCM6D2Hx9fOwXLdwP1MW6PAx1BJBP2lA,2842
|
|
15
|
-
safenax-0.4.
|
|
16
|
-
safenax-0.4.
|
|
17
|
-
safenax-0.4.
|
|
18
|
-
safenax-0.4.
|
|
15
|
+
safenax-0.4.7.dist-info/METADATA,sha256=qYYCkUAu6eBJtMJw8uhM6aEQvSrJ3YMC2SjnsWHoCh0,1202
|
|
16
|
+
safenax-0.4.7.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
17
|
+
safenax-0.4.7.dist-info/licenses/LICENSE,sha256=BI7P9lDrJUcIUIX_4sCSE9pKHgCYIKWzHCOFyn85eKk,1077
|
|
18
|
+
safenax-0.4.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|