diffinytrace 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. diffinytrace/__init__.py +122 -0
  2. diffinytrace/basis_functions/__init__.py +14 -0
  3. diffinytrace/basis_functions/bspline.py +521 -0
  4. diffinytrace/basis_functions/chebyshev.py +3 -0
  5. diffinytrace/basis_functions/legendre.py +77 -0
  6. diffinytrace/basis_functions/zernike.py +235 -0
  7. diffinytrace/config.py +140 -0
  8. diffinytrace/constraints.py +54 -0
  9. diffinytrace/element.py +1660 -0
  10. diffinytrace/export/__init__.py +8 -0
  11. diffinytrace/export/cad.py +253 -0
  12. diffinytrace/gaussian_smoother.py +530 -0
  13. diffinytrace/hat_smoother.py +44 -0
  14. diffinytrace/integrators.py +452 -0
  15. diffinytrace/intersection.py +285 -0
  16. diffinytrace/optimize.py +808 -0
  17. diffinytrace/physical_object.py +150 -0
  18. diffinytrace/plotting/__init__.py +16 -0
  19. diffinytrace/plotting/core.py +92 -0
  20. diffinytrace/plotting/quantity2D.py +188 -0
  21. diffinytrace/plotting/system2D.py +220 -0
  22. diffinytrace/plotting/system3D.py +327 -0
  23. diffinytrace/plotting/wavelength.py +231 -0
  24. diffinytrace/refractive_index.py +101 -0
  25. diffinytrace/render.py +77 -0
  26. diffinytrace/source.py +661 -0
  27. diffinytrace/spectrum.py +79 -0
  28. diffinytrace/surface.py +468 -0
  29. diffinytrace/target_grid.py +399 -0
  30. diffinytrace/transforms.py +472 -0
  31. diffinytrace/utils/__init__.py +7 -0
  32. diffinytrace/utils/autograd.py +116 -0
  33. diffinytrace/utils/irradiance_importer.py +134 -0
  34. diffinytrace-2.1.dist-info/METADATA +26 -0
  35. diffinytrace-2.1.dist-info/RECORD +38 -0
  36. diffinytrace-2.1.dist-info/WHEEL +5 -0
  37. diffinytrace-2.1.dist-info/licenses/LICENSE +21 -0
  38. diffinytrace-2.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,150 @@
