libinephany 0.19.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -76,7 +76,7 @@ class Statistic(ABC):
76
76
  self.max_cache_size = max_statistic_cache_size
77
77
  self.downsample_percent = tensor_stats_downsample_percentage
78
78
  self.sample_frequency = statistic_sample_frequency
79
- self.skip_statistics: list[str] | None = None
79
+ self.include_statistics: list[str] | None = None
80
80
 
81
81
  @final
82
82
  @property
@@ -195,12 +195,17 @@ class Statistic(ABC):
195
195
  Processes the tensor cache to build a TensorStatistic model.
196
196
  """
197
197
 
198
+ if not self.include_statistics:
199
+ raise ValueError(f"{self.__class__.__name__} must be provided with include_statistics.")
200
+
198
201
  if self._tensor_cache:
199
202
  concatenated = torch.cat(self._tensor_cache)
200
203
  self._tensor_cache = []
201
204
 
202
205
  statistics = TensorStatistics.build(
203
- tensor=concatenated, skip_statistics=self.skip_statistics, sample_percentage=self.downsample_percent
206
+ tensor=concatenated,
207
+ include_statistics=self.include_statistics,
208
+ sample_percentage=self.downsample_percent,
204
209
  )
205
210
  self._data.append(statistics) # type: ignore
206
211
 
@@ -368,18 +373,18 @@ class FirstOrderGradients(Statistic):
368
373
  def __init__(
369
374
  self,
370
375
  *,
371
- skip_statistics: list[str] | None = None,
376
+ include_statistics: list[str] | None = None,
372
377
  **kwargs,
373
378
  ) -> None:
374
379
  """
375
- :param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
376
- fields in the model to not include in returned observations.
380
+ :param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
381
+ fields in the model to include in returned observations.
377
382
  :param kwargs: Other observation keyword arguments.
378
383
  """
379
384
 
380
385
  super().__init__(**kwargs)
381
386
 
382
- self.skip_statistics = skip_statistics
387
+ self.include_statistics = include_statistics
383
388
 
384
389
  def _get_storage_format(self) -> StatisticStorageTypes:
385
390
  """
@@ -422,22 +427,22 @@ class SecondOrderGradients(Statistic):
422
427
  def __init__(
423
428
  self,
424
429
  *,
430
+ include_statistics: list[str] | None = None,
425
431
  compute_hessian_diagonal: bool = False,
426
- skip_statistics: list[str] | None = None,
427
432
  **kwargs,
428
433
  ) -> None:
429
434
  """
435
+ :param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
436
+ fields in the model to include in returned observations.
430
437
  :param compute_hessian_diagonal: Whether to compute the Hessian diagonal to determine second order gradients
431
438
  or use the squared first order gradients as approximations in the same way Adam does.
432
- :param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
433
- fields in the model to not include in returned observations.
434
439
  :param kwargs: Other observation keyword arguments.
435
440
  """
436
441
 
437
442
  super().__init__(**kwargs)
438
443
 
439
444
  self.compute_hessian_diagonal = compute_hessian_diagonal
440
- self.skip_statistics = skip_statistics
445
+ self.include_statistics = include_statistics
441
446
 
442
447
  @property
443
448
  def requires_gradient_graphs(self) -> bool:
@@ -520,18 +525,18 @@ class ActivationStatistics(Statistic):
520
525
  def __init__(
521
526
  self,
522
527
  *,
523
- skip_statistics: list[str] | None = None,
528
+ include_statistics: list[str] | None = None,
524
529
  **kwargs,
525
530
  ) -> None:
526
531
  """
527
- :param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
528
- fields in the model to not include in returned observations.
532
+ :param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
533
+ fields in the model to include in returned observations.
529
534
  :param kwargs: Other observation keyword arguments.
530
535
  """
531
536
 
532
537
  super().__init__(**kwargs)
533
538
 
534
- self.skip_statistics = skip_statistics
539
+ self.include_statistics = include_statistics
535
540
 
536
541
  @property
537
542
  def uses_forward_hook(self) -> bool:
@@ -554,6 +559,9 @@ class ActivationStatistics(Statistic):
554
559
  :return: Forward hook to register the function with.
555
560
  """
556
561
 
562
+ if self.include_statistics is None:
563
+ raise ValueError("include_statistics is required to use forward hooks!")
564
+
557
565
  def hook(module: nn.Module, layer_input: torch.Tensor, layer_output: torch.Tensor) -> None:
558
566
  """
559
567
  :param module: Module the hook was registered with. Not used here.
@@ -563,7 +571,9 @@ class ActivationStatistics(Statistic):
563
571
 
564
572
  if self._sample_number % self.sample_frequency == 0:
565
573
  statistics = TensorStatistics.build(
566
- tensor=layer_output, skip_statistics=self.skip_statistics, sample_percentage=self.downsample_percent
574
+ tensor=layer_output,
575
+ include_statistics=self.include_statistics,
576
+ sample_percentage=self.downsample_percent,
567
577
  )
568
578
  self._data.append(statistics) # type: ignore
569
579
 
@@ -598,23 +608,23 @@ class ActivationStatistics(Statistic):
598
608
  return None
599
609
 
600
610
 
601
- class InnerStepParameterUpdateStatistics(Statistic):
611
+ class ParameterUpdateStatistics(Statistic):
602
612
 
603
613
  def __init__(
604
614
  self,
605
615
  *,
606
- skip_statistics: list[str] | None = None,
616
+ include_statistics: list[str] | None = None,
607
617
  **kwargs,
608
618
  ) -> None:
609
619
  """
610
- :param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
611
- fields in the model to not include in returned observations.
620
+ :param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
621
+ fields in the model to include in returned observations.
612
622
  :param kwargs: Other observation keyword arguments.
613
623
  """
614
624
 
615
625
  super().__init__(**kwargs)
616
626
 
617
- self.skip_statistics = skip_statistics
627
+ self.include_statistics = include_statistics
618
628
 
619
629
  def _get_storage_format(self) -> StatisticStorageTypes:
620
630
  """
@@ -650,56 +660,9 @@ class InnerStepParameterUpdateStatistics(Statistic):
650
660
  return update_tensor
651
661
 
652
662
 
653
- class ParameterUpdateStatistics(Statistic):
654
-
655
- def __init__(
656
- self,
657
- *,
658
- skip_statistics: list[str] | None = None,
659
- **kwargs,
660
- ) -> None:
661
- """
662
- :param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
663
- fields in the model to not include in returned observations.
664
- :param kwargs: Other observation keyword arguments.
665
- """
666
-
667
- super().__init__(**kwargs)
668
-
669
- self.skip_statistics = skip_statistics
670
-
671
- def _get_storage_format(self) -> StatisticStorageTypes:
672
- """
673
- :return: Storage format this observation stores data in. Must be one of the enum attributes in the
674
- StatisticStorageTypes enumeration class.
675
- """
676
-
677
- return StatisticStorageTypes.TENSOR_STATISTICS
678
-
679
- def _gather(
680
- self,
681
- *,
682
- optimizer: optim.Optimizer,
683
- model: nn.Module,
684
- parameters: list[torch.Tensor],
685
- parameter_group: dict[str, Any],
686
- ) -> torch.Tensor | TensorStatistics | float | None:
687
- """
688
- :param optimizer: Optimizer the given parameters and parameter group came from.
689
- :param model: Inner model to gather statistics from.
690
- :param parameters: List of parameters to gather statistics from.
691
- :param parameter_group: Parameter group the parameters originate from.
692
- :return: None, TensorStatistics model or a float.
693
- """
694
-
695
- update_tensor = observation_utils.form_update_tensor(
696
- optimizer=optimizer, parameters=parameters, parameter_group=parameter_group
697
- )
698
-
699
- if update_tensor is None:
700
- update_tensor = torch.cat([torch.zeros(p.view(-1).shape, device=p.device) for p in parameters])
663
+ class LHOPTParameterUpdateStatistics(ParameterUpdateStatistics):
701
664
 
702
- return update_tensor
665
+ pass
703
666
 
704
667
 
705
668
  class ParameterStatistics(Statistic):
@@ -707,18 +670,18 @@ class ParameterStatistics(Statistic):
707
670
  def __init__(
708
671
  self,
709
672
  *,
710
- skip_statistics: list[str] | None = None,
673
+ include_statistics: list[str] | None = None,
711
674
  **kwargs,
712
675
  ) -> None:
713
676
  """
