libinephany 0.18.0__py3-none-any.whl → 0.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- libinephany/observations/observation_utils.py +19 -2
- libinephany/observations/observers/global_observers/__init__.py +19 -1
- libinephany/observations/observers/global_observers/constants.py +2 -0
- libinephany/observations/observers/global_observers/gradient_observers.py +319 -1
- libinephany/observations/observers/global_observers/model_observers.py +219 -3
- libinephany/observations/observers/local_observers.py +127 -1
- libinephany/observations/statistic_trackers.py +595 -0
- libinephany/utils/constants.py +3 -3
- {libinephany-0.18.0.dist-info → libinephany-0.19.0.dist-info}/METADATA +1 -1
- {libinephany-0.18.0.dist-info → libinephany-0.19.0.dist-info}/RECORD +13 -13
- {libinephany-0.18.0.dist-info → libinephany-0.19.0.dist-info}/WHEEL +0 -0
- {libinephany-0.18.0.dist-info → libinephany-0.19.0.dist-info}/licenses/LICENSE +0 -0
- {libinephany-0.18.0.dist-info → libinephany-0.19.0.dist-info}/top_level.txt +0 -0
@@ -11,6 +11,7 @@ from typing import Any, Callable, final
|
|
11
11
|
import torch
|
12
12
|
import torch.distributed as dist
|
13
13
|
import torch.nn as nn
|
14
|
+
import torch.nn.functional as F
|
14
15
|
import torch.optim as optim
|
15
16
|
from torch.distributed import ReduceOp
|
16
17
|
|
@@ -597,6 +598,58 @@ class ActivationStatistics(Statistic):
|
|
597
598
|
return None
|
598
599
|
|
599
600
|
|
601
|
+
class InnerStepParameterUpdateStatistics(Statistic):
|
602
|
+
|
603
|
+
def __init__(
|
604
|
+
self,
|
605
|
+
*,
|
606
|
+
skip_statistics: list[str] | None = None,
|
607
|
+
**kwargs,
|
608
|
+
) -> None:
|
609
|
+
"""
|
610
|
+
:param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
611
|
+
fields in the model to not include in returned observations.
|
612
|
+
:param kwargs: Other observation keyword arguments.
|
613
|
+
"""
|
614
|
+
|
615
|
+
super().__init__(**kwargs)
|
616
|
+
|
617
|
+
self.skip_statistics = skip_statistics
|
618
|
+
|
619
|
+
def _get_storage_format(self) -> StatisticStorageTypes:
|
620
|
+
"""
|
621
|
+
:return: Storage format this observation stores data in. Must be one of the enum attributes in the
|
622
|
+
StatisticStorageTypes enumeration class.
|
623
|
+
"""
|
624
|
+
|
625
|
+
return StatisticStorageTypes.TENSOR_STATISTICS
|
626
|
+
|
627
|
+
def _gather(
|
628
|
+
self,
|
629
|
+
*,
|
630
|
+
optimizer: optim.Optimizer,
|
631
|
+
model: nn.Module,
|
632
|
+
parameters: list[torch.Tensor],
|
633
|
+
parameter_group: dict[str, Any],
|
634
|
+
) -> torch.Tensor | TensorStatistics | float | None:
|
635
|
+
"""
|
636
|
+
:param optimizer: Optimizer the given parameters and parameter group came from.
|
637
|
+
:param model: Inner model to gather statistics from.
|
638
|
+
:param parameters: List of parameters to gather statistics from.
|
639
|
+
:param parameter_group: Parameter group the parameters originate from.
|
640
|
+
:return: None, TensorStatistics model or a float.
|
641
|
+
"""
|
642
|
+
|
643
|
+
update_tensor = observation_utils.form_update_tensor(
|
644
|
+
optimizer=optimizer, parameters=parameters, parameter_group=parameter_group
|
645
|
+
)
|
646
|
+
|
647
|
+
if update_tensor is None:
|
648
|
+
update_tensor = torch.cat([torch.zeros(p.view(-1).shape, device=p.device) for p in parameters])
|
649
|
+
|
650
|
+
return update_tensor
|
651
|
+
|
652
|
+
|
600
653
|
class ParameterUpdateStatistics(Statistic):
|
601
654
|
|
602
655
|
def __init__(
|
@@ -694,6 +747,51 @@ class ParameterStatistics(Statistic):
|
|
694
747
|
return torch.cat([p.data.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p)])
|
695
748
|
|
696
749
|
|
750
|
+
class InnerStepParameterStatistics(Statistic):
|
751
|
+
|
752
|
+
def __init__(
|
753
|
+
self,
|
754
|
+
*,
|
755
|
+
skip_statistics: list[str] | None = None,
|
756
|
+
**kwargs,
|
757
|
+
) -> None:
|
758
|
+
"""
|
759
|
+
:param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
760
|
+
fields in the model to not include in returned observations.
|
761
|
+
:param kwargs: Other observation keyword arguments.
|
762
|
+
"""
|
763
|
+
|
764
|
+
super().__init__(**kwargs)
|
765
|
+
|
766
|
+
self.skip_statistics = skip_statistics
|
767
|
+
|
768
|
+
def _get_storage_format(self) -> StatisticStorageTypes:
|
769
|
+
"""
|
770
|
+
:return: Storage format this observation stores data in. Must be one of the enum attributes in the
|
771
|
+
StatisticStorageTypes enumeration class.
|
772
|
+
"""
|
773
|
+
|
774
|
+
return StatisticStorageTypes.TENSOR_STATISTICS
|
775
|
+
|
776
|
+
def _gather(
|
777
|
+
self,
|
778
|
+
*,
|
779
|
+
optimizer: optim.Optimizer,
|
780
|
+
model: nn.Module,
|
781
|
+
parameters: list[torch.Tensor],
|
782
|
+
parameter_group: dict[str, Any],
|
783
|
+
) -> torch.Tensor | TensorStatistics | float | None:
|
784
|
+
"""
|
785
|
+
:param optimizer: Optimizer the given parameters and parameter group came from.
|
786
|
+
:param model: Inner model to gather statistics from.
|
787
|
+
:param parameters: List of parameters to gather statistics from.
|
788
|
+
:param parameter_group: Parameter group the parameters originate from.
|
789
|
+
:return: None, TensorStatistics model or a float.
|
790
|
+
"""
|
791
|
+
|
792
|
+
return torch.cat([p.data.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p)])
|
793
|
+
|
794
|
+
|
697
795
|
class LAMBTrustRatioStatistics(Statistic):
|
698
796
|
|
699
797
|
def __init__(
|
@@ -759,6 +857,71 @@ class LAMBTrustRatioStatistics(Statistic):
|
|
759
857
|
return lamb_trust_ratio
|
760
858
|
|
761
859
|
|
860
|
+
class LHOPTLAMBTrustRatioStatistics(Statistic):
|
861
|
+
|
862
|
+
def __init__(
|
863
|
+
self,
|
864
|
+
*,
|
865
|
+
use_log_transform: bool = False,
|
866
|
+
**kwargs,
|
867
|
+
) -> None:
|
868
|
+
"""
|
869
|
+
:param use_log_transform: Whether to transform the LAMB trust ratio by taking ln(1 + R).
|
870
|
+
:param kwargs: Other observation keyword arguments.
|
871
|
+
"""
|
872
|
+
|
873
|
+
super().__init__(**kwargs)
|
874
|
+
|
875
|
+
self.use_log_transform = use_log_transform
|
876
|
+
|
877
|
+
def _get_storage_format(self) -> StatisticStorageTypes:
|
878
|
+
"""
|
879
|
+
:return: Storage format this observation stores data in. Must be one of the enum attributes in the
|
880
|
+
StatisticStorageTypes enumeration class.
|
881
|
+
"""
|
882
|
+
|
883
|
+
return StatisticStorageTypes.FLOAT
|
884
|
+
|
885
|
+
def _gather(
|
886
|
+
self,
|
887
|
+
*,
|
888
|
+
optimizer: optim.Optimizer,
|
889
|
+
model: nn.Module,
|
890
|
+
parameters: list[torch.Tensor],
|
891
|
+
parameter_group: dict[str, Any],
|
892
|
+
) -> torch.Tensor | TensorStatistics | float | None:
|
893
|
+
"""
|
894
|
+
:param optimizer: Optimizer the given parameters and parameter group came from.
|
895
|
+
:param model: Inner model to gather statistics from.
|
896
|
+
:param parameters: List of parameters to gather statistics from.
|
897
|
+
:param parameter_group: Parameter group the parameters originate from.
|
898
|
+
:return: None, TensorStatistics model or a float.
|
899
|
+
"""
|
900
|
+
|
901
|
+
weights_list = [p.data.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p)]
|
902
|
+
if weights_list:
|
903
|
+
weights = torch.cat(weights_list)
|
904
|
+
|
905
|
+
else:
|
906
|
+
weights = None
|
907
|
+
|
908
|
+
updates = observation_utils.form_update_tensor(
|
909
|
+
optimizer=optimizer, parameters=parameters, parameter_group=parameter_group
|
910
|
+
)
|
911
|
+
|
912
|
+
update_norm = torch.norm(updates, p=2).item() if updates is not None else 0
|
913
|
+
weight_norm = torch.norm(weights, p=2).item() if weights is not None else 0
|
914
|
+
|
915
|
+
lamb_trust_ratio = 0.0
|
916
|
+
if update_norm > 0:
|
917
|
+
lamb_trust_ratio = weight_norm / update_norm
|
918
|
+
|
919
|
+
if self.use_log_transform:
|
920
|
+
lamb_trust_ratio = math.log(1 + lamb_trust_ratio)
|
921
|
+
|
922
|
+
return lamb_trust_ratio
|
923
|
+
|
924
|
+
|
762
925
|
class NumberOfParameters(Statistic):
|
763
926
|
|
764
927
|
def __init__(
|
@@ -958,3 +1121,435 @@ class GradientVarianceFraction(Statistic):
|
|
958
1121
|
return 0.0
|
959
1122
|
|
960
1123
|
return variance_parameters / total_parameters
|
1124
|
+
|
1125
|
+
|
1126
|
+
class AverageParameterUpdateMagnitudeStatistics(Statistic):
|
1127
|
+
|
1128
|
+
def __init__(
|
1129
|
+
self,
|
1130
|
+
*,
|
1131
|
+
skip_statistics: list[str] | None = None,
|
1132
|
+
**kwargs,
|
1133
|
+
) -> None:
|
1134
|
+
"""
|
1135
|
+
:param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
1136
|
+
fields in the model to not include in returned observations.
|
1137
|
+
:param kwargs: Other observation keyword arguments.
|
1138
|
+
"""
|
1139
|
+
|
1140
|
+
super().__init__(**kwargs)
|
1141
|
+
|
1142
|
+
self.skip_statistics = skip_statistics
|
1143
|
+
|
1144
|
+
def _get_storage_format(self) -> StatisticStorageTypes:
|
1145
|
+
"""
|
1146
|
+
:return: Storage format this observation stores data in. Must be one of the enum attributes in the
|
1147
|
+
StatisticStorageTypes enumeration class.
|
1148
|
+
"""
|
1149
|
+
|
1150
|
+
return StatisticStorageTypes.FLOAT
|
1151
|
+
|
1152
|
+
def _gather(
|
1153
|
+
self,
|
1154
|
+
*,
|
1155
|
+
optimizer: optim.Optimizer,
|
1156
|
+
model: nn.Module,
|
1157
|
+
parameters: list[torch.Tensor],
|
1158
|
+
parameter_group: dict[str, Any],
|
1159
|
+
) -> torch.Tensor | TensorStatistics | float | None:
|
1160
|
+
"""
|
1161
|
+
:param optimizer: Optimizer the given parameters and parameter group came from.
|
1162
|
+
:param model: Inner model to gather statistics from.
|
1163
|
+
:param parameters: List of parameters to gather statistics from.
|
1164
|
+
:param parameter_group: Parameter group the parameters originate from.
|
1165
|
+
:return: None or a float.
|
1166
|
+
"""
|
1167
|
+
|
1168
|
+
update_tensor = observation_utils.form_update_tensor(
|
1169
|
+
optimizer=optimizer, parameters=parameters, parameter_group=parameter_group
|
1170
|
+
)
|
1171
|
+
|
1172
|
+
# when update tensor is none, return 0.0
|
1173
|
+
if update_tensor is None:
|
1174
|
+
return 0.0
|
1175
|
+
|
1176
|
+
update_tensor = update_tensor.view(-1)
|
1177
|
+
update_tensor = update_tensor.abs()
|
1178
|
+
|
1179
|
+
average_update_magnitude = update_tensor.mean().item()
|
1180
|
+
|
1181
|
+
return average_update_magnitude
|
1182
|
+
|
1183
|
+
|
1184
|
+
class MomentumGradientRatioStatistics(Statistic):
|
1185
|
+
|
1186
|
+
def __init__(
|
1187
|
+
self,
|
1188
|
+
*,
|
1189
|
+
skip_statistics: list[str] | None = None,
|
1190
|
+
**kwargs,
|
1191
|
+
) -> None:
|
1192
|
+
"""
|
1193
|
+
:param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
1194
|
+
fields in the model to not include in returned observations.
|
1195
|
+
:param kwargs: Other observation keyword arguments.
|
1196
|
+
"""
|
1197
|
+
|
1198
|
+
super().__init__(**kwargs)
|
1199
|
+
|
1200
|
+
self.skip_statistics = skip_statistics
|
1201
|
+
|
1202
|
+
def _get_storage_format(self) -> StatisticStorageTypes:
|
1203
|
+
"""
|
1204
|
+
:return: Storage format this observation stores data in. Must be one of the enum attributes in the
|
1205
|
+
StatisticStorageTypes enumeration class.
|
1206
|
+
"""
|
1207
|
+
|
1208
|
+
return StatisticStorageTypes.FLOAT
|
1209
|
+
|
1210
|
+
def _gather(
|
1211
|
+
self,
|
1212
|
+
*,
|
1213
|
+
optimizer: optim.Optimizer,
|
1214
|
+
model: nn.Module,
|
1215
|
+
parameters: list[torch.Tensor],
|
1216
|
+
parameter_group: dict[str, Any],
|
1217
|
+
) -> torch.Tensor | TensorStatistics | float | None:
|
1218
|
+
"""
|
1219
|
+
:param optimizer: Optimizer the given parameters and parameter group came from.
|
1220
|
+
:param model: Inner model to gather statistics from.
|
1221
|
+
:param parameters: List of parameters to gather statistics from.
|
1222
|
+
:param parameter_group: Parameter group the parameters originate from.
|
1223
|
+
:return: None, TensorStatistics model or a float.
|
1224
|
+
"""
|
1225
|
+
|
1226
|
+
momentum = observation_utils.form_momentum_tensor(
|
1227
|
+
optimizer=optimizer, parameters=parameters, parameter_group=parameter_group
|
1228
|
+
)
|
1229
|
+
if momentum is None:
|
1230
|
+
return None
|
1231
|
+
|
1232
|
+
gradients_list = [
|
1233
|
+
p.grad.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p) and p.grad is not None
|
1234
|
+
]
|
1235
|
+
|
1236
|
+
# Handle empty gradients list
|
1237
|
+
if not gradients_list:
|
1238
|
+
return 0.0
|
1239
|
+
|
1240
|
+
gradients = torch.cat(gradients_list).view(-1)
|
1241
|
+
|
1242
|
+
# momentum_gradient_ratio r^t=\frac{\|g^t\|_2}{\|\nabla f(w^t)\|_2}
|
1243
|
+
gradients_norm = gradients.norm(p=2)
|
1244
|
+
momentum_norm = momentum.norm(p=2)
|
1245
|
+
|
1246
|
+
if momentum_norm == 0:
|
1247
|
+
momentum_gradient_ratio = 0.0
|
1248
|
+
else:
|
1249
|
+
momentum_gradient_ratio = (gradients_norm / momentum_norm).item()
|
1250
|
+
|
1251
|
+
return momentum_gradient_ratio
|
1252
|
+
|
1253
|
+
|
1254
|
+
class LogOfNoiseScaleStatistics(Statistic):
|
1255
|
+
"""
|
1256
|
+
Statistics for the log of noise scale in training.
|
1257
|
+
|
1258
|
+
Tracks the log of noise scale B_{noise} using the formula:
|
1259
|
+
B_{noise} = tr(ΣH) / (G^T H G) ≈ (B/ε) * tr(HΣ) / tr(H^3 Σ)
|
1260
|
+
where:
|
1261
|
+
- H is the Hessian matrix
|
1262
|
+
- G is the gradient vector
|
1263
|
+
- Σ is the noise covariance matrix
|
1264
|
+
- B is the batch size
|
1265
|
+
- ε is the learning rate
|
1266
|
+
|
1267
|
+
|
1268
|
+
"""
|
1269
|
+
|
1270
|
+
def __init__(
|
1271
|
+
self,
|
1272
|
+
*,
|
1273
|
+
skip_statistics: list[str] | None = None,
|
1274
|
+
**kwargs,
|
1275
|
+
) -> None:
|
1276
|
+
"""
|
1277
|
+
:param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
1278
|
+
fields in the model to not include in returned observations.
|
1279
|
+
:param kwargs: Other observation keyword arguments.
|
1280
|
+
"""
|
1281
|
+
|
1282
|
+
super().__init__(**kwargs)
|
1283
|
+
|
1284
|
+
self.skip_statistics = skip_statistics
|
1285
|
+
|
1286
|
+
@property
|
1287
|
+
def requires_gradient_graphs(self) -> bool:
|
1288
|
+
"""
|
1289
|
+
:return: Whether the statistic requires gradient graphs to be retained.
|
1290
|
+
"""
|
1291
|
+
|
1292
|
+
return False
|
1293
|
+
|
1294
|
+
@staticmethod
|
1295
|
+
def compute_hessian_diagonals(parameters: list[torch.Tensor]) -> torch.Tensor:
|
1296
|
+
"""
|
1297
|
+
:param parameters: Parameters to compute the hessian diagonal matrices for.
|
1298
|
+
:return: Tensor containing the hessian diagonal matrices for all given parameters.
|
1299
|
+
"""
|
1300
|
+
|
1301
|
+
hessian_diagonals = []
|
1302
|
+
|
1303
|
+
for parameter in parameters:
|
1304
|
+
if parameter.grad is not None:
|
1305
|
+
so_gradient = torch.autograd.grad(
|
1306
|
+
outputs=parameter.grad.clone(),
|
1307
|
+
inputs=parameter,
|
1308
|
+
grad_outputs=torch.ones_like(parameter.grad, requires_grad=True),
|
1309
|
+
only_inputs=True,
|
1310
|
+
retain_graph=True,
|
1311
|
+
create_graph=True,
|
1312
|
+
allow_unused=True,
|
1313
|
+
)[0]
|
1314
|
+
|
1315
|
+
if so_gradient is not None:
|
1316
|
+
hessian_diagonals.append(so_gradient.view(-1))
|
1317
|
+
else:
|
1318
|
+
hessian_diagonals.append(torch.zeros_like(parameter.view(-1)))
|
1319
|
+
|
1320
|
+
return torch.cat(hessian_diagonals)
|
1321
|
+
|
1322
|
+
def _get_storage_format(self) -> StatisticStorageTypes:
|
1323
|
+
"""
|
1324
|
+
:return: Storage format this observation stores data in. Must be one of the enum attributes in the
|
1325
|
+
StatisticStorageTypes enumeration class.
|
1326
|
+
"""
|
1327
|
+
|
1328
|
+
return StatisticStorageTypes.FLOAT
|
1329
|
+
|
1330
|
+
def _gather(
|
1331
|
+
self,
|
1332
|
+
*,
|
1333
|
+
optimizer: optim.Optimizer,
|
1334
|
+
model: nn.Module,
|
1335
|
+
parameters: list[torch.Tensor],
|
1336
|
+
parameter_group: dict[str, Any],
|
1337
|
+
) -> torch.Tensor | TensorStatistics | float | None:
|
1338
|
+
"""
|
1339
|
+
:param optimizer: Optimizer the given parameters and parameter group came from.
|
1340
|
+
:param model: Inner model to gather statistics from.
|
1341
|
+
:param parameters: List of parameters to gather statistics from.
|
1342
|
+
:param parameter_group: Parameter group the parameters originate from.
|
1343
|
+
:return: None, TensorStatistics model or a float.
|
1344
|
+
|
1345
|
+
Computes the log of noise scale using the approximate formula:
|
1346
|
+
log(B_{noise}) ≈ log(B/ε) (move to observer) + log(tr(HΣ)) - log(tr(H^3 Σ))
|
1347
|
+
where:
|
1348
|
+
- H is the Hessian matrix
|
1349
|
+
- Σ is the noise covariance matrix
|
1350
|
+
- B is the batch size
|
1351
|
+
- ε is the learning rate
|
1352
|
+
|
1353
|
+
"""
|
1354
|
+
|
1355
|
+
# Compute Hessian diagonals as in SecondOrderGradients Observation
|
1356
|
+
# hessian_diagonals = self.compute_hessian_diagonals(parameters)
|
1357
|
+
# use squared first order gradients as approximations
|
1358
|
+
fo_gradients = [
|
1359
|
+
p.grad.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p) and p.grad is not None
|
1360
|
+
]
|
1361
|
+
if not fo_gradients:
|
1362
|
+
return None
|
1363
|
+
|
1364
|
+
hessian_diagonals = torch.cat(fo_gradients) ** 2
|
1365
|
+
|
1366
|
+
if hessian_diagonals.numel() == 0: # No gradients
|
1367
|
+
return None
|
1368
|
+
|
1369
|
+
# For noise covariance matrix Σ, we'll use the identity matrix as an approximation
|
1370
|
+
# This is a common assumption when the exact noise structure is unknown
|
1371
|
+
noise_covariance = torch.ones_like(hessian_diagonals)
|
1372
|
+
|
1373
|
+
# Compute tr(HΣ)
|
1374
|
+
trace_hessian_noise_covariance = torch.sum(hessian_diagonals * noise_covariance)
|
1375
|
+
|
1376
|
+
# Avoid division by zero and log of zero
|
1377
|
+
if trace_hessian_noise_covariance <= 0:
|
1378
|
+
return None
|
1379
|
+
|
1380
|
+
log_trace_hessian_noise_covariance = torch.log(trace_hessian_noise_covariance).item()
|
1381
|
+
|
1382
|
+
# Compute tr(H^3 Σ)
|
1383
|
+
trace_hessian_cubed_noise_covariance = torch.sum(hessian_diagonals**3 * noise_covariance)
|
1384
|
+
if trace_hessian_cubed_noise_covariance <= 0:
|
1385
|
+
return None
|
1386
|
+
|
1387
|
+
log_trace_hessian_cubed_noise_covariance = torch.log(trace_hessian_cubed_noise_covariance).item()
|
1388
|
+
|
1389
|
+
# Compute final result: log(B_{noise}) ≈ log(tr(HΣ)) - log(tr(H^3 Σ))
|
1390
|
+
# Note: log(B/ε) term is handled in the observer layer
|
1391
|
+
log_noise_scale_without_log_b_over_epsilon = (
|
1392
|
+
log_trace_hessian_noise_covariance - log_trace_hessian_cubed_noise_covariance
|
1393
|
+
)
|
1394
|
+
|
1395
|
+
return log_noise_scale_without_log_b_over_epsilon
|
1396
|
+
|
1397
|
+
|
1398
|
+
class CosineSimilarityObserverOfGradientAndMomentumStatistics(Statistic):
|
1399
|
+
"""
|
1400
|
+
Statistics for the cosine similarity of gradient and momentum.
|
1401
|
+
"""
|
1402
|
+
|
1403
|
+
def _get_storage_format(self) -> StatisticStorageTypes:
|
1404
|
+
"""
|
1405
|
+
:return: Storage format this observation stores data in. Must be one of the enum attributes in the
|
1406
|
+
StatisticStorageTypes enumeration class.
|
1407
|
+
"""
|
1408
|
+
|
1409
|
+
return StatisticStorageTypes.FLOAT
|
1410
|
+
|
1411
|
+
def _gather(
|
1412
|
+
self,
|
1413
|
+
*,
|
1414
|
+
optimizer: optim.Optimizer,
|
1415
|
+
model: nn.Module,
|
1416
|
+
parameters: list[torch.Tensor],
|
1417
|
+
parameter_group: dict[str, Any],
|
1418
|
+
) -> torch.Tensor | TensorStatistics | float | None:
|
1419
|
+
"""
|
1420
|
+
:param optimizer: Optimizer the given parameters and parameter group came from.
|
1421
|
+
:param model: Inner model to gather statistics from.
|
1422
|
+
:param parameters: List of parameters to gather statistics from.
|
1423
|
+
:param parameter_group: Parameter group the parameters originate from.
|
1424
|
+
:return: None, TensorStatistics model or a float.
|
1425
|
+
"""
|
1426
|
+
parameters_with_grads = [
|
1427
|
+
p for p in parameters if observation_utils.tensor_on_local_rank(p) and p.grad is not None
|
1428
|
+
]
|
1429
|
+
|
1430
|
+
if not parameters_with_grads:
|
1431
|
+
return None
|
1432
|
+
|
1433
|
+
gradients_list = [p.grad.view(-1) for p in parameters_with_grads]
|
1434
|
+
gradients = torch.cat(gradients_list).view(-1)
|
1435
|
+
|
1436
|
+
momentum = observation_utils.form_momentum_tensor(
|
1437
|
+
optimizer=optimizer, parameters=parameters_with_grads, parameter_group=parameter_group
|
1438
|
+
)
|
1439
|
+
if momentum is None:
|
1440
|
+
return None
|
1441
|
+
|
1442
|
+
gradients_2d = gradients.unsqueeze(0)
|
1443
|
+
momentum_2d = momentum.unsqueeze(0)
|
1444
|
+
|
1445
|
+
cosine_similarity = F.cosine_similarity(gradients_2d, momentum_2d, dim=1).item()
|
1446
|
+
|
1447
|
+
return cosine_similarity
|
1448
|
+
|
1449
|
+
|
1450
|
+
class CosineSimilarityObserverOfGradientAndUpdateStatistics(Statistic):
|
1451
|
+
"""
|
1452
|
+
Statistics for the cosine similarity of gradient and update.
|
1453
|
+
"""
|
1454
|
+
|
1455
|
+
def _get_storage_format(self) -> StatisticStorageTypes:
|
1456
|
+
"""
|
1457
|
+
:return: Storage format this observation stores data in. Must be one of the enum attributes in the
|
1458
|
+
StatisticStorageTypes enumeration class.
|
1459
|
+
"""
|
1460
|
+
|
1461
|
+
return StatisticStorageTypes.FLOAT
|
1462
|
+
|
1463
|
+
def _gather(
|
1464
|
+
self,
|
1465
|
+
*,
|
1466
|
+
optimizer: optim.Optimizer,
|
1467
|
+
model: nn.Module,
|
1468
|
+
parameters: list[torch.Tensor],
|
1469
|
+
parameter_group: dict[str, Any],
|
1470
|
+
) -> torch.Tensor | TensorStatistics | float | None:
|
1471
|
+
"""
|
1472
|
+
:param optimizer: Optimizer the given parameters and parameter group came from.
|
1473
|
+
:param model: Inner model to gather statistics from.
|
1474
|
+
:param parameters: List of parameters to gather statistics from.
|
1475
|
+
:param parameter_group: Parameter group the parameters originate from.
|
1476
|
+
:return: None, TensorStatistics model or a float.
|
1477
|
+
"""
|
1478
|
+
# Filter parameters that have gradients to ensure consistent tensor sizes
|
1479
|
+
parameters_with_grads = [
|
1480
|
+
p for p in parameters if observation_utils.tensor_on_local_rank(p) and p.grad is not None
|
1481
|
+
]
|
1482
|
+
|
1483
|
+
if not parameters_with_grads:
|
1484
|
+
return None
|
1485
|
+
|
1486
|
+
gradients_list = [p.grad.view(-1) for p in parameters_with_grads]
|
1487
|
+
gradients = torch.cat(gradients_list).view(-1)
|
1488
|
+
|
1489
|
+
update_tensor = observation_utils.form_update_tensor(
|
1490
|
+
optimizer=optimizer, parameters=parameters_with_grads, parameter_group=parameter_group
|
1491
|
+
)
|
1492
|
+
|
1493
|
+
if update_tensor is None:
|
1494
|
+
return None
|
1495
|
+
|
1496
|
+
gradients_2d = gradients.unsqueeze(0)
|
1497
|
+
update_tensor_2d = update_tensor.unsqueeze(0)
|
1498
|
+
|
1499
|
+
cosine_similarity = F.cosine_similarity(gradients_2d, update_tensor_2d, dim=1).item()
|
1500
|
+
|
1501
|
+
return cosine_similarity
|
1502
|
+
|
1503
|
+
|
1504
|
+
class CosineSimilarityOfGradientAndParameterStatistics(Statistic):
|
1505
|
+
"""
|
1506
|
+
Statistics for the cosine similarity of gradient and parameter.
|
1507
|
+
"""
|
1508
|
+
|
1509
|
+
def _get_storage_format(self) -> StatisticStorageTypes:
|
1510
|
+
"""
|
1511
|
+
:return: Storage format this observation stores data in. Must be one of the enum attributes in the
|
1512
|
+
StatisticStorageTypes enumeration class.
|
1513
|
+
"""
|
1514
|
+
|
1515
|
+
return StatisticStorageTypes.FLOAT
|
1516
|
+
|
1517
|
+
def _gather(
|
1518
|
+
self,
|
1519
|
+
*,
|
1520
|
+
optimizer: optim.Optimizer,
|
1521
|
+
model: nn.Module,
|
1522
|
+
parameters: list[torch.Tensor],
|
1523
|
+
parameter_group: dict[str, Any],
|
1524
|
+
) -> torch.Tensor | TensorStatistics | float | None:
|
1525
|
+
"""
|
1526
|
+
:param optimizer: Optimizer the given parameters and parameter group came from.
|
1527
|
+
:param model: Inner model to gather statistics from.
|
1528
|
+
:param parameters: List of parameters to gather statistics from.
|
1529
|
+
:param parameter_group: Parameter group the parameters originate from.
|
1530
|
+
:return: None, TensorStatistics model or a float.
|
1531
|
+
"""
|
1532
|
+
# Filter parameters that have gradients to ensure consistent tensor sizes
|
1533
|
+
parameters_with_grads = [
|
1534
|
+
p for p in parameters if observation_utils.tensor_on_local_rank(p) and p.grad is not None
|
1535
|
+
]
|
1536
|
+
|
1537
|
+
if not parameters_with_grads:
|
1538
|
+
return None
|
1539
|
+
|
1540
|
+
gradients_list = [p.grad.view(-1) for p in parameters_with_grads]
|
1541
|
+
gradients = torch.cat(gradients_list).view(-1)
|
1542
|
+
|
1543
|
+
parameters_list = [p.view(-1) for p in parameters_with_grads]
|
1544
|
+
|
1545
|
+
if not parameters_list:
|
1546
|
+
return None
|
1547
|
+
|
1548
|
+
parameters_tensor = torch.cat(parameters_list).view(-1)
|
1549
|
+
|
1550
|
+
gradients_2d = gradients.unsqueeze(0)
|
1551
|
+
parameters_tensor_2d = parameters_tensor.unsqueeze(0)
|
1552
|
+
|
1553
|
+
cosine_similarity = F.cosine_similarity(gradients_2d, parameters_tensor_2d, dim=1).item()
|
1554
|
+
|
1555
|
+
return cosine_similarity
|
libinephany/utils/constants.py
CHANGED
@@ -43,7 +43,7 @@ AGENT_PREFIX_EPS = "adam-eps"
|
|
43
43
|
AGENT_PREFIX_SGD_MOMENTUM = "sgd-momentum"
|
44
44
|
|
45
45
|
AGENT_BATCH_SIZE = "batch-size"
|
46
|
-
|
46
|
+
AGENT_PREFIX_GRADIENT_ACCUMULATION = "gradient-accumulation"
|
47
47
|
|
48
48
|
AGENT_BANDIT_SUFFIX = "bandit-agent"
|
49
49
|
|
@@ -68,7 +68,7 @@ PREFIXES = [
|
|
68
68
|
AGENT_PREFIX_BETA_TWO,
|
69
69
|
AGENT_PREFIX_EPS,
|
70
70
|
AGENT_PREFIX_SGD_MOMENTUM,
|
71
|
-
|
71
|
+
AGENT_PREFIX_GRADIENT_ACCUMULATION,
|
72
72
|
]
|
73
73
|
PREFIXES_TO_HPARAMS = {
|
74
74
|
AGENT_PREFIX_LR: LEARNING_RATE,
|
@@ -79,6 +79,6 @@ PREFIXES_TO_HPARAMS = {
|
|
79
79
|
AGENT_PREFIX_BETA_TWO: ADAM_BETA_TWO,
|
80
80
|
AGENT_PREFIX_EPS: ADAM_EPS,
|
81
81
|
AGENT_PREFIX_SGD_MOMENTUM: SGD_MOMENTUM,
|
82
|
-
|
82
|
+
AGENT_PREFIX_GRADIENT_ACCUMULATION: GRADIENT_ACCUMULATION,
|
83
83
|
}
|
84
84
|
HPARAMS_TO_PREFIXES = {hparam: prefix for prefix, hparam in PREFIXES_TO_HPARAMS.items()}
|