@zigrivers/scaffold 3.14.0 → 3.15.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +31 -9
- package/content/knowledge/research/research-architecture.md +385 -0
- package/content/knowledge/research/research-conventions.md +248 -0
- package/content/knowledge/research/research-dev-environment.md +303 -0
- package/content/knowledge/research/research-experiment-loop.md +429 -0
- package/content/knowledge/research/research-experiment-tracking.md +336 -0
- package/content/knowledge/research/research-ml-architecture-search.md +383 -0
- package/content/knowledge/research/research-ml-evaluation.md +407 -0
- package/content/knowledge/research/research-ml-experiment-tracking.md +466 -0
- package/content/knowledge/research/research-ml-training-patterns.md +413 -0
- package/content/knowledge/research/research-observability.md +395 -0
- package/content/knowledge/research/research-overfitting-prevention.md +306 -0
- package/content/knowledge/research/research-project-structure.md +264 -0
- package/content/knowledge/research/research-quant-backtesting.md +326 -0
- package/content/knowledge/research/research-quant-market-data.md +366 -0
- package/content/knowledge/research/research-quant-metrics.md +335 -0
- package/content/knowledge/research/research-quant-requirements.md +223 -0
- package/content/knowledge/research/research-quant-risk.md +469 -0
- package/content/knowledge/research/research-quant-strategy-patterns.md +412 -0
- package/content/knowledge/research/research-requirements.md +201 -0
- package/content/knowledge/research/research-security.md +374 -0
- package/content/knowledge/research/research-sim-compute-management.md +538 -0
- package/content/knowledge/research/research-sim-engine-patterns.md +448 -0
- package/content/knowledge/research/research-sim-parameter-spaces.md +425 -0
- package/content/knowledge/research/research-sim-validation.md +456 -0
- package/content/knowledge/research/research-testing.md +334 -0
- package/content/methodology/research-ml-research.yml +23 -0
- package/content/methodology/research-overlay.yml +65 -0
- package/content/methodology/research-quant-finance.yml +29 -0
- package/content/methodology/research-simulation.yml +23 -0
- package/dist/cli/commands/adopt.d.ts.map +1 -1
- package/dist/cli/commands/adopt.js +22 -1
- package/dist/cli/commands/adopt.js.map +1 -1
- package/dist/cli/commands/adopt.serialization.test.js +41 -0
- package/dist/cli/commands/adopt.serialization.test.js.map +1 -1
- package/dist/cli/commands/init.d.ts +4 -0
- package/dist/cli/commands/init.d.ts.map +1 -1
- package/dist/cli/commands/init.js +32 -2
- package/dist/cli/commands/init.js.map +1 -1
- package/dist/cli/init-flag-families.d.ts +6 -1
- package/dist/cli/init-flag-families.d.ts.map +1 -1
- package/dist/cli/init-flag-families.js +32 -1
- package/dist/cli/init-flag-families.js.map +1 -1
- package/dist/cli/init-flag-families.test.js +47 -0
- package/dist/cli/init-flag-families.test.js.map +1 -1
- package/dist/config/schema.d.ts +272 -16
- package/dist/config/schema.d.ts.map +1 -1
- package/dist/config/schema.js +25 -1
- package/dist/config/schema.js.map +1 -1
- package/dist/config/schema.test.js +103 -3
- package/dist/config/schema.test.js.map +1 -1
- package/dist/core/assembly/overlay-loader.d.ts +12 -0
- package/dist/core/assembly/overlay-loader.d.ts.map +1 -1
- package/dist/core/assembly/overlay-loader.js +30 -0
- package/dist/core/assembly/overlay-loader.js.map +1 -1
- package/dist/core/assembly/overlay-loader.test.js +66 -1
- package/dist/core/assembly/overlay-loader.test.js.map +1 -1
- package/dist/core/assembly/overlay-state-resolver.d.ts.map +1 -1
- package/dist/core/assembly/overlay-state-resolver.js +48 -19
- package/dist/core/assembly/overlay-state-resolver.js.map +1 -1
- package/dist/core/assembly/overlay-state-resolver.test.js +80 -0
- package/dist/core/assembly/overlay-state-resolver.test.js.map +1 -1
- package/dist/e2e/project-type-overlays.test.js +119 -0
- package/dist/e2e/project-type-overlays.test.js.map +1 -1
- package/dist/project/adopt.d.ts.map +1 -1
- package/dist/project/adopt.js +3 -1
- package/dist/project/adopt.js.map +1 -1
- package/dist/project/detectors/disambiguate.js +1 -1
- package/dist/project/detectors/disambiguate.js.map +1 -1
- package/dist/project/detectors/index.d.ts.map +1 -1
- package/dist/project/detectors/index.js +2 -1
- package/dist/project/detectors/index.js.map +1 -1
- package/dist/project/detectors/ml.d.ts.map +1 -1
- package/dist/project/detectors/ml.js +2 -6
- package/dist/project/detectors/ml.js.map +1 -1
- package/dist/project/detectors/research.d.ts +4 -0
- package/dist/project/detectors/research.d.ts.map +1 -0
- package/dist/project/detectors/research.js +141 -0
- package/dist/project/detectors/research.js.map +1 -0
- package/dist/project/detectors/research.test.d.ts +2 -0
- package/dist/project/detectors/research.test.d.ts.map +1 -0
- package/dist/project/detectors/research.test.js +235 -0
- package/dist/project/detectors/research.test.js.map +1 -0
- package/dist/project/detectors/shared-signals.d.ts +3 -0
- package/dist/project/detectors/shared-signals.d.ts.map +1 -0
- package/dist/project/detectors/shared-signals.js +9 -0
- package/dist/project/detectors/shared-signals.js.map +1 -0
- package/dist/project/detectors/types.d.ts +6 -2
- package/dist/project/detectors/types.d.ts.map +1 -1
- package/dist/project/detectors/types.js.map +1 -1
- package/dist/types/config.d.ts +7 -1
- package/dist/types/config.d.ts.map +1 -1
- package/dist/wizard/copy/core.d.ts.map +1 -1
- package/dist/wizard/copy/core.js +4 -0
- package/dist/wizard/copy/core.js.map +1 -1
- package/dist/wizard/copy/index.d.ts.map +1 -1
- package/dist/wizard/copy/index.js +2 -0
- package/dist/wizard/copy/index.js.map +1 -1
- package/dist/wizard/copy/research.d.ts +3 -0
- package/dist/wizard/copy/research.d.ts.map +1 -0
- package/dist/wizard/copy/research.js +27 -0
- package/dist/wizard/copy/research.js.map +1 -0
- package/dist/wizard/copy/types.d.ts +5 -1
- package/dist/wizard/copy/types.d.ts.map +1 -1
- package/dist/wizard/flags.d.ts +7 -1
- package/dist/wizard/flags.d.ts.map +1 -1
- package/dist/wizard/questions.d.ts +4 -2
- package/dist/wizard/questions.d.ts.map +1 -1
- package/dist/wizard/questions.js +27 -1
- package/dist/wizard/questions.js.map +1 -1
- package/dist/wizard/questions.test.js +51 -0
- package/dist/wizard/questions.test.js.map +1 -1
- package/dist/wizard/wizard.d.ts +3 -2
- package/dist/wizard/wizard.d.ts.map +1 -1
- package/dist/wizard/wizard.js +3 -1
- package/dist/wizard/wizard.js.map +1 -1
- package/package.json +1 -1
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
---
|
|
2
|
+
name: research-ml-architecture-search
|
|
3
|
+
description: Neural Architecture Search patterns including search space definition, search strategies, mutation operators, performance prediction, and multi-objective optimization
|
|
4
|
+
topics: [research, ml-research, nas, architecture-search, search-space, evolutionary, darts, surrogate-model, multi-objective]
|
|
5
|
+
---
|
|
6
|
+
|
|
7
|
+
Neural Architecture Search (NAS) automates the discovery of model architectures that outperform hand-designed ones. The core challenge is navigating an exponentially large search space efficiently -- evaluating every candidate is infeasible, so search strategies must balance exploration (trying diverse architectures) with exploitation (refining promising ones). A well-designed NAS pipeline defines the search space precisely, applies an appropriate search strategy, uses performance prediction to avoid wasting compute on bad candidates, and manages the total search budget to stay within resource constraints.
|
|
8
|
+
|
|
9
|
+
## Summary
|
|
10
|
+
|
|
11
|
+
Define search spaces as structured graphs with explicit operation choices and connectivity patterns. Choose search strategies based on budget: random search for baselines, evolutionary algorithms for large discrete spaces, reinforcement learning for sequential construction, and differentiable methods (DARTS) for gradient-based continuous relaxation. Use surrogate models to predict performance from partial training, reducing evaluation cost by 10-100x. Apply mutation operators that preserve architectural validity. Track Pareto frontiers for multi-objective NAS (accuracy vs latency, accuracy vs parameters).
|
|
12
|
+
|
|
13
|
+
## Deep Guidance
|
|
14
|
+
|
|
15
|
+
### Search Space Definition
|
|
16
|
+
|
|
17
|
+
The search space defines what architectures can be explored. Too narrow misses good designs; too broad wastes compute on invalid structures:
|
|
18
|
+
|
|
19
|
+
```python
|
|
20
|
+
# src/nas/search_space.py
|
|
21
|
+
from dataclasses import dataclass, field
|
|
22
|
+
from enum import Enum
|
|
23
|
+
from typing import Any
|
|
24
|
+
|
|
25
|
+
class OperationType(Enum):
|
|
26
|
+
"""Primitive operations available in the search space."""
|
|
27
|
+
CONV_3X3 = "conv_3x3"
|
|
28
|
+
CONV_5X5 = "conv_5x5"
|
|
29
|
+
SEPARABLE_CONV_3X3 = "sep_conv_3x3"
|
|
30
|
+
SEPARABLE_CONV_5X5 = "sep_conv_5x5"
|
|
31
|
+
DILATED_CONV_3X3 = "dil_conv_3x3"
|
|
32
|
+
MAX_POOL_3X3 = "max_pool_3x3"
|
|
33
|
+
AVG_POOL_3X3 = "avg_pool_3x3"
|
|
34
|
+
SKIP_CONNECT = "skip_connect"
|
|
35
|
+
ZERO = "zero" # No connection
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class SearchSpace:
|
|
39
|
+
"""Defines the architecture search space."""
|
|
40
|
+
num_nodes: int = 7 # Nodes per cell
|
|
41
|
+
num_ops_per_edge: int = 1 # Operations per edge
|
|
42
|
+
available_ops: list[OperationType] = field(
|
|
43
|
+
default_factory=lambda: list(OperationType)
|
|
44
|
+
)
|
|
45
|
+
num_cells: int = 8 # Total cells in the network
|
|
46
|
+
num_reduction_cells: int = 2 # Cells that downsample
|
|
47
|
+
channel_choices: list[int] = field(
|
|
48
|
+
default_factory=lambda: [16, 32, 64, 128]
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def space_size(self) -> int:
|
|
53
|
+
"""Estimate total number of architectures in the space."""
|
|
54
|
+
num_edges = self.num_nodes * (self.num_nodes - 1) // 2
|
|
55
|
+
ops_per_cell = len(self.available_ops) ** num_edges
|
|
56
|
+
return ops_per_cell * len(self.channel_choices) ** self.num_cells
|
|
57
|
+
|
|
58
|
+
def validate_architecture(self, arch: "Architecture") -> list[str]:
|
|
59
|
+
"""Check that an architecture is valid within this space."""
|
|
60
|
+
issues = []
|
|
61
|
+
if len(arch.cells) != self.num_cells:
|
|
62
|
+
issues.append(f"Expected {self.num_cells} cells, got {len(arch.cells)}")
|
|
63
|
+
for i, cell in enumerate(arch.cells):
|
|
64
|
+
for edge in cell.edges:
|
|
65
|
+
if edge.op not in self.available_ops:
|
|
66
|
+
issues.append(f"Cell {i}: invalid op {edge.op}")
|
|
67
|
+
return issues
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@dataclass
|
|
71
|
+
class Edge:
|
|
72
|
+
src_node: int
|
|
73
|
+
dst_node: int
|
|
74
|
+
op: OperationType
|
|
75
|
+
|
|
76
|
+
@dataclass
|
|
77
|
+
class Cell:
|
|
78
|
+
edges: list[Edge]
|
|
79
|
+
is_reduction: bool = False
|
|
80
|
+
|
|
81
|
+
@dataclass
|
|
82
|
+
class Architecture:
|
|
83
|
+
cells: list[Cell]
|
|
84
|
+
channels: list[int]
|
|
85
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+
### Search Strategies
|
|
89
|
+
|
|
90
|
+
#### Random Search (Baseline)
|
|
91
|
+
|
|
92
|
+
Always implement random search first -- it is surprisingly competitive and provides the baseline that any sophisticated method must beat:
|
|
93
|
+
|
|
94
|
+
```python
|
|
95
|
+
# src/nas/strategies/random_search.py
|
|
96
|
+
import random
|
|
97
|
+
from src.nas.search_space import SearchSpace, Architecture, Cell, Edge
|
|
98
|
+
|
|
99
|
+
def random_architecture(space: SearchSpace, seed: int | None = None) -> Architecture:
|
|
100
|
+
"""Sample a uniformly random valid architecture."""
|
|
101
|
+
rng = random.Random(seed)
|
|
102
|
+
cells = []
|
|
103
|
+
for i in range(space.num_cells):
|
|
104
|
+
is_reduction = i in _reduction_positions(space)
|
|
105
|
+
edges = []
|
|
106
|
+
for src in range(space.num_nodes):
|
|
107
|
+
for dst in range(src + 1, space.num_nodes):
|
|
108
|
+
op = rng.choice(space.available_ops)
|
|
109
|
+
edges.append(Edge(src_node=src, dst_node=dst, op=op))
|
|
110
|
+
cells.append(Cell(edges=edges, is_reduction=is_reduction))
|
|
111
|
+
channels = [rng.choice(space.channel_choices) for _ in range(space.num_cells)]
|
|
112
|
+
return Architecture(cells=cells, channels=channels)
|
|
113
|
+
|
|
114
|
+
def _reduction_positions(space: SearchSpace) -> list[int]:
|
|
115
|
+
"""Place reduction cells evenly through the network."""
|
|
116
|
+
step = space.num_cells // (space.num_reduction_cells + 1)
|
|
117
|
+
return [step * (i + 1) for i in range(space.num_reduction_cells)]
|
|
118
|
+
```
|
|
119
|
+
|
|
120
|
+
#### Evolutionary Search
|
|
121
|
+
|
|
122
|
+
Evolutionary NAS maintains a population of architectures, selects the best, applies mutations, and iterates:
|
|
123
|
+
|
|
124
|
+
```python
|
|
125
|
+
# src/nas/strategies/evolutionary.py
|
|
126
|
+
from dataclasses import dataclass
|
|
127
|
+
from src.nas.search_space import SearchSpace, Architecture
|
|
128
|
+
from src.nas.mutation import mutate_architecture
|
|
129
|
+
|
|
130
|
+
@dataclass
|
|
131
|
+
class EvolutionConfig:
|
|
132
|
+
population_size: int = 50
|
|
133
|
+
tournament_size: int = 10
|
|
134
|
+
mutation_rate: float = 0.3
|
|
135
|
+
max_generations: int = 500
|
|
136
|
+
early_stop_patience: int = 50 # Generations without improvement
|
|
137
|
+
|
|
138
|
+
def evolutionary_search(
|
|
139
|
+
space: SearchSpace,
|
|
140
|
+
evaluate_fn, # Architecture -> float (fitness)
|
|
141
|
+
config: EvolutionConfig = EvolutionConfig(),
|
|
142
|
+
) -> list[tuple[Architecture, float]]:
|
|
143
|
+
"""Run evolutionary architecture search."""
|
|
144
|
+
import random
|
|
145
|
+
|
|
146
|
+
# Initialize population
|
|
147
|
+
population = [
|
|
148
|
+
(random_architecture(space), None)
|
|
149
|
+
for _ in range(config.population_size)
|
|
150
|
+
]
|
|
151
|
+
# Evaluate initial population
|
|
152
|
+
population = [(arch, evaluate_fn(arch)) for arch, _ in population]
|
|
153
|
+
|
|
154
|
+
best_fitness = max(f for _, f in population)
|
|
155
|
+
stale_generations = 0
|
|
156
|
+
history = []
|
|
157
|
+
|
|
158
|
+
for gen in range(config.max_generations):
|
|
159
|
+
# Tournament selection
|
|
160
|
+
parent = _tournament_select(population, config.tournament_size)
|
|
161
|
+
|
|
162
|
+
# Mutation
|
|
163
|
+
child = mutate_architecture(parent, space, config.mutation_rate)
|
|
164
|
+
|
|
165
|
+
# Evaluate child
|
|
166
|
+
child_fitness = evaluate_fn(child)
|
|
167
|
+
|
|
168
|
+
# Replace worst in population
|
|
169
|
+
population.sort(key=lambda x: x[1])
|
|
170
|
+
population[0] = (child, child_fitness)
|
|
171
|
+
|
|
172
|
+
# Track progress
|
|
173
|
+
gen_best = max(f for _, f in population)
|
|
174
|
+
history.append({"generation": gen, "best_fitness": gen_best})
|
|
175
|
+
|
|
176
|
+
if gen_best > best_fitness:
|
|
177
|
+
best_fitness = gen_best
|
|
178
|
+
stale_generations = 0
|
|
179
|
+
else:
|
|
180
|
+
stale_generations += 1
|
|
181
|
+
if stale_generations >= config.early_stop_patience:
|
|
182
|
+
break
|
|
183
|
+
|
|
184
|
+
population.sort(key=lambda x: x[1], reverse=True)
|
|
185
|
+
return population
|
|
186
|
+
|
|
187
|
+
def _tournament_select(population, k):
|
|
188
|
+
import random
|
|
189
|
+
candidates = random.sample(population, k)
|
|
190
|
+
return max(candidates, key=lambda x: x[1])[0]
|
|
191
|
+
```
|
|
192
|
+
|
|
193
|
+
#### Differentiable NAS (DARTS)
|
|
194
|
+
|
|
195
|
+
DARTS relaxes the discrete search space into a continuous one, enabling gradient-based optimization:
|
|
196
|
+
|
|
197
|
+
```python
|
|
198
|
+
# src/nas/strategies/darts.py
|
|
199
|
+
import torch
|
|
200
|
+
import torch.nn as nn
|
|
201
|
+
import torch.nn.functional as F
|
|
202
|
+
|
|
203
|
+
class MixedOp(nn.Module):
|
|
204
|
+
"""Weighted mixture of operations for continuous relaxation."""
|
|
205
|
+
|
|
206
|
+
def __init__(self, channels: int, ops: list[nn.Module]):
|
|
207
|
+
super().__init__()
|
|
208
|
+
self.ops = nn.ModuleList(ops)
|
|
209
|
+
self.alpha = nn.Parameter(torch.zeros(len(ops))) # Architecture weights
|
|
210
|
+
|
|
211
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
212
|
+
weights = F.softmax(self.alpha, dim=0)
|
|
213
|
+
return sum(w * op(x) for w, op in zip(weights, self.ops))
|
|
214
|
+
|
|
215
|
+
def discretize(self) -> int:
|
|
216
|
+
"""Select the operation with highest weight."""
|
|
217
|
+
return self.alpha.argmax().item()
|
|
218
|
+
```
|
|
219
|
+
|
|
220
|
+
### Mutation Operators
|
|
221
|
+
|
|
222
|
+
Mutations must preserve architectural validity while enabling meaningful exploration:
|
|
223
|
+
|
|
224
|
+
```python
|
|
225
|
+
# src/nas/mutation.py
|
|
226
|
+
import random
|
|
227
|
+
from src.nas.search_space import SearchSpace, Architecture, Cell, Edge
|
|
228
|
+
|
|
229
|
+
def mutate_architecture(
|
|
230
|
+
arch: Architecture,
|
|
231
|
+
space: SearchSpace,
|
|
232
|
+
mutation_rate: float = 0.3,
|
|
233
|
+
) -> Architecture:
|
|
234
|
+
"""Apply random mutations while preserving validity."""
|
|
235
|
+
import copy
|
|
236
|
+
child = copy.deepcopy(arch)
|
|
237
|
+
|
|
238
|
+
for cell in child.cells:
|
|
239
|
+
for edge in cell.edges:
|
|
240
|
+
if random.random() < mutation_rate:
|
|
241
|
+
edge.op = random.choice(space.available_ops)
|
|
242
|
+
|
|
243
|
+
# Optionally mutate channel widths
|
|
244
|
+
for i in range(len(child.channels)):
|
|
245
|
+
if random.random() < mutation_rate * 0.5:
|
|
246
|
+
child.channels[i] = random.choice(space.channel_choices)
|
|
247
|
+
|
|
248
|
+
return child
|
|
249
|
+
|
|
250
|
+
def crossover(parent_a: Architecture, parent_b: Architecture) -> Architecture:
|
|
251
|
+
"""Single-point crossover between two architectures."""
|
|
252
|
+
import copy
|
|
253
|
+
child = copy.deepcopy(parent_a)
|
|
254
|
+
crossover_point = random.randint(1, len(child.cells) - 1)
|
|
255
|
+
child.cells[crossover_point:] = copy.deepcopy(parent_b.cells[crossover_point:])
|
|
256
|
+
child.channels[crossover_point:] = parent_b.channels[crossover_point:]
|
|
257
|
+
return child
|
|
258
|
+
```
|
|
259
|
+
|
|
260
|
+
### Performance Prediction (Surrogates)
|
|
261
|
+
|
|
262
|
+
Full training is expensive. Surrogates predict final performance from cheap features:
|
|
263
|
+
|
|
264
|
+
```python
|
|
265
|
+
# src/nas/surrogate.py
|
|
266
|
+
import numpy as np
|
|
267
|
+
from sklearn.ensemble import GradientBoostingRegressor
|
|
268
|
+
|
|
269
|
+
class PerformancePredictor:
|
|
270
|
+
"""Predict architecture performance from structural features."""
|
|
271
|
+
|
|
272
|
+
def __init__(self):
|
|
273
|
+
self.model = GradientBoostingRegressor(n_estimators=100)
|
|
274
|
+
self.is_fitted = False
|
|
275
|
+
|
|
276
|
+
def extract_features(self, arch: Architecture) -> np.ndarray:
|
|
277
|
+
"""Convert architecture to a fixed-length feature vector."""
|
|
278
|
+
features = []
|
|
279
|
+
for cell in arch.cells:
|
|
280
|
+
op_counts = [0] * len(OperationType)
|
|
281
|
+
for edge in cell.edges:
|
|
282
|
+
op_counts[list(OperationType).index(edge.op)] += 1
|
|
283
|
+
features.extend(op_counts)
|
|
284
|
+
features.extend(arch.channels)
|
|
285
|
+
return np.array(features, dtype=np.float32)
|
|
286
|
+
|
|
287
|
+
def fit(self, architectures: list[Architecture], scores: list[float]) -> None:
|
|
288
|
+
"""Train surrogate on evaluated architectures."""
|
|
289
|
+
X = np.array([self.extract_features(a) for a in architectures])
|
|
290
|
+
self.model.fit(X, scores)
|
|
291
|
+
self.is_fitted = True
|
|
292
|
+
|
|
293
|
+
def predict(self, arch: Architecture) -> float:
|
|
294
|
+
"""Predict performance without full training."""
|
|
295
|
+
if not self.is_fitted:
|
|
296
|
+
raise RuntimeError("Surrogate not fitted yet")
|
|
297
|
+
X = self.extract_features(arch).reshape(1, -1)
|
|
298
|
+
return self.model.predict(X)[0]
|
|
299
|
+
|
|
300
|
+
def acquisition_score(self, arch: Architecture) -> float:
|
|
301
|
+
"""Score for acquisition function (exploration vs exploitation)."""
|
|
302
|
+
pred = self.predict(arch)
|
|
303
|
+
# Simple UCB-style: higher predicted + bonus for uncertainty
|
|
304
|
+
return pred
|
|
305
|
+
```
|
|
306
|
+
|
|
307
|
+
### Search Budget Management
|
|
308
|
+
|
|
309
|
+
NAS must operate within compute constraints. Track and enforce budgets:
|
|
310
|
+
|
|
311
|
+
```python
|
|
312
|
+
# src/nas/budget.py
|
|
313
|
+
from dataclasses import dataclass
|
|
314
|
+
import time
|
|
315
|
+
|
|
316
|
+
@dataclass
|
|
317
|
+
class SearchBudget:
|
|
318
|
+
"""Track and enforce NAS compute budget."""
|
|
319
|
+
max_gpu_hours: float = 100.0
|
|
320
|
+
max_evaluations: int = 1000
|
|
321
|
+
max_wall_time_hours: float = 48.0
|
|
322
|
+
|
|
323
|
+
# Running totals
|
|
324
|
+
gpu_hours_used: float = 0.0
|
|
325
|
+
evaluations_done: int = 0
|
|
326
|
+
start_time: float = 0.0
|
|
327
|
+
|
|
328
|
+
def start(self) -> None:
|
|
329
|
+
self.start_time = time.time()
|
|
330
|
+
|
|
331
|
+
def record_evaluation(self, gpu_hours: float) -> None:
|
|
332
|
+
self.gpu_hours_used += gpu_hours
|
|
333
|
+
self.evaluations_done += 1
|
|
334
|
+
|
|
335
|
+
def is_exhausted(self) -> bool:
|
|
336
|
+
wall_hours = (time.time() - self.start_time) / 3600
|
|
337
|
+
return (
|
|
338
|
+
self.gpu_hours_used >= self.max_gpu_hours
|
|
339
|
+
or self.evaluations_done >= self.max_evaluations
|
|
340
|
+
or wall_hours >= self.max_wall_time_hours
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
def remaining_fraction(self) -> float:
|
|
344
|
+
gpu_frac = 1 - self.gpu_hours_used / self.max_gpu_hours
|
|
345
|
+
eval_frac = 1 - self.evaluations_done / self.max_evaluations
|
|
346
|
+
return min(gpu_frac, eval_frac)
|
|
347
|
+
```
|
|
348
|
+
|
|
349
|
+
### Multi-Objective NAS
|
|
350
|
+
|
|
351
|
+
Real NAS problems have multiple objectives (accuracy, latency, parameters, FLOPs). Track the Pareto frontier:
|
|
352
|
+
|
|
353
|
+
```python
|
|
354
|
+
# src/nas/pareto.py
|
|
355
|
+
from dataclasses import dataclass
|
|
356
|
+
|
|
357
|
+
@dataclass
|
|
358
|
+
class ObjectiveResult:
|
|
359
|
+
architecture_id: str
|
|
360
|
+
accuracy: float
|
|
361
|
+
latency_ms: float
|
|
362
|
+
params_millions: float
|
|
363
|
+
flops_billions: float
|
|
364
|
+
|
|
365
|
+
def is_dominated(a: ObjectiveResult, b: ObjectiveResult) -> bool:
|
|
366
|
+
"""Return True if b dominates a (b is better in all objectives)."""
|
|
367
|
+
return (
|
|
368
|
+
b.accuracy >= a.accuracy
|
|
369
|
+
and b.latency_ms <= a.latency_ms
|
|
370
|
+
and b.params_millions <= a.params_millions
|
|
371
|
+
and (b.accuracy > a.accuracy or b.latency_ms < a.latency_ms
|
|
372
|
+
or b.params_millions < a.params_millions)
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
def pareto_frontier(results: list[ObjectiveResult]) -> list[ObjectiveResult]:
|
|
376
|
+
"""Extract non-dominated solutions (Pareto frontier)."""
|
|
377
|
+
frontier = []
|
|
378
|
+
for candidate in results:
|
|
379
|
+
dominated = any(is_dominated(candidate, other) for other in results)
|
|
380
|
+
if not dominated:
|
|
381
|
+
frontier.append(candidate)
|
|
382
|
+
return frontier
|
|
383
|
+
```
|