libinephany 0.15.1__py3-none-any.whl → 0.15.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,7 +14,7 @@ import torch
14
14
  import torch.optim as optim
15
15
 
16
16
  from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
17
- from libinephany.utils import optim_utils
17
+ from libinephany.utils import optim_utils, torch_distributed_utils
18
18
 
19
19
  # ======================================================================================================================
20
20
  #
@@ -173,7 +173,14 @@ def tensor_on_local_rank(tensor: torch.Tensor | None) -> bool:
173
173
  :return: Whether the tensor is owned by the local rank.
174
174
  """
175
175
 
176
- return tensor is not None and tensor.grad is not None and tensor.numel() > 0
176
+ valid_tensor = tensor is not None and tensor.grad is not None and tensor.numel() > 0
177
+
178
+ if valid_tensor and tensor.is_cuda:
179
+ local_rank = torch_distributed_utils.get_local_rank()
180
+
181
+ return tensor.device.index == local_rank
182
+
183
+ return valid_tensor
177
184
 
178
185
 
179
186
  def form_update_tensor(
@@ -193,13 +193,14 @@ class Statistic(ABC):
193
193
  Processes the tensor cache to build a TensorStatistic model.
194
194
  """
195
195
 
196
- concatenated = torch.cat(self._tensor_cache)
197
- self._tensor_cache = []
196
+ if self._tensor_cache:
197
+ concatenated = torch.cat(self._tensor_cache)
198
+ self._tensor_cache = []
198
199
 
199
- statistics = TensorStatistics.build(
200
- tensor=concatenated, skip_statistics=self.skip_statistics, sample_percentage=self.downsample_percent
201
- )
202
- self._data.append(statistics) # type: ignore
200
+ statistics = TensorStatistics.build(
201
+ tensor=concatenated, skip_statistics=self.skip_statistics, sample_percentage=self.downsample_percent
202
+ )
203
+ self._data.append(statistics) # type: ignore
203
204
 
204
205
  @staticmethod
205
206
  @final
@@ -281,7 +282,6 @@ class Statistic(ABC):
281
282
 
282
283
  parameter_group = self._find_parameter_group(optimizer=optimizer)
283
284
  parameters = self._get_parameters(parameter_group=parameter_group)
284
- self._sample_number += 1
285
285
 
286
286
  if self._sample_number % self.sample_frequency == 0:
287
287
  statistic = self._gather(
@@ -290,18 +290,18 @@ class Statistic(ABC):
290
290
 
291
291
  statistic = self._distributed_reduce(statistic=statistic)
292
292
 
293
- if not torch_distributed_utils.is_scheduler_master_rank():
294
- return
293
+ if torch_distributed_utils.is_scheduler_master_rank():
294
+ if isinstance(statistic, torch.Tensor):
295
+ statistic = statistic.view(-1)
296
+ self._tensor_cache.append(statistic)
295
297
 
296
- if isinstance(statistic, torch.Tensor):
297
- statistic = statistic.view(-1)
298
- self._tensor_cache.append(statistic)
298
+ if len(self._tensor_cache) >= self.max_cache_size:
299
+ self._process_tensor_cache()
299
300
 
300
- if len(self._tensor_cache) >= self.max_cache_size:
301
- self._process_tensor_cache()
301
+ elif statistic is not None:
302
+ self._data.append(statistic) # type: ignore
302
303
 
303
- elif statistic is not None:
304
- self._data.append(statistic) # type: ignore
304
+ self._sample_number += 1
305
305
 
306
306
  @final
307
307
  def fetch(self) -> TensorStatistics | float | None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: libinephany
3
- Version: 0.15.1
3
+ Version: 0.15.2
4
4
  Summary: Inephany library containing code commonly used by multiple subpackages.
5
5
  Author-email: Inephany <info@inephany.com>
6
6
  License: Apache 2.0
@@ -2,11 +2,11 @@ libinephany/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  libinephany/aws/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  libinephany/aws/s3_functions.py,sha256=W8u85A6tDloo4FlJvydJbVHCUq_m9i8KDGdnKzy-Xpg,1745
4
4
  libinephany/observations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- libinephany/observations/observation_utils.py,sha256=wb6EZiaEiPuOqN26zzuT1rHyehoKh-8818KXn8pHweI,8688
5
+ libinephany/observations/observation_utils.py,sha256=pR7MM57KYYJxqRXsr9eMnhm7m_aGffH-dyFejnj2w_I,8899
6
6
  libinephany/observations/observer_pipeline.py,sha256=ZhONGXJQSgs2VJJn9d2F7ItkYqntvchl9-JTyxW9eU0,12146
7
7
  libinephany/observations/pipeline_coordinator.py,sha256=FrN3linKaC0pVE5uKjlh_0Fi8Mb1oK91NzH3Fq7PvyM,7420
8
8
  libinephany/observations/statistic_manager.py,sha256=LLg1zSxnJr2oQQepYla3qoUuRy10rsthr9jta4wEbnc,8956
9
- libinephany/observations/statistic_trackers.py,sha256=J444i9EZ30vcYOEYqcDBzz7_UDpEE2hW_ISYBu_hwYc,30180
9
+ libinephany/observations/statistic_trackers.py,sha256=PvuIqCkeaiAzCUJOYLOxM-Dl655HH1kQgiK4kjpEIyo,30236
10
10
  libinephany/observations/observers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  libinephany/observations/observers/base_observers.py,sha256=RkG5SW0b6Ooy0_oscRHxyB_YFNP7k8fxu37jBZElxIM,15418
12
12
  libinephany/observations/observers/global_observers.py,sha256=3TaiV2AxMOXfDq-kXMU3ZSo-rQENNCFhdWCJtpY99ok,38684
@@ -50,8 +50,8 @@ libinephany/utils/typing.py,sha256=rGbaPO3MaUndsWiC_wHzReD_TOLYqb43i01pKN-j7Xs,6
50
50
  libinephany/web_apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
51
51
  libinephany/web_apps/error_logger.py,sha256=gAQIaqerqP4ornXZwFF1cghjnd2mMZEt3aVrTuUCr34,16653
52
52
  libinephany/web_apps/web_app_utils.py,sha256=qiq_lasPipgN1RgRudPJc342kYci8O_4RqppxmIX8NY,4095
53
- libinephany-0.15.1.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
54
- libinephany-0.15.1.dist-info/METADATA,sha256=lsqYtqyJ_k_clascJkzx8rR7gEN75tZ8lCKNqcH1cps,8354
55
- libinephany-0.15.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
56
- libinephany-0.15.1.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
57
- libinephany-0.15.1.dist-info/RECORD,,
53
+ libinephany-0.15.2.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
54
+ libinephany-0.15.2.dist-info/METADATA,sha256=GU4yudoPYVfLXO5dLp8UNY2XMQMpSSdw8lP-ZZUaQy4,8354
55
+ libinephany-0.15.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
56
+ libinephany-0.15.2.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
57
+ libinephany-0.15.2.dist-info/RECORD,,