torch-projectors 0.12.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_projectors-0.12.0/LICENSE.txt +21 -0
- torch_projectors-0.12.0/MANIFEST.in +11 -0
- torch_projectors-0.12.0/PKG-INFO +13 -0
- torch_projectors-0.12.0/README.md +359 -0
- torch_projectors-0.12.0/csrc/cpu/2d/backprojection_2d_kernels.cpp +537 -0
- torch_projectors-0.12.0/csrc/cpu/2d/backprojection_2d_kernels.h +24 -0
- torch_projectors-0.12.0/csrc/cpu/2d/projection_2d_kernels.cpp +375 -0
- torch_projectors-0.12.0/csrc/cpu/2d/projection_2d_kernels.h +24 -0
- torch_projectors-0.12.0/csrc/cpu/3d/backprojection_2d_to_3d_kernels.cpp +590 -0
- torch_projectors-0.12.0/csrc/cpu/3d/backprojection_2d_to_3d_kernels.h +73 -0
- torch_projectors-0.12.0/csrc/cpu/3d/projection_3d_to_2d_kernels.cpp +402 -0
- torch_projectors-0.12.0/csrc/cpu/3d/projection_3d_to_2d_kernels.h +95 -0
- torch_projectors-0.12.0/csrc/cpu/common/atomic_ops.h +47 -0
- torch_projectors-0.12.0/csrc/cpu/common/cubic_kernels.h +70 -0
- torch_projectors-0.12.0/csrc/cpu/common/fftw_sampling.h +162 -0
- torch_projectors-0.12.0/csrc/cpu/common/interpolation_kernels.h +444 -0
- torch_projectors-0.12.0/csrc/cpu/common/projection_utils.h +497 -0
- torch_projectors-0.12.0/csrc/cuda/2d/backprojection_2d_kernels.cu +1192 -0
- torch_projectors-0.12.0/csrc/cuda/2d/backprojection_2d_kernels.h +28 -0
- torch_projectors-0.12.0/csrc/cuda/2d/projection_2d_kernels.cu +921 -0
- torch_projectors-0.12.0/csrc/cuda/2d/projection_2d_kernels.h +26 -0
- torch_projectors-0.12.0/csrc/cuda/3d/backprojection_2d_to_3d_kernels.cu +1363 -0
- torch_projectors-0.12.0/csrc/cuda/3d/backprojection_2d_to_3d_kernels.h +28 -0
- torch_projectors-0.12.0/csrc/cuda/3d/projection_3d_to_2d_kernels.cu +1181 -0
- torch_projectors-0.12.0/csrc/cuda/3d/projection_3d_to_2d_kernels.h +26 -0
- torch_projectors-0.12.0/csrc/mps/2d/backproject_2d_back.metal +244 -0
- torch_projectors-0.12.0/csrc/mps/2d/backproject_2d_forw.metal +119 -0
- torch_projectors-0.12.0/csrc/mps/2d/backproject_utilities_2d.metal +203 -0
- torch_projectors-0.12.0/csrc/mps/2d/backprojection_2d_kernels.h +87 -0
- torch_projectors-0.12.0/csrc/mps/2d/backprojection_2d_kernels.mm +518 -0
- torch_projectors-0.12.0/csrc/mps/2d/project_2d_back.metal +294 -0
- torch_projectors-0.12.0/csrc/mps/2d/project_2d_forw.metal +99 -0
- torch_projectors-0.12.0/csrc/mps/2d/projection_2d_kernels.h +26 -0
- torch_projectors-0.12.0/csrc/mps/2d/projection_2d_kernels.mm +415 -0
- torch_projectors-0.12.0/csrc/mps/2d/utilities_2d.metal +286 -0
- torch_projectors-0.12.0/csrc/mps/3d/backproject_2d_to_3d_back.metal +287 -0
- torch_projectors-0.12.0/csrc/mps/3d/backproject_2d_to_3d_forw.metal +135 -0
- torch_projectors-0.12.0/csrc/mps/3d/backproject_utilities_3d.metal +290 -0
- torch_projectors-0.12.0/csrc/mps/3d/backprojection_2d_to_3d_kernels.h +87 -0
- torch_projectors-0.12.0/csrc/mps/3d/backprojection_2d_to_3d_kernels.mm +521 -0
- torch_projectors-0.12.0/csrc/mps/3d/project_3d_to_2d_back.metal +339 -0
- torch_projectors-0.12.0/csrc/mps/3d/project_3d_to_2d_forw.metal +113 -0
- torch_projectors-0.12.0/csrc/mps/3d/projection_3d_to_2d_kernels.h +76 -0
- torch_projectors-0.12.0/csrc/mps/3d/projection_3d_to_2d_kernels.mm +416 -0
- torch_projectors-0.12.0/csrc/mps/3d/utilities_3d.metal +344 -0
- torch_projectors-0.12.0/csrc/torch_projectors.cpp +101 -0
- torch_projectors-0.12.0/generate_metal_headers.py +149 -0
- torch_projectors-0.12.0/pyproject.toml +108 -0
- torch_projectors-0.12.0/setup.cfg +4 -0
- torch_projectors-0.12.0/setup.py +274 -0
- torch_projectors-0.12.0/tests/test_utils.py +384 -0
- torch_projectors-0.12.0/torch_projectors/__init__.py +30 -0
- torch_projectors-0.12.0/torch_projectors/_version.py +1 -0
- torch_projectors-0.12.0/torch_projectors/ops.py +574 -0
- torch_projectors-0.12.0/torch_projectors.egg-info/PKG-INFO +13 -0
- torch_projectors-0.12.0/torch_projectors.egg-info/SOURCES.txt +57 -0
- torch_projectors-0.12.0/torch_projectors.egg-info/dependency_links.txt +1 -0
- torch_projectors-0.12.0/torch_projectors.egg-info/requires.txt +1 -0
- torch_projectors-0.12.0/torch_projectors.egg-info/top_level.txt +1 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
The MIT License (MIT)
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Genentech, Inc.
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# Files required to build the C++/Metal extension from a source distribution.
|
|
2
|
+
include LICENSE.txt
|
|
3
|
+
include README.md
|
|
4
|
+
include pyproject.toml
|
|
5
|
+
include generate_metal_headers.py
|
|
6
|
+
|
|
7
|
+
# Native sources (CPU, CUDA, MPS) and headers needed to compile from sdist.
|
|
8
|
+
recursive-include csrc *.cpp *.h *.cu *.cuh *.mm *.metal
|
|
9
|
+
|
|
10
|
+
# Version marker written by the PyPI publish workflow (clean, suffix-free).
|
|
11
|
+
include torch_projectors/_version.py
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torch-projectors
|
|
3
|
+
Version: 0.12.0
|
|
4
|
+
Summary: Differentiable forward and backward projectors for cryo-EM with fast native implementations for CPU, CUDA, and MPS backends.
|
|
5
|
+
Author: Dimitry Tegunov
|
|
6
|
+
Author-email: tegunov@gmail.com
|
|
7
|
+
License-File: LICENSE.txt
|
|
8
|
+
Requires-Dist: torch>=2.6.0
|
|
9
|
+
Dynamic: author
|
|
10
|
+
Dynamic: author-email
|
|
11
|
+
Dynamic: license-file
|
|
12
|
+
Dynamic: requires-dist
|
|
13
|
+
Dynamic: summary
|
|
@@ -0,0 +1,359 @@
|
|
|
1
|
+
# torch-projectors
|
|
2
|
+
|
|
3
|
+

