libinephany 0.18.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,6 +11,7 @@ from typing import Any, Callable, final
11
11
  import torch
12
12
  import torch.distributed as dist
13
13
  import torch.nn as nn
14
+ import torch.nn.functional as F
14
15
  import torch.optim as optim
15
16
  from torch.distributed import ReduceOp
16
17
 
@@ -75,7 +76,7 @@ class Statistic(ABC):
75
76
  self.max_cache_size = max_statistic_cache_size
76
77
  self.downsample_percent = tensor_stats_downsample_percentage
77
78
  self.sample_frequency = statistic_sample_frequency
78
- self.skip_statistics: list[str] | None = None
79
+ self.include_statistics: list[str] | None = None
79
80
 
80
81
  @final
81
82
  @property
@@ -194,12 +195,17 @@ class Statistic(ABC):
194
195
  Processes the tensor cache to build a TensorStatistic model.
195
196
  """
196
197
 
198
+ if not self.include_statistics:
199
+ raise ValueError(f"{self.__class__.__name__} must be provided with include_statistics.")
200
+
197
201
  if self._tensor_cache:
198
202
  concatenated = torch.cat(self._tensor_cache)
199
203
  self._tensor_cache = []
200
204
 
201
205
  statistics = TensorStatistics.build(
202
- tensor=concatenated, skip_statistics=self.skip_statistics, sample_percentage=self.downsample_percent
206
+ tensor=concatenated,
207
+ include_statistics=self.include_statistics,
208
+ sample_percentage=self.downsample_percent,
203
209
  )
204
210
  self._data.append(statistics) # type: ignore
205
211
 
@@ -367,18 +373,18 @@ class FirstOrderGradients(Statistic):
367
373
  def __init__(
368
374
  self,
369
375
  *,
370
- skip_statistics: list[str] | None = None,
376
+ include_statistics: list[str] | None = None,
371
377
  **kwargs,
372
378
  ) -> None:
373
379
  """
374
- :param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
375
- fields in the model to not include in returned observations.
380
+ :param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
381
+ fields in the model to include in returned observations.
376
382
  :param kwargs: Other observation keyword arguments.
377
383
  """
378
384
 
379
385
  super().__init__(**kwargs)
380
386
 
381
- self.skip_statistics = skip_statistics
387
+ self.include_statistics = include_statistics
382
388
 
383
389
  def _get_storage_format(self) -> StatisticStorageTypes:
384
390
  """
@@ -421,22 +427,22 @@ class SecondOrderGradients(Statistic):
421
427
  def __init__(
422
428
  self,
423
429
  *,
430
+ include_statistics: list[str] | None = None,
424
431
  compute_hessian_diagonal: bool = False,
425
- skip_statistics: list[str] | None = None,
426
432
  **kwargs,
427
433
  ) -> None:
428
434
  """
435
+ :param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
436
+ fields in the model to include in returned observations.
429
437
  :param compute_hessian_diagonal: Whether to compute the Hessian diagonal to determine second order gradients
430
438
  or use the squared first order gradients as approximations in the same way Adam does.
431
- :param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
432
- fields in the model to not include in returned observations.
433
439
  :param kwargs: Other observation keyword arguments.
434
440
  """
435
441
 
436
442
  super().__init__(**kwargs)
437
443
 
438
444
  self.compute_hessian_diagonal = compute_hessian_diagonal
439
- self.skip_statistics = skip_statistics
445
+ self.include_statistics = include_statistics
440
446
 
441
447
  @property
442
448
  def requires_gradient_graphs(self) -> bool:
@@ -519,18 +525,18 @@ class ActivationStatistics(Statistic):
519
525
  def __init__(
520
526
  self,
521
527
  *,
522
- skip_statistics: list[str] | None = None,
528
+ include_statistics: list[str] | None = None,
523
529
  **kwargs,
524
530
  ) -> None:
525
531
  """
526
- :param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
527
- fields in the model to not include in returned observations.
532
+ :param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
533
+ fields in the model to include in returned observations.
528
534
  :param kwargs: Other observation keyword arguments.
529
535
  """
530
536
 
531
537
  super().__init__(**kwargs)
532
538
 
533
- self.skip_statistics = skip_statistics
539
+ self.include_statistics = include_statistics
534
540
 
