lsurf 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. lsurf/__init__.py +471 -0
  2. lsurf/analysis/__init__.py +107 -0
  3. lsurf/analysis/healpix_utils.py +418 -0
  4. lsurf/analysis/sphere_viz.py +1280 -0
  5. lsurf/cli/__init__.py +48 -0
  6. lsurf/cli/build.py +398 -0
  7. lsurf/cli/config_schema.py +318 -0
  8. lsurf/cli/gui_cmd.py +76 -0
  9. lsurf/cli/interactive.py +850 -0
  10. lsurf/cli/main.py +81 -0
  11. lsurf/cli/run.py +806 -0
  12. lsurf/detectors/__init__.py +266 -0
  13. lsurf/detectors/analysis.py +289 -0
  14. lsurf/detectors/base.py +284 -0
  15. lsurf/detectors/constant_size_rings.py +485 -0
  16. lsurf/detectors/directional.py +45 -0
  17. lsurf/detectors/extended/__init__.py +73 -0
  18. lsurf/detectors/extended/local_sphere.py +353 -0
  19. lsurf/detectors/extended/recording_sphere.py +368 -0
  20. lsurf/detectors/planar.py +45 -0
  21. lsurf/detectors/protocol.py +187 -0
  22. lsurf/detectors/recording_spheres.py +63 -0
  23. lsurf/detectors/results.py +1140 -0
  24. lsurf/detectors/small/__init__.py +79 -0
  25. lsurf/detectors/small/directional.py +330 -0
  26. lsurf/detectors/small/planar.py +401 -0
  27. lsurf/detectors/small/spherical.py +450 -0
  28. lsurf/detectors/spherical.py +45 -0
  29. lsurf/geometry/__init__.py +199 -0
  30. lsurf/geometry/builder.py +478 -0
  31. lsurf/geometry/cell.py +228 -0
  32. lsurf/geometry/cell_geometry.py +247 -0
  33. lsurf/geometry/detector_arrays.py +1785 -0
  34. lsurf/geometry/geometry.py +222 -0
  35. lsurf/geometry/surface_analysis.py +375 -0
  36. lsurf/geometry/validation.py +91 -0
  37. lsurf/gui/__init__.py +51 -0
  38. lsurf/gui/app.py +903 -0
  39. lsurf/gui/core/__init__.py +39 -0
  40. lsurf/gui/core/scene.py +343 -0
  41. lsurf/gui/core/simulation.py +264 -0
  42. lsurf/gui/renderers/__init__.py +40 -0
  43. lsurf/gui/renderers/ray_renderer.py +353 -0
  44. lsurf/gui/renderers/source_renderer.py +505 -0
  45. lsurf/gui/renderers/surface_renderer.py +477 -0
  46. lsurf/gui/views/__init__.py +48 -0
  47. lsurf/gui/views/config_editor.py +3199 -0
  48. lsurf/gui/views/properties.py +257 -0
  49. lsurf/gui/views/results.py +291 -0
  50. lsurf/gui/views/scene_tree.py +180 -0
  51. lsurf/gui/views/viewport_3d.py +555 -0
  52. lsurf/gui/views/visualizations.py +712 -0
  53. lsurf/materials/__init__.py +169 -0
  54. lsurf/materials/base/__init__.py +64 -0
  55. lsurf/materials/base/full_inhomogeneous.py +208 -0
  56. lsurf/materials/base/grid_inhomogeneous.py +319 -0
  57. lsurf/materials/base/homogeneous.py +342 -0
  58. lsurf/materials/base/material_field.py +527 -0
  59. lsurf/materials/base/simple_inhomogeneous.py +418 -0
  60. lsurf/materials/base/spectral_inhomogeneous.py +497 -0
  61. lsurf/materials/implementations/__init__.py +120 -0
  62. lsurf/materials/implementations/data/alpha_values_typical_atmosphere_updated.txt +24 -0
  63. lsurf/materials/implementations/duct_atmosphere.py +390 -0
  64. lsurf/materials/implementations/exponential_atmosphere.py +435 -0
  65. lsurf/materials/implementations/gaussian_lens.py +120 -0
  66. lsurf/materials/implementations/interpolated_data.py +123 -0
  67. lsurf/materials/implementations/layered_atmosphere.py +134 -0
  68. lsurf/materials/implementations/linear_gradient.py +109 -0
  69. lsurf/materials/implementations/linsley_atmosphere.py +764 -0
  70. lsurf/materials/implementations/standard_materials.py +126 -0
  71. lsurf/materials/implementations/turbulent_atmosphere.py +135 -0
  72. lsurf/materials/implementations/us_standard_atmosphere.py +149 -0
  73. lsurf/materials/utils/__init__.py +77 -0
  74. lsurf/materials/utils/constants.py +45 -0
  75. lsurf/materials/utils/device_functions.py +117 -0
  76. lsurf/materials/utils/dispersion.py +160 -0
  77. lsurf/materials/utils/factories.py +142 -0
  78. lsurf/propagation/__init__.py +91 -0
  79. lsurf/propagation/detector_gpu.py +67 -0
  80. lsurf/propagation/gpu_device_rays.py +294 -0
  81. lsurf/propagation/kernels/__init__.py +175 -0
  82. lsurf/propagation/kernels/absorption/__init__.py +61 -0
  83. lsurf/propagation/kernels/absorption/grid.py +240 -0
  84. lsurf/propagation/kernels/absorption/simple.py +232 -0
  85. lsurf/propagation/kernels/absorption/spectral.py +410 -0
  86. lsurf/propagation/kernels/detection/__init__.py +64 -0
  87. lsurf/propagation/kernels/detection/protocol.py +102 -0
  88. lsurf/propagation/kernels/detection/spherical.py +255 -0
  89. lsurf/propagation/kernels/device_functions.py +790 -0
  90. lsurf/propagation/kernels/fresnel/__init__.py +64 -0
  91. lsurf/propagation/kernels/fresnel/protocol.py +97 -0
  92. lsurf/propagation/kernels/fresnel/standard.py +258 -0
  93. lsurf/propagation/kernels/intersection/__init__.py +79 -0
  94. lsurf/propagation/kernels/intersection/annular_plane.py +207 -0
  95. lsurf/propagation/kernels/intersection/bounded_plane.py +205 -0
  96. lsurf/propagation/kernels/intersection/plane.py +166 -0
  97. lsurf/propagation/kernels/intersection/protocol.py +95 -0
  98. lsurf/propagation/kernels/intersection/signed_distance.py +742 -0
  99. lsurf/propagation/kernels/intersection/sphere.py +190 -0
  100. lsurf/propagation/kernels/propagation/__init__.py +85 -0
  101. lsurf/propagation/kernels/propagation/grid.py +527 -0
  102. lsurf/propagation/kernels/propagation/protocol.py +105 -0
  103. lsurf/propagation/kernels/propagation/simple.py +460 -0
  104. lsurf/propagation/kernels/propagation/spectral.py +875 -0
  105. lsurf/propagation/kernels/registry.py +331 -0
  106. lsurf/propagation/kernels/surface/__init__.py +72 -0
  107. lsurf/propagation/kernels/surface/bisection.py +232 -0
  108. lsurf/propagation/kernels/surface/detection.py +402 -0
  109. lsurf/propagation/kernels/surface/reduction.py +166 -0
  110. lsurf/propagation/propagator_protocol.py +222 -0
  111. lsurf/propagation/propagators/__init__.py +101 -0
  112. lsurf/propagation/propagators/detector_handler.py +354 -0
  113. lsurf/propagation/propagators/factory.py +200 -0
  114. lsurf/propagation/propagators/fresnel_handler.py +305 -0
  115. lsurf/propagation/propagators/gpu_gradient.py +566 -0
  116. lsurf/propagation/propagators/gpu_surface_propagator.py +707 -0
  117. lsurf/propagation/propagators/gradient.py +429 -0
  118. lsurf/propagation/propagators/intersection_handler.py +327 -0
  119. lsurf/propagation/propagators/material_propagator.py +398 -0
  120. lsurf/propagation/propagators/signed_distance_handler.py +522 -0
  121. lsurf/propagation/propagators/spectral_gpu_gradient.py +553 -0
  122. lsurf/propagation/propagators/surface_interaction.py +616 -0
  123. lsurf/propagation/propagators/surface_propagator.py +719 -0
  124. lsurf/py.typed +1 -0
  125. lsurf/simulation/__init__.py +70 -0
  126. lsurf/simulation/config.py +164 -0
  127. lsurf/simulation/orchestrator.py +462 -0
  128. lsurf/simulation/result.py +299 -0
  129. lsurf/simulation/simulation.py +262 -0
  130. lsurf/sources/__init__.py +128 -0
  131. lsurf/sources/base.py +264 -0
  132. lsurf/sources/collimated.py +252 -0
  133. lsurf/sources/custom.py +409 -0
  134. lsurf/sources/diverging.py +228 -0
  135. lsurf/sources/gaussian.py +272 -0
  136. lsurf/sources/parallel_from_positions.py +197 -0
  137. lsurf/sources/point.py +172 -0
  138. lsurf/sources/uniform_diverging.py +258 -0
  139. lsurf/surfaces/__init__.py +184 -0
  140. lsurf/surfaces/cpu/__init__.py +50 -0
  141. lsurf/surfaces/cpu/curved_wave.py +463 -0
  142. lsurf/surfaces/cpu/gerstner_wave.py +381 -0
  143. lsurf/surfaces/cpu/wave_params.py +118 -0
  144. lsurf/surfaces/gpu/__init__.py +72 -0
  145. lsurf/surfaces/gpu/annular_plane.py +453 -0
  146. lsurf/surfaces/gpu/bounded_plane.py +390 -0
  147. lsurf/surfaces/gpu/curved_wave.py +483 -0
  148. lsurf/surfaces/gpu/gerstner_wave.py +377 -0
  149. lsurf/surfaces/gpu/multi_curved_wave.py +520 -0
  150. lsurf/surfaces/gpu/plane.py +299 -0
  151. lsurf/surfaces/gpu/recording_sphere.py +587 -0
  152. lsurf/surfaces/gpu/sphere.py +311 -0
  153. lsurf/surfaces/protocol.py +336 -0
  154. lsurf/surfaces/registry.py +373 -0
  155. lsurf/utilities/__init__.py +175 -0
  156. lsurf/utilities/detector_analysis.py +814 -0
  157. lsurf/utilities/fresnel.py +628 -0
  158. lsurf/utilities/interactions.py +1215 -0
  159. lsurf/utilities/propagation.py +602 -0
  160. lsurf/utilities/ray_data.py +532 -0
  161. lsurf/utilities/recording_sphere.py +745 -0
  162. lsurf/utilities/time_spread.py +463 -0
  163. lsurf/visualization/__init__.py +329 -0
  164. lsurf/visualization/absorption_plots.py +334 -0
  165. lsurf/visualization/atmospheric_plots.py +754 -0
  166. lsurf/visualization/common.py +348 -0
  167. lsurf/visualization/detector_plots.py +1350 -0
  168. lsurf/visualization/detector_sphere_plots.py +1173 -0
  169. lsurf/visualization/fresnel_plots.py +1061 -0
  170. lsurf/visualization/ocean_simulation_plots.py +999 -0
  171. lsurf/visualization/polarization_plots.py +916 -0
  172. lsurf/visualization/raytracing_plots.py +1521 -0
  173. lsurf/visualization/ring_detector_plots.py +1867 -0
  174. lsurf/visualization/time_spread_plots.py +531 -0
  175. lsurf-1.0.0.dist-info/METADATA +381 -0
  176. lsurf-1.0.0.dist-info/RECORD +180 -0
  177. lsurf-1.0.0.dist-info/WHEEL +5 -0
  178. lsurf-1.0.0.dist-info/entry_points.txt +2 -0
  179. lsurf-1.0.0.dist-info/licenses/LICENSE +32 -0
  180. lsurf-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1215 @@
1
+ # The Clear BSD License
2
+ #
3
+ # Copyright (c) 2026 Tobias Heibges
4
+ # All rights reserved.
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted (subject to the limitations in the disclaimer
8
+ # below) provided that the following conditions are met:
9
+ #
10
+ # * Redistributions of source code must retain the above copyright notice,
11
+ # this list of conditions and the following disclaimer.
12
+ #
13
+ # * Redistributions in binary form must reproduce the above copyright
14
+ # notice, this list of conditions and the following disclaimer in the
15
+ # documentation and/or other materials provided with the distribution.
16
+ #
17
+ # * Neither the name of the copyright holder nor the names of its
18
+ # contributors may be used to endorse or promote products derived from this
19
+ # software without specific prior written permission.
20
+ #
21
+ # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
22
+ # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
23
+ # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24
+ # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
25
+ # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
26
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
27
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
28
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
29
+ # BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
30
+ # IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
31
+ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
32
+ # POSSIBILITY OF SUCH DAMAGE.
33
+
34
+ """
35
+ Ray-Surface Interactions
36
+
37
+ Handles the physics of ray-surface interactions including reflection,
38
+ refraction, and intensity updates based on Fresnel equations.
39
+
40
+ At each interface, rays are split into reflected and refracted components
41
+ with intensities determined by the Fresnel equations. This properly models
42
+ the physical behavior where light partially reflects and partially transmits
43
+ at each interface.
44
+
45
+ Functions
46
+ ---------
47
+ process_surface_interaction
48
+ Process rays intersecting a surface (generates both reflected + refracted rays)
49
+ reflect_rays
50
+ Apply reflection to rays at surface (modifies rays in place)
51
+ refract_rays
52
+ Apply refraction to rays at surface (modifies rays in place)
53
+ trace_rays_multi_bounce
54
+ Trace reflected rays through multiple bounces (legacy, reflected-only tracing)
55
+ trace_rays_with_splitting
56
+ Trace rays with proper Fresnel splitting (both reflected AND refracted rays)
57
+
58
+ References
59
+ ----------
60
+ .. [1] Glassner, A. S. (1989). An Introduction to Ray Tracing.
61
+ Academic Press.
62
+ """
63
+
64
+ import numpy as np
65
+ from numpy.typing import NDArray
66
+
67
+ from ..surfaces import Surface
68
+ from .fresnel import (
69
+ compute_reflection_direction,
70
+ compute_refraction_direction,
71
+ fresnel_coefficients,
72
+ initialize_polarization_vectors,
73
+ transform_polarization_reflection,
74
+ transform_polarization_refraction,
75
+ )
76
+ from .ray_data import RayBatch, create_ray_batch
77
+
78
+
79
+ def process_surface_interaction(
80
+ rays: RayBatch,
81
+ surface: Surface,
82
+ wavelength: float | NDArray[np.float32] = 500e-9,
83
+ generate_reflected: bool = True,
84
+ generate_refracted: bool = True,
85
+ polarization: str = "unpolarized",
86
+ track_polarization_vector: bool = False,
87
+ ) -> tuple[RayBatch | None, RayBatch | None]:
88
+ """
89
+ Process rays intersecting a surface.
90
+
91
+ Computes intersections, applies Fresnel equations, and generates
92
+ reflected and/or refracted ray bundles.
93
+
94
+ Parameters
95
+ ----------
96
+ rays : RayBatch
97
+ Input rays to test for intersection
98
+ surface : Surface
99
+ Surface to intersect with
100
+ wavelength : float or ndarray, optional
101
+ Wavelength for computing refractive indices (default: 500 nm)
102
+ If rays have multiple wavelengths, use rays.wavelengths
103
+ generate_reflected : bool, optional
104
+ Whether to generate reflected rays (default: True)
105
+ generate_refracted : bool, optional
106
+ Whether to generate refracted rays (default: True)
107
+ polarization : str, optional
108
+ Polarization state: 's', 'p', or 'unpolarized' (default)
109
+ Used for Fresnel coefficient calculation.
110
+ track_polarization_vector : bool, optional
111
+ Whether to track 3D polarization vectors through the interaction.
112
+ If True, polarization vectors are initialized (if not present) and
113
+ transformed through reflection/refraction. (default: False)
114
+
115
+ Returns
116
+ -------
117
+ reflected_rays : RayBatch or None
118
+ Reflected rays (None if generate_reflected=False or no hits)
119
+ refracted_rays : RayBatch or None
120
+ Refracted rays (None if generate_refracted=False or no hits)
121
+
122
+ Notes
123
+ -----
124
+ Only active rays are tested for intersection.
125
+ Input rays are not modified.
126
+
127
+ When track_polarization_vector=True, the function:
128
+ 1. Initializes polarization vectors if rays.polarization_vector is None
129
+ 2. Transforms polarization vectors through reflection/refraction
130
+ 3. Stores the transformed vectors in the output ray batches
131
+
132
+ Examples
133
+ --------
134
+ >>> # Create surface and rays
135
+ >>> surface = PlanarSurface(
136
+ ... point=(0, 0, 1),
137
+ ... normal=(0, 0, -1),
138
+ ... material_front=BK7_GLASS,
139
+ ... material_back=AIR_STP
140
+ ... )
141
+ >>> rays = ... # Create rays
142
+ >>>
143
+ >>> # Process interaction with polarization tracking
144
+ >>> reflected, refracted = process_surface_interaction(
145
+ ... rays, surface, generate_reflected=True, generate_refracted=True,
146
+ ... track_polarization_vector=True
147
+ ... )
148
+ """
149
+ # Only test active rays
150
+ active_mask = rays.active
151
+ if not np.any(active_mask):
152
+ return None, None
153
+
154
+ active_origins = rays.positions[active_mask]
155
+ active_directions = rays.directions[active_mask]
156
+ active_wavelengths = rays.wavelengths[active_mask]
157
+ active_intensities = rays.intensities[active_mask]
158
+ active_times = rays.accumulated_time[active_mask]
159
+
160
+ # Handle polarization vectors
161
+ active_polarization_vectors = None
162
+ if track_polarization_vector:
163
+ if rays.polarization_vector is not None:
164
+ active_polarization_vectors = rays.polarization_vector[active_mask]
165
+ else:
166
+ # Initialize polarization vectors based on polarization state
167
+ active_polarization_vectors = initialize_polarization_vectors(
168
+ active_directions, polarization=polarization
169
+ )
170
+
171
+ # Find intersections
172
+ distances, hit_mask = surface.intersect(active_origins, active_directions)
173
+
174
+ if not np.any(hit_mask):
175
+ return None, None
176
+
177
+ # Intersection points
178
+ hit_positions = (
179
+ active_origins[hit_mask]
180
+ + distances[hit_mask, np.newaxis] * active_directions[hit_mask]
181
+ )
182
+ hit_directions = active_directions[hit_mask]
183
+ hit_wavelengths = active_wavelengths[hit_mask]
184
+ hit_intensities = active_intensities[hit_mask]
185
+ hit_times = active_times[hit_mask]
186
+ hit_distances = distances[hit_mask]
187
+
188
+ # Get hit polarization vectors if tracking
189
+ hit_polarization_vectors = None
190
+ if track_polarization_vector and active_polarization_vectors is not None:
191
+ hit_polarization_vectors = active_polarization_vectors[hit_mask]
192
+
193
+ # Compute surface normals
194
+ normals = surface.normal_at(hit_positions, hit_directions)
195
+
196
+ # Get refractive indices at intersection points
197
+ # Simplification: use material properties at wavelength
198
+ n1_values = np.array(
199
+ [
200
+ surface.material_front.get_refractive_index(pos[0], pos[1], pos[2], wl)
201
+ for pos, wl in zip(hit_positions, hit_wavelengths, strict=False)
202
+ ],
203
+ dtype=np.float32,
204
+ )
205
+
206
+ n2_values = np.array(
207
+ [
208
+ surface.material_back.get_refractive_index(pos[0], pos[1], pos[2], wl)
209
+ for pos, wl in zip(hit_positions, hit_wavelengths, strict=False)
210
+ ],
211
+ dtype=np.float32,
212
+ )
213
+
214
+ # Calculate time to reach surface: distance / phase_velocity
215
+ # Phase velocity = c / n
216
+ c = 299792458.0 # Speed of light in m/s
217
+ travel_time = hit_distances * n1_values / c # Time = distance * n / c
218
+ updated_times = hit_times + travel_time
219
+
220
+ # Compute incident angle
221
+ cos_theta_i = -np.sum(hit_directions * normals, axis=1)
222
+ cos_theta_i = np.abs(cos_theta_i) # Ensure positive
223
+
224
+ # Compute Fresnel coefficients
225
+ R, T = fresnel_coefficients(n1_values, n2_values, cos_theta_i, polarization)
226
+
227
+ # If tracking polarization vectors, also compute R_s and R_p separately
228
+ # for proper Fresnel weighting of polarization components
229
+ R_s = None
230
+ R_p = None
231
+ if track_polarization_vector:
232
+ R_s, _ = fresnel_coefficients(n1_values, n2_values, cos_theta_i, "s")
233
+ R_p, _ = fresnel_coefficients(n1_values, n2_values, cos_theta_i, "p")
234
+
235
+ # Generate reflected rays
236
+ reflected_rays = None
237
+ if generate_reflected:
238
+ reflected_directions = compute_reflection_direction(hit_directions, normals)
239
+ reflected_intensities = hit_intensities * R
240
+
241
+ # Create reflected ray batch
242
+ num_reflected = len(hit_positions)
243
+ reflected_rays = create_ray_batch(
244
+ num_rays=num_reflected,
245
+ enable_polarization_vector=track_polarization_vector,
246
+ )
247
+ # Offset along ray direction to prevent immediate re-intersection
248
+ # Must be larger than intersection tolerance (1e-3) to avoid self-intersection
249
+ reflected_rays.positions[:] = hit_positions + 0.01 * reflected_directions
250
+ reflected_rays.directions[:] = reflected_directions
251
+ reflected_rays.wavelengths[:] = hit_wavelengths
252
+ reflected_rays.intensities[:] = reflected_intensities
253
+ reflected_rays.active[:] = (
254
+ reflected_intensities > 1e-10
255
+ ) # Deactivate very weak rays
256
+ reflected_rays.accumulated_time[:] = updated_times
257
+ reflected_rays.generations[:] = rays.generations[active_mask][hit_mask] + 1
258
+
259
+ # Transform polarization vectors for reflected rays with Fresnel weighting
260
+ if track_polarization_vector and hit_polarization_vectors is not None:
261
+ reflected_pol = transform_polarization_reflection(
262
+ hit_polarization_vectors,
263
+ hit_directions,
264
+ reflected_directions,
265
+ normals,
266
+ R_s=R_s,
267
+ R_p=R_p,
268
+ )
269
+ reflected_rays.polarization_vector[:] = reflected_pol
270
+
271
+ # Generate refracted rays
272
+ refracted_rays = None
273
+ if generate_refracted:
274
+ refracted_directions, tir_mask = compute_refraction_direction(
275
+ hit_directions, normals, n1_values, n2_values
276
+ )
277
+
278
+ # Transmission intensity (0 for TIR)
279
+ refracted_intensities = hit_intensities * T
280
+ refracted_intensities[tir_mask] = 0.0
281
+
282
+ # Create refracted ray batch
283
+ num_refracted = len(hit_positions)
284
+ refracted_rays = create_ray_batch(
285
+ num_rays=num_refracted,
286
+ enable_polarization_vector=track_polarization_vector,
287
+ )
288
+ # Offset along ray direction to prevent immediate re-intersection
289
+ # Must be larger than intersection tolerance (1e-3) to avoid self-intersection
290
+ refracted_rays.positions[:] = hit_positions + 0.01 * refracted_directions
291
+ refracted_rays.directions[:] = refracted_directions
292
+ refracted_rays.wavelengths[:] = hit_wavelengths
293
+ refracted_rays.intensities[:] = refracted_intensities
294
+ refracted_rays.active[:] = (refracted_intensities > 1e-10) & (~tir_mask)
295
+ refracted_rays.accumulated_time[:] = updated_times
296
+ refracted_rays.generations[:] = rays.generations[active_mask][hit_mask] + 1
297
+
298
+ # Transform polarization vectors for refracted rays
299
+ if track_polarization_vector and hit_polarization_vectors is not None:
300
+ refracted_pol = transform_polarization_refraction(
301
+ hit_polarization_vectors,
302
+ hit_directions,
303
+ refracted_directions,
304
+ normals,
305
+ )
306
+ # For TIR rays, set polarization to zero (they won't be used)
307
+ refracted_pol[tir_mask] = 0.0
308
+ refracted_rays.polarization_vector[:] = refracted_pol
309
+
310
+ return reflected_rays, refracted_rays
311
+
312
+
313
+ def reflect_rays(
314
+ rays: RayBatch,
315
+ surface: Surface,
316
+ wavelength: float | NDArray[np.float32] = 500e-9,
317
+ polarization: str = "unpolarized",
318
+ in_place: bool = False,
319
+ ) -> RayBatch:
320
+ """
321
+ Apply reflection to rays at surface.
322
+
323
+ Updates ray directions and intensities based on Fresnel reflection.
324
+ Rays that don't hit the surface are deactivated.
325
+
326
+ Parameters
327
+ ----------
328
+ rays : RayBatch
329
+ Input rays
330
+ surface : Surface
331
+ Surface to reflect from
332
+ wavelength : float or ndarray, optional
333
+ Wavelength for computing refractive indices
334
+ polarization : str, optional
335
+ Polarization state
336
+ in_place : bool, optional
337
+ If True, modify rays in place. If False, return new batch.
338
+
339
+ Returns
340
+ -------
341
+ RayBatch
342
+ Reflected rays (same object if in_place=True)
343
+ """
344
+ if not in_place:
345
+ rays = rays.clone()
346
+
347
+ # Only process active rays
348
+ active_mask = rays.active
349
+ if not np.any(active_mask):
350
+ return rays
351
+
352
+ # Find intersections
353
+ distances, hit_mask = surface.intersect(
354
+ rays.positions[active_mask], rays.directions[active_mask]
355
+ )
356
+
357
+ if not np.any(hit_mask):
358
+ rays.active[:] = False
359
+ return rays
360
+
361
+ # Build full-size hit mask
362
+ full_hit_mask = np.zeros(len(rays.positions), dtype=bool)
363
+ full_hit_mask[active_mask] = hit_mask
364
+
365
+ # Process hits
366
+ active_indices = np.where(active_mask)[0][hit_mask]
367
+
368
+ hit_positions = (
369
+ rays.positions[active_indices]
370
+ + distances[hit_mask, np.newaxis] * rays.directions[active_indices]
371
+ )
372
+ normals = surface.normal_at(hit_positions, rays.directions[active_indices])
373
+
374
+ # Compute reflection
375
+ reflected_directions = compute_reflection_direction(
376
+ rays.directions[active_indices], normals
377
+ )
378
+
379
+ # Compute Fresnel coefficients
380
+ cos_theta_i = -np.sum(rays.directions[active_indices] * normals, axis=1)
381
+ cos_theta_i = np.abs(cos_theta_i)
382
+
383
+ n1_values = np.array(
384
+ [
385
+ surface.material_front.get_refractive_index(pos[0], pos[1], pos[2], wl)
386
+ for pos, wl in zip(
387
+ hit_positions, rays.wavelengths[active_indices], strict=False
388
+ )
389
+ ],
390
+ dtype=np.float32,
391
+ )
392
+
393
+ n2_values = np.array(
394
+ [
395
+ surface.material_back.get_refractive_index(pos[0], pos[1], pos[2], wl)
396
+ for pos, wl in zip(
397
+ hit_positions, rays.wavelengths[active_indices], strict=False
398
+ )
399
+ ],
400
+ dtype=np.float32,
401
+ )
402
+
403
+ R, _ = fresnel_coefficients(n1_values, n2_values, cos_theta_i, polarization)
404
+
405
+ # Update rays - offset must exceed surface.intersect min_distance (0.01)
406
+ rays.positions[active_indices] = hit_positions + 0.02 * reflected_directions
407
+ rays.directions[active_indices] = reflected_directions
408
+ rays.intensities[active_indices] *= R
409
+ rays.generations[active_indices] += 1
410
+
411
+ # Deactivate non-hits and weak rays
412
+ rays.active[~full_hit_mask] = False
413
+ rays.active[rays.intensities < 1e-10] = False
414
+
415
+ return rays
416
+
417
+
418
+ def refract_rays(
419
+ rays: RayBatch,
420
+ surface: Surface,
421
+ wavelength: float | NDArray[np.float32] = 500e-9,
422
+ polarization: str = "unpolarized",
423
+ in_place: bool = False,
424
+ ) -> RayBatch:
425
+ """
426
+ Apply refraction to rays at surface.
427
+
428
+ Updates ray directions and intensities based on Fresnel transmission.
429
+ Rays undergoing total internal reflection are deactivated.
430
+
431
+ Parameters
432
+ ----------
433
+ rays : RayBatch
434
+ Input rays
435
+ surface : Surface
436
+ Surface to refract through
437
+ wavelength : float or ndarray, optional
438
+ Wavelength for computing refractive indices
439
+ polarization : str, optional
440
+ Polarization state
441
+ in_place : bool, optional
442
+ If True, modify rays in place. If False, return new batch.
443
+
444
+ Returns
445
+ -------
446
+ RayBatch
447
+ Refracted rays (same object if in_place=True)
448
+ """
449
+ if not in_place:
450
+ rays = rays.clone()
451
+
452
+ # Only process active rays
453
+ active_mask = rays.active
454
+ if not np.any(active_mask):
455
+ return rays
456
+
457
+ # Find intersections
458
+ distances, hit_mask = surface.intersect(
459
+ rays.positions[active_mask], rays.directions[active_mask]
460
+ )
461
+
462
+ if not np.any(hit_mask):
463
+ rays.active[:] = False
464
+ return rays
465
+
466
+ # Build full-size hit mask
467
+ full_hit_mask = np.zeros(len(rays.positions), dtype=bool)
468
+ full_hit_mask[active_mask] = hit_mask
469
+
470
+ # Process hits
471
+ active_indices = np.where(active_mask)[0][hit_mask]
472
+
473
+ hit_positions = (
474
+ rays.positions[active_indices]
475
+ + distances[hit_mask, np.newaxis] * rays.directions[active_indices]
476
+ )
477
+ normals = surface.normal_at(hit_positions, rays.directions[active_indices])
478
+
479
+ # Get refractive indices
480
+ n1_values = np.array(
481
+ [
482
+ surface.material_front.get_refractive_index(pos[0], pos[1], pos[2], wl)
483
+ for pos, wl in zip(
484
+ hit_positions, rays.wavelengths[active_indices], strict=False
485
+ )
486
+ ],
487
+ dtype=np.float32,
488
+ )
489
+
490
+ n2_values = np.array(
491
+ [
492
+ surface.material_back.get_refractive_index(pos[0], pos[1], pos[2], wl)
493
+ for pos, wl in zip(
494
+ hit_positions, rays.wavelengths[active_indices], strict=False
495
+ )
496
+ ],
497
+ dtype=np.float32,
498
+ )
499
+
500
+ # Compute refraction
501
+ refracted_directions, tir_mask = compute_refraction_direction(
502
+ rays.directions[active_indices], normals, n1_values, n2_values
503
+ )
504
+
505
+ # Compute Fresnel coefficients
506
+ cos_theta_i = -np.sum(rays.directions[active_indices] * normals, axis=1)
507
+ cos_theta_i = np.abs(cos_theta_i)
508
+
509
+ _, T = fresnel_coefficients(n1_values, n2_values, cos_theta_i, polarization)
510
+
511
+ # Update rays - offset must exceed surface.intersect min_distance (0.01)
512
+ rays.positions[active_indices] = hit_positions + 0.02 * refracted_directions
513
+ rays.directions[active_indices] = refracted_directions
514
+ rays.intensities[active_indices] *= T
515
+ rays.generations[active_indices] += 1
516
+
517
+ # Deactivate TIR rays, non-hits, and weak rays
518
+ tir_indices = active_indices[tir_mask]
519
+ rays.active[tir_indices] = False
520
+ rays.active[~full_hit_mask] = False
521
+ rays.active[rays.intensities < 1e-10] = False
522
+
523
+ return rays
524
+
525
+
526
+ def trace_rays_multi_bounce(
527
+ rays: RayBatch,
528
+ surface: Surface,
529
+ max_bounces: int = 10,
530
+ bounding_radius: float = 10000.0,
531
+ wavelength: float = 532e-9,
532
+ min_intensity: float = 1e-10,
533
+ track_refracted: bool = True,
534
+ ) -> tuple[RayBatch, RayBatch, dict]:
535
+ """
536
+ Trace rays through multiple surface interactions until termination.
537
+
538
+ Rays are traced until they:
539
+ - Exit the bounding sphere
540
+ - Reach maximum number of bounces
541
+ - Have intensity below threshold
542
+
543
+ Parameters
544
+ ----------
545
+ rays : RayBatch
546
+ Initial ray batch to trace
547
+ surface : Surface
548
+ Surface to interact with
549
+ max_bounces : int, optional
550
+ Maximum number of surface interactions (default: 10)
551
+ bounding_radius : float, optional
552
+ Radius of bounding sphere in meters (default: 10000)
553
+ wavelength : float, optional
554
+ Wavelength for Fresnel calculations (default: 532nm)
555
+ min_intensity : float, optional
556
+ Minimum intensity threshold (default: 1e-10)
557
+ track_refracted : bool, optional
558
+ Whether to track refracted rays (default: True)
559
+
560
+ Returns
561
+ -------
562
+ final_reflected : RayBatch
563
+ Final state of all reflected rays that exited bounding sphere
564
+ final_refracted : RayBatch
565
+ Final state of all refracted rays (combined from all bounces)
566
+ ray_paths : dict
567
+ Dictionary containing:
568
+ - 'reflected_paths': list of paths, each path is Nx3 array of positions
569
+ - 'refracted_paths': list of paths for refracted rays (start from refraction point)
570
+ - 'reflected_final_dirs': list of final direction vectors for each reflected path
571
+ - 'refracted_final_dirs': list of final direction vectors for each refracted path
572
+
573
+ Notes
574
+ -----
575
+ The function tracks reflected rays through multiple bounces on the
576
+ wavy surface. Each ray's complete path (all positions) is stored for
577
+ visualization. Refracted rays are collected but not further traced
578
+ (they go into the water and don't re-interact with the air-water
579
+ interface from below in typical scenarios).
580
+
581
+ Examples
582
+ --------
583
+ >>> # Trace rays with up to 5 bounces
584
+ >>> final_refl, final_refr, paths = trace_rays_multi_bounce(
585
+ ... rays, surface, max_bounces=5, bounding_radius=5000.0
586
+ ... )
587
+ >>> # Plot a ray's path
588
+ >>> path = paths['reflected_paths'][0] # First ray's path
589
+ >>> plt.plot(path[:, 0], path[:, 2]) # x-z projection
590
+ """
591
+ from .ray_data import create_ray_batch
592
+
593
+ # Clone input rays to avoid modifying original
594
+ current_rays = rays.clone()
595
+
596
+ # Track original ray indices - maps current index to original ray index
597
+ num_original = rays.num_rays
598
+ current_to_original = np.arange(num_original)
599
+
600
+ # Initialize paths for each original ray - start with their original positions
601
+ # Each path is a list of positions that will be converted to array at the end
602
+ ray_path_lists = [[] for _ in range(num_original)]
603
+ ray_final_dirs = [None for _ in range(num_original)] # Final direction for each ray
604
+ ray_is_reflected = [
605
+ True for _ in range(num_original)
606
+ ] # True if reflected, False if refracted
607
+
608
+ # Store initial positions
609
+ for i in range(num_original):
610
+ ray_path_lists[i].append(rays.positions[i].copy())
611
+
612
+ # Storage for rays that have exited the bounding sphere
613
+ exited_positions = []
614
+ exited_directions = []
615
+ exited_wavelengths = []
616
+ exited_intensities = []
617
+ exited_times = []
618
+ exited_generations = []
619
+
620
+ # Storage for all refracted rays
621
+ all_refracted_positions = []
622
+ all_refracted_directions = []
623
+ all_refracted_wavelengths = []
624
+ all_refracted_intensities = []
625
+ all_refracted_times = []
626
+ all_refracted_generations = []
627
+
628
+ # Storage for refracted ray paths (separate from reflected paths)
629
+ refracted_path_lists = []
630
+ refracted_final_dirs = []
631
+
632
+ for _bounce in range(max_bounces):
633
+ # Check which rays are still inside bounding sphere
634
+ distances_from_origin = np.linalg.norm(current_rays.positions, axis=1)
635
+ inside_sphere = distances_from_origin < bounding_radius
636
+
637
+ # Rays that have exited - store them and record final positions/directions
638
+ exited_mask = current_rays.active & ~inside_sphere
639
+ if np.any(exited_mask):
640
+ exited_indices = np.where(exited_mask)[0]
641
+ for local_idx in exited_indices:
642
+ orig_idx = current_to_original[local_idx]
643
+ # Record final position and direction
644
+ ray_path_lists[orig_idx].append(
645
+ current_rays.positions[local_idx].copy()
646
+ )
647
+ ray_final_dirs[orig_idx] = current_rays.directions[local_idx].copy()
648
+
649
+ exited_positions.append(current_rays.positions[exited_mask].copy())
650
+ exited_directions.append(current_rays.directions[exited_mask].copy())
651
+ exited_wavelengths.append(current_rays.wavelengths[exited_mask].copy())
652
+ exited_intensities.append(current_rays.intensities[exited_mask].copy())
653
+ exited_times.append(current_rays.accumulated_time[exited_mask].copy())
654
+ exited_generations.append(current_rays.generations[exited_mask].copy())
655
+ current_rays.active[exited_mask] = False
656
+
657
+ # Update active mask for rays still inside sphere
658
+ current_rays.active &= inside_sphere
659
+
660
+ # Check if any rays are still active
661
+ if not np.any(current_rays.active):
662
+ break
663
+
664
+ # Process surface interaction
665
+ reflected_rays, refracted_rays = process_surface_interaction(
666
+ current_rays,
667
+ surface,
668
+ wavelength=wavelength,
669
+ generate_reflected=True,
670
+ generate_refracted=track_refracted,
671
+ )
672
+
673
+ # Record intersection positions for active rays (reflected rays have the intersection point)
674
+ if reflected_rays is not None and reflected_rays.num_rays > 0:
675
+ # Map the reflected rays back to original indices
676
+ # Note: process_surface_interaction returns rays in same order as input active rays
677
+ active_indices = np.where(current_rays.active)[0]
678
+ for i, local_idx in enumerate(active_indices):
679
+ if i < reflected_rays.num_rays:
680
+ orig_idx = current_to_original[local_idx]
681
+ ray_path_lists[orig_idx].append(reflected_rays.positions[i].copy())
682
+
683
+ # Collect refracted rays and create separate paths for them
684
+ if refracted_rays is not None and refracted_rays.num_rays > 0:
685
+ active_refr = refracted_rays.active
686
+ if np.any(active_refr):
687
+ # Record refracted ray paths (they start from the refraction point)
688
+ active_indices = np.where(current_rays.active)[0]
689
+ refr_active_indices = np.where(active_refr)[0]
690
+ for i in refr_active_indices:
691
+ if i < len(active_indices):
692
+ local_idx = active_indices[i]
693
+ orig_idx = current_to_original[local_idx]
694
+ # Mark this ray as having been refracted
695
+ ray_is_reflected[orig_idx] = False
696
+ # Create a refracted path starting from refraction point
697
+ refracted_path_lists.append(
698
+ np.array([refracted_rays.positions[i].copy()])
699
+ )
700
+ refracted_final_dirs.append(refracted_rays.directions[i].copy())
701
+
702
+ all_refracted_positions.append(
703
+ refracted_rays.positions[active_refr].copy()
704
+ )
705
+ all_refracted_directions.append(
706
+ refracted_rays.directions[active_refr].copy()
707
+ )
708
+ all_refracted_wavelengths.append(
709
+ refracted_rays.wavelengths[active_refr].copy()
710
+ )
711
+ all_refracted_intensities.append(
712
+ refracted_rays.intensities[active_refr].copy()
713
+ )
714
+ all_refracted_times.append(
715
+ refracted_rays.accumulated_time[active_refr].copy()
716
+ )
717
+ all_refracted_generations.append(
718
+ refracted_rays.generations[active_refr].copy()
719
+ )
720
+
721
+ # Continue with reflected rays
722
+ if reflected_rays is None or reflected_rays.num_rays == 0:
723
+ break
724
+
725
+ # Filter out weak rays
726
+ reflected_rays.active &= reflected_rays.intensities > min_intensity
727
+
728
+ if not np.any(reflected_rays.active):
729
+ break
730
+
731
+ # Update index mapping for remaining active rays
732
+ active_mask = reflected_rays.active
733
+ active_indices = np.where(current_rays.active)[0]
734
+
735
+ # Build new mapping: new index -> original ray index
736
+ new_to_original = []
737
+ for i, is_active in enumerate(active_mask):
738
+ if is_active and i < len(active_indices):
739
+ local_idx = active_indices[i]
740
+ new_to_original.append(current_to_original[local_idx])
741
+ current_to_original = np.array(new_to_original)
742
+
743
+ # Use reflected rays for next iteration
744
+ current_rays = reflected_rays
745
+
746
+ # Handle any remaining active rays (didn't exit yet)
747
+ if np.any(current_rays.active):
748
+ remaining_mask = current_rays.active
749
+ remaining_indices = np.where(remaining_mask)[0]
750
+ for local_idx in remaining_indices:
751
+ if local_idx < len(current_to_original):
752
+ orig_idx = current_to_original[local_idx]
753
+ ray_final_dirs[orig_idx] = current_rays.directions[local_idx].copy()
754
+
755
+ remaining_positions = current_rays.positions[remaining_mask]
756
+ remaining_directions = current_rays.directions[remaining_mask]
757
+ remaining_wavelengths = current_rays.wavelengths[remaining_mask]
758
+ remaining_intensities = current_rays.intensities[remaining_mask]
759
+ remaining_times = current_rays.accumulated_time[remaining_mask]
760
+ remaining_generations = current_rays.generations[remaining_mask]
761
+
762
+ exited_positions.append(remaining_positions)
763
+ exited_directions.append(remaining_directions)
764
+ exited_wavelengths.append(remaining_wavelengths)
765
+ exited_intensities.append(remaining_intensities)
766
+ exited_times.append(remaining_times)
767
+ exited_generations.append(remaining_generations)
768
+
769
+ # Combine all exited rays into final reflected batch
770
+ if len(exited_positions) > 0:
771
+ all_exited_positions = np.vstack(exited_positions)
772
+ all_exited_directions = np.vstack(exited_directions)
773
+ all_exited_wavelengths = np.concatenate(exited_wavelengths)
774
+ all_exited_intensities = np.concatenate(exited_intensities)
775
+ all_exited_times = np.concatenate(exited_times)
776
+ all_exited_generations = np.concatenate(exited_generations)
777
+
778
+ num_exited = len(all_exited_positions)
779
+ final_reflected = create_ray_batch(num_rays=num_exited)
780
+ final_reflected.positions[:] = all_exited_positions
781
+ final_reflected.directions[:] = all_exited_directions
782
+ final_reflected.wavelengths[:] = all_exited_wavelengths
783
+ final_reflected.intensities[:] = all_exited_intensities
784
+ final_reflected.accumulated_time[:] = all_exited_times
785
+ final_reflected.generations[:] = all_exited_generations
786
+ final_reflected.active[:] = True
787
+ else:
788
+ # Return empty batch if no rays exited
789
+ final_reflected = create_ray_batch(num_rays=0)
790
+
791
+ # Combine all refracted rays
792
+ if len(all_refracted_positions) > 0:
793
+ combined_refr_positions = np.vstack(all_refracted_positions)
794
+ combined_refr_directions = np.vstack(all_refracted_directions)
795
+ combined_refr_wavelengths = np.concatenate(all_refracted_wavelengths)
796
+ combined_refr_intensities = np.concatenate(all_refracted_intensities)
797
+ combined_refr_times = np.concatenate(all_refracted_times)
798
+ combined_refr_generations = np.concatenate(all_refracted_generations)
799
+
800
+ num_refracted = len(combined_refr_positions)
801
+ final_refracted = create_ray_batch(num_rays=num_refracted)
802
+ final_refracted.positions[:] = combined_refr_positions
803
+ final_refracted.directions[:] = combined_refr_directions
804
+ final_refracted.wavelengths[:] = combined_refr_wavelengths
805
+ final_refracted.intensities[:] = combined_refr_intensities
806
+ final_refracted.accumulated_time[:] = combined_refr_times
807
+ final_refracted.generations[:] = combined_refr_generations
808
+ final_refracted.active[:] = True
809
+ else:
810
+ final_refracted = create_ray_batch(num_rays=0)
811
+
812
+ # Convert path lists to arrays
813
+ reflected_paths = []
814
+ reflected_final_directions = []
815
+ for i in range(num_original):
816
+ if len(ray_path_lists[i]) > 0:
817
+ reflected_paths.append(np.array(ray_path_lists[i]))
818
+ reflected_final_directions.append(ray_final_dirs[i])
819
+
820
+ ray_paths = {
821
+ "reflected_paths": reflected_paths,
822
+ "refracted_paths": refracted_path_lists,
823
+ "reflected_final_dirs": reflected_final_directions,
824
+ "refracted_final_dirs": refracted_final_dirs,
825
+ }
826
+
827
+ return final_reflected, final_refracted, ray_paths
828
+
829
+
830
+ def trace_rays_with_splitting(
831
+ rays: RayBatch,
832
+ surfaces: list,
833
+ max_bounces: int = 10,
834
+ bounding_radius: float = 10000.0,
835
+ bounding_center: tuple = None,
836
+ wavelength: float = 532e-9,
837
+ min_intensity: float = 1e-10,
838
+ polarization: str = "unpolarized",
839
+ ) -> tuple[RayBatch, dict]:
840
+ """
841
+ Trace rays through multiple surfaces with proper Fresnel ray splitting.
842
+
843
+ At each surface interaction, every ray is split into two child rays:
844
+ - A reflected ray with intensity scaled by the Fresnel reflection coefficient R
845
+ - A refracted ray with intensity scaled by the Fresnel transmission coefficient T
846
+
847
+ This creates a ray tree where both reflected and refracted paths are traced
848
+ until termination conditions are met.
849
+
850
+ Parameters
851
+ ----------
852
+ rays : RayBatch
853
+ Initial ray batch to trace
854
+ surfaces : list of Surface
855
+ List of surfaces to interact with (tested in order for closest hit)
856
+ max_bounces : int, optional
857
+ Maximum number of surface interactions per ray tree branch (default: 10)
858
+ bounding_radius : float, optional
859
+ Radius of bounding sphere in meters (default: 10000)
860
+ bounding_center : tuple of float, optional
861
+ Center of bounding sphere (x, y, z) in meters. Default is (0, 0, 0).
862
+ For curved Earth simulations, use (0, 0, -EARTH_RADIUS).
863
+ wavelength : float, optional
864
+ Wavelength for Fresnel calculations (default: 532nm)
865
+ min_intensity : float, optional
866
+ Minimum intensity threshold for ray termination (default: 1e-10)
867
+ polarization : str, optional
868
+ Polarization state: 's', 'p', or 'unpolarized' (default)
869
+
870
+ Returns
871
+ -------
872
+ final_rays : RayBatch
873
+ All terminal rays (rays that exited the scene or hit max bounces)
874
+ Each ray has its intensity weighted by the product of all Fresnel
875
+ coefficients along its path.
876
+ trace_info : dict
877
+ Dictionary containing tracing statistics:
878
+ - 'total_rays_created': Total number of rays created during tracing
879
+ - 'max_depth_reached': Maximum tree depth reached
880
+ - 'terminated_by_intensity': Number of rays terminated due to low intensity
881
+ - 'terminated_by_bounds': Number of rays that exited bounding sphere
882
+ - 'terminated_by_max_bounces': Number of rays that hit max bounce limit
883
+
884
+ Notes
885
+ -----
886
+ This function implements proper physical ray splitting where each ray at
887
+ an interface creates two child rays:
888
+
889
+ - Reflected ray: direction from law of reflection, intensity = I_parent * R
890
+ - Refracted ray: direction from Snell's law, intensity = I_parent * T
891
+
892
+ where R and T are the Fresnel reflection and transmission coefficients
893
+ computed from the refractive indices and incident angle.
894
+
895
+ For total internal reflection (TIR), T=0 and R=1, so only a reflected
896
+ ray is created.
897
+
898
+ The number of rays grows exponentially with depth (up to 2^depth), so
899
+ the min_intensity threshold is critical for pruning weak ray branches.
900
+
901
+ Examples
902
+ --------
903
+ >>> from surface_roughness.surfaces import PlanarSurface
904
+ >>> from surface_roughness.materials import Glass, Air
905
+ >>>
906
+ >>> # Create a glass slab
907
+ >>> surface1 = PlanarSurface(
908
+ ... point=(0, 0, 0.01),
909
+ ... normal=(0, 0, -1),
910
+ ... material_front=Air(),
911
+ ... material_back=Glass()
912
+ ... )
913
+ >>> surface2 = PlanarSurface(
914
+ ... point=(0, 0, 0.02),
915
+ ... normal=(0, 0, -1),
916
+ ... material_front=Glass(),
917
+ ... material_back=Air()
918
+ ... )
919
+ >>>
920
+ >>> # Trace rays with splitting
921
+ >>> final_rays, info = trace_rays_with_splitting(
922
+ ... rays, [surface1, surface2], max_bounces=5
923
+ ... )
924
+ >>> print(f"Created {info['total_rays_created']} rays total")
925
+ """
926
+ from .ray_data import create_ray_batch
927
+
928
+ # Set default bounding center if not provided
929
+ if bounding_center is None:
930
+ bounding_center = np.array([0.0, 0.0, 0.0])
931
+ else:
932
+ bounding_center = np.array(bounding_center)
933
+
934
+ # Statistics tracking
935
+ total_rays_created = rays.num_rays
936
+ terminated_by_intensity = 0
937
+ terminated_by_bounds = 0
938
+ terminated_by_max_bounces = 0
939
+ max_depth_reached = 0
940
+
941
+ # Queue of rays to process: (RayBatch, current_depth)
942
+ ray_queue = [(rays.clone(), 0)]
943
+
944
+ # Collect all terminal rays
945
+ terminal_positions = []
946
+ terminal_directions = []
947
+ terminal_wavelengths = []
948
+ terminal_intensities = []
949
+ terminal_times = []
950
+ terminal_generations = []
951
+
952
+ while ray_queue:
953
+ current_rays, depth = ray_queue.pop(0)
954
+ max_depth_reached = max(max_depth_reached, depth)
955
+
956
+ # Filter out already inactive rays
957
+ if not np.any(current_rays.active):
958
+ continue
959
+
960
+ # Check bounding sphere - terminate rays outside
961
+ # Distance is measured from bounding_center, not origin
962
+ distances_from_center = np.linalg.norm(
963
+ current_rays.positions - bounding_center, axis=1
964
+ )
965
+ outside_bounds = distances_from_center >= bounding_radius
966
+
967
+ exited_mask = current_rays.active & outside_bounds
968
+ if np.any(exited_mask):
969
+ terminated_by_bounds += np.sum(exited_mask)
970
+ terminal_positions.append(current_rays.positions[exited_mask].copy())
971
+ terminal_directions.append(current_rays.directions[exited_mask].copy())
972
+ terminal_wavelengths.append(current_rays.wavelengths[exited_mask].copy())
973
+ terminal_intensities.append(current_rays.intensities[exited_mask].copy())
974
+ terminal_times.append(current_rays.accumulated_time[exited_mask].copy())
975
+ terminal_generations.append(current_rays.generations[exited_mask].copy())
976
+ current_rays.active[exited_mask] = False
977
+
978
+ # Check depth limit
979
+ if depth >= max_bounces:
980
+ remaining_mask = current_rays.active
981
+ if np.any(remaining_mask):
982
+ terminated_by_max_bounces += np.sum(remaining_mask)
983
+ terminal_positions.append(current_rays.positions[remaining_mask].copy())
984
+ terminal_directions.append(
985
+ current_rays.directions[remaining_mask].copy()
986
+ )
987
+ terminal_wavelengths.append(
988
+ current_rays.wavelengths[remaining_mask].copy()
989
+ )
990
+ terminal_intensities.append(
991
+ current_rays.intensities[remaining_mask].copy()
992
+ )
993
+ terminal_times.append(
994
+ current_rays.accumulated_time[remaining_mask].copy()
995
+ )
996
+ terminal_generations.append(
997
+ current_rays.generations[remaining_mask].copy()
998
+ )
999
+ continue
1000
+
1001
+ # Check intensity threshold - terminate weak rays
1002
+ weak_mask = current_rays.active & (current_rays.intensities < min_intensity)
1003
+ if np.any(weak_mask):
1004
+ terminated_by_intensity += np.sum(weak_mask)
1005
+ current_rays.active[weak_mask] = False
1006
+
1007
+ if not np.any(current_rays.active):
1008
+ continue
1009
+
1010
+ # Find closest surface intersection among all surfaces
1011
+ active_mask = current_rays.active
1012
+ active_origins = current_rays.positions[active_mask]
1013
+ active_directions = current_rays.directions[active_mask]
1014
+ num_active = np.sum(active_mask)
1015
+
1016
+ # Initialize with no hit
1017
+ closest_distances = np.full(num_active, np.inf, dtype=np.float32)
1018
+ closest_surface_idx = np.full(num_active, -1, dtype=np.int32)
1019
+ any_hit = np.zeros(num_active, dtype=bool)
1020
+
1021
+ for surf_idx, surface in enumerate(surfaces):
1022
+ distances, hit_mask = surface.intersect(active_origins, active_directions)
1023
+ # Update closest hit
1024
+ closer = hit_mask & (distances < closest_distances)
1025
+ closest_distances[closer] = distances[closer]
1026
+ closest_surface_idx[closer] = surf_idx
1027
+ any_hit |= hit_mask
1028
+
1029
+ if not np.any(any_hit):
1030
+ # No surface hit - rays continue to infinity
1031
+ # Compute exact intersection with bounding sphere and terminate there
1032
+
1033
+ # Solve: |pos + t*dir - center|^2 = R^2
1034
+ # Let p = pos - center
1035
+ # |p + t*dir|^2 = R^2
1036
+ # |dir|^2 * t^2 + 2*(p·dir)*t + (|p|^2 - R^2) = 0
1037
+
1038
+ positions = current_rays.positions[active_mask]
1039
+ directions = current_rays.directions[active_mask]
1040
+ p = positions - bounding_center
1041
+
1042
+ a = np.sum(directions**2, axis=1) # |dir|^2, should be 1
1043
+ b = 2 * np.sum(p * directions, axis=1) # 2*(p·dir)
1044
+ c = np.sum(p**2, axis=1) - bounding_radius**2 # |p|^2 - R^2
1045
+
1046
+ discriminant = b**2 - 4 * a * c
1047
+
1048
+ # For rays inside sphere, discriminant > 0 and we want positive t
1049
+ # t = (-b ± sqrt(disc)) / 2a
1050
+ # We want the far intersection (exit point), so use + sqrt
1051
+ valid = discriminant >= 0
1052
+ t_exit = np.zeros(len(positions), dtype=np.float32)
1053
+ t_exit[valid] = (-b[valid] + np.sqrt(discriminant[valid])) / (2 * a[valid])
1054
+ t_exit[~valid] = 0 # Shouldn't happen for rays inside sphere
1055
+ t_exit = np.maximum(t_exit, 0) # Only forward intersection
1056
+
1057
+ # Move rays to exact bounding sphere intersection
1058
+ exit_positions = positions + t_exit[:, np.newaxis] * directions
1059
+
1060
+ # Store these rays as terminated at bounds
1061
+ active_indices = np.where(active_mask)[0]
1062
+ terminated_by_bounds += len(active_indices)
1063
+ terminal_positions.append(exit_positions.copy())
1064
+ terminal_directions.append(directions.copy())
1065
+ terminal_wavelengths.append(current_rays.wavelengths[active_mask].copy())
1066
+ terminal_intensities.append(current_rays.intensities[active_mask].copy())
1067
+ terminal_times.append(current_rays.accumulated_time[active_mask].copy())
1068
+ terminal_generations.append(current_rays.generations[active_mask].copy())
1069
+
1070
+ # Deactivate these rays
1071
+ current_rays.active[active_mask] = False
1072
+ continue
1073
+
1074
+ # Process each hit surface separately
1075
+ for surf_idx, surface in enumerate(surfaces):
1076
+ surf_hit_mask = (closest_surface_idx == surf_idx) & any_hit
1077
+
1078
+ if not np.any(surf_hit_mask):
1079
+ continue
1080
+
1081
+ # Create a temporary batch for rays hitting this surface
1082
+ full_hit_mask = np.zeros(len(current_rays.positions), dtype=bool)
1083
+ active_indices = np.where(active_mask)[0]
1084
+ full_hit_mask[active_indices[surf_hit_mask]] = True
1085
+
1086
+ if not np.any(full_hit_mask):
1087
+ continue
1088
+
1089
+ # Extract hitting rays
1090
+ hit_origins = current_rays.positions[full_hit_mask]
1091
+ hit_directions = current_rays.directions[full_hit_mask]
1092
+ hit_wavelengths = current_rays.wavelengths[full_hit_mask]
1093
+ hit_intensities = current_rays.intensities[full_hit_mask]
1094
+ hit_times = current_rays.accumulated_time[full_hit_mask]
1095
+ hit_generations = current_rays.generations[full_hit_mask]
1096
+ hit_distances = closest_distances[surf_hit_mask]
1097
+
1098
+ # Compute intersection points
1099
+ hit_positions = hit_origins + hit_distances[:, np.newaxis] * hit_directions
1100
+
1101
+ # Get surface normals
1102
+ normals = surface.normal_at(hit_positions, hit_directions)
1103
+
1104
+ # Get refractive indices
1105
+ n1_values = np.array(
1106
+ [
1107
+ surface.material_front.get_refractive_index(
1108
+ pos[0], pos[1], pos[2], wl
1109
+ )
1110
+ for pos, wl in zip(hit_positions, hit_wavelengths, strict=False)
1111
+ ],
1112
+ dtype=np.float32,
1113
+ )
1114
+ n2_values = np.array(
1115
+ [
1116
+ surface.material_back.get_refractive_index(
1117
+ pos[0], pos[1], pos[2], wl
1118
+ )
1119
+ for pos, wl in zip(hit_positions, hit_wavelengths, strict=False)
1120
+ ],
1121
+ dtype=np.float32,
1122
+ )
1123
+
1124
+ # Compute travel time
1125
+ c = 299792458.0
1126
+ travel_time = hit_distances * n1_values / c
1127
+ updated_times = hit_times + travel_time
1128
+
1129
+ # Compute Fresnel coefficients
1130
+ cos_theta_i = -np.sum(hit_directions * normals, axis=1)
1131
+ cos_theta_i = np.abs(cos_theta_i)
1132
+ R, T = fresnel_coefficients(n1_values, n2_values, cos_theta_i, polarization)
1133
+
1134
+ # Create REFLECTED rays (always created)
1135
+ reflected_directions = compute_reflection_direction(hit_directions, normals)
1136
+ reflected_intensities = hit_intensities * R
1137
+ reflected_active = reflected_intensities > min_intensity
1138
+
1139
+ if np.any(reflected_active):
1140
+ num_reflected = np.sum(reflected_active)
1141
+ reflected_rays = create_ray_batch(num_rays=num_reflected)
1142
+ # Offset must exceed surface.intersect min_distance (0.01)
1143
+ reflected_rays.positions[:] = (
1144
+ hit_positions[reflected_active]
1145
+ + 0.02 * reflected_directions[reflected_active]
1146
+ )
1147
+ reflected_rays.directions[:] = reflected_directions[reflected_active]
1148
+ reflected_rays.wavelengths[:] = hit_wavelengths[reflected_active]
1149
+ reflected_rays.intensities[:] = reflected_intensities[reflected_active]
1150
+ reflected_rays.accumulated_time[:] = updated_times[reflected_active]
1151
+ reflected_rays.generations[:] = hit_generations[reflected_active] + 1
1152
+ reflected_rays.active[:] = True
1153
+
1154
+ total_rays_created += num_reflected
1155
+ ray_queue.append((reflected_rays, depth + 1))
1156
+
1157
+ # Create REFRACTED rays (unless TIR)
1158
+ refracted_directions, tir_mask = compute_refraction_direction(
1159
+ hit_directions, normals, n1_values, n2_values
1160
+ )
1161
+ refracted_intensities = hit_intensities * T
1162
+ refracted_intensities[tir_mask] = 0.0
1163
+ refracted_active = (refracted_intensities > min_intensity) & (~tir_mask)
1164
+
1165
+ if np.any(refracted_active):
1166
+ num_refracted = np.sum(refracted_active)
1167
+ refracted_rays = create_ray_batch(num_rays=num_refracted)
1168
+ # Offset must exceed surface.intersect min_distance (0.01)
1169
+ refracted_rays.positions[:] = (
1170
+ hit_positions[refracted_active]
1171
+ + 0.02 * refracted_directions[refracted_active]
1172
+ )
1173
+ refracted_rays.directions[:] = refracted_directions[refracted_active]
1174
+ refracted_rays.wavelengths[:] = hit_wavelengths[refracted_active]
1175
+ refracted_rays.intensities[:] = refracted_intensities[refracted_active]
1176
+ refracted_rays.accumulated_time[:] = updated_times[refracted_active]
1177
+ refracted_rays.generations[:] = hit_generations[refracted_active] + 1
1178
+ refracted_rays.active[:] = True
1179
+
1180
+ total_rays_created += num_refracted
1181
+ ray_queue.append((refracted_rays, depth + 1))
1182
+
1183
+ # Mark processed rays as inactive in current batch
1184
+ current_rays.active[full_hit_mask] = False
1185
+
1186
+ # Combine all terminal rays
1187
+ if len(terminal_positions) > 0:
1188
+ all_positions = np.vstack(terminal_positions)
1189
+ all_directions = np.vstack(terminal_directions)
1190
+ all_wavelengths = np.concatenate(terminal_wavelengths)
1191
+ all_intensities = np.concatenate(terminal_intensities)
1192
+ all_times = np.concatenate(terminal_times)
1193
+ all_generations = np.concatenate(terminal_generations)
1194
+
1195
+ num_terminal = len(all_positions)
1196
+ final_rays = create_ray_batch(num_rays=num_terminal)
1197
+ final_rays.positions[:] = all_positions
1198
+ final_rays.directions[:] = all_directions
1199
+ final_rays.wavelengths[:] = all_wavelengths
1200
+ final_rays.intensities[:] = all_intensities
1201
+ final_rays.accumulated_time[:] = all_times
1202
+ final_rays.generations[:] = all_generations
1203
+ final_rays.active[:] = True
1204
+ else:
1205
+ final_rays = create_ray_batch(num_rays=0)
1206
+
1207
+ trace_info = {
1208
+ "total_rays_created": total_rays_created,
1209
+ "max_depth_reached": max_depth_reached,
1210
+ "terminated_by_intensity": terminated_by_intensity,
1211
+ "terminated_by_bounds": terminated_by_bounds,
1212
+ "terminated_by_max_bounces": terminated_by_max_bounces,
1213
+ }
1214
+
1215
+ return final_rays, trace_info