714
- :param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
715
- fields in the model to not include in returned observations.
677
+ :param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
678
+ fields in the model to include in returned observations.
716
679
  :param kwargs: Other observation keyword arguments.
717
680
  """
718
681
 
719
682
  super().__init__(**kwargs)
720
683
 
721
- self.skip_statistics = skip_statistics
684
+ self.include_statistics = include_statistics
722
685
 
723
686
  def _get_storage_format(self) -> StatisticStorageTypes:
724
687
  """
@@ -747,49 +710,9 @@ class ParameterStatistics(Statistic):
747
710
  return torch.cat([p.data.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p)])
748
711
 
749
712
 
750
- class InnerStepParameterStatistics(Statistic):
713
+ class LHOPTParameterStatistics(ParameterStatistics):
751
714
 
752
- def __init__(
753
- self,
754
- *,
755
- skip_statistics: list[str] | None = None,
756
- **kwargs,
757
- ) -> None:
758
- """
759
- :param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
760
- fields in the model to not include in returned observations.
761
- :param kwargs: Other observation keyword arguments.
762
- """
763
-
764
- super().__init__(**kwargs)
765
-
766
- self.skip_statistics = skip_statistics
767
-
768
- def _get_storage_format(self) -> StatisticStorageTypes:
769
- """
770
- :return: Storage format this observation stores data in. Must be one of the enum attributes in the
771
- StatisticStorageTypes enumeration class.
772
- """
773
-
774
- return StatisticStorageTypes.TENSOR_STATISTICS
775
-
776
- def _gather(
777
- self,
778
- *,
779
- optimizer: optim.Optimizer,
780
- model: nn.Module,
781
- parameters: list[torch.Tensor],
782
- parameter_group: dict[str, Any],
783
- ) -> torch.Tensor | TensorStatistics | float | None:
784
- """
785
- :param optimizer: Optimizer the given parameters and parameter group came from.
786
- :param model: Inner model to gather statistics from.
787
- :param parameters: List of parameters to gather statistics from.
788
- :param parameter_group: Parameter group the parameters originate from.
789
- :return: None, TensorStatistics model or a float.
790
- """
791
-
792
- return torch.cat([p.data.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p)])
715
+ pass
793
716
 
794
717
 
795
718
  class LAMBTrustRatioStatistics(Statistic):
@@ -797,16 +720,20 @@ class LAMBTrustRatioStatistics(Statistic):
797
720
  def __init__(
798
721
  self,
799
722
  *,
723
+ include_statistics: list[str] | None = None,
800
724
  use_log_transform: bool = False,
801
725
  **kwargs,
802
726
  ) -> None:
803
727
  """
728
+ :param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
729
+ fields in the model to include in returned observations.
804
730
  :param use_log_transform: Whether to transform the LAMB trust ratio by taking ln(1 + R).
805
731
  :param kwargs: Other observation keyword arguments.
806
732
  """
807
733
 
808
734
  super().__init__(**kwargs)
809
735
 
736
+ self.include_statistics = include_statistics
810
737
  self.use_log_transform = use_log_transform
811
738
 
812
739
  def _get_storage_format(self) -> StatisticStorageTypes:
@@ -857,69 +784,9 @@ class LAMBTrustRatioStatistics(Statistic):
857
784
  return lamb_trust_ratio
858
785
 
859
786
 
860
- class LHOPTLAMBTrustRatioStatistics(Statistic):
787
+ class LHOPTLAMBTrustRatioStatistics(LAMBTrustRatioStatistics):
861
788
 
