morphml 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of morphml might be problematic. Click here for more details.
- morphml/__init__.py +14 -0
- morphml/api/__init__.py +26 -0
- morphml/api/app.py +326 -0
- morphml/api/auth.py +193 -0
- morphml/api/client.py +338 -0
- morphml/api/models.py +132 -0
- morphml/api/rate_limit.py +192 -0
- morphml/benchmarking/__init__.py +36 -0
- morphml/benchmarking/comparison.py +430 -0
- morphml/benchmarks/__init__.py +56 -0
- morphml/benchmarks/comparator.py +409 -0
- morphml/benchmarks/datasets.py +280 -0
- morphml/benchmarks/metrics.py +199 -0
- morphml/benchmarks/openml_suite.py +201 -0
- morphml/benchmarks/problems.py +289 -0
- morphml/benchmarks/suite.py +318 -0
- morphml/cli/__init__.py +5 -0
- morphml/cli/commands/experiment.py +329 -0
- morphml/cli/main.py +457 -0
- morphml/cli/quickstart.py +312 -0
- morphml/config.py +278 -0
- morphml/constraints/__init__.py +19 -0
- morphml/constraints/handler.py +205 -0
- morphml/constraints/predicates.py +285 -0
- morphml/core/__init__.py +3 -0
- morphml/core/crossover.py +449 -0
- morphml/core/dsl/README.md +359 -0
- morphml/core/dsl/__init__.py +72 -0
- morphml/core/dsl/ast_nodes.py +364 -0
- morphml/core/dsl/compiler.py +318 -0
- morphml/core/dsl/layers.py +368 -0
- morphml/core/dsl/lexer.py +336 -0
- morphml/core/dsl/parser.py +455 -0
- morphml/core/dsl/search_space.py +386 -0
- morphml/core/dsl/syntax.py +199 -0
- morphml/core/dsl/type_system.py +361 -0
- morphml/core/dsl/validator.py +386 -0
- morphml/core/graph/__init__.py +40 -0
- morphml/core/graph/edge.py +124 -0
- morphml/core/graph/graph.py +507 -0
- morphml/core/graph/mutations.py +409 -0
- morphml/core/graph/node.py +196 -0
- morphml/core/graph/serialization.py +361 -0
- morphml/core/graph/visualization.py +431 -0
- morphml/core/objectives/__init__.py +20 -0
- morphml/core/search/__init__.py +33 -0
- morphml/core/search/individual.py +252 -0
- morphml/core/search/parameters.py +453 -0
- morphml/core/search/population.py +375 -0
- morphml/core/search/search_engine.py +340 -0
- morphml/distributed/__init__.py +76 -0
- morphml/distributed/fault_tolerance.py +497 -0
- morphml/distributed/health_monitor.py +348 -0
- morphml/distributed/master.py +709 -0
- morphml/distributed/proto/README.md +224 -0
- morphml/distributed/proto/__init__.py +74 -0
- morphml/distributed/proto/worker.proto +170 -0
- morphml/distributed/proto/worker_pb2.py +79 -0
- morphml/distributed/proto/worker_pb2_grpc.py +423 -0
- morphml/distributed/resource_manager.py +416 -0
- morphml/distributed/scheduler.py +567 -0
- morphml/distributed/storage/__init__.py +33 -0
- morphml/distributed/storage/artifacts.py +381 -0
- morphml/distributed/storage/cache.py +366 -0
- morphml/distributed/storage/checkpointing.py +329 -0
- morphml/distributed/storage/database.py +459 -0
- morphml/distributed/worker.py +549 -0
- morphml/evaluation/__init__.py +5 -0
- morphml/evaluation/heuristic.py +237 -0
- morphml/exceptions.py +55 -0
- morphml/execution/__init__.py +5 -0
- morphml/execution/local_executor.py +350 -0
- morphml/integrations/__init__.py +28 -0
- morphml/integrations/jax_adapter.py +206 -0
- morphml/integrations/pytorch_adapter.py +530 -0
- morphml/integrations/sklearn_adapter.py +206 -0
- morphml/integrations/tensorflow_adapter.py +230 -0
- morphml/logging_config.py +93 -0
- morphml/meta_learning/__init__.py +66 -0
- morphml/meta_learning/architecture_similarity.py +277 -0
- morphml/meta_learning/experiment_database.py +240 -0
- morphml/meta_learning/knowledge_base/__init__.py +19 -0
- morphml/meta_learning/knowledge_base/embedder.py +179 -0
- morphml/meta_learning/knowledge_base/knowledge_base.py +313 -0
- morphml/meta_learning/knowledge_base/meta_features.py +265 -0
- morphml/meta_learning/knowledge_base/vector_store.py +271 -0
- morphml/meta_learning/predictors/__init__.py +27 -0
- morphml/meta_learning/predictors/ensemble.py +221 -0
- morphml/meta_learning/predictors/gnn_predictor.py +552 -0
- morphml/meta_learning/predictors/learning_curve.py +231 -0
- morphml/meta_learning/predictors/proxy_metrics.py +261 -0
- morphml/meta_learning/strategy_evolution/__init__.py +27 -0
- morphml/meta_learning/strategy_evolution/adaptive_optimizer.py +226 -0
- morphml/meta_learning/strategy_evolution/bandit.py +276 -0
- morphml/meta_learning/strategy_evolution/portfolio.py +230 -0
- morphml/meta_learning/transfer.py +581 -0
- morphml/meta_learning/warm_start.py +286 -0
- morphml/optimizers/__init__.py +74 -0
- morphml/optimizers/adaptive_operators.py +399 -0
- morphml/optimizers/bayesian/__init__.py +52 -0
- morphml/optimizers/bayesian/acquisition.py +387 -0
- morphml/optimizers/bayesian/base.py +319 -0
- morphml/optimizers/bayesian/gaussian_process.py +635 -0
- morphml/optimizers/bayesian/smac.py +534 -0
- morphml/optimizers/bayesian/tpe.py +411 -0
- morphml/optimizers/differential_evolution.py +220 -0
- morphml/optimizers/evolutionary/__init__.py +61 -0
- morphml/optimizers/evolutionary/cma_es.py +416 -0
- morphml/optimizers/evolutionary/differential_evolution.py +556 -0
- morphml/optimizers/evolutionary/encoding.py +426 -0
- morphml/optimizers/evolutionary/particle_swarm.py +449 -0
- morphml/optimizers/genetic_algorithm.py +486 -0
- morphml/optimizers/gradient_based/__init__.py +22 -0
- morphml/optimizers/gradient_based/darts.py +550 -0
- morphml/optimizers/gradient_based/enas.py +585 -0
- morphml/optimizers/gradient_based/operations.py +474 -0
- morphml/optimizers/gradient_based/utils.py +601 -0
- morphml/optimizers/hill_climbing.py +169 -0
- morphml/optimizers/multi_objective/__init__.py +56 -0
- morphml/optimizers/multi_objective/indicators.py +504 -0
- morphml/optimizers/multi_objective/nsga2.py +647 -0
- morphml/optimizers/multi_objective/visualization.py +427 -0
- morphml/optimizers/nsga2.py +308 -0
- morphml/optimizers/random_search.py +172 -0
- morphml/optimizers/simulated_annealing.py +181 -0
- morphml/plugins/__init__.py +35 -0
- morphml/plugins/custom_evaluator_example.py +81 -0
- morphml/plugins/custom_optimizer_example.py +63 -0
- morphml/plugins/plugin_system.py +454 -0
- morphml/reports/__init__.py +30 -0
- morphml/reports/generator.py +362 -0
- morphml/tracking/__init__.py +7 -0
- morphml/tracking/experiment.py +309 -0
- morphml/tracking/logger.py +301 -0
- morphml/tracking/reporter.py +357 -0
- morphml/utils/__init__.py +6 -0
- morphml/utils/checkpoint.py +189 -0
- morphml/utils/comparison.py +390 -0
- morphml/utils/export.py +407 -0
- morphml/utils/progress.py +392 -0
- morphml/utils/validation.py +392 -0
- morphml/version.py +7 -0
- morphml/visualization/__init__.py +50 -0
- morphml/visualization/analytics.py +423 -0
- morphml/visualization/architecture_diagrams.py +353 -0
- morphml/visualization/architecture_plot.py +223 -0
- morphml/visualization/convergence_plot.py +174 -0
- morphml/visualization/crossover_viz.py +386 -0
- morphml/visualization/graph_viz.py +338 -0
- morphml/visualization/pareto_plot.py +149 -0
- morphml/visualization/plotly_dashboards.py +422 -0
- morphml/visualization/population.py +309 -0
- morphml/visualization/progress.py +260 -0
- morphml-1.0.0.dist-info/METADATA +434 -0
- morphml-1.0.0.dist-info/RECORD +158 -0
- morphml-1.0.0.dist-info/WHEEL +4 -0
- morphml-1.0.0.dist-info/entry_points.txt +3 -0
- morphml-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
"""JAX/Flax adapter for MorphML.
|
|
2
|
+
|
|
3
|
+
Converts ModelGraph to Flax Module for functional neural networks.
|
|
4
|
+
|
|
5
|
+
Example:
|
|
6
|
+
>>> from morphml.integrations import JAXAdapter
|
|
7
|
+
>>> adapter = JAXAdapter()
|
|
8
|
+
>>> model = adapter.build_model(graph)
|
|
9
|
+
>>> params = model.init(rng, x)
|
|
10
|
+
>>> output = model.apply(params, x)
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from typing import Optional, Tuple
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
import jax
|
|
17
|
+
import jax.numpy as jnp
|
|
18
|
+
from flax import linen as nn
|
|
19
|
+
|
|
20
|
+
JAX_AVAILABLE = True
|
|
21
|
+
except ImportError:
|
|
22
|
+
JAX_AVAILABLE = False
|
|
23
|
+
jax = None
|
|
24
|
+
jnp = None
|
|
25
|
+
nn = None
|
|
26
|
+
|
|
27
|
+
from morphml.core.graph import GraphNode, ModelGraph
|
|
28
|
+
from morphml.logging_config import get_logger
|
|
29
|
+
|
|
30
|
+
logger = get_logger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class JAXAdapter:
|
|
34
|
+
"""
|
|
35
|
+
Convert ModelGraph to JAX/Flax Module.
|
|
36
|
+
|
|
37
|
+
Provides functional neural network implementation using JAX and Flax.
|
|
38
|
+
|
|
39
|
+
Example:
|
|
40
|
+
>>> adapter = JAXAdapter()
|
|
41
|
+
>>> model = adapter.build_model(graph)
|
|
42
|
+
>>> rng = jax.random.PRNGKey(0)
|
|
43
|
+
>>> params = model.init(rng, jnp.ones((1, 32, 32, 3)))
|
|
44
|
+
>>> output = model.apply(params, x)
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self):
|
|
48
|
+
"""Initialize JAX adapter."""
|
|
49
|
+
if not JAX_AVAILABLE:
|
|
50
|
+
raise ImportError(
|
|
51
|
+
"JAX and Flax are required for JAXAdapter. " "Install with: pip install jax flax"
|
|
52
|
+
)
|
|
53
|
+
logger.info("Initialized JAXAdapter")
|
|
54
|
+
|
|
55
|
+
def build_model(
|
|
56
|
+
self, graph: ModelGraph, input_shape: Optional[Tuple[int, ...]] = None
|
|
57
|
+
) -> nn.Module:
|
|
58
|
+
"""
|
|
59
|
+
Build Flax module from graph.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
graph: ModelGraph to convert
|
|
63
|
+
input_shape: Input shape (H, W, C) for JAX
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Flax Module instance
|
|
67
|
+
|
|
68
|
+
Example:
|
|
69
|
+
>>> model = adapter.build_model(graph, input_shape=(32, 32, 3))
|
|
70
|
+
"""
|
|
71
|
+
return GraphModule(graph, input_shape)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class GraphModule(nn.Module):
|
|
75
|
+
"""
|
|
76
|
+
Flax module generated from ModelGraph.
|
|
77
|
+
|
|
78
|
+
Implements functional neural network following graph topology.
|
|
79
|
+
|
|
80
|
+
Attributes:
|
|
81
|
+
graph: Source ModelGraph
|
|
82
|
+
input_shape: Expected input shape
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
graph: ModelGraph
|
|
86
|
+
input_shape: Optional[Tuple[int, ...]] = None
|
|
87
|
+
|
|
88
|
+
def setup(self):
|
|
89
|
+
"""Setup layers."""
|
|
90
|
+
self.layers = {}
|
|
91
|
+
|
|
92
|
+
for node_id, node in self.graph.nodes.items():
|
|
93
|
+
layer = self._create_layer(node)
|
|
94
|
+
if layer is not None:
|
|
95
|
+
self.layers[str(node_id)] = layer
|
|
96
|
+
|
|
97
|
+
def _create_layer(self, node: GraphNode):
|
|
98
|
+
"""
|
|
99
|
+
Create Flax layer from node.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
node: GraphNode to convert
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Flax layer or None
|
|
106
|
+
"""
|
|
107
|
+
op = node.operation
|
|
108
|
+
params = node.params
|
|
109
|
+
|
|
110
|
+
if op == "input":
|
|
111
|
+
return None
|
|
112
|
+
|
|
113
|
+
elif op == "conv2d":
|
|
114
|
+
return nn.Conv(
|
|
115
|
+
features=params.get("filters", 64),
|
|
116
|
+
kernel_size=(params.get("kernel_size", 3),) * 2,
|
|
117
|
+
strides=(params.get("stride", 1),) * 2,
|
|
118
|
+
padding=params.get("padding", "SAME"),
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
elif op == "dense":
|
|
122
|
+
return nn.Dense(features=params.get("units", 10))
|
|
123
|
+
|
|
124
|
+
elif op == "batchnorm":
|
|
125
|
+
return nn.BatchNorm()
|
|
126
|
+
|
|
127
|
+
elif op == "dropout":
|
|
128
|
+
return nn.Dropout(rate=params.get("rate", 0.5))
|
|
129
|
+
|
|
130
|
+
else:
|
|
131
|
+
return None
|
|
132
|
+
|
|
133
|
+
@nn.compact
|
|
134
|
+
def __call__(self, x, training: bool = False):
|
|
135
|
+
"""
|
|
136
|
+
Forward pass.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
x: Input array
|
|
140
|
+
training: Whether in training mode
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
Output array
|
|
144
|
+
"""
|
|
145
|
+
# Track outputs
|
|
146
|
+
outputs = {}
|
|
147
|
+
|
|
148
|
+
for node in self.graph.topological_sort():
|
|
149
|
+
layer = self.layers.get(str(node.id))
|
|
150
|
+
|
|
151
|
+
# Get input
|
|
152
|
+
if not node.predecessors:
|
|
153
|
+
node_input = x
|
|
154
|
+
else:
|
|
155
|
+
pred_outputs = [outputs[pred.id] for pred in node.predecessors]
|
|
156
|
+
|
|
157
|
+
if len(pred_outputs) == 1:
|
|
158
|
+
node_input = pred_outputs[0]
|
|
159
|
+
else:
|
|
160
|
+
# Concatenate along channel dimension
|
|
161
|
+
node_input = jnp.concatenate(pred_outputs, axis=-1)
|
|
162
|
+
|
|
163
|
+
# Apply layer
|
|
164
|
+
if layer is not None:
|
|
165
|
+
if isinstance(layer, nn.Dropout):
|
|
166
|
+
outputs[node.id] = layer(node_input, deterministic=not training)
|
|
167
|
+
elif isinstance(layer, nn.BatchNorm):
|
|
168
|
+
outputs[node.id] = layer(node_input, use_running_average=not training)
|
|
169
|
+
else:
|
|
170
|
+
outputs[node.id] = layer(node_input)
|
|
171
|
+
else:
|
|
172
|
+
# Apply activation or pass through
|
|
173
|
+
op = node.operation
|
|
174
|
+
if op == "relu":
|
|
175
|
+
outputs[node.id] = nn.relu(node_input)
|
|
176
|
+
elif op == "sigmoid":
|
|
177
|
+
outputs[node.id] = nn.sigmoid(node_input)
|
|
178
|
+
elif op == "tanh":
|
|
179
|
+
outputs[node.id] = nn.tanh(node_input)
|
|
180
|
+
elif op == "softmax":
|
|
181
|
+
outputs[node.id] = nn.softmax(node_input)
|
|
182
|
+
elif op == "maxpool":
|
|
183
|
+
pool_size = node.params.get("pool_size", 2)
|
|
184
|
+
outputs[node.id] = nn.max_pool(
|
|
185
|
+
node_input,
|
|
186
|
+
window_shape=(pool_size, pool_size),
|
|
187
|
+
strides=(pool_size, pool_size),
|
|
188
|
+
)
|
|
189
|
+
elif op == "avgpool":
|
|
190
|
+
pool_size = node.params.get("pool_size", 2)
|
|
191
|
+
outputs[node.id] = nn.avg_pool(
|
|
192
|
+
node_input,
|
|
193
|
+
window_shape=(pool_size, pool_size),
|
|
194
|
+
strides=(pool_size, pool_size),
|
|
195
|
+
)
|
|
196
|
+
elif op == "flatten":
|
|
197
|
+
outputs[node.id] = node_input.reshape((node_input.shape[0], -1))
|
|
198
|
+
else:
|
|
199
|
+
outputs[node.id] = node_input
|
|
200
|
+
|
|
201
|
+
# Return output
|
|
202
|
+
output_nodes = [n for n in self.graph.nodes.values() if not n.successors]
|
|
203
|
+
if output_nodes:
|
|
204
|
+
return outputs[output_nodes[0].id]
|
|
205
|
+
else:
|
|
206
|
+
return outputs[list(outputs.keys())[-1]]
|