warp-lang 1.6.2__py3-none-win_amd64.whl → 1.7.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (191) hide show
  1. warp/__init__.py +7 -1
  2. warp/autograd.py +12 -2
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +410 -0
  6. warp/build_dll.py +6 -14
  7. warp/builtins.py +463 -372
  8. warp/codegen.py +196 -124
  9. warp/config.py +42 -6
  10. warp/context.py +496 -271
  11. warp/dlpack.py +8 -6
  12. warp/examples/assets/nonuniform.usd +0 -0
  13. warp/examples/assets/nvidia_logo.png +0 -0
  14. warp/examples/benchmarks/benchmark_cloth.py +1 -1
  15. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  16. warp/examples/core/example_sample_mesh.py +300 -0
  17. warp/examples/distributed/example_jacobi_mpi.py +507 -0
  18. warp/examples/fem/example_apic_fluid.py +1 -1
  19. warp/examples/fem/example_burgers.py +2 -2
  20. warp/examples/fem/example_deformed_geometry.py +1 -1
  21. warp/examples/fem/example_distortion_energy.py +1 -1
  22. warp/examples/fem/example_magnetostatics.py +6 -6
  23. warp/examples/fem/utils.py +9 -3
  24. warp/examples/interop/example_jax_callable.py +116 -0
  25. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  26. warp/examples/interop/example_jax_kernel.py +205 -0
  27. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  28. warp/examples/tile/example_tile_matmul.py +2 -4
  29. warp/fem/__init__.py +11 -1
  30. warp/fem/adaptivity.py +4 -4
  31. warp/fem/field/field.py +11 -1
  32. warp/fem/field/nodal_field.py +56 -88
  33. warp/fem/field/virtual.py +62 -23
  34. warp/fem/geometry/adaptive_nanogrid.py +16 -13
  35. warp/fem/geometry/closest_point.py +1 -1
  36. warp/fem/geometry/deformed_geometry.py +5 -2
  37. warp/fem/geometry/geometry.py +5 -0
  38. warp/fem/geometry/grid_2d.py +12 -12
  39. warp/fem/geometry/grid_3d.py +12 -15
  40. warp/fem/geometry/hexmesh.py +5 -7
  41. warp/fem/geometry/nanogrid.py +9 -11
  42. warp/fem/geometry/quadmesh.py +13 -13
  43. warp/fem/geometry/tetmesh.py +3 -4
  44. warp/fem/geometry/trimesh.py +7 -20
  45. warp/fem/integrate.py +262 -93
  46. warp/fem/linalg.py +5 -5
  47. warp/fem/quadrature/pic_quadrature.py +37 -22
  48. warp/fem/quadrature/quadrature.py +194 -25
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_function_space.py +4 -2
  51. warp/fem/space/basis_space.py +25 -18
  52. warp/fem/space/hexmesh_function_space.py +2 -2
  53. warp/fem/space/partition.py +6 -2
  54. warp/fem/space/quadmesh_function_space.py +8 -8
  55. warp/fem/space/shape/cube_shape_function.py +23 -23
  56. warp/fem/space/shape/square_shape_function.py +12 -12
  57. warp/fem/space/shape/triangle_shape_function.py +1 -1
  58. warp/fem/space/tetmesh_function_space.py +3 -3
  59. warp/fem/space/trimesh_function_space.py +2 -2
  60. warp/fem/utils.py +12 -6
  61. warp/jax.py +14 -1
  62. warp/jax_experimental/__init__.py +16 -0
  63. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
  64. warp/jax_experimental/ffi.py +702 -0
  65. warp/jax_experimental/xla_ffi.py +602 -0
  66. warp/math.py +89 -0
  67. warp/native/array.h +13 -0
  68. warp/native/builtin.h +29 -3
  69. warp/native/bvh.cpp +3 -1
  70. warp/native/bvh.cu +42 -14
  71. warp/native/bvh.h +2 -1
  72. warp/native/clang/clang.cpp +30 -3
  73. warp/native/cuda_util.cpp +14 -0
  74. warp/native/cuda_util.h +2 -0
  75. warp/native/exports.h +68 -63
  76. warp/native/intersect.h +26 -26
  77. warp/native/intersect_adj.h +33 -33
  78. warp/native/marching.cu +1 -1
  79. warp/native/mat.h +513 -9
  80. warp/native/mesh.h +10 -10
  81. warp/native/quat.h +99 -11
  82. warp/native/rand.h +6 -0
  83. warp/native/sort.cpp +122 -59
  84. warp/native/sort.cu +152 -15
  85. warp/native/sort.h +8 -1
  86. warp/native/sparse.cpp +43 -22
  87. warp/native/sparse.cu +52 -17
  88. warp/native/svd.h +116 -0
  89. warp/native/tile.h +312 -116
  90. warp/native/tile_reduce.h +46 -3
  91. warp/native/vec.h +68 -7
  92. warp/native/volume.cpp +85 -113
  93. warp/native/volume_builder.cu +25 -10
  94. warp/native/volume_builder.h +6 -0
  95. warp/native/warp.cpp +5 -6
  96. warp/native/warp.cu +100 -11
  97. warp/native/warp.h +19 -10
  98. warp/optim/linear.py +10 -10
  99. warp/render/render_opengl.py +19 -17
  100. warp/render/render_usd.py +93 -3
  101. warp/sim/articulation.py +4 -4
  102. warp/sim/collide.py +32 -19
  103. warp/sim/import_mjcf.py +449 -155
  104. warp/sim/import_urdf.py +32 -12
  105. warp/sim/inertia.py +189 -156
  106. warp/sim/integrator_euler.py +8 -5
  107. warp/sim/integrator_featherstone.py +3 -10
  108. warp/sim/integrator_vbd.py +207 -2
  109. warp/sim/integrator_xpbd.py +8 -5
  110. warp/sim/model.py +71 -25
  111. warp/sim/render.py +4 -0
  112. warp/sim/utils.py +2 -2
  113. warp/sparse.py +642 -555
  114. warp/stubs.py +217 -20
  115. warp/tests/__main__.py +0 -15
  116. warp/tests/assets/torus.usda +1 -1
  117. warp/tests/cuda/__init__.py +0 -0
  118. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  119. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  120. warp/tests/geometry/__init__.py +0 -0
  121. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  122. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  123. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  124. warp/tests/interop/__init__.py +0 -0
  125. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  126. warp/tests/sim/__init__.py +0 -0
  127. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  128. warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
  129. warp/tests/sim/test_inertia.py +161 -0
  130. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  131. warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
  132. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  133. warp/tests/sim/test_vbd.py +597 -0
  134. warp/tests/sim/test_xpbd.py +399 -0
  135. warp/tests/test_bool.py +1 -1
  136. warp/tests/test_codegen.py +24 -3
  137. warp/tests/test_examples.py +40 -38
  138. warp/tests/test_fem.py +98 -14
  139. warp/tests/test_linear_solvers.py +0 -11
  140. warp/tests/test_mat.py +577 -156
  141. warp/tests/test_mat_scalar_ops.py +4 -4
  142. warp/tests/test_overwrite.py +0 -60
  143. warp/tests/test_quat.py +356 -151
  144. warp/tests/test_rand.py +44 -37
  145. warp/tests/test_sparse.py +47 -6
  146. warp/tests/test_spatial.py +75 -0
  147. warp/tests/test_static.py +1 -1
  148. warp/tests/test_utils.py +84 -4
  149. warp/tests/test_vec.py +336 -178
  150. warp/tests/tile/__init__.py +0 -0
  151. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  152. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
  153. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  154. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  155. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  156. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  157. warp/tests/unittest_serial.py +1 -0
  158. warp/tests/unittest_suites.py +45 -62
  159. warp/tests/unittest_utils.py +2 -1
  160. warp/thirdparty/unittest_parallel.py +3 -1
  161. warp/types.py +175 -666
  162. warp/utils.py +137 -72
  163. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
  164. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
  165. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
  166. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
  167. warp/examples/optim/example_walker.py +0 -317
  168. warp/native/cutlass_gemm.cpp +0 -43
  169. warp/native/cutlass_gemm.cu +0 -382
  170. warp/tests/test_matmul.py +0 -511
  171. warp/tests/test_matmul_lite.py +0 -411
  172. warp/tests/test_vbd.py +0 -386
  173. warp/tests/unused_test_misc.py +0 -77
  174. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  175. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  176. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  177. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  178. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  179. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  180. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  181. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  182. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  183. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  184. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  185. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  186. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  187. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  188. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  189. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  190. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  191. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
