libinephany 0.15.0__py3-none-any.whl → 0.15.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,7 +14,7 @@ import torch
14
14
  import torch.optim as optim
15
15
 
16
16
  from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
17
- from libinephany.utils import optim_utils
17
+ from libinephany.utils import optim_utils, torch_distributed_utils
18
18
 
19
19
  # ======================================================================================================================
20
20
  #
@@ -173,7 +173,14 @@ def tensor_on_local_rank(tensor: torch.Tensor | None) -> bool:
173
173
  :return: Whether the tensor is owned by the local rank.
174
174
  """
175
175
 
176
- return tensor is not None and tensor.grad is not None and tensor.numel() > 0
176
+ valid_tensor = tensor is not None and tensor.grad is not None and tensor.numel() > 0
177
+
178
+ if valid_tensor and tensor.is_cuda:
179
+ local_rank = torch_distributed_utils.get_local_rank()
180
+
181
+ return tensor.device.index == local_rank
182
+
183
+ return valid_tensor
177
184
 
178
185
 
179
186
  def form_update_tensor(
@@ -36,7 +36,7 @@ class InitialHyperparameters(GlobalObserver):
36
36
 
37
37
  super().__init__(**kwargs)
38
38
 
39
- force_skip = ["samples"]
39
+ force_skip = ["samples", "gradient_accumulation"]
40
40
  skip_hparams = force_skip if skip_hparams is None else skip_hparams + force_skip
41
41
  self.skip_hparams = [] if skip_hparams is None else skip_hparams
42
42
  self.pad_with = pad_with
@@ -193,13 +193,14 @@ class Statistic(ABC):
193
193
  Processes the tensor cache to build a TensorStatistic model.
194
194
  """
195
195
 
196
- concatenated = torch.cat(self._tensor_cache)
197
- self._tensor_cache = []
196
+ if self._tensor_cache:
197
+ concatenated = torch.cat(self._tensor_cache)
198
+ self._tensor_cache = []
198
199
 
199
- statistics = TensorStatistics.build(
200
- tensor=concatenated, skip_statistics=self.skip_statistics, sample_percentage=self.downsample_percent
201
- )
202
- self._data.append(statistics) # type: ignore
200
+ statistics = TensorStatistics.build(
201
+ tensor=concatenated, skip_statistics=self.skip_statistics, sample_percentage=self.downsample_percent
202
+ )
203
+ self._data.append(statistics) # type: ignore
203
204
 
204
205
  @staticmethod
205
206
  @final
@@ -213,10 +214,10 @@ class Statistic(ABC):
213
214
 
214
215
  if torch_distributed_utils.is_scheduler_master_rank():
215
216
  if isinstance(statistic, torch.Tensor):
216
- shape = statistic.shape
217
+ shape = statistic.view(-1).shape
217
218
 
218
219
  elif isinstance(statistic, TensorStatistics):
219
- shape = statistic.to_tensor().shape
220
+ shape = statistic.to_tensor().view(-1).shape
220
221
 
221
222
  elif statistic is not None:
222
223
  shape = torch.tensor([statistic]).shape
@@ -239,23 +240,21 @@ class Statistic(ABC):
239
240
  if not torch_distributed_utils.is_distributed():
240
241
  return statistic
241
242
 
242
- if statistic is None:
243
- shape = self._determine_reduction_shape(statistic=statistic)
244
-
245
- if shape is None:
246
- return statistic
243
+ shape = self._determine_reduction_shape(statistic=statistic)
247
244
 
248
- to_reduce = torch.zeros(shape)
245
+ if statistic is None:
246
+ to_reduce = torch.zeros(shape, dtype=torch.float64)
249
247
 
250
248
  elif isinstance(statistic, torch.Tensor):
251
- to_reduce = statistic.clone()
249
+ to_reduce = statistic.clone().to(torch.float64).view(-1)
252
250
 
253
251
  elif isinstance(statistic, TensorStatistics):
254
- to_reduce = statistic.to_tensor()
252
+ to_reduce = statistic.to_tensor().to(torch.float64).view(-1)
255
253
 
256
254
  else:
257
- to_reduce = torch.tensor([statistic])
255
+ to_reduce = torch.tensor([statistic], dtype=torch.float64)
258
256
 
257
+ to_reduce = to_reduce.to(torch_distributed_utils.get_local_device())
259
258
  dist.reduce(to_reduce, dst=MASTER_SCHEDULER_RANK, op=ReduceOp.SUM)
260
259
 
