imt-ring 1.6.11__py3-none-any.whl → 1.6.13__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.11
3
+ Version: 1.6.13
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -1,22 +1,22 @@
1
1
  ring/__init__.py,sha256=k7tL-XgggUwWxHCXyv60rQn-OcXHPg82QcIUkKLEd-c,5057
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
- ring/base.py,sha256=MkkziQx01sdMOpB8MFUDFgFlZUrXCFjpb8hS9yKHUyM,33751
3
+ ring/base.py,sha256=Ystn1EjTyOXBhVm5koroV_YPUYtFxrteJLd-XR3kEL8,33840
4
4
  ring/maths.py,sha256=qPHH6TpHCK3TgExI98gNEySoSRKOwteN9McUlyUFipI,12207
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
7
7
  ring/algorithms/_random.py,sha256=fc26yEQjSjtf0NluZ41CyeGIRci0ldrRlThueHR9H7U,14007
8
8
  ring/algorithms/dynamics.py,sha256=GOedL1STj6oXcXgMA7dB4PabvCQxPBbirJQhXBRuKqE,10929
9
- ring/algorithms/jcalc.py,sha256=bM8VARgqEiVPy7632geKYGk4MZddZfI8XHdW5kXF3HI,28594
9
+ ring/algorithms/jcalc.py,sha256=bwfVH3qKEnUs6RFgEEeUBnecpBt-nf8cesJbNGDrE7g,28974
10
10
  ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
11
- ring/algorithms/sensors.py,sha256=0zLjd9TzOkU5W5GJU6Dk4QwYjwqs9AUlzUKU8aSX_dc,18126
11
+ ring/algorithms/sensors.py,sha256=0xOzdQIc1kBF0CkoPXWWCx3MmV4SG3wj7knVnnMWq9M,18124
12
12
  ring/algorithms/custom_joints/__init__.py,sha256=fzeE7TdUhmGgbbFAyis1tKcyQ4Fo8LigDwD3hUVnH_w,316
13
13
  ring/algorithms/custom_joints/rr_imp_joint.py,sha256=_YJK0p8_0MHFtr1NuGnNZoxTbwaMQyUjYv7EtsPiU3A,2402
14
14
  ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
15
15
  ring/algorithms/custom_joints/suntay.py,sha256=tOEGM304XciHO4pmvxr4faA4xXVO4N2HlPdFmXKbcrw,16726
16
16
  ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
17
- ring/algorithms/generator/base.py,sha256=LRAKxzrwq6fp4lgVw6IUg4i7isx3iqJLHvpFK1aTRcg,15732
17
+ ring/algorithms/generator/base.py,sha256=vxUdA0ZeSNH3SOanL51qVRvCiJrmsWQyQX0g2fdm3Rg,15825
18
18
  ring/algorithms/generator/batch.py,sha256=9yFxVv11hij-fJXGPxA3zEh1bE2_jrZk0R7kyGaiM5c,2551
19
- ring/algorithms/generator/finalize_fns.py,sha256=LUw1Wc2YrmMRRh4RF704ob3bZOXktAZAbbLoBm_p1yw,9131
19
+ ring/algorithms/generator/finalize_fns.py,sha256=559sGXs06n46p-eme0SE8hn0lXwGT0P2r3-52ElTldo,9861
20
20
  ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
21
21
  ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
22
22
  ring/algorithms/generator/setup_fns.py,sha256=MFz3czHBeWs1Zk1A8O02CyQpQ-NCyW9PMpbqmKit6es,1455
@@ -53,10 +53,10 @@ ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
53
53
  ring/ml/__init__.py,sha256=nbh48gaswWeY4S4vT1sply_3ROj2DQ7agjoLR4Ho3T8,1517
54
54
  ring/ml/base.py,sha256=lfwEZLBDglOSRWChUHoH1kezefhttPV9TMEpNIqsMNw,9972
55
55
  ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
56
- ring/ml/ml_utils.py,sha256=GooyH5uxA6cJM7ZcWDUfSkSKq6dg7kCIbhkbjJs_rLw,6674
56
+ ring/ml/ml_utils.py,sha256=1GXJfeoXbwCbRdYA2np3CbJpSupaw4eyf3quh9y4BO0,6462
57
57
  ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
58
58
  ring/ml/ringnet.py,sha256=Tb2WJ_cc5L3mk1lo0NOfkpXIzJZXf4PJ5aLPtHQyUmY,8650