warp/codegen.py CHANGED
@@ -26,7 +26,7 @@ import re
26
26
  import sys
27
27
  import textwrap
28
28
  import types
29
- from typing import Any, Callable, Dict, Mapping, Optional, Sequence
29
+ from typing import Any, Callable, Dict, Mapping, Optional, Sequence, get_args, get_origin
30
30
 
31
31
  import warp.config
32
32
  from warp.types import *
@@ -57,7 +57,7 @@ class WarpCodegenKeyError(KeyError):
57
57
 
58
58
 
59
59
  # map operator to function name
60
- builtin_operators = {}
60
+ builtin_operators: Dict[type[ast.AST], str] = {}
61
61
 
62
62
  # see https://www.ics.uci.edu/~pattis/ICS-31/lectures/opexp.pdf for a
63
63
  # nice overview of python operators
@@ -122,16 +122,6 @@ def get_closure_cell_contents(obj):
122
122
  return None
123
123
 
124
124
 
125
- def get_type_origin(tp):
126
- # Compatible version of `typing.get_origin()` for Python 3.7 and older.
127
- return getattr(tp, "__origin__", None)
128
-
129
-
130
- def get_type_args(tp):
131
- # Compatible version of `typing.get_args()` for Python 3.7 and older.
132
- return getattr(tp, "__args__", ())
133
-
134
-
135
125
  def eval_annotations(annotations: Mapping[str, Any], obj: Any) -> Mapping[str, Any]:
136
126
  """Un-stringize annotations caused by `from __future__ import annotations` of PEP 563."""
137
127
  # Implementation backported from `inspect.get_annotations()` for Python 3.9 and older.
@@ -212,7 +202,7 @@ def get_full_arg_spec(func: Callable) -> inspect.FullArgSpec:
212
202
  return spec._replace(annotations=eval_annotations(spec.annotations, func))
213
203
 
214
204
 
215
- def struct_instance_repr_recursive(inst: StructInstance, depth: int) -> str:
205
+ def struct_instance_repr_recursive(inst: StructInstance, depth: int, use_repr: bool) -> str:
216
206
  indent = "\t"
217
207
 
218
208
  # handle empty structs
@@ -226,9 +216,12 @@ def struct_instance_repr_recursive(inst: StructInstance, depth: int) -> str:
226
216
  field_value = getattr(inst, field_name, None)
227
217
 
228
218
  if isinstance(field_value, StructInstance):
229
- field_value = struct_instance_repr_recursive(field_value, depth + 1)
219
+ field_value = struct_instance_repr_recursive(field_value, depth + 1, use_repr)
230
220
 
231
- lines.append(f"{indent * (depth + 1)}{field_name}={field_value},")
221
+ if use_repr:
222
+ lines.append(f"{indent * (depth + 1)}{field_name}={field_value!r},")
223
+ else:
224
+ lines.append(f"{indent * (depth + 1)}{field_name}={field_value!s},")
232
225
 
233
226
  lines.append(f"{indent * depth})")
234
227
  return "\n".join(lines)
@@ -351,7 +344,10 @@ class StructInstance:
351
344
  return self._ctype