535
541
  @property
536
542
  def uses_forward_hook(self) -> bool:
@@ -553,6 +559,9 @@ class ActivationStatistics(Statistic):
553
559
  :return: Forward hook to register the function with.
554
560
  """
555
561
 
562
+ if self.include_statistics is None:
563
+ raise ValueError("include_statistics is required to use forward hooks!")
564
+
556
565
  def hook(module: nn.Module, layer_input: torch.Tensor, layer_output: torch.Tensor) -> None:
557
566
  """
558
567
  :param module: Module the hook was registered with. Not used here.
@@ -562,7 +571,9 @@ class ActivationStatistics(Statistic):
562
571
 
563
572
  if self._sample_number % self.sample_frequency == 0:
564
573
  statistics = TensorStatistics.build(
565
- tensor=layer_output, skip_statistics=self.skip_statistics, sample_percentage=self.downsample_percent
574
+ tensor=layer_output,
575
+ include_statistics=self.include_statistics,
576
+ sample_percentage=self.downsample_percent,
566
577
  )
567
578
  self._data.append(statistics) # type: ignore
568
579
 
@@ -602,18 +613,18 @@ class ParameterUpdateStatistics(Statistic):
602
613
  def __init__(
603
614
  self,
604
615
  *,
605
- skip_statistics: list[str] | None = None,
616
+ include_statistics: list[str] | None = None,
606
617
  **kwargs,
607
618
  ) -> None:
608
619
  """
609
- :param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
610
- fields in the model to not include in returned observations.
620
+ :param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
621
+ fields in the model to include in returned observations.
611
622
  :param kwargs: Other observation keyword arguments.
612
623
  """
613
624
 
614
625
  super().__init__(**kwargs)
615
626
 
616
- self.skip_statistics = skip_statistics
627
+ self.include_statistics = include_statistics
617
628
 
618
629
  def _get_storage_format(self) -> StatisticStorageTypes:
619
630
  """
@@ -649,23 +660,28 @@ class ParameterUpdateStatistics(Statistic):
649
660
  return update_tensor
650
661
 
651
662
 
663
+ class LHOPTParameterUpdateStatistics(ParameterUpdateStatistics):
664
+
665
+ pass
666
+
667
+
652
668
  class ParameterStatistics(Statistic):
653
669
 
654
670
  def __init__(
655
671
  self,
656
672
  *,
657
- skip_statistics: list[str] | None = None,
673
+ include_statistics: list[str] | None = None,
658
674
  **kwargs,
659
675
  ) -> None:
660
676
  """
661
- :param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
662
- fields in the model to not include in returned observations.
677
+ :param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
678
+ fields in the model to include in returned observations.
663
679
  :param kwargs: Other observation keyword arguments.
664
680
  """
665
681
 
666
682
  super().__init__(**kwargs)
667
683
 
668
- self.skip_statistics = skip_statistics
684
+ self.include_statistics = include_statistics
669
685
 
670
686
  def _get_storage_format(self) -> StatisticStorageTypes:
671
687
  """
@@ -694,21 +710,30 @@ class ParameterStatistics(Statistic):
694
710
  return torch.cat([p.data.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p)])
695
711
 
696
712
 
713
+ class LHOPTParameterStatistics(ParameterStatistics):
714
+
715
+ pass
716
+
717
+
697
718
  class LAMBTrustRatioStatistics(Statistic):
698
719
 
699
720
  def __init__(
700
721
  self,
701
722
  *,
723
+ include_statistics: list[str] | None = None,
702
724
  use_log_transform: bool = False,
703
725
  **kwargs,
704
726
  ) -> None:
705
727
  """
728
+ :param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
729
+ fields in the model to include in returned observations.
706
730
  :param use_log_transform: Whether to transform the LAMB trust ratio by taking ln(1 + R).
707
731
  :param kwargs: Other observation keyword arguments.
708
732
  """
709
733
 
710
734
  super().__init__(**kwargs)
711
735
 
736
+ self.include_statistics = include_statistics
712
737
  self.use_log_transform = use_log_transform
713
738
 
714
739
  def _get_storage_format(self) -> StatisticStorageTypes:
@@ -759,6 +784,11 @@ class LAMBTrustRatioStatistics(Statistic):
759
784
  return lamb_trust_ratio
760
785
 
761
786
 
787
+ class LHOPTLAMBTrustRatioStatistics(LAMBTrustRatioStatistics):
788
+
789
+ pass
790
+
791
+
762
792
  class NumberOfParameters(Statistic):
763
793
 
764
794
  def __init__(
@@ -958,3 +988,385 @@ class GradientVarianceFraction(Statistic):
958
988
  return 0.0
959
989
 
960
990
  return variance_parameters / total_parameters
991
+
992
+
993
+ class AverageParameterUpdateMagnitudeStatistics(Statistic):
994
+
995
+ def _get_storage_format(self) -> StatisticStorageTypes:
996
+ """
997
+ :return: Storage format this observation stores data in. Must be one of the enum attributes in the
998
+ StatisticStorageTypes enumeration class.
999
+ """
1000
+
1001
+ return StatisticStorageTypes.FLOAT
1002
+
1003
+ def _gather(
1004
+ self,
1005
+ *,
1006
+ optimizer: optim.Optimizer,
1007
+ model: nn.Module,
1008
+ parameters: list[torch.Tensor],
1009
+ parameter_group: dict[str, Any],
1010
+ ) -> torch.Tensor | TensorStatistics | float | None:
1011
+ """
1012
+ :param optimizer: Optimizer the given parameters and parameter group came from.
1013
+ :param model: Inner model to gather statistics from.
1014
+ :param parameters: List of parameters to gather statistics from.
1015
+ :param parameter_group: Parameter group the parameters originate from.
1016
+ :return: None or a float.
1017
+ """
1018
+
1019
+ update_tensor = observation_utils.form_update_tensor(
1020
+ optimizer=optimizer, parameters=parameters, parameter_group=parameter_group
1021
+ )
1022
+
1023
+ # when update tensor is none, return 0.0
1024
+ if update_tensor is None:
1025
+ return 0.0
1026
+
1027
+ update_tensor = update_tensor.view(-1)
1028
+ update_tensor = update_tensor.abs()
1029
+
1030
+ average_update_magnitude = update_tensor.mean().item()
1031
+
1032
+ return average_update_magnitude
1033
+
1034
+
1035
+ class MomentumGradientRatioStatistics(Statistic):
1036
+
1037
+ def _get_storage_format(self) -> StatisticStorageTypes:
1038
+ """
1039
+ :return: Storage format this observation stores data in. Must be one of the enum attributes in the
1040
+ StatisticStorageTypes enumeration class.
1041
+ """
1042
+
1043
+ return StatisticStorageTypes.FLOAT
1044
+
1045
+ def _gather(
1046
+ self,
1047
+ *,
1048
+ optimizer: optim.Optimizer,
1049
+ model: nn.Module,
1050
+ parameters: list[torch.Tensor],
1051
+ parameter_group: dict[str, Any],
1052
+ ) -> torch.Tensor | TensorStatistics | float | None:
1053
+ """
1054
+ :param optimizer: Optimizer the given parameters and parameter group came from.
1055
+ :param model: Inner model to gather statistics from.
1056
+ :param parameters: List of parameters to gather statistics from.
1057
+ :param parameter_group: Parameter group the parameters originate from.
1058
+ :return: None, TensorStatistics model or a float.
1059
+ """
1060
+
1061
+ momentum = observation_utils.form_momentum_tensor(
1062
+ optimizer=optimizer, parameters=parameters, parameter_group=parameter_group
1063
+ )
1064
+ if momentum is None:
1065
+ return None
1066
+
1067
+ gradients_list = [
1068
+ p.grad.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p) and p.grad is not None
1069
+ ]
1070
+
1071
+ # Handle empty gradients list
1072
+ if not gradients_list:
1073
+ return 0.0
1074
+
1075
+ gradients = torch.cat(gradients_list).view(-1)
1076
+
1077
+ # momentum_gradient_ratio r^t=\frac{\|g^t\|_2}{\|\nabla f(w^t)\|_2}
1078
+ gradients_norm = gradients.norm(p=2)
1079
+ momentum_norm = momentum.norm(p=2)
1080
+
1081
+ if momentum_norm == 0:
1082
+ momentum_gradient_ratio = 0.0
1083
+ else:
1084
+ momentum_gradient_ratio = (gradients_norm / momentum_norm).item()
1085
+
1086
+ return momentum_gradient_ratio
1087
+
1088
+
1089
+ class LogOfNoiseScaleStatistics(Statistic):
1090
+ """
1091
+ Statistics for the log of noise scale in training.
1092
+
1093
+ Tracks the log of noise scale B_{noise} using the formula:
1094
+ B_{noise} = tr(ΣH) / (G^T H G) ≈ (B/ε) * tr(HΣ) / tr(H^3 Σ)
1095
+ where:
1096
+ - H is the Hessian matrix
1097
+ - G is the gradient vector
1098
+ - Σ is the noise covariance matrix
1099
+ - B is the batch size
1100
+ - ε is the learning rate
1101
+ """
1102
+
1103
+ @property
1104
+ def requires_gradient_graphs(self) -> bool:
1105
+ """
1106
+ :return: Whether the statistic requires gradient graphs to be retained.
1107
+ """
1108
+
1109
+ return False
1110
+
1111
+ @staticmethod
1112
+ def compute_hessian_diagonals(parameters: list[torch.Tensor]) -> torch.Tensor:
1113
+ """
1114
+ :param parameters: Parameters to compute the hessian diagonal matrices for.
1115
+ :return: Tensor containing the hessian diagonal matrices for all given parameters.
1116
+ """
1117
+
1118
+ hessian_diagonals = []
1119
+
1120
+ for parameter in parameters:
1121
+ if parameter.grad is not None:
1122
+ so_gradient = torch.autograd.grad(
1123
+ outputs=parameter.grad.clone(),
1124
+ inputs=parameter,
1125
+ grad_outputs=torch.ones_like(parameter.grad, requires_grad=True),
1126
+ only_inputs=True,
1127
+ retain_graph=True,
1128
+ create_graph=True,
1129
+ allow_unused=True,
1130
+ )[0]
1131
+
1132
+ if so_gradient is not None:
1133
+ hessian_diagonals.append(so_gradient.view(-1))
1134
+ else:
1135
+ hessian_diagonals.append(torch.zeros_like(parameter.view(-1)))
1136
+
1137
+ return torch.cat(hessian_diagonals)
1138
+
1139
+ def _get_storage_format(self) -> StatisticStorageTypes:
1140
+ """
1141
+ :return: Storage format this observation stores data in. Must be one of the enum attributes in the
1142
+ StatisticStorageTypes enumeration class.
1143
+ """
1144
+
1145
+ return StatisticStorageTypes.FLOAT
1146
+
1147
+ def _gather(
1148
+ self,
1149
+ *,
1150
+ optimizer: optim.Optimizer,
1151
+ model: nn.Module,
1152
+ parameters: list[torch.Tensor],
1153
+ parameter_group: dict[str, Any],
1154
+ ) -> torch.Tensor | TensorStatistics | float | None:
1155
+ """
1156
+ :param optimizer: Optimizer the given parameters and parameter group came from.
1157
+ :param model: Inner model to gather statistics from.
1158
+ :param parameters: List of parameters to gather statistics from.
1159
+ :param parameter_group: Parameter group the parameters originate from.
1160
+ :return: None, TensorStatistics model or a float.
1161
+
1162
+ Computes the log of noise scale using the approximate formula:
1163
+ log(B_{noise}) ≈ log(B/ε) (move to observer) + log(tr(HΣ)) - log(tr(H^3 Σ))
1164
+ where:
1165
+ - H is the Hessian matrix
1166
+ - Σ is the noise covariance matrix
1167
+ - B is the batch size
1168
+ - ε is the learning rate
1169
+
1170
+ """
1171
+
1172
+ # Compute Hessian diagonals as in SecondOrderGradients Observation
1173
+ # hessian_diagonals = self.compute_hessian_diagonals(parameters)
1174
+ # use squared first order gradients as approximations
1175
+ fo_gradients = [
1176
+ p.grad.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p) and p.grad is not None
1177
+ ]
1178
+ if not fo_gradients:
1179
+ return None
1180
+
1181
+ hessian_diagonals = torch.cat(fo_gradients) ** 2
1182
+
1183
+ if hessian_diagonals.numel() == 0: # No gradients
1184
+ return None
1185
+
1186
+ # For noise covariance matrix Σ, we'll use the identity matrix as an approximation
1187
+ # This is a common assumption when the exact noise structure is unknown
1188
+ noise_covariance = torch.ones_like(hessian_diagonals)
1189
+
1190
+ # Compute tr(HΣ)
1191
+ trace_hessian_noise_covariance = torch.sum(hessian_diagonals * noise_covariance)
1192
+
1193
+ # Avoid division by zero and log of zero
1194
+ if trace_hessian_noise_covariance <= 0:
1195
+ return None
1196
+
1197
+ log_trace_hessian_noise_covariance = torch.log(trace_hessian_noise_covariance).item()
1198
+
1199
+ # Compute tr(H^3 Σ)
1200
+ trace_hessian_cubed_noise_covariance = torch.sum(hessian_diagonals**3 * noise_covariance)
1201
+ if trace_hessian_cubed_noise_covariance <= 0:
1202
+ return None
1203
+
1204
+ log_trace_hessian_cubed_noise_covariance = torch.log(trace_hessian_cubed_noise_covariance).item()
1205
+
1206
+ # Compute final result: log(B_{noise}) ≈ log(tr(HΣ)) - log(tr(H^3 Σ))
1207
+ # Note: log(B/ε) term is handled in the observer layer
1208
+ log_noise_scale_without_log_b_over_epsilon = (
1209
+ log_trace_hessian_noise_covariance - log_trace_hessian_cubed_noise_covariance
1210
+ )
1211
+
1212
+ return log_noise_scale_without_log_b_over_epsilon
1213
+
1214
+
1215
+ class CosineSimilarityObserverOfGradientAndMomentumStatistics(Statistic):
1216
+ """
1217
+ Statistics for the cosine similarity of gradient and momentum.
1218
+ """
1219
+
1220
+ def _get_storage_format(self) -> StatisticStorageTypes:
1221
+ """
1222
+ :return: Storage format this observation stores data in. Must be one of the enum attributes in the
1223
+ StatisticStorageTypes enumeration class.
1224
+ """
1225
+
1226
+ return StatisticStorageTypes.FLOAT
1227
+
1228
+ def _gather(
1229
+ self,
1230
+ *,
1231
+ optimizer: optim.Optimizer,
1232
+ model: nn.Module,
1233
+ parameters: list[torch.Tensor],
1234
+ parameter_group: dict[str, Any],
1235
+ ) -> torch.Tensor | TensorStatistics | float | None:
1236
+ """
1237
+ :param optimizer: Optimizer the given parameters and parameter group came from.
1238
+ :param model: Inner model to gather statistics from.
1239
+ :param parameters: List of parameters to gather statistics from.
1240
+ :param parameter_group: Parameter group the parameters originate from.
1241
+ :return: None, TensorStatistics model or a float.
1242
+ """
1243
+ parameters_with_grads = [
1244
+ p for p in parameters if observation_utils.tensor_on_local_rank(p) and p.grad is not None
1245
+ ]
1246
+
1247
+ if not parameters_with_grads:
1248
+ return None
1249
+
1250
+ gradients_list = [p.grad.view(-1) for p in parameters_with_grads]
1251
+ gradients = torch.cat(gradients_list).view(-1)
1252
+
1253
+ momentum = observation_utils.form_momentum_tensor(
1254
+ optimizer=optimizer, parameters=parameters_with_grads, parameter_group=parameter_group
1255
+ )
1256
+ if momentum is None:
1257
+ return None
1258
+
1259
+ gradients_2d = gradients.unsqueeze(0)
1260
+ momentum_2d = momentum.unsqueeze(0)
1261
+
1262
+ cosine_similarity = F.cosine_similarity(gradients_2d, momentum_2d, dim=1).item()
1263
+
1264
+ return cosine_similarity
1265
+
1266
+
1267
+ class CosineSimilarityObserverOfGradientAndUpdateStatistics(Statistic):
1268
+ """
1269
+ Statistics for the cosine similarity of gradient and update.
1270
+ """
1271
+
1272
+ def _get_storage_format(self) -> StatisticStorageTypes:
1273
+ """
1274
+ :return: Storage format this observation stores data in. Must be one of the enum attributes in the
1275
+ StatisticStorageTypes enumeration class.
1276
+ """
1277
+
1278
+ return StatisticStorageTypes.FLOAT
1279
+
1280
+ def _gather(
1281
+ self,
1282
+ *,
1283
+ optimizer: optim.Optimizer,
1284
+ model: nn.Module,
1285
+ parameters: list[torch.Tensor],
1286
+ parameter_group: dict[str, Any],
1287
+ ) -> torch.Tensor | TensorStatistics | float | None:
1288
+ """
1289
+ :param optimizer: Optimizer the given parameters and parameter group came from.
1290
+ :param model: Inner model to gather statistics from.
1291
+ :param parameters: List of parameters to gather statistics from.
1292
+ :param parameter_group: Parameter group the parameters originate from.
1293
+ :return: None, TensorStatistics model or a float.
1294
+ """
1295
+ # Filter parameters that have gradients to ensure consistent tensor sizes
1296
+ parameters_with_grads = [
1297
+ p for p in parameters if observation_utils.tensor_on_local_rank(p) and p.grad is not None
1298
+ ]
1299
+
1300
+ if not parameters_with_grads:
1301
+ return None
1302
+
1303
+ gradients_list = [p.grad.view(-1) for p in parameters_with_grads]
1304
+ gradients = torch.cat(gradients_list).view(-1)
1305
+
1306
+ update_tensor = observation_utils.form_update_tensor(
1307
+ optimizer=optimizer, parameters=parameters_with_grads, parameter_group=parameter_group
1308
+ )
1309
+
1310
+ if update_tensor is None:
1311
+ return None
1312
+
1313
+ gradients_2d = gradients.unsqueeze(0)
1314
+ update_tensor_2d = update_tensor.unsqueeze(0)
1315
+
1316
+ cosine_similarity = F.cosine_similarity(gradients_2d, update_tensor_2d, dim=1).item()
1317
+
1318
+ return cosine_similarity
1319
+
1320
+
1321
+ class CosineSimilarityOfGradientAndParameterStatistics(Statistic):
1322
+ """
1323
+ Statistics for the cosine similarity of gradient and parameter.
1324
+ """
1325
+
1326
+ def _get_storage_format(self) -> StatisticStorageTypes:
1327
+ """
1328
+ :return: Storage format this observation stores data in. Must be one of the enum attributes in the
1329
+ StatisticStorageTypes enumeration class.
1330
+ """
1331
+
1332
+ return StatisticStorageTypes.FLOAT
1333
+
1334
+ def _gather(
1335
+ self,
1336
+ *,
1337
+ optimizer: optim.Optimizer,
1338
+ model: nn.Module,
1339
+ parameters: list[torch.Tensor],
1340
+ parameter_group: dict[str, Any],
1341
+ ) -> torch.Tensor | TensorStatistics | float | None:
1342
+ """
1343
+ :param optimizer: Optimizer the given parameters and parameter group came from.
1344
+ :param model: Inner model to gather statistics from.
1345
+ :param parameters: List of parameters to gather statistics from.
1346
+ :param parameter_group: Parameter group the parameters originate from.
1347
+ :return: None, TensorStatistics model or a float.
1348
+ """
1349
+ # Filter parameters that have gradients to ensure consistent tensor sizes
1350
+ parameters_with_grads = [
1351
+ p for p in parameters if observation_utils.tensor_on_local_rank(p) and p.grad is not None
1352
+ ]
1353
+
1354
+ if not parameters_with_grads:
1355
+ return None
1356
+
1357
+ gradients_list = [p.grad.view(-1) for p in parameters_with_grads]
1358
+ gradients = torch.cat(gradients_list).view(-1)
1359
+
1360
+ parameters_list = [p.view(-1) for p in parameters_with_grads]
1361
+
1362
+ if not parameters_list:
1363
+ return None
1364
+
1365
+ parameters_tensor = torch.cat(parameters_list).view(-1)
1366
+
1367
+ gradients_2d = gradients.unsqueeze(0)
1368
+ parameters_tensor_2d = parameters_tensor.unsqueeze(0)
1369
+
1370
+ cosine_similarity = F.cosine_similarity(gradients_2d, parameters_tensor_2d, dim=1).item()
1371
+
1372
+ return cosine_similarity