862
- def __init__(
863
- self,
864
- *,
865
- use_log_transform: bool = False,
866
- **kwargs,
867
- ) -> None:
868
- """
869
- :param use_log_transform: Whether to transform the LAMB trust ratio by taking ln(1 + R).
870
- :param kwargs: Other observation keyword arguments.
871
- """
872
-
873
- super().__init__(**kwargs)
874
-
875
- self.use_log_transform = use_log_transform
876
-
877
- def _get_storage_format(self) -> StatisticStorageTypes:
878
- """
879
- :return: Storage format this observation stores data in. Must be one of the enum attributes in the
880
- StatisticStorageTypes enumeration class.
881
- """
882
-
883
- return StatisticStorageTypes.FLOAT
884
-
885
- def _gather(
886
- self,
887
- *,
888
- optimizer: optim.Optimizer,
889
- model: nn.Module,
890
- parameters: list[torch.Tensor],
891
- parameter_group: dict[str, Any],
892
- ) -> torch.Tensor | TensorStatistics | float | None:
893
- """
894
- :param optimizer: Optimizer the given parameters and parameter group came from.
895
- :param model: Inner model to gather statistics from.
896
- :param parameters: List of parameters to gather statistics from.
897
- :param parameter_group: Parameter group the parameters originate from.
898
- :return: None, TensorStatistics model or a float.
899
- """
900
-
901
- weights_list = [p.data.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p)]
902
- if weights_list:
903
- weights = torch.cat(weights_list)
904
-
905
- else:
906
- weights = None
907
-
908
- updates = observation_utils.form_update_tensor(
909
- optimizer=optimizer, parameters=parameters, parameter_group=parameter_group
910
- )
911
-
912
- update_norm = torch.norm(updates, p=2).item() if updates is not None else 0
913
- weight_norm = torch.norm(weights, p=2).item() if weights is not None else 0
914
-
915
- lamb_trust_ratio = 0.0
916
- if update_norm > 0:
917
- lamb_trust_ratio = weight_norm / update_norm
918
-
919
- if self.use_log_transform:
920
- lamb_trust_ratio = math.log(1 + lamb_trust_ratio)
921
-
922
- return lamb_trust_ratio
789
+ pass
923
790
 
924
791
 
925
792
  class NumberOfParameters(Statistic):
@@ -1125,22 +992,6 @@ class GradientVarianceFraction(Statistic):
1125
992
 
1126
993
  class AverageParameterUpdateMagnitudeStatistics(Statistic):
1127
994
 
1128
- def __init__(
1129
- self,
1130
- *,
1131
- skip_statistics: list[str] | None = None,
1132
- **kwargs,
1133
- ) -> None:
1134
- """
1135
- :param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
1136
- fields in the model to not include in returned observations.
1137
- :param kwargs: Other observation keyword arguments.
1138
- """
1139
-
1140
- super().__init__(**kwargs)
1141
-
1142
- self.skip_statistics = skip_statistics
1143
-
1144
995
  def _get_storage_format(self) -> StatisticStorageTypes:
1145
996
  """
1146
997
  :return: Storage format this observation stores data in. Must be one of the enum attributes in the
@@ -1183,22 +1034,6 @@ class AverageParameterUpdateMagnitudeStatistics(Statistic):
1183
1034
 
1184
1035
  class MomentumGradientRatioStatistics(Statistic):
1185
1036
 
1186
- def __init__(
1187
- self,
1188
- *,
1189
- skip_statistics: list[str] | None = None,
1190
- **kwargs,
1191
- ) -> None:
1192
- """
1193
- :param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
1194
- fields in the model to not include in returned observations.
1195
- :param kwargs: Other observation keyword arguments.
1196
- """
1197
-
1198
- super().__init__(**kwargs)
1199
-
1200
- self.skip_statistics = skip_statistics
1201
-
1202
1037
  def _get_storage_format(self) -> StatisticStorageTypes:
1203
1038
  """