352
345
 
353
346
  def __repr__(self):
354
- return struct_instance_repr_recursive(self, 0)
347
+ return struct_instance_repr_recursive(self, 0, use_repr=True)
348
+
349
+ def __str__(self):
350
+ return struct_instance_repr_recursive(self, 0, use_repr=False)
355
351
 
356
352
  def to(self, device):
357
353
  """Copies this struct with all array members moved onto the given device.
@@ -415,12 +411,14 @@ class StructInstance:
415
411
 
416
412
 
417
413
  class Struct:
418
- def __init__(self, cls, key, module):
414
+ hash: bytes
415
+
416
+ def __init__(self, cls: type, key: str, module: warp.context.Module):
419
417
  self.cls = cls
420
418
  self.module = module
421
419
  self.key = key
420
+ self.vars: Dict[str, Var] = {}
422
421
 
423
- self.vars = {}
424
422
  annotations = get_annotations(self.cls)
425
423
  for label, type in annotations.items():
426
424
  self.vars[label] = Var(label, type)
@@ -591,11 +589,11 @@ class Reference:
591
589
  self.value_type = value_type
592
590
 
593
591
 
594
- def is_reference(type):
592
+ def is_reference(type: Any) -> builtins.bool:
595
593
  return isinstance(type, Reference)
596
594
 
597
595
 
598
- def strip_reference(arg):
596
+ def strip_reference(arg: Any) -> Any:
599
597
  if is_reference(arg):
600
598
  return arg.value_type
601
599
  else:
@@ -623,7 +621,15 @@ def compute_type_str(base_name, template_params):
623
621
 
624
622
 
625
623
  class Var:
626
- def __init__(self, label, type, requires_grad=False, constant=None, prefix=True):
624
+ def __init__(
625
+ self,
626
+ label: str,
627
+ type: type,
628
+ requires_grad: builtins.bool = False,
629
+ constant: Optional[builtins.bool] = None,
630
+ prefix: builtins.bool = True,
631
+ relative_lineno: Optional[int] = None,
632
+ ):
627
633
  # convert built-in types to wp types
628
634
  if type == float:
629
635
  type = float32
@@ -646,11 +652,14 @@ class Var:
646
652
  # used to associate a view array Var with its parent array Var
647
653
  self.parent = None
648
654
 
655
+ # Used to associate the variable with the Python statement that resulted in it being created.
656
+ self.relative_lineno = relative_lineno
657
+
649
658
  def __str__(self):
650
659
  return self.label
651
660
 
652
661
  @staticmethod
653
- def type_to_ctype(t, value_type=False):
662
+ def type_to_ctype(t: type, value_type: builtins.bool = False) -> str:
654
663
  if is_array(t):
655
664
  if hasattr(t.dtype, "_wp_generic_type_str_"):
656
665
  dtypestr = compute_type_str(f"wp::{t.dtype._wp_generic_type_str_}", t.dtype._wp_type_params_)
@@ -681,7 +690,7 @@ class Var:
681
690
  else:
682
691
  return f"wp::{t.__name__}"
683
692
 
684
- def ctype(self, value_type=False):
693
+ def ctype(self, value_type: builtins.bool = False) -> str:
685
694
  return Var.type_to_ctype(self.type, value_type)
686
695
 
687
696
  def emit(self, prefix: str = "var"):
@@ -803,7 +812,7 @@ def func_match_args(func, arg_types, kwarg_types):
803
812
  return True
804
813
 
805
814
 
806
- def get_arg_type(arg: Union[Var, Any]):
815
+ def get_arg_type(arg: Union[Var, Any]) -> type:
807
816
  if isinstance(arg, str):
808
817
  return str
809
818
 
@@ -819,7 +828,7 @@ def get_arg_type(arg: Union[Var, Any]):
819
828
  return type(arg)
820
829
 
821
830
 
822
- def get_arg_value(arg: Union[Var, Any]):
831
+ def get_arg_value(arg: Any) -> Any:
823
832
  if isinstance(arg, Sequence):
824
833
  return tuple(get_arg_value(x) for x in arg)
825
834
 
@@ -867,6 +876,9 @@ class Adjoint:
867
876
  "please save it on a file and use `importlib` if needed."
868
877
  ) from e
869
878
 
879
+ # Indicates where the function definition starts (excludes decorators)
880
+ adj.fun_def_lineno = None
881
+
870
882
  # get function source code
871
883
  adj.source = inspect.getsource(func)
872
884
  # ensures that indented class methods can be parsed as kernels
@@ -941,9 +953,6 @@ class Adjoint:
941
953
  # for unit testing errors being spit out from kernels.
942
954
  adj.skip_build = False
943
955
 
944
- # Collect the LTOIR required at link-time
945
- adj.ltoirs = []
946
-
947
956
  # allocate extra space for a function call that requires its
948
957
  # own shared memory space, we treat shared memory as a stack
949
958
  # where each function pushes and pops space off, the extra
@@ -1133,7 +1142,7 @@ class Adjoint:
1133
1142
  name = str(index)
1134
1143
 
1135
1144
  # allocate new variable
1136
- v = Var(name, type=type, constant=constant)
1145
+ v = Var(name, type=type, constant=constant, relative_lineno=adj.lineno)
1137
1146
 
1138
1147
  adj.variables.append(v)
1139
1148
 
@@ -1158,11 +1167,44 @@ class Adjoint:
1158
1167
 
1159
1168
  return var
1160
1169
 
1161
- # append a statement to the forward pass
1162
- def add_forward(adj, statement, replay=None, skip_replay=False):
1170
+ def get_line_directive(adj, statement: str, relative_lineno: Optional[int] = None) -> Optional[str]:
1171
+ """Get a line directive for the given statement.
1172
+
1173
+ Args:
1174
+ statement: The statement to get the line directive for.
1175
+ relative_lineno: The line number of the statement relative to the function.
1176
+
1177
+ Returns:
1178
+ A line directive for the given statement, or None if no line directive is needed.
1179
+ """
1180
+
1181
+ # lineinfo is enabled by default in debug mode regardless of the builder option, don't want to unnecessarily
1182
+ # emit line directives in generated code if it's not being compiled with line information
1183
+ lineinfo_enabled = (
1184
+ adj.builder_options.get("lineinfo", False) or adj.builder_options.get("mode", "release") == "debug"
1185
+ )
1186
+
1187
+ if relative_lineno is not None and lineinfo_enabled and warp.config.line_directives:
1188
+ is_comment = statement.strip().startswith("//")
1189
+ if not is_comment:
1190
+ line = relative_lineno + adj.fun_lineno
1191
+ # Convert backslashes to forward slashes for CUDA compatibility
1192
+ normalized_path = adj.filename.replace("\\", "/")
1193
+ return f'#line {line} "{normalized_path}"'
1194
+ return None
1195
+
1196
+ def add_forward(adj, statement: str, replay: Optional[str] = None, skip_replay: builtins.bool = False) -> None:
1197
+ """Append a statement to the forward pass."""
1198
+
1199
+ if line_directive := adj.get_line_directive(statement, adj.lineno):
1200
+ adj.blocks[-1].body_forward.append(line_directive)
1201
+
1163
1202
  adj.blocks[-1].body_forward.append(adj.indentation + statement)
1164
1203
 
1165
1204
  if not skip_replay:
1205
+ if line_directive:
1206
+ adj.blocks[-1].body_replay.append(line_directive)
1207
+
1166
1208
  if replay:
1167
1209
  # if custom replay specified then output it
1168
1210
  adj.blocks[-1].body_replay.append(adj.indentation + replay)
@@ -1171,9 +1213,14 @@ class Adjoint:
1171
1213
  adj.blocks[-1].body_replay.append(adj.indentation + statement)
1172
1214
 
1173
1215
  # append a statement to the reverse pass
1174
- def add_reverse(adj, statement):
1216
+ def add_reverse(adj, statement: str) -> None:
1217
+ """Append a statement to the reverse pass."""
1218
+
1175
1219
  adj.blocks[-1].body_reverse.append(adj.indentation + statement)
1176
1220
 
1221
+ if line_directive := adj.get_line_directive(statement, adj.lineno):
1222
+ adj.blocks[-1].body_reverse.append(line_directive)
1223
+
1177
1224
  def add_constant(adj, n):
1178
1225
  output = adj.add_var(type=type(n), constant=n)
1179
1226
  return output
@@ -1281,7 +1328,7 @@ class Adjoint:
1281
1328
 
1282
1329
  # Bind the positional and keyword arguments to the function's signature
1283
1330
  # in order to process them as Python does it.
1284
- bound_args = func.signature.bind(*args, **kwargs)
1331
+ bound_args: inspect.BoundArguments = func.signature.bind(*args, **kwargs)
1285
1332
 
1286
1333
  # Type args are the “compile time” argument values we get from codegen.
1287
1334
  # For example, when calling `wp.vec3f(...)` from within a kernel,
@@ -1451,6 +1498,8 @@ class Adjoint:
1451
1498
 
1452
1499
  def add_return(adj, var):
1453
1500
  if var is None or len(var) == 0:
1501
+ # NOTE: If this kernel gets compiled for a CUDA device, then we need
1502
+ # to convert the return; into a continue; in codegen_func_forward()
1454
1503
  adj.add_forward("return;", f"goto label{adj.label_count};")
1455
1504
  elif len(var) == 1:
1456
1505
  adj.add_forward(f"return {var[0].emit()};", f"goto label{adj.label_count};")
@@ -1624,6 +1673,8 @@ class Adjoint:
1624
1673
  adj.blocks[-1].body_reverse.extend(reversed(reverse))
1625
1674
 
1626
1675
  def emit_FunctionDef(adj, node):
1676
+ adj.fun_def_lineno = node.lineno
1677
+
1627
1678
  for f in node.body:
1628
1679
  # Skip variable creation for standalone constants, including docstrings
1629
1680
  if isinstance(f, ast.Expr) and isinstance(f.value, ast.Constant):
@@ -1688,7 +1739,7 @@ class Adjoint:
1688
1739
 
1689
1740
  if var1 != var2:
1690
1741
  # insert a phi function that selects var1, var2 based on cond
1691
- out = adj.add_builtin_call("select", [cond, var1, var2])
1742
+ out = adj.add_builtin_call("where", [cond, var2, var1])
1692
1743
  adj.symbols[sym] = out
1693
1744
 
1694
1745
  symbols_prev = adj.symbols.copy()
@@ -1712,7 +1763,7 @@ class Adjoint:
1712
1763
  if var1 != var2:
1713
1764
  # insert a phi function that selects var1, var2 based on cond
1714
1765
  # note the reversed order of vars since we want to use !cond as our select
1715
- out = adj.add_builtin_call("select", [cond, var2, var1])
1766
+ out = adj.add_builtin_call("where", [cond, var1, var2])
1716
1767
  adj.symbols[sym] = out
1717
1768
 
1718
1769
  def emit_Compare(adj, node):
@@ -1856,25 +1907,6 @@ class Adjoint:
1856
1907
  ) from e
1857
1908
  raise WarpCodegenAttributeError(f"Error, `{node.attr}` is not an attribute of '{aggregate}'") from e
1858
1909
 
1859
- def emit_String(adj, node):
1860
- # string constant
1861
- return adj.add_constant(node.s)
1862
-
1863
- def emit_Num(adj, node):
1864
- # lookup constant, if it has already been assigned then return existing var
1865
- key = (node.n, type(node.n))
1866
-
1867
- if key in adj.symbols:
1868
- return adj.symbols[key]
1869
- else:
1870
- out = adj.add_constant(node.n)
1871
- adj.symbols[key] = out
1872
- return out
1873
-
1874
- def emit_Ellipsis(adj, node):
1875
- # stubbed @wp.native_func
1876
- return
1877
-
1878
1910
  def emit_Assert(adj, node):
1879
1911
  # eval condition
1880
1912
  cond = adj.eval(node.test)
@@ -1886,24 +1918,11 @@ class Adjoint:
1886
1918
 
1887
1919
  adj.add_forward(f'assert(("{escaped_segment}",{cond.emit()}));')
1888
1920
 
1889
- def emit_NameConstant(adj, node):
1890
- if node.value:
1891
- return adj.add_constant(node.value)
1892
- elif node.value is None:
1893
- raise WarpCodegenTypeError("None type unsupported")
1894
- else:
1895
- return adj.add_constant(False)
1896
-
1897
1921
  def emit_Constant(adj, node):
1898
- if isinstance(node, ast.Str):
1899
- return adj.emit_String(node)
1900
- elif isinstance(node, ast.Num):
1901
- return adj.emit_Num(node)
1902
- elif isinstance(node, ast.Ellipsis):
1903
- return adj.emit_Ellipsis(node)
1922
+ if node.value is None:
1923
+ raise WarpCodegenTypeError("None type unsupported")
1904
1924
  else:
1905
- assert isinstance(node, ast.NameConstant) or isinstance(node, ast.Constant)
1906
- return adj.emit_NameConstant(node)
1925
+ return adj.add_constant(node.value)
1907
1926
 
1908
1927
  def emit_BinOp(adj, node):
1909
1928
  # evaluate binary operator arguments
@@ -1997,10 +2016,11 @@ class Adjoint:
1997
2016
  adj.end_while()
1998
2017
 
1999
2018
  def eval_num(adj, a):
2000
- if isinstance(a, ast.Num):
2001
- return True, a.n
2002
- if isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Num):
2003
- return True, -a.operand.n
2019
+ if isinstance(a, ast.Constant):
2020
+ return True, a.value
2021
+ if isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Constant):
2022
+ # Negative constant
2023
+ return True, -a.operand.value
2004
2024
 
2005
2025
  # try and resolve the expression to an object
2006
2026
  # e.g.: wp.constant in the globals scope
@@ -2530,8 +2550,8 @@ class Adjoint:
2530
2550
  f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
2531
2551
  )
2532
2552
  else:
2533
- if adj.builder_options.get("enable_backward", True):
2534
- out = adj.add_builtin_call("assign", [target, *indices, rhs])
2553
+ if warp.config.enable_vector_component_overwrites:
2554
+ out = adj.add_builtin_call("assign_copy", [target, *indices, rhs])
2535
2555
 
2536
2556
  # re-point target symbol to out var
2537
2557
  for id in adj.symbols:
@@ -2539,8 +2559,7 @@ class Adjoint:
2539
2559
  adj.symbols[id] = out
2540
2560
  break
2541
2561
  else:
2542
- attr = adj.add_builtin_call("index", [target, *indices])
2543
- adj.add_builtin_call("store", [attr, rhs])
2562
+ adj.add_builtin_call("assign_inplace", [target, *indices, rhs])
2544
2563
 
2545
2564
  else:
2546
2565
  raise WarpCodegenError(
@@ -2583,8 +2602,8 @@ class Adjoint:
2583
2602
  attr = adj.add_builtin_call("indexref", [aggregate, index])
2584
2603
  adj.add_builtin_call("store", [attr, rhs])
2585
2604
  else:
2586
- if adj.builder_options.get("enable_backward", True):
2587
- out = adj.add_builtin_call("assign", [aggregate, index, rhs])
2605
+ if warp.config.enable_vector_component_overwrites:
2606
+ out = adj.add_builtin_call("assign_copy", [aggregate, index, rhs])
2588
2607
 
2589
2608
  # re-point target symbol to out var
2590
2609
  for id in adj.symbols:
@@ -2592,8 +2611,7 @@ class Adjoint:
2592
2611
  adj.symbols[id] = out
2593
2612
  break
2594
2613
  else:
2595
- attr = adj.add_builtin_call("index", [aggregate, index])
2596
- adj.add_builtin_call("store", [attr, rhs])
2614
+ adj.add_builtin_call("assign_inplace", [aggregate, index, rhs])
2597
2615
 
2598
2616
  else:
2599
2617
  attr = adj.emit_Attribute(lhs)
@@ -2699,10 +2717,12 @@ class Adjoint:
2699
2717
 
2700
2718
  elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
2701
2719
  if isinstance(node.op, ast.Add):
2702
- adj.add_builtin_call("augassign_add", [target, *indices, rhs])
2720
+ adj.add_builtin_call("add_inplace", [target, *indices, rhs])
2703
2721
  elif isinstance(node.op, ast.Sub):
2704
- adj.add_builtin_call("augassign_sub", [target, *indices, rhs])
2722
+ adj.add_builtin_call("sub_inplace", [target, *indices, rhs])
2705
2723
  else:
2724
+ if warp.config.verbose:
2725
+ print(f"Warning: in-place op {node.op} is not differentiable")
2706
2726
  make_new_assign_statement()
2707
2727
  return
2708
2728
 
@@ -2732,9 +2752,6 @@ class Adjoint:
2732
2752
  ast.BoolOp: emit_BoolOp,
2733
2753
  ast.Name: emit_Name,
2734
2754
  ast.Attribute: emit_Attribute,
2735
- ast.Str: emit_String, # Deprecated in 3.8; use Constant
2736
- ast.Num: emit_Num, # Deprecated in 3.8; use Constant
2737
- ast.NameConstant: emit_NameConstant, # Deprecated in 3.8; use Constant
2738
2755
  ast.Constant: emit_Constant,
2739
2756
  ast.BinOp: emit_BinOp,
2740
2757
  ast.UnaryOp: emit_UnaryOp,
@@ -2744,14 +2761,13 @@ class Adjoint:
2744
2761
  ast.Continue: emit_Continue,
2745
2762
  ast.Expr: emit_Expr,
2746
2763
  ast.Call: emit_Call,
2747
- ast.Index: emit_Index, # Deprecated in 3.8; Use the index value directly instead.
2764
+ ast.Index: emit_Index, # Deprecated in 3.9
2748
2765
  ast.Subscript: emit_Subscript,
2749
2766
  ast.Assign: emit_Assign,
2750
2767
  ast.Return: emit_Return,
2751
2768
  ast.AugAssign: emit_AugAssign,
2752
2769
  ast.Tuple: emit_Tuple,
2753
2770
  ast.Pass: emit_Pass,
2754
- ast.Ellipsis: emit_Ellipsis,
2755
2771
  ast.Assert: emit_Assert,
2756
2772
  }
2757
2773
 
@@ -2947,12 +2963,16 @@ class Adjoint:
2947
2963
 
2948
2964
  # We want to replace the expression code in-place,
2949
2965
  # so reparse it to get the correct column info.
2950
- len_value_locs = []
2966
+ len_value_locs: List[Tuple[int, int, int]] = []
2951
2967
  expr_tree = ast.parse(static_code)
2952
2968
  assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr)
2953
2969
  expr_root = expr_tree.body[0].value
2954
2970
  for expr_node in ast.walk(expr_root):
2955
- if isinstance(expr_node, ast.Call) and expr_node.func.id == "len" and len(expr_node.args) == 1:
2971
+ if (
2972
+ isinstance(expr_node, ast.Call)
2973
+ and getattr(expr_node.func, "id", None) == "len"
2974
+ and len(expr_node.args) == 1
2975
+ ):
2956
2976
  len_expr = static_code[expr_node.col_offset : expr_node.end_col_offset]
2957
2977
  try:
2958
2978
  len_value = eval(len_expr, len_expr_ctx)
@@ -3110,9 +3130,9 @@ class Adjoint:
3110
3130
 
3111
3131
  local_variables = set() # Track local variables appearing on the LHS so we know when variables are shadowed
3112
3132
 
3113
- constants = {}
3114
- types = {}
3115
- functions = {}
3133
+ constants: Dict[str, Any] = {}
3134
+ types: Dict[Union[Struct, type], Any] = {}
3135
+ functions: Dict[warp.context.Function, Any] = {}
3116
3136
 
3117
3137
  for node in ast.walk(adj.tree):
3118
3138
  if isinstance(node, ast.Name) and node.id not in local_variables:
@@ -3155,7 +3175,7 @@ class Adjoint:
3155
3175
  # code generation
3156
3176
 
3157
3177
  cpu_module_header = """
