eizen-nsga 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,191 @@
1
+ Metadata-Version: 2.4
2
+ Name: eizen-nsga
3
+ Version: 1.0.0
4
+ Summary: Simple inference package for NSGA-Net trained models
5
+ Home-page: https://github.com/eizen-ai/nsga-net
6
+ Author: Eizen.ai Team
7
+ Author-email: "Eizen.ai Team" <support@eizen.ai>
8
+ License: MIT
9
+ Project-URL: Homepage, https://eizen.ai
10
+ Project-URL: Repository, https://github.com/eizen-ai/nsga-net
11
+ Project-URL: Issues, https://github.com/eizen-ai/nsga-net/issues
12
+ Classifier: Development Status :: 4 - Beta
13
+ Classifier: Intended Audience :: Developers
14
+ Classifier: Intended Audience :: Science/Research
15
+ Classifier: License :: OSI Approved :: MIT License
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Programming Language :: Python :: 3.8
18
+ Classifier: Programming Language :: Python :: 3.9
19
+ Classifier: Programming Language :: Python :: 3.10
20
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
21
+ Requires-Python: >=3.8
22
+ Description-Content-Type: text/markdown
23
+ License-File: LICENSE
24
+ Requires-Dist: torch>=1.9.0
25
+ Requires-Dist: torchvision>=0.10.0
26
+ Requires-Dist: numpy>=1.19.0
27
+ Requires-Dist: pillow>=8.0.0
28
+ Provides-Extra: sota
29
+ Requires-Dist: ultralytics>=8.0.0; extra == "sota"
30
+ Provides-Extra: nn
31
+ Requires-Dist: pandas>=1.0.0; extra == "nn"
32
+ Requires-Dist: scikit-learn>=0.24.0; extra == "nn"
33
+ Provides-Extra: transformer
34
+ Requires-Dist: transformers>=4.0.0; extra == "transformer"
35
+ Requires-Dist: tokenizers>=0.10.0; extra == "transformer"
36
+ Provides-Extra: all
37
+ Requires-Dist: ultralytics>=8.0.0; extra == "all"
38
+ Requires-Dist: pandas>=1.0.0; extra == "all"
39
+ Requires-Dist: scikit-learn>=0.24.0; extra == "all"
40
+ Requires-Dist: transformers>=4.0.0; extra == "all"
41
+ Requires-Dist: tokenizers>=0.10.0; extra == "all"
42
+ Dynamic: author
43
+ Dynamic: home-page
44
+ Dynamic: license-file
45
+ Dynamic: requires-python
46
+
47
+ # eizen-nsga
48
+
49
+ **Standalone** inference package for NSGA-Net trained models. Works just like ultralytics YOLO.
50
+
51
+ ✅ **Fully independent** - no external dependencies on NSGA-Net codebase
52
+ ✅ **Simple API** - load model from ZIP or URL, run inference
53
+ ✅ **PyPI ready** - publish and use anywhere
54
+ ✅ **Supports URLs** - load models from HTTP, S3, GCS, and more
55
+
56
+ ## Installation
57
+
58
+ ```bash
59
+ # For SOTA (computer vision) models
60
+ pip install eizen-nsga[sota]
61
+
62
+ # For NN (tabular) models
63
+ pip install eizen-nsga[nn]
64
+
65
+ # For Transformer (LLM) models
66
+ pip install eizen-nsga[transformer]
67
+
68
+ # Install all dependencies
69
+ pip install eizen-nsga[all]
70
+ ```
71
+
72
+ ## Quick Start
73
+
74
+ ```python
75
+ from eizen_nsga import NASModel
76
+
77
+ # Load model from local ZIP file
78
+ model = NASModel("trained_model.zip")
79
+
80
+ # Or load from URL (HTTP, HTTPS, S3, GCS, etc.)
81
+ model = NASModel("https://example.com/models/model.zip")
82
+ model = NASModel("s3://mybucket/models/model.zip")
83
+
84
+ # Run inference (accepts file path, PIL Image, or numpy array)
85
+ result = model.predict("image.jpg")
86
+
87
+ # Classification results
88
+ print(f"Top class: {result['class']}")
89
+ print(f"Confidence: {result['confidence']:.2%}")
90
+ print(f"Top 5 predictions: {result['predictions']}")
91
+
92
+ # Or call directly like YOLO
93
+ result = model("image.jpg")
94
+ ```
95
+
96
+ ## Detection Models
97
+
98
+ ```python
99
+ from nsga_inference import NASModel
100
+
101
+ # Load detection model
102
+ model = NASModel("detector.zip")
103
+
104
+ # Run detection
105
+ result = model.predict("image.jpg", conf_threshold=0.3, iou_threshold=0.45)
106
+
107
+ # Detection results
108
+ print(f"Found {result['count']} objects")
109
+ for det in result['detections']:
110
+ print(f" Box: {det['box']}, Class: {det['class']}, Conf: {det['confidence']}")
111
+ ```
112
+
113
+ ## Model Info
114
+
115
+ ```python
116
+ # Get model information
117
+ info = model.info()
118
+ print(f"Category: {info['category']}")
119
+ print(f"Task: {info['task']}")
120
+ print(f"Backbone: {info['backbone']}")
121
+ print(f"Classes: {info['num_classes']}")
122
+ ```
123
+
124
+ ## Device Selection
125
+
126
+ ```python
127
+ # Auto-detect device (default)
128
+ model = NASModel("model.zip")
129
+
130
+ # Specify device
131
+ model = NASModel("model.zip", device="cuda")
132
+ model = NASModel("model.zip", device="cpu")
133
+
134
+ # Move to different device
135
+ model.to("cuda")
136
+ ```
137
+
138
+ ## How It Works
139
+
140
+ 1. **Extract**: Automatically extracts model ZIP file
141
+ 2. **Parse**: Reads `log.txt` to get model configuration (genome, backbone, etc.)
142
+ 3. **Build**: Constructs model architecture from genome encoding
143
+ 4. **Load**: Loads trained weights from `weights.pt`
144
+ 5. **Predict**: Runs inference on your images
145
+
146
+ ## Supported Model Categories
147
+
148
+ - **SOTA**: Computer vision models (YOLOv8, ResNet, EfficientNet backbones) ✅
149
+ - **NN**: Tabular data models (coming soon)
150
+ - **Transformer**: LLM models (coming soon)
151
+
152
+ ## Model ZIP Structure
153
+
154
+ Your trained model ZIP should contain:
155
+ ```
156
+ model.zip
157
+ ├── weights.pt # Trained model weights
158
+ └── log.txt # Model configuration (genome, backbone, etc.)
159
+ ```
160
+
161
+ ## Package is Fully Standalone
162
+
163
+ This package includes all necessary SOTA modules internally:
164
+ - Model builders (micro/macro architectures)
165
+ - Genome encoders/decoders
166
+ - Neural operations
167
+ - Backbone registry
168
+ - Detection heads
169
+
170
+ No need to have the NSGA-Net training codebase installed!
171
+
172
+ ## Requirements
173
+
174
+ **Core dependencies:**
175
+ - torch >= 1.9.0
176
+ - torchvision >= 0.10.0
177
+ - numpy >= 1.19.0
178
+ - pillow >= 8.0.0
179
+
180
+ **Optional (for SOTA models):**
181
+ - ultralytics >= 8.0.0
182
+
183
+ ## License
184
+
185
+ MIT License - Copyright (c) 2024 Eizen.ai Team
186
+
187
+ ## Links
188
+
189
+ - Homepage: https://eizen.ai
190
+ - GitHub: https://github.com/eizen-ai/nsga-net
191
+ - Issues: https://github.com/eizen-ai/nsga-net/issues
@@ -0,0 +1,13 @@
1
+ eizen_nsga-1.0.0.dist-info/licenses/LICENSE,sha256=_HF_TY-jv6lFXR4QcG1iTnMYgDC3MfslE5S0OvUGlQA,1091
2
+ sota/__init__.py,sha256=3hyX_CR86TvV2RYQaES0FhZdtklUBu3DB7HAv1Z0yGo,525
3
+ sota/detection_heads.py,sha256=Tx7qgqc0_BBtaIenxNHg_dIwbDHyOQz3QsamNEa9gX0,11206
4
+ sota/macro_encoding.py,sha256=lwV8Nptwt014UWW-2rAmivYqS-5Jfs3g8DCmzhRzHsA,3037
5
+ sota/macro_models.py,sha256=okitzfXyaRPjGb5lAh9BNwqUVjTaUqaj7jwypKKVVm0,12254
6
+ sota/micro_encoding.py,sha256=LrhgzhgnGMP2Up7Uqy0zT5nedzefqDXl2HJ8YL2TOVM,4820
7
+ sota/micro_models.py,sha256=piLM6jt4LEiiQuox6AL6I-K0W0XcGLcL2qHWZgPbpvA,22062
8
+ sota/micro_operations.py,sha256=3E2JL1eSzWYmtI4X9omBkocWcAoZslXdkIj8CS1h4dQ,8702
9
+ sota/model_registry.py,sha256=lQQKUKCVWEQEql5mDZrlwRJEIVaqdokNWEEN4SWrnA0,17750
10
+ eizen_nsga-1.0.0.dist-info/METADATA,sha256=uQP38LPaimMrKbFtdFhUXuBaOsIMaQntO1oJRvjDLK4,5268
11
+ eizen_nsga-1.0.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
12
+ eizen_nsga-1.0.0.dist-info/top_level.txt,sha256=pPnDGDIcVIcY4aYyBjID0_x2rcMP9_Q3G-ZHKSZGwQI,5
13
+ eizen_nsga-1.0.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.10.2)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Eizen.ai Team
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1 @@
1
+ sota
sota/__init__.py ADDED
@@ -0,0 +1,22 @@
1
+ """
2
+ SOTA module for NSGA-Net inference package
3
+ Contains model builders, encoders, and operations for neural architecture search
4
+ """
5
+
6
+ from . import micro_models
7
+ from . import macro_models
8
+ from . import micro_encoding
9
+ from . import macro_encoding
10
+ from . import micro_operations
11
+ from . import model_registry
12
+ from . import detection_heads
13
+
14
+ __all__ = [
15
+ 'micro_models',
16
+ 'macro_models',
17
+ 'micro_encoding',
18
+ 'macro_encoding',
19
+ 'micro_operations',
20
+ 'model_registry',
21
+ 'detection_heads',
22
+ ]
@@ -0,0 +1,312 @@
1
+ """
2
+ Detection heads for object detection tasks.
3
+ Implements YOLO-style detection head that works with adapter networks.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class YOLODetectionHead(nn.Module):
15
+ """
16
+ YOLO-style detection head for object detection.
17
+
18
+ Predicts bounding boxes, objectness scores, and class probabilities on a grid.
19
+ Each grid cell can predict multiple anchors.
20
+
21
+ Output format: [batch, num_anchors * (5 + num_classes), H, W]
22
+ where 5 = (x, y, w, h, objectness)
23
+ """
24
+
25
+ def __init__(self, in_channels, num_classes, num_anchors=3, grid_size=7):
26
+ """
27
+ Args:
28
+ in_channels: Number of input channels from backbone/adapter
29
+ num_classes: Number of object classes
30
+ num_anchors: Number of anchor boxes per grid cell
31
+ grid_size: Size of prediction grid (e.g., 7 means 7x7 grid)
32
+ """
33
+ super(YOLODetectionHead, self).__init__()
34
+
35
+ self.in_channels = in_channels
36
+ self.num_classes = num_classes
37
+ self.num_anchors = num_anchors
38
+ self.grid_size = grid_size
39
+
40
+ # Number of predictions per anchor: 4 (bbox) + 1 (objectness) + num_classes
41
+ self.num_predictions = 5 + num_classes
42
+
43
+ # Detection layers
44
+ # Reduce channels first for efficiency
45
+ self.conv1 = nn.Sequential(
46
+ nn.Conv2d(in_channels, 512, kernel_size=3, padding=1, bias=False),
47
+ nn.BatchNorm2d(512),
48
+ nn.LeakyReLU(0.1, inplace=True)
49
+ )
50
+
51
+ self.conv2 = nn.Sequential(
52
+ nn.Conv2d(512, 256, kernel_size=3, padding=1, bias=False),
53
+ nn.BatchNorm2d(256),
54
+ nn.LeakyReLU(0.1, inplace=True)
55
+ )
56
+
57
+ # Final prediction layer
58
+ self.pred_conv = nn.Conv2d(
59
+ 256,
60
+ num_anchors * self.num_predictions,
61
+ kernel_size=1
62
+ )
63
+
64
+ # Adaptive pooling to ensure output is correct grid size
65
+ self.adaptive_pool = nn.AdaptiveAvgPool2d((grid_size, grid_size))
66
+
67
+ self._initialize_weights()
68
+
69
+ logger.info(f"YOLODetectionHead initialized: {num_anchors} anchors, "
70
+ f"{num_classes} classes, {grid_size}x{grid_size} grid")
71
+
72
+ def _initialize_weights(self):
73
+ """Initialize weights for detection head."""
74
+ for m in self.modules():
75
+ if isinstance(m, nn.Conv2d):
76
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
77
+ if m.bias is not None:
78
+ nn.init.constant_(m.bias, 0)
79
+ elif isinstance(m, nn.BatchNorm2d):
80
+ nn.init.constant_(m.weight, 1)
81
+ nn.init.constant_(m.bias, 0)
82
+
83
+ def forward(self, x):
84
+ """
85
+ Args:
86
+ x: Input features of shape (batch, in_channels, H, W)
87
+
88
+ Returns:
89
+ predictions: Tensor of shape (batch, num_anchors, grid_size, grid_size, 5 + num_classes)
90
+ where last dim is [x, y, w, h, objectness, class_probs...]
91
+ """
92
+ # Apply detection layers
93
+ x = self.conv1(x)
94
+ x = self.conv2(x)
95
+
96
+ # Ensure correct spatial size
97
+ if x.size(2) != self.grid_size or x.size(3) != self.grid_size:
98
+ x = F.interpolate(x, size=(self.grid_size, self.grid_size),
99
+ mode='bilinear', align_corners=False)
100
+
101
+ # Prediction layer
102
+ predictions = self.pred_conv(x)
103
+
104
+ # Reshape: [batch, num_anchors * predictions, H, W]
105
+ # -> [batch, num_anchors, predictions, H, W]
106
+ # -> [batch, num_anchors, H, W, predictions]
107
+ batch_size = predictions.size(0)
108
+ predictions = predictions.view(
109
+ batch_size,
110
+ self.num_anchors,
111
+ self.num_predictions,
112
+ self.grid_size,
113
+ self.grid_size
114
+ )
115
+ predictions = predictions.permute(0, 1, 3, 4, 2).contiguous()
116
+
117
+ # predictions shape: [batch, num_anchors, grid_size, grid_size, 5 + num_classes]
118
+ return predictions
119
+
120
+
121
+ class MultiScaleDetectionHead(nn.Module):
122
+ """
123
+ Multi-scale detection head with predictions at multiple resolutions.
124
+ Similar to YOLOv3/v4 multi-scale approach.
125
+ """
126
+
127
+ def __init__(self, in_channels, num_classes, num_anchors=3, grid_sizes=[7, 14]):
128
+ """
129
+ Args:
130
+ in_channels: Number of input channels from backbone/adapter
131
+ num_classes: Number of object classes
132
+ num_anchors: Number of anchor boxes per grid cell
133
+ grid_sizes: List of grid sizes for multi-scale predictions
134
+ """
135
+ super(MultiScaleDetectionHead, self).__init__()
136
+
137
+ self.in_channels = in_channels
138
+ self.num_classes = num_classes
139
+ self.num_anchors = num_anchors
140
+ self.grid_sizes = grid_sizes
141
+
142
+ # Create detection head for each scale
143
+ self.detection_heads = nn.ModuleList([
144
+ YOLODetectionHead(in_channels, num_classes, num_anchors, grid_size)
145
+ for grid_size in grid_sizes
146
+ ])
147
+
148
+ logger.info(f"MultiScaleDetectionHead initialized with scales: {grid_sizes}")
149
+
150
+ def forward(self, x):
151
+ """
152
+ Args:
153
+ x: Input features of shape (batch, in_channels, H, W)
154
+
155
+ Returns:
156
+ predictions: List of tensors, one per scale
157
+ Each tensor has shape [batch, num_anchors, grid_size, grid_size, 5 + num_classes]
158
+ """
159
+ predictions = []
160
+ for head in self.detection_heads:
161
+ pred = head(x)
162
+ predictions.append(pred)
163
+
164
+ return predictions
165
+
166
+
167
+ def decode_predictions(predictions, conf_threshold=0.5, grid_size=7):
168
+ """
169
+ Decode raw predictions to actual bounding boxes.
170
+
171
+ Args:
172
+ predictions: Tensor of shape [batch, num_anchors, grid_size, grid_size, 5 + num_classes]
173
+ conf_threshold: Confidence threshold for filtering predictions
174
+ grid_size: Size of the prediction grid
175
+
176
+ Returns:
177
+ boxes: List of tensors, one per image in batch
178
+ Each tensor has shape [N, 4] with boxes in format [x_center, y_center, w, h] (normalized)
179
+ scores: List of tensors, one per image
180
+ Each tensor has shape [N] with objectness * class_prob scores
181
+ labels: List of tensors, one per image
182
+ Each tensor has shape [N] with predicted class indices
183
+ """
184
+ batch_size = predictions.size(0)
185
+ num_anchors = predictions.size(1)
186
+
187
+ # Apply sigmoid to objectness and class predictions
188
+ # predictions[..., :4] are bbox coords (will apply sigmoid)
189
+ # predictions[..., 4] is objectness
190
+ # predictions[..., 5:] are class logits
191
+
192
+ pred_boxes = torch.sigmoid(predictions[..., :4]) # [batch, anchors, H, W, 4]
193
+ pred_obj = torch.sigmoid(predictions[..., 4]) # [batch, anchors, H, W]
194
+ pred_cls = predictions[..., 5:] # [batch, anchors, H, W, num_classes]
195
+
196
+ # Apply softmax to class predictions
197
+ pred_cls = F.softmax(pred_cls, dim=-1)
198
+
199
+ # Get class scores and labels
200
+ class_scores, class_labels = torch.max(pred_cls, dim=-1) # [batch, anchors, H, W]
201
+
202
+ # Combine objectness and class confidence
203
+ confidence = pred_obj * class_scores # [batch, anchors, H, W]
204
+
205
+ # Decode boxes for each image in batch
206
+ batch_boxes = []
207
+ batch_scores = []
208
+ batch_labels = []
209
+
210
+ for b in range(batch_size):
211
+ # Get predictions above threshold
212
+ mask = confidence[b] > conf_threshold
213
+
214
+ if mask.sum() == 0:
215
+ # No detections
216
+ batch_boxes.append(torch.zeros((0, 4), device=predictions.device))
217
+ batch_scores.append(torch.zeros((0,), device=predictions.device))
218
+ batch_labels.append(torch.zeros((0,), dtype=torch.long, device=predictions.device))
219
+ continue
220
+
221
+ # Get valid predictions
222
+ valid_boxes = pred_boxes[b][mask] # [N, 4]
223
+ valid_scores = confidence[b][mask] # [N]
224
+ valid_labels = class_labels[b][mask] # [N]
225
+
226
+ # Convert grid-relative coordinates to image-relative
227
+ # Note: pred_boxes are already in [0, 1] range due to sigmoid
228
+ # They represent offsets within grid cells, so we need to adjust
229
+
230
+ batch_boxes.append(valid_boxes)
231
+ batch_scores.append(valid_scores)
232
+ batch_labels.append(valid_labels)
233
+
234
+ return batch_boxes, batch_scores, batch_labels
235
+
236
+
237
+ def non_max_suppression(boxes, scores, labels, iou_threshold=0.4):
238
+ """
239
+ Apply Non-Maximum Suppression to remove duplicate detections.
240
+
241
+ Args:
242
+ boxes: Tensor of shape [N, 4] in format [x_center, y_center, w, h]
243
+ scores: Tensor of shape [N]
244
+ labels: Tensor of shape [N]
245
+ iou_threshold: IoU threshold for NMS
246
+
247
+ Returns:
248
+ keep_indices: Tensor of indices to keep
249
+ """
250
+ if len(boxes) == 0:
251
+ return torch.tensor([], dtype=torch.long, device=boxes.device)
252
+
253
+ # Convert center format to corner format for NMS
254
+ # [x_center, y_center, w, h] -> [x1, y1, x2, y2]
255
+ x_center, y_center, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
256
+ x1 = x_center - w / 2
257
+ y1 = y_center - h / 2
258
+ x2 = x_center + w / 2
259
+ y2 = y_center + h / 2
260
+ boxes_corners = torch.stack([x1, y1, x2, y2], dim=1)
261
+
262
+ # Apply NMS per class
263
+ keep_indices = []
264
+ unique_labels = labels.unique()
265
+
266
+ for label in unique_labels:
267
+ label_mask = labels == label
268
+ label_boxes = boxes_corners[label_mask]
269
+ label_scores = scores[label_mask]
270
+ label_indices = torch.where(label_mask)[0]
271
+
272
+ # Use torchvision NMS
273
+ keep = torch.ops.torchvision.nms(label_boxes, label_scores, iou_threshold)
274
+ keep_indices.append(label_indices[keep])
275
+
276
+ if len(keep_indices) > 0:
277
+ keep_indices = torch.cat(keep_indices)
278
+ else:
279
+ keep_indices = torch.tensor([], dtype=torch.long, device=boxes.device)
280
+
281
+ return keep_indices
282
+
283
+
284
+ def build_detection_head(head_type, in_channels, num_classes, **kwargs):
285
+ """
286
+ Factory function to build detection head.
287
+
288
+ Args:
289
+ head_type: Type of detection head ('yolo' or 'multiscale')
290
+ in_channels: Number of input channels
291
+ num_classes: Number of object classes
292
+ **kwargs: Additional arguments for the head
293
+
294
+ Returns:
295
+ Detection head module
296
+ """
297
+ if head_type == 'yolo':
298
+ return YOLODetectionHead(
299
+ in_channels=in_channels,
300
+ num_classes=num_classes,
301
+ num_anchors=kwargs.get('num_anchors', 3),
302
+ grid_size=kwargs.get('grid_size', 7)
303
+ )
304
+ elif head_type == 'multiscale':
305
+ return MultiScaleDetectionHead(
306
+ in_channels=in_channels,
307
+ num_classes=num_classes,
308
+ num_anchors=kwargs.get('num_anchors', 3),
309
+ grid_sizes=kwargs.get('grid_sizes', [7, 14])
310
+ )
311
+ else:
312
+ raise ValueError(f"Unknown detection head type: {head_type}")
sota/macro_encoding.py ADDED
@@ -0,0 +1,100 @@
1
+ # Macro Search Space Encoding for YOLO Adapter
2
+ # Search for connectivity pattern in phases (Genetic CNN style)
3
+ # Based on NSGA-Net reference implementation
4
+
5
+ import numpy as np
6
+ from collections import namedtuple
7
+
8
+ # MacroGenotype: List of DAG matrices per phase
9
+ MacroGenotype = namedtuple('MacroGenotype', 'ops')
10
+
11
+ def phase_dencode(phase_bit_string):
12
+ """
13
+ Convert flat connectivity bits to adjacency structure.
14
+ Returns a list of lists where genome[i] contains connections for node i.
15
+ """
16
+ # Calculate number of nodes from bit length
17
+ # Length L = n*(n+1)/2 + 1
18
+ # 2*L = n^2 + n + 2
19
+ # n^2 + n + (2 - 2L) = 0
20
+ # n = (-1 + sqrt(1 - 4(2-2L)))/2 = (-1 + sqrt(8L - 7))/2
21
+ # This formula from reference: n = int(np.sqrt(2 * len(phase_bit_string) - 7/4) - 1/2)
22
+ # seems to correspond to L = n(n+1)/2 + 1.
23
+
24
+ n = int(np.sqrt(2 * len(phase_bit_string) - 7/4) - 1/2)
25
+ genome = []
26
+
27
+ # For each node i (0 to n-1)
28
+ for i in range(n):
29
+ operator = []
30
+ # Check connections from all previous nodes 0 to i
31
+ # Triangular number indexing
32
+ for j in range(i + 1):
33
+ idx = int(i * (i + 1) / 2 + j)
34
+ operator.append(phase_bit_string[idx])
35
+ genome.append(operator)
36
+
37
+ # Last bit is residual skip for the whole phase
38
+ genome.append([phase_bit_string[-1]])
39
+ return genome
40
+
41
+
42
+ def convert(bit_string, n_phases=3):
43
+ """
44
+ Convert flat integer array to list of phase bit-strings
45
+ """
46
+ # Assumes bit_string is a np array
47
+ if bit_string.shape[0] % n_phases != 0:
48
+ # If mismatch (e.g. during initialization before bound setup fix), try to infer
49
+ n_phases = len(bit_string) // (len(bit_string) // n_phases)
50
+
51
+ phase_length = bit_string.shape[0] // n_phases
52
+ genome = []
53
+ for i in range(0, bit_string.shape[0], phase_length):
54
+ sub_genome = bit_string[i:i+phase_length]
55
+ genome.append(sub_genome.tolist())
56
+
57
+ return genome
58
+
59
+
60
+ def decode(genome):
61
+ """
62
+ Decode list of phase bit-strings into Genotype structure (adjacency lists)
63
+ """
64
+ genotype = []
65
+ for gene in genome:
66
+ genotype.append(phase_dencode(gene))
67
+
68
+ return MacroGenotype(ops=genotype)
69
+
70
+
71
+ def get_search_space_bounds(n_nodes=4, n_phases=3):
72
+ """
73
+ Get bounds for macro search space (Binary encoding)
74
+
75
+ Args:
76
+ n_nodes: number of nodes inside each phase
77
+ n_phases: number of phases to stack
78
+
79
+ Returns:
80
+ n_var: number of bits
81
+ lb: lower bounds (all 0)
82
+ ub: upper bounds (all 1)
83
+ """
84
+ # Calculate bits per phase
85
+ # n_nodes nodes.
86
+ # Node 1 inputs: 1 bit (from Input)
87
+ # Node 2 inputs: 2 bits (from Input, Node 1)
88
+ # ...
89
+ # Node n inputs: n bits
90
+ # Total connection bits = n*(n+1)/2
91
+ # Plus 1 bit for global skip connection
92
+
93
+ bits_per_phase = int(n_nodes * (n_nodes + 1) / 2 + 1)
94
+ n_var = bits_per_phase * n_phases
95
+
96
+ lb = np.zeros(n_var)
97
+ ub = np.ones(n_var) # Binary encoding
98
+
99
+ return n_var, lb, ub
100
+