libinephany 0.15.0__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -36,7 +36,7 @@ class InitialHyperparameters(GlobalObserver):
36
36
 
37
37
  super().__init__(**kwargs)
38
38
 
39
- force_skip = ["samples"]
39
+ force_skip = ["samples", "gradient_accumulation"]
40
40
  skip_hparams = force_skip if skip_hparams is None else skip_hparams + force_skip
41
41
  self.skip_hparams = [] if skip_hparams is None else skip_hparams
42
42
  self.pad_with = pad_with
@@ -213,10 +213,10 @@ class Statistic(ABC):
213
213
 
214
214
  if torch_distributed_utils.is_scheduler_master_rank():
215
215
  if isinstance(statistic, torch.Tensor):
216
- shape = statistic.shape
216
+ shape = statistic.view(-1).shape
217
217
 
218
218
  elif isinstance(statistic, TensorStatistics):
219
- shape = statistic.to_tensor().shape
219
+ shape = statistic.to_tensor().view(-1).shape
220
220
 
221
221
  elif statistic is not None:
222
222
  shape = torch.tensor([statistic]).shape
@@ -239,23 +239,21 @@ class Statistic(ABC):
239
239
  if not torch_distributed_utils.is_distributed():
240
240
  return statistic
241
241
 
242
- if statistic is None:
243
- shape = self._determine_reduction_shape(statistic=statistic)
244
-
245
- if shape is None:
246
- return statistic
242
+ shape = self._determine_reduction_shape(statistic=statistic)
247
243
 
248
- to_reduce = torch.zeros(shape)
244
+ if statistic is None:
245
+ to_reduce = torch.zeros(shape, dtype=torch.float64)
249
246
 
250
247
  elif isinstance(statistic, torch.Tensor):
251
- to_reduce = statistic.clone()
248
+ to_reduce = statistic.clone().to(torch.float64).view(-1)
252
249
 
253
250
  elif isinstance(statistic, TensorStatistics):
254
- to_reduce = statistic.to_tensor()
251
+ to_reduce = statistic.to_tensor().to(torch.float64).view(-1)
255
252
 
256
253
  else:
257
- to_reduce = torch.tensor([statistic])
254
+ to_reduce = torch.tensor([statistic], dtype=torch.float64)
258
255
 
256
+ to_reduce = to_reduce.to(torch_distributed_utils.get_local_device())
259
257
  dist.reduce(to_reduce, dst=MASTER_SCHEDULER_RANK, op=ReduceOp.SUM)
260
258
 
261
259
  if not torch_distributed_utils.is_scheduler_master_rank():
@@ -283,11 +281,13 @@ class Statistic(ABC):
283
281
 
284
282
  parameter_group = self._find_parameter_group(optimizer=optimizer)
285
283
  parameters = self._get_parameters(parameter_group=parameter_group)
284
+ self._sample_number += 1
286
285
 
287
286
  if self._sample_number % self.sample_frequency == 0:
288
287
  statistic = self._gather(
289
288
  optimizer=optimizer, model=model, parameters=parameters, parameter_group=parameter_group
290
289
  )
290
+
291
291
  statistic = self._distributed_reduce(statistic=statistic)
292
292
 
293
293
  if not torch_distributed_utils.is_scheduler_master_rank():
@@ -303,9 +303,6 @@ class Statistic(ABC):
303
303
  elif statistic is not None:
304
304
  self._data.append(statistic) # type: ignore
305
305
 
306
- if torch_distributed_utils.is_scheduler_master_rank():
307
- self._sample_number += 1
308
-
309
306
  @final
310
307
  def fetch(self) -> TensorStatistics | float | None:
311
308
  """
@@ -4,8 +4,10 @@
4
4
  #
5
5
  # ======================================================================================================================
6
6
 
7
+ import os
7
8
  from typing import Any
8
9
 
10
+ import torch
9
11
  import torch.distributed as dist
10
12
 
11
13
  # ======================================================================================================================
@@ -14,7 +16,11 @@ import torch.distributed as dist
14
16
  #
15
17
  # ======================================================================================================================
16
18
 
19
+ CUDA = "cuda"
20
+ CPU = "cpu"
21
+ CUDA_PREFIX = f"{CUDA}:"
17
22
  MASTER_SCHEDULER_RANK = 0
23
+ LOCAL_RANK = "LOCAL_RANK"
18
24
 
19
25
  # ======================================================================================================================
20
26
  #
@@ -48,7 +54,10 @@ def get_local_rank() -> int:
48
54
  :return: Distributed computing rank of this process.
49
55
  """
50
56
 
51
- return dist.get_rank() if is_distributed() else MASTER_SCHEDULER_RANK
57
+ if not is_distributed():
58
+ return MASTER_SCHEDULER_RANK
59
+
60
+ return dist.get_rank()
52
61
 
53
62
 
54
63
  def is_scheduler_master_rank() -> bool:
@@ -83,3 +92,15 @@ def barrier() -> None:
83
92
 
84
93
  if is_distributed():
85
94
  dist.barrier()
