eizen-nsga 1.0.0__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
eizen_nsga/__init__.py ADDED
@@ -0,0 +1,15 @@
1
+ """
2
+ eizen-nsga: Simple inference package for NSGA-Net trained models
3
+
4
+ Usage:
5
+ from eizen_nsga import NASModel
6
+
7
+ model = NASModel("trained_model.zip")
8
+ result = model.predict("image.jpg")
9
+ """
10
+
11
+ __version__ = "1.0.0"
12
+
13
+ from .model import NASModel
14
+
15
+ __all__ = ['NASModel']
eizen_nsga/model.py ADDED
@@ -0,0 +1,382 @@
1
+ """
2
+ NASModel - Universal model class for NSGA-Net trained models
3
+
4
+ Simple interface like ultralytics YOLO:
5
+ model = NASModel("model.zip")
6
+ result = model("image.jpg")
7
+ """
8
+
9
+ import os
10
+ import re
11
+ import ast
12
+ import sys
13
+ import zipfile
14
+ import tempfile
15
+ import shutil
16
+ from typing import Union, Optional, Dict, Any, List
17
+ from urllib.parse import urlparse
18
+ from urllib.request import urlretrieve
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import numpy as np
23
+ from PIL import Image
24
+ import torchvision.transforms as transforms
25
+
26
+
27
+ class NASModel:
28
+ """
29
+ Universal model class for NSGA-Net trained models.
30
+
31
+ Simple interface like YOLO:
32
+ model = NASModel("model.zip")
33
+ result = model.predict("image.jpg")
34
+
35
+ Supports:
36
+ - SOTA: Computer vision (YOLOv8, ResNet, EfficientNet)
37
+ - NN: Tabular data (coming soon)
38
+ - Transformer: LLMs (coming soon)
39
+ """
40
+
41
+ def __init__(self, path: str, device: Optional[str] = None):
42
+ """
43
+ Load a trained NSGA-Net model.
44
+
45
+ Args:
46
+ path: Path or URL to model ZIP file or directory
47
+ Supports: local paths, HTTP/HTTPS URLs, S3, GCS, etc.
48
+ device: 'cuda', 'cpu', or None (auto-detect)
49
+ """
50
+ self.source_path = path
51
+ self._temp_dir = None
52
+
53
+ # Extract ZIP if needed
54
+ model_dir = self._extract_model(path)
55
+
56
+ # Find weights and config
57
+ weights_path, log_path = self._find_model_files(model_dir)
58
+
59
+ # Parse configuration
60
+ self.config = self._parse_log(log_path)
61
+ self.model_category = self.config.get('modelCategory', 'SOTA')
62
+ self.task = self.config.get('task', 'classification')
63
+
64
+ # Set up class names (like YOLO's model.names)
65
+ class_names = self.config.get('class_names', None)
66
+ if class_names and isinstance(class_names, list):
67
+ self.names = {i: name for i, name in enumerate(class_names)}
68
+ else:
69
+ # Fallback to indices if no names available
70
+ num_classes = self.config.get('num_classes', 0)
71
+ self.names = {i: str(i) for i in range(num_classes)}
72
+
73
+ # Set device
74
+ if device is None:
75
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
76
+ self.device = device
77
+
78
+ # Build and load model
79
+ self.model = self._build_model()
80
+ self._load_weights(weights_path)
81
+ self.model.eval()
82
+
83
+ print(f"✅ Loaded {self.model_category} model for {self.task}")
84
+
85
+ def _extract_model(self, path: str) -> str:
86
+ """Extract ZIP if needed, return model directory. Supports local paths and URLs."""
87
+ # Check if path is a URL
88
+ parsed = urlparse(path)
89
+ is_url = parsed.scheme in ('http', 'https', 's3', 'gs', 'ftp')
90
+
91
+ if is_url:
92
+ # Download from URL to temporary file
93
+ print(f"📥 Downloading model from URL...")
94
+ self._temp_dir = tempfile.mkdtemp()
95
+ local_zip = os.path.join(self._temp_dir, 'model.zip')
96
+ urlretrieve(path, local_zip)
97
+ print(f"✅ Downloaded to {local_zip}")
98
+ path = local_zip
99
+
100
+ # Extract ZIP if it's a zip file
101
+ if path.endswith('.zip') and os.path.isfile(path):
102
+ if not self._temp_dir:
103
+ self._temp_dir = tempfile.mkdtemp()
104
+ with zipfile.ZipFile(path, 'r') as zf:
105
+ zf.extractall(self._temp_dir)
106
+ return self._temp_dir
107
+ return path
108
+
109
+ def _find_model_files(self, model_dir: str) -> tuple:
110
+ """Find weights.pt and log.txt in model directory."""
111
+ weights_path = None
112
+ log_path = None
113
+
114
+ # Check root first
115
+ if os.path.exists(os.path.join(model_dir, 'weights.pt')):
116
+ weights_path = os.path.join(model_dir, 'weights.pt')
117
+ log_path = os.path.join(model_dir, 'log.txt')
118
+ else:
119
+ # Search subdirectories
120
+ for root, dirs, files in os.walk(model_dir):
121
+ if 'weights.pt' in files:
122
+ weights_path = os.path.join(root, 'weights.pt')
123
+ log_path = os.path.join(root, 'log.txt')
124
+ break
125
+
126
+ if not weights_path or not os.path.exists(weights_path):
127
+ raise FileNotFoundError(f"weights.pt not found in {model_dir}")
128
+
129
+ return weights_path, log_path
130
+
131
+ def _parse_log(self, log_path: str) -> Dict[str, Any]:
132
+ """Parse log.txt to extract configuration."""
133
+ config = {}
134
+
135
+ if not log_path or not os.path.exists(log_path):
136
+ return config
137
+
138
+ with open(log_path, 'r') as f:
139
+ content = f.read()
140
+
141
+ # Extract all key-value pairs
142
+ patterns = {
143
+ 'modelCategory': r'modelCategory[:\s]*(\w+)',
144
+ 'task': r'Task[:\s]*(\w+)',
145
+ 'genome': r'Genome[:\s]*(.+)',
146
+ 'backbone': r'backbone[:\s]*(\S+)',
147
+ 'search_space': r'search_space[:\s]*(\w+)',
148
+ 'num_cells': r'num_cells[:\s]*(\d+)',
149
+ 'num_classes': r'num_classes[:\s]*(\d+)',
150
+ 'image_size': r'image_size[:\s]*(\d+)',
151
+ 'class_names': r'class_names[:\s]*(.+)',
152
+ }
153
+
154
+ for key, pattern in patterns.items():
155
+ match = re.search(pattern, content, re.IGNORECASE | re.MULTILINE)
156
+ if match:
157
+ value = match.group(1).strip()
158
+ try:
159
+ config[key] = ast.literal_eval(value)
160
+ except:
161
+ config[key] = value
162
+
163
+ return config
164
+
165
+ def _build_model(self) -> nn.Module:
166
+ """Build model architecture based on category."""
167
+ category = self.model_category.upper()
168
+
169
+ if category == 'SOTA':
170
+ return self._build_sota_model()
171
+ elif category == 'NN':
172
+ return self._build_nn_model()
173
+ elif category == 'TRANSFORMER':
174
+ return self._build_transformer_model()
175
+ else:
176
+ raise ValueError(f"Unknown model category: {category}")
177
+
178
+ def _build_sota_model(self) -> nn.Module:
179
+ """Build SOTA adapter model from genome."""
180
+ genome = self.config.get('genome')
181
+ backbone = self.config.get('backbone', 'yolov8n')
182
+ search_space = self.config.get('search_space', 'micro')
183
+ num_cells = int(self.config.get('num_cells', 3))
184
+ num_classes = int(self.config.get('num_classes', 10))
185
+
186
+ if genome is None:
187
+ raise ValueError("genome not found in log.txt")
188
+
189
+ # Import SOTA modules from package
190
+ from .sota.micro_models import GeneralAdapterNetwork
191
+ from .sota.macro_models import MacroAdapterNetwork
192
+ from .sota import micro_encoding, macro_encoding
193
+
194
+ # Decode and build
195
+ if search_space == 'macro':
196
+ genotype = macro_encoding.decode(genome)
197
+ model = MacroAdapterNetwork(
198
+ genotype=genotype,
199
+ backbone_name=backbone,
200
+ num_classes=num_classes,
201
+ num_cells=num_cells,
202
+ pretrained=True,
203
+ freeze_backbone=True
204
+ )
205
+ else:
206
+ genotype = micro_encoding.decode(genome)
207
+ model = GeneralAdapterNetwork(
208
+ genotype=genotype,
209
+ backbone_name=backbone,
210
+ num_classes=num_classes,
211
+ num_cells=num_cells,
212
+ pretrained=True,
213
+ freeze_backbone=True
214
+ )
215
+
216
+ return model.to(self.device)
217
+
218
+ def _build_nn_model(self) -> nn.Module:
219
+ """Build NN model (tabular)."""
220
+ raise NotImplementedError("NN model support coming soon")
221
+
222
+ def _build_transformer_model(self) -> nn.Module:
223
+ """Build Transformer model (LLM)."""
224
+ raise NotImplementedError("Transformer model support coming soon")
225
+
226
+ def _load_weights(self, weights_path: str):
227
+ """Load trained weights into model."""
228
+ state_dict = torch.load(weights_path, map_location=self.device)
229
+
230
+ # Filter FLOPs counting keys
231
+ filtered = {
232
+ k: v for k, v in state_dict.items()
233
+ if 'total_ops' not in k and 'total_params' not in k
234
+ }
235
+
236
+ # Load weights
237
+ missing, unexpected = self.model.load_state_dict(filtered, strict=False)
238
+
239
+ if missing:
240
+ print(f"⚠️ Missing keys: {len(missing)}")
241
+ if unexpected:
242
+ print(f"⚠️ Unexpected keys: {len(unexpected)}")
243
+
244
+ def predict(self, source: Union[str, np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
245
+ """
246
+ Run inference on image(s).
247
+
248
+ Args:
249
+ source: Image path, PIL Image, or numpy array
250
+ **kwargs: Additional arguments (top_k, conf_threshold, etc.)
251
+
252
+ Returns:
253
+ Prediction results dictionary
254
+ """
255
+ if self.task == 'detection':
256
+ return self._detect(source, **kwargs)
257
+ else:
258
+ return self._classify(source, **kwargs)
259
+
260
+ def _classify(self, image: Union[str, np.ndarray, Image.Image],
261
+ top_k: int = 5) -> Dict[str, Any]:
262
+ """Classification inference."""
263
+ # Preprocess
264
+ tensor = self._preprocess_image(image)
265
+
266
+ # Inference
267
+ with torch.no_grad():
268
+ outputs, _ = self.model(tensor)
269
+ probs = torch.softmax(outputs, dim=1)
270
+
271
+ # Get top-k
272
+ top_probs, top_indices = torch.topk(probs[0], min(top_k, probs.shape[1]))
273
+
274
+ predictions = [
275
+ {'class': idx.item(), 'confidence': prob.item()}
276
+ for idx, prob in zip(top_indices, top_probs)
277
+ ]
278
+
279
+ return {
280
+ 'predictions': predictions,
281
+ 'class': predictions[0]['class'],
282
+ 'confidence': predictions[0]['confidence']
283
+ }
284
+
285
+ def _detect(self, image: Union[str, np.ndarray, Image.Image],
286
+ conf_threshold: float = 0.25,
287
+ iou_threshold: float = 0.45) -> Dict[str, Any]:
288
+ """Detection inference."""
289
+ # Preprocess with larger size for detection
290
+ tensor = self._preprocess_image(image, size=640)
291
+
292
+ # Inference
293
+ with torch.no_grad():
294
+ outputs, _ = self.model(tensor)
295
+
296
+ # Post-process detections
297
+ detections = self._postprocess_detections(outputs, conf_threshold, iou_threshold)
298
+
299
+ return {
300
+ 'detections': detections,
301
+ 'count': len(detections)
302
+ }
303
+
304
+ def _preprocess_image(self, image: Union[str, np.ndarray, Image.Image],
305
+ size: int = None) -> torch.Tensor:
306
+ """Preprocess image for model input."""
307
+ if size is None:
308
+ size = self.config.get('image_size', 224)
309
+
310
+ # Load image
311
+ if isinstance(image, str):
312
+ image = Image.open(image).convert('RGB')
313
+ elif isinstance(image, np.ndarray):
314
+ image = Image.fromarray(image).convert('RGB')
315
+
316
+ # Transform
317
+ transform = transforms.Compose([
318
+ transforms.Resize((size, size)),
319
+ transforms.ToTensor(),
320
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
321
+ ])
322
+
323
+ return transform(image).unsqueeze(0).to(self.device)
324
+
325
+ def _postprocess_detections(self, outputs: torch.Tensor,
326
+ conf_threshold: float,
327
+ iou_threshold: float) -> List[Dict]:
328
+ """Post-process detection outputs."""
329
+ detections = []
330
+
331
+ try:
332
+ from .sota.detection_heads import decode_predictions, non_max_suppression
333
+
334
+ grid_size = outputs.size(2) if outputs.dim() >= 3 else 7
335
+ pred_boxes, pred_scores, pred_labels = decode_predictions(
336
+ outputs, conf_threshold=conf_threshold, grid_size=grid_size
337
+ )
338
+
339
+ for i in range(len(pred_boxes)):
340
+ if len(pred_boxes[i]) > 0:
341
+ keep = non_max_suppression(
342
+ pred_boxes[i], pred_scores[i], pred_labels[i],
343
+ iou_threshold=iou_threshold
344
+ )
345
+ for j in keep:
346
+ detections.append({
347
+ 'box': pred_boxes[i][j].cpu().numpy().tolist(),
348
+ 'confidence': pred_scores[i][j].item(),
349
+ 'class': pred_labels[i][j].item()
350
+ })
351
+ except ImportError:
352
+ pass
353
+
354
+ return detections
355
+
356
+ def __call__(self, source: Union[str, np.ndarray, Image.Image], **kwargs):
357
+ """Call model like a function (similar to YOLO)."""
358
+ return self.predict(source, **kwargs)
359
+
360
+ def to(self, device: str):
361
+ """Move model to device."""
362
+ self.device = device
363
+ self.model = self.model.to(device)
364
+ return self
365
+
366
+ def info(self) -> Dict[str, Any]:
367
+ """Get model information."""
368
+ return {
369
+ 'category': self.model_category,
370
+ 'task': self.task,
371
+ 'backbone': self.config.get('backbone'),
372
+ 'num_classes': self.config.get('num_classes'),
373
+ 'device': self.device
374
+ }
375
+
376
+ def __del__(self):
377
+ """Cleanup temp directory."""
378
+ if self._temp_dir and os.path.exists(self._temp_dir):
379
+ shutil.rmtree(self._temp_dir, ignore_errors=True)
380
+
381
+ def __repr__(self):
382
+ return f"NASModel(category={self.model_category}, task={self.task}, device={self.device})"
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.4
1
+ Metadata-Version: 2.1
2
2
  Name: eizen-nsga
3
- Version: 1.0.0
3
+ Version: 1.0.2
4
4
  Summary: Simple inference package for NSGA-Net trained models
5
5
  Home-page: https://github.com/eizen-ai/nsga-net
6
6
  Author: Eizen.ai Team
@@ -21,28 +21,24 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
21
21
  Requires-Python: >=3.8
22
22
  Description-Content-Type: text/markdown
23
23
  License-File: LICENSE
24
- Requires-Dist: torch>=1.9.0
25
- Requires-Dist: torchvision>=0.10.0
26
- Requires-Dist: numpy>=1.19.0
27
- Requires-Dist: pillow>=8.0.0
28
- Provides-Extra: sota
29
- Requires-Dist: ultralytics>=8.0.0; extra == "sota"
24
+ Requires-Dist: numpy >=1.19.0
25
+ Requires-Dist: pillow >=8.0.0
26
+ Requires-Dist: torch >=1.9.0
27
+ Requires-Dist: torchvision >=0.10.0
28
+ Provides-Extra: all
29
+ Requires-Dist: pandas >=1.0.0 ; extra == 'all'
30
+ Requires-Dist: scikit-learn >=0.24.0 ; extra == 'all'
31
+ Requires-Dist: tokenizers >=0.10.0 ; extra == 'all'
32
+ Requires-Dist: transformers >=4.0.0 ; extra == 'all'
33
+ Requires-Dist: ultralytics >=8.0.0 ; extra == 'all'
30
34
  Provides-Extra: nn
31
- Requires-Dist: pandas>=1.0.0; extra == "nn"
32
- Requires-Dist: scikit-learn>=0.24.0; extra == "nn"
35
+ Requires-Dist: pandas >=1.0.0 ; extra == 'nn'
36
+ Requires-Dist: scikit-learn >=0.24.0 ; extra == 'nn'
37
+ Provides-Extra: sota
38
+ Requires-Dist: ultralytics >=8.0.0 ; extra == 'sota'
33
39
  Provides-Extra: transformer
34
- Requires-Dist: transformers>=4.0.0; extra == "transformer"
35
- Requires-Dist: tokenizers>=0.10.0; extra == "transformer"
36
- Provides-Extra: all
37
- Requires-Dist: ultralytics>=8.0.0; extra == "all"
38
- Requires-Dist: pandas>=1.0.0; extra == "all"
39
- Requires-Dist: scikit-learn>=0.24.0; extra == "all"
40
- Requires-Dist: transformers>=4.0.0; extra == "all"
41
- Requires-Dist: tokenizers>=0.10.0; extra == "all"
42
- Dynamic: author
43
- Dynamic: home-page
44
- Dynamic: license-file
45
- Dynamic: requires-python
40
+ Requires-Dist: tokenizers >=0.10.0 ; extra == 'transformer'
41
+ Requires-Dist: transformers >=4.0.0 ; extra == 'transformer'
46
42
 
47
43
  # eizen-nsga
48
44
 
@@ -0,0 +1,15 @@
1
+ eizen_nsga/__init__.py,sha256=SYzEu-7pzFy8Ul-bkjnNzuRQtvVQcfbTLBG1Dapun5g,291
2
+ eizen_nsga/model.py,sha256=NC1gyGdh8uwfzpQ8pYPLHd3pLOvFxtGauD1nZN6zgm8,13728
3
+ eizen_nsga/sota/__init__.py,sha256=3hyX_CR86TvV2RYQaES0FhZdtklUBu3DB7HAv1Z0yGo,525
4
+ eizen_nsga/sota/detection_heads.py,sha256=Tx7qgqc0_BBtaIenxNHg_dIwbDHyOQz3QsamNEa9gX0,11206
5
+ eizen_nsga/sota/macro_encoding.py,sha256=lwV8Nptwt014UWW-2rAmivYqS-5Jfs3g8DCmzhRzHsA,3037
6
+ eizen_nsga/sota/macro_models.py,sha256=okitzfXyaRPjGb5lAh9BNwqUVjTaUqaj7jwypKKVVm0,12254
7
+ eizen_nsga/sota/micro_encoding.py,sha256=LrhgzhgnGMP2Up7Uqy0zT5nedzefqDXl2HJ8YL2TOVM,4820
8
+ eizen_nsga/sota/micro_models.py,sha256=piLM6jt4LEiiQuox6AL6I-K0W0XcGLcL2qHWZgPbpvA,22062
9
+ eizen_nsga/sota/micro_operations.py,sha256=3E2JL1eSzWYmtI4X9omBkocWcAoZslXdkIj8CS1h4dQ,8702
10
+ eizen_nsga/sota/model_registry.py,sha256=lQQKUKCVWEQEql5mDZrlwRJEIVaqdokNWEEN4SWrnA0,17750
11
+ eizen_nsga-1.0.2.dist-info/LICENSE,sha256=_HF_TY-jv6lFXR4QcG1iTnMYgDC3MfslE5S0OvUGlQA,1091
12
+ eizen_nsga-1.0.2.dist-info/METADATA,sha256=klm_RoP1eL5nYeTxLGFmNX6yFDY2tEODN5an8AxGknw,5210
13
+ eizen_nsga-1.0.2.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
14
+ eizen_nsga-1.0.2.dist-info/top_level.txt,sha256=BPCkm-TWk4CpW-H-eKYfaa6KWJuepOHlKR3M5UhB4-4,11
15
+ eizen_nsga-1.0.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.10.2)
2
+ Generator: bdist_wheel (0.42.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -0,0 +1 @@
1
+ eizen_nsga
@@ -1,13 +0,0 @@
1
- eizen_nsga-1.0.0.dist-info/licenses/LICENSE,sha256=_HF_TY-jv6lFXR4QcG1iTnMYgDC3MfslE5S0OvUGlQA,1091
2
- sota/__init__.py,sha256=3hyX_CR86TvV2RYQaES0FhZdtklUBu3DB7HAv1Z0yGo,525
3
- sota/detection_heads.py,sha256=Tx7qgqc0_BBtaIenxNHg_dIwbDHyOQz3QsamNEa9gX0,11206
4
- sota/macro_encoding.py,sha256=lwV8Nptwt014UWW-2rAmivYqS-5Jfs3g8DCmzhRzHsA,3037
5
- sota/macro_models.py,sha256=okitzfXyaRPjGb5lAh9BNwqUVjTaUqaj7jwypKKVVm0,12254
6
- sota/micro_encoding.py,sha256=LrhgzhgnGMP2Up7Uqy0zT5nedzefqDXl2HJ8YL2TOVM,4820
7
- sota/micro_models.py,sha256=piLM6jt4LEiiQuox6AL6I-K0W0XcGLcL2qHWZgPbpvA,22062
8
- sota/micro_operations.py,sha256=3E2JL1eSzWYmtI4X9omBkocWcAoZslXdkIj8CS1h4dQ,8702
9
- sota/model_registry.py,sha256=lQQKUKCVWEQEql5mDZrlwRJEIVaqdokNWEEN4SWrnA0,17750
10
- eizen_nsga-1.0.0.dist-info/METADATA,sha256=uQP38LPaimMrKbFtdFhUXuBaOsIMaQntO1oJRvjDLK4,5268
11
- eizen_nsga-1.0.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
12
- eizen_nsga-1.0.0.dist-info/top_level.txt,sha256=pPnDGDIcVIcY4aYyBjID0_x2rcMP9_Q3G-ZHKSZGwQI,5
13
- eizen_nsga-1.0.0.dist-info/RECORD,,
@@ -1 +0,0 @@
1
- sota
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes