diffinytrace 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. diffinytrace/__init__.py +122 -0
  2. diffinytrace/basis_functions/__init__.py +14 -0
  3. diffinytrace/basis_functions/bspline.py +521 -0
  4. diffinytrace/basis_functions/chebyshev.py +3 -0
  5. diffinytrace/basis_functions/legendre.py +77 -0
  6. diffinytrace/basis_functions/zernike.py +235 -0
  7. diffinytrace/config.py +140 -0
  8. diffinytrace/constraints.py +54 -0
  9. diffinytrace/element.py +1660 -0
  10. diffinytrace/export/__init__.py +8 -0
  11. diffinytrace/export/cad.py +253 -0
  12. diffinytrace/gaussian_smoother.py +530 -0
  13. diffinytrace/hat_smoother.py +44 -0
  14. diffinytrace/integrators.py +452 -0
  15. diffinytrace/intersection.py +285 -0
  16. diffinytrace/optimize.py +808 -0
  17. diffinytrace/physical_object.py +150 -0
  18. diffinytrace/plotting/__init__.py +16 -0
  19. diffinytrace/plotting/core.py +92 -0
  20. diffinytrace/plotting/quantity2D.py +188 -0
  21. diffinytrace/plotting/system2D.py +220 -0
  22. diffinytrace/plotting/system3D.py +327 -0
  23. diffinytrace/plotting/wavelength.py +231 -0
  24. diffinytrace/refractive_index.py +101 -0
  25. diffinytrace/render.py +77 -0
  26. diffinytrace/source.py +661 -0
  27. diffinytrace/spectrum.py +79 -0
  28. diffinytrace/surface.py +468 -0
  29. diffinytrace/target_grid.py +399 -0
  30. diffinytrace/transforms.py +472 -0
  31. diffinytrace/utils/__init__.py +7 -0
  32. diffinytrace/utils/autograd.py +116 -0
  33. diffinytrace/utils/irradiance_importer.py +134 -0
  34. diffinytrace-2.1.dist-info/METADATA +26 -0
  35. diffinytrace-2.1.dist-info/RECORD +38 -0
  36. diffinytrace-2.1.dist-info/WHEEL +5 -0
  37. diffinytrace-2.1.dist-info/licenses/LICENSE +21 -0
  38. diffinytrace-2.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,472 @@
