evograd-diff 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evograd/__init__.py +67 -0
- evograd/algorithms/__init__.py +138 -0
- evograd/algorithms/cmaes.py +1365 -0
- evograd/algorithms/de.py +895 -0
- evograd/algorithms/ga.py +532 -0
- evograd/algorithms/pso.py +648 -0
- evograd/algorithms/shade.py +1165 -0
- evograd/benchmarks/functions/__init__.py +229 -0
- evograd/benchmarks/functions/base.py +217 -0
- evograd/benchmarks/functions/cec2017/__init__.py +250 -0
- evograd/benchmarks/functions/cec2017/basic.py +413 -0
- evograd/benchmarks/functions/cec2017/composition.py +580 -0
- evograd/benchmarks/functions/cec2017/data.pkl +0 -0
- evograd/benchmarks/functions/cec2017/data.py +350 -0
- evograd/benchmarks/functions/cec2017/hybrid.py +406 -0
- evograd/benchmarks/functions/cec2017/simple.py +326 -0
- evograd/benchmarks/functions/classical.py +649 -0
- evograd/benchmarks/functions/smoothed_funnel.py +476 -0
- evograd/benchmarks/functions/transforms.py +463 -0
- evograd/benchmarks/run_benchmark_functions.py +1208 -0
- evograd/core/__init__.py +73 -0
- evograd/core/algorithm.py +778 -0
- evograd/core/maximize.py +269 -0
- evograd/core/minimize.py +740 -0
- evograd/core/problem.py +444 -0
- evograd/core/result.py +571 -0
- evograd/core/termination.py +602 -0
- evograd/operators/__init__.py +178 -0
- evograd/operators/crossover.py +1117 -0
- evograd/operators/mutation.py +1098 -0
- evograd/operators/relaxations.py +175 -0
- evograd/operators/repair.py +601 -0
- evograd/operators/sampling.py +577 -0
- evograd/operators/selection.py +981 -0
- evograd/operators/survival.py +1000 -0
- evograd/tests/__init__.py +11 -0
- evograd/tests/run_all.py +78 -0
- evograd/tests/test_core.py +528 -0
- evograd/tests/test_ga.py +572 -0
- evograd/tests/test_operators.py +662 -0
- evograd/tests/test_per_individual.py +326 -0
- evograd/tests/test_utils.py +328 -0
- evograd/utils/__init__.py +97 -0
- evograd/utils/callbacks.py +926 -0
- evograd/utils/device.py +502 -0
- evograd/utils/duplicates.py +421 -0
- evograd_diff-0.1.0.dist-info/METADATA +439 -0
- evograd_diff-0.1.0.dist-info/RECORD +50 -0
- evograd_diff-0.1.0.dist-info/WHEEL +4 -0
- evograd_diff-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,926 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Callbacks for monitoring and controlling the optimisation process.
|
|
3
|
+
|
|
4
|
+
This module provides a callback system inspired by Keras/PyTorch Lightning,
|
|
5
|
+
allowing users to hook into the optimisation loop at various points.
|
|
6
|
+
|
|
7
|
+
Available callbacks:
|
|
8
|
+
- HistoryCallback: Track fitness and hyperparameter history
|
|
9
|
+
- EarlyStoppingCallback: Stop when no improvement is detected
|
|
10
|
+
- ConvergenceCallback: Stop when fitness change falls below threshold
|
|
11
|
+
- PrintCallback: Print progress during optimisation
|
|
12
|
+
- CheckpointCallback: Save algorithm state periodically
|
|
13
|
+
- CompositeCallback: Combine multiple callbacks
|
|
14
|
+
|
|
15
|
+
Example:
|
|
16
|
+
>>> from evograd.utils.callbacks import HistoryCallback, EarlyStoppingCallback
|
|
17
|
+
>>>
|
|
18
|
+
>>> history = HistoryCallback(track_population=False)
|
|
19
|
+
>>> early_stop = EarlyStoppingCallback(patience=50, min_delta=1e-6)
|
|
20
|
+
>>>
|
|
21
|
+
>>> # Pass to minimize function
|
|
22
|
+
>>> result = minimize(algorithm, callbacks=[history, early_stop])
|
|
23
|
+
>>>
|
|
24
|
+
>>> # Access history after optimisation
|
|
25
|
+
>>> print(history.best_fitness) # List of best fitness per generation
|
|
26
|
+
>>> print(history.to_dataframe()) # Pandas DataFrame if available
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from __future__ import annotations
|
|
30
|
+
|
|
31
|
+
import time
|
|
32
|
+
import warnings
|
|
33
|
+
from abc import ABC, abstractmethod
|
|
34
|
+
from dataclasses import dataclass, field
|
|
35
|
+
from enum import Enum, auto
|
|
36
|
+
from pathlib import Path
|
|
37
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
|
38
|
+
|
|
39
|
+
import torch
|
|
40
|
+
|
|
41
|
+
if TYPE_CHECKING:
|
|
42
|
+
from torch import Tensor
|
|
43
|
+
|
|
44
|
+
__all__ = [
|
|
45
|
+
"CallbackEvent",
|
|
46
|
+
"Callback",
|
|
47
|
+
"HistoryCallback",
|
|
48
|
+
"EarlyStoppingCallback",
|
|
49
|
+
"ConvergenceCallback",
|
|
50
|
+
"PrintCallback",
|
|
51
|
+
"CheckpointCallback",
|
|
52
|
+
"CompositeCallback",
|
|
53
|
+
"CallbackList",
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# =============================================================================
|
|
58
|
+
# Callback Events
|
|
59
|
+
# =============================================================================
|
|
60
|
+
|
|
61
|
+
class CallbackEvent(Enum):
|
|
62
|
+
"""Events that trigger callback methods."""
|
|
63
|
+
|
|
64
|
+
OPTIMISATION_START = auto()
|
|
65
|
+
OPTIMISATION_END = auto()
|
|
66
|
+
GENERATION_START = auto()
|
|
67
|
+
GENERATION_END = auto()
|
|
68
|
+
EVALUATION_START = auto()
|
|
69
|
+
EVALUATION_END = auto()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# =============================================================================
|
|
73
|
+
# Callback State Container
|
|
74
|
+
# =============================================================================
|
|
75
|
+
|
|
76
|
+
@dataclass
|
|
77
|
+
class CallbackState:
|
|
78
|
+
"""
|
|
79
|
+
State object passed to callbacks containing current optimisation info.
|
|
80
|
+
|
|
81
|
+
Attributes:
|
|
82
|
+
generation: Current generation number (0-indexed).
|
|
83
|
+
n_evals: Total number of fitness evaluations so far.
|
|
84
|
+
max_evals: Maximum allowed fitness evaluations.
|
|
85
|
+
max_generations: Maximum allowed generations (if set).
|
|
86
|
+
best_fitness: Best fitness value found so far.
|
|
87
|
+
best_solution: Best solution found so far.
|
|
88
|
+
current_fitness: Fitness values of current population.
|
|
89
|
+
current_population: Current population tensor.
|
|
90
|
+
algorithm: Reference to the algorithm instance.
|
|
91
|
+
hyperparams: Dictionary of current hyperparameter values.
|
|
92
|
+
elapsed_time: Time elapsed since optimisation start.
|
|
93
|
+
stop_optimisation: Flag to signal early stopping.
|
|
94
|
+
extra: Dictionary for custom data from algorithms.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
generation: int = 0
|
|
98
|
+
n_evals: int = 0
|
|
99
|
+
max_evals: Optional[int] = None
|
|
100
|
+
max_generations: Optional[int] = None
|
|
101
|
+
best_fitness: float = float('inf')
|
|
102
|
+
best_solution: Optional[Tensor] = None
|
|
103
|
+
current_fitness: Optional[Tensor] = None
|
|
104
|
+
current_population: Optional[Tensor] = None
|
|
105
|
+
algorithm: Optional[Any] = None
|
|
106
|
+
hyperparams: Dict[str, Any] = field(default_factory=dict)
|
|
107
|
+
elapsed_time: float = 0.0
|
|
108
|
+
stop_optimisation: bool = False
|
|
109
|
+
extra: Dict[str, Any] = field(default_factory=dict)
|
|
110
|
+
|
|
111
|
+
def request_stop(self, reason: str = "") -> None:
|
|
112
|
+
"""Request the optimisation loop to stop."""
|
|
113
|
+
self.stop_optimisation = True
|
|
114
|
+
self.extra["stop_reason"] = reason
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
# =============================================================================
|
|
118
|
+
# Base Callback
|
|
119
|
+
# =============================================================================
|
|
120
|
+
|
|
121
|
+
class Callback(ABC):
|
|
122
|
+
"""
|
|
123
|
+
Abstract base class for all callbacks.
|
|
124
|
+
|
|
125
|
+
Callbacks can hook into various points of the optimisation loop:
|
|
126
|
+
- on_optimisation_start: Called once before optimisation begins
|
|
127
|
+
- on_optimisation_end: Called once after optimisation completes
|
|
128
|
+
- on_generation_start: Called at the start of each generation
|
|
129
|
+
- on_generation_end: Called at the end of each generation
|
|
130
|
+
- on_evaluation_start: Called before fitness evaluation
|
|
131
|
+
- on_evaluation_end: Called after fitness evaluation
|
|
132
|
+
|
|
133
|
+
To create a custom callback, subclass this and override the desired methods.
|
|
134
|
+
|
|
135
|
+
Example:
|
|
136
|
+
>>> class MyCallback(Callback):
|
|
137
|
+
... def on_generation_end(self, state: CallbackState) -> None:
|
|
138
|
+
... if state.generation % 100 == 0:
|
|
139
|
+
... print(f"Gen {state.generation}: {state.best_fitness:.6f}")
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
def on_optimisation_start(self, state: CallbackState) -> None:
|
|
143
|
+
"""Called when optimisation begins."""
|
|
144
|
+
pass
|
|
145
|
+
|
|
146
|
+
def on_optimisation_end(self, state: CallbackState) -> None:
|
|
147
|
+
"""Called when optimisation ends."""
|
|
148
|
+
pass
|
|
149
|
+
|
|
150
|
+
def on_generation_start(self, state: CallbackState) -> None:
|
|
151
|
+
"""Called at the start of each generation."""
|
|
152
|
+
pass
|
|
153
|
+
|
|
154
|
+
def on_generation_end(self, state: CallbackState) -> None:
|
|
155
|
+
"""Called at the end of each generation."""
|
|
156
|
+
pass
|
|
157
|
+
|
|
158
|
+
def on_evaluation_start(self, state: CallbackState) -> None:
|
|
159
|
+
"""Called before fitness evaluation."""
|
|
160
|
+
pass
|
|
161
|
+
|
|
162
|
+
def on_evaluation_end(self, state: CallbackState) -> None:
|
|
163
|
+
"""Called after fitness evaluation."""
|
|
164
|
+
pass
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
# =============================================================================
|
|
168
|
+
# History Callback
|
|
169
|
+
# =============================================================================
|
|
170
|
+
|
|
171
|
+
class HistoryCallback(Callback):
|
|
172
|
+
"""
|
|
173
|
+
Track optimisation history including fitness values and hyperparameters.
|
|
174
|
+
|
|
175
|
+
This callback records various metrics at each generation, providing
|
|
176
|
+
a complete picture of the optimisation trajectory.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
track_population: Whether to store full population at each generation.
|
|
180
|
+
Warning: This can consume significant memory for large populations.
|
|
181
|
+
track_hyperparams: Whether to track hyperparameter changes.
|
|
182
|
+
track_diversity: Whether to compute population diversity metrics.
|
|
183
|
+
track_fitness_stats: Whether to track min/max/mean/std of fitness.
|
|
184
|
+
|
|
185
|
+
Attributes:
|
|
186
|
+
best_fitness: List of best fitness at each generation.
|
|
187
|
+
best_solution: List of best solutions at each generation.
|
|
188
|
+
mean_fitness: List of mean fitness at each generation.
|
|
189
|
+
std_fitness: List of fitness std at each generation.
|
|
190
|
+
min_fitness: List of min fitness at each generation.
|
|
191
|
+
max_fitness: List of max fitness at each generation.
|
|
192
|
+
n_evals: List of cumulative evaluations at each generation.
|
|
193
|
+
elapsed_time: List of elapsed time at each generation.
|
|
194
|
+
hyperparams: Dict mapping hyperparam names to lists of values.
|
|
195
|
+
populations: List of population tensors (if track_population=True).
|
|
196
|
+
diversity: List of diversity metrics (if track_diversity=True).
|
|
197
|
+
|
|
198
|
+
Example:
|
|
199
|
+
>>> history = HistoryCallback(track_hyperparams=True)
|
|
200
|
+
>>> result = minimize(algorithm, callbacks=[history])
|
|
201
|
+
>>>
|
|
202
|
+
>>> # Plot convergence
|
|
203
|
+
>>> import matplotlib.pyplot as plt
|
|
204
|
+
>>> plt.plot(history.best_fitness)
|
|
205
|
+
>>> plt.xlabel('Generation')
|
|
206
|
+
>>> plt.ylabel('Best Fitness')
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
def __init__(
|
|
210
|
+
self,
|
|
211
|
+
track_population: bool = False,
|
|
212
|
+
track_hyperparams: bool = True,
|
|
213
|
+
track_diversity: bool = False,
|
|
214
|
+
track_fitness_stats: bool = True,
|
|
215
|
+
) -> None:
|
|
216
|
+
self.track_population = track_population
|
|
217
|
+
self.track_hyperparams = track_hyperparams
|
|
218
|
+
self.track_diversity = track_diversity
|
|
219
|
+
self.track_fitness_stats = track_fitness_stats
|
|
220
|
+
|
|
221
|
+
# Core tracking
|
|
222
|
+
self.best_fitness: List[float] = []
|
|
223
|
+
self.best_solution: List[Tensor] = []
|
|
224
|
+
self.n_evals: List[int] = []
|
|
225
|
+
self.elapsed_time: List[float] = []
|
|
226
|
+
self.generations: List[int] = []
|
|
227
|
+
|
|
228
|
+
# Fitness statistics
|
|
229
|
+
self.mean_fitness: List[float] = []
|
|
230
|
+
self.std_fitness: List[float] = []
|
|
231
|
+
self.min_fitness: List[float] = []
|
|
232
|
+
self.max_fitness: List[float] = []
|
|
233
|
+
|
|
234
|
+
# Hyperparameters
|
|
235
|
+
self.hyperparams: Dict[str, List[Any]] = {}
|
|
236
|
+
|
|
237
|
+
# Population (optional)
|
|
238
|
+
self.populations: List[Tensor] = []
|
|
239
|
+
|
|
240
|
+
# Diversity (optional)
|
|
241
|
+
self.diversity: List[float] = []
|
|
242
|
+
|
|
243
|
+
# Timing
|
|
244
|
+
self._start_time: Optional[float] = None
|
|
245
|
+
|
|
246
|
+
def _record_metrics(self, state: CallbackState) -> None:
|
|
247
|
+
# Core metrics
|
|
248
|
+
self.generations.append(state.generation)
|
|
249
|
+
self.best_fitness.append(float(state.best_fitness))
|
|
250
|
+
self.n_evals.append(state.n_evals)
|
|
251
|
+
|
|
252
|
+
if self._start_time is not None:
|
|
253
|
+
self.elapsed_time.append(time.perf_counter() - self._start_time)
|
|
254
|
+
|
|
255
|
+
# Best solution (detached clone)
|
|
256
|
+
if state.best_solution is not None:
|
|
257
|
+
self.best_solution.append(state.best_solution.detach().clone())
|
|
258
|
+
|
|
259
|
+
# Fitness statistics
|
|
260
|
+
if self.track_fitness_stats and state.current_fitness is not None:
|
|
261
|
+
fitness = state.current_fitness
|
|
262
|
+
self.mean_fitness.append(float(fitness.mean().item()))
|
|
263
|
+
self.std_fitness.append(float(fitness.std().item()))
|
|
264
|
+
self.min_fitness.append(float(fitness.min().item()))
|
|
265
|
+
self.max_fitness.append(float(fitness.max().item()))
|
|
266
|
+
|
|
267
|
+
# Hyperparameters
|
|
268
|
+
if self.track_hyperparams and state.hyperparams:
|
|
269
|
+
for name, value in state.hyperparams.items():
|
|
270
|
+
if name not in self.hyperparams:
|
|
271
|
+
self.hyperparams[name] = []
|
|
272
|
+
# Convert tensor to scalar if needed
|
|
273
|
+
if isinstance(value, torch.Tensor):
|
|
274
|
+
value = float(value.mean()) if value.numel() > 1 else float(value)
|
|
275
|
+
self.hyperparams[name].append(value)
|
|
276
|
+
|
|
277
|
+
# Population snapshot
|
|
278
|
+
if self.track_population and state.current_population is not None:
|
|
279
|
+
self.populations.append(state.current_population.detach().clone())
|
|
280
|
+
|
|
281
|
+
# Diversity
|
|
282
|
+
if self.track_diversity and state.current_population is not None:
|
|
283
|
+
div = self._compute_diversity(state.current_population)
|
|
284
|
+
self.diversity.append(div)
|
|
285
|
+
|
|
286
|
+
def on_optimisation_start(self, state: CallbackState) -> None:
|
|
287
|
+
"""Reset history and start timer."""
|
|
288
|
+
self._start_time = time.perf_counter()
|
|
289
|
+
|
|
290
|
+
# Clear all lists
|
|
291
|
+
self.best_fitness.clear()
|
|
292
|
+
self.best_solution.clear()
|
|
293
|
+
self.n_evals.clear()
|
|
294
|
+
self.elapsed_time.clear()
|
|
295
|
+
self.generations.clear()
|
|
296
|
+
self.mean_fitness.clear()
|
|
297
|
+
self.std_fitness.clear()
|
|
298
|
+
self.min_fitness.clear()
|
|
299
|
+
self.max_fitness.clear()
|
|
300
|
+
self.hyperparams.clear()
|
|
301
|
+
self.populations.clear()
|
|
302
|
+
self.diversity.clear()
|
|
303
|
+
|
|
304
|
+
self._record_metrics(state)
|
|
305
|
+
|
|
306
|
+
def on_generation_end(self, state: CallbackState) -> None:
|
|
307
|
+
"""Record metrics at end of generation or initialization."""
|
|
308
|
+
self._record_metrics(state)
|
|
309
|
+
|
|
310
|
+
def _compute_diversity(self, population: Tensor) -> float:
|
|
311
|
+
"""
|
|
312
|
+
Compute population diversity as mean pairwise distance.
|
|
313
|
+
|
|
314
|
+
Uses L2 norm between individuals.
|
|
315
|
+
"""
|
|
316
|
+
if population.shape[0] < 2:
|
|
317
|
+
return 0.0
|
|
318
|
+
|
|
319
|
+
# Compute pairwise distances
|
|
320
|
+
dists = torch.cdist(population, population, p=2)
|
|
321
|
+
|
|
322
|
+
# Mean of upper triangle (excluding diagonal)
|
|
323
|
+
n = population.shape[0]
|
|
324
|
+
mask = torch.triu(torch.ones(n, n, device=population.device), diagonal=1).bool()
|
|
325
|
+
mean_dist = dists[mask].mean()
|
|
326
|
+
|
|
327
|
+
return float(mean_dist)
|
|
328
|
+
|
|
329
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
330
|
+
"""Convert history to dictionary."""
|
|
331
|
+
result = {
|
|
332
|
+
"generation": self.generations,
|
|
333
|
+
"best_fitness": self.best_fitness,
|
|
334
|
+
"n_evals": self.n_evals,
|
|
335
|
+
"elapsed_time": self.elapsed_time,
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
if self.track_fitness_stats:
|
|
339
|
+
result["mean_fitness"] = self.mean_fitness
|
|
340
|
+
result["std_fitness"] = self.std_fitness
|
|
341
|
+
result["min_fitness"] = self.min_fitness
|
|
342
|
+
result["max_fitness"] = self.max_fitness
|
|
343
|
+
|
|
344
|
+
if self.track_hyperparams:
|
|
345
|
+
for name, values in self.hyperparams.items():
|
|
346
|
+
result[f"hp_{name}"] = values
|
|
347
|
+
|
|
348
|
+
if self.track_diversity:
|
|
349
|
+
result["diversity"] = self.diversity
|
|
350
|
+
|
|
351
|
+
return result
|
|
352
|
+
|
|
353
|
+
def to_dataframe(self):
|
|
354
|
+
"""
|
|
355
|
+
Convert history to pandas DataFrame.
|
|
356
|
+
|
|
357
|
+
Returns:
|
|
358
|
+
pandas.DataFrame if pandas is available, else raises ImportError.
|
|
359
|
+
"""
|
|
360
|
+
try:
|
|
361
|
+
import pandas as pd
|
|
362
|
+
return pd.DataFrame(self.to_dict())
|
|
363
|
+
except ImportError:
|
|
364
|
+
raise ImportError(
|
|
365
|
+
"pandas is required for to_dataframe(). "
|
|
366
|
+
"Install with: pip install pandas"
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
def __len__(self) -> int:
|
|
370
|
+
"""Number of recorded generations."""
|
|
371
|
+
return len(self.generations)
|
|
372
|
+
|
|
373
|
+
def __repr__(self) -> str:
|
|
374
|
+
n_gens = len(self.generations)
|
|
375
|
+
best = self.best_fitness[-1] if self.best_fitness else float('inf')
|
|
376
|
+
return f"HistoryCallback(generations={n_gens}, best_fitness={best:.6g})"
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
# =============================================================================
|
|
380
|
+
# Early Stopping Callback
|
|
381
|
+
# =============================================================================
|
|
382
|
+
|
|
383
|
+
class EarlyStoppingCallback(Callback):
|
|
384
|
+
"""
|
|
385
|
+
Stop optimisation when fitness stops improving.
|
|
386
|
+
|
|
387
|
+
Monitors the best fitness and stops if no improvement is seen
|
|
388
|
+
for a specified number of generations (patience).
|
|
389
|
+
|
|
390
|
+
Args:
|
|
391
|
+
patience: Number of generations to wait for improvement.
|
|
392
|
+
min_delta: Minimum change to qualify as an improvement.
|
|
393
|
+
Improvement is defined as: new_best < best - min_delta
|
|
394
|
+
baseline: Initial baseline value. If None, uses first generation's best.
|
|
395
|
+
restore_best: Whether to restore best solution when stopping.
|
|
396
|
+
verbose: Whether to print when stopping.
|
|
397
|
+
|
|
398
|
+
Attributes:
|
|
399
|
+
best_fitness: Best fitness seen so far.
|
|
400
|
+
best_generation: Generation where best fitness was found.
|
|
401
|
+
wait: Current number of generations without improvement.
|
|
402
|
+
stopped_generation: Generation where stopping was triggered.
|
|
403
|
+
|
|
404
|
+
Example:
|
|
405
|
+
>>> early_stop = EarlyStoppingCallback(patience=100, min_delta=1e-8)
|
|
406
|
+
>>> result = minimize(algorithm, callbacks=[early_stop])
|
|
407
|
+
>>>
|
|
408
|
+
>>> if early_stop.stopped_generation is not None:
|
|
409
|
+
... print(f"Stopped at generation {early_stop.stopped_generation}")
|
|
410
|
+
"""
|
|
411
|
+
|
|
412
|
+
def __init__(
|
|
413
|
+
self,
|
|
414
|
+
patience: int = 50,
|
|
415
|
+
min_delta: float = 0.0,
|
|
416
|
+
baseline: Optional[float] = None,
|
|
417
|
+
restore_best: bool = True,
|
|
418
|
+
verbose: bool = False,
|
|
419
|
+
) -> None:
|
|
420
|
+
if patience < 1:
|
|
421
|
+
raise ValueError(f"patience must be >= 1, got {patience}")
|
|
422
|
+
if min_delta < 0:
|
|
423
|
+
raise ValueError(f"min_delta must be >= 0, got {min_delta}")
|
|
424
|
+
|
|
425
|
+
self.patience = patience
|
|
426
|
+
self.min_delta = min_delta
|
|
427
|
+
self.baseline = baseline
|
|
428
|
+
self.restore_best = restore_best
|
|
429
|
+
self.verbose = verbose
|
|
430
|
+
|
|
431
|
+
# State
|
|
432
|
+
self.best_fitness: float = float('inf')
|
|
433
|
+
self.best_generation: int = 0
|
|
434
|
+
self.best_solution: Optional[Tensor] = None
|
|
435
|
+
self.wait: int = 0
|
|
436
|
+
self.stopped_generation: Optional[int] = None
|
|
437
|
+
|
|
438
|
+
def on_optimisation_start(self, state: CallbackState) -> None:
|
|
439
|
+
"""Reset state at start of optimisation."""
|
|
440
|
+
self.best_fitness = self.baseline if self.baseline is not None else float('inf')
|
|
441
|
+
self.best_generation = 0
|
|
442
|
+
self.best_solution = None
|
|
443
|
+
self.wait = 0
|
|
444
|
+
self.stopped_generation = None
|
|
445
|
+
|
|
446
|
+
def on_generation_end(self, state: CallbackState) -> None:
|
|
447
|
+
"""Check for improvement and possibly stop."""
|
|
448
|
+
current = state.best_fitness
|
|
449
|
+
|
|
450
|
+
# Check for improvement
|
|
451
|
+
if current < self.best_fitness - self.min_delta:
|
|
452
|
+
self.best_fitness = current
|
|
453
|
+
self.best_generation = state.generation
|
|
454
|
+
if state.best_solution is not None:
|
|
455
|
+
self.best_solution = state.best_solution.detach().clone()
|
|
456
|
+
self.wait = 0
|
|
457
|
+
else:
|
|
458
|
+
self.wait += 1
|
|
459
|
+
|
|
460
|
+
# Check patience
|
|
461
|
+
if self.wait >= self.patience:
|
|
462
|
+
self.stopped_generation = state.generation
|
|
463
|
+
state.request_stop(
|
|
464
|
+
f"EarlyStopping: No improvement for {self.patience} generations. "
|
|
465
|
+
f"Best: {self.best_fitness:.6g} at generation {self.best_generation}"
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
if self.verbose:
|
|
469
|
+
print(
|
|
470
|
+
f"Early stopping at generation {state.generation}. "
|
|
471
|
+
f"Best fitness: {self.best_fitness:.6g} "
|
|
472
|
+
f"(generation {self.best_generation})"
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
def on_optimisation_end(self, state: CallbackState) -> None:
|
|
476
|
+
"""Optionally restore best solution."""
|
|
477
|
+
if self.restore_best and self.best_solution is not None:
|
|
478
|
+
# Store in extra for minimize to pick up
|
|
479
|
+
state.extra["restored_solution"] = self.best_solution
|
|
480
|
+
state.extra["restored_fitness"] = self.best_fitness
|
|
481
|
+
|
|
482
|
+
def __repr__(self) -> str:
|
|
483
|
+
status = "active" if self.stopped_generation is None else f"stopped@{self.stopped_generation}"
|
|
484
|
+
return (
|
|
485
|
+
f"EarlyStoppingCallback(patience={self.patience}, "
|
|
486
|
+
f"min_delta={self.min_delta}, status={status})"
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
# =============================================================================
|
|
491
|
+
# Convergence Callback
|
|
492
|
+
# =============================================================================
|
|
493
|
+
|
|
494
|
+
class ConvergenceCallback(Callback):
|
|
495
|
+
"""
|
|
496
|
+
Stop optimisation when fitness change falls below threshold.
|
|
497
|
+
|
|
498
|
+
Monitors the relative or absolute change in best fitness over a window
|
|
499
|
+
of generations and stops when change is consistently small.
|
|
500
|
+
|
|
501
|
+
Args:
|
|
502
|
+
threshold: Convergence threshold for fitness change.
|
|
503
|
+
window: Number of generations to average change over.
|
|
504
|
+
mode: 'absolute' or 'relative' change measurement.
|
|
505
|
+
min_generations: Minimum generations before convergence check starts.
|
|
506
|
+
verbose: Whether to print when stopping.
|
|
507
|
+
|
|
508
|
+
Example:
|
|
509
|
+
>>> conv = ConvergenceCallback(threshold=1e-6, window=20, mode='relative')
|
|
510
|
+
>>> result = minimize(algorithm, callbacks=[conv])
|
|
511
|
+
"""
|
|
512
|
+
|
|
513
|
+
def __init__(
|
|
514
|
+
self,
|
|
515
|
+
threshold: float = 1e-6,
|
|
516
|
+
window: int = 10,
|
|
517
|
+
mode: str = "absolute",
|
|
518
|
+
min_generations: int = 100,
|
|
519
|
+
verbose: bool = False,
|
|
520
|
+
) -> None:
|
|
521
|
+
if threshold <= 0:
|
|
522
|
+
raise ValueError(f"threshold must be > 0, got {threshold}")
|
|
523
|
+
if window < 2:
|
|
524
|
+
raise ValueError(f"window must be >= 2, got {window}")
|
|
525
|
+
if mode not in ("absolute", "relative"):
|
|
526
|
+
raise ValueError(f"mode must be 'absolute' or 'relative', got '{mode}'")
|
|
527
|
+
|
|
528
|
+
self.threshold = threshold
|
|
529
|
+
self.window = window
|
|
530
|
+
self.mode = mode
|
|
531
|
+
self.min_generations = min_generations
|
|
532
|
+
self.verbose = verbose
|
|
533
|
+
|
|
534
|
+
# State
|
|
535
|
+
self.fitness_history: List[float] = []
|
|
536
|
+
self.stopped_generation: Optional[int] = None
|
|
537
|
+
self.final_change: Optional[float] = None
|
|
538
|
+
|
|
539
|
+
def on_optimisation_start(self, state: CallbackState) -> None:
|
|
540
|
+
"""Reset state."""
|
|
541
|
+
self.fitness_history.clear()
|
|
542
|
+
self.stopped_generation = None
|
|
543
|
+
self.final_change = None
|
|
544
|
+
|
|
545
|
+
def on_generation_end(self, state: CallbackState) -> None:
|
|
546
|
+
"""Check convergence."""
|
|
547
|
+
self.fitness_history.append(float(state.best_fitness))
|
|
548
|
+
|
|
549
|
+
# Skip if not enough history
|
|
550
|
+
if len(self.fitness_history) < self.window:
|
|
551
|
+
return
|
|
552
|
+
|
|
553
|
+
# Skip if minimum generations not reached
|
|
554
|
+
if state.generation < self.min_generations:
|
|
555
|
+
return
|
|
556
|
+
|
|
557
|
+
# Compute change over window
|
|
558
|
+
recent = self.fitness_history[-self.window:]
|
|
559
|
+
old_val = recent[0]
|
|
560
|
+
new_val = recent[-1]
|
|
561
|
+
|
|
562
|
+
if self.mode == "absolute":
|
|
563
|
+
change = abs(new_val - old_val)
|
|
564
|
+
else: # relative
|
|
565
|
+
if abs(old_val) < 1e-10:
|
|
566
|
+
change = abs(new_val - old_val)
|
|
567
|
+
else:
|
|
568
|
+
change = abs((new_val - old_val) / old_val)
|
|
569
|
+
|
|
570
|
+
self.final_change = change
|
|
571
|
+
|
|
572
|
+
# Check threshold
|
|
573
|
+
if change < self.threshold:
|
|
574
|
+
self.stopped_generation = state.generation
|
|
575
|
+
state.request_stop(
|
|
576
|
+
f"Convergence: {self.mode} change {change:.2e} < {self.threshold:.2e} "
|
|
577
|
+
f"over {self.window} generations"
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
if self.verbose:
|
|
581
|
+
print(
|
|
582
|
+
f"Converged at generation {state.generation}. "
|
|
583
|
+
f"{self.mode.capitalize()} change: {change:.2e}"
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
def __repr__(self) -> str:
|
|
587
|
+
status = "active" if self.stopped_generation is None else f"stopped@{self.stopped_generation}"
|
|
588
|
+
return (
|
|
589
|
+
f"ConvergenceCallback(threshold={self.threshold}, "
|
|
590
|
+
f"window={self.window}, mode='{self.mode}', status={status})"
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
# =============================================================================
|
|
595
|
+
# Print Callback
|
|
596
|
+
# =============================================================================
|
|
597
|
+
|
|
598
|
+
class PrintCallback(Callback):
|
|
599
|
+
"""
|
|
600
|
+
Print progress during optimisation.
|
|
601
|
+
|
|
602
|
+
Args:
|
|
603
|
+
every: Print every N generations.
|
|
604
|
+
show_hyperparams: Whether to show hyperparameter values.
|
|
605
|
+
show_evals: Whether to show evaluation count.
|
|
606
|
+
show_time: Whether to show elapsed time.
|
|
607
|
+
format_spec: Format specification for fitness (e.g., '.6f', '.4e').
|
|
608
|
+
|
|
609
|
+
Example:
|
|
610
|
+
>>> printer = PrintCallback(every=50, show_time=True)
|
|
611
|
+
>>> result = minimize(algorithm, callbacks=[printer])
|
|
612
|
+
|
|
613
|
+
# Output:
|
|
614
|
+
# Gen 50 | Best: 1.234567e+02 | Evals: 5000 | Time: 1.23s
|
|
615
|
+
# Gen 100 | Best: 4.567890e+01 | Evals: 10000 | Time: 2.45s
|
|
616
|
+
"""
|
|
617
|
+
|
|
618
|
+
def __init__(
|
|
619
|
+
self,
|
|
620
|
+
every: int = 1,
|
|
621
|
+
show_hyperparams: bool = False,
|
|
622
|
+
show_evals: bool = True,
|
|
623
|
+
show_time: bool = True,
|
|
624
|
+
format_spec: str = ".6e",
|
|
625
|
+
) -> None:
|
|
626
|
+
if every < 1:
|
|
627
|
+
raise ValueError(f"every must be >= 1, got {every}")
|
|
628
|
+
|
|
629
|
+
self.every = every
|
|
630
|
+
self.show_hyperparams = show_hyperparams
|
|
631
|
+
self.show_evals = show_evals
|
|
632
|
+
self.show_time = show_time
|
|
633
|
+
self.format_spec = format_spec
|
|
634
|
+
|
|
635
|
+
self._start_time: Optional[float] = None
|
|
636
|
+
|
|
637
|
+
def _report_generation(self, state: CallbackState) -> None:
|
|
638
|
+
"""Print progress if at print interval."""
|
|
639
|
+
if state.generation % self.every != 0:
|
|
640
|
+
return
|
|
641
|
+
|
|
642
|
+
# Build output string
|
|
643
|
+
parts = [f"Gen {state.generation:5d}"]
|
|
644
|
+
parts.append(f"Best: {state.best_fitness:{self.format_spec}}")
|
|
645
|
+
|
|
646
|
+
if self.show_evals:
|
|
647
|
+
parts.append(f"Evals: {state.n_evals:7d}")
|
|
648
|
+
|
|
649
|
+
if self.show_time and self._start_time is not None:
|
|
650
|
+
elapsed = time.perf_counter() - self._start_time
|
|
651
|
+
parts.append(f"Time: {elapsed:.2f}s")
|
|
652
|
+
|
|
653
|
+
if self.show_hyperparams and state.hyperparams:
|
|
654
|
+
hp_str = ", ".join(
|
|
655
|
+
f"{k}={v:.4f}" if isinstance(v, float) else f"{k}={v}"
|
|
656
|
+
for k, v in list(state.hyperparams.items())[:3] # Limit to 3
|
|
657
|
+
)
|
|
658
|
+
parts.append(f"HP: {hp_str}")
|
|
659
|
+
|
|
660
|
+
print(" | ".join(parts))
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
def on_optimisation_start(self, state: CallbackState) -> None:
|
|
664
|
+
"""Record start time, print header, and print initilization"""
|
|
665
|
+
self._start_time = time.perf_counter()
|
|
666
|
+
|
|
667
|
+
# Print header
|
|
668
|
+
header = "Gen"
|
|
669
|
+
header += " | Best Fitness"
|
|
670
|
+
if self.show_evals:
|
|
671
|
+
header += " | Evals"
|
|
672
|
+
if self.show_time:
|
|
673
|
+
header += " | Time"
|
|
674
|
+
|
|
675
|
+
print("-" * len(header))
|
|
676
|
+
print(header)
|
|
677
|
+
print("-" * len(header))
|
|
678
|
+
|
|
679
|
+
self._report_generation(state)
|
|
680
|
+
|
|
681
|
+
def on_generation_end(self, state: CallbackState) -> None:
|
|
682
|
+
"""Print progress after each generation"""
|
|
683
|
+
self._report_generation(state)
|
|
684
|
+
|
|
685
|
+
def on_optimisation_end(self, state: CallbackState) -> None:
|
|
686
|
+
"""Print final summary."""
|
|
687
|
+
elapsed = time.perf_counter() - self._start_time if self._start_time else 0
|
|
688
|
+
print("-" * 60)
|
|
689
|
+
print(f"Optimisation complete!")
|
|
690
|
+
print(f" Final best: {state.best_fitness:{self.format_spec}}")
|
|
691
|
+
print(f" Generations: {state.generation}")
|
|
692
|
+
print(f" Evaluations: {state.n_evals}")
|
|
693
|
+
print(f" Time: {elapsed:.2f}s")
|
|
694
|
+
|
|
695
|
+
if state.extra.get("stop_reason"):
|
|
696
|
+
print(f" Stop reason: {state.extra['stop_reason']}")
|
|
697
|
+
|
|
698
|
+
|
|
699
|
+
# =============================================================================
|
|
700
|
+
# Checkpoint Callback
|
|
701
|
+
# =============================================================================
|
|
702
|
+
|
|
703
|
+
class CheckpointCallback(Callback):
|
|
704
|
+
"""
|
|
705
|
+
Save algorithm state periodically.
|
|
706
|
+
|
|
707
|
+
Args:
|
|
708
|
+
directory: Directory to save checkpoints.
|
|
709
|
+
every: Save every N generations.
|
|
710
|
+
save_best_only: Only save when best fitness improves.
|
|
711
|
+
max_to_keep: Maximum number of checkpoints to keep (None for all).
|
|
712
|
+
prefix: Filename prefix for checkpoints.
|
|
713
|
+
|
|
714
|
+
Example:
|
|
715
|
+
>>> ckpt = CheckpointCallback(
|
|
716
|
+
... directory="checkpoints",
|
|
717
|
+
... every=100,
|
|
718
|
+
... save_best_only=True
|
|
719
|
+
... )
|
|
720
|
+
>>> result = minimize(algorithm, callbacks=[ckpt])
|
|
721
|
+
"""
|
|
722
|
+
|
|
723
|
+
def __init__(
|
|
724
|
+
self,
|
|
725
|
+
directory: Union[str, Path],
|
|
726
|
+
every: int = 100,
|
|
727
|
+
save_best_only: bool = False,
|
|
728
|
+
max_to_keep: Optional[int] = 5,
|
|
729
|
+
prefix: str = "checkpoint",
|
|
730
|
+
) -> None:
|
|
731
|
+
self.directory = Path(directory)
|
|
732
|
+
self.every = every
|
|
733
|
+
self.save_best_only = save_best_only
|
|
734
|
+
self.max_to_keep = max_to_keep
|
|
735
|
+
self.prefix = prefix
|
|
736
|
+
|
|
737
|
+
# State
|
|
738
|
+
self.best_fitness: float = float('inf')
|
|
739
|
+
self.saved_paths: List[Path] = []
|
|
740
|
+
|
|
741
|
+
def on_optimisation_start(self, state: CallbackState) -> None:
|
|
742
|
+
"""Create directory and reset state."""
|
|
743
|
+
self.directory.mkdir(parents=True, exist_ok=True)
|
|
744
|
+
self.best_fitness = float('inf')
|
|
745
|
+
self.saved_paths.clear()
|
|
746
|
+
|
|
747
|
+
def on_generation_end(self, state: CallbackState) -> None:
|
|
748
|
+
"""Save checkpoint if conditions are met."""
|
|
749
|
+
should_save = False
|
|
750
|
+
|
|
751
|
+
if self.save_best_only:
|
|
752
|
+
if state.best_fitness < self.best_fitness:
|
|
753
|
+
self.best_fitness = state.best_fitness
|
|
754
|
+
should_save = True
|
|
755
|
+
else:
|
|
756
|
+
if state.generation % self.every == 0:
|
|
757
|
+
should_save = True
|
|
758
|
+
|
|
759
|
+
if not should_save:
|
|
760
|
+
return
|
|
761
|
+
|
|
762
|
+
# Build checkpoint
|
|
763
|
+
checkpoint = {
|
|
764
|
+
"generation": state.generation,
|
|
765
|
+
"n_evals": state.n_evals,
|
|
766
|
+
"best_fitness": float(state.best_fitness),
|
|
767
|
+
"best_solution": state.best_solution.detach().cpu() if state.best_solution is not None else None,
|
|
768
|
+
"hyperparams": {
|
|
769
|
+
k: v.detach().cpu() if isinstance(v, torch.Tensor) else v
|
|
770
|
+
for k, v in state.hyperparams.items()
|
|
771
|
+
},
|
|
772
|
+
}
|
|
773
|
+
|
|
774
|
+
# Save algorithm state dict if available
|
|
775
|
+
if state.algorithm is not None and hasattr(state.algorithm, "state_dict"):
|
|
776
|
+
checkpoint["algorithm_state"] = state.algorithm.state_dict()
|
|
777
|
+
|
|
778
|
+
# Save
|
|
779
|
+
filename = f"{self.prefix}_gen{state.generation:06d}.pt"
|
|
780
|
+
filepath = self.directory / filename
|
|
781
|
+
torch.save(checkpoint, filepath)
|
|
782
|
+
self.saved_paths.append(filepath)
|
|
783
|
+
|
|
784
|
+
# Clean up old checkpoints
|
|
785
|
+
if self.max_to_keep is not None:
|
|
786
|
+
while len(self.saved_paths) > self.max_to_keep:
|
|
787
|
+
old_path = self.saved_paths.pop(0)
|
|
788
|
+
if old_path.exists():
|
|
789
|
+
old_path.unlink()
|
|
790
|
+
|
|
791
|
+
@staticmethod
|
|
792
|
+
def load_checkpoint(path: Union[str, Path]) -> Dict[str, Any]:
|
|
793
|
+
"""Load a checkpoint file."""
|
|
794
|
+
return torch.load(path, weights_only=False)
|
|
795
|
+
|
|
796
|
+
def __repr__(self) -> str:
|
|
797
|
+
return (
|
|
798
|
+
f"CheckpointCallback(directory='{self.directory}', "
|
|
799
|
+
f"every={self.every}, saved={len(self.saved_paths)})"
|
|
800
|
+
)
|
|
801
|
+
|
|
802
|
+
|
|
803
|
+
# =============================================================================
|
|
804
|
+
# Composite Callback
|
|
805
|
+
# =============================================================================
|
|
806
|
+
|
|
807
|
+
class CompositeCallback(Callback):
|
|
808
|
+
"""
|
|
809
|
+
Combine multiple callbacks into one.
|
|
810
|
+
|
|
811
|
+
This is useful for passing a single callback object that internally
|
|
812
|
+
manages multiple callbacks.
|
|
813
|
+
|
|
814
|
+
Args:
|
|
815
|
+
callbacks: List of callbacks to combine.
|
|
816
|
+
|
|
817
|
+
Example:
|
|
818
|
+
>>> composite = CompositeCallback([
|
|
819
|
+
... HistoryCallback(),
|
|
820
|
+
... EarlyStoppingCallback(patience=50),
|
|
821
|
+
... PrintCallback(every=10)
|
|
822
|
+
... ])
|
|
823
|
+
>>> result = minimize(algorithm, callbacks=[composite])
|
|
824
|
+
"""
|
|
825
|
+
|
|
826
|
+
def __init__(self, callbacks: List[Callback]) -> None:
|
|
827
|
+
self.callbacks = list(callbacks)
|
|
828
|
+
|
|
829
|
+
def add(self, callback: Callback) -> "CompositeCallback":
|
|
830
|
+
"""Add a callback."""
|
|
831
|
+
self.callbacks.append(callback)
|
|
832
|
+
return self
|
|
833
|
+
|
|
834
|
+
def on_optimisation_start(self, state: CallbackState) -> None:
|
|
835
|
+
for cb in self.callbacks:
|
|
836
|
+
cb.on_optimisation_start(state)
|
|
837
|
+
|
|
838
|
+
def on_optimisation_end(self, state: CallbackState) -> None:
|
|
839
|
+
for cb in self.callbacks:
|
|
840
|
+
cb.on_optimisation_end(state)
|
|
841
|
+
|
|
842
|
+
def on_generation_start(self, state: CallbackState) -> None:
|
|
843
|
+
for cb in self.callbacks:
|
|
844
|
+
cb.on_generation_start(state)
|
|
845
|
+
|
|
846
|
+
def on_generation_end(self, state: CallbackState) -> None:
|
|
847
|
+
for cb in self.callbacks:
|
|
848
|
+
cb.on_generation_end(state)
|
|
849
|
+
|
|
850
|
+
def on_evaluation_start(self, state: CallbackState) -> None:
|
|
851
|
+
for cb in self.callbacks:
|
|
852
|
+
cb.on_evaluation_start(state)
|
|
853
|
+
|
|
854
|
+
def on_evaluation_end(self, state: CallbackState) -> None:
|
|
855
|
+
for cb in self.callbacks:
|
|
856
|
+
cb.on_evaluation_end(state)
|
|
857
|
+
|
|
858
|
+
def __len__(self) -> int:
|
|
859
|
+
return len(self.callbacks)
|
|
860
|
+
|
|
861
|
+
def __iter__(self):
|
|
862
|
+
return iter(self.callbacks)
|
|
863
|
+
|
|
864
|
+
def __repr__(self) -> str:
|
|
865
|
+
return f"CompositeCallback({len(self.callbacks)} callbacks)"
|
|
866
|
+
|
|
867
|
+
|
|
868
|
+
# =============================================================================
|
|
869
|
+
# Callback List (Convenience Alias)
|
|
870
|
+
# =============================================================================
|
|
871
|
+
|
|
872
|
+
class CallbackList(CompositeCallback):
|
|
873
|
+
"""
|
|
874
|
+
Alias for CompositeCallback with list-like interface.
|
|
875
|
+
|
|
876
|
+
This mimics Keras's CallbackList for familiarity.
|
|
877
|
+
"""
|
|
878
|
+
|
|
879
|
+
def append(self, callback: Callback) -> None:
|
|
880
|
+
"""Append a callback."""
|
|
881
|
+
self.callbacks.append(callback)
|
|
882
|
+
|
|
883
|
+
def extend(self, callbacks: List[Callback]) -> None:
|
|
884
|
+
"""Extend with multiple callbacks."""
|
|
885
|
+
self.callbacks.extend(callbacks)
|
|
886
|
+
|
|
887
|
+
def __getitem__(self, idx: int) -> Callback:
|
|
888
|
+
return self.callbacks[idx]
|
|
889
|
+
|
|
890
|
+
|
|
891
|
+
# =============================================================================
|
|
892
|
+
# Utility Functions
|
|
893
|
+
# =============================================================================
|
|
894
|
+
|
|
895
|
+
def create_default_callbacks(
|
|
896
|
+
verbose: bool = True,
|
|
897
|
+
history: bool = True,
|
|
898
|
+
early_stopping: bool = False,
|
|
899
|
+
patience: int = 50,
|
|
900
|
+
print_every: int = 100,
|
|
901
|
+
) -> CallbackList:
|
|
902
|
+
"""
|
|
903
|
+
Create a default set of callbacks.
|
|
904
|
+
|
|
905
|
+
Args:
|
|
906
|
+
verbose: Whether to include PrintCallback.
|
|
907
|
+
history: Whether to include HistoryCallback.
|
|
908
|
+
early_stopping: Whether to include EarlyStoppingCallback.
|
|
909
|
+
patience: Patience for early stopping.
|
|
910
|
+
print_every: Print interval.
|
|
911
|
+
|
|
912
|
+
Returns:
|
|
913
|
+
CallbackList with requested callbacks.
|
|
914
|
+
"""
|
|
915
|
+
callbacks = CallbackList([])
|
|
916
|
+
|
|
917
|
+
if history:
|
|
918
|
+
callbacks.append(HistoryCallback())
|
|
919
|
+
|
|
920
|
+
if early_stopping:
|
|
921
|
+
callbacks.append(EarlyStoppingCallback(patience=patience))
|
|
922
|
+
|
|
923
|
+
if verbose:
|
|
924
|
+
callbacks.append(PrintCallback(every=print_every))
|
|
925
|
+
|
|
926
|
+
return callbacks
|