|
|
4
|
+

|
|
5
|
+

|
|
6
|
+
|
|
7
|
+
A high-performance, differentiable 2D and 3D projection library for PyTorch, designed for cryogenic electron microscopy (cryo-EM) and tomography applications. The library provides forward and backward projection operators that work in Fourier space, following the Projection-Slice Theorem.
|
|
8
|
+
|
|
9
|
+
## Features
|
|
10
|
+
|
|
11
|
+
- **Multi-Platform Support**: CPU on Linux, Windows and MacOS; CUDA on Linux; Metal Performance Shaders (MPS) on Apple Silicon
|
|
12
|
+
- **Multiple Backends**: Optimized kernels for different hardware platforms
|
|
13
|
+
- **Interpolation Methods**: Linear and cubic interpolation in 2D and 3D
|
|
14
|
+
- **Fourier Space Operations**: Efficient projections using PyTorch's RFFT format
|
|
15
|
+
- **Full Differentiability**: Gradient support for reconstructions, rotations, and shifts
|
|
16
|
+
- **Batch Processing**: Efficient handling of multiple reconstructions and poses
|
|
17
|
+
- **Oversampling Support**: Computationally efficient and accurate interpolation
|
|
18
|
+
- **Fourier Filtering**: Optional radius cutoff for low-pass filtering
|
|
19
|
+
|
|
20
|
+
## Core API
|
|
21
|
+
|
|
22
|
+
The library provides four main high-level functions:
|
|
23
|
+
|
|
24
|
+
### 2D-to-2D Operations
|
|
25
|
+
- `project_2d_forw()`: Forward project 2D Fourier reconstructions to 2D projections
|
|
26
|
+
- `backproject_2d_forw()`: Back-project 2D projections into 2D reconstructions (adjoint operation)
|
|
27
|
+
|
|
28
|
+
### 3D-to-2D Operations
|
|
29
|
+
- `project_3d_to_2d_forw()`: Forward project 3D Fourier volumes to 2D projections
|
|
30
|
+
|
|
31
|
+
### 2D-to-3D Operations
|
|
32
|
+
- `backproject_2d_to_3d_forw()`: Back-project 2D projections into 3D reconstructions (adjoint operation)
|
|
33
|
+
|
|
34
|
+
## Installation
|
|
35
|
+
|
|
36
|
+
### User Installation
|
|
37
|
+
|
|
38
|
+
Install pre-built wheels from our wheelhouses for your platform:
|
|
39
|
+
|
|
40
|
+
```bash
|
|
41
|
+
# CPU-only on Linux, Windows, MacOS (+ MPS support on MacOS) (requires torch==2.6.0)
|
|
42
|
+
pip install torch-projectors --index-url https://warpem.github.io/torch-projectors/cpu/simple/
|
|
43
|
+
|
|
44
|
+
# CUDA 12.6 on Linux (requires torch==2.6.0)
|
|
45
|
+
pip install torch-projectors --index-url https://warpem.github.io/torch-projectors/cu126/simple/
|
|
46
|
+
|
|
47
|
+
# CUDA 12.8 on Linux (requires torch==2.7.0)
|
|
48
|
+
pip install torch-projectors --index-url https://warpem.github.io/torch-projectors/cu128/simple/
|
|
49
|
+
|
|
50
|
+
# CUDA 12.9 on Linux (requires torch==2.8.0)
|
|
51
|
+
pip install torch-projectors --index-url https://warpem.github.io/torch-projectors/cu129/simple/
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
**Note**: Ensure you have the correct PyTorch version installed for your chosen CUDA version.
|
|
55
|
+
|
|
56
|
+
### Development Setup
|
|
57
|
+
|
|
58
|
+
For development, you'll need to build from source. Requires Python 3.9–3.13:
|
|
59
|
+
|
|
60
|
+
```bash
|
|
61
|
+
# Create environment
|
|
62
|
+
conda create -n torch-projectors python=3.11 -y
|
|
63
|
+
conda activate torch-projectors
|
|
64
|
+
|
|
65
|
+
# Install PyTorch (version depends on your CUDA requirements)
|
|
66
|
+
# For CPU-only or MPS:
|
|
67
|
+
pip install torch==2.6.0
|
|
68
|
+
|
|
69
|
+
# For CUDA 12.6:
|
|
70
|
+
pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/cu126
|
|
71
|
+
|
|
72
|
+
# For CUDA 12.8:
|
|
73
|
+
pip install torch==2.7.0 --index-url https://download.pytorch.org/whl/cu128
|
|
74
|
+
|
|
75
|
+
# For CUDA 12.9:
|
|
76
|
+
pip install torch==2.8.0 --index-url https://download.pytorch.org/whl/cu129
|
|
77
|
+
|
|
78
|
+
# Install development dependencies
|
|
79
|
+
pip install pytest matplotlib pybind11
|
|
80
|
+
|
|
81
|
+
# Install in editable mode (compiles C++ extensions)
|
|
82
|
+
python -m pip install -e . --no-build-isolation
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
The build system automatically detects and enables:
|
|
86
|
+
- **CUDA support** on Linux and Windows when the CUDA Toolkit is available
|
|
87
|
+
- **MPS support** on macOS with Apple Silicon
|
|
88
|
+
- **CPU fallback** on all platforms
|
|
89
|
+
|
|
90
|
+
## Usage Examples
|
|
91
|
+
|
|
92
|
+
This section demonstrates minimal usage patterns for the main projection operations with oversampling:
|
|
93
|
+
|
|
94
|
+
### 2D-to-2D Forward Projection
|
|
95
|
+
|
|
96
|
+
```python
|
|
97
|
+
import torch
|
|
98
|
+
import torch_projectors
|
|
99
|
+
|
|
100
|
+
# Helper function to pad and prepare real-space data
|
|
101
|
+
def pad_and_fftshift(tensor, oversampling_factor):
|
|
102
|
+
H, W = tensor.shape[-2:]
|
|
103
|
+
new_size = int(H * oversampling_factor)
|
|
104
|
+
if new_size % 2 != 0:
|
|
105
|
+
new_size += 1
|
|
106
|
+
pad_total = new_size - H
|
|
107
|
+
pad_before = pad_total // 2
|
|
108
|
+
pad_after = pad_total - pad_before
|
|
109
|
+
padded = torch.nn.functional.pad(tensor, (pad_before, pad_after, pad_before, pad_after))
|
|
110
|
+
return torch.fft.fftshift(padded, dim=(-2, -1))
|
|
111
|
+
|
|
112
|
+
# Start with real-space image
|
|
113
|
+
real_image = torch.randn(32, 32)
|
|
114
|
+
|
|
115
|
+
# 1. Zero pad 2x and fftshift
|
|
116
|
+
padded_image = pad_and_fftshift(real_image, 2.0)
|
|
117
|
+
|
|
118
|
+
# 2. Convert to Fourier space
|
|
119
|
+
fourier_image = torch.fft.rfft2(padded_image, norm='forward')
|
|
120
|
+
|
|
121
|
+
# 3. Set up projection parameters (90-degree rotation)
|
|
122
|
+
rotations = torch.tensor([[0., -1.], [1., 0.]], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
|
|
123
|
+
shifts = torch.zeros(1, 1, 2, dtype=torch.float32)
|
|
124
|
+
|
|
125
|
+
# 4. Forward project with oversampling=2.0
|
|
126
|
+
projection = torch_projectors.project_2d_forw(
|
|
127
|
+
fourier_image.unsqueeze(0), # Add batch dimension
|
|
128
|
+
rotations,
|
|
129
|
+
shifts=shifts,
|
|
130
|
+
output_shape=(32, 32),
|
|
131
|
+
interpolation='linear',
|
|
132
|
+
oversampling=2.0
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# 5. Convert back to real space
|
|
136
|
+
result = torch.fft.irfft2(projection[0, 0], s=(32, 32))
|
|
137
|
+
result = torch.fft.ifftshift(result)
|
|
138
|
+
```
|
|
139
|
+
|
|
140
|
+
### 2D-to-2D Backward Projection
|
|
141
|
+
|
|
142
|
+
```python
|
|
143
|
+
import torch
|
|
144
|
+
import torch_projectors
|
|
145
|
+
|
|
146
|
+
# Helper function to crop and ifftshift real-space data
|
|
147
|
+
def ifftshift_and_crop(real_tensor, oversampling_factor):
|
|
148
|
+
shifted = torch.fft.ifftshift(real_tensor, dim=(-2, -1))
|
|
149
|
+
current_size = real_tensor.shape[-1]
|
|
150
|
+
original_size = int(current_size / oversampling_factor)
|
|
151
|
+
crop_total = current_size - original_size
|
|
152
|
+
crop_start = crop_total // 2
|
|
153
|
+
crop_end = crop_start + original_size
|
|
154
|
+
return shifted[..., crop_start:crop_end, crop_start:crop_end]
|
|
155
|
+
|
|
156
|
+
# Start with real-space image (e.g., a projection to backproject)
|
|
157
|
+
real_projection = torch.randn(32, 32)
|
|
158
|
+
|
|
159
|
+
# 1. fftshift and convert to Fourier space
|
|
160
|
+
shifted_projection = torch.fft.fftshift(real_projection)
|
|
161
|
+
fourier_projection = torch.fft.rfft2(shifted_projection, norm='forward')
|
|
162
|
+
|
|
163
|
+
# 2. Set up backprojection parameters
|
|
164
|
+
rotations = torch.tensor([[0., -1.], [1., 0.]], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
|
|
165
|
+
shifts = torch.zeros(1, 1, 2, dtype=torch.float32)
|
|
166
|
+
|
|
167
|
+
# 3. Backward project with oversampling=2.0
|
|
168
|
+
data_rec, weight_rec = torch_projectors.backproject_2d_forw(
|
|
169
|
+
fourier_projection.unsqueeze(0).unsqueeze(0), # Add batch and pose dimensions
|
|
170
|
+
rotations,
|
|
171
|
+
shifts=shifts,
|
|
172
|
+
interpolation='linear',
|
|
173
|
+
oversampling=2.0
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# 4. Convert reconstruction to real space
|
|
177
|
+
real_reconstruction = torch.fft.irfft2(data_rec[0], norm='forward')
|
|
178
|
+
|
|
179
|
+
# 5. ifftshift and crop to 0.5x size (original size from 2x oversampling)
|
|
180
|
+
result = ifftshift_and_crop(real_reconstruction, 2.0)
|
|
181
|
+
```
|
|
182
|
+
|
|
183
|
+
### 3D-to-2D Forward Projection
|
|
184
|
+
|
|
185
|
+
```python
|
|
186
|
+
import torch
|
|
187
|
+
import torch_projectors
|
|
188
|
+
|
|
189
|
+
# Helper function to pad 3D volumes
|
|
190
|
+
def pad_and_fftshift_3d(tensor, oversampling_factor):
|
|
191
|
+
D, H, W = tensor.shape[-3:]
|
|
192
|
+
new_size = int(D * oversampling_factor)
|
|
193
|
+
if new_size % 2 != 0:
|
|
194
|
+
new_size += 1
|
|
195
|
+
pad_total = new_size - D
|
|
196
|
+
pad_before = pad_total // 2
|
|
197
|
+
pad_after = pad_total - pad_before
|
|
198
|
+
padded = torch.nn.functional.pad(tensor,
|
|
199
|
+
(pad_before, pad_after, # W
|
|
200
|
+
pad_before, pad_after, # H
|
|
201
|
+
pad_before, pad_after)) # D
|
|
202
|
+
return torch.fft.fftshift(padded, dim=(-3, -2, -1))
|
|
203
|
+
|
|
204
|
+
# Start with 3D real-space volume
|
|
205
|
+
real_volume = torch.randn(32, 32, 32)
|
|
206
|
+
|
|
207
|
+
# 1. Zero pad 2x and fftshift
|
|
208
|
+
padded_volume = pad_and_fftshift_3d(real_volume, 2.0)
|
|
209
|
+
|
|
210
|
+
# 2. Convert to Fourier space
|
|
211
|
+
fourier_volume = torch.fft.rfftn(padded_volume, dim=(-3, -2, -1), norm='forward')
|
|
212
|
+
|
|
213
|
+
# 3. Set up projection parameters (90-degree rotation around Y axis)
|
|
214
|
+
rotations = torch.tensor([
|
|
215
|
+
[0., 0., 1.], # x' = z
|
|
216
|
+
[0., 1., 0.], # y' = y
|
|
217
|
+
[-1., 0., 0.] # z' = -x
|
|
218
|
+
], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
|
|
219
|
+
shifts = torch.zeros(1, 1, 2, dtype=torch.float32)
|
|
220
|
+
|
|
221
|
+
# 4. Forward project 3D->2D with oversampling=2.0
|
|
222
|
+
projection = torch_projectors.project_3d_to_2d_forw(
|
|
223
|
+
fourier_volume.unsqueeze(0), # Add batch dimension
|
|
224
|
+
rotations,
|
|
225
|
+
shifts=shifts,
|
|
226
|
+
output_shape=(32, 32),
|
|
227
|
+
interpolation='linear',
|
|
228
|
+
oversampling=2.0
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
# 5. Convert back to real space
|
|
232
|
+
result = torch.fft.irfft2(projection[0, 0], s=(32, 32))
|
|
233
|
+
result = torch.fft.ifftshift(result)
|
|
234
|
+
```
|
|
235
|
+
|
|
236
|
+
### 2D-to-3D Backward Projection
|
|
237
|
+
|
|
238
|
+
```python
|
|
239
|
+
import torch
|
|
240
|
+
import torch_projectors
|
|
241
|
+
|
|
242
|
+
# Helper function to crop 3D volumes
|
|
243
|
+
def ifftshift_and_crop_3d(real_tensor, oversampling_factor):
|
|
244
|
+
shifted = torch.fft.ifftshift(real_tensor, dim=(-3, -2, -1))
|
|
245
|
+
current_size = real_tensor.shape[-3]
|
|
246
|
+
original_size = int(current_size / oversampling_factor)
|
|
247
|
+
crop_total = current_size - original_size
|
|
248
|
+
crop_start = crop_total // 2
|
|
249
|
+
crop_end = crop_start + original_size
|
|
250
|
+
return shifted[..., crop_start:crop_end, crop_start:crop_end, crop_start:crop_end]
|
|
251
|
+
|
|
252
|
+
# Start with 2D real-space projection
|
|
253
|
+
real_projection = torch.randn(32, 32)
|
|
254
|
+
|
|
255
|
+
# 1. fftshift and convert to Fourier space
|
|
256
|
+
shifted_projection = torch.fft.fftshift(real_projection)
|
|
257
|
+
fourier_projection = torch.fft.rfft2(shifted_projection, norm='forward')
|
|
258
|
+
|
|
259
|
+
# 2. Set up backprojection parameters (rotation matrix for 3D)
|
|
260
|
+
rotations = torch.tensor([
|
|
261
|
+
[1., 0., 0.],
|
|
262
|
+
[0., 1., 0.],
|
|
263
|
+
[0., 0., 1.]
|
|
264
|
+
], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
|
|
265
|
+
shifts = torch.zeros(1, 1, 2, dtype=torch.float32)
|
|
266
|
+
|
|
267
|
+
# 3. Backward project 2D->3D with oversampling=2.0
|
|
268
|
+
data_rec, weight_rec = torch_projectors.backproject_2d_to_3d_forw(
|
|
269
|
+
fourier_projection.unsqueeze(0).unsqueeze(0), # Add batch and pose dimensions
|
|
270
|
+
rotations,
|
|
271
|
+
shifts=shifts,
|
|
272
|
+
interpolation='linear',
|
|
273
|
+
oversampling=2.0
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# 4. Convert reconstruction to real space
|
|
277
|
+
real_reconstruction = torch.fft.irfftn(data_rec[0], dim=(-3, -2, -1), norm='forward')
|
|
278
|
+
|
|
279
|
+
# 5. ifftshift and crop to 0.5x size (original size from 2x oversampling)
|
|
280
|
+
result = ifftshift_and_crop_3d(real_reconstruction, 2.0)
|
|
281
|
+
```
|
|
282
|
+
|
|
283
|
+
## Architecture
|
|
284
|
+
|
|
285
|
+
### Core Components
|
|
286
|
+
|
|
287
|
+
- **Python API**: `torch_projectors/ops.py` - Main user interface
|
|
288
|
+
- **C++ Kernels**:
|
|
289
|
+
- `csrc/cpu/2d/projection_2d_kernels.cpp` - 2D forward/backward projection
|
|
290
|
+
- `csrc/cpu/2d/backprojection_2d_kernels.cpp` - 2D back-projection (adjoint)
|
|
291
|
+
- `csrc/cpu/3d/projection_3d_to_2d_kernels.cpp` - 3D-to-2D projection
|
|
292
|
+
- `csrc/cpu/3d/backprojection_2d_to_3d_kernels.cpp` - 2D-to-3D back-projection (adjoint)
|
|
293
|
+
- **CUDA Kernels**: `csrc/cuda/*.cu` - GPU acceleration (when available)
|
|
294
|
+
- **Metal Shaders**: `csrc/mps/*.metal` - Apple Silicon optimization
|
|
295
|
+
- **Operator Registration**: `csrc/torch_projectors.cpp` - PyTorch integration
|
|
296
|
+
|
|
297
|
+
### Design Pattern
|
|
298
|
+
|
|
299
|
+
- **C++ Kernels**: Performance-critical forward/backward operations
|
|
300
|
+
- **TORCH_LIBRARY Registration**: Operators registered in the `torch_projectors` namespace
|
|
301
|
+
- **Python Autograd**: `torch.library.register_autograd` links C++ operators for seamless differentiation
|
|
302
|
+
|
|
303
|
+
## Data Format
|
|
304
|
+
|
|
305
|
+
- **Fourier Space**: Uses PyTorch's RFFT format (last dimension is `N/2 + 1`)
|
|
306
|
+
- **Coordinate System**: Origin `(0,0,0)` at index `[..., 0, 0, 0]`
|
|
307
|
+
- **Batch Dimensions**: Two batch dimensions - first for reconstructions, second for poses
|
|
308
|
+
- **Friedel Symmetry**: Automatically handled for real-valued reconstructions
|
|
309
|
+
|
|
310
|
+
## Testing
|
|
311
|
+
|
|
312
|
+
Comprehensive test suite with visual validation:
|
|
313
|
+
|
|
314
|
+
```bash
|
|
315
|
+
# Run all tests
|
|
316
|
+
pytest
|
|
317
|
+
|
|
318
|
+
# Run specific test categories
|
|
319
|
+
pytest tests/test_basic_projection.py # Core functionality
|
|
320
|
+
pytest tests/test_gradients.py # Gradient verification
|
|
321
|
+
pytest tests/test_cross_platform.py # Multi-platform consistency
|
|
322
|
+
pytest tests/test_performance.py # Performance benchmarks
|
|
323
|
+
pytest tests/test_visual_validation.py # Visual output validation
|
|
324
|
+
```
|
|
325
|
+
|
|
326
|
+
Tests generate visualization outputs in `test_outputs/` for manual inspection and include:
|
|
327
|
+
- Numerical correctness validation
|
|
328
|
+
- Gradient checking via autograd
|
|
329
|
+
- Visual validation with matplotlib plots
|
|
330
|
+
- Cross-platform consistency verification
|
|
331
|
+
- Performance benchmarking
|
|
332
|
+
|
|
333
|
+
## Key Features
|
|
334
|
+
|
|
335
|
+
### 2D Back-Projection (New!)
|
|
336
|
+
- **Adjoint Operations**: Mathematical transpose of forward projection
|
|
337
|
+
- **Weight Accumulation**: Support for CTF² or other weight functions
|
|
338
|
+
- **Full Differentiability**: Gradients w.r.t. projections, weights, rotations, and shifts
|
|
339
|
+
- **Conjugate Phase Shifts**: Proper mathematical adjoint with conjugate phase corrections
|
|
340
|
+
- **Wiener Filtering Ready**: Separate data/weight accumulation enables downstream filtering
|
|
341
|
+
|
|
342
|
+
### Interpolation & Filtering
|
|
343
|
+
- **Interpolation Methods**: Linear (bilinear/trilinear) and cubic (bicubic/tricubic)
|
|
344
|
+
- **Oversampling Support**: Coordinate scaling for computational efficiency
|
|
345
|
+
- **Fourier Filtering**: Optional radius cutoff for low-pass filtering
|
|
346
|
+
- **Friedel Symmetry**: Automatic handling for real-valued reconstructions
|
|
347
|
+
|
|
348
|
+
## Development Status
|
|
349
|
+
|
|
350
|
+
This project is under active development. Current capabilities include:
|
|
351
|
+
- ✅ 2D-to-2D forward projection with full gradient support
|
|
352
|
+
- ✅ 2D-to-2D back-projection (adjoint) with weight accumulation
|
|
353
|
+
- ✅ 3D-to-2D forward projection with full gradient support
|
|
354
|
+
- ✅ 2D-to-3D back-projection (adjoint) with full gradient support
|
|
355
|
+
- 🚧 3D-to-3D projection operations
|
|
356
|
+
- 🚧 3D-to-3D back-projection (adjoint) with weight accumulation
|
|
357
|
+
- 🚧 CUDA and MPS backend implementations
|
|
358
|
+
|
|
359
|
+
The architecture is designed to support future expansion to additional projection geometries and backend optimizations.
|