libinephany 0.18.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,7 +13,7 @@ from pydantic import BaseModel
13
13
  #
14
14
  # ======================================================================================================================
15
15
 
16
- STRIP_SUFFIX = "_"
16
+ FIELD_SUFFIX = "_"
17
17
 
18
18
  # ======================================================================================================================
19
19
  #
@@ -164,28 +164,38 @@ class TensorStatistics(BaseModel):
164
164
  return tensor[random_indices]
165
165
 
166
166
  @classmethod
167
- def filter_skip_statistics(cls, skip_statistics: list[str] | None) -> list[str]:
167
+ def filter_include_statistics(cls, include_statistics: list[str]) -> list[str]:
168
168
  """
169
- :param skip_statistics: Names of the fields in the model to not include in returned observations.
170
- :return: Empty list if skip_statistics was None or skip_statistics filtered to include only the names of fields
171
- present in this pydantic model.
169
+ :param include_statistics: Names of the fields in the model to include in returned observations.
170
+ :return: List of fields from the given include_statistics list that are present in this pydantic model.
171
+ :raises ValueError: If no statistics to include are given.
172
172
  """
173
173
 
174
- return (
175
- [skip_stat for skip_stat in skip_statistics if skip_stat in cls.model_fields.keys()]
176
- if skip_statistics is not None
177
- else []
178
- )
174
+ filtered_include_statistics: list[str] = []
175
+
176
+ for include_stat in include_statistics:
177
+ with_suffix = include_stat + FIELD_SUFFIX if not include_stat.endswith(FIELD_SUFFIX) else include_stat
178
+
179
+ if with_suffix in cls.model_fields.keys():
180
+ filtered_include_statistics.append(with_suffix)
181
+
182
+ if not filtered_include_statistics:
183
+ raise ValueError(f"No statistics to include given to {cls.__name__}!")
184
+
185
+ return filtered_include_statistics
179
186
 
180
187
  @classmethod
181
188
  def build(
182
- cls, tensor: torch.Tensor, sample_percentage: float = 0.01, skip_statistics: list[str] | None = None
189
+ cls,
190
+ tensor: torch.Tensor,
191
+ include_statistics: list[str],
192
+ sample_percentage: float = 0.01,
183
193
  ) -> "TensorStatistics":
184
194
  """
185
195
  :param tensor: Tensor to compute and store statistics of.
196
+ :param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
197
+ fields in the model to include in returned observations.
186
198
  :param sample_percentage: Percentage of the given tensor to randomly sample and compute statistics from.
187
- :param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
188
- fields in the model to not include in returned observations.
189
199
  :return: Constructed tensor statistics.
190
200
  """
191
201
 
@@ -193,12 +203,10 @@ class TensorStatistics(BaseModel):
193
203
  downsampled_tensor = cls.downsample_tensor(tensor=tensor, sample_percentage=sample_percentage)
194
204
 
195
205
  for field, field_value in stats.model_dump().items():
196
- name = field[:-1] if field.endswith("_") else field
197
-
198
- if skip_statistics is not None and name in skip_statistics:
199
- continue
206
+ name = field[:-1] if field.endswith(FIELD_SUFFIX) else field
200
207
 
201
- setattr(stats, name, downsampled_tensor)
208
+ if name in include_statistics:
209
+ setattr(stats, name, downsampled_tensor)
202
210
 
203
211
  return stats
204
212
 
@@ -219,28 +227,21 @@ class TensorStatistics(BaseModel):
219
227
  inter_quartile_range_=tensor[6],
220
228
  )
221
229
 
222
- def to_list(self, skip_statistics: list[str] | None) -> list[float]:
230
+ def to_list(self, include_statistics: list[str]) -> list[float]:
223
231
  """
224
- :param skip_statistics: None or a list of field names to skip from adding to the returned list.
232
+ :param include_statistics: List of field names to include in the returned list.
225
233
  :return: List of field values.
226
234
  """
227
235
 
228
- if skip_statistics is None:
229
- skip_statistics = []
230
-
231
- if not all(skip_stat in self.model_fields.keys() for skip_stat in skip_statistics):
232
- raise ValueError(
233
- f"One or more skip statistic keys do not exist in TensorStatistics. Valid Skip Keys: "
234
- f"{list(self.model_fields.keys())} Given Skip Keys: {skip_statistics}"
235
- )
236
+ filtered_includes = self.filter_include_statistics(include_statistics=include_statistics)
236
237
 