1204
1039
  :return: Storage format this observation stores data in. Must be one of the enum attributes in the
@@ -1263,26 +1098,8 @@ class LogOfNoiseScaleStatistics(Statistic):
1263
1098
  - Σ is the noise covariance matrix
1264
1099
  - B is the batch size
1265
1100
  - ε is the learning rate
1266
-
1267
-
1268
1101
  """
1269
1102
 
1270
- def __init__(
1271
- self,
1272
- *,
1273
- skip_statistics: list[str] | None = None,
1274
- **kwargs,
1275
- ) -> None:
1276
- """
1277
- :param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
1278
- fields in the model to not include in returned observations.
1279
- :param kwargs: Other observation keyword arguments.
1280
- """
1281
-
1282
- super().__init__(**kwargs)
1283
-
1284
- self.skip_statistics = skip_statistics
1285
-
1286
1103
  @property
1287
1104
  def requires_gradient_graphs(self) -> bool:
1288
1105
  """
@@ -1370,19 +1187,17 @@ class LogOfNoiseScaleStatistics(Statistic):
1370
1187
  # This is a common assumption when the exact noise structure is unknown
1371
1188
  noise_covariance = torch.ones_like(hessian_diagonals)
1372
1189
 
1373
- # Compute tr(HΣ)
1374
- trace_hessian_noise_covariance = torch.sum(hessian_diagonals * noise_covariance)
1375
-
1376
- # Avoid division by zero and log of zero
1377
- if trace_hessian_noise_covariance <= 0:
1378
- return None
1190
+ # Compute tr(HΣ), add zero division tolerance to avoid log of zero when gradient is too small
1191
+ trace_hessian_noise_covariance = (
1192
+ torch.sum(hessian_diagonals * noise_covariance) + LHOPT_CONSTANTS["ZERO_DIVISION_TOLERANCE"]
1193
+ )
1379
1194
 
1380
1195
  log_trace_hessian_noise_covariance = torch.log(trace_hessian_noise_covariance).item()
1381
1196
 
1382
- # Compute tr(H^3 Σ)
1383
- trace_hessian_cubed_noise_covariance = torch.sum(hessian_diagonals**3 * noise_covariance)
1384
- if trace_hessian_cubed_noise_covariance <= 0:
1385
- return None
1197
+ # Compute tr(H^3 Σ), add zero division tolerance to avoid log of zero when gradient is too small
1198
+ trace_hessian_cubed_noise_covariance = (
1199
+ torch.sum(hessian_diagonals**3 * noise_covariance) + LHOPT_CONSTANTS["ZERO_DIVISION_TOLERANCE"]
1200
+ )
1386
1201
 
1387
1202
  log_trace_hessian_cubed_noise_covariance = torch.log(trace_hessian_cubed_noise_covariance).item()
1388
1203
 
@@ -13,7 +13,7 @@ from pydantic import BaseModel
13
13
  #
14
14
  # ======================================================================================================================
15
15
 
16
- STRIP_SUFFIX = "_"
16
+ FIELD_SUFFIX = "_"
17
17
 
18
18
  # ======================================================================================================================
19
19
  #
@@ -164,28 +164,38 @@ class TensorStatistics(BaseModel):
164
164
  return tensor[random_indices]
165
165
 
166
166
  @classmethod
167
- def filter_skip_statistics(cls, skip_statistics: list[str] | None) -> list[str]:
167
+ def filter_include_statistics(cls, include_statistics: list[str]) -> list[str]:
168
168
  """
169
- :param skip_statistics: Names of the fields in the model to not include in returned observations.
170
- :return: Empty list if skip_statistics was None or skip_statistics filtered to include only the names of fields
171
- present in this pydantic model.
169
+ :param include_statistics: Names of the fields in the model to include in returned observations.
170
+ :return: List of fields from the given include_statistics list that are present in this pydantic model.
171
+ :raises ValueError: If no statistics to include are given.
172
172
  """
173
173
 
