imt-ring 1.6.12__py3-none-any.whl → 1.6.13__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {imt_ring-1.6.12.dist-info → imt_ring-1.6.13.dist-info}/METADATA +1 -1
- {imt_ring-1.6.12.dist-info → imt_ring-1.6.13.dist-info}/RECORD +6 -6
- ring/ml/ml_utils.py +2 -0
- ring/ml/rnno_v1.py +14 -1
- {imt_ring-1.6.12.dist-info → imt_ring-1.6.13.dist-info}/WHEEL +0 -0
- {imt_ring-1.6.12.dist-info → imt_ring-1.6.13.dist-info}/top_level.txt +0 -0
@@ -53,10 +53,10 @@ ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
|
|
53
53
|
ring/ml/__init__.py,sha256=nbh48gaswWeY4S4vT1sply_3ROj2DQ7agjoLR4Ho3T8,1517
|
54
54
|
ring/ml/base.py,sha256=lfwEZLBDglOSRWChUHoH1kezefhttPV9TMEpNIqsMNw,9972
|
55
55
|
ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
|
56
|
-
ring/ml/ml_utils.py,sha256=
|
56
|
+
ring/ml/ml_utils.py,sha256=1GXJfeoXbwCbRdYA2np3CbJpSupaw4eyf3quh9y4BO0,6462
|
57
57
|
ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
|
58
58
|
ring/ml/ringnet.py,sha256=Tb2WJ_cc5L3mk1lo0NOfkpXIzJZXf4PJ5aLPtHQyUmY,8650
|
59
|
-
ring/ml/rnno_v1.py,sha256=
|
59
|
+
ring/ml/rnno_v1.py,sha256=ujyIkDxMSTag9iRFEmoHqfqSrlOFjcZs9_rBbLd8p9Q,1380
|
60
60
|
ring/ml/train.py,sha256=huUfMK6eotS6BRrQKoZ-AUG0um3jlqpfQFZNJT8LKiE,10854
|
61
61
|
ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
|
62
62
|
ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
|
@@ -84,7 +84,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
|
|
84
84
|
ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
|
85
85
|
ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
|
86
86
|
ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
|
87
|
-
imt_ring-1.6.
|
88
|
-
imt_ring-1.6.
|
89
|
-
imt_ring-1.6.
|
90
|
-
imt_ring-1.6.
|
87
|
+
imt_ring-1.6.13.dist-info/METADATA,sha256=wMwfHX8PsYaxXZRldgqd71fGltIfUo_W9xVE9DSz5o0,3821
|
88
|
+
imt_ring-1.6.13.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
89
|
+
imt_ring-1.6.13.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
90
|
+
imt_ring-1.6.13.dist-info/RECORD,,
|
ring/ml/ml_utils.py
CHANGED
ring/ml/rnno_v1.py
CHANGED
@@ -4,6 +4,8 @@ import haiku as hk
|
|
4
4
|
import jax
|
5
5
|
import jax.numpy as jnp
|
6
6
|
|
7
|
+
from .ringnet import LSTM
|
8
|
+
|
7
9
|
|
8
10
|
def rnno_v1_forward_factory(
|
9
11
|
output_dim: int,
|
@@ -13,18 +15,29 @@ def rnno_v1_forward_factory(
|
|
13
15
|
act_fn_linear=jax.nn.relu,
|
14
16
|
act_fn_rnn=jax.nn.elu,
|
15
17
|
lam: Optional[tuple[int]] = None,
|
18
|
+
celltype: str = "gru",
|
16
19
|
):
|
17
20
|
# unused
|
18
21
|
del lam
|
19
22
|
|
23
|
+
if celltype == "gru":
|
24
|
+
_cell = hk.GRU
|
25
|
+
_factor = 1
|
26
|
+
elif celltype == "lstm":
|
27
|
+
_cell = LSTM
|
28
|
+
_factor = 2
|
29
|
+
else:
|
30
|
+
raise NotImplementedError
|
31
|
+
|
20
32
|
@hk.without_apply_rng
|
21
33
|
@hk.transform_with_state
|
22
34
|
def forward_fn(X):
|
23
35
|
assert X.shape[-2] == 1
|
24
36
|
|
25
37
|
for i, n_units in enumerate(rnn_layers):
|
38
|
+
n_units = _factor * n_units
|
26
39
|
state = hk.get_state(f"rnn_{i}", shape=[1, n_units], init=jnp.zeros)
|
27
|
-
X, state = hk.dynamic_unroll(
|
40
|
+
X, state = hk.dynamic_unroll(_cell(n_units), X, state)
|
28
41
|
hk.set_state(f"rnn_{i}", state)
|
29
42
|
|
30
43
|
if layernorm:
|
File without changes
|
File without changes
|