diffinytrace 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. diffinytrace/__init__.py +122 -0
  2. diffinytrace/basis_functions/__init__.py +14 -0
  3. diffinytrace/basis_functions/bspline.py +521 -0
  4. diffinytrace/basis_functions/chebyshev.py +3 -0
  5. diffinytrace/basis_functions/legendre.py +77 -0
  6. diffinytrace/basis_functions/zernike.py +235 -0
  7. diffinytrace/config.py +140 -0
  8. diffinytrace/constraints.py +54 -0
  9. diffinytrace/element.py +1660 -0
  10. diffinytrace/export/__init__.py +8 -0
  11. diffinytrace/export/cad.py +253 -0
  12. diffinytrace/gaussian_smoother.py +530 -0
  13. diffinytrace/hat_smoother.py +44 -0
  14. diffinytrace/integrators.py +452 -0
  15. diffinytrace/intersection.py +285 -0
  16. diffinytrace/optimize.py +808 -0
  17. diffinytrace/physical_object.py +150 -0
  18. diffinytrace/plotting/__init__.py +16 -0
  19. diffinytrace/plotting/core.py +92 -0
  20. diffinytrace/plotting/quantity2D.py +188 -0
  21. diffinytrace/plotting/system2D.py +220 -0
  22. diffinytrace/plotting/system3D.py +327 -0
  23. diffinytrace/plotting/wavelength.py +231 -0
  24. diffinytrace/refractive_index.py +101 -0
  25. diffinytrace/render.py +77 -0
  26. diffinytrace/source.py +661 -0
  27. diffinytrace/spectrum.py +79 -0
  28. diffinytrace/surface.py +468 -0
  29. diffinytrace/target_grid.py +399 -0
  30. diffinytrace/transforms.py +472 -0
  31. diffinytrace/utils/__init__.py +7 -0
  32. diffinytrace/utils/autograd.py +116 -0
  33. diffinytrace/utils/irradiance_importer.py +134 -0
  34. diffinytrace-2.1.dist-info/METADATA +26 -0
  35. diffinytrace-2.1.dist-info/RECORD +38 -0
  36. diffinytrace-2.1.dist-info/WHEEL +5 -0
  37. diffinytrace-2.1.dist-info/licenses/LICENSE +21 -0
  38. diffinytrace-2.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,327 @@
