morphml 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of morphml might be problematic. Click here for more details.

Files changed (158) hide show
  1. morphml/__init__.py +14 -0
  2. morphml/api/__init__.py +26 -0
  3. morphml/api/app.py +326 -0
  4. morphml/api/auth.py +193 -0
  5. morphml/api/client.py +338 -0
  6. morphml/api/models.py +132 -0
  7. morphml/api/rate_limit.py +192 -0
  8. morphml/benchmarking/__init__.py +36 -0
  9. morphml/benchmarking/comparison.py +430 -0
  10. morphml/benchmarks/__init__.py +56 -0
  11. morphml/benchmarks/comparator.py +409 -0
  12. morphml/benchmarks/datasets.py +280 -0
  13. morphml/benchmarks/metrics.py +199 -0
  14. morphml/benchmarks/openml_suite.py +201 -0
  15. morphml/benchmarks/problems.py +289 -0
  16. morphml/benchmarks/suite.py +318 -0
  17. morphml/cli/__init__.py +5 -0
  18. morphml/cli/commands/experiment.py +329 -0
  19. morphml/cli/main.py +457 -0
  20. morphml/cli/quickstart.py +312 -0
  21. morphml/config.py +278 -0
  22. morphml/constraints/__init__.py +19 -0
  23. morphml/constraints/handler.py +205 -0
  24. morphml/constraints/predicates.py +285 -0
  25. morphml/core/__init__.py +3 -0
  26. morphml/core/crossover.py +449 -0
  27. morphml/core/dsl/README.md +359 -0
  28. morphml/core/dsl/__init__.py +72 -0
  29. morphml/core/dsl/ast_nodes.py +364 -0
  30. morphml/core/dsl/compiler.py +318 -0
  31. morphml/core/dsl/layers.py +368 -0
  32. morphml/core/dsl/lexer.py +336 -0
  33. morphml/core/dsl/parser.py +455 -0
  34. morphml/core/dsl/search_space.py +386 -0
  35. morphml/core/dsl/syntax.py +199 -0
  36. morphml/core/dsl/type_system.py +361 -0
  37. morphml/core/dsl/validator.py +386 -0
  38. morphml/core/graph/__init__.py +40 -0
  39. morphml/core/graph/edge.py +124 -0
  40. morphml/core/graph/graph.py +507 -0
  41. morphml/core/graph/mutations.py +409 -0
  42. morphml/core/graph/node.py +196 -0
  43. morphml/core/graph/serialization.py +361 -0
  44. morphml/core/graph/visualization.py +431 -0
  45. morphml/core/objectives/__init__.py +20 -0
  46. morphml/core/search/__init__.py +33 -0
  47. morphml/core/search/individual.py +252 -0
  48. morphml/core/search/parameters.py +453 -0
  49. morphml/core/search/population.py +375 -0
  50. morphml/core/search/search_engine.py +340 -0
  51. morphml/distributed/__init__.py +76 -0
  52. morphml/distributed/fault_tolerance.py +497 -0
  53. morphml/distributed/health_monitor.py +348 -0
  54. morphml/distributed/master.py +709 -0
  55. morphml/distributed/proto/README.md +224 -0
  56. morphml/distributed/proto/__init__.py +74 -0
  57. morphml/distributed/proto/worker.proto +170 -0
  58. morphml/distributed/proto/worker_pb2.py +79 -0
  59. morphml/distributed/proto/worker_pb2_grpc.py +423 -0
  60. morphml/distributed/resource_manager.py +416 -0
  61. morphml/distributed/scheduler.py +567 -0
  62. morphml/distributed/storage/__init__.py +33 -0
  63. morphml/distributed/storage/artifacts.py +381 -0
  64. morphml/distributed/storage/cache.py +366 -0
  65. morphml/distributed/storage/checkpointing.py +329 -0
  66. morphml/distributed/storage/database.py +459 -0
  67. morphml/distributed/worker.py +549 -0
  68. morphml/evaluation/__init__.py +5 -0
  69. morphml/evaluation/heuristic.py +237 -0
  70. morphml/exceptions.py +55 -0
  71. morphml/execution/__init__.py +5 -0
  72. morphml/execution/local_executor.py +350 -0
  73. morphml/integrations/__init__.py +28 -0
  74. morphml/integrations/jax_adapter.py +206 -0
  75. morphml/integrations/pytorch_adapter.py +530 -0
  76. morphml/integrations/sklearn_adapter.py +206 -0
  77. morphml/integrations/tensorflow_adapter.py +230 -0
  78. morphml/logging_config.py +93 -0
  79. morphml/meta_learning/__init__.py +66 -0
  80. morphml/meta_learning/architecture_similarity.py +277 -0
  81. morphml/meta_learning/experiment_database.py +240 -0
  82. morphml/meta_learning/knowledge_base/__init__.py +19 -0
  83. morphml/meta_learning/knowledge_base/embedder.py +179 -0
  84. morphml/meta_learning/knowledge_base/knowledge_base.py +313 -0
  85. morphml/meta_learning/knowledge_base/meta_features.py +265 -0
  86. morphml/meta_learning/knowledge_base/vector_store.py +271 -0
  87. morphml/meta_learning/predictors/__init__.py +27 -0
  88. morphml/meta_learning/predictors/ensemble.py +221 -0
  89. morphml/meta_learning/predictors/gnn_predictor.py +552 -0
  90. morphml/meta_learning/predictors/learning_curve.py +231 -0
  91. morphml/meta_learning/predictors/proxy_metrics.py +261 -0
  92. morphml/meta_learning/strategy_evolution/__init__.py +27 -0
  93. morphml/meta_learning/strategy_evolution/adaptive_optimizer.py +226 -0
  94. morphml/meta_learning/strategy_evolution/bandit.py +276 -0
  95. morphml/meta_learning/strategy_evolution/portfolio.py +230 -0
  96. morphml/meta_learning/transfer.py +581 -0
  97. morphml/meta_learning/warm_start.py +286 -0
  98. morphml/optimizers/__init__.py +74 -0
  99. morphml/optimizers/adaptive_operators.py +399 -0
  100. morphml/optimizers/bayesian/__init__.py +52 -0
  101. morphml/optimizers/bayesian/acquisition.py +387 -0
  102. morphml/optimizers/bayesian/base.py +319 -0
  103. morphml/optimizers/bayesian/gaussian_process.py +635 -0
  104. morphml/optimizers/bayesian/smac.py +534 -0
  105. morphml/optimizers/bayesian/tpe.py +411 -0
  106. morphml/optimizers/differential_evolution.py +220 -0
  107. morphml/optimizers/evolutionary/__init__.py +61 -0
  108. morphml/optimizers/evolutionary/cma_es.py +416 -0
  109. morphml/optimizers/evolutionary/differential_evolution.py +556 -0
  110. morphml/optimizers/evolutionary/encoding.py +426 -0
  111. morphml/optimizers/evolutionary/particle_swarm.py +449 -0
  112. morphml/optimizers/genetic_algorithm.py +486 -0
  113. morphml/optimizers/gradient_based/__init__.py +22 -0
  114. morphml/optimizers/gradient_based/darts.py +550 -0
  115. morphml/optimizers/gradient_based/enas.py +585 -0
  116. morphml/optimizers/gradient_based/operations.py +474 -0
  117. morphml/optimizers/gradient_based/utils.py +601 -0
  118. morphml/optimizers/hill_climbing.py +169 -0
  119. morphml/optimizers/multi_objective/__init__.py +56 -0
  120. morphml/optimizers/multi_objective/indicators.py +504 -0
  121. morphml/optimizers/multi_objective/nsga2.py +647 -0
  122. morphml/optimizers/multi_objective/visualization.py +427 -0
  123. morphml/optimizers/nsga2.py +308 -0
  124. morphml/optimizers/random_search.py +172 -0
  125. morphml/optimizers/simulated_annealing.py +181 -0
  126. morphml/plugins/__init__.py +35 -0
  127. morphml/plugins/custom_evaluator_example.py +81 -0
  128. morphml/plugins/custom_optimizer_example.py +63 -0
  129. morphml/plugins/plugin_system.py +454 -0
  130. morphml/reports/__init__.py +30 -0
  131. morphml/reports/generator.py +362 -0
  132. morphml/tracking/__init__.py +7 -0
  133. morphml/tracking/experiment.py +309 -0
  134. morphml/tracking/logger.py +301 -0
  135. morphml/tracking/reporter.py +357 -0
  136. morphml/utils/__init__.py +6 -0
  137. morphml/utils/checkpoint.py +189 -0
  138. morphml/utils/comparison.py +390 -0
  139. morphml/utils/export.py +407 -0
  140. morphml/utils/progress.py +392 -0
  141. morphml/utils/validation.py +392 -0
  142. morphml/version.py +7 -0
  143. morphml/visualization/__init__.py +50 -0
  144. morphml/visualization/analytics.py +423 -0
  145. morphml/visualization/architecture_diagrams.py +353 -0
  146. morphml/visualization/architecture_plot.py +223 -0
  147. morphml/visualization/convergence_plot.py +174 -0
  148. morphml/visualization/crossover_viz.py +386 -0
  149. morphml/visualization/graph_viz.py +338 -0
  150. morphml/visualization/pareto_plot.py +149 -0
  151. morphml/visualization/plotly_dashboards.py +422 -0
  152. morphml/visualization/population.py +309 -0
  153. morphml/visualization/progress.py +260 -0
  154. morphml-1.0.0.dist-info/METADATA +434 -0
  155. morphml-1.0.0.dist-info/RECORD +158 -0
  156. morphml-1.0.0.dist-info/WHEEL +4 -0
  157. morphml-1.0.0.dist-info/entry_points.txt +3 -0
  158. morphml-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,179 @@
