lsurf 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. lsurf/__init__.py +471 -0
  2. lsurf/analysis/__init__.py +107 -0
  3. lsurf/analysis/healpix_utils.py +418 -0
  4. lsurf/analysis/sphere_viz.py +1280 -0
  5. lsurf/cli/__init__.py +48 -0
  6. lsurf/cli/build.py +398 -0
  7. lsurf/cli/config_schema.py +318 -0
  8. lsurf/cli/gui_cmd.py +76 -0
  9. lsurf/cli/interactive.py +850 -0
  10. lsurf/cli/main.py +81 -0
  11. lsurf/cli/run.py +806 -0
  12. lsurf/detectors/__init__.py +266 -0
  13. lsurf/detectors/analysis.py +289 -0
  14. lsurf/detectors/base.py +284 -0
  15. lsurf/detectors/constant_size_rings.py +485 -0
  16. lsurf/detectors/directional.py +45 -0
  17. lsurf/detectors/extended/__init__.py +73 -0
  18. lsurf/detectors/extended/local_sphere.py +353 -0
  19. lsurf/detectors/extended/recording_sphere.py +368 -0
  20. lsurf/detectors/planar.py +45 -0
  21. lsurf/detectors/protocol.py +187 -0
  22. lsurf/detectors/recording_spheres.py +63 -0
  23. lsurf/detectors/results.py +1140 -0
  24. lsurf/detectors/small/__init__.py +79 -0
  25. lsurf/detectors/small/directional.py +330 -0
  26. lsurf/detectors/small/planar.py +401 -0
  27. lsurf/detectors/small/spherical.py +450 -0
  28. lsurf/detectors/spherical.py +45 -0
  29. lsurf/geometry/__init__.py +199 -0
  30. lsurf/geometry/builder.py +478 -0
  31. lsurf/geometry/cell.py +228 -0
  32. lsurf/geometry/cell_geometry.py +247 -0
  33. lsurf/geometry/detector_arrays.py +1785 -0
  34. lsurf/geometry/geometry.py +222 -0
  35. lsurf/geometry/surface_analysis.py +375 -0
  36. lsurf/geometry/validation.py +91 -0
  37. lsurf/gui/__init__.py +51 -0
  38. lsurf/gui/app.py +903 -0
  39. lsurf/gui/core/__init__.py +39 -0
  40. lsurf/gui/core/scene.py +343 -0
  41. lsurf/gui/core/simulation.py +264 -0
  42. lsurf/gui/renderers/__init__.py +40 -0
  43. lsurf/gui/renderers/ray_renderer.py +353 -0
  44. lsurf/gui/renderers/source_renderer.py +505 -0
  45. lsurf/gui/renderers/surface_renderer.py +477 -0
  46. lsurf/gui/views/__init__.py +48 -0
  47. lsurf/gui/views/config_editor.py +3199 -0
  48. lsurf/gui/views/properties.py +257 -0
  49. lsurf/gui/views/results.py +291 -0
  50. lsurf/gui/views/scene_tree.py +180 -0
  51. lsurf/gui/views/viewport_3d.py +555 -0
  52. lsurf/gui/views/visualizations.py +712 -0
  53. lsurf/materials/__init__.py +169 -0
  54. lsurf/materials/base/__init__.py +64 -0
  55. lsurf/materials/base/full_inhomogeneous.py +208 -0
  56. lsurf/materials/base/grid_inhomogeneous.py +319 -0
  57. lsurf/materials/base/homogeneous.py +342 -0
  58. lsurf/materials/base/material_field.py +527 -0
  59. lsurf/materials/base/simple_inhomogeneous.py +418 -0
  60. lsurf/materials/base/spectral_inhomogeneous.py +497 -0
  61. lsurf/materials/implementations/__init__.py +120 -0
  62. lsurf/materials/implementations/data/alpha_values_typical_atmosphere_updated.txt +24 -0
  63. lsurf/materials/implementations/duct_atmosphere.py +390 -0
  64. lsurf/materials/implementations/exponential_atmosphere.py +435 -0
  65. lsurf/materials/implementations/gaussian_lens.py +120 -0
  66. lsurf/materials/implementations/interpolated_data.py +123 -0
  67. lsurf/materials/implementations/layered_atmosphere.py +134 -0
  68. lsurf/materials/implementations/linear_gradient.py +109 -0
  69. lsurf/materials/implementations/linsley_atmosphere.py +764 -0
  70. lsurf/materials/implementations/standard_materials.py +126 -0
  71. lsurf/materials/implementations/turbulent_atmosphere.py +135 -0
  72. lsurf/materials/implementations/us_standard_atmosphere.py +149 -0
  73. lsurf/materials/utils/__init__.py +77 -0
  74. lsurf/materials/utils/constants.py +45 -0
  75. lsurf/materials/utils/device_functions.py +117 -0
  76. lsurf/materials/utils/dispersion.py +160 -0
  77. lsurf/materials/utils/factories.py +142 -0
  78. lsurf/propagation/__init__.py +91 -0
  79. lsurf/propagation/detector_gpu.py +67 -0
  80. lsurf/propagation/gpu_device_rays.py +294 -0
  81. lsurf/propagation/kernels/__init__.py +175 -0
  82. lsurf/propagation/kernels/absorption/__init__.py +61 -0
  83. lsurf/propagation/kernels/absorption/grid.py +240 -0
  84. lsurf/propagation/kernels/absorption/simple.py +232 -0
  85. lsurf/propagation/kernels/absorption/spectral.py +410 -0
  86. lsurf/propagation/kernels/detection/__init__.py +64 -0
  87. lsurf/propagation/kernels/detection/protocol.py +102 -0
  88. lsurf/propagation/kernels/detection/spherical.py +255 -0
  89. lsurf/propagation/kernels/device_functions.py +790 -0
  90. lsurf/propagation/kernels/fresnel/__init__.py +64 -0
  91. lsurf/propagation/kernels/fresnel/protocol.py +97 -0
  92. lsurf/propagation/kernels/fresnel/standard.py +258 -0
  93. lsurf/propagation/kernels/intersection/__init__.py +79 -0
  94. lsurf/propagation/kernels/intersection/annular_plane.py +207 -0
  95. lsurf/propagation/kernels/intersection/bounded_plane.py +205 -0
  96. lsurf/propagation/kernels/intersection/plane.py +166 -0
  97. lsurf/propagation/kernels/intersection/protocol.py +95 -0
  98. lsurf/propagation/kernels/intersection/signed_distance.py +742 -0
  99. lsurf/propagation/kernels/intersection/sphere.py +190 -0
  100. lsurf/propagation/kernels/propagation/__init__.py +85 -0
  101. lsurf/propagation/kernels/propagation/grid.py +527 -0
  102. lsurf/propagation/kernels/propagation/protocol.py +105 -0
  103. lsurf/propagation/kernels/propagation/simple.py +460 -0
  104. lsurf/propagation/kernels/propagation/spectral.py +875 -0
  105. lsurf/propagation/kernels/registry.py +331 -0
  106. lsurf/propagation/kernels/surface/__init__.py +72 -0
  107. lsurf/propagation/kernels/surface/bisection.py +232 -0
  108. lsurf/propagation/kernels/surface/detection.py +402 -0
  109. lsurf/propagation/kernels/surface/reduction.py +166 -0
  110. lsurf/propagation/propagator_protocol.py +222 -0
  111. lsurf/propagation/propagators/__init__.py +101 -0
  112. lsurf/propagation/propagators/detector_handler.py +354 -0
  113. lsurf/propagation/propagators/factory.py +200 -0
  114. lsurf/propagation/propagators/fresnel_handler.py +305 -0
  115. lsurf/propagation/propagators/gpu_gradient.py +566 -0
  116. lsurf/propagation/propagators/gpu_surface_propagator.py +707 -0
  117. lsurf/propagation/propagators/gradient.py +429 -0
  118. lsurf/propagation/propagators/intersection_handler.py +327 -0
  119. lsurf/propagation/propagators/material_propagator.py +398 -0
  120. lsurf/propagation/propagators/signed_distance_handler.py +522 -0
  121. lsurf/propagation/propagators/spectral_gpu_gradient.py +553 -0
  122. lsurf/propagation/propagators/surface_interaction.py +616 -0
  123. lsurf/propagation/propagators/surface_propagator.py +719 -0
  124. lsurf/py.typed +1 -0
  125. lsurf/simulation/__init__.py +70 -0
  126. lsurf/simulation/config.py +164 -0
  127. lsurf/simulation/orchestrator.py +462 -0
  128. lsurf/simulation/result.py +299 -0
  129. lsurf/simulation/simulation.py +262 -0
  130. lsurf/sources/__init__.py +128 -0
  131. lsurf/sources/base.py +264 -0
  132. lsurf/sources/collimated.py +252 -0
  133. lsurf/sources/custom.py +409 -0
  134. lsurf/sources/diverging.py +228 -0
  135. lsurf/sources/gaussian.py +272 -0
  136. lsurf/sources/parallel_from_positions.py +197 -0
  137. lsurf/sources/point.py +172 -0
  138. lsurf/sources/uniform_diverging.py +258 -0
  139. lsurf/surfaces/__init__.py +184 -0
  140. lsurf/surfaces/cpu/__init__.py +50 -0
  141. lsurf/surfaces/cpu/curved_wave.py +463 -0
  142. lsurf/surfaces/cpu/gerstner_wave.py +381 -0
  143. lsurf/surfaces/cpu/wave_params.py +118 -0
  144. lsurf/surfaces/gpu/__init__.py +72 -0
  145. lsurf/surfaces/gpu/annular_plane.py +453 -0
  146. lsurf/surfaces/gpu/bounded_plane.py +390 -0
  147. lsurf/surfaces/gpu/curved_wave.py +483 -0
  148. lsurf/surfaces/gpu/gerstner_wave.py +377 -0
  149. lsurf/surfaces/gpu/multi_curved_wave.py +520 -0
  150. lsurf/surfaces/gpu/plane.py +299 -0
  151. lsurf/surfaces/gpu/recording_sphere.py +587 -0
  152. lsurf/surfaces/gpu/sphere.py +311 -0
  153. lsurf/surfaces/protocol.py +336 -0
  154. lsurf/surfaces/registry.py +373 -0
  155. lsurf/utilities/__init__.py +175 -0
  156. lsurf/utilities/detector_analysis.py +814 -0
  157. lsurf/utilities/fresnel.py +628 -0
  158. lsurf/utilities/interactions.py +1215 -0
  159. lsurf/utilities/propagation.py +602 -0
  160. lsurf/utilities/ray_data.py +532 -0
  161. lsurf/utilities/recording_sphere.py +745 -0
  162. lsurf/utilities/time_spread.py +463 -0
  163. lsurf/visualization/__init__.py +329 -0
  164. lsurf/visualization/absorption_plots.py +334 -0
  165. lsurf/visualization/atmospheric_plots.py +754 -0
  166. lsurf/visualization/common.py +348 -0
  167. lsurf/visualization/detector_plots.py +1350 -0
  168. lsurf/visualization/detector_sphere_plots.py +1173 -0
  169. lsurf/visualization/fresnel_plots.py +1061 -0
  170. lsurf/visualization/ocean_simulation_plots.py +999 -0
  171. lsurf/visualization/polarization_plots.py +916 -0
  172. lsurf/visualization/raytracing_plots.py +1521 -0
  173. lsurf/visualization/ring_detector_plots.py +1867 -0
  174. lsurf/visualization/time_spread_plots.py +531 -0
  175. lsurf-1.0.0.dist-info/METADATA +381 -0
  176. lsurf-1.0.0.dist-info/RECORD +180 -0
  177. lsurf-1.0.0.dist-info/WHEEL +5 -0
  178. lsurf-1.0.0.dist-info/entry_points.txt +2 -0
  179. lsurf-1.0.0.dist-info/licenses/LICENSE +32 -0
  180. lsurf-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1521 @@
1
+ # The Clear BSD License
2
+ #
3
+ # Copyright (c) 2026 Tobias Heibges
4
+ # All rights reserved.
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted (subject to the limitations in the disclaimer
8
+ # below) provided that the following conditions are met:
9
+ #
10
+ # * Redistributions of source code must retain the above copyright notice,
11
+ # this list of conditions and the following disclaimer.
12
+ #
13
+ # * Redistributions in binary form must reproduce the above copyright
14
+ # notice, this list of conditions and the following disclaimer in the
15
+ # documentation and/or other materials provided with the distribution.
16
+ #
17
+ # * Neither the name of the copyright holder nor the names of its
18
+ # contributors may be used to endorse or promote products derived from this
19
+ # software without specific prior written permission.
20
+ #
21
+ # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
22
+ # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
23
+ # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24
+ # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
25
+ # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
26
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
27
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
28
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
29
+ # BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
30
+ # IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
31
+ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
32
+ # POSSIBILITY OF SUCH DAMAGE.
33
+
34
+ """
35
+ Ray Tracing Visualization - Individual Axis Functions
36
+
37
+ Functions for plotting ray paths, intersections, and propagation.
38
+ Each function draws on a single axis, enabling flexible composition.
39
+ """
40
+
41
+ from typing import TYPE_CHECKING, Optional
42
+
43
+ import matplotlib.pyplot as plt
44
+ import numpy as np
45
+ from matplotlib.axes import Axes
46
+ from matplotlib.figure import Figure
47
+
48
+ if TYPE_CHECKING:
49
+ from ..utilities.ray_data import RayBatch
50
+ from ..surfaces import Surface
51
+
52
+ from .common import (
53
+ DEFAULT_LINEWIDTH,
54
+ DEFAULT_MARKERSIZE,
55
+ INTENSITY_CMAP,
56
+ LINE_ALPHA,
57
+ SCATTER_ALPHA,
58
+ WAVELENGTH_CMAP,
59
+ add_colorbar,
60
+ get_color_mapping,
61
+ get_projection_config,
62
+ save_figure,
63
+ setup_axis_grid,
64
+ )
65
+
66
+ # =============================================================================
67
+ # Single-Axis Ray Path Functions
68
+ # =============================================================================
69
+
70
+
71
+ def plot_ray_paths_projection(
72
+ ax: Axes,
73
+ ray_history: list["RayBatch"],
74
+ projection: str = "xz",
75
+ max_rays: int = 100,
76
+ color_by: str = "wavelength",
77
+ alpha: float = LINE_ALPHA,
78
+ linewidth: float = DEFAULT_LINEWIDTH,
79
+ show_colorbar: bool = True,
80
+ ) -> plt.cm.ScalarMappable | None:
81
+ """
82
+ Plot ray paths as a 2D projection on given axis.
83
+
84
+ Parameters
85
+ ----------
86
+ ax : Axes
87
+ Matplotlib axes to draw on.
88
+ ray_history : List[RayBatch]
89
+ List of ray batches at different time steps.
90
+ projection : str
91
+ Projection plane: 'xy', 'xz', or 'yz'.
92
+ max_rays : int
93
+ Maximum rays to plot for performance.
94
+ color_by : str
95
+ Color rays by: 'wavelength', 'intensity', 'generation', 'index'.
96
+ alpha : float
97
+ Line transparency.
98
+ linewidth : float
99
+ Line width.
100
+ show_colorbar : bool
101
+ Whether to add colorbar.
102
+
103
+ Returns
104
+ -------
105
+ sm : ScalarMappable or None
106
+ ScalarMappable for external colorbar, or None.
107
+ """
108
+ if len(ray_history) == 0:
109
+ return None
110
+
111
+ initial_batch = ray_history[0]
112
+ n_rays = initial_batch.num_rays
113
+
114
+ # Sample rays
115
+ if n_rays > max_rays:
116
+ ray_indices = np.linspace(0, n_rays - 1, max_rays, dtype=int)
117
+ else:
118
+ ray_indices = np.arange(n_rays)
119
+
120
+ # Coordinate mapping
121
+ idx1, idx2, xlabel, ylabel = get_projection_config(projection)
122
+
123
+ # Color mapping
124
+ if color_by == "wavelength":
125
+ values = initial_batch.wavelengths[ray_indices] * 1e9
126
+ cmap = WAVELENGTH_CMAP
127
+ label = "Wavelength (nm)"
128
+ elif color_by == "intensity":
129
+ values = initial_batch.intensities[ray_indices]
130
+ cmap = INTENSITY_CMAP
131
+ label = "Intensity"
132
+ elif color_by == "generation":
133
+ values = initial_batch.generations[ray_indices]
134
+ cmap = "tab10"
135
+ label = "Generation"
136
+ else:
137
+ values = ray_indices.astype(float)
138
+ cmap = "tab20"
139
+ label = "Ray Index"
140
+
141
+ colors, norm, sm = get_color_mapping(values, cmap)
142
+
143
+ # Plot paths
144
+ for i, ray_idx in enumerate(ray_indices):
145
+ coords1 = [batch.positions[ray_idx, idx1] for batch in ray_history]
146
+ coords2 = [batch.positions[ray_idx, idx2] for batch in ray_history]
147
+ ax.plot(coords1, coords2, color=colors[i], alpha=alpha, linewidth=linewidth)
148
+
149
+ setup_axis_grid(ax, xlabel, ylabel, f"{projection.upper()} Projection")
150
+ ax.set_aspect("equal", adjustable="box")
151
+
152
+ if show_colorbar:
153
+ add_colorbar(ax, sm, label)
154
+
155
+ return sm
156
+
157
+
158
+ def plot_ray_endpoints_scatter(
159
+ ax: Axes,
160
+ rays: "RayBatch",
161
+ projection: str = "xy",
162
+ color_by: str = "intensity",
163
+ alpha: float = SCATTER_ALPHA,
164
+ size: float = DEFAULT_MARKERSIZE,
165
+ show_colorbar: bool = True,
166
+ ) -> plt.cm.ScalarMappable | None:
167
+ """
168
+ Plot ray endpoint positions as scatter plot.
169
+
170
+ Parameters
171
+ ----------
172
+ ax : Axes
173
+ Matplotlib axes.
174
+ rays : RayBatch
175
+ Ray batch.
176
+ projection : str
177
+ Plane: 'xy', 'xz', 'yz'.
178
+ color_by : str
179
+ Color by: 'intensity', 'wavelength', 'generation', 'time'.
180
+ alpha : float
181
+ Point transparency.
182
+ size : float
183
+ Point size.
184
+ show_colorbar : bool
185
+ Whether to add colorbar.
186
+
187
+ Returns
188
+ -------
189
+ sm : ScalarMappable or None
190
+ For external colorbar.
191
+ """
192
+ active_mask = rays.active
193
+ positions = rays.positions[active_mask]
194
+
195
+ # Coordinate selection
196
+ idx1, idx2, xlabel, ylabel = get_projection_config(projection)
197
+
198
+ x, y = positions[:, idx1], positions[:, idx2]
199
+
200
+ # Color values
201
+ if color_by == "intensity":
202
+ c = rays.intensities[active_mask]
203
+ cmap = INTENSITY_CMAP
204
+ clabel = "Intensity"
205
+ elif color_by == "wavelength":
206
+ c = rays.wavelengths[active_mask] * 1e9
207
+ cmap = WAVELENGTH_CMAP
208
+ clabel = "Wavelength (nm)"
209
+ elif color_by == "generation":
210
+ c = rays.generations[active_mask]
211
+ cmap = "tab10"
212
+ clabel = "Generation"
213
+ elif color_by == "time":
214
+ c = rays.accumulated_time[active_mask] * 1e6
215
+ cmap = "coolwarm"
216
+ clabel = "Time (μs)"
217
+ else:
218
+ c = np.arange(len(x))
219
+ cmap = "viridis"
220
+ clabel = "Index"
221
+
222
+ scatter = ax.scatter(x, y, c=c, s=size, alpha=alpha, cmap=cmap)
223
+ setup_axis_grid(ax, xlabel, ylabel, f"Ray Endpoints - {projection.upper()}")
224
+ ax.set_aspect("equal", adjustable="box")
225
+
226
+ if show_colorbar and c is not None:
227
+ add_colorbar(ax, scatter, clabel)
228
+
229
+ return scatter
230
+
231
+
232
+ def plot_ray_endpoints_histogram(
233
+ ax: Axes,
234
+ rays: "RayBatch",
235
+ projection: str = "xy",
236
+ bins: int = 50,
237
+ cmap: str = "hot",
238
+ ) -> None:
239
+ """
240
+ Plot 2D histogram of ray endpoint density.
241
+
242
+ Parameters
243
+ ----------
244
+ ax : Axes
245
+ Matplotlib axes.
246
+ rays : RayBatch
247
+ Ray batch.
248
+ projection : str
249
+ Plane: 'xy', 'xz', 'yz'.
250
+ bins : int
251
+ Number of histogram bins.
252
+ cmap : str
253
+ Colormap name.
254
+ """
255
+ active_mask = rays.active
256
+ positions = rays.positions[active_mask]
257
+
258
+ idx1, idx2, xlabel, ylabel = get_projection_config(projection)
259
+
260
+ x, y = positions[:, idx1], positions[:, idx2]
261
+
262
+ _, _, _, im = ax.hist2d(x, y, bins=bins, cmap=cmap, cmin=1)
263
+ setup_axis_grid(ax, xlabel, ylabel, "Ray Density")
264
+ ax.set_aspect("equal", adjustable="box")
265
+ add_colorbar(ax, im, "Count")
266
+
267
+
268
+ # =============================================================================
269
+ # Surface Intersection Visualization
270
+ # =============================================================================
271
+
272
+
273
+ def plot_surface_profile(
274
+ ax: Axes,
275
+ surface: "Surface",
276
+ x_range: tuple[float, float] = (-200, 200),
277
+ y: float = 0.0,
278
+ n_points: int = 1000,
279
+ color: str = "blue",
280
+ linewidth: float = 2.0,
281
+ label: str = "Surface",
282
+ ) -> None:
283
+ """
284
+ Plot surface height profile along x-axis.
285
+
286
+ Parameters
287
+ ----------
288
+ ax : Axes
289
+ Matplotlib axes.
290
+ surface : Surface
291
+ Surface object with _surface_z method.
292
+ x_range : tuple
293
+ (x_min, x_max) range.
294
+ y : float
295
+ Y-coordinate for profile.
296
+ n_points : int
297
+ Number of sample points.
298
+ color : str
299
+ Line color.
300
+ linewidth : float
301
+ Line width.
302
+ label : str
303
+ Legend label.
304
+ """
305
+ x = np.linspace(x_range[0], x_range[1], n_points)
306
+ y_arr = np.full_like(x, y)
307
+
308
+ if hasattr(surface, "_surface_z"):
309
+ z = np.array(
310
+ [surface._surface_z(xi, yi) for xi, yi in zip(x, y_arr, strict=False)]
311
+ )
312
+ else:
313
+ z = np.zeros_like(x)
314
+
315
+ ax.plot(x, z, color=color, linewidth=linewidth, label=label)
316
+ ax.fill_between(x, z, z.min() - 0.5, alpha=0.3, color=color)
317
+ setup_axis_grid(ax, "X (m)", "Z (m)")
318
+
319
+
320
+ def plot_bounce_points(
321
+ ax: Axes,
322
+ bounce_positions: np.ndarray,
323
+ bounce_number: int = 1,
324
+ color: str | None = None,
325
+ size: float = 20,
326
+ alpha: float = 0.7,
327
+ projection: str = "xz",
328
+ label: str | None = None,
329
+ ) -> None:
330
+ """
331
+ Plot ray bounce points on surface.
332
+
333
+ Parameters
334
+ ----------
335
+ ax : Axes
336
+ Matplotlib axes.
337
+ bounce_positions : ndarray
338
+ (N, 3) array of bounce positions.
339
+ bounce_number : int
340
+ Bounce index (for color selection).
341
+ color : str, optional
342
+ Override color.
343
+ size : float
344
+ Marker size.
345
+ alpha : float
346
+ Transparency.
347
+ projection : str
348
+ Coordinate projection ('xz', 'xy', 'yz').
349
+ label : str, optional
350
+ Legend label.
351
+ """
352
+ if len(bounce_positions) == 0:
353
+ return
354
+
355
+ # Default colors by bounce number
356
+ bounce_colors = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd"]
357
+ if color is None:
358
+ color = bounce_colors[bounce_number % len(bounce_colors)]
359
+
360
+ idx1, idx2, _, _ = get_projection_config(projection)
361
+
362
+ x = bounce_positions[:, idx1]
363
+ z = bounce_positions[:, idx2]
364
+
365
+ if label is None:
366
+ label = f"Bounce {bounce_number}"
367
+
368
+ ax.scatter(x, z, c=color, s=size, alpha=alpha, label=label, edgecolors="none")
369
+
370
+
371
+ def plot_incoming_rays(
372
+ ax: Axes,
373
+ rays: "RayBatch",
374
+ surface: "Surface",
375
+ projection: str = "xz",
376
+ color: str = "gold",
377
+ alpha: float = 0.3,
378
+ linewidth: float = 0.5,
379
+ max_rays: int = 100,
380
+ ) -> None:
381
+ """
382
+ Plot incoming ray segments from origin to surface intersection.
383
+
384
+ Parameters
385
+ ----------
386
+ ax : Axes
387
+ Matplotlib axes.
388
+ rays : RayBatch
389
+ Ray batch before intersection.
390
+ surface : Surface
391
+ Surface for intersection calculation.
392
+ projection : str
393
+ Coordinate projection.
394
+ color : str
395
+ Ray color.
396
+ alpha : float
397
+ Transparency.
398
+ linewidth : float
399
+ Line width.
400
+ max_rays : int
401
+ Maximum rays to plot.
402
+ """
403
+ idx1, idx2, _, _ = get_projection_config(projection)
404
+
405
+ n_rays = min(rays.num_rays, max_rays)
406
+ sample_idx = (
407
+ np.linspace(0, rays.num_rays - 1, n_rays, dtype=int)
408
+ if rays.num_rays > max_rays
409
+ else np.arange(rays.num_rays)
410
+ )
411
+
412
+ for i in sample_idx:
413
+ if not rays.active[i]:
414
+ continue
415
+
416
+ pos = rays.positions[i]
417
+ direction = rays.directions[i]
418
+
419
+ # Find intersection
420
+ t, hit = surface.intersect(
421
+ pos.reshape(1, 3).astype(np.float32),
422
+ direction.reshape(1, 3).astype(np.float32),
423
+ np.array([True]),
424
+ )
425
+
426
+ if hit[0] and t[0] > 0:
427
+ intersection = pos + t[0] * direction
428
+ ax.plot(
429
+ [pos[idx1], intersection[idx1]],
430
+ [pos[idx2], intersection[idx2]],
431
+ color=color,
432
+ alpha=alpha,
433
+ linewidth=linewidth,
434
+ )
435
+
436
+
437
+ def plot_reflected_rays(
438
+ ax: Axes,
439
+ rays: "RayBatch",
440
+ length: float = 100.0,
441
+ projection: str = "xz",
442
+ color: str = "cyan",
443
+ alpha: float = 0.3,
444
+ linewidth: float = 0.5,
445
+ max_rays: int = 100,
446
+ ) -> None:
447
+ """
448
+ Plot reflected ray segments from current position.
449
+
450
+ Parameters
451
+ ----------
452
+ ax : Axes
453
+ Matplotlib axes.
454
+ rays : RayBatch
455
+ Reflected ray batch.
456
+ length : float
457
+ Length of ray segments to draw.
458
+ projection : str
459
+ Coordinate projection.
460
+ color : str
461
+ Ray color.
462
+ alpha : float
463
+ Transparency.
464
+ linewidth : float
465
+ Line width.
466
+ max_rays : int
467
+ Maximum rays to plot.
468
+ """
469
+ idx1, idx2, _, _ = get_projection_config(projection)
470
+
471
+ n_rays = min(rays.num_rays, max_rays)
472
+ sample_idx = (
473
+ np.linspace(0, rays.num_rays - 1, n_rays, dtype=int)
474
+ if rays.num_rays > max_rays
475
+ else np.arange(rays.num_rays)
476
+ )
477
+
478
+ for i in sample_idx:
479
+ if not rays.active[i]:
480
+ continue
481
+
482
+ pos = rays.positions[i]
483
+ direction = rays.directions[i]
484
+ endpoint = pos + length * direction
485
+
486
+ ax.plot(
487
+ [pos[idx1], endpoint[idx1]],
488
+ [pos[idx2], endpoint[idx2]],
489
+ color=color,
490
+ alpha=alpha,
491
+ linewidth=linewidth,
492
+ )
493
+
494
+
495
+ # =============================================================================
496
+ # Multi-Bounce Visualization
497
+ # =============================================================================
498
+
499
+
500
+ def plot_multi_bounce_paths(
501
+ ax: Axes,
502
+ ray_paths: dict[str, list[np.ndarray]],
503
+ projection: str = "xz",
504
+ reflected_color: str = "cyan",
505
+ refracted_color: str = "orange",
506
+ alpha: float = 0.3,
507
+ linewidth: float = 0.5,
508
+ max_paths: int = 100,
509
+ ) -> None:
510
+ """
511
+ Plot multi-bounce ray paths from trace_rays_multi_bounce output.
512
+
513
+ Parameters
514
+ ----------
515
+ ax : Axes
516
+ Matplotlib axes.
517
+ ray_paths : dict
518
+ Dictionary with 'reflected_paths' and/or 'refracted_paths' lists.
519
+ projection : str
520
+ Coordinate projection.
521
+ reflected_color : str
522
+ Color for reflected paths.
523
+ refracted_color : str
524
+ Color for refracted paths.
525
+ alpha : float
526
+ Transparency.
527
+ linewidth : float
528
+ Line width.
529
+ max_paths : int
530
+ Maximum paths to plot.
531
+ """
532
+ idx1, idx2, _, _ = get_projection_config(projection)
533
+
534
+ # Plot reflected paths
535
+ if "reflected_paths" in ray_paths:
536
+ paths = ray_paths["reflected_paths"]
537
+ n_paths = min(len(paths), max_paths)
538
+ sample_idx = (
539
+ np.linspace(0, len(paths) - 1, n_paths, dtype=int)
540
+ if len(paths) > max_paths
541
+ else range(len(paths))
542
+ )
543
+
544
+ for i in sample_idx:
545
+ path = paths[i]
546
+ if len(path) > 1:
547
+ ax.plot(
548
+ path[:, idx1],
549
+ path[:, idx2],
550
+ color=reflected_color,
551
+ alpha=alpha,
552
+ linewidth=linewidth,
553
+ )
554
+
555
+ # Plot refracted paths
556
+ if "refracted_paths" in ray_paths:
557
+ paths = ray_paths["refracted_paths"]
558
+ n_paths = min(len(paths), max_paths)
559
+ sample_idx = (
560
+ np.linspace(0, len(paths) - 1, n_paths, dtype=int)
561
+ if len(paths) > max_paths
562
+ else range(len(paths))
563
+ )
564
+
565
+ for i in sample_idx:
566
+ path = paths[i]
567
+ if len(path) > 1:
568
+ ax.plot(
569
+ path[:, idx1],
570
+ path[:, idx2],
571
+ color=refracted_color,
572
+ alpha=alpha,
573
+ linewidth=linewidth,
574
+ )
575
+
576
+
577
+ # =============================================================================
578
+ # Composite Figure Builders
579
+ # =============================================================================
580
+
581
+
582
+ def create_ray_overview_figure(
583
+ rays: "RayBatch",
584
+ surface: "Surface",
585
+ reflected_rays: Optional["RayBatch"] = None,
586
+ bounce_points: list[np.ndarray] | None = None,
587
+ figsize: tuple[float, float] = (16, 10),
588
+ x_range: tuple[float, float] = (-500, 500),
589
+ title: str = "Ray Tracing Overview",
590
+ save_path: str | None = None,
591
+ ) -> Figure:
592
+ """
593
+ Create comprehensive ray tracing overview figure.
594
+
595
+ Parameters
596
+ ----------
597
+ rays : RayBatch
598
+ Initial rays.
599
+ surface : Surface
600
+ Wave surface.
601
+ reflected_rays : RayBatch, optional
602
+ Reflected rays.
603
+ bounce_points : List[ndarray], optional
604
+ List of bounce position arrays per bounce.
605
+ figsize : tuple
606
+ Figure size.
607
+ x_range : tuple
608
+ X-axis range.
609
+ title : str
610
+ Figure title.
611
+ save_path : str, optional
612
+ Path to save figure.
613
+
614
+ Returns
615
+ -------
616
+ Figure
617
+ Matplotlib figure.
618
+ """
619
+ fig, axes = plt.subplots(2, 2, figsize=figsize, constrained_layout=True)
620
+ fig.suptitle(title, fontsize=14, fontweight="bold")
621
+
622
+ # Top-left: XZ view with surface and rays
623
+ ax_xz = axes[0, 0]
624
+ plot_surface_profile(ax_xz, surface, x_range=x_range)
625
+ plot_incoming_rays(ax_xz, rays, surface, projection="xz")
626
+ if reflected_rays is not None:
627
+ plot_reflected_rays(ax_xz, reflected_rays, projection="xz", length=50)
628
+ if bounce_points is not None:
629
+ for i, bp in enumerate(bounce_points):
630
+ if len(bp) > 0:
631
+ plot_bounce_points(
632
+ ax_xz, np.array(bp), bounce_number=i + 1, projection="xz"
633
+ )
634
+ ax_xz.legend(loc="upper right")
635
+ ax_xz.set_title("XZ View (Side)")
636
+
637
+ # Top-right: XY view (top-down)
638
+ ax_xy = axes[0, 1]
639
+ plot_ray_endpoints_scatter(ax_xy, rays, projection="xy", color_by="intensity")
640
+ ax_xy.set_title("XY View (Top)")
641
+
642
+ # Bottom-left: Ray endpoint histogram
643
+ ax_hist = axes[1, 0]
644
+ if reflected_rays is not None:
645
+ plot_ray_endpoints_histogram(ax_hist, reflected_rays, projection="xz")
646
+ else:
647
+ plot_ray_endpoints_histogram(ax_hist, rays, projection="xz")
648
+ ax_hist.set_title("Ray Density")
649
+
650
+ # Bottom-right: Surface detail
651
+ ax_detail = axes[1, 1]
652
+ detail_range = (x_range[0] / 5, x_range[1] / 5)
653
+ plot_surface_profile(ax_detail, surface, x_range=detail_range, color="darkblue")
654
+ if (
655
+ bounce_points is not None
656
+ and len(bounce_points) > 0
657
+ and len(bounce_points[0]) > 0
658
+ ):
659
+ # Filter to detail range
660
+ bp = np.array(bounce_points[0])
661
+ mask = (bp[:, 0] >= detail_range[0]) & (bp[:, 0] <= detail_range[1])
662
+ if np.any(mask):
663
+ plot_bounce_points(ax_detail, bp[mask], bounce_number=1, projection="xz")
664
+ ax_detail.set_title("Surface Detail")
665
+
666
+ if save_path:
667
+ save_figure(fig, save_path)
668
+
669
+ return fig
670
+
671
+
672
+ # =============================================================================
673
+ # Production Simulation Figure
674
+ # =============================================================================
675
+
676
+
677
+ def plot_production_ray_overview(
678
+ original_rays: "RayBatch",
679
+ surface: "Surface",
680
+ config: dict,
681
+ output_path: str,
682
+ timestamp: str,
683
+ max_bounces: int = 2,
684
+ ) -> Figure:
685
+ """
686
+ Create production simulation ray overview with surface bounce points.
687
+
688
+ Shows incoming rays, bounce points on wave surface (colored by bounce number),
689
+ and reflected rays toward recording sphere.
690
+
691
+ Parameters
692
+ ----------
693
+ original_rays : RayBatch
694
+ Original rays before tracing.
695
+ surface : Surface
696
+ The wave surface (e.g., CurvedWaveSurface).
697
+ config : dict
698
+ Simulation configuration with keys:
699
+ - grazing_angle: Beam grazing angle in degrees
700
+ - beam_radius: Beam radius in meters
701
+ - earth_radius: Earth radius in meters
702
+ - recording_altitude: Recording sphere altitude in meters
703
+ - source_distance: Source distance in meters
704
+ output_path : str
705
+ Directory to save figure.
706
+ timestamp : str
707
+ Timestamp for filename.
708
+ max_bounces : int
709
+ Maximum number of bounces to visualize (default: 2).
710
+
711
+ Returns
712
+ -------
713
+ Figure
714
+ Matplotlib figure.
715
+ """
716
+ from pathlib import Path
717
+
718
+ from ..surfaces import EARTH_RADIUS
719
+ from ..utilities.ray_data import create_ray_batch
720
+
721
+ output_path = Path(output_path)
722
+ output_path.mkdir(parents=True, exist_ok=True)
723
+
724
+ # Subsample rays for visualization
725
+ n_vis_rays = min(500, original_rays.num_rays)
726
+ vis_sample_idx = np.random.choice(original_rays.num_rays, n_vis_rays, replace=False)
727
+
728
+ original_rays_vis = create_ray_batch(num_rays=n_vis_rays)
729
+ original_rays_vis.positions[:] = original_rays.positions[vis_sample_idx]
730
+ original_rays_vis.directions[:] = original_rays.directions[vis_sample_idx]
731
+ original_rays_vis.wavelengths[:] = original_rays.wavelengths[vis_sample_idx]
732
+ original_rays_vis.intensities[:] = original_rays.intensities[vis_sample_idx]
733
+ original_rays_vis.active[:] = True
734
+
735
+ # Trace rays through multiple bounces using batch intersection
736
+ bounce_points = [[] for _ in range(max_bounces)]
737
+
738
+ # Start with original rays
739
+ current_rays = original_rays_vis.clone()
740
+
741
+ for bounce_num in range(max_bounces):
742
+ # Batch intersection for all active rays
743
+ hit_distances, hit_mask = surface.intersect(
744
+ current_rays.positions, current_rays.directions
745
+ )
746
+
747
+ if not np.any(hit_mask):
748
+ break
749
+
750
+ # Get hit positions for rays that intersected
751
+ hit_positions = (
752
+ current_rays.positions[hit_mask]
753
+ + hit_distances[hit_mask, np.newaxis] * current_rays.directions[hit_mask]
754
+ )
755
+
756
+ # Store bounce points
757
+ for hit_pos in hit_positions:
758
+ bounce_points[bounce_num].append(hit_pos.copy())
759
+
760
+ # Get normals at hit points
761
+ normals = surface.normal_at(hit_positions, current_rays.directions[hit_mask])
762
+
763
+ # Compute reflected directions
764
+ dot_prod = np.sum(
765
+ current_rays.directions[hit_mask] * normals, axis=1, keepdims=True
766
+ )
767
+ reflected_dirs = current_rays.directions[hit_mask] - 2 * dot_prod * normals
768
+
769
+ # Create new ray batch for next bounce (only rays that hit)
770
+ n_hits = np.sum(hit_mask)
771
+ if n_hits == 0:
772
+ break
773
+
774
+ next_rays = create_ray_batch(num_rays=n_hits)
775
+ next_rays.positions[:] = hit_positions + 0.01 * reflected_dirs # Small offset
776
+ next_rays.directions[:] = reflected_dirs
777
+ next_rays.intensities[:] = current_rays.intensities[hit_mask]
778
+ next_rays.active[:] = True
779
+
780
+ current_rays = next_rays
781
+
782
+ # Convert bounce points to arrays and filter to beam footprint region
783
+ grazing_angle_rad = np.radians(config["grazing_angle"])
784
+ elongation_factor = 1.0 / np.sin(grazing_angle_rad)
785
+ footprint_radius = config["beam_radius"] * elongation_factor
786
+ filter_radius = footprint_radius * 2.0 # Keep points within 2x the footprint
787
+
788
+ for i in range(max_bounces):
789
+ if len(bounce_points[i]) > 0:
790
+ bounce_points[i] = np.array(bounce_points[i])
791
+ # Filter to remove rays that escaped to infinity
792
+ distances = np.sqrt(
793
+ bounce_points[i][:, 0] ** 2 + bounce_points[i][:, 1] ** 2
794
+ )
795
+ valid_mask = distances < filter_radius
796
+ bounce_points[i] = bounce_points[i][valid_mask]
797
+ else:
798
+ bounce_points[i] = np.empty((0, 3))
799
+
800
+ # Print bounce position statistics
801
+ print("Bounce Position Statistics:")
802
+ for i in range(max_bounces):
803
+ if len(bounce_points[i]) > 0:
804
+ x_positions = bounce_points[i][:, 0]
805
+ z_positions = bounce_points[i][:, 2]
806
+ print(f" Bounce {i+1}:")
807
+ print(
808
+ f" X: mean={np.mean(x_positions):6.1f} m, std={np.std(x_positions):6.1f} m, "
809
+ f"range=[{np.min(x_positions):6.1f}, {np.max(x_positions):6.1f}]"
810
+ )
811
+ print(
812
+ f" Z: mean={np.mean(z_positions):6.3f} m, std={np.std(z_positions):6.3f} m, "
813
+ f"range=[{np.min(z_positions):6.3f}, {np.max(z_positions):6.3f}]"
814
+ )
815
+ print(f" Count: {len(bounce_points[i])} rays")
816
+
817
+ # Create figure
818
+ fig, axes = plt.subplots(1, 2, figsize=(16, 8))
819
+
820
+ # Left panel: Full scale view
821
+ ax1 = axes[0]
822
+
823
+ earth_center = np.array([0, 0, -EARTH_RADIUS])
824
+ theta = np.linspace(-0.01, 0.01, 100)
825
+ earth_x = EARTH_RADIUS * np.sin(theta)
826
+ earth_z = earth_center[2] + EARTH_RADIUS * np.cos(theta)
827
+ ax1.fill_between(
828
+ earth_x / 1000, earth_z / 1000, -10, color="#4a90d9", alpha=0.3, label="Ocean"
829
+ )
830
+ ax1.plot(earth_x / 1000, earth_z / 1000, "b-", linewidth=2, label="Sea surface")
831
+
832
+ # Recording sphere
833
+ recording_radius = (
834
+ config.get("earth_radius", EARTH_RADIUS) + config["recording_altitude"]
835
+ )
836
+ rec_x = recording_radius * np.sin(theta)
837
+ rec_z = earth_center[2] + recording_radius * np.cos(theta)
838
+ ax1.plot(
839
+ rec_x / 1000,
840
+ rec_z / 1000,
841
+ "g--",
842
+ linewidth=1.5,
843
+ label=f'Recording sphere ({config["recording_altitude"]/1000:.0f} km)',
844
+ )
845
+
846
+ # Plot sample rays (show incoming and reflected rays)
847
+ if len(bounce_points[0]) > 0:
848
+ n_plot = min(200, len(bounce_points[0]))
849
+ plot_indices = np.linspace(0, len(bounce_points[0]) - 1, n_plot, dtype=int)
850
+
851
+ # Incoming rays
852
+ if len(plot_indices) > 0:
853
+ idx = plot_indices[0]
854
+ start = original_rays_vis.positions[idx]
855
+ end = bounce_points[0][idx]
856
+ ax1.plot(
857
+ [start[0] / 1000, end[0] / 1000],
858
+ [start[2] / 1000, end[2] / 1000],
859
+ "r-",
860
+ alpha=0.6,
861
+ linewidth=0.8,
862
+ label="Incoming rays",
863
+ )
864
+
865
+ for idx in plot_indices[1:]:
866
+ if idx < len(original_rays_vis.positions) and idx < len(bounce_points[0]):
867
+ start = original_rays_vis.positions[idx]
868
+ end = bounce_points[0][idx]
869
+ ax1.plot(
870
+ [start[0] / 1000, end[0] / 1000],
871
+ [start[2] / 1000, end[2] / 1000],
872
+ "r-",
873
+ alpha=0.5,
874
+ linewidth=0.8,
875
+ )
876
+
877
+ # Reflected rays (from first bounce)
878
+ ray_length = config["recording_altitude"] * 1.5
879
+ if len(bounce_points[0]) > 0:
880
+ # Get intersection and normals for sample rays
881
+ sample_positions = original_rays_vis.positions[plot_indices]
882
+ sample_directions = original_rays_vis.directions[plot_indices]
883
+
884
+ hit_distances, hit_mask = surface.intersect(
885
+ sample_positions, sample_directions
886
+ )
887
+
888
+ if np.any(hit_mask):
889
+ hit_positions = (
890
+ sample_positions[hit_mask]
891
+ + hit_distances[hit_mask, np.newaxis] * sample_directions[hit_mask]
892
+ )
893
+ normals = surface.normal_at(hit_positions, sample_directions[hit_mask])
894
+ dot_prod = np.sum(
895
+ sample_directions[hit_mask] * normals, axis=1, keepdims=True
896
+ )
897
+ reflected_dirs = sample_directions[hit_mask] - 2 * dot_prod * normals
898
+
899
+ for i, (hit_pos, reflected_dir) in enumerate(
900
+ zip(hit_positions, reflected_dirs, strict=False)
901
+ ):
902
+ end = hit_pos + reflected_dir * ray_length
903
+ if i == 0:
904
+ ax1.plot(
905
+ [hit_pos[0] / 1000, end[0] / 1000],
906
+ [hit_pos[2] / 1000, end[2] / 1000],
907
+ "g-",
908
+ alpha=0.6,
909
+ linewidth=0.8,
910
+ label="Reflected rays",
911
+ )
912
+ else:
913
+ ax1.plot(
914
+ [hit_pos[0] / 1000, end[0] / 1000],
915
+ [hit_pos[2] / 1000, end[2] / 1000],
916
+ "g-",
917
+ alpha=0.5,
918
+ linewidth=0.8,
919
+ )
920
+
921
+ ax1.set_xlabel("X (km)", fontsize=12)
922
+ ax1.set_ylabel("Z (km)", fontsize=12)
923
+ ax1.set_title("Ray Paths Overview (X-Z Plane)", fontsize=14)
924
+ ax1.legend(loc="upper right")
925
+ ax1.set_aspect("equal")
926
+ ax1.grid(True, alpha=0.3)
927
+
928
+ max_range = (
929
+ max(config.get("source_distance", 10000), config["recording_altitude"]) * 1.5
930
+ )
931
+ ax1.set_xlim(-max_range / 1000 * 0.5, max_range / 1000 * 1.5)
932
+ ax1.set_ylim(-5, config["recording_altitude"] / 1000 * 1.3)
933
+
934
+ # Right panel: Surface detail with multi-bounce points
935
+ ax2 = axes[1]
936
+
937
+ x_range = np.linspace(-footprint_radius * 1.2, footprint_radius * 1.2, 400)
938
+ z_wave = []
939
+ for x in x_range:
940
+ if hasattr(surface, "_compute_wave_displacement"):
941
+ _, _, dz = surface._compute_wave_displacement(
942
+ np.array([x]), np.array([0.0])
943
+ )
944
+ z_wave.append(dz[0])
945
+ elif hasattr(surface, "_surface_z"):
946
+ z_wave.append(surface._surface_z(x, 0.0))
947
+ else:
948
+ z_wave.append(0.0)
949
+
950
+ z_wave = np.array(z_wave)
951
+ ax2.fill_between(x_range, z_wave, z_wave.min() - 2, color="#4a90d9", alpha=0.5)
952
+ ax2.plot(x_range, z_wave, "b-", linewidth=2, label="Wave surface")
953
+
954
+ # Plot bounce points with different colors for each generation
955
+ bounce_colors = ["red", "cyan", "magenta", "yellow"]
956
+ bounce_sizes = [20, 15, 12, 10]
957
+ bounce_labels = ["1st bounce", "2nd bounce", "3rd bounce", "4th bounce"]
958
+
959
+ for bounce_idx in range(min(max_bounces, len(bounce_points))):
960
+ if len(bounce_points[bounce_idx]) > 0:
961
+ bounce_x = bounce_points[bounce_idx][:, 0]
962
+ bounce_z = bounce_points[bounce_idx][:, 2]
963
+ mean_x = np.mean(bounce_x)
964
+
965
+ label_text = f"{bounce_labels[bounce_idx]} (mean X={mean_x:.1f}m)"
966
+ ax2.scatter(
967
+ bounce_x,
968
+ bounce_z,
969
+ c=bounce_colors[bounce_idx],
970
+ s=bounce_sizes[bounce_idx],
971
+ alpha=0.7,
972
+ label=label_text,
973
+ zorder=3 + bounce_idx,
974
+ edgecolors="black",
975
+ linewidths=0.5,
976
+ )
977
+
978
+ # Draw vertical line at mean X position
979
+ ax2.axvline(
980
+ mean_x,
981
+ color=bounce_colors[bounce_idx],
982
+ linestyle=":",
983
+ alpha=0.4,
984
+ linewidth=1.5,
985
+ )
986
+
987
+ # Mark beam footprint
988
+ ax2.axvline(-footprint_radius, color="orange", linestyle="--", alpha=0.5)
989
+ ax2.axvline(
990
+ footprint_radius,
991
+ color="orange",
992
+ linestyle="--",
993
+ alpha=0.5,
994
+ label=f"Beam footprint (±{footprint_radius:.0f}m)",
995
+ )
996
+
997
+ ax2.set_xlabel("X (m)", fontsize=12)
998
+ ax2.set_ylabel("Z (m)", fontsize=12)
999
+ ax2.set_title("Wave Surface Detail with Multi-Bounce Points", fontsize=14)
1000
+ ax2.legend()
1001
+ ax2.grid(True, alpha=0.3)
1002
+
1003
+ plt.tight_layout()
1004
+ fig_path = output_path / f"simulation_{timestamp}_overview.png"
1005
+ plt.savefig(fig_path, dpi=150, bbox_inches="tight")
1006
+ plt.close()
1007
+
1008
+ return fig
1009
+
1010
+
1011
+ def plot_wave_surface_detail(
1012
+ reflected_rays: "RayBatch",
1013
+ surface: "Surface",
1014
+ x_range: tuple[float, float] = (-200, 200),
1015
+ figsize: tuple[float, float] = (12, 6),
1016
+ save_path: str | None = None,
1017
+ ) -> Figure:
1018
+ """
1019
+ Plot wave surface detail with ray intersection points.
1020
+
1021
+ Parameters
1022
+ ----------
1023
+ reflected_rays : RayBatch
1024
+ Batch of reflected rays.
1025
+ surface : Surface
1026
+ The surface object (must have _surface_z method).
1027
+ x_range : tuple
1028
+ X-axis range for plotting.
1029
+ figsize : tuple
1030
+ Figure size (width, height).
1031
+ save_path : str, optional
1032
+ Path to save figure.
1033
+
1034
+ Returns
1035
+ -------
1036
+ Figure
1037
+ Matplotlib figure.
1038
+ """
1039
+ fig, ax = plt.subplots(figsize=figsize)
1040
+
1041
+ # Plot wave surface
1042
+ x_detail = np.linspace(x_range[0], x_range[1], 1000)
1043
+ y_detail = np.zeros_like(x_detail)
1044
+ z_detail = surface._surface_z(x_detail, y_detail)
1045
+ ax.plot(x_detail, z_detail, "b-", linewidth=3, label="Wave Surface", zorder=3)
1046
+ ax.fill_between(
1047
+ x_detail,
1048
+ z_detail,
1049
+ z_detail.min() - 0.5,
1050
+ color="lightblue",
1051
+ alpha=0.4,
1052
+ zorder=1,
1053
+ )
1054
+
1055
+ # Plot intersection points
1056
+ # Back-calculate actual hit positions (rays are offset by 0.01m along direction)
1057
+ actual_hit_positions = reflected_rays.positions - 0.01 * reflected_rays.directions
1058
+ reflection_x = actual_hit_positions[:, 0]
1059
+ reflection_z = actual_hit_positions[:, 2]
1060
+
1061
+ ax.scatter(
1062
+ reflection_x,
1063
+ reflection_z,
1064
+ c="red",
1065
+ s=8,
1066
+ alpha=0.6,
1067
+ zorder=5,
1068
+ label="Intersection Points",
1069
+ )
1070
+
1071
+ ax.set_xlim(x_range[0], x_range[1])
1072
+ z_range = z_detail.max() - z_detail.min()
1073
+ ax.set_ylim(z_detail.min() - z_range * 0.3, z_detail.max() + z_range * 0.5)
1074
+ ax.set_xlabel("X Position (m)", fontsize=11, fontweight="bold")
1075
+ ax.set_ylabel("Z Position (m)", fontsize=11, fontweight="bold")
1076
+ ax.set_title("Wave Surface Detail with Ray Intersections", fontweight="bold")
1077
+ ax.grid(True, alpha=0.3)
1078
+ ax.legend(loc="upper right", fontsize=10)
1079
+
1080
+ if save_path:
1081
+ fig.savefig(save_path, dpi=150, bbox_inches="tight")
1082
+
1083
+ return fig
1084
+
1085
+
1086
+ def plot_ray_paths_with_surface(
1087
+ rays: "RayBatch",
1088
+ reflected_rays: "RayBatch",
1089
+ surface: "Surface",
1090
+ detector_distance: float = 1000.0,
1091
+ source_distance: float = 1000.0,
1092
+ refracted_rays: Optional["RayBatch"] = None,
1093
+ ray_paths: dict | None = None,
1094
+ figsize: tuple[float, float] = (16, 10),
1095
+ save_path: str | None = None,
1096
+ ) -> Figure:
1097
+ """
1098
+ Plot full ray paths (incoming, reflected, and refracted) with wave surface.
1099
+
1100
+ Parameters
1101
+ ----------
1102
+ rays : RayBatch
1103
+ Initial ray batch (before interaction).
1104
+ reflected_rays : RayBatch
1105
+ Reflected ray batch (after interaction).
1106
+ surface : Surface
1107
+ The surface object.
1108
+ detector_distance : float
1109
+ Detector distance in meters.
1110
+ source_distance : float
1111
+ Source distance in meters (unused, for API compatibility).
1112
+ refracted_rays : RayBatch, optional
1113
+ Refracted ray batch (after interaction).
1114
+ ray_paths : dict, optional
1115
+ Dictionary with ray path data from trace_rays_multi_bounce containing:
1116
+ - 'reflected_paths': list of Nx3 arrays, one per ray
1117
+ - 'refracted_paths': list of Nx3 arrays for refracted rays
1118
+ - 'reflected_final_dirs': final direction for each reflected path
1119
+ - 'refracted_final_dirs': final direction for each refracted path
1120
+ figsize : tuple
1121
+ Figure size.
1122
+ save_path : str, optional
1123
+ Path to save figure.
1124
+
1125
+ Returns
1126
+ -------
1127
+ Figure
1128
+ Matplotlib figure.
1129
+ """
1130
+ fig, ax = plt.subplots(figsize=figsize)
1131
+
1132
+ # Determine scale based on distances (use km for large distances)
1133
+ use_km = detector_distance >= 100.0
1134
+ scale_factor = 1000.0 if use_km else 1.0
1135
+ distance_label = "km" if use_km else "m"
1136
+
1137
+ # Plot surface with extended range
1138
+ x_min = min(rays.positions[:, 0].min(), reflected_rays.positions[:, 0].min())
1139
+ x_max = max(rays.positions[:, 0].max(), reflected_rays.positions[:, 0].max())
1140
+ x_range = x_max - x_min
1141
+ x_surf = np.linspace(x_min - x_range * 0.2, x_max + detector_distance * 0.3, 1000)
1142
+ y_surf = np.zeros_like(x_surf)
1143
+
1144
+ # Handle both wave surfaces and planar surfaces
1145
+ if hasattr(surface, "_surface_z"):
1146
+ z_surf = surface._surface_z(x_surf, y_surf)
1147
+ surface_label = "Wave Surface"
1148
+ else:
1149
+ # Planar surface - assume z=0 horizontal plane
1150
+ z_surf = np.zeros_like(x_surf)
1151
+ surface_label = "Planar Surface"
1152
+
1153
+ ax.plot(
1154
+ x_surf / scale_factor,
1155
+ z_surf,
1156
+ "b-",
1157
+ linewidth=3,
1158
+ label=surface_label,
1159
+ zorder=3,
1160
+ )
1161
+ z_fill_bottom = z_surf.min() - max(abs(z_surf.max() - z_surf.min()) * 0.5, 0.01)
1162
+ ax.fill_between(
1163
+ x_surf / scale_factor,
1164
+ z_surf,
1165
+ z_fill_bottom,
1166
+ color="lightblue",
1167
+ alpha=0.3,
1168
+ zorder=1,
1169
+ )
1170
+
1171
+ # Sample rays to plot (for performance and clarity)
1172
+ num_rays_to_plot = min(100, rays.num_rays)
1173
+ indices_to_plot = np.linspace(0, rays.num_rays - 1, num_rays_to_plot, dtype=int)
1174
+
1175
+ # Plot ray paths if ray_paths dict is provided (multi-bounce tracking)
1176
+ if ray_paths is not None and "reflected_paths" in ray_paths:
1177
+ reflected_paths = ray_paths["reflected_paths"]
1178
+ reflected_final_dirs = ray_paths.get("reflected_final_dirs", [])
1179
+ refracted_paths = ray_paths.get("refracted_paths", [])
1180
+ refracted_final_dirs = ray_paths.get("refracted_final_dirs", [])
1181
+
1182
+ # Sample paths to plot
1183
+ num_paths = len(reflected_paths)
1184
+ num_to_plot = min(100, num_paths)
1185
+ path_indices = (
1186
+ np.linspace(0, num_paths - 1, num_to_plot, dtype=int)
1187
+ if num_paths > 0
1188
+ else []
1189
+ )
1190
+
1191
+ # Color for reflected paths based on number of bounces
1192
+ for path_idx in path_indices:
1193
+ path = reflected_paths[path_idx]
1194
+ if path is None or len(path) < 2:
1195
+ continue
1196
+
1197
+ # Plot the complete path
1198
+ ax.plot(
1199
+ path[:, 0] / scale_factor,
1200
+ path[:, 2],
1201
+ "b-",
1202
+ linewidth=0.8,
1203
+ alpha=0.4,
1204
+ zorder=2,
1205
+ )
1206
+
1207
+ # Mark bounce points with different colors
1208
+ colors = ["red", "orange", "yellow", "pink", "purple"]
1209
+ for i in range(1, len(path)): # Skip start position
1210
+ color = colors[(i - 1) % len(colors)]
1211
+ ax.scatter(
1212
+ path[i, 0] / scale_factor,
1213
+ path[i, 2],
1214
+ c=color,
1215
+ s=10,
1216
+ alpha=0.6,
1217
+ zorder=4,
1218
+ )
1219
+
1220
+ # Plot final direction as an extending ray
1221
+ if (
1222
+ path_idx < len(reflected_final_dirs)
1223
+ and reflected_final_dirs[path_idx] is not None
1224
+ ):
1225
+ final_pos = path[-1]
1226
+ final_dir = reflected_final_dirs[path_idx]
1227
+ end_pos = final_pos + final_dir * detector_distance * 0.3
1228
+ ax.plot(
1229
+ [final_pos[0] / scale_factor, end_pos[0] / scale_factor],
1230
+ [final_pos[2], end_pos[2]],
1231
+ "r-",
1232
+ linewidth=0.8,
1233
+ alpha=0.5,
1234
+ zorder=2,
1235
+ )
1236
+
1237
+ # Plot refracted paths (dashed green lines going into water)
1238
+ num_refr_paths = len(refracted_paths)
1239
+ num_refr_to_plot = min(50, num_refr_paths)
1240
+ refr_path_indices = (
1241
+ np.linspace(0, num_refr_paths - 1, num_refr_to_plot, dtype=int)
1242
+ if num_refr_paths > 0
1243
+ else []
1244
+ )
1245
+
1246
+ for path_idx in refr_path_indices:
1247
+ path = refracted_paths[path_idx]
1248
+ if path is None or len(path) < 1:
1249
+ continue
1250
+
1251
+ if (
1252
+ path_idx < len(refracted_final_dirs)
1253
+ and refracted_final_dirs[path_idx] is not None
1254
+ ):
1255
+ start_pos = path[0]
1256
+ refr_dir = refracted_final_dirs[path_idx]
1257
+ end_pos = start_pos + refr_dir * detector_distance * 0.2
1258
+ ax.plot(
1259
+ [start_pos[0] / scale_factor, end_pos[0] / scale_factor],
1260
+ [start_pos[2], end_pos[2]],
1261
+ "g--",
1262
+ linewidth=0.6,
1263
+ alpha=0.4,
1264
+ zorder=2,
1265
+ )
1266
+ else:
1267
+ # Original single-bounce plotting
1268
+ for idx in indices_to_plot:
1269
+ if idx >= reflected_rays.num_rays:
1270
+ continue
1271
+
1272
+ # Back-calculate actual hit position (rays are offset by 0.01m along direction)
1273
+ reflect_dir = reflected_rays.directions[idx, :]
1274
+ actual_hit_pos = reflected_rays.positions[idx, :] - 0.01 * reflect_dir
1275
+ start_pos = rays.positions[idx, :]
1276
+ ax.plot(
1277
+ [start_pos[0] / scale_factor, actual_hit_pos[0] / scale_factor],
1278
+ [start_pos[2], actual_hit_pos[2]],
1279
+ "b-",
1280
+ linewidth=0.8,
1281
+ alpha=0.4,
1282
+ zorder=2,
1283
+ )
1284
+
1285
+ # Reflected ray: from actual hit position outward
1286
+ ray_length = detector_distance * 0.5
1287
+ end_pos = actual_hit_pos + reflect_dir * ray_length
1288
+ ax.plot(
1289
+ [actual_hit_pos[0] / scale_factor, end_pos[0] / scale_factor],
1290
+ [actual_hit_pos[2], end_pos[2]],
1291
+ "r-",
1292
+ linewidth=0.8,
1293
+ alpha=0.5,
1294
+ zorder=2,
1295
+ )
1296
+
1297
+ # Plot refracted rays (if provided)
1298
+ if refracted_rays is not None and refracted_rays.num_rays > 0:
1299
+ refr_indices = np.linspace(
1300
+ 0,
1301
+ min(refracted_rays.num_rays - 1, rays.num_rays - 1),
1302
+ num_rays_to_plot,
1303
+ dtype=int,
1304
+ )
1305
+ for idx in refr_indices:
1306
+ if idx >= refracted_rays.num_rays:
1307
+ continue
1308
+ refract_dir = refracted_rays.directions[idx, :]
1309
+ # Back-calculate actual hit position (rays are offset by 0.01m along direction)
1310
+ actual_hit_pos = refracted_rays.positions[idx, :] - 0.01 * refract_dir
1311
+ ray_length = detector_distance * 0.3
1312
+ end_pos = actual_hit_pos + refract_dir * ray_length
1313
+ ax.plot(
1314
+ [actual_hit_pos[0] / scale_factor, end_pos[0] / scale_factor],
1315
+ [actual_hit_pos[2], end_pos[2]],
1316
+ "g--",
1317
+ linewidth=0.8,
1318
+ alpha=0.5,
1319
+ zorder=2,
1320
+ )
1321
+
1322
+ # Add beam source indicator
1323
+ beam_source = rays.positions[0, :]
1324
+ ax.scatter(
1325
+ [beam_source[0] / scale_factor],
1326
+ [beam_source[2]],
1327
+ c="blue",
1328
+ s=300,
1329
+ marker="*",
1330
+ edgecolors="black",
1331
+ linewidths=2,
1332
+ label="Beam Source",
1333
+ zorder=6,
1334
+ )
1335
+
1336
+ ax.set_xlabel(f"X Position ({distance_label})", fontsize=13, fontweight="bold")
1337
+ ax.set_ylabel("Z Position (m)", fontsize=13, fontweight="bold")
1338
+
1339
+ # Update title based on what's shown
1340
+ if ray_paths is not None and "reflected_paths" in ray_paths:
1341
+ num_paths = len(ray_paths["reflected_paths"])
1342
+ title_text = f"Ray Paths: Multi-Bounce Reflection ({num_paths} ray paths)"
1343
+ elif refracted_rays is not None and refracted_rays.num_rays > 0:
1344
+ title_text = (
1345
+ f"Ray Paths: Reflection & Refraction ({num_rays_to_plot} rays shown)"
1346
+ )
1347
+ else:
1348
+ title_text = (
1349
+ f"Ray Paths: Reflection from Wave Surface ({num_rays_to_plot} rays shown)"
1350
+ )
1351
+ ax.set_title(title_text, fontweight="bold", fontsize=15)
1352
+ ax.grid(True, alpha=0.3, linewidth=0.5)
1353
+
1354
+ # Build legend
1355
+ legend_elements = [
1356
+ plt.Line2D([0], [0], color="b", linewidth=2, alpha=0.5, label="Incoming"),
1357
+ plt.Line2D([0], [0], color="r", linewidth=2, alpha=0.5, label="Reflected"),
1358
+ ]
1359
+ if refracted_rays is not None and refracted_rays.num_rays > 0:
1360
+ legend_elements.append(
1361
+ plt.Line2D(
1362
+ [0],
1363
+ [0],
1364
+ color="g",
1365
+ linewidth=2,
1366
+ linestyle="--",
1367
+ alpha=0.5,
1368
+ label="Refracted",
1369
+ )
1370
+ )
1371
+ legend_elements.append(
1372
+ plt.Line2D(
1373
+ [0],
1374
+ [0],
1375
+ marker="*",
1376
+ color="w",
1377
+ markerfacecolor="blue",
1378
+ markersize=12,
1379
+ label="Beam Source",
1380
+ )
1381
+ )
1382
+ ax.legend(handles=legend_elements, loc="upper left", fontsize=11, framealpha=0.9)
1383
+
1384
+ ax.set_xlim(
1385
+ x_min / scale_factor - x_range * 0.2 / scale_factor,
1386
+ (x_max + detector_distance * 0.3) / scale_factor,
1387
+ )
1388
+ ax.set_ylim(
1389
+ min(
1390
+ z_surf.min() - abs(z_surf.max() - z_surf.min()) * 0.5,
1391
+ -detector_distance * 0.01,
1392
+ ),
1393
+ max(z_surf.max(), detector_distance * 0.1),
1394
+ )
1395
+
1396
+ if not use_km:
1397
+ ax.set_aspect("equal", adjustable="datalim")
1398
+
1399
+ if save_path:
1400
+ fig.savefig(save_path, dpi=150, bbox_inches="tight")
1401
+ plt.close(fig)
1402
+
1403
+ return fig
1404
+
1405
+
1406
+ # =============================================================================
1407
+ # Legacy Convenience Functions (Backward Compatibility)
1408
+ # =============================================================================
1409
+
1410
+
1411
+ def plot_ray_paths_2d(
1412
+ ray_history: list["RayBatch"],
1413
+ max_rays: int = 100,
1414
+ color_by: str = "wavelength",
1415
+ alpha: float = 0.4,
1416
+ linewidth: float = 0.8,
1417
+ figsize: tuple[float, float] = (15, 5),
1418
+ save_path: str | None = None,
1419
+ ) -> Figure:
1420
+ """
1421
+ Create figure with three 2D projections of ray paths.
1422
+
1423
+ This is a convenience function for quick visualization. For custom layouts,
1424
+ use plot_ray_paths_projection() on individual axes.
1425
+
1426
+ Parameters
1427
+ ----------
1428
+ ray_history : List[RayBatch]
1429
+ List of ray batches at different propagation steps.
1430
+ max_rays : int
1431
+ Maximum rays to plot (sampled uniformly if exceeded).
1432
+ color_by : str
1433
+ Color rays by: 'wavelength', 'intensity', 'generation', 'index'.
1434
+ alpha : float
1435
+ Line transparency.
1436
+ linewidth : float
1437
+ Line width.
1438
+ figsize : tuple
1439
+ Figure size.
1440
+ save_path : str, optional
1441
+ Path to save figure.
1442
+
1443
+ Returns
1444
+ -------
1445
+ Figure
1446
+ Matplotlib figure with three subplots (XY, XZ, YZ).
1447
+ """
1448
+ fig, axes = plt.subplots(1, 3, figsize=figsize, constrained_layout=True)
1449
+ fig.suptitle("Ray Paths - 2D Projections", fontsize=14, fontweight="bold")
1450
+
1451
+ for ax, proj in zip(axes, ["xy", "xz", "yz"], strict=False):
1452
+ plot_ray_paths_projection(
1453
+ ax,
1454
+ ray_history,
1455
+ projection=proj,
1456
+ max_rays=max_rays,
1457
+ color_by=color_by,
1458
+ alpha=alpha,
1459
+ linewidth=linewidth,
1460
+ show_colorbar=(proj == "yz"),
1461
+ )
1462
+
1463
+ if save_path:
1464
+ save_figure(fig, save_path)
1465
+
1466
+ return fig
1467
+
1468
+
1469
+ # Re-export Fresnel/Brewster functions from fresnel_plots module for backward compatibility
1470
+
1471
+ # Re-export polarization functions from polarization_plots module for backward compatibility
1472
+
1473
+
1474
+ def plot_ray_endpoints(
1475
+ rays: "RayBatch",
1476
+ plane: str = "xy",
1477
+ color_by: str = "wavelength",
1478
+ bins: int = 50,
1479
+ figsize: tuple[float, float] = (12, 5),
1480
+ save_path: str | None = None,
1481
+ ) -> Figure:
1482
+ """
1483
+ Create figure with scatter and histogram of ray endpoints.
1484
+
1485
+ This is a convenience function for quick visualization. For custom layouts,
1486
+ use plot_ray_endpoints_scatter() and plot_ray_endpoints_histogram().
1487
+
1488
+ Parameters
1489
+ ----------
1490
+ rays : RayBatch
1491
+ Ray batch with endpoint positions.
1492
+ plane : str
1493
+ Projection plane: 'xy', 'xz', 'yz'.
1494
+ color_by : str
1495
+ Color scatter by: 'wavelength', 'intensity'.
1496
+ bins : int
1497
+ Histogram bins.
1498
+ figsize : tuple
1499
+ Figure size.
1500
+ save_path : str, optional
1501
+ Path to save figure.
1502
+
1503
+ Returns
1504
+ -------
1505
+ Figure
1506
+ Matplotlib figure with scatter and histogram.
1507
+ """
1508
+ fig, axes = plt.subplots(1, 2, figsize=figsize, constrained_layout=True)
1509
+ fig.suptitle(
1510
+ f"Ray Endpoints - {plane.upper()} Plane", fontsize=14, fontweight="bold"
1511
+ )
1512
+
1513
+ plot_ray_endpoints_scatter(
1514
+ axes[0], rays, projection=plane, color_by=color_by, show_colorbar=True
1515
+ )
1516
+ plot_ray_endpoints_histogram(axes[1], rays, projection=plane, bins=bins)
1517
+
1518
+ if save_path:
1519
+ save_figure(fig, save_path)
1520
+
1521
+ return fig