95
+
96
+
97
+ def get_local_device() -> torch.device:
98
+ """
99
+ :return: Local device of the current rank.
100
+ """
101
+
102
+ if not is_distributed():
103
+ return torch.device(CUDA if torch.cuda.is_available() else CPU)
104
+
105
+ local_device_rank = os.environ.get(LOCAL_RANK, MASTER_SCHEDULER_RANK)
106
+ return torch.device(f"{CUDA_PREFIX}{local_device_rank}" if torch.cuda.is_available() else CPU)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: libinephany
3
- Version: 0.15.0
3
+ Version: 0.15.1
4
4
  Summary: Inephany library containing code commonly used by multiple subpackages.
5
5
  Author-email: Inephany <info@inephany.com>
6
6
  License: Apache 2.0
@@ -6,10 +6,10 @@ libinephany/observations/observation_utils.py,sha256=wb6EZiaEiPuOqN26zzuT1rHyeho
6
6
  libinephany/observations/observer_pipeline.py,sha256=ZhONGXJQSgs2VJJn9d2F7ItkYqntvchl9-JTyxW9eU0,12146
7
7
  libinephany/observations/pipeline_coordinator.py,sha256=FrN3linKaC0pVE5uKjlh_0Fi8Mb1oK91NzH3Fq7PvyM,7420
8
8
  libinephany/observations/statistic_manager.py,sha256=LLg1zSxnJr2oQQepYla3qoUuRy10rsthr9jta4wEbnc,8956
9
- libinephany/observations/statistic_trackers.py,sha256=flkXquMHvY6YjnQAvRElsV5OUm7Ek_PhA1_fvtX-0oQ,30124
9
+ libinephany/observations/statistic_trackers.py,sha256=J444i9EZ30vcYOEYqcDBzz7_UDpEE2hW_ISYBu_hwYc,30180
10
10
  libinephany/observations/observers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  libinephany/observations/observers/base_observers.py,sha256=RkG5SW0b6Ooy0_oscRHxyB_YFNP7k8fxu37jBZElxIM,15418
12
- libinephany/observations/observers/global_observers.py,sha256=-BJJaYjQSO82qskIlY_iijd3Lk1Ei1d3Hg1fzmYUPSM,38659
12
+ libinephany/observations/observers/global_observers.py,sha256=3TaiV2AxMOXfDq-kXMU3ZSo-rQENNCFhdWCJtpY99ok,38684
13
13
  libinephany/observations/observers/local_observers.py,sha256=EdivrylOcmxRsu4xiMwZqwmPX8Ru9-IRwoPk6En7qvw,37050
14
14
  libinephany/observations/observers/observer_containers.py,sha256=g73ScbRRVTNbGEBb-Nyk8AQwoDhKZaqTd6OYP8FIcOs,8771
15
15
  libinephany/observations/post_processors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -43,15 +43,15 @@ libinephany/utils/optim_utils.py,sha256=-PLqsyuq4ZH3spBy_olNB3yuLwvhnLrCF0384elC
43
43
  libinephany/utils/random_seeds.py,sha256=eF-ErrMShu8mp9V_gXrB_iUxR-Lb-OtHypEEUQAGn2Y,1565
44
44
  libinephany/utils/samplers.py,sha256=uyVGAy5cm5bCyWMOuySJmzUc_vFuieO_3zydJciwdv4,12158
45
45
  libinephany/utils/standardizers.py,sha256=pG1K_XL4OR_NjVtT6Hjbln1dk1BtQdDuSK1PQTkA17Y,8014
46
- libinephany/utils/torch_distributed_utils.py,sha256=ygdVz-s7hMRoBJcZkNRBlF81MYnxoRJt8S0SAwq6SC4,2467
46
+ libinephany/utils/torch_distributed_utils.py,sha256=UPMfhdZZwyHX_r3h55AAK4PcB-zFtjK37Z5aawAKNmE,2968
47
47
  libinephany/utils/torch_utils.py,sha256=o5TsqrXe6Id04P6SqB_avGBRZutbu6IBB61llAHQ_PY,2696
48
48
  libinephany/utils/transforms.py,sha256=Ca4pbCs_FbCpXb8M8oPxrP5QOqOAwGSdGpKzy5YUubc,3503
49
49
  libinephany/utils/typing.py,sha256=rGbaPO3MaUndsWiC_wHzReD_TOLYqb43i01pKN-j7Xs,624
50
50
  libinephany/web_apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
51
51
  libinephany/web_apps/error_logger.py,sha256=gAQIaqerqP4ornXZwFF1cghjnd2mMZEt3aVrTuUCr34,16653
52
52
  libinephany/web_apps/web_app_utils.py,sha256=qiq_lasPipgN1RgRudPJc342kYci8O_4RqppxmIX8NY,4095
53
- libinephany-0.15.0.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
54
- libinephany-0.15.0.dist-info/METADATA,sha256=lU7SqV1ArMEAyuZ845Z1jAYxNUEYGfJ8Tl6Df6EwSpc,8354
55
- libinephany-0.15.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
56
- libinephany-0.15.0.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
57
- libinephany-0.15.0.dist-info/RECORD,,
53
+ libinephany-0.15.1.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
54
+ libinephany-0.15.1.dist-info/METADATA,sha256=lsqYtqyJ_k_clascJkzx8rR7gEN75tZ8lCKNqcH1cps,8354
55
+ libinephany-0.15.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
56
+ libinephany-0.15.1.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
57
+ libinephany-0.15.1.dist-info/RECORD,,