torchopticsy 0.8.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,26 @@
1
+ Metadata-Version: 2.1
2
+ Name: torchopticsy
3
+ Version: 0.8.3
4
+ Summary: PyTorch-based optics caculation
5
+ Author: YuningYe
6
+ Author-email: 1956860113@qq.com
7
+ License: MIT
8
+ Description-Content-Type: text/markdown
9
+ Requires-Dist: torch
10
+ Requires-Dist: opencv-python
11
+ Requires-Dist: matplotlib
12
+
13
+ # torchOpticsY
14
+ A PyTorch-based optics calculation library.
15
+
16
+
17
+ ## Features
18
+ - Luneburg integral
19
+ - Debye鈥揥olf integral.
20
+
21
+ ## Install
22
+ pip install torchOpticsY
23
+
24
+
25
+ 0.2 - Interference: OffAxisPhase
26
+ 0.1 - Diffraction: Debye_Wolf,Luneburg
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,16 @@
1
+ import pathlib
2
+ import setuptools
3
+
4
+ setuptools.setup(
5
+ name="torchopticsy",
6
+ version="0.8.3",
7
+ description="PyTorch-based optics caculation",
8
+ long_description=pathlib.Path("README.md").read_text(),
9
+ long_description_content_type="text/markdown",
10
+ author="YuningYe",
11
+ author_email="1956860113@qq.com",
12
+ license="MIT",
13
+ packages=setuptools.find_packages(),
14
+ install_requires=["torch", "opencv-python", "matplotlib"],
15
+ include_package_data=True,
16
+ )
@@ -0,0 +1,680 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class BluesteinDFT:
6
+ # Due to the error in periodic expansion, it can be treated by making up for zeros in the future
7
+ # Now only the central region is relatively accurate, and the surrounding area is affected by pseudo-diffraction by periodic conditions
8
+ def __init__(self, f1, f2, fs, mout, m_input, device="cpu"):
9
+ self.device = torch.device(device)
10
+ self.f1 = f1
11
+ self.f2 = f2
12
+ self.fs = fs
13
+ self.mout = mout
14
+ self.m_input = m_input
15
+
16
+ # Frequency adjustment
17
+ f11 = f1 + (mout * fs + f2 - f1) / (2 * mout)
18
+ f22 = f2 + (mout * fs + f2 - f1) / (2 * mout)
19
+ self.f11 = f11
20
+ self.f22 = f22
21
+
22
+ # Chirp parameters
23
+ a = torch.exp(1j * 2 * torch.pi * torch.tensor(f11 / fs))
24
+ w = torch.exp(-1j * 2 * torch.pi * torch.tensor(f22 - f11) / (mout * fs))
25
+ self.a = a.to(self.device)
26
+ self.w = w.to(self.device)
27
+
28
+ h = torch.arange(
29
+ -m_input + 1,
30
+ max(mout - 1, m_input - 1) + 1,
31
+ device=self.device,
32
+ dtype=torch.float64,
33
+ )
34
+ h = self.w ** ((h**2) / 2)
35
+
36
+ self.h = h
37
+ self.mp = m_input + mout - 1
38
+ padded_len = 2 ** int(torch.ceil(torch.log2(torch.tensor(self.mp))))
39
+ self.padded_len = padded_len
40
+
41
+ h_inv = torch.zeros(padded_len, dtype=torch.complex64, device=self.device)
42
+ h_inv[: self.mp] = 1 / h[: self.mp]
43
+ self.ft = torch.fft.fft(h_inv)
44
+
45
+ b_exp = torch.arange(0, m_input, device=self.device)
46
+ self.b_phase = (self.a**-b_exp) * h[m_input - 1 : 2 * m_input - 1]
47
+
48
+ l = torch.linspace(0, mout - 1, mout, device=self.device)
49
+ l = l / mout * (f22 - f11) + f11
50
+ Mshift = -m_input / 2
51
+ self.Mshift = torch.exp(-1j * 2 * torch.pi * l * (Mshift + 0.5) / fs)
52
+
53
+ def transform(self, x, dim=-1):
54
+ x = x.to(self.device)
55
+ m = self.m_input
56
+
57
+ dim = dim if dim >= 0 else x.ndim + dim
58
+
59
+ if x.shape[dim] != m:
60
+ print(m)
61
+ print(x.shape)
62
+ raise ValueError(
63
+ f"Expected dimension {dim} to be of size {m}, but got {x.shape[dim]}"
64
+ )
65
+
66
+ x = x.transpose(dim, -1)
67
+
68
+ b_phase = self.b_phase.view((1,) * (x.ndim - 1) + (-1,))
69
+ x_weighted = x * b_phase
70
+
71
+ original_shape = x_weighted.shape
72
+ x_weighted = x_weighted.reshape(-1, m)
73
+
74
+ b_padded = torch.zeros(
75
+ (x_weighted.shape[0], self.padded_len),
76
+ dtype=torch.complex64,
77
+ device=self.device,
78
+ )
79
+ b_padded[:, :m] = x_weighted
80
+
81
+ b_fft = torch.fft.fft(b_padded, dim=1)
82
+ conv = b_fft * self.ft[None, :]
83
+ result = torch.fft.ifft(conv, dim=1)
84
+
85
+ result = (
86
+ result[:, self.m_input - 1 : self.mp] * self.h[self.m_input - 1 : self.mp]
87
+ )
88
+ result = result * self.Mshift[None, :]
89
+
90
+ new_shape = list(original_shape[:-1]) + [self.mout]
91
+ result = result.reshape(*new_shape)
92
+
93
+ result = result.transpose(-1, dim)
94
+
95
+ return result
96
+
97
+
98
+ class DebyeWolf:
99
+ def __init__(
100
+ self,
101
+ Min,
102
+ xrange,
103
+ yrange,
104
+ zrange,
105
+ Mout,
106
+ lams, # list of wavelengths
107
+ NA,
108
+ focal_length,
109
+ n1=1,
110
+ device="cpu",
111
+ ):
112
+ self.device = device
113
+ self.Min = Min
114
+ self.xrange = xrange
115
+ self.yrange = yrange
116
+ self.z_arr = torch.linspace(zrange[0], zrange[1], Mout[2], device=device)
117
+ self.Moutx, self.Mouty = Mout[0], Mout[1]
118
+ lams = torch.tensor(lams, device=device)
119
+ self.lams, self.k0, self.n1, self.NA, self.focal_length = (
120
+ lams,
121
+ 2 * torch.pi / lams,
122
+ n1,
123
+ NA,
124
+ focal_length,
125
+ )
126
+
127
+ self.N = (Min - 1) / 2
128
+
129
+ m = torch.linspace(-Min / 2, Min / 2, Min, device=self.device)
130
+ n = torch.linspace(-Min / 2, Min / 2, Min, device=self.device)
131
+ self.m_grid, self.n_grid = torch.meshgrid(m, n, indexing="ij")
132
+
133
+ self.th = torch.asin(
134
+ torch.clamp(
135
+ NA * torch.sqrt(self.m_grid**2 + self.n_grid**2) / (self.N * n1), max=1
136
+ )
137
+ )
138
+ self.mask = self.th > torch.arcsin(torch.tensor(NA / n1))
139
+ self.phi = torch.atan2(self.n_grid, self.m_grid)
140
+ self.phi[self.phi < 0] += 2 * torch.pi
141
+
142
+ self._sqrt_costh = 1 / torch.sqrt(torch.cos(self.th).unsqueeze(-1))
143
+ self._sqrt_costh[torch.isnan(self._sqrt_costh)] = 0
144
+ self._sqrt_costh[self.mask] = 0
145
+
146
+ fs = lams * (Min - 1) / (2 * NA)
147
+ self.fs = fs
148
+ self.bluesteins_y = []
149
+ self.bluesteins_x = []
150
+ self.C = (
151
+ -1j
152
+ * torch.exp(1j * self.k0 * n1 * focal_length)
153
+ * focal_length
154
+ * (lams)
155
+ / (self.n1)
156
+ / fs
157
+ / fs
158
+ )
159
+ fs = fs.cpu().tolist()
160
+ for f in fs:
161
+ self.bluesteins_y.append(
162
+ BluesteinDFT(
163
+ f / 2 + self.yrange[0],
164
+ f / 2 + self.yrange[1],
165
+ f,
166
+ self.Mouty,
167
+ Min,
168
+ device=device,
169
+ )
170
+ )
171
+ self.bluesteins_x.append(
172
+ BluesteinDFT(
173
+ f / 2 + self.xrange[0],
174
+ f / 2 + self.xrange[1],
175
+ f,
176
+ self.Moutx,
177
+ Min,
178
+ device=device,
179
+ )
180
+ )
181
+ self.E_ideals = torch.ones_like(self.lams)
182
+ self.R = torch.stack(
183
+ [
184
+ -torch.sin(self.th) * torch.cos(self.phi),
185
+ -torch.sin(self.th) * torch.sin(self.phi),
186
+ torch.cos(self.th),
187
+ ],
188
+ dim=-1,
189
+ )
190
+ self.R = self.R.unsqueeze(-2)
191
+
192
+ def __call__(self, E, correct=False):
193
+ # The input E has shape (batch, x, y, 2, lam),
194
+ # where the z-component is not included.
195
+ # The output E has shape (batch, x, y, z, 3, lam).
196
+ # For different wavelengths (lam), a simple for-loop is used for now.
197
+
198
+ Ex_in, Ey_in = E[..., 0:1, :], E[..., 1:2, :]
199
+ th = self.th.unsqueeze(-1).unsqueeze(-1)
200
+ phi = self.phi.unsqueeze(-1).unsqueeze(-1)
201
+ z_arr = self.z_arr.unsqueeze(0).unsqueeze(0).unsqueeze(-1)
202
+ k0, n1 = self.k0.view(1, 1, 1, -1), self.n1
203
+
204
+ costh = torch.cos(th)
205
+ _sqrt_costh = self._sqrt_costh.unsqueeze(-1)
206
+ phase = torch.exp(1j * k0 * n1 * z_arr * costh)
207
+ deltadim = 0
208
+ C = (self.C / self.E_ideals).view(1, 1, 1, -1)
209
+ E_out = torch.zeros(
210
+ [self.Moutx, self.Mouty, self.z_arr.numel(), 3, self.lams.numel()],
211
+ dtype=torch.complex64,
212
+ device=self.device,
213
+ )
214
+ if E.dim() == 5:
215
+ th = th.unsqueeze(0)
216
+ phi = phi.unsqueeze(0)
217
+ z_arr = z_arr.unsqueeze(0)
218
+ k0 = k0.unsqueeze(0)
219
+ costh = costh.unsqueeze(0)
220
+ _sqrt_costh = _sqrt_costh.unsqueeze(0)
221
+ phase = phase.unsqueeze(0)
222
+ C = C.unsqueeze(0)
223
+ deltadim = 1
224
+ E_out = torch.zeros(
225
+ [
226
+ E.size(0),
227
+ self.Moutx,
228
+ self.Mouty,
229
+ self.z_arr.numel(),
230
+ 3,
231
+ self.lams.numel(),
232
+ ],
233
+ dtype=torch.complex64,
234
+ device=self.device,
235
+ )
236
+ Ex = (
237
+ (
238
+ Ex_in * (1 + (costh - 1) * torch.cos(phi) ** 2)
239
+ + Ey_in * (costh - 1) * torch.cos(phi) * torch.sin(phi)
240
+ )
241
+ * phase
242
+ * _sqrt_costh
243
+ )
244
+
245
+ Ey = (
246
+ (
247
+ Ex_in * (costh - 1) * torch.cos(phi) * torch.sin(phi)
248
+ + Ey_in * (1 + (costh - 1) * torch.sin(phi) ** 2)
249
+ )
250
+ * phase
251
+ * _sqrt_costh
252
+ )
253
+
254
+ Ez = (
255
+ (Ex_in * torch.cos(phi) + Ey_in * torch.sin(phi))
256
+ * torch.sin(th)
257
+ * phase
258
+ * _sqrt_costh
259
+ )
260
+ if correct:
261
+ temp = torch.stack([Ex, Ey, Ez], dim=-1)
262
+ temp = temp - 0.5 * self.R * torch.sum(self.R * temp, dim=-1, keepdim=True)
263
+ Ex, Ey, Ez = temp[:, :, :, 0], temp[:, :, :, 1], temp[:, :, :, 2]
264
+ for i in range(self.lams.numel()):
265
+ E_out[..., 0, i] = self.bluesteins_x[i].transform(
266
+ self.bluesteins_y[i].transform(Ex[..., i], dim=1 + deltadim),
267
+ dim=0 + deltadim,
268
+ )
269
+ E_out[..., 1, i] = self.bluesteins_x[i].transform(
270
+ self.bluesteins_y[i].transform(Ey[..., i], dim=1 + deltadim),
271
+ dim=0 + deltadim,
272
+ )
273
+ E_out[..., 2, i] = self.bluesteins_x[i].transform(
274
+ self.bluesteins_y[i].transform(Ez[..., i], dim=1 + deltadim),
275
+ dim=0 + deltadim,
276
+ )
277
+
278
+ return C * E_out
279
+
280
+ def Get_Z_offset_Phase(self, z):
281
+ k0, n1 = self.k0, self.n1
282
+ costh = torch.cos(self.th)
283
+ phase = k0.view(1, 1, -1) * n1 * z * (-costh).unsqueeze(-1)
284
+ return phase
285
+
286
+
287
+ def fft_circular_conv2d(E, G):
288
+ E_fft = torch.fft.fft2(E.permute(2, 0, 1))
289
+ G_fft = torch.fft.fft2(G.permute(2, 0, 1))
290
+ C_fft = E_fft * G_fft
291
+ C_ifft = torch.fft.ifft2(C_fft).permute(1, 2, 0)
292
+ return C_ifft
293
+
294
+
295
+ def fourier_upsample2d(E: torch.Tensor, up_sample=1) -> torch.Tensor:
296
+ x, y = E.shape[0], E.shape[1]
297
+ Xn, Yn = x * up_sample, y * up_sample
298
+
299
+ # 2D FFT(仅前两维),正交归一避免缩放差异
300
+ F = torch.fft.fft2(E, dim=(0, 1), norm="ortho")
301
+ Fc = torch.fft.fftshift(F, dim=(0, 1))
302
+
303
+ # 频域补零到更大网格
304
+ Fp = torch.zeros((Xn, Yn) + E.shape[2:], dtype=E.dtype, device=E.device)
305
+
306
+ x0 = (Xn - x + 1) // 2
307
+ y0 = (Yn - y + 1) // 2
308
+ Fp[x0 : x0 + x, y0 : y0 + y, ...] = Fc
309
+
310
+ Fp = torch.fft.ifftshift(Fp, dim=(0, 1))
311
+ up_E_full = torch.fft.ifft2(Fp, dim=(0, 1), norm="ortho")
312
+
313
+ X_keep = (x - 1) * up_sample + 1
314
+ Y_keep = (y - 1) * up_sample + 1
315
+ up_E = up_E_full[:X_keep, :Y_keep, ...] * up_sample
316
+ return up_E
317
+
318
+
319
+ def upsample_coords_xy(x: torch.Tensor, y: torch.Tensor, up_sample: int):
320
+ """
321
+ 输入:
322
+ x: 原 x 轴坐标,形状 (Nx,)
323
+ y: 原 y 轴坐标,形状 (Ny,)
324
+ up_sample: 正整数,上采样倍数
325
+
326
+ 返回:
327
+ x_up: 形状 ((Nx-1)*up + 1,)
328
+ y_up: 形状 ((Ny-1)*up + 1,)
329
+ (范围不变,步长缩小 1/up)
330
+ 说明:
331
+ 假设 x、y 为等间距采样(FFT/零填充本身也要求等间距)
332
+ """
333
+ if up_sample < 1:
334
+ raise ValueError("up_sample 必须 >= 1")
335
+ if up_sample == 1:
336
+ return x, y
337
+
338
+ if x.numel() < 2 or y.numel() < 2:
339
+ raise ValueError("x 和 y 至少需要 2 个点用于定义范围。")
340
+
341
+ # 直接按起点和终点线性插值(等间距)
342
+ dx = (x[1] - x[0]) / up_sample * 1.5 * (up_sample - 1)
343
+ dy = (y[1] - y[0]) / up_sample * 1.5 * (up_sample - 1)
344
+ x_up = torch.linspace(
345
+ x[0] + dx,
346
+ x[-1] + dx,
347
+ (x.numel() - 1) * up_sample + 1,
348
+ device=x.device,
349
+ dtype=x.dtype,
350
+ )
351
+ y_up = torch.linspace(
352
+ y[0] + dy,
353
+ y[-1] + dy,
354
+ (y.numel() - 1) * up_sample + 1,
355
+ device=y.device,
356
+ dtype=y.dtype,
357
+ )
358
+ return x_up, y_up
359
+
360
+
361
+ class Luneburg:
362
+ # The calculation method for Ez is different and omitted here.
363
+ # The sampling theorem must be satisfied: T < lambda / 2
364
+ def __init__(
365
+ self,
366
+ uv_len, # Integration region
367
+ x_range,
368
+ y_range,
369
+ z_range,
370
+ sampling_interval, # Sampling interval, a 1D array; strictly followed in xy-directions,
371
+ # while z-direction is adaptively adjusted
372
+ lams,
373
+ focal_length, # Evaluation is based on this focal length
374
+ n1=1,
375
+ up_sample=1, # Integer > 1,
376
+ device="cuda",
377
+ ):
378
+ self.device = device
379
+ self.up_sample = up_sample
380
+ lams = torch.tensor(lams, device=device)
381
+ self.lams = lams
382
+ self.k = 2 * torch.pi * n1 / self.lams
383
+ self.n1 = n1
384
+ self.focal_length = focal_length
385
+ uv_len = uv_len - sampling_interval[0]
386
+ ux_arr = torch.arange(
387
+ -uv_len / 2 + x_range[0] - sampling_interval[0],
388
+ x_range[1] + uv_len / 2,
389
+ step=sampling_interval[0],
390
+ device=device,
391
+ )
392
+
393
+ vy_arr = torch.arange(
394
+ -uv_len / 2 + y_range[0] - sampling_interval[1],
395
+ y_range[1] + uv_len / 2,
396
+ step=sampling_interval[1],
397
+ device=device,
398
+ )
399
+ z_arr = torch.linspace(
400
+ z_range[0] + focal_length,
401
+ z_range[1] + focal_length,
402
+ int((z_range[1] - z_range[0]) / sampling_interval[2] / 2) * 2 + 1,
403
+ device=device,
404
+ )
405
+
406
+ center = (
407
+ round((x_range[0] + x_range[1]) / 2 / sampling_interval[0])
408
+ * sampling_interval[0]
409
+ )
410
+ u_arr = ux_arr - center
411
+ u_start_id = torch.argmin(torch.abs(u_arr + uv_len / 2))
412
+ u_end_id = torch.argmin(torch.abs(u_arr - uv_len / 2))
413
+ u_arr = u_arr[u_start_id : u_end_id + 1] + sampling_interval[0] / 2
414
+ du = torch.mean(u_arr)
415
+ u_arr = u_arr - du
416
+ center = (
417
+ round((y_range[0] + y_range[1]) / 2 / sampling_interval[1])
418
+ * sampling_interval[1]
419
+ )
420
+ v_arr = vy_arr - center
421
+ v_start_id = torch.argmin(torch.abs(v_arr + uv_len / 2))
422
+ v_end_id = torch.argmin(torch.abs(v_arr - uv_len / 2))
423
+ v_arr = v_arr[v_start_id : v_end_id + 1] + sampling_interval[1] / 2
424
+ dv = torch.mean(v_arr)
425
+ v_arr = v_arr - dv
426
+
427
+ x_start_id = torch.argmin(torch.abs(ux_arr - x_range[0]))
428
+ x_end_id = torch.argmin(torch.abs(ux_arr - x_range[1]))
429
+ y_start_id = torch.argmin(torch.abs(vy_arr - y_range[0]))
430
+ y_end_id = torch.argmin(torch.abs(vy_arr - y_range[1]))
431
+
432
+ x_arr = ux_arr[x_start_id : x_end_id + 1] - du
433
+ y_arr = vy_arr[y_start_id : y_end_id + 1] - dv
434
+
435
+ co_public = (
436
+ -1j
437
+ * sampling_interval[0]
438
+ * sampling_interval[1]
439
+ * n1
440
+ / self.lams.view(1, 1, 1, -1)
441
+ ) # / 4 / torch.pi
442
+ R_q = torch.sqrt(
443
+ (ux_arr.view(-1, 1, 1, 1)) ** 2
444
+ + (vy_arr.view(1, -1, 1, 1)) ** 2
445
+ + z_arr.view(1, 1, -1, 1) ** 2
446
+ )
447
+ G_2D = self.k * R_q * 1j
448
+ G_2D = (
449
+ torch.exp(G_2D)
450
+ / R_q**2
451
+ * co_public
452
+ * (1 - 1 / G_2D)
453
+ * torch.sinc(
454
+ n1
455
+ * ux_arr.view(-1, 1, 1, 1)
456
+ * sampling_interval[0]
457
+ / (R_q * lams.view(1, 1, 1, -1))
458
+ )
459
+ * torch.sinc(
460
+ n1
461
+ * vy_arr.view(1, -1, 1, 1)
462
+ * sampling_interval[1]
463
+ / (R_q * lams.view(1, 1, 1, -1))
464
+ )
465
+ )
466
+
467
+ self.G_2D = G_2D
468
+
469
+ self.u_arr = u_arr
470
+ self.v_arr = v_arr
471
+ self.x_arr, self.y_arr = upsample_coords_xy(x_arr, y_arr, up_sample)
472
+ self.z_arr = z_arr
473
+
474
+ self.x_num = x_end_id - x_start_id + 1
475
+ self.y_num = y_end_id - y_start_id + 1
476
+
477
+ self.u_pad_num = ux_arr.numel() - u_arr.numel()
478
+ self.v_pad_num = vy_arr.numel() - v_arr.numel()
479
+ self.dS = sampling_interval[0] * sampling_interval[1]
480
+
481
+ def __call__(self, E):
482
+ # E should have shape (u, v, lam)
483
+ if E.dim() == 3:
484
+ E_padded = E
485
+ pad = (
486
+ 0,
487
+ 0, # No padding applied to the 2nd dimension (D2)
488
+ self.v_pad_num,
489
+ 0, # Extend on both sides of the 1st dimension (D1)
490
+ self.u_pad_num,
491
+ 0, # Extend on both sides of the 0th dimension (D0)
492
+ )
493
+
494
+ E_padded = F.pad(E_padded, pad, mode="constant", value=0)
495
+ E_out = torch.zeros(
496
+ [
497
+ self.x_num,
498
+ self.y_num,
499
+ self.z_arr.numel(),
500
+ self.lams.numel(),
501
+ ],
502
+ device=self.device,
503
+ dtype=torch.complex64,
504
+ )
505
+ for z in range(self.z_arr.numel()):
506
+ E_out[:, :, z, :] = fft_circular_conv2d(
507
+ E_padded, self.G_2D[:, :, z] * self.z_arr[z]
508
+ )[
509
+ : self.x_num,
510
+ : self.y_num,
511
+ ]
512
+ E_out = fourier_upsample2d(E_out, self.up_sample)
513
+ return E_out
514
+
515
+ def Get_G_2D(self, x_f, y_f, z_f): # Typically used for optimization and validation
516
+ z = self.focal_length + z_f # Detection coordinate
517
+ co_public = -1j * self.dS * self.n1 / self.lams # / 4 / torch.pi
518
+ R_q = torch.sqrt(
519
+ z**2
520
+ + (x_f - self.u_arr.view(-1, 1, 1)) ** 2
521
+ + (y_f - self.v_arr.view(1, -1, 1)) ** 2
522
+ )
523
+ ik0R = self.k.view(1, 1, -1) * R_q * 1j
524
+ G_2D = torch.exp(ik0R) / R_q**2 * (1 - 1 / ik0R) * co_public
525
+
526
+ return z * G_2D
527
+
528
+
529
+ class LuneburgOld:
530
+ # The calculation method for Ez is different and omitted here.
531
+ # The sampling theorem must be satisfied: T < lambda / 2
532
+ def __init__(
533
+ self,
534
+ uv_len, # Integration region
535
+ x_range,
536
+ y_range,
537
+ z_range,
538
+ sampling_interval, # Sampling interval, a 1D array; strictly followed in xy-directions,
539
+ # while z-direction is adaptively adjusted
540
+ lams,
541
+ focal_length, # Evaluation is based on this focal length
542
+ n1=1,
543
+ scale=1, # Integer > 1, used to represent a reduced sampling frequency for E
544
+ device="cuda",
545
+ ):
546
+ self.device = device
547
+ self.scale = scale
548
+ lams = torch.tensor(lams, device=device)
549
+ self.lams = lams
550
+ self.k = 2 * torch.pi * n1 / self.lams
551
+ self.n1 = n1
552
+ self.focal_length = focal_length
553
+ uv_len = uv_len - sampling_interval[0]
554
+ ux_arr = torch.arange(
555
+ -uv_len / 2 + x_range[0] - sampling_interval[0],
556
+ x_range[1] + uv_len / 2,
557
+ step=sampling_interval[0],
558
+ device=device,
559
+ )
560
+
561
+ vy_arr = torch.arange(
562
+ -uv_len / 2 + y_range[0] - sampling_interval[1],
563
+ y_range[1] + uv_len / 2,
564
+ step=sampling_interval[1],
565
+ device=device,
566
+ )
567
+ z_arr = torch.linspace(
568
+ z_range[0] + focal_length,
569
+ z_range[1] + focal_length,
570
+ int((z_range[1] - z_range[0]) / sampling_interval[2] / 2) * 2 + 1,
571
+ device=device,
572
+ )
573
+
574
+ center = (
575
+ round((x_range[0] + x_range[1]) / 2 / sampling_interval[0])
576
+ * sampling_interval[0]
577
+ )
578
+ u_arr = ux_arr - center
579
+ u_start_id = torch.argmin(torch.abs(u_arr + uv_len / 2))
580
+ u_end_id = torch.argmin(torch.abs(u_arr - uv_len / 2))
581
+ u_arr = u_arr[u_start_id : u_end_id + 1 : scale] + sampling_interval[0] / 2
582
+ du = torch.mean(u_arr)
583
+ u_arr = u_arr - du
584
+ center = (
585
+ round((y_range[0] + y_range[1]) / 2 / sampling_interval[1])
586
+ * sampling_interval[1]
587
+ )
588
+ v_arr = vy_arr - center
589
+ v_start_id = torch.argmin(torch.abs(v_arr + uv_len / 2))
590
+ v_end_id = torch.argmin(torch.abs(v_arr - uv_len / 2))
591
+ v_arr = v_arr[v_start_id : v_end_id + 1 : scale] + sampling_interval[1] / 2
592
+ dv = torch.mean(v_arr)
593
+ v_arr = v_arr - dv
594
+
595
+ x_start_id = torch.argmin(torch.abs(ux_arr - x_range[0]))
596
+ x_end_id = torch.argmin(torch.abs(ux_arr - x_range[1]))
597
+ y_start_id = torch.argmin(torch.abs(vy_arr - y_range[0]))
598
+ y_end_id = torch.argmin(torch.abs(vy_arr - y_range[1]))
599
+
600
+ x_arr = ux_arr[x_start_id : x_end_id + 1] - du
601
+ y_arr = vy_arr[y_start_id : y_end_id + 1] - dv
602
+
603
+ co_public = (
604
+ -1j
605
+ * sampling_interval[0]
606
+ * sampling_interval[1]
607
+ * n1
608
+ / self.lams.view(1, 1, 1, -1)
609
+ ) # / 4 / torch.pi
610
+ R_q = torch.sqrt(
611
+ (ux_arr.view(-1, 1, 1, 1)) ** 2
612
+ + (vy_arr.view(1, -1, 1, 1)) ** 2
613
+ + z_arr.view(1, 1, -1, 1) ** 2
614
+ )
615
+ G_2D = self.k * R_q * 1j
616
+ G_2D = torch.exp(G_2D) / R_q**2 * co_public * (1 - 1 / G_2D)
617
+
618
+ self.G_2D = G_2D
619
+
620
+ self.u_arr = u_arr
621
+ self.v_arr = v_arr
622
+ self.x_arr = x_arr
623
+ self.y_arr = y_arr
624
+ self.z_arr = z_arr
625
+
626
+ self.x_num = x_end_id - x_start_id + 1
627
+ self.y_num = y_end_id - y_start_id + 1
628
+
629
+ self.u_pad_num = ux_arr.numel() - u_arr.numel() * scale
630
+ self.v_pad_num = vy_arr.numel() - v_arr.numel() * scale
631
+ self.dS = sampling_interval[0] * sampling_interval[1]
632
+
633
+ def __call__(self, E):
634
+ # E should have shape (u, v, lam)
635
+ if E.dim() == 3:
636
+ E_padded = E.repeat_interleave(self.scale, dim=0)
637
+ E_padded = E_padded.repeat_interleave(self.scale, dim=1)
638
+ pad = (
639
+ 0,
640
+ 0, # No padding applied to the 2nd dimension (D2)
641
+ self.v_pad_num,
642
+ 0, # Extend on both sides of the 1st dimension (D1)
643
+ self.u_pad_num,
644
+ 0, # Extend on both sides of the 0th dimension (D0)
645
+ )
646
+
647
+ E_padded = F.pad(E_padded, pad, mode="constant", value=0)
648
+ E_out = torch.zeros(
649
+ [
650
+ self.x_num,
651
+ self.y_num,
652
+ self.z_arr.numel(),
653
+ self.lams.numel(),
654
+ ],
655
+ device=self.device,
656
+ dtype=torch.complex64,
657
+ )
658
+ for z in range(self.z_arr.numel()):
659
+ E_out[:, :, z, :] = fft_circular_conv2d(
660
+ E_padded, self.G_2D[:, :, z] * self.z_arr[z]
661
+ )[
662
+ : self.x_num,
663
+ : self.y_num,
664
+ ]
665
+ return E_out
666
+
667
+ def Get_G_2D(self, x_f, y_f, z_f): # Typically used for optimization and validation
668
+ z = self.focal_length + z_f # Detection coordinate
669
+ co_public = (
670
+ -1j * self.dS * self.n1 / self.lams * self.scale**2
671
+ ) # / 4 / torch.pi
672
+ R_q = torch.sqrt(
673
+ z**2
674
+ + (x_f - self.u_arr.view(-1, 1, 1)) ** 2
675
+ + (y_f - self.v_arr.view(1, -1, 1)) ** 2
676
+ )
677
+ ik0R = self.k.view(1, 1, -1) * R_q * 1j
678
+ G_2D = torch.exp(ik0R) / R_q**2 * (1 - 1 / ik0R) * co_public
679
+
680
+ return z * G_2D
@@ -0,0 +1,284 @@
1
+ # %%
2
+ # 滤波确实可以去掉伪影
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import cv2
6
+ import os
7
+ import matplotlib.pyplot as plt
8
+
9
+ # pixel_size=2.4/100*1.8/2
10
+ # class OffAxisPhaseOld:
11
+ # def __init__(
12
+ # self,
13
+ # file_name,
14
+ # continuous=True, # 默认是连续光
15
+ # block_size=1024, # 每个子块的大小
16
+ # overlap=0, # 子块之间是否重叠
17
+ # ):
18
+ # self.continuous = continuous
19
+ # self.block_size = block_size
20
+ # self.overlap = overlap
21
+
22
+ # # 支持的图片后缀
23
+ # self.valid_ext = [".jpg", ".png", ".bmp"]
24
+
25
+ # # 读取图像
26
+ # BACK = self._read_image("BACK", file_name)
27
+ # self.OBJ = torch.clamp(self._read_image("OBJ", file_name) - BACK, min=0)
28
+ # self.REF = torch.clamp(self._read_image("REF", file_name) - BACK, min=0)
29
+ # self.OBJ_REF = torch.clamp(self._read_image("OBJ_REF", file_name) - BACK, min=0)
30
+ # self.INC = torch.clamp(self._read_image("INC", file_name) - BACK, min=0)
31
+ # self.INC_REF = torch.clamp(self._read_image("INC_REF", file_name) - BACK, min=0)
32
+
33
+ # def _read_image(self, prefix, file_name):
34
+ # for ext in self.valid_ext:
35
+ # img_path = os.path.join(file_name, f"{prefix}{ext}")
36
+ # if os.path.exists(img_path):
37
+ # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
38
+ # if img is None:
39
+ # raise ValueError(f"无法读取图像: {img_path}")
40
+ # img_tensor = torch.from_numpy(img)
41
+
42
+ # if img_tensor.dtype == torch.uint8:
43
+ # img_tensor = img_tensor.to(torch.float32) / 255.0
44
+ # elif img_tensor.dtype == torch.uint16:
45
+ # img_tensor = img_tensor.to(torch.float32) / 65535.0
46
+ # else:
47
+ # raise ValueError(f"不支持的图像位深: {img_tensor.dtype}")
48
+ # return img_tensor
49
+ # raise FileNotFoundError(f"在 {file_name} 中找不到 {prefix} 图像")
50
+
51
+ # def _split_blocks(self, img):
52
+ # H, W = img.shape
53
+ # step = self.block_size - self.overlap
54
+ # blocks = []
55
+
56
+ # y_list = list(range(0, H - self.block_size + 1, step))
57
+ # x_list = list(range(0, W - self.block_size + 1, step))
58
+
59
+ # # 确保边缘覆盖
60
+ # if y_list[-1] != H - self.block_size:
61
+ # y_list.append(H - self.block_size)
62
+ # if x_list[-1] != W - self.block_size:
63
+ # x_list.append(W - self.block_size)
64
+
65
+ # for y in y_list:
66
+ # for x in x_list:
67
+ # blocks.append(
68
+ # ((y, x), img[y : y + self.block_size, x : x + self.block_size])
69
+ # )
70
+
71
+ # return blocks, H, W
72
+
73
+ # def _estimate_angle(self, block, global_angle):
74
+ # """先用全局角度粗筛,再精确估计局部峰值角度"""
75
+ # f = torch.fft.fftshift(torch.fft.fft2(block))
76
+
77
+ # H, W = f.shape
78
+ # ky_vals = torch.fft.fftshift(torch.fft.fftfreq(H, d=pixel_size)).to(f.device)
79
+ # kx_vals = torch.fft.fftshift(torch.fft.fftfreq(W, d=pixel_size)).to(f.device)
80
+ # KY, KX = torch.meshgrid(ky_vals, kx_vals, indexing="ij")
81
+
82
+ # angle = torch.deg2rad(torch.tensor(global_angle))
83
+ # vx, vy = torch.cos(angle), torch.sin(angle)
84
+
85
+ # cross = vx * KY - vy * KX
86
+ # mask = cross >= 0
87
+ # mag = torch.abs(f) * mask # 只保留一半
88
+ # # 去掉直流分量
89
+ # cy, cx = H // 2, W // 2
90
+ # mag[cy - 20 : cy + 20, cx - 20 : cx + 20] = 0
91
+ # # ==== Step 2: 找峰值 ====
92
+ # max_pos = torch.nonzero(mag == mag.max(), as_tuple=False)[0]
93
+ # dy, dx = max_pos[0] - mag.size(0) / 2, max_pos[1] - mag.size(1) / 2
94
+ # angle_loc = torch.atan2(-dx, dy) # 子图峰值方向
95
+ # Rmax = torch.sqrt(dy**2+dx**2)*0.75#滤波半径
96
+ # kmax=torch.sqrt(ky_vals[max_pos[0]]**2+kx_vals[max_pos[1]]**2)
97
+
98
+ # print(kmax)
99
+ # return torch.rad2deg(angle_loc).item(), f, mag,Rmax
100
+
101
+ # def _filter_one(self, input_tensor, angle_deg):
102
+ # f = torch.fft.fftshift(torch.fft.fft2(input_tensor))
103
+
104
+ # H, W = f.shape
105
+ # ky_vals = torch.fft.fftshift(torch.fft.fftfreq(H, d=1.0 / H)).to(f.device)
106
+ # kx_vals = torch.fft.fftshift(torch.fft.fftfreq(W, d=1.0 / W)).to(f.device)
107
+ # KY, KX = torch.meshgrid(ky_vals, kx_vals, indexing="ij")
108
+
109
+ # angle = torch.deg2rad(torch.tensor(angle_deg))
110
+ # vx, vy = torch.cos(angle), torch.sin(angle)
111
+
112
+ # cross = vx * KY - vy * KX
113
+ # mask = cross >= 0
114
+
115
+ # f_filtered = f * mask
116
+ # result = torch.fft.ifft2(torch.fft.ifftshift(f_filtered))
117
+ # return result, f, mask
118
+
119
+ # def __call__(self, angle_deg=0, visualize=False, vis_num=1,lowpass=False):
120
+ # # 分块处理
121
+ # a_full = self.OBJ_REF - self.OBJ - self.REF
122
+ # b_full = self.INC_REF - self.INC - self.REF
123
+
124
+ # blocks, H, W = self._split_blocks(a_full)
125
+ # E_full = torch.zeros((H, W), dtype=torch.complex64)
126
+
127
+ # vis_count = 0
128
+ # x=torch.arange(self.block_size).float()
129
+ # x-=torch.mean(x)
130
+ # y=torch.arange(self.block_size).float()
131
+ # y-=torch.mean(y)
132
+ # R=torch.sqrt(x.view(-1,1)**2+y.view(1,-1)**2)
133
+
134
+ # for (y, x), a_block in blocks:
135
+ # b_block = b_full[y : y + self.block_size, x : x + self.block_size]
136
+
137
+ # # 每个子块独立估计方向
138
+ # block_angle, f_b, mag_b,Rmax = self._estimate_angle(b_block, angle_deg)
139
+
140
+ # a_filtered, f_a, mask_a = self._filter_one(a_block, block_angle)
141
+ # b_filtered, _, mask_b = self._filter_one(b_block, block_angle)
142
+
143
+ # E = a_filtered/ (b_filtered)
144
+ # E[~torch.isfinite(E)] = 0
145
+
146
+ # # 放回全图
147
+ # if lowpass:
148
+ # E = torch.fft.fftshift(torch.fft.fft2(E))
149
+ # plt.figure(figsize=(8,8))
150
+ # plt.subplot(2, 2, 1)
151
+ # plt.pcolormesh(torch.log1p(torch.abs(E[800:-800,800:-800])).cpu())
152
+ # plt.colorbar()
153
+ # plt.title("E FFT magnitude (before)")
154
+ # E[R>Rmax]=0
155
+ # plt.subplot(2, 2, 2)
156
+ # plt.pcolormesh(torch.log1p(torch.abs(E[800:-800,800:-800])).cpu())
157
+ # plt.colorbar()
158
+ # plt.title("E FFT magnitude (after)")
159
+ # plt.subplot(2, 2, 3)
160
+ # plt.pcolormesh((R[800:-800,800:-800]>Rmax).float().cpu())
161
+ # plt.title("R magnitude (after)")
162
+
163
+
164
+ # E = torch.fft.ifft2(torch.fft.ifftshift(E))
165
+ # E_full[y : y + self.block_size, x : x + self.block_size] = E
166
+
167
+ # # 可视化部分子图
168
+ # if visualize and vis_count < vis_num:
169
+ # plt.figure(figsize=(12, 6))
170
+ # plt.suptitle(f"Block ({y},{x}), angle={block_angle:.2f}°")
171
+
172
+ # plt.subplot(2, 2, 1)
173
+ # plt.imshow(torch.log1p(mag_b).cpu(), cmap="gray")
174
+ # plt.title("b FFT magnitude (before)")
175
+
176
+ # plt.subplot(2, 2, 2)
177
+ # plt.imshow(torch.log1p(torch.abs(f_b * mask_b)).cpu(), cmap="gray")
178
+ # plt.title("b FFT after mask")
179
+
180
+ # plt.subplot(2, 2, 3)
181
+ # plt.imshow(torch.log1p(torch.abs(torch.abs(f_a))).cpu(), cmap="gray")
182
+ # plt.title("a FFT magnitude (before)")
183
+
184
+ # plt.subplot(2, 2, 4)
185
+ # plt.imshow(torch.log1p(torch.abs(f_a * mask_a)).cpu(), cmap="gray")
186
+ # plt.title("a FFT after mask")
187
+
188
+ # plt.tight_layout()
189
+ # plt.show()
190
+
191
+ # vis_count += 1
192
+
193
+ # return E_full
194
+
195
+
196
+ class OffAxisPhase:
197
+ def __init__(self, file_name, pixel_size=2.4 / 100 * 1.8 / 2):
198
+ self.pixel_size = pixel_size
199
+
200
+ # 支持的图片后缀
201
+ self.valid_ext = [".jpg", ".png", ".bmp"]
202
+
203
+ # 读取图像
204
+ BACK = self._read_image("BACK", file_name)
205
+ self.OBJ = torch.clamp(self._read_image("OBJ", file_name) - BACK, min=0)
206
+ self.REF = torch.clamp(self._read_image("REF", file_name) - BACK, min=0)
207
+ self.OBJ_REF = torch.clamp(self._read_image("OBJ_REF", file_name) - BACK, min=0)
208
+ self.INC = torch.clamp(self._read_image("INC", file_name) - BACK, min=0)
209
+ self.INC_REF = torch.clamp(self._read_image("INC_REF", file_name) - BACK, min=0)
210
+
211
+ def _read_image(self, prefix, file_name):
212
+ for ext in self.valid_ext:
213
+ img_path = os.path.join(file_name, f"{prefix}{ext}")
214
+ if os.path.exists(img_path):
215
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
216
+ if img is None:
217
+ raise ValueError(f"无法读取图像: {img_path}")
218
+ img_tensor = torch.from_numpy(img)
219
+
220
+ if img_tensor.dtype == torch.uint8:
221
+ img_tensor = img_tensor.to(torch.float32) / 255.0
222
+ elif img_tensor.dtype == torch.uint16:
223
+ img_tensor = img_tensor.to(torch.float32) / 65535.0
224
+ else:
225
+ raise ValueError(f"不支持的图像位深: {img_tensor.dtype}")
226
+ return img_tensor
227
+ raise FileNotFoundError(f"在 {file_name} 中找不到 {prefix} 图像")
228
+
229
+ def _get_mask(self, block, global_angle):
230
+ """估计局部峰值"""
231
+ f = torch.fft.fftshift(torch.fft.fft2(block))
232
+ H, W = f.shape
233
+ ky_vals = torch.fft.fftshift(torch.fft.fftfreq(H, d=self.pixel_size)).to(
234
+ f.device
235
+ )
236
+ kx_vals = torch.fft.fftshift(torch.fft.fftfreq(W, d=self.pixel_size)).to(
237
+ f.device
238
+ )
239
+ KY, KX = torch.meshgrid(ky_vals, kx_vals, indexing="ij")
240
+
241
+ angle = torch.deg2rad(torch.tensor(global_angle))
242
+ vx, vy = torch.cos(angle), torch.sin(angle)
243
+
244
+ cross = vx * KY - vy * KX
245
+ mask = cross >= 0
246
+ mag = torch.abs(f) * mask # 只保留一半
247
+ # 去掉直流分量
248
+ cy, cx = H // 2, W // 2
249
+ mag[cy - 15 : cy + 15, cx - 15 : cx + 15] = 0
250
+ # ==== Step 2: 找峰值 ====
251
+ max_pos = torch.nonzero(mag == mag.max(), as_tuple=False)[0]
252
+ kyc, kxc = ky_vals[max_pos[0]], kx_vals[max_pos[1]]
253
+ kmax = torch.sqrt(kyc**2 + kxc**2) * 0.9 # 滤波半径
254
+ print("kmax:" + str(round(kmax.cpu().item(), 2)))
255
+ mask = torch.sqrt((KY - kyc) ** 2 + (KX - kxc) ** 2) < kmax
256
+ return mask
257
+
258
+ def __call__(self, angle_deg=0, visualize=False):
259
+ # 分块处理
260
+ a_full = self.OBJ_REF - self.OBJ - self.REF
261
+ b_full = self.INC_REF - self.INC - self.REF
262
+
263
+ mask = self._get_mask(b_full, angle_deg)
264
+ a_full = torch.fft.fftshift(torch.fft.fft2(a_full))
265
+ a_full[~mask] = 0
266
+ a_full = torch.fft.ifft2(torch.fft.ifftshift(a_full))
267
+
268
+ b_full = torch.fft.fftshift(torch.fft.fft2(b_full))
269
+ if visualize:
270
+ plt.figure(figsize=(12, 6))
271
+ plt.subplot(2, 1, 1)
272
+ plt.imshow(torch.log1p(torch.abs(b_full)).cpu(), cmap="gray")
273
+ plt.title("b_full")
274
+
275
+ plt.subplot(2, 1, 2)
276
+ plt.imshow(mask.cpu(), cmap="gray")
277
+ plt.title("mask")
278
+ b_full[~mask] = 0
279
+ b_full = torch.fft.ifft2(torch.fft.ifftshift(b_full))
280
+
281
+ E = a_full / (b_full)
282
+ E[~torch.isfinite(E)] = 0
283
+
284
+ return E
@@ -0,0 +1,26 @@
1
+ Metadata-Version: 2.1
2
+ Name: torchopticsy
3
+ Version: 0.8.3
4
+ Summary: PyTorch-based optics caculation
5
+ Author: YuningYe
6
+ Author-email: 1956860113@qq.com
7
+ License: MIT
8
+ Description-Content-Type: text/markdown
9
+ Requires-Dist: torch
10
+ Requires-Dist: opencv-python
11
+ Requires-Dist: matplotlib
12
+
13
+ # torchOpticsY
14
+ A PyTorch-based optics calculation library.
15
+
16
+
17
+ ## Features
18
+ - Luneburg integral
19
+ - Debye鈥揥olf integral.
20
+
21
+ ## Install
22
+ pip install torchOpticsY
23
+
24
+
25
+ 0.2 - Interference: OffAxisPhase
26
+ 0.1 - Diffraction: Debye_Wolf,Luneburg
@@ -0,0 +1,9 @@
1
+ setup.py
2
+ torchopticsy/Diffraction.py
3
+ torchopticsy/Interference.py
4
+ torchopticsy/__init__.py
5
+ torchopticsy.egg-info/PKG-INFO
6
+ torchopticsy.egg-info/SOURCES.txt
7
+ torchopticsy.egg-info/dependency_links.txt
8
+ torchopticsy.egg-info/requires.txt
9
+ torchopticsy.egg-info/top_level.txt
@@ -0,0 +1,3 @@
1
+ torch
2
+ opencv-python
3
+ matplotlib
@@ -0,0 +1 @@
1
+ torchopticsy