libinephany 0.18.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- libinephany/observations/observation_utils.py +19 -2
- libinephany/observations/observers/base_observers.py +20 -8
- libinephany/observations/observers/global_observers/__init__.py +19 -1
- libinephany/observations/observers/global_observers/constants.py +2 -0
- libinephany/observations/observers/global_observers/gradient_observers.py +320 -3
- libinephany/observations/observers/global_observers/hyperparameter_observers.py +26 -18
- libinephany/observations/observers/global_observers/model_observers.py +220 -6
- libinephany/observations/observers/global_observers/progress_observers.py +7 -1
- libinephany/observations/observers/local_observers.py +158 -25
- libinephany/observations/statistic_trackers.py +435 -23
- libinephany/pydantic_models/schemas/tensor_statistics.py +33 -32
- libinephany/pydantic_models/states/hyperparameter_states.py +32 -30
- {libinephany-0.18.1.dist-info → libinephany-1.0.0.dist-info}/METADATA +1 -1
- {libinephany-0.18.1.dist-info → libinephany-1.0.0.dist-info}/RECORD +17 -17
- {libinephany-0.18.1.dist-info → libinephany-1.0.0.dist-info}/WHEEL +0 -0
- {libinephany-0.18.1.dist-info → libinephany-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {libinephany-0.18.1.dist-info → libinephany-1.0.0.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,7 @@ from pydantic import BaseModel
|
|
13
13
|
#
|
14
14
|
# ======================================================================================================================
|
15
15
|
|
16
|
-
|
16
|
+
FIELD_SUFFIX = "_"
|
17
17
|
|
18
18
|
# ======================================================================================================================
|
19
19
|
#
|
@@ -164,28 +164,38 @@ class TensorStatistics(BaseModel):
|
|
164
164
|
return tensor[random_indices]
|
165
165
|
|
166
166
|
@classmethod
|
167
|
-
def
|
167
|
+
def filter_include_statistics(cls, include_statistics: list[str]) -> list[str]:
|
168
168
|
"""
|
169
|
-
:param
|
170
|
-
:return:
|
171
|
-
|
169
|
+
:param include_statistics: Names of the fields in the model to include in returned observations.
|
170
|
+
:return: List of fields from the given include_statistics list that are present in this pydantic model.
|
171
|
+
:raises ValueError: If no statistics to include are given.
|
172
172
|
"""
|
173
173
|
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
else
|
178
|
-
|
174
|
+
filtered_include_statistics: list[str] = []
|
175
|
+
|
176
|
+
for include_stat in include_statistics:
|
177
|
+
with_suffix = include_stat + FIELD_SUFFIX if not include_stat.endswith(FIELD_SUFFIX) else include_stat
|
178
|
+
|
179
|
+
if with_suffix in cls.model_fields.keys():
|
180
|
+
filtered_include_statistics.append(with_suffix)
|
181
|
+
|
182
|
+
if not filtered_include_statistics:
|
183
|
+
raise ValueError(f"No statistics to include given to {cls.__name__}!")
|
184
|
+
|
185
|
+
return filtered_include_statistics
|
179
186
|
|
180
187
|
@classmethod
|
181
188
|
def build(
|
182
|
-
cls,
|
189
|
+
cls,
|
190
|
+
tensor: torch.Tensor,
|
191
|
+
include_statistics: list[str],
|
192
|
+
sample_percentage: float = 0.01,
|
183
193
|
) -> "TensorStatistics":
|
184
194
|
"""
|
185
195
|
:param tensor: Tensor to compute and store statistics of.
|
196
|
+
:param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
197
|
+
fields in the model to include in returned observations.
|
186
198
|
:param sample_percentage: Percentage of the given tensor to randomly sample and compute statistics from.
|
187
|
-
:param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
188
|
-
fields in the model to not include in returned observations.
|
189
199
|
:return: Constructed tensor statistics.
|
190
200
|
"""
|
191
201
|
|
@@ -193,12 +203,10 @@ class TensorStatistics(BaseModel):
|
|
193
203
|
downsampled_tensor = cls.downsample_tensor(tensor=tensor, sample_percentage=sample_percentage)
|
194
204
|
|
195
205
|
for field, field_value in stats.model_dump().items():
|
196
|
-
name = field[:-1] if field.endswith(
|
197
|
-
|
198
|
-
if skip_statistics is not None and name in skip_statistics:
|
199
|
-
continue
|
206
|
+
name = field[:-1] if field.endswith(FIELD_SUFFIX) else field
|
200
207
|
|
201
|
-
|
208
|
+
if name in include_statistics:
|
209
|
+
setattr(stats, name, downsampled_tensor)
|
202
210
|
|
203
211
|
return stats
|
204
212
|
|
@@ -219,28 +227,21 @@ class TensorStatistics(BaseModel):
|
|
219
227
|
inter_quartile_range_=tensor[6],
|
220
228
|
)
|
221
229
|
|
222
|
-
def to_list(self,
|
230
|
+
def to_list(self, include_statistics: list[str]) -> list[float]:
|
223
231
|
"""
|
224
|
-
:param
|
232
|
+
:param include_statistics: List of field names to include in the returned list.
|
225
233
|
:return: List of field values.
|
226
234
|
"""
|
227
235
|
|
228
|
-
|
229
|
-
skip_statistics = []
|
230
|
-
|
231
|
-
if not all(skip_stat in self.model_fields.keys() for skip_stat in skip_statistics):
|
232
|
-
raise ValueError(
|
233
|
-
f"One or more skip statistic keys do not exist in TensorStatistics. Valid Skip Keys: "
|
234
|
-
f"{list(self.model_fields.keys())} Given Skip Keys: {skip_statistics}"
|
235
|
-
)
|
236
|
+
filtered_includes = self.filter_include_statistics(include_statistics=include_statistics)
|
236
237
|
|
237
238
|
as_list = []
|
238
239
|
|
239
240
|
for field, field_value in self.model_dump().items():
|
240
|
-
|
241
|
-
continue
|
241
|
+
without_suffix = field[:-1]
|
242
242
|
|
243
|
-
|
243
|
+
if field in filtered_includes or without_suffix in filtered_includes:
|
244
|
+
as_list.append(field_value)
|
244
245
|
|
245
246
|
return as_list
|
246
247
|
|
@@ -267,6 +268,6 @@ class TensorStatistics(BaseModel):
|
|
267
268
|
"""
|
268
269
|
|
269
270
|
return {
|
270
|
-
field[:-1] if field.endswith(
|
271
|
+
field[:-1] if field.endswith(FIELD_SUFFIX) else field: field_value
|
271
272
|
for field, field_value in self.model_dump().items()
|
272
273
|
}
|
@@ -295,52 +295,52 @@ class HyperparameterContainer(BaseModel):
|
|
295
295
|
f"{self.get_hyperparameters()}."
|
296
296
|
)
|
297
297
|
|
298
|
-
def get_initial_internal_values(self,
|
298
|
+
def get_initial_internal_values(self, include_hparams: list[str] | None = None) -> dict[str, float | int | None]:
|
299
299
|
"""
|
300
|
-
:param
|
301
|
-
|
300
|
+
:param include_hparams: Names of the hyperparameters to include in the hyperparameter name to initial internal
|
301
|
+
value mapping.
|
302
302
|
:return: Dictionary mapping hyperparameter names to their current values for this parameter group.
|
303
303
|
"""
|
304
304
|
|
305
|
-
if
|
306
|
-
|
305
|
+
if include_hparams is None:
|
306
|
+
include_hparams = list(self.model_fields.keys())
|
307
307
|
|
308
308
|
return {
|
309
309
|
field_name: field_value.initial_internal_value
|
310
310
|
for field_name, field_value in self.__dict__.items()
|
311
|
-
if isinstance(field_value, Hyperparameter) and field_name
|
311
|
+
if isinstance(field_value, Hyperparameter) and field_name in include_hparams
|
312
312
|
}
|
313
313
|
|
314
|
-
def get_current_internal_values(self,
|
314
|
+
def get_current_internal_values(self, include_hparams: list[str] | None = None) -> dict[str, float | int | None]:
|
315
315
|
"""
|
316
|
-
:param
|
317
|
-
|
316
|
+
:param include_hparams: Names of the hyperparameters to include in the hyperparameter name to current internal
|
317
|
+
value mapping.
|
318
318
|
:return: Dictionary mapping hyperparameter names to their current values for this parameter group.
|
319
319
|
"""
|
320
320
|
|
321
|
-
if
|
322
|
-
|
321
|
+
if include_hparams is None:
|
322
|
+
include_hparams = list(self.model_fields.keys())
|
323
323
|
|
324
324
|
return {
|
325
325
|
field_name: field_value.current_internal_value
|
326
326
|
for field_name, field_value in self.__dict__.items()
|
327
|
-
if isinstance(field_value, Hyperparameter) and field_name
|
327
|
+
if isinstance(field_value, Hyperparameter) and field_name in include_hparams
|
328
328
|
}
|
329
329
|
|
330
|
-
def get_current_deltas(self,
|
330
|
+
def get_current_deltas(self, include_hparams: list[str] | None = None) -> dict[str, float | int | None]:
|
331
331
|
"""
|
332
|
-
:param
|
333
|
-
|
332
|
+
:param include_hparams: Names of the hyperparameters to include in the hyperparameter name to current delta
|
333
|
+
mapping.
|
334
334
|
:return: Dictionary mapping hyperparameter names to their current deltas for this parameter group.
|
335
335
|
"""
|
336
336
|
|
337
|
-
if
|
338
|
-
|
337
|
+
if include_hparams is None:
|
338
|
+
include_hparams = list(self.model_fields.keys())
|
339
339
|
|
340
340
|
return {
|
341
341
|
field_name: field_value.current_delta
|
342
342
|
for field_name, field_value in self.__dict__.items()
|
343
|
-
if isinstance(field_value, Hyperparameter) and field_name
|
343
|
+
if isinstance(field_value, Hyperparameter) and field_name in include_hparams
|
344
344
|
}
|
345
345
|
|
346
346
|
def set_internal_values(self, internal_values: dict[str, float | int | None]) -> None:
|
@@ -440,19 +440,21 @@ class ParameterGroupHParams(HyperparameterContainer):
|
|
440
440
|
)
|
441
441
|
|
442
442
|
def get_hyperparameter_transform_types(
|
443
|
-
self,
|
443
|
+
self, include_hparams: list[str] | None = None
|
444
444
|
) -> dict[str, HyperparameterTransformType]:
|
445
445
|
"""
|
446
|
+
:param include_hparams: Names of the hyperparameters to include in the hyperparameter name to transform type
|
447
|
+
mapping.
|
446
448
|
:return: Dictionary mapping hyperparameter names to their transform type for this parameter group.
|
447
449
|
"""
|
448
450
|
|
449
|
-
if
|
450
|
-
|
451
|
+
if include_hparams is None:
|
452
|
+
include_hparams = list(self.model_fields.keys())
|
451
453
|
|
452
454
|
return {
|
453
455
|
field_name: field_value.transform_type
|
454
456
|
for field_name, field_value in self.__dict__.items()
|
455
|
-
if isinstance(field_value, Hyperparameter) and field_name
|
457
|
+
if isinstance(field_value, Hyperparameter) and field_name in include_hparams
|
456
458
|
}
|
457
459
|
|
458
460
|
|
@@ -839,15 +841,15 @@ class HyperparameterStates(BaseModel):
|
|
839
841
|
for parameter_group_hparams in self.parameter_group_hparams.values():
|
840
842
|
parameter_group_hparams.set_to_initial_values()
|
841
843
|
|
842
|
-
def get_initial_internal_values(self,
|
844
|
+
def get_initial_internal_values(self, include_hparams: list[str] | None = None) -> dict[str, float | int | None]:
|
843
845
|
"""
|
844
|
-
:param
|
846
|
+
:param include_hparams: Hyperparameters to include while retrieving initial values.
|
845
847
|
:return: Dictionary mapping hyperparameter names to their initial values at the start of training.
|
846
848
|
"""
|
847
849
|
|
848
850
|
initial_internal_values = {
|
849
|
-
**self.global_hparams.get_initial_internal_values(
|
850
|
-
**next(iter(self.parameter_group_hparams.values())).get_initial_internal_values(
|
851
|
+
**self.global_hparams.get_initial_internal_values(include_hparams),
|
852
|
+
**next(iter(self.parameter_group_hparams.values())).get_initial_internal_values(include_hparams),
|
851
853
|
}
|
852
854
|
initial_internal_values = {
|
853
855
|
hparam_name: initial_internal_values.get(hparam_name, None)
|
@@ -856,15 +858,15 @@ class HyperparameterStates(BaseModel):
|
|
856
858
|
|
857
859
|
return initial_internal_values
|
858
860
|
|
859
|
-
def get_current_internal_values(self,
|
861
|
+
def get_current_internal_values(self, include_hparams: list[str] | None = None) -> dict[str, float | int | None]:
|
860
862
|
"""
|
861
|
-
:param
|
863
|
+
:param include_hparams: Hyperparameters to include while retrieving current values.
|
862
864
|
:return: Dictionary mapping hyperparameter names to their current values during training.
|
863
865
|
"""
|
864
866
|
|
865
867
|
current_internal_values = {
|
866
|
-
**self.global_hparams.get_current_internal_values(
|
867
|
-
**next(iter(self.parameter_group_hparams.values())).get_current_internal_values(
|
868
|
+
**self.global_hparams.get_current_internal_values(include_hparams),
|
869
|
+
**next(iter(self.parameter_group_hparams.values())).get_current_internal_values(include_hparams),
|
868
870
|
}
|
869
871
|
current_internal_values = {
|
870
872
|
hparam_name: current_internal_values.get(hparam_name, None)
|
@@ -2,23 +2,23 @@ libinephany/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
libinephany/aws/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
libinephany/aws/s3_functions.py,sha256=W8u85A6tDloo4FlJvydJbVHCUq_m9i8KDGdnKzy-Xpg,1745
|
4
4
|
libinephany/observations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
|
-
libinephany/observations/observation_utils.py,sha256=
|
5
|
+
libinephany/observations/observation_utils.py,sha256=JSNJYEi2d-VQ0ZovfHrn28RDv41u-a6M-W4ZT8UUyhI,17279
|
6
6
|
libinephany/observations/observer_pipeline.py,sha256=_xA4vrijhG8-9MCtGXnKAEmpd6q0nKVpJgY_qSbypIA,12979
|
7
7
|
libinephany/observations/pipeline_coordinator.py,sha256=mLfaHhkXVhMp9w5jWIAL3jPyauCM-795qOzyqwGOSdw,7932
|
8
8
|
libinephany/observations/statistic_manager.py,sha256=LLg1zSxnJr2oQQepYla3qoUuRy10rsthr9jta4wEbnc,8956
|
9
|
-
libinephany/observations/statistic_trackers.py,sha256=
|
9
|
+
libinephany/observations/statistic_trackers.py,sha256=3LHvBXQ977-9ZW-KE9UohDeOayZdqQ5UQMTt0kuea40,47574
|
10
10
|
libinephany/observations/observers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
-
libinephany/observations/observers/base_observers.py,sha256=
|
12
|
-
libinephany/observations/observers/local_observers.py,sha256=
|
11
|
+
libinephany/observations/observers/base_observers.py,sha256=9YI_jkivoCjyPtNSn3VnPADF1VqwGdHPkgi1kDzed3Y,16516
|
12
|
+
libinephany/observations/observers/local_observers.py,sha256=yBXmuCaDZotJPmBZKdrPfGYtVO-CCvT5ZS0KvROOcE4,45657
|
13
13
|
libinephany/observations/observers/observer_containers.py,sha256=VNyqGgxYJ4r49Msp_kk-POgicb-_5w54twuT1qfNdxw,9562
|
14
|
-
libinephany/observations/observers/global_observers/__init__.py,sha256=
|
14
|
+
libinephany/observations/observers/global_observers/__init__.py,sha256=87WHRPYmL0tVsaTKUd91pwEpCZtHPSKRQoba2VQjswA,3018
|
15
15
|
libinephany/observations/observers/global_observers/base_classes.py,sha256=CCkRx86Lll3gFzfqervP0jKdzNFKkKU7tEBh8ic1Yrc,8249
|
16
|
-
libinephany/observations/observers/global_observers/constants.py,sha256=
|
17
|
-
libinephany/observations/observers/global_observers/gradient_observers.py,sha256=
|
18
|
-
libinephany/observations/observers/global_observers/hyperparameter_observers.py,sha256=
|
16
|
+
libinephany/observations/observers/global_observers/constants.py,sha256=C_PwYhKxatJxNe5Jzb1tpoiRXAxxPrGkcdQBMQD8msY,1139
|
17
|
+
libinephany/observations/observers/global_observers/gradient_observers.py,sha256=ZeujBhKeq8adw_J13omurjZnfloiadMYkiPuXYUZ8BU,20972
|
18
|
+
libinephany/observations/observers/global_observers/hyperparameter_observers.py,sha256=o035-nSfjj7dy7Pz1IxpAqvU3tYQraxQd8Pttknxa6A,15034
|
19
19
|
libinephany/observations/observers/global_observers/loss_observers.py,sha256=FlSuJqAJIXcAS_ypdZna6xxz89glI23A6D00sDn7ZLU,18508
|
20
|
-
libinephany/observations/observers/global_observers/model_observers.py,sha256=
|
21
|
-
libinephany/observations/observers/global_observers/progress_observers.py,sha256=
|
20
|
+
libinephany/observations/observers/global_observers/model_observers.py,sha256=HVNHnqk2uuXkmP8y3SL-IQ0AYArfXc7b-wckv9X7qbM,28457
|
21
|
+
libinephany/observations/observers/global_observers/progress_observers.py,sha256=ypLk1_POAjA8V8rAaQ0B6Qh8m_04s9PAoXsw1KxVrLg,5872
|
22
22
|
libinephany/observations/post_processors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
23
23
|
libinephany/observations/post_processors/postprocessors.py,sha256=43_e5UaDPr2KbAvqc_w3wLqnlm7bgRjqgCtyQ95-8cM,5913
|
24
24
|
libinephany/pydantic_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -32,9 +32,9 @@ libinephany/pydantic_models/schemas/inner_task_profile.py,sha256=Xu0tQmhGwV043tT
|
|
32
32
|
libinephany/pydantic_models/schemas/observation_models.py,sha256=MLhxqDet9Yol1D5mkQGQsQT23sm37AStRLnPc4sgcZc,2110
|
33
33
|
libinephany/pydantic_models/schemas/request_schemas.py,sha256=VED8eAUvBofxeAx9gWU8DyCZOTVD3QsHRq-TO7kyOqk,1260
|
34
34
|
libinephany/pydantic_models/schemas/response_schemas.py,sha256=SKFuasdjX5aH_I0vT3SwnpwhyMf9cNPB1ZpDeAGgoO8,2158
|
35
|
-
libinephany/pydantic_models/schemas/tensor_statistics.py,sha256=
|
35
|
+
libinephany/pydantic_models/schemas/tensor_statistics.py,sha256=hWl52SxPYHFBXUwwZ9X7E1iFhuxGsPypd14hFTx91jw,8166
|
36
36
|
libinephany/pydantic_models/states/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
37
|
-
libinephany/pydantic_models/states/hyperparameter_states.py,sha256=
|
37
|
+
libinephany/pydantic_models/states/hyperparameter_states.py,sha256=Esi1xdrH9xOJwhpSezkfbTzbKX4O26IpN5zzkWD3Mf8,33779
|
38
38
|
libinephany/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
39
39
|
libinephany/utils/agent_utils.py,sha256=_2w1AY5Y4mQ5hes_Rq014VhZXOtIOn-W92mZgeixv3g,2658
|
40
40
|
libinephany/utils/asyncio_worker.py,sha256=Ew23zKIbG1zwyCudcyiObMrw4G0f3p2QXzZfM4mePqI,2751
|
@@ -57,8 +57,8 @@ libinephany/utils/typing.py,sha256=rGbaPO3MaUndsWiC_wHzReD_TOLYqb43i01pKN-j7Xs,6
|
|
57
57
|
libinephany/web_apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
58
58
|
libinephany/web_apps/error_logger.py,sha256=gAQIaqerqP4ornXZwFF1cghjnd2mMZEt3aVrTuUCr34,16653
|
59
59
|
libinephany/web_apps/web_app_utils.py,sha256=qiq_lasPipgN1RgRudPJc342kYci8O_4RqppxmIX8NY,4095
|
60
|
-
libinephany-0.
|
61
|
-
libinephany-0.
|
62
|
-
libinephany-0.
|
63
|
-
libinephany-0.
|
64
|
-
libinephany-0.
|
60
|
+
libinephany-1.0.0.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
|
61
|
+
libinephany-1.0.0.dist-info/METADATA,sha256=N_0lNQBtTOt8bzKEyQhmiA_l4AkWKtMoq6GDL3FGXKI,8389
|
62
|
+
libinephany-1.0.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
63
|
+
libinephany-1.0.0.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
|
64
|
+
libinephany-1.0.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|