261
260
  if not torch_distributed_utils.is_scheduler_master_rank():
@@ -288,23 +287,21 @@ class Statistic(ABC):
288
287
  statistic = self._gather(
289
288
  optimizer=optimizer, model=model, parameters=parameters, parameter_group=parameter_group
290
289
  )
291
- statistic = self._distributed_reduce(statistic=statistic)
292
290
 
293
- if not torch_distributed_utils.is_scheduler_master_rank():
294
- return
291
+ statistic = self._distributed_reduce(statistic=statistic)
295
292
 
296
- if isinstance(statistic, torch.Tensor):
297
- statistic = statistic.view(-1)
298
- self._tensor_cache.append(statistic)
293
+ if torch_distributed_utils.is_scheduler_master_rank():
294
+ if isinstance(statistic, torch.Tensor):
295
+ statistic = statistic.view(-1)
296
+ self._tensor_cache.append(statistic)
299
297
 
300
- if len(self._tensor_cache) >= self.max_cache_size:
301
- self._process_tensor_cache()
298
+ if len(self._tensor_cache) >= self.max_cache_size:
299
+ self._process_tensor_cache()
302
300
 
303
- elif statistic is not None:
304
- self._data.append(statistic) # type: ignore
301
+ elif statistic is not None:
302
+ self._data.append(statistic) # type: ignore
305
303
 
306
- if torch_distributed_utils.is_scheduler_master_rank():
307
- self._sample_number += 1
304
+ self._sample_number += 1
308
305
 
309
306
  @final
310
307
  def fetch(self) -> TensorStatistics | float | None:
@@ -4,8 +4,10 @@
4
4
  #
5
5
  # ======================================================================================================================
6
6
 
7
+ import os
7
8
  from typing import Any
8
9
 
10
+ import torch
9
11
  import torch.distributed as dist
10
12
 
11
13
  # ======================================================================================================================
@@ -14,7 +16,11 @@ import torch.distributed as dist
14
16
  #
15
17
  # ======================================================================================================================
16
18
 
19
+ CUDA = "cuda"
20
+ CPU = "cpu"
21
+ CUDA_PREFIX = f"{CUDA}:"
17
22
  MASTER_SCHEDULER_RANK = 0
23
+ LOCAL_RANK = "LOCAL_RANK"
18
24
 
19
25
  # ======================================================================================================================
20
26
  #
@@ -48,7 +54,10 @@ def get_local_rank() -> int:
48
54
  :return: Distributed computing rank of this process.
49
55
  """
50
56
 
51
- return dist.get_rank() if is_distributed() else MASTER_SCHEDULER_RANK
57
+ if not is_distributed():
58
+ return MASTER_SCHEDULER_RANK
59
+
60
+ return dist.get_rank()
52
61
 
53
62
 
54
63
  def is_scheduler_master_rank() -> bool:
@@ -83,3 +92,15 @@ def barrier() -> None:
83
92
 
84
93
  if is_distributed():
85
94
  dist.barrier()
95
+
96
+
97
+ def get_local_device() -> torch.device:
98
+ """
99
+ :return: Local device of the current rank.
100
+ """
101
+
102
+ if not is_distributed():
103
+ return torch.device(CUDA if torch.cuda.is_available() else CPU)
104
+
105
+ local_device_rank = os.environ.get(LOCAL_RANK, MASTER_SCHEDULER_RANK)
106
+ return torch.device(f"{CUDA_PREFIX}{local_device_rank}" if torch.cuda.is_available() else CPU)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: libinephany
3
- Version: 0.15.0
3
+ Version: 0.15.2
4
4
  Summary: Inephany library containing code commonly used by multiple subpackages.
5
5
  Author-email: Inephany <info@inephany.com>
6
6
  License: Apache 2.0
@@ -2,14 +2,14 @@ libinephany/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  libinephany/aws/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  libinephany/aws/s3_functions.py,sha256=W8u85A6tDloo4FlJvydJbVHCUq_m9i8KDGdnKzy-Xpg,1745
4
4
  libinephany/observations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- libinephany/observations/observation_utils.py,sha256=wb6EZiaEiPuOqN26zzuT1rHyehoKh-8818KXn8pHweI,8688
5
+ libinephany/observations/observation_utils.py,sha256=pR7MM57KYYJxqRXsr9eMnhm7m_aGffH-dyFejnj2w_I,8899
6
6
  libinephany/observations/observer_pipeline.py,sha256=ZhONGXJQSgs2VJJn9d2F7ItkYqntvchl9-JTyxW9eU0,12146
7
7
  libinephany/observations/pipeline_coordinator.py,sha256=FrN3linKaC0pVE5uKjlh_0Fi8Mb1oK91NzH3Fq7PvyM,7420
8
8
  libinephany/observations/statistic_manager.py,sha256=LLg1zSxnJr2oQQepYla3qoUuRy10rsthr9jta4wEbnc,8956
9
- libinephany/observations/statistic_trackers.py,sha256=flkXquMHvY6YjnQAvRElsV5OUm7Ek_PhA1_fvtX-0oQ,30124
9
+ libinephany/observations/statistic_trackers.py,sha256=PvuIqCkeaiAzCUJOYLOxM-Dl655HH1kQgiK4kjpEIyo,30236
10
10
  libinephany/observations/observers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  libinephany/observations/observers/base_observers.py,sha256=RkG5SW0b6Ooy0_oscRHxyB_YFNP7k8fxu37jBZElxIM,15418
12
- libinephany/observations/observers/global_observers.py,sha256=-BJJaYjQSO82qskIlY_iijd3Lk1Ei1d3Hg1fzmYUPSM,38659
12
+ libinephany/observations/observers/global_observers.py,sha256=3TaiV2AxMOXfDq-kXMU3ZSo-rQENNCFhdWCJtpY99ok,38684
13
13
  libinephany/observations/observers/local_observers.py,sha256=EdivrylOcmxRsu4xiMwZqwmPX8Ru9-IRwoPk6En7qvw,37050
14
14
  libinephany/observations/observers/observer_containers.py,sha256=g73ScbRRVTNbGEBb-Nyk8AQwoDhKZaqTd6OYP8FIcOs,8771
15
15
  libinephany/observations/post_processors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -43,15 +43,15 @@ libinephany/utils/optim_utils.py,sha256=-PLqsyuq4ZH3spBy_olNB3yuLwvhnLrCF0384elC
43
43
  libinephany/utils/random_seeds.py,sha256=eF-ErrMShu8mp9V_gXrB_iUxR-Lb-OtHypEEUQAGn2Y,1565
44
44
  libinephany/utils/samplers.py,sha256=uyVGAy5cm5bCyWMOuySJmzUc_vFuieO_3zydJciwdv4,12158
45
45
  libinephany/utils/standardizers.py,sha256=pG1K_XL4OR_NjVtT6Hjbln1dk1BtQdDuSK1PQTkA17Y,8014
46
- libinephany/utils/torch_distributed_utils.py,sha256=ygdVz-s7hMRoBJcZkNRBlF81MYnxoRJt8S0SAwq6SC4,2467
46
+ libinephany/utils/torch_distributed_utils.py,sha256=UPMfhdZZwyHX_r3h55AAK4PcB-zFtjK37Z5aawAKNmE,2968
47
47
  libinephany/utils/torch_utils.py,sha256=o5TsqrXe6Id04P6SqB_avGBRZutbu6IBB61llAHQ_PY,2696
48
48
  libinephany/utils/transforms.py,sha256=Ca4pbCs_FbCpXb8M8oPxrP5QOqOAwGSdGpKzy5YUubc,3503
49
49
  libinephany/utils/typing.py,sha256=rGbaPO3MaUndsWiC_wHzReD_TOLYqb43i01pKN-j7Xs,624
50
50
  libinephany/web_apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
51
51
  libinephany/web_apps/error_logger.py,sha256=gAQIaqerqP4ornXZwFF1cghjnd2mMZEt3aVrTuUCr34,16653
52
52
  libinephany/web_apps/web_app_utils.py,sha256=qiq_lasPipgN1RgRudPJc342kYci8O_4RqppxmIX8NY,4095
53
- libinephany-0.15.0.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
54
- libinephany-0.15.0.dist-info/METADATA,sha256=lU7SqV1ArMEAyuZ845Z1jAYxNUEYGfJ8Tl6Df6EwSpc,8354
55
- libinephany-0.15.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
56
- libinephany-0.15.0.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
57
- libinephany-0.15.0.dist-info/RECORD,,
53
+ libinephany-0.15.2.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
54
+ libinephany-0.15.2.dist-info/METADATA,sha256=GU4yudoPYVfLXO5dLp8UNY2XMQMpSSdw8lP-ZZUaQy4,8354
55
+ libinephany-0.15.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
56
+ libinephany-0.15.2.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
57
+ libinephany-0.15.2.dist-info/RECORD,,