174
- return (
175
- [skip_stat for skip_stat in skip_statistics if skip_stat in cls.model_fields.keys()]
176
- if skip_statistics is not None
177
- else []
178
- )
174
+ filtered_include_statistics: list[str] = []
175
+
176
+ for include_stat in include_statistics:
177
+ with_suffix = include_stat + FIELD_SUFFIX if not include_stat.endswith(FIELD_SUFFIX) else include_stat
178
+
179
+ if with_suffix in cls.model_fields.keys():
180
+ filtered_include_statistics.append(with_suffix)
181
+
182
+ if not filtered_include_statistics:
183
+ raise ValueError(f"No statistics to include given to {cls.__name__}!")
184
+
185
+ return filtered_include_statistics
179
186
 
180
187
  @classmethod
181
188
  def build(
182
- cls, tensor: torch.Tensor, sample_percentage: float = 0.01, skip_statistics: list[str] | None = None
189
+ cls,
190
+ tensor: torch.Tensor,
191
+ include_statistics: list[str],
192
+ sample_percentage: float = 0.01,
183
193
  ) -> "TensorStatistics":
184
194
  """
185
195
  :param tensor: Tensor to compute and store statistics of.
196
+ :param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
197
+ fields in the model to include in returned observations.
186
198
  :param sample_percentage: Percentage of the given tensor to randomly sample and compute statistics from.
187
- :param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
188
- fields in the model to not include in returned observations.
189
199
  :return: Constructed tensor statistics.
190
200
  """
191
201
 
@@ -193,12 +203,10 @@ class TensorStatistics(BaseModel):
193
203
  downsampled_tensor = cls.downsample_tensor(tensor=tensor, sample_percentage=sample_percentage)
194
204
 
195
205
  for field, field_value in stats.model_dump().items():
196
- name = field[:-1] if field.endswith("_") else field
197
-
198
- if skip_statistics is not None and name in skip_statistics:
199
- continue
206
+ name = field[:-1] if field.endswith(FIELD_SUFFIX) else field
200
207
 
201
- setattr(stats, name, downsampled_tensor)
208
+ if name in include_statistics:
209
+ setattr(stats, name, downsampled_tensor)
202
210
 
203
211
  return stats
204
212
 
@@ -219,28 +227,21 @@ class TensorStatistics(BaseModel):
219
227
  inter_quartile_range_=tensor[6],
220
228
  )
221
229
 
222
- def to_list(self, skip_statistics: list[str] | None) -> list[float]:
230
+ def to_list(self, include_statistics: list[str]) -> list[float]:
223
231
  """
224
- :param skip_statistics: None or a list of field names to skip from adding to the returned list.
232
+ :param include_statistics: List of field names to include in the returned list.
225
233
  :return: List of field values.
226
234
  """
227
235
 
228
- if skip_statistics is None:
229
- skip_statistics = []
230
-
231
- if not all(skip_stat in self.model_fields.keys() for skip_stat in skip_statistics):
232
- raise ValueError(
233
- f"One or more skip statistic keys do not exist in TensorStatistics. Valid Skip Keys: "
234
- f"{list(self.model_fields.keys())} Given Skip Keys: {skip_statistics}"
235
- )
236
+ filtered_includes = self.filter_include_statistics(include_statistics=include_statistics)
236
237
 
237
238
  as_list = []
238
239
 
239
240
  for field, field_value in self.model_dump().items():
240
- if field in skip_statistics:
241
- continue
241
+ without_suffix = field[:-1]
242
242
 
243
- as_list.append(field_value)
243
+ if field in filtered_includes or without_suffix in filtered_includes:
244
+ as_list.append(field_value)
244
245
 
245
246
  return as_list
246
247
 
@@ -267,6 +268,6 @@ class TensorStatistics(BaseModel):
267
268
  """
268
269
 
269
270
  return {
270
- field[:-1] if field.endswith(STRIP_SUFFIX) else field: field_value
271
+ field[:-1] if field.endswith(FIELD_SUFFIX) else field: field_value
271
272
  for field, field_value in self.model_dump().items()
272
273
  }