diffinytrace 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffinytrace/__init__.py +122 -0
- diffinytrace/basis_functions/__init__.py +14 -0
- diffinytrace/basis_functions/bspline.py +521 -0
- diffinytrace/basis_functions/chebyshev.py +3 -0
- diffinytrace/basis_functions/legendre.py +77 -0
- diffinytrace/basis_functions/zernike.py +235 -0
- diffinytrace/config.py +140 -0
- diffinytrace/constraints.py +54 -0
- diffinytrace/element.py +1660 -0
- diffinytrace/export/__init__.py +8 -0
- diffinytrace/export/cad.py +253 -0
- diffinytrace/gaussian_smoother.py +530 -0
- diffinytrace/hat_smoother.py +44 -0
- diffinytrace/integrators.py +452 -0
- diffinytrace/intersection.py +285 -0
- diffinytrace/optimize.py +808 -0
- diffinytrace/physical_object.py +150 -0
- diffinytrace/plotting/__init__.py +16 -0
- diffinytrace/plotting/core.py +92 -0
- diffinytrace/plotting/quantity2D.py +188 -0
- diffinytrace/plotting/system2D.py +220 -0
- diffinytrace/plotting/system3D.py +327 -0
- diffinytrace/plotting/wavelength.py +231 -0
- diffinytrace/refractive_index.py +101 -0
- diffinytrace/render.py +77 -0
- diffinytrace/source.py +661 -0
- diffinytrace/spectrum.py +79 -0
- diffinytrace/surface.py +468 -0
- diffinytrace/target_grid.py +399 -0
- diffinytrace/transforms.py +472 -0
- diffinytrace/utils/__init__.py +7 -0
- diffinytrace/utils/autograd.py +116 -0
- diffinytrace/utils/irradiance_importer.py +134 -0
- diffinytrace-2.1.dist-info/METADATA +26 -0
- diffinytrace-2.1.dist-info/RECORD +38 -0
- diffinytrace-2.1.dist-info/WHEEL +5 -0
- diffinytrace-2.1.dist-info/licenses/LICENSE +21 -0
- diffinytrace-2.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,327 @@
|
|
|
1
|
+
# Copyright (c) 2025 Martin Pflaum
|
|
2
|
+
# This file is part of the diffinytrace project, licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"ray_paths_one_bin",
|
|
6
|
+
"ray_paths",
|
|
7
|
+
"surface",
|
|
8
|
+
"get_optical_system_layout",
|
|
9
|
+
"_plot_surface",
|
|
10
|
+
"_plot_surface_recursively",
|
|
11
|
+
"plot"
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
import pandas as pd
|
|
15
|
+
import torch
|
|
16
|
+
import plotly.express as px
|
|
17
|
+
import plotly.graph_objects as go
|
|
18
|
+
import numpy as np
|
|
19
|
+
#import matplotlib.pyplot as plt
|
|
20
|
+
#from PIL import Image
|
|
21
|
+
#import tempfile
|
|
22
|
+
import plotly.express as px
|
|
23
|
+
import plotly.graph_objects as go
|
|
24
|
+
import pandas as pd
|
|
25
|
+
import matplotlib.colors as mcolors
|
|
26
|
+
import plotly.io as pio
|
|
27
|
+
import copy
|
|
28
|
+
|
|
29
|
+
##2013FF
|
|
30
|
+
def ray_paths_one_bin(rays,ray_color,ray_linewidth):
|
|
31
|
+
"""
|
|
32
|
+
Generate a Plotly 3D line plot for a group of rays with the same number of segments.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
rays (list[torch.Tensor]): List of ray paths, each as a tensor.
|
|
36
|
+
ray_color (str): Color for the ray lines.
|
|
37
|
+
ray_linewidth (float): Line width for the ray lines.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
plotly.graph_objs.Figure: Plotly figure containing the ray paths.
|
|
41
|
+
"""
|
|
42
|
+
rays = [elem.numpy() for elem in rays]
|
|
43
|
+
rays = np.array(rays)
|
|
44
|
+
rays = torch.tensor(rays)
|
|
45
|
+
x = rays[:,:,0].reshape(-1)
|
|
46
|
+
y = rays[:,:,1].reshape(-1)
|
|
47
|
+
z = rays[:,:,2].reshape(-1)
|
|
48
|
+
ray_id = torch.arange(rays.shape[1]).reshape(-1,1).repeat(1,rays.shape[0]).T.reshape(-1)
|
|
49
|
+
df = pd.DataFrame({"X":x,"Y":y,"Z":z,"ray id":ray_id})
|
|
50
|
+
|
|
51
|
+
line_fig = px.line_3d(df, x='X', y='Y', z='Z', line_group="ray id")
|
|
52
|
+
|
|
53
|
+
for k in range(len(line_fig.data)):
|
|
54
|
+
line_fig.data[k].line.color = ray_color
|
|
55
|
+
line_fig.data[k].line.width = ray_linewidth
|
|
56
|
+
return line_fig
|
|
57
|
+
|
|
58
|
+
def ray_paths(rays,ray_color="#9673A6",ray_linewidth=3):
|
|
59
|
+
"""
|
|
60
|
+
Generate Plotly line objects for multiple ray paths, grouped by path length.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
rays (list[torch.Tensor]): List of rays to plot.
|
|
64
|
+
ray_color (str): Color of the rays.
|
|
65
|
+
ray_linewidth (float): Line width of the rays.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
list: List of Plotly line objects for each ray group.
|
|
69
|
+
"""
|
|
70
|
+
ray_color = mcolors.to_hex(ray_color)
|
|
71
|
+
data = []
|
|
72
|
+
if not rays is None:
|
|
73
|
+
ray_path_bins = {}
|
|
74
|
+
for elem in rays:
|
|
75
|
+
if not len(elem) in ray_path_bins.keys():
|
|
76
|
+
ray_path_bins[len(elem)] = []
|
|
77
|
+
ray_path_bins[len(elem)] += [elem]
|
|
78
|
+
for key in ray_path_bins.keys():
|
|
79
|
+
line_fig = ray_paths_one_bin(ray_path_bins[key],ray_color,ray_linewidth)
|
|
80
|
+
data += [*line_fig.data]
|
|
81
|
+
return data
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def surface(transformation,surface,name,aperture_radius,resolution,colorscale,is_square=False):
|
|
85
|
+
"""
|
|
86
|
+
Generate a Plotly surface plot for an optical element.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
transformation: Transformation object for the optical element.
|
|
90
|
+
surface: Surface object to plot.
|
|
91
|
+
name (str): Name of the surface.
|
|
92
|
+
aperture_radius (float): Radius of the aperture.
|
|
93
|
+
resolution (int): Resolution for the surface plot.
|
|
94
|
+
colorscale (list): Color scale for the surface plot.
|
|
95
|
+
is_square (bool): Whether the aperture is square or circular.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
list: List of Plotly surface objects.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
_x = torch.linspace(-aperture_radius,aperture_radius,resolution)
|
|
102
|
+
_y = torch.linspace(-aperture_radius,aperture_radius,resolution)
|
|
103
|
+
mesh = torch.meshgrid(_x,_y)
|
|
104
|
+
x = mesh[0].reshape(-1)
|
|
105
|
+
y = mesh[1].reshape(-1)
|
|
106
|
+
O = torch.zeros((x.shape[0],3))
|
|
107
|
+
|
|
108
|
+
if not is_square:
|
|
109
|
+
mul = (torch.sqrt(x*x+y*y)>aperture_radius).float()/torch.sqrt(x*x+y*y)*aperture_radius
|
|
110
|
+
mul += (torch.sqrt(x*x+y*y)<aperture_radius).float()
|
|
111
|
+
x = x*mul
|
|
112
|
+
y = y*mul
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
O[:,0] = x
|
|
116
|
+
O[:,1] = y
|
|
117
|
+
z = None
|
|
118
|
+
|
|
119
|
+
with torch.no_grad():
|
|
120
|
+
z = surface.explicit(O)
|
|
121
|
+
z = z.detach().reshape(-1)
|
|
122
|
+
x = x.detach().reshape(-1)
|
|
123
|
+
y = y.detach().reshape(-1)
|
|
124
|
+
v = torch.zeros((x.shape[0],4))
|
|
125
|
+
v[:,0] = x
|
|
126
|
+
v[:,1] = y
|
|
127
|
+
v[:,2] = z
|
|
128
|
+
v[:,3] = torch.ones_like(v[:,3])
|
|
129
|
+
|
|
130
|
+
Mv = None
|
|
131
|
+
with torch.no_grad():
|
|
132
|
+
M = transformation.get_transformation_matrix().detach()
|
|
133
|
+
Mv = v@M.T
|
|
134
|
+
|
|
135
|
+
x = Mv[:,0].reshape(_x.shape[0],_x.shape[0])
|
|
136
|
+
y = Mv[:,1].reshape(_x.shape[0],_x.shape[0])
|
|
137
|
+
z = Mv[:,2].reshape(_x.shape[0],_x.shape[0])
|
|
138
|
+
|
|
139
|
+
data = []
|
|
140
|
+
data += [go.Surface(x=x, y=y, z=z,showscale=False,name=name,colorscale=colorscale)]
|
|
141
|
+
return data
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def get_optical_system_layout(show_grid,xlabel="x [mm]",ylabel="y [mm]",zlabel="z [mm]",xticks=None,yticks=None,zticks=None,axislabel_font_size=10,tick_font_size=10):
|
|
145
|
+
"""
|
|
146
|
+
Create a Plotly layout for 3D visualization of the optical system.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
show_grid (bool): Whether to show the grid.
|
|
150
|
+
xlabel (str): Label for the x-axis.
|
|
151
|
+
ylabel (str): Label for the y-axis.
|
|
152
|
+
zlabel (str): Label for the z-axis.
|
|
153
|
+
xticks (list[float], optional): Custom x-ticks.
|
|
154
|
+
yticks (list[float], optional): Custom y-ticks.
|
|
155
|
+
zticks (list[float], optional): Custom z-ticks.
|
|
156
|
+
axislabel_font_size (int): Font size for axis labels.
|
|
157
|
+
tick_font_size (int): Font size for tick labels.
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
plotly.graph_objs.Layout: Layout object for the plot.
|
|
161
|
+
"""
|
|
162
|
+
#TODO write wrapper for plot3D!
|
|
163
|
+
camera = dict(
|
|
164
|
+
up=dict(x=1., y=0., z=0)
|
|
165
|
+
)
|
|
166
|
+
xaxis=dict(
|
|
167
|
+
visible=show_grid,
|
|
168
|
+
title=dict(text=xlabel, font=dict(size=axislabel_font_size)), # X axis title font size
|
|
169
|
+
tickfont=dict(size=tick_font_size) # X axis tick labels font size
|
|
170
|
+
)
|
|
171
|
+
yaxis=dict(
|
|
172
|
+
visible=show_grid,
|
|
173
|
+
title=dict(text=ylabel, font=dict(size=axislabel_font_size)), # Y axis title font size
|
|
174
|
+
tickfont=dict(size=tick_font_size) # Y axis tick labels font size
|
|
175
|
+
)
|
|
176
|
+
zaxis=dict(
|
|
177
|
+
visible=show_grid,
|
|
178
|
+
title=dict(text=zlabel, font=dict(size=axislabel_font_size)), # Z axis title font size
|
|
179
|
+
tickfont=dict(size=tick_font_size) # Z axis tick labels font size
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
if xticks is not None:
|
|
183
|
+
xaxis["tickvals"] = xticks
|
|
184
|
+
if yticks is not None:
|
|
185
|
+
yaxis["tickvals"] = yticks
|
|
186
|
+
|
|
187
|
+
if zticks is not None:
|
|
188
|
+
zaxis["tickvals"] = zticks
|
|
189
|
+
|
|
190
|
+
scene = dict(
|
|
191
|
+
xaxis=xaxis,
|
|
192
|
+
yaxis=yaxis,
|
|
193
|
+
zaxis=zaxis,
|
|
194
|
+
aspectmode='data',
|
|
195
|
+
aspectratio = dict(x=1, y=1, z=1),
|
|
196
|
+
)
|
|
197
|
+
""" scene = dict(
|
|
198
|
+
xaxis = dict(visible=show_axis),
|
|
199
|
+
yaxis = dict(visible=show_axis),
|
|
200
|
+
zaxis = dict(visible=show_axis),
|
|
201
|
+
aspectmode='data',
|
|
202
|
+
aspectratio = dict(x=1, y=1, z=1),
|
|
203
|
+
xaxis_title='x [mm]',
|
|
204
|
+
yaxis_title='y [mm]',
|
|
205
|
+
zaxis_title='z [mm]')
|
|
206
|
+
"""
|
|
207
|
+
layout = go.Layout(scene_camera=camera,scene=scene)
|
|
208
|
+
return layout
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _plot_surface(surface,name,resolution):
|
|
212
|
+
"""
|
|
213
|
+
Generate Plotly surface objects for all 3D surface segments of an element.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
surface: Object with get_plot_points_3D and get_plotly_color_scale methods.
|
|
217
|
+
name (str): Name for the surface.
|
|
218
|
+
resolution (int): Resolution for the surface plot.
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
list: List of Plotly surface objects.
|
|
222
|
+
"""
|
|
223
|
+
surface_list = surface.get_plot_points_3D(resolution)
|
|
224
|
+
if len(surface_list)==0:
|
|
225
|
+
return []
|
|
226
|
+
colorscale = surface.get_plotly_color_scale()
|
|
227
|
+
|
|
228
|
+
data = []
|
|
229
|
+
for k,(x,y,z) in enumerate(surface_list):
|
|
230
|
+
try:
|
|
231
|
+
data += [go.Surface(x=x, y=y, z=z,showscale=False,name=name+f"_{k}",colorscale=colorscale[k])]
|
|
232
|
+
except:
|
|
233
|
+
print("Wrong number of colorscales or colorscales is not correct, fallback to first colorscale!")
|
|
234
|
+
data += [go.Surface(x=x, y=y, z=z,showscale=False,name=name+f"_{k}",colorscale=colorscale[0])]
|
|
235
|
+
|
|
236
|
+
return data
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def _plot_surface_recursively(current_elem,name,resolution):
|
|
240
|
+
"""
|
|
241
|
+
Recursively generate Plotly surface objects for an element and its plotable children.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
current_elem: The current plotable element.
|
|
245
|
+
name (str): Name for the element.
|
|
246
|
+
resolution (int): Resolution for the surface plot.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
list: List of Plotly surface objects for the element and its children.
|
|
250
|
+
"""
|
|
251
|
+
out = _plot_surface(current_elem,name,resolution)
|
|
252
|
+
for elem,elem_name in current_elem.get_plotable_childs():
|
|
253
|
+
out += _plot_surface_recursively(elem,elem_name,resolution)
|
|
254
|
+
return out
|
|
255
|
+
|
|
256
|
+
def plot(element=None,
|
|
257
|
+
rays=None,
|
|
258
|
+
resolution=32,
|
|
259
|
+
show_grid=True,
|
|
260
|
+
xlabel="x [mm]",
|
|
261
|
+
ylabel="y [mm]",
|
|
262
|
+
zlabel="z [mm]",
|
|
263
|
+
xticks=None,
|
|
264
|
+
yticks=None,
|
|
265
|
+
zticks=None,
|
|
266
|
+
axislabel_font_size=10,
|
|
267
|
+
tick_font_size=10,
|
|
268
|
+
ray_color="#9673A6",
|
|
269
|
+
ray_linewidth=3.,
|
|
270
|
+
show=True,
|
|
271
|
+
html_file_name=None):
|
|
272
|
+
"""
|
|
273
|
+
Visualize the optical system and ray paths in 3D using Plotly.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
element: The optical system element to plot (must implement Plotable interface).
|
|
277
|
+
rays (list[torch.Tensor] or dict): List of rays or dict containing ray paths.
|
|
278
|
+
resolution (int): Resolution for the surface plot.
|
|
279
|
+
show_grid (bool): Whether to show the grid.
|
|
280
|
+
xlabel (str): Label for the x-axis.
|
|
281
|
+
ylabel (str): Label for the y-axis.
|
|
282
|
+
zlabel (str): Label for the z-axis.
|
|
283
|
+
xticks (list[float], optional): Custom x-ticks.
|
|
284
|
+
yticks (list[float], optional): Custom y-ticks.
|
|
285
|
+
zticks (list[float], optional): Custom z-ticks.
|
|
286
|
+
axislabel_font_size (int): Font size for axis labels.
|
|
287
|
+
tick_font_size (int): Font size for tick labels.
|
|
288
|
+
ray_color (str): Color of the rays.
|
|
289
|
+
ray_linewidth (float): Line width of the rays.
|
|
290
|
+
show (bool): Whether to display the plot immediately.
|
|
291
|
+
html_file_name (str, optional): If provided, saves the plot as an HTML file.
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
plotly.graph_objs.Figure or None: The Plotly figure object if show is False, otherwise None.
|
|
295
|
+
"""
|
|
296
|
+
|
|
297
|
+
data = []
|
|
298
|
+
if isinstance(element,(list,tuple)):
|
|
299
|
+
for subelem in element:
|
|
300
|
+
subelem = copy.deepcopy(subelem)
|
|
301
|
+
subelem = subelem.to("cpu")
|
|
302
|
+
data += _plot_surface_recursively(subelem,"",resolution)
|
|
303
|
+
|
|
304
|
+
elif not element is None:
|
|
305
|
+
element = copy.deepcopy(element)
|
|
306
|
+
element = element.to("cpu")
|
|
307
|
+
data += _plot_surface_recursively(element,"",resolution)
|
|
308
|
+
|
|
309
|
+
if not rays is None:
|
|
310
|
+
if isinstance(rays,dict):
|
|
311
|
+
rays = rays["ray_paths"]
|
|
312
|
+
|
|
313
|
+
rays = [elem.cpu() for elem in rays]
|
|
314
|
+
data += ray_paths(rays,ray_color,ray_linewidth)
|
|
315
|
+
layout = get_optical_system_layout(show_grid,xlabel,ylabel,zlabel,xticks,yticks,zticks,axislabel_font_size,tick_font_size)
|
|
316
|
+
fig = go.Figure(data=data,layout=layout)
|
|
317
|
+
if show:
|
|
318
|
+
fig.show()
|
|
319
|
+
|
|
320
|
+
if not html_file_name is None:
|
|
321
|
+
if html_file_name[-5:]!=".html":
|
|
322
|
+
raise RuntimeError("html_file_name should end with .html!")
|
|
323
|
+
|
|
324
|
+
pio.write_html(fig, file=html_file_name, auto_open=False)
|
|
325
|
+
|
|
326
|
+
if not show:
|
|
327
|
+
return fig
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
# Copyright (c) 2025 Martin Pflaum
|
|
2
|
+
# This file is part of the diffinytrace project, licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"PlotableWavelength",
|
|
6
|
+
"add_colour_bar",
|
|
7
|
+
"plot"
|
|
8
|
+
]
|
|
9
|
+
|
|
10
|
+
import matplotlib.pyplot as plt
|
|
11
|
+
import numpy as np
|
|
12
|
+
import torch
|
|
13
|
+
from colour import XYZ_to_sRGB,wavelength_to_XYZ
|
|
14
|
+
from typing import Tuple,Optional,Union
|
|
15
|
+
|
|
16
|
+
class PlotableWavelength:
|
|
17
|
+
"""
|
|
18
|
+
Represents a wavelength range and y-axis label for plotting spectral data.
|
|
19
|
+
|
|
20
|
+
Attributes:
|
|
21
|
+
bounds (tuple): Lower and upper bounds for the wavelength range.
|
|
22
|
+
ylabel (str): Label for the y-axis in plots.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, bounds: Tuple[float, float], ylabel: str):
|
|
26
|
+
self.bounds = bounds
|
|
27
|
+
self.ylabel = ylabel
|
|
28
|
+
|
|
29
|
+
def add_colour_bar(fig, ax, wl):
|
|
30
|
+
"""
|
|
31
|
+
Add a color strip below the plot to represent the wavelength spectrum.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
fig (matplotlib.figure.Figure): The figure object.
|
|
35
|
+
ax (matplotlib.axes.Axes): The main axis of the plot.
|
|
36
|
+
wl (array-like): Wavelengths in µm.
|
|
37
|
+
"""
|
|
38
|
+
left, bottom, width, height = ax.get_position().bounds
|
|
39
|
+
color_ax = fig.add_axes([left, bottom - 0.15, width, 0.03]) # Position further below to avoid overlap
|
|
40
|
+
|
|
41
|
+
def wavelength_to_rgb(wl):
|
|
42
|
+
wl = wl*1000.
|
|
43
|
+
if 360.0 < wl and wl < 780.0:
|
|
44
|
+
rgb = XYZ_to_sRGB(wavelength_to_XYZ(wl))
|
|
45
|
+
return np.clip(rgb, 0.0, 1.0) # Ensure RGB values are within [0, 1]
|
|
46
|
+
else:
|
|
47
|
+
return (0.,0.,0.)
|
|
48
|
+
|
|
49
|
+
colors = [wavelength_to_rgb(_wl) for _wl in wl]
|
|
50
|
+
|
|
51
|
+
for i in range(len(wl) - 1):
|
|
52
|
+
color_ax.fill_between([wl[i], wl[i + 1]], 0, 1, color=colors[i])
|
|
53
|
+
|
|
54
|
+
color_ax.set_xlim(np.min(wl),np.max(wl))
|
|
55
|
+
color_ax.axis('off') # Hide axis for a clean color strip
|
|
56
|
+
|
|
57
|
+
#TODO change bmin and bmax to bounds
|
|
58
|
+
#refractive_index
|
|
59
|
+
def plot(wl,vals=None,title="",xlabel="Wavelength [µm]",ylabel="y",labels=None,colour_bar=True,linewidth=2,legend=True,resolution=500,show=True):
|
|
60
|
+
"""
|
|
61
|
+
Plot a spectrum with a color strip below it.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
wl (array-like): Wavelengths in nm or µm.
|
|
65
|
+
vals (array-like): Values of the spectrum at the given wavelengths.
|
|
66
|
+
title (str): Title of the plot.
|
|
67
|
+
xlabel (str): Label for the x-axis.
|
|
68
|
+
ylabel (str): Label for the y-axis.
|
|
69
|
+
labels (list): Labels for the different curves.
|
|
70
|
+
colour_bar (bool): Whether to show a color bar.
|
|
71
|
+
linewidth (int): Line width of the plot.
|
|
72
|
+
legend (bool): Whether to show a legend.
|
|
73
|
+
resolution (int): Resolution of the plot.
|
|
74
|
+
show (bool): Whether to show the plot.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
None
|
|
78
|
+
"""
|
|
79
|
+
if vals is None:
|
|
80
|
+
if not isinstance(wl,PlotableWavelength):
|
|
81
|
+
raise RuntimeError("if vals=None, wl must be a PlotableWavelength!")
|
|
82
|
+
plotable_func = wl
|
|
83
|
+
wl = np.linspace(*plotable_func.bounds,resolution)
|
|
84
|
+
vals = plotable_func(wl)
|
|
85
|
+
if ylabel=="y":
|
|
86
|
+
ylabel = plotable_func.ylabel
|
|
87
|
+
# Create figure and main axis
|
|
88
|
+
fig, ax = plt.subplots(figsize=(10, 5))
|
|
89
|
+
plt.subplots_adjust(bottom=0.3) # Increase space at the bottom
|
|
90
|
+
vals = np.array(vals)
|
|
91
|
+
|
|
92
|
+
wl = np.array(wl)
|
|
93
|
+
if (wl>100.).any():
|
|
94
|
+
print("wl is µm not nm! Setting wl to µm")
|
|
95
|
+
wl = wl/1000.
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
vmin = np.min(vals)
|
|
99
|
+
vmax = np.max(vals)
|
|
100
|
+
|
|
101
|
+
if len(vals.shape) == 1:
|
|
102
|
+
val = vals
|
|
103
|
+
vmin = np.min(val)
|
|
104
|
+
ax.plot(wl, val, color='black', linewidth=linewidth)
|
|
105
|
+
ax.fill_between(wl, val, color='gray', alpha=0.2)
|
|
106
|
+
else:
|
|
107
|
+
if vals.shape[1] != wl.shape[0]:
|
|
108
|
+
vals = vals.T
|
|
109
|
+
|
|
110
|
+
for i in range(len(vals)):
|
|
111
|
+
val = vals[i]
|
|
112
|
+
label = None
|
|
113
|
+
if labels is not None:
|
|
114
|
+
label = labels[i]
|
|
115
|
+
ax.plot(wl, val,label=label, linewidth=linewidth)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
ax.set_xlim(np.min(wl),np.max(wl))
|
|
119
|
+
ax.set_ylim(vmin,vmax+(vmax-vmin)*0.1)
|
|
120
|
+
|
|
121
|
+
ax.set_xlabel(xlabel)
|
|
122
|
+
ax.set_ylabel(ylabel)
|
|
123
|
+
ax.set_title(title)
|
|
124
|
+
if colour_bar:
|
|
125
|
+
add_colour_bar(fig, ax, wl)
|
|
126
|
+
|
|
127
|
+
if labels is not None:
|
|
128
|
+
if legend:
|
|
129
|
+
ax.legend(loc='upper right')
|
|
130
|
+
if show:
|
|
131
|
+
plt.show()
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
"""
|
|
135
|
+
import matplotlib.pyplot as plt
|
|
136
|
+
import numpy as np
|
|
137
|
+
import colour
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
import matplotlib.pyplot as plt
|
|
141
|
+
import numpy as np
|
|
142
|
+
import colour
|
|
143
|
+
|
|
144
|
+
# Define the wavelength range (in nm) and the spectrum curve (e.g., Gaussian example)
|
|
145
|
+
|
|
146
|
+
bmin = 300.
|
|
147
|
+
bmax = 3000.
|
|
148
|
+
wavelengths = np.linspace(bmin, bmax, 1000)
|
|
149
|
+
spectrum = np.exp(-((wavelengths - 550) / 40) ** 2) # Gaussian curve centered at 550 nm
|
|
150
|
+
|
|
151
|
+
# Function to map wavelength to RGB color
|
|
152
|
+
def wavelength_to_rgb(wavelength):
|
|
153
|
+
if 360.0 < wavelength and wavelength < 780.0:
|
|
154
|
+
rgb = colour.XYZ_to_sRGB(colour.wavelength_to_XYZ(wavelength))
|
|
155
|
+
return np.clip(rgb, 0.0, 1.0) # Ensure RGB values are within [0, 1]
|
|
156
|
+
else:
|
|
157
|
+
return (0.,0.,0.)
|
|
158
|
+
def add_color_strip(ax, bmin, bmax, resolution=1000):
|
|
159
|
+
wavelengths = np.linspace(bmin, bmax, resolution)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
colors = [wavelength_to_rgb(wl) for wl in wavelengths]
|
|
163
|
+
|
|
164
|
+
# Create the color strip as a series of filled segments
|
|
165
|
+
for i in range(len(wavelengths) - 1):
|
|
166
|
+
ax.fill_between([wavelengths[i], wavelengths[i + 1]], 0, 1, color=colors[i])
|
|
167
|
+
|
|
168
|
+
ax.set_xlim(bmin, bmax)
|
|
169
|
+
ax.axis('off') # Hide axis for a clean color strip
|
|
170
|
+
|
|
171
|
+
# Create figure and main axis
|
|
172
|
+
fig, ax = plt.subplots(figsize=(10, 5))
|
|
173
|
+
plt.subplots_adjust(bottom=0.3) # Increase space at the bottom
|
|
174
|
+
|
|
175
|
+
# Plot the spectrum curve
|
|
176
|
+
ax.plot(wavelengths, spectrum, color='black', linewidth=2)
|
|
177
|
+
ax.fill_between(wavelengths, spectrum, color='gray', alpha=0.2)
|
|
178
|
+
|
|
179
|
+
# Add labels and limits for the main plot
|
|
180
|
+
ax.set_xlim(bmin, bmax)
|
|
181
|
+
ax.set_ylim(0, 1.1)
|
|
182
|
+
ax.set_xlabel("Wavelength (nm)")
|
|
183
|
+
ax.set_ylabel("Intensity (a.u.)")
|
|
184
|
+
ax.set_title("Spectrum with Corresponding Colors")
|
|
185
|
+
|
|
186
|
+
# Add an extra axis for the color strip below the main plot, with extra spacing
|
|
187
|
+
left, bottom, width, height = ax.get_position().bounds
|
|
188
|
+
color_ax = fig.add_axes([left, bottom - 0.15, width, 0.03]) # Position further below to avoid overlap
|
|
189
|
+
add_color_strip(color_ax, bmin, bmax, resolution=1000)
|
|
190
|
+
|
|
191
|
+
plt.show()
|
|
192
|
+
#%%
|
|
193
|
+
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
#%%
|
|
200
|
+
import matplotlib.pyplot as plt
|
|
201
|
+
import numpy as np
|
|
202
|
+
import colour
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
import matplotlib.pyplot as plt
|
|
206
|
+
import numpy as np
|
|
207
|
+
import colour
|
|
208
|
+
|
|
209
|
+
# Define the wavelength range (in nm) and the spectrum curve (e.g., Gaussian example)
|
|
210
|
+
wavelengths = np.linspace(380, 780, 1000)
|
|
211
|
+
spectrum = pvlib.spectrum.get_am15g(wavelengths)
|
|
212
|
+
spectrum = np.array(spectrum)
|
|
213
|
+
|
|
214
|
+
# Function to map wavelength to RGB color
|
|
215
|
+
def wavelength_to_rgb(wavelength):
|
|
216
|
+
rgb = colour.XYZ_to_sRGB(colour.wavelength_to_XYZ(wavelength))
|
|
217
|
+
return np.clip(rgb, 0.0, 1.0) # Ensure RGB values are within [0, 1]
|
|
218
|
+
|
|
219
|
+
colors = [wavelength_to_rgb(wl) for wl in wavelengths]
|
|
220
|
+
|
|
221
|
+
# Create the color strip as a series of filled segments
|
|
222
|
+
fig, ax = plt.subplots(figsize=(10, 5))
|
|
223
|
+
for i in range(len(wavelengths) - 1):
|
|
224
|
+
plt.fill_between([wavelengths[i], wavelengths[i + 1]], 0, spectrum[i+1], color=colors[i],interpolate=True)
|
|
225
|
+
|
|
226
|
+
# Plot the spectrum curve
|
|
227
|
+
plt.xlim(380, 750)
|
|
228
|
+
plt.ylim(0., 2.0)
|
|
229
|
+
plt.plot(wavelengths, spectrum, color='black')
|
|
230
|
+
plt.show()
|
|
231
|
+
"""
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
# Copyright (c) 2025 Martin Pflaum
|
|
2
|
+
# This file is part of the diffinytrace project, licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"RefractiveIndex",
|
|
7
|
+
"materials"
|
|
8
|
+
]
|
|
9
|
+
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
import torch
|
|
12
|
+
import numpy as np
|
|
13
|
+
from .plotting.wavelength import PlotableWavelength
|
|
14
|
+
|
|
15
|
+
class RefractiveIndex(nn.Module,PlotableWavelength):
|
|
16
|
+
r"""This class is used to calculate the refractive index of a material.
|
|
17
|
+
|
|
18
|
+
At material interfaces, the transmitted direction :math:`\mathbf{D'}` is computed based on the surface normal
|
|
19
|
+
:math:`\mathbf{N} = \nabla s / \|\nabla s\|` and the incident direction :math:`\mathbf{D}`, using Snell's law (see :cite:`do`):
|
|
20
|
+
|
|
21
|
+
.. math::
|
|
22
|
+
|
|
23
|
+
\mathbf{D'} = \mathbf{N} \sqrt{1 - (1 - \cos^2 \psi_i) \eta^2} + \eta (\mathbf{D} - \mathbf{N} \cos \psi_i),
|
|
24
|
+
|
|
25
|
+
where :math:`\cos \psi_i = \mathbf{D} \cdot \mathbf{N}` and :math:`\eta = n / n'` is the ratio of the refractive indices
|
|
26
|
+
of the two materials.
|
|
27
|
+
|
|
28
|
+
We have implemented the refractive indices as a class that is initialized by a refractive index function and the start
|
|
29
|
+
and end wavelengths for which the function is valid. This makes it very convenient to use with the RefractiveIndex.info
|
|
30
|
+
database of optical constants (see :cite:`polyanskiy2024refractiveindex`), since this database often provides Python
|
|
31
|
+
functions for wavelength-dependent refractive indices.
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
Below is an example of how to set up an optical material in our ray tracing library:
|
|
35
|
+
|
|
36
|
+
>>> import diffinytrace as dit
|
|
37
|
+
>>> BaSF = dit.RefractiveIndex(
|
|
38
|
+
>>> lambda x: (1 + 1.65554268 / (1 - 0.0104485644 / x**2) +
|
|
39
|
+
>>> 0.17131977 / (1 - 0.0499394756 / x**2) +
|
|
40
|
+
>>> 1.33664448 / (1 - 118.961472 / x**2))**0.5,
|
|
41
|
+
>>> [0.365, 2.5]
|
|
42
|
+
>>> )
|
|
43
|
+
>>> dit.plotting.wavelength.plot(
|
|
44
|
+
>>> BaSF, title="Refractive index of BaSF (Barium dense flint)"
|
|
45
|
+
>>> )
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
func (callable): A function that takes a wavelength in μm and returns the refractive index.
|
|
49
|
+
bounds (tuple): A tuple containing the minimum and maximum wavelength in μm.
|
|
50
|
+
"""
|
|
51
|
+
def __init__(self,func,bounds):
|
|
52
|
+
nn.Module.__init__(self)
|
|
53
|
+
PlotableWavelength.__init__(self,bounds,"n [1]")
|
|
54
|
+
self.func = func
|
|
55
|
+
self.bounds = bounds
|
|
56
|
+
|
|
57
|
+
def forward(self,wl):
|
|
58
|
+
"""Calculates the refractive index for given wavelengths.
|
|
59
|
+
Args:
|
|
60
|
+
wl (torch.Tensor or float): Wavelength in μm.
|
|
61
|
+
Returns:
|
|
62
|
+
torch.Tensor: Refractive index at the given wavelengths.
|
|
63
|
+
"""
|
|
64
|
+
if not torch.is_tensor(wl):
|
|
65
|
+
wl = torch.tensor(wl)
|
|
66
|
+
vmin,vmax = self.bounds
|
|
67
|
+
if not (((vmin <=wl).float()*(wl<=vmax).float())==1.0).all():
|
|
68
|
+
print(f"The wavelength should be given in μm and between {vmin} and {vmax}. Fallback to constant val.")
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
out = self.func(wl)
|
|
72
|
+
if isinstance(out,float):
|
|
73
|
+
return out*torch.ones_like(wl)
|
|
74
|
+
if isinstance(out,np.ndarray):
|
|
75
|
+
out = torch.tensor(out,device=wl.device,dtype=wl.dtype)
|
|
76
|
+
|
|
77
|
+
out[vmin > wl] = self.func(vmin)
|
|
78
|
+
out[wl>vmax] = self.func(vmax)
|
|
79
|
+
|
|
80
|
+
if torch.is_tensor(out):
|
|
81
|
+
if len(out.shape) == 0:
|
|
82
|
+
return out*torch.ones_like(wl)
|
|
83
|
+
return out
|
|
84
|
+
|
|
85
|
+
"""
|
|
86
|
+
All material data is from https://refractiveindex.info/. Please verify the equation and ranges by ur self and the references.
|
|
87
|
+
"""
|
|
88
|
+
materials = {
|
|
89
|
+
"NONE": RefractiveIndex(lambda x: 1.0,(0.0,torch.inf)),
|
|
90
|
+
"AIR": RefractiveIndex(lambda x: 1+0.05792105/(238.0185-x**-2)+0.00167917/(57.362-x**-2),(0.23,1.69)),#P. E. Ciddor. Refractive index of air: new equations for the visible and near infrared, Appl. Optics 35, 1566-1573 (1996)
|
|
91
|
+
"HELIUM": RefractiveIndex(lambda x: (1+4977.77e-8/(1-28.54e-6/x**2)+1856.94e-8/(1-7.76e-3/x**2))**.5,(0.48,2.06)),#C. R. Mansfield and E. R. Peck. Dispersion of helium, J. Opt. Soc. Am. 59, 199-203 (1969)
|
|
92
|
+
"PMMA": RefractiveIndex(lambda x: (1+0.99654/(1-0.00787/x**2)+0.18964/(1-0.02191/x**2)+0.00411/(1-3.85727/x**2))**.5,(0.405,1.08)),#Marcin Szczurowski
|
|
93
|
+
"NBK7": RefractiveIndex(lambda x: (1+1.03961212/(1-0.00600069867/x**2)+0.231792344/(1-0.0200179144/x**2)+1.01046945/(1-103.560653/x**2))**.5,(0.3,2.5)),#SCHOTT
|
|
94
|
+
"BAF10": RefractiveIndex(lambda x: (1+1.5851495/(1-0.00926681282/x**2)+0.143559385/(1-0.0424489805/x**2)+1.08521269/(1-105.613573/x**2))**.5,(0.35,2.5)),#SCHOTT
|
|
95
|
+
"BAK1": RefractiveIndex(lambda x: (1+1.12365662/(1-0.00644742752/x**2)+0.309276848/(1-0.0222284402/x**2)+0.881511957/(1-107.297751/x**2))**.5,(0.3,2.5)),#SCHOTT
|
|
96
|
+
"FK51A": RefractiveIndex(lambda x: (1+0.971247817/(1-0.00472301995/x**2)+0.216901417/(1-0.0153575612/x**2)+0.904651666/(1-168.68133/x**2))**.5,(0.29,2.5)),#SCHOTT
|
|
97
|
+
"LASF9": RefractiveIndex(lambda x: (1+2.00029547/(1-0.0121426017/x**2)+0.298926886/(1-0.0538736236/x**2)+1.80691843/(1-156.530829/x**2))**.5,(0.365,2.5)),#SCHOTT
|
|
98
|
+
"SF5": RefractiveIndex(lambda x: (1+1.52481889/(1-0.011254756/x**2)+0.187085527/(1-0.0588995392/x**2)+1.42729015/(1-129.141675/x**2))**.5,(0.37,2.5)),#SCHOTT
|
|
99
|
+
"SF10": RefractiveIndex(lambda x: (1+1.62153902/(1-0.0122241457/x**2)+0.256287842/(1-0.0595736775/x**2)+1.64447552/(1-147.468793/x**2))**.5,(0.38,2.5)),#SCHOTT
|
|
100
|
+
"SF11": RefractiveIndex(lambda x: (1+1.73759695/(1-0.013188707/x**2)+0.313747346/(1-0.0623068142/x**2)+1.89878101/(1-155.23629/x**2))**.5,(0.37,2.5)),#SCHOTT
|
|
101
|
+
}
|