imt-ring 1.6.30__py3-none-any.whl → 1.6.31__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.30
3
+ Version: 1.6.31
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -78,7 +78,7 @@ ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
78
78
  ring/utils/batchsize.py,sha256=FbOii7MDP4oPZd9GJOKehFatfnb6WZ0b9z349iZYs1A,1786
79
79
  ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
80
80
  ring/utils/dataloader.py,sha256=2CcsbUY2AZs8LraS5HTJXlEseuF-1gKmfyBkSsib-tE,3748
81
- ring/utils/dataloader_torch.py,sha256=DR2uUiA9x49_6EBjnbVLfWu7GBX7wtKjgHSIlF80HO0,1502
81
+ ring/utils/dataloader_torch.py,sha256=wMKJ-eCJ4cHjisGODOZgDVG2r-XQjSANBQFfC05wpzo,2092
82
82
  ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
83
83
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
84
84
  ring/utils/path.py,sha256=zRPfxYNesvgefkddd26oar6f9433LkMGkhp9dF3rPUs,1926
@@ -86,7 +86,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
86
86
  ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
87
87
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
88
88
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
89
- imt_ring-1.6.30.dist-info/METADATA,sha256=CC1Tqt_uwTMJZiSWoxBCLjkc3J1hWsNhPTJNsxamM0U,4251
90
- imt_ring-1.6.30.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
91
- imt_ring-1.6.30.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
- imt_ring-1.6.30.dist-info/RECORD,,
89
+ imt_ring-1.6.31.dist-info/METADATA,sha256=ISL0fShgxIGskumWa3mtqCBcOvOfEtJc4XE8Rt-EJCA,4251
90
+ imt_ring-1.6.31.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
91
+ imt_ring-1.6.31.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
+ imt_ring-1.6.31.dist-info/RECORD,,
@@ -1,4 +1,6 @@
1
1
  import os
2
+ from typing import Optional
3
+ import warnings
2
4
 
3
5
  import jax
4
6
  import torch
@@ -35,15 +37,20 @@ def dataset_to_generator(
35
37
  batch_size: int,
36
38
  shuffle=True,
37
39
  seed: int = 1,
40
+ num_workers: Optional[int] = None,
38
41
  **kwargs,
39
42
  ):
40
43
  torch.manual_seed(seed)
41
44
 
45
+ if num_workers is None:
46
+ num_workers = _get_number_of_logical_cores()
47
+
42
48
  dl = DataLoader(
43
49
  dataset,
44
50
  batch_size=batch_size,
45
51
  shuffle=shuffle,
46
- multiprocessing_context="spawn" if kwargs.get("num_workers", 0) > 0 else None,
52
+ multiprocessing_context="spawn" if num_workers > 0 else None,
53
+ num_workers=num_workers,
47
54
  **kwargs,
48
55
  )
49
56
  dl_iter = iter(dl)
@@ -60,3 +67,20 @@ def dataset_to_generator(
60
67
  return to_numpy(next(dl_iter))
61
68
 
62
69
  return generator
70
+
71
+
72
+ def _get_number_of_logical_cores() -> int:
73
+ N = None
74
+ if hasattr(os, "sched_getaffinity"):
75
+ try:
76
+ N = len(os.sched_getaffinity(0))
77
+ except Exception:
78
+ pass
79
+ if N is None:
80
+ N = os.cpu_count()
81
+ if N is None:
82
+ warnings.warn(
83
+ "Could not automatically set the `num_workers` variable, defaults to `0`"
84
+ )
85
+ N = 0
86
+ return N