diffinytrace 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffinytrace/__init__.py +122 -0
- diffinytrace/basis_functions/__init__.py +14 -0
- diffinytrace/basis_functions/bspline.py +521 -0
- diffinytrace/basis_functions/chebyshev.py +3 -0
- diffinytrace/basis_functions/legendre.py +77 -0
- diffinytrace/basis_functions/zernike.py +235 -0
- diffinytrace/config.py +140 -0
- diffinytrace/constraints.py +54 -0
- diffinytrace/element.py +1660 -0
- diffinytrace/export/__init__.py +8 -0
- diffinytrace/export/cad.py +253 -0
- diffinytrace/gaussian_smoother.py +530 -0
- diffinytrace/hat_smoother.py +44 -0
- diffinytrace/integrators.py +452 -0
- diffinytrace/intersection.py +285 -0
- diffinytrace/optimize.py +808 -0
- diffinytrace/physical_object.py +150 -0
- diffinytrace/plotting/__init__.py +16 -0
- diffinytrace/plotting/core.py +92 -0
- diffinytrace/plotting/quantity2D.py +188 -0
- diffinytrace/plotting/system2D.py +220 -0
- diffinytrace/plotting/system3D.py +327 -0
- diffinytrace/plotting/wavelength.py +231 -0
- diffinytrace/refractive_index.py +101 -0
- diffinytrace/render.py +77 -0
- diffinytrace/source.py +661 -0
- diffinytrace/spectrum.py +79 -0
- diffinytrace/surface.py +468 -0
- diffinytrace/target_grid.py +399 -0
- diffinytrace/transforms.py +472 -0
- diffinytrace/utils/__init__.py +7 -0
- diffinytrace/utils/autograd.py +116 -0
- diffinytrace/utils/irradiance_importer.py +134 -0
- diffinytrace-2.1.dist-info/METADATA +26 -0
- diffinytrace-2.1.dist-info/RECORD +38 -0
- diffinytrace-2.1.dist-info/WHEEL +5 -0
- diffinytrace-2.1.dist-info/licenses/LICENSE +21 -0
- diffinytrace-2.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,472 @@
|
|
|
1
|
+
# Copyright (c) 2025 Martin Pflaum
|
|
2
|
+
# This file is part of the diffinytrace project, licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"Transform",
|
|
6
|
+
"Identity",
|
|
7
|
+
"Compose",
|
|
8
|
+
"Offset",
|
|
9
|
+
"Distance",
|
|
10
|
+
"Rotation",
|
|
11
|
+
"rotation_matrix_x",
|
|
12
|
+
"rotation_matrix_y",
|
|
13
|
+
"rotation_matrix_z"
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
import torch.nn as nn
|
|
18
|
+
from .intersection import SemiFunctionalModule,cat_semi_functionals
|
|
19
|
+
import numpy as np
|
|
20
|
+
from .optimize import make_parameter_from_input
|
|
21
|
+
|
|
22
|
+
class Transform(SemiFunctionalModule):
|
|
23
|
+
"""
|
|
24
|
+
Base class for coordinate transformations.
|
|
25
|
+
|
|
26
|
+
This class provides interfaces to transform directions and positions between
|
|
27
|
+
local and global coordinate systems using homogeneous coordinates.
|
|
28
|
+
|
|
29
|
+
Methods:
|
|
30
|
+
get_functional_param_args(): Return parameters required for the transformation.
|
|
31
|
+
functional(O, *params): Apply transformation in functional style.
|
|
32
|
+
get_transformation_matrix(): Return the 4x4 transformation matrix.
|
|
33
|
+
to_global_dir(direction): Transform direction to global space.
|
|
34
|
+
to_local_dir(direction): Transform direction to local space.
|
|
35
|
+
to_global_pos(position): Transform position to global space.
|
|
36
|
+
to_local_pos(position): Transform position to local space.
|
|
37
|
+
"""
|
|
38
|
+
def __init__(self):
|
|
39
|
+
super().__init__()
|
|
40
|
+
|
|
41
|
+
def get_functional_param_args(self):
|
|
42
|
+
"""
|
|
43
|
+
Return parameters required for the transformation which constructs the surfaces through the functional.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
list: List of parameters required for the functional which constructs the surfaces.
|
|
47
|
+
"""
|
|
48
|
+
raise NotImplementedError("params_list not implemented")
|
|
49
|
+
|
|
50
|
+
@staticmethod
|
|
51
|
+
def functional(O,*params)->torch.Tensor:
|
|
52
|
+
"""
|
|
53
|
+
Apply transformation in functional style. This is global to local.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
O (torch.Tensor): Input tensor to be transformed.
|
|
57
|
+
*params: Parameters for the transformation.
|
|
58
|
+
|
|
59
|
+
"""
|
|
60
|
+
raise NotImplementedError("functional not implemented")
|
|
61
|
+
|
|
62
|
+
def get_transformation_matrix(self,device=None,dtype=None)->torch.Tensor:
|
|
63
|
+
"""
|
|
64
|
+
Return the 4x4 transformation matrix.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
device (torch.device, optional): Device for the matrix.
|
|
68
|
+
dtype (torch.dtype, optional): Data type for the matrix.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
torch.Tensor: 4x4 transformation matrix.
|
|
72
|
+
"""
|
|
73
|
+
raise NotImplementedError("get_transformation_matrix not implemented")
|
|
74
|
+
|
|
75
|
+
def get_transform(self):
|
|
76
|
+
"""
|
|
77
|
+
Returns itself.
|
|
78
|
+
"""
|
|
79
|
+
return self
|
|
80
|
+
|
|
81
|
+
def to_global_dir(self,direction:torch.Tensor) -> torch.Tensor:
|
|
82
|
+
"""
|
|
83
|
+
Transform direction to global space.
|
|
84
|
+
Args:
|
|
85
|
+
direction (torch.Tensor): Direction vector in local space.
|
|
86
|
+
Returns:
|
|
87
|
+
torch.Tensor: Direction vector in global space.
|
|
88
|
+
"""
|
|
89
|
+
M = self.get_transformation_matrix(direction.device,direction.dtype)
|
|
90
|
+
R = M[np.ix_([0,1,2],[0,1,2])]
|
|
91
|
+
out = direction@R.T
|
|
92
|
+
return out
|
|
93
|
+
|
|
94
|
+
def to_local_dir(self,direction:torch.Tensor) -> torch.Tensor:
|
|
95
|
+
"""
|
|
96
|
+
Transform direction to local space.
|
|
97
|
+
Args:
|
|
98
|
+
direction (torch.Tensor): Direction vector in global space.
|
|
99
|
+
Returns:
|
|
100
|
+
torch.Tensor: Direction vector in local space.
|
|
101
|
+
"""
|
|
102
|
+
M = self.get_transformation_matrix(direction.device,direction.dtype)
|
|
103
|
+
R = M[np.ix_([0,1,2],[0,1,2])]
|
|
104
|
+
R_inv = torch.inverse(R)
|
|
105
|
+
out = direction@R_inv.T
|
|
106
|
+
return out
|
|
107
|
+
|
|
108
|
+
def to_global_pos(self,position:torch.Tensor) -> torch.Tensor:
|
|
109
|
+
"""
|
|
110
|
+
Transform position to global space.
|
|
111
|
+
Args:
|
|
112
|
+
position (torch.Tensor): Position vector in local space.
|
|
113
|
+
Returns:
|
|
114
|
+
torch.Tensor: Position vector in global space.
|
|
115
|
+
"""
|
|
116
|
+
M = self.get_transformation_matrix(position.device,position.dtype)
|
|
117
|
+
v = torch.zeros((position.shape[0],4),device=position.device,dtype=position.dtype)
|
|
118
|
+
v[:,[0,1,2]] = position
|
|
119
|
+
v[:,3] = torch.ones_like(v[:,3])
|
|
120
|
+
_out = v@M.T
|
|
121
|
+
out = _out[:,[0,1,2]]
|
|
122
|
+
return out
|
|
123
|
+
|
|
124
|
+
def to_local_pos(self,position:torch.Tensor) -> torch.Tensor:
|
|
125
|
+
"""
|
|
126
|
+
Transform position to local space.
|
|
127
|
+
Args:
|
|
128
|
+
position (torch.Tensor): Position vector in global space.
|
|
129
|
+
Returns:
|
|
130
|
+
torch.Tensor: Position vector in local space.
|
|
131
|
+
"""
|
|
132
|
+
return self.functional(position,*self.get_functional_param_args())
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class Identity(Transform):
|
|
136
|
+
"""
|
|
137
|
+
Identity transformation that returns input positions unchanged.
|
|
138
|
+
|
|
139
|
+
Example:
|
|
140
|
+
>>> import diffinytrace as dit
|
|
141
|
+
>>> transf1 = dit.transforms.Identity()
|
|
142
|
+
"""
|
|
143
|
+
def __init__(self):
|
|
144
|
+
super().__init__()
|
|
145
|
+
|
|
146
|
+
def get_functional_param_args(self):
|
|
147
|
+
return []
|
|
148
|
+
|
|
149
|
+
@staticmethod
|
|
150
|
+
def functional(O:torch.Tensor) -> torch.Tensor:
|
|
151
|
+
return O
|
|
152
|
+
|
|
153
|
+
def get_transformation_matrix(self,device=None,dtype=None) -> torch.Tensor:
|
|
154
|
+
out = torch.eye(4,device=device,dtype=dtype)
|
|
155
|
+
return out
|
|
156
|
+
|
|
157
|
+
class Compose(Transform):
|
|
158
|
+
"""
|
|
159
|
+
Compose multiple transforms in sequence.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
transform_list (list[Transform]): List of transformations to apply in order.
|
|
163
|
+
"""
|
|
164
|
+
def __init__(self,transform_list):
|
|
165
|
+
super().__init__()
|
|
166
|
+
self.transform_list = nn.ModuleList(transform_list)
|
|
167
|
+
self.functional = cat_semi_functionals(self.transform_list)
|
|
168
|
+
|
|
169
|
+
def get_functional_param_args(self):
|
|
170
|
+
out = []
|
|
171
|
+
for elem in self.transform_list:
|
|
172
|
+
out += elem.get_functional_param_args()
|
|
173
|
+
return out
|
|
174
|
+
|
|
175
|
+
def get_transformation_matrix(self,device=None,dtype=None) -> torch.Tensor:
|
|
176
|
+
out = torch.eye(4,device=device,dtype=dtype)
|
|
177
|
+
for elem in self.transform_list:
|
|
178
|
+
tmp = elem.get_transformation_matrix(device,dtype)
|
|
179
|
+
if not device is None and tmp.device != device:
|
|
180
|
+
tmp = tmp.to(device)
|
|
181
|
+
if not dtype is None and tmp.dtype != dtype:
|
|
182
|
+
tmp = tmp.to(dtype)
|
|
183
|
+
if out.device != tmp.device:
|
|
184
|
+
tmp = tmp.to(device)
|
|
185
|
+
if out.dtype != tmp.dtype:
|
|
186
|
+
tmp = tmp.to(dtype)
|
|
187
|
+
out = out @ tmp
|
|
188
|
+
return out
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class Offset(Transform):
|
|
192
|
+
r"""
|
|
193
|
+
Translation transform using an offset vector.
|
|
194
|
+
|
|
195
|
+
The offset transformation shifts a position by a specified vector
|
|
196
|
+
\( \vec{w} = (w_x, w_y, w_z) \). The transformation matrix \( M \)
|
|
197
|
+
for an offset transformation is:
|
|
198
|
+
|
|
199
|
+
.. math::
|
|
200
|
+
|
|
201
|
+
M^{offset}(w_x, w_y, w_z) =
|
|
202
|
+
\begin{bmatrix}
|
|
203
|
+
1 & 0 & 0 & w_x \\
|
|
204
|
+
0 & 1 & 0 & w_y \\
|
|
205
|
+
0 & 0 & 1 & w_z \\
|
|
206
|
+
0 & 0 & 0 & 1
|
|
207
|
+
\end{bmatrix}
|
|
208
|
+
|
|
209
|
+
Example:
|
|
210
|
+
>>> import diffinytrace as dit
|
|
211
|
+
>>> transf1 = dit.transforms.Identity()
|
|
212
|
+
>>> transf2 = dit.transforms.Offset([1.0, 2.0, 3.0], parent_transform=transf1)
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
pos (Tensor or list or float): The offset position as a 3D vector.
|
|
216
|
+
parent_transform (Transform, optional): Optional parent transformation.
|
|
217
|
+
"""
|
|
218
|
+
def __init__(self,pos,parent_transform=Identity()):
|
|
219
|
+
super().__init__()
|
|
220
|
+
self.pos = make_parameter_from_input(pos)
|
|
221
|
+
self.parent_transform = parent_transform.get_transform()
|
|
222
|
+
|
|
223
|
+
def get_functional_param_args(self):
|
|
224
|
+
return [self.pos]+self.parent_transform.get_functional_param_args()
|
|
225
|
+
|
|
226
|
+
def functional(self,O:torch.Tensor,pos:torch.Tensor,*parent_param_args)->torch.Tensor:
|
|
227
|
+
O = self.parent_transform.functional(O,*parent_param_args)
|
|
228
|
+
return O-pos
|
|
229
|
+
|
|
230
|
+
def get_transformation_matrix(self,device=None,dtype=None) -> torch.Tensor:
|
|
231
|
+
if device is None:
|
|
232
|
+
device = self.pos.device
|
|
233
|
+
if dtype is None:
|
|
234
|
+
dtype = self.pos.dtype
|
|
235
|
+
parent_transform_matrix = self.parent_transform.get_transformation_matrix(device=device,dtype=dtype)
|
|
236
|
+
this_matrix = torch.eye(4,device=device,dtype=dtype)
|
|
237
|
+
|
|
238
|
+
this_matrix[[0,1,2],-1] = self.pos.to(device=device,dtype=dtype)
|
|
239
|
+
|
|
240
|
+
out = parent_transform_matrix@this_matrix
|
|
241
|
+
return out
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class Distance(Transform):
|
|
245
|
+
r"""
|
|
246
|
+
Applies a translation along a specific axis by a given distance.
|
|
247
|
+
|
|
248
|
+
The distance transformation applies a translation by a specific distance along a given axis
|
|
249
|
+
(e.g., \( x \)-, \( y \)-, or \( z \)-axis). The transformation matrix \( M \) for a distance
|
|
250
|
+
transformation along the \( z \)-axis is given by:
|
|
251
|
+
|
|
252
|
+
.. math::
|
|
253
|
+
|
|
254
|
+
M^{dist}_z(d) =
|
|
255
|
+
\begin{bmatrix}
|
|
256
|
+
1 & 0 & 0 & 0 \\
|
|
257
|
+
0 & 1 & 0 & 0 \\
|
|
258
|
+
0 & 0 & 1 & d \\
|
|
259
|
+
0 & 0 & 0 & 1
|
|
260
|
+
\end{bmatrix},
|
|
261
|
+
|
|
262
|
+
where \( d \) represents the distance of translation along the \( z \)-axis.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
distance (float or Tensor): Distance to translate.
|
|
266
|
+
axis (int): Axis along which translation is applied (0=X, 1=Y, 2=Z).
|
|
267
|
+
parent_transform (Transform): Optional parent transformation.
|
|
268
|
+
|
|
269
|
+
Example:
|
|
270
|
+
>>> import diffinytrace as dit
|
|
271
|
+
>>> transf1 = dit.transforms.Identity()
|
|
272
|
+
>>> transf2 = dit.transforms.Distance(10.0,axis=2,parent_transform=transf1)
|
|
273
|
+
|
|
274
|
+
Notes:
|
|
275
|
+
For the local to global transformation it applies the following transformation:
|
|
276
|
+
|
|
277
|
+
.. math::
|
|
278
|
+
|
|
279
|
+
\mathbf{x}_\text{local} = \mathbf{x}_\text{parent} + d \cdot \mathbf{e}_i
|
|
280
|
+
"""
|
|
281
|
+
def __init__(self,distance,axis = 2,parent_transform=Identity()):
|
|
282
|
+
super().__init__()
|
|
283
|
+
self.distance = make_parameter_from_input(distance)
|
|
284
|
+
self.unit_vec = torch.tensor([0.,0.,0.])
|
|
285
|
+
#self.register_buffer('unit_vec', torch.tensor([0.,0.,0.])) # Buffer attribute
|
|
286
|
+
|
|
287
|
+
self.unit_vec[axis] = 1.0 #is constant!
|
|
288
|
+
self.parent_transform = parent_transform.get_transform()
|
|
289
|
+
|
|
290
|
+
def get_functional_param_args(self):
|
|
291
|
+
unit_vec = self.unit_vec
|
|
292
|
+
if unit_vec.device != self.distance.device:
|
|
293
|
+
unit_vec = unit_vec.to(self.distance.device)
|
|
294
|
+
return [self.distance,unit_vec]+self.parent_transform.get_functional_param_args()
|
|
295
|
+
|
|
296
|
+
def functional(self,O:torch.Tensor,distance:torch.Tensor,unit_vec:torch.Tensor,*parent_param_args)->torch.Tensor:
|
|
297
|
+
O = self.parent_transform.functional(O,*parent_param_args)
|
|
298
|
+
O = O-distance*unit_vec
|
|
299
|
+
return O
|
|
300
|
+
|
|
301
|
+
def get_transformation_matrix(self,device=None,dtype=None)->torch.Tensor:
|
|
302
|
+
if device is None:
|
|
303
|
+
device = self.distance.device
|
|
304
|
+
if dtype is None:
|
|
305
|
+
dtype = self.distance.dtype
|
|
306
|
+
|
|
307
|
+
unit_vec = self.unit_vec.to(device=device,dtype=dtype)
|
|
308
|
+
parent_transform_matrix = self.parent_transform.get_transformation_matrix(device,dtype)
|
|
309
|
+
this_matrix = torch.eye(4,device=device,dtype=dtype)
|
|
310
|
+
this_matrix[[0,1,2],-1] = self.distance.to(device=device,dtype=dtype)*unit_vec
|
|
311
|
+
out = parent_transform_matrix@this_matrix
|
|
312
|
+
return out
|
|
313
|
+
|
|
314
|
+
def rotation_matrix_x(angle:torch.Tensor) -> torch.Tensor:
|
|
315
|
+
"""
|
|
316
|
+
Construct a 3x3 rotation matrix around the X-axis.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
angle (Tensor): Angle in degrees.
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
Tensor: 3x3 rotation matrix.
|
|
323
|
+
"""
|
|
324
|
+
# Convert angle from degrees to radians
|
|
325
|
+
angle = angle * (2.0 * torch.pi / 360.0)
|
|
326
|
+
device = angle.device
|
|
327
|
+
dtype = angle.dtype
|
|
328
|
+
|
|
329
|
+
# Initialize a 4x4 identity matrix
|
|
330
|
+
rot_x = torch.eye(3, dtype=dtype, device=device)
|
|
331
|
+
|
|
332
|
+
# Set the rotation entries
|
|
333
|
+
rot_x[1, 1] = torch.cos(angle)
|
|
334
|
+
rot_x[1, 2] = -torch.sin(angle)
|
|
335
|
+
rot_x[2, 1] = torch.sin(angle)
|
|
336
|
+
rot_x[2, 2] = torch.cos(angle)
|
|
337
|
+
|
|
338
|
+
return rot_x
|
|
339
|
+
|
|
340
|
+
def rotation_matrix_y(angle:torch.Tensor) -> torch.Tensor:
|
|
341
|
+
"""
|
|
342
|
+
Construct a 3x3 rotation matrix around the Y-axis.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
angle (Tensor): Angle in degrees.
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
Tensor: 3x3 rotation matrix.
|
|
349
|
+
"""
|
|
350
|
+
# Convert angle from degrees to radians
|
|
351
|
+
angle = angle * (2.0 * torch.pi / 360.0)
|
|
352
|
+
device = angle.device
|
|
353
|
+
dtype = angle.dtype
|
|
354
|
+
|
|
355
|
+
# Initialize a 4x4 identity matrix
|
|
356
|
+
rot_y = torch.eye(3, dtype=dtype, device=device)
|
|
357
|
+
|
|
358
|
+
# Set the rotation entries
|
|
359
|
+
rot_y[0, 0] = torch.cos(angle)
|
|
360
|
+
rot_y[0, 2] = torch.sin(angle)
|
|
361
|
+
rot_y[2, 0] = -torch.sin(angle)
|
|
362
|
+
rot_y[2, 2] = torch.cos(angle)
|
|
363
|
+
|
|
364
|
+
return rot_y
|
|
365
|
+
|
|
366
|
+
def rotation_matrix_z(angle:torch.Tensor) -> torch.Tensor:
|
|
367
|
+
"""
|
|
368
|
+
Construct a 3x3 rotation matrix around the Z-axis.
|
|
369
|
+
|
|
370
|
+
Args:
|
|
371
|
+
angle (Tensor): Angle in degrees.
|
|
372
|
+
|
|
373
|
+
Returns:
|
|
374
|
+
Tensor: 3x3 rotation matrix.
|
|
375
|
+
"""
|
|
376
|
+
# Convert angle from degrees to radians
|
|
377
|
+
angle = angle * (2.0 * torch.pi / 360.0)
|
|
378
|
+
device = angle.device
|
|
379
|
+
dtype = angle.dtype
|
|
380
|
+
|
|
381
|
+
# Initialize a 4x4 identity matrix
|
|
382
|
+
rot_z = torch.eye(3, dtype=dtype, device=device)
|
|
383
|
+
|
|
384
|
+
# Set the rotation entries
|
|
385
|
+
rot_z[0, 0] = torch.cos(angle)
|
|
386
|
+
rot_z[0, 1] = -torch.sin(angle)
|
|
387
|
+
rot_z[1, 0] = torch.sin(angle)
|
|
388
|
+
rot_z[1, 1] = torch.cos(angle)
|
|
389
|
+
|
|
390
|
+
return rot_z
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
class Rotation(Transform):
|
|
394
|
+
r"""
|
|
395
|
+
Applies a 3D rotation around a principal axis.
|
|
396
|
+
|
|
397
|
+
The rotational transformation rotates a point or direction around a specific axis
|
|
398
|
+
(e.g., \( x \)-, \( y \)-, and \( z \)-axis). For example, the rotation matrix
|
|
399
|
+
around the \( z \)-axis is:
|
|
400
|
+
|
|
401
|
+
.. math::
|
|
402
|
+
|
|
403
|
+
M^{rot}_z(\theta_z) =
|
|
404
|
+
\begin{bmatrix}
|
|
405
|
+
\cos \theta_z & -\sin \theta_z & 0 & 0 \\
|
|
406
|
+
\sin \theta_z & \cos \theta_z & 0 & 0 \\
|
|
407
|
+
0 & 0 & 1 & 0 \\
|
|
408
|
+
0 & 0 & 0 & 1
|
|
409
|
+
\end{bmatrix}
|
|
410
|
+
|
|
411
|
+
Args:
|
|
412
|
+
angle (float or Tensor): Rotation angle in degrees.
|
|
413
|
+
axis (int): Axis index (0=X, 1=Y, 2=Z).
|
|
414
|
+
parent_transform (Transform, optional): Optional parent transformation.
|
|
415
|
+
|
|
416
|
+
Example:
|
|
417
|
+
>>> import diffinytrace as dit
|
|
418
|
+
>>> transf1 = dit.transforms.Identity()
|
|
419
|
+
>>> transf2 = dit.transforms.Distance(10.0,axis=2,parent_transform=transf1)
|
|
420
|
+
>>> transf3 = dit.transforms.Rotation(45.,axis=0,parent_transform=transf2)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
"""
|
|
424
|
+
def __init__(self, angle: float, axis: int, parent_transform=Identity()):
|
|
425
|
+
#TODO test rotation for combi angle_x, angle_y, angle_z Reihenfolge egal?
|
|
426
|
+
super().__init__()
|
|
427
|
+
self.angle = make_parameter_from_input(angle)
|
|
428
|
+
self.axis = axis
|
|
429
|
+
self.parent_transform = parent_transform.get_transform()
|
|
430
|
+
|
|
431
|
+
def get_functional_param_args(self):
|
|
432
|
+
return [self.angle]+self.parent_transform.get_functional_param_args()
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
def functional(self,O:torch.Tensor,angle:torch.Tensor,*parent_param_args)->torch.Tensor:
|
|
436
|
+
#R = rotate_3d(angle_x, angle_y, angle_z)
|
|
437
|
+
O = self.parent_transform.functional(O,*parent_param_args)
|
|
438
|
+
R = None
|
|
439
|
+
if self.axis == 0:
|
|
440
|
+
R = rotation_matrix_x(360.0-angle)
|
|
441
|
+
if self.axis == 1:
|
|
442
|
+
R = rotation_matrix_y(360.0-angle)
|
|
443
|
+
if self.axis == 2:
|
|
444
|
+
R = rotation_matrix_z(360.0-angle)
|
|
445
|
+
|
|
446
|
+
return O@R.T
|
|
447
|
+
|
|
448
|
+
def get_transformation_matrix(self,device=None,dtype=None) -> torch.Tensor:
|
|
449
|
+
if device is None:
|
|
450
|
+
device = self.angle.device
|
|
451
|
+
if dtype is None:
|
|
452
|
+
dtype = self.angle.dtype
|
|
453
|
+
|
|
454
|
+
parent_transform_matrix = self.parent_transform.get_transformation_matrix(device,dtype)
|
|
455
|
+
R = None
|
|
456
|
+
if self.axis == 0:
|
|
457
|
+
R = rotation_matrix_x(self.angle)
|
|
458
|
+
if self.axis == 1:
|
|
459
|
+
R = rotation_matrix_y(self.angle)
|
|
460
|
+
if self.axis == 2:
|
|
461
|
+
R = rotation_matrix_z(self.angle)
|
|
462
|
+
|
|
463
|
+
if R.device != device:
|
|
464
|
+
R = R.to(device)
|
|
465
|
+
if R.dtype != dtype:
|
|
466
|
+
R = R.to(dtype)
|
|
467
|
+
this_matrix = torch.eye(4,device=device,dtype=dtype)
|
|
468
|
+
this_matrix[:3,:3] = R
|
|
469
|
+
out = parent_transform_matrix@this_matrix
|
|
470
|
+
return out
|
|
471
|
+
|
|
472
|
+
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
# Copyright (c) 2025 Martin Pflaum
|
|
2
|
+
# This file is part of the diffinytrace project, licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"grad"
|
|
7
|
+
]
|
|
8
|
+
|
|
9
|
+
import typing
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def grad(
|
|
14
|
+
outputs: torch.types._TensorOrTensors,
|
|
15
|
+
inputs: torch.types._TensorOrTensorsOrGradEdge,
|
|
16
|
+
grad_outputs: typing.Optional[torch.types._TensorOrTensors] = None,
|
|
17
|
+
retain_graph: typing.Optional[bool] = True,
|
|
18
|
+
create_graph: bool = True,
|
|
19
|
+
only_inputs: bool = True,
|
|
20
|
+
is_grads_batched: bool = False,
|
|
21
|
+
materialize_grads: bool = False,
|
|
22
|
+
remove_no_grad_outputs: bool = True
|
|
23
|
+
):
|
|
24
|
+
"""
|
|
25
|
+
Computes the gradients of the outputs with respect to the inputs.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
outputs (torch.Tensor or tuple of torch.Tensor): The output tensors.
|
|
29
|
+
inputs (torch.Tensor or tuple of torch.Tensor): The input tensors.
|
|
30
|
+
grad_outputs (torch.Tensor or tuple of torch.Tensor, optional): The gradients of the outputs.
|
|
31
|
+
retain_graph (bool, optional): Whether to retain the graph after computing gradients.
|
|
32
|
+
create_graph (bool, optional): Whether to create the graph for higher-order gradients.
|
|
33
|
+
only_inputs (bool, optional): Whether to only compute gradients for the inputs.
|
|
34
|
+
is_grads_batched (bool, optional): Whether the gradients are batched.
|
|
35
|
+
materialize_grads (bool, optional): Whether to materialize the gradients.
|
|
36
|
+
remove_no_grad_outputs (bool, optional): Whether to remove outputs that do not require gradients.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
list: A list of gradients for each input tensor.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
if torch.is_tensor(inputs):
|
|
44
|
+
inputs = [inputs]
|
|
45
|
+
inputs = [elem for elem in inputs]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
if remove_no_grad_outputs:
|
|
49
|
+
if torch.is_tensor(grad_outputs) or torch.is_tensor(outputs):
|
|
50
|
+
if torch.is_tensor(outputs):
|
|
51
|
+
if not outputs.requires_grad:
|
|
52
|
+
if torch.is_tensor(inputs):
|
|
53
|
+
raise RuntimeError("this branch should not be called!")
|
|
54
|
+
else:
|
|
55
|
+
out = []
|
|
56
|
+
for elem in inputs:
|
|
57
|
+
if torch.is_tensor(elem):
|
|
58
|
+
out += [torch.zeros_like(elem)]
|
|
59
|
+
else:
|
|
60
|
+
out += [None]
|
|
61
|
+
return out
|
|
62
|
+
else:
|
|
63
|
+
_grad_outputs = [elem for elem in grad_outputs]
|
|
64
|
+
|
|
65
|
+
new_grad_outputs = []
|
|
66
|
+
new_outputs = []
|
|
67
|
+
for k,elem in enumerate(outputs):
|
|
68
|
+
if torch.is_tensor(elem):
|
|
69
|
+
if elem.requires_grad:
|
|
70
|
+
grad_elem = _grad_outputs[k]
|
|
71
|
+
new_outputs += [elem]
|
|
72
|
+
new_grad_outputs += [grad_elem]
|
|
73
|
+
grad_outputs = new_grad_outputs
|
|
74
|
+
outputs = new_outputs
|
|
75
|
+
|
|
76
|
+
inputs_requires_grad = []
|
|
77
|
+
back_map_input = {}
|
|
78
|
+
inputs_map_i = 0
|
|
79
|
+
for k,param in enumerate(inputs):
|
|
80
|
+
#param = inputs[k]
|
|
81
|
+
|
|
82
|
+
if param is None:
|
|
83
|
+
continue
|
|
84
|
+
if param.requires_grad:
|
|
85
|
+
inputs_requires_grad += [param]
|
|
86
|
+
back_map_input[k] = inputs_map_i
|
|
87
|
+
inputs_map_i += 1
|
|
88
|
+
grad_tmp = []
|
|
89
|
+
if len(inputs_requires_grad)!=0:
|
|
90
|
+
grad_tmp = torch.autograd.grad(outputs=outputs,
|
|
91
|
+
inputs=inputs_requires_grad,
|
|
92
|
+
grad_outputs=grad_outputs,
|
|
93
|
+
retain_graph=retain_graph,
|
|
94
|
+
create_graph=create_graph,
|
|
95
|
+
only_inputs=only_inputs,
|
|
96
|
+
allow_unused=True,
|
|
97
|
+
is_grads_batched=is_grads_batched,
|
|
98
|
+
materialize_grads=materialize_grads)
|
|
99
|
+
else:
|
|
100
|
+
|
|
101
|
+
pass
|
|
102
|
+
grad = [None for input in inputs]
|
|
103
|
+
for k in range(len(grad)):
|
|
104
|
+
if k in back_map_input.keys():
|
|
105
|
+
inputs_map_i = back_map_input[k]
|
|
106
|
+
grad[k] = grad_tmp[inputs_map_i]
|
|
107
|
+
else:
|
|
108
|
+
if materialize_grads:
|
|
109
|
+
if inputs[k] is None:
|
|
110
|
+
grad[k] = None
|
|
111
|
+
else:
|
|
112
|
+
grad[k] = torch.zeros_like(inputs[k])
|
|
113
|
+
else:
|
|
114
|
+
grad[k] = None
|
|
115
|
+
|
|
116
|
+
return grad
|