imt-ring 1.6.13__py3-none-any.whl → 1.6.14__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.13
3
+ Version: 1.6.14
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -56,7 +56,7 @@ ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
56
56
  ring/ml/ml_utils.py,sha256=1GXJfeoXbwCbRdYA2np3CbJpSupaw4eyf3quh9y4BO0,6462
57
57
  ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
58
58
  ring/ml/ringnet.py,sha256=Tb2WJ_cc5L3mk1lo0NOfkpXIzJZXf4PJ5aLPtHQyUmY,8650
59
- ring/ml/rnno_v1.py,sha256=ujyIkDxMSTag9iRFEmoHqfqSrlOFjcZs9_rBbLd8p9Q,1380
59
+ ring/ml/rnno_v1.py,sha256=pciltx7dR_yhFv2g8BvfV0VPgKm0HxXDxL2YjBPBaKQ,1400
60
60
  ring/ml/train.py,sha256=huUfMK6eotS6BRrQKoZ-AUG0um3jlqpfQFZNJT8LKiE,10854
61
61
  ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
62
62
  ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
@@ -84,7 +84,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
84
84
  ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
85
85
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
86
86
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
87
- imt_ring-1.6.13.dist-info/METADATA,sha256=wMwfHX8PsYaxXZRldgqd71fGltIfUo_W9xVE9DSz5o0,3821
88
- imt_ring-1.6.13.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
89
- imt_ring-1.6.13.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
90
- imt_ring-1.6.13.dist-info/RECORD,,
87
+ imt_ring-1.6.14.dist-info/METADATA,sha256=ZWWkJ_qO8O2KkW1m_HnXoz22H7ivmg7Qk4j7hEIgy9k,3821
88
+ imt_ring-1.6.14.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
89
+ imt_ring-1.6.14.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
90
+ imt_ring-1.6.14.dist-info/RECORD,,
ring/ml/rnno_v1.py CHANGED
@@ -35,10 +35,11 @@ def rnno_v1_forward_factory(
35
35
  assert X.shape[-2] == 1
36
36
 
37
37
  for i, n_units in enumerate(rnn_layers):
38
- n_units = _factor * n_units
39
- state = hk.get_state(f"rnn_{i}", shape=[1, n_units], init=jnp.zeros)
40
- X, state = hk.dynamic_unroll(_cell(n_units), X, state)
41
- hk.set_state(f"rnn_{i}", state)
38
+ state = hk.get_state(
39
+ f"rnn_{i}", shape=[1, n_units * _factor], init=jnp.zeros
40
+ )
41
+ X, state = hk.dynamic_unroll(_cell(n_units), X[..., 0, :], state[0])
42
+ hk.set_state(f"rnn_{i}", state[None])
42
43
 
43
44
  if layernorm:
44
45
  X = hk.LayerNorm(axis=-1, create_scale=False, create_offset=False)(X)