morphml 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of morphml might be problematic. Click here for more details.
- morphml/__init__.py +14 -0
- morphml/api/__init__.py +26 -0
- morphml/api/app.py +326 -0
- morphml/api/auth.py +193 -0
- morphml/api/client.py +338 -0
- morphml/api/models.py +132 -0
- morphml/api/rate_limit.py +192 -0
- morphml/benchmarking/__init__.py +36 -0
- morphml/benchmarking/comparison.py +430 -0
- morphml/benchmarks/__init__.py +56 -0
- morphml/benchmarks/comparator.py +409 -0
- morphml/benchmarks/datasets.py +280 -0
- morphml/benchmarks/metrics.py +199 -0
- morphml/benchmarks/openml_suite.py +201 -0
- morphml/benchmarks/problems.py +289 -0
- morphml/benchmarks/suite.py +318 -0
- morphml/cli/__init__.py +5 -0
- morphml/cli/commands/experiment.py +329 -0
- morphml/cli/main.py +457 -0
- morphml/cli/quickstart.py +312 -0
- morphml/config.py +278 -0
- morphml/constraints/__init__.py +19 -0
- morphml/constraints/handler.py +205 -0
- morphml/constraints/predicates.py +285 -0
- morphml/core/__init__.py +3 -0
- morphml/core/crossover.py +449 -0
- morphml/core/dsl/README.md +359 -0
- morphml/core/dsl/__init__.py +72 -0
- morphml/core/dsl/ast_nodes.py +364 -0
- morphml/core/dsl/compiler.py +318 -0
- morphml/core/dsl/layers.py +368 -0
- morphml/core/dsl/lexer.py +336 -0
- morphml/core/dsl/parser.py +455 -0
- morphml/core/dsl/search_space.py +386 -0
- morphml/core/dsl/syntax.py +199 -0
- morphml/core/dsl/type_system.py +361 -0
- morphml/core/dsl/validator.py +386 -0
- morphml/core/graph/__init__.py +40 -0
- morphml/core/graph/edge.py +124 -0
- morphml/core/graph/graph.py +507 -0
- morphml/core/graph/mutations.py +409 -0
- morphml/core/graph/node.py +196 -0
- morphml/core/graph/serialization.py +361 -0
- morphml/core/graph/visualization.py +431 -0
- morphml/core/objectives/__init__.py +20 -0
- morphml/core/search/__init__.py +33 -0
- morphml/core/search/individual.py +252 -0
- morphml/core/search/parameters.py +453 -0
- morphml/core/search/population.py +375 -0
- morphml/core/search/search_engine.py +340 -0
- morphml/distributed/__init__.py +76 -0
- morphml/distributed/fault_tolerance.py +497 -0
- morphml/distributed/health_monitor.py +348 -0
- morphml/distributed/master.py +709 -0
- morphml/distributed/proto/README.md +224 -0
- morphml/distributed/proto/__init__.py +74 -0
- morphml/distributed/proto/worker.proto +170 -0
- morphml/distributed/proto/worker_pb2.py +79 -0
- morphml/distributed/proto/worker_pb2_grpc.py +423 -0
- morphml/distributed/resource_manager.py +416 -0
- morphml/distributed/scheduler.py +567 -0
- morphml/distributed/storage/__init__.py +33 -0
- morphml/distributed/storage/artifacts.py +381 -0
- morphml/distributed/storage/cache.py +366 -0
- morphml/distributed/storage/checkpointing.py +329 -0
- morphml/distributed/storage/database.py +459 -0
- morphml/distributed/worker.py +549 -0
- morphml/evaluation/__init__.py +5 -0
- morphml/evaluation/heuristic.py +237 -0
- morphml/exceptions.py +55 -0
- morphml/execution/__init__.py +5 -0
- morphml/execution/local_executor.py +350 -0
- morphml/integrations/__init__.py +28 -0
- morphml/integrations/jax_adapter.py +206 -0
- morphml/integrations/pytorch_adapter.py +530 -0
- morphml/integrations/sklearn_adapter.py +206 -0
- morphml/integrations/tensorflow_adapter.py +230 -0
- morphml/logging_config.py +93 -0
- morphml/meta_learning/__init__.py +66 -0
- morphml/meta_learning/architecture_similarity.py +277 -0
- morphml/meta_learning/experiment_database.py +240 -0
- morphml/meta_learning/knowledge_base/__init__.py +19 -0
- morphml/meta_learning/knowledge_base/embedder.py +179 -0
- morphml/meta_learning/knowledge_base/knowledge_base.py +313 -0
- morphml/meta_learning/knowledge_base/meta_features.py +265 -0
- morphml/meta_learning/knowledge_base/vector_store.py +271 -0
- morphml/meta_learning/predictors/__init__.py +27 -0
- morphml/meta_learning/predictors/ensemble.py +221 -0
- morphml/meta_learning/predictors/gnn_predictor.py +552 -0
- morphml/meta_learning/predictors/learning_curve.py +231 -0
- morphml/meta_learning/predictors/proxy_metrics.py +261 -0
- morphml/meta_learning/strategy_evolution/__init__.py +27 -0
- morphml/meta_learning/strategy_evolution/adaptive_optimizer.py +226 -0
- morphml/meta_learning/strategy_evolution/bandit.py +276 -0
- morphml/meta_learning/strategy_evolution/portfolio.py +230 -0
- morphml/meta_learning/transfer.py +581 -0
- morphml/meta_learning/warm_start.py +286 -0
- morphml/optimizers/__init__.py +74 -0
- morphml/optimizers/adaptive_operators.py +399 -0
- morphml/optimizers/bayesian/__init__.py +52 -0
- morphml/optimizers/bayesian/acquisition.py +387 -0
- morphml/optimizers/bayesian/base.py +319 -0
- morphml/optimizers/bayesian/gaussian_process.py +635 -0
- morphml/optimizers/bayesian/smac.py +534 -0
- morphml/optimizers/bayesian/tpe.py +411 -0
- morphml/optimizers/differential_evolution.py +220 -0
- morphml/optimizers/evolutionary/__init__.py +61 -0
- morphml/optimizers/evolutionary/cma_es.py +416 -0
- morphml/optimizers/evolutionary/differential_evolution.py +556 -0
- morphml/optimizers/evolutionary/encoding.py +426 -0
- morphml/optimizers/evolutionary/particle_swarm.py +449 -0
- morphml/optimizers/genetic_algorithm.py +486 -0
- morphml/optimizers/gradient_based/__init__.py +22 -0
- morphml/optimizers/gradient_based/darts.py +550 -0
- morphml/optimizers/gradient_based/enas.py +585 -0
- morphml/optimizers/gradient_based/operations.py +474 -0
- morphml/optimizers/gradient_based/utils.py +601 -0
- morphml/optimizers/hill_climbing.py +169 -0
- morphml/optimizers/multi_objective/__init__.py +56 -0
- morphml/optimizers/multi_objective/indicators.py +504 -0
- morphml/optimizers/multi_objective/nsga2.py +647 -0
- morphml/optimizers/multi_objective/visualization.py +427 -0
- morphml/optimizers/nsga2.py +308 -0
- morphml/optimizers/random_search.py +172 -0
- morphml/optimizers/simulated_annealing.py +181 -0
- morphml/plugins/__init__.py +35 -0
- morphml/plugins/custom_evaluator_example.py +81 -0
- morphml/plugins/custom_optimizer_example.py +63 -0
- morphml/plugins/plugin_system.py +454 -0
- morphml/reports/__init__.py +30 -0
- morphml/reports/generator.py +362 -0
- morphml/tracking/__init__.py +7 -0
- morphml/tracking/experiment.py +309 -0
- morphml/tracking/logger.py +301 -0
- morphml/tracking/reporter.py +357 -0
- morphml/utils/__init__.py +6 -0
- morphml/utils/checkpoint.py +189 -0
- morphml/utils/comparison.py +390 -0
- morphml/utils/export.py +407 -0
- morphml/utils/progress.py +392 -0
- morphml/utils/validation.py +392 -0
- morphml/version.py +7 -0
- morphml/visualization/__init__.py +50 -0
- morphml/visualization/analytics.py +423 -0
- morphml/visualization/architecture_diagrams.py +353 -0
- morphml/visualization/architecture_plot.py +223 -0
- morphml/visualization/convergence_plot.py +174 -0
- morphml/visualization/crossover_viz.py +386 -0
- morphml/visualization/graph_viz.py +338 -0
- morphml/visualization/pareto_plot.py +149 -0
- morphml/visualization/plotly_dashboards.py +422 -0
- morphml/visualization/population.py +309 -0
- morphml/visualization/progress.py +260 -0
- morphml-1.0.0.dist-info/METADATA +434 -0
- morphml-1.0.0.dist-info/RECORD +158 -0
- morphml-1.0.0.dist-info/WHEEL +4 -0
- morphml-1.0.0.dist-info/entry_points.txt +3 -0
- morphml-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,709 @@
|
|
|
1
|
+
"""Master node for distributed architecture search.
|
|
2
|
+
|
|
3
|
+
The master coordinates optimization across multiple worker nodes, distributing
|
|
4
|
+
evaluation tasks and aggregating results.
|
|
5
|
+
|
|
6
|
+
Author: Eshan Roy <eshanized@proton.me>
|
|
7
|
+
Organization: TONMOY INFRASTRUCTURE & VISION
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import threading
|
|
11
|
+
import time
|
|
12
|
+
import uuid
|
|
13
|
+
from concurrent import futures
|
|
14
|
+
from dataclasses import dataclass, field
|
|
15
|
+
from queue import Empty, Queue
|
|
16
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
import grpc
|
|
20
|
+
|
|
21
|
+
from morphml.distributed.proto import worker_pb2, worker_pb2_grpc
|
|
22
|
+
|
|
23
|
+
GRPC_AVAILABLE = True
|
|
24
|
+
except ImportError:
|
|
25
|
+
GRPC_AVAILABLE = False
|
|
26
|
+
|
|
27
|
+
# Create stub modules when grpc is not available
|
|
28
|
+
class _StubModule:
|
|
29
|
+
def __getattr__(self, name):
|
|
30
|
+
raise ImportError(
|
|
31
|
+
"grpc is not installed. Install with: pip install grpcio grpcio-tools"
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
worker_pb2 = _StubModule()
|
|
35
|
+
worker_pb2_grpc = _StubModule()
|
|
36
|
+
|
|
37
|
+
from morphml.core.graph import ModelGraph
|
|
38
|
+
from morphml.exceptions import DistributedError
|
|
39
|
+
from morphml.logging_config import get_logger
|
|
40
|
+
|
|
41
|
+
logger = get_logger(__name__)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class WorkerInfo:
|
|
46
|
+
"""Worker node metadata and state."""
|
|
47
|
+
|
|
48
|
+
worker_id: str
|
|
49
|
+
host: str
|
|
50
|
+
port: int
|
|
51
|
+
num_gpus: int
|
|
52
|
+
gpu_ids: List[int] = field(default_factory=list)
|
|
53
|
+
status: str = "idle" # 'idle', 'busy', 'dead'
|
|
54
|
+
last_heartbeat: float = field(default_factory=time.time)
|
|
55
|
+
current_task: Optional[str] = None
|
|
56
|
+
tasks_completed: int = 0
|
|
57
|
+
tasks_failed: int = 0
|
|
58
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
59
|
+
|
|
60
|
+
def is_alive(self, timeout: float = 30.0) -> bool:
|
|
61
|
+
"""Check if worker is alive based on heartbeat."""
|
|
62
|
+
return (time.time() - self.last_heartbeat) < timeout
|
|
63
|
+
|
|
64
|
+
def is_available(self) -> bool:
|
|
65
|
+
"""Check if worker is available for new tasks."""
|
|
66
|
+
return self.status == "idle" and self.is_alive()
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@dataclass
|
|
70
|
+
class Task:
|
|
71
|
+
"""Evaluation task for distributed execution."""
|
|
72
|
+
|
|
73
|
+
task_id: str
|
|
74
|
+
architecture: ModelGraph
|
|
75
|
+
status: str = "pending" # 'pending', 'running', 'completed', 'failed'
|
|
76
|
+
worker_id: Optional[str] = None
|
|
77
|
+
created_at: float = field(default_factory=time.time)
|
|
78
|
+
started_at: Optional[float] = None
|
|
79
|
+
completed_at: Optional[float] = None
|
|
80
|
+
result: Optional[Dict[str, Any]] = None
|
|
81
|
+
error: Optional[str] = None
|
|
82
|
+
num_retries: int = 0
|
|
83
|
+
max_retries: int = 3
|
|
84
|
+
|
|
85
|
+
def duration(self) -> Optional[float]:
|
|
86
|
+
"""Get task execution duration."""
|
|
87
|
+
if self.started_at and self.completed_at:
|
|
88
|
+
return self.completed_at - self.started_at
|
|
89
|
+
return None
|
|
90
|
+
|
|
91
|
+
def can_retry(self) -> bool:
|
|
92
|
+
"""Check if task can be retried."""
|
|
93
|
+
return self.num_retries < self.max_retries
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class MasterNode:
|
|
97
|
+
"""
|
|
98
|
+
Master node for distributed NAS.
|
|
99
|
+
|
|
100
|
+
Coordinates architecture search across multiple worker nodes:
|
|
101
|
+
1. Distributes evaluation tasks to workers
|
|
102
|
+
2. Collects and aggregates results
|
|
103
|
+
3. Monitors worker health via heartbeat
|
|
104
|
+
4. Handles worker failures and task reassignment
|
|
105
|
+
5. Manages optimization state
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
optimizer: Base optimizer (GA, BO, etc.)
|
|
109
|
+
config: Master configuration
|
|
110
|
+
- host: Master host (default: '0.0.0.0')
|
|
111
|
+
- port: Master port (default: 50051)
|
|
112
|
+
- num_workers: Expected number of workers
|
|
113
|
+
- heartbeat_interval: Heartbeat check interval (seconds, default: 10)
|
|
114
|
+
- task_timeout: Task timeout (seconds, default: 3600)
|
|
115
|
+
- max_retries: Maximum task retries (default: 3)
|
|
116
|
+
|
|
117
|
+
Example:
|
|
118
|
+
>>> from morphml.optimizers import GeneticAlgorithm
|
|
119
|
+
>>> optimizer = GeneticAlgorithm(space, population_size=50)
|
|
120
|
+
>>> master = MasterNode(optimizer, {'port': 50051, 'num_workers': 4})
|
|
121
|
+
>>> master.start()
|
|
122
|
+
>>> best = master.run_experiment(num_generations=100)
|
|
123
|
+
>>> master.stop()
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
def __init__(self, optimizer: Any, config: Dict[str, Any]):
|
|
127
|
+
"""Initialize master node."""
|
|
128
|
+
if not GRPC_AVAILABLE:
|
|
129
|
+
raise DistributedError(
|
|
130
|
+
"gRPC not available. Install with: pip install grpcio grpcio-tools"
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
self.optimizer = optimizer
|
|
134
|
+
self.config = config
|
|
135
|
+
|
|
136
|
+
# Server configuration
|
|
137
|
+
self.host = config.get("host", "0.0.0.0")
|
|
138
|
+
self.port = config.get("port", 50051)
|
|
139
|
+
self.num_workers = config.get("num_workers", 4)
|
|
140
|
+
self.heartbeat_interval = config.get("heartbeat_interval", 10)
|
|
141
|
+
self.task_timeout = config.get("task_timeout", 3600)
|
|
142
|
+
self.max_retries = config.get("max_retries", 3)
|
|
143
|
+
|
|
144
|
+
# Worker registry
|
|
145
|
+
self.workers: Dict[str, WorkerInfo] = {}
|
|
146
|
+
self.worker_lock = threading.Lock()
|
|
147
|
+
|
|
148
|
+
# Task management
|
|
149
|
+
self.pending_tasks: Queue = Queue()
|
|
150
|
+
self.running_tasks: Dict[str, Task] = {}
|
|
151
|
+
self.completed_tasks: Dict[str, Task] = {}
|
|
152
|
+
self.failed_tasks: Dict[str, Task] = {}
|
|
153
|
+
self.task_lock = threading.Lock()
|
|
154
|
+
|
|
155
|
+
# gRPC server
|
|
156
|
+
self.server: Optional[grpc.Server] = None
|
|
157
|
+
self.master_id = str(uuid.uuid4())[:8]
|
|
158
|
+
|
|
159
|
+
# State
|
|
160
|
+
self.running = False
|
|
161
|
+
self.total_evaluations = 0
|
|
162
|
+
|
|
163
|
+
logger.info(f"Initialized MasterNode (id={self.master_id}) on {self.host}:{self.port}")
|
|
164
|
+
|
|
165
|
+
def start(self) -> None:
|
|
166
|
+
"""Start master node server."""
|
|
167
|
+
logger.info(f"Starting master node on {self.host}:{self.port}")
|
|
168
|
+
|
|
169
|
+
# Create gRPC server
|
|
170
|
+
self.server = grpc.server(
|
|
171
|
+
futures.ThreadPoolExecutor(max_workers=20),
|
|
172
|
+
options=[
|
|
173
|
+
("grpc.max_send_message_length", 100 * 1024 * 1024),
|
|
174
|
+
("grpc.max_receive_message_length", 100 * 1024 * 1024),
|
|
175
|
+
],
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# Add servicer
|
|
179
|
+
worker_pb2_grpc.add_MasterServiceServicer_to_server(MasterServicer(self), self.server)
|
|
180
|
+
|
|
181
|
+
# Start server
|
|
182
|
+
self.server.add_insecure_port(f"{self.host}:{self.port}")
|
|
183
|
+
self.server.start()
|
|
184
|
+
|
|
185
|
+
self.running = True
|
|
186
|
+
|
|
187
|
+
# Start background threads
|
|
188
|
+
self._start_heartbeat_monitor()
|
|
189
|
+
self._start_task_dispatcher()
|
|
190
|
+
|
|
191
|
+
logger.info(f"Master node started successfully (id={self.master_id})")
|
|
192
|
+
|
|
193
|
+
def stop(self) -> None:
|
|
194
|
+
"""Stop master node gracefully."""
|
|
195
|
+
logger.info("Stopping master node")
|
|
196
|
+
self.running = False
|
|
197
|
+
|
|
198
|
+
if self.server:
|
|
199
|
+
self.server.stop(grace=5)
|
|
200
|
+
|
|
201
|
+
logger.info("Master node stopped")
|
|
202
|
+
|
|
203
|
+
def register_worker(self, worker_id: str, worker_info: Dict[str, Any]) -> bool:
|
|
204
|
+
"""
|
|
205
|
+
Register a worker node.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
worker_id: Unique worker identifier
|
|
209
|
+
worker_info: Worker metadata
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
True if registration successful
|
|
213
|
+
"""
|
|
214
|
+
with self.worker_lock:
|
|
215
|
+
self.workers[worker_id] = WorkerInfo(
|
|
216
|
+
worker_id=worker_id,
|
|
217
|
+
host=worker_info["host"],
|
|
218
|
+
port=worker_info["port"],
|
|
219
|
+
num_gpus=worker_info.get("num_gpus", 1),
|
|
220
|
+
gpu_ids=worker_info.get("gpu_ids", []),
|
|
221
|
+
status="idle",
|
|
222
|
+
last_heartbeat=time.time(),
|
|
223
|
+
metadata=worker_info.get("metadata", {}),
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
logger.info(
|
|
227
|
+
f"Worker registered: {worker_id} "
|
|
228
|
+
f"({worker_info['host']}:{worker_info['port']}, "
|
|
229
|
+
f"GPUs: {worker_info.get('num_gpus', 1)})"
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
return True
|
|
233
|
+
|
|
234
|
+
def update_heartbeat(self, worker_id: str, status: str) -> bool:
|
|
235
|
+
"""
|
|
236
|
+
Update worker heartbeat.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
worker_id: Worker identifier
|
|
240
|
+
status: Worker status ('idle', 'busy', 'error')
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
True if worker found and updated
|
|
244
|
+
"""
|
|
245
|
+
with self.worker_lock:
|
|
246
|
+
if worker_id in self.workers:
|
|
247
|
+
worker = self.workers[worker_id]
|
|
248
|
+
worker.last_heartbeat = time.time()
|
|
249
|
+
worker.status = status
|
|
250
|
+
return True
|
|
251
|
+
|
|
252
|
+
logger.warning(f"Heartbeat from unknown worker: {worker_id}")
|
|
253
|
+
return False
|
|
254
|
+
|
|
255
|
+
def submit_task(self, architecture: ModelGraph, task_id: Optional[str] = None) -> str:
|
|
256
|
+
"""
|
|
257
|
+
Submit architecture evaluation task.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
architecture: ModelGraph to evaluate
|
|
261
|
+
task_id: Optional task ID (generated if not provided)
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
Task ID
|
|
265
|
+
"""
|
|
266
|
+
if task_id is None:
|
|
267
|
+
task_id = str(uuid.uuid4())
|
|
268
|
+
|
|
269
|
+
task = Task(
|
|
270
|
+
task_id=task_id,
|
|
271
|
+
architecture=architecture,
|
|
272
|
+
status="pending",
|
|
273
|
+
created_at=time.time(),
|
|
274
|
+
max_retries=self.max_retries,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
self.pending_tasks.put(task)
|
|
278
|
+
|
|
279
|
+
logger.debug(f"Task submitted: {task_id}")
|
|
280
|
+
|
|
281
|
+
return task_id
|
|
282
|
+
|
|
283
|
+
def get_result(self, task_id: str, timeout: Optional[float] = None) -> Optional[Dict[str, Any]]:
|
|
284
|
+
"""
|
|
285
|
+
Get result for task (blocking).
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
task_id: Task identifier
|
|
289
|
+
timeout: Maximum wait time (seconds)
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
Task result dictionary or None if timeout/failed
|
|
293
|
+
"""
|
|
294
|
+
start_time = time.time()
|
|
295
|
+
|
|
296
|
+
while True:
|
|
297
|
+
with self.task_lock:
|
|
298
|
+
# Check completed
|
|
299
|
+
if task_id in self.completed_tasks:
|
|
300
|
+
task = self.completed_tasks[task_id]
|
|
301
|
+
return task.result
|
|
302
|
+
|
|
303
|
+
# Check failed
|
|
304
|
+
if task_id in self.failed_tasks:
|
|
305
|
+
logger.warning(f"Task {task_id} failed: {self.failed_tasks[task_id].error}")
|
|
306
|
+
return None
|
|
307
|
+
|
|
308
|
+
# Check timeout
|
|
309
|
+
if timeout and (time.time() - start_time) > timeout:
|
|
310
|
+
logger.warning(f"Task {task_id} timeout after {timeout}s")
|
|
311
|
+
return None
|
|
312
|
+
|
|
313
|
+
time.sleep(0.1)
|
|
314
|
+
|
|
315
|
+
def run_experiment(
|
|
316
|
+
self, num_generations: int = 100, callback: Optional[Callable] = None
|
|
317
|
+
) -> List[ModelGraph]:
|
|
318
|
+
"""
|
|
319
|
+
Run full distributed NAS experiment.
|
|
320
|
+
|
|
321
|
+
Args:
|
|
322
|
+
num_generations: Number of optimization generations
|
|
323
|
+
callback: Optional callback(generation, stats) called each generation
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
List of best architectures found
|
|
327
|
+
|
|
328
|
+
Example:
|
|
329
|
+
>>> def progress(gen, stats):
|
|
330
|
+
... print(f"Gen {gen}: best={stats['best_fitness']:.4f}")
|
|
331
|
+
>>> best = master.run_experiment(100, callback=progress)
|
|
332
|
+
"""
|
|
333
|
+
logger.info(
|
|
334
|
+
f"Starting distributed experiment "
|
|
335
|
+
f"({num_generations} generations, {self.num_workers} workers)"
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
# Wait for workers
|
|
339
|
+
self._wait_for_workers(timeout=300)
|
|
340
|
+
|
|
341
|
+
# Initialize optimizer
|
|
342
|
+
logger.info("Initializing population")
|
|
343
|
+
if hasattr(self.optimizer, "initialize_population"):
|
|
344
|
+
self.optimizer.initialize_population()
|
|
345
|
+
|
|
346
|
+
# Evolution loop
|
|
347
|
+
for generation in range(num_generations):
|
|
348
|
+
logger.info(f"Generation {generation + 1}/{num_generations}")
|
|
349
|
+
|
|
350
|
+
# Get individuals to evaluate
|
|
351
|
+
if hasattr(self.optimizer, "population"):
|
|
352
|
+
individuals = self.optimizer.population.get_unevaluated()
|
|
353
|
+
else:
|
|
354
|
+
# Fallback for optimizers without population
|
|
355
|
+
individuals = []
|
|
356
|
+
|
|
357
|
+
if not individuals:
|
|
358
|
+
logger.warning(f"No individuals to evaluate in generation {generation}")
|
|
359
|
+
continue
|
|
360
|
+
|
|
361
|
+
# Submit tasks
|
|
362
|
+
task_mapping = []
|
|
363
|
+
for individual in individuals:
|
|
364
|
+
task_id = self.submit_task(individual.graph)
|
|
365
|
+
task_mapping.append((individual, task_id))
|
|
366
|
+
|
|
367
|
+
# Wait for results
|
|
368
|
+
for individual, task_id in task_mapping:
|
|
369
|
+
result = self.get_result(task_id, timeout=self.task_timeout)
|
|
370
|
+
|
|
371
|
+
if result is None:
|
|
372
|
+
logger.warning(f"Task {task_id} failed")
|
|
373
|
+
individual.set_fitness(0.0)
|
|
374
|
+
else:
|
|
375
|
+
fitness = result.get("fitness", result.get("val_accuracy", 0.0))
|
|
376
|
+
individual.set_fitness(fitness)
|
|
377
|
+
self.total_evaluations += 1
|
|
378
|
+
|
|
379
|
+
# Optimizer step (if applicable)
|
|
380
|
+
if hasattr(self.optimizer, "evolve"):
|
|
381
|
+
self.optimizer.evolve()
|
|
382
|
+
|
|
383
|
+
# Statistics
|
|
384
|
+
stats = self._get_generation_stats()
|
|
385
|
+
logger.info(
|
|
386
|
+
f"Gen {generation + 1}: "
|
|
387
|
+
f"best={stats['best_fitness']:.4f}, "
|
|
388
|
+
f"avg={stats['avg_fitness']:.4f}, "
|
|
389
|
+
f"workers={len(self.workers)}/{self.num_workers}"
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
# Callback
|
|
393
|
+
if callback:
|
|
394
|
+
callback(generation + 1, stats)
|
|
395
|
+
|
|
396
|
+
# Return best architectures
|
|
397
|
+
logger.info(f"Experiment complete. Total evaluations: {self.total_evaluations}")
|
|
398
|
+
|
|
399
|
+
if hasattr(self.optimizer, "population"):
|
|
400
|
+
return self.optimizer.population.get_best_individuals(10)
|
|
401
|
+
elif hasattr(self.optimizer, "best_individual"):
|
|
402
|
+
return [self.optimizer.best_individual]
|
|
403
|
+
else:
|
|
404
|
+
return []
|
|
405
|
+
|
|
406
|
+
def get_statistics(self) -> Dict[str, Any]:
|
|
407
|
+
"""Get master node statistics."""
|
|
408
|
+
with self.worker_lock:
|
|
409
|
+
alive_workers = sum(1 for w in self.workers.values() if w.is_alive())
|
|
410
|
+
busy_workers = sum(1 for w in self.workers.values() if w.status == "busy")
|
|
411
|
+
|
|
412
|
+
with self.task_lock:
|
|
413
|
+
pending = self.pending_tasks.qsize()
|
|
414
|
+
running = len(self.running_tasks)
|
|
415
|
+
completed = len(self.completed_tasks)
|
|
416
|
+
failed = len(self.failed_tasks)
|
|
417
|
+
|
|
418
|
+
return {
|
|
419
|
+
"workers_total": len(self.workers),
|
|
420
|
+
"workers_alive": alive_workers,
|
|
421
|
+
"workers_busy": busy_workers,
|
|
422
|
+
"tasks_pending": pending,
|
|
423
|
+
"tasks_running": running,
|
|
424
|
+
"tasks_completed": completed,
|
|
425
|
+
"tasks_failed": failed,
|
|
426
|
+
"total_evaluations": self.total_evaluations,
|
|
427
|
+
}
|
|
428
|
+
|
|
429
|
+
def _wait_for_workers(self, timeout: float = 300) -> None:
|
|
430
|
+
"""Wait for expected number of workers to connect."""
|
|
431
|
+
logger.info(f"Waiting for {self.num_workers} workers (timeout: {timeout}s)...")
|
|
432
|
+
|
|
433
|
+
start_time = time.time()
|
|
434
|
+
while len(self.workers) < self.num_workers:
|
|
435
|
+
if (time.time() - start_time) > timeout:
|
|
436
|
+
raise DistributedError(
|
|
437
|
+
f"Only {len(self.workers)}/{self.num_workers} workers connected "
|
|
438
|
+
f"after {timeout}s"
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
time.sleep(1)
|
|
442
|
+
|
|
443
|
+
logger.info(f"All {self.num_workers} workers connected")
|
|
444
|
+
|
|
445
|
+
def _start_heartbeat_monitor(self) -> None:
|
|
446
|
+
"""Start background thread to monitor worker heartbeats."""
|
|
447
|
+
|
|
448
|
+
def monitor() -> None:
|
|
449
|
+
while self.running:
|
|
450
|
+
time.time()
|
|
451
|
+
|
|
452
|
+
with self.worker_lock:
|
|
453
|
+
for worker_id, worker in list(self.workers.items()):
|
|
454
|
+
# Check heartbeat timeout
|
|
455
|
+
if not worker.is_alive():
|
|
456
|
+
logger.warning(f"Worker {worker_id} heartbeat timeout")
|
|
457
|
+
worker.status = "dead"
|
|
458
|
+
|
|
459
|
+
# Reassign tasks
|
|
460
|
+
self._reassign_worker_tasks(worker_id)
|
|
461
|
+
|
|
462
|
+
time.sleep(self.heartbeat_interval)
|
|
463
|
+
|
|
464
|
+
thread = threading.Thread(target=monitor, daemon=True, name="HeartbeatMonitor")
|
|
465
|
+
thread.start()
|
|
466
|
+
logger.debug("Heartbeat monitor started")
|
|
467
|
+
|
|
468
|
+
def _start_task_dispatcher(self) -> None:
|
|
469
|
+
"""Start background thread to dispatch tasks to workers."""
|
|
470
|
+
|
|
471
|
+
def dispatch() -> None:
|
|
472
|
+
while self.running:
|
|
473
|
+
try:
|
|
474
|
+
# Get pending task
|
|
475
|
+
task = self.pending_tasks.get(timeout=1.0)
|
|
476
|
+
|
|
477
|
+
# Find available worker
|
|
478
|
+
worker = self._find_available_worker()
|
|
479
|
+
|
|
480
|
+
if worker:
|
|
481
|
+
# Dispatch task
|
|
482
|
+
self._dispatch_task_to_worker(task, worker)
|
|
483
|
+
else:
|
|
484
|
+
# No workers available, requeue
|
|
485
|
+
self.pending_tasks.put(task)
|
|
486
|
+
time.sleep(0.5)
|
|
487
|
+
|
|
488
|
+
except Empty:
|
|
489
|
+
continue
|
|
490
|
+
except Exception as e:
|
|
491
|
+
logger.error(f"Task dispatcher error: {e}")
|
|
492
|
+
|
|
493
|
+
thread = threading.Thread(target=dispatch, daemon=True, name="TaskDispatcher")
|
|
494
|
+
thread.start()
|
|
495
|
+
logger.debug("Task dispatcher started")
|
|
496
|
+
|
|
497
|
+
def _find_available_worker(self) -> Optional[WorkerInfo]:
|
|
498
|
+
"""Find an idle worker."""
|
|
499
|
+
with self.worker_lock:
|
|
500
|
+
for worker in self.workers.values():
|
|
501
|
+
if worker.is_available():
|
|
502
|
+
return worker
|
|
503
|
+
return None
|
|
504
|
+
|
|
505
|
+
def _dispatch_task_to_worker(self, task: Task, worker: WorkerInfo) -> None:
|
|
506
|
+
"""Dispatch task to specific worker."""
|
|
507
|
+
logger.debug(f"Dispatching task {task.task_id} to worker {worker.worker_id}")
|
|
508
|
+
|
|
509
|
+
# Update task status
|
|
510
|
+
task.status = "running"
|
|
511
|
+
task.worker_id = worker.worker_id
|
|
512
|
+
task.started_at = time.time()
|
|
513
|
+
|
|
514
|
+
with self.task_lock:
|
|
515
|
+
self.running_tasks[task.task_id] = task
|
|
516
|
+
|
|
517
|
+
# Update worker status
|
|
518
|
+
worker.status = "busy"
|
|
519
|
+
worker.current_task = task.task_id
|
|
520
|
+
|
|
521
|
+
# Send task via gRPC (async)
|
|
522
|
+
def send_task() -> None:
|
|
523
|
+
try:
|
|
524
|
+
channel = grpc.insecure_channel(f"{worker.host}:{worker.port}")
|
|
525
|
+
stub = worker_pb2_grpc.WorkerServiceStub(channel)
|
|
526
|
+
|
|
527
|
+
request = worker_pb2.EvaluateRequest(
|
|
528
|
+
task_id=task.task_id, architecture=task.architecture.to_json()
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
# Note: This is async, result comes via SubmitResult RPC
|
|
532
|
+
stub.Evaluate(request, timeout=self.task_timeout)
|
|
533
|
+
|
|
534
|
+
except grpc.RpcError as e:
|
|
535
|
+
logger.error(f"Failed to dispatch task {task.task_id}: {e}")
|
|
536
|
+
self._handle_task_failure(task.task_id, str(e))
|
|
537
|
+
|
|
538
|
+
# Run in thread to avoid blocking
|
|
539
|
+
threading.Thread(target=send_task, daemon=True).start()
|
|
540
|
+
|
|
541
|
+
def _handle_task_result(self, task_id: str, result: Dict[str, Any], duration: float) -> None:
|
|
542
|
+
"""Handle task result from worker."""
|
|
543
|
+
with self.task_lock:
|
|
544
|
+
if task_id not in self.running_tasks:
|
|
545
|
+
logger.warning(f"Received result for unknown task: {task_id}")
|
|
546
|
+
return
|
|
547
|
+
|
|
548
|
+
task = self.running_tasks.pop(task_id)
|
|
549
|
+
task.status = "completed"
|
|
550
|
+
task.completed_at = time.time()
|
|
551
|
+
task.result = result
|
|
552
|
+
|
|
553
|
+
self.completed_tasks[task_id] = task
|
|
554
|
+
|
|
555
|
+
# Update worker
|
|
556
|
+
if task.worker_id:
|
|
557
|
+
with self.worker_lock:
|
|
558
|
+
if task.worker_id in self.workers:
|
|
559
|
+
worker = self.workers[task.worker_id]
|
|
560
|
+
worker.status = "idle"
|
|
561
|
+
worker.current_task = None
|
|
562
|
+
worker.tasks_completed += 1
|
|
563
|
+
|
|
564
|
+
logger.debug(
|
|
565
|
+
f"Task {task_id} completed in {duration:.2f}s "
|
|
566
|
+
f"(fitness: {result.get('fitness', 'N/A')})"
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
def _handle_task_failure(self, task_id: str, error: str) -> None:
|
|
570
|
+
"""Handle task failure."""
|
|
571
|
+
with self.task_lock:
|
|
572
|
+
if task_id in self.running_tasks:
|
|
573
|
+
task = self.running_tasks.pop(task_id)
|
|
574
|
+
else:
|
|
575
|
+
logger.warning(f"Failure for unknown task: {task_id}")
|
|
576
|
+
return
|
|
577
|
+
|
|
578
|
+
task.error = error
|
|
579
|
+
task.num_retries += 1
|
|
580
|
+
|
|
581
|
+
# Retry if possible
|
|
582
|
+
if task.can_retry():
|
|
583
|
+
logger.warning(
|
|
584
|
+
f"Task {task_id} failed (retry {task.num_retries}/{task.max_retries}): {error}"
|
|
585
|
+
)
|
|
586
|
+
task.status = "pending"
|
|
587
|
+
task.worker_id = None
|
|
588
|
+
self.pending_tasks.put(task)
|
|
589
|
+
else:
|
|
590
|
+
logger.error(f"Task {task_id} failed permanently after {task.num_retries} retries")
|
|
591
|
+
task.status = "failed"
|
|
592
|
+
self.failed_tasks[task_id] = task
|
|
593
|
+
|
|
594
|
+
# Update worker
|
|
595
|
+
if task.worker_id:
|
|
596
|
+
with self.worker_lock:
|
|
597
|
+
if task.worker_id in self.workers:
|
|
598
|
+
worker = self.workers[task.worker_id]
|
|
599
|
+
worker.status = "idle"
|
|
600
|
+
worker.current_task = None
|
|
601
|
+
worker.tasks_failed += 1
|
|
602
|
+
|
|
603
|
+
def _reassign_worker_tasks(self, worker_id: str) -> None:
|
|
604
|
+
"""Reassign tasks from dead worker."""
|
|
605
|
+
with self.task_lock:
|
|
606
|
+
tasks_to_reassign = []
|
|
607
|
+
|
|
608
|
+
for _task_id, task in list(self.running_tasks.items()):
|
|
609
|
+
if task.worker_id == worker_id:
|
|
610
|
+
tasks_to_reassign.append(task)
|
|
611
|
+
|
|
612
|
+
for task in tasks_to_reassign:
|
|
613
|
+
logger.warning(f"Reassigning task {task.task_id} from dead worker {worker_id}")
|
|
614
|
+
self.running_tasks.pop(task.task_id)
|
|
615
|
+
task.status = "pending"
|
|
616
|
+
task.worker_id = None
|
|
617
|
+
task.num_retries += 1
|
|
618
|
+
|
|
619
|
+
if task.can_retry():
|
|
620
|
+
self.pending_tasks.put(task)
|
|
621
|
+
else:
|
|
622
|
+
task.status = "failed"
|
|
623
|
+
self.failed_tasks[task.task_id] = task
|
|
624
|
+
|
|
625
|
+
def _get_generation_stats(self) -> Dict[str, float]:
|
|
626
|
+
"""Get statistics for current generation."""
|
|
627
|
+
if hasattr(self.optimizer, "population"):
|
|
628
|
+
evaluated = [ind for ind in self.optimizer.population.individuals if ind.is_evaluated()]
|
|
629
|
+
|
|
630
|
+
if evaluated:
|
|
631
|
+
fitnesses = [ind.fitness for ind in evaluated]
|
|
632
|
+
return {
|
|
633
|
+
"best_fitness": max(fitnesses),
|
|
634
|
+
"avg_fitness": sum(fitnesses) / len(fitnesses),
|
|
635
|
+
"min_fitness": min(fitnesses),
|
|
636
|
+
}
|
|
637
|
+
|
|
638
|
+
return {"best_fitness": 0.0, "avg_fitness": 0.0, "min_fitness": 0.0}
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
if GRPC_AVAILABLE:
|
|
642
|
+
|
|
643
|
+
class MasterServicer(worker_pb2_grpc.MasterServiceServicer):
|
|
644
|
+
"""gRPC servicer for master node."""
|
|
645
|
+
|
|
646
|
+
def __init__(self, master: MasterNode):
|
|
647
|
+
"""Initialize servicer."""
|
|
648
|
+
self.master = master
|
|
649
|
+
|
|
650
|
+
def RegisterWorker(
|
|
651
|
+
self, request: worker_pb2.RegisterRequest, context: grpc.ServicerContext
|
|
652
|
+
) -> worker_pb2.RegisterResponse:
|
|
653
|
+
"""Handle worker registration."""
|
|
654
|
+
try:
|
|
655
|
+
worker_info = {
|
|
656
|
+
"host": request.host,
|
|
657
|
+
"port": request.port,
|
|
658
|
+
"num_gpus": request.num_gpus,
|
|
659
|
+
"gpu_ids": list(request.gpu_ids),
|
|
660
|
+
"metadata": dict(request.metadata),
|
|
661
|
+
}
|
|
662
|
+
|
|
663
|
+
success = self.master.register_worker(request.worker_id, worker_info)
|
|
664
|
+
|
|
665
|
+
return worker_pb2.RegisterResponse(
|
|
666
|
+
success=success,
|
|
667
|
+
message="Worker registered successfully",
|
|
668
|
+
master_id=self.master.master_id,
|
|
669
|
+
)
|
|
670
|
+
|
|
671
|
+
except Exception as e:
|
|
672
|
+
logger.error(f"Worker registration failed: {e}")
|
|
673
|
+
return worker_pb2.RegisterResponse(
|
|
674
|
+
success=False, message=str(e), master_id=self.master.master_id
|
|
675
|
+
)
|
|
676
|
+
|
|
677
|
+
def Heartbeat(
|
|
678
|
+
self, request: worker_pb2.HeartbeatRequest, context: grpc.ServicerContext
|
|
679
|
+
) -> worker_pb2.HeartbeatResponse:
|
|
680
|
+
"""Handle worker heartbeat."""
|
|
681
|
+
success = self.master.update_heartbeat(request.worker_id, request.status)
|
|
682
|
+
|
|
683
|
+
return worker_pb2.HeartbeatResponse(
|
|
684
|
+
acknowledged=success, should_continue=self.master.running
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
def SubmitResult(
|
|
688
|
+
self, request: worker_pb2.ResultRequest, context: grpc.ServicerContext
|
|
689
|
+
) -> worker_pb2.ResultResponse:
|
|
690
|
+
"""Handle task result submission."""
|
|
691
|
+
try:
|
|
692
|
+
if request.success:
|
|
693
|
+
result = dict(request.metrics)
|
|
694
|
+
self.master._handle_task_result(request.task_id, result, request.duration)
|
|
695
|
+
else:
|
|
696
|
+
self.master._handle_task_failure(request.task_id, request.error)
|
|
697
|
+
|
|
698
|
+
return worker_pb2.ResultResponse(acknowledged=True, message="Result received")
|
|
699
|
+
|
|
700
|
+
except Exception as e:
|
|
701
|
+
logger.error(f"Failed to handle result: {e}")
|
|
702
|
+
return worker_pb2.ResultResponse(acknowledged=False, message=str(e))
|
|
703
|
+
|
|
704
|
+
def RequestTask(
|
|
705
|
+
self, request: worker_pb2.TaskRequest, context: grpc.ServicerContext
|
|
706
|
+
) -> worker_pb2.TaskResponse:
|
|
707
|
+
"""Handle task request from worker (pull model)."""
|
|
708
|
+
# Pull model implementation for future use
|
|
709
|
+
return worker_pb2.TaskResponse(has_task=False, tasks=[])
|