237
238
  as_list = []
238
239
 
239
240
  for field, field_value in self.model_dump().items():
240
- if field in skip_statistics:
241
- continue
241
+ without_suffix = field[:-1]
242
242
 
243
- as_list.append(field_value)
243
+ if field in filtered_includes or without_suffix in filtered_includes:
244
+ as_list.append(field_value)
244
245
 
245
246
  return as_list
246
247
 
@@ -267,6 +268,6 @@ class TensorStatistics(BaseModel):
267
268
  """
268
269
 
269
270
  return {
270
- field[:-1] if field.endswith(STRIP_SUFFIX) else field: field_value
271
+ field[:-1] if field.endswith(FIELD_SUFFIX) else field: field_value
271
272
  for field, field_value in self.model_dump().items()
272
273
  }
@@ -295,52 +295,52 @@ class HyperparameterContainer(BaseModel):
295
295
  f"{self.get_hyperparameters()}."
296
296
  )
297
297
 
298
- def get_initial_internal_values(self, skip_hparams: list[str] | None = None) -> dict[str, float | int | None]:
298
+ def get_initial_internal_values(self, include_hparams: list[str] | None = None) -> dict[str, float | int | None]:
299
299
  """
300
- :param skip_hparams: Names of the hyperparameters to not include in the initial values vector returned by
301
- this observation.
300
+ :param include_hparams: Names of the hyperparameters to include in the hyperparameter name to initial internal
301
+ value mapping.
302
302
  :return: Dictionary mapping hyperparameter names to their current values for this parameter group.
303
303
  """
304
304
 
305
- if skip_hparams is None:
306
- skip_hparams = []
305
+ if include_hparams is None:
306
+ include_hparams = list(self.model_fields.keys())
307
307
 
308
308
  return {
309
309
  field_name: field_value.initial_internal_value
310
310
  for field_name, field_value in self.__dict__.items()
311
- if isinstance(field_value, Hyperparameter) and field_name not in skip_hparams
311
+ if isinstance(field_value, Hyperparameter) and field_name in include_hparams
312
312
  }
313
313
 
314
- def get_current_internal_values(self, skip_hparams: list[str] | None = None) -> dict[str, float | int | None]:
314
+ def get_current_internal_values(self, include_hparams: list[str] | None = None) -> dict[str, float | int | None]:
315
315
  """
316
- :param skip_hparams: Names of the hyperparameters to not include in the current values vector returned by
317
- this observation.
316
+ :param include_hparams: Names of the hyperparameters to include in the hyperparameter name to current internal
317
+ value mapping.
318
318
  :return: Dictionary mapping hyperparameter names to their current values for this parameter group.
319
319
  """
320
320
 
321
- if skip_hparams is None:
322
- skip_hparams = []
321
+ if include_hparams is None:
322
+ include_hparams = list(self.model_fields.keys())
323
323
 
324
324
  return {
325
325
  field_name: field_value.current_internal_value
326
326
  for field_name, field_value in self.__dict__.items()
327
- if isinstance(field_value, Hyperparameter) and field_name not in skip_hparams
327
+ if isinstance(field_value, Hyperparameter) and field_name in include_hparams
328
328
  }
329
329
 
330
- def get_current_deltas(self, skip_hparams: list[str] | None = None) -> dict[str, float | int | None]:
330
+ def get_current_deltas(self, include_hparams: list[str] | None = None) -> dict[str, float | int | None]:
331
331
  """
332
- :param skip_hparams: Names of the hyperparameters to not include in the current deltas vector returned by
333
- this observation.
332
+ :param include_hparams: Names of the hyperparameters to include in the hyperparameter name to current delta
333
+ mapping.
334
334
  :return: Dictionary mapping hyperparameter names to their current deltas for this parameter group.
335
335
  """
336
336
 
337
- if skip_hparams is None:
338
- skip_hparams = []
337
+ if include_hparams is None:
338
+ include_hparams = list(self.model_fields.keys())
339
339
 
340
340
  return {
341
341
  field_name: field_value.current_delta
342
342
  for field_name, field_value in self.__dict__.items()
343
- if isinstance(field_value, Hyperparameter) and field_name not in skip_hparams
343
+ if isinstance(field_value, Hyperparameter) and field_name in include_hparams
344
344
  }
345
345
 
346
346
  def set_internal_values(self, internal_values: dict[str, float | int | None]) -> None:
@@ -440,19 +440,21 @@ class ParameterGroupHParams(HyperparameterContainer):
440
440
  )
441
441
 
442
442
  def get_hyperparameter_transform_types(
443
- self, skip_hparams: list[str] | None = None
443
+ self, include_hparams: list[str] | None = None
444
444
  ) -> dict[str, HyperparameterTransformType]:
445
445
  """
446
+ :param include_hparams: Names of the hyperparameters to include in the hyperparameter name to transform type
447
+ mapping.
446
448
  :return: Dictionary mapping hyperparameter names to their transform type for this parameter group.
447
449
  """
448
450
 
449
- if skip_hparams is None:
450
- skip_hparams = []
451
+ if include_hparams is None:
452
+ include_hparams = list(self.model_fields.keys())
451
453
 
452
454
  return {
453
455
  field_name: field_value.transform_type
454
456
  for field_name, field_value in self.__dict__.items()
455
- if isinstance(field_value, Hyperparameter) and field_name not in skip_hparams
457
+ if isinstance(field_value, Hyperparameter) and field_name in include_hparams
456
458
  }
457
459
 
458
460
 
@@ -839,15 +841,15 @@ class HyperparameterStates(BaseModel):
839
841
  for parameter_group_hparams in self.parameter_group_hparams.values():
840
842
  parameter_group_hparams.set_to_initial_values()
841
843
 
842
- def get_initial_internal_values(self, skip_hparams: list[str] | None = None) -> dict[str, float | int | None]:
844
+ def get_initial_internal_values(self, include_hparams: list[str] | None = None) -> dict[str, float | int | None]:
843
845
  """
844
- :param skip_hparams: Hyperparameters to ignore while retrieving initial values.
846
+ :param include_hparams: Hyperparameters to include while retrieving initial values.
845
847
  :return: Dictionary mapping hyperparameter names to their initial values at the start of training.
846
848
  """
847
849
 
848
850
  initial_internal_values = {
849
- **self.global_hparams.get_initial_internal_values(skip_hparams),
850
- **next(iter(self.parameter_group_hparams.values())).get_initial_internal_values(skip_hparams),
851
+ **self.global_hparams.get_initial_internal_values(include_hparams),
852
+ **next(iter(self.parameter_group_hparams.values())).get_initial_internal_values(include_hparams),
851
853
  }
852
854
  initial_internal_values = {
853
855
  hparam_name: initial_internal_values.get(hparam_name, None)
@@ -856,15 +858,15 @@ class HyperparameterStates(BaseModel):
856
858
 
857
859
  return initial_internal_values
858
860
 
859
- def get_current_internal_values(self, skip_hparams: list[str] | None = None) -> dict[str, float | int | None]:
861
+ def get_current_internal_values(self, include_hparams: list[str] | None = None) -> dict[str, float | int | None]:
860
862
  """
861
- :param skip_hparams: Hyperparameters to ignore while retrieving current values.
863
+ :param include_hparams: Hyperparameters to include while retrieving current values.
862
864
  :return: Dictionary mapping hyperparameter names to their current values during training.