1
+ # Copyright (c) 2025 Martin Pflaum
2
+ # This file is part of the diffinytrace project, licensed under the MIT License.
3
+
4
+ __all__ = [
5
+ "PhysicalObject",
6
+ "PhysicalSurface"
7
+ ]
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+
13
+ class PhysicalObject(nn.Module):
14
+ """
15
+ Abstract base class for physical objects in the optical system.
16
+ This class can be used to define surface distance constraints and is
17
+ also used for plotting.
18
+ """
19
+ def __init__(self):
20
+ super().__init__()
21
+
22
+ def get_transformation_matrix(self):
23
+ """
24
+ Returns the transformation matrix of the object.
25
+
26
+ Returns:
27
+ torch.Tensor: The transformation matrix.
28
+ """
29
+ return self.get_transform().get_transformation_matrix()
30
+
31
+ def to_global_dir(self,direction):
32
+ """
33
+ Converts a direction vector from local to global coordinates.
34
+
35
+ Args:
36
+ direction (torch.Tensor): Direction vector in local coordinates.
37
+
38
+ Returns:
39
+ torch.Tensor: Direction vector in global coordinates.
40
+ """
41
+ return self.get_transform().to_global_dir(direction)
42
+
43
+ def to_local_dir(self,direction):
44
+ """
45
+ Converts a direction vector from global to local coordinates.
46
+
47
+ Args:
48
+ direction (torch.Tensor): Direction vector in global coordinates.
49
+
50
+ Returns:
51
+ torch.Tensor: Direction vector in local coordinates.
52
+ """
53
+ return self.get_transform().to_local_dir(direction)
54
+
55
+ def to_global_pos(self,position):
56
+ """
57
+ Converts a position from local to global coordinates.
58
+
59
+ Args:
60
+ position (torch.Tensor): Position in local coordinates.
61
+
62
+ Returns:
63
+ torch.Tensor: Position in global coordinates.
64
+ """
65
+ return self.get_transform().to_global_pos(position)
66
+
67
+ def to_local_pos(self,position):
68
+ """
69
+ Converts a position from global to local coordinates.
70
+
71
+ Args:
72
+ position (torch.Tensor): Position in global coordinates.
73
+
74
+ Returns:
75
+ torch.Tensor: Position in local coordinates.
76
+ """
77
+ return self.get_transform().to_local_pos(position)
78
+
79
+ def get_transform(self):
80
+ """
81
+ Returns the transformation object associated with this physical object.
82
+
83
+ Raises:
84
+ NotImplementedError: If not implemented in subclass.
85
+
86
+ Returns:
87
+ object: Transformation object.
88
+ """
89
+ raise NotImplementedError("PhysicalObject: get_transform not implemented")
90
+
91
+ class PhysicalSurface(PhysicalObject):
92
+ """
93
+ Abstract base class for physical surfaces in the optical system.
94
+ This class can be used to define surface distance constraints and is
95
+ also used for plotting.
96
+ """
97
+ def __init__(self):
98
+ super().__init__()
99
+
100
+ def get_constraint_funs_leq_zero(self):
101
+ """
102
+ Returns constraint functions for the surface that must be less than or equal to zero.
103
+
104
+ Raises:
105
+ NotImplementedError: If not implemented in subclass.
106
+
107
+ Returns:
108
+ list[Callable]: List of constraint functions.
109
+ """
110
+ raise NotImplementedError("PhysicalSurface: get_constraint_funs_geq_zero is not implemented")
111
+
112
+ """
113
+ def get_corners_in_parameter_space(self):
114
+ raise NotImplementedError("get_corners: is not implemented")
115
+
116
+ def get_edge_funcs_in_parameter_space(self):
117
+ raise NotImplementedError("get_edge_funcs_in_parameter_space: is not implemented")
118
+ """
119
+
120
+ def parametric_sample(self, num_points: int, method: str = "sobol") -> tuple[torch.Tensor, torch.Tensor]:
121
+ """
122
+ Samples points on the surface in parameter space.
123
+
124
+ Args:
125
+ num_points (int): Number of points to sample.
126
+ method (str, optional): Sampling method. Defaults to "sobol".
127
+
128
+ Raises:
129
+ NotImplementedError: If not implemented in subclass.
130
+
131
+ Returns:
132
+ tuple[torch.Tensor, torch.Tensor]: Sampled parameter positions and corresponding surface positions.
133
+ """
134
+ raise NotImplementedError("PhysicalSurface: sample() not implemented")
135
+
136
+ def parametric_surface(self, parametric_pos: torch.Tensor) -> torch.Tensor:
137
+ """
138
+ Maps parameter space positions to surface positions.
139
+
140
+ Args:
141
+ parametric_pos (torch.Tensor): Positions in parameter space.
142
+
143
+ Raises:
144
+ NotImplementedError: If not implemented in subclass.
145
+
146
+ Returns:
147
+ torch.Tensor: Surface positions.
148
+ """
149
+ raise NotImplementedError("PhysicalSurface: parametric_surface is not implemented")
150
+
@@ -0,0 +1,16 @@
1
+ # Copyright (c) 2025 Martin Pflaum
2
+ # This file is part of the diffinytrace project, licensed under the MIT License.
3
+
4
+ __all__ = [
5
+ "quantity2D",
6
+ "system2D",
7
+ "system3D",
8
+ "Plotable",
9
+ "wavelength"
10
+ ]
11
+
12
+ from . import quantity2D
13
+ from . import system2D
14
+ from . import system3D
15
+ from .core import Plotable
16
+ from . import wavelength
@@ -0,0 +1,92 @@
1
+ # Copyright (c) 2025 Martin Pflaum
2
+ # This file is part of the diffinytrace project, licensed under the MIT License.
3
+
4
+ __all__ = [
5
+ "Plotable"
6
+ ]
7
+
8
+ from matplotlib.colors import to_rgb
9
+ from typing import List,Tuple,Optional,Union
10
+
11
+ class Plotable:
12
+ """
13
+ Base class for objects that can be visualized in the optical system.
14
+
15
+ This class provides a common interface for objects that support 2D/3D plotting,
16
+ color scale configuration, and hierarchical visualization. Subclasses should implement
17
+ methods for generating plot points and color scales as needed.
18
+
19
+ Attributes:
20
+ fill_color (str): Color used to fill the object in plots.
21
+ outline_color (str): Color used for the object's outline in plots.
22
+ is_volume (bool): If True, the object is treated as a volumetric entity.
23
+ """
24
+ def __init__(self, fill_color:str = "white", outline_color:str = "black", is_volume:bool = False):
25
+ """
26
+ Initialize the plotable object with fill and outline colors.
27
+
28
+ Args:
29
+ fill_color (str): The color used to fill the object.
30
+ outline_color (str): The color used for the outline of the object.
31
+ is_volume (bool): If True, the object is treated as a volume.
32
+
33
+ """
34
+ self.fill_color = fill_color
35
+ self.outline_color = outline_color
36
+ self.is_volume = is_volume
37
+
38
+ def get_plotly_color_scale(self)->List[List[Union[float,str]]]:
39
+ """
40
+ Returns a color scale for Plotly, based on the fill and outline colors.
41
+ """
42
+
43
+ fill_color_rgb = to_rgb(self.fill_color)
44
+ outline_color_rgb = to_rgb(self.outline_color)
45
+
46
+ fill_color_text = f'rgb({int(fill_color_rgb[0] * 255)}, {int(fill_color_rgb[1] * 255)}, {int(fill_color_rgb[2] * 255)})'
47
+ outline_color_text = f'rgb({int(outline_color_rgb[0] * 255)}, {int(outline_color_rgb[1] * 255)}, {int(outline_color_rgb[2] * 255)})'
48
+
49
+ colorscale = [[0,fill_color_text],
50
+ [1,outline_color_text]]
51
+ return [colorscale]
52
+
53
+ def get_plotable_childs(self)->List:
54
+ """
55
+ Returns a list of all plotable child objects of this object.
56
+ Each child is represented as a list containing the child object and its name.
57
+ """
58
+ out = []
59
+ for attr_name in dir(self):
60
+ attr = getattr(self, attr_name)
61
+ if isinstance(attr, Plotable):
62
+ out.append([attr,attr_name])
63
+ return out
64
+
65
+ def get_plot_points_2D(self, resolution:int)->List:
66
+ """
67
+ Returns a list of 2D plot points for the object.
68
+
69
+ Args:
70
+ resolution (int): The resolution for the plot points.
71
+
72
+
73
+ Returns:
74
+ list: A list of 2D plot points.
75
+ """
76
+
77
+ print("get_plot_points_2D not implemented")
78
+ return []
79
+
80
+ def get_plot_points_3D(self, resolution:int)->List:
81
+ """
82
+ Returns a list of 3D plot points for the object.
83
+
84
+ Args:
85
+ resolution (int): The resolution for the plot points.
86
+
87
+ Returns:
88
+ list: A list of 3D plot points.
89
+ """
90
+ print("get_plot_points_3D not implemented")
91
+ return []
92
+
@@ -0,0 +1,188 @@
1
+ # Copyright (c) 2025 Martin Pflaum
2
+ # This file is part of the diffinytrace project, licensed under the MIT License.
3
+
4
+ __all__ = [
5
+ "plot"
6
+ ]
7
+
8
+ import matplotlib.pyplot as plt
9
+ from torch import linspace,meshgrid,zeros,no_grad,is_tensor
10
+ from copy import deepcopy
11
+ import numpy as np
12
+ from typing import Callable,Tuple,Optional,Union
13
+
14
+ def plot(val: Union[Callable, np.ndarray],
15
+ title: str = "",
16
+ x_range: Optional[Tuple[float, float]] = None,
17
+ y_range: Optional[Tuple[float, float]] = None,
18
+ cmap: str = "jet",
19
+ subtitle: str = "",
20
+ title_fontsize: int = 14,
21
+ suptitle_fontsize: int = 12,
22
+ interpolation: str = "none",
23
+ xlabel: str = "x [mm]",
24
+ ylabel: str = "y [mm]",
25
+ colorbar: bool = True,
26
+ norm = None,
27
+ show: bool = True,
28
+ vmin: float | None = None,
29
+ vmax: float | None = None,
30
+ resolution: int = 501,
31
+ **kwargs):
32
+ """
33
+ Plot a 2D quantity using matplotlib.
34
+ This function handles both callable quantities and numpy arrays.
35
+ If a callable is provided, it should accept a 2D array of coordinates and return a 2D array of values.
36
+ The function will create a 2D plot with the specified parameters.
37
+ If a numpy array is provided, it will be plotted directly.
38
+
39
+ Args:
40
+ val (callable or np.ndarray): The quantity to plot. If callable, it should accept a 2D array of coordinates.
41
+ title (str): Title of the plot.
42
+ x_range (tuple): Range of x-axis.
43
+ y_range (tuple): Range of y-axis.
44
+ cmap (str): Colormap to use.
45
+ subtitle (str): Subtitle of the plot.
46
+ title_fontsize (int): Font size of the title.
47
+ suptitle_fontsize (int): Font size of the subtitle.
48
+ interpolation (str): Interpolation method.
49
+ xlabel (str): Label for x-axis.
50
+ ylabel (str): Label for y-axis.
51
+ colorbar (bool): Whether to show colorbar.
52
+ show (bool): Whether to show the plot.
53
+ vmin (float): Minimum value for color normalization.
54
+ vmax (float): Maximum value for color normalization.
55
+ resolution (int): Resolution of the plot.
56
+ **kwargs: Additional arguments for imshow.
57
+
58
+ Returns:
59
+ None
60
+ """
61
+
62
+ if is_tensor(val):
63
+ val = val.detach().cpu().numpy()#.T
64
+
65
+ if not (x_range is None):
66
+ if isinstance(x_range,float):
67
+ x_range = [-x_range,x_range]
68
+ if y_range is None:
69
+ y_range = x_range
70
+ if y_range is None:
71
+ y_range = [0,1.]
72
+
73
+ if x_range is None:
74
+ x_range = [0,1.]
75
+
76
+ if not isinstance(val,np.ndarray):
77
+ _y = linspace(*x_range,resolution)
78
+ _x = linspace(*y_range,resolution)
79
+ mesh = meshgrid(_y,_x)
80
+ y = mesh[0].reshape(-1)
81
+ x = mesh[1].reshape(-1)
82
+ O = zeros((x.shape[0],2))
83
+ O[:,0] = x
84
+ O[:,1] = y
85
+ val = val(O).reshape(resolution,resolution)
86
+ if is_tensor(val):
87
+ val = val.detach().cpu().numpy()
88
+
89
+ val = val[::-1]
90
+
91
+ plt.cla() # Clear axis
92
+ plt.clf() # Clear figure
93
+ fig, ax = plt.subplots()
94
+ mappable = ax.imshow(val,cmap=cmap,interpolation=interpolation,extent=list(x_range)+list(y_range),norm=norm,vmin=vmin,vmax=vmax,**kwargs)
95
+
96
+
97
+ ax.set_ylabel(ylabel)
98
+ ax.set_xlabel(xlabel)
99
+ if subtitle != "":
100
+ fig.suptitle(title,fontsize=title_fontsize)
101
+ ax.set_title(subtitle, fontsize=suptitle_fontsize)
102
+ else:
103
+ ax.set_title(title,fontsize=title_fontsize)
104
+
105
+ if colorbar:
106
+ plt.colorbar(mappable, ax=ax)
107
+
108
+ if show:
109
+ plt.show()
110
+
111
+
112
+ def intensity(val,title="",x_range=None,y_range=None,cmap="jet",interpolation="none",xlabel="x [mm]",ylabel="y [mm]",norm=None,show=True,vmin: float | None = None,vmax: float | None = None,**kwargs):
113
+ """
114
+ Plot a 2D intensity distribution using matplotlib.
115
+
116
+ Args:
117
+ val (callable, np.ndarray, or torch.Tensor): The intensity data to plot. If callable, it should accept a 2D array of coordinates and return a 2D array of values.
118
+ title (str, optional): Title of the plot.
119
+ x_range (tuple or None, optional): Range of x-axis. If None, defaults to [0, 1].
120
+ y_range (tuple or None, optional): Range of y-axis. If None, defaults to [0, 1].
121
+ cmap (str, optional): Colormap to use. Default is "jet".
122
+ interpolation (str, optional): Interpolation method for imshow. Default is "none".
123
+ xlabel (str, optional): Label for x-axis. Default is "x [mm]".
124
+ ylabel (str, optional): Label for y-axis. Default is "y [mm]".
125
+ norm (matplotlib.colors.Normalize or None, optional): Normalization for color mapping.
126
+ show (bool, optional): Whether to display the plot. Default is True.
127
+ vmin (float or None, optional): Minimum value for color normalization.
128
+ vmax (float or None, optional): Maximum value for color normalization.
129
+ **kwargs: Additional keyword arguments passed to matplotlib's imshow.
130
+
131
+ Returns:
132
+ None
133
+ """
134
+ plot(val,f"{title} [$W/mm^2$]",x_range,y_range,cmap=cmap,interpolation=interpolation,xlabel=xlabel,ylabel=ylabel,norm=norm,show=show,vmin=vmin,vmax=vmax,**kwargs)
135
+
136
+ def height(val,title="",x_range=None,y_range=None,cmap="cool",interpolation="none",xlabel="x [mm]",ylabel="y [mm]",norm=None,show=True,vmin: float | None = None,vmax: float | None = None,**kwargs):
137
+ """
138
+ Plot a 2D height distribution using matplotlib.
139
+
140
+ Args:
141
+ val (callable, np.ndarray, or torch.Tensor): The height data to plot. If callable, it should accept a 2D array of coordinates and return a 2D array of values.
142
+ title (str, optional): Title of the plot.
143
+ x_range (tuple or None, optional): Range of x-axis. If None, defaults to [0, 1].
144
+ y_range (tuple or None, optional): Range of y-axis. If None, defaults to [0, 1].
145
+ cmap (str, optional): Colormap to use. Default is "cool".
146
+ interpolation (str, optional): Interpolation method for imshow. Default is "none".
147
+ xlabel (str, optional): Label for x-axis. Default is "x [mm]".
148
+ ylabel (str, optional): Label for y-axis. Default is "y [mm]".
149
+ norm (matplotlib.colors.Normalize or None, optional): Normalization for color mapping.
150
+ show (bool, optional): Whether to display the plot. Default is True.
151
+ vmin (float or None, optional): Minimum value for color normalization.
152
+ vmax (float or None, optional): Maximum value for color normalization.
153
+ **kwargs: Additional keyword arguments passed to matplotlib's imshow.
154
+
155
+ Returns:
156
+ None
157
+ """
158
+ plot(val,f"Height z [$mm$] of {title}",x_range,y_range,cmap=cmap,interpolation=interpolation,xlabel=xlabel,ylabel=ylabel,norm=norm,show=show,vmin=vmin,vmax=vmax,**kwargs)
159
+
160
+ """
161
+ TODO: these functions need testing again before integration
162
+
163
+ def surface(surface,name,aperture_radius,resolution=256,is_square=True,norm=None,show=True,**kwargs):
164
+ surface = deepcopy(surface)
165
+ surface = surface.cpu()
166
+ x_range = (-aperture_radius,aperture_radius)
167
+ y_range = (-aperture_radius,aperture_radius)
168
+ _x = linspace(-aperture_radius,aperture_radius,resolution)
169
+ _y = linspace(-aperture_radius,aperture_radius,resolution)
170
+ mesh = meshgrid(_x,_y)
171
+ x = mesh[0].reshape(-1)
172
+ y = mesh[1].reshape(-1)
173
+ O = zeros((x.shape[0],3))
174
+
175
+ O[:,0] = x
176
+ O[:,1] = y
177
+ z = None
178
+
179
+ with no_grad():
180
+ z = surface.functional(O,*surface.get_functional_param_args())
181
+
182
+ if not is_square:
183
+ z[O[:,[0,1]].norm(dim=-1)>aperture_radius] = float("nan")
184
+
185
+ z = z.detach().reshape(resolution,resolution)
186
+ height(z,name,x_range,y_range,norm=norm,show=show,**kwargs)
187
+ """
188
+
@@ -0,0 +1,220 @@
1
+ # Copyright (c) 2025 Martin Pflaum
2
+ # This file is part of the diffinytrace project, licensed under the MIT License.
3
+
4
+ __all__ = [
5
+ "annotate_position_simple",
6
+ "annotate_position",
7
+ "annotated_arrow",
8
+ "layout",
9
+ "ray_paths",
10
+ "_plot_surface",
11
+ "_plot_surface_recursively",
12
+ "plot"
13
+ ]
14
+
15
+ import torch
16
+ import matplotlib.pyplot as plt
17
+ import numpy as np
18
+ import matplotlib.patches as patches
19
+ import matplotlib.colors as mcolors
20
+ from copy import deepcopy
21
+
22
+ def annotate_position_simple(nz,ny,name):
23
+ """
24
+ Annotate the position of a point in 2D space using its coordinates.
25
+
26
+ Args:
27
+ nz (torch.Tensor): z-coordinates of the point.
28
+ ny (torch.Tensor): y-coordinates of the point.
29
+ name (str): Text label to annotate at the position.
30
+
31
+ Returns:
32
+ None
33
+ """
34
+ zdiff = (torch.max(nz)-torch.min(nz))
35
+ ydiff = (torch.max(ny)-torch.min(ny))
36
+ offset = max(ydiff*0.05,zdiff*0.025)
37
+ argmax = torch.argmax(ny)
38
+ zpos = nz[argmax]#torch.min(nz)+zdiff*0.5
39
+ ypos = ny[argmax]
40
+ fontsize = 10
41
+ #-len(name)*fontsize/4.0
42
+ plt.annotate(name,xy=(zpos, ypos),fontsize=fontsize,xytext=(0.0,fontsize*0.5), textcoords='offset points')
43
+
44
+
45
+ def annotate_position(position,offset,name,color="black",**kwargs):
46
+ """
47
+ Annotate a point in 2D space with an arrow and label.
48
+
49
+ Args:
50
+ position (tuple): (z, y) coordinates of the point.
51
+ offset (tuple): Offset for the annotation text.
52
+ name (str): Text label to annotate.
53
+ color (str): Color of the annotation and arrow.
54
+
55
+ Returns:
56
+ None
57
+ """
58
+ plt.annotate(name,color=color,xy=position,xytext=offset, textcoords='offset points',arrowprops=dict(arrowstyle="->",color=color,linewidth=1.5, mutation_scale=10), **kwargs)
59
+
60
+
61
+ def annotated_arrow(start,end,offset,name,arrowstyle,color="black",**kwargs):
62
+ """
63
+ Draw and annotate an arrow between two points in 2D space.
64
+
65
+ Args:
66
+ start (tuple): Start position (z, y) of the arrow.
67
+ end (tuple): End position (z, y) of the arrow.
68
+ offset (tuple): Offset for the annotation text.
69
+ name (str): Text label to annotate.
70
+ arrowstyle (str): Matplotlib arrow style string.
71
+ color (str): Color of the arrow and annotation.
72
+
73
+ Returns:
74
+ None
75
+ """
76
+
77
+ arrow_patch = patches.FancyArrowPatch(start, end, arrowstyle=arrowstyle,linewidth=1.5, mutation_scale=10,color=color)
78
+ plt.gca().add_patch(arrow_patch)
79
+ middle = (start[0]+ (end[0]-start[0])*0.5,start[1]+ (end[1]-start[1])*0.5)
80
+
81
+ plt.annotate(name,xy=middle,xytext=offset, textcoords='offset points',color=color,**kwargs)
82
+
83
+ def layout():
84
+ """
85
+ Set up the layout for the 2D plot, including margins, aspect ratio, and axis labels.
86
+
87
+ Returns:
88
+ None
89
+ """
90
+ #plt.grid(True)
91
+ plt.margins(x=0.1,y=0.1)
92
+ plt.gca().set_aspect('equal')
93
+ plt.ylabel("y [mm]")
94
+ plt.xlabel("z [mm]")
95
+
96
+ def ray_paths(rays,ray_color="#85549c",ray_linewidth=1.25):
97
+ """
98
+ Plot ray paths projected onto the y-z plane.
99
+
100
+ Args:
101
+ rays (list[torch.Tensor]): List of ray paths to plot.
102
+ ray_color (str): Color of the rays.
103
+ ray_linewidth (float): Line width of the rays.
104
+
105
+ Returns:
106
+ None
107
+ """
108
+ ray_color = mcolors.to_hex(ray_color)
109
+ print("WARNING: ray_paths will project the ray position onto the y-z plane!")
110
+ pathsA = rays
111
+ if torch.is_tensor(rays[0]):
112
+ pathsA = np.array([elem.numpy() for elem in rays])
113
+ pathsA = np.array(pathsA)
114
+
115
+ for iray in range(pathsA.shape[1]):
116
+ plt.plot(pathsA[:,iray,2],pathsA[:,iray,1],color=ray_color,linewidth=ray_linewidth)
117
+
118
+
119
+
120
+ def _plot_surface(surface,name,resolution,annotate,fill_color,outline_color,linewidth):
121
+ """
122
+ Plot a 2D surface and optionally annotate it.
123
+
124
+ Args:
125
+ surface: Object with get_plot_points_2D method.
126
+ name (str): Name for annotation.
127
+ resolution (int): Resolution for the surface plot.
128
+ annotate (bool): Whether to annotate the surface.
129
+ fill_color (str): Fill color for the surface.
130
+ outline_color (str): Outline color for the surface.
131
+ linewidth (float): Line width for the surface.
132
+
133
+ Returns:
134
+ None
135
+ """
136
+ surface_list = surface.get_plot_points_2D(resolution)
137
+ if len(surface_list)==0:
138
+ return
139
+ if fill_color is None:
140
+ fill_color = surface.fill_color
141
+ if outline_color is None:
142
+ outline_color = surface.outline_color
143
+
144
+ zs,ys = torch.cat([z for z,y in surface_list]),torch.cat([y for z,y in surface_list])
145
+ if annotate:
146
+ annotate_position_simple(zs,ys,name)
147
+ if surface.is_volume:
148
+ ax = plt.gca()
149
+ ax.fill(zs, ys, facecolor=fill_color, edgecolor=outline_color, linewidth=linewidth)
150
+ else:
151
+ for z,y in surface_list:
152
+ plt.plot(z,y,color=outline_color,label="",linewidth=linewidth)
153
+
154
+
155
+ def _plot_surface_recursively(current_elem,name,resolution=200,annotate=False,fill_color=None,outline_color=None,linewidth=None):
156
+ """
157
+ Recursively plot a surface and its plotable children in 2D.
158
+
159
+ Args:
160
+ current_elem: The current plotable element.
161
+ name (str): Name for annotation.
162
+ resolution (int): Resolution for the surface plot.
163
+ annotate (bool): Whether to annotate the surface.
164
+ fill_color (str): Fill color for the surface.
165
+ outline_color (str): Outline color for the surface.
166
+ linewidth (float): Line width for the surface.
167
+
168
+ Returns:
169
+ None
170
+ """
171
+ _plot_surface(current_elem,name,resolution,annotate,fill_color,outline_color,linewidth)
172
+ for elem,elem_name in current_elem.get_plotable_childs():
173
+ _plot_surface_recursively(elem,elem_name,resolution,annotate,fill_color,outline_color,linewidth)
174
+
175
+
176
+ def plot(element=None,rays=None,resolution=200,annotate=False,ray_color="#85549c",ray_linewidth=1.25,fill_color=None,outline_color=None,linewidth=None,show=True):
177
+ """
178
+ Plot a 2D surface and optionally ray paths.
179
+
180
+ Args:
181
+ element: The element to plot (must implement Plotable interface).
182
+ rays (list[torch.Tensor]): List of ray paths to plot.
183
+ resolution (int): Resolution for the surface plot.
184
+ annotate (bool): Whether to annotate the surface.
185
+ ray_color (str): Color of the rays.
186
+ ray_linewidth (float): Line width of the rays.
187
+ fill_color (str): Fill color for the surface.
188
+ outline_color (str): Outline color for the surface.
189
+ linewidth (float): Line width for the surface.
190
+ show (bool): Whether to display the plot immediately.
191
+
192
+ Returns:
193
+ None
194
+ """
195
+
196
+ layout()
197
+
198
+ if isinstance(element,(list,tuple)):
199
+ for subelem in element:
200
+ subelem = deepcopy(subelem)
201
+ subelem = subelem.to("cpu")
202
+ _plot_surface_recursively(subelem,"",resolution,annotate,fill_color,outline_color,linewidth)
203
+
204
+ elif not element is None:
205
+ element = deepcopy(element)
206
+ element = element.to("cpu")
207
+ _plot_surface_recursively(element,"",resolution,annotate,fill_color,outline_color,linewidth)
208
+
209
+
210
+ if not rays is None:
211
+ if isinstance(rays,dict):
212
+ rays = rays["ray_paths"]
213
+
214
+ if torch.is_tensor(rays[0]):
215
+ rays = [elem.cpu() for elem in rays]
216
+
217
+ ray_paths(rays,ray_color=ray_color,ray_linewidth=ray_linewidth)
218
+ if show:
219
+ plt.show()
220
+