libinephany 0.19.0__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- libinephany/observations/observers/base_observers.py +20 -8
- libinephany/observations/observers/global_observers/gradient_observers.py +15 -16
- libinephany/observations/observers/global_observers/hyperparameter_observers.py +26 -18
- libinephany/observations/observers/global_observers/model_observers.py +18 -20
- libinephany/observations/observers/global_observers/progress_observers.py +7 -1
- libinephany/observations/observers/local_observers.py +35 -28
- libinephany/observations/statistic_trackers.py +52 -237
- libinephany/pydantic_models/schemas/tensor_statistics.py +33 -32
- libinephany/pydantic_models/states/hyperparameter_states.py +32 -30
- libinephany/utils/enums.py +6 -0
- {libinephany-0.19.0.dist-info → libinephany-1.0.1.dist-info}/METADATA +1 -1
- {libinephany-0.19.0.dist-info → libinephany-1.0.1.dist-info}/RECORD +15 -15
- {libinephany-0.19.0.dist-info → libinephany-1.0.1.dist-info}/WHEEL +0 -0
- {libinephany-0.19.0.dist-info → libinephany-1.0.1.dist-info}/licenses/LICENSE +0 -0
- {libinephany-0.19.0.dist-info → libinephany-1.0.1.dist-info}/top_level.txt +0 -0
@@ -76,7 +76,7 @@ class Statistic(ABC):
|
|
76
76
|
self.max_cache_size = max_statistic_cache_size
|
77
77
|
self.downsample_percent = tensor_stats_downsample_percentage
|
78
78
|
self.sample_frequency = statistic_sample_frequency
|
79
|
-
self.
|
79
|
+
self.include_statistics: list[str] | None = None
|
80
80
|
|
81
81
|
@final
|
82
82
|
@property
|
@@ -195,12 +195,17 @@ class Statistic(ABC):
|
|
195
195
|
Processes the tensor cache to build a TensorStatistic model.
|
196
196
|
"""
|
197
197
|
|
198
|
+
if not self.include_statistics:
|
199
|
+
raise ValueError(f"{self.__class__.__name__} must be provided with include_statistics.")
|
200
|
+
|
198
201
|
if self._tensor_cache:
|
199
202
|
concatenated = torch.cat(self._tensor_cache)
|
200
203
|
self._tensor_cache = []
|
201
204
|
|
202
205
|
statistics = TensorStatistics.build(
|
203
|
-
tensor=concatenated,
|
206
|
+
tensor=concatenated,
|
207
|
+
include_statistics=self.include_statistics,
|
208
|
+
sample_percentage=self.downsample_percent,
|
204
209
|
)
|
205
210
|
self._data.append(statistics) # type: ignore
|
206
211
|
|
@@ -368,18 +373,18 @@ class FirstOrderGradients(Statistic):
|
|
368
373
|
def __init__(
|
369
374
|
self,
|
370
375
|
*,
|
371
|
-
|
376
|
+
include_statistics: list[str] | None = None,
|
372
377
|
**kwargs,
|
373
378
|
) -> None:
|
374
379
|
"""
|
375
|
-
:param
|
376
|
-
fields in the model to
|
380
|
+
:param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
381
|
+
fields in the model to include in returned observations.
|
377
382
|
:param kwargs: Other observation keyword arguments.
|
378
383
|
"""
|
379
384
|
|
380
385
|
super().__init__(**kwargs)
|
381
386
|
|
382
|
-
self.
|
387
|
+
self.include_statistics = include_statistics
|
383
388
|
|
384
389
|
def _get_storage_format(self) -> StatisticStorageTypes:
|
385
390
|
"""
|
@@ -422,22 +427,22 @@ class SecondOrderGradients(Statistic):
|
|
422
427
|
def __init__(
|
423
428
|
self,
|
424
429
|
*,
|
430
|
+
include_statistics: list[str] | None = None,
|
425
431
|
compute_hessian_diagonal: bool = False,
|
426
|
-
skip_statistics: list[str] | None = None,
|
427
432
|
**kwargs,
|
428
433
|
) -> None:
|
429
434
|
"""
|
435
|
+
:param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
436
|
+
fields in the model to include in returned observations.
|
430
437
|
:param compute_hessian_diagonal: Whether to compute the Hessian diagonal to determine second order gradients
|
431
438
|
or use the squared first order gradients as approximations in the same way Adam does.
|
432
|
-
:param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
433
|
-
fields in the model to not include in returned observations.
|
434
439
|
:param kwargs: Other observation keyword arguments.
|
435
440
|
"""
|
436
441
|
|
437
442
|
super().__init__(**kwargs)
|
438
443
|
|
439
444
|
self.compute_hessian_diagonal = compute_hessian_diagonal
|
440
|
-
self.
|
445
|
+
self.include_statistics = include_statistics
|
441
446
|
|
442
447
|
@property
|
443
448
|
def requires_gradient_graphs(self) -> bool:
|
@@ -520,18 +525,18 @@ class ActivationStatistics(Statistic):
|
|
520
525
|
def __init__(
|
521
526
|
self,
|
522
527
|
*,
|
523
|
-
|
528
|
+
include_statistics: list[str] | None = None,
|
524
529
|
**kwargs,
|
525
530
|
) -> None:
|
526
531
|
"""
|
527
|
-
:param
|
528
|
-
fields in the model to
|
532
|
+
:param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
533
|
+
fields in the model to include in returned observations.
|
529
534
|
:param kwargs: Other observation keyword arguments.
|
530
535
|
"""
|
531
536
|
|
532
537
|
super().__init__(**kwargs)
|
533
538
|
|
534
|
-
self.
|
539
|
+
self.include_statistics = include_statistics
|
535
540
|
|
536
541
|
@property
|
537
542
|
def uses_forward_hook(self) -> bool:
|
@@ -554,6 +559,9 @@ class ActivationStatistics(Statistic):
|
|
554
559
|
:return: Forward hook to register the function with.
|
555
560
|
"""
|
556
561
|
|
562
|
+
if self.include_statistics is None:
|
563
|
+
raise ValueError("include_statistics is required to use forward hooks!")
|
564
|
+
|
557
565
|
def hook(module: nn.Module, layer_input: torch.Tensor, layer_output: torch.Tensor) -> None:
|
558
566
|
"""
|
559
567
|
:param module: Module the hook was registered with. Not used here.
|
@@ -563,7 +571,9 @@ class ActivationStatistics(Statistic):
|
|
563
571
|
|
564
572
|
if self._sample_number % self.sample_frequency == 0:
|
565
573
|
statistics = TensorStatistics.build(
|
566
|
-
tensor=layer_output,
|
574
|
+
tensor=layer_output,
|
575
|
+
include_statistics=self.include_statistics,
|
576
|
+
sample_percentage=self.downsample_percent,
|
567
577
|
)
|
568
578
|
self._data.append(statistics) # type: ignore
|
569
579
|
|
@@ -598,23 +608,23 @@ class ActivationStatistics(Statistic):
|
|
598
608
|
return None
|
599
609
|
|
600
610
|
|
601
|
-
class
|
611
|
+
class ParameterUpdateStatistics(Statistic):
|
602
612
|
|
603
613
|
def __init__(
|
604
614
|
self,
|
605
615
|
*,
|
606
|
-
|
616
|
+
include_statistics: list[str] | None = None,
|
607
617
|
**kwargs,
|
608
618
|
) -> None:
|
609
619
|
"""
|
610
|
-
:param
|
611
|
-
fields in the model to
|
620
|
+
:param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
621
|
+
fields in the model to include in returned observations.
|
612
622
|
:param kwargs: Other observation keyword arguments.
|
613
623
|
"""
|
614
624
|
|
615
625
|
super().__init__(**kwargs)
|
616
626
|
|
617
|
-
self.
|
627
|
+
self.include_statistics = include_statistics
|
618
628
|
|
619
629
|
def _get_storage_format(self) -> StatisticStorageTypes:
|
620
630
|
"""
|
@@ -650,56 +660,9 @@ class InnerStepParameterUpdateStatistics(Statistic):
|
|
650
660
|
return update_tensor
|
651
661
|
|
652
662
|
|
653
|
-
class ParameterUpdateStatistics
|
654
|
-
|
655
|
-
def __init__(
|
656
|
-
self,
|
657
|
-
*,
|
658
|
-
skip_statistics: list[str] | None = None,
|
659
|
-
**kwargs,
|
660
|
-
) -> None:
|
661
|
-
"""
|
662
|
-
:param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
663
|
-
fields in the model to not include in returned observations.
|
664
|
-
:param kwargs: Other observation keyword arguments.
|
665
|
-
"""
|
666
|
-
|
667
|
-
super().__init__(**kwargs)
|
668
|
-
|
669
|
-
self.skip_statistics = skip_statistics
|
670
|
-
|
671
|
-
def _get_storage_format(self) -> StatisticStorageTypes:
|
672
|
-
"""
|
673
|
-
:return: Storage format this observation stores data in. Must be one of the enum attributes in the
|
674
|
-
StatisticStorageTypes enumeration class.
|
675
|
-
"""
|
676
|
-
|
677
|
-
return StatisticStorageTypes.TENSOR_STATISTICS
|
678
|
-
|
679
|
-
def _gather(
|
680
|
-
self,
|
681
|
-
*,
|
682
|
-
optimizer: optim.Optimizer,
|
683
|
-
model: nn.Module,
|
684
|
-
parameters: list[torch.Tensor],
|
685
|
-
parameter_group: dict[str, Any],
|
686
|
-
) -> torch.Tensor | TensorStatistics | float | None:
|
687
|
-
"""
|
688
|
-
:param optimizer: Optimizer the given parameters and parameter group came from.
|
689
|
-
:param model: Inner model to gather statistics from.
|
690
|
-
:param parameters: List of parameters to gather statistics from.
|
691
|
-
:param parameter_group: Parameter group the parameters originate from.
|
692
|
-
:return: None, TensorStatistics model or a float.
|
693
|
-
"""
|
694
|
-
|
695
|
-
update_tensor = observation_utils.form_update_tensor(
|
696
|
-
optimizer=optimizer, parameters=parameters, parameter_group=parameter_group
|
697
|
-
)
|
698
|
-
|
699
|
-
if update_tensor is None:
|
700
|
-
update_tensor = torch.cat([torch.zeros(p.view(-1).shape, device=p.device) for p in parameters])
|
663
|
+
class LHOPTParameterUpdateStatistics(ParameterUpdateStatistics):
|
701
664
|
|
702
|
-
|
665
|
+
pass
|
703
666
|
|
704
667
|
|
705
668
|
class ParameterStatistics(Statistic):
|
@@ -707,18 +670,18 @@ class ParameterStatistics(Statistic):
|
|
707
670
|
def __init__(
|
708
671
|
self,
|
709
672
|
*,
|
710
|
-
|
673
|
+
include_statistics: list[str] | None = None,
|
711
674
|
**kwargs,
|
712
675
|
) -> None:
|
713
676
|
"""
|
714
|
-
:param
|
715
|
-
fields in the model to
|
677
|
+
:param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
678
|
+
fields in the model to include in returned observations.
|
716
679
|
:param kwargs: Other observation keyword arguments.
|
717
680
|
"""
|
718
681
|
|
719
682
|
super().__init__(**kwargs)
|
720
683
|
|
721
|
-
self.
|
684
|
+
self.include_statistics = include_statistics
|
722
685
|
|
723
686
|
def _get_storage_format(self) -> StatisticStorageTypes:
|
724
687
|
"""
|
@@ -747,49 +710,9 @@ class ParameterStatistics(Statistic):
|
|
747
710
|
return torch.cat([p.data.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p)])
|
748
711
|
|
749
712
|
|
750
|
-
class
|
713
|
+
class LHOPTParameterStatistics(ParameterStatistics):
|
751
714
|
|
752
|
-
|
753
|
-
self,
|
754
|
-
*,
|
755
|
-
skip_statistics: list[str] | None = None,
|
756
|
-
**kwargs,
|
757
|
-
) -> None:
|
758
|
-
"""
|
759
|
-
:param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
760
|
-
fields in the model to not include in returned observations.
|
761
|
-
:param kwargs: Other observation keyword arguments.
|
762
|
-
"""
|
763
|
-
|
764
|
-
super().__init__(**kwargs)
|
765
|
-
|
766
|
-
self.skip_statistics = skip_statistics
|
767
|
-
|
768
|
-
def _get_storage_format(self) -> StatisticStorageTypes:
|
769
|
-
"""
|
770
|
-
:return: Storage format this observation stores data in. Must be one of the enum attributes in the
|
771
|
-
StatisticStorageTypes enumeration class.
|
772
|
-
"""
|
773
|
-
|
774
|
-
return StatisticStorageTypes.TENSOR_STATISTICS
|
775
|
-
|
776
|
-
def _gather(
|
777
|
-
self,
|
778
|
-
*,
|
779
|
-
optimizer: optim.Optimizer,
|
780
|
-
model: nn.Module,
|
781
|
-
parameters: list[torch.Tensor],
|
782
|
-
parameter_group: dict[str, Any],
|
783
|
-
) -> torch.Tensor | TensorStatistics | float | None:
|
784
|
-
"""
|
785
|
-
:param optimizer: Optimizer the given parameters and parameter group came from.
|
786
|
-
:param model: Inner model to gather statistics from.
|
787
|
-
:param parameters: List of parameters to gather statistics from.
|
788
|
-
:param parameter_group: Parameter group the parameters originate from.
|
789
|
-
:return: None, TensorStatistics model or a float.
|
790
|
-
"""
|
791
|
-
|
792
|
-
return torch.cat([p.data.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p)])
|
715
|
+
pass
|
793
716
|
|
794
717
|
|
795
718
|
class LAMBTrustRatioStatistics(Statistic):
|
@@ -797,16 +720,20 @@ class LAMBTrustRatioStatistics(Statistic):
|
|
797
720
|
def __init__(
|
798
721
|
self,
|
799
722
|
*,
|
723
|
+
include_statistics: list[str] | None = None,
|
800
724
|
use_log_transform: bool = False,
|
801
725
|
**kwargs,
|
802
726
|
) -> None:
|
803
727
|
"""
|
728
|
+
:param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
729
|
+
fields in the model to include in returned observations.
|
804
730
|
:param use_log_transform: Whether to transform the LAMB trust ratio by taking ln(1 + R).
|
805
731
|
:param kwargs: Other observation keyword arguments.
|
806
732
|
"""
|
807
733
|
|
808
734
|
super().__init__(**kwargs)
|
809
735
|
|
736
|
+
self.include_statistics = include_statistics
|
810
737
|
self.use_log_transform = use_log_transform
|
811
738
|
|
812
739
|
def _get_storage_format(self) -> StatisticStorageTypes:
|
@@ -857,69 +784,9 @@ class LAMBTrustRatioStatistics(Statistic):
|
|
857
784
|
return lamb_trust_ratio
|
858
785
|
|
859
786
|
|
860
|
-
class LHOPTLAMBTrustRatioStatistics(
|
787
|
+
class LHOPTLAMBTrustRatioStatistics(LAMBTrustRatioStatistics):
|
861
788
|
|
862
|
-
|
863
|
-
self,
|
864
|
-
*,
|
865
|
-
use_log_transform: bool = False,
|
866
|
-
**kwargs,
|
867
|
-
) -> None:
|
868
|
-
"""
|
869
|
-
:param use_log_transform: Whether to transform the LAMB trust ratio by taking ln(1 + R).
|
870
|
-
:param kwargs: Other observation keyword arguments.
|
871
|
-
"""
|
872
|
-
|
873
|
-
super().__init__(**kwargs)
|
874
|
-
|
875
|
-
self.use_log_transform = use_log_transform
|
876
|
-
|
877
|
-
def _get_storage_format(self) -> StatisticStorageTypes:
|
878
|
-
"""
|
879
|
-
:return: Storage format this observation stores data in. Must be one of the enum attributes in the
|
880
|
-
StatisticStorageTypes enumeration class.
|
881
|
-
"""
|
882
|
-
|
883
|
-
return StatisticStorageTypes.FLOAT
|
884
|
-
|
885
|
-
def _gather(
|
886
|
-
self,
|
887
|
-
*,
|
888
|
-
optimizer: optim.Optimizer,
|
889
|
-
model: nn.Module,
|
890
|
-
parameters: list[torch.Tensor],
|
891
|
-
parameter_group: dict[str, Any],
|
892
|
-
) -> torch.Tensor | TensorStatistics | float | None:
|
893
|
-
"""
|
894
|
-
:param optimizer: Optimizer the given parameters and parameter group came from.
|
895
|
-
:param model: Inner model to gather statistics from.
|
896
|
-
:param parameters: List of parameters to gather statistics from.
|
897
|
-
:param parameter_group: Parameter group the parameters originate from.
|
898
|
-
:return: None, TensorStatistics model or a float.
|
899
|
-
"""
|
900
|
-
|
901
|
-
weights_list = [p.data.view(-1) for p in parameters if observation_utils.tensor_on_local_rank(p)]
|
902
|
-
if weights_list:
|
903
|
-
weights = torch.cat(weights_list)
|
904
|
-
|
905
|
-
else:
|
906
|
-
weights = None
|
907
|
-
|
908
|
-
updates = observation_utils.form_update_tensor(
|
909
|
-
optimizer=optimizer, parameters=parameters, parameter_group=parameter_group
|
910
|
-
)
|
911
|
-
|
912
|
-
update_norm = torch.norm(updates, p=2).item() if updates is not None else 0
|
913
|
-
weight_norm = torch.norm(weights, p=2).item() if weights is not None else 0
|
914
|
-
|
915
|
-
lamb_trust_ratio = 0.0
|
916
|
-
if update_norm > 0:
|
917
|
-
lamb_trust_ratio = weight_norm / update_norm
|
918
|
-
|
919
|
-
if self.use_log_transform:
|
920
|
-
lamb_trust_ratio = math.log(1 + lamb_trust_ratio)
|
921
|
-
|
922
|
-
return lamb_trust_ratio
|
789
|
+
pass
|
923
790
|
|
924
791
|
|
925
792
|
class NumberOfParameters(Statistic):
|
@@ -1125,22 +992,6 @@ class GradientVarianceFraction(Statistic):
|
|
1125
992
|
|
1126
993
|
class AverageParameterUpdateMagnitudeStatistics(Statistic):
|
1127
994
|
|
1128
|
-
def __init__(
|
1129
|
-
self,
|
1130
|
-
*,
|
1131
|
-
skip_statistics: list[str] | None = None,
|
1132
|
-
**kwargs,
|
1133
|
-
) -> None:
|
1134
|
-
"""
|
1135
|
-
:param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
1136
|
-
fields in the model to not include in returned observations.
|
1137
|
-
:param kwargs: Other observation keyword arguments.
|
1138
|
-
"""
|
1139
|
-
|
1140
|
-
super().__init__(**kwargs)
|
1141
|
-
|
1142
|
-
self.skip_statistics = skip_statistics
|
1143
|
-
|
1144
995
|
def _get_storage_format(self) -> StatisticStorageTypes:
|
1145
996
|
"""
|
1146
997
|
:return: Storage format this observation stores data in. Must be one of the enum attributes in the
|
@@ -1183,22 +1034,6 @@ class AverageParameterUpdateMagnitudeStatistics(Statistic):
|
|
1183
1034
|
|
1184
1035
|
class MomentumGradientRatioStatistics(Statistic):
|
1185
1036
|
|
1186
|
-
def __init__(
|
1187
|
-
self,
|
1188
|
-
*,
|
1189
|
-
skip_statistics: list[str] | None = None,
|
1190
|
-
**kwargs,
|
1191
|
-
) -> None:
|
1192
|
-
"""
|
1193
|
-
:param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
1194
|
-
fields in the model to not include in returned observations.
|
1195
|
-
:param kwargs: Other observation keyword arguments.
|
1196
|
-
"""
|
1197
|
-
|
1198
|
-
super().__init__(**kwargs)
|
1199
|
-
|
1200
|
-
self.skip_statistics = skip_statistics
|
1201
|
-
|
1202
1037
|
def _get_storage_format(self) -> StatisticStorageTypes:
|
1203
1038
|
"""
|
1204
1039
|
:return: Storage format this observation stores data in. Must be one of the enum attributes in the
|
@@ -1263,26 +1098,8 @@ class LogOfNoiseScaleStatistics(Statistic):
|
|
1263
1098
|
- Σ is the noise covariance matrix
|
1264
1099
|
- B is the batch size
|
1265
1100
|
- ε is the learning rate
|
1266
|
-
|
1267
|
-
|
1268
1101
|
"""
|
1269
1102
|
|
1270
|
-
def __init__(
|
1271
|
-
self,
|
1272
|
-
*,
|
1273
|
-
skip_statistics: list[str] | None = None,
|
1274
|
-
**kwargs,
|
1275
|
-
) -> None:
|
1276
|
-
"""
|
1277
|
-
:param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
1278
|
-
fields in the model to not include in returned observations.
|
1279
|
-
:param kwargs: Other observation keyword arguments.
|
1280
|
-
"""
|
1281
|
-
|
1282
|
-
super().__init__(**kwargs)
|
1283
|
-
|
1284
|
-
self.skip_statistics = skip_statistics
|
1285
|
-
|
1286
1103
|
@property
|
1287
1104
|
def requires_gradient_graphs(self) -> bool:
|
1288
1105
|
"""
|
@@ -1370,19 +1187,17 @@ class LogOfNoiseScaleStatistics(Statistic):
|
|
1370
1187
|
# This is a common assumption when the exact noise structure is unknown
|
1371
1188
|
noise_covariance = torch.ones_like(hessian_diagonals)
|
1372
1189
|
|
1373
|
-
# Compute tr(HΣ)
|
1374
|
-
trace_hessian_noise_covariance =
|
1375
|
-
|
1376
|
-
|
1377
|
-
if trace_hessian_noise_covariance <= 0:
|
1378
|
-
return None
|
1190
|
+
# Compute tr(HΣ), add zero division tolerance to avoid log of zero when gradient is too small
|
1191
|
+
trace_hessian_noise_covariance = (
|
1192
|
+
torch.sum(hessian_diagonals * noise_covariance) + LHOPT_CONSTANTS["ZERO_DIVISION_TOLERANCE"]
|
1193
|
+
)
|
1379
1194
|
|
1380
1195
|
log_trace_hessian_noise_covariance = torch.log(trace_hessian_noise_covariance).item()
|
1381
1196
|
|
1382
|
-
# Compute tr(H^3 Σ)
|
1383
|
-
trace_hessian_cubed_noise_covariance =
|
1384
|
-
|
1385
|
-
|
1197
|
+
# Compute tr(H^3 Σ), add zero division tolerance to avoid log of zero when gradient is too small
|
1198
|
+
trace_hessian_cubed_noise_covariance = (
|
1199
|
+
torch.sum(hessian_diagonals**3 * noise_covariance) + LHOPT_CONSTANTS["ZERO_DIVISION_TOLERANCE"]
|
1200
|
+
)
|
1386
1201
|
|
1387
1202
|
log_trace_hessian_cubed_noise_covariance = torch.log(trace_hessian_cubed_noise_covariance).item()
|
1388
1203
|
|
@@ -13,7 +13,7 @@ from pydantic import BaseModel
|
|
13
13
|
#
|
14
14
|
# ======================================================================================================================
|
15
15
|
|
16
|
-
|
16
|
+
FIELD_SUFFIX = "_"
|
17
17
|
|
18
18
|
# ======================================================================================================================
|
19
19
|
#
|
@@ -164,28 +164,38 @@ class TensorStatistics(BaseModel):
|
|
164
164
|
return tensor[random_indices]
|
165
165
|
|
166
166
|
@classmethod
|
167
|
-
def
|
167
|
+
def filter_include_statistics(cls, include_statistics: list[str]) -> list[str]:
|
168
168
|
"""
|
169
|
-
:param
|
170
|
-
:return:
|
171
|
-
|
169
|
+
:param include_statistics: Names of the fields in the model to include in returned observations.
|
170
|
+
:return: List of fields from the given include_statistics list that are present in this pydantic model.
|
171
|
+
:raises ValueError: If no statistics to include are given.
|
172
172
|
"""
|
173
173
|
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
else
|
178
|
-
|
174
|
+
filtered_include_statistics: list[str] = []
|
175
|
+
|
176
|
+
for include_stat in include_statistics:
|
177
|
+
with_suffix = include_stat + FIELD_SUFFIX if not include_stat.endswith(FIELD_SUFFIX) else include_stat
|
178
|
+
|
179
|
+
if with_suffix in cls.model_fields.keys():
|
180
|
+
filtered_include_statistics.append(with_suffix)
|
181
|
+
|
182
|
+
if not filtered_include_statistics:
|
183
|
+
raise ValueError(f"No statistics to include given to {cls.__name__}!")
|
184
|
+
|
185
|
+
return filtered_include_statistics
|
179
186
|
|
180
187
|
@classmethod
|
181
188
|
def build(
|
182
|
-
cls,
|
189
|
+
cls,
|
190
|
+
tensor: torch.Tensor,
|
191
|
+
include_statistics: list[str],
|
192
|
+
sample_percentage: float = 0.01,
|
183
193
|
) -> "TensorStatistics":
|
184
194
|
"""
|
185
195
|
:param tensor: Tensor to compute and store statistics of.
|
196
|
+
:param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
197
|
+
fields in the model to include in returned observations.
|
186
198
|
:param sample_percentage: Percentage of the given tensor to randomly sample and compute statistics from.
|
187
|
-
:param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
188
|
-
fields in the model to not include in returned observations.
|
189
199
|
:return: Constructed tensor statistics.
|
190
200
|
"""
|
191
201
|
|
@@ -193,12 +203,10 @@ class TensorStatistics(BaseModel):
|
|
193
203
|
downsampled_tensor = cls.downsample_tensor(tensor=tensor, sample_percentage=sample_percentage)
|
194
204
|
|
195
205
|
for field, field_value in stats.model_dump().items():
|
196
|
-
name = field[:-1] if field.endswith(
|
197
|
-
|
198
|
-
if skip_statistics is not None and name in skip_statistics:
|
199
|
-
continue
|
206
|
+
name = field[:-1] if field.endswith(FIELD_SUFFIX) else field
|
200
207
|
|
201
|
-
|
208
|
+
if name in include_statistics:
|
209
|
+
setattr(stats, name, downsampled_tensor)
|
202
210
|
|
203
211
|
return stats
|
204
212
|
|
@@ -219,28 +227,21 @@ class TensorStatistics(BaseModel):
|
|
219
227
|
inter_quartile_range_=tensor[6],
|
220
228
|
)
|
221
229
|
|
222
|
-
def to_list(self,
|
230
|
+
def to_list(self, include_statistics: list[str]) -> list[float]:
|
223
231
|
"""
|
224
|
-
:param
|
232
|
+
:param include_statistics: List of field names to include in the returned list.
|
225
233
|
:return: List of field values.
|
226
234
|
"""
|
227
235
|
|
228
|
-
|
229
|
-
skip_statistics = []
|
230
|
-
|
231
|
-
if not all(skip_stat in self.model_fields.keys() for skip_stat in skip_statistics):
|
232
|
-
raise ValueError(
|
233
|
-
f"One or more skip statistic keys do not exist in TensorStatistics. Valid Skip Keys: "
|
234
|
-
f"{list(self.model_fields.keys())} Given Skip Keys: {skip_statistics}"
|
235
|
-
)
|
236
|
+
filtered_includes = self.filter_include_statistics(include_statistics=include_statistics)
|
236
237
|
|
237
238
|
as_list = []
|
238
239
|
|
239
240
|
for field, field_value in self.model_dump().items():
|
240
|
-
|
241
|
-
continue
|
241
|
+
without_suffix = field[:-1]
|
242
242
|
|
243
|
-
|
243
|
+
if field in filtered_includes or without_suffix in filtered_includes:
|
244
|
+
as_list.append(field_value)
|
244
245
|
|
245
246
|
return as_list
|
246
247
|
|
@@ -267,6 +268,6 @@ class TensorStatistics(BaseModel):
|
|
267
268
|
"""
|
268
269
|
|
269
270
|
return {
|
270
|
-
field[:-1] if field.endswith(
|
271
|
+
field[:-1] if field.endswith(FIELD_SUFFIX) else field: field_value
|
271
272
|
for field, field_value in self.model_dump().items()
|
272
273
|
}
|