libinephany 0.18.1__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,6 +11,7 @@ from typing import Any, Callable, final
11
11
  import torch
12
12
  import torch.distributed as dist
13
13
  import torch.nn as nn
14
+ import torch.nn.functional as F
14
15
  import torch.optim as optim
15
16
  from torch.distributed import ReduceOp
16
17
 
@@ -597,6 +598,58 @@ class ActivationStatistics(Statistic):
597
598
  return None
598
599
 
599
600
 
601
+ class InnerStepParameterUpdateStatistics(Statistic):
602
+
603
+ def __init__(
604
+ self,
605
+ *,
606
+ skip_statistics: list[str] | None = None,
607
+ **kwargs,
608
+ ) -> None:
609
+ """
610
+ :param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
611
+ fields in the model to not include in returned observations.
612
+ :param kwargs: Other observation keyword arguments.
613
+ """
614
+
615
+ super().__init__(**kwargs)
616
+
617
+ self.skip_statistics = skip_statistics
618
+
619
+ def _get_storage_format(self) -> StatisticStorageTypes:
620
+ """
621
+ :return: Storage format this observation stores data in. Must be one of the enum attributes in the
622
+ StatisticStorageTypes enumeration class.
623
+ """
624
+
625
+ return StatisticStorageTypes.TENSOR_STATISTICS
626
+
627
+ def _gather(
628
+ self,
629
+ *,
630
+ optimizer: optim.Optimizer,
631
+ model: nn.Module,
632
+ parameters: list[torch.Tensor],
633
+ parameter_group: dict[str, Any],
634
+ ) -> torch.Tensor | TensorStatistics | float | None:
635
+ """
636
+ :param optimizer: Optimizer the given parameters and parameter group came from.
637
+ :param model: Inner model to gather statistics from.
638
+ :param parameters: List of parameters to gather statistics from.
639
+ :param parameter_group: Parameter group the parameters originate from.
640
+ :return: None, TensorStatistics model or a float.
641
+ """
642
+
643
+ update_tensor = observation_utils.form_update_tensor(
644
+ optimizer=optimizer, parameters=parameters, parameter_group=parameter_group
645
+ )
646
+
647
+ if update_tensor is None:
648
+ update_tensor = torch.cat([torch.zeros(p.view(-1).shape, device=p.device) for p in parameters])
649
+
650
+ return update_tensor
651
+
652
+
600
653
  class ParameterUpdateStatistics(Statistic):
601
654
 
602
655
  def __init__(
@@ -694,6 +747,51 @@ class ParameterStatistics(Statistic):
694
747
  return torch.cat([p.data.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p)])
695
748
 
696
749
 
750
+ class InnerStepParameterStatistics(Statistic):
751
+
752
+ def __init__(
753
+ self,
754
+ *,
755
+ skip_statistics: list[str] | None = None,
756
+ **kwargs,
757
+ ) -> None:
758
+ """
759
+ :param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
760
+ fields in the model to not include in returned observations.
761
+ :param kwargs: Other observation keyword arguments.
762
+ """
763
+
764
+ super().__init__(**kwargs)
765
+
766
+ self.skip_statistics = skip_statistics
767
+
768
+ def _get_storage_format(self) -> StatisticStorageTypes:
769
+ """
770
+ :return: Storage format this observation stores data in. Must be one of the enum attributes in the
771
+ StatisticStorageTypes enumeration class.
772
+ """
773
+
774
+ return StatisticStorageTypes.TENSOR_STATISTICS
775
+
776
+ def _gather(
777
+ self,
778
+ *,
779
+ optimizer: optim.Optimizer,
780
+ model: nn.Module,
781
+ parameters: list[torch.Tensor],
782
+ parameter_group: dict[str, Any],
783
+ ) -> torch.Tensor | TensorStatistics | float | None:
784
+ """
785
+ :param optimizer: Optimizer the given parameters and parameter group came from.
786
+ :param model: Inner model to gather statistics from.
787
+ :param parameters: List of parameters to gather statistics from.
788
+ :param parameter_group: Parameter group the parameters originate from.
789
+ :return: None, TensorStatistics model or a float.
790
+ """
791
+
792
+ return torch.cat([p.data.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p)])
793
+
794
+
697
795
  class LAMBTrustRatioStatistics(Statistic):
698
796
 
699
797
  def __init__(
@@ -759,6 +857,71 @@ class LAMBTrustRatioStatistics(Statistic):
759
857
  return lamb_trust_ratio
760
858
 
761
859
 
860
+ class LHOPTLAMBTrustRatioStatistics(Statistic):
861
+
862
+ def __init__(
863
+ self,
864
+ *,
865
+ use_log_transform: bool = False,
866
+ **kwargs,
867
+ ) -> None:
868
+ """
869
+ :param use_log_transform: Whether to transform the LAMB trust ratio by taking ln(1 + R).
870
+ :param kwargs: Other observation keyword arguments.
871
+ """
872
+
873
+ super().__init__(**kwargs)
874
+
875
+ self.use_log_transform = use_log_transform
876
+
877
+ def _get_storage_format(self) -> StatisticStorageTypes:
878
+ """
879
+ :return: Storage format this observation stores data in. Must be one of the enum attributes in the
880
+ StatisticStorageTypes enumeration class.
881
+ """
882
+
883
+ return StatisticStorageTypes.FLOAT
884
+
885
+ def _gather(
886
+ self,
887
+ *,
888
+ optimizer: optim.Optimizer,
889
+ model: nn.Module,
890
+ parameters: list[torch.Tensor],
891
+ parameter_group: dict[str, Any],
892
+ ) -> torch.Tensor | TensorStatistics | float | None:
893
+ """
894
+ :param optimizer: Optimizer the given parameters and parameter group came from.
895
+ :param model: Inner model to gather statistics from.
896
+ :param parameters: List of parameters to gather statistics from.
897
+ :param parameter_group: Parameter group the parameters originate from.
898
+ :return: None, TensorStatistics model or a float.
899
+ """
900
+
901
+ weights_list = [p.data.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p)]
902
+ if weights_list:
903
+ weights = torch.cat(weights_list)
904
+
905
+ else:
906
+ weights = None
907
+
908
+ updates = observation_utils.form_update_tensor(
909
+ optimizer=optimizer, parameters=parameters, parameter_group=parameter_group
910
+ )
911
+
912
+ update_norm = torch.norm(updates, p=2).item() if updates is not None else 0
913
+ weight_norm = torch.norm(weights, p=2).item() if weights is not None else 0
914
+
915
+ lamb_trust_ratio = 0.0
916
+ if update_norm > 0:
917
+ lamb_trust_ratio = weight_norm / update_norm
918
+
919
+ if self.use_log_transform:
920
+ lamb_trust_ratio = math.log(1 + lamb_trust_ratio)
921
+
922
+ return lamb_trust_ratio
923
+
924
+
762
925
  class NumberOfParameters(Statistic):
763
926
 
764
927
  def __init__(
@@ -958,3 +1121,435 @@ class GradientVarianceFraction(Statistic):
958
1121
  return 0.0
959
1122
 
960
1123
  return variance_parameters / total_parameters
1124
+
1125
+
1126
+ class AverageParameterUpdateMagnitudeStatistics(Statistic):
1127
+
1128
+ def __init__(
1129
+ self,
1130
+ *,
1131
+ skip_statistics: list[str] | None = None,
1132
+ **kwargs,
1133
+ ) -> None:
1134
+ """
1135
+ :param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
1136
+ fields in the model to not include in returned observations.
1137
+ :param kwargs: Other observation keyword arguments.
1138
+ """
1139
+
1140
+ super().__init__(**kwargs)
1141
+
1142
+ self.skip_statistics = skip_statistics
1143
+
1144
+ def _get_storage_format(self) -> StatisticStorageTypes:
1145
+ """
1146
+ :return: Storage format this observation stores data in. Must be one of the enum attributes in the
1147
+ StatisticStorageTypes enumeration class.
1148
+ """
1149
+
1150
+ return StatisticStorageTypes.FLOAT
1151
+
1152
+ def _gather(
1153
+ self,
1154
+ *,
1155
+ optimizer: optim.Optimizer,
1156
+ model: nn.Module,
1157
+ parameters: list[torch.Tensor],
1158
+ parameter_group: dict[str, Any],
1159
+ ) -> torch.Tensor | TensorStatistics | float | None:
1160
+ """
1161
+ :param optimizer: Optimizer the given parameters and parameter group came from.
1162
+ :param model: Inner model to gather statistics from.
1163
+ :param parameters: List of parameters to gather statistics from.
1164
+ :param parameter_group: Parameter group the parameters originate from.
1165
+ :return: None or a float.
1166
+ """
1167
+
1168
+ update_tensor = observation_utils.form_update_tensor(
1169
+ optimizer=optimizer, parameters=parameters, parameter_group=parameter_group
1170
+ )
1171
+
1172
+ # when update tensor is none, return 0.0
1173
+ if update_tensor is None:
1174
+ return 0.0
1175
+
1176
+ update_tensor = update_tensor.view(-1)
1177
+ update_tensor = update_tensor.abs()
1178
+
1179
+ average_update_magnitude = update_tensor.mean().item()
1180
+
1181
+ return average_update_magnitude
1182
+
1183
+
1184
+ class MomentumGradientRatioStatistics(Statistic):
1185
+
1186
+ def __init__(
1187
+ self,
1188
+ *,
1189
+ skip_statistics: list[str] | None = None,
1190
+ **kwargs,
1191
+ ) -> None:
1192
+ """
1193
+ :param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
1194
+ fields in the model to not include in returned observations.
1195
+ :param kwargs: Other observation keyword arguments.
1196
+ """
1197
+
1198
+ super().__init__(**kwargs)
1199
+
1200
+ self.skip_statistics = skip_statistics
1201
+
1202
+ def _get_storage_format(self) -> StatisticStorageTypes:
1203
+ """
1204
+ :return: Storage format this observation stores data in. Must be one of the enum attributes in the
1205
+ StatisticStorageTypes enumeration class.
1206
+ """
1207
+
1208
+ return StatisticStorageTypes.FLOAT
1209
+
1210
+ def _gather(
1211
+ self,
1212
+ *,
1213
+ optimizer: optim.Optimizer,
1214
+ model: nn.Module,
1215
+ parameters: list[torch.Tensor],
1216
+ parameter_group: dict[str, Any],
1217
+ ) -> torch.Tensor | TensorStatistics | float | None:
1218
+ """
1219
+ :param optimizer: Optimizer the given parameters and parameter group came from.
1220
+ :param model: Inner model to gather statistics from.
1221
+ :param parameters: List of parameters to gather statistics from.
1222
+ :param parameter_group: Parameter group the parameters originate from.
1223
+ :return: None, TensorStatistics model or a float.
1224
+ """
1225
+
1226
+ momentum = observation_utils.form_momentum_tensor(
1227
+ optimizer=optimizer, parameters=parameters, parameter_group=parameter_group
1228
+ )
1229
+ if momentum is None:
1230
+ return None
1231
+
1232
+ gradients_list = [
1233
+ p.grad.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p) and p.grad is not None
1234
+ ]
1235
+
1236
+ # Handle empty gradients list
1237
+ if not gradients_list:
1238
+ return 0.0
1239
+
1240
+ gradients = torch.cat(gradients_list).view(-1)
1241
+
1242
+ # momentum_gradient_ratio r^t=\frac{\|g^t\|_2}{\|\nabla f(w^t)\|_2}
1243
+ gradients_norm = gradients.norm(p=2)
1244
+ momentum_norm = momentum.norm(p=2)
1245
+
1246
+ if momentum_norm == 0:
1247
+ momentum_gradient_ratio = 0.0
1248
+ else:
1249
+ momentum_gradient_ratio = (gradients_norm / momentum_norm).item()
1250
+
1251
+ return momentum_gradient_ratio
1252
+
1253
+
1254
+ class LogOfNoiseScaleStatistics(Statistic):
1255
+ """
1256
+ Statistics for the log of noise scale in training.
1257
+
1258
+ Tracks the log of noise scale B_{noise} using the formula:
1259
+ B_{noise} = tr(ΣH) / (G^T H G) ≈ (B/ε) * tr(HΣ) / tr(H^3 Σ)
1260
+ where:
1261
+ - H is the Hessian matrix
1262
+ - G is the gradient vector
1263
+ - Σ is the noise covariance matrix
1264
+ - B is the batch size
1265
+ - ε is the learning rate
1266
+
1267
+
1268
+ """
1269
+
1270
+ def __init__(
1271
+ self,
1272
+ *,
1273
+ skip_statistics: list[str] | None = None,
1274
+ **kwargs,
1275
+ ) -> None:
1276
+ """
1277
+ :param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
1278
+ fields in the model to not include in returned observations.
1279
+ :param kwargs: Other observation keyword arguments.
1280
+ """
1281
+
1282
+ super().__init__(**kwargs)
1283
+
1284
+ self.skip_statistics = skip_statistics
1285
+
1286
+ @property
1287
+ def requires_gradient_graphs(self) -> bool:
1288
+ """
1289
+ :return: Whether the statistic requires gradient graphs to be retained.
1290
+ """
1291
+
1292
+ return False
1293
+
1294
+ @staticmethod
1295
+ def compute_hessian_diagonals(parameters: list[torch.Tensor]) -> torch.Tensor:
1296
+ """
1297
+ :param parameters: Parameters to compute the hessian diagonal matrices for.
1298
+ :return: Tensor containing the hessian diagonal matrices for all given parameters.
1299
+ """
1300
+
1301
+ hessian_diagonals = []
1302
+
1303
+ for parameter in parameters:
1304
+ if parameter.grad is not None:
1305
+ so_gradient = torch.autograd.grad(
1306
+ outputs=parameter.grad.clone(),
1307
+ inputs=parameter,
1308
+ grad_outputs=torch.ones_like(parameter.grad, requires_grad=True),
1309
+ only_inputs=True,
1310
+ retain_graph=True,
1311
+ create_graph=True,
1312
+ allow_unused=True,
1313
+ )[0]
1314
+
1315
+ if so_gradient is not None:
1316
+ hessian_diagonals.append(so_gradient.view(-1))
1317
+ else:
1318
+ hessian_diagonals.append(torch.zeros_like(parameter.view(-1)))
1319
+
1320
+ return torch.cat(hessian_diagonals)
1321
+
1322
+ def _get_storage_format(self) -> StatisticStorageTypes:
1323
+ """
1324
+ :return: Storage format this observation stores data in. Must be one of the enum attributes in the
1325
+ StatisticStorageTypes enumeration class.
1326
+ """
1327
+
1328
+ return StatisticStorageTypes.FLOAT
1329
+
1330
+ def _gather(
1331
+ self,
1332
+ *,
1333
+ optimizer: optim.Optimizer,
1334
+ model: nn.Module,
1335
+ parameters: list[torch.Tensor],
1336
+ parameter_group: dict[str, Any],
1337
+ ) -> torch.Tensor | TensorStatistics | float | None:
1338
+ """
1339
+ :param optimizer: Optimizer the given parameters and parameter group came from.
1340
+ :param model: Inner model to gather statistics from.
1341
+ :param parameters: List of parameters to gather statistics from.
1342
+ :param parameter_group: Parameter group the parameters originate from.
1343
+ :return: None, TensorStatistics model or a float.
1344
+
1345
+ Computes the log of noise scale using the approximate formula:
1346
+ log(B_{noise}) ≈ log(B/ε) (move to observer) + log(tr(HΣ)) - log(tr(H^3 Σ))
1347
+ where:
1348
+ - H is the Hessian matrix
1349
+ - Σ is the noise covariance matrix
1350
+ - B is the batch size
1351
+ - ε is the learning rate
1352
+
1353
+ """
1354
+
1355
+ # Compute Hessian diagonals as in SecondOrderGradients Observation
1356
+ # hessian_diagonals = self.compute_hessian_diagonals(parameters)
1357
+ # use squared first order gradients as approximations
1358
+ fo_gradients = [
1359
+ p.grad.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p) and p.grad is not None
1360
+ ]
1361
+ if not fo_gradients:
1362
+ return None
1363
+
1364
+ hessian_diagonals = torch.cat(fo_gradients) ** 2
1365
+
1366
+ if hessian_diagonals.numel() == 0: # No gradients
1367
+ return None
1368
+
1369
+ # For noise covariance matrix Σ, we'll use the identity matrix as an approximation
1370
+ # This is a common assumption when the exact noise structure is unknown
1371
+ noise_covariance = torch.ones_like(hessian_diagonals)
1372
+
1373
+ # Compute tr(HΣ)
1374
+ trace_hessian_noise_covariance = torch.sum(hessian_diagonals * noise_covariance)
1375
+
1376
+ # Avoid division by zero and log of zero
1377
+ if trace_hessian_noise_covariance <= 0:
1378
+ return None
1379
+
1380
+ log_trace_hessian_noise_covariance = torch.log(trace_hessian_noise_covariance).item()
1381
+
1382
+ # Compute tr(H^3 Σ)
1383
+ trace_hessian_cubed_noise_covariance = torch.sum(hessian_diagonals**3 * noise_covariance)
1384
+ if trace_hessian_cubed_noise_covariance <= 0:
1385
+ return None
1386
+
1387
+ log_trace_hessian_cubed_noise_covariance = torch.log(trace_hessian_cubed_noise_covariance).item()
1388
+
1389
+ # Compute final result: log(B_{noise}) ≈ log(tr(HΣ)) - log(tr(H^3 Σ))
1390
+ # Note: log(B/ε) term is handled in the observer layer
1391
+ log_noise_scale_without_log_b_over_epsilon = (
1392
+ log_trace_hessian_noise_covariance - log_trace_hessian_cubed_noise_covariance
1393
+ )
1394
+
1395
+ return log_noise_scale_without_log_b_over_epsilon
1396
+
1397
+
1398
+ class CosineSimilarityObserverOfGradientAndMomentumStatistics(Statistic):
1399
+ """
1400
+ Statistics for the cosine similarity of gradient and momentum.
1401
+ """
1402
+
1403
+ def _get_storage_format(self) -> StatisticStorageTypes:
1404
+ """
1405
+ :return: Storage format this observation stores data in. Must be one of the enum attributes in the
1406
+ StatisticStorageTypes enumeration class.
1407
+ """
1408
+
1409
+ return StatisticStorageTypes.FLOAT
1410
+
1411
+ def _gather(
1412
+ self,
1413
+ *,
1414
+ optimizer: optim.Optimizer,
1415
+ model: nn.Module,
1416
+ parameters: list[torch.Tensor],
1417
+ parameter_group: dict[str, Any],
1418
+ ) -> torch.Tensor | TensorStatistics | float | None:
1419
+ """
1420
+ :param optimizer: Optimizer the given parameters and parameter group came from.
1421
+ :param model: Inner model to gather statistics from.
1422
+ :param parameters: List of parameters to gather statistics from.
1423
+ :param parameter_group: Parameter group the parameters originate from.
1424
+ :return: None, TensorStatistics model or a float.
1425
+ """
1426
+ parameters_with_grads = [
1427
+ p for p in parameters if observation_utils.tensor_on_local_rank(p) and p.grad is not None
1428
+ ]
1429
+
1430
+ if not parameters_with_grads:
1431
+ return None
1432
+
1433
+ gradients_list = [p.grad.view(-1) for p in parameters_with_grads]
1434
+ gradients = torch.cat(gradients_list).view(-1)
1435
+
1436
+ momentum = observation_utils.form_momentum_tensor(
1437
+ optimizer=optimizer, parameters=parameters_with_grads, parameter_group=parameter_group
1438
+ )
1439
+ if momentum is None:
1440
+ return None
1441
+
1442
+ gradients_2d = gradients.unsqueeze(0)
1443
+ momentum_2d = momentum.unsqueeze(0)
1444
+
1445
+ cosine_similarity = F.cosine_similarity(gradients_2d, momentum_2d, dim=1).item()
1446
+
1447
+ return cosine_similarity
1448
+
1449
+
1450
+ class CosineSimilarityObserverOfGradientAndUpdateStatistics(Statistic):
1451
+ """
1452
+ Statistics for the cosine similarity of gradient and update.
1453
+ """
1454
+
1455
+ def _get_storage_format(self) -> StatisticStorageTypes:
1456
+ """
1457
+ :return: Storage format this observation stores data in. Must be one of the enum attributes in the
1458
+ StatisticStorageTypes enumeration class.
1459
+ """
1460
+
1461
+ return StatisticStorageTypes.FLOAT
1462
+
1463
+ def _gather(
1464
+ self,
1465
+ *,
1466
+ optimizer: optim.Optimizer,
1467
+ model: nn.Module,
1468
+ parameters: list[torch.Tensor],
1469
+ parameter_group: dict[str, Any],
1470
+ ) -> torch.Tensor | TensorStatistics | float | None:
1471
+ """
1472
+ :param optimizer: Optimizer the given parameters and parameter group came from.
1473
+ :param model: Inner model to gather statistics from.
1474
+ :param parameters: List of parameters to gather statistics from.
1475
+ :param parameter_group: Parameter group the parameters originate from.
1476
+ :return: None, TensorStatistics model or a float.
1477
+ """
1478
+ # Filter parameters that have gradients to ensure consistent tensor sizes
1479
+ parameters_with_grads = [
1480
+ p for p in parameters if observation_utils.tensor_on_local_rank(p) and p.grad is not None
1481
+ ]
1482
+
1483
+ if not parameters_with_grads:
1484
+ return None
1485
+
1486
+ gradients_list = [p.grad.view(-1) for p in parameters_with_grads]
1487
+ gradients = torch.cat(gradients_list).view(-1)
1488
+
1489
+ update_tensor = observation_utils.form_update_tensor(
1490
+ optimizer=optimizer, parameters=parameters_with_grads, parameter_group=parameter_group
1491
+ )
1492
+
1493
+ if update_tensor is None:
1494
+ return None
1495
+
1496
+ gradients_2d = gradients.unsqueeze(0)
1497
+ update_tensor_2d = update_tensor.unsqueeze(0)
1498
+
1499
+ cosine_similarity = F.cosine_similarity(gradients_2d, update_tensor_2d, dim=1).item()
1500
+
1501
+ return cosine_similarity
1502
+
1503
+
1504
+ class CosineSimilarityOfGradientAndParameterStatistics(Statistic):
1505
+ """
1506
+ Statistics for the cosine similarity of gradient and parameter.
1507
+ """
1508
+
1509
+ def _get_storage_format(self) -> StatisticStorageTypes:
1510
+ """
1511
+ :return: Storage format this observation stores data in. Must be one of the enum attributes in the
1512
+ StatisticStorageTypes enumeration class.
1513
+ """
1514
+
1515
+ return StatisticStorageTypes.FLOAT
1516
+
1517
+ def _gather(
1518
+ self,
1519
+ *,
1520
+ optimizer: optim.Optimizer,
1521
+ model: nn.Module,
1522
+ parameters: list[torch.Tensor],
1523
+ parameter_group: dict[str, Any],
1524
+ ) -> torch.Tensor | TensorStatistics | float | None:
1525
+ """
1526
+ :param optimizer: Optimizer the given parameters and parameter group came from.
1527
+ :param model: Inner model to gather statistics from.
1528
+ :param parameters: List of parameters to gather statistics from.
1529
+ :param parameter_group: Parameter group the parameters originate from.
1530
+ :return: None, TensorStatistics model or a float.
1531
+ """
1532
+ # Filter parameters that have gradients to ensure consistent tensor sizes
1533
+ parameters_with_grads = [
1534
+ p for p in parameters if observation_utils.tensor_on_local_rank(p) and p.grad is not None
1535
+ ]
1536
+
1537
+ if not parameters_with_grads:
1538
+ return None
1539
+
1540
+ gradients_list = [p.grad.view(-1) for p in parameters_with_grads]
1541
+ gradients = torch.cat(gradients_list).view(-1)
1542
+
1543
+ parameters_list = [p.view(-1) for p in parameters_with_grads]
1544
+
1545
+ if not parameters_list:
1546
+ return None
1547
+
1548
+ parameters_tensor = torch.cat(parameters_list).view(-1)
1549
+
1550
+ gradients_2d = gradients.unsqueeze(0)
1551
+ parameters_tensor_2d = parameters_tensor.unsqueeze(0)
1552
+
1553
+ cosine_similarity = F.cosine_similarity(gradients_2d, parameters_tensor_2d, dim=1).item()
1554
+
1555
+ return cosine_similarity
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: libinephany
3
- Version: 0.18.1
3
+ Version: 0.19.0
4
4
  Summary: Inephany library containing code commonly used by multiple subpackages.
5
5
  Author-email: Inephany <info@inephany.com>
6
6
  License: Apache 2.0
@@ -2,22 +2,22 @@ libinephany/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  libinephany/aws/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  libinephany/aws/s3_functions.py,sha256=W8u85A6tDloo4FlJvydJbVHCUq_m9i8KDGdnKzy-Xpg,1745
4
4
  libinephany/observations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- libinephany/observations/observation_utils.py,sha256=ejQ-hzKq_MX7A0304KypScmzloReNrH8dmcgonoRl4Q,16266
5
+ libinephany/observations/observation_utils.py,sha256=JSNJYEi2d-VQ0ZovfHrn28RDv41u-a6M-W4ZT8UUyhI,17279
6
6
  libinephany/observations/observer_pipeline.py,sha256=_xA4vrijhG8-9MCtGXnKAEmpd6q0nKVpJgY_qSbypIA,12979
7
7
  libinephany/observations/pipeline_coordinator.py,sha256=mLfaHhkXVhMp9w5jWIAL3jPyauCM-795qOzyqwGOSdw,7932
8
8
  libinephany/observations/statistic_manager.py,sha256=LLg1zSxnJr2oQQepYla3qoUuRy10rsthr9jta4wEbnc,8956
9
- libinephany/observations/statistic_trackers.py,sha256=06NjCy1MI865oU0KB5f-wQE3b2RvYawOOWxNJx4rFpw,32939
9
+ libinephany/observations/statistic_trackers.py,sha256=Z0JCJouVaPRUdR9o8GbD4Vm72aRREuxUIqGY8zgFUWg,53622
10
10
  libinephany/observations/observers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  libinephany/observations/observers/base_observers.py,sha256=Tkk2AvQ6mRY6fdhhxHQWohA0xSHGdPqxul1C7C7Frj4,15924
12
- libinephany/observations/observers/local_observers.py,sha256=7azTW227-rG_QJR5a0StfE0O4Ca1boMV9nrCphnhWik,40344
12
+ libinephany/observations/observers/local_observers.py,sha256=9zWPizg6LuuembSibX67azIF12_Mr07EIzx_6d9ANVU,45290
13
13
  libinephany/observations/observers/observer_containers.py,sha256=VNyqGgxYJ4r49Msp_kk-POgicb-_5w54twuT1qfNdxw,9562
14
- libinephany/observations/observers/global_observers/__init__.py,sha256=lQO-nq7hILu4F3ddFXcCR-ghfv0dzEA9nAYxZta7rxk,2306
14
+ libinephany/observations/observers/global_observers/__init__.py,sha256=87WHRPYmL0tVsaTKUd91pwEpCZtHPSKRQoba2VQjswA,3018
15
15
  libinephany/observations/observers/global_observers/base_classes.py,sha256=CCkRx86Lll3gFzfqervP0jKdzNFKkKU7tEBh8ic1Yrc,8249
16
- libinephany/observations/observers/global_observers/constants.py,sha256=olXYxh353Th6hyhfk85kHxAWPeaDpuPkcz_awFvEz6c,1054
17
- libinephany/observations/observers/global_observers/gradient_observers.py,sha256=z2ow6zmTr7ujeaYyh5qGRE1woIt_yc2uxAaie7zgxqc,8398
16
+ libinephany/observations/observers/global_observers/constants.py,sha256=C_PwYhKxatJxNe5Jzb1tpoiRXAxxPrGkcdQBMQD8msY,1139
17
+ libinephany/observations/observers/global_observers/gradient_observers.py,sha256=6L9Hw6K--vTRe0Zp4dQ9ByOLwU6p9Z8vhMYcTQVA4IU,21042
18
18
  libinephany/observations/observers/global_observers/hyperparameter_observers.py,sha256=soGWoYpO5rqUQU0p4pr6QVjBvVg2odK71mVkxV38Ras,14838
19
19
  libinephany/observations/observers/global_observers/loss_observers.py,sha256=FlSuJqAJIXcAS_ypdZna6xxz89glI23A6D00sDn7ZLU,18508
20
- libinephany/observations/observers/global_observers/model_observers.py,sha256=bJIEdq5wWkLrw-MNllCe36FE4aMLbC4RoR1-wOOHMxc,19537
20
+ libinephany/observations/observers/global_observers/model_observers.py,sha256=xNqNWz2YUATakm_BuHXCGtGCeYodpyeupa8yxLo-TAA,28433
21
21
  libinephany/observations/observers/global_observers/progress_observers.py,sha256=m62jUiwPaOUzYG1h7Vg6znj_jK9699lDhg4AhK212s8,5615
22
22
  libinephany/observations/post_processors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
23
  libinephany/observations/post_processors/postprocessors.py,sha256=43_e5UaDPr2KbAvqc_w3wLqnlm7bgRjqgCtyQ95-8cM,5913
@@ -57,8 +57,8 @@ libinephany/utils/typing.py,sha256=rGbaPO3MaUndsWiC_wHzReD_TOLYqb43i01pKN-j7Xs,6
57
57
  libinephany/web_apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
58
58
  libinephany/web_apps/error_logger.py,sha256=gAQIaqerqP4ornXZwFF1cghjnd2mMZEt3aVrTuUCr34,16653
59
59
  libinephany/web_apps/web_app_utils.py,sha256=qiq_lasPipgN1RgRudPJc342kYci8O_4RqppxmIX8NY,4095
60
- libinephany-0.18.1.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
61
- libinephany-0.18.1.dist-info/METADATA,sha256=d2rH-yA7F1cVk2OnTyvdaQ6_CFey1uiY4ATbqsxH9Pg,8390
62
- libinephany-0.18.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
63
- libinephany-0.18.1.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
64
- libinephany-0.18.1.dist-info/RECORD,,
60
+ libinephany-0.19.0.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
61
+ libinephany-0.19.0.dist-info/METADATA,sha256=iojcQg83wJVyPzuFssPtW_Maz0wO2JGFXCLcolBsnfQ,8390
62
+ libinephany-0.19.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
63
+ libinephany-0.19.0.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
64
+ libinephany-0.19.0.dist-info/RECORD,,