safenax 0.4.5__py3-none-any.whl → 0.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -224,7 +224,13 @@ class FrozenLakeV2(environment.Environment):
224
224
 
225
225
  new_state = EnvState(pos=next_pos, time=new_time)
226
226
 
227
- return self.get_obs(new_state, params), new_state, reward, done, {"cost": cost}
227
+ return (
228
+ self.get_obs(new_state, params),
229
+ new_state,
230
+ reward,
231
+ done,
232
+ {"cost": cost, "tile_type": tile_type},
233
+ )
228
234
 
229
235
  def reset_env(
230
236
  self,
@@ -64,6 +64,11 @@ class EnvParams:
64
64
  garch_params: Dict[str, GARCHParams] # GARCH params for each asset
65
65
 
66
66
 
67
+ class ObsType(Enum):
68
+ EASY = "easy"
69
+ MARKET = "market"
70
+
71
+
67
72
  @jax.jit
68
73
  def _sample_garch(carry, x):
69
74
  """
@@ -125,6 +130,7 @@ class PortfolioOptimizationGARCH(Environment):
125
130
  step_size: int = 1,
126
131
  num_samples: int = 1_000_000,
127
132
  num_trajectories: int = 1,
133
+ obs_type: ObsType = ObsType.MARKET,
128
134
  ):
129
135
  """
130
136
  Initialize GARCH portfolio environment.
@@ -142,6 +148,7 @@ class PortfolioOptimizationGARCH(Environment):
142
148
  self.step_size = step_size
143
149
  self.num_samples = num_samples
144
150
  self.num_trajectories = num_trajectories
151
+ self.obs_type = obs_type
145
152
 
146
153
  # Store individual GARCH params for default_params property
147
154
  self._garch_params = {name: garch_params[name] for name in self.asset_names}
@@ -241,25 +248,28 @@ class PortfolioOptimizationGARCH(Environment):
241
248
  dtype=jnp.float32,
242
249
  )
243
250
 
244
- def observation_space(self, params: EnvParams) -> spaces.Box:
245
- """Observation: recent returns and volatilities for all assets."""
246
- # num_garch_params = (
247
- # self.num_assets * 2
248
- # + self.num_assets * self.vec_params.alpha.shape[1]
249
- # + self.num_assets * self.vec_params.beta.shape[1]
250
- # )
251
- # obs_shape = (
252
- # self.num_assets
253
- # + 1
254
- # + self.step_size * self.num_assets * 2
255
- # + num_garch_params,
256
- # )
251
+ def _obs_space_market(self, params: EnvParams) -> spaces.Box:
252
+ obs_shape = (self.num_assets,)
253
+ return spaces.Box(
254
+ low=-jnp.inf, high=jnp.inf, shape=obs_shape, dtype=jnp.float32
255
+ )
256
+
257
+ def _obs_space_easy(self, params: EnvParams) -> spaces.Box:
257
258
  obs_shape = (self.num_assets * 2,)
258
259
  return spaces.Box(
259
260
  low=-jnp.inf, high=jnp.inf, shape=obs_shape, dtype=jnp.float32
260
261
  )
261
262
 
262
- def get_obs_easy(self, state: EnvState, params: EnvParams) -> jax.Array:
263
+ def observation_space(self, params: EnvParams) -> spaces.Box:
264
+ """Return observation space based on configured observation type."""
265
+ if self.obs_type == ObsType.EASY:
266
+ return self._obs_space_easy(params)
267
+ elif self.obs_type == ObsType.MARKET:
268
+ return self._obs_space_market(params)
269
+ else:
270
+ raise ValueError(f"Unknown observation type: {self.obs_type}")
271
+
272
+ def _get_obs_easy(self, state: EnvState, params: EnvParams) -> jax.Array:
263
273
  next_time = state.time + self.step_size
264
274
  # Index into the correct trajectory: (num_trajectories, num_samples, num_assets)
265
275
  # Use dynamic_slice for JIT compatibility
@@ -274,7 +284,22 @@ class PortfolioOptimizationGARCH(Environment):
274
284
  obs = jnp.concatenate([next_vol.flatten(), mu.flatten()])
275
285
  return obs
276
286
 
277
- def get_obs(self, state: EnvState, params: EnvParams) -> jax.Array:
287
+ def _get_obs_market(self, state: EnvState, params: EnvParams) -> jax.Array:
288
+ # Extract recent returns and volatilities from pre-generated path
289
+ start_time_idx = jnp.maximum(0, state.time - self.step_size + 1)
290
+
291
+ # Index into correct trajectory using dynamic_slice
292
+ log_returns_window = jax.lax.dynamic_slice(
293
+ self.log_returns,
294
+ (state.trajectory_id, start_time_idx, 0),
295
+ (1, self.step_size, self.num_assets),
296
+ ).squeeze(0) # Remove trajectory dimension
297
+
298
+ step_log_return = log_returns_window.sum(axis=0)
299
+
300
+ return step_log_return
301
+
302
+ def _get_obs_full(self, state: EnvState, params: EnvParams) -> jax.Array:
278
303
  """Get observation from current state."""
279
304
  # Extract recent returns and volatilities from pre-generated path
280
305
  start_time_idx = jnp.maximum(0, state.time - self.step_size + 1)
