imt-ring 1.6.3__tar.gz → 1.6.5__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (116) hide show
  1. {imt_ring-1.6.3 → imt_ring-1.6.5}/PKG-INFO +31 -1
  2. {imt_ring-1.6.3 → imt_ring-1.6.5}/pyproject.toml +1 -1
  3. {imt_ring-1.6.3 → imt_ring-1.6.5}/readme.md +30 -0
  4. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/imt_ring.egg-info/PKG-INFO +31 -1
  5. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/__init__.py +43 -40
  6. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/algorithms/dynamics.py +1 -0
  7. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/algorithms/generator/base.py +6 -2
  8. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/base.py +4 -1
  9. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/ml/ringnet.py +1 -0
  10. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/utils/utils.py +3 -2
  11. {imt_ring-1.6.3 → imt_ring-1.6.5}/setup.cfg +0 -0
  12. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/imt_ring.egg-info/SOURCES.txt +0 -0
  13. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/imt_ring.egg-info/dependency_links.txt +0 -0
  14. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/imt_ring.egg-info/requires.txt +0 -0
  15. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/imt_ring.egg-info/top_level.txt +0 -0
  16. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/algebra.py +0 -0
  17. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/algorithms/__init__.py +0 -0
  18. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/algorithms/_random.py +0 -0
  19. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/algorithms/custom_joints/__init__.py +0 -0
  20. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/algorithms/custom_joints/rr_imp_joint.py +0 -0
  21. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
  22. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/algorithms/custom_joints/suntay.py +0 -0
  23. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/algorithms/generator/__init__.py +0 -0
  24. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/algorithms/generator/batch.py +0 -0
  25. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/algorithms/generator/finalize_fns.py +0 -0
  26. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/algorithms/generator/motion_artifacts.py +0 -0
  27. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/algorithms/generator/pd_control.py +0 -0
  28. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/algorithms/generator/setup_fns.py +0 -0
  29. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/algorithms/generator/types.py +0 -0
  30. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/algorithms/jcalc.py +0 -0
  31. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/algorithms/kinematics.py +0 -0
  32. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/algorithms/sensors.py +0 -0
  33. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/__init__.py +0 -0
  34. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/examples/branched.xml +0 -0
  35. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
  36. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
  37. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
  38. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/examples/inv_pendulum.xml +0 -0
  39. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
  40. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/examples/spherical_stiff.xml +0 -0
  41. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/examples/symmetric.xml +0 -0
  42. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/examples/test_all_1.xml +0 -0
  43. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/examples/test_all_2.xml +0 -0
  44. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
  45. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/examples/test_control.xml +0 -0
  46. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/examples/test_double_pendulum.xml +0 -0
  47. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/examples/test_free.xml +0 -0
  48. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/examples/test_kinematics.xml +0 -0
  49. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
  50. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
  51. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/examples/test_randomize_position.xml +0 -0
  52. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/examples/test_sensors.xml +0 -0
  53. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
  54. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/examples.py +0 -0
  55. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/test_examples.py +0 -0
  56. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/xml/__init__.py +0 -0
  57. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/xml/abstract.py +0 -0
  58. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/xml/from_xml.py +0 -0
  59. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/xml/test_from_xml.py +0 -0
  60. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/xml/test_to_xml.py +0 -0
  61. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/io/xml/to_xml.py +0 -0
  62. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/maths.py +0 -0
  63. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/ml/__init__.py +0 -0
  64. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/ml/base.py +0 -0
  65. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/ml/callbacks.py +0 -0
  66. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/ml/ml_utils.py +0 -0
  67. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/ml/optimizer.py +0 -0
  68. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  69. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/ml/params/0x1d76628065a71e0f.pickle +0 -0
  70. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/ml/rnno_v1.py +0 -0
  71. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/ml/train.py +0 -0
  72. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/ml/training_loop.py +0 -0
  73. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/rendering/__init__.py +0 -0
  74. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/rendering/base_render.py +0 -0
  75. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/rendering/mujoco_render.py +0 -0
  76. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/rendering/vispy_render.py +0 -0
  77. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/rendering/vispy_visuals.py +0 -0
  78. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/sim2real/__init__.py +0 -0
  79. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/sim2real/sim2real.py +0 -0
  80. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/spatial.py +0 -0
  81. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/sys_composer/__init__.py +0 -0
  82. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/sys_composer/delete_sys.py +0 -0
  83. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/sys_composer/inject_sys.py +0 -0
  84. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/sys_composer/morph_sys.py +0 -0
  85. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/utils/__init__.py +0 -0
  86. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/utils/backend.py +0 -0
  87. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/utils/batchsize.py +0 -0
  88. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/utils/colab.py +0 -0
  89. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/utils/hdf5.py +0 -0
  90. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/utils/normalizer.py +0 -0
  91. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/utils/path.py +0 -0
  92. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/utils/randomize_sys.py +0 -0
  93. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/utils/register_gym_envs/__init__.py +0 -0
  94. {imt_ring-1.6.3 → imt_ring-1.6.5}/src/ring/utils/register_gym_envs/saddle.py +0 -0
  95. {imt_ring-1.6.3 → imt_ring-1.6.5}/tests/test_algebra.py +0 -0
  96. {imt_ring-1.6.3 → imt_ring-1.6.5}/tests/test_base.py +0 -0
  97. {imt_ring-1.6.3 → imt_ring-1.6.5}/tests/test_custom_joints.py +0 -0
  98. {imt_ring-1.6.3 → imt_ring-1.6.5}/tests/test_dynamics.py +0 -0
  99. {imt_ring-1.6.3 → imt_ring-1.6.5}/tests/test_generator.py +0 -0
  100. {imt_ring-1.6.3 → imt_ring-1.6.5}/tests/test_jcalc.py +0 -0
  101. {imt_ring-1.6.3 → imt_ring-1.6.5}/tests/test_jit.py +0 -0
  102. {imt_ring-1.6.3 → imt_ring-1.6.5}/tests/test_kinematics.py +0 -0
  103. {imt_ring-1.6.3 → imt_ring-1.6.5}/tests/test_maths.py +0 -0
  104. {imt_ring-1.6.3 → imt_ring-1.6.5}/tests/test_ml_utils.py +0 -0
  105. {imt_ring-1.6.3 → imt_ring-1.6.5}/tests/test_motion_artifacts.py +0 -0
  106. {imt_ring-1.6.3 → imt_ring-1.6.5}/tests/test_pd_control.py +0 -0
  107. {imt_ring-1.6.3 → imt_ring-1.6.5}/tests/test_quickstart_example.py +0 -0
  108. {imt_ring-1.6.3 → imt_ring-1.6.5}/tests/test_random.py +0 -0
  109. {imt_ring-1.6.3 → imt_ring-1.6.5}/tests/test_randomize.py +0 -0
  110. {imt_ring-1.6.3 → imt_ring-1.6.5}/tests/test_rcmg.py +0 -0
  111. {imt_ring-1.6.3 → imt_ring-1.6.5}/tests/test_render.py +0 -0
  112. {imt_ring-1.6.3 → imt_ring-1.6.5}/tests/test_sensors.py +0 -0
  113. {imt_ring-1.6.3 → imt_ring-1.6.5}/tests/test_sim2real.py +0 -0
  114. {imt_ring-1.6.3 → imt_ring-1.6.5}/tests/test_sys_composer.py +0 -0
  115. {imt_ring-1.6.3 → imt_ring-1.6.5}/tests/test_train.py +0 -0
  116. {imt_ring-1.6.3 → imt_ring-1.6.5}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.3