1
+ # Copyright (c) 2025 Martin Pflaum
2
+ # This file is part of the diffinytrace project, licensed under the MIT License.
3
+
4
+ __all__ = [
5
+ "ray_paths_one_bin",
6
+ "ray_paths",
7
+ "surface",
8
+ "get_optical_system_layout",
9
+ "_plot_surface",
10
+ "_plot_surface_recursively",
11
+ "plot"
12
+ ]
13
+
14
+ import pandas as pd
15
+ import torch
16
+ import plotly.express as px
17
+ import plotly.graph_objects as go
18
+ import numpy as np
19
+ #import matplotlib.pyplot as plt
20
+ #from PIL import Image
21
+ #import tempfile
22
+ import plotly.express as px
23
+ import plotly.graph_objects as go
24
+ import pandas as pd
25
+ import matplotlib.colors as mcolors
26
+ import plotly.io as pio
27
+ import copy
28
+
29
+ ##2013FF
30
+ def ray_paths_one_bin(rays,ray_color,ray_linewidth):
31
+ """
32
+ Generate a Plotly 3D line plot for a group of rays with the same number of segments.
33
+
34
+ Args:
35
+ rays (list[torch.Tensor]): List of ray paths, each as a tensor.
36
+ ray_color (str): Color for the ray lines.
37
+ ray_linewidth (float): Line width for the ray lines.
38
+
39
+ Returns:
40
+ plotly.graph_objs.Figure: Plotly figure containing the ray paths.
41
+ """
42
+ rays = [elem.numpy() for elem in rays]
43
+ rays = np.array(rays)
44
+ rays = torch.tensor(rays)
45
+ x = rays[:,:,0].reshape(-1)
46
+ y = rays[:,:,1].reshape(-1)
47
+ z = rays[:,:,2].reshape(-1)
48
+ ray_id = torch.arange(rays.shape[1]).reshape(-1,1).repeat(1,rays.shape[0]).T.reshape(-1)
49
+ df = pd.DataFrame({"X":x,"Y":y,"Z":z,"ray id":ray_id})
50
+
51
+ line_fig = px.line_3d(df, x='X', y='Y', z='Z', line_group="ray id")
52
+
53
+ for k in range(len(line_fig.data)):
54
+ line_fig.data[k].line.color = ray_color
55
+ line_fig.data[k].line.width = ray_linewidth
56
+ return line_fig
57
+
58
+ def ray_paths(rays,ray_color="#9673A6",ray_linewidth=3):
59
+ """
60
+ Generate Plotly line objects for multiple ray paths, grouped by path length.
61
+
62
+ Args:
63
+ rays (list[torch.Tensor]): List of rays to plot.
64
+ ray_color (str): Color of the rays.
65
+ ray_linewidth (float): Line width of the rays.
66
+
67
+ Returns:
68
+ list: List of Plotly line objects for each ray group.
69
+ """
70
+ ray_color = mcolors.to_hex(ray_color)
71
+ data = []
72
+ if not rays is None:
73
+ ray_path_bins = {}
74
+ for elem in rays:
75
+ if not len(elem) in ray_path_bins.keys():
76
+ ray_path_bins[len(elem)] = []
77
+ ray_path_bins[len(elem)] += [elem]
78
+ for key in ray_path_bins.keys():
79
+ line_fig = ray_paths_one_bin(ray_path_bins[key],ray_color,ray_linewidth)
80
+ data += [*line_fig.data]
81
+ return data
82
+
83
+
84
+ def surface(transformation,surface,name,aperture_radius,resolution,colorscale,is_square=False):
85
+ """
86
+ Generate a Plotly surface plot for an optical element.
87
+
88
+ Args:
89
+ transformation: Transformation object for the optical element.
90
+ surface: Surface object to plot.
91
+ name (str): Name of the surface.
92
+ aperture_radius (float): Radius of the aperture.
93
+ resolution (int): Resolution for the surface plot.
94
+ colorscale (list): Color scale for the surface plot.
95
+ is_square (bool): Whether the aperture is square or circular.
96
+
97
+ Returns:
98
+ list: List of Plotly surface objects.
99
+ """
100
+
101
+ _x = torch.linspace(-aperture_radius,aperture_radius,resolution)
102
+ _y = torch.linspace(-aperture_radius,aperture_radius,resolution)
103
+ mesh = torch.meshgrid(_x,_y)
104
+ x = mesh[0].reshape(-1)
105
+ y = mesh[1].reshape(-1)
106
+ O = torch.zeros((x.shape[0],3))
107
+
108
+ if not is_square:
109
+ mul = (torch.sqrt(x*x+y*y)>aperture_radius).float()/torch.sqrt(x*x+y*y)*aperture_radius
110
+ mul += (torch.sqrt(x*x+y*y)<aperture_radius).float()
111
+ x = x*mul
112
+ y = y*mul
113
+
114
+
115
+ O[:,0] = x
116
+ O[:,1] = y
117
+ z = None
118
+
119
+ with torch.no_grad():
120
+ z = surface.explicit(O)
121
+ z = z.detach().reshape(-1)
122
+ x = x.detach().reshape(-1)
123
+ y = y.detach().reshape(-1)
124
+ v = torch.zeros((x.shape[0],4))
125
+ v[:,0] = x
126
+ v[:,1] = y
127
+ v[:,2] = z
128
+ v[:,3] = torch.ones_like(v[:,3])
129
+
130
+ Mv = None
131
+ with torch.no_grad():
132
+ M = transformation.get_transformation_matrix().detach()
133
+ Mv = v@M.T
134
+
135
+ x = Mv[:,0].reshape(_x.shape[0],_x.shape[0])
136
+ y = Mv[:,1].reshape(_x.shape[0],_x.shape[0])
137
+ z = Mv[:,2].reshape(_x.shape[0],_x.shape[0])
138
+
139
+ data = []
140
+ data += [go.Surface(x=x, y=y, z=z,showscale=False,name=name,colorscale=colorscale)]
141
+ return data
142
+
143
+
144
+ def get_optical_system_layout(show_grid,xlabel="x [mm]",ylabel="y [mm]",zlabel="z [mm]",xticks=None,yticks=None,zticks=None,axislabel_font_size=10,tick_font_size=10):
145
+ """
146
+ Create a Plotly layout for 3D visualization of the optical system.
147
+
148
+ Args:
149
+ show_grid (bool): Whether to show the grid.
150
+ xlabel (str): Label for the x-axis.
151
+ ylabel (str): Label for the y-axis.
152
+ zlabel (str): Label for the z-axis.
153
+ xticks (list[float], optional): Custom x-ticks.
154
+ yticks (list[float], optional): Custom y-ticks.
155
+ zticks (list[float], optional): Custom z-ticks.
156
+ axislabel_font_size (int): Font size for axis labels.
157
+ tick_font_size (int): Font size for tick labels.
158
+
159
+ Returns:
160
+ plotly.graph_objs.Layout: Layout object for the plot.
161
+ """
162
+ #TODO write wrapper for plot3D!
163
+ camera = dict(
164
+ up=dict(x=1., y=0., z=0)
165
+ )
166
+ xaxis=dict(
167
+ visible=show_grid,
168
+ title=dict(text=xlabel, font=dict(size=axislabel_font_size)), # X axis title font size
169
+ tickfont=dict(size=tick_font_size) # X axis tick labels font size
170
+ )
171
+ yaxis=dict(
172
+ visible=show_grid,
173
+ title=dict(text=ylabel, font=dict(size=axislabel_font_size)), # Y axis title font size
174
+ tickfont=dict(size=tick_font_size) # Y axis tick labels font size
175
+ )
176
+ zaxis=dict(
177
+ visible=show_grid,
178
+ title=dict(text=zlabel, font=dict(size=axislabel_font_size)), # Z axis title font size
179
+ tickfont=dict(size=tick_font_size) # Z axis tick labels font size
180
+ )
181
+
182
+ if xticks is not None:
183
+ xaxis["tickvals"] = xticks
184
+ if yticks is not None:
185
+ yaxis["tickvals"] = yticks
186
+
187
+ if zticks is not None:
188
+ zaxis["tickvals"] = zticks
189
+
190
+ scene = dict(
191
+ xaxis=xaxis,
192
+ yaxis=yaxis,
193
+ zaxis=zaxis,
194
+ aspectmode='data',
195
+ aspectratio = dict(x=1, y=1, z=1),
196
+ )
197
+ """ scene = dict(
198
+ xaxis = dict(visible=show_axis),
199
+ yaxis = dict(visible=show_axis),
200
+ zaxis = dict(visible=show_axis),
201
+ aspectmode='data',
202
+ aspectratio = dict(x=1, y=1, z=1),
203
+ xaxis_title='x [mm]',
204
+ yaxis_title='y [mm]',
205
+ zaxis_title='z [mm]')
206
+ """
207
+ layout = go.Layout(scene_camera=camera,scene=scene)
208
+ return layout
209
+
210
+
211
+ def _plot_surface(surface,name,resolution):
212
+ """
213
+ Generate Plotly surface objects for all 3D surface segments of an element.
214
+
215
+ Args:
216
+ surface: Object with get_plot_points_3D and get_plotly_color_scale methods.
217
+ name (str): Name for the surface.
218
+ resolution (int): Resolution for the surface plot.
219
+
220
+ Returns:
221
+ list: List of Plotly surface objects.
222
+ """
223
+ surface_list = surface.get_plot_points_3D(resolution)
224
+ if len(surface_list)==0:
225
+ return []
226
+ colorscale = surface.get_plotly_color_scale()
227
+
228
+ data = []
229
+ for k,(x,y,z) in enumerate(surface_list):
230
+ try:
231
+ data += [go.Surface(x=x, y=y, z=z,showscale=False,name=name+f"_{k}",colorscale=colorscale[k])]
232
+ except:
233
+ print("Wrong number of colorscales or colorscales is not correct, fallback to first colorscale!")
234
+ data += [go.Surface(x=x, y=y, z=z,showscale=False,name=name+f"_{k}",colorscale=colorscale[0])]
235
+
236
+ return data
237
+
238
+
239
+ def _plot_surface_recursively(current_elem,name,resolution):
240
+ """
241
+ Recursively generate Plotly surface objects for an element and its plotable children.
242
+
243
+ Args:
244
+ current_elem: The current plotable element.
245
+ name (str): Name for the element.
246
+ resolution (int): Resolution for the surface plot.
247
+
248
+ Returns:
249
+ list: List of Plotly surface objects for the element and its children.
250
+ """
251
+ out = _plot_surface(current_elem,name,resolution)
252
+ for elem,elem_name in current_elem.get_plotable_childs():
253
+ out += _plot_surface_recursively(elem,elem_name,resolution)
254
+ return out
255
+
256
+ def plot(element=None,
257
+ rays=None,
258
+ resolution=32,
259
+ show_grid=True,
260
+ xlabel="x [mm]",
261
+ ylabel="y [mm]",
262
+ zlabel="z [mm]",
263
+ xticks=None,
264
+ yticks=None,
265
+ zticks=None,
266
+ axislabel_font_size=10,
267
+ tick_font_size=10,
268
+ ray_color="#9673A6",
269
+ ray_linewidth=3.,
270
+ show=True,
271
+ html_file_name=None):
272
+ """
273
+ Visualize the optical system and ray paths in 3D using Plotly.
274
+
275
+ Args:
276
+ element: The optical system element to plot (must implement Plotable interface).
277
+ rays (list[torch.Tensor] or dict): List of rays or dict containing ray paths.
278
+ resolution (int): Resolution for the surface plot.
279
+ show_grid (bool): Whether to show the grid.
280
+ xlabel (str): Label for the x-axis.
281
+ ylabel (str): Label for the y-axis.
282
+ zlabel (str): Label for the z-axis.
283
+ xticks (list[float], optional): Custom x-ticks.
284
+ yticks (list[float], optional): Custom y-ticks.
285
+ zticks (list[float], optional): Custom z-ticks.
286
+ axislabel_font_size (int): Font size for axis labels.
287
+ tick_font_size (int): Font size for tick labels.
288
+ ray_color (str): Color of the rays.
289
+ ray_linewidth (float): Line width of the rays.
290
+ show (bool): Whether to display the plot immediately.
291
+ html_file_name (str, optional): If provided, saves the plot as an HTML file.
292
+
293
+ Returns:
294
+ plotly.graph_objs.Figure or None: The Plotly figure object if show is False, otherwise None.
295
+ """
296
+
297
+ data = []
298
+ if isinstance(element,(list,tuple)):
299
+ for subelem in element:
300
+ subelem = copy.deepcopy(subelem)
301
+ subelem = subelem.to("cpu")
302
+ data += _plot_surface_recursively(subelem,"",resolution)
303
+
304
+ elif not element is None:
305
+ element = copy.deepcopy(element)
306
+ element = element.to("cpu")
307
+ data += _plot_surface_recursively(element,"",resolution)
308
+
309
+ if not rays is None:
310
+ if isinstance(rays,dict):
311
+ rays = rays["ray_paths"]
312
+
313
+ rays = [elem.cpu() for elem in rays]
314
+ data += ray_paths(rays,ray_color,ray_linewidth)
315
+ layout = get_optical_system_layout(show_grid,xlabel,ylabel,zlabel,xticks,yticks,zticks,axislabel_font_size,tick_font_size)
316
+ fig = go.Figure(data=data,layout=layout)
317
+ if show:
318
+ fig.show()
319
+
320
+ if not html_file_name is None:
321
+ if html_file_name[-5:]!=".html":
322
+ raise RuntimeError("html_file_name should end with .html!")
323
+
324
+ pio.write_html(fig, file=html_file_name, auto_open=False)
325
+
326
+ if not show:
327
+ return fig
@@ -0,0 +1,231 @@
1
+ # Copyright (c) 2025 Martin Pflaum
2
+ # This file is part of the diffinytrace project, licensed under the MIT License.
3
+
4
+ __all__ = [
5
+ "PlotableWavelength",
6
+ "add_colour_bar",
7
+ "plot"
8
+ ]
9
+
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import torch
13
+ from colour import XYZ_to_sRGB,wavelength_to_XYZ
14
+ from typing import Tuple,Optional,Union
15
+
16
+ class PlotableWavelength:
17
+ """
18
+ Represents a wavelength range and y-axis label for plotting spectral data.
19
+
20
+ Attributes:
21
+ bounds (tuple): Lower and upper bounds for the wavelength range.
22
+ ylabel (str): Label for the y-axis in plots.
23
+ """
24
+
25
+ def __init__(self, bounds: Tuple[float, float], ylabel: str):
26
+ self.bounds = bounds
27
+ self.ylabel = ylabel
28
+
29
+ def add_colour_bar(fig, ax, wl):
30
+ """
31
+ Add a color strip below the plot to represent the wavelength spectrum.
32
+
33
+ Args:
34
+ fig (matplotlib.figure.Figure): The figure object.
35
+ ax (matplotlib.axes.Axes): The main axis of the plot.
36
+ wl (array-like): Wavelengths in µm.
37
+ """
38
+ left, bottom, width, height = ax.get_position().bounds
39
+ color_ax = fig.add_axes([left, bottom - 0.15, width, 0.03]) # Position further below to avoid overlap
40
+
41
+ def wavelength_to_rgb(wl):
42
+ wl = wl*1000.
43
+ if 360.0 < wl and wl < 780.0:
44
+ rgb = XYZ_to_sRGB(wavelength_to_XYZ(wl))
45
+ return np.clip(rgb, 0.0, 1.0) # Ensure RGB values are within [0, 1]
46
+ else:
47
+ return (0.,0.,0.)
48
+
49
+ colors = [wavelength_to_rgb(_wl) for _wl in wl]
50
+
51
+ for i in range(len(wl) - 1):
52
+ color_ax.fill_between([wl[i], wl[i + 1]], 0, 1, color=colors[i])
53
+
54
+ color_ax.set_xlim(np.min(wl),np.max(wl))
55
+ color_ax.axis('off') # Hide axis for a clean color strip
56
+
57
+ #TODO change bmin and bmax to bounds
58
+ #refractive_index
59
+ def plot(wl,vals=None,title="",xlabel="Wavelength [µm]",ylabel="y",labels=None,colour_bar=True,linewidth=2,legend=True,resolution=500,show=True):
60
+ """
61
+ Plot a spectrum with a color strip below it.
62
+
63
+ Args:
64
+ wl (array-like): Wavelengths in nm or µm.
65
+ vals (array-like): Values of the spectrum at the given wavelengths.
66
+ title (str): Title of the plot.
67
+ xlabel (str): Label for the x-axis.
68
+ ylabel (str): Label for the y-axis.
69
+ labels (list): Labels for the different curves.
70
+ colour_bar (bool): Whether to show a color bar.
71
+ linewidth (int): Line width of the plot.
72
+ legend (bool): Whether to show a legend.
73
+ resolution (int): Resolution of the plot.
74
+ show (bool): Whether to show the plot.
75
+
76
+ Returns:
77
+ None
78
+ """
79
+ if vals is None:
80
+ if not isinstance(wl,PlotableWavelength):
81
+ raise RuntimeError("if vals=None, wl must be a PlotableWavelength!")
82
+ plotable_func = wl
83
+ wl = np.linspace(*plotable_func.bounds,resolution)
84
+ vals = plotable_func(wl)
85
+ if ylabel=="y":
86
+ ylabel = plotable_func.ylabel
87
+ # Create figure and main axis
88
+ fig, ax = plt.subplots(figsize=(10, 5))
89
+ plt.subplots_adjust(bottom=0.3) # Increase space at the bottom
90
+ vals = np.array(vals)
91
+
92
+ wl = np.array(wl)
93
+ if (wl>100.).any():
94
+ print("wl is µm not nm! Setting wl to µm")
95
+ wl = wl/1000.
96
+
97
+
98
+ vmin = np.min(vals)
99
+ vmax = np.max(vals)
100
+
101
+ if len(vals.shape) == 1:
102
+ val = vals
103
+ vmin = np.min(val)
104
+ ax.plot(wl, val, color='black', linewidth=linewidth)
105
+ ax.fill_between(wl, val, color='gray', alpha=0.2)
106
+ else:
107
+ if vals.shape[1] != wl.shape[0]:
108
+ vals = vals.T
109
+
110
+ for i in range(len(vals)):
111
+ val = vals[i]
112
+ label = None
113
+ if labels is not None:
114
+ label = labels[i]
115
+ ax.plot(wl, val,label=label, linewidth=linewidth)
116
+
117
+
118
+ ax.set_xlim(np.min(wl),np.max(wl))
119
+ ax.set_ylim(vmin,vmax+(vmax-vmin)*0.1)
120
+
121
+ ax.set_xlabel(xlabel)
122
+ ax.set_ylabel(ylabel)
123
+ ax.set_title(title)
124
+ if colour_bar:
125
+ add_colour_bar(fig, ax, wl)
126
+
127
+ if labels is not None:
128
+ if legend:
129
+ ax.legend(loc='upper right')
130
+ if show:
131
+ plt.show()
132
+
133
+
134
+ """
135
+ import matplotlib.pyplot as plt
136
+ import numpy as np
137
+ import colour
138
+
139
+
140
+ import matplotlib.pyplot as plt
141
+ import numpy as np
142
+ import colour
143
+
144
+ # Define the wavelength range (in nm) and the spectrum curve (e.g., Gaussian example)
145
+
146
+ bmin = 300.
147
+ bmax = 3000.
148
+ wavelengths = np.linspace(bmin, bmax, 1000)
149
+ spectrum = np.exp(-((wavelengths - 550) / 40) ** 2) # Gaussian curve centered at 550 nm
150
+
151
+ # Function to map wavelength to RGB color
152
+ def wavelength_to_rgb(wavelength):
153
+ if 360.0 < wavelength and wavelength < 780.0:
154
+ rgb = colour.XYZ_to_sRGB(colour.wavelength_to_XYZ(wavelength))
155
+ return np.clip(rgb, 0.0, 1.0) # Ensure RGB values are within [0, 1]
156
+ else:
157
+ return (0.,0.,0.)
158
+ def add_color_strip(ax, bmin, bmax, resolution=1000):
159
+ wavelengths = np.linspace(bmin, bmax, resolution)
160
+
161
+
162
+ colors = [wavelength_to_rgb(wl) for wl in wavelengths]
163
+
164
+ # Create the color strip as a series of filled segments
165
+ for i in range(len(wavelengths) - 1):
166
+ ax.fill_between([wavelengths[i], wavelengths[i + 1]], 0, 1, color=colors[i])
167
+
168
+ ax.set_xlim(bmin, bmax)
169
+ ax.axis('off') # Hide axis for a clean color strip
170
+
171
+ # Create figure and main axis
172
+ fig, ax = plt.subplots(figsize=(10, 5))
173
+ plt.subplots_adjust(bottom=0.3) # Increase space at the bottom
174
+
175
+ # Plot the spectrum curve
176
+ ax.plot(wavelengths, spectrum, color='black', linewidth=2)
177
+ ax.fill_between(wavelengths, spectrum, color='gray', alpha=0.2)
178
+
179
+ # Add labels and limits for the main plot
180
+ ax.set_xlim(bmin, bmax)
181
+ ax.set_ylim(0, 1.1)
182
+ ax.set_xlabel("Wavelength (nm)")
183
+ ax.set_ylabel("Intensity (a.u.)")
184
+ ax.set_title("Spectrum with Corresponding Colors")
185
+
186
+ # Add an extra axis for the color strip below the main plot, with extra spacing
187
+ left, bottom, width, height = ax.get_position().bounds
188
+ color_ax = fig.add_axes([left, bottom - 0.15, width, 0.03]) # Position further below to avoid overlap
189
+ add_color_strip(color_ax, bmin, bmax, resolution=1000)
190
+
191
+ plt.show()
192
+ #%%
193
+
194
+ """
195
+
196
+
197
+ """
198
+
199
+ #%%
200
+ import matplotlib.pyplot as plt
201
+ import numpy as np
202
+ import colour
203
+
204
+
205
+ import matplotlib.pyplot as plt
206
+ import numpy as np
207
+ import colour
208
+
209
+ # Define the wavelength range (in nm) and the spectrum curve (e.g., Gaussian example)
210
+ wavelengths = np.linspace(380, 780, 1000)
211
+ spectrum = pvlib.spectrum.get_am15g(wavelengths)
212
+ spectrum = np.array(spectrum)
213
+
214
+ # Function to map wavelength to RGB color
215
+ def wavelength_to_rgb(wavelength):
216
+ rgb = colour.XYZ_to_sRGB(colour.wavelength_to_XYZ(wavelength))
217
+ return np.clip(rgb, 0.0, 1.0) # Ensure RGB values are within [0, 1]
218
+
219
+ colors = [wavelength_to_rgb(wl) for wl in wavelengths]
220
+
221
+ # Create the color strip as a series of filled segments
222
+ fig, ax = plt.subplots(figsize=(10, 5))
223
+ for i in range(len(wavelengths) - 1):
224
+ plt.fill_between([wavelengths[i], wavelengths[i + 1]], 0, spectrum[i+1], color=colors[i],interpolate=True)
225
+
226
+ # Plot the spectrum curve
227
+ plt.xlim(380, 750)
228
+ plt.ylim(0., 2.0)
229
+ plt.plot(wavelengths, spectrum, color='black')
230
+ plt.show()
231
+ """
@@ -0,0 +1,101 @@
1
+ # Copyright (c) 2025 Martin Pflaum
2
+ # This file is part of the diffinytrace project, licensed under the MIT License.
3
+
4
+
5
+ __all__ = [
6
+ "RefractiveIndex",
7
+ "materials"
8
+ ]
9
+
10
+ import torch.nn as nn
11
+ import torch
12
+ import numpy as np
13
+ from .plotting.wavelength import PlotableWavelength
14
+
15
+ class RefractiveIndex(nn.Module,PlotableWavelength):
16
+ r"""This class is used to calculate the refractive index of a material.
17
+
18
+ At material interfaces, the transmitted direction :math:`\mathbf{D'}` is computed based on the surface normal
19
+ :math:`\mathbf{N} = \nabla s / \|\nabla s\|` and the incident direction :math:`\mathbf{D}`, using Snell's law (see :cite:`do`):
20
+
21
+ .. math::
22
+
23
+ \mathbf{D'} = \mathbf{N} \sqrt{1 - (1 - \cos^2 \psi_i) \eta^2} + \eta (\mathbf{D} - \mathbf{N} \cos \psi_i),
24
+
25
+ where :math:`\cos \psi_i = \mathbf{D} \cdot \mathbf{N}` and :math:`\eta = n / n'` is the ratio of the refractive indices
26
+ of the two materials.
27
+
28
+ We have implemented the refractive indices as a class that is initialized by a refractive index function and the start
29
+ and end wavelengths for which the function is valid. This makes it very convenient to use with the RefractiveIndex.info
30
+ database of optical constants (see :cite:`polyanskiy2024refractiveindex`), since this database often provides Python
31
+ functions for wavelength-dependent refractive indices.
32
+
33
+ Example:
34
+ Below is an example of how to set up an optical material in our ray tracing library:
35
+
36
+ >>> import diffinytrace as dit
37
+ >>> BaSF = dit.RefractiveIndex(
38
+ >>> lambda x: (1 + 1.65554268 / (1 - 0.0104485644 / x**2) +
39
+ >>> 0.17131977 / (1 - 0.0499394756 / x**2) +
40
+ >>> 1.33664448 / (1 - 118.961472 / x**2))**0.5,
41
+ >>> [0.365, 2.5]
42
+ >>> )
43
+ >>> dit.plotting.wavelength.plot(
44
+ >>> BaSF, title="Refractive index of BaSF (Barium dense flint)"
45
+ >>> )
46
+
47
+ Args:
48
+ func (callable): A function that takes a wavelength in μm and returns the refractive index.
49
+ bounds (tuple): A tuple containing the minimum and maximum wavelength in μm.
50
+ """
51
+ def __init__(self,func,bounds):
52
+ nn.Module.__init__(self)
53
+ PlotableWavelength.__init__(self,bounds,"n [1]")
54
+ self.func = func
55
+ self.bounds = bounds
56
+
57
+ def forward(self,wl):
58
+ """Calculates the refractive index for given wavelengths.
59
+ Args:
60
+ wl (torch.Tensor or float): Wavelength in μm.
61
+ Returns:
62
+ torch.Tensor: Refractive index at the given wavelengths.
63
+ """
64
+ if not torch.is_tensor(wl):
65
+ wl = torch.tensor(wl)
66
+ vmin,vmax = self.bounds
67
+ if not (((vmin <=wl).float()*(wl<=vmax).float())==1.0).all():
68
+ print(f"The wavelength should be given in μm and between {vmin} and {vmax}. Fallback to constant val.")
69
+
70
+
71
+ out = self.func(wl)
72
+ if isinstance(out,float):
73
+ return out*torch.ones_like(wl)
74
+ if isinstance(out,np.ndarray):
75
+ out = torch.tensor(out,device=wl.device,dtype=wl.dtype)
76
+
77
+ out[vmin > wl] = self.func(vmin)
78
+ out[wl>vmax] = self.func(vmax)
79
+
80
+ if torch.is_tensor(out):
81
+ if len(out.shape) == 0:
82
+ return out*torch.ones_like(wl)
83
+ return out
84
+
85
+ """
86
+ All material data is from https://refractiveindex.info/. Please verify the equation and ranges by ur self and the references.
87
+ """
88
+ materials = {
89
+ "NONE": RefractiveIndex(lambda x: 1.0,(0.0,torch.inf)),
90
+ "AIR": RefractiveIndex(lambda x: 1+0.05792105/(238.0185-x**-2)+0.00167917/(57.362-x**-2),(0.23,1.69)),#P. E. Ciddor. Refractive index of air: new equations for the visible and near infrared, Appl. Optics 35, 1566-1573 (1996)
91
+ "HELIUM": RefractiveIndex(lambda x: (1+4977.77e-8/(1-28.54e-6/x**2)+1856.94e-8/(1-7.76e-3/x**2))**.5,(0.48,2.06)),#C. R. Mansfield and E. R. Peck. Dispersion of helium, J. Opt. Soc. Am. 59, 199-203 (1969)
92
+ "PMMA": RefractiveIndex(lambda x: (1+0.99654/(1-0.00787/x**2)+0.18964/(1-0.02191/x**2)+0.00411/(1-3.85727/x**2))**.5,(0.405,1.08)),#Marcin Szczurowski
93
+ "NBK7": RefractiveIndex(lambda x: (1+1.03961212/(1-0.00600069867/x**2)+0.231792344/(1-0.0200179144/x**2)+1.01046945/(1-103.560653/x**2))**.5,(0.3,2.5)),#SCHOTT
94
+ "BAF10": RefractiveIndex(lambda x: (1+1.5851495/(1-0.00926681282/x**2)+0.143559385/(1-0.0424489805/x**2)+1.08521269/(1-105.613573/x**2))**.5,(0.35,2.5)),#SCHOTT
95
+ "BAK1": RefractiveIndex(lambda x: (1+1.12365662/(1-0.00644742752/x**2)+0.309276848/(1-0.0222284402/x**2)+0.881511957/(1-107.297751/x**2))**.5,(0.3,2.5)),#SCHOTT
96
+ "FK51A": RefractiveIndex(lambda x: (1+0.971247817/(1-0.00472301995/x**2)+0.216901417/(1-0.0153575612/x**2)+0.904651666/(1-168.68133/x**2))**.5,(0.29,2.5)),#SCHOTT
97
+ "LASF9": RefractiveIndex(lambda x: (1+2.00029547/(1-0.0121426017/x**2)+0.298926886/(1-0.0538736236/x**2)+1.80691843/(1-156.530829/x**2))**.5,(0.365,2.5)),#SCHOTT
98
+ "SF5": RefractiveIndex(lambda x: (1+1.52481889/(1-0.011254756/x**2)+0.187085527/(1-0.0588995392/x**2)+1.42729015/(1-129.141675/x**2))**.5,(0.37,2.5)),#SCHOTT
99
+ "SF10": RefractiveIndex(lambda x: (1+1.62153902/(1-0.0122241457/x**2)+0.256287842/(1-0.0595736775/x**2)+1.64447552/(1-147.468793/x**2))**.5,(0.38,2.5)),#SCHOTT
100
+ "SF11": RefractiveIndex(lambda x: (1+1.73759695/(1-0.013188707/x**2)+0.313747346/(1-0.0623068142/x**2)+1.89878101/(1-155.23629/x**2))**.5,(0.37,2.5)),#SCHOTT
101
+ }