@@ -310,6 +335,15 @@ class PortfolioOptimizationGARCH(Environment):
310
335
  )
311
336
  return obs
312
337
 
338
+ def get_obs(self, state: EnvState, params: EnvParams) -> jax.Array:
339
+ """Get observation based on configured observation type."""
340
+ if self.obs_type == ObsType.EASY:
341
+ return self._get_obs_easy(state, params)
342
+ elif self.obs_type == ObsType.MARKET:
343
+ return self._get_obs_market(state, params)
344
+ else:
345
+ raise ValueError(f"Unknown observation type: {self.obs_type}")
346
+
313
347
  def is_terminal(self, state: EnvState, params: EnvParams) -> jax.Array:
314
348
  """Check if episode is done."""
315
349
  max_steps_reached = state.step >= params.max_steps
@@ -393,7 +427,7 @@ class PortfolioOptimizationGARCH(Environment):
393
427
  total_value=new_total_value,
394
428
  )
395
429
 
396
- obs = self.get_obs_easy(next_state, params)
430
+ obs = self.get_obs(next_state, params)
397
431
  done = self.is_terminal(next_state, params)
398
432
  info = {"cost": -reward}
399
433
  return obs, next_state, reward, done, info
@@ -448,7 +482,7 @@ class PortfolioOptimizationGARCH(Environment):
448
482
  values=values,
449
483
  total_value=jnp.sum(values),
450
484
  )
451
- obs = self.get_obs_easy(state, params)
485
+ obs = self.get_obs(state, params)
452
486
  return obs, state
453
487
 
454
488
  def plot_garch(self, trajectory_id: int = 0):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: safenax
3
- Version: 0.4.5
3
+ Version: 0.4.7
4
4
  Summary: Constrained environments with a gymnax interface
5
5
  Project-URL: Homepage, https://github.com/0xprofessooor/safenax
6
6
  Project-URL: Repository, https://github.com/0xprofessooor/safenax
@@ -5,14 +5,14 @@ safenax/eco_ant/eco_ant_v1.py,sha256=G6YekTSSK2orcYjNR9QNVZkKpeIrqM56m7gmsNu4cOI
5
5
  safenax/eco_ant/eco_ant_v2.py,sha256=Aid3ySUJuzGHLiC4L93wLNRy9IrTTsEdP7Ii8aDxQqQ,2601
6
6
  safenax/frozen_lake/__init__.py,sha256=81aH7mpQiEWJeem4usZTbilSdlXDJybA7ePowxRyQhc,176
7
7
  safenax/frozen_lake/frozen_lake_v1.py,sha256=6Yy9tm4MbrNiYXDNi091Jsh9iwoEcKz3TrOn2sVVGTw,7887
8
- safenax/frozen_lake/frozen_lake_v2.py,sha256=fVe1ObHXU_PLnG1MwiAlZEG-VaUoIq5YYEPfiC0-tok,9634
8
+ safenax/frozen_lake/frozen_lake_v2.py,sha256=rXSJ8lJKPsF0ZVQcQpBm6ysBYq5I2G0GIfsfLD1-TmA,9731
9
9
  safenax/portfolio_optimization/__init__.py,sha256=tbtCF4fVfan2nfFJc2wNl24hCALSb0yON1OYboN5OGk,245
10
10
  safenax/portfolio_optimization/po_crypto.py,sha256=Bi4QCd4MoeQAnhag22MFWdqy1uQ5hVQdiwYymP9v7N4,7342
11
- safenax/portfolio_optimization/po_garch.py,sha256=f2kneV5NpH_ebG_IFcfUvc3qthzZHEZt5YwcKgaI9sI,33320
11
+ safenax/portfolio_optimization/po_garch.py,sha256=qIVWDJg1qw95pSqVViwQfZGqZ66NNAGyIa337qyMmbU,34690
12
12
  safenax/wrappers/__init__.py,sha256=v9wyHyR482ZEfmfTtcGabpf_lUHze4fy-NjrEaGv3zA,158
13
13
  safenax/wrappers/brax.py,sha256=svijcYVoWy5ej7RRLuN8VixDL_cMXKBK-veFsC57LRE,2985
14
14
  safenax/wrappers/log.py,sha256=jsjT0FJBo21rCM6D2Hx9fOwXLdwP1MW6PAx1BJBP2lA,2842
15
- safenax-0.4.5.dist-info/METADATA,sha256=J4fo_TeyPwOLsK_6NDqUbDECDjOhWYhAv77TUFJI_ZE,1202
16
- safenax-0.4.5.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
17
- safenax-0.4.5.dist-info/licenses/LICENSE,sha256=BI7P9lDrJUcIUIX_4sCSE9pKHgCYIKWzHCOFyn85eKk,1077
18
- safenax-0.4.5.dist-info/RECORD,,
15
+ safenax-0.4.7.dist-info/METADATA,sha256=qYYCkUAu6eBJtMJw8uhM6aEQvSrJ3YMC2SjnsWHoCh0,1202
16
+ safenax-0.4.7.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
17
+ safenax-0.4.7.dist-info/licenses/LICENSE,sha256=BI7P9lDrJUcIUIX_4sCSE9pKHgCYIKWzHCOFyn85eKk,1077
18
+ safenax-0.4.7.dist-info/RECORD,,