59
- ring/ml/rnno_v1.py,sha256=T4SKG7iypqn2HBQLKhDmJ2Slj2Z5jtUBHvX_6aL8pyM,1103
59
+ ring/ml/rnno_v1.py,sha256=ujyIkDxMSTag9iRFEmoHqfqSrlOFjcZs9_rBbLd8p9Q,1380
60
60
  ring/ml/train.py,sha256=huUfMK6eotS6BRrQKoZ-AUG0um3jlqpfQFZNJT8LKiE,10854
61
61
  ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
62
62
  ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
@@ -76,14 +76,15 @@ ring/utils/__init__.py,sha256=MHHavc8YfjBlmB-zAV42QEQS_ebW7cy0lhWXEVyQU7s,720
76
76
  ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
77
77
  ring/utils/batchsize.py,sha256=FbOii7MDP4oPZd9GJOKehFatfnb6WZ0b9z349iZYs1A,1786
78
78
  ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
79
+ ring/utils/dataloader.py,sha256=2CcsbUY2AZs8LraS5HTJXlEseuF-1gKmfyBkSsib-tE,3748
79
80
  ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
80
81
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
81
82
  ring/utils/path.py,sha256=zRPfxYNesvgefkddd26oar6f9433LkMGkhp9dF3rPUs,1926
82
83
  ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
