libinephany 0.15.1__py3-none-any.whl → 0.15.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -173,7 +173,21 @@ def tensor_on_local_rank(tensor: torch.Tensor | None) -> bool:
173
173
  :return: Whether the tensor is owned by the local rank.
174
174
  """
175
175
 
176
- return tensor is not None and tensor.grad is not None and tensor.numel() > 0
176
+ valid_tensor = tensor is not None and tensor.numel() > 0
177
+ device_index = tensor.device.index if valid_tensor else None
178
+
179
+ xpu_available = torch.xpu.is_available()
180
+ cuda_available = torch.cuda.is_available()
181
+
182
+ if valid_tensor and xpu_available and tensor.is_xpu:
183
+ current_device = torch.xpu.current_device()
184
+ return device_index == current_device
185
+
186
+ elif valid_tensor and cuda_available and tensor.is_cuda:
187
+ current_device = torch.cuda.current_device()
188
+ return device_index == current_device
189
+
190
+ return valid_tensor
177
191
 
178
192
 
179
193
  def form_update_tensor(
@@ -193,13 +193,14 @@ class Statistic(ABC):
193
193
  Processes the tensor cache to build a TensorStatistic model.
194
194
  """
195
195
 
196
- concatenated = torch.cat(self._tensor_cache)
197
- self._tensor_cache = []
196
+ if self._tensor_cache:
197
+ concatenated = torch.cat(self._tensor_cache)
198
+ self._tensor_cache = []
198
199
 
199
- statistics = TensorStatistics.build(
200
- tensor=concatenated, skip_statistics=self.skip_statistics, sample_percentage=self.downsample_percent
201
- )
202
- self._data.append(statistics) # type: ignore
200
+ statistics = TensorStatistics.build(
201
+ tensor=concatenated, skip_statistics=self.skip_statistics, sample_percentage=self.downsample_percent
202
+ )
203
+ self._data.append(statistics) # type: ignore
203
204
 
204
205
  @staticmethod
205
206
  @final
@@ -281,7 +282,6 @@ class Statistic(ABC):
281
282
 
282
283
  parameter_group = self._find_parameter_group(optimizer=optimizer)
283
284
  parameters = self._get_parameters(parameter_group=parameter_group)
284
- self._sample_number += 1
285
285
 
286
286
  if self._sample_number % self.sample_frequency == 0:
287
287
  statistic = self._gather(
@@ -290,18 +290,18 @@ class Statistic(ABC):
290
290
 
291
291
  statistic = self._distributed_reduce(statistic=statistic)
292
292
 
293
- if not torch_distributed_utils.is_scheduler_master_rank():
294
- return
293
+ if torch_distributed_utils.is_scheduler_master_rank():
294
+ if isinstance(statistic, torch.Tensor):
295
+ statistic = statistic.view(-1)
296
+ self._tensor_cache.append(statistic)
295
297
 
296
- if isinstance(statistic, torch.Tensor):
297
- statistic = statistic.view(-1)
298
- self._tensor_cache.append(statistic)
298
+ if len(self._tensor_cache) >= self.max_cache_size:
299
+ self._process_tensor_cache()
299
300
 
300
- if len(self._tensor_cache) >= self.max_cache_size:
301
- self._process_tensor_cache()
301
+ elif statistic is not None:
302
+ self._data.append(statistic) # type: ignore
302
303
 
303
- elif statistic is not None:
304
- self._data.append(statistic) # type: ignore
304
+ self._sample_number += 1
305
305
 
306
306
  @final
307
307
  def fetch(self) -> TensorStatistics | float | None:
@@ -403,7 +403,9 @@ class FirstOrderGradients(Statistic):
403
403
  :return: None, TensorStatistics model or a float.
404
404
  """
405
405
 
406
- gradients = [p.grad.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p)]
406
+ gradients = [
407
+ p.grad.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p) and p.grad is not None
408
+ ]
407
409
 
408
410
  if not gradients:
409
411
  return None
@@ -495,7 +497,9 @@ class SecondOrderGradients(Statistic):
495
497
  :return: None, TensorStatistics model or a float.
496
498
  """
497
499
 
498
- fo_gradients = [p.grad.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p)]
500
+ fo_gradients = [
501
+ p.grad.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p) and p.grad is not None
502
+ ]
499
503
 
500
504
  if not fo_gradients:
501
505
  return None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: libinephany
3
- Version: 0.15.1
3
+ Version: 0.15.3
4
4
  Summary: Inephany library containing code commonly used by multiple subpackages.
5
5
  Author-email: Inephany <info@inephany.com>
6
6
  License: Apache 2.0
@@ -2,11 +2,11 @@ libinephany/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  libinephany/aws/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  libinephany/aws/s3_functions.py,sha256=W8u85A6tDloo4FlJvydJbVHCUq_m9i8KDGdnKzy-Xpg,1745
4
4
  libinephany/observations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- libinephany/observations/observation_utils.py,sha256=wb6EZiaEiPuOqN26zzuT1rHyehoKh-8818KXn8pHweI,8688
5
+ libinephany/observations/observation_utils.py,sha256=IidbKlZ4YaLnUFwMFcUdHVQjxEu78bvpD_nqaCnrgTA,9168
6
6
  libinephany/observations/observer_pipeline.py,sha256=ZhONGXJQSgs2VJJn9d2F7ItkYqntvchl9-JTyxW9eU0,12146
7
7
  libinephany/observations/pipeline_coordinator.py,sha256=FrN3linKaC0pVE5uKjlh_0Fi8Mb1oK91NzH3Fq7PvyM,7420
8
8
  libinephany/observations/statistic_manager.py,sha256=LLg1zSxnJr2oQQepYla3qoUuRy10rsthr9jta4wEbnc,8956
9
- libinephany/observations/statistic_trackers.py,sha256=J444i9EZ30vcYOEYqcDBzz7_UDpEE2hW_ISYBu_hwYc,30180
9
+ libinephany/observations/statistic_trackers.py,sha256=PUBqGgMRi51SmiNh5HAH5kpYxsaflRepmM-uKyMiQZg,30326
10
10
  libinephany/observations/observers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  libinephany/observations/observers/base_observers.py,sha256=RkG5SW0b6Ooy0_oscRHxyB_YFNP7k8fxu37jBZElxIM,15418
12
12
  libinephany/observations/observers/global_observers.py,sha256=3TaiV2AxMOXfDq-kXMU3ZSo-rQENNCFhdWCJtpY99ok,38684
@@ -50,8 +50,8 @@ libinephany/utils/typing.py,sha256=rGbaPO3MaUndsWiC_wHzReD_TOLYqb43i01pKN-j7Xs,6
50
50
  libinephany/web_apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
51
51
  libinephany/web_apps/error_logger.py,sha256=gAQIaqerqP4ornXZwFF1cghjnd2mMZEt3aVrTuUCr34,16653
52
52
  libinephany/web_apps/web_app_utils.py,sha256=qiq_lasPipgN1RgRudPJc342kYci8O_4RqppxmIX8NY,4095
53
- libinephany-0.15.1.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
54
- libinephany-0.15.1.dist-info/METADATA,sha256=lsqYtqyJ_k_clascJkzx8rR7gEN75tZ8lCKNqcH1cps,8354
55
- libinephany-0.15.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
56
- libinephany-0.15.1.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
57
- libinephany-0.15.1.dist-info/RECORD,,
53
+ libinephany-0.15.3.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
54
+ libinephany-0.15.3.dist-info/METADATA,sha256=UuSiJtqLBermVKve6SNjBQz_ViSvMVf5NEootGJNp4w,8354
55
+ libinephany-0.15.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
56
+ libinephany-0.15.3.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
57
+ libinephany-0.15.3.dist-info/RECORD,,