863
865
  """
864
866
 
865
867
  current_internal_values = {
866
- **self.global_hparams.get_current_internal_values(skip_hparams),
867
- **next(iter(self.parameter_group_hparams.values())).get_current_internal_values(skip_hparams),
868
+ **self.global_hparams.get_current_internal_values(include_hparams),
869
+ **next(iter(self.parameter_group_hparams.values())).get_current_internal_values(include_hparams),
868
870
  }
869
871
  current_internal_values = {
870
872
  hparam_name: current_internal_values.get(hparam_name, None)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: libinephany
3
- Version: 0.18.1
3
+ Version: 1.0.0
4
4
  Summary: Inephany library containing code commonly used by multiple subpackages.
5
5
  Author-email: Inephany <info@inephany.com>
6
6
  License: Apache 2.0
@@ -2,23 +2,23 @@ libinephany/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  libinephany/aws/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  libinephany/aws/s3_functions.py,sha256=W8u85A6tDloo4FlJvydJbVHCUq_m9i8KDGdnKzy-Xpg,1745
4
4
  libinephany/observations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- libinephany/observations/observation_utils.py,sha256=ejQ-hzKq_MX7A0304KypScmzloReNrH8dmcgonoRl4Q,16266
5
+ libinephany/observations/observation_utils.py,sha256=JSNJYEi2d-VQ0ZovfHrn28RDv41u-a6M-W4ZT8UUyhI,17279
6
6
  libinephany/observations/observer_pipeline.py,sha256=_xA4vrijhG8-9MCtGXnKAEmpd6q0nKVpJgY_qSbypIA,12979
7
7
  libinephany/observations/pipeline_coordinator.py,sha256=mLfaHhkXVhMp9w5jWIAL3jPyauCM-795qOzyqwGOSdw,7932
8
8
  libinephany/observations/statistic_manager.py,sha256=LLg1zSxnJr2oQQepYla3qoUuRy10rsthr9jta4wEbnc,8956
9
- libinephany/observations/statistic_trackers.py,sha256=06NjCy1MI865oU0KB5f-wQE3b2RvYawOOWxNJx4rFpw,32939
9
+ libinephany/observations/statistic_trackers.py,sha256=3LHvBXQ977-9ZW-KE9UohDeOayZdqQ5UQMTt0kuea40,47574
10
10
  libinephany/observations/observers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- libinephany/observations/observers/base_observers.py,sha256=Tkk2AvQ6mRY6fdhhxHQWohA0xSHGdPqxul1C7C7Frj4,15924
12
- libinephany/observations/observers/local_observers.py,sha256=7azTW227-rG_QJR5a0StfE0O4Ca1boMV9nrCphnhWik,40344
11
+ libinephany/observations/observers/base_observers.py,sha256=9YI_jkivoCjyPtNSn3VnPADF1VqwGdHPkgi1kDzed3Y,16516
12
+ libinephany/observations/observers/local_observers.py,sha256=yBXmuCaDZotJPmBZKdrPfGYtVO-CCvT5ZS0KvROOcE4,45657
13
13
  libinephany/observations/observers/observer_containers.py,sha256=VNyqGgxYJ4r49Msp_kk-POgicb-_5w54twuT1qfNdxw,9562
14
- libinephany/observations/observers/global_observers/__init__.py,sha256=lQO-nq7hILu4F3ddFXcCR-ghfv0dzEA9nAYxZta7rxk,2306
14
+ libinephany/observations/observers/global_observers/__init__.py,sha256=87WHRPYmL0tVsaTKUd91pwEpCZtHPSKRQoba2VQjswA,3018
15
15
  libinephany/observations/observers/global_observers/base_classes.py,sha256=CCkRx86Lll3gFzfqervP0jKdzNFKkKU7tEBh8ic1Yrc,8249
16
- libinephany/observations/observers/global_observers/constants.py,sha256=olXYxh353Th6hyhfk85kHxAWPeaDpuPkcz_awFvEz6c,1054
17
- libinephany/observations/observers/global_observers/gradient_observers.py,sha256=z2ow6zmTr7ujeaYyh5qGRE1woIt_yc2uxAaie7zgxqc,8398
18
- libinephany/observations/observers/global_observers/hyperparameter_observers.py,sha256=soGWoYpO5rqUQU0p4pr6QVjBvVg2odK71mVkxV38Ras,14838
16
+ libinephany/observations/observers/global_observers/constants.py,sha256=C_PwYhKxatJxNe5Jzb1tpoiRXAxxPrGkcdQBMQD8msY,1139
17
+ libinephany/observations/observers/global_observers/gradient_observers.py,sha256=ZeujBhKeq8adw_J13omurjZnfloiadMYkiPuXYUZ8BU,20972
18
+ libinephany/observations/observers/global_observers/hyperparameter_observers.py,sha256=o035-nSfjj7dy7Pz1IxpAqvU3tYQraxQd8Pttknxa6A,15034
19
19
  libinephany/observations/observers/global_observers/loss_observers.py,sha256=FlSuJqAJIXcAS_ypdZna6xxz89glI23A6D00sDn7ZLU,18508
20
- libinephany/observations/observers/global_observers/model_observers.py,sha256=bJIEdq5wWkLrw-MNllCe36FE4aMLbC4RoR1-wOOHMxc,19537
21
- libinephany/observations/observers/global_observers/progress_observers.py,sha256=m62jUiwPaOUzYG1h7Vg6znj_jK9699lDhg4AhK212s8,5615
20
+ libinephany/observations/observers/global_observers/model_observers.py,sha256=HVNHnqk2uuXkmP8y3SL-IQ0AYArfXc7b-wckv9X7qbM,28457
21
+ libinephany/observations/observers/global_observers/progress_observers.py,sha256=ypLk1_POAjA8V8rAaQ0B6Qh8m_04s9PAoXsw1KxVrLg,5872
22
22
  libinephany/observations/post_processors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
23
  libinephany/observations/post_processors/postprocessors.py,sha256=43_e5UaDPr2KbAvqc_w3wLqnlm7bgRjqgCtyQ95-8cM,5913
24
24
  libinephany/pydantic_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -32,9 +32,9 @@ libinephany/pydantic_models/schemas/inner_task_profile.py,sha256=Xu0tQmhGwV043tT
32
32
  libinephany/pydantic_models/schemas/observation_models.py,sha256=MLhxqDet9Yol1D5mkQGQsQT23sm37AStRLnPc4sgcZc,2110
33
33
  libinephany/pydantic_models/schemas/request_schemas.py,sha256=VED8eAUvBofxeAx9gWU8DyCZOTVD3QsHRq-TO7kyOqk,1260
34
34
  libinephany/pydantic_models/schemas/response_schemas.py,sha256=SKFuasdjX5aH_I0vT3SwnpwhyMf9cNPB1ZpDeAGgoO8,2158
35
- libinephany/pydantic_models/schemas/tensor_statistics.py,sha256=ZrBDiVy8Pt-Bw5RU3ePfjyFjP-jjC4_KsMmZy-Tjb1s,8115
35
+ libinephany/pydantic_models/schemas/tensor_statistics.py,sha256=hWl52SxPYHFBXUwwZ9X7E1iFhuxGsPypd14hFTx91jw,8166
36
36
  libinephany/pydantic_models/states/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
37
- libinephany/pydantic_models/states/hyperparameter_states.py,sha256=qQqnYyKR5ucENb_Zyk0x71_zpgUtG_rOzbOcE0kTrkQ,33478
37
+ libinephany/pydantic_models/states/hyperparameter_states.py,sha256=Esi1xdrH9xOJwhpSezkfbTzbKX4O26IpN5zzkWD3Mf8,33779
38
38
  libinephany/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
39
39
  libinephany/utils/agent_utils.py,sha256=_2w1AY5Y4mQ5hes_Rq014VhZXOtIOn-W92mZgeixv3g,2658
40
40
  libinephany/utils/asyncio_worker.py,sha256=Ew23zKIbG1zwyCudcyiObMrw4G0f3p2QXzZfM4mePqI,2751
@@ -57,8 +57,8 @@ libinephany/utils/typing.py,sha256=rGbaPO3MaUndsWiC_wHzReD_TOLYqb43i01pKN-j7Xs,6
57
57
  libinephany/web_apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
58
58
  libinephany/web_apps/error_logger.py,sha256=gAQIaqerqP4ornXZwFF1cghjnd2mMZEt3aVrTuUCr34,16653
59
59
  libinephany/web_apps/web_app_utils.py,sha256=qiq_lasPipgN1RgRudPJc342kYci8O_4RqppxmIX8NY,4095
60
- libinephany-0.18.1.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
61
- libinephany-0.18.1.dist-info/METADATA,sha256=d2rH-yA7F1cVk2OnTyvdaQ6_CFey1uiY4ATbqsxH9Pg,8390
62
- libinephany-0.18.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
63
- libinephany-0.18.1.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
64
- libinephany-0.18.1.dist-info/RECORD,,
60
+ libinephany-1.0.0.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
61
+ libinephany-1.0.0.dist-info/METADATA,sha256=N_0lNQBtTOt8bzKEyQhmiA_l4AkWKtMoq6GDL3FGXKI,8389
62
+ libinephany-1.0.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
63
+ libinephany-1.0.0.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
64
+ libinephany-1.0.0.dist-info/RECORD,,