morphml 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of morphml might be problematic. Click here for more details.

Files changed (158) hide show
  1. morphml/__init__.py +14 -0
  2. morphml/api/__init__.py +26 -0
  3. morphml/api/app.py +326 -0
  4. morphml/api/auth.py +193 -0
  5. morphml/api/client.py +338 -0
  6. morphml/api/models.py +132 -0
  7. morphml/api/rate_limit.py +192 -0
  8. morphml/benchmarking/__init__.py +36 -0
  9. morphml/benchmarking/comparison.py +430 -0
  10. morphml/benchmarks/__init__.py +56 -0
  11. morphml/benchmarks/comparator.py +409 -0
  12. morphml/benchmarks/datasets.py +280 -0
  13. morphml/benchmarks/metrics.py +199 -0
  14. morphml/benchmarks/openml_suite.py +201 -0
  15. morphml/benchmarks/problems.py +289 -0
  16. morphml/benchmarks/suite.py +318 -0
  17. morphml/cli/__init__.py +5 -0
  18. morphml/cli/commands/experiment.py +329 -0
  19. morphml/cli/main.py +457 -0
  20. morphml/cli/quickstart.py +312 -0
  21. morphml/config.py +278 -0
  22. morphml/constraints/__init__.py +19 -0
  23. morphml/constraints/handler.py +205 -0
  24. morphml/constraints/predicates.py +285 -0
  25. morphml/core/__init__.py +3 -0
  26. morphml/core/crossover.py +449 -0
  27. morphml/core/dsl/README.md +359 -0
  28. morphml/core/dsl/__init__.py +72 -0
  29. morphml/core/dsl/ast_nodes.py +364 -0
  30. morphml/core/dsl/compiler.py +318 -0
  31. morphml/core/dsl/layers.py +368 -0
  32. morphml/core/dsl/lexer.py +336 -0
  33. morphml/core/dsl/parser.py +455 -0
  34. morphml/core/dsl/search_space.py +386 -0
  35. morphml/core/dsl/syntax.py +199 -0
  36. morphml/core/dsl/type_system.py +361 -0
  37. morphml/core/dsl/validator.py +386 -0
  38. morphml/core/graph/__init__.py +40 -0
  39. morphml/core/graph/edge.py +124 -0
  40. morphml/core/graph/graph.py +507 -0
  41. morphml/core/graph/mutations.py +409 -0
  42. morphml/core/graph/node.py +196 -0
  43. morphml/core/graph/serialization.py +361 -0
  44. morphml/core/graph/visualization.py +431 -0
  45. morphml/core/objectives/__init__.py +20 -0
  46. morphml/core/search/__init__.py +33 -0
  47. morphml/core/search/individual.py +252 -0
  48. morphml/core/search/parameters.py +453 -0
  49. morphml/core/search/population.py +375 -0
  50. morphml/core/search/search_engine.py +340 -0
  51. morphml/distributed/__init__.py +76 -0
  52. morphml/distributed/fault_tolerance.py +497 -0
  53. morphml/distributed/health_monitor.py +348 -0
  54. morphml/distributed/master.py +709 -0
  55. morphml/distributed/proto/README.md +224 -0
  56. morphml/distributed/proto/__init__.py +74 -0
  57. morphml/distributed/proto/worker.proto +170 -0
  58. morphml/distributed/proto/worker_pb2.py +79 -0
  59. morphml/distributed/proto/worker_pb2_grpc.py +423 -0
  60. morphml/distributed/resource_manager.py +416 -0
  61. morphml/distributed/scheduler.py +567 -0
  62. morphml/distributed/storage/__init__.py +33 -0
  63. morphml/distributed/storage/artifacts.py +381 -0
  64. morphml/distributed/storage/cache.py +366 -0
  65. morphml/distributed/storage/checkpointing.py +329 -0
  66. morphml/distributed/storage/database.py +459 -0
  67. morphml/distributed/worker.py +549 -0
  68. morphml/evaluation/__init__.py +5 -0
  69. morphml/evaluation/heuristic.py +237 -0
  70. morphml/exceptions.py +55 -0
  71. morphml/execution/__init__.py +5 -0
  72. morphml/execution/local_executor.py +350 -0
  73. morphml/integrations/__init__.py +28 -0
  74. morphml/integrations/jax_adapter.py +206 -0
  75. morphml/integrations/pytorch_adapter.py +530 -0
  76. morphml/integrations/sklearn_adapter.py +206 -0
  77. morphml/integrations/tensorflow_adapter.py +230 -0
  78. morphml/logging_config.py +93 -0
  79. morphml/meta_learning/__init__.py +66 -0
  80. morphml/meta_learning/architecture_similarity.py +277 -0
  81. morphml/meta_learning/experiment_database.py +240 -0
  82. morphml/meta_learning/knowledge_base/__init__.py +19 -0
  83. morphml/meta_learning/knowledge_base/embedder.py +179 -0
  84. morphml/meta_learning/knowledge_base/knowledge_base.py +313 -0
  85. morphml/meta_learning/knowledge_base/meta_features.py +265 -0
  86. morphml/meta_learning/knowledge_base/vector_store.py +271 -0
  87. morphml/meta_learning/predictors/__init__.py +27 -0
  88. morphml/meta_learning/predictors/ensemble.py +221 -0
  89. morphml/meta_learning/predictors/gnn_predictor.py +552 -0
  90. morphml/meta_learning/predictors/learning_curve.py +231 -0
  91. morphml/meta_learning/predictors/proxy_metrics.py +261 -0
  92. morphml/meta_learning/strategy_evolution/__init__.py +27 -0
  93. morphml/meta_learning/strategy_evolution/adaptive_optimizer.py +226 -0
  94. morphml/meta_learning/strategy_evolution/bandit.py +276 -0
  95. morphml/meta_learning/strategy_evolution/portfolio.py +230 -0
  96. morphml/meta_learning/transfer.py +581 -0
  97. morphml/meta_learning/warm_start.py +286 -0
  98. morphml/optimizers/__init__.py +74 -0
  99. morphml/optimizers/adaptive_operators.py +399 -0
  100. morphml/optimizers/bayesian/__init__.py +52 -0
  101. morphml/optimizers/bayesian/acquisition.py +387 -0
  102. morphml/optimizers/bayesian/base.py +319 -0
  103. morphml/optimizers/bayesian/gaussian_process.py +635 -0
  104. morphml/optimizers/bayesian/smac.py +534 -0
  105. morphml/optimizers/bayesian/tpe.py +411 -0
  106. morphml/optimizers/differential_evolution.py +220 -0
  107. morphml/optimizers/evolutionary/__init__.py +61 -0
  108. morphml/optimizers/evolutionary/cma_es.py +416 -0
  109. morphml/optimizers/evolutionary/differential_evolution.py +556 -0
  110. morphml/optimizers/evolutionary/encoding.py +426 -0
  111. morphml/optimizers/evolutionary/particle_swarm.py +449 -0
  112. morphml/optimizers/genetic_algorithm.py +486 -0
  113. morphml/optimizers/gradient_based/__init__.py +22 -0
  114. morphml/optimizers/gradient_based/darts.py +550 -0
  115. morphml/optimizers/gradient_based/enas.py +585 -0
  116. morphml/optimizers/gradient_based/operations.py +474 -0
  117. morphml/optimizers/gradient_based/utils.py +601 -0
  118. morphml/optimizers/hill_climbing.py +169 -0
  119. morphml/optimizers/multi_objective/__init__.py +56 -0
  120. morphml/optimizers/multi_objective/indicators.py +504 -0
  121. morphml/optimizers/multi_objective/nsga2.py +647 -0
  122. morphml/optimizers/multi_objective/visualization.py +427 -0
  123. morphml/optimizers/nsga2.py +308 -0
  124. morphml/optimizers/random_search.py +172 -0
  125. morphml/optimizers/simulated_annealing.py +181 -0
  126. morphml/plugins/__init__.py +35 -0
  127. morphml/plugins/custom_evaluator_example.py +81 -0
  128. morphml/plugins/custom_optimizer_example.py +63 -0
  129. morphml/plugins/plugin_system.py +454 -0
  130. morphml/reports/__init__.py +30 -0
  131. morphml/reports/generator.py +362 -0
  132. morphml/tracking/__init__.py +7 -0
  133. morphml/tracking/experiment.py +309 -0
  134. morphml/tracking/logger.py +301 -0
  135. morphml/tracking/reporter.py +357 -0
  136. morphml/utils/__init__.py +6 -0
  137. morphml/utils/checkpoint.py +189 -0
  138. morphml/utils/comparison.py +390 -0
  139. morphml/utils/export.py +407 -0
  140. morphml/utils/progress.py +392 -0
  141. morphml/utils/validation.py +392 -0
  142. morphml/version.py +7 -0
  143. morphml/visualization/__init__.py +50 -0
  144. morphml/visualization/analytics.py +423 -0
  145. morphml/visualization/architecture_diagrams.py +353 -0
  146. morphml/visualization/architecture_plot.py +223 -0
  147. morphml/visualization/convergence_plot.py +174 -0
  148. morphml/visualization/crossover_viz.py +386 -0
  149. morphml/visualization/graph_viz.py +338 -0
  150. morphml/visualization/pareto_plot.py +149 -0
  151. morphml/visualization/plotly_dashboards.py +422 -0
  152. morphml/visualization/population.py +309 -0
  153. morphml/visualization/progress.py +260 -0
  154. morphml-1.0.0.dist-info/METADATA +434 -0
  155. morphml-1.0.0.dist-info/RECORD +158 -0
  156. morphml-1.0.0.dist-info/WHEEL +4 -0
  157. morphml-1.0.0.dist-info/entry_points.txt +3 -0
  158. morphml-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,581 @@
1
+ """Transfer learning across tasks for neural architecture search.
2
+
3
+ Enables architecture transfer, fine-tuning, and domain adaptation.
4
+
5
+ Author: Eshan Roy <eshanized@proton.me>
6
+ Organization: TONMOY INFRASTRUCTURE & VISION
7
+ """
8
+
9
+ from typing import Any, Dict, List, Optional
10
+
11
+ import numpy as np
12
+
13
+ from morphml.core.dsl import SearchSpace
14
+ from morphml.core.graph import ModelGraph
15
+ from morphml.logging_config import get_logger
16
+ from morphml.meta_learning.architecture_similarity import compute_task_similarity
17
+ from morphml.meta_learning.experiment_database import TaskMetadata
18
+
19
+ logger = get_logger(__name__)
20
+
21
+
22
+ class ArchitectureTransfer:
23
+ """
24
+ Transfer architectures between related tasks.
25
+
26
+ Strategies:
27
+ 1. **Direct Transfer**: Use architecture as-is
28
+ 2. **Adaptation**: Modify input/output layers for new task
29
+ 3. **Capacity Adjustment**: Scale model size based on task complexity
30
+ 4. **Progressive Transfer**: Gradually adapt through intermediate tasks
31
+
32
+ Example:
33
+ >>> # Transfer CIFAR-10 architecture to CIFAR-100
34
+ >>> source_task = TaskMetadata(
35
+ ... task_id='cifar10',
36
+ ... dataset_name='CIFAR-10',
37
+ ... num_classes=10,
38
+ ... input_size=(3, 32, 32),
39
+ ... num_samples=50000
40
+ ... )
41
+ >>>
42
+ >>> target_task = TaskMetadata(
43
+ ... task_id='cifar100',
44
+ ... dataset_name='CIFAR-100',
45
+ ... num_classes=100,
46
+ ... input_size=(3, 32, 32),
47
+ ... num_samples=50000
48
+ ... )
49
+ >>>
50
+ >>> transferred = ArchitectureTransfer.transfer_architecture(
51
+ ... source_arch=best_arch_cifar10,
52
+ ... source_task=source_task,
53
+ ... target_task=target_task,
54
+ ... adaptation_strategy='modify_head'
55
+ ... )
56
+ """
57
+
58
+ @staticmethod
59
+ def transfer_architecture(
60
+ source_arch: ModelGraph,
61
+ source_task: TaskMetadata,
62
+ target_task: TaskMetadata,
63
+ adaptation_strategy: str = "modify_head",
64
+ capacity_scale: float = 1.0,
65
+ ) -> ModelGraph:
66
+ """
67
+ Adapt architecture for new task.
68
+
69
+ Args:
70
+ source_arch: Architecture from source task
71
+ source_task: Source task metadata
72
+ target_task: Target task metadata
73
+ adaptation_strategy: How to adapt ('direct', 'modify_head', 'full_adapt')
74
+ capacity_scale: Scale model capacity (e.g., 1.5 for larger model)
75
+
76
+ Returns:
77
+ Transferred architecture adapted for target task
78
+ """
79
+ logger.info(
80
+ f"Transferring architecture from {source_task.dataset_name} "
81
+ f"to {target_task.dataset_name} (strategy={adaptation_strategy})"
82
+ )
83
+
84
+ # Clone architecture
85
+ transferred = source_arch.clone()
86
+
87
+ if adaptation_strategy == "direct":
88
+ # No modifications - direct transfer
89
+ logger.info("Direct transfer: no modifications")
90
+ return transferred
91
+
92
+ elif adaptation_strategy == "modify_head":
93
+ # Modify input and output layers only
94
+ transferred = ArchitectureTransfer._modify_io_layers(
95
+ transferred, source_task, target_task
96
+ )
97
+
98
+ elif adaptation_strategy == "full_adapt":
99
+ # Full adaptation: IO + capacity scaling
100
+ transferred = ArchitectureTransfer._modify_io_layers(
101
+ transferred, source_task, target_task
102
+ )
103
+
104
+ if capacity_scale != 1.0:
105
+ transferred = ArchitectureTransfer._scale_capacity(transferred, capacity_scale)
106
+
107
+ else:
108
+ raise ValueError(f"Unknown adaptation strategy: {adaptation_strategy}")
109
+
110
+ logger.info(
111
+ f"Transfer complete. Nodes: {len(source_arch.nodes)} → {len(transferred.nodes)}"
112
+ )
113
+
114
+ return transferred
115
+
116
+ @staticmethod
117
+ def _modify_io_layers(
118
+ graph: ModelGraph,
119
+ source_task: TaskMetadata,
120
+ target_task: TaskMetadata,
121
+ ) -> ModelGraph:
122
+ """Modify input and output layers for new task."""
123
+ modified = graph.clone()
124
+
125
+ # Modify input layer
126
+ input_nodes = [n for n in modified.nodes.values() if n.operation == "input"]
127
+
128
+ if input_nodes and source_task.input_size != target_task.input_size:
129
+ for input_node in input_nodes:
130
+ input_node.params["input_shape"] = target_task.input_size
131
+ logger.debug(
132
+ f"Updated input shape: {source_task.input_size} → {target_task.input_size}"
133
+ )
134
+
135
+ # Modify output layer
136
+ output_nodes = [n for n in modified.nodes.values() if n.operation == "output"]
137
+
138
+ if not output_nodes:
139
+ # Find dense/linear layers near the end
140
+ topo_order = modified.topological_sort()
141
+ if topo_order:
142
+ last_nodes = [n for n in topo_order[-3:] if n.operation in ["dense", "linear"]]
143
+ output_nodes = last_nodes[-1:] if last_nodes else []
144
+
145
+ if output_nodes and source_task.num_classes != target_task.num_classes:
146
+ for output_node in output_nodes:
147
+ if "units" in output_node.params:
148
+ output_node.params["units"] = target_task.num_classes
149
+ logger.debug(
150
+ f"Updated output units: {source_task.num_classes} → {target_task.num_classes}"
151
+ )
152
+ elif "out_features" in output_node.params:
153
+ output_node.params["out_features"] = target_task.num_classes
154
+
155
+ return modified
156
+
157
+ @staticmethod
158
+ def _scale_capacity(graph: ModelGraph, scale: float) -> ModelGraph:
159
+ """
160
+ Scale model capacity by adjusting layer widths.
161
+
162
+ Args:
163
+ graph: Architecture to scale
164
+ scale: Scaling factor (e.g., 1.5 = 50% wider)
165
+
166
+ Returns:
167
+ Scaled architecture
168
+ """
169
+ scaled = graph.clone()
170
+
171
+ for node in scaled.nodes.values():
172
+ # Scale convolutional filters
173
+ if node.operation in ["conv2d", "conv1d"] and "filters" in node.params:
174
+ original = node.params["filters"]
175
+ node.params["filters"] = int(original * scale)
176
+ logger.debug(f"Scaled {node.id} filters: {original} → {node.params['filters']}")
177
+
178
+ # Scale dense units
179
+ elif node.operation in ["dense", "linear"] and "units" in node.params:
180
+ # Don't scale output layer
181
+ if node.operation != "output":
182
+ original = node.params["units"]
183
+ node.params["units"] = int(original * scale)
184
+ logger.debug(f"Scaled {node.id} units: {original} → {node.params['units']}")
185
+
186
+ return scaled
187
+
188
+ @staticmethod
189
+ def evaluate_transferability(
190
+ source_task: TaskMetadata,
191
+ target_task: TaskMetadata,
192
+ method: str = "comprehensive",
193
+ ) -> float:
194
+ """
195
+ Estimate how well architectures will transfer between tasks.
196
+
197
+ Args:
198
+ source_task: Source task metadata
199
+ target_task: Target task metadata
200
+ method: Scoring method ('comprehensive', 'simple', 'similarity')
201
+
202
+ Returns:
203
+ Transferability score (0-1, higher = better transfer expected)
204
+ """
205
+ if method == "simple":
206
+ return ArchitectureTransfer._simple_transferability(source_task, target_task)
207
+ elif method == "similarity":
208
+ return compute_task_similarity(source_task, target_task)
209
+ else: # comprehensive
210
+ return ArchitectureTransfer._comprehensive_transferability(source_task, target_task)
211
+
212
+ @staticmethod
213
+ def _simple_transferability(
214
+ source_task: TaskMetadata,
215
+ target_task: TaskMetadata,
216
+ ) -> float:
217
+ """Simple heuristic-based transferability."""
218
+ # Same dataset = perfect transfer
219
+ if source_task.dataset_name == target_task.dataset_name:
220
+ return 1.0
221
+
222
+ # Different problem types = poor transfer
223
+ if source_task.problem_type != target_task.problem_type:
224
+ return 0.3
225
+
226
+ # Compare task properties
227
+ size_ratio = min(source_task.num_samples, target_task.num_samples) / max(
228
+ source_task.num_samples, target_task.num_samples
229
+ )
230
+
231
+ class_ratio = min(source_task.num_classes, target_task.num_classes) / max(
232
+ source_task.num_classes, target_task.num_classes
233
+ )
234
+
235
+ # Average the ratios
236
+ transferability = (size_ratio + class_ratio) / 2.0
237
+
238
+ return float(np.clip(transferability, 0.0, 1.0))
239
+
240
+ @staticmethod
241
+ def _comprehensive_transferability(
242
+ source_task: TaskMetadata,
243
+ target_task: TaskMetadata,
244
+ ) -> float:
245
+ """Comprehensive transferability scoring."""
246
+ scores = []
247
+
248
+ # 1. Problem type match (0.3 weight)
249
+ if source_task.problem_type == target_task.problem_type:
250
+ scores.append((1.0, 0.3))
251
+ else:
252
+ scores.append((0.2, 0.3))
253
+
254
+ # 2. Dataset family (0.2 weight)
255
+ dataset_families = {
256
+ "CIFAR-10": "cifar",
257
+ "CIFAR-100": "cifar",
258
+ "ImageNet": "imagenet",
259
+ "ImageNet-16": "imagenet",
260
+ "MNIST": "mnist",
261
+ "Fashion-MNIST": "mnist",
262
+ }
263
+
264
+ source_family = dataset_families.get(source_task.dataset_name, "other")
265
+ target_family = dataset_families.get(target_task.dataset_name, "other")
266
+
267
+ if source_family == target_family:
268
+ scores.append((1.0, 0.2))
269
+ else:
270
+ scores.append((0.5, 0.2))
271
+
272
+ # 3. Input size similarity (0.2 weight)
273
+ if source_task.input_size == target_task.input_size:
274
+ input_score = 1.0
275
+ else:
276
+ # Compute dimensionality ratio
277
+ source_dims = np.prod(source_task.input_size) if source_task.input_size else 1
278
+ target_dims = np.prod(target_task.input_size) if target_task.input_size else 1
279
+ ratio = min(source_dims, target_dims) / max(source_dims, target_dims)
280
+ input_score = float(ratio)
281
+
282
+ scores.append((input_score, 0.2))
283
+
284
+ # 4. Class count similarity (0.15 weight)
285
+ class_ratio = min(source_task.num_classes, target_task.num_classes) / max(
286
+ source_task.num_classes, target_task.num_classes
287
+ )
288
+ scores.append((class_ratio, 0.15))
289
+
290
+ # 5. Dataset size similarity (0.15 weight)
291
+ size_ratio = min(source_task.num_samples, target_task.num_samples) / max(
292
+ source_task.num_samples, target_task.num_samples
293
+ )
294
+ scores.append((size_ratio, 0.15))
295
+
296
+ # Weighted average
297
+ total_score = sum(score * weight for score, weight in scores)
298
+
299
+ return float(np.clip(total_score, 0.0, 1.0))
300
+
301
+ @staticmethod
302
+ def recommend_transfer_strategy(
303
+ source_task: TaskMetadata,
304
+ target_task: TaskMetadata,
305
+ ) -> Dict[str, Any]:
306
+ """
307
+ Recommend optimal transfer strategy.
308
+
309
+ Returns:
310
+ Dict with:
311
+ - strategy: Recommended strategy name
312
+ - capacity_scale: Recommended capacity scaling
313
+ - reasoning: Explanation
314
+ """
315
+ transferability = ArchitectureTransfer.evaluate_transferability(source_task, target_task)
316
+
317
+ # Determine strategy
318
+ if transferability > 0.9:
319
+ strategy = "direct"
320
+ capacity_scale = 1.0
321
+ reasoning = "Tasks are very similar - direct transfer recommended"
322
+
323
+ elif transferability > 0.7:
324
+ strategy = "modify_head"
325
+ capacity_scale = 1.0
326
+ reasoning = "Tasks are similar - only modify input/output layers"
327
+
328
+ elif transferability > 0.4:
329
+ # Check if target is larger
330
+ if target_task.num_classes > source_task.num_classes * 2:
331
+ strategy = "full_adapt"
332
+ capacity_scale = 1.5
333
+ reasoning = "Target task larger - increase capacity"
334
+ elif target_task.num_classes < source_task.num_classes / 2:
335
+ strategy = "full_adapt"
336
+ capacity_scale = 0.7
337
+ reasoning = "Target task smaller - reduce capacity"
338
+ else:
339
+ strategy = "full_adapt"
340
+ capacity_scale = 1.0
341
+ reasoning = "Moderate similarity - full adaptation"
342
+
343
+ else:
344
+ strategy = "full_adapt"
345
+ capacity_scale = 1.2
346
+ reasoning = "Tasks differ significantly - extensive adaptation needed"
347
+
348
+ return {
349
+ "strategy": strategy,
350
+ "capacity_scale": capacity_scale,
351
+ "transferability": transferability,
352
+ "reasoning": reasoning,
353
+ }
354
+
355
+
356
+ class FineTuningStrategy:
357
+ """
358
+ Fine-tuning protocols for transferred architectures.
359
+
360
+ When training a transferred architecture on new task, different
361
+ strategies can be employed:
362
+
363
+ 1. **Full Fine-Tuning**: Train all parameters
364
+ 2. **Freeze Early Layers**: Only train later layers
365
+ 3. **Differential Learning Rates**: Lower LR for early layers
366
+ 4. **Progressive Unfreezing**: Gradually unfreeze layers
367
+
368
+ Note: This class provides configuration. Actual training requires
369
+ a framework-specific implementation (PyTorch, TensorFlow, etc.)
370
+
371
+ Example:
372
+ >>> strategy = FineTuningStrategy.get_strategy(
373
+ ... transfer_type='similar_tasks',
374
+ ... model_depth=20
375
+ ... )
376
+ >>>
377
+ >>> print(strategy)
378
+ {
379
+ 'method': 'freeze_early',
380
+ 'freeze_ratio': 0.5,
381
+ 'learning_rate': 0.001,
382
+ 'num_epochs': 50,
383
+ ...
384
+ }
385
+ """
386
+
387
+ @staticmethod
388
+ def get_strategy(
389
+ transfer_type: str,
390
+ model_depth: int,
391
+ target_dataset_size: int = 50000,
392
+ ) -> Dict[str, Any]:
393
+ """
394
+ Get recommended fine-tuning strategy.
395
+
396
+ Args:
397
+ transfer_type: Type of transfer
398
+ - 'same_domain': Same dataset family
399
+ - 'similar_tasks': Related but different tasks
400
+ - 'distant_tasks': Very different tasks
401
+ model_depth: Number of layers in model
402
+ target_dataset_size: Size of target dataset
403
+
404
+ Returns:
405
+ Fine-tuning configuration dict
406
+ """
407
+ if transfer_type == "same_domain":
408
+ # Minimal adaptation needed
409
+ return {
410
+ "method": "freeze_early",
411
+ "freeze_ratio": 0.7, # Freeze first 70% of layers
412
+ "learning_rate": 1e-4,
413
+ "num_epochs": 30,
414
+ "warmup_epochs": 5,
415
+ "description": "Freeze early layers, train classification head",
416
+ }
417
+
418
+ elif transfer_type == "similar_tasks":
419
+ # Moderate adaptation
420
+ if target_dataset_size < 10000:
421
+ # Small dataset - be conservative
422
+ return {
423
+ "method": "freeze_early",
424
+ "freeze_ratio": 0.5,
425
+ "learning_rate": 5e-4,
426
+ "num_epochs": 50,
427
+ "warmup_epochs": 10,
428
+ "description": "Freeze half, careful training to avoid overfitting",
429
+ }
430
+ else:
431
+ # Larger dataset - can train more
432
+ return {
433
+ "method": "differential_lr",
434
+ "early_lr": 1e-5,
435
+ "late_lr": 1e-3,
436
+ "num_epochs": 75,
437
+ "warmup_epochs": 10,
438
+ "description": "Lower LR for early layers, higher for later layers",
439
+ }
440
+
441
+ else: # distant_tasks
442
+ # Full adaptation needed
443
+ return {
444
+ "method": "progressive_unfreezing",
445
+ "initial_freeze_ratio": 0.8,
446
+ "unfreeze_schedule": [0.6, 0.4, 0.2, 0.0], # Gradually unfreeze
447
+ "learning_rate": 1e-3,
448
+ "num_epochs": 100,
449
+ "warmup_epochs": 15,
450
+ "description": "Progressively unfreeze layers during training",
451
+ }
452
+
453
+ @staticmethod
454
+ def generate_freeze_mask(
455
+ num_layers: int,
456
+ freeze_ratio: float,
457
+ ) -> List[bool]:
458
+ """
459
+ Generate mask indicating which layers to freeze.
460
+
461
+ Args:
462
+ num_layers: Total number of layers
463
+ freeze_ratio: Ratio of layers to freeze (0-1)
464
+
465
+ Returns:
466
+ List of booleans (True = freeze, False = train)
467
+ """
468
+ freeze_until = int(num_layers * freeze_ratio)
469
+
470
+ mask = []
471
+ for i in range(num_layers):
472
+ mask.append(i < freeze_until)
473
+
474
+ return mask
475
+
476
+
477
+ class MultiTaskNAS:
478
+ """
479
+ Neural Architecture Search for multiple tasks simultaneously.
480
+
481
+ Finds architectures that perform well across a distribution of tasks,
482
+ enabling better transfer learning and generalization.
483
+
484
+ Args:
485
+ tasks: List of tasks to optimize for
486
+ search_space: Architecture search space
487
+ task_weights: Optional weights for each task (default: equal)
488
+
489
+ Example:
490
+ >>> tasks = [
491
+ ... TaskMetadata(task_id='cifar10', ...),
492
+ ... TaskMetadata(task_id='cifar100', ...),
493
+ ... TaskMetadata(task_id='svhn', ...),
494
+ ... ]
495
+ >>>
496
+ >>> multi_nas = MultiTaskNAS(
497
+ ... tasks=tasks,
498
+ ... search_space=space,
499
+ ... task_weights=[0.5, 0.3, 0.2] # Prioritize CIFAR-10
500
+ ... )
501
+ >>>
502
+ >>> # This would be integrated with an optimizer
503
+ >>> # best_arch = multi_nas.search(optimizer, evaluator)
504
+ """
505
+
506
+ def __init__(
507
+ self,
508
+ tasks: List[TaskMetadata],
509
+ search_space: SearchSpace,
510
+ task_weights: Optional[List[float]] = None,
511
+ ):
512
+ """Initialize multi-task NAS."""
513
+ self.tasks = tasks
514
+ self.search_space = search_space
515
+
516
+ # Normalize task weights
517
+ if task_weights is None:
518
+ task_weights = [1.0 / len(tasks)] * len(tasks)
519
+
520
+ total = sum(task_weights)
521
+ self.task_weights = [w / total for w in task_weights]
522
+
523
+ logger.info(
524
+ f"Initialized MultiTaskNAS with {len(tasks)} tasks " f"(weights={self.task_weights})"
525
+ )
526
+
527
+ def evaluate_multi_task_fitness(
528
+ self,
529
+ architecture: ModelGraph,
530
+ evaluator_fn,
531
+ ) -> Dict[str, float]:
532
+ """
533
+ Evaluate architecture on all tasks.
534
+
535
+ Args:
536
+ architecture: Architecture to evaluate
537
+ evaluator_fn: Function to evaluate arch on a task
538
+ Signature: evaluator_fn(arch, task) -> fitness
539
+
540
+ Returns:
541
+ Dict with per-task fitness and weighted average
542
+ """
543
+ fitnesses = {}
544
+
545
+ for task, _weight in zip(self.tasks, self.task_weights):
546
+ # Adapt architecture for this task
547
+ adapted = ArchitectureTransfer.transfer_architecture(
548
+ source_arch=architecture,
549
+ source_task=self.tasks[0], # Use first task as "source"
550
+ target_task=task,
551
+ adaptation_strategy="modify_head",
552
+ )
553
+
554
+ # Evaluate
555
+ fitness = evaluator_fn(adapted, task)
556
+ fitnesses[task.task_id] = fitness
557
+
558
+ # Weighted average
559
+ weighted_fitness = sum(
560
+ fitnesses[task.task_id] * weight for task, weight in zip(self.tasks, self.task_weights)
561
+ )
562
+
563
+ fitnesses["weighted_average"] = weighted_fitness
564
+
565
+ return fitnesses
566
+
567
+ def create_multi_task_evaluator(self, base_evaluator):
568
+ """
569
+ Create evaluator function for multi-task optimization.
570
+
571
+ Returns function that can be used with any optimizer.
572
+ """
573
+
574
+ def multi_task_eval(architecture: ModelGraph) -> float:
575
+ results = self.evaluate_multi_task_fitness(
576
+ architecture,
577
+ lambda arch, task: base_evaluator(arch),
578
+ )
579
+ return results["weighted_average"]
580
+
581
+ return multi_task_eval