imt-ring 1.6.30__py3-none-any.whl → 1.6.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.30
3
+ Version: 1.6.31
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -78,7 +78,7 @@ ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
78
78
  ring/utils/batchsize.py,sha256=FbOii7MDP4oPZd9GJOKehFatfnb6WZ0b9z349iZYs1A,1786
79
79
  ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
80
80
  ring/utils/dataloader.py,sha256=2CcsbUY2AZs8LraS5HTJXlEseuF-1gKmfyBkSsib-tE,3748
81
- ring/utils/dataloader_torch.py,sha256=DR2uUiA9x49_6EBjnbVLfWu7GBX7wtKjgHSIlF80HO0,1502
81
+ ring/utils/dataloader_torch.py,sha256=wMKJ-eCJ4cHjisGODOZgDVG2r-XQjSANBQFfC05wpzo,2092
82
82
  ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
83
83
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
84
84
  ring/utils/path.py,sha256=zRPfxYNesvgefkddd26oar6f9433LkMGkhp9dF3rPUs,1926
@@ -86,7 +86,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
86
86
  ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
87
87
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
88
88
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
89
- imt_ring-1.6.30.dist-info/METADATA,sha256=CC1Tqt_uwTMJZiSWoxBCLjkc3J1hWsNhPTJNsxamM0U,4251
90
- imt_ring-1.6.30.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
91
- imt_ring-1.6.30.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
- imt_ring-1.6.30.dist-info/RECORD,,
89
+ imt_ring-1.6.31.dist-info/METADATA,sha256=ISL0fShgxIGskumWa3mtqCBcOvOfEtJc4XE8Rt-EJCA,4251
90
+ imt_ring-1.6.31.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
91
+ imt_ring-1.6.31.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
+ imt_ring-1.6.31.dist-info/RECORD,,
@@ -1,4 +1,6 @@
1
1
  import os
2
+ from typing import Optional
3
+ import warnings
2
4
 
3
5
  import jax
4
6
  import torch
@@ -35,15 +37,20 @@ def dataset_to_generator(
35
37
  batch_size: int,
36
38
  shuffle=True,
37
39
  seed: int = 1,
40
+ num_workers: Optional[int] = None,
38
41
  **kwargs,
39
42
  ):
40
43
  torch.manual_seed(seed)
41
44
 
45
+ if num_workers is None:
46
+ num_workers = _get_number_of_logical_cores()
47
+
42
48
  dl = DataLoader(
43
49
  dataset,
44
50
  batch_size=batch_size,
45
51
  shuffle=shuffle,
46
- multiprocessing_context="spawn" if kwargs.get("num_workers", 0) > 0 else None,
52
+ multiprocessing_context="spawn" if num_workers > 0 else None,
53
+ num_workers=num_workers,
47
54
  **kwargs,
48
55
  )
49
56
  dl_iter = iter(dl)
@@ -60,3 +67,20 @@ def dataset_to_generator(
60
67
  return to_numpy(next(dl_iter))
61
68
 
62
69
  return generator
70
+
71
+
72
+ def _get_number_of_logical_cores() -> int:
73
+ N = None
74
+ if hasattr(os, "sched_getaffinity"):
75
+ try:
76
+ N = len(os.sched_getaffinity(0))
77
+ except Exception:
78
+ pass
79
+ if N is None:
80
+ N = os.cpu_count()
81
+ if N is None:
82
+ warnings.warn(
83
+ "Could not automatically set the `num_workers` variable, defaults to `0`"
84
+ )
85
+ N = 0
86
+ return N