3158
- #define WP_TILE_BLOCK_DIM {tile_size}
3178
+ #define WP_TILE_BLOCK_DIM {block_dim}
3159
3179
  #define WP_NO_CRT
3160
3180
  #include "builtin.h"
3161
3181
 
@@ -3174,7 +3194,7 @@ cpu_module_header = """
3174
3194
  """
3175
3195
 
3176
3196
  cuda_module_header = """
3177
- #define WP_TILE_BLOCK_DIM {tile_size}
3197
+ #define WP_TILE_BLOCK_DIM {block_dim}
3178
3198
  #define WP_NO_CRT
3179
3199
  #include "builtin.h"
3180
3200
 
@@ -3197,6 +3217,7 @@ struct {name}
3197
3217
  {{
3198
3218
  {struct_body}
3199
3219
 
3220
+ {defaulted_constructor_def}
3200
3221
  CUDA_CALLABLE {name}({forward_args})
3201
3222
  {forward_initializers}
3202
3223
  {{
@@ -3239,53 +3260,53 @@ static void adj_{name}(
3239
3260
 
3240
3261
  cuda_forward_function_template = """
3241
3262
  // {filename}:{lineno}
3242
- static CUDA_CALLABLE {return_type} {name}(
3263
+ {line_directive}static CUDA_CALLABLE {return_type} {name}(
3243
3264
  {forward_args})
3244
3265
  {{
3245
- {forward_body}}}
3266
+ {forward_body}{line_directive}}}
3246
3267
 