1
+ """Architecture embedding for vector search.
2
+
3
+ Author: Eshan Roy <eshanized@proton.me>
4
+ Organization: TONMOY INFRASTRUCTURE & VISION
5
+ """
6
+
7
+ from typing import List
8
+
9
+ import numpy as np
10
+
11
+ from morphml.core.graph import ModelGraph
12
+ from morphml.logging_config import get_logger
13
+
14
+ logger = get_logger(__name__)
15
+
16
+
17
+ class ArchitectureEmbedder:
18
+ """
19
+ Embed neural architectures as fixed-size vectors.
20
+
21
+ Methods:
22
+ - Simple: Feature-based embedding (no external deps)
23
+ - GNN: Graph neural network embedding (requires PyTorch)
24
+
25
+ Args:
26
+ method: Embedding method ('simple' or 'gnn')
27
+ embedding_dim: Dimension of output vectors
28
+
29
+ Example:
30
+ >>> embedder = ArchitectureEmbedder(method='simple', embedding_dim=128)
31
+ >>> embedding = embedder.embed(graph)
32
+ >>> print(embedding.shape) # (128,)
33
+ """
34
+
35
+ def __init__(self, method: str = "simple", embedding_dim: int = 128):
36
+ """Initialize embedder."""
37
+ self.method = method
38
+ self.embedding_dim = embedding_dim
39
+
40
+ logger.info(f"Initialized ArchitectureEmbedder (method={method}, dim={embedding_dim})")
41
+
42
+ def embed(self, graph: ModelGraph) -> np.ndarray:
43
+ """
44
+ Embed architecture as vector.
45
+
46
+ Args:
47
+ graph: Architecture graph
48
+
49
+ Returns:
50
+ Embedding vector of shape [embedding_dim]
51
+ """
52
+ if self.method == "simple":
53
+ return self._embed_simple(graph)
54
+ elif self.method == "gnn":
55
+ return self._embed_gnn(graph)
56
+ else:
57
+ raise ValueError(f"Unknown embedding method: {self.method}")
58
+
59
+ def _embed_simple(self, graph: ModelGraph) -> np.ndarray:
60
+ """
61
+ Simple feature-based embedding.
62
+
63
+ Uses architectural statistics and operation counts.
64
+ """
65
+ features = []
66
+
67
+ # Operation type counts (one-hot style)
68
+ op_types = [
69
+ "conv2d",
70
+ "maxpool2d",
71
+ "avgpool2d",
72
+ "dense",
73
+ "relu",
74
+ "tanh",
75
+ "sigmoid",
76
+ "batchnorm",
77
+ "dropout",
78
+ "flatten",
79
+ "input",
80
+ "output",
81
+ ]
82
+
83
+ ops = [layer.layer_type for layer in graph.layers]
84
+ for op_type in op_types:
85
+ count = ops.count(op_type) / max(len(ops), 1) # Normalize
86
+ features.append(count)
87
+
88
+ # Graph structure features
89
+ features.append(len(graph.layers) / 100.0) # Normalize depth
90
+ features.append(graph.count_parameters() / 1000000.0) # Params in millions
91
+
92
+ # Layer-specific features
93
+ conv_filters = []
94
+ dense_units = []
95
+
96
+ for layer in graph.layers:
97
+ if layer.layer_type == "conv2d":
98
+ filters = layer.config.get("filters", 64)
99
+ conv_filters.append(filters)
100
+ elif layer.layer_type == "dense":
101
+ units = layer.config.get("units", 128)
102
+ dense_units.append(units)
103
+
104
+ # Statistics of layer sizes
105
+ if conv_filters:
106
+ features.append(np.mean(conv_filters) / 512.0)
107
+ features.append(np.max(conv_filters) / 1024.0)
108
+ else:
109
+ features.extend([0.0, 0.0])
110
+
111
+ if dense_units:
112
+ features.append(np.mean(dense_units) / 1024.0)
113
+ features.append(np.max(dense_units) / 2048.0)
114
+ else:
115
+ features.extend([0.0, 0.0])
116
+
117
+ # Convert to numpy array
118
+ features = np.array(features, dtype=np.float32)
119
+
120
+ # Pad or truncate to embedding_dim
121
+ if len(features) < self.embedding_dim:
122
+ # Pad with zeros
123
+ padding = np.zeros(self.embedding_dim - len(features), dtype=np.float32)
124
+ features = np.concatenate([features, padding])
125
+ else:
126
+ # Truncate
127
+ features = features[: self.embedding_dim]
128
+
129
+ # Normalize to unit length
130
+ norm = np.linalg.norm(features)
131
+ if norm > 0:
132
+ features = features / norm
133
+
134
+ return features
135
+
136
+ def _embed_gnn(self, graph: ModelGraph) -> np.ndarray:
137
+ """
138
+ GNN-based embedding (requires PyTorch).
139
+
140
+ Falls back to simple if PyTorch not available.
141
+ """
142
+ try:
143
+ # TODO: Implement GNN encoder
144
+ logger.warning("GNN embedding not implemented, using simple")
145
+ return self._embed_simple(graph)
146
+ except ImportError:
147
+ logger.warning("PyTorch not available, using simple embedding")
148
+ return self._embed_simple(graph)
149
+
150
+ def batch_embed(self, graphs: List[ModelGraph]) -> np.ndarray:
151
+ """
152
+ Embed multiple graphs.
153
+
154
+ Args:
155
+ graphs: List of architectures
156
+
157
+ Returns:
158
+ Array of embeddings, shape [num_graphs, embedding_dim]
159
+ """
160
+ embeddings = [self.embed(g) for g in graphs]
161
+ return np.array(embeddings)
162
+
163
+ def compute_similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
164
+ """
165
+ Compute cosine similarity between embeddings.
166
+
167
+ Args:
168
+ embedding1: First embedding
169
+ embedding2: Second embedding
170
+
171
+ Returns:
172
+ Similarity score (0-1, 1=identical)
173
+ """
174
+ # Cosine similarity
175
+ similarity = np.dot(embedding1, embedding2) / (
176
+ np.linalg.norm(embedding1) * np.linalg.norm(embedding2) + 1e-8
177
+ )
178
+
179
+ return max(0.0, min(1.0, similarity))
@@ -0,0 +1,313 @@
1
+ """Knowledge base for experiment history with vector search.
2
+
3
+ Author: Eshan Roy <eshanized@proton.me>
4
+ Organization: TONMOY INFRASTRUCTURE & VISION
5
+ """
6
+
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ from morphml.core.graph import ModelGraph
10
+ from morphml.logging_config import get_logger
11
+ from morphml.meta_learning.experiment_database import TaskMetadata
12
+ from morphml.meta_learning.knowledge_base.embedder import ArchitectureEmbedder
13
+ from morphml.meta_learning.knowledge_base.meta_features import MetaFeatureExtractor
14
+ from morphml.meta_learning.knowledge_base.vector_store import VectorStore
15
+
16
+ logger = get_logger(__name__)
17
+
18
+
19
+ class KnowledgeBase:
20
+ """
21
+ Complete knowledge base for architecture search history.
22
+
23
+ Combines:
24
+ - Architecture embedding
25
+ - Meta-feature extraction
26
+ - Vector-based similarity search
27
+ - Persistent storage
28
+
29
+ Args:
30
+ embedding_dim: Dimension of architecture embeddings
31
+ persist_path: Path for persistent storage
32
+ embedding_method: Method for embedding ('simple' or 'gnn')
33
+
34
+ Example:
35
+ >>> kb = KnowledgeBase(embedding_dim=128)
36
+ >>>
37
+ >>> # Add architecture
38
+ >>> kb.add_architecture(
39
+ ... architecture=graph,
40
+ ... task=task_metadata,
41
+ ... metrics={'accuracy': 0.92, 'latency': 0.05}
42
+ ... )
43
+ >>>
44
+ >>> # Search similar
45
+ >>> similar = kb.search_similar(query_graph, top_k=10)
46
+ >>>
47
+ >>> # Get best for task
48
+ >>> best = kb.get_best_for_task(task_metadata, top_k=5)
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ embedding_dim: int = 128,
54
+ persist_path: Optional[str] = None,
55
+ embedding_method: str = "simple",
56
+ ):
57
+ """Initialize knowledge base."""
58
+ self.embedding_dim = embedding_dim
59
+ self.persist_path = persist_path
60
+
61
+ # Components
62
+ self.embedder = ArchitectureEmbedder(method=embedding_method, embedding_dim=embedding_dim)
63
+ self.meta_extractor = MetaFeatureExtractor()
64
+ self.vector_store = VectorStore(embedding_dim=embedding_dim, persist_path=persist_path)
65
+
66
+ # Statistics
67
+ self.num_added = 0
68
+
69
+ logger.info(
70
+ f"Initialized KnowledgeBase " f"(dim={embedding_dim}, method={embedding_method})"
71
+ )
72
+
73
+ def add_architecture(
74
+ self,
75
+ architecture: ModelGraph,
76
+ task: TaskMetadata,
77
+ metrics: Dict[str, float],
78
+ experiment_id: Optional[str] = None,
79
+ ) -> int:
80
+ """
81
+ Add architecture to knowledge base.
82
+
83
+ Args:
84
+ architecture: Architecture graph
85
+ task: Task metadata
86
+ metrics: Performance metrics (accuracy, latency, etc.)
87
+ experiment_id: Optional experiment identifier
88
+
89
+ Returns:
90
+ ID of added architecture
91
+ """
92
+ # Generate embedding
93
+ embedding = self.embedder.embed(architecture)
94
+
95
+ # Extract meta-features
96
+ arch_features = self.meta_extractor.extract_architecture_features(architecture)
97
+ task_features = self.meta_extractor.extract_task_features(task)
98
+
99
+ # Create metadata
100
+ metadata = {
101
+ "experiment_id": experiment_id or f"exp_{self.num_added}",
102
+ "task_id": task.task_id,
103
+ "dataset": task.dataset_name,
104
+ "num_classes": task.num_classes,
105
+ "problem_type": task.problem_type,
106
+ "metrics": metrics,
107
+ "arch_features": arch_features,
108
+ "task_features": task_features,
109
+ "num_layers": len(architecture.layers),
110
+ "num_parameters": architecture.count_parameters(),
111
+ }
112
+
113
+ # Add to vector store
114
+ idx = self.vector_store.add(embedding=embedding, metadata=metadata, data=architecture)
115
+
116
+ self.num_added += 1
117
+
118
+ logger.debug(
119
+ f"Added architecture {idx}: "
120
+ f"task={task.dataset_name}, "
121
+ f"accuracy={metrics.get('accuracy', 0):.4f}"
122
+ )
123
+
124
+ return idx
125
+
126
+ def search_similar(
127
+ self,
128
+ query_architecture: ModelGraph,
129
+ top_k: int = 10,
130
+ task_filter: Optional[TaskMetadata] = None,
131
+ metric_threshold: Optional[float] = None,
132
+ ) -> List[Tuple[ModelGraph, float, Dict[str, Any]]]:
133
+ """
134
+ Search for similar architectures.
135
+
136
+ Args:
137
+ query_architecture: Query architecture
138
+ top_k: Number of results
139
+ task_filter: Optional task to filter by
140
+ metric_threshold: Optional minimum accuracy threshold
141
+
142
+ Returns:
143
+ List of (architecture, similarity, metadata) tuples
144
+ """
145
+ # Generate query embedding
146
+ query_embedding = self.embedder.embed(query_architecture)
147
+
148
+ # Define filter function
149
+ def filter_fn(metadata: Dict[str, Any]) -> bool:
150
+ # Filter by task
151
+ if task_filter is not None:
152
+ if metadata["dataset"] != task_filter.dataset_name:
153
+ return False
154
+
155
+ # Filter by metric
156
+ if metric_threshold is not None:
157
+ accuracy = metadata["metrics"].get("accuracy", 0)
158
+ if accuracy < metric_threshold:
159
+ return False
160
+
161
+ return True
162
+
163
+ # Search
164
+ results = self.vector_store.search(
165
+ query_embedding=query_embedding,
166
+ top_k=top_k,
167
+ filter_fn=filter_fn if (task_filter or metric_threshold) else None,
168
+ )
169
+
170
+ # Convert distance to similarity
171
+ similar = []
172
+ for _idx, distance, metadata, architecture in results:
173
+ # Convert L2 distance to similarity (0-1)
174
+ similarity = 1.0 / (1.0 + distance)
175
+ similar.append((architecture, similarity, metadata))
176
+
177
+ return similar
178
+
179
+ def get_best_for_task(
180
+ self,
181
+ task: TaskMetadata,
182
+ top_k: int = 10,
183
+ metric: str = "accuracy",
184
+ ) -> List[Tuple[ModelGraph, float]]:
185
+ """
186
+ Get best architectures for a specific task.
187
+
188
+ Args:
189
+ task: Task metadata
190
+ top_k: Number of results
191
+ metric: Metric to sort by
192
+
193
+ Returns:
194
+ List of (architecture, metric_value) tuples
195
+ """
196
+ # Get all architectures for this task
197
+ all_results = []
198
+
199
+ for idx in range(self.vector_store.size()):
200
+ _, metadata, architecture = self.vector_store.get(idx)
201
+
202
+ if metadata["dataset"] == task.dataset_name:
203
+ metric_value = metadata["metrics"].get(metric, 0.0)
204
+ all_results.append((architecture, metric_value, metadata))
205
+
206
+ # Sort by metric
207
+ all_results.sort(key=lambda x: x[1], reverse=True)
208
+
209
+ # Return top-k
210
+ return [(arch, metric_val) for arch, metric_val, _ in all_results[:top_k]]
211
+
212
+ def cluster_tasks(self, num_clusters: int = 5) -> Dict[int, List[str]]:
213
+ """
214
+ Cluster tasks by similarity.
215
+
216
+ Args:
217
+ num_clusters: Number of clusters
218
+
219
+ Returns:
220
+ Dict mapping cluster ID to list of task IDs
221
+ """
222
+ # Extract all task features
223
+ task_features_list = []
224
+ task_ids = []
225
+
226
+ for idx in range(self.vector_store.size()):
227
+ _, metadata, _ = self.vector_store.get(idx)
228
+ task_id = metadata["task_id"]
229
+
230
+ if task_id not in task_ids:
231
+ task_ids.append(task_id)
232
+ task_features = metadata["task_features"]
233
+
234
+ # Convert to vector
235
+ feature_vec = self.meta_extractor.feature_vector(task_features)
236
+ task_features_list.append(feature_vec)
237
+
238
+ if len(task_features_list) < num_clusters:
239
+ logger.warning(
240
+ f"Only {len(task_features_list)} tasks, " f"requested {num_clusters} clusters"
241
+ )
242
+ num_clusters = len(task_features_list)
243
+
244
+ # Simple k-means clustering
245
+ from sklearn.cluster import KMeans
246
+
247
+ kmeans = KMeans(n_clusters=num_clusters, random_state=42)
248
+ labels = kmeans.fit_predict(task_features_list)
249
+
250
+ # Group by cluster
251
+ clusters = {}
252
+ for task_id, label in zip(task_ids, labels):
253
+ label = int(label)
254
+ if label not in clusters:
255
+ clusters[label] = []
256
+ clusters[label].append(task_id)
257
+
258
+ logger.info(f"Clustered {len(task_ids)} tasks into {len(clusters)} clusters")
259
+
260
+ return clusters
261
+
262
+ def get_statistics(self) -> Dict[str, Any]:
263
+ """
264
+ Get knowledge base statistics.
265
+
266
+ Returns:
267
+ Statistics dictionary
268
+ """
269
+ # Count tasks and datasets
270
+ tasks = set()
271
+ datasets = set()
272
+
273
+ for idx in range(self.vector_store.size()):
274
+ _, metadata, _ = self.vector_store.get(idx)
275
+ tasks.add(metadata["task_id"])
276
+ datasets.add(metadata["dataset"])
277
+
278
+ return {
279
+ "num_architectures": self.vector_store.size(),
280
+ "num_tasks": len(tasks),
281
+ "num_datasets": len(datasets),
282
+ "embedding_dim": self.embedding_dim,
283
+ }
284
+
285
+ def save(self, path: Optional[str] = None) -> None:
286
+ """
287
+ Save knowledge base to disk.
288
+
289
+ Args:
290
+ path: Path to save to (uses persist_path if None)
291
+ """
292
+ save_path = path or self.persist_path
293
+
294
+ if save_path is None:
295
+ raise ValueError("No save path specified")
296
+
297
+ self.vector_store.save(save_path)
298
+ logger.info(f"Saved knowledge base to {save_path}")
299
+
300
+ def load(self, path: Optional[str] = None) -> None:
301
+ """
302
+ Load knowledge base from disk.
303
+
304
+ Args:
305
+ path: Path to load from (uses persist_path if None)
306
+ """
307
+ load_path = path or self.persist_path
308
+
309
+ if load_path is None:
310
+ raise ValueError("No load path specified")
311
+
312
+ self.vector_store.load(load_path)
313
+ logger.info(f"Loaded knowledge base from {load_path}")