imt-ring 1.6.12__py3-none-any.whl → 1.6.13__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.12
3
+ Version: 1.6.13
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -53,10 +53,10 @@ ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
53
53
  ring/ml/__init__.py,sha256=nbh48gaswWeY4S4vT1sply_3ROj2DQ7agjoLR4Ho3T8,1517
54
54
  ring/ml/base.py,sha256=lfwEZLBDglOSRWChUHoH1kezefhttPV9TMEpNIqsMNw,9972
55
55
  ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
56
- ring/ml/ml_utils.py,sha256=siiRWbUpjYQz1nAlARm47oqR2K74YTiE1syCoOEmiWw,6370
56
+ ring/ml/ml_utils.py,sha256=1GXJfeoXbwCbRdYA2np3CbJpSupaw4eyf3quh9y4BO0,6462
57
57
  ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
58
58
  ring/ml/ringnet.py,sha256=Tb2WJ_cc5L3mk1lo0NOfkpXIzJZXf4PJ5aLPtHQyUmY,8650
59
- ring/ml/rnno_v1.py,sha256=T4SKG7iypqn2HBQLKhDmJ2Slj2Z5jtUBHvX_6aL8pyM,1103
59
+ ring/ml/rnno_v1.py,sha256=ujyIkDxMSTag9iRFEmoHqfqSrlOFjcZs9_rBbLd8p9Q,1380
60
60
  ring/ml/train.py,sha256=huUfMK6eotS6BRrQKoZ-AUG0um3jlqpfQFZNJT8LKiE,10854
61
61
  ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
62
62
  ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
@@ -84,7 +84,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
84
84
  ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
85
85
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
86
86
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
87
- imt_ring-1.6.12.dist-info/METADATA,sha256=NIcGCBCzA9jwqxvyHYHl5QdfiaFLLxdnQjOk17YX0bA,3821
88
- imt_ring-1.6.12.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
89
- imt_ring-1.6.12.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
90
- imt_ring-1.6.12.dist-info/RECORD,,
87
+ imt_ring-1.6.13.dist-info/METADATA,sha256=wMwfHX8PsYaxXZRldgqd71fGltIfUo_W9xVE9DSz5o0,3821
88
+ imt_ring-1.6.13.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
89
+ imt_ring-1.6.13.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
90
+ imt_ring-1.6.13.dist-info/RECORD,,
ring/ml/ml_utils.py CHANGED
@@ -184,6 +184,8 @@ def on_cluster() -> bool:
184
184
 
185
185
 
186
186
  def unique_id() -> str:
187
+ if wandb.run is not None:
188
+ wandb.config.setdefault("unique_id", ring._UNIQUE_ID)
187
189
  return ring._UNIQUE_ID
188
190
 
189
191
 
ring/ml/rnno_v1.py CHANGED
@@ -4,6 +4,8 @@ import haiku as hk
4
4
  import jax
5
5
  import jax.numpy as jnp
6
6
 
7
+ from .ringnet import LSTM
8
+
7
9
 
8
10
  def rnno_v1_forward_factory(
9
11
  output_dim: int,
@@ -13,18 +15,29 @@ def rnno_v1_forward_factory(
13
15
  act_fn_linear=jax.nn.relu,
14
16
  act_fn_rnn=jax.nn.elu,
15
17
  lam: Optional[tuple[int]] = None,
18
+ celltype: str = "gru",
16
19
  ):
17
20
  # unused
18
21
  del lam
19
22
 
23
+ if celltype == "gru":
24
+ _cell = hk.GRU
25
+ _factor = 1
26
+ elif celltype == "lstm":
27
+ _cell = LSTM
28
+ _factor = 2
29
+ else:
30
+ raise NotImplementedError
31
+
20
32
  @hk.without_apply_rng
21
33
  @hk.transform_with_state
22
34
  def forward_fn(X):
23
35
  assert X.shape[-2] == 1
24
36
 
25
37
  for i, n_units in enumerate(rnn_layers):
38
+ n_units = _factor * n_units
26
39
  state = hk.get_state(f"rnn_{i}", shape=[1, n_units], init=jnp.zeros)
27
- X, state = hk.dynamic_unroll(hk.GRU(n_units), X, state)
40
+ X, state = hk.dynamic_unroll(_cell(n_units), X, state)
28
41
  hk.set_state(f"rnn_{i}", state)
29
42
 
30
43
  if layernorm: