morphml 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of morphml might be problematic. Click here for more details.
- morphml/__init__.py +14 -0
- morphml/api/__init__.py +26 -0
- morphml/api/app.py +326 -0
- morphml/api/auth.py +193 -0
- morphml/api/client.py +338 -0
- morphml/api/models.py +132 -0
- morphml/api/rate_limit.py +192 -0
- morphml/benchmarking/__init__.py +36 -0
- morphml/benchmarking/comparison.py +430 -0
- morphml/benchmarks/__init__.py +56 -0
- morphml/benchmarks/comparator.py +409 -0
- morphml/benchmarks/datasets.py +280 -0
- morphml/benchmarks/metrics.py +199 -0
- morphml/benchmarks/openml_suite.py +201 -0
- morphml/benchmarks/problems.py +289 -0
- morphml/benchmarks/suite.py +318 -0
- morphml/cli/__init__.py +5 -0
- morphml/cli/commands/experiment.py +329 -0
- morphml/cli/main.py +457 -0
- morphml/cli/quickstart.py +312 -0
- morphml/config.py +278 -0
- morphml/constraints/__init__.py +19 -0
- morphml/constraints/handler.py +205 -0
- morphml/constraints/predicates.py +285 -0
- morphml/core/__init__.py +3 -0
- morphml/core/crossover.py +449 -0
- morphml/core/dsl/README.md +359 -0
- morphml/core/dsl/__init__.py +72 -0
- morphml/core/dsl/ast_nodes.py +364 -0
- morphml/core/dsl/compiler.py +318 -0
- morphml/core/dsl/layers.py +368 -0
- morphml/core/dsl/lexer.py +336 -0
- morphml/core/dsl/parser.py +455 -0
- morphml/core/dsl/search_space.py +386 -0
- morphml/core/dsl/syntax.py +199 -0
- morphml/core/dsl/type_system.py +361 -0
- morphml/core/dsl/validator.py +386 -0
- morphml/core/graph/__init__.py +40 -0
- morphml/core/graph/edge.py +124 -0
- morphml/core/graph/graph.py +507 -0
- morphml/core/graph/mutations.py +409 -0
- morphml/core/graph/node.py +196 -0
- morphml/core/graph/serialization.py +361 -0
- morphml/core/graph/visualization.py +431 -0
- morphml/core/objectives/__init__.py +20 -0
- morphml/core/search/__init__.py +33 -0
- morphml/core/search/individual.py +252 -0
- morphml/core/search/parameters.py +453 -0
- morphml/core/search/population.py +375 -0
- morphml/core/search/search_engine.py +340 -0
- morphml/distributed/__init__.py +76 -0
- morphml/distributed/fault_tolerance.py +497 -0
- morphml/distributed/health_monitor.py +348 -0
- morphml/distributed/master.py +709 -0
- morphml/distributed/proto/README.md +224 -0
- morphml/distributed/proto/__init__.py +74 -0
- morphml/distributed/proto/worker.proto +170 -0
- morphml/distributed/proto/worker_pb2.py +79 -0
- morphml/distributed/proto/worker_pb2_grpc.py +423 -0
- morphml/distributed/resource_manager.py +416 -0
- morphml/distributed/scheduler.py +567 -0
- morphml/distributed/storage/__init__.py +33 -0
- morphml/distributed/storage/artifacts.py +381 -0
- morphml/distributed/storage/cache.py +366 -0
- morphml/distributed/storage/checkpointing.py +329 -0
- morphml/distributed/storage/database.py +459 -0
- morphml/distributed/worker.py +549 -0
- morphml/evaluation/__init__.py +5 -0
- morphml/evaluation/heuristic.py +237 -0
- morphml/exceptions.py +55 -0
- morphml/execution/__init__.py +5 -0
- morphml/execution/local_executor.py +350 -0
- morphml/integrations/__init__.py +28 -0
- morphml/integrations/jax_adapter.py +206 -0
- morphml/integrations/pytorch_adapter.py +530 -0
- morphml/integrations/sklearn_adapter.py +206 -0
- morphml/integrations/tensorflow_adapter.py +230 -0
- morphml/logging_config.py +93 -0
- morphml/meta_learning/__init__.py +66 -0
- morphml/meta_learning/architecture_similarity.py +277 -0
- morphml/meta_learning/experiment_database.py +240 -0
- morphml/meta_learning/knowledge_base/__init__.py +19 -0
- morphml/meta_learning/knowledge_base/embedder.py +179 -0
- morphml/meta_learning/knowledge_base/knowledge_base.py +313 -0
- morphml/meta_learning/knowledge_base/meta_features.py +265 -0
- morphml/meta_learning/knowledge_base/vector_store.py +271 -0
- morphml/meta_learning/predictors/__init__.py +27 -0
- morphml/meta_learning/predictors/ensemble.py +221 -0
- morphml/meta_learning/predictors/gnn_predictor.py +552 -0
- morphml/meta_learning/predictors/learning_curve.py +231 -0
- morphml/meta_learning/predictors/proxy_metrics.py +261 -0
- morphml/meta_learning/strategy_evolution/__init__.py +27 -0
- morphml/meta_learning/strategy_evolution/adaptive_optimizer.py +226 -0
- morphml/meta_learning/strategy_evolution/bandit.py +276 -0
- morphml/meta_learning/strategy_evolution/portfolio.py +230 -0
- morphml/meta_learning/transfer.py +581 -0
- morphml/meta_learning/warm_start.py +286 -0
- morphml/optimizers/__init__.py +74 -0
- morphml/optimizers/adaptive_operators.py +399 -0
- morphml/optimizers/bayesian/__init__.py +52 -0
- morphml/optimizers/bayesian/acquisition.py +387 -0
- morphml/optimizers/bayesian/base.py +319 -0
- morphml/optimizers/bayesian/gaussian_process.py +635 -0
- morphml/optimizers/bayesian/smac.py +534 -0
- morphml/optimizers/bayesian/tpe.py +411 -0
- morphml/optimizers/differential_evolution.py +220 -0
- morphml/optimizers/evolutionary/__init__.py +61 -0
- morphml/optimizers/evolutionary/cma_es.py +416 -0
- morphml/optimizers/evolutionary/differential_evolution.py +556 -0
- morphml/optimizers/evolutionary/encoding.py +426 -0
- morphml/optimizers/evolutionary/particle_swarm.py +449 -0
- morphml/optimizers/genetic_algorithm.py +486 -0
- morphml/optimizers/gradient_based/__init__.py +22 -0
- morphml/optimizers/gradient_based/darts.py +550 -0
- morphml/optimizers/gradient_based/enas.py +585 -0
- morphml/optimizers/gradient_based/operations.py +474 -0
- morphml/optimizers/gradient_based/utils.py +601 -0
- morphml/optimizers/hill_climbing.py +169 -0
- morphml/optimizers/multi_objective/__init__.py +56 -0
- morphml/optimizers/multi_objective/indicators.py +504 -0
- morphml/optimizers/multi_objective/nsga2.py +647 -0
- morphml/optimizers/multi_objective/visualization.py +427 -0
- morphml/optimizers/nsga2.py +308 -0
- morphml/optimizers/random_search.py +172 -0
- morphml/optimizers/simulated_annealing.py +181 -0
- morphml/plugins/__init__.py +35 -0
- morphml/plugins/custom_evaluator_example.py +81 -0
- morphml/plugins/custom_optimizer_example.py +63 -0
- morphml/plugins/plugin_system.py +454 -0
- morphml/reports/__init__.py +30 -0
- morphml/reports/generator.py +362 -0
- morphml/tracking/__init__.py +7 -0
- morphml/tracking/experiment.py +309 -0
- morphml/tracking/logger.py +301 -0
- morphml/tracking/reporter.py +357 -0
- morphml/utils/__init__.py +6 -0
- morphml/utils/checkpoint.py +189 -0
- morphml/utils/comparison.py +390 -0
- morphml/utils/export.py +407 -0
- morphml/utils/progress.py +392 -0
- morphml/utils/validation.py +392 -0
- morphml/version.py +7 -0
- morphml/visualization/__init__.py +50 -0
- morphml/visualization/analytics.py +423 -0
- morphml/visualization/architecture_diagrams.py +353 -0
- morphml/visualization/architecture_plot.py +223 -0
- morphml/visualization/convergence_plot.py +174 -0
- morphml/visualization/crossover_viz.py +386 -0
- morphml/visualization/graph_viz.py +338 -0
- morphml/visualization/pareto_plot.py +149 -0
- morphml/visualization/plotly_dashboards.py +422 -0
- morphml/visualization/population.py +309 -0
- morphml/visualization/progress.py +260 -0
- morphml-1.0.0.dist-info/METADATA +434 -0
- morphml-1.0.0.dist-info/RECORD +158 -0
- morphml-1.0.0.dist-info/WHEEL +4 -0
- morphml-1.0.0.dist-info/entry_points.txt +3 -0
- morphml-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
"""Constraint handler for managing architecture constraints."""
|
|
2
|
+
|
|
3
|
+
from typing import Dict, List
|
|
4
|
+
|
|
5
|
+
from morphml.constraints.predicates import Constraint
|
|
6
|
+
from morphml.core.graph import ModelGraph
|
|
7
|
+
from morphml.logging_config import get_logger
|
|
8
|
+
|
|
9
|
+
logger = get_logger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ConstraintHandler:
|
|
13
|
+
"""
|
|
14
|
+
Manages and evaluates constraints on architectures.
|
|
15
|
+
|
|
16
|
+
Example:
|
|
17
|
+
>>> handler = ConstraintHandler()
|
|
18
|
+
>>> handler.add_constraint(MaxParametersConstraint(1000000))
|
|
19
|
+
>>> handler.add_constraint(DepthConstraint(min_depth=5, max_depth=20))
|
|
20
|
+
>>>
|
|
21
|
+
>>> if handler.check(graph):
|
|
22
|
+
... print("Valid!")
|
|
23
|
+
>>>
|
|
24
|
+
>>> penalty = handler.total_penalty(graph)
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self):
|
|
28
|
+
"""Initialize constraint handler."""
|
|
29
|
+
self.constraints: List[Constraint] = []
|
|
30
|
+
logger.debug("Created ConstraintHandler")
|
|
31
|
+
|
|
32
|
+
def add_constraint(self, constraint: Constraint) -> None:
|
|
33
|
+
"""
|
|
34
|
+
Add a constraint.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
constraint: Constraint to add
|
|
38
|
+
"""
|
|
39
|
+
self.constraints.append(constraint)
|
|
40
|
+
logger.debug(f"Added constraint: {constraint}")
|
|
41
|
+
|
|
42
|
+
def remove_constraint(self, name: str) -> None:
|
|
43
|
+
"""
|
|
44
|
+
Remove a constraint by name.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
name: Constraint name
|
|
48
|
+
"""
|
|
49
|
+
self.constraints = [c for c in self.constraints if c.name != name]
|
|
50
|
+
logger.debug(f"Removed constraint: {name}")
|
|
51
|
+
|
|
52
|
+
def check(self, graph: ModelGraph) -> bool:
|
|
53
|
+
"""
|
|
54
|
+
Check if graph satisfies all constraints.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
graph: Graph to check
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
True if all constraints satisfied
|
|
61
|
+
"""
|
|
62
|
+
for constraint in self.constraints:
|
|
63
|
+
if not constraint.check(graph):
|
|
64
|
+
logger.debug(f"Constraint violated: {constraint.name}")
|
|
65
|
+
return False
|
|
66
|
+
return True
|
|
67
|
+
|
|
68
|
+
def get_violations(self, graph: ModelGraph) -> List[str]:
|
|
69
|
+
"""
|
|
70
|
+
Get list of violated constraints.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
graph: Graph to check
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
List of violated constraint names
|
|
77
|
+
"""
|
|
78
|
+
violations = []
|
|
79
|
+
for constraint in self.constraints:
|
|
80
|
+
if not constraint.check(graph):
|
|
81
|
+
violations.append(constraint.name)
|
|
82
|
+
return violations
|
|
83
|
+
|
|
84
|
+
def get_detailed_violations(self, graph: ModelGraph) -> List[Dict[str, any]]:
|
|
85
|
+
"""
|
|
86
|
+
Get detailed information about violated constraints.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
graph: Graph to check
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
List of dictionaries with violation details including:
|
|
93
|
+
- name: Constraint name
|
|
94
|
+
- message: Descriptive error message
|
|
95
|
+
- penalty: Penalty value
|
|
96
|
+
- actual: Actual value measured
|
|
97
|
+
- expected: Expected value/range
|
|
98
|
+
"""
|
|
99
|
+
violations = []
|
|
100
|
+
for constraint in self.constraints:
|
|
101
|
+
if not constraint.check(graph):
|
|
102
|
+
violation_info = {
|
|
103
|
+
"name": constraint.name,
|
|
104
|
+
"penalty": constraint.penalty(graph),
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
# Add constraint-specific details
|
|
108
|
+
if hasattr(constraint, "get_violation_message"):
|
|
109
|
+
violation_info["message"] = constraint.get_violation_message(graph)
|
|
110
|
+
else:
|
|
111
|
+
violation_info["message"] = f"Constraint '{constraint.name}' violated"
|
|
112
|
+
|
|
113
|
+
if hasattr(constraint, "get_actual_value"):
|
|
114
|
+
violation_info["actual"] = constraint.get_actual_value(graph)
|
|
115
|
+
|
|
116
|
+
if hasattr(constraint, "get_expected_range"):
|
|
117
|
+
violation_info["expected"] = constraint.get_expected_range()
|
|
118
|
+
|
|
119
|
+
violations.append(violation_info)
|
|
120
|
+
|
|
121
|
+
return violations
|
|
122
|
+
|
|
123
|
+
def format_violations(self, graph: ModelGraph) -> str:
|
|
124
|
+
"""
|
|
125
|
+
Format violation details as a human-readable string.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
graph: Graph to check
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
Formatted string describing all violations
|
|
132
|
+
"""
|
|
133
|
+
violations = self.get_detailed_violations(graph)
|
|
134
|
+
|
|
135
|
+
if not violations:
|
|
136
|
+
return "No constraint violations"
|
|
137
|
+
|
|
138
|
+
lines = [f"Found {len(violations)} constraint violation(s):"]
|
|
139
|
+
for i, v in enumerate(violations, 1):
|
|
140
|
+
lines.append(f"\n{i}. {v['name']}")
|
|
141
|
+
lines.append(f" Message: {v['message']}")
|
|
142
|
+
if "actual" in v:
|
|
143
|
+
lines.append(f" Actual: {v['actual']}")
|
|
144
|
+
if "expected" in v:
|
|
145
|
+
lines.append(f" Expected: {v['expected']}")
|
|
146
|
+
lines.append(f" Penalty: {v['penalty']:.4f}")
|
|
147
|
+
|
|
148
|
+
return "\n".join(lines)
|
|
149
|
+
|
|
150
|
+
def total_penalty(self, graph: ModelGraph) -> float:
|
|
151
|
+
"""
|
|
152
|
+
Calculate total penalty for all constraints.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
graph: Graph to evaluate
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Total penalty (0 = all constraints satisfied)
|
|
159
|
+
"""
|
|
160
|
+
if not self.constraints:
|
|
161
|
+
return 0.0
|
|
162
|
+
|
|
163
|
+
total = sum(c.penalty(graph) for c in self.constraints)
|
|
164
|
+
return total / len(self.constraints)
|
|
165
|
+
|
|
166
|
+
def get_penalties(self, graph: ModelGraph) -> Dict[str, float]:
|
|
167
|
+
"""
|
|
168
|
+
Get penalties for each constraint.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
graph: Graph to evaluate
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
Dictionary mapping constraint names to penalties
|
|
175
|
+
"""
|
|
176
|
+
return {c.name: c.penalty(graph) for c in self.constraints}
|
|
177
|
+
|
|
178
|
+
def apply_penalty_to_fitness(
|
|
179
|
+
self, fitness: float, graph: ModelGraph, penalty_weight: float = 0.5
|
|
180
|
+
) -> float:
|
|
181
|
+
"""
|
|
182
|
+
Apply constraint penalties to fitness score.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
fitness: Original fitness
|
|
186
|
+
graph: Architecture graph
|
|
187
|
+
penalty_weight: Weight for penalty term
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
Penalized fitness
|
|
191
|
+
"""
|
|
192
|
+
penalty = self.total_penalty(graph)
|
|
193
|
+
return fitness * (1 - penalty_weight * penalty)
|
|
194
|
+
|
|
195
|
+
def clear(self) -> None:
|
|
196
|
+
"""Remove all constraints."""
|
|
197
|
+
self.constraints.clear()
|
|
198
|
+
logger.debug("Cleared all constraints")
|
|
199
|
+
|
|
200
|
+
def __len__(self) -> int:
|
|
201
|
+
"""Return number of constraints."""
|
|
202
|
+
return len(self.constraints)
|
|
203
|
+
|
|
204
|
+
def __repr__(self) -> str:
|
|
205
|
+
return f"ConstraintHandler(constraints={len(self.constraints)})"
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
"""Constraint predicates for architecture validation."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Optional, Set
|
|
4
|
+
|
|
5
|
+
from morphml.core.graph import ModelGraph
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Constraint:
|
|
9
|
+
"""Base class for constraints."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, name: str):
|
|
12
|
+
"""Initialize constraint."""
|
|
13
|
+
self.name = name
|
|
14
|
+
|
|
15
|
+
def check(self, graph: ModelGraph) -> bool:
|
|
16
|
+
"""Check if graph satisfies constraint."""
|
|
17
|
+
raise NotImplementedError
|
|
18
|
+
|
|
19
|
+
def penalty(self, graph: ModelGraph) -> float:
|
|
20
|
+
"""Calculate penalty for constraint violation."""
|
|
21
|
+
return 0.0 if self.check(graph) else 1.0
|
|
22
|
+
|
|
23
|
+
def __repr__(self) -> str:
|
|
24
|
+
return f"{self.__class__.__name__}(name={self.name})"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class MaxParametersConstraint(Constraint):
|
|
28
|
+
"""Constraint on maximum parameters."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, max_params: int, name: str = "max_parameters"):
|
|
31
|
+
"""
|
|
32
|
+
Initialize max parameters constraint.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
max_params: Maximum allowed parameters
|
|
36
|
+
name: Constraint name
|
|
37
|
+
"""
|
|
38
|
+
super().__init__(name)
|
|
39
|
+
self.max_params = max_params
|
|
40
|
+
|
|
41
|
+
def check(self, graph: ModelGraph) -> bool:
|
|
42
|
+
"""Check if within parameter limit."""
|
|
43
|
+
return graph.estimate_parameters() <= self.max_params
|
|
44
|
+
|
|
45
|
+
def penalty(self, graph: ModelGraph) -> float:
|
|
46
|
+
"""Penalty based on excess parameters."""
|
|
47
|
+
params = graph.estimate_parameters()
|
|
48
|
+
if params <= self.max_params:
|
|
49
|
+
return 0.0
|
|
50
|
+
excess_ratio = (params - self.max_params) / self.max_params
|
|
51
|
+
return min(1.0, excess_ratio)
|
|
52
|
+
|
|
53
|
+
def get_actual_value(self, graph: ModelGraph) -> int:
|
|
54
|
+
"""Get actual parameter count."""
|
|
55
|
+
return graph.estimate_parameters()
|
|
56
|
+
|
|
57
|
+
def get_expected_range(self) -> str:
|
|
58
|
+
"""Get expected parameter range."""
|
|
59
|
+
return f"<= {self.max_params:,}"
|
|
60
|
+
|
|
61
|
+
def get_violation_message(self, graph: ModelGraph) -> str:
|
|
62
|
+
"""Get detailed violation message."""
|
|
63
|
+
actual = self.get_actual_value(graph)
|
|
64
|
+
excess = actual - self.max_params
|
|
65
|
+
return f"Architecture has {actual:,} parameters, exceeding limit by {excess:,}"
|
|
66
|
+
|
|
67
|
+
def __repr__(self) -> str:
|
|
68
|
+
return f"MaxParametersConstraint(max={self.max_params:,})"
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class MinParametersConstraint(Constraint):
|
|
72
|
+
"""Constraint on minimum parameters."""
|
|
73
|
+
|
|
74
|
+
def __init__(self, min_params: int, name: str = "min_parameters"):
|
|
75
|
+
"""Initialize min parameters constraint."""
|
|
76
|
+
super().__init__(name)
|
|
77
|
+
self.min_params = min_params
|
|
78
|
+
|
|
79
|
+
def check(self, graph: ModelGraph) -> bool:
|
|
80
|
+
"""Check if meets minimum parameters."""
|
|
81
|
+
return graph.estimate_parameters() >= self.min_params
|
|
82
|
+
|
|
83
|
+
def penalty(self, graph: ModelGraph) -> float:
|
|
84
|
+
"""Penalty based on parameter deficit."""
|
|
85
|
+
params = graph.estimate_parameters()
|
|
86
|
+
if params >= self.min_params:
|
|
87
|
+
return 0.0
|
|
88
|
+
deficit_ratio = (self.min_params - params) / self.min_params
|
|
89
|
+
return min(1.0, deficit_ratio)
|
|
90
|
+
|
|
91
|
+
def get_actual_value(self, graph: ModelGraph) -> int:
|
|
92
|
+
"""Get actual parameter count."""
|
|
93
|
+
return graph.estimate_parameters()
|
|
94
|
+
|
|
95
|
+
def get_expected_range(self) -> str:
|
|
96
|
+
"""Get expected parameter range."""
|
|
97
|
+
return f">= {self.min_params:,}"
|
|
98
|
+
|
|
99
|
+
def get_violation_message(self, graph: ModelGraph) -> str:
|
|
100
|
+
"""Get detailed violation message."""
|
|
101
|
+
actual = self.get_actual_value(graph)
|
|
102
|
+
deficit = self.min_params - actual
|
|
103
|
+
return f"Architecture has {actual:,} parameters, below minimum by {deficit:,}"
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class DepthConstraint(Constraint):
|
|
107
|
+
"""Constraint on network depth."""
|
|
108
|
+
|
|
109
|
+
def __init__(self, min_depth: int = 1, max_depth: int = 100, name: str = "depth"):
|
|
110
|
+
"""Initialize depth constraint."""
|
|
111
|
+
super().__init__(name)
|
|
112
|
+
self.min_depth = min_depth
|
|
113
|
+
self.max_depth = max_depth
|
|
114
|
+
|
|
115
|
+
def check(self, graph: ModelGraph) -> bool:
|
|
116
|
+
"""Check if depth is within range."""
|
|
117
|
+
depth = graph.get_depth()
|
|
118
|
+
return self.min_depth <= depth <= self.max_depth
|
|
119
|
+
|
|
120
|
+
def penalty(self, graph: ModelGraph) -> float:
|
|
121
|
+
"""Penalty based on depth violation."""
|
|
122
|
+
depth = graph.get_depth()
|
|
123
|
+
|
|
124
|
+
if depth < self.min_depth:
|
|
125
|
+
deficit = (self.min_depth - depth) / self.min_depth
|
|
126
|
+
return min(1.0, deficit)
|
|
127
|
+
elif depth > self.max_depth:
|
|
128
|
+
excess = (depth - self.max_depth) / self.max_depth
|
|
129
|
+
return min(1.0, excess)
|
|
130
|
+
else:
|
|
131
|
+
return 0.0
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class WidthConstraint(Constraint):
|
|
135
|
+
"""Constraint on network width."""
|
|
136
|
+
|
|
137
|
+
def __init__(self, min_width: int = 1, max_width: int = 100, name: str = "width"):
|
|
138
|
+
"""Initialize width constraint."""
|
|
139
|
+
super().__init__(name)
|
|
140
|
+
self.min_width = min_width
|
|
141
|
+
self.max_width = max_width
|
|
142
|
+
|
|
143
|
+
def check(self, graph: ModelGraph) -> bool:
|
|
144
|
+
"""Check if width is within range."""
|
|
145
|
+
width = graph.get_max_width()
|
|
146
|
+
return self.min_width <= width <= self.max_width
|
|
147
|
+
|
|
148
|
+
def penalty(self, graph: ModelGraph) -> float:
|
|
149
|
+
"""Penalty based on width violation."""
|
|
150
|
+
width = graph.get_max_width()
|
|
151
|
+
|
|
152
|
+
if width < self.min_width:
|
|
153
|
+
return (self.min_width - width) / self.min_width
|
|
154
|
+
elif width > self.max_width:
|
|
155
|
+
return (width - self.max_width) / self.max_width
|
|
156
|
+
else:
|
|
157
|
+
return 0.0
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class OperationConstraint(Constraint):
|
|
161
|
+
"""Constraint on required/forbidden operations."""
|
|
162
|
+
|
|
163
|
+
def __init__(
|
|
164
|
+
self,
|
|
165
|
+
required_ops: Optional[Set[str]] = None,
|
|
166
|
+
forbidden_ops: Optional[Set[str]] = None,
|
|
167
|
+
name: str = "operations",
|
|
168
|
+
):
|
|
169
|
+
"""
|
|
170
|
+
Initialize operation constraint.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
required_ops: Set of required operation types
|
|
174
|
+
forbidden_ops: Set of forbidden operation types
|
|
175
|
+
name: Constraint name
|
|
176
|
+
"""
|
|
177
|
+
super().__init__(name)
|
|
178
|
+
self.required_ops = required_ops or set()
|
|
179
|
+
self.forbidden_ops = forbidden_ops or set()
|
|
180
|
+
|
|
181
|
+
def check(self, graph: ModelGraph) -> bool:
|
|
182
|
+
"""Check if operations satisfy constraints."""
|
|
183
|
+
graph_ops = {node.operation for node in graph.nodes.values()}
|
|
184
|
+
|
|
185
|
+
# Check required operations
|
|
186
|
+
if self.required_ops and not self.required_ops.issubset(graph_ops):
|
|
187
|
+
return False
|
|
188
|
+
|
|
189
|
+
# Check forbidden operations
|
|
190
|
+
if self.forbidden_ops and self.forbidden_ops.intersection(graph_ops):
|
|
191
|
+
return False
|
|
192
|
+
|
|
193
|
+
return True
|
|
194
|
+
|
|
195
|
+
def penalty(self, graph: ModelGraph) -> float:
|
|
196
|
+
"""Penalty based on operation violations."""
|
|
197
|
+
graph_ops = {node.operation for node in graph.nodes.values()}
|
|
198
|
+
|
|
199
|
+
penalty = 0.0
|
|
200
|
+
|
|
201
|
+
# Penalty for missing required ops
|
|
202
|
+
if self.required_ops:
|
|
203
|
+
missing = len(self.required_ops - graph_ops)
|
|
204
|
+
penalty += missing / max(1, len(self.required_ops))
|
|
205
|
+
|
|
206
|
+
# Penalty for forbidden ops
|
|
207
|
+
if self.forbidden_ops:
|
|
208
|
+
forbidden_count = len(self.forbidden_ops.intersection(graph_ops))
|
|
209
|
+
penalty += forbidden_count / max(1, len(self.forbidden_ops))
|
|
210
|
+
|
|
211
|
+
return min(1.0, penalty)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class ConnectivityConstraint(Constraint):
|
|
215
|
+
"""Constraint on graph connectivity."""
|
|
216
|
+
|
|
217
|
+
def __init__(
|
|
218
|
+
self,
|
|
219
|
+
min_edges: Optional[int] = None,
|
|
220
|
+
max_edges: Optional[int] = None,
|
|
221
|
+
name: str = "connectivity",
|
|
222
|
+
):
|
|
223
|
+
"""Initialize connectivity constraint."""
|
|
224
|
+
super().__init__(name)
|
|
225
|
+
self.min_edges = min_edges
|
|
226
|
+
self.max_edges = max_edges
|
|
227
|
+
|
|
228
|
+
def check(self, graph: ModelGraph) -> bool:
|
|
229
|
+
"""Check if connectivity is within range."""
|
|
230
|
+
num_edges = len(graph.edges)
|
|
231
|
+
|
|
232
|
+
if self.min_edges is not None and num_edges < self.min_edges:
|
|
233
|
+
return False
|
|
234
|
+
if self.max_edges is not None and num_edges > self.max_edges:
|
|
235
|
+
return False
|
|
236
|
+
|
|
237
|
+
return True
|
|
238
|
+
|
|
239
|
+
def penalty(self, graph: ModelGraph) -> float:
|
|
240
|
+
"""Penalty based on connectivity violation."""
|
|
241
|
+
num_edges = len(graph.edges)
|
|
242
|
+
|
|
243
|
+
if self.min_edges and num_edges < self.min_edges:
|
|
244
|
+
return (self.min_edges - num_edges) / self.min_edges
|
|
245
|
+
elif self.max_edges and num_edges > self.max_edges:
|
|
246
|
+
return (num_edges - self.max_edges) / self.max_edges
|
|
247
|
+
else:
|
|
248
|
+
return 0.0
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class CompositeConstraint(Constraint):
|
|
252
|
+
"""Composite constraint combining multiple constraints."""
|
|
253
|
+
|
|
254
|
+
def __init__(self, constraints: List[Constraint], mode: str = "all", name: str = "composite"):
|
|
255
|
+
"""
|
|
256
|
+
Initialize composite constraint.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
constraints: List of constraints
|
|
260
|
+
mode: 'all' (AND) or 'any' (OR)
|
|
261
|
+
name: Constraint name
|
|
262
|
+
"""
|
|
263
|
+
super().__init__(name)
|
|
264
|
+
self.constraints = constraints
|
|
265
|
+
self.mode = mode
|
|
266
|
+
|
|
267
|
+
def check(self, graph: ModelGraph) -> bool:
|
|
268
|
+
"""Check if graph satisfies composite constraint."""
|
|
269
|
+
if self.mode == "all":
|
|
270
|
+
return all(c.check(graph) for c in self.constraints)
|
|
271
|
+
elif self.mode == "any":
|
|
272
|
+
return any(c.check(graph) for c in self.constraints)
|
|
273
|
+
else:
|
|
274
|
+
return False
|
|
275
|
+
|
|
276
|
+
def penalty(self, graph: ModelGraph) -> float:
|
|
277
|
+
"""Calculate composite penalty."""
|
|
278
|
+
penalties = [c.penalty(graph) for c in self.constraints]
|
|
279
|
+
|
|
280
|
+
if self.mode == "all":
|
|
281
|
+
return sum(penalties) / len(penalties) if penalties else 0.0
|
|
282
|
+
elif self.mode == "any":
|
|
283
|
+
return min(penalties) if penalties else 0.0
|
|
284
|
+
else:
|
|
285
|
+
return 0.0
|
morphml/core/__init__.py
ADDED