imt-ring 1.6.11__py3-none-any.whl → 1.6.12__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {imt_ring-1.6.11.dist-info → imt_ring-1.6.12.dist-info}/METADATA +1 -1
- {imt_ring-1.6.11.dist-info → imt_ring-1.6.12.dist-info}/RECORD +12 -11
- {imt_ring-1.6.11.dist-info → imt_ring-1.6.12.dist-info}/WHEEL +1 -1
- ring/algorithms/generator/base.py +3 -0
- ring/algorithms/generator/finalize_fns.py +17 -0
- ring/algorithms/jcalc.py +10 -1
- ring/algorithms/sensors.py +1 -1
- ring/base.py +3 -1
- ring/ml/ml_utils.py +14 -21
- ring/utils/dataloader.py +159 -0
- ring/utils/utils.py +13 -3
- {imt_ring-1.6.11.dist-info → imt_ring-1.6.12.dist-info}/top_level.txt +0 -0
@@ -1,22 +1,22 @@
|
|
1
1
|
ring/__init__.py,sha256=k7tL-XgggUwWxHCXyv60rQn-OcXHPg82QcIUkKLEd-c,5057
|
2
2
|
ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
|
3
|
-
ring/base.py,sha256=
|
3
|
+
ring/base.py,sha256=Ystn1EjTyOXBhVm5koroV_YPUYtFxrteJLd-XR3kEL8,33840
|
4
4
|
ring/maths.py,sha256=qPHH6TpHCK3TgExI98gNEySoSRKOwteN9McUlyUFipI,12207
|
5
5
|
ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
|
6
6
|
ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
|
7
7
|
ring/algorithms/_random.py,sha256=fc26yEQjSjtf0NluZ41CyeGIRci0ldrRlThueHR9H7U,14007
|
8
8
|
ring/algorithms/dynamics.py,sha256=GOedL1STj6oXcXgMA7dB4PabvCQxPBbirJQhXBRuKqE,10929
|
9
|
-
ring/algorithms/jcalc.py,sha256=
|
9
|
+
ring/algorithms/jcalc.py,sha256=bwfVH3qKEnUs6RFgEEeUBnecpBt-nf8cesJbNGDrE7g,28974
|
10
10
|
ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
|
11
|
-
ring/algorithms/sensors.py,sha256=
|
11
|
+
ring/algorithms/sensors.py,sha256=0xOzdQIc1kBF0CkoPXWWCx3MmV4SG3wj7knVnnMWq9M,18124
|
12
12
|
ring/algorithms/custom_joints/__init__.py,sha256=fzeE7TdUhmGgbbFAyis1tKcyQ4Fo8LigDwD3hUVnH_w,316
|
13
13
|
ring/algorithms/custom_joints/rr_imp_joint.py,sha256=_YJK0p8_0MHFtr1NuGnNZoxTbwaMQyUjYv7EtsPiU3A,2402
|
14
14
|
ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
|
15
15
|
ring/algorithms/custom_joints/suntay.py,sha256=tOEGM304XciHO4pmvxr4faA4xXVO4N2HlPdFmXKbcrw,16726
|
16
16
|
ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
|
17
|
-
ring/algorithms/generator/base.py,sha256=
|
17
|
+
ring/algorithms/generator/base.py,sha256=vxUdA0ZeSNH3SOanL51qVRvCiJrmsWQyQX0g2fdm3Rg,15825
|
18
18
|
ring/algorithms/generator/batch.py,sha256=9yFxVv11hij-fJXGPxA3zEh1bE2_jrZk0R7kyGaiM5c,2551
|
19
|
-
ring/algorithms/generator/finalize_fns.py,sha256=
|
19
|
+
ring/algorithms/generator/finalize_fns.py,sha256=559sGXs06n46p-eme0SE8hn0lXwGT0P2r3-52ElTldo,9861
|
20
20
|
ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
|
21
21
|
ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
|
22
22
|
ring/algorithms/generator/setup_fns.py,sha256=MFz3czHBeWs1Zk1A8O02CyQpQ-NCyW9PMpbqmKit6es,1455
|
@@ -53,7 +53,7 @@ ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
|
|
53
53
|
ring/ml/__init__.py,sha256=nbh48gaswWeY4S4vT1sply_3ROj2DQ7agjoLR4Ho3T8,1517
|
54
54
|
ring/ml/base.py,sha256=lfwEZLBDglOSRWChUHoH1kezefhttPV9TMEpNIqsMNw,9972
|
55
55
|
ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
|
56
|
-
ring/ml/ml_utils.py,sha256=
|
56
|
+
ring/ml/ml_utils.py,sha256=siiRWbUpjYQz1nAlARm47oqR2K74YTiE1syCoOEmiWw,6370
|
57
57
|
ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
|
58
58
|
ring/ml/ringnet.py,sha256=Tb2WJ_cc5L3mk1lo0NOfkpXIzJZXf4PJ5aLPtHQyUmY,8650
|
59
59
|
ring/ml/rnno_v1.py,sha256=T4SKG7iypqn2HBQLKhDmJ2Slj2Z5jtUBHvX_6aL8pyM,1103
|
@@ -76,14 +76,15 @@ ring/utils/__init__.py,sha256=MHHavc8YfjBlmB-zAV42QEQS_ebW7cy0lhWXEVyQU7s,720
|
|
76
76
|
ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
|
77
77
|
ring/utils/batchsize.py,sha256=FbOii7MDP4oPZd9GJOKehFatfnb6WZ0b9z349iZYs1A,1786
|
78
78
|
ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
|
79
|
+
ring/utils/dataloader.py,sha256=2CcsbUY2AZs8LraS5HTJXlEseuF-1gKmfyBkSsib-tE,3748
|
79
80
|
ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
|
80
81
|
ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
|
81
82
|
ring/utils/path.py,sha256=zRPfxYNesvgefkddd26oar6f9433LkMGkhp9dF3rPUs,1926
|
82
83
|
ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
|
83
|
-
ring/utils/utils.py,sha256=
|
84
|
+
ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
|
84
85
|
ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
|
85
86
|
ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
|
86
|
-
imt_ring-1.6.
|
87
|
-
imt_ring-1.6.
|
88
|
-
imt_ring-1.6.
|
89
|
-
imt_ring-1.6.
|
87
|
+
imt_ring-1.6.12.dist-info/METADATA,sha256=NIcGCBCzA9jwqxvyHYHl5QdfiaFLLxdnQjOk17YX0bA,3821
|
88
|
+
imt_ring-1.6.12.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
89
|
+
imt_ring-1.6.12.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
90
|
+
imt_ring-1.6.12.dist-info/RECORD,,
|
@@ -321,6 +321,9 @@ def _build_mconfig_batched_generator(
|
|
321
321
|
"using the `randomize_motion_artifacts` flag, so it must be enabled."
|
322
322
|
)
|
323
323
|
|
324
|
+
if dynamic_simulation:
|
325
|
+
finalize_fns.DynamicalSimulation.assert_test_system(sys)
|
326
|
+
|
324
327
|
def _setup_fn(key: types.PRNGKey, sys: base.System) -> base.System:
|
325
328
|
pipe = []
|
326
329
|
if imu_motion_artifacts and randomize_motion_artifacts:
|
@@ -180,6 +180,23 @@ class DynamicalSimulation:
|
|
180
180
|
self.overwrite_q_ref = overwrite_q_ref
|
181
181
|
self.unroll_kwargs = unroll_kwargs
|
182
182
|
|
183
|
+
@staticmethod
|
184
|
+
def assert_test_system(sys: base.System) -> None:
|
185
|
+
"test that system has no zero mass bodies and no joints without damping"
|
186
|
+
|
187
|
+
def f(_, __, n, m, d):
|
188
|
+
assert d.size == 0 or m > 0, (
|
189
|
+
"Dynamic simulation is set to `True` which requires masses >= 0, "
|
190
|
+
f"but found body `{n}` with mass={float(m[0])}. This can lead to NaNs."
|
191
|
+
)
|
192
|
+
|
193
|
+
assert d.size == 0 or all(d > 0.0), (
|
194
|
+
"Dynamic simulation is set to `True` which requires dampings > 0, "
|
195
|
+
f"but found body `{n}` with damping={d}. This can lead to NaNs."
|
196
|
+
)
|
197
|
+
|
198
|
+
sys.scan(f, "lld", sys.link_names, sys.links.inertia.mass, sys.link_damping)
|
199
|
+
|
183
200
|
def __call__(
|
184
201
|
self, Xy: types.Xy, extras: types.OutputExtras
|
185
202
|
) -> tuple[types.Xy, types.OutputExtras]:
|
ring/algorithms/jcalc.py
CHANGED
@@ -205,7 +205,7 @@ def _is_feasible_config1(c: MotionConfig) -> bool:
|
|
205
205
|
return False
|
206
206
|
return True
|
207
207
|
|
208
|
-
|
208
|
+
cond1 = all(
|
209
209
|
[
|
210
210
|
dx_deltax_check(*args)
|
211
211
|
for args in zip(
|
@@ -217,6 +217,15 @@ def _is_feasible_config1(c: MotionConfig) -> bool:
|
|
217
217
|
]
|
218
218
|
)
|
219
219
|
|
220
|
+
# this one tests that the initial value is inside the feasible value range
|
221
|
+
# so e.g. if you choose pos0_min=-10 then you can't choose pos_min=-1
|
222
|
+
def inside_box_checks(x_min, x_max, x0_min, x0_max) -> bool:
|
223
|
+
return (x0_min >= x_min) and (x0_max <= x_max)
|
224
|
+
|
225
|
+
cond2 = inside_box_checks(c.pos_min, c.pos_max, c.pos0_min, c.pos0_max)
|
226
|
+
|
227
|
+
return cond1 and cond2
|
228
|
+
|
220
229
|
|
221
230
|
def _find_interval(t: jax.Array, boundaries: jax.Array):
|
222
231
|
"""Find the interval of `boundaries` between which `t` lies.
|
ring/algorithms/sensors.py
CHANGED
@@ -131,7 +131,7 @@ def magnetometer(rot: jax.Array, magvec: jax.Array) -> jax.Array:
|
|
131
131
|
# - gyr: rad/s
|
132
132
|
# - mag: a.u.
|
133
133
|
NOISE_LEVELS = {"acc": 0.048, "gyr": jnp.deg2rad(0.7), "mag": 0.01}
|
134
|
-
BIAS_LEVELS = {"acc": 0.5, "gyr": jnp.deg2rad(3
|
134
|
+
BIAS_LEVELS = {"acc": 0.5, "gyr": jnp.deg2rad(3), "mag": 0.0}
|
135
135
|
|
136
136
|
|
137
137
|
def add_noise_bias(
|
ring/base.py
CHANGED
@@ -690,7 +690,9 @@ class System(_Base):
|
|
690
690
|
transparent_segment_to_root: bool = True,
|
691
691
|
**kwargs,
|
692
692
|
):
|
693
|
-
"`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent.
|
693
|
+
"""`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent.
|
694
|
+
Note that the body in yhat that connects to -1, is parent-to-child!
|
695
|
+
"""
|
694
696
|
return ring.rendering.render_prediction(
|
695
697
|
self, xs, yhat, transparent_segment_to_root, **kwargs
|
696
698
|
)
|
ring/ml/ml_utils.py
CHANGED
@@ -12,7 +12,6 @@ import numpy as np
|
|
12
12
|
from tree_utils import PyTree
|
13
13
|
|
14
14
|
import ring
|
15
|
-
from ring.utils import import_lib
|
16
15
|
import wandb
|
17
16
|
|
18
17
|
# An arbitrarily nested dictionary with Array leaves; Or strings
|
@@ -190,36 +189,30 @@ def unique_id() -> str:
|
|
190
189
|
|
191
190
|
def save_model_tf(jax_func, path: str, *input, validate: bool = True):
|
192
191
|
from jax.experimental import jax2tf
|
192
|
+
import tensorflow as tf
|
193
193
|
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
signature = jax.tree_map(
|
198
|
-
lambda arr: tf.TensorSpec(list(arr.shape), tf.float32), input
|
199
|
-
)
|
194
|
+
signature = jax.tree_map(
|
195
|
+
lambda arr: tf.TensorSpec(list(arr.shape), tf.float32), input
|
196
|
+
)
|
200
197
|
|
201
|
-
|
202
|
-
def __init__(self, jax_func):
|
203
|
-
super().__init__()
|
204
|
-
self.tf_func = jax2tf.convert(jax_func, with_gradient=False)
|
198
|
+
tf_func = jax2tf.convert(jax_func, with_gradient=False)
|
205
199
|
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
)
|
212
|
-
def __call__(self, *args):
|
213
|
-
return self.tf_func(*args)
|
200
|
+
class RingTFModule(tf.Module):
|
201
|
+
@partial(
|
202
|
+
tf.function, autograph=False, jit_compile=True, input_signature=signature
|
203
|
+
)
|
204
|
+
def __call__(self, *args):
|
205
|
+
return tf_func(*args)
|
214
206
|
|
215
|
-
|
207
|
+
model = RingTFModule()
|
216
208
|
|
217
|
-
model = _create_module(jax_func, input)
|
218
209
|
tf.saved_model.save(
|
219
210
|
model,
|
220
211
|
path,
|
221
212
|
options=tf.saved_model.SaveOptions(experimental_custom_gradients=False),
|
213
|
+
signatures={"default": model.__call__},
|
222
214
|
)
|
215
|
+
|
223
216
|
if validate:
|
224
217
|
output_jax = jax_func(*input)
|
225
218
|
output_tf = tf.saved_model.load(path)(*input)
|
ring/utils/dataloader.py
ADDED
@@ -0,0 +1,159 @@
|
|
1
|
+
import os
|
2
|
+
import random
|
3
|
+
from typing import Callable, Optional
|
4
|
+
|
5
|
+
import jax
|
6
|
+
import numpy as np
|
7
|
+
from ring.utils import parse_path
|
8
|
+
from ring.utils import pickle_load
|
9
|
+
import torch
|
10
|
+
from torch.utils.data import DataLoader
|
11
|
+
from torch.utils.data import Dataset
|
12
|
+
import tqdm
|
13
|
+
from tree_utils import PyTree
|
14
|
+
|
15
|
+
|
16
|
+
def make_generator(
|
17
|
+
*paths,
|
18
|
+
batch_size,
|
19
|
+
transform,
|
20
|
+
shuffle=True,
|
21
|
+
seed: int = 1,
|
22
|
+
backend: str = "eager",
|
23
|
+
**kwargs,
|
24
|
+
):
|
25
|
+
if backend == "grain":
|
26
|
+
_make_gen = pygrain_generator
|
27
|
+
elif backend == "torch":
|
28
|
+
_make_gen = pytorch_generator
|
29
|
+
elif backend == "eager":
|
30
|
+
_make_gen = eager_generator
|
31
|
+
else:
|
32
|
+
raise NotImplementedError
|
33
|
+
|
34
|
+
return _make_gen(
|
35
|
+
*paths,
|
36
|
+
batch_size=batch_size,
|
37
|
+
transform=transform,
|
38
|
+
shuffle=shuffle,
|
39
|
+
seed=seed,
|
40
|
+
**kwargs,
|
41
|
+
)
|
42
|
+
|
43
|
+
|
44
|
+
T = PyTree[np.ndarray]
|
45
|
+
|
46
|
+
|
47
|
+
class _Dataset(Dataset):
|
48
|
+
def __init__(self, *paths, transform):
|
49
|
+
|
50
|
+
self.files = [self.listdir(path) for path in paths]
|
51
|
+
Ns = set([len(f) for f in self.files])
|
52
|
+
assert len(Ns) == 1, f"{Ns}"
|
53
|
+
|
54
|
+
self.P = len(self.files)
|
55
|
+
self.N = list(Ns)[0]
|
56
|
+
self.transform = transform
|
57
|
+
|
58
|
+
def __len__(self):
|
59
|
+
return self.N
|
60
|
+
|
61
|
+
def __getitem__(self, idx: int):
|
62
|
+
element = [pickle_load(self.files[p][idx]) for p in range(self.P)]
|
63
|
+
if self.transform is not None:
|
64
|
+
element = self.transform(element)
|
65
|
+
return element
|
66
|
+
|
67
|
+
@staticmethod
|
68
|
+
def listdir(path: str) -> list:
|
69
|
+
return [parse_path(path, file) for file in os.listdir(path)]
|
70
|
+
|
71
|
+
def __call__(self, idx: int):
|
72
|
+
return self[idx]
|
73
|
+
|
74
|
+
|
75
|
+
class TransformTransform:
|
76
|
+
def __init__(self, transform):
|
77
|
+
self.transform = transform
|
78
|
+
|
79
|
+
def __call__(self, element):
|
80
|
+
if self.transform is None:
|
81
|
+
return element
|
82
|
+
return self.transform(element, np.random.default_rng())
|
83
|
+
|
84
|
+
|
85
|
+
def pytorch_generator(
|
86
|
+
*paths,
|
87
|
+
batch_size: int,
|
88
|
+
transform: Optional[Callable[[T], T]] = None,
|
89
|
+
shuffle=True,
|
90
|
+
seed: int = 1,
|
91
|
+
**kwargs,
|
92
|
+
):
|
93
|
+
torch.manual_seed(seed)
|
94
|
+
|
95
|
+
ds = _Dataset(*paths, transform=TransformTransform(transform))
|
96
|
+
dl = DataLoader(
|
97
|
+
ds,
|
98
|
+
batch_size=batch_size,
|
99
|
+
shuffle=shuffle,
|
100
|
+
multiprocessing_context="spawn" if kwargs.get("num_workers", 0) > 0 else None,
|
101
|
+
**kwargs,
|
102
|
+
)
|
103
|
+
dl_iter = iter(dl)
|
104
|
+
|
105
|
+
def to_numpy(tree: PyTree[torch.Tensor]):
|
106
|
+
return jax.tree_map(lambda tensor: tensor.numpy(), tree)
|
107
|
+
|
108
|
+
def generator(_):
|
109
|
+
nonlocal dl, dl_iter
|
110
|
+
try:
|
111
|
+
return to_numpy(next(dl_iter))
|
112
|
+
except StopIteration:
|
113
|
+
dl_iter = iter(dl)
|
114
|
+
return to_numpy(next(dl_iter))
|
115
|
+
|
116
|
+
return generator
|
117
|
+
|
118
|
+
|
119
|
+
def eager_generator(
|
120
|
+
*paths,
|
121
|
+
batch_size: int,
|
122
|
+
transform: Optional[Callable[[T], T]] = None,
|
123
|
+
shuffle=True,
|
124
|
+
seed=1,
|
125
|
+
):
|
126
|
+
from ring import RCMG
|
127
|
+
|
128
|
+
random.seed(seed)
|
129
|
+
|
130
|
+
ds = _Dataset(*paths, transform=TransformTransform(transform))
|
131
|
+
data = [ds[i] for i in tqdm.tqdm(range(len(ds)), total=len(ds))]
|
132
|
+
return RCMG.eager_gen_from_list(data, batch_size, shuffle=shuffle)
|
133
|
+
|
134
|
+
|
135
|
+
def pygrain_generator(
|
136
|
+
*paths, batch_size: int, transform=None, shuffle=True, seed=1, **kwargs
|
137
|
+
):
|
138
|
+
|
139
|
+
import grain.python as pygrain # type: ignore
|
140
|
+
|
141
|
+
class _Transform(pygrain.RandomMapTransform):
|
142
|
+
def random_map(self, element, rng: np.random.Generator):
|
143
|
+
return transform(element, rng)
|
144
|
+
|
145
|
+
ds = _Dataset(*paths, transform=None)
|
146
|
+
dl = pygrain.load(
|
147
|
+
ds,
|
148
|
+
batch_size=batch_size,
|
149
|
+
shuffle=shuffle,
|
150
|
+
seed=seed,
|
151
|
+
transformations=[_Transform()],
|
152
|
+
**kwargs,
|
153
|
+
)
|
154
|
+
iter_dl = iter(dl)
|
155
|
+
|
156
|
+
def generator(_):
|
157
|
+
return next(iter_dl)
|
158
|
+
|
159
|
+
return generator
|
ring/utils/utils.py
CHANGED
@@ -3,6 +3,7 @@ import io
|
|
3
3
|
import pickle
|
4
4
|
import random
|
5
5
|
from typing import Optional
|
6
|
+
import warnings
|
6
7
|
|
7
8
|
import jax
|
8
9
|
import jax.numpy as jnp
|
@@ -195,7 +196,7 @@ def replace_elements_w_nans(
|
|
195
196
|
assert min(include_elements) >= 0
|
196
197
|
assert max(include_elements) < len(list_of_data)
|
197
198
|
|
198
|
-
def _is_nan(ele: tree_utils.PyTree, i: int):
|
199
|
+
def _is_nan(ele: tree_utils.PyTree, i: int, verbose: bool):
|
199
200
|
isnan = np.any(
|
200
201
|
[np.any(np.isnan(arr)) for arr in jax.tree_util.tree_leaves(ele)]
|
201
202
|
)
|
@@ -205,13 +206,22 @@ def replace_elements_w_nans(
|
|
205
206
|
return True
|
206
207
|
return False
|
207
208
|
|
209
|
+
list_of_isnan = [int(_is_nan(e, 0, False)) for e in list_of_data]
|
210
|
+
perc_of_isnan = sum(list_of_isnan) / len(list_of_data)
|
211
|
+
|
212
|
+
if perc_of_isnan >= 0.02:
|
213
|
+
warnings.warn(
|
214
|
+
f"{perc_of_isnan * 100}% of {len(list_of_data)} datapoints are NaN"
|
215
|
+
)
|
216
|
+
assert perc_of_isnan != 1
|
217
|
+
|
208
218
|
list_of_data_nonan = []
|
209
219
|
for i, ele in enumerate(list_of_data):
|
210
|
-
if _is_nan(ele, i):
|
220
|
+
if _is_nan(ele, i, verbose):
|
211
221
|
while True:
|
212
222
|
j = random.choice(include_elements)
|
213
223
|
ele_j = list_of_data[j]
|
214
|
-
if not _is_nan(ele_j, j):
|
224
|
+
if not _is_nan(ele_j, j, verbose):
|
215
225
|
ele = pytree_deepcopy(ele_j)
|
216
226
|
break
|
217
227
|
list_of_data_nonan.append(ele)
|
File without changes
|