torchopticsy 0.8.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchopticsy-0.8.3/PKG-INFO +26 -0
- torchopticsy-0.8.3/setup.cfg +4 -0
- torchopticsy-0.8.3/setup.py +16 -0
- torchopticsy-0.8.3/torchopticsy/Diffraction.py +680 -0
- torchopticsy-0.8.3/torchopticsy/Interference.py +284 -0
- torchopticsy-0.8.3/torchopticsy/__init__.py +1 -0
- torchopticsy-0.8.3/torchopticsy.egg-info/PKG-INFO +26 -0
- torchopticsy-0.8.3/torchopticsy.egg-info/SOURCES.txt +9 -0
- torchopticsy-0.8.3/torchopticsy.egg-info/dependency_links.txt +1 -0
- torchopticsy-0.8.3/torchopticsy.egg-info/requires.txt +3 -0
- torchopticsy-0.8.3/torchopticsy.egg-info/top_level.txt +1 -0
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: torchopticsy
|
|
3
|
+
Version: 0.8.3
|
|
4
|
+
Summary: PyTorch-based optics caculation
|
|
5
|
+
Author: YuningYe
|
|
6
|
+
Author-email: 1956860113@qq.com
|
|
7
|
+
License: MIT
|
|
8
|
+
Description-Content-Type: text/markdown
|
|
9
|
+
Requires-Dist: torch
|
|
10
|
+
Requires-Dist: opencv-python
|
|
11
|
+
Requires-Dist: matplotlib
|
|
12
|
+
|
|
13
|
+
# torchOpticsY
|
|
14
|
+
A PyTorch-based optics calculation library.
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
## Features
|
|
18
|
+
- Luneburg integral
|
|
19
|
+
- Debye鈥揥olf integral.
|
|
20
|
+
|
|
21
|
+
## Install
|
|
22
|
+
pip install torchOpticsY
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
0.2 - Interference: OffAxisPhase
|
|
26
|
+
0.1 - Diffraction: Debye_Wolf,Luneburg
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import pathlib
|
|
2
|
+
import setuptools
|
|
3
|
+
|
|
4
|
+
setuptools.setup(
|
|
5
|
+
name="torchopticsy",
|
|
6
|
+
version="0.8.3",
|
|
7
|
+
description="PyTorch-based optics caculation",
|
|
8
|
+
long_description=pathlib.Path("README.md").read_text(),
|
|
9
|
+
long_description_content_type="text/markdown",
|
|
10
|
+
author="YuningYe",
|
|
11
|
+
author_email="1956860113@qq.com",
|
|
12
|
+
license="MIT",
|
|
13
|
+
packages=setuptools.find_packages(),
|
|
14
|
+
install_requires=["torch", "opencv-python", "matplotlib"],
|
|
15
|
+
include_package_data=True,
|
|
16
|
+
)
|
|
@@ -0,0 +1,680 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class BluesteinDFT:
|
|
6
|
+
# Due to the error in periodic expansion, it can be treated by making up for zeros in the future
|
|
7
|
+
# Now only the central region is relatively accurate, and the surrounding area is affected by pseudo-diffraction by periodic conditions
|
|
8
|
+
def __init__(self, f1, f2, fs, mout, m_input, device="cpu"):
|
|
9
|
+
self.device = torch.device(device)
|
|
10
|
+
self.f1 = f1
|
|
11
|
+
self.f2 = f2
|
|
12
|
+
self.fs = fs
|
|
13
|
+
self.mout = mout
|
|
14
|
+
self.m_input = m_input
|
|
15
|
+
|
|
16
|
+
# Frequency adjustment
|
|
17
|
+
f11 = f1 + (mout * fs + f2 - f1) / (2 * mout)
|
|
18
|
+
f22 = f2 + (mout * fs + f2 - f1) / (2 * mout)
|
|
19
|
+
self.f11 = f11
|
|
20
|
+
self.f22 = f22
|
|
21
|
+
|
|
22
|
+
# Chirp parameters
|
|
23
|
+
a = torch.exp(1j * 2 * torch.pi * torch.tensor(f11 / fs))
|
|
24
|
+
w = torch.exp(-1j * 2 * torch.pi * torch.tensor(f22 - f11) / (mout * fs))
|
|
25
|
+
self.a = a.to(self.device)
|
|
26
|
+
self.w = w.to(self.device)
|
|
27
|
+
|
|
28
|
+
h = torch.arange(
|
|
29
|
+
-m_input + 1,
|
|
30
|
+
max(mout - 1, m_input - 1) + 1,
|
|
31
|
+
device=self.device,
|
|
32
|
+
dtype=torch.float64,
|
|
33
|
+
)
|
|
34
|
+
h = self.w ** ((h**2) / 2)
|
|
35
|
+
|
|
36
|
+
self.h = h
|
|
37
|
+
self.mp = m_input + mout - 1
|
|
38
|
+
padded_len = 2 ** int(torch.ceil(torch.log2(torch.tensor(self.mp))))
|
|
39
|
+
self.padded_len = padded_len
|
|
40
|
+
|
|
41
|
+
h_inv = torch.zeros(padded_len, dtype=torch.complex64, device=self.device)
|
|
42
|
+
h_inv[: self.mp] = 1 / h[: self.mp]
|
|
43
|
+
self.ft = torch.fft.fft(h_inv)
|
|
44
|
+
|
|
45
|
+
b_exp = torch.arange(0, m_input, device=self.device)
|
|
46
|
+
self.b_phase = (self.a**-b_exp) * h[m_input - 1 : 2 * m_input - 1]
|
|
47
|
+
|
|
48
|
+
l = torch.linspace(0, mout - 1, mout, device=self.device)
|
|
49
|
+
l = l / mout * (f22 - f11) + f11
|
|
50
|
+
Mshift = -m_input / 2
|
|
51
|
+
self.Mshift = torch.exp(-1j * 2 * torch.pi * l * (Mshift + 0.5) / fs)
|
|
52
|
+
|
|
53
|
+
def transform(self, x, dim=-1):
|
|
54
|
+
x = x.to(self.device)
|
|
55
|
+
m = self.m_input
|
|
56
|
+
|
|
57
|
+
dim = dim if dim >= 0 else x.ndim + dim
|
|
58
|
+
|
|
59
|
+
if x.shape[dim] != m:
|
|
60
|
+
print(m)
|
|
61
|
+
print(x.shape)
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"Expected dimension {dim} to be of size {m}, but got {x.shape[dim]}"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
x = x.transpose(dim, -1)
|
|
67
|
+
|
|
68
|
+
b_phase = self.b_phase.view((1,) * (x.ndim - 1) + (-1,))
|
|
69
|
+
x_weighted = x * b_phase
|
|
70
|
+
|
|
71
|
+
original_shape = x_weighted.shape
|
|
72
|
+
x_weighted = x_weighted.reshape(-1, m)
|
|
73
|
+
|
|
74
|
+
b_padded = torch.zeros(
|
|
75
|
+
(x_weighted.shape[0], self.padded_len),
|
|
76
|
+
dtype=torch.complex64,
|
|
77
|
+
device=self.device,
|
|
78
|
+
)
|
|
79
|
+
b_padded[:, :m] = x_weighted
|
|
80
|
+
|
|
81
|
+
b_fft = torch.fft.fft(b_padded, dim=1)
|
|
82
|
+
conv = b_fft * self.ft[None, :]
|
|
83
|
+
result = torch.fft.ifft(conv, dim=1)
|
|
84
|
+
|
|
85
|
+
result = (
|
|
86
|
+
result[:, self.m_input - 1 : self.mp] * self.h[self.m_input - 1 : self.mp]
|
|
87
|
+
)
|
|
88
|
+
result = result * self.Mshift[None, :]
|
|
89
|
+
|
|
90
|
+
new_shape = list(original_shape[:-1]) + [self.mout]
|
|
91
|
+
result = result.reshape(*new_shape)
|
|
92
|
+
|
|
93
|
+
result = result.transpose(-1, dim)
|
|
94
|
+
|
|
95
|
+
return result
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class DebyeWolf:
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
Min,
|
|
102
|
+
xrange,
|
|
103
|
+
yrange,
|
|
104
|
+
zrange,
|
|
105
|
+
Mout,
|
|
106
|
+
lams, # list of wavelengths
|
|
107
|
+
NA,
|
|
108
|
+
focal_length,
|
|
109
|
+
n1=1,
|
|
110
|
+
device="cpu",
|
|
111
|
+
):
|
|
112
|
+
self.device = device
|
|
113
|
+
self.Min = Min
|
|
114
|
+
self.xrange = xrange
|
|
115
|
+
self.yrange = yrange
|
|
116
|
+
self.z_arr = torch.linspace(zrange[0], zrange[1], Mout[2], device=device)
|
|
117
|
+
self.Moutx, self.Mouty = Mout[0], Mout[1]
|
|
118
|
+
lams = torch.tensor(lams, device=device)
|
|
119
|
+
self.lams, self.k0, self.n1, self.NA, self.focal_length = (
|
|
120
|
+
lams,
|
|
121
|
+
2 * torch.pi / lams,
|
|
122
|
+
n1,
|
|
123
|
+
NA,
|
|
124
|
+
focal_length,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
self.N = (Min - 1) / 2
|
|
128
|
+
|
|
129
|
+
m = torch.linspace(-Min / 2, Min / 2, Min, device=self.device)
|
|
130
|
+
n = torch.linspace(-Min / 2, Min / 2, Min, device=self.device)
|
|
131
|
+
self.m_grid, self.n_grid = torch.meshgrid(m, n, indexing="ij")
|
|
132
|
+
|
|
133
|
+
self.th = torch.asin(
|
|
134
|
+
torch.clamp(
|
|
135
|
+
NA * torch.sqrt(self.m_grid**2 + self.n_grid**2) / (self.N * n1), max=1
|
|
136
|
+
)
|
|
137
|
+
)
|
|
138
|
+
self.mask = self.th > torch.arcsin(torch.tensor(NA / n1))
|
|
139
|
+
self.phi = torch.atan2(self.n_grid, self.m_grid)
|
|
140
|
+
self.phi[self.phi < 0] += 2 * torch.pi
|
|
141
|
+
|
|
142
|
+
self._sqrt_costh = 1 / torch.sqrt(torch.cos(self.th).unsqueeze(-1))
|
|
143
|
+
self._sqrt_costh[torch.isnan(self._sqrt_costh)] = 0
|
|
144
|
+
self._sqrt_costh[self.mask] = 0
|
|
145
|
+
|
|
146
|
+
fs = lams * (Min - 1) / (2 * NA)
|
|
147
|
+
self.fs = fs
|
|
148
|
+
self.bluesteins_y = []
|
|
149
|
+
self.bluesteins_x = []
|
|
150
|
+
self.C = (
|
|
151
|
+
-1j
|
|
152
|
+
* torch.exp(1j * self.k0 * n1 * focal_length)
|
|
153
|
+
* focal_length
|
|
154
|
+
* (lams)
|
|
155
|
+
/ (self.n1)
|
|
156
|
+
/ fs
|
|
157
|
+
/ fs
|
|
158
|
+
)
|
|
159
|
+
fs = fs.cpu().tolist()
|
|
160
|
+
for f in fs:
|
|
161
|
+
self.bluesteins_y.append(
|
|
162
|
+
BluesteinDFT(
|
|
163
|
+
f / 2 + self.yrange[0],
|
|
164
|
+
f / 2 + self.yrange[1],
|
|
165
|
+
f,
|
|
166
|
+
self.Mouty,
|
|
167
|
+
Min,
|
|
168
|
+
device=device,
|
|
169
|
+
)
|
|
170
|
+
)
|
|
171
|
+
self.bluesteins_x.append(
|
|
172
|
+
BluesteinDFT(
|
|
173
|
+
f / 2 + self.xrange[0],
|
|
174
|
+
f / 2 + self.xrange[1],
|
|
175
|
+
f,
|
|
176
|
+
self.Moutx,
|
|
177
|
+
Min,
|
|
178
|
+
device=device,
|
|
179
|
+
)
|
|
180
|
+
)
|
|
181
|
+
self.E_ideals = torch.ones_like(self.lams)
|
|
182
|
+
self.R = torch.stack(
|
|
183
|
+
[
|
|
184
|
+
-torch.sin(self.th) * torch.cos(self.phi),
|
|
185
|
+
-torch.sin(self.th) * torch.sin(self.phi),
|
|
186
|
+
torch.cos(self.th),
|
|
187
|
+
],
|
|
188
|
+
dim=-1,
|
|
189
|
+
)
|
|
190
|
+
self.R = self.R.unsqueeze(-2)
|
|
191
|
+
|
|
192
|
+
def __call__(self, E, correct=False):
|
|
193
|
+
# The input E has shape (batch, x, y, 2, lam),
|
|
194
|
+
# where the z-component is not included.
|
|
195
|
+
# The output E has shape (batch, x, y, z, 3, lam).
|
|
196
|
+
# For different wavelengths (lam), a simple for-loop is used for now.
|
|
197
|
+
|
|
198
|
+
Ex_in, Ey_in = E[..., 0:1, :], E[..., 1:2, :]
|
|
199
|
+
th = self.th.unsqueeze(-1).unsqueeze(-1)
|
|
200
|
+
phi = self.phi.unsqueeze(-1).unsqueeze(-1)
|
|
201
|
+
z_arr = self.z_arr.unsqueeze(0).unsqueeze(0).unsqueeze(-1)
|
|
202
|
+
k0, n1 = self.k0.view(1, 1, 1, -1), self.n1
|
|
203
|
+
|
|
204
|
+
costh = torch.cos(th)
|
|
205
|
+
_sqrt_costh = self._sqrt_costh.unsqueeze(-1)
|
|
206
|
+
phase = torch.exp(1j * k0 * n1 * z_arr * costh)
|
|
207
|
+
deltadim = 0
|
|
208
|
+
C = (self.C / self.E_ideals).view(1, 1, 1, -1)
|
|
209
|
+
E_out = torch.zeros(
|
|
210
|
+
[self.Moutx, self.Mouty, self.z_arr.numel(), 3, self.lams.numel()],
|
|
211
|
+
dtype=torch.complex64,
|
|
212
|
+
device=self.device,
|
|
213
|
+
)
|
|
214
|
+
if E.dim() == 5:
|
|
215
|
+
th = th.unsqueeze(0)
|
|
216
|
+
phi = phi.unsqueeze(0)
|
|
217
|
+
z_arr = z_arr.unsqueeze(0)
|
|
218
|
+
k0 = k0.unsqueeze(0)
|
|
219
|
+
costh = costh.unsqueeze(0)
|
|
220
|
+
_sqrt_costh = _sqrt_costh.unsqueeze(0)
|
|
221
|
+
phase = phase.unsqueeze(0)
|
|
222
|
+
C = C.unsqueeze(0)
|
|
223
|
+
deltadim = 1
|
|
224
|
+
E_out = torch.zeros(
|
|
225
|
+
[
|
|
226
|
+
E.size(0),
|
|
227
|
+
self.Moutx,
|
|
228
|
+
self.Mouty,
|
|
229
|
+
self.z_arr.numel(),
|
|
230
|
+
3,
|
|
231
|
+
self.lams.numel(),
|
|
232
|
+
],
|
|
233
|
+
dtype=torch.complex64,
|
|
234
|
+
device=self.device,
|
|
235
|
+
)
|
|
236
|
+
Ex = (
|
|
237
|
+
(
|
|
238
|
+
Ex_in * (1 + (costh - 1) * torch.cos(phi) ** 2)
|
|
239
|
+
+ Ey_in * (costh - 1) * torch.cos(phi) * torch.sin(phi)
|
|
240
|
+
)
|
|
241
|
+
* phase
|
|
242
|
+
* _sqrt_costh
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
Ey = (
|
|
246
|
+
(
|
|
247
|
+
Ex_in * (costh - 1) * torch.cos(phi) * torch.sin(phi)
|
|
248
|
+
+ Ey_in * (1 + (costh - 1) * torch.sin(phi) ** 2)
|
|
249
|
+
)
|
|
250
|
+
* phase
|
|
251
|
+
* _sqrt_costh
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
Ez = (
|
|
255
|
+
(Ex_in * torch.cos(phi) + Ey_in * torch.sin(phi))
|
|
256
|
+
* torch.sin(th)
|
|
257
|
+
* phase
|
|
258
|
+
* _sqrt_costh
|
|
259
|
+
)
|
|
260
|
+
if correct:
|
|
261
|
+
temp = torch.stack([Ex, Ey, Ez], dim=-1)
|
|
262
|
+
temp = temp - 0.5 * self.R * torch.sum(self.R * temp, dim=-1, keepdim=True)
|
|
263
|
+
Ex, Ey, Ez = temp[:, :, :, 0], temp[:, :, :, 1], temp[:, :, :, 2]
|
|
264
|
+
for i in range(self.lams.numel()):
|
|
265
|
+
E_out[..., 0, i] = self.bluesteins_x[i].transform(
|
|
266
|
+
self.bluesteins_y[i].transform(Ex[..., i], dim=1 + deltadim),
|
|
267
|
+
dim=0 + deltadim,
|
|
268
|
+
)
|
|
269
|
+
E_out[..., 1, i] = self.bluesteins_x[i].transform(
|
|
270
|
+
self.bluesteins_y[i].transform(Ey[..., i], dim=1 + deltadim),
|
|
271
|
+
dim=0 + deltadim,
|
|
272
|
+
)
|
|
273
|
+
E_out[..., 2, i] = self.bluesteins_x[i].transform(
|
|
274
|
+
self.bluesteins_y[i].transform(Ez[..., i], dim=1 + deltadim),
|
|
275
|
+
dim=0 + deltadim,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
return C * E_out
|
|
279
|
+
|
|
280
|
+
def Get_Z_offset_Phase(self, z):
|
|
281
|
+
k0, n1 = self.k0, self.n1
|
|
282
|
+
costh = torch.cos(self.th)
|
|
283
|
+
phase = k0.view(1, 1, -1) * n1 * z * (-costh).unsqueeze(-1)
|
|
284
|
+
return phase
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def fft_circular_conv2d(E, G):
|
|
288
|
+
E_fft = torch.fft.fft2(E.permute(2, 0, 1))
|
|
289
|
+
G_fft = torch.fft.fft2(G.permute(2, 0, 1))
|
|
290
|
+
C_fft = E_fft * G_fft
|
|
291
|
+
C_ifft = torch.fft.ifft2(C_fft).permute(1, 2, 0)
|
|
292
|
+
return C_ifft
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def fourier_upsample2d(E: torch.Tensor, up_sample=1) -> torch.Tensor:
|
|
296
|
+
x, y = E.shape[0], E.shape[1]
|
|
297
|
+
Xn, Yn = x * up_sample, y * up_sample
|
|
298
|
+
|
|
299
|
+
# 2D FFT(仅前两维),正交归一避免缩放差异
|
|
300
|
+
F = torch.fft.fft2(E, dim=(0, 1), norm="ortho")
|
|
301
|
+
Fc = torch.fft.fftshift(F, dim=(0, 1))
|
|
302
|
+
|
|
303
|
+
# 频域补零到更大网格
|
|
304
|
+
Fp = torch.zeros((Xn, Yn) + E.shape[2:], dtype=E.dtype, device=E.device)
|
|
305
|
+
|
|
306
|
+
x0 = (Xn - x + 1) // 2
|
|
307
|
+
y0 = (Yn - y + 1) // 2
|
|
308
|
+
Fp[x0 : x0 + x, y0 : y0 + y, ...] = Fc
|
|
309
|
+
|
|
310
|
+
Fp = torch.fft.ifftshift(Fp, dim=(0, 1))
|
|
311
|
+
up_E_full = torch.fft.ifft2(Fp, dim=(0, 1), norm="ortho")
|
|
312
|
+
|
|
313
|
+
X_keep = (x - 1) * up_sample + 1
|
|
314
|
+
Y_keep = (y - 1) * up_sample + 1
|
|
315
|
+
up_E = up_E_full[:X_keep, :Y_keep, ...] * up_sample
|
|
316
|
+
return up_E
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def upsample_coords_xy(x: torch.Tensor, y: torch.Tensor, up_sample: int):
|
|
320
|
+
"""
|
|
321
|
+
输入:
|
|
322
|
+
x: 原 x 轴坐标,形状 (Nx,)
|
|
323
|
+
y: 原 y 轴坐标,形状 (Ny,)
|
|
324
|
+
up_sample: 正整数,上采样倍数
|
|
325
|
+
|
|
326
|
+
返回:
|
|
327
|
+
x_up: 形状 ((Nx-1)*up + 1,)
|
|
328
|
+
y_up: 形状 ((Ny-1)*up + 1,)
|
|
329
|
+
(范围不变,步长缩小 1/up)
|
|
330
|
+
说明:
|
|
331
|
+
假设 x、y 为等间距采样(FFT/零填充本身也要求等间距)
|
|
332
|
+
"""
|
|
333
|
+
if up_sample < 1:
|
|
334
|
+
raise ValueError("up_sample 必须 >= 1")
|
|
335
|
+
if up_sample == 1:
|
|
336
|
+
return x, y
|
|
337
|
+
|
|
338
|
+
if x.numel() < 2 or y.numel() < 2:
|
|
339
|
+
raise ValueError("x 和 y 至少需要 2 个点用于定义范围。")
|
|
340
|
+
|
|
341
|
+
# 直接按起点和终点线性插值(等间距)
|
|
342
|
+
dx = (x[1] - x[0]) / up_sample * 1.5 * (up_sample - 1)
|
|
343
|
+
dy = (y[1] - y[0]) / up_sample * 1.5 * (up_sample - 1)
|
|
344
|
+
x_up = torch.linspace(
|
|
345
|
+
x[0] + dx,
|
|
346
|
+
x[-1] + dx,
|
|
347
|
+
(x.numel() - 1) * up_sample + 1,
|
|
348
|
+
device=x.device,
|
|
349
|
+
dtype=x.dtype,
|
|
350
|
+
)
|
|
351
|
+
y_up = torch.linspace(
|
|
352
|
+
y[0] + dy,
|
|
353
|
+
y[-1] + dy,
|
|
354
|
+
(y.numel() - 1) * up_sample + 1,
|
|
355
|
+
device=y.device,
|
|
356
|
+
dtype=y.dtype,
|
|
357
|
+
)
|
|
358
|
+
return x_up, y_up
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
class Luneburg:
|
|
362
|
+
# The calculation method for Ez is different and omitted here.
|
|
363
|
+
# The sampling theorem must be satisfied: T < lambda / 2
|
|
364
|
+
def __init__(
|
|
365
|
+
self,
|
|
366
|
+
uv_len, # Integration region
|
|
367
|
+
x_range,
|
|
368
|
+
y_range,
|
|
369
|
+
z_range,
|
|
370
|
+
sampling_interval, # Sampling interval, a 1D array; strictly followed in xy-directions,
|
|
371
|
+
# while z-direction is adaptively adjusted
|
|
372
|
+
lams,
|
|
373
|
+
focal_length, # Evaluation is based on this focal length
|
|
374
|
+
n1=1,
|
|
375
|
+
up_sample=1, # Integer > 1,
|
|
376
|
+
device="cuda",
|
|
377
|
+
):
|
|
378
|
+
self.device = device
|
|
379
|
+
self.up_sample = up_sample
|
|
380
|
+
lams = torch.tensor(lams, device=device)
|
|
381
|
+
self.lams = lams
|
|
382
|
+
self.k = 2 * torch.pi * n1 / self.lams
|
|
383
|
+
self.n1 = n1
|
|
384
|
+
self.focal_length = focal_length
|
|
385
|
+
uv_len = uv_len - sampling_interval[0]
|
|
386
|
+
ux_arr = torch.arange(
|
|
387
|
+
-uv_len / 2 + x_range[0] - sampling_interval[0],
|
|
388
|
+
x_range[1] + uv_len / 2,
|
|
389
|
+
step=sampling_interval[0],
|
|
390
|
+
device=device,
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
vy_arr = torch.arange(
|
|
394
|
+
-uv_len / 2 + y_range[0] - sampling_interval[1],
|
|
395
|
+
y_range[1] + uv_len / 2,
|
|
396
|
+
step=sampling_interval[1],
|
|
397
|
+
device=device,
|
|
398
|
+
)
|
|
399
|
+
z_arr = torch.linspace(
|
|
400
|
+
z_range[0] + focal_length,
|
|
401
|
+
z_range[1] + focal_length,
|
|
402
|
+
int((z_range[1] - z_range[0]) / sampling_interval[2] / 2) * 2 + 1,
|
|
403
|
+
device=device,
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
center = (
|
|
407
|
+
round((x_range[0] + x_range[1]) / 2 / sampling_interval[0])
|
|
408
|
+
* sampling_interval[0]
|
|
409
|
+
)
|
|
410
|
+
u_arr = ux_arr - center
|
|
411
|
+
u_start_id = torch.argmin(torch.abs(u_arr + uv_len / 2))
|
|
412
|
+
u_end_id = torch.argmin(torch.abs(u_arr - uv_len / 2))
|
|
413
|
+
u_arr = u_arr[u_start_id : u_end_id + 1] + sampling_interval[0] / 2
|
|
414
|
+
du = torch.mean(u_arr)
|
|
415
|
+
u_arr = u_arr - du
|
|
416
|
+
center = (
|
|
417
|
+
round((y_range[0] + y_range[1]) / 2 / sampling_interval[1])
|
|
418
|
+
* sampling_interval[1]
|
|
419
|
+
)
|
|
420
|
+
v_arr = vy_arr - center
|
|
421
|
+
v_start_id = torch.argmin(torch.abs(v_arr + uv_len / 2))
|
|
422
|
+
v_end_id = torch.argmin(torch.abs(v_arr - uv_len / 2))
|
|
423
|
+
v_arr = v_arr[v_start_id : v_end_id + 1] + sampling_interval[1] / 2
|
|
424
|
+
dv = torch.mean(v_arr)
|
|
425
|
+
v_arr = v_arr - dv
|
|
426
|
+
|
|
427
|
+
x_start_id = torch.argmin(torch.abs(ux_arr - x_range[0]))
|
|
428
|
+
x_end_id = torch.argmin(torch.abs(ux_arr - x_range[1]))
|
|
429
|
+
y_start_id = torch.argmin(torch.abs(vy_arr - y_range[0]))
|
|
430
|
+
y_end_id = torch.argmin(torch.abs(vy_arr - y_range[1]))
|
|
431
|
+
|
|
432
|
+
x_arr = ux_arr[x_start_id : x_end_id + 1] - du
|
|
433
|
+
y_arr = vy_arr[y_start_id : y_end_id + 1] - dv
|
|
434
|
+
|
|
435
|
+
co_public = (
|
|
436
|
+
-1j
|
|
437
|
+
* sampling_interval[0]
|
|
438
|
+
* sampling_interval[1]
|
|
439
|
+
* n1
|
|
440
|
+
/ self.lams.view(1, 1, 1, -1)
|
|
441
|
+
) # / 4 / torch.pi
|
|
442
|
+
R_q = torch.sqrt(
|
|
443
|
+
(ux_arr.view(-1, 1, 1, 1)) ** 2
|
|
444
|
+
+ (vy_arr.view(1, -1, 1, 1)) ** 2
|
|
445
|
+
+ z_arr.view(1, 1, -1, 1) ** 2
|
|
446
|
+
)
|
|
447
|
+
G_2D = self.k * R_q * 1j
|
|
448
|
+
G_2D = (
|
|
449
|
+
torch.exp(G_2D)
|
|
450
|
+
/ R_q**2
|
|
451
|
+
* co_public
|
|
452
|
+
* (1 - 1 / G_2D)
|
|
453
|
+
* torch.sinc(
|
|
454
|
+
n1
|
|
455
|
+
* ux_arr.view(-1, 1, 1, 1)
|
|
456
|
+
* sampling_interval[0]
|
|
457
|
+
/ (R_q * lams.view(1, 1, 1, -1))
|
|
458
|
+
)
|
|
459
|
+
* torch.sinc(
|
|
460
|
+
n1
|
|
461
|
+
* vy_arr.view(1, -1, 1, 1)
|
|
462
|
+
* sampling_interval[1]
|
|
463
|
+
/ (R_q * lams.view(1, 1, 1, -1))
|
|
464
|
+
)
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
self.G_2D = G_2D
|
|
468
|
+
|
|
469
|
+
self.u_arr = u_arr
|
|
470
|
+
self.v_arr = v_arr
|
|
471
|
+
self.x_arr, self.y_arr = upsample_coords_xy(x_arr, y_arr, up_sample)
|
|
472
|
+
self.z_arr = z_arr
|
|
473
|
+
|
|
474
|
+
self.x_num = x_end_id - x_start_id + 1
|
|
475
|
+
self.y_num = y_end_id - y_start_id + 1
|
|
476
|
+
|
|
477
|
+
self.u_pad_num = ux_arr.numel() - u_arr.numel()
|
|
478
|
+
self.v_pad_num = vy_arr.numel() - v_arr.numel()
|
|
479
|
+
self.dS = sampling_interval[0] * sampling_interval[1]
|
|
480
|
+
|
|
481
|
+
def __call__(self, E):
|
|
482
|
+
# E should have shape (u, v, lam)
|
|
483
|
+
if E.dim() == 3:
|
|
484
|
+
E_padded = E
|
|
485
|
+
pad = (
|
|
486
|
+
0,
|
|
487
|
+
0, # No padding applied to the 2nd dimension (D2)
|
|
488
|
+
self.v_pad_num,
|
|
489
|
+
0, # Extend on both sides of the 1st dimension (D1)
|
|
490
|
+
self.u_pad_num,
|
|
491
|
+
0, # Extend on both sides of the 0th dimension (D0)
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
E_padded = F.pad(E_padded, pad, mode="constant", value=0)
|
|
495
|
+
E_out = torch.zeros(
|
|
496
|
+
[
|
|
497
|
+
self.x_num,
|
|
498
|
+
self.y_num,
|
|
499
|
+
self.z_arr.numel(),
|
|
500
|
+
self.lams.numel(),
|
|
501
|
+
],
|
|
502
|
+
device=self.device,
|
|
503
|
+
dtype=torch.complex64,
|
|
504
|
+
)
|
|
505
|
+
for z in range(self.z_arr.numel()):
|
|
506
|
+
E_out[:, :, z, :] = fft_circular_conv2d(
|
|
507
|
+
E_padded, self.G_2D[:, :, z] * self.z_arr[z]
|
|
508
|
+
)[
|
|
509
|
+
: self.x_num,
|
|
510
|
+
: self.y_num,
|
|
511
|
+
]
|
|
512
|
+
E_out = fourier_upsample2d(E_out, self.up_sample)
|
|
513
|
+
return E_out
|
|
514
|
+
|
|
515
|
+
def Get_G_2D(self, x_f, y_f, z_f): # Typically used for optimization and validation
|
|
516
|
+
z = self.focal_length + z_f # Detection coordinate
|
|
517
|
+
co_public = -1j * self.dS * self.n1 / self.lams # / 4 / torch.pi
|
|
518
|
+
R_q = torch.sqrt(
|
|
519
|
+
z**2
|
|
520
|
+
+ (x_f - self.u_arr.view(-1, 1, 1)) ** 2
|
|
521
|
+
+ (y_f - self.v_arr.view(1, -1, 1)) ** 2
|
|
522
|
+
)
|
|
523
|
+
ik0R = self.k.view(1, 1, -1) * R_q * 1j
|
|
524
|
+
G_2D = torch.exp(ik0R) / R_q**2 * (1 - 1 / ik0R) * co_public
|
|
525
|
+
|
|
526
|
+
return z * G_2D
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
class LuneburgOld:
|
|
530
|
+
# The calculation method for Ez is different and omitted here.
|
|
531
|
+
# The sampling theorem must be satisfied: T < lambda / 2
|
|
532
|
+
def __init__(
|
|
533
|
+
self,
|
|
534
|
+
uv_len, # Integration region
|
|
535
|
+
x_range,
|
|
536
|
+
y_range,
|
|
537
|
+
z_range,
|
|
538
|
+
sampling_interval, # Sampling interval, a 1D array; strictly followed in xy-directions,
|
|
539
|
+
# while z-direction is adaptively adjusted
|
|
540
|
+
lams,
|
|
541
|
+
focal_length, # Evaluation is based on this focal length
|
|
542
|
+
n1=1,
|
|
543
|
+
scale=1, # Integer > 1, used to represent a reduced sampling frequency for E
|
|
544
|
+
device="cuda",
|
|
545
|
+
):
|
|
546
|
+
self.device = device
|
|
547
|
+
self.scale = scale
|
|
548
|
+
lams = torch.tensor(lams, device=device)
|
|
549
|
+
self.lams = lams
|
|
550
|
+
self.k = 2 * torch.pi * n1 / self.lams
|
|
551
|
+
self.n1 = n1
|
|
552
|
+
self.focal_length = focal_length
|
|
553
|
+
uv_len = uv_len - sampling_interval[0]
|
|
554
|
+
ux_arr = torch.arange(
|
|
555
|
+
-uv_len / 2 + x_range[0] - sampling_interval[0],
|
|
556
|
+
x_range[1] + uv_len / 2,
|
|
557
|
+
step=sampling_interval[0],
|
|
558
|
+
device=device,
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
vy_arr = torch.arange(
|
|
562
|
+
-uv_len / 2 + y_range[0] - sampling_interval[1],
|
|
563
|
+
y_range[1] + uv_len / 2,
|
|
564
|
+
step=sampling_interval[1],
|
|
565
|
+
device=device,
|
|
566
|
+
)
|
|
567
|
+
z_arr = torch.linspace(
|
|
568
|
+
z_range[0] + focal_length,
|
|
569
|
+
z_range[1] + focal_length,
|
|
570
|
+
int((z_range[1] - z_range[0]) / sampling_interval[2] / 2) * 2 + 1,
|
|
571
|
+
device=device,
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
center = (
|
|
575
|
+
round((x_range[0] + x_range[1]) / 2 / sampling_interval[0])
|
|
576
|
+
* sampling_interval[0]
|
|
577
|
+
)
|
|
578
|
+
u_arr = ux_arr - center
|
|
579
|
+
u_start_id = torch.argmin(torch.abs(u_arr + uv_len / 2))
|
|
580
|
+
u_end_id = torch.argmin(torch.abs(u_arr - uv_len / 2))
|
|
581
|
+
u_arr = u_arr[u_start_id : u_end_id + 1 : scale] + sampling_interval[0] / 2
|
|
582
|
+
du = torch.mean(u_arr)
|
|
583
|
+
u_arr = u_arr - du
|
|
584
|
+
center = (
|
|
585
|
+
round((y_range[0] + y_range[1]) / 2 / sampling_interval[1])
|
|
586
|
+
* sampling_interval[1]
|
|
587
|
+
)
|
|
588
|
+
v_arr = vy_arr - center
|
|
589
|
+
v_start_id = torch.argmin(torch.abs(v_arr + uv_len / 2))
|
|
590
|
+
v_end_id = torch.argmin(torch.abs(v_arr - uv_len / 2))
|
|
591
|
+
v_arr = v_arr[v_start_id : v_end_id + 1 : scale] + sampling_interval[1] / 2
|
|
592
|
+
dv = torch.mean(v_arr)
|
|
593
|
+
v_arr = v_arr - dv
|
|
594
|
+
|
|
595
|
+
x_start_id = torch.argmin(torch.abs(ux_arr - x_range[0]))
|
|
596
|
+
x_end_id = torch.argmin(torch.abs(ux_arr - x_range[1]))
|
|
597
|
+
y_start_id = torch.argmin(torch.abs(vy_arr - y_range[0]))
|
|
598
|
+
y_end_id = torch.argmin(torch.abs(vy_arr - y_range[1]))
|
|
599
|
+
|
|
600
|
+
x_arr = ux_arr[x_start_id : x_end_id + 1] - du
|
|
601
|
+
y_arr = vy_arr[y_start_id : y_end_id + 1] - dv
|
|
602
|
+
|
|
603
|
+
co_public = (
|
|
604
|
+
-1j
|
|
605
|
+
* sampling_interval[0]
|
|
606
|
+
* sampling_interval[1]
|
|
607
|
+
* n1
|
|
608
|
+
/ self.lams.view(1, 1, 1, -1)
|
|
609
|
+
) # / 4 / torch.pi
|
|
610
|
+
R_q = torch.sqrt(
|
|
611
|
+
(ux_arr.view(-1, 1, 1, 1)) ** 2
|
|
612
|
+
+ (vy_arr.view(1, -1, 1, 1)) ** 2
|
|
613
|
+
+ z_arr.view(1, 1, -1, 1) ** 2
|
|
614
|
+
)
|
|
615
|
+
G_2D = self.k * R_q * 1j
|
|
616
|
+
G_2D = torch.exp(G_2D) / R_q**2 * co_public * (1 - 1 / G_2D)
|
|
617
|
+
|
|
618
|
+
self.G_2D = G_2D
|
|
619
|
+
|
|
620
|
+
self.u_arr = u_arr
|
|
621
|
+
self.v_arr = v_arr
|
|
622
|
+
self.x_arr = x_arr
|
|
623
|
+
self.y_arr = y_arr
|
|
624
|
+
self.z_arr = z_arr
|
|
625
|
+
|
|
626
|
+
self.x_num = x_end_id - x_start_id + 1
|
|
627
|
+
self.y_num = y_end_id - y_start_id + 1
|
|
628
|
+
|
|
629
|
+
self.u_pad_num = ux_arr.numel() - u_arr.numel() * scale
|
|
630
|
+
self.v_pad_num = vy_arr.numel() - v_arr.numel() * scale
|
|
631
|
+
self.dS = sampling_interval[0] * sampling_interval[1]
|
|
632
|
+
|
|
633
|
+
def __call__(self, E):
|
|
634
|
+
# E should have shape (u, v, lam)
|
|
635
|
+
if E.dim() == 3:
|
|
636
|
+
E_padded = E.repeat_interleave(self.scale, dim=0)
|
|
637
|
+
E_padded = E_padded.repeat_interleave(self.scale, dim=1)
|
|
638
|
+
pad = (
|
|
639
|
+
0,
|
|
640
|
+
0, # No padding applied to the 2nd dimension (D2)
|
|
641
|
+
self.v_pad_num,
|
|
642
|
+
0, # Extend on both sides of the 1st dimension (D1)
|
|
643
|
+
self.u_pad_num,
|
|
644
|
+
0, # Extend on both sides of the 0th dimension (D0)
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
E_padded = F.pad(E_padded, pad, mode="constant", value=0)
|
|
648
|
+
E_out = torch.zeros(
|
|
649
|
+
[
|
|
650
|
+
self.x_num,
|
|
651
|
+
self.y_num,
|
|
652
|
+
self.z_arr.numel(),
|
|
653
|
+
self.lams.numel(),
|
|
654
|
+
],
|
|
655
|
+
device=self.device,
|
|
656
|
+
dtype=torch.complex64,
|
|
657
|
+
)
|
|
658
|
+
for z in range(self.z_arr.numel()):
|
|
659
|
+
E_out[:, :, z, :] = fft_circular_conv2d(
|
|
660
|
+
E_padded, self.G_2D[:, :, z] * self.z_arr[z]
|
|
661
|
+
)[
|
|
662
|
+
: self.x_num,
|
|
663
|
+
: self.y_num,
|
|
664
|
+
]
|
|
665
|
+
return E_out
|
|
666
|
+
|
|
667
|
+
def Get_G_2D(self, x_f, y_f, z_f): # Typically used for optimization and validation
|
|
668
|
+
z = self.focal_length + z_f # Detection coordinate
|
|
669
|
+
co_public = (
|
|
670
|
+
-1j * self.dS * self.n1 / self.lams * self.scale**2
|
|
671
|
+
) # / 4 / torch.pi
|
|
672
|
+
R_q = torch.sqrt(
|
|
673
|
+
z**2
|
|
674
|
+
+ (x_f - self.u_arr.view(-1, 1, 1)) ** 2
|
|
675
|
+
+ (y_f - self.v_arr.view(1, -1, 1)) ** 2
|
|
676
|
+
)
|
|
677
|
+
ik0R = self.k.view(1, 1, -1) * R_q * 1j
|
|
678
|
+
G_2D = torch.exp(ik0R) / R_q**2 * (1 - 1 / ik0R) * co_public
|
|
679
|
+
|
|
680
|
+
return z * G_2D
|
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
# %%
|
|
2
|
+
# 滤波确实可以去掉伪影
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
import cv2
|
|
6
|
+
import os
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
|
|
9
|
+
# pixel_size=2.4/100*1.8/2
|
|
10
|
+
# class OffAxisPhaseOld:
|
|
11
|
+
# def __init__(
|
|
12
|
+
# self,
|
|
13
|
+
# file_name,
|
|
14
|
+
# continuous=True, # 默认是连续光
|
|
15
|
+
# block_size=1024, # 每个子块的大小
|
|
16
|
+
# overlap=0, # 子块之间是否重叠
|
|
17
|
+
# ):
|
|
18
|
+
# self.continuous = continuous
|
|
19
|
+
# self.block_size = block_size
|
|
20
|
+
# self.overlap = overlap
|
|
21
|
+
|
|
22
|
+
# # 支持的图片后缀
|
|
23
|
+
# self.valid_ext = [".jpg", ".png", ".bmp"]
|
|
24
|
+
|
|
25
|
+
# # 读取图像
|
|
26
|
+
# BACK = self._read_image("BACK", file_name)
|
|
27
|
+
# self.OBJ = torch.clamp(self._read_image("OBJ", file_name) - BACK, min=0)
|
|
28
|
+
# self.REF = torch.clamp(self._read_image("REF", file_name) - BACK, min=0)
|
|
29
|
+
# self.OBJ_REF = torch.clamp(self._read_image("OBJ_REF", file_name) - BACK, min=0)
|
|
30
|
+
# self.INC = torch.clamp(self._read_image("INC", file_name) - BACK, min=0)
|
|
31
|
+
# self.INC_REF = torch.clamp(self._read_image("INC_REF", file_name) - BACK, min=0)
|
|
32
|
+
|
|
33
|
+
# def _read_image(self, prefix, file_name):
|
|
34
|
+
# for ext in self.valid_ext:
|
|
35
|
+
# img_path = os.path.join(file_name, f"{prefix}{ext}")
|
|
36
|
+
# if os.path.exists(img_path):
|
|
37
|
+
# img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
|
38
|
+
# if img is None:
|
|
39
|
+
# raise ValueError(f"无法读取图像: {img_path}")
|
|
40
|
+
# img_tensor = torch.from_numpy(img)
|
|
41
|
+
|
|
42
|
+
# if img_tensor.dtype == torch.uint8:
|
|
43
|
+
# img_tensor = img_tensor.to(torch.float32) / 255.0
|
|
44
|
+
# elif img_tensor.dtype == torch.uint16:
|
|
45
|
+
# img_tensor = img_tensor.to(torch.float32) / 65535.0
|
|
46
|
+
# else:
|
|
47
|
+
# raise ValueError(f"不支持的图像位深: {img_tensor.dtype}")
|
|
48
|
+
# return img_tensor
|
|
49
|
+
# raise FileNotFoundError(f"在 {file_name} 中找不到 {prefix} 图像")
|
|
50
|
+
|
|
51
|
+
# def _split_blocks(self, img):
|
|
52
|
+
# H, W = img.shape
|
|
53
|
+
# step = self.block_size - self.overlap
|
|
54
|
+
# blocks = []
|
|
55
|
+
|
|
56
|
+
# y_list = list(range(0, H - self.block_size + 1, step))
|
|
57
|
+
# x_list = list(range(0, W - self.block_size + 1, step))
|
|
58
|
+
|
|
59
|
+
# # 确保边缘覆盖
|
|
60
|
+
# if y_list[-1] != H - self.block_size:
|
|
61
|
+
# y_list.append(H - self.block_size)
|
|
62
|
+
# if x_list[-1] != W - self.block_size:
|
|
63
|
+
# x_list.append(W - self.block_size)
|
|
64
|
+
|
|
65
|
+
# for y in y_list:
|
|
66
|
+
# for x in x_list:
|
|
67
|
+
# blocks.append(
|
|
68
|
+
# ((y, x), img[y : y + self.block_size, x : x + self.block_size])
|
|
69
|
+
# )
|
|
70
|
+
|
|
71
|
+
# return blocks, H, W
|
|
72
|
+
|
|
73
|
+
# def _estimate_angle(self, block, global_angle):
|
|
74
|
+
# """先用全局角度粗筛,再精确估计局部峰值角度"""
|
|
75
|
+
# f = torch.fft.fftshift(torch.fft.fft2(block))
|
|
76
|
+
|
|
77
|
+
# H, W = f.shape
|
|
78
|
+
# ky_vals = torch.fft.fftshift(torch.fft.fftfreq(H, d=pixel_size)).to(f.device)
|
|
79
|
+
# kx_vals = torch.fft.fftshift(torch.fft.fftfreq(W, d=pixel_size)).to(f.device)
|
|
80
|
+
# KY, KX = torch.meshgrid(ky_vals, kx_vals, indexing="ij")
|
|
81
|
+
|
|
82
|
+
# angle = torch.deg2rad(torch.tensor(global_angle))
|
|
83
|
+
# vx, vy = torch.cos(angle), torch.sin(angle)
|
|
84
|
+
|
|
85
|
+
# cross = vx * KY - vy * KX
|
|
86
|
+
# mask = cross >= 0
|
|
87
|
+
# mag = torch.abs(f) * mask # 只保留一半
|
|
88
|
+
# # 去掉直流分量
|
|
89
|
+
# cy, cx = H // 2, W // 2
|
|
90
|
+
# mag[cy - 20 : cy + 20, cx - 20 : cx + 20] = 0
|
|
91
|
+
# # ==== Step 2: 找峰值 ====
|
|
92
|
+
# max_pos = torch.nonzero(mag == mag.max(), as_tuple=False)[0]
|
|
93
|
+
# dy, dx = max_pos[0] - mag.size(0) / 2, max_pos[1] - mag.size(1) / 2
|
|
94
|
+
# angle_loc = torch.atan2(-dx, dy) # 子图峰值方向
|
|
95
|
+
# Rmax = torch.sqrt(dy**2+dx**2)*0.75#滤波半径
|
|
96
|
+
# kmax=torch.sqrt(ky_vals[max_pos[0]]**2+kx_vals[max_pos[1]]**2)
|
|
97
|
+
|
|
98
|
+
# print(kmax)
|
|
99
|
+
# return torch.rad2deg(angle_loc).item(), f, mag,Rmax
|
|
100
|
+
|
|
101
|
+
# def _filter_one(self, input_tensor, angle_deg):
|
|
102
|
+
# f = torch.fft.fftshift(torch.fft.fft2(input_tensor))
|
|
103
|
+
|
|
104
|
+
# H, W = f.shape
|
|
105
|
+
# ky_vals = torch.fft.fftshift(torch.fft.fftfreq(H, d=1.0 / H)).to(f.device)
|
|
106
|
+
# kx_vals = torch.fft.fftshift(torch.fft.fftfreq(W, d=1.0 / W)).to(f.device)
|
|
107
|
+
# KY, KX = torch.meshgrid(ky_vals, kx_vals, indexing="ij")
|
|
108
|
+
|
|
109
|
+
# angle = torch.deg2rad(torch.tensor(angle_deg))
|
|
110
|
+
# vx, vy = torch.cos(angle), torch.sin(angle)
|
|
111
|
+
|
|
112
|
+
# cross = vx * KY - vy * KX
|
|
113
|
+
# mask = cross >= 0
|
|
114
|
+
|
|
115
|
+
# f_filtered = f * mask
|
|
116
|
+
# result = torch.fft.ifft2(torch.fft.ifftshift(f_filtered))
|
|
117
|
+
# return result, f, mask
|
|
118
|
+
|
|
119
|
+
# def __call__(self, angle_deg=0, visualize=False, vis_num=1,lowpass=False):
|
|
120
|
+
# # 分块处理
|
|
121
|
+
# a_full = self.OBJ_REF - self.OBJ - self.REF
|
|
122
|
+
# b_full = self.INC_REF - self.INC - self.REF
|
|
123
|
+
|
|
124
|
+
# blocks, H, W = self._split_blocks(a_full)
|
|
125
|
+
# E_full = torch.zeros((H, W), dtype=torch.complex64)
|
|
126
|
+
|
|
127
|
+
# vis_count = 0
|
|
128
|
+
# x=torch.arange(self.block_size).float()
|
|
129
|
+
# x-=torch.mean(x)
|
|
130
|
+
# y=torch.arange(self.block_size).float()
|
|
131
|
+
# y-=torch.mean(y)
|
|
132
|
+
# R=torch.sqrt(x.view(-1,1)**2+y.view(1,-1)**2)
|
|
133
|
+
|
|
134
|
+
# for (y, x), a_block in blocks:
|
|
135
|
+
# b_block = b_full[y : y + self.block_size, x : x + self.block_size]
|
|
136
|
+
|
|
137
|
+
# # 每个子块独立估计方向
|
|
138
|
+
# block_angle, f_b, mag_b,Rmax = self._estimate_angle(b_block, angle_deg)
|
|
139
|
+
|
|
140
|
+
# a_filtered, f_a, mask_a = self._filter_one(a_block, block_angle)
|
|
141
|
+
# b_filtered, _, mask_b = self._filter_one(b_block, block_angle)
|
|
142
|
+
|
|
143
|
+
# E = a_filtered/ (b_filtered)
|
|
144
|
+
# E[~torch.isfinite(E)] = 0
|
|
145
|
+
|
|
146
|
+
# # 放回全图
|
|
147
|
+
# if lowpass:
|
|
148
|
+
# E = torch.fft.fftshift(torch.fft.fft2(E))
|
|
149
|
+
# plt.figure(figsize=(8,8))
|
|
150
|
+
# plt.subplot(2, 2, 1)
|
|
151
|
+
# plt.pcolormesh(torch.log1p(torch.abs(E[800:-800,800:-800])).cpu())
|
|
152
|
+
# plt.colorbar()
|
|
153
|
+
# plt.title("E FFT magnitude (before)")
|
|
154
|
+
# E[R>Rmax]=0
|
|
155
|
+
# plt.subplot(2, 2, 2)
|
|
156
|
+
# plt.pcolormesh(torch.log1p(torch.abs(E[800:-800,800:-800])).cpu())
|
|
157
|
+
# plt.colorbar()
|
|
158
|
+
# plt.title("E FFT magnitude (after)")
|
|
159
|
+
# plt.subplot(2, 2, 3)
|
|
160
|
+
# plt.pcolormesh((R[800:-800,800:-800]>Rmax).float().cpu())
|
|
161
|
+
# plt.title("R magnitude (after)")
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
# E = torch.fft.ifft2(torch.fft.ifftshift(E))
|
|
165
|
+
# E_full[y : y + self.block_size, x : x + self.block_size] = E
|
|
166
|
+
|
|
167
|
+
# # 可视化部分子图
|
|
168
|
+
# if visualize and vis_count < vis_num:
|
|
169
|
+
# plt.figure(figsize=(12, 6))
|
|
170
|
+
# plt.suptitle(f"Block ({y},{x}), angle={block_angle:.2f}°")
|
|
171
|
+
|
|
172
|
+
# plt.subplot(2, 2, 1)
|
|
173
|
+
# plt.imshow(torch.log1p(mag_b).cpu(), cmap="gray")
|
|
174
|
+
# plt.title("b FFT magnitude (before)")
|
|
175
|
+
|
|
176
|
+
# plt.subplot(2, 2, 2)
|
|
177
|
+
# plt.imshow(torch.log1p(torch.abs(f_b * mask_b)).cpu(), cmap="gray")
|
|
178
|
+
# plt.title("b FFT after mask")
|
|
179
|
+
|
|
180
|
+
# plt.subplot(2, 2, 3)
|
|
181
|
+
# plt.imshow(torch.log1p(torch.abs(torch.abs(f_a))).cpu(), cmap="gray")
|
|
182
|
+
# plt.title("a FFT magnitude (before)")
|
|
183
|
+
|
|
184
|
+
# plt.subplot(2, 2, 4)
|
|
185
|
+
# plt.imshow(torch.log1p(torch.abs(f_a * mask_a)).cpu(), cmap="gray")
|
|
186
|
+
# plt.title("a FFT after mask")
|
|
187
|
+
|
|
188
|
+
# plt.tight_layout()
|
|
189
|
+
# plt.show()
|
|
190
|
+
|
|
191
|
+
# vis_count += 1
|
|
192
|
+
|
|
193
|
+
# return E_full
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class OffAxisPhase:
|
|
197
|
+
def __init__(self, file_name, pixel_size=2.4 / 100 * 1.8 / 2):
|
|
198
|
+
self.pixel_size = pixel_size
|
|
199
|
+
|
|
200
|
+
# 支持的图片后缀
|
|
201
|
+
self.valid_ext = [".jpg", ".png", ".bmp"]
|
|
202
|
+
|
|
203
|
+
# 读取图像
|
|
204
|
+
BACK = self._read_image("BACK", file_name)
|
|
205
|
+
self.OBJ = torch.clamp(self._read_image("OBJ", file_name) - BACK, min=0)
|
|
206
|
+
self.REF = torch.clamp(self._read_image("REF", file_name) - BACK, min=0)
|
|
207
|
+
self.OBJ_REF = torch.clamp(self._read_image("OBJ_REF", file_name) - BACK, min=0)
|
|
208
|
+
self.INC = torch.clamp(self._read_image("INC", file_name) - BACK, min=0)
|
|
209
|
+
self.INC_REF = torch.clamp(self._read_image("INC_REF", file_name) - BACK, min=0)
|
|
210
|
+
|
|
211
|
+
def _read_image(self, prefix, file_name):
|
|
212
|
+
for ext in self.valid_ext:
|
|
213
|
+
img_path = os.path.join(file_name, f"{prefix}{ext}")
|
|
214
|
+
if os.path.exists(img_path):
|
|
215
|
+
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
|
216
|
+
if img is None:
|
|
217
|
+
raise ValueError(f"无法读取图像: {img_path}")
|
|
218
|
+
img_tensor = torch.from_numpy(img)
|
|
219
|
+
|
|
220
|
+
if img_tensor.dtype == torch.uint8:
|
|
221
|
+
img_tensor = img_tensor.to(torch.float32) / 255.0
|
|
222
|
+
elif img_tensor.dtype == torch.uint16:
|
|
223
|
+
img_tensor = img_tensor.to(torch.float32) / 65535.0
|
|
224
|
+
else:
|
|
225
|
+
raise ValueError(f"不支持的图像位深: {img_tensor.dtype}")
|
|
226
|
+
return img_tensor
|
|
227
|
+
raise FileNotFoundError(f"在 {file_name} 中找不到 {prefix} 图像")
|
|
228
|
+
|
|
229
|
+
def _get_mask(self, block, global_angle):
|
|
230
|
+
"""估计局部峰值"""
|
|
231
|
+
f = torch.fft.fftshift(torch.fft.fft2(block))
|
|
232
|
+
H, W = f.shape
|
|
233
|
+
ky_vals = torch.fft.fftshift(torch.fft.fftfreq(H, d=self.pixel_size)).to(
|
|
234
|
+
f.device
|
|
235
|
+
)
|
|
236
|
+
kx_vals = torch.fft.fftshift(torch.fft.fftfreq(W, d=self.pixel_size)).to(
|
|
237
|
+
f.device
|
|
238
|
+
)
|
|
239
|
+
KY, KX = torch.meshgrid(ky_vals, kx_vals, indexing="ij")
|
|
240
|
+
|
|
241
|
+
angle = torch.deg2rad(torch.tensor(global_angle))
|
|
242
|
+
vx, vy = torch.cos(angle), torch.sin(angle)
|
|
243
|
+
|
|
244
|
+
cross = vx * KY - vy * KX
|
|
245
|
+
mask = cross >= 0
|
|
246
|
+
mag = torch.abs(f) * mask # 只保留一半
|
|
247
|
+
# 去掉直流分量
|
|
248
|
+
cy, cx = H // 2, W // 2
|
|
249
|
+
mag[cy - 15 : cy + 15, cx - 15 : cx + 15] = 0
|
|
250
|
+
# ==== Step 2: 找峰值 ====
|
|
251
|
+
max_pos = torch.nonzero(mag == mag.max(), as_tuple=False)[0]
|
|
252
|
+
kyc, kxc = ky_vals[max_pos[0]], kx_vals[max_pos[1]]
|
|
253
|
+
kmax = torch.sqrt(kyc**2 + kxc**2) * 0.9 # 滤波半径
|
|
254
|
+
print("kmax:" + str(round(kmax.cpu().item(), 2)))
|
|
255
|
+
mask = torch.sqrt((KY - kyc) ** 2 + (KX - kxc) ** 2) < kmax
|
|
256
|
+
return mask
|
|
257
|
+
|
|
258
|
+
def __call__(self, angle_deg=0, visualize=False):
|
|
259
|
+
# 分块处理
|
|
260
|
+
a_full = self.OBJ_REF - self.OBJ - self.REF
|
|
261
|
+
b_full = self.INC_REF - self.INC - self.REF
|
|
262
|
+
|
|
263
|
+
mask = self._get_mask(b_full, angle_deg)
|
|
264
|
+
a_full = torch.fft.fftshift(torch.fft.fft2(a_full))
|
|
265
|
+
a_full[~mask] = 0
|
|
266
|
+
a_full = torch.fft.ifft2(torch.fft.ifftshift(a_full))
|
|
267
|
+
|
|
268
|
+
b_full = torch.fft.fftshift(torch.fft.fft2(b_full))
|
|
269
|
+
if visualize:
|
|
270
|
+
plt.figure(figsize=(12, 6))
|
|
271
|
+
plt.subplot(2, 1, 1)
|
|
272
|
+
plt.imshow(torch.log1p(torch.abs(b_full)).cpu(), cmap="gray")
|
|
273
|
+
plt.title("b_full")
|
|
274
|
+
|
|
275
|
+
plt.subplot(2, 1, 2)
|
|
276
|
+
plt.imshow(mask.cpu(), cmap="gray")
|
|
277
|
+
plt.title("mask")
|
|
278
|
+
b_full[~mask] = 0
|
|
279
|
+
b_full = torch.fft.ifft2(torch.fft.ifftshift(b_full))
|
|
280
|
+
|
|
281
|
+
E = a_full / (b_full)
|
|
282
|
+
E[~torch.isfinite(E)] = 0
|
|
283
|
+
|
|
284
|
+
return E
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: torchopticsy
|
|
3
|
+
Version: 0.8.3
|
|
4
|
+
Summary: PyTorch-based optics caculation
|
|
5
|
+
Author: YuningYe
|
|
6
|
+
Author-email: 1956860113@qq.com
|
|
7
|
+
License: MIT
|
|
8
|
+
Description-Content-Type: text/markdown
|
|
9
|
+
Requires-Dist: torch
|
|
10
|
+
Requires-Dist: opencv-python
|
|
11
|
+
Requires-Dist: matplotlib
|
|
12
|
+
|
|
13
|
+
# torchOpticsY
|
|
14
|
+
A PyTorch-based optics calculation library.
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
## Features
|
|
18
|
+
- Luneburg integral
|
|
19
|
+
- Debye鈥揥olf integral.
|
|
20
|
+
|
|
21
|
+
## Install
|
|
22
|
+
pip install torchOpticsY
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
0.2 - Interference: OffAxisPhase
|
|
26
|
+
0.1 - Diffraction: Debye_Wolf,Luneburg
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
setup.py
|
|
2
|
+
torchopticsy/Diffraction.py
|
|
3
|
+
torchopticsy/Interference.py
|
|
4
|
+
torchopticsy/__init__.py
|
|
5
|
+
torchopticsy.egg-info/PKG-INFO
|
|
6
|
+
torchopticsy.egg-info/SOURCES.txt
|
|
7
|
+
torchopticsy.egg-info/dependency_links.txt
|
|
8
|
+
torchopticsy.egg-info/requires.txt
|
|
9
|
+
torchopticsy.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
torchopticsy
|