findcrack 0.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,44 @@
1
+ name: Release
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - 'v*' # Triggers on tags like v0.1.1
7
+
8
+ permissions:
9
+ contents: write # Required for creating GitHub Releases
10
+ id-token: write # Required for OIDC Trusted Publishing to PyPI/TestPyPI
11
+
12
+ jobs:
13
+ release:
14
+ name: Build, Publish and Release
15
+ runs-on: ubuntu-latest
16
+ steps:
17
+ - name: Checkout code
18
+ uses: actions/checkout@v4
19
+ with:
20
+ fetch-depth: 0 # Required for setuptools-scm to fetch all tags for version resolution
21
+
22
+ - name: Install uv
23
+ uses: astral-sh/setup-uv@v5
24
+ with:
25
+ enable-cache: true
26
+ python-version: ">=3.11"
27
+
28
+ - name: Build distributions
29
+ run: uv build
30
+
31
+ - name: Publish to TestPyPI
32
+ uses: pypa/gh-action-pypi-publish@release/v1
33
+ continue-on-error: true
34
+ with:
35
+ repository-url: https://test.pypi.org/legacy/
36
+
37
+ - name: Publish to PyPI
38
+ uses: pypa/gh-action-pypi-publish@release/v1
39
+
40
+ - name: Create GitHub Release
41
+ uses: softprops/action-gh-release@v2
42
+ with:
43
+ files: dist/*
44
+ generate_release_notes: true
@@ -0,0 +1,34 @@
1
+ # vscode
2
+ .vscode/
3
+
4
+ # model
5
+ models/
6
+ *.pt
7
+ *.onnx
8
+ *.h5
9
+ *.pkl
10
+ dist/
11
+ *.whl
12
+
13
+ # Python
14
+ __pycache__/
15
+ *.ipynb
16
+ *.pyc
17
+ *.pyo
18
+
19
+ # Models - critical: keep large files out of git!
20
+ checkpoints/
21
+ *.pt
22
+ *.onnx
23
+ *.h5
24
+ *.pkl
25
+ *.bin
26
+
27
+ # Build artifacts
28
+ dist/
29
+ build/
30
+ *.egg-info/
31
+
32
+ # Virtual environment
33
+ .venv/
34
+ .env
@@ -0,0 +1 @@
1
+ 3.11
@@ -0,0 +1,152 @@
1
+ Metadata-Version: 2.4
2
+ Name: findcrack
3
+ Version: 0.0.0
4
+ Summary: A deep learning crack detection package supporting U-Net and DeepCrack models with PyTorch and ONNX backends.
5
+ Author: StrikerEurika
6
+ License-Expression: MIT
7
+ Project-URL: Homepage, https://github.com/StrikerEurika/findcrack
8
+ Project-URL: Repository, https://github.com/StrikerEurika/findcrack
9
+ Project-URL: Releases, https://github.com/StrikerEurika/findcrack/releases
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: Operating System :: OS Independent
12
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
13
+ Classifier: Topic :: Scientific/Engineering :: Image Processing
14
+ Requires-Python: >=3.11
15
+ Description-Content-Type: text/markdown
16
+ Requires-Dist: albumentations>=2.0.8
17
+ Requires-Dist: onnxruntime>=1.27.0
18
+ Requires-Dist: opencv-python>=4.13.0.92
19
+ Requires-Dist: torch>=2.6.0
20
+ Requires-Dist: torchaudio>=2.6.0
21
+ Requires-Dist: torchvision>=0.21.0
22
+
23
+ # findcrack
24
+
25
+ `findcrack` is a deep learning crack detection package designed for pixel-level segmentation on high-resolution images. It supports U-Net and DeepCrack architectures, providing an easy-to-use API for inference, model caching, and multi-backend execution (PyTorch & ONNX).
26
+
27
+ ---
28
+
29
+ ## Features
30
+
31
+ - **Pre-trained Model Zoo**: Fetch pre-trained model weights (e.g., `Seg_UNET_CFD_actual_v1`, `Seg_UNET_CFD_actual_v2`) dynamically on demand.
32
+ - **Unified Backend Engine**: Seamlessly executes either PyTorch (`.pth`/`.pt`) or ONNX (`.onnx`) models using the same standard interface.
33
+ - **Sliding-Window Inference**: Efficiently process ultra-high-resolution images by dividing them into overlapping patches.
34
+ - **Gaussian & Average Blending**: Reconstructs the full image from patches using overlapping Gaussian blending filters to eliminate edge-seam artifacts.
35
+ - **Test-Time Augmentation (TTA)**: Performs multi-way augmentations (original, horizontal flip, vertical flip, and rotations) to produce highly robust prediction masks.
36
+ - **Validation Metrics**: Compute standard segmentation metrics like IoU, Dice Coefficient, Precision, Recall, and Pixel Accuracy.
37
+
38
+ ---
39
+
40
+ ## Installation
41
+
42
+ You can install `findcrack` directly from source or via PyPI (once published):
43
+
44
+ ```bash
45
+ # Install via pip
46
+ pip install findcrack
47
+
48
+ # Or using uv
49
+ uv add findcrack
50
+ ```
51
+
52
+ ---
53
+
54
+ ## Quickstart
55
+
56
+ Here is how to load a pre-trained model and run crack detection on a large image:
57
+
58
+ ```python
59
+ import cv2
60
+ from findcrack import CrackInferencePipeline, load_model
61
+
62
+ # 1. Load a pre-trained model from the official registry (or use your own URL)
63
+ # The weights are downloaded dynamically from GitHub releases on first use.
64
+ model = load_model("Seg_UNET_CFD_actual_v1", device="cuda")
65
+
66
+ # 2. Setup the inference pipeline
67
+ pipeline = CrackInferencePipeline(
68
+ model=model,
69
+ device="cuda",
70
+ patch_size=512,
71
+ overlap_ratio=0.2,
72
+ confidence_threhold=0.5,
73
+ use_tta=True # Enables multi-way Test-Time Augmentation
74
+ )
75
+
76
+ # 3. Perform inference
77
+ results = pipeline.predict("path/to/high_res_concrete.jpg")
78
+
79
+ # The results dictionary contains:
80
+ # - results["original_image"]: Original RGB image (numpy array)
81
+ # - results["confidence_map"]: Float probability map [0.0 - 1.0]
82
+ # - results["binary_mask"]: Binary segmentation mask [0 or 255]
83
+
84
+ # Save the output mask
85
+ cv2.imwrite("detected_cracks.png", results["binary_mask"])
86
+ ```
87
+
88
+ ---
89
+
90
+ ## API Reference
91
+
92
+ ### Model Loading & Caching
93
+
94
+ #### `load_model(variant: str, device: str = "cpu", force_download: bool = False, architecture = None, **kwargs)`
95
+ Loads a model variant from the local registry or directly from a remote HTTP(S) URL.
96
+
97
+ - **Parameters**:
98
+ - `variant`: The name of a registered variant (e.g., `"Seg_UNET_CFD_actual_v1"`) or a direct HTTP(S) URL to a weights file.
99
+ - `device`: Target execution device (`"cpu"`, `"cuda"`, or `"mps"`).
100
+ - `force_download`: If `True`, re-downloads weights even if cached locally.
101
+ - `architecture`: PyTorch architecture class (e.g., `UNet`, `DeepCrack`) - required only if loading a raw `.pth`/`.pt` file from a custom URL.
102
+
103
+ ```python
104
+ from findcrack import load_model, UNet
105
+
106
+ # Load custom model weights directly from an external URL
107
+ model = load_model(
108
+ variant="https://my-domain.com/custom_unet.pth",
109
+ architecture=UNet,
110
+ device="cuda"
111
+ )
112
+ ```
113
+
114
+ #### `list_models()`
115
+ Returns a list of all pre-trained models available in the built-in registry.
116
+
117
+ #### `register_model(name: str, url: str, architecture = None, kwargs: dict = None, backend: str = "pytorch")`
118
+ Registers a custom variant dynamically at runtime.
119
+
120
+ ---
121
+
122
+ ### Pipeline Configuration
123
+
124
+ #### `CrackInferencePipeline(model, device: str = "cuda", patch_size: int = 512, overlap_ratio: float = 0.2, confidence_threhold: float = 0.5, use_tta: bool = False)`
125
+ Handles sliding window preprocessing, execution, TTA, and patching reconstruction.
126
+
127
+ ---
128
+
129
+ ## Directory Structure
130
+
131
+ ```text
132
+ src/
133
+ └── findcrack/
134
+ ├── __init__.py # Main API endpoints (load_model, CrackInferencePipeline, etc.)
135
+ ├── metrics.py # Segmentation evaluation metrics (IoU, Dice, etc.)
136
+ ├── patching.py # Sliding window extraction and blend reconstruction
137
+ ├── pipeline.py # Crack Inference Pipeline wrapper
138
+ ├── preprocess.py # Color-space CLAHE contrast enhancement & transforms
139
+ ├── tta.py # Test-Time Augmentation forward pass routines
140
+ └── models/
141
+ ├── __init__.py # Model module exports
142
+ ├── unet.py # U-Net model definition
143
+ ├── deepcrack.py # DeepCrack model definition
144
+ ├── onnx_wrapper.py # Wrapper for running ONNX models as nn.Modules
145
+ └── zoo.py # Remote weight registry and cached loaders
146
+ ```
147
+
148
+ ---
149
+
150
+ ## License
151
+
152
+ This project is licensed under the MIT License.
@@ -0,0 +1,130 @@
1
+ # findcrack
2
+
3
+ `findcrack` is a deep learning crack detection package designed for pixel-level segmentation on high-resolution images. It supports U-Net and DeepCrack architectures, providing an easy-to-use API for inference, model caching, and multi-backend execution (PyTorch & ONNX).
4
+
5
+ ---
6
+
7
+ ## Features
8
+
9
+ - **Pre-trained Model Zoo**: Fetch pre-trained model weights (e.g., `Seg_UNET_CFD_actual_v1`, `Seg_UNET_CFD_actual_v2`) dynamically on demand.
10
+ - **Unified Backend Engine**: Seamlessly executes either PyTorch (`.pth`/`.pt`) or ONNX (`.onnx`) models using the same standard interface.
11
+ - **Sliding-Window Inference**: Efficiently process ultra-high-resolution images by dividing them into overlapping patches.
12
+ - **Gaussian & Average Blending**: Reconstructs the full image from patches using overlapping Gaussian blending filters to eliminate edge-seam artifacts.
13
+ - **Test-Time Augmentation (TTA)**: Performs multi-way augmentations (original, horizontal flip, vertical flip, and rotations) to produce highly robust prediction masks.
14
+ - **Validation Metrics**: Compute standard segmentation metrics like IoU, Dice Coefficient, Precision, Recall, and Pixel Accuracy.
15
+
16
+ ---
17
+
18
+ ## Installation
19
+
20
+ You can install `findcrack` directly from source or via PyPI (once published):
21
+
22
+ ```bash
23
+ # Install via pip
24
+ pip install findcrack
25
+
26
+ # Or using uv
27
+ uv add findcrack
28
+ ```
29
+
30
+ ---
31
+
32
+ ## Quickstart
33
+
34
+ Here is how to load a pre-trained model and run crack detection on a large image:
35
+
36
+ ```python
37
+ import cv2
38
+ from findcrack import CrackInferencePipeline, load_model
39
+
40
+ # 1. Load a pre-trained model from the official registry (or use your own URL)
41
+ # The weights are downloaded dynamically from GitHub releases on first use.
42
+ model = load_model("Seg_UNET_CFD_actual_v1", device="cuda")
43
+
44
+ # 2. Setup the inference pipeline
45
+ pipeline = CrackInferencePipeline(
46
+ model=model,
47
+ device="cuda",
48
+ patch_size=512,
49
+ overlap_ratio=0.2,
50
+ confidence_threhold=0.5,
51
+ use_tta=True # Enables multi-way Test-Time Augmentation
52
+ )
53
+
54
+ # 3. Perform inference
55
+ results = pipeline.predict("path/to/high_res_concrete.jpg")
56
+
57
+ # The results dictionary contains:
58
+ # - results["original_image"]: Original RGB image (numpy array)
59
+ # - results["confidence_map"]: Float probability map [0.0 - 1.0]
60
+ # - results["binary_mask"]: Binary segmentation mask [0 or 255]
61
+
62
+ # Save the output mask
63
+ cv2.imwrite("detected_cracks.png", results["binary_mask"])
64
+ ```
65
+
66
+ ---
67
+
68
+ ## API Reference
69
+
70
+ ### Model Loading & Caching
71
+
72
+ #### `load_model(variant: str, device: str = "cpu", force_download: bool = False, architecture = None, **kwargs)`
73
+ Loads a model variant from the local registry or directly from a remote HTTP(S) URL.
74
+
75
+ - **Parameters**:
76
+ - `variant`: The name of a registered variant (e.g., `"Seg_UNET_CFD_actual_v1"`) or a direct HTTP(S) URL to a weights file.
77
+ - `device`: Target execution device (`"cpu"`, `"cuda"`, or `"mps"`).
78
+ - `force_download`: If `True`, re-downloads weights even if cached locally.
79
+ - `architecture`: PyTorch architecture class (e.g., `UNet`, `DeepCrack`) - required only if loading a raw `.pth`/`.pt` file from a custom URL.
80
+
81
+ ```python
82
+ from findcrack import load_model, UNet
83
+
84
+ # Load custom model weights directly from an external URL
85
+ model = load_model(
86
+ variant="https://my-domain.com/custom_unet.pth",
87
+ architecture=UNet,
88
+ device="cuda"
89
+ )
90
+ ```
91
+
92
+ #### `list_models()`
93
+ Returns a list of all pre-trained models available in the built-in registry.
94
+
95
+ #### `register_model(name: str, url: str, architecture = None, kwargs: dict = None, backend: str = "pytorch")`
96
+ Registers a custom variant dynamically at runtime.
97
+
98
+ ---
99
+
100
+ ### Pipeline Configuration
101
+
102
+ #### `CrackInferencePipeline(model, device: str = "cuda", patch_size: int = 512, overlap_ratio: float = 0.2, confidence_threhold: float = 0.5, use_tta: bool = False)`
103
+ Handles sliding window preprocessing, execution, TTA, and patching reconstruction.
104
+
105
+ ---
106
+
107
+ ## Directory Structure
108
+
109
+ ```text
110
+ src/
111
+ └── findcrack/
112
+ ├── __init__.py # Main API endpoints (load_model, CrackInferencePipeline, etc.)
113
+ ├── metrics.py # Segmentation evaluation metrics (IoU, Dice, etc.)
114
+ ├── patching.py # Sliding window extraction and blend reconstruction
115
+ ├── pipeline.py # Crack Inference Pipeline wrapper
116
+ ├── preprocess.py # Color-space CLAHE contrast enhancement & transforms
117
+ ├── tta.py # Test-Time Augmentation forward pass routines
118
+ └── models/
119
+ ├── __init__.py # Model module exports
120
+ ├── unet.py # U-Net model definition
121
+ ├── deepcrack.py # DeepCrack model definition
122
+ ├── onnx_wrapper.py # Wrapper for running ONNX models as nn.Modules
123
+ └── zoo.py # Remote weight registry and cached loaders
124
+ ```
125
+
126
+ ---
127
+
128
+ ## License
129
+
130
+ This project is licensed under the MIT License.
@@ -0,0 +1,6 @@
1
+ def main():
2
+ print("Hello from findcrack!")
3
+
4
+
5
+ if __name__ == "__main__":
6
+ main()
@@ -0,0 +1,39 @@
1
+ [project]
2
+ name = "findcrack"
3
+ dynamic = ["version"]
4
+ description = "A deep learning crack detection package supporting U-Net and DeepCrack models with PyTorch and ONNX backends."
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ authors = [
8
+ { name = "StrikerEurika" }
9
+ ]
10
+ license = "MIT"
11
+ classifiers = [
12
+ "Programming Language :: Python :: 3",
13
+ "Operating System :: OS Independent",
14
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
15
+ "Topic :: Scientific/Engineering :: Image Processing",
16
+ ]
17
+ dependencies = [
18
+ "albumentations>=2.0.8",
19
+ "onnxruntime>=1.27.0",
20
+ "opencv-python>=4.13.0.92",
21
+ "torch>=2.6.0",
22
+ "torchaudio>=2.6.0",
23
+ "torchvision>=0.21.0",
24
+ ]
25
+
26
+ [project.urls]
27
+ Homepage = "https://github.com/StrikerEurika/findcrack"
28
+ Repository = "https://github.com/StrikerEurika/findcrack"
29
+ Releases = "https://github.com/StrikerEurika/findcrack/releases"
30
+
31
+ [tool.setuptools.packages.find]
32
+ where = ["src"]
33
+
34
+ [build-system]
35
+ requires = ["setuptools>=61.0", "setuptools-scm>=8.0"]
36
+ build-backend = "setuptools.build_meta"
37
+
38
+ [tool.setuptools-scm]
39
+ write_to = "src/findcrack/_version.py"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,20 @@
1
+ from .pipeline import CrackInferencePipeline
2
+ from .models import load_model, UNet, DeepCrack, list_models, register_model
3
+ from .metrics import calculate_metrics
4
+ from .preprocess import apply_lab_clahe
5
+
6
+ try:
7
+ from ._version import version as __version__
8
+ except ImportError:
9
+ __version__ = "unknown"
10
+
11
+ __all__ = [
12
+ "CrackInferencePipeline",
13
+ "load_model",
14
+ "UNet",
15
+ "DeepCrack",
16
+ "calculate_metrics",
17
+ "apply_lab_clahe",
18
+ "list_models",
19
+ "register_model",
20
+ ]
@@ -0,0 +1,23 @@
1
+ import numpy as np
2
+
3
+ def calculate_metrics(y_true: np.ndarray, y_prediction: np.ndarray, epsilon: float = 1e-7) -> dict:
4
+ """
5
+ Calculates IoU, Dice, Precision, Recall, and Pixel Accuracy.
6
+ """
7
+
8
+ y_true = y_true.astype(bool)
9
+ y_prediction = y_prediction.astype(bool)
10
+
11
+ # True Positives, False Positives, True Negatives, False Negatives
12
+ TP = np.sum(y_true & y_prediction)
13
+ FP = np.sum(~y_true & y_prediction)
14
+ FN = np.sum(y_true & ~y_prediction)
15
+ TN = np.sum(~y_true & ~y_prediction)
16
+
17
+ return {
18
+ "IoU": TP / (TP + FP + FN + epsilon),
19
+ "Dice": (2 * TP) / (2 * TP + FP + FN + epsilon),
20
+ "Precision": TP / (TP + FP + epsilon),
21
+ "Recall": TP / (TP + FN + epsilon),
22
+ "Pixel Accuracy": (TP + TN) / (TP + TN + FP + FN + epsilon)
23
+ }
@@ -0,0 +1,16 @@
1
+ from .unet import UNet
2
+ from .deepcrack import DeepCrack
3
+ from .onnx_wrapper import ONNXModelWrapper
4
+ from .zoo import load_model, MODEL_REGISTRY, list_models, register_model
5
+
6
+ __all__ = [
7
+ "UNet",
8
+ "DeepCrack",
9
+ "ONNXModelWrapper",
10
+ "load_model",
11
+ "MODEL_REGISTRY",
12
+ "list_models",
13
+ "register_model",
14
+ ]
15
+
16
+
@@ -0,0 +1,93 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class DoubleConv(nn.Module):
6
+ """(convolution => [BN] => ReLU) * 2"""
7
+ def __init__(self, in_channels, out_channels):
8
+ super().__init__()
9
+ self.double_conv = nn.Sequential(
10
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
11
+ nn.BatchNorm2d(out_channels),
12
+ nn.ReLU(inplace=True),
13
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
14
+ nn.BatchNorm2d(out_channels),
15
+ nn.ReLU(inplace=True)
16
+ )
17
+
18
+ def forward(self, x):
19
+ return self.double_conv(x)
20
+
21
+
22
+ class DeepCrack(nn.Module):
23
+ """
24
+ DeepCrack: A Deep Hierarchical Feature Learning Architecture for Crack Segmentation.
25
+ Fuses hierarchical convolutional features from both the encoder and decoder stages
26
+ at the same scale.
27
+ """
28
+ def __init__(self, n_channels: int = 3, n_classes: int = 1):
29
+ super().__init__()
30
+ self.n_channels = n_channels
31
+ self.n_classes = n_classes
32
+
33
+ # Encoder (downsampling blocks)
34
+ self.enc1 = DoubleConv(n_channels, 64)
35
+ self.enc2 = DoubleConv(64, 128)
36
+ self.enc3 = DoubleConv(128, 256)
37
+ self.enc4 = DoubleConv(256, 512)
38
+ self.enc5 = DoubleConv(512, 512)
39
+
40
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
41
+
42
+ # Decoder (upsampling & concatenation blocks)
43
+ self.dec5 = DoubleConv(512, 512)
44
+ self.dec4 = DoubleConv(512 + 512, 256)
45
+ self.dec3 = DoubleConv(256 + 256, 128)
46
+ self.dec2 = DoubleConv(128 + 128, 64)
47
+ self.dec1 = DoubleConv(64 + 64, 64)
48
+
49
+ # Side prediction layers (maps feature channels to class channels at each scale)
50
+ self.side1 = nn.Conv2d(64, n_classes, kernel_size=1)
51
+ self.side2 = nn.Conv2d(64, n_classes, kernel_size=1)
52
+ self.side3 = nn.Conv2d(128, n_classes, kernel_size=1)
53
+ self.side4 = nn.Conv2d(256, n_classes, kernel_size=1)
54
+ self.side5 = nn.Conv2d(512, n_classes, kernel_size=1)
55
+
56
+ # Fusion layer that combines all 5 side predictions into the final output
57
+ self.fuse = nn.Conv2d(n_classes * 5, n_classes, kernel_size=1)
58
+
59
+ def forward(self, x):
60
+ # 1. Encoder path
61
+ e1 = self.enc1(x)
62
+ e2 = self.enc2(self.pool(e1))
63
+ e3 = self.enc3(self.pool(e2))
64
+ e4 = self.enc4(self.pool(e3))
65
+ e5 = self.enc5(self.pool(e4))
66
+
67
+ # 2. Decoder path (with bilinear interpolation upsampling)
68
+ d5 = self.dec5(e5)
69
+
70
+ d4_up = F.interpolate(d5, size=e4.shape[2:], mode='bilinear', align_corners=True)
71
+ d4 = self.dec4(torch.cat([d4_up, e4], dim=1))
72
+
73
+ d3_up = F.interpolate(d4, size=e3.shape[2:], mode='bilinear', align_corners=True)
74
+ d3 = self.dec3(torch.cat([d3_up, e3], dim=1))
75
+
76
+ d2_up = F.interpolate(d3, size=e2.shape[2:], mode='bilinear', align_corners=True)
77
+ d2 = self.dec2(torch.cat([d2_up, e2], dim=1))
78
+
79
+ d1_up = F.interpolate(d2, size=e1.shape[2:], mode='bilinear', align_corners=True)
80
+ d1 = self.dec1(torch.cat([d1_up, e1], dim=1))
81
+
82
+ # 3. Extract side predictions and upsample to input dimensions
83
+ h, w = x.shape[2:]
84
+ s1 = F.interpolate(self.side1(d1), size=(h, w), mode='bilinear', align_corners=True)
85
+ s2 = F.interpolate(self.side2(d2), size=(h, w), mode='bilinear', align_corners=True)
86
+ s3 = F.interpolate(self.side3(d3), size=(h, w), mode='bilinear', align_corners=True)
87
+ s4 = F.interpolate(self.side4(d4), size=(h, w), mode='bilinear', align_corners=True)
88
+ s5 = F.interpolate(self.side5(d5), size=(h, w), mode='bilinear', align_corners=True)
89
+
90
+ # 4. Fuse side predictions
91
+ fused = self.fuse(torch.cat([s1, s2, s3, s4, s5], dim=1))
92
+
93
+ return fused
@@ -0,0 +1,35 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+ class ONNXModelWrapper(nn.Module):
6
+ """
7
+ Wraps an ONNX Runtime InferenceSession inside a PyTorch nn.Module.
8
+ This allows running inference on ONNX models using the exact same code
9
+ and APIs as PyTorch models, supporting both CPU/GPU tensor operations
10
+ and test-time augmentation (TTA) pipelines.
11
+ """
12
+ def __init__(self, model_path: str, device: str = "cpu"):
13
+ super().__init__()
14
+ import onnxruntime as ort
15
+
16
+ # Select execution providers based on the target device
17
+ if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"):
18
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
19
+ else:
20
+ providers = ["CPUExecutionProvider"]
21
+
22
+ self.session = ort.InferenceSession(model_path, providers=providers)
23
+ self.input_name = self.session.get_inputs()[0].name
24
+ self.output_name = self.session.get_outputs()[0].name
25
+
26
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
27
+ # 1. Convert PyTorch tensor to NumPy array (typically expected as float32)
28
+ x_np = x.detach().cpu().numpy().astype(np.float32)
29
+
30
+ # 2. Run inference using ONNX Runtime
31
+ outputs = self.session.run([self.output_name], {self.input_name: x_np})
32
+
33
+ # 3. Convert prediction back to PyTorch tensor and move to the original device
34
+ out_tensor = torch.from_numpy(outputs[0]).to(x.device)
35
+ return out_tensor