eizen-nsga 1.0.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
eizen_nsga/__init__.py ADDED
@@ -0,0 +1,15 @@
1
+ """
2
+ eizen-nsga: Simple inference package for NSGA-Net trained models
3
+
4
+ Usage:
5
+ from eizen_nsga import NASModel
6
+
7
+ model = NASModel("trained_model.zip")
8
+ result = model.predict("image.jpg")
9
+ """
10
+
11
+ __version__ = "1.0.0"
12
+
13
+ from .model import NASModel
14
+
15
+ __all__ = ['NASModel']
eizen_nsga/model.py ADDED
@@ -0,0 +1,372 @@
1
+ """
2
+ NASModel - Universal model class for NSGA-Net trained models
3
+
4
+ Simple interface like ultralytics YOLO:
5
+ model = NASModel("model.zip")
6
+ result = model("image.jpg")
7
+ """
8
+
9
+ import os
10
+ import re
11
+ import ast
12
+ import sys
13
+ import zipfile
14
+ import tempfile
15
+ import shutil
16
+ from typing import Union, Optional, Dict, Any, List
17
+ from urllib.parse import urlparse
18
+ from urllib.request import urlretrieve
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import numpy as np
23
+ from PIL import Image
24
+ import torchvision.transforms as transforms
25
+
26
+
27
+ class NASModel:
28
+ """
29
+ Universal model class for NSGA-Net trained models.
30
+
31
+ Simple interface like YOLO:
32
+ model = NASModel("model.zip")
33
+ result = model.predict("image.jpg")
34
+
35
+ Supports:
36
+ - SOTA: Computer vision (YOLOv8, ResNet, EfficientNet)
37
+ - NN: Tabular data (coming soon)
38
+ - Transformer: LLMs (coming soon)
39
+ """
40
+
41
+ def __init__(self, path: str, device: Optional[str] = None):
42
+ """
43
+ Load a trained NSGA-Net model.
44
+
45
+ Args:
46
+ path: Path or URL to model ZIP file or directory
47
+ Supports: local paths, HTTP/HTTPS URLs, S3, GCS, etc.
48
+ device: 'cuda', 'cpu', or None (auto-detect)
49
+ """
50
+ self.source_path = path
51
+ self._temp_dir = None
52
+
53
+ # Extract ZIP if needed
54
+ model_dir = self._extract_model(path)
55
+
56
+ # Find weights and config
57
+ weights_path, log_path = self._find_model_files(model_dir)
58
+
59
+ # Parse configuration
60
+ self.config = self._parse_log(log_path)
61
+ self.model_category = self.config.get('modelCategory', 'SOTA')
62
+ self.task = self.config.get('task', 'classification')
63
+
64
+ # Set device
65
+ if device is None:
66
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
67
+ self.device = device
68
+
69
+ # Build and load model
70
+ self.model = self._build_model()
71
+ self._load_weights(weights_path)
72
+ self.model.eval()
73
+
74
+ print(f"✅ Loaded {self.model_category} model for {self.task}")
75
+
76
+ def _extract_model(self, path: str) -> str:
77
+ """Extract ZIP if needed, return model directory. Supports local paths and URLs."""
78
+ # Check if path is a URL
79
+ parsed = urlparse(path)
80
+ is_url = parsed.scheme in ('http', 'https', 's3', 'gs', 'ftp')
81
+
82
+ if is_url:
83
+ # Download from URL to temporary file
84
+ print(f"📥 Downloading model from URL...")
85
+ self._temp_dir = tempfile.mkdtemp()
86
+ local_zip = os.path.join(self._temp_dir, 'model.zip')
87
+ urlretrieve(path, local_zip)
88
+ print(f"✅ Downloaded to {local_zip}")
89
+ path = local_zip
90
+
91
+ # Extract ZIP if it's a zip file
92
+ if path.endswith('.zip') and os.path.isfile(path):
93
+ if not self._temp_dir:
94
+ self._temp_dir = tempfile.mkdtemp()
95
+ with zipfile.ZipFile(path, 'r') as zf:
96
+ zf.extractall(self._temp_dir)
97
+ return self._temp_dir
98
+ return path
99
+
100
+ def _find_model_files(self, model_dir: str) -> tuple:
101
+ """Find weights.pt and log.txt in model directory."""
102
+ weights_path = None
103
+ log_path = None
104
+
105
+ # Check root first
106
+ if os.path.exists(os.path.join(model_dir, 'weights.pt')):
107
+ weights_path = os.path.join(model_dir, 'weights.pt')
108
+ log_path = os.path.join(model_dir, 'log.txt')
109
+ else:
110
+ # Search subdirectories
111
+ for root, dirs, files in os.walk(model_dir):
112
+ if 'weights.pt' in files:
113
+ weights_path = os.path.join(root, 'weights.pt')
114
+ log_path = os.path.join(root, 'log.txt')
115
+ break
116
+
117
+ if not weights_path or not os.path.exists(weights_path):
118
+ raise FileNotFoundError(f"weights.pt not found in {model_dir}")
119
+
120
+ return weights_path, log_path
121
+
122
+ def _parse_log(self, log_path: str) -> Dict[str, Any]:
123
+ """Parse log.txt to extract configuration."""
124
+ config = {}
125
+
126
+ if not log_path or not os.path.exists(log_path):
127
+ return config
128
+
129
+ with open(log_path, 'r') as f:
130
+ content = f.read()
131
+
132
+ # Extract all key-value pairs
133
+ patterns = {
134
+ 'modelCategory': r'modelCategory[:\s]*(\w+)',
135
+ 'task': r'Task[:\s]*(\w+)',
136
+ 'genome': r'Genome[:\s]*(.+)',
137
+ 'backbone': r'backbone[:\s]*(\S+)',
138
+ 'search_space': r'search_space[:\s]*(\w+)',
139
+ 'num_cells': r'num_cells[:\s]*(\d+)',
140
+ 'num_classes': r'num_classes[:\s]*(\d+)',
141
+ 'image_size': r'image_size[:\s]*(\d+)',
142
+ }
143
+
144
+ for key, pattern in patterns.items():
145
+ match = re.search(pattern, content, re.IGNORECASE | re.MULTILINE)
146
+ if match:
147
+ value = match.group(1).strip()
148
+ try:
149
+ config[key] = ast.literal_eval(value)
150
+ except:
151
+ config[key] = value
152
+
153
+ return config
154
+
155
+ def _build_model(self) -> nn.Module:
156
+ """Build model architecture based on category."""
157
+ category = self.model_category.upper()
158
+
159
+ if category == 'SOTA':
160
+ return self._build_sota_model()
161
+ elif category == 'NN':
162
+ return self._build_nn_model()
163
+ elif category == 'TRANSFORMER':
164
+ return self._build_transformer_model()
165
+ else:
166
+ raise ValueError(f"Unknown model category: {category}")
167
+
168
+ def _build_sota_model(self) -> nn.Module:
169
+ """Build SOTA adapter model from genome."""
170
+ genome = self.config.get('genome')
171
+ backbone = self.config.get('backbone', 'yolov8n')
172
+ search_space = self.config.get('search_space', 'micro')
173
+ num_cells = int(self.config.get('num_cells', 3))
174
+ num_classes = int(self.config.get('num_classes', 10))
175
+
176
+ if genome is None:
177
+ raise ValueError("genome not found in log.txt")
178
+
179
+ # Import SOTA modules from package
180
+ from .sota.micro_models import GeneralAdapterNetwork
181
+ from .sota.macro_models import MacroAdapterNetwork
182
+ from .sota import micro_encoding, macro_encoding
183
+
184
+ # Decode and build
185
+ if search_space == 'macro':
186
+ genotype = macro_encoding.decode(genome)
187
+ model = MacroAdapterNetwork(
188
+ genotype=genotype,
189
+ backbone_name=backbone,
190
+ num_classes=num_classes,
191
+ num_cells=num_cells,
192
+ pretrained=True,
193
+ freeze_backbone=True
194
+ )
195
+ else:
196
+ genotype = micro_encoding.decode(genome)
197
+ model = GeneralAdapterNetwork(
198
+ genotype=genotype,
199
+ backbone_name=backbone,
200
+ num_classes=num_classes,
201
+ num_cells=num_cells,
202
+ pretrained=True,
203
+ freeze_backbone=True
204
+ )
205
+
206
+ return model.to(self.device)
207
+
208
+ def _build_nn_model(self) -> nn.Module:
209
+ """Build NN model (tabular)."""
210
+ raise NotImplementedError("NN model support coming soon")
211
+
212
+ def _build_transformer_model(self) -> nn.Module:
213
+ """Build Transformer model (LLM)."""
214
+ raise NotImplementedError("Transformer model support coming soon")
215
+
216
+ def _load_weights(self, weights_path: str):
217
+ """Load trained weights into model."""
218
+ state_dict = torch.load(weights_path, map_location=self.device)
219
+
220
+ # Filter FLOPs counting keys
221
+ filtered = {
222
+ k: v for k, v in state_dict.items()
223
+ if 'total_ops' not in k and 'total_params' not in k
224
+ }
225
+
226
+ # Load weights
227
+ missing, unexpected = self.model.load_state_dict(filtered, strict=False)
228
+
229
+ if missing:
230
+ print(f"⚠️ Missing keys: {len(missing)}")
231
+ if unexpected:
232
+ print(f"⚠️ Unexpected keys: {len(unexpected)}")
233
+
234
+ def predict(self, source: Union[str, np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
235
+ """
236
+ Run inference on image(s).
237
+
238
+ Args:
239
+ source: Image path, PIL Image, or numpy array
240
+ **kwargs: Additional arguments (top_k, conf_threshold, etc.)
241
+
242
+ Returns:
243
+ Prediction results dictionary
244
+ """
245
+ if self.task == 'detection':
246
+ return self._detect(source, **kwargs)
247
+ else:
248
+ return self._classify(source, **kwargs)
249
+
250
+ def _classify(self, image: Union[str, np.ndarray, Image.Image],
251
+ top_k: int = 5) -> Dict[str, Any]:
252
+ """Classification inference."""
253
+ # Preprocess
254
+ tensor = self._preprocess_image(image)
255
+
256
+ # Inference
257
+ with torch.no_grad():
258
+ outputs, _ = self.model(tensor)
259
+ probs = torch.softmax(outputs, dim=1)
260
+
261
+ # Get top-k
262
+ top_probs, top_indices = torch.topk(probs[0], min(top_k, probs.shape[1]))
263
+
264
+ predictions = [
265
+ {'class': idx.item(), 'confidence': prob.item()}
266
+ for idx, prob in zip(top_indices, top_probs)
267
+ ]
268
+
269
+ return {
270
+ 'predictions': predictions,
271
+ 'class': predictions[0]['class'],
272
+ 'confidence': predictions[0]['confidence']
273
+ }
274
+
275
+ def _detect(self, image: Union[str, np.ndarray, Image.Image],
276
+ conf_threshold: float = 0.25,
277
+ iou_threshold: float = 0.45) -> Dict[str, Any]:
278
+ """Detection inference."""
279
+ # Preprocess with larger size for detection
280
+ tensor = self._preprocess_image(image, size=640)
281
+
282
+ # Inference
283
+ with torch.no_grad():
284
+ outputs, _ = self.model(tensor)
285
+
286
+ # Post-process detections
287
+ detections = self._postprocess_detections(outputs, conf_threshold, iou_threshold)
288
+
289
+ return {
290
+ 'detections': detections,
291
+ 'count': len(detections)
292
+ }
293
+
294
+ def _preprocess_image(self, image: Union[str, np.ndarray, Image.Image],
295
+ size: int = None) -> torch.Tensor:
296
+ """Preprocess image for model input."""
297
+ if size is None:
298
+ size = self.config.get('image_size', 224)
299
+
300
+ # Load image
301
+ if isinstance(image, str):
302
+ image = Image.open(image).convert('RGB')
303
+ elif isinstance(image, np.ndarray):
304
+ image = Image.fromarray(image).convert('RGB')
305
+
306
+ # Transform
307
+ transform = transforms.Compose([
308
+ transforms.Resize((size, size)),
309
+ transforms.ToTensor(),
310
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
311
+ ])
312
+
313
+ return transform(image).unsqueeze(0).to(self.device)
314
+
315
+ def _postprocess_detections(self, outputs: torch.Tensor,
316
+ conf_threshold: float,
317
+ iou_threshold: float) -> List[Dict]:
318
+ """Post-process detection outputs."""
319
+ detections = []
320
+
321
+ try:
322
+ from .sota.detection_heads import decode_predictions, non_max_suppression
323
+
324
+ grid_size = outputs.size(2) if outputs.dim() >= 3 else 7
325
+ pred_boxes, pred_scores, pred_labels = decode_predictions(
326
+ outputs, conf_threshold=conf_threshold, grid_size=grid_size
327
+ )
328
+
329
+ for i in range(len(pred_boxes)):
330
+ if len(pred_boxes[i]) > 0:
331
+ keep = non_max_suppression(
332
+ pred_boxes[i], pred_scores[i], pred_labels[i],
333
+ iou_threshold=iou_threshold
334
+ )
335
+ for j in keep:
336
+ detections.append({
337
+ 'box': pred_boxes[i][j].cpu().numpy().tolist(),
338
+ 'confidence': pred_scores[i][j].item(),
339
+ 'class': pred_labels[i][j].item()
340
+ })
341
+ except ImportError:
342
+ pass
343
+
344
+ return detections
345
+
346
+ def __call__(self, source: Union[str, np.ndarray, Image.Image], **kwargs):
347
+ """Call model like a function (similar to YOLO)."""
348
+ return self.predict(source, **kwargs)
349
+
350
+ def to(self, device: str):
351
+ """Move model to device."""
352
+ self.device = device
353
+ self.model = self.model.to(device)
354
+ return self
355
+
356
+ def info(self) -> Dict[str, Any]:
357
+ """Get model information."""
358
+ return {
359
+ 'category': self.model_category,
360
+ 'task': self.task,
361
+ 'backbone': self.config.get('backbone'),
362
+ 'num_classes': self.config.get('num_classes'),
363
+ 'device': self.device
364
+ }
365
+
366
+ def __del__(self):
367
+ """Cleanup temp directory."""
368
+ if self._temp_dir and os.path.exists(self._temp_dir):
369
+ shutil.rmtree(self._temp_dir, ignore_errors=True)
370
+
371
+ def __repr__(self):
372
+ return f"NASModel(category={self.model_category}, task={self.task}, device={self.device})"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: eizen-nsga
3
- Version: 1.0.0
3
+ Version: 1.0.1
4
4
  Summary: Simple inference package for NSGA-Net trained models
5
5
  Home-page: https://github.com/eizen-ai/nsga-net
6
6
  Author: Eizen.ai Team
@@ -0,0 +1,15 @@
1
+ eizen_nsga/__init__.py,sha256=SYzEu-7pzFy8Ul-bkjnNzuRQtvVQcfbTLBG1Dapun5g,291
2
+ eizen_nsga/model.py,sha256=bIasvlnrVmaXnaYza7J9dMDoFIK-X1LC2PcPNIzdaBI,13224
3
+ eizen_nsga/sota/__init__.py,sha256=3hyX_CR86TvV2RYQaES0FhZdtklUBu3DB7HAv1Z0yGo,525
4
+ eizen_nsga/sota/detection_heads.py,sha256=Tx7qgqc0_BBtaIenxNHg_dIwbDHyOQz3QsamNEa9gX0,11206
5
+ eizen_nsga/sota/macro_encoding.py,sha256=lwV8Nptwt014UWW-2rAmivYqS-5Jfs3g8DCmzhRzHsA,3037
6
+ eizen_nsga/sota/macro_models.py,sha256=okitzfXyaRPjGb5lAh9BNwqUVjTaUqaj7jwypKKVVm0,12254
7
+ eizen_nsga/sota/micro_encoding.py,sha256=LrhgzhgnGMP2Up7Uqy0zT5nedzefqDXl2HJ8YL2TOVM,4820
8
+ eizen_nsga/sota/micro_models.py,sha256=piLM6jt4LEiiQuox6AL6I-K0W0XcGLcL2qHWZgPbpvA,22062
9
+ eizen_nsga/sota/micro_operations.py,sha256=3E2JL1eSzWYmtI4X9omBkocWcAoZslXdkIj8CS1h4dQ,8702
10
+ eizen_nsga/sota/model_registry.py,sha256=lQQKUKCVWEQEql5mDZrlwRJEIVaqdokNWEEN4SWrnA0,17750
11
+ eizen_nsga-1.0.1.dist-info/licenses/LICENSE,sha256=_HF_TY-jv6lFXR4QcG1iTnMYgDC3MfslE5S0OvUGlQA,1091
12
+ eizen_nsga-1.0.1.dist-info/METADATA,sha256=D-hn9vHIdz88l9C2MgzxLHZU2nH8gj3DVfSKipCiEdQ,5268
13
+ eizen_nsga-1.0.1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
14
+ eizen_nsga-1.0.1.dist-info/top_level.txt,sha256=BPCkm-TWk4CpW-H-eKYfaa6KWJuepOHlKR3M5UhB4-4,11
15
+ eizen_nsga-1.0.1.dist-info/RECORD,,
@@ -0,0 +1 @@
1
+ eizen_nsga
@@ -1,13 +0,0 @@
1
- eizen_nsga-1.0.0.dist-info/licenses/LICENSE,sha256=_HF_TY-jv6lFXR4QcG1iTnMYgDC3MfslE5S0OvUGlQA,1091
2
- sota/__init__.py,sha256=3hyX_CR86TvV2RYQaES0FhZdtklUBu3DB7HAv1Z0yGo,525
3
- sota/detection_heads.py,sha256=Tx7qgqc0_BBtaIenxNHg_dIwbDHyOQz3QsamNEa9gX0,11206
4
- sota/macro_encoding.py,sha256=lwV8Nptwt014UWW-2rAmivYqS-5Jfs3g8DCmzhRzHsA,3037
5
- sota/macro_models.py,sha256=okitzfXyaRPjGb5lAh9BNwqUVjTaUqaj7jwypKKVVm0,12254
6
- sota/micro_encoding.py,sha256=LrhgzhgnGMP2Up7Uqy0zT5nedzefqDXl2HJ8YL2TOVM,4820
7
- sota/micro_models.py,sha256=piLM6jt4LEiiQuox6AL6I-K0W0XcGLcL2qHWZgPbpvA,22062
8
- sota/micro_operations.py,sha256=3E2JL1eSzWYmtI4X9omBkocWcAoZslXdkIj8CS1h4dQ,8702
9
- sota/model_registry.py,sha256=lQQKUKCVWEQEql5mDZrlwRJEIVaqdokNWEEN4SWrnA0,17750
10
- eizen_nsga-1.0.0.dist-info/METADATA,sha256=uQP38LPaimMrKbFtdFhUXuBaOsIMaQntO1oJRvjDLK4,5268
11
- eizen_nsga-1.0.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
12
- eizen_nsga-1.0.0.dist-info/top_level.txt,sha256=pPnDGDIcVIcY4aYyBjID0_x2rcMP9_Q3G-ZHKSZGwQI,5
13
- eizen_nsga-1.0.0.dist-info/RECORD,,
@@ -1 +0,0 @@
1
- sota
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes