vectorwaves 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,623 @@
1
+ """
2
+ Beam Representation & Generation
3
+ ================================
4
+
5
+ Translates experiment configurations into physical beam objects. This module
6
+ contains the `Beam` dataclass, which stores precomputed coefficients and
7
+ angular spectrum data, and the `BeamMaker` factory.
8
+
9
+ Pipeline Context:
10
+ 1. Config (config_stuff.py) -> Provided as input to BeamMaker.
11
+ 2. Beam (beam_stuff.py) -> **CURRENT STEP**. Precomputes spectrum.
12
+ 3. Engine (engine_stuff.py) -> Resulting Beam is passed to FieldEngine.
13
+ """
14
+
15
+ import numpy as np
16
+ import warnings
17
+ from dataclasses import dataclass
18
+ from typing import Tuple, Optional, Literal
19
+ from functools import cached_property
20
+
21
+ from .config_stuff import Config
22
+
23
+ @dataclass
24
+ class Beam:
25
+ """
26
+ Structured representation of a decomposed electromagnetic beam.
27
+
28
+ This class serves two purposes:
29
+ 1. Storage: Holds the wavevectors (k) and complex
30
+ vector amplitudes (c) used by backends for field superposition.
31
+ 2. Physics Analysis: Provides properties and visualization tools to
32
+ inspect the beam's divergence, power, and spectrum.
33
+
34
+ Attributes
35
+ ----------
36
+ k : np.ndarray
37
+ (3, N) Wavevectors in units of rad/spatial_unit.
38
+ c : np.ndarray
39
+ (3, N) Complex vector amplitudes incorporating polarization,
40
+ intensity scaling, and phase offsets.
41
+ w : np.ndarray
42
+ (N,) Angular frequencies (norm of k).
43
+ inv_w : np.ndarray
44
+ (N,) Precomputed inverse frequencies for normalization.
45
+ a : np.ndarray
46
+ (N,) Scalar complex amplitudes (A * exp(i*phi)) before polarization.
47
+ """
48
+ k: np.ndarray
49
+ c: np.ndarray
50
+ w: np.ndarray
51
+ inv_w: np.ndarray
52
+ a: np.ndarray
53
+
54
+ # =========================================================================
55
+ # PHYSICS PROPERTIES
56
+ # =========================================================================
57
+
58
+ def __repr__(self) -> str:
59
+ return f"<Beam: {self.num_modes:,} modes | Power: {self.total_power:.2e}>"
60
+
61
+ @cached_property
62
+ def num_modes(self) -> int:
63
+ """Total number of plane wave modes in the beam."""
64
+ return self.k.shape[1]
65
+
66
+ @cached_property
67
+ def wavelengths(self) -> np.ndarray:
68
+ """Physical wavelengths of each mode (N,)."""
69
+ wl = np.zeros_like(self.w)
70
+ mask = self.w > 0
71
+ wl[mask] = 2 * np.pi / self.w[mask]
72
+ return wl
73
+
74
+ @cached_property
75
+ def k_hat(self) -> np.ndarray:
76
+ """Normalized propagation direction unit vectors (3, N)."""
77
+ return self.k * self.inv_w
78
+
79
+ @cached_property
80
+ def amplitudes(self) -> np.ndarray:
81
+ """The real-valued magnitude of each mode (Scalar)."""
82
+ return np.abs(self.a)
83
+
84
+ @cached_property
85
+ def polarizations(self) -> np.ndarray:
86
+ """
87
+ The Jones vectors (unit complex vectors) of each mode.
88
+ Extracted by dividing the vector amplitude (c) by the scalar amplitude (a).
89
+ """
90
+ mask = np.abs(self.a) > 1e-15
91
+ pol = np.zeros_like(self.c, dtype=complex)
92
+ pol[:, mask] = self.c[:, mask] / self.a[mask]
93
+ return pol
94
+
95
+ @cached_property
96
+ def mode_irradiances(self) -> np.ndarray:
97
+ """
98
+ The spatiotemporally averaged irradiance contribution of each mode.
99
+
100
+ For a single plane wave, the time-averaged energy density is uniform
101
+ across all space. This value represents the 'weight' of each discrete
102
+ spectral component in the superposition.
103
+
104
+ Calculated as the squared norm of the complex vector amplitudes:
105
+ |c_x|^2 + |c_y|^2 + |c_z|^2.
106
+ """
107
+ return np.sum(np.abs(self.c)**2, axis=0)
108
+
109
+ @cached_property
110
+ def total_power(self) -> float:
111
+ """
112
+ The integrated spectral norm of the beam.
113
+
114
+ Calculated as the sum of all individual mode irradiances,
115
+ it represents the total energy content of the angular spectrum.
116
+
117
+ While the local spatial intensity (FieldResult.intensity_E) is shaped by
118
+ interference, this value is a conserved quantity that defines the 'bulk'
119
+ strength of the beam.
120
+ """
121
+ return float(np.sum(self.mode_irradiances))
122
+
123
+ @cached_property
124
+ def mode_weights(self) -> np.ndarray:
125
+ """Normalized power contribution of each mode (sums to 1)."""
126
+ if self.total_power < 1e-15:
127
+ return np.zeros_like(self.mode_irradiances)
128
+ return self.mode_irradiances / self.total_power
129
+
130
+ @cached_property
131
+ def mean_direction(self) -> np.ndarray:
132
+ """Intensity-weighted mean propagation direction (3,)."""
133
+ mean_k = np.sum(self.k_hat * self.mode_weights[np.newaxis, :], axis=1)
134
+ mean_norm = np.linalg.norm(mean_k)
135
+
136
+ if mean_norm < 1e-5:
137
+ # Fallback: Use the axis of the mode with the highest irradiance
138
+ max_idx = np.argmax(self.mode_irradiances)
139
+ return self.k_hat[:, max_idx]
140
+
141
+ return mean_k / mean_norm
142
+
143
+ @cached_property
144
+ def rms_divergence(self) -> float:
145
+ """RMS divergence half-angle in radians."""
146
+ if self.total_power < 1e-15: return 0.0
147
+
148
+ cos_thetas = np.clip(np.dot(self.mean_direction, self.k_hat), -1.0, 1.0)
149
+ thetas = np.arccos(cos_thetas)
150
+ return np.sqrt(np.sum(self.mode_weights * thetas**2))
151
+
152
+ @cached_property
153
+ def wavelength_spectrum(self) -> Tuple[np.ndarray, np.ndarray]:
154
+ """
155
+ Groups plane waves by physical wavelength using a tight relative tolerance.
156
+
157
+ Returns
158
+ -------
159
+ unique_wls : np.ndarray
160
+ 1D array of unique physical wavelengths in the beam.
161
+ spectra : np.ndarray
162
+ 1D array of the integrated irradiance matching each unique wavelength.
163
+ """
164
+ if self.num_modes == 0:
165
+ return np.array([]), np.array([])
166
+
167
+ wls = self.wavelengths
168
+ sort_idx = np.argsort(wls)
169
+ sorted_wls = wls[sort_idx]
170
+ sorted_irrad = self.mode_irradiances[sort_idx]
171
+
172
+ unique_wls = []
173
+ spectra = []
174
+
175
+ current_wl = sorted_wls[0]
176
+ current_sum = sorted_irrad[0]
177
+
178
+ for i in range(1, len(sorted_wls)):
179
+ wl = sorted_wls[i]
180
+ if np.isclose(wl, current_wl, rtol=1e-6, atol=0.0):
181
+ current_sum += sorted_irrad[i]
182
+ else:
183
+ unique_wls.append(current_wl)
184
+ spectra.append(current_sum)
185
+ current_wl = wl
186
+ current_sum = sorted_irrad[i]
187
+
188
+ unique_wls.append(current_wl)
189
+ spectra.append(current_sum)
190
+
191
+ return np.array(unique_wls), np.array(spectra)
192
+
193
+ # =========================================================================
194
+ # USER TOOLS & VISUALIZATION
195
+ # =========================================================================
196
+
197
+ def summary(self):
198
+ """Prints a physical summary of the beam including divergence and axis."""
199
+ print(f"--- Beam Physics Summary ---")
200
+ print(f"Modes : {self.num_modes:,}")
201
+ print(f"Total Power : {self.total_power:.2e}")
202
+
203
+ unique_wls, _ = self.wavelength_spectrum
204
+ if len(unique_wls) == 1:
205
+ # Use .3g to automatically handle scientific notation nicely
206
+ print(f"Wavelength : {unique_wls[0]:.3g} (Monochromatic)")
207
+ elif len(unique_wls) < 10:
208
+ wls_str = ", ".join([f"{w:.3g}" for w in unique_wls])
209
+ print(f"Wavelengths : [{wls_str}] ({len(unique_wls)} distinct lines)")
210
+ else:
211
+ print(f"Wavelengths : {np.min(unique_wls):.3g} to {np.max(unique_wls):.3g} (Broadband)")
212
+
213
+ if self.total_power > 1e-15:
214
+ md = self.mean_direction
215
+ print(f"Mean Axis : [{md[0]:.3f}, {md[1]:.3f}, {md[2]:.3f}]")
216
+ print(f"RMS Divergence : ~{np.degrees(self.rms_divergence):.2f} degrees half-angle")
217
+
218
+ def plot_kspace_3d(
219
+ self, cmap='inferno', show: bool =True,
220
+ plot_type:Literal['colored_vectors','colored_sphere','matplotlib_scatter']='colored_vectors'
221
+ ):
222
+ """
223
+ Renders an interactive 3D visualization of the wavevectors and amplitudes.
224
+
225
+ Parameters
226
+ ----------
227
+ cmap : str, optional
228
+ Colormap for mode amplitudes (default is 'inferno').
229
+ plot_type : Literal['colored_vectors', 'colored_sphere'], optional
230
+ 'colored_vectors'
231
+ Interactive PyVista arrows.
232
+ 'colored_sphere'
233
+ Interactive PyVista sphere heatmap.
234
+ 'matplotlib_scatter'
235
+ Lightweight matplotlib 3D scatter.
236
+ show: bool, optional
237
+ If True, displays the plot. Default is True.
238
+
239
+ Returns
240
+ -------
241
+ pyvista.Plotter or None
242
+ Plotter object for further manipulation, or None if PyVista is missing.
243
+ """
244
+
245
+ if plot_type == 'matplotlib_scatter':
246
+ import matplotlib.pyplot as plt
247
+ from matplotlib.colors import Normalize
248
+ from matplotlib.cm import ScalarMappable
249
+
250
+ fig = plt.figure(figsize=(16,9))
251
+ ax = fig.add_subplot(projection="3d")
252
+
253
+ sc = ax.scatter(
254
+ self.k_hat[0],
255
+ self.k_hat[1],
256
+ self.k_hat[2],
257
+ c=self.amplitudes,
258
+ cmap=cmap,
259
+ s=30
260
+ )
261
+
262
+ ax.set_xlabel(r"$k_x$", fontsize=20)
263
+ ax.set_ylabel(r"$k_y$", fontsize=20)
264
+ ax.set_zlabel(r"$k_z$", fontsize=20)
265
+
266
+ ax.set_box_aspect((1, 1, 0.5))
267
+
268
+ cbar = fig.colorbar(sc, ax=ax)
269
+ cbar.set_label("Amplitude", fontsize=15)
270
+
271
+ if show:
272
+ plt.show()
273
+
274
+ return ax
275
+
276
+ try:
277
+ import pyvista as pv
278
+ except ImportError:
279
+ warnings.warn("pyvista is required for 3D visualization.")
280
+ return
281
+
282
+ plotter = pv.Plotter()
283
+ plotter.set_scale(1)
284
+ plotter.show_axes()
285
+
286
+
287
+ if plot_type == 'colored_vectors':
288
+ origins = np.zeros((self.num_modes, 3))
289
+ mesh = pv.PolyData(origins)
290
+ mesh["vec"] = self.k_hat.T
291
+ mesh["amplitudes"] = self.amplitudes
292
+ arrows = mesh.glyph(orient="vec", scale=False, factor=0.2)
293
+ plotter.add_mesh(
294
+ arrows, scalars='amplitudes',
295
+ cmap=cmap, clim=[0, np.max(self.amplitudes)],
296
+ scalar_bar_args={'vertical': True, 'title': 'Amplitude'}
297
+ )
298
+
299
+ elif plot_type == 'colored_sphere':
300
+ sphere = pv.Sphere(theta_resolution=60, phi_resolution=120)
301
+ # map amplitudes to sphere points using nearest-neighbor
302
+ from scipy.spatial import cKDTree
303
+ tree = cKDTree(self.k_hat.T)
304
+ _, idx = tree.query(sphere.points) # find nearest k_hat for each sphere point
305
+ sphere["amplitudes"] = self.amplitudes[idx]
306
+ plotter.add_mesh(
307
+ sphere, scalars='amplitudes',
308
+ cmap=cmap, clim=[0, np.max(self.amplitudes)],
309
+ scalar_bar_args={'vertical': True, 'title': 'Amplitude'}
310
+ )
311
+
312
+ else: raise ValueError("plot_type must be colored_sphere or colored_vectors")
313
+
314
+ if show: plotter.show()
315
+
316
+ return plotter
317
+
318
+ def plot_k_perp_profile(self, normal: Optional[Tuple[float, float, float]] = None, show: bool = True):
319
+ """
320
+ Plots Amplitude vs Transverse wave number (k_perp).
321
+
322
+ Parameters
323
+ ----------
324
+ normal : tuple, optional
325
+ The normal vector defining the longitudinal axis. If None, it attempts
326
+ to find the intensity-weighted mean direction. If the beam is
327
+ perfectly symmetric (e.g., a standing wave), it falls back to the
328
+ direction of the dominant mode.
329
+ show: bool, optional
330
+ If True, displays the plot. Default is True.
331
+
332
+ Returns
333
+ -------
334
+ Tuple[matplotlib.figure.Figure, matplotlib.axes.Axes] or None
335
+
336
+ """
337
+ try:
338
+ import matplotlib.pyplot as plt
339
+ except ImportError:
340
+ warnings.warn("matplotlib is required to plot k-space profiles.")
341
+ return None
342
+
343
+ if normal is not None:
344
+ normal_vec = np.array(normal, dtype=float)
345
+ norm = np.linalg.norm(normal_vec)
346
+ normal_vec = normal_vec / norm if norm > 0 else np.array([0., 0., 1.])
347
+ else:
348
+ normal_vec = self.mean_direction # <--- Leverages new property!
349
+
350
+ k_parallel_mags = np.dot(normal_vec, self.k)
351
+ k_parallel_vecs = normal_vec[:, np.newaxis] * k_parallel_mags
352
+ k_perp = np.linalg.norm(self.k - k_parallel_vecs, axis=0)
353
+
354
+ fig, ax = plt.subplots(figsize=(7, 4))
355
+ ax.scatter(k_perp, self.amplitudes, s=15)
356
+ ax.set_xlabel(r'Transverse Wavenumber $k_\perp$')
357
+ ax.set_ylabel('Mode Amplitude')
358
+ ax.set_title(f"K-Space Transverse Profile about\nNormal: [{normal_vec[0]:.2f}, {normal_vec[1]:.2f}, {normal_vec[2]:.2f}]")
359
+ ax.grid(True, alpha=0.2)
360
+ ax.set_ylim(0, np.max(self.amplitudes)*1.2)
361
+ plt.tight_layout()
362
+
363
+ if show: plt.show()
364
+ return fig, ax
365
+
366
+ def plot_wavelength_spectrum(self, show: bool = True):
367
+ """
368
+ Plots the intensity-weighted wavelength spectrum.
369
+
370
+ Parameters
371
+ ----------
372
+ show: bool, optional
373
+ If True, displays the plot. Default is True.
374
+
375
+ Returns
376
+ -------
377
+ Tuple[matplotlib.figure.Figure, matplotlib.axes.Axes] or None
378
+ """
379
+ try:
380
+ import matplotlib.pyplot as plt
381
+ except ImportError:
382
+ warnings.warn("matplotlib is required to plot spectrum.")
383
+ return None
384
+
385
+ fig, ax = plt.subplots(figsize=(6, 4))
386
+ unique_wls, spectra = self.wavelength_spectrum
387
+
388
+ if len(unique_wls) == 1:
389
+ ax.axvline(unique_wls[0], color='indigo', lw=3, label=fr'$\lambda={unique_wls[0]:.3g}$')
390
+ ax.legend()
391
+ else:
392
+ # Bar width: 2% of the spectrum range, or 0.1% of the smallest wavelength
393
+ ptp = np.ptp(unique_wls)
394
+ bar_width = ptp * 0.02 if ptp > 0 else unique_wls[0] * 1e-3
395
+ ax.bar(unique_wls, spectra, width=bar_width, color='indigo')
396
+
397
+ # Optional: Format X-axis for scientific notation if very small
398
+ ax.ticklabel_format(style='sci', axis='x', scilimits=(-3, 3))
399
+
400
+ ax.set_xlabel("Wavelength")
401
+ ax.set_ylabel("Intensity")
402
+ if show: plt.show()
403
+ return fig, ax
404
+
405
+ class BeamMaker:
406
+ """
407
+ Factory class that translates a `Config` object into a `Beam`.
408
+
409
+ Handles the mathematical heavy lifting of sphere sampling (Fibonacci),
410
+ polarization basis construction (Rodrigues), and spectral weight application.
411
+ """
412
+ def __init__(self, config: Config):
413
+ self.config = config
414
+ self.config.validate()
415
+ self.rng = np.random.default_rng(config.source.randomize.seed)
416
+
417
+ def generate_beam(self) -> Beam:
418
+ """
419
+ Executes the generation pipeline to produce a superposition of plane waves.
420
+
421
+ This method aggregates configurations (wavelength, angular sampling,
422
+ stochastic noise, and k-space profiles) to create a fully quantified
423
+ electromagnetic beam.
424
+
425
+ Returns
426
+ -------
427
+ Beam
428
+ Precomputed beam object ready for evaluation in the FieldEngine.
429
+
430
+ Raises
431
+ ------
432
+ ValueError
433
+ If `num_modes` is less than 1.
434
+ ValueError
435
+ If the generated beam evaluates to a total power near zero (< 1e-15).
436
+ This typically occurs when the k-space profile evaluates to zero
437
+ across all sampled angular grid points (e.g., mismatch between
438
+ `beam_axis` and the profile's non-zero domain).
439
+ """
440
+ modes = self.config.source.num_modes
441
+ if modes < 1:
442
+ raise ValueError(f"num_modes must be >= 1. Got: {modes}.")
443
+ elif modes == 1:
444
+ warnings.warn("num_modes is 1. Generating a pure single plane wave.")
445
+ elif modes < 10:
446
+ warnings.warn(f"num_modes ({modes}) is very low. Beam profile may be under-sampled.")
447
+
448
+ if self.config.verbose:
449
+ print(f"--- Starting Beam Generation (Modes: {modes}) ---")
450
+
451
+ wls = np.atleast_1d(self.config.source.wavelength)
452
+ num_wls = len(wls)
453
+
454
+ # 1. Compute polychromatic envelope weights
455
+ if num_wls > 1:
456
+ poly_cfg = self.config.source.polychromatic
457
+ weights = np.array([poly_cfg.profile(wl, **poly_cfg.params) for wl in wls])
458
+ w_sum = np.linalg.norm(weights)
459
+ weights = weights / w_sum if w_sum > 1e-12 else np.ones(num_wls) / np.sqrt(num_wls)
460
+ else:
461
+ weights = np.ones(num_wls)
462
+
463
+ weights *= np.sqrt(self.config.source.intensity_scale)
464
+
465
+ # 2. Generate sampling grid on unit sphere
466
+ master_k_hats, master_d_omega = self._sample_sphere_fib(
467
+ N=self.config.source.num_modes,
468
+ beam_axis=self.config.source.beam_axis,
469
+ theta_max=self.config.source.theta_max
470
+ )
471
+
472
+ # 3. Generate wave batches per wavelength
473
+ all_ks, all_cs, all_amps = [], [], []
474
+
475
+ for i, (wl, spectral_weight) in enumerate(zip(wls, weights)):
476
+ indices = slice(i, None, num_wls)
477
+ k_chunk = master_k_hats[indices]
478
+ d_omega_chunk = master_d_omega[indices]
479
+ if len(k_chunk) == 0: continue
480
+
481
+ ks, cs, amps = self._generate_monochromatic_batch(wl, k_chunk, d_omega_chunk, spectral_weight)
482
+ all_ks.append(ks); all_cs.append(cs); all_amps.append(amps)
483
+
484
+ # 4. Final Aggregation
485
+ k_out = np.vstack(all_ks).T
486
+ c_out = np.vstack(all_cs).T
487
+ a_out = np.concatenate(all_amps)
488
+ w_out = np.linalg.norm(k_out, axis=0)
489
+
490
+ with np.errstate(divide='ignore'):
491
+ inv_w_out = 1.0 / w_out
492
+ inv_w_out[w_out == 0] = 0
493
+
494
+ beam = Beam(k=k_out, c=c_out, w=w_out, inv_w=inv_w_out, a=a_out)
495
+ if beam.total_power < 1e-15:
496
+ raise ValueError(
497
+ "Generated beam has essentially zero power (total_power < 1e-15). "
498
+ "Check your k-space profile, beam_axis, and theta_max. The angular "
499
+ "sampling grid may have completely missed the profile's non-zero region."
500
+ )
501
+ else:
502
+ axis = np.array(self.config.source.beam_axis)
503
+ cos_thetas = np.dot(axis, beam.k_hat)
504
+ thetas = np.arccos(np.clip(cos_thetas, -1.0, 1.0))
505
+ actual_theta_max = self.config.source.theta_max
506
+ if actual_theta_max < (np.pi / 2 - 1e-4):
507
+ # Check the intensity of modes within the outer 5% of the sampled cone
508
+ edge_mask = thetas > (0.95 * actual_theta_max)
509
+ if np.any(edge_mask):
510
+ max_edge_amp = np.max(beam.amplitudes[edge_mask])
511
+ peak_amp = np.max(beam.amplitudes)
512
+
513
+ if max_edge_amp > 0.01 * peak_amp:
514
+ warnings.warn(
515
+ f"Beam Clipping Detected: The k-space spectrum is still active at the \
516
+ boundary of theta_max ({np.degrees(actual_theta_max):.1f}°). "
517
+ )
518
+
519
+ return beam
520
+
521
+ def _generate_monochromatic_batch(self, wavelength: float, k_hats: np.ndarray,
522
+ d_omega: np.ndarray, weight: float) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
523
+ """Internal helper for generating modes at a specific wavelength line."""
524
+ N = len(k_hats)
525
+ ks = (2 * np.pi / wavelength) * k_hats
526
+
527
+ # --- Polarization Basis (Local Transverse Frames) ---
528
+ e1, e2 = self._transverse_basis_batch_rod(k_hats, self.config.source.beam_axis)
529
+ px, py = self.config.source.pol_vect
530
+
531
+ # Handle Randomized Polarization
532
+ pol_rot_max = self.config.source.randomize.pol_rot_max
533
+ if self.config.source.randomize.pol_state:
534
+ temp = self._sample_sphere_fib(N, (0, 0, 1), np.pi)
535
+ s1, s2, s3 = temp[0][:, 0], temp[0][:, 1], temp[0][:, 2]
536
+ P = np.sqrt((1.0+s1)/2.0)[:,None]*e1 + (np.sqrt((1.0-s1)/2.0)*np.exp(1j*np.arctan2(s3, s2)))[:,None]*e2
537
+ elif pol_rot_max > 0:
538
+ angles = self.rng.uniform(-pol_rot_max, pol_rot_max, size=N)
539
+ c_a, s_a = np.cos(angles), np.sin(angles)
540
+ P = (c_a*px - s_a*py)[:, None]*e1 + (s_a*px + c_a*py)[:, None]*e2
541
+ else:
542
+ P = px * e1 + py * e2
543
+
544
+ # --- K-Space Amplitude Spectrum ---
545
+ kspace_cfg = self.config.source.k_space
546
+ if kspace_cfg.vectorised:
547
+ amps = np.asarray(kspace_cfg.profile(ks.T, **kspace_cfg.params), dtype=complex).squeeze()
548
+ else:
549
+ amps = np.array([kspace_cfg.profile(k, **kspace_cfg.params) for k in ks], dtype=complex)
550
+
551
+ # --- Stochastic Noise ---
552
+ phase_max = self.config.source.randomize.phase_max
553
+ if phase_max > 0:
554
+ amps *= np.exp(1j * self.rng.uniform(-phase_max, phase_max, size=N))
555
+ if self.config.source.randomize.amplitude:
556
+ amps *= (self.rng.normal(0,1,N) + 1j*self.rng.normal(0,1,N)) * 0.7071
557
+
558
+ # --- Power Normalization ---
559
+ raw_power = np.sum(np.abs(amps)**2 * d_omega)
560
+ if raw_power < 1e-15:
561
+ return ks, np.zeros((N, 3), dtype=complex), np.zeros(N, dtype=complex)
562
+
563
+ scaling = (1.0 / np.sqrt(raw_power)) * weight * d_omega
564
+ amps *= scaling
565
+ return ks, P * amps[:, np.newaxis], amps
566
+
567
+ # =========================================================================
568
+ # SAMPLING STRATEGIES
569
+ # =========================================================================
570
+ def _sample_sphere_fib(self, N: int, beam_axis: Tuple, theta_max: float) -> Tuple[np.ndarray, np.ndarray]:
571
+ z_min = np.cos(theta_max)
572
+ z_range = 1.0 - z_min
573
+ i = np.arange(N)
574
+ z = 1.0 - (i + 0.5) * z_range / N
575
+ r = np.sqrt(np.maximum(0, 1 - z**2))
576
+ phi = np.pi * (3.0 - np.sqrt(5.0)) * i
577
+
578
+ points = np.column_stack((r * np.cos(phi), r * np.sin(phi), z))
579
+ ang = self.rng.uniform(0, 2*np.pi)
580
+ c, s = np.cos(ang), np.sin(ang)
581
+ points = points @ np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]]).T
582
+
583
+ return self._align_to_axis(points, beam_axis), np.full(N, (2 * np.pi * z_range) / N)
584
+
585
+ def _align_to_axis(self, points: np.ndarray, target_axis: Tuple) -> np.ndarray:
586
+ target = np.array(target_axis)
587
+ norm = np.linalg.norm(target)
588
+ if norm == 0: return points
589
+ target = target / norm
590
+ z_hat = np.array([0.0, 0.0, 1.0])
591
+ c = np.dot(z_hat, target)
592
+ if c > 0.999999: return points
593
+ if c < -0.999999:
594
+ p2 = points.copy()
595
+ p2[:, 2] *= -1; p2[:, 0] *= -1
596
+ return p2
597
+ v = np.cross(z_hat, target)
598
+ vx = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
599
+ R = np.eye(3) + vx + (vx @ vx) * ((1 - c) / np.dot(v, v))
600
+ return points @ R.T
601
+
602
+ def _transverse_basis_batch_rod(self, ks: np.ndarray, beam_axis: Tuple) -> Tuple[np.ndarray, np.ndarray]:
603
+ beam_axis = np.array(beam_axis)
604
+ n = beam_axis / np.linalg.norm(beam_axis) if np.linalg.norm(beam_axis) > 0 else np.array([0., 0., 1.])
605
+ ks_norm = ks / np.linalg.norm(ks, axis=1, keepdims=True)
606
+
607
+ u = np.cross(n, [0.0, 0.0, 1.0] if np.abs(n[2]) < 0.9 else [0.0, 1.0, 0.0])
608
+ u /= np.linalg.norm(u)
609
+ v = np.cross(n, u)
610
+
611
+ w = np.cross(n, ks_norm)
612
+ s = np.linalg.norm(w, axis=1, keepdims=True)
613
+ c = np.sum(n * ks_norm, axis=1, keepdims=True)
614
+
615
+ e1, e2 = np.tile(u, (len(ks), 1)), np.tile(v, (len(ks), 1))
616
+ mask = s[:, 0] > 1e-9
617
+ if np.any(mask):
618
+ wn = w[mask] / s[mask]
619
+ u_dot, v_dot = np.sum(wn * u, axis=1, keepdims=True), np.sum(wn * v, axis=1, keepdims=True)
620
+ e1[mask] = (u * c[mask] + np.cross(wn, u) * s[mask] + wn * u_dot * (1 - c[mask]))
621
+ e2[mask] = (v * c[mask] + np.cross(wn, v) * s[mask] + wn * v_dot * (1 - c[mask]))
622
+
623
+ return e1, e2