libinephany 0.16.3__py3-none-any.whl → 0.16.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- libinephany/observations/observation_utils.py +2 -14
- libinephany/observations/observers/global_observers/__init__.py +58 -0
- libinephany/observations/observers/global_observers/base_classes.py +183 -0
- libinephany/observations/observers/global_observers/constants.py +33 -0
- libinephany/observations/observers/global_observers/gradient_observers.py +112 -0
- libinephany/observations/observers/global_observers/hyperparameter_observers.py +286 -0
- libinephany/observations/observers/global_observers/loss_observers.py +464 -0
- libinephany/observations/observers/global_observers/model_observers.py +327 -0
- libinephany/observations/observers/global_observers/progress_observers.py +142 -0
- libinephany/observations/observers/observer_containers.py +3 -1
- {libinephany-0.16.3.dist-info → libinephany-0.16.4.dist-info}/METADATA +1 -1
- {libinephany-0.16.3.dist-info → libinephany-0.16.4.dist-info}/RECORD +15 -8
- libinephany/observations/observers/global_observers.py +0 -991
- {libinephany-0.16.3.dist-info → libinephany-0.16.4.dist-info}/WHEEL +0 -0
- {libinephany-0.16.3.dist-info → libinephany-0.16.4.dist-info}/licenses/LICENSE +0 -0
- {libinephany-0.16.3.dist-info → libinephany-0.16.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,464 @@
|
|
1
|
+
# ======================================================================================================================
|
2
|
+
#
|
3
|
+
# IMPORTS
|
4
|
+
#
|
5
|
+
# ======================================================================================================================
|
6
|
+
|
7
|
+
import math
|
8
|
+
from typing import Any
|
9
|
+
|
10
|
+
from libinephany.observations.observation_utils import StatisticStorageTypes
|
11
|
+
from libinephany.observations.observers.base_observers import GlobalObserver
|
12
|
+
from libinephany.observations.observers.global_observers.base_classes import (
|
13
|
+
LHOPTCheckpointBaseObserver,
|
14
|
+
LHOPTOuterStepBaseObserver,
|
15
|
+
)
|
16
|
+
from libinephany.observations.observers.global_observers.constants import LHOPT_CONSTANTS
|
17
|
+
from libinephany.pydantic_models.schemas.observation_models import ObservationInputs
|
18
|
+
from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
|
19
|
+
from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
|
20
|
+
|
21
|
+
# ======================================================================================================================
|
22
|
+
#
|
23
|
+
# CLASSES
|
24
|
+
#
|
25
|
+
# ======================================================================================================================
|
26
|
+
|
27
|
+
|
28
|
+
class TrainingLoss(GlobalObserver):
|
29
|
+
|
30
|
+
@property
|
31
|
+
def can_standardize(self) -> bool:
|
32
|
+
"""
|
33
|
+
:return: Whether the observation can be standardized.
|
34
|
+
"""
|
35
|
+
|
36
|
+
return False
|
37
|
+
|
38
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
39
|
+
"""
|
40
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
41
|
+
enumeration class.
|
42
|
+
"""
|
43
|
+
|
44
|
+
return StatisticStorageTypes.FLOAT
|
45
|
+
|
46
|
+
def _observe(
|
47
|
+
self,
|
48
|
+
observation_inputs: ObservationInputs,
|
49
|
+
hyperparameter_states: HyperparameterStates,
|
50
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
51
|
+
action_taken: float | int | None,
|
52
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
53
|
+
"""
|
54
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
55
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
56
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
57
|
+
names to floats or TensorStatistic models.
|
58
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
59
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
60
|
+
"""
|
61
|
+
|
62
|
+
return observation_inputs.training_loss
|
63
|
+
|
64
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
65
|
+
"""
|
66
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
67
|
+
needed.
|
68
|
+
"""
|
69
|
+
|
70
|
+
return {}
|
71
|
+
|
72
|
+
|
73
|
+
class ValidationLoss(GlobalObserver):
|
74
|
+
|
75
|
+
@property
|
76
|
+
def can_standardize(self) -> bool:
|
77
|
+
"""
|
78
|
+
:return: Whether the observation can be standardized.
|
79
|
+
"""
|
80
|
+
|
81
|
+
return False
|
82
|
+
|
83
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
84
|
+
"""
|
85
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
86
|
+
enumeration class.
|
87
|
+
"""
|
88
|
+
|
89
|
+
return StatisticStorageTypes.FLOAT
|
90
|
+
|
91
|
+
def _observe(
|
92
|
+
self,
|
93
|
+
observation_inputs: ObservationInputs,
|
94
|
+
hyperparameter_states: HyperparameterStates,
|
95
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
96
|
+
action_taken: float | int | None,
|
97
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
98
|
+
"""
|
99
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
100
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
101
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
102
|
+
names to floats or TensorStatistic models.
|
103
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
104
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
105
|
+
"""
|
106
|
+
|
107
|
+
return observation_inputs.validation_loss
|
108
|
+
|
109
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
110
|
+
"""
|
111
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
112
|
+
needed.
|
113
|
+
"""
|
114
|
+
|
115
|
+
return {}
|
116
|
+
|
117
|
+
|
118
|
+
class LossRatio(GlobalObserver):
|
119
|
+
|
120
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
121
|
+
"""
|
122
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
123
|
+
enumeration class.
|
124
|
+
"""
|
125
|
+
|
126
|
+
return StatisticStorageTypes.FLOAT
|
127
|
+
|
128
|
+
def _observe(
|
129
|
+
self,
|
130
|
+
observation_inputs: ObservationInputs,
|
131
|
+
hyperparameter_states: HyperparameterStates,
|
132
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
133
|
+
action_taken: float | int | None,
|
134
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
135
|
+
"""
|
136
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
137
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
138
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
139
|
+
names to floats or TensorStatistic models.
|
140
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
141
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
142
|
+
"""
|
143
|
+
|
144
|
+
if observation_inputs.validation_loss == 0:
|
145
|
+
return 0
|
146
|
+
|
147
|
+
return observation_inputs.training_loss / observation_inputs.validation_loss
|
148
|
+
|
149
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
150
|
+
"""
|
151
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
152
|
+
needed.
|
153
|
+
"""
|
154
|
+
|
155
|
+
return {}
|
156
|
+
|
157
|
+
|
158
|
+
class TrainingScore(GlobalObserver):
|
159
|
+
|
160
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
161
|
+
"""
|
162
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
163
|
+
enumeration class.
|
164
|
+
"""
|
165
|
+
|
166
|
+
return StatisticStorageTypes.FLOAT
|
167
|
+
|
168
|
+
def _observe(
|
169
|
+
self,
|
170
|
+
observation_inputs: ObservationInputs,
|
171
|
+
hyperparameter_states: HyperparameterStates,
|
172
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
173
|
+
action_taken: float | int | None,
|
174
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
175
|
+
"""
|
176
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
177
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
178
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
179
|
+
names to floats or TensorStatistic models.
|
180
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
181
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
182
|
+
"""
|
183
|
+
|
184
|
+
return observation_inputs.training_score
|
185
|
+
|
186
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
187
|
+
"""
|
188
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
189
|
+
needed.
|
190
|
+
"""
|
191
|
+
|
192
|
+
return {}
|
193
|
+
|
194
|
+
|
195
|
+
class ValidationScore(GlobalObserver):
|
196
|
+
|
197
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
198
|
+
"""
|
199
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
200
|
+
enumeration class.
|
201
|
+
"""
|
202
|
+
|
203
|
+
return StatisticStorageTypes.FLOAT
|
204
|
+
|
205
|
+
def _observe(
|
206
|
+
self,
|
207
|
+
observation_inputs: ObservationInputs,
|
208
|
+
hyperparameter_states: HyperparameterStates,
|
209
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
210
|
+
action_taken: float | int | None,
|
211
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
212
|
+
"""
|
213
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
214
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
215
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
216
|
+
names to floats or TensorStatistic models.
|
217
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
218
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
219
|
+
"""
|
220
|
+
|
221
|
+
return observation_inputs.validation_score
|
222
|
+
|
223
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
224
|
+
"""
|
225
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
226
|
+
needed.
|
227
|
+
"""
|
228
|
+
|
229
|
+
return {}
|
230
|
+
|
231
|
+
|
232
|
+
class LHOPTTrainingLoss(LHOPTOuterStepBaseObserver):
|
233
|
+
"""
|
234
|
+
This is a global observer from the OpenAI paper "Learning to Optimize with Reinforcement Learning"
|
235
|
+
https://arxiv.org/abs/2305.18291.
|
236
|
+
|
237
|
+
It returns three-dimensional observations: [is_nan, is_inf, cdf_feature] for training loss values.
|
238
|
+
|
239
|
+
This observer use the CDF calculation from the paper and applies CDF transformation using the CDF mean and std.
|
240
|
+
"""
|
241
|
+
|
242
|
+
def _observe(
|
243
|
+
self,
|
244
|
+
observation_inputs: ObservationInputs,
|
245
|
+
hyperparameter_states: HyperparameterStates,
|
246
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
247
|
+
action_taken: float | int | None,
|
248
|
+
) -> list[int | float]:
|
249
|
+
"""
|
250
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
251
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
252
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
253
|
+
names to floats or TensorStatistic models.
|
254
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
255
|
+
:return: List of three features: [is_nan, is_inf, cdf_feature]
|
256
|
+
"""
|
257
|
+
|
258
|
+
training_loss = observation_inputs.training_loss
|
259
|
+
|
260
|
+
cdf_feature = self._compute_cdf_feature(training_loss)
|
261
|
+
|
262
|
+
self._update_time()
|
263
|
+
|
264
|
+
return [int(math.isnan(training_loss)), int(math.isinf(training_loss)), cdf_feature]
|
265
|
+
|
266
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
267
|
+
"""
|
268
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
269
|
+
needed.
|
270
|
+
"""
|
271
|
+
|
272
|
+
return {}
|
273
|
+
|
274
|
+
def reset(self) -> None:
|
275
|
+
"""Reset the observer by clearing the time series."""
|
276
|
+
self._time_series: list[tuple[float, float]] = []
|
277
|
+
self._current_time: float = 0.0
|
278
|
+
|
279
|
+
@property
|
280
|
+
def vector_length(self) -> int:
|
281
|
+
"""
|
282
|
+
:return: Length of the vector returned by this observation if it returns a vector.
|
283
|
+
"""
|
284
|
+
return 3 # [is_nan, is_inf, cdf_feature]
|
285
|
+
|
286
|
+
|
287
|
+
class LHOPTValidationLoss(LHOPTOuterStepBaseObserver):
|
288
|
+
"""
|
289
|
+
This is a global observer from the OpenAI paper "Learning to Optimize with Reinforcement Learning"
|
290
|
+
https://arxiv.org/abs/2305.18291.
|
291
|
+
|
292
|
+
It returns three-dimensional observations: [is_nan, is_inf, cdf_feature] for validation loss values.
|
293
|
+
|
294
|
+
This observer use the CDF calculation from the paper and applies CDF transformation using the CDF mean and std.
|
295
|
+
"""
|
296
|
+
|
297
|
+
def _observe(
|
298
|
+
self,
|
299
|
+
observation_inputs: ObservationInputs,
|
300
|
+
hyperparameter_states: HyperparameterStates,
|
301
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
302
|
+
action_taken: float | int | None,
|
303
|
+
) -> list[int | float]:
|
304
|
+
"""
|
305
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
306
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
307
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
308
|
+
names to floats or TensorStatistics models.
|
309
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
310
|
+
:return: List of three features: [is_nan, is_inf, cdf_feature]
|
311
|
+
"""
|
312
|
+
|
313
|
+
validation_loss = observation_inputs.validation_loss
|
314
|
+
|
315
|
+
cdf_feature = self._compute_cdf_feature(validation_loss)
|
316
|
+
|
317
|
+
self._update_time()
|
318
|
+
|
319
|
+
return [int(math.isnan(validation_loss)), int(math.isinf(validation_loss)), cdf_feature]
|
320
|
+
|
321
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
322
|
+
"""
|
323
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
324
|
+
needed.
|
325
|
+
"""
|
326
|
+
|
327
|
+
return {}
|
328
|
+
|
329
|
+
def reset(self) -> None:
|
330
|
+
"""Reset the observer by clearing the time series."""
|
331
|
+
self._time_series: list[tuple[float, float]] = []
|
332
|
+
self._current_time: float = 0.0
|
333
|
+
|
334
|
+
@property
|
335
|
+
def vector_length(self) -> int:
|
336
|
+
"""
|
337
|
+
:return: Length of the vector returned by this observation if it returns a vector.
|
338
|
+
"""
|
339
|
+
return 3 # [is_nan, is_inf, cdf_feature]
|
340
|
+
|
341
|
+
|
342
|
+
class LHOPTLossRatio(LHOPTOuterStepBaseObserver):
|
343
|
+
"""
|
344
|
+
This is a global observer from the OpenAI paper "Learning to Optimize with Reinforcement Learning"
|
345
|
+
https://arxiv.org/abs/2305.18291.
|
346
|
+
|
347
|
+
It returns three-dimensional observations: [is_nan, tanh, cdf_feature] for loss ratio values.
|
348
|
+
|
349
|
+
This observer computes the logarithm of the ratio between validation_score and training_score,
|
350
|
+
providing three features:
|
351
|
+
1. is_nan - whether the log ratio is NaN
|
352
|
+
2. tanh(log_ratio) - bounded feature using hyperbolic tangent
|
353
|
+
3. cdf_feature - CDF transformed feature using CDF mean and std
|
354
|
+
"""
|
355
|
+
|
356
|
+
def _observe(
|
357
|
+
self,
|
358
|
+
observation_inputs: ObservationInputs,
|
359
|
+
hyperparameter_states: HyperparameterStates,
|
360
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
361
|
+
action_taken: float | int | None,
|
362
|
+
) -> list[int | float]:
|
363
|
+
"""
|
364
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
365
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
366
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
367
|
+
names to floats or TensorStatistics models.
|
368
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
369
|
+
:return: List of three features: [is_nan, tanh, cdf_feature]
|
370
|
+
"""
|
371
|
+
|
372
|
+
log_ratio = self._compute_log_ratio(
|
373
|
+
training_score=observation_inputs.training_score, validation_score=observation_inputs.validation_score
|
374
|
+
)
|
375
|
+
|
376
|
+
tanh_feature = math.tanh(max(-LHOPT_CONSTANTS["TANH_BOUND"], min(LHOPT_CONSTANTS["TANH_BOUND"], log_ratio)))
|
377
|
+
|
378
|
+
cdf_feature = self._compute_cdf_feature(log_ratio)
|
379
|
+
|
380
|
+
self._update_time()
|
381
|
+
|
382
|
+
return [int(math.isnan(log_ratio)), tanh_feature, cdf_feature]
|
383
|
+
|
384
|
+
def _compute_log_ratio(self, training_score: float, validation_score: float) -> float:
|
385
|
+
"""
|
386
|
+
Compute the log ratio of validation_score to training_score.
|
387
|
+
|
388
|
+
:param training_score: Training score value
|
389
|
+
:param validation_score: Validation score value
|
390
|
+
:return: Log ratio value
|
391
|
+
"""
|
392
|
+
if training_score <= 0:
|
393
|
+
return 0.0
|
394
|
+
|
395
|
+
if validation_score <= 0:
|
396
|
+
return 0.0
|
397
|
+
|
398
|
+
# Calculate the ratio of validation_score to training_score
|
399
|
+
score_ratio = validation_score / training_score
|
400
|
+
|
401
|
+
return math.log(score_ratio)
|
402
|
+
|
403
|
+
@property
|
404
|
+
def vector_length(self) -> int:
|
405
|
+
"""
|
406
|
+
:return: Length of the vector returned by this observation if it returns a vector.
|
407
|
+
"""
|
408
|
+
return 3 # [is_nan, tanh, cdf_feature]
|
409
|
+
|
410
|
+
|
411
|
+
class PercentileOfLossAtEachCheckpoint(LHOPTCheckpointBaseObserver):
|
412
|
+
"""
|
413
|
+
Observer that computes percentile of loss values at each checkpoint.
|
414
|
+
"""
|
415
|
+
|
416
|
+
def __init__(
|
417
|
+
self,
|
418
|
+
checkpoint_interval: int = LHOPT_CONSTANTS["DEFAULT_CHECKPOINT_INTERVAL"],
|
419
|
+
percentile: float = LHOPT_CONSTANTS["DEFAULT_PERCENTILE"],
|
420
|
+
**kwargs,
|
421
|
+
) -> None:
|
422
|
+
"""
|
423
|
+
:param checkpoint_interval: How often to create checkpoints (in training steps).
|
424
|
+
:param percentile: Percentile to compute (0.0 to 1.0).
|
425
|
+
:param kwargs: Miscellaneous keyword arguments.
|
426
|
+
"""
|
427
|
+
super().__init__(checkpoint_interval=checkpoint_interval, **kwargs)
|
428
|
+
self.percentile = max(0.0, min(1.0, percentile))
|
429
|
+
|
430
|
+
def _observe(
|
431
|
+
self,
|
432
|
+
observation_inputs: ObservationInputs,
|
433
|
+
hyperparameter_states: HyperparameterStates,
|
434
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
435
|
+
action_taken: float | int | None,
|
436
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
437
|
+
"""
|
438
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
439
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
440
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
441
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
442
|
+
:return: Percentile value of loss at checkpoint.
|
443
|
+
"""
|
444
|
+
training_loss = observation_inputs.training_loss
|
445
|
+
|
446
|
+
# Handle cold start
|
447
|
+
self._cold_start(training_loss)
|
448
|
+
|
449
|
+
# Update history
|
450
|
+
self._update_history(training_loss)
|
451
|
+
|
452
|
+
# Check if we should create a checkpoint
|
453
|
+
if self._should_create_checkpoint():
|
454
|
+
# Compute percentile
|
455
|
+
sorted_history = sorted(self._history)
|
456
|
+
index = int(self.percentile * (len(sorted_history) - 1))
|
457
|
+
percentile_value = sorted_history[index]
|
458
|
+
|
459
|
+
self._cached_observation = percentile_value
|
460
|
+
return percentile_value
|
461
|
+
else:
|
462
|
+
# Return last value during warm-up
|
463
|
+
self._cached_observation = self.last_value
|
464
|
+
return self.last_value
|