imt-ring 1.6.14__py3-none-any.whl → 1.6.15__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.14
3
+ Version: 1.6.15
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -55,8 +55,8 @@ ring/ml/base.py,sha256=lfwEZLBDglOSRWChUHoH1kezefhttPV9TMEpNIqsMNw,9972
55
55
  ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
56
56
  ring/ml/ml_utils.py,sha256=1GXJfeoXbwCbRdYA2np3CbJpSupaw4eyf3quh9y4BO0,6462
57
57
  ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
58
- ring/ml/ringnet.py,sha256=Tb2WJ_cc5L3mk1lo0NOfkpXIzJZXf4PJ5aLPtHQyUmY,8650
59
- ring/ml/rnno_v1.py,sha256=pciltx7dR_yhFv2g8BvfV0VPgKm0HxXDxL2YjBPBaKQ,1400
58
+ ring/ml/ringnet.py,sha256=Oud23uKmcvFtwNKdEu2KMMvNAFzJM_yBSRNz2a3CjL4,8670
59
+ ring/ml/rnno_v1.py,sha256=2qE08OIvTJ5PvSxKpYGzGSrvEImWrdAT_qslZ7jP5tA,1372
60
60
  ring/ml/train.py,sha256=huUfMK6eotS6BRrQKoZ-AUG0um3jlqpfQFZNJT8LKiE,10854
61
61
  ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
62
62
  ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
@@ -84,7 +84,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
84
84
  ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
85
85
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
86
86
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
87
- imt_ring-1.6.14.dist-info/METADATA,sha256=ZWWkJ_qO8O2KkW1m_HnXoz22H7ivmg7Qk4j7hEIgy9k,3821
88
- imt_ring-1.6.14.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
89
- imt_ring-1.6.14.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
90
- imt_ring-1.6.14.dist-info/RECORD,,
87
+ imt_ring-1.6.15.dist-info/METADATA,sha256=zG-f_woph73I5ErczEeJaYXZLNahC-oNXDtzj26f1Po,3821
88
+ imt_ring-1.6.15.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
89
+ imt_ring-1.6.15.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
90
+ imt_ring-1.6.15.dist-info/RECORD,,
ring/ml/ringnet.py CHANGED
@@ -175,7 +175,7 @@ class LSTM(hk.RNNCore):
175
175
  prev_state: jax.Array,
176
176
  ):
177
177
  if len(inputs.shape) > 2 or not inputs.shape:
178
- raise ValueError("LSTM input must be rank-1 or rank-2.")
178
+ raise ValueError(f"LSTM input must be rank-1 or rank-2; not {inputs.shape}")
179
179
  prev_state_h = prev_state[: self.hidden_size]
180
180
  prev_state_c = prev_state[self.hidden_size :]
181
181
  x_and_h = jnp.concatenate([inputs, prev_state_h], axis=-1)
ring/ml/rnno_v1.py CHANGED
@@ -33,13 +33,12 @@ def rnno_v1_forward_factory(
33
33
  @hk.transform_with_state
34
34
  def forward_fn(X):
35
35
  assert X.shape[-2] == 1
36
+ X = X[..., 0, :]
36
37
 
37
38
  for i, n_units in enumerate(rnn_layers):
38
- state = hk.get_state(
39
- f"rnn_{i}", shape=[1, n_units * _factor], init=jnp.zeros
40
- )
41
- X, state = hk.dynamic_unroll(_cell(n_units), X[..., 0, :], state[0])
42
- hk.set_state(f"rnn_{i}", state[None])
39
+ state = hk.get_state(f"rnn_{i}", shape=[n_units * _factor], init=jnp.zeros)
40
+ X, state = hk.dynamic_unroll(_cell(n_units), X, state)
41
+ hk.set_state(f"rnn_{i}", state)
43
42
 
44
43
  if layernorm:
45
44
  X = hk.LayerNorm(axis=-1, create_scale=False, create_offset=False)(X)