1
+ # Copyright (c) 2025 Martin Pflaum
2
+ # This file is part of the diffinytrace project, licensed under the MIT License.
3
+
4
+ __all__ = [
5
+ "Transform",
6
+ "Identity",
7
+ "Compose",
8
+ "Offset",
9
+ "Distance",
10
+ "Rotation",
11
+ "rotation_matrix_x",
12
+ "rotation_matrix_y",
13
+ "rotation_matrix_z"
14
+ ]
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from .intersection import SemiFunctionalModule,cat_semi_functionals
19
+ import numpy as np
20
+ from .optimize import make_parameter_from_input
21
+
22
+ class Transform(SemiFunctionalModule):
23
+ """
24
+ Base class for coordinate transformations.
25
+
26
+ This class provides interfaces to transform directions and positions between
27
+ local and global coordinate systems using homogeneous coordinates.
28
+
29
+ Methods:
30
+ get_functional_param_args(): Return parameters required for the transformation.
31
+ functional(O, *params): Apply transformation in functional style.
32
+ get_transformation_matrix(): Return the 4x4 transformation matrix.
33
+ to_global_dir(direction): Transform direction to global space.
34
+ to_local_dir(direction): Transform direction to local space.
35
+ to_global_pos(position): Transform position to global space.
36
+ to_local_pos(position): Transform position to local space.
37
+ """
38
+ def __init__(self):
39
+ super().__init__()
40
+
41
+ def get_functional_param_args(self):
42
+ """
43
+ Return parameters required for the transformation which constructs the surfaces through the functional.
44
+
45
+ Returns:
46
+ list: List of parameters required for the functional which constructs the surfaces.
47
+ """
48
+ raise NotImplementedError("params_list not implemented")
49
+
50
+ @staticmethod
51
+ def functional(O,*params)->torch.Tensor:
52
+ """
53
+ Apply transformation in functional style. This is global to local.
54
+
55
+ Args:
56
+ O (torch.Tensor): Input tensor to be transformed.
57
+ *params: Parameters for the transformation.
58
+
59
+ """
60
+ raise NotImplementedError("functional not implemented")
61
+
62
+ def get_transformation_matrix(self,device=None,dtype=None)->torch.Tensor:
63
+ """
64
+ Return the 4x4 transformation matrix.
65
+
66
+ Args:
67
+ device (torch.device, optional): Device for the matrix.
68
+ dtype (torch.dtype, optional): Data type for the matrix.
69
+
70
+ Returns:
71
+ torch.Tensor: 4x4 transformation matrix.
72
+ """
73
+ raise NotImplementedError("get_transformation_matrix not implemented")
74
+
75
+ def get_transform(self):
76
+ """
77
+ Returns itself.
78
+ """
79
+ return self
80
+
81
+ def to_global_dir(self,direction:torch.Tensor) -> torch.Tensor:
82
+ """
83
+ Transform direction to global space.
84
+ Args:
85
+ direction (torch.Tensor): Direction vector in local space.
86
+ Returns:
87
+ torch.Tensor: Direction vector in global space.
88
+ """
89
+ M = self.get_transformation_matrix(direction.device,direction.dtype)
90
+ R = M[np.ix_([0,1,2],[0,1,2])]
91
+ out = direction@R.T
92
+ return out
93
+
94
+ def to_local_dir(self,direction:torch.Tensor) -> torch.Tensor:
95
+ """
96
+ Transform direction to local space.
97
+ Args:
98
+ direction (torch.Tensor): Direction vector in global space.
99
+ Returns:
100
+ torch.Tensor: Direction vector in local space.
101
+ """
102
+ M = self.get_transformation_matrix(direction.device,direction.dtype)
103
+ R = M[np.ix_([0,1,2],[0,1,2])]
104
+ R_inv = torch.inverse(R)
105
+ out = direction@R_inv.T
106
+ return out
107
+
108
+ def to_global_pos(self,position:torch.Tensor) -> torch.Tensor:
109
+ """
110
+ Transform position to global space.
111
+ Args:
112
+ position (torch.Tensor): Position vector in local space.
113
+ Returns:
114
+ torch.Tensor: Position vector in global space.
115
+ """
116
+ M = self.get_transformation_matrix(position.device,position.dtype)
117
+ v = torch.zeros((position.shape[0],4),device=position.device,dtype=position.dtype)
118
+ v[:,[0,1,2]] = position
119
+ v[:,3] = torch.ones_like(v[:,3])
120
+ _out = v@M.T
121
+ out = _out[:,[0,1,2]]
122
+ return out
123
+
124
+ def to_local_pos(self,position:torch.Tensor) -> torch.Tensor:
125
+ """
126
+ Transform position to local space.
127
+ Args:
128
+ position (torch.Tensor): Position vector in global space.
129
+ Returns:
130
+ torch.Tensor: Position vector in local space.
131
+ """
132
+ return self.functional(position,*self.get_functional_param_args())
133
+
134
+
135
+ class Identity(Transform):
136
+ """
137
+ Identity transformation that returns input positions unchanged.
138
+
139
+ Example:
140
+ >>> import diffinytrace as dit
141
+ >>> transf1 = dit.transforms.Identity()
142
+ """
143
+ def __init__(self):
144
+ super().__init__()
145
+
146
+ def get_functional_param_args(self):
147
+ return []
148
+
149
+ @staticmethod
150
+ def functional(O:torch.Tensor) -> torch.Tensor:
151
+ return O
152
+
153
+ def get_transformation_matrix(self,device=None,dtype=None) -> torch.Tensor:
154
+ out = torch.eye(4,device=device,dtype=dtype)
155
+ return out
156
+
157
+ class Compose(Transform):
158
+ """
159
+ Compose multiple transforms in sequence.
160
+
161
+ Args:
162
+ transform_list (list[Transform]): List of transformations to apply in order.
163
+ """
164
+ def __init__(self,transform_list):
165
+ super().__init__()
166
+ self.transform_list = nn.ModuleList(transform_list)
167
+ self.functional = cat_semi_functionals(self.transform_list)
168
+
169
+ def get_functional_param_args(self):
170
+ out = []
171
+ for elem in self.transform_list:
172
+ out += elem.get_functional_param_args()
173
+ return out
174
+
175
+ def get_transformation_matrix(self,device=None,dtype=None) -> torch.Tensor:
176
+ out = torch.eye(4,device=device,dtype=dtype)
177
+ for elem in self.transform_list:
178
+ tmp = elem.get_transformation_matrix(device,dtype)
179
+ if not device is None and tmp.device != device:
180
+ tmp = tmp.to(device)
181
+ if not dtype is None and tmp.dtype != dtype:
182
+ tmp = tmp.to(dtype)
183
+ if out.device != tmp.device:
184
+ tmp = tmp.to(device)
185
+ if out.dtype != tmp.dtype:
186
+ tmp = tmp.to(dtype)
187
+ out = out @ tmp
188
+ return out
189
+
190
+
191
+ class Offset(Transform):
192
+ r"""
193
+ Translation transform using an offset vector.
194
+
195
+ The offset transformation shifts a position by a specified vector
196
+ \( \vec{w} = (w_x, w_y, w_z) \). The transformation matrix \( M \)
197
+ for an offset transformation is:
198
+
199
+ .. math::
200
+
201
+ M^{offset}(w_x, w_y, w_z) =
202
+ \begin{bmatrix}
203
+ 1 & 0 & 0 & w_x \\
204
+ 0 & 1 & 0 & w_y \\
205
+ 0 & 0 & 1 & w_z \\
206
+ 0 & 0 & 0 & 1
207
+ \end{bmatrix}
208
+
209
+ Example:
210
+ >>> import diffinytrace as dit
211
+ >>> transf1 = dit.transforms.Identity()
212
+ >>> transf2 = dit.transforms.Offset([1.0, 2.0, 3.0], parent_transform=transf1)
213
+
214
+ Args:
215
+ pos (Tensor or list or float): The offset position as a 3D vector.
216
+ parent_transform (Transform, optional): Optional parent transformation.
217
+ """
218
+ def __init__(self,pos,parent_transform=Identity()):
219
+ super().__init__()
220
+ self.pos = make_parameter_from_input(pos)
221
+ self.parent_transform = parent_transform.get_transform()
222
+
223
+ def get_functional_param_args(self):
224
+ return [self.pos]+self.parent_transform.get_functional_param_args()
225
+
226
+ def functional(self,O:torch.Tensor,pos:torch.Tensor,*parent_param_args)->torch.Tensor:
227
+ O = self.parent_transform.functional(O,*parent_param_args)
228
+ return O-pos
229
+
230
+ def get_transformation_matrix(self,device=None,dtype=None) -> torch.Tensor:
231
+ if device is None:
232
+ device = self.pos.device
233
+ if dtype is None:
234
+ dtype = self.pos.dtype
235
+ parent_transform_matrix = self.parent_transform.get_transformation_matrix(device=device,dtype=dtype)
236
+ this_matrix = torch.eye(4,device=device,dtype=dtype)
237
+
238
+ this_matrix[[0,1,2],-1] = self.pos.to(device=device,dtype=dtype)
239
+
240
+ out = parent_transform_matrix@this_matrix
241
+ return out
242
+
243
+
244
+ class Distance(Transform):
245
+ r"""
246
+ Applies a translation along a specific axis by a given distance.
247
+
248
+ The distance transformation applies a translation by a specific distance along a given axis
249
+ (e.g., \( x \)-, \( y \)-, or \( z \)-axis). The transformation matrix \( M \) for a distance
250
+ transformation along the \( z \)-axis is given by:
251
+
252
+ .. math::
253
+
254
+ M^{dist}_z(d) =
255
+ \begin{bmatrix}
256
+ 1 & 0 & 0 & 0 \\
257
+ 0 & 1 & 0 & 0 \\
258
+ 0 & 0 & 1 & d \\
259
+ 0 & 0 & 0 & 1
260
+ \end{bmatrix},
261
+
262
+ where \( d \) represents the distance of translation along the \( z \)-axis.
263
+
264
+ Args:
265
+ distance (float or Tensor): Distance to translate.
266
+ axis (int): Axis along which translation is applied (0=X, 1=Y, 2=Z).
267
+ parent_transform (Transform): Optional parent transformation.
268
+
269
+ Example:
270
+ >>> import diffinytrace as dit
271
+ >>> transf1 = dit.transforms.Identity()
272
+ >>> transf2 = dit.transforms.Distance(10.0,axis=2,parent_transform=transf1)
273
+
274
+ Notes:
275
+ For the local to global transformation it applies the following transformation:
276
+
277
+ .. math::
278
+
279
+ \mathbf{x}_\text{local} = \mathbf{x}_\text{parent} + d \cdot \mathbf{e}_i
280
+ """
281
+ def __init__(self,distance,axis = 2,parent_transform=Identity()):
282
+ super().__init__()
283
+ self.distance = make_parameter_from_input(distance)
284
+ self.unit_vec = torch.tensor([0.,0.,0.])
285
+ #self.register_buffer('unit_vec', torch.tensor([0.,0.,0.])) # Buffer attribute
286
+
287
+ self.unit_vec[axis] = 1.0 #is constant!
288
+ self.parent_transform = parent_transform.get_transform()
289
+
290
+ def get_functional_param_args(self):
291
+ unit_vec = self.unit_vec
292
+ if unit_vec.device != self.distance.device:
293
+ unit_vec = unit_vec.to(self.distance.device)
294
+ return [self.distance,unit_vec]+self.parent_transform.get_functional_param_args()
295
+
296
+ def functional(self,O:torch.Tensor,distance:torch.Tensor,unit_vec:torch.Tensor,*parent_param_args)->torch.Tensor:
297
+ O = self.parent_transform.functional(O,*parent_param_args)
298
+ O = O-distance*unit_vec
299
+ return O
300
+
301
+ def get_transformation_matrix(self,device=None,dtype=None)->torch.Tensor:
302
+ if device is None:
303
+ device = self.distance.device
304
+ if dtype is None:
305
+ dtype = self.distance.dtype
306
+
307
+ unit_vec = self.unit_vec.to(device=device,dtype=dtype)
308
+ parent_transform_matrix = self.parent_transform.get_transformation_matrix(device,dtype)
309
+ this_matrix = torch.eye(4,device=device,dtype=dtype)
310
+ this_matrix[[0,1,2],-1] = self.distance.to(device=device,dtype=dtype)*unit_vec
311
+ out = parent_transform_matrix@this_matrix
312
+ return out
313
+
314
+ def rotation_matrix_x(angle:torch.Tensor) -> torch.Tensor:
315
+ """
316
+ Construct a 3x3 rotation matrix around the X-axis.
317
+
318
+ Args:
319
+ angle (Tensor): Angle in degrees.
320
+
321
+ Returns:
322
+ Tensor: 3x3 rotation matrix.
323
+ """
324
+ # Convert angle from degrees to radians
325
+ angle = angle * (2.0 * torch.pi / 360.0)
326
+ device = angle.device
327
+ dtype = angle.dtype
328
+
329
+ # Initialize a 4x4 identity matrix
330
+ rot_x = torch.eye(3, dtype=dtype, device=device)
331
+
332
+ # Set the rotation entries
333
+ rot_x[1, 1] = torch.cos(angle)
334
+ rot_x[1, 2] = -torch.sin(angle)
335
+ rot_x[2, 1] = torch.sin(angle)
336
+ rot_x[2, 2] = torch.cos(angle)
337
+
338
+ return rot_x
339
+
340
+ def rotation_matrix_y(angle:torch.Tensor) -> torch.Tensor:
341
+ """
342
+ Construct a 3x3 rotation matrix around the Y-axis.
343
+
344
+ Args:
345
+ angle (Tensor): Angle in degrees.
346
+
347
+ Returns:
348
+ Tensor: 3x3 rotation matrix.
349
+ """
350
+ # Convert angle from degrees to radians
351
+ angle = angle * (2.0 * torch.pi / 360.0)
352
+ device = angle.device
353
+ dtype = angle.dtype
354
+
355
+ # Initialize a 4x4 identity matrix
356
+ rot_y = torch.eye(3, dtype=dtype, device=device)
357
+
358
+ # Set the rotation entries
359
+ rot_y[0, 0] = torch.cos(angle)
360
+ rot_y[0, 2] = torch.sin(angle)
361
+ rot_y[2, 0] = -torch.sin(angle)
362
+ rot_y[2, 2] = torch.cos(angle)
363
+
364
+ return rot_y
365
+
366
+ def rotation_matrix_z(angle:torch.Tensor) -> torch.Tensor:
367
+ """
368
+ Construct a 3x3 rotation matrix around the Z-axis.
369
+
370
+ Args:
371
+ angle (Tensor): Angle in degrees.
372
+
373
+ Returns:
374
+ Tensor: 3x3 rotation matrix.
375
+ """
376
+ # Convert angle from degrees to radians
377
+ angle = angle * (2.0 * torch.pi / 360.0)
378
+ device = angle.device
379
+ dtype = angle.dtype
380
+
381
+ # Initialize a 4x4 identity matrix
382
+ rot_z = torch.eye(3, dtype=dtype, device=device)
383
+
384
+ # Set the rotation entries
385
+ rot_z[0, 0] = torch.cos(angle)
386
+ rot_z[0, 1] = -torch.sin(angle)
387
+ rot_z[1, 0] = torch.sin(angle)
388
+ rot_z[1, 1] = torch.cos(angle)
389
+
390
+ return rot_z
391
+
392
+
393
+ class Rotation(Transform):
394
+ r"""
395
+ Applies a 3D rotation around a principal axis.
396
+
397
+ The rotational transformation rotates a point or direction around a specific axis
398
+ (e.g., \( x \)-, \( y \)-, and \( z \)-axis). For example, the rotation matrix
399
+ around the \( z \)-axis is:
400
+
401
+ .. math::
402
+
403
+ M^{rot}_z(\theta_z) =
404
+ \begin{bmatrix}
405
+ \cos \theta_z & -\sin \theta_z & 0 & 0 \\
406
+ \sin \theta_z & \cos \theta_z & 0 & 0 \\
407
+ 0 & 0 & 1 & 0 \\
408
+ 0 & 0 & 0 & 1
409
+ \end{bmatrix}
410
+
411
+ Args:
412
+ angle (float or Tensor): Rotation angle in degrees.
413
+ axis (int): Axis index (0=X, 1=Y, 2=Z).
414
+ parent_transform (Transform, optional): Optional parent transformation.
415
+
416
+ Example:
417
+ >>> import diffinytrace as dit
418
+ >>> transf1 = dit.transforms.Identity()
419
+ >>> transf2 = dit.transforms.Distance(10.0,axis=2,parent_transform=transf1)
420
+ >>> transf3 = dit.transforms.Rotation(45.,axis=0,parent_transform=transf2)
421
+
422
+
423
+ """
424
+ def __init__(self, angle: float, axis: int, parent_transform=Identity()):
425
+ #TODO test rotation for combi angle_x, angle_y, angle_z Reihenfolge egal?
426
+ super().__init__()
427
+ self.angle = make_parameter_from_input(angle)
428
+ self.axis = axis
429
+ self.parent_transform = parent_transform.get_transform()
430
+
431
+ def get_functional_param_args(self):
432
+ return [self.angle]+self.parent_transform.get_functional_param_args()
433
+
434
+
435
+ def functional(self,O:torch.Tensor,angle:torch.Tensor,*parent_param_args)->torch.Tensor:
436
+ #R = rotate_3d(angle_x, angle_y, angle_z)
437
+ O = self.parent_transform.functional(O,*parent_param_args)
438
+ R = None
439
+ if self.axis == 0:
440
+ R = rotation_matrix_x(360.0-angle)
441
+ if self.axis == 1:
442
+ R = rotation_matrix_y(360.0-angle)
443
+ if self.axis == 2:
444
+ R = rotation_matrix_z(360.0-angle)
445
+
446
+ return O@R.T
447
+
448
+ def get_transformation_matrix(self,device=None,dtype=None) -> torch.Tensor:
449
+ if device is None:
450
+ device = self.angle.device
451
+ if dtype is None:
452
+ dtype = self.angle.dtype
453
+
454
+ parent_transform_matrix = self.parent_transform.get_transformation_matrix(device,dtype)
455
+ R = None
456
+ if self.axis == 0:
457
+ R = rotation_matrix_x(self.angle)
458
+ if self.axis == 1:
459
+ R = rotation_matrix_y(self.angle)
460
+ if self.axis == 2:
461
+ R = rotation_matrix_z(self.angle)
462
+
463
+ if R.device != device:
464
+ R = R.to(device)
465
+ if R.dtype != dtype:
466
+ R = R.to(dtype)
467
+ this_matrix = torch.eye(4,device=device,dtype=dtype)
468
+ this_matrix[:3,:3] = R
469
+ out = parent_transform_matrix@this_matrix
470
+ return out
471
+
472
+
@@ -0,0 +1,7 @@
1
+ # Copyright (c) 2025 Martin Pflaum
2
+ # This file is part of the diffinytrace project, licensed under the MIT License.
3
+
4
+
5
+ #from . import sympy_helper
6
+ #from . import autograd
7
+ from . import irradiance_importer
@@ -0,0 +1,116 @@
1
+ # Copyright (c) 2025 Martin Pflaum
2
+ # This file is part of the diffinytrace project, licensed under the MIT License.
3
+
4
+
5
+ __all__ = [
6
+ "grad"
7
+ ]
8
+
9
+ import typing
10
+ import torch
11
+
12
+
13
+ def grad(
14
+ outputs: torch.types._TensorOrTensors,
15
+ inputs: torch.types._TensorOrTensorsOrGradEdge,
16
+ grad_outputs: typing.Optional[torch.types._TensorOrTensors] = None,
17
+ retain_graph: typing.Optional[bool] = True,
18
+ create_graph: bool = True,
19
+ only_inputs: bool = True,
20
+ is_grads_batched: bool = False,
21
+ materialize_grads: bool = False,
22
+ remove_no_grad_outputs: bool = True
23
+ ):
24
+ """
25
+ Computes the gradients of the outputs with respect to the inputs.
26
+
27
+ Args:
28
+ outputs (torch.Tensor or tuple of torch.Tensor): The output tensors.
29
+ inputs (torch.Tensor or tuple of torch.Tensor): The input tensors.
30
+ grad_outputs (torch.Tensor or tuple of torch.Tensor, optional): The gradients of the outputs.
31
+ retain_graph (bool, optional): Whether to retain the graph after computing gradients.
32
+ create_graph (bool, optional): Whether to create the graph for higher-order gradients.
33
+ only_inputs (bool, optional): Whether to only compute gradients for the inputs.
34
+ is_grads_batched (bool, optional): Whether the gradients are batched.
35
+ materialize_grads (bool, optional): Whether to materialize the gradients.
36
+ remove_no_grad_outputs (bool, optional): Whether to remove outputs that do not require gradients.
37
+
38
+ Returns:
39
+ list: A list of gradients for each input tensor.
40
+ """
41
+
42
+
43
+ if torch.is_tensor(inputs):
44
+ inputs = [inputs]
45
+ inputs = [elem for elem in inputs]
46
+
47
+
48
+ if remove_no_grad_outputs:
49
+ if torch.is_tensor(grad_outputs) or torch.is_tensor(outputs):
50
+ if torch.is_tensor(outputs):
51
+ if not outputs.requires_grad:
52
+ if torch.is_tensor(inputs):
53
+ raise RuntimeError("this branch should not be called!")
54
+ else:
55
+ out = []
56
+ for elem in inputs:
57
+ if torch.is_tensor(elem):
58
+ out += [torch.zeros_like(elem)]
59
+ else:
60
+ out += [None]
61
+ return out
62
+ else:
63
+ _grad_outputs = [elem for elem in grad_outputs]
64
+
65
+ new_grad_outputs = []
66
+ new_outputs = []
67
+ for k,elem in enumerate(outputs):
68
+ if torch.is_tensor(elem):
69
+ if elem.requires_grad:
70
+ grad_elem = _grad_outputs[k]
71
+ new_outputs += [elem]
72
+ new_grad_outputs += [grad_elem]
73
+ grad_outputs = new_grad_outputs
74
+ outputs = new_outputs
75
+
76
+ inputs_requires_grad = []
77
+ back_map_input = {}
78
+ inputs_map_i = 0
79
+ for k,param in enumerate(inputs):
80
+ #param = inputs[k]
81
+
82
+ if param is None:
83
+ continue
84
+ if param.requires_grad:
85
+ inputs_requires_grad += [param]
86
+ back_map_input[k] = inputs_map_i
87
+ inputs_map_i += 1
88
+ grad_tmp = []
89
+ if len(inputs_requires_grad)!=0:
90
+ grad_tmp = torch.autograd.grad(outputs=outputs,
91
+ inputs=inputs_requires_grad,
92
+ grad_outputs=grad_outputs,
93
+ retain_graph=retain_graph,
94
+ create_graph=create_graph,
95
+ only_inputs=only_inputs,
96
+ allow_unused=True,
97
+ is_grads_batched=is_grads_batched,
98
+ materialize_grads=materialize_grads)
99
+ else:
100
+
101
+ pass
102
+ grad = [None for input in inputs]
103
+ for k in range(len(grad)):
104
+ if k in back_map_input.keys():
105
+ inputs_map_i = back_map_input[k]
106
+ grad[k] = grad_tmp[inputs_map_i]
107
+ else:
108
+ if materialize_grads:
109
+ if inputs[k] is None:
110
+ grad[k] = None
111
+ else:
112
+ grad[k] = torch.zeros_like(inputs[k])
113
+ else:
114
+ grad[k] = None
115
+
116
+ return grad