libinephany 0.18.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- libinephany/observations/observation_utils.py +19 -2
- libinephany/observations/observers/base_observers.py +20 -8
- libinephany/observations/observers/global_observers/__init__.py +19 -1
- libinephany/observations/observers/global_observers/constants.py +2 -0
- libinephany/observations/observers/global_observers/gradient_observers.py +320 -3
- libinephany/observations/observers/global_observers/hyperparameter_observers.py +26 -18
- libinephany/observations/observers/global_observers/model_observers.py +220 -6
- libinephany/observations/observers/global_observers/progress_observers.py +7 -1
- libinephany/observations/observers/local_observers.py +158 -25
- libinephany/observations/statistic_trackers.py +435 -23
- libinephany/pydantic_models/schemas/tensor_statistics.py +33 -32
- libinephany/pydantic_models/states/hyperparameter_states.py +32 -30
- {libinephany-0.18.1.dist-info → libinephany-1.0.0.dist-info}/METADATA +1 -1
- {libinephany-0.18.1.dist-info → libinephany-1.0.0.dist-info}/RECORD +17 -17
- {libinephany-0.18.1.dist-info → libinephany-1.0.0.dist-info}/WHEEL +0 -0
- {libinephany-0.18.1.dist-info → libinephany-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {libinephany-0.18.1.dist-info → libinephany-1.0.0.dist-info}/top_level.txt +0 -0
@@ -11,6 +11,7 @@ from typing import Any, Callable, final
|
|
11
11
|
import torch
|
12
12
|
import torch.distributed as dist
|
13
13
|
import torch.nn as nn
|
14
|
+
import torch.nn.functional as F
|
14
15
|
import torch.optim as optim
|
15
16
|
from torch.distributed import ReduceOp
|
16
17
|
|
@@ -75,7 +76,7 @@ class Statistic(ABC):
|
|
75
76
|
self.max_cache_size = max_statistic_cache_size
|
76
77
|
self.downsample_percent = tensor_stats_downsample_percentage
|
77
78
|
self.sample_frequency = statistic_sample_frequency
|
78
|
-
self.
|
79
|
+
self.include_statistics: list[str] | None = None
|
79
80
|
|
80
81
|
@final
|
81
82
|
@property
|
@@ -194,12 +195,17 @@ class Statistic(ABC):
|
|
194
195
|
Processes the tensor cache to build a TensorStatistic model.
|
195
196
|
"""
|
196
197
|
|
198
|
+
if not self.include_statistics:
|
199
|
+
raise ValueError(f"{self.__class__.__name__} must be provided with include_statistics.")
|
200
|
+
|
197
201
|
if self._tensor_cache:
|
198
202
|
concatenated = torch.cat(self._tensor_cache)
|
199
203
|
self._tensor_cache = []
|
200
204
|
|
201
205
|
statistics = TensorStatistics.build(
|
202
|
-
tensor=concatenated,
|
206
|
+
tensor=concatenated,
|
207
|
+
include_statistics=self.include_statistics,
|
208
|
+
sample_percentage=self.downsample_percent,
|
203
209
|
)
|
204
210
|
self._data.append(statistics) # type: ignore
|
205
211
|
|
@@ -367,18 +373,18 @@ class FirstOrderGradients(Statistic):
|
|
367
373
|
def __init__(
|
368
374
|
self,
|
369
375
|
*,
|
370
|
-
|
376
|
+
include_statistics: list[str] | None = None,
|
371
377
|
**kwargs,
|
372
378
|
) -> None:
|
373
379
|
"""
|
374
|
-
:param
|
375
|
-
fields in the model to
|
380
|
+
:param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
381
|
+
fields in the model to include in returned observations.
|
376
382
|
:param kwargs: Other observation keyword arguments.
|
377
383
|
"""
|
378
384
|
|
379
385
|
super().__init__(**kwargs)
|
380
386
|
|
381
|
-
self.
|
387
|
+
self.include_statistics = include_statistics
|
382
388
|
|
383
389
|
def _get_storage_format(self) -> StatisticStorageTypes:
|
384
390
|
"""
|
@@ -421,22 +427,22 @@ class SecondOrderGradients(Statistic):
|
|
421
427
|
def __init__(
|
422
428
|
self,
|
423
429
|
*,
|
430
|
+
include_statistics: list[str] | None = None,
|
424
431
|
compute_hessian_diagonal: bool = False,
|
425
|
-
skip_statistics: list[str] | None = None,
|
426
432
|
**kwargs,
|
427
433
|
) -> None:
|
428
434
|
"""
|
435
|
+
:param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
436
|
+
fields in the model to include in returned observations.
|
429
437
|
:param compute_hessian_diagonal: Whether to compute the Hessian diagonal to determine second order gradients
|
430
438
|
or use the squared first order gradients as approximations in the same way Adam does.
|
431
|
-
:param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
432
|
-
fields in the model to not include in returned observations.
|
433
439
|
:param kwargs: Other observation keyword arguments.
|
434
440
|
"""
|
435
441
|
|
436
442
|
super().__init__(**kwargs)
|
437
443
|
|
438
444
|
self.compute_hessian_diagonal = compute_hessian_diagonal
|
439
|
-
self.
|
445
|
+
self.include_statistics = include_statistics
|
440
446
|
|
441
447
|
@property
|
442
448
|
def requires_gradient_graphs(self) -> bool:
|
@@ -519,18 +525,18 @@ class ActivationStatistics(Statistic):
|
|
519
525
|
def __init__(
|
520
526
|
self,
|
521
527
|
*,
|
522
|
-
|
528
|
+
include_statistics: list[str] | None = None,
|
523
529
|
**kwargs,
|
524
530
|
) -> None:
|
525
531
|
"""
|
526
|
-
:param
|
527
|
-
fields in the model to
|
532
|
+
:param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
533
|
+
fields in the model to include in returned observations.
|
528
534
|
:param kwargs: Other observation keyword arguments.
|
529
535
|
"""
|
530
536
|
|
531
537
|
super().__init__(**kwargs)
|
532
538
|
|
533
|
-
self.
|
539
|
+
self.include_statistics = include_statistics
|
534
540
|
|
535
541
|
@property
|
536
542
|
def uses_forward_hook(self) -> bool:
|
@@ -553,6 +559,9 @@ class ActivationStatistics(Statistic):
|
|
553
559
|
:return: Forward hook to register the function with.
|
554
560
|
"""
|
555
561
|
|
562
|
+
if self.include_statistics is None:
|
563
|
+
raise ValueError("include_statistics is required to use forward hooks!")
|
564
|
+
|
556
565
|
def hook(module: nn.Module, layer_input: torch.Tensor, layer_output: torch.Tensor) -> None:
|
557
566
|
"""
|
558
567
|
:param module: Module the hook was registered with. Not used here.
|
@@ -562,7 +571,9 @@ class ActivationStatistics(Statistic):
|
|
562
571
|
|
563
572
|
if self._sample_number % self.sample_frequency == 0:
|
564
573
|
statistics = TensorStatistics.build(
|
565
|
-
tensor=layer_output,
|
574
|
+
tensor=layer_output,
|
575
|
+
include_statistics=self.include_statistics,
|
576
|
+
sample_percentage=self.downsample_percent,
|
566
577
|
)
|
567
578
|
self._data.append(statistics) # type: ignore
|
568
579
|
|
@@ -602,18 +613,18 @@ class ParameterUpdateStatistics(Statistic):
|
|
602
613
|
def __init__(
|
603
614
|
self,
|
604
615
|
*,
|
605
|
-
|
616
|
+
include_statistics: list[str] | None = None,
|
606
617
|
**kwargs,
|
607
618
|
) -> None:
|
608
619
|
"""
|
609
|
-
:param
|
610
|
-
fields in the model to
|
620
|
+
:param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
621
|
+
fields in the model to include in returned observations.
|
611
622
|
:param kwargs: Other observation keyword arguments.
|
612
623
|
"""
|
613
624
|
|
614
625
|
super().__init__(**kwargs)
|
615
626
|
|
616
|
-
self.
|
627
|
+
self.include_statistics = include_statistics
|
617
628
|
|
618
629
|
def _get_storage_format(self) -> StatisticStorageTypes:
|
619
630
|
"""
|
@@ -649,23 +660,28 @@ class ParameterUpdateStatistics(Statistic):
|
|
649
660
|
return update_tensor
|
650
661
|
|
651
662
|
|
663
|
+
class LHOPTParameterUpdateStatistics(ParameterUpdateStatistics):
|
664
|
+
|
665
|
+
pass
|
666
|
+
|
667
|
+
|
652
668
|
class ParameterStatistics(Statistic):
|
653
669
|
|
654
670
|
def __init__(
|
655
671
|
self,
|
656
672
|
*,
|
657
|
-
|
673
|
+
include_statistics: list[str] | None = None,
|
658
674
|
**kwargs,
|
659
675
|
) -> None:
|
660
676
|
"""
|
661
|
-
:param
|
662
|
-
fields in the model to
|
677
|
+
:param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
678
|
+
fields in the model to include in returned observations.
|
663
679
|
:param kwargs: Other observation keyword arguments.
|
664
680
|
"""
|
665
681
|
|
666
682
|
super().__init__(**kwargs)
|
667
683
|
|
668
|
-
self.
|
684
|
+
self.include_statistics = include_statistics
|
669
685
|
|
670
686
|
def _get_storage_format(self) -> StatisticStorageTypes:
|
671
687
|
"""
|
@@ -694,21 +710,30 @@ class ParameterStatistics(Statistic):
|
|
694
710
|
return torch.cat([p.data.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p)])
|
695
711
|
|
696
712
|
|
713
|
+
class LHOPTParameterStatistics(ParameterStatistics):
|
714
|
+
|
715
|
+
pass
|
716
|
+
|
717
|
+
|
697
718
|
class LAMBTrustRatioStatistics(Statistic):
|
698
719
|
|
699
720
|
def __init__(
|
700
721
|
self,
|
701
722
|
*,
|
723
|
+
include_statistics: list[str] | None = None,
|
702
724
|
use_log_transform: bool = False,
|
703
725
|
**kwargs,
|
704
726
|
) -> None:
|
705
727
|
"""
|
728
|
+
:param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
729
|
+
fields in the model to include in returned observations.
|
706
730
|
:param use_log_transform: Whether to transform the LAMB trust ratio by taking ln(1 + R).
|
707
731
|
:param kwargs: Other observation keyword arguments.
|
708
732
|
"""
|
709
733
|
|
710
734
|
super().__init__(**kwargs)
|
711
735
|
|
736
|
+
self.include_statistics = include_statistics
|
712
737
|
self.use_log_transform = use_log_transform
|
713
738
|
|
714
739
|
def _get_storage_format(self) -> StatisticStorageTypes:
|
@@ -759,6 +784,11 @@ class LAMBTrustRatioStatistics(Statistic):
|
|
759
784
|
return lamb_trust_ratio
|
760
785
|
|
761
786
|
|
787
|
+
class LHOPTLAMBTrustRatioStatistics(LAMBTrustRatioStatistics):
|
788
|
+
|
789
|
+
pass
|
790
|
+
|
791
|
+
|
762
792
|
class NumberOfParameters(Statistic):
|
763
793
|
|
764
794
|
def __init__(
|
@@ -958,3 +988,385 @@ class GradientVarianceFraction(Statistic):
|
|
958
988
|
return 0.0
|
959
989
|
|
960
990
|
return variance_parameters / total_parameters
|
991
|
+
|
992
|
+
|
993
|
+
class AverageParameterUpdateMagnitudeStatistics(Statistic):
|
994
|
+
|
995
|
+
def _get_storage_format(self) -> StatisticStorageTypes:
|
996
|
+
"""
|
997
|
+
:return: Storage format this observation stores data in. Must be one of the enum attributes in the
|
998
|
+
StatisticStorageTypes enumeration class.
|
999
|
+
"""
|
1000
|
+
|
1001
|
+
return StatisticStorageTypes.FLOAT
|
1002
|
+
|
1003
|
+
def _gather(
|
1004
|
+
self,
|
1005
|
+
*,
|
1006
|
+
optimizer: optim.Optimizer,
|
1007
|
+
model: nn.Module,
|
1008
|
+
parameters: list[torch.Tensor],
|
1009
|
+
parameter_group: dict[str, Any],
|
1010
|
+
) -> torch.Tensor | TensorStatistics | float | None:
|
1011
|
+
"""
|
1012
|
+
:param optimizer: Optimizer the given parameters and parameter group came from.
|
1013
|
+
:param model: Inner model to gather statistics from.
|
1014
|
+
:param parameters: List of parameters to gather statistics from.
|
1015
|
+
:param parameter_group: Parameter group the parameters originate from.
|
1016
|
+
:return: None or a float.
|
1017
|
+
"""
|
1018
|
+
|
1019
|
+
update_tensor = observation_utils.form_update_tensor(
|
1020
|
+
optimizer=optimizer, parameters=parameters, parameter_group=parameter_group
|
1021
|
+
)
|
1022
|
+
|
1023
|
+
# when update tensor is none, return 0.0
|
1024
|
+
if update_tensor is None:
|
1025
|
+
return 0.0
|
1026
|
+
|
1027
|
+
update_tensor = update_tensor.view(-1)
|
1028
|
+
update_tensor = update_tensor.abs()
|
1029
|
+
|
1030
|
+
average_update_magnitude = update_tensor.mean().item()
|
1031
|
+
|
1032
|
+
return average_update_magnitude
|
1033
|
+
|
1034
|
+
|
1035
|
+
class MomentumGradientRatioStatistics(Statistic):
|
1036
|
+
|
1037
|
+
def _get_storage_format(self) -> StatisticStorageTypes:
|
1038
|
+
"""
|
1039
|
+
:return: Storage format this observation stores data in. Must be one of the enum attributes in the
|
1040
|
+
StatisticStorageTypes enumeration class.
|
1041
|
+
"""
|
1042
|
+
|
1043
|
+
return StatisticStorageTypes.FLOAT
|
1044
|
+
|
1045
|
+
def _gather(
|
1046
|
+
self,
|
1047
|
+
*,
|
1048
|
+
optimizer: optim.Optimizer,
|
1049
|
+
model: nn.Module,
|
1050
|
+
parameters: list[torch.Tensor],
|
1051
|
+
parameter_group: dict[str, Any],
|
1052
|
+
) -> torch.Tensor | TensorStatistics | float | None:
|
1053
|
+
"""
|
1054
|
+
:param optimizer: Optimizer the given parameters and parameter group came from.
|
1055
|
+
:param model: Inner model to gather statistics from.
|
1056
|
+
:param parameters: List of parameters to gather statistics from.
|
1057
|
+
:param parameter_group: Parameter group the parameters originate from.
|
1058
|
+
:return: None, TensorStatistics model or a float.
|
1059
|
+
"""
|
1060
|
+
|
1061
|
+
momentum = observation_utils.form_momentum_tensor(
|
1062
|
+
optimizer=optimizer, parameters=parameters, parameter_group=parameter_group
|
1063
|
+
)
|
1064
|
+
if momentum is None:
|
1065
|
+
return None
|
1066
|
+
|
1067
|
+
gradients_list = [
|
1068
|
+
p.grad.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p) and p.grad is not None
|
1069
|
+
]
|
1070
|
+
|
1071
|
+
# Handle empty gradients list
|
1072
|
+
if not gradients_list:
|
1073
|
+
return 0.0
|
1074
|
+
|
1075
|
+
gradients = torch.cat(gradients_list).view(-1)
|
1076
|
+
|
1077
|
+
# momentum_gradient_ratio r^t=\frac{\|g^t\|_2}{\|\nabla f(w^t)\|_2}
|
1078
|
+
gradients_norm = gradients.norm(p=2)
|
1079
|
+
momentum_norm = momentum.norm(p=2)
|
1080
|
+
|
1081
|
+
if momentum_norm == 0:
|
1082
|
+
momentum_gradient_ratio = 0.0
|
1083
|
+
else:
|
1084
|
+
momentum_gradient_ratio = (gradients_norm / momentum_norm).item()
|
1085
|
+
|
1086
|
+
return momentum_gradient_ratio
|
1087
|
+
|
1088
|
+
|
1089
|
+
class LogOfNoiseScaleStatistics(Statistic):
|
1090
|
+
"""
|
1091
|
+
Statistics for the log of noise scale in training.
|
1092
|
+
|
1093
|
+
Tracks the log of noise scale B_{noise} using the formula:
|
1094
|
+
B_{noise} = tr(ΣH) / (G^T H G) ≈ (B/ε) * tr(HΣ) / tr(H^3 Σ)
|
1095
|
+
where:
|
1096
|
+
- H is the Hessian matrix
|
1097
|
+
- G is the gradient vector
|
1098
|
+
- Σ is the noise covariance matrix
|
1099
|
+
- B is the batch size
|
1100
|
+
- ε is the learning rate
|
1101
|
+
"""
|
1102
|
+
|
1103
|
+
@property
|
1104
|
+
def requires_gradient_graphs(self) -> bool:
|
1105
|
+
"""
|
1106
|
+
:return: Whether the statistic requires gradient graphs to be retained.
|
1107
|
+
"""
|
1108
|
+
|
1109
|
+
return False
|
1110
|
+
|
1111
|
+
@staticmethod
|
1112
|
+
def compute_hessian_diagonals(parameters: list[torch.Tensor]) -> torch.Tensor:
|
1113
|
+
"""
|
1114
|
+
:param parameters: Parameters to compute the hessian diagonal matrices for.
|
1115
|
+
:return: Tensor containing the hessian diagonal matrices for all given parameters.
|
1116
|
+
"""
|
1117
|
+
|
1118
|
+
hessian_diagonals = []
|
1119
|
+
|
1120
|
+
for parameter in parameters:
|
1121
|
+
if parameter.grad is not None:
|
1122
|
+
so_gradient = torch.autograd.grad(
|
1123
|
+
outputs=parameter.grad.clone(),
|
1124
|
+
inputs=parameter,
|
1125
|
+
grad_outputs=torch.ones_like(parameter.grad, requires_grad=True),
|
1126
|
+
only_inputs=True,
|
1127
|
+
retain_graph=True,
|
1128
|
+
create_graph=True,
|
1129
|
+
allow_unused=True,
|
1130
|
+
)[0]
|
1131
|
+
|
1132
|
+
if so_gradient is not None:
|
1133
|
+
hessian_diagonals.append(so_gradient.view(-1))
|
1134
|
+
else:
|
1135
|
+
hessian_diagonals.append(torch.zeros_like(parameter.view(-1)))
|
1136
|
+
|
1137
|
+
return torch.cat(hessian_diagonals)
|
1138
|
+
|
1139
|
+
def _get_storage_format(self) -> StatisticStorageTypes:
|
1140
|
+
"""
|
1141
|
+
:return: Storage format this observation stores data in. Must be one of the enum attributes in the
|
1142
|
+
StatisticStorageTypes enumeration class.
|
1143
|
+
"""
|
1144
|
+
|
1145
|
+
return StatisticStorageTypes.FLOAT
|
1146
|
+
|
1147
|
+
def _gather(
|
1148
|
+
self,
|
1149
|
+
*,
|
1150
|
+
optimizer: optim.Optimizer,
|
1151
|
+
model: nn.Module,
|
1152
|
+
parameters: list[torch.Tensor],
|
1153
|
+
parameter_group: dict[str, Any],
|
1154
|
+
) -> torch.Tensor | TensorStatistics | float | None:
|
1155
|
+
"""
|
1156
|
+
:param optimizer: Optimizer the given parameters and parameter group came from.
|
1157
|
+
:param model: Inner model to gather statistics from.
|
1158
|
+
:param parameters: List of parameters to gather statistics from.
|
1159
|
+
:param parameter_group: Parameter group the parameters originate from.
|
1160
|
+
:return: None, TensorStatistics model or a float.
|
1161
|
+
|
1162
|
+
Computes the log of noise scale using the approximate formula:
|
1163
|
+
log(B_{noise}) ≈ log(B/ε) (move to observer) + log(tr(HΣ)) - log(tr(H^3 Σ))
|
1164
|
+
where:
|
1165
|
+
- H is the Hessian matrix
|
1166
|
+
- Σ is the noise covariance matrix
|
1167
|
+
- B is the batch size
|
1168
|
+
- ε is the learning rate
|
1169
|
+
|
1170
|
+
"""
|
1171
|
+
|
1172
|
+
# Compute Hessian diagonals as in SecondOrderGradients Observation
|
1173
|
+
# hessian_diagonals = self.compute_hessian_diagonals(parameters)
|
1174
|
+
# use squared first order gradients as approximations
|
1175
|
+
fo_gradients = [
|
1176
|
+
p.grad.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p) and p.grad is not None
|
1177
|
+
]
|
1178
|
+
if not fo_gradients:
|
1179
|
+
return None
|
1180
|
+
|
1181
|
+
hessian_diagonals = torch.cat(fo_gradients) ** 2
|
1182
|
+
|
1183
|
+
if hessian_diagonals.numel() == 0: # No gradients
|
1184
|
+
return None
|
1185
|
+
|
1186
|
+
# For noise covariance matrix Σ, we'll use the identity matrix as an approximation
|
1187
|
+
# This is a common assumption when the exact noise structure is unknown
|
1188
|
+
noise_covariance = torch.ones_like(hessian_diagonals)
|
1189
|
+
|
1190
|
+
# Compute tr(HΣ)
|
1191
|
+
trace_hessian_noise_covariance = torch.sum(hessian_diagonals * noise_covariance)
|
1192
|
+
|
1193
|
+
# Avoid division by zero and log of zero
|
1194
|
+
if trace_hessian_noise_covariance <= 0:
|
1195
|
+
return None
|
1196
|
+
|
1197
|
+
log_trace_hessian_noise_covariance = torch.log(trace_hessian_noise_covariance).item()
|
1198
|
+
|
1199
|
+
# Compute tr(H^3 Σ)
|
1200
|
+
trace_hessian_cubed_noise_covariance = torch.sum(hessian_diagonals**3 * noise_covariance)
|
1201
|
+
if trace_hessian_cubed_noise_covariance <= 0:
|
1202
|
+
return None
|
1203
|
+
|
1204
|
+
log_trace_hessian_cubed_noise_covariance = torch.log(trace_hessian_cubed_noise_covariance).item()
|
1205
|
+
|
1206
|
+
# Compute final result: log(B_{noise}) ≈ log(tr(HΣ)) - log(tr(H^3 Σ))
|
1207
|
+
# Note: log(B/ε) term is handled in the observer layer
|
1208
|
+
log_noise_scale_without_log_b_over_epsilon = (
|
1209
|
+
log_trace_hessian_noise_covariance - log_trace_hessian_cubed_noise_covariance
|
1210
|
+
)
|
1211
|
+
|
1212
|
+
return log_noise_scale_without_log_b_over_epsilon
|
1213
|
+
|
1214
|
+
|
1215
|
+
class CosineSimilarityObserverOfGradientAndMomentumStatistics(Statistic):
|
1216
|
+
"""
|
1217
|
+
Statistics for the cosine similarity of gradient and momentum.
|
1218
|
+
"""
|
1219
|
+
|
1220
|
+
def _get_storage_format(self) -> StatisticStorageTypes:
|
1221
|
+
"""
|
1222
|
+
:return: Storage format this observation stores data in. Must be one of the enum attributes in the
|
1223
|
+
StatisticStorageTypes enumeration class.
|
1224
|
+
"""
|
1225
|
+
|
1226
|
+
return StatisticStorageTypes.FLOAT
|
1227
|
+
|
1228
|
+
def _gather(
|
1229
|
+
self,
|
1230
|
+
*,
|
1231
|
+
optimizer: optim.Optimizer,
|
1232
|
+
model: nn.Module,
|
1233
|
+
parameters: list[torch.Tensor],
|
1234
|
+
parameter_group: dict[str, Any],
|
1235
|
+
) -> torch.Tensor | TensorStatistics | float | None:
|
1236
|
+
"""
|
1237
|
+
:param optimizer: Optimizer the given parameters and parameter group came from.
|
1238
|
+
:param model: Inner model to gather statistics from.
|
1239
|
+
:param parameters: List of parameters to gather statistics from.
|
1240
|
+
:param parameter_group: Parameter group the parameters originate from.
|
1241
|
+
:return: None, TensorStatistics model or a float.
|
1242
|
+
"""
|
1243
|
+
parameters_with_grads = [
|
1244
|
+
p for p in parameters if observation_utils.tensor_on_local_rank(p) and p.grad is not None
|
1245
|
+
]
|
1246
|
+
|
1247
|
+
if not parameters_with_grads:
|
1248
|
+
return None
|
1249
|
+
|
1250
|
+
gradients_list = [p.grad.view(-1) for p in parameters_with_grads]
|
1251
|
+
gradients = torch.cat(gradients_list).view(-1)
|
1252
|
+
|
1253
|
+
momentum = observation_utils.form_momentum_tensor(
|
1254
|
+
optimizer=optimizer, parameters=parameters_with_grads, parameter_group=parameter_group
|
1255
|
+
)
|
1256
|
+
if momentum is None:
|
1257
|
+
return None
|
1258
|
+
|
1259
|
+
gradients_2d = gradients.unsqueeze(0)
|
1260
|
+
momentum_2d = momentum.unsqueeze(0)
|
1261
|
+
|
1262
|
+
cosine_similarity = F.cosine_similarity(gradients_2d, momentum_2d, dim=1).item()
|
1263
|
+
|
1264
|
+
return cosine_similarity
|
1265
|
+
|
1266
|
+
|
1267
|
+
class CosineSimilarityObserverOfGradientAndUpdateStatistics(Statistic):
|
1268
|
+
"""
|
1269
|
+
Statistics for the cosine similarity of gradient and update.
|
1270
|
+
"""
|
1271
|
+
|
1272
|
+
def _get_storage_format(self) -> StatisticStorageTypes:
|
1273
|
+
"""
|
1274
|
+
:return: Storage format this observation stores data in. Must be one of the enum attributes in the
|
1275
|
+
StatisticStorageTypes enumeration class.
|
1276
|
+
"""
|
1277
|
+
|
1278
|
+
return StatisticStorageTypes.FLOAT
|
1279
|
+
|
1280
|
+
def _gather(
|
1281
|
+
self,
|
1282
|
+
*,
|
1283
|
+
optimizer: optim.Optimizer,
|
1284
|
+
model: nn.Module,
|
1285
|
+
parameters: list[torch.Tensor],
|
1286
|
+
parameter_group: dict[str, Any],
|
1287
|
+
) -> torch.Tensor | TensorStatistics | float | None:
|
1288
|
+
"""
|
1289
|
+
:param optimizer: Optimizer the given parameters and parameter group came from.
|
1290
|
+
:param model: Inner model to gather statistics from.
|
1291
|
+
:param parameters: List of parameters to gather statistics from.
|
1292
|
+
:param parameter_group: Parameter group the parameters originate from.
|
1293
|
+
:return: None, TensorStatistics model or a float.
|
1294
|
+
"""
|
1295
|
+
# Filter parameters that have gradients to ensure consistent tensor sizes
|
1296
|
+
parameters_with_grads = [
|
1297
|
+
p for p in parameters if observation_utils.tensor_on_local_rank(p) and p.grad is not None
|
1298
|
+
]
|
1299
|
+
|
1300
|
+
if not parameters_with_grads:
|
1301
|
+
return None
|
1302
|
+
|
1303
|
+
gradients_list = [p.grad.view(-1) for p in parameters_with_grads]
|
1304
|
+
gradients = torch.cat(gradients_list).view(-1)
|
1305
|
+
|
1306
|
+
update_tensor = observation_utils.form_update_tensor(
|
1307
|
+
optimizer=optimizer, parameters=parameters_with_grads, parameter_group=parameter_group
|
1308
|
+
)
|
1309
|
+
|
1310
|
+
if update_tensor is None:
|
1311
|
+
return None
|
1312
|
+
|
1313
|
+
gradients_2d = gradients.unsqueeze(0)
|
1314
|
+
update_tensor_2d = update_tensor.unsqueeze(0)
|
1315
|
+
|
1316
|
+
cosine_similarity = F.cosine_similarity(gradients_2d, update_tensor_2d, dim=1).item()
|
1317
|
+
|
1318
|
+
return cosine_similarity
|
1319
|
+
|
1320
|
+
|
1321
|
+
class CosineSimilarityOfGradientAndParameterStatistics(Statistic):
|
1322
|
+
"""
|
1323
|
+
Statistics for the cosine similarity of gradient and parameter.
|
1324
|
+
"""
|
1325
|
+
|
1326
|
+
def _get_storage_format(self) -> StatisticStorageTypes:
|
1327
|
+
"""
|
1328
|
+
:return: Storage format this observation stores data in. Must be one of the enum attributes in the
|
1329
|
+
StatisticStorageTypes enumeration class.
|
1330
|
+
"""
|
1331
|
+
|
1332
|
+
return StatisticStorageTypes.FLOAT
|
1333
|
+
|
1334
|
+
def _gather(
|
1335
|
+
self,
|
1336
|
+
*,
|
1337
|
+
optimizer: optim.Optimizer,
|
1338
|
+
model: nn.Module,
|
1339
|
+
parameters: list[torch.Tensor],
|
1340
|
+
parameter_group: dict[str, Any],
|
1341
|
+
) -> torch.Tensor | TensorStatistics | float | None:
|
1342
|
+
"""
|
1343
|
+
:param optimizer: Optimizer the given parameters and parameter group came from.
|
1344
|
+
:param model: Inner model to gather statistics from.
|
1345
|
+
:param parameters: List of parameters to gather statistics from.
|
1346
|
+
:param parameter_group: Parameter group the parameters originate from.
|
1347
|
+
:return: None, TensorStatistics model or a float.
|
1348
|
+
"""
|
1349
|
+
# Filter parameters that have gradients to ensure consistent tensor sizes
|
1350
|
+
parameters_with_grads = [
|
1351
|
+
p for p in parameters if observation_utils.tensor_on_local_rank(p) and p.grad is not None
|
1352
|
+
]
|
1353
|
+
|
1354
|
+
if not parameters_with_grads:
|
1355
|
+
return None
|
1356
|
+
|
1357
|
+
gradients_list = [p.grad.view(-1) for p in parameters_with_grads]
|
1358
|
+
gradients = torch.cat(gradients_list).view(-1)
|
1359
|
+
|
1360
|
+
parameters_list = [p.view(-1) for p in parameters_with_grads]
|
1361
|
+
|
1362
|
+
if not parameters_list:
|
1363
|
+
return None
|
1364
|
+
|
1365
|
+
parameters_tensor = torch.cat(parameters_list).view(-1)
|
1366
|
+
|
1367
|
+
gradients_2d = gradients.unsqueeze(0)
|
1368
|
+
parameters_tensor_2d = parameters_tensor.unsqueeze(0)
|
1369
|
+
|
1370
|
+
cosine_similarity = F.cosine_similarity(gradients_2d, parameters_tensor_2d, dim=1).item()
|
1371
|
+
|
1372
|
+
return cosine_similarity
|