3247
3268
  """
3248
3269
 
3249
3270
  cuda_reverse_function_template = """
3250
3271
  // {filename}:{lineno}
3251
- static CUDA_CALLABLE void adj_{name}(
3272
+ {line_directive}static CUDA_CALLABLE void adj_{name}(
3252
3273
  {reverse_args})
3253
3274
  {{
3254
- {reverse_body}}}
3275
+ {reverse_body}{line_directive}}}
3255
3276
 
3256
3277
  """
3257
3278
 
3258
3279
  cuda_kernel_template_forward = """
3259
3280
 
3260
- extern "C" __global__ void {name}_cuda_kernel_forward(
3281
+ {line_directive}extern "C" __global__ void {name}_cuda_kernel_forward(
3261
3282
  {forward_args})
3262
3283
  {{
3263
- for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3264
- _idx < dim.size;
3265
- _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3284
+ {line_directive} for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3285
+ {line_directive} _idx < dim.size;
3286
+ {line_directive} _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3266
3287
  {{
3267
3288
  // reset shared memory allocator
3268
- wp::tile_alloc_shared(0, true);
3289
+ {line_directive} wp::tile_alloc_shared(0, true);
3269
3290
 
3270
- {forward_body} }}
3271
- }}
3291
+ {forward_body}{line_directive} }}
3292
+ {line_directive}}}
3272
3293
 
3273
3294
  """
3274
3295
 
3275
3296
  cuda_kernel_template_backward = """
3276
3297
 
3277
- extern "C" __global__ void {name}_cuda_kernel_backward(
3298
+ {line_directive}extern "C" __global__ void {name}_cuda_kernel_backward(
3278
3299
  {reverse_args})
3279
3300
  {{
3280
- for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3281
- _idx < dim.size;
3282
- _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3301
+ {line_directive} for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3302
+ {line_directive} _idx < dim.size;
3303
+ {line_directive} _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3283
3304
  {{
3284
3305
  // reset shared memory allocator
3285
- wp::tile_alloc_shared(0, true);
3306
+ {line_directive} wp::tile_alloc_shared(0, true);
3286
3307
 
3287
- {reverse_body} }}
3288
- }}
3308
+ {reverse_body}{line_directive} }}
3309
+ {line_directive}}}
3289
3310
 
3290
3311
  """
3291
3312
 
@@ -3315,10 +3336,17 @@ extern "C" {{
3315
3336
  WP_API void {name}_cpu_forward(
3316
3337
  {forward_args})
3317
3338
  {{
3318
- for (size_t task_index = 0; task_index < dim.size; ++task_index)
3339
+ for (size_t task_index = 0; task_index < dim.size; ++task_index)
3319
3340
  {{
3341
+ // init shared memory allocator
3342
+ wp::tile_alloc_shared(0, true);
3343
+
3320
3344
  {name}_cpu_kernel_forward(
3321
3345
  {forward_params});
3346
+
3347
+ // check shared memory allocator
3348
+ wp::tile_alloc_shared(0, false, true);
3349
+
3322
3350
  }}
3323
3351
  }}
3324
3352
 
@@ -3335,8 +3363,14 @@ WP_API void {name}_cpu_backward(
3335
3363
  {{
3336
3364
  for (size_t task_index = 0; task_index < dim.size; ++task_index)
3337
3365
  {{
3366
+ // initialize shared memory allocator
3367
+ wp::tile_alloc_shared(0, true);
3368
+
3338
3369
  {name}_cpu_kernel_backward(
3339
3370
  {reverse_params});
3371
+
3372
+ // check shared memory allocator
3373
+ wp::tile_alloc_shared(0, false, true);
3340
3374
  }}
3341
3375
  }}
3342
3376
 
@@ -3418,7 +3452,7 @@ def indent(args, stops=1):
3418
3452
 
3419
3453
 
3420
3454
  # generates a C function name based on the python function name
3421
- def make_full_qualified_name(func):
3455
+ def make_full_qualified_name(func: Union[str, Callable]) -> str:
3422
3456
  if not isinstance(func, str):
3423
3457
  func = func.__qualname__
3424
3458
  return re.sub("[^0-9a-zA-Z_]+", "", func.replace(".", "__"))
@@ -3448,7 +3482,8 @@ def codegen_struct(struct, device="cpu", indent_size=4):
3448
3482
  # forward args
3449
3483
  for label, var in struct.vars.items():
3450
3484
  var_ctype = var.ctype()
3451
- forward_args.append(f"{var_ctype} const& {label} = {{}}")
3485
+ default_arg_def = " = {}" if forward_args else ""
3486
+ forward_args.append(f"{var_ctype} const& {label}{default_arg_def}")
3452
3487
  reverse_args.append(f"{var_ctype} const&")
3453
3488
 
3454
3489
  namespace = "wp::" if var_ctype.startswith("wp::") or var_ctype == "bool" else ""
@@ -3472,6 +3507,9 @@ def codegen_struct(struct, device="cpu", indent_size=4):
3472
3507
 
3473
3508
  reverse_args.append(name + " & adj_ret")
3474
3509
 
3510
+ # explicitly defaulted default constructor if no default constructor has been defined
3511
+ defaulted_constructor_def = f"{name}() = default;" if forward_args else ""
3512
+
3475
3513
  return struct_template.format(
3476
3514
  name=name,
3477
3515
  struct_body="".join([indent_block + l for l in body]),
@@ -3481,6 +3519,7 @@ def codegen_struct(struct, device="cpu", indent_size=4):
3481
3519
  reverse_body="".join(reverse_body),
3482
3520
  prefix_add_body="".join(prefix_add_body),
3483
3521
  atomic_add_body="".join(atomic_add_body),
3522
+ defaulted_constructor_def=defaulted_constructor_def,
3484
3523
  )
3485
3524
 
3486
3525
 
@@ -3510,14 +3549,21 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
3510
3549
  else:
3511
3550
  lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
3512
3551
 
3552
+ if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
3553
+ lines.insert(-1, f"{line_directive}\n")
3554
+
3513
3555
  # forward pass
3514
3556
  lines += ["//---------\n"]
3515
3557
  lines += ["// forward\n"]
3516
3558
 
3517
3559
  for f in adj.blocks[0].body_forward:
3518
- lines += [f + "\n"]
3560
+ if func_type == "kernel" and device == "cuda" and f.lstrip().startswith("return;"):
3561
+ # Use of grid-stride loops in CUDA kernels requires that we convert return; to continue;
3562
+ lines += [f.replace("return;", "continue;") + "\n"]
3563
+ else:
3564
+ lines += [f + "\n"]
3519
3565
 
3520
- return "".join([indent_block + l for l in lines])
3566
+ return "".join(l.lstrip() if l.lstrip().startswith("#line") else indent_block + l for l in lines)
3521
3567
 
3522
3568
 
3523
3569
  def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
@@ -3547,6 +3593,9 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
3547
3593
  else:
3548
3594
  lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
3549
3595
 
3596
+ if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
3597
+ lines.insert(-1, f"{line_directive}\n")
3598
+
3550
3599
  # dual vars
3551
3600
  lines += ["//---------\n"]
3552
3601
  lines += ["// dual vars\n"]
@@ -3567,6 +3616,9 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
3567
3616
  else:
3568
3617
  lines += [f"{ctype} {name} = {{}};\n"]
3569
3618
 
3619
+ if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
3620
+ lines.insert(-1, f"{line_directive}\n")
3621
+
3570
3622
  # forward pass
3571
3623
  lines += ["//---------\n"]
3572
3624
  lines += ["// forward\n"]
@@ -3587,7 +3639,7 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
3587
3639
  else:
3588
3640
  lines += ["return;\n"]
3589
3641
 
3590
- return "".join([indent_block + l for l in lines])
3642
+ return "".join(l.lstrip() if l.lstrip().startswith("#line") else indent_block + l for l in lines)
3591
3643
 
3592
3644
 
3593
3645
  def codegen_func(adj, c_func_name: str, device="cpu", options=None):
@@ -3595,11 +3647,11 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3595
3647
  options = {}
3596
3648
 
3597
3649
  if adj.return_var is not None and "return" in adj.arg_types:
3598
- if get_type_origin(adj.arg_types["return"]) is tuple:
3599
- if len(get_type_args(adj.arg_types["return"])) != len(adj.return_var):
3650
+ if get_origin(adj.arg_types["return"]) is tuple:
3651
+ if len(get_args(adj.arg_types["return"])) != len(adj.return_var):
3600
3652
  raise WarpCodegenError(
3601
3653
  f"The function `{adj.fun_name}` has its return type "
3602
- f"annotated as a tuple of {len(get_type_args(adj.arg_types['return']))} elements "
3654
+ f"annotated as a tuple of {len(get_args(adj.arg_types['return']))} elements "
3603
3655
  f"but the code returns {len(adj.return_var)} values."
3604
3656
  )
3605
3657
  elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var)):
@@ -3608,7 +3660,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3608
3660
  f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
3609
3661
  f"but the code returns a tuple with types `({', '.join(warp.context.type_str(x.type) for x in adj.return_var)})`."
3610
3662
  )
3611
- elif len(adj.return_var) > 1 and get_type_origin(adj.arg_types["return"]) is not tuple:
3663
+ elif len(adj.return_var) > 1 and get_origin(adj.arg_types["return"]) is not tuple:
3612
3664
  raise WarpCodegenError(
3613
3665
  f"The function `{adj.fun_name}` has its return type "
3614
3666
  f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
@@ -3621,6 +3673,13 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3621
3673
  f"but the code returns a value of type `{warp.context.type_str(adj.return_var[0].type)}`."
3622
3674
  )
3623
3675
 
3676
+ # Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
3677
+ # This is used as a catch-all C-to-Python source line mapping for any code that does not have
3678
+ # a direct mapping to a Python source line.
3679
+ func_line_directive = ""
3680
+ if line_directive := adj.get_line_directive("", adj.fun_def_lineno - 1):
3681
+ func_line_directive = f"{line_directive}\n"
3682
+
3624
3683
  # forward header
3625
3684
  if adj.return_var is not None and len(adj.return_var) == 1:
3626
3685
  return_type = adj.return_var[0].ctype()
@@ -3684,6 +3743,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3684
3743
  forward_body=forward_body,
3685
3744
  filename=adj.filename,
3686
3745
  lineno=adj.fun_lineno,
3746
+ line_directive=func_line_directive,
3687
3747
  )
3688
3748
 
3689
3749
  if not adj.skip_reverse_codegen:
@@ -3702,6 +3762,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3702
3762
  reverse_body=reverse_body,
3703
3763
  filename=adj.filename,
3704
3764
  lineno=adj.fun_lineno,
3765
+ line_directive=func_line_directive,
3705
3766
  )
3706
3767
 
3707
3768
  return s
@@ -3744,6 +3805,7 @@ def codegen_snippet(adj, name, snippet, adj_snippet, replay_snippet):
3744
3805
  forward_body=snippet,
3745
3806
  filename=adj.filename,
3746
3807
  lineno=adj.fun_lineno,
3808
+ line_directive="",
3747
3809
  )
3748
3810
 
3749
3811
  if replay_snippet is not None:
@@ -3754,6 +3816,7 @@ def codegen_snippet(adj, name, snippet, adj_snippet, replay_snippet):
3754
3816
  forward_body=replay_snippet,
3755
3817
  filename=adj.filename,
3756
3818
  lineno=adj.fun_lineno,
3819
+ line_directive="",
3757
3820
  )
3758
3821
 
3759
3822
  if adj_snippet:
@@ -3769,6 +3832,7 @@ def codegen_snippet(adj, name, snippet, adj_snippet, replay_snippet):
3769
3832
  reverse_body=reverse_body,
3770
3833
  filename=adj.filename,
3771
3834
  lineno=adj.fun_lineno,
3835
+ line_directive="",
3772
3836
  )
3773
3837
 
3774
3838
  return s
@@ -3781,6 +3845,13 @@ def codegen_kernel(kernel, device, options):
3781
3845
 
3782
3846
  adj = kernel.adj
3783
3847
 
3848
+ # Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
3849
+ # This is used as a catch-all C-to-Python source line mapping for any code that does not have
3850
+ # a direct mapping to a Python source line.
3851
+ func_line_directive = ""
3852
+ if line_directive := adj.get_line_directive("", adj.fun_def_lineno - 1):
3853
+ func_line_directive = f"{line_directive}\n"
3854
+
3784
3855
  if device == "cpu":
3785
3856
  template_forward = cpu_kernel_template_forward
3786
3857
  template_backward = cpu_kernel_template_backward
@@ -3808,6 +3879,7 @@ def codegen_kernel(kernel, device, options):
3808
3879
  {
3809
3880
  "forward_args": indent(forward_args),
3810
3881
  "forward_body": forward_body,
3882
+ "line_directive": func_line_directive,
3811
3883
  }
3812
3884
  )
3813
3885
  template += template_forward