imt-ring 1.6.13__py3-none-any.whl → 1.6.15__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {imt_ring-1.6.13.dist-info → imt_ring-1.6.15.dist-info}/METADATA +1 -1
- {imt_ring-1.6.13.dist-info → imt_ring-1.6.15.dist-info}/RECORD +6 -6
- ring/ml/ringnet.py +1 -1
- ring/ml/rnno_v1.py +2 -2
- {imt_ring-1.6.13.dist-info → imt_ring-1.6.15.dist-info}/WHEEL +0 -0
- {imt_ring-1.6.13.dist-info → imt_ring-1.6.15.dist-info}/top_level.txt +0 -0
@@ -55,8 +55,8 @@ ring/ml/base.py,sha256=lfwEZLBDglOSRWChUHoH1kezefhttPV9TMEpNIqsMNw,9972
|
|
55
55
|
ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
|
56
56
|
ring/ml/ml_utils.py,sha256=1GXJfeoXbwCbRdYA2np3CbJpSupaw4eyf3quh9y4BO0,6462
|
57
57
|
ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
|
58
|
-
ring/ml/ringnet.py,sha256=
|
59
|
-
ring/ml/rnno_v1.py,sha256=
|
58
|
+
ring/ml/ringnet.py,sha256=Oud23uKmcvFtwNKdEu2KMMvNAFzJM_yBSRNz2a3CjL4,8670
|
59
|
+
ring/ml/rnno_v1.py,sha256=2qE08OIvTJ5PvSxKpYGzGSrvEImWrdAT_qslZ7jP5tA,1372
|
60
60
|
ring/ml/train.py,sha256=huUfMK6eotS6BRrQKoZ-AUG0um3jlqpfQFZNJT8LKiE,10854
|
61
61
|
ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
|
62
62
|
ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
|
@@ -84,7 +84,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
|
|
84
84
|
ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
|
85
85
|
ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
|
86
86
|
ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
|
87
|
-
imt_ring-1.6.
|
88
|
-
imt_ring-1.6.
|
89
|
-
imt_ring-1.6.
|
90
|
-
imt_ring-1.6.
|
87
|
+
imt_ring-1.6.15.dist-info/METADATA,sha256=zG-f_woph73I5ErczEeJaYXZLNahC-oNXDtzj26f1Po,3821
|
88
|
+
imt_ring-1.6.15.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
89
|
+
imt_ring-1.6.15.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
90
|
+
imt_ring-1.6.15.dist-info/RECORD,,
|
ring/ml/ringnet.py
CHANGED
@@ -175,7 +175,7 @@ class LSTM(hk.RNNCore):
|
|
175
175
|
prev_state: jax.Array,
|
176
176
|
):
|
177
177
|
if len(inputs.shape) > 2 or not inputs.shape:
|
178
|
-
raise ValueError("LSTM input must be rank-1 or rank-2.")
|
178
|
+
raise ValueError(f"LSTM input must be rank-1 or rank-2; not {inputs.shape}")
|
179
179
|
prev_state_h = prev_state[: self.hidden_size]
|
180
180
|
prev_state_c = prev_state[self.hidden_size :]
|
181
181
|
x_and_h = jnp.concatenate([inputs, prev_state_h], axis=-1)
|
ring/ml/rnno_v1.py
CHANGED
@@ -33,10 +33,10 @@ def rnno_v1_forward_factory(
|
|
33
33
|
@hk.transform_with_state
|
34
34
|
def forward_fn(X):
|
35
35
|
assert X.shape[-2] == 1
|
36
|
+
X = X[..., 0, :]
|
36
37
|
|
37
38
|
for i, n_units in enumerate(rnn_layers):
|
38
|
-
|
39
|
-
state = hk.get_state(f"rnn_{i}", shape=[1, n_units], init=jnp.zeros)
|
39
|
+
state = hk.get_state(f"rnn_{i}", shape=[n_units * _factor], init=jnp.zeros)
|
40
40
|
X, state = hk.dynamic_unroll(_cell(n_units), X, state)
|
41
41
|
hk.set_state(f"rnn_{i}", state)
|
42
42
|
|
File without changes
|
File without changes
|