83
- ring/utils/utils.py,sha256=oGC7kh19s5zvmnUvWy8B3fBl9loVU58ppz91osk2m3w,6550
84
+ ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
84
85
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
85
86
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
86
- imt_ring-1.6.11.dist-info/METADATA,sha256=kkQfOD5LOSzB4lR7LvkHeck6fB_KPNrSKIsvPizJAKI,3821
87
- imt_ring-1.6.11.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
88
- imt_ring-1.6.11.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
89
- imt_ring-1.6.11.dist-info/RECORD,,
87
+ imt_ring-1.6.13.dist-info/METADATA,sha256=wMwfHX8PsYaxXZRldgqd71fGltIfUo_W9xVE9DSz5o0,3821
88
+ imt_ring-1.6.13.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
89
+ imt_ring-1.6.13.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
90
+ imt_ring-1.6.13.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (74.1.2)
2
+ Generator: setuptools (75.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -321,6 +321,9 @@ def _build_mconfig_batched_generator(
321
321
  "using the `randomize_motion_artifacts` flag, so it must be enabled."
322
322
  )
323
323
 
324
+ if dynamic_simulation:
325
+ finalize_fns.DynamicalSimulation.assert_test_system(sys)
326
+
324
327
  def _setup_fn(key: types.PRNGKey, sys: base.System) -> base.System:
325
328
  pipe = []
326
329
  if imu_motion_artifacts and randomize_motion_artifacts:
@@ -180,6 +180,23 @@ class DynamicalSimulation:
180
180
  self.overwrite_q_ref = overwrite_q_ref
181
181
  self.unroll_kwargs = unroll_kwargs
182
182
 
183
+ @staticmethod
184
+ def assert_test_system(sys: base.System) -> None:
185
+ "test that system has no zero mass bodies and no joints without damping"
186
+
187
+ def f(_, __, n, m, d):
188
+ assert d.size == 0 or m > 0, (
189
+ "Dynamic simulation is set to `True` which requires masses >= 0, "
190
+ f"but found body `{n}` with mass={float(m[0])}. This can lead to NaNs."
191
+ )
192
+
193
+ assert d.size == 0 or all(d > 0.0), (
194
+ "Dynamic simulation is set to `True` which requires dampings > 0, "
195
+ f"but found body `{n}` with damping={d}. This can lead to NaNs."
196
+ )
197
+
198
+ sys.scan(f, "lld", sys.link_names, sys.links.inertia.mass, sys.link_damping)
199
+
183
200
  def __call__(
184
201
  self, Xy: types.Xy, extras: types.OutputExtras
185
202
  ) -> tuple[types.Xy, types.OutputExtras]:
ring/algorithms/jcalc.py CHANGED
@@ -205,7 +205,7 @@ def _is_feasible_config1(c: MotionConfig) -> bool:
205
205
  return False
206
206
  return True
207
207
 
208
- return all(
208
+ cond1 = all(
209
209
  [
210
210
  dx_deltax_check(*args)
211
211
  for args in zip(
@@ -217,6 +217,15 @@ def _is_feasible_config1(c: MotionConfig) -> bool:
217
217
  ]
218
218
  )
219
219
 
220
+ # this one tests that the initial value is inside the feasible value range
221
+ # so e.g. if you choose pos0_min=-10 then you can't choose pos_min=-1
222
+ def inside_box_checks(x_min, x_max, x0_min, x0_max) -> bool:
223
+ return (x0_min >= x_min) and (x0_max <= x_max)
224
+
225
+ cond2 = inside_box_checks(c.pos_min, c.pos_max, c.pos0_min, c.pos0_max)
226
+
227
+ return cond1 and cond2
228
+
220
229
 
221
230
  def _find_interval(t: jax.Array, boundaries: jax.Array):
222
231
  """Find the interval of `boundaries` between which `t` lies.
@@ -131,7 +131,7 @@ def magnetometer(rot: jax.Array, magvec: jax.Array) -> jax.Array:
131
131
  # - gyr: rad/s
132
132
  # - mag: a.u.
133
133
  NOISE_LEVELS = {"acc": 0.048, "gyr": jnp.deg2rad(0.7), "mag": 0.01}
134
- BIAS_LEVELS = {"acc": 0.5, "gyr": jnp.deg2rad(3.6), "mag": 0.0}
134
+ BIAS_LEVELS = {"acc": 0.5, "gyr": jnp.deg2rad(3), "mag": 0.0}
135
135
 
136
136
 
137
137
  def add_noise_bias(
ring/base.py CHANGED
@@ -690,7 +690,9 @@ class System(_Base):
690
690
  transparent_segment_to_root: bool = True,
691
691
  **kwargs,
692
692
  ):
693
- "`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
693
+ """`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent.
694
+ Note that the body in yhat that connects to -1, is parent-to-child!
695
+ """
694
696
  return ring.rendering.render_prediction(
695
697
  self, xs, yhat, transparent_segment_to_root, **kwargs
696
698
  )
ring/ml/ml_utils.py CHANGED
@@ -12,7 +12,6 @@ import numpy as np
12
12
  from tree_utils import PyTree
13
13
 
14
14
  import ring
15
- from ring.utils import import_lib
16
15
  import wandb
17
16
 
18
17
  # An arbitrarily nested dictionary with Array leaves; Or strings
@@ -185,41 +184,37 @@ def on_cluster() -> bool:
185
184
 
186
185
 
187
186
  def unique_id() -> str:
187
+ if wandb.run is not None:
188
+ wandb.config.setdefault("unique_id", ring._UNIQUE_ID)
188
189
  return ring._UNIQUE_ID
189
190
 
190
191
 
191
192
  def save_model_tf(jax_func, path: str, *input, validate: bool = True):
192
193
  from jax.experimental import jax2tf
194
+ import tensorflow as tf
193
195
 
194
- tf = import_lib("tensorflow", "the function `save_model_tf`")
195
-
196
- def _create_module(jax_func, input):
197
- signature = jax.tree_map(
198
- lambda arr: tf.TensorSpec(list(arr.shape), tf.float32), input
199
- )
196
+ signature = jax.tree_map(
197
+ lambda arr: tf.TensorSpec(list(arr.shape), tf.float32), input
198
+ )
200
199
 
201
- class RingTFModule(tf.Module):
202
- def __init__(self, jax_func):
203
- super().__init__()
204
- self.tf_func = jax2tf.convert(jax_func, with_gradient=False)
200
+ tf_func = jax2tf.convert(jax_func, with_gradient=False)
205
201
 
206
- @partial(
207
- tf.function,
208
- autograph=False,
209
- jit_compile=True,
210
- input_signature=signature,
211
- )
212
- def __call__(self, *args):
213
- return self.tf_func(*args)
202
+ class RingTFModule(tf.Module):
203
+ @partial(
204
+ tf.function, autograph=False, jit_compile=True, input_signature=signature
205
+ )
206
+ def __call__(self, *args):
207
+ return tf_func(*args)
214
208
 
215
- return RingTFModule(jax_func)
209
+ model = RingTFModule()
216
210
 
217
- model = _create_module(jax_func, input)
218
211
  tf.saved_model.save(
219
212
  model,
220
213
  path,
221
214
  options=tf.saved_model.SaveOptions(experimental_custom_gradients=False),
215
+ signatures={"default": model.__call__},
222
216
  )
217
+
223
218
  if validate:
224
219
  output_jax = jax_func(*input)
225
220
  output_tf = tf.saved_model.load(path)(*input)
ring/ml/rnno_v1.py CHANGED
@@ -4,6 +4,8 @@ import haiku as hk
4
4
  import jax
5
5
  import jax.numpy as jnp
6
6
 
7
+ from .ringnet import LSTM
8
+
7
9
 
8
10
  def rnno_v1_forward_factory(
9
11
  output_dim: int,
@@ -13,18 +15,29 @@ def rnno_v1_forward_factory(
13
15
  act_fn_linear=jax.nn.relu,
14
16
  act_fn_rnn=jax.nn.elu,
15
17
  lam: Optional[tuple[int]] = None,
18
+ celltype: str = "gru",
16
19
  ):
17
20
  # unused
18
21
  del lam
19
22
 
23
+ if celltype == "gru":
24
+ _cell = hk.GRU
25
+ _factor = 1
26
+ elif celltype == "lstm":
27
+ _cell = LSTM
28
+ _factor = 2
29
+ else:
30
+ raise NotImplementedError
31
+
20
32
  @hk.without_apply_rng
21
33
  @hk.transform_with_state
22
34
  def forward_fn(X):
23
35
  assert X.shape[-2] == 1
24
36
 
25
37
  for i, n_units in enumerate(rnn_layers):
38
+ n_units = _factor * n_units
26
39
  state = hk.get_state(f"rnn_{i}", shape=[1, n_units], init=jnp.zeros)
27
- X, state = hk.dynamic_unroll(hk.GRU(n_units), X, state)
40
+ X, state = hk.dynamic_unroll(_cell(n_units), X, state)
28
41
  hk.set_state(f"rnn_{i}", state)
29
42
 
30
43
  if layernorm:
@@ -0,0 +1,159 @@
1
+ import os
2
+ import random
3
+ from typing import Callable, Optional
4
+
5
+ import jax
6
+ import numpy as np
7
+ from ring.utils import parse_path
8
+ from ring.utils import pickle_load
9
+ import torch
10
+ from torch.utils.data import DataLoader
11
+ from torch.utils.data import Dataset
12
+ import tqdm
13
+ from tree_utils import PyTree
14
+
15
+
16
+ def make_generator(
17
+ *paths,
18
+ batch_size,
19
+ transform,
20
+ shuffle=True,
21
+ seed: int = 1,
22
+ backend: str = "eager",
23
+ **kwargs,
24
+ ):
25
+ if backend == "grain":
26
+ _make_gen = pygrain_generator
27
+ elif backend == "torch":
28
+ _make_gen = pytorch_generator
29
+ elif backend == "eager":
30
+ _make_gen = eager_generator
31
+ else:
32
+ raise NotImplementedError
33
+
34
+ return _make_gen(
35
+ *paths,
36
+ batch_size=batch_size,
37
+ transform=transform,
38
+ shuffle=shuffle,
39
+ seed=seed,
40
+ **kwargs,
41
+ )
42
+
43
+
44
+ T = PyTree[np.ndarray]
45
+
46
+
47
+ class _Dataset(Dataset):
48
+ def __init__(self, *paths, transform):
49
+
50
+ self.files = [self.listdir(path) for path in paths]
51
+ Ns = set([len(f) for f in self.files])
52
+ assert len(Ns) == 1, f"{Ns}"
53
+
54
+ self.P = len(self.files)
55
+ self.N = list(Ns)[0]
56
+ self.transform = transform
57
+
58
+ def __len__(self):
59
+ return self.N
60
+
61
+ def __getitem__(self, idx: int):
62
+ element = [pickle_load(self.files[p][idx]) for p in range(self.P)]
63
+ if self.transform is not None:
64
+ element = self.transform(element)
65
+ return element
66
+
67
+ @staticmethod
68
+ def listdir(path: str) -> list:
69
+ return [parse_path(path, file) for file in os.listdir(path)]
70
+
71
+ def __call__(self, idx: int):
72
+ return self[idx]
73
+
74
+
75
+ class TransformTransform:
76
+ def __init__(self, transform):
77
+ self.transform = transform
78
+
79
+ def __call__(self, element):
80
+ if self.transform is None:
81
+ return element
82
+ return self.transform(element, np.random.default_rng())
83
+
84
+
85
+ def pytorch_generator(
86
+ *paths,
87
+ batch_size: int,
88
+ transform: Optional[Callable[[T], T]] = None,
89
+ shuffle=True,
90
+ seed: int = 1,
91
+ **kwargs,
92
+ ):
93
+ torch.manual_seed(seed)
94
+
95
+ ds = _Dataset(*paths, transform=TransformTransform(transform))
96
+ dl = DataLoader(
97
+ ds,
98
+ batch_size=batch_size,
99
+ shuffle=shuffle,
100
+ multiprocessing_context="spawn" if kwargs.get("num_workers", 0) > 0 else None,
101
+ **kwargs,
102
+ )
103
+ dl_iter = iter(dl)
104
+
105
+ def to_numpy(tree: PyTree[torch.Tensor]):
106
+ return jax.tree_map(lambda tensor: tensor.numpy(), tree)
107
+
108
+ def generator(_):
109
+ nonlocal dl, dl_iter
110
+ try:
111
+ return to_numpy(next(dl_iter))
112
+ except StopIteration:
113
+ dl_iter = iter(dl)
114
+ return to_numpy(next(dl_iter))
115
+
116
+ return generator
117
+
118
+
119
+ def eager_generator(
120
+ *paths,
121
+ batch_size: int,
122
+ transform: Optional[Callable[[T], T]] = None,
123
+ shuffle=True,
124
+ seed=1,
125
+ ):
126
+ from ring import RCMG
127
+
128
+ random.seed(seed)
129
+
130
+ ds = _Dataset(*paths, transform=TransformTransform(transform))
131
+ data = [ds[i] for i in tqdm.tqdm(range(len(ds)), total=len(ds))]
132
+ return RCMG.eager_gen_from_list(data, batch_size, shuffle=shuffle)
133
+
134
+
135
+ def pygrain_generator(
136
+ *paths, batch_size: int, transform=None, shuffle=True, seed=1, **kwargs
137
+ ):
138
+
139
+ import grain.python as pygrain # type: ignore
140
+
141
+ class _Transform(pygrain.RandomMapTransform):
142
+ def random_map(self, element, rng: np.random.Generator):
143
+ return transform(element, rng)
144
+
145
+ ds = _Dataset(*paths, transform=None)
146
+ dl = pygrain.load(
147
+ ds,
148
+ batch_size=batch_size,
149
+ shuffle=shuffle,
150
+ seed=seed,
151
+ transformations=[_Transform()],
152
+ **kwargs,
153
+ )
154
+ iter_dl = iter(dl)
155
+
156
+ def generator(_):
157
+ return next(iter_dl)
158
+
159
+ return generator
ring/utils/utils.py CHANGED
@@ -3,6 +3,7 @@ import io
3
3
  import pickle
4
4
  import random
5
5
  from typing import Optional
6
+ import warnings
6
7
 
7
8
  import jax
8
9
  import jax.numpy as jnp
@@ -195,7 +196,7 @@ def replace_elements_w_nans(
195
196
  assert min(include_elements) >= 0
196
197
  assert max(include_elements) < len(list_of_data)
197
198
 
198
- def _is_nan(ele: tree_utils.PyTree, i: int):
199
+ def _is_nan(ele: tree_utils.PyTree, i: int, verbose: bool):
199
200
  isnan = np.any(
200
201
  [np.any(np.isnan(arr)) for arr in jax.tree_util.tree_leaves(ele)]
201
202
  )
@@ -205,13 +206,22 @@ def replace_elements_w_nans(
205
206
  return True
206
207
  return False
207
208
 
209
+ list_of_isnan = [int(_is_nan(e, 0, False)) for e in list_of_data]
210
+ perc_of_isnan = sum(list_of_isnan) / len(list_of_data)
211
+
212
+ if perc_of_isnan >= 0.02:
213
+ warnings.warn(
214
+ f"{perc_of_isnan * 100}% of {len(list_of_data)} datapoints are NaN"
215
+ )
216
+ assert perc_of_isnan != 1
217
+
208
218
  list_of_data_nonan = []
209
219
  for i, ele in enumerate(list_of_data):
210
- if _is_nan(ele, i):
220
+ if _is_nan(ele, i, verbose):
211
221
  while True:
212
222
  j = random.choice(include_elements)
213
223
  ele_j = list_of_data[j]
214
- if not _is_nan(ele_j, j):
224
+ if not _is_nan(ele_j, j, verbose):
215
225
  ele = pytree_deepcopy(ele_j)
216
226
  break
217
227
  list_of_data_nonan.append(ele)