3
+ Version: 1.6.5
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -60,6 +60,36 @@ pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-re
60
60
 
61
61
  Available [here](https://simipixel.github.io/ring/).
62
62
 
63
+ ## Quickstart Example
64
+ ```python
65
+ import ring
66
+ import numpy as np
67
+
68
+ T : int = 30 # sequence length [s]
69
+ Ts : float = 0.01 # sampling interval [s]
70
+ B : int = 1 # batch size
71
+ lam: list[int] = [0, 1, 2] # parent array
72
+ N : int = len(lam) # number of bodies
73
+ T_i: int = int(T/Ts) # number of timesteps
74
+
75
+ X = np.zeros((B, T_i, N, 9))
76
+ # where X is structured as follows:
77
+ # X[..., :3] = acc
78
+ # X[..., 3:6] = gyr
79
+ # X[..., 6:9] = jointaxis
80
+
81
+ # let's assume we have an IMU on each outer segment of the
82
+ # three-segment kinematic chain
83
+ X[..., 0, :3] = acc_segment1
84
+ X[..., 2, :3] = acc_segment3
85
+ X[..., 0, 3:6] = gyr_segment1
86
+ X[..., 2, 3:6] = gyr_segment3
87
+
88
+ ringnet = ring.RING(lam, Ts)
89
+ yhat, _ = ringnet.apply(X)
90
+ # yhat: unit quaternions, shape = (B, T_i, N, 4)
91
+ ```
92
+
63
93
  ### Known fixes
64
94
 
65
95
  #### Offscreen rendering with Mujoco
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "imt-ring"
7
- version = "1.6.3"
7
+ version = "1.6.5"
8
8
  authors = [
9
9
  { name="Simon Bachhuber", email="simon.bachhuber@fau.de" },
10
10
  ]
@@ -22,6 +22,36 @@ pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-re
22
22
 
23
23
  Available [here](https://simipixel.github.io/ring/).
24
24
 
25
+ ## Quickstart Example
26
+ ```python
27
+ import ring
28
+ import numpy as np
29
+
30
+ T : int = 30 # sequence length [s]
31
+ Ts : float = 0.01 # sampling interval [s]
32
+ B : int = 1 # batch size
33
+ lam: list[int] = [0, 1, 2] # parent array
34
+ N : int = len(lam) # number of bodies
35
+ T_i: int = int(T/Ts) # number of timesteps
36
+
37
+ X = np.zeros((B, T_i, N, 9))
38
+ # where X is structured as follows:
39
+ # X[..., :3] = acc
40
+ # X[..., 3:6] = gyr
41
+ # X[..., 6:9] = jointaxis
42
+
43
+ # let's assume we have an IMU on each outer segment of the
44
+ # three-segment kinematic chain
45
+ X[..., 0, :3] = acc_segment1
46
+ X[..., 2, :3] = acc_segment3
47
+ X[..., 0, 3:6] = gyr_segment1
48
+ X[..., 2, 3:6] = gyr_segment3
49
+
50
+ ringnet = ring.RING(lam, Ts)
51
+ yhat, _ = ringnet.apply(X)
52
+ # yhat: unit quaternions, shape = (B, T_i, N, 4)
53
+ ```
54
+
25
55
  ### Known fixes
26
56
 
27
57
  #### Offscreen rendering with Mujoco
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.3
3
+ Version: 1.6.5
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -60,6 +60,36 @@ pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-re
60
60
 
61
61
  Available [here](https://simipixel.github.io/ring/).
62
62
 
63
+ ## Quickstart Example
64
+ ```python
65
+ import ring
66
+ import numpy as np
67
+
68
+ T : int = 30 # sequence length [s]
69
+ Ts : float = 0.01 # sampling interval [s]
70
+ B : int = 1 # batch size
71
+ lam: list[int] = [0, 1, 2] # parent array
72
+ N : int = len(lam) # number of bodies
73
+ T_i: int = int(T/Ts) # number of timesteps
74
+
75
+ X = np.zeros((B, T_i, N, 9))
76
+ # where X is structured as follows:
77
+ # X[..., :3] = acc
78
+ # X[..., 3:6] = gyr
79
+ # X[..., 6:9] = jointaxis
80
+
81
+ # let's assume we have an IMU on each outer segment of the
82
+ # three-segment kinematic chain
83
+ X[..., 0, :3] = acc_segment1
84
+ X[..., 2, :3] = acc_segment3
85
+ X[..., 0, 3:6] = gyr_segment1
86
+ X[..., 2, 3:6] = gyr_segment3
87
+
88
+ ringnet = ring.RING(lam, Ts)
89
+ yhat, _ = ringnet.apply(X)
90
+ # yhat: unit quaternions, shape = (B, T_i, N, 4)
91
+ ```
92
+
63
93
  ### Known fixes
64
94
 
65
95
  #### Offscreen rendering with Mujoco
@@ -20,52 +20,55 @@ from .base import System
20
20
  from .base import Transform
21
21
 
22
22
 
23
- def RING(lam: list[int] | None, Ts: float | None, **kwargs):
23
+ def RING(lam: list[int] | None, Ts: float | None, **kwargs) -> ml.AbstractFilter:
24
24
  """Creates the RING network.
25
25
 
26
26
  Params:
27
27
  lam: parent array, if `None` must be given via `ringnet.apply(..., lam=lam)`
28
28
  Ts : sampling interval of IMU data; time delta in seconds
29
29
 
30
- Usage:
31
- >>> import ring
32
- >>> import numpy as np
33
- >>>
34
- >>> T : int = 30 # sequence length [s]
35
- >>> Ts : float = 0.01 # sampling interval [s]
36
- >>> B : int = 1 # batch size
37
- >>> lam: list[int] = [0, 1, 2] # parent array
38
- >>> N : int = len(lam) # number of bodies
39
- >>> T_i: int = int(T/Ts) # number of timesteps
40
- >>>
41
- >>> X = np.zeros((B, T_i, N, 9))
42
- >>> # where X is structured as follows:
43
- >>> # X[..., :3] = acc
44
- >>> # X[..., 3:6] = gyr
45
- >>> # X[..., 6:9] = jointaxis
46
- >>>
47
- >>> # let's assume we have an IMU on each outer segment of the
48
- >>> # three-segment kinematic chain
49
- >>> X[:, :, 0, :3] = acc_segment1
50
- >>> X[:, :, 2, :3] = acc_segment3
51
- >>> X[:, :, 0, 3:6] = gyr_segment1
52
- >>> X[:, :, 2, 3:6] = gyr_segment3
53
- >>>
54
- >>> ringnet = ring.RING(lam, Ts)
55
- >>>
56
- >>> yhat, _ = ringnet.apply(X)
57
- >>> # yhat : unit quaternions, shape = (B, T_i, N, 4)
58
- >>> # yhat[b, :, i] is the orientation from body `i` to parent body `lam[i]`
59
- >>>
60
- >>> # use `jax.jit` to compile the forward pass
61
- >>> jit_apply = jax.jit(ringnet.apply)
62
- >>> yhat, _ = jit_apply(X)
63
- >>>
64
- >>> # manually pass in and out the hidden state like so
65
- >>> initial_state = None
66
- >>> yhat, state = ringnet.apply(X, state=initial_state)
67
- >>> # state: final hidden state, shape = (B, N, 2*H)
68
-
30
+ Returns:
31
+ ring.ml.AbstractFilter: An instantiation of `ring.ml.ringnet.RING` with trained
32
+ parameters.
33
+
34
+ Examples:
35
+ >>> import ring
36
+ >>> import numpy as np
37
+ >>>
38
+ >>> T : int = 30 # sequence length [s]
39
+ >>> Ts : float = 0.01 # sampling interval [s]
40
+ >>> B : int = 1 # batch size
41
+ >>> lam: list[int] = [0, 1, 2] # parent array
42
+ >>> N : int = len(lam) # number of bodies
43
+ >>> T_i: int = int(T/Ts) # number of timesteps
44
+ >>>
45
+ >>> X = np.zeros((B, T_i, N, 9))
46
+ >>> # where X is structured as follows:
47
+ >>> # X[..., :3] = acc
48
+ >>> # X[..., 3:6] = gyr
49
+ >>> # X[..., 6:9] = jointaxis
50
+ >>>
51
+ >>> # let's assume we have an IMU on each outer segment of the
52
+ >>> # three-segment kinematic chain
53
+ >>> X[:, :, 0, :3] = acc_segment1
54
+ >>> X[:, :, 2, :3] = acc_segment3
55
+ >>> X[:, :, 0, 3:6] = gyr_segment1
56
+ >>> X[:, :, 2, 3:6] = gyr_segment3
57
+ >>>
58
+ >>> ringnet = ring.RING(lam, Ts)
59
+ >>>
60
+ >>> yhat, _ = ringnet.apply(X)
61
+ >>> # yhat : unit quaternions, shape = (B, T_i, N, 4)
62
+ >>> # yhat[b, :, i] is the orientation from body `i` to parent body `lam[i]`
63
+ >>>
64
+ >>> # use `jax.jit` to compile the forward pass
65
+ >>> jit_apply = jax.jit(ringnet.apply)
66
+ >>> yhat, _ = jit_apply(X)
67
+ >>>
68
+ >>> # manually pass in and out the hidden state like so
69
+ >>> initial_state = None
70
+ >>> yhat, state = ringnet.apply(X, state=initial_state)
71
+ >>> # state: final hidden state, shape = (B, N, 2*H)
69
72
  """
70
73
  from pathlib import Path
71
74
  import warnings
@@ -303,6 +303,7 @@ def step(
303
303
  taus: Optional[jax.Array] = None,
304
304
  n_substeps: int = 1,
305
305
  ) -> base.State:
306
+ "Steps the dynamics. Returns the state of next timestep."
306
307
  assert sys.q_size() == state.q.size
307
308
  if taus is None:
308
309
  taus = jnp.zeros_like(state.qd)
@@ -4,6 +4,7 @@ import warnings
4
4
 
5
5
  import jax
6
6
  import jax.numpy as jnp
7
+ import numpy as np
7
8
  import tree_utils
8
9
 
9
10
  from ring import base
@@ -47,6 +48,7 @@ class RCMG:
47
48
  cor: bool = False,
48
49
  disable_tqdm: bool = False,
49
50
  ) -> None:
51
+ "Random Chain Motion Generator"
50
52
 
51
53
  sys, config = utils.to_list(sys), utils.to_list(config)
52
54
  sys_ml = sys[0] if sys_ml is None else sys_ml
@@ -141,7 +143,9 @@ class RCMG:
141
143
 
142
144
  return n_calls
143
145
 
144
- def to_list(self, sizes: int | list[int] = 1, seed: int = 1):
146
+ def to_list(
147
+ self, sizes: int | list[int] = 1, seed: int = 1
148
+ ) -> list[tree_utils.PyTree[np.ndarray]]:
145
149
  "Returns list of unbatched sequences as numpy arrays."
146
150
  repeats = self._compute_repeats(sizes)
147
151
  sizes = list(jnp.array(repeats) * jnp.array(self._size_of_generators))
@@ -170,7 +174,7 @@ class RCMG:
170
174
  seed: int = 1,
171
175
  overwrite: bool = True,
172
176
  ) -> None:
173
- data = tree_utils.tree_batch(self.to_list(sizes, seed))
177
+ data = tree_utils.tree_batch(self.to_list(sizes, seed), backend="numpy")
174
178
  utils.pickle_save(data, path, overwrite=overwrite)
175
179
 
176
180
  def to_eager_gen(
@@ -113,7 +113,9 @@ class _Base:
113
113
  class Transform(_Base):
114
114
  """Represents the Transformation from Plücker A to Plücker B,
115
115
  where B is located relative to A at `pos` in frame A and `rot` is the
116
- relative quaternion from A to B."""
116
+ relative quaternion from A to B.
117
+ Create using `Transform.create(pos=..., rot=...)
118
+ """
117
119
 
118
120
  pos: Vector
119
121
  rot: Quaternion
@@ -399,6 +401,7 @@ QD_WIDTHS = {
399
401
 
400
402
  @struct.dataclass
401
403
  class System(_Base):
404
+ "System object. Create using `System.create(path_xml)`"
402
405
  link_parents: list[int] = struct.field(False)
403
406
  links: Link
404
407
  link_types: list[str] = struct.field(False)
@@ -200,6 +200,7 @@ class RING(ml_base.AbstractFilter):
200
200
  forward_factory=make_ring,
201
201
  **kwargs,
202
202
  ):
203
+ "Untrained RING network"
203
204
  self.forward_lam_factory = partial(forward_factory, **kwargs)
204
205
  self.params = self._load_params(params)
205
206
  self.lam = lam
@@ -210,8 +210,9 @@ def replace_elements_w_nans(
210
210
  if _is_nan(ele, i):
211
211
  while True:
212
212
  j = random.choice(include_elements)
213
- if not _is_nan(list_of_data[j], j):
214
- ele = list_of_data[j]
213
+ ele_j = list_of_data[j]
214
+ if not _is_nan(ele_j, j):
215
+ ele = pytree_deepcopy(ele_j)
215
216
  break
216
217
  list_of_data_nonan.append(ele)
217
218
  return list_of_data_nonan
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes