imt-ring 1.6.31__tar.gz → 1.6.33__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (119) hide show
  1. {imt_ring-1.6.31 → imt_ring-1.6.33}/PKG-INFO +1 -1
  2. {imt_ring-1.6.31 → imt_ring-1.6.33}/pyproject.toml +1 -1
  3. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/imt_ring.egg-info/PKG-INFO +1 -1
  4. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/algorithms/generator/base.py +2 -0
  5. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/algorithms/generator/batch.py +6 -1
  6. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/algorithms/jcalc.py +19 -4
  7. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/utils/dataloader_torch.py +65 -3
  8. {imt_ring-1.6.31 → imt_ring-1.6.33}/tests/test_custom_joints.py +4 -4
  9. {imt_ring-1.6.31 → imt_ring-1.6.33}/tests/test_random.py +9 -7
  10. {imt_ring-1.6.31 → imt_ring-1.6.33}/tests/test_train.py +2 -1
  11. {imt_ring-1.6.31 → imt_ring-1.6.33}/readme.md +0 -0
  12. {imt_ring-1.6.31 → imt_ring-1.6.33}/setup.cfg +0 -0
  13. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/imt_ring.egg-info/SOURCES.txt +0 -0
  14. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/imt_ring.egg-info/dependency_links.txt +0 -0
  15. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/imt_ring.egg-info/requires.txt +0 -0
  16. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/imt_ring.egg-info/top_level.txt +0 -0
  17. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/__init__.py +0 -0
  18. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/algebra.py +0 -0
  19. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/algorithms/__init__.py +0 -0
  20. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/algorithms/_random.py +0 -0
  21. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/algorithms/custom_joints/__init__.py +0 -0
  22. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/algorithms/custom_joints/rr_imp_joint.py +0 -0
  23. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
  24. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/algorithms/custom_joints/rsaddle_joint.py +0 -0
  25. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/algorithms/custom_joints/suntay.py +0 -0
  26. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/algorithms/dynamics.py +0 -0
  27. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/algorithms/generator/__init__.py +0 -0
  28. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/algorithms/generator/finalize_fns.py +0 -0
  29. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/algorithms/generator/motion_artifacts.py +0 -0
  30. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/algorithms/generator/pd_control.py +0 -0
  31. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/algorithms/generator/setup_fns.py +0 -0
  32. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/algorithms/generator/types.py +0 -0
  33. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/algorithms/kinematics.py +0 -0
  34. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/algorithms/sensors.py +0 -0
  35. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/base.py +0 -0
  36. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/__init__.py +0 -0
  37. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/examples/branched.xml +0 -0
  38. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
  39. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
  40. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
  41. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/examples/inv_pendulum.xml +0 -0
  42. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
  43. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/examples/spherical_stiff.xml +0 -0
  44. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/examples/symmetric.xml +0 -0
  45. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/examples/test_all_1.xml +0 -0
  46. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/examples/test_all_2.xml +0 -0
  47. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
  48. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/examples/test_control.xml +0 -0
  49. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/examples/test_double_pendulum.xml +0 -0
  50. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/examples/test_free.xml +0 -0
  51. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/examples/test_kinematics.xml +0 -0
  52. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
  53. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
  54. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/examples/test_randomize_position.xml +0 -0
  55. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/examples/test_sensors.xml +0 -0
  56. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
  57. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/examples.py +0 -0
  58. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/test_examples.py +0 -0
  59. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/xml/__init__.py +0 -0
  60. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/xml/abstract.py +0 -0
  61. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/xml/from_xml.py +0 -0
  62. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/xml/test_from_xml.py +0 -0
  63. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/xml/test_to_xml.py +0 -0
  64. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/io/xml/to_xml.py +0 -0
  65. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/maths.py +0 -0
  66. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/ml/__init__.py +0 -0
  67. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/ml/base.py +0 -0
  68. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/ml/callbacks.py +0 -0
  69. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/ml/ml_utils.py +0 -0
  70. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/ml/optimizer.py +0 -0
  71. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  72. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/ml/params/0x1d76628065a71e0f.pickle +0 -0
  73. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/ml/ringnet.py +0 -0
  74. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/ml/rnno_v1.py +0 -0
  75. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/ml/train.py +0 -0
  76. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/ml/training_loop.py +0 -0
  77. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/rendering/__init__.py +0 -0
  78. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/rendering/base_render.py +0 -0
  79. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/rendering/mujoco_render.py +0 -0
  80. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/rendering/vispy_render.py +0 -0
  81. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/rendering/vispy_visuals.py +0 -0
  82. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/sim2real/__init__.py +0 -0
  83. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/sim2real/sim2real.py +0 -0
  84. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/spatial.py +0 -0
  85. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/sys_composer/__init__.py +0 -0
  86. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/sys_composer/delete_sys.py +0 -0
  87. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/sys_composer/inject_sys.py +0 -0
  88. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/sys_composer/morph_sys.py +0 -0
  89. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/utils/__init__.py +0 -0
  90. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/utils/backend.py +0 -0
  91. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/utils/batchsize.py +0 -0
  92. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/utils/colab.py +0 -0
  93. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/utils/dataloader.py +0 -0
  94. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/utils/hdf5.py +0 -0
  95. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/utils/normalizer.py +0 -0
  96. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/utils/path.py +0 -0
  97. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/utils/randomize_sys.py +0 -0
  98. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/utils/register_gym_envs/__init__.py +0 -0
  99. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/utils/register_gym_envs/saddle.py +0 -0
  100. {imt_ring-1.6.31 → imt_ring-1.6.33}/src/ring/utils/utils.py +0 -0
  101. {imt_ring-1.6.31 → imt_ring-1.6.33}/tests/test_algebra.py +0 -0
  102. {imt_ring-1.6.31 → imt_ring-1.6.33}/tests/test_base.py +0 -0
  103. {imt_ring-1.6.31 → imt_ring-1.6.33}/tests/test_dynamics.py +0 -0
  104. {imt_ring-1.6.31 → imt_ring-1.6.33}/tests/test_generator.py +0 -0
  105. {imt_ring-1.6.31 → imt_ring-1.6.33}/tests/test_jcalc.py +0 -0
  106. {imt_ring-1.6.31 → imt_ring-1.6.33}/tests/test_jit.py +0 -0
  107. {imt_ring-1.6.31 → imt_ring-1.6.33}/tests/test_kinematics.py +0 -0
  108. {imt_ring-1.6.31 → imt_ring-1.6.33}/tests/test_maths.py +0 -0
  109. {imt_ring-1.6.31 → imt_ring-1.6.33}/tests/test_ml_utils.py +0 -0
  110. {imt_ring-1.6.31 → imt_ring-1.6.33}/tests/test_motion_artifacts.py +0 -0
  111. {imt_ring-1.6.31 → imt_ring-1.6.33}/tests/test_pd_control.py +0 -0
  112. {imt_ring-1.6.31 → imt_ring-1.6.33}/tests/test_quickstart_example.py +0 -0
  113. {imt_ring-1.6.31 → imt_ring-1.6.33}/tests/test_randomize.py +0 -0
  114. {imt_ring-1.6.31 → imt_ring-1.6.33}/tests/test_rcmg.py +0 -0
  115. {imt_ring-1.6.31 → imt_ring-1.6.33}/tests/test_render.py +0 -0
  116. {imt_ring-1.6.31 → imt_ring-1.6.33}/tests/test_sensors.py +0 -0
  117. {imt_ring-1.6.31 → imt_ring-1.6.33}/tests/test_sim2real.py +0 -0
  118. {imt_ring-1.6.31 → imt_ring-1.6.33}/tests/test_sys_composer.py +0 -0
  119. {imt_ring-1.6.31 → imt_ring-1.6.33}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.31
3
+ Version: 1.6.33
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "imt-ring"
7
- version = "1.6.31"
7
+ version = "1.6.33"
8
8
  authors = [
9
9
  { name="Simon Bachhuber", email="simon.bachhuber@fau.de" },
10
10
  ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.31
3
+ Version: 1.6.33
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -213,6 +213,8 @@ class RCMG:
213
213
  )
214
214
  save_fn(d, file)
215
215
  i += 1
216
+ # cleanup
217
+ del data
216
218
 
217
219
  gens, n_calls = self._generators_ncalls(sizes)
218
220
  batch.generators_eager(gens, n_calls, callback, seed, self._disable_tqdm)
@@ -1,3 +1,4 @@
1
+ import gc
1
2
  from typing import Callable
2
3
 
3
4
  import jax
@@ -83,4 +84,8 @@ def generators_eager(
83
84
 
84
85
  sample_flat, _ = jax.tree_util.tree_flatten(sample)
85
86
  size = 1 if len(sample_flat) == 0 else sample_flat[0].shape[0]
86
- callback([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
87
+ callback([jax.tree_map(lambda a: a[i].copy(), sample) for i in range(size)])
88
+
89
+ # cleanup
90
+ del sample, sample_flat
91
+ gc.collect()
@@ -412,6 +412,9 @@ def _find_interval(t: jax.Array, boundaries: jax.Array):
412
412
  def join_motionconfigs(
413
413
  configs: list[MotionConfig], boundaries: list[float]
414
414
  ) -> MotionConfig:
415
+ # to avoid a circular import due to `ring.utils.randomize_sys` importing `jcalc`
416
+ from ring.utils import tree_equal
417
+
415
418
  assert len(configs) == (
416
419
  len(boundaries) + 1
417
420
  ), "length of `boundaries` should be one less than length of `configs`"
@@ -434,10 +437,22 @@ def join_motionconfigs(
434
437
  time_independent_fields = [key for key in attrs if not is_time_dependent_field(key)]
435
438
 
436
439
  for time_dep_field in time_independent_fields:
437
- field_values = set([getattr(config, time_dep_field) for config in configs])
438
- assert (
439
- len(field_values) == 1
440
- ), f"MotionConfig.{time_dep_field}={field_values}. Should be one unique value.."
440
+ try:
441
+ field_values = set([getattr(config, time_dep_field) for config in configs])
442
+ assert (
443
+ len(field_values) == 1
444
+ ), f"MotionConfig.{time_dep_field}={field_values}. "
445
+ "Should be one unique value.."
446
+ except (
447
+ TypeError
448
+ ): # dict is not hashable so test equality of all elements differently
449
+ comparison_ele = getattr(configs[0], time_dep_field)
450
+ for other_config in configs[1:]:
451
+ other_ele = getattr(other_config, time_dep_field)
452
+ assert tree_equal(
453
+ comparison_ele, other_ele
454
+ ), f"MotionConfig.{time_dep_field} with {comparison_ele} != {other_ele}"
455
+ " Should be one unique value.."
441
456
 
442
457
  changes = {field: new_value(field) for field in time_dependent_fields}
443
458
  return replace(configs[0], **changes)
@@ -1,8 +1,9 @@
1
1
  import os
2
- from typing import Optional
2
+ from typing import Any, Optional
3
3
  import warnings
4
4
 
5
5
  import jax
6
+ import numpy as np
6
7
  import torch
7
8
  from torch.utils.data import DataLoader
8
9
  from torch.utils.data import Dataset
@@ -12,7 +13,7 @@ from ring.utils import parse_path
12
13
  from ring.utils import pickle_load
13
14
 
14
15
 
15
- class FolderOfPickleFilesDataset(Dataset):
16
+ class FolderOfFilesDataset(Dataset):
16
17
  def __init__(self, path, transform=None):
17
18
  self.files = self.listdir(path)
18
19
  self.transform = transform
@@ -22,7 +23,7 @@ class FolderOfPickleFilesDataset(Dataset):
22
23
  return self.N
23
24
 
24
25
  def __getitem__(self, idx: int):
25
- element = pickle_load(self.files[idx])
26
+ element = self._load_file(self.files[idx])
26
27
  if self.transform is not None:
27
28
  element = self.transform(element)
28
29
  return element
@@ -31,6 +32,10 @@ class FolderOfPickleFilesDataset(Dataset):
31
32
  def listdir(path: str) -> list:
32
33
  return [parse_path(path, file) for file in os.listdir(path)]
33
34
 
35
+ @staticmethod
36
+ def _load_file(file_path: str) -> Any:
37
+ return pickle_load(file_path)
38
+
34
39
 
35
40
  def dataset_to_generator(
36
41
  dataset: Dataset,
@@ -84,3 +89,60 @@ def _get_number_of_logical_cores() -> int:
84
89
  )
85
90
  N = 0
86
91
  return N
92
+
93
+
94
+ class MultiDataset(Dataset):
95
+ def __init__(self, datasets, transform=None):
96
+ """
97
+ Args:
98
+ datasets: A list of datasets to sample from.
99
+ transform: A function that takes N items (one from each dataset) and combines them.
100
+ """ # noqa: E501
101
+ self.datasets = datasets
102
+ self.transform = transform
103
+
104
+ def __len__(self):
105
+ # Length is defined by the smallest dataset in the list
106
+ return min(len(ds) for ds in self.datasets)
107
+
108
+ def __getitem__(self, idx):
109
+ sampled_items = [ds[idx] for ds in self.datasets]
110
+
111
+ if self.transform:
112
+ # Apply the transformation to all sampled items
113
+ return self.transform(*sampled_items)
114
+
115
+ return tuple(sampled_items)
116
+
117
+
118
+ class ShuffledDataset(Dataset):
119
+ def __init__(self, dataset):
120
+ """
121
+ Wrapper that shuffles the dataset indices once.
122
+
123
+ Args:
124
+ dataset (Dataset): The original dataset to shuffle.
125
+ """
126
+ self.dataset = dataset
127
+ self.shuffled_indices = np.random.permutation(
128
+ len(dataset)
129
+ ) # Shuffle indices once
130
+
131
+ def __len__(self):
132
+ return len(self.dataset)
133
+
134
+ def __getitem__(self, idx):
135
+ """
136
+ Returns the data at the shuffled index.
137
+
138
+ Args:
139
+ idx (int): Index in the shuffled dataset.
140
+ """
141
+ original_idx = self.shuffled_indices[idx]
142
+ return self.dataset[original_idx]
143
+
144
+
145
+ def dataset_to_Xy(ds: Dataset):
146
+ return dataset_to_generator(ds, batch_size=len(ds), shuffle=False, num_workers=0)(
147
+ None
148
+ )
@@ -43,7 +43,7 @@ def test_virtual_input_joint_axes_rr_joint():
43
43
  np.testing.assert_allclose(
44
44
  -X["seg1"]["joint_axes"],
45
45
  np.repeat(np.array([[0.0, 1, 0]]), 1000, axis=0),
46
- atol=4e-7,
46
+ atol=2e-6,
47
47
  )
48
48
  np.testing.assert_allclose(
49
49
  -X["seg3"]["joint_axes"],
@@ -57,8 +57,8 @@ def test_virtual_input_joint_axes_rr_joint():
57
57
  np.testing.assert_allclose(
58
58
  X["seg1"]["joint_axes"],
59
59
  np.repeat(-joint_axes[1:2], 1000, axis=0),
60
- atol=4e-7,
61
- rtol=2e-4,
60
+ atol=2e-6,
61
+ rtol=5e-4,
62
62
  )
63
63
  np.testing.assert_allclose(
64
64
  -X["seg3"]["joint_axes"],
@@ -103,7 +103,7 @@ def test_virtual_input_joint_axes_rr_imp_joint():
103
103
  np.testing.assert_allclose(
104
104
  -X["seg1"]["joint_axes"],
105
105
  np.repeat(np.array([[0.0, 1, 0]]), 1000, axis=0),
106
- atol=4e-7,
106
+ atol=2e-6,
107
107
  )
108
108
  np.testing.assert_allclose(
109
109
  X["seg3"]["joint_axes"],
@@ -26,6 +26,8 @@ def test_delta_ang_min_max(
26
26
  dt, next_phi = _resolve_range_of_motion(
27
27
  range_of_motion,
28
28
  range_of_motion_method,
29
+ -2 * np.pi,
30
+ 2 * np.pi,
29
31
  0.0,
30
32
  3.14,
31
33
  float(delta_ang_min),
@@ -83,10 +85,10 @@ def test_angle(randomized_interpolation, range_of_motion, range_of_motion_method
83
85
  def test_position():
84
86
  for Ts in [0.1, 0.01]:
85
87
  T = 30
86
- POS_0 = 0.0
87
- pos = ring.algorithms.random_position_over_time(
88
- jrand.PRNGKey(1), POS_0, -0.2, 0.2, 0.1, 0.5, 0.1, 0.5, T, Ts, None, 10
89
- )
90
- assert pos.shape == (int(T / Ts),)
91
- # TODO Why does this fail for POS_0 != 0.0?
92
- assert pos[0] == POS_0
88
+ for POS_0 in [0.0, 1.0]:
89
+ pos = ring.algorithms.random_position_over_time(
90
+ jrand.PRNGKey(1), POS_0, -0.2, 0.2, 0.1, 0.5, 0.1, 0.5, T, Ts, None, 10
91
+ )
92
+ assert pos.shape == (int(T / Ts),)
93
+ # TODO Why does this fail for POS_0 != 0.0?
94
+ assert pos[0] == POS_0
@@ -33,7 +33,8 @@ def test_rnno():
33
33
  ml.train_fn(
34
34
  gen,
35
35
  5,
36
- ml.RNNO(N * 4, return_quats=True, eval=False, hidden_state_dim=20),
36
+ # .unwrapped to get ride of the `GroundtruthHeadingWrapper`
37
+ ml.RNNO(N * 4, return_quats=True, eval=False, hidden_state_dim=20).unwrapped,
37
38
  )
38
39
 
39
40
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes