libinephany 1.0.1__tar.gz → 1.0.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. libinephany-1.0.3/CODE_VERSION.cfg +1 -0
  2. {libinephany-1.0.1/libinephany.egg-info → libinephany-1.0.3}/PKG-INFO +1 -1
  3. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/observations/observers/base_observers.py +26 -0
  4. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/observations/observers/global_observers/base_classes.py +50 -64
  5. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/observations/observers/global_observers/constants.py +14 -1
  6. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/observations/observers/global_observers/gradient_observers.py +55 -94
  7. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/observations/observers/global_observers/hyperparameter_observers.py +26 -8
  8. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/observations/observers/global_observers/loss_observers.py +26 -47
  9. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/observations/observers/global_observers/model_observers.py +65 -25
  10. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/observations/observers/local_observers.py +90 -52
  11. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/utils/enums.py +1 -0
  12. {libinephany-1.0.1 → libinephany-1.0.3/libinephany.egg-info}/PKG-INFO +1 -1
  13. libinephany-1.0.1/CODE_VERSION.cfg +0 -1
  14. {libinephany-1.0.1 → libinephany-1.0.3}/LICENSE +0 -0
  15. {libinephany-1.0.1 → libinephany-1.0.3}/MANIFEST.in +0 -0
  16. {libinephany-1.0.1 → libinephany-1.0.3}/README.md +0 -0
  17. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/__init__.py +0 -0
  18. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/aws/__init__.py +0 -0
  19. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/aws/s3_functions.py +0 -0
  20. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/observations/__init__.py +0 -0
  21. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/observations/observation_utils.py +0 -0
  22. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/observations/observer_pipeline.py +0 -0
  23. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/observations/observers/__init__.py +0 -0
  24. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/observations/observers/global_observers/__init__.py +0 -0
  25. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/observations/observers/global_observers/progress_observers.py +0 -0
  26. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/observations/observers/observer_containers.py +0 -0
  27. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/observations/pipeline_coordinator.py +0 -0
  28. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/observations/post_processors/__init__.py +0 -0
  29. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/observations/post_processors/postprocessors.py +0 -0
  30. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/observations/statistic_manager.py +0 -0
  31. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/observations/statistic_trackers.py +0 -0
  32. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/pydantic_models/__init__.py +0 -0
  33. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/pydantic_models/configs/__init__.py +0 -0
  34. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/pydantic_models/configs/hyperparameter_configs.py +0 -0
  35. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/pydantic_models/configs/observer_config.py +0 -0
  36. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/pydantic_models/configs/outer_model_config.py +0 -0
  37. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/pydantic_models/schemas/__init__.py +0 -0
  38. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/pydantic_models/schemas/agent_info.py +0 -0
  39. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/pydantic_models/schemas/inner_task_profile.py +0 -0
  40. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/pydantic_models/schemas/observation_models.py +0 -0
  41. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/pydantic_models/schemas/request_schemas.py +0 -0
  42. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/pydantic_models/schemas/response_schemas.py +0 -0
  43. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/pydantic_models/schemas/tensor_statistics.py +0 -0
  44. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/pydantic_models/states/__init__.py +0 -0
  45. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/pydantic_models/states/hyperparameter_states.py +0 -0
  46. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/utils/__init__.py +0 -0
  47. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/utils/agent_utils.py +0 -0
  48. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/utils/asyncio_worker.py +0 -0
  49. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/utils/backend_statuses.py +0 -0
  50. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/utils/constants.py +0 -0
  51. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/utils/directory_utils.py +0 -0
  52. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/utils/dropout_utils.py +0 -0
  53. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/utils/error_severities.py +0 -0
  54. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/utils/exceptions.py +0 -0
  55. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/utils/import_utils.py +0 -0
  56. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/utils/optim_utils.py +0 -0
  57. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/utils/random_seeds.py +0 -0
  58. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/utils/samplers.py +0 -0
  59. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/utils/standardizers.py +0 -0
  60. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/utils/torch_distributed_utils.py +0 -0
  61. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/utils/torch_utils.py +0 -0
  62. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/utils/transforms.py +0 -0
  63. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/utils/typing.py +0 -0
  64. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/web_apps/__init__.py +0 -0
  65. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/web_apps/error_logger.py +0 -0
  66. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany/web_apps/web_app_utils.py +0 -0
  67. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany.egg-info/SOURCES.txt +0 -0
  68. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany.egg-info/dependency_links.txt +0 -0
  69. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany.egg-info/requires.txt +0 -0
  70. {libinephany-1.0.1 → libinephany-1.0.3}/libinephany.egg-info/top_level.txt +0 -0
  71. {libinephany-1.0.1 → libinephany-1.0.3}/pyproject.toml +0 -0
  72. {libinephany-1.0.1 → libinephany-1.0.3}/setup.cfg +0 -0
@@ -0,0 +1 @@
1
+ 1.0.3
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: libinephany
3
- Version: 1.0.1
3
+ Version: 1.0.3
4
4
  Summary: Inephany library containing code commonly used by multiple subpackages.
5
5
  Author-email: Inephany <info@inephany.com>
6
6
  License: Apache 2.0
@@ -44,6 +44,7 @@ class Observer(ABC):
44
44
  observer_config: ObserverConfig,
45
45
  should_standardize: bool = True,
46
46
  include_statistics: list[str] | None = None,
47
+ include_hparams: list[str] | None = None,
47
48
  **kwargs,
48
49
  ) -> None:
49
50
  """
@@ -52,6 +53,8 @@ class Observer(ABC):
52
53
  :param should_standardize: Whether standardization should be applied to returned values.
53
54
  :param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
54
55
  fields in the model to include in returned observations.
56
+ :param include_hparams: If the observation uses the HyperparameterStates model to return observations, names of the
57
+ hyperparameters to include in returned observations.
55
58
  :param kwargs: Miscellaneous keyword arguments.
56
59
  """
57
60
 
@@ -64,10 +67,17 @@ class Observer(ABC):
64
67
  self.should_standardize = should_standardize and self.can_standardize
65
68
 
66
69
  self.include_statistics: list[str] | None = None
70
+ self.include_hparams = include_hparams
67
71
 
68
72
  if include_statistics is not None:
69
73
  self.include_statistics = TensorStatistics.filter_include_statistics(include_statistics=include_statistics)
70
74
 
75
+ if self.requires_include_statistics and not self.include_statistics:
76
+ raise ValueError(f"{self.__class__.__name__} must be provided with include_statistics.")
77
+
78
+ if self.requires_include_hparams and not self.include_hparams:
79
+ raise ValueError(f"{self.__class__.__name__} must be provided with include_hparams.")
80
+
71
81
  @final
72
82
  @property
73
83
  def in_training_mode(self) -> bool:
@@ -143,6 +153,22 @@ class Observer(ABC):
143
153
 
144
154
  return True
145
155
 
156
+ @property
157
+ def requires_include_statistics(self) -> bool:
158
+ """
159
+ :return: Whether the observation requires include_statistics to be provided.
160
+ """
161
+
162
+ return False
163
+
164
+ @property
165
+ def requires_include_hparams(self) -> bool:
166
+ """
167
+ :return: Whether the observation requires include_hparams to be provided.
168
+ """
169
+
170
+ return False
171
+
146
172
  @property
147
173
  @abstractmethod
148
174
  def standardizer_key_infix(self) -> str:
@@ -1,6 +1,6 @@
1
1
  # ======================================================================================================================
2
2
  #
3
- # BASE CLASSES
3
+ # IMPORTS
4
4
  #
5
5
  # ======================================================================================================================
6
6
 
@@ -15,6 +15,12 @@ from libinephany.pydantic_models.schemas.observation_models import ObservationIn
15
15
  from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
16
16
  from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
17
17
 
18
+ # ======================================================================================================================
19
+ #
20
+ # CLASSES
21
+ #
22
+ # ======================================================================================================================
23
+
18
24
 
19
25
  class LHOPTBaseObserver(GlobalObserver, ABC):
20
26
  """
@@ -33,13 +39,14 @@ class LHOPTBaseObserver(GlobalObserver, ABC):
33
39
  :param kwargs: Other observation keyword arguments.
34
40
  """
35
41
  super().__init__(**kwargs)
36
- self.decay_factor = max(0.0, decay_factor)
37
- self.time_window = max(1, time_window)
38
42
 
39
43
  # Store time series data for CDF calculation
40
44
  self._time_series: list[tuple[float, float]] = [] # (time, value) pairs
41
45
  self._current_time: float = 0.0
42
46
 
47
+ self.decay_factor = max(0.0, decay_factor)
48
+ self.time_window = max(1, time_window)
49
+
43
50
  @property
44
51
  def can_standardize(self) -> bool:
45
52
  """
@@ -48,6 +55,28 @@ class LHOPTBaseObserver(GlobalObserver, ABC):
48
55
  """
49
56
  return False
50
57
 
58
+ @staticmethod
59
+ def _compute_log_ratio(numerator: float, denominator: float) -> float:
60
+ """
61
+ Compute the log ratio.
62
+
63
+ :param numerator: Numerator value
64
+ :param denominator: Denominator value
65
+ :return: Log ratio value
66
+ """
67
+ # Calculate the ratio of numerator to denominator
68
+ invalid_denominator = math.isinf(denominator) or math.isnan(denominator)
69
+
70
+ if denominator <= LHOPT_CONSTANTS["ZERO_DIVISION_TOLERANCE"] or invalid_denominator:
71
+ return 0.0
72
+
73
+ ratio = numerator / denominator
74
+
75
+ if ratio <= 0:
76
+ return 0.0
77
+
78
+ return math.log(ratio)
79
+
51
80
  def _get_observation_format(self) -> StatisticStorageTypes:
52
81
  """
53
82
  :return: Format the observation returns data in. Must be one of the StatisticStorageTypes
@@ -68,18 +97,6 @@ class LHOPTBaseObserver(GlobalObserver, ABC):
68
97
  """Update the current time counter."""
69
98
  self._current_time += 1.0
70
99
 
71
- def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
72
- """
73
- :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
74
- needed.
75
- """
76
- return {}
77
-
78
- def reset(self) -> None:
79
- """Reset the observer by clearing the time series."""
80
- self._time_series = []
81
- self._current_time = 0.0
82
-
83
100
  @abstractmethod
84
101
  def _observe(
85
102
  self,
@@ -94,27 +111,13 @@ class LHOPTBaseObserver(GlobalObserver, ABC):
94
111
  :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
95
112
  :param action_taken: Action taken by the agent this class instance is assigned to.
96
113
  """
97
- raise NotImplementedError
98
-
99
- def _compute_log_ratio(self, numerator: float, denominator: float) -> float:
100
- """
101
- Compute the log ratio.
102
-
103
- :param numerator: Numerator value
104
- :param denominator: Denominator value
105
- :return: Log ratio value
106
- """
107
- # Calculate the ratio of numerator to denominator
108
114
 
109
- if denominator <= LHOPT_CONSTANTS["ZERO_DIVISION_TOLERANCE"]:
110
- return 0.0
115
+ ...
111
116
 
112
- ratio = numerator / denominator
113
-
114
- if ratio <= 0:
115
- return 0.0
116
-
117
- return math.log(ratio)
117
+ def reset(self) -> None:
118
+ """Reset the observer by clearing the time series."""
119
+ self._time_series = []
120
+ self._current_time = 0.0
118
121
 
119
122
 
120
123
  class LHOPTCheckpointBaseObserver(GlobalObserver, ABC):
@@ -128,8 +131,10 @@ class LHOPTCheckpointBaseObserver(GlobalObserver, ABC):
128
131
  :param kwargs: Miscellaneous keyword arguments.
129
132
  """
130
133
  super().__init__(**kwargs)
131
- self.checkpoint_interval = checkpoint_interval
134
+
132
135
  self._history: list[float] = []
136
+
137
+ self.checkpoint_interval = checkpoint_interval
133
138
  self.last_value: float | None = None
134
139
 
135
140
  @property
@@ -175,18 +180,6 @@ class LHOPTCheckpointBaseObserver(GlobalObserver, ABC):
175
180
  if self.last_value is None:
176
181
  self.last_value = value
177
182
 
178
- def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
179
- """
180
- :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
181
- needed.
182
- """
183
- return {}
184
-
185
- def reset(self) -> None:
186
- """Reset the observer by clearing history."""
187
- self._history = []
188
- self.last_value = None
189
-
190
183
  @abstractmethod
191
184
  def _observe(
192
185
  self,
@@ -201,24 +194,17 @@ class LHOPTCheckpointBaseObserver(GlobalObserver, ABC):
201
194
  :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
202
195
  :param action_taken: Action taken by the agent this class instance is assigned to.
203
196
  """
204
- raise NotImplementedError
205
197
 
206
- def _compute_log_ratio(self, numerator: float, denominator: float) -> float:
207
- """
208
- Compute the log ratio.
198
+ ...
209
199
 
210
- :param numerator: Numerator value
211
- :param denominator: Denominator value
212
- :return: Log ratio value
200
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
213
201
  """
214
- # Calculate the ratio of numerator to denominator
215
-
216
- if denominator <= LHOPT_CONSTANTS["ZERO_DIVISION_TOLERANCE"]:
217
- return 0.0
218
-
219
- ratio = numerator / denominator
220
-
221
- if ratio <= 0:
222
- return 0.0
202
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
203
+ needed.
204
+ """
205
+ return {}
223
206
 
224
- return math.log(ratio)
207
+ def reset(self) -> None:
208
+ """Reset the observer by clearing history."""
209
+ self._history = []
210
+ self.last_value = None
@@ -1,11 +1,17 @@
1
1
  # ======================================================================================================================
2
2
  #
3
- # CONSTANTS
3
+ # IMPORTS
4
4
  #
5
5
  # ======================================================================================================================
6
6
 
7
7
  from typing import TypedDict
8
8
 
9
+ # ======================================================================================================================
10
+ #
11
+ # CLASSES
12
+ #
13
+ # ======================================================================================================================
14
+
9
15
 
10
16
  class LHOPTConstants(TypedDict):
11
17
  IS_NAN: float
@@ -23,6 +29,13 @@ class LHOPTConstants(TypedDict):
23
29
  DEFAULT_ENV_STEP_SAMPLE_FREQUENCY: int
24
30
 
25
31
 
32
+ # ======================================================================================================================
33
+ #
34
+ # CONSTANTS
35
+ #
36
+ # ======================================================================================================================
37
+
38
+
26
39
  # Create the constants instance
27
40
  LHOPT_CONSTANTS: LHOPTConstants = LHOPTConstants(
28
41
  IS_NAN=1.0,
@@ -1,6 +1,6 @@
1
1
  # ======================================================================================================================
2
2
  #
3
- # GRADIENT OBSERVERS
3
+ # IMPORTS
4
4
  #
5
5
  # ======================================================================================================================
6
6
 
@@ -16,9 +16,23 @@ from libinephany.pydantic_models.schemas.observation_models import ObservationIn
16
16
  from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
17
17
  from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
18
18
 
19
+ # ======================================================================================================================
20
+ #
21
+ # CLASSES
22
+ #
23
+ # ======================================================================================================================
24
+
19
25
 
20
26
  class GlobalFirstOrderGradients(GlobalObserver):
21
27
 
28
+ @property
29
+ def requires_include_statistics(self) -> bool:
30
+ """
31
+ :return: Whether the observation requires include_statistics to be provided.
32
+ """
33
+
34
+ return True
35
+
22
36
  def _get_observation_format(self) -> StatisticStorageTypes:
23
37
  """
24
38
  :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
@@ -74,6 +88,14 @@ class GlobalSecondOrderGradients(GlobalObserver):
74
88
 
75
89
  self.compute_hessian_diagonal = compute_hessian_diagonal
76
90
 
91
+ @property
92
+ def requires_include_statistics(self) -> bool:
93
+ """
94
+ :return: Whether the observation requires include_statistics to be provided.
95
+ """
96
+
97
+ return True
98
+
77
99
  def _get_observation_format(self) -> StatisticStorageTypes:
78
100
  """
79
101
  :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
@@ -136,21 +158,6 @@ class LHOPTGradientVarianceFraction(LHOPTBaseObserver):
136
158
  super().__init__(**kwargs)
137
159
  self.variance_threshold = variance_threshold
138
160
 
139
- @property
140
- def can_standardize(self) -> bool:
141
- """
142
- This observer has its own CDF calculation, no need to standardize.
143
- :return: Whether the observation can be standardized.
144
- """
145
- return False
146
-
147
- def _get_observation_format(self) -> StatisticStorageTypes:
148
- """
149
- :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
150
- enumeration class.
151
- """
152
- return StatisticStorageTypes.VECTOR
153
-
154
161
  @property
155
162
  def vector_length(self) -> int:
156
163
  """
@@ -173,8 +180,6 @@ class LHOPTGradientVarianceFraction(LHOPTBaseObserver):
173
180
  :param action_taken: Action taken by the agent this class instance is assigned to.
174
181
  :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
175
182
  """
176
- if statistic_trackers.GradientVarianceFraction.__name__ not in tracked_statistics:
177
- return [0.0, 0.0]
178
183
 
179
184
  raw_value = list(tracked_statistics[statistic_trackers.GradientVarianceFraction.__name__].values())[0] # type: ignore[list-item]
180
185
 
@@ -204,14 +209,6 @@ class LHOPTMomentumGradientRatio(LHOPTBaseObserver):
204
209
  It returns two-dimensional observations: [raw_value, cdf_feature] for momentum gradient ratio values.
205
210
  """
206
211
 
207
- def _get_observation_format(self) -> StatisticStorageTypes:
208
- """
209
- :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
210
- enumeration class.
211
- """
212
-
213
- return StatisticStorageTypes.VECTOR
214
-
215
212
  @property
216
213
  def vector_length(self) -> int:
217
214
  """
@@ -219,6 +216,14 @@ class LHOPTMomentumGradientRatio(LHOPTBaseObserver):
219
216
  """
220
217
  return 2 # [raw_value, cdf_feature]
221
218
 
219
+ @property
220
+ def requires_include_statistics(self) -> bool:
221
+ """
222
+ :return: Whether the observation requires include_statistics to be provided.
223
+ """
224
+
225
+ return True
226
+
222
227
  def _observe(
223
228
  self,
224
229
  observation_inputs: ObservationInputs,
@@ -266,29 +271,6 @@ class CosineSimilarityObserverOfGradientAndMomentum(LHOPTBaseObserver):
266
271
  It returns two-dimensional observations: [raw_value, cdf_feature] for cosine similarity of gradient and momentum values.
267
272
  """
268
273
 
269
- def __init__(
270
- self,
271
- *,
272
- include_statistics: list[str] | None = None,
273
- **kwargs,
274
- ) -> None:
275
- """
276
- :param include_statistics: List of statistics to include.
277
- :param kwargs: Miscellaneous keyword arguments.
278
- """
279
-
280
- super().__init__(**kwargs)
281
-
282
- self.include_statistics = include_statistics
283
-
284
- def _get_observation_format(self) -> StatisticStorageTypes:
285
- """
286
- :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
287
- enumeration class.
288
- """
289
-
290
- return StatisticStorageTypes.VECTOR
291
-
292
274
  @property
293
275
  def vector_length(self) -> int:
294
276
  """
@@ -296,6 +278,14 @@ class CosineSimilarityObserverOfGradientAndMomentum(LHOPTBaseObserver):
296
278
  """
297
279
  return 3 # [raw_value, cdf_feature, logit_of_cdf_feature]
298
280
 
281
+ @property
282
+ def requires_include_statistics(self) -> bool:
283
+ """
284
+ :return: Whether the observation requires include_statistics to be provided.
285
+ """
286
+
287
+ return True
288
+
299
289
  def _observe(
300
290
  self,
301
291
  observation_inputs: ObservationInputs,
@@ -351,29 +341,6 @@ class CosineSimilarityObserverOfGradientAndUpdate(LHOPTBaseObserver):
351
341
  It returns two-dimensional observations: [raw_value, cdf_feature] for cosine similarity of gradient and update values.
352
342
  """
353
343
 
354
- def __init__(
355
- self,
356
- *,
357
- include_statistics: list[str] | None = None,
358
- **kwargs,
359
- ) -> None:
360
- """
361
- :param include_statistics: List of statistics to include.
362
- :param kwargs: Miscellaneous keyword arguments.
363
- """
364
-
365
- super().__init__(**kwargs)
366
-
367
- self.include_statistics = include_statistics
368
-
369
- def _get_observation_format(self) -> StatisticStorageTypes:
370
- """
371
- :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
372
- enumeration class.
373
- """
374
-
375
- return StatisticStorageTypes.VECTOR
376
-
377
344
  @property
378
345
  def vector_length(self) -> int:
379
346
  """
@@ -381,6 +348,14 @@ class CosineSimilarityObserverOfGradientAndUpdate(LHOPTBaseObserver):
381
348
  """
382
349
  return 3 # [raw_value, cdf_feature, logit_of_cdf_feature]
383
350
 
351
+ @property
352
+ def requires_include_statistics(self) -> bool:
353
+ """
354
+ :return: Whether the observation requires include_statistics to be provided.
355
+ """
356
+
357
+ return True
358
+
384
359
  def _observe(
385
360
  self,
386
361
  observation_inputs: ObservationInputs,
@@ -436,28 +411,6 @@ class CosineSimilarityOfGradientAndParameter(LHOPTBaseObserver):
436
411
  It returns two-dimensional observations: [raw_value, cdf_feature] for cosine similarity of gradient and parameter values.
437
412
  """
438
413
 
439
- def __init__(
440
- self,
441
- *,
442
- include_statistics: list[str] | None = None,
443
- **kwargs,
444
- ) -> None:
445
- """
446
- :param include_statistics: List of statistics to include.
447
- :param kwargs: Miscellaneous keyword arguments.
448
- """
449
- super().__init__(**kwargs)
450
-
451
- self.include_statistics = include_statistics
452
-
453
- def _get_observation_format(self) -> StatisticStorageTypes:
454
- """
455
- :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
456
- enumeration class.
457
- """
458
-
459
- return StatisticStorageTypes.VECTOR
460
-
461
414
  @property
462
415
  def vector_length(self) -> int:
463
416
  """
@@ -465,6 +418,14 @@ class CosineSimilarityOfGradientAndParameter(LHOPTBaseObserver):
465
418
  """
466
419
  return 3 # [raw_value, cdf_feature, logit_of_cdf_feature]
467
420
 
421
+ @property
422
+ def requires_include_statistics(self) -> bool:
423
+ """
424
+ :return: Whether the observation requires include_statistics to be provided.
425
+ """
426
+
427
+ return True
428
+
468
429
  def _observe(
469
430
  self,
470
431
  observation_inputs: ObservationInputs,
@@ -1,6 +1,6 @@
1
1
  # ======================================================================================================================
2
2
  #
3
- # HYPERPARAMETER OBSERVERS
3
+ # IMPORTS
4
4
  #
5
5
  # ======================================================================================================================
6
6
 
@@ -12,25 +12,28 @@ from torch.optim import SGD, Adam, AdamW
12
12
  from libinephany.observations import observation_utils
13
13
  from libinephany.observations.observation_utils import StatisticStorageTypes
14
14
  from libinephany.observations.observers.base_observers import GlobalObserver
15
- from libinephany.observations.observers.global_observers.base_classes import LHOPT_CONSTANTS
15
+ from libinephany.observations.observers.global_observers.constants import LHOPT_CONSTANTS
16
16
  from libinephany.pydantic_models.schemas.observation_models import ObservationInputs
17
17
  from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
18
18
  from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
19
19
  from libinephany.utils.enums import ModelFamilies
20
20
 
21
+ # ======================================================================================================================
22
+ #
23
+ # CLASSES
24
+ #
25
+ # ======================================================================================================================
26
+
21
27
 
22
28
  class InitialHyperparameters(GlobalObserver):
23
29
 
24
- def __init__(self, include_hparams: list[str] | None = None, pad_with: float = 0.0, **kwargs) -> None:
30
+ def __init__(self, pad_with: float = 0.0, **kwargs) -> None:
25
31
  """
26
- :param include_hparams: Names of the hyperparameters to include in the initial values vector returned by
27
- this observation.
28
32
  :param kwargs: Miscellaneous keyword arguments.
29
33
  """
30
34
 
31
35
  super().__init__(**kwargs)
32
36
 
33
- self.include_hparams = include_hparams
34
37
  self.pad_with = pad_with
35
38
 
36
39
  @property
@@ -62,6 +65,14 @@ class InitialHyperparameters(GlobalObserver):
62
65
 
63
66
  return False
64
67
 
68
+ @property
69
+ def requires_include_hparams(self) -> bool:
70
+ """
71
+ :return: Whether the observation requires include_hparams to be provided.
72
+ """
73
+
74
+ return True
75
+
65
76
  def _get_observation_format(self) -> StatisticStorageTypes:
66
77
  """
67
78
  :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
@@ -298,7 +309,7 @@ class LHOPTHyperparameterRatio(GlobalObserver):
298
309
  providing insights into how much hyperparameters have changed from their starting values.
299
310
  """
300
311
 
301
- def __init__(self, include_hparams: list[str] | None = None, pad_with: float = 0.0, **kwargs) -> None:
312
+ def __init__(self, pad_with: float = 0.0, **kwargs) -> None:
302
313
  """
303
314
  :param include_hparams: Names of the hyperparameters to include in the initial values vector returned by
304
315
  this observation.
@@ -307,7 +318,6 @@ class LHOPTHyperparameterRatio(GlobalObserver):
307
318
 
308
319
  super().__init__(**kwargs)
309
320
 
310
- self.include_hparams = include_hparams
311
321
  self.pad_with = pad_with
312
322
 
313
323
  @property
@@ -339,6 +349,14 @@ class LHOPTHyperparameterRatio(GlobalObserver):
339
349
 
340
350
  return False
341
351
 
352
+ @property
353
+ def requires_include_hparams(self) -> bool:
354
+ """
355
+ :return: Whether the observation requires include_hparams to be provided.
356
+ """
357
+
358
+ return True
359
+
342
360
  def _get_observation_format(self) -> StatisticStorageTypes:
343
361
  """
344
362
  :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes