warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (179) hide show
  1. warp/__init__.py +7 -1
  2. warp/bin/libwarp-clang.dylib +0 -0
  3. warp/bin/libwarp.dylib +0 -0
  4. warp/build.py +410 -0
  5. warp/build_dll.py +6 -14
  6. warp/builtins.py +452 -362
  7. warp/codegen.py +179 -119
  8. warp/config.py +42 -6
  9. warp/context.py +490 -271
  10. warp/dlpack.py +8 -6
  11. warp/examples/assets/nonuniform.usd +0 -0
  12. warp/examples/assets/nvidia_logo.png +0 -0
  13. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  14. warp/examples/core/example_sample_mesh.py +300 -0
  15. warp/examples/fem/example_apic_fluid.py +1 -1
  16. warp/examples/fem/example_burgers.py +2 -2
  17. warp/examples/fem/example_deformed_geometry.py +1 -1
  18. warp/examples/fem/example_distortion_energy.py +1 -1
  19. warp/examples/fem/example_magnetostatics.py +6 -6
  20. warp/examples/fem/utils.py +9 -3
  21. warp/examples/interop/example_jax_callable.py +116 -0
  22. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  23. warp/examples/interop/example_jax_kernel.py +205 -0
  24. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  25. warp/examples/tile/example_tile_matmul.py +2 -4
  26. warp/fem/__init__.py +11 -1
  27. warp/fem/adaptivity.py +4 -4
  28. warp/fem/field/nodal_field.py +22 -68
  29. warp/fem/field/virtual.py +62 -23
  30. warp/fem/geometry/adaptive_nanogrid.py +9 -10
  31. warp/fem/geometry/closest_point.py +1 -1
  32. warp/fem/geometry/deformed_geometry.py +5 -2
  33. warp/fem/geometry/geometry.py +5 -0
  34. warp/fem/geometry/grid_2d.py +12 -12
  35. warp/fem/geometry/grid_3d.py +12 -15
  36. warp/fem/geometry/hexmesh.py +5 -7
  37. warp/fem/geometry/nanogrid.py +9 -11
  38. warp/fem/geometry/quadmesh.py +13 -13
  39. warp/fem/geometry/tetmesh.py +3 -4
  40. warp/fem/geometry/trimesh.py +3 -8
  41. warp/fem/integrate.py +262 -93
  42. warp/fem/linalg.py +5 -5
  43. warp/fem/quadrature/pic_quadrature.py +37 -22
  44. warp/fem/quadrature/quadrature.py +194 -25
  45. warp/fem/space/__init__.py +1 -1
  46. warp/fem/space/basis_function_space.py +4 -2
  47. warp/fem/space/basis_space.py +25 -18
  48. warp/fem/space/hexmesh_function_space.py +2 -2
  49. warp/fem/space/partition.py +6 -2
  50. warp/fem/space/quadmesh_function_space.py +8 -8
  51. warp/fem/space/shape/cube_shape_function.py +23 -23
  52. warp/fem/space/shape/square_shape_function.py +12 -12
  53. warp/fem/space/shape/triangle_shape_function.py +1 -1
  54. warp/fem/space/tetmesh_function_space.py +3 -3
  55. warp/fem/space/trimesh_function_space.py +2 -2
  56. warp/fem/utils.py +12 -6
  57. warp/jax.py +14 -1
  58. warp/jax_experimental/__init__.py +16 -0
  59. warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
  60. warp/jax_experimental/ffi.py +698 -0
  61. warp/jax_experimental/xla_ffi.py +602 -0
  62. warp/math.py +89 -0
  63. warp/native/array.h +13 -0
  64. warp/native/builtin.h +29 -3
  65. warp/native/bvh.cpp +3 -1
  66. warp/native/bvh.cu +42 -14
  67. warp/native/bvh.h +2 -1
  68. warp/native/clang/clang.cpp +30 -3
  69. warp/native/cuda_util.cpp +14 -0
  70. warp/native/cuda_util.h +2 -0
  71. warp/native/exports.h +68 -63
  72. warp/native/intersect.h +26 -26
  73. warp/native/intersect_adj.h +33 -33
  74. warp/native/marching.cu +1 -1
  75. warp/native/mat.h +513 -9
  76. warp/native/mesh.h +10 -10
  77. warp/native/quat.h +99 -11
  78. warp/native/rand.h +6 -0
  79. warp/native/sort.cpp +122 -59
  80. warp/native/sort.cu +152 -15
  81. warp/native/sort.h +8 -1
  82. warp/native/sparse.cpp +43 -22
  83. warp/native/sparse.cu +52 -17
  84. warp/native/svd.h +116 -0
  85. warp/native/tile.h +301 -105
  86. warp/native/tile_reduce.h +46 -3
  87. warp/native/vec.h +68 -7
  88. warp/native/volume.cpp +85 -113
  89. warp/native/volume_builder.cu +25 -10
  90. warp/native/volume_builder.h +6 -0
  91. warp/native/warp.cpp +5 -6
  92. warp/native/warp.cu +99 -10
  93. warp/native/warp.h +19 -10
  94. warp/optim/linear.py +10 -10
  95. warp/sim/articulation.py +4 -4
  96. warp/sim/collide.py +21 -10
  97. warp/sim/import_mjcf.py +449 -155
  98. warp/sim/import_urdf.py +32 -12
  99. warp/sim/integrator_euler.py +5 -5
  100. warp/sim/integrator_featherstone.py +3 -10
  101. warp/sim/integrator_vbd.py +207 -2
  102. warp/sim/integrator_xpbd.py +5 -5
  103. warp/sim/model.py +42 -13
  104. warp/sim/utils.py +2 -2
  105. warp/sparse.py +642 -555
  106. warp/stubs.py +216 -19
  107. warp/tests/__main__.py +0 -15
  108. warp/tests/cuda/__init__.py +0 -0
  109. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  110. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  111. warp/tests/geometry/__init__.py +0 -0
  112. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  113. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  114. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  115. warp/tests/interop/__init__.py +0 -0
  116. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  117. warp/tests/sim/__init__.py +0 -0
  118. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  119. warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
  120. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  121. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  122. warp/tests/sim/test_vbd.py +597 -0
  123. warp/tests/test_bool.py +1 -1
  124. warp/tests/test_examples.py +28 -36
  125. warp/tests/test_fem.py +23 -4
  126. warp/tests/test_linear_solvers.py +0 -11
  127. warp/tests/test_mat.py +233 -79
  128. warp/tests/test_mat_scalar_ops.py +4 -4
  129. warp/tests/test_overwrite.py +0 -60
  130. warp/tests/test_quat.py +67 -46
  131. warp/tests/test_rand.py +44 -37
  132. warp/tests/test_sparse.py +47 -6
  133. warp/tests/test_spatial.py +75 -0
  134. warp/tests/test_static.py +1 -1
  135. warp/tests/test_utils.py +84 -4
  136. warp/tests/test_vec.py +46 -34
  137. warp/tests/tile/__init__.py +0 -0
  138. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  139. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
  140. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  141. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  142. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  143. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  144. warp/tests/unittest_serial.py +1 -0
  145. warp/tests/unittest_suites.py +45 -59
  146. warp/tests/unittest_utils.py +2 -1
  147. warp/thirdparty/unittest_parallel.py +3 -1
  148. warp/types.py +110 -658
  149. warp/utils.py +137 -72
  150. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
  151. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
  152. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  153. warp/examples/optim/example_walker.py +0 -317
  154. warp/native/cutlass_gemm.cpp +0 -43
  155. warp/native/cutlass_gemm.cu +0 -382
  156. warp/tests/test_matmul.py +0 -511
  157. warp/tests/test_matmul_lite.py +0 -411
  158. warp/tests/test_vbd.py +0 -386
  159. warp/tests/unused_test_misc.py +0 -77
  160. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  161. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  162. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  163. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  164. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  165. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  166. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  167. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  168. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  169. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  170. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  171. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  172. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  173. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  174. /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
  175. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  176. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  177. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  178. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
  179. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/codegen.py CHANGED
@@ -26,7 +26,7 @@ import re
26
26
  import sys
27
27
  import textwrap
28
28
  import types
29
- from typing import Any, Callable, Dict, Mapping, Optional, Sequence
29
+ from typing import Any, Callable, Dict, Mapping, Optional, Sequence, get_args, get_origin
30
30
 
31
31
  import warp.config
32
32
  from warp.types import *
@@ -57,7 +57,7 @@ class WarpCodegenKeyError(KeyError):
57
57
 
58
58
 
59
59
  # map operator to function name
60
- builtin_operators = {}
60
+ builtin_operators: Dict[type[ast.AST], str] = {}
61
61
 
62
62
  # see https://www.ics.uci.edu/~pattis/ICS-31/lectures/opexp.pdf for a
63
63
  # nice overview of python operators
@@ -122,16 +122,6 @@ def get_closure_cell_contents(obj):
122
122
  return None
123
123
 
124
124
 
125
- def get_type_origin(tp):
126
- # Compatible version of `typing.get_origin()` for Python 3.7 and older.
127
- return getattr(tp, "__origin__", None)
128
-
129
-
130
- def get_type_args(tp):
131
- # Compatible version of `typing.get_args()` for Python 3.7 and older.
132
- return getattr(tp, "__args__", ())
133
-
134
-
135
125
  def eval_annotations(annotations: Mapping[str, Any], obj: Any) -> Mapping[str, Any]:
136
126
  """Un-stringize annotations caused by `from __future__ import annotations` of PEP 563."""
137
127
  # Implementation backported from `inspect.get_annotations()` for Python 3.9 and older.
@@ -415,12 +405,14 @@ class StructInstance:
415
405
 
416
406
 
417
407
  class Struct:
418
- def __init__(self, cls, key, module):
408
+ hash: bytes
409
+
410
+ def __init__(self, cls: type, key: str, module: warp.context.Module):
419
411
  self.cls = cls
420
412
  self.module = module
421
413
  self.key = key
414
+ self.vars: Dict[str, Var] = {}
422
415
 
423
- self.vars = {}
424
416
  annotations = get_annotations(self.cls)
425
417
  for label, type in annotations.items():
426
418
  self.vars[label] = Var(label, type)
@@ -591,11 +583,11 @@ class Reference:
591
583
  self.value_type = value_type
592
584
 
593
585
 
594
- def is_reference(type):
586
+ def is_reference(type: Any) -> builtins.bool:
595
587
  return isinstance(type, Reference)
596
588
 
597
589
 
598
- def strip_reference(arg):
590
+ def strip_reference(arg: Any) -> Any:
599
591
  if is_reference(arg):
600
592
  return arg.value_type
601
593
  else:
@@ -623,7 +615,15 @@ def compute_type_str(base_name, template_params):
623
615
 
624
616
 
625
617
  class Var:
626
- def __init__(self, label, type, requires_grad=False, constant=None, prefix=True):
618
+ def __init__(
619
+ self,
620
+ label: str,
621
+ type: type,
622
+ requires_grad: builtins.bool = False,
623
+ constant: Optional[builtins.bool] = None,
624
+ prefix: builtins.bool = True,
625
+ relative_lineno: Optional[int] = None,
626
+ ):
627
627
  # convert built-in types to wp types
628
628
  if type == float:
629
629
  type = float32
@@ -646,11 +646,14 @@ class Var:
646
646
  # used to associate a view array Var with its parent array Var
647
647
  self.parent = None
648
648
 
649
+ # Used to associate the variable with the Python statement that resulted in it being created.
650
+ self.relative_lineno = relative_lineno
651
+
649
652
  def __str__(self):
650
653
  return self.label
651
654
 
652
655
  @staticmethod
653
- def type_to_ctype(t, value_type=False):
656
+ def type_to_ctype(t: type, value_type: builtins.bool = False) -> str:
654
657
  if is_array(t):
655
658
  if hasattr(t.dtype, "_wp_generic_type_str_"):
656
659
  dtypestr = compute_type_str(f"wp::{t.dtype._wp_generic_type_str_}", t.dtype._wp_type_params_)
@@ -681,7 +684,7 @@ class Var:
681
684
  else:
682
685
  return f"wp::{t.__name__}"
683
686
 
684
- def ctype(self, value_type=False):
687
+ def ctype(self, value_type: builtins.bool = False) -> str:
685
688
  return Var.type_to_ctype(self.type, value_type)
686
689
 
687
690
  def emit(self, prefix: str = "var"):
@@ -803,7 +806,7 @@ def func_match_args(func, arg_types, kwarg_types):
803
806
  return True
804
807
 
805
808
 
806
- def get_arg_type(arg: Union[Var, Any]):
809
+ def get_arg_type(arg: Union[Var, Any]) -> type:
807
810
  if isinstance(arg, str):
808
811
  return str
809
812
 
@@ -819,7 +822,7 @@ def get_arg_type(arg: Union[Var, Any]):
819
822
  return type(arg)
820
823
 
821
824
 
822
- def get_arg_value(arg: Union[Var, Any]):
825
+ def get_arg_value(arg: Any) -> Any:
823
826
  if isinstance(arg, Sequence):
824
827
  return tuple(get_arg_value(x) for x in arg)
825
828
 
@@ -867,6 +870,9 @@ class Adjoint:
867
870
  "please save it on a file and use `importlib` if needed."
868
871
  ) from e
869
872
 
873
+ # Indicates where the function definition starts (excludes decorators)
874
+ adj.fun_def_lineno = None
875
+
870
876
  # get function source code
871
877
  adj.source = inspect.getsource(func)
872
878
  # ensures that indented class methods can be parsed as kernels
@@ -941,9 +947,6 @@ class Adjoint:
941
947
  # for unit testing errors being spit out from kernels.
942
948
  adj.skip_build = False
943
949
 
944
- # Collect the LTOIR required at link-time
945
- adj.ltoirs = []
946
-
947
950
  # allocate extra space for a function call that requires its
948
951
  # own shared memory space, we treat shared memory as a stack
949
952
  # where each function pushes and pops space off, the extra
@@ -1133,7 +1136,7 @@ class Adjoint:
1133
1136
  name = str(index)
1134
1137
 
1135
1138
  # allocate new variable
1136
- v = Var(name, type=type, constant=constant)
1139
+ v = Var(name, type=type, constant=constant, relative_lineno=adj.lineno)
1137
1140
 
1138
1141
  adj.variables.append(v)
1139
1142
 
@@ -1158,11 +1161,44 @@ class Adjoint:
1158
1161
 
1159
1162
  return var
1160
1163
 
1161
- # append a statement to the forward pass
1162
- def add_forward(adj, statement, replay=None, skip_replay=False):
1164
+ def get_line_directive(adj, statement: str, relative_lineno: Optional[int] = None) -> Optional[str]:
1165
+ """Get a line directive for the given statement.
1166
+
1167
+ Args:
1168
+ statement: The statement to get the line directive for.
1169
+ relative_lineno: The line number of the statement relative to the function.
1170
+
1171
+ Returns:
1172
+ A line directive for the given statement, or None if no line directive is needed.
1173
+ """
1174
+
1175
+ # lineinfo is enabled by default in debug mode regardless of the builder option, don't want to unnecessarily
1176
+ # emit line directives in generated code if it's not being compiled with line information
1177
+ lineinfo_enabled = (
1178
+ adj.builder_options.get("lineinfo", False) or adj.builder_options.get("mode", "release") == "debug"
1179
+ )
1180
+
1181
+ if relative_lineno is not None and lineinfo_enabled and warp.config.line_directives:
1182
+ is_comment = statement.strip().startswith("//")
1183
+ if not is_comment:
1184
+ line = relative_lineno + adj.fun_lineno
1185
+ # Convert backslashes to forward slashes for CUDA compatibility
1186
+ normalized_path = adj.filename.replace("\\", "/")
1187
+ return f'#line {line} "{normalized_path}"'
1188
+ return None
1189
+
1190
+ def add_forward(adj, statement: str, replay: Optional[str] = None, skip_replay: builtins.bool = False) -> None:
1191
+ """Append a statement to the forward pass."""
1192
+
1193
+ if line_directive := adj.get_line_directive(statement, adj.lineno):
1194
+ adj.blocks[-1].body_forward.append(line_directive)
1195
+
1163
1196
  adj.blocks[-1].body_forward.append(adj.indentation + statement)
1164
1197
 
1165
1198
  if not skip_replay:
1199
+ if line_directive:
1200
+ adj.blocks[-1].body_replay.append(line_directive)
1201
+
1166
1202
  if replay:
1167
1203
  # if custom replay specified then output it
1168
1204
  adj.blocks[-1].body_replay.append(adj.indentation + replay)
@@ -1171,9 +1207,14 @@ class Adjoint:
1171
1207
  adj.blocks[-1].body_replay.append(adj.indentation + statement)
1172
1208
 
1173
1209
  # append a statement to the reverse pass
1174
- def add_reverse(adj, statement):
1210
+ def add_reverse(adj, statement: str) -> None:
1211
+ """Append a statement to the reverse pass."""
1212
+
1175
1213
  adj.blocks[-1].body_reverse.append(adj.indentation + statement)
1176
1214
 
1215
+ if line_directive := adj.get_line_directive(statement, adj.lineno):
1216
+ adj.blocks[-1].body_reverse.append(line_directive)
1217
+
1177
1218
  def add_constant(adj, n):
1178
1219
  output = adj.add_var(type=type(n), constant=n)
1179
1220
  return output
@@ -1281,7 +1322,7 @@ class Adjoint:
1281
1322
 
1282
1323
  # Bind the positional and keyword arguments to the function's signature
1283
1324
  # in order to process them as Python does it.
1284
- bound_args = func.signature.bind(*args, **kwargs)
1325
+ bound_args: inspect.BoundArguments = func.signature.bind(*args, **kwargs)
1285
1326
 
1286
1327
  # Type args are the “compile time” argument values we get from codegen.
1287
1328
  # For example, when calling `wp.vec3f(...)` from within a kernel,
@@ -1624,6 +1665,8 @@ class Adjoint:
1624
1665
  adj.blocks[-1].body_reverse.extend(reversed(reverse))
1625
1666
 
1626
1667
  def emit_FunctionDef(adj, node):
1668
+ adj.fun_def_lineno = node.lineno
1669
+
1627
1670
  for f in node.body:
1628
1671
  # Skip variable creation for standalone constants, including docstrings
1629
1672
  if isinstance(f, ast.Expr) and isinstance(f.value, ast.Constant):
@@ -1688,7 +1731,7 @@ class Adjoint:
1688
1731
 
1689
1732
  if var1 != var2:
1690
1733
  # insert a phi function that selects var1, var2 based on cond
1691
- out = adj.add_builtin_call("select", [cond, var1, var2])
1734
+ out = adj.add_builtin_call("where", [cond, var2, var1])
1692
1735
  adj.symbols[sym] = out
1693
1736
 
1694
1737
  symbols_prev = adj.symbols.copy()
@@ -1712,7 +1755,7 @@ class Adjoint:
1712
1755
  if var1 != var2:
1713
1756
  # insert a phi function that selects var1, var2 based on cond
1714
1757
  # note the reversed order of vars since we want to use !cond as our select
1715
- out = adj.add_builtin_call("select", [cond, var2, var1])
1758
+ out = adj.add_builtin_call("where", [cond, var1, var2])
1716
1759
  adj.symbols[sym] = out
1717
1760
 
1718
1761
  def emit_Compare(adj, node):
@@ -1856,25 +1899,6 @@ class Adjoint:
1856
1899
  ) from e
1857
1900
  raise WarpCodegenAttributeError(f"Error, `{node.attr}` is not an attribute of '{aggregate}'") from e
1858
1901
 
1859
- def emit_String(adj, node):
1860
- # string constant
1861
- return adj.add_constant(node.s)
1862
-
1863
- def emit_Num(adj, node):
1864
- # lookup constant, if it has already been assigned then return existing var
1865
- key = (node.n, type(node.n))
1866
-
1867
- if key in adj.symbols:
1868
- return adj.symbols[key]
1869
- else:
1870
- out = adj.add_constant(node.n)
1871
- adj.symbols[key] = out
1872
- return out
1873
-
1874
- def emit_Ellipsis(adj, node):
1875
- # stubbed @wp.native_func
1876
- return
1877
-
1878
1902
  def emit_Assert(adj, node):
1879
1903
  # eval condition
1880
1904
  cond = adj.eval(node.test)
@@ -1886,24 +1910,11 @@ class Adjoint:
1886
1910
 
1887
1911
  adj.add_forward(f'assert(("{escaped_segment}",{cond.emit()}));')
1888
1912
 
1889
- def emit_NameConstant(adj, node):
1890
- if node.value:
1891
- return adj.add_constant(node.value)
1892
- elif node.value is None:
1893
- raise WarpCodegenTypeError("None type unsupported")
1894
- else:
1895
- return adj.add_constant(False)
1896
-
1897
1913
  def emit_Constant(adj, node):
1898
- if isinstance(node, ast.Str):
1899
- return adj.emit_String(node)
1900
- elif isinstance(node, ast.Num):
1901
- return adj.emit_Num(node)
1902
- elif isinstance(node, ast.Ellipsis):
1903
- return adj.emit_Ellipsis(node)
1914
+ if node.value is None:
1915
+ raise WarpCodegenTypeError("None type unsupported")
1904
1916
  else:
1905
- assert isinstance(node, ast.NameConstant) or isinstance(node, ast.Constant)
1906
- return adj.emit_NameConstant(node)
1917
+ return adj.add_constant(node.value)
1907
1918
 
1908
1919
  def emit_BinOp(adj, node):
1909
1920
  # evaluate binary operator arguments
@@ -1997,10 +2008,11 @@ class Adjoint:
1997
2008
  adj.end_while()
1998
2009
 
1999
2010
  def eval_num(adj, a):
2000
- if isinstance(a, ast.Num):
2001
- return True, a.n
2002
- if isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Num):
2003
- return True, -a.operand.n
2011
+ if isinstance(a, ast.Constant):
2012
+ return True, a.value
2013
+ if isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Constant):
2014
+ # Negative constant
2015
+ return True, -a.operand.value
2004
2016
 
2005
2017
  # try and resolve the expression to an object
2006
2018
  # e.g.: wp.constant in the globals scope
@@ -2530,8 +2542,8 @@ class Adjoint:
2530
2542
  f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
2531
2543
  )
2532
2544
  else:
2533
- if adj.builder_options.get("enable_backward", True):
2534
- out = adj.add_builtin_call("assign", [target, *indices, rhs])
2545
+ if warp.config.enable_vector_component_overwrites:
2546
+ out = adj.add_builtin_call("assign_copy", [target, *indices, rhs])
2535
2547
 
2536
2548
  # re-point target symbol to out var
2537
2549
  for id in adj.symbols:
@@ -2539,8 +2551,7 @@ class Adjoint:
2539
2551
  adj.symbols[id] = out
2540
2552
  break
2541
2553
  else:
2542
- attr = adj.add_builtin_call("index", [target, *indices])
2543
- adj.add_builtin_call("store", [attr, rhs])
2554
+ adj.add_builtin_call("assign_inplace", [target, *indices, rhs])
2544
2555
 
2545
2556
  else:
2546
2557
  raise WarpCodegenError(
@@ -2583,8 +2594,8 @@ class Adjoint:
2583
2594
  attr = adj.add_builtin_call("indexref", [aggregate, index])
2584
2595
  adj.add_builtin_call("store", [attr, rhs])
2585
2596
  else:
2586
- if adj.builder_options.get("enable_backward", True):
2587
- out = adj.add_builtin_call("assign", [aggregate, index, rhs])
2597
+ if warp.config.enable_vector_component_overwrites:
2598
+ out = adj.add_builtin_call("assign_copy", [aggregate, index, rhs])
2588
2599
 
2589
2600
  # re-point target symbol to out var
2590
2601
  for id in adj.symbols:
@@ -2592,8 +2603,7 @@ class Adjoint:
2592
2603
  adj.symbols[id] = out
2593
2604
  break
2594
2605
  else:
2595
- attr = adj.add_builtin_call("index", [aggregate, index])
2596
- adj.add_builtin_call("store", [attr, rhs])
2606
+ adj.add_builtin_call("assign_inplace", [aggregate, index, rhs])
2597
2607
 
2598
2608
  else:
2599
2609
  attr = adj.emit_Attribute(lhs)
@@ -2699,10 +2709,12 @@ class Adjoint:
2699
2709
 
2700
2710
  elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
2701
2711
  if isinstance(node.op, ast.Add):
2702
- adj.add_builtin_call("augassign_add", [target, *indices, rhs])
2712
+ adj.add_builtin_call("add_inplace", [target, *indices, rhs])
2703
2713
  elif isinstance(node.op, ast.Sub):
2704
- adj.add_builtin_call("augassign_sub", [target, *indices, rhs])
2714
+ adj.add_builtin_call("sub_inplace", [target, *indices, rhs])
2705
2715
  else:
2716
+ if warp.config.verbose:
2717
+ print(f"Warning: in-place op {node.op} is not differentiable")
2706
2718
  make_new_assign_statement()
2707
2719
  return
2708
2720
 
@@ -2732,9 +2744,6 @@ class Adjoint:
2732
2744
  ast.BoolOp: emit_BoolOp,
2733
2745
  ast.Name: emit_Name,
2734
2746
  ast.Attribute: emit_Attribute,
2735
- ast.Str: emit_String, # Deprecated in 3.8; use Constant
2736
- ast.Num: emit_Num, # Deprecated in 3.8; use Constant
2737
- ast.NameConstant: emit_NameConstant, # Deprecated in 3.8; use Constant
2738
2747
  ast.Constant: emit_Constant,
2739
2748
  ast.BinOp: emit_BinOp,
2740
2749
  ast.UnaryOp: emit_UnaryOp,
@@ -2744,14 +2753,13 @@ class Adjoint:
2744
2753
  ast.Continue: emit_Continue,
2745
2754
  ast.Expr: emit_Expr,
2746
2755
  ast.Call: emit_Call,
2747
- ast.Index: emit_Index, # Deprecated in 3.8; Use the index value directly instead.
2756
+ ast.Index: emit_Index, # Deprecated in 3.9
2748
2757
  ast.Subscript: emit_Subscript,
2749
2758
  ast.Assign: emit_Assign,
2750
2759
  ast.Return: emit_Return,
2751
2760
  ast.AugAssign: emit_AugAssign,
2752
2761
  ast.Tuple: emit_Tuple,
2753
2762
  ast.Pass: emit_Pass,
2754
- ast.Ellipsis: emit_Ellipsis,
2755
2763
  ast.Assert: emit_Assert,
2756
2764
  }
2757
2765
 
@@ -2947,12 +2955,16 @@ class Adjoint:
2947
2955
 
2948
2956
  # We want to replace the expression code in-place,
2949
2957
  # so reparse it to get the correct column info.
2950
- len_value_locs = []
2958
+ len_value_locs: List[Tuple[int, int, int]] = []
2951
2959
  expr_tree = ast.parse(static_code)
2952
2960
  assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr)
2953
2961
  expr_root = expr_tree.body[0].value
2954
2962
  for expr_node in ast.walk(expr_root):
2955
- if isinstance(expr_node, ast.Call) and expr_node.func.id == "len" and len(expr_node.args) == 1:
2963
+ if (
2964
+ isinstance(expr_node, ast.Call)
2965
+ and getattr(expr_node.func, "id", None) == "len"
2966
+ and len(expr_node.args) == 1
2967
+ ):
2956
2968
  len_expr = static_code[expr_node.col_offset : expr_node.end_col_offset]
2957
2969
  try:
2958
2970
  len_value = eval(len_expr, len_expr_ctx)
@@ -3110,9 +3122,9 @@ class Adjoint:
3110
3122
 
3111
3123
  local_variables = set() # Track local variables appearing on the LHS so we know when variables are shadowed
3112
3124
 
3113
- constants = {}
3114
- types = {}
3115
- functions = {}
3125
+ constants: Dict[str, Any] = {}
3126
+ types: Dict[Union[Struct, type], Any] = {}
3127
+ functions: Dict[warp.context.Function, Any] = {}
3116
3128
 
3117
3129
  for node in ast.walk(adj.tree):
3118
3130
  if isinstance(node, ast.Name) and node.id not in local_variables:
@@ -3155,7 +3167,7 @@ class Adjoint:
3155
3167
  # code generation
3156
3168
 
3157
3169
  cpu_module_header = """
3158
- #define WP_TILE_BLOCK_DIM {tile_size}
3170
+ #define WP_TILE_BLOCK_DIM {block_dim}
3159
3171
  #define WP_NO_CRT
3160
3172
  #include "builtin.h"
3161
3173
 
@@ -3174,7 +3186,7 @@ cpu_module_header = """
3174
3186
  """
3175
3187
 
3176
3188
  cuda_module_header = """
3177
- #define WP_TILE_BLOCK_DIM {tile_size}
3189
+ #define WP_TILE_BLOCK_DIM {block_dim}
3178
3190
  #define WP_NO_CRT
3179
3191
  #include "builtin.h"
3180
3192
 
@@ -3197,6 +3209,7 @@ struct {name}
3197
3209
  {{
3198
3210
  {struct_body}
3199
3211
 
3212
+ {defaulted_constructor_def}
3200
3213
  CUDA_CALLABLE {name}({forward_args})
3201
3214
  {forward_initializers}
3202
3215
  {{
@@ -3239,53 +3252,53 @@ static void adj_{name}(
3239
3252
 
3240
3253
  cuda_forward_function_template = """
3241
3254
  // {filename}:{lineno}
3242
- static CUDA_CALLABLE {return_type} {name}(
3255
+ {line_directive}static CUDA_CALLABLE {return_type} {name}(
3243
3256
  {forward_args})
3244
3257
  {{
3245
- {forward_body}}}
3258
+ {forward_body}{line_directive}}}
3246
3259
 
3247
3260
  """
3248
3261
 
3249
3262
  cuda_reverse_function_template = """
3250
3263
  // {filename}:{lineno}
3251
- static CUDA_CALLABLE void adj_{name}(
3264
+ {line_directive}static CUDA_CALLABLE void adj_{name}(
3252
3265
  {reverse_args})
3253
3266
  {{
3254
- {reverse_body}}}
3267
+ {reverse_body}{line_directive}}}
3255
3268
 
3256
3269
  """
3257
3270
 
3258
3271
  cuda_kernel_template_forward = """
3259
3272
 
3260
- extern "C" __global__ void {name}_cuda_kernel_forward(
3273
+ {line_directive}extern "C" __global__ void {name}_cuda_kernel_forward(
3261
3274
  {forward_args})
3262
3275
  {{
3263
- for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3264
- _idx < dim.size;
3265
- _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3276
+ {line_directive} for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3277
+ {line_directive} _idx < dim.size;
3278
+ {line_directive} _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3266
3279
  {{
3267
3280
  // reset shared memory allocator
3268
- wp::tile_alloc_shared(0, true);
3281
+ {line_directive} wp::tile_alloc_shared(0, true);
3269
3282
 
3270
- {forward_body} }}
3271
- }}
3283
+ {forward_body}{line_directive} }}
3284
+ {line_directive}}}
3272
3285
 
3273
3286
  """
3274
3287
 
3275
3288
  cuda_kernel_template_backward = """
3276
3289
 
3277
- extern "C" __global__ void {name}_cuda_kernel_backward(
3290
+ {line_directive}extern "C" __global__ void {name}_cuda_kernel_backward(
3278
3291
  {reverse_args})
3279
3292
  {{
3280
- for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3281
- _idx < dim.size;
3282
- _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3293
+ {line_directive} for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3294
+ {line_directive} _idx < dim.size;
3295
+ {line_directive} _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3283
3296
  {{
3284
3297
  // reset shared memory allocator
3285
- wp::tile_alloc_shared(0, true);
3298
+ {line_directive} wp::tile_alloc_shared(0, true);
3286
3299
 
3287
- {reverse_body} }}
3288
- }}
3300
+ {reverse_body}{line_directive} }}
3301
+ {line_directive}}}
3289
3302
 
3290
3303
  """
3291
3304
 
@@ -3315,10 +3328,17 @@ extern "C" {{
3315
3328
  WP_API void {name}_cpu_forward(
3316
3329
  {forward_args})
3317
3330
  {{
3318
- for (size_t task_index = 0; task_index < dim.size; ++task_index)
3331
+ for (size_t task_index = 0; task_index < dim.size; ++task_index)
3319
3332
  {{
3333
+ // init shared memory allocator
3334
+ wp::tile_alloc_shared(0, true);
3335
+
3320
3336
  {name}_cpu_kernel_forward(
3321
3337
  {forward_params});
3338
+
3339
+ // check shared memory allocator
3340
+ wp::tile_alloc_shared(0, false, true);
3341
+
3322
3342
  }}
3323
3343
  }}
3324
3344
 
@@ -3335,8 +3355,14 @@ WP_API void {name}_cpu_backward(
3335
3355
  {{
3336
3356
  for (size_t task_index = 0; task_index < dim.size; ++task_index)
3337
3357
  {{
3358
+ // initialize shared memory allocator
3359
+ wp::tile_alloc_shared(0, true);
3360
+
3338
3361
  {name}_cpu_kernel_backward(
3339
3362
  {reverse_params});
3363
+
3364
+ // check shared memory allocator
3365
+ wp::tile_alloc_shared(0, false, true);
3340
3366
  }}
3341
3367
  }}
3342
3368
 
@@ -3418,7 +3444,7 @@ def indent(args, stops=1):
3418
3444
 
3419
3445
 
3420
3446
  # generates a C function name based on the python function name
3421
- def make_full_qualified_name(func):
3447
+ def make_full_qualified_name(func: Union[str, Callable]) -> str:
3422
3448
  if not isinstance(func, str):
3423
3449
  func = func.__qualname__
3424
3450
  return re.sub("[^0-9a-zA-Z_]+", "", func.replace(".", "__"))
@@ -3448,7 +3474,8 @@ def codegen_struct(struct, device="cpu", indent_size=4):
3448
3474
  # forward args
3449
3475
  for label, var in struct.vars.items():
3450
3476
  var_ctype = var.ctype()
3451
- forward_args.append(f"{var_ctype} const& {label} = {{}}")
3477
+ default_arg_def = " = {}" if forward_args else ""
3478
+ forward_args.append(f"{var_ctype} const& {label}{default_arg_def}")
3452
3479
  reverse_args.append(f"{var_ctype} const&")
3453
3480
 
3454
3481
  namespace = "wp::" if var_ctype.startswith("wp::") or var_ctype == "bool" else ""
@@ -3472,6 +3499,9 @@ def codegen_struct(struct, device="cpu", indent_size=4):
3472
3499
 
3473
3500
  reverse_args.append(name + " & adj_ret")
3474
3501
 
3502
+ # explicitly defaulted default constructor if no default constructor has been defined
3503
+ defaulted_constructor_def = f"{name}() = default;" if forward_args else ""
3504
+
3475
3505
  return struct_template.format(
3476
3506
  name=name,
3477
3507
  struct_body="".join([indent_block + l for l in body]),
@@ -3481,6 +3511,7 @@ def codegen_struct(struct, device="cpu", indent_size=4):
3481
3511
  reverse_body="".join(reverse_body),
3482
3512
  prefix_add_body="".join(prefix_add_body),
3483
3513
  atomic_add_body="".join(atomic_add_body),
3514
+ defaulted_constructor_def=defaulted_constructor_def,
3484
3515
  )
3485
3516
 
3486
3517
 
@@ -3510,6 +3541,9 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
3510
3541
  else:
3511
3542
  lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
3512
3543
 
3544
+ if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
3545
+ lines.insert(-1, f"{line_directive}\n")
3546
+
3513
3547
  # forward pass
3514
3548
  lines += ["//---------\n"]
3515
3549
  lines += ["// forward\n"]
@@ -3517,7 +3551,7 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
3517
3551
  for f in adj.blocks[0].body_forward:
3518
3552
  lines += [f + "\n"]
3519
3553
 
3520
- return "".join([indent_block + l for l in lines])
3554
+ return "".join(l.lstrip() if l.lstrip().startswith("#line") else indent_block + l for l in lines)
3521
3555
 
3522
3556
 
3523
3557
  def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
@@ -3547,6 +3581,9 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
3547
3581
  else:
3548
3582
  lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
3549
3583
 
3584
+ if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
3585
+ lines.insert(-1, f"{line_directive}\n")
3586
+
3550
3587
  # dual vars
3551
3588
  lines += ["//---------\n"]
3552
3589
  lines += ["// dual vars\n"]
@@ -3567,6 +3604,9 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
3567
3604
  else:
3568
3605
  lines += [f"{ctype} {name} = {{}};\n"]
3569
3606
 
3607
+ if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
3608
+ lines.insert(-1, f"{line_directive}\n")
3609
+
3570
3610
  # forward pass
3571
3611
  lines += ["//---------\n"]
3572
3612
  lines += ["// forward\n"]
@@ -3587,7 +3627,7 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
3587
3627
  else:
3588
3628
  lines += ["return;\n"]
3589
3629
 
3590
- return "".join([indent_block + l for l in lines])
3630
+ return "".join(l.lstrip() if l.lstrip().startswith("#line") else indent_block + l for l in lines)
3591
3631
 
3592
3632
 
3593
3633
  def codegen_func(adj, c_func_name: str, device="cpu", options=None):
@@ -3595,11 +3635,11 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3595
3635
  options = {}
3596
3636
 
3597
3637
  if adj.return_var is not None and "return" in adj.arg_types:
3598
- if get_type_origin(adj.arg_types["return"]) is tuple:
3599
- if len(get_type_args(adj.arg_types["return"])) != len(adj.return_var):
3638
+ if get_origin(adj.arg_types["return"]) is tuple:
3639
+ if len(get_args(adj.arg_types["return"])) != len(adj.return_var):
3600
3640
  raise WarpCodegenError(
3601
3641
  f"The function `{adj.fun_name}` has its return type "
3602
- f"annotated as a tuple of {len(get_type_args(adj.arg_types['return']))} elements "
3642
+ f"annotated as a tuple of {len(get_args(adj.arg_types['return']))} elements "
3603
3643
  f"but the code returns {len(adj.return_var)} values."
3604
3644
  )
3605
3645
  elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var)):
@@ -3608,7 +3648,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3608
3648
  f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
3609
3649
  f"but the code returns a tuple with types `({', '.join(warp.context.type_str(x.type) for x in adj.return_var)})`."
3610
3650
  )
3611
- elif len(adj.return_var) > 1 and get_type_origin(adj.arg_types["return"]) is not tuple:
3651
+ elif len(adj.return_var) > 1 and get_origin(adj.arg_types["return"]) is not tuple:
3612
3652
  raise WarpCodegenError(
3613
3653
  f"The function `{adj.fun_name}` has its return type "
3614
3654
  f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
@@ -3621,6 +3661,13 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3621
3661
  f"but the code returns a value of type `{warp.context.type_str(adj.return_var[0].type)}`."
3622
3662
  )
3623
3663
 
3664
+ # Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
3665
+ # This is used as a catch-all C-to-Python source line mapping for any code that does not have
3666
+ # a direct mapping to a Python source line.
3667
+ func_line_directive = ""
3668
+ if line_directive := adj.get_line_directive("", adj.fun_def_lineno - 1):
3669
+ func_line_directive = f"{line_directive}\n"
3670
+
3624
3671
  # forward header
3625
3672
  if adj.return_var is not None and len(adj.return_var) == 1:
3626
3673
  return_type = adj.return_var[0].ctype()
@@ -3684,6 +3731,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3684
3731
  forward_body=forward_body,
3685
3732
  filename=adj.filename,
3686
3733
  lineno=adj.fun_lineno,
3734
+ line_directive=func_line_directive,
3687
3735
  )
3688
3736
 
3689
3737
  if not adj.skip_reverse_codegen:
@@ -3702,6 +3750,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3702
3750
  reverse_body=reverse_body,
3703
3751
  filename=adj.filename,
3704
3752
  lineno=adj.fun_lineno,
3753
+ line_directive=func_line_directive,
3705
3754
  )
3706
3755
 
3707
3756
  return s
@@ -3744,6 +3793,7 @@ def codegen_snippet(adj, name, snippet, adj_snippet, replay_snippet):
3744
3793
  forward_body=snippet,
3745
3794
  filename=adj.filename,
3746
3795
  lineno=adj.fun_lineno,
3796
+ line_directive="",
3747
3797
  )
3748
3798
 
3749
3799
  if replay_snippet is not None:
@@ -3754,6 +3804,7 @@ def codegen_snippet(adj, name, snippet, adj_snippet, replay_snippet):
3754
3804
  forward_body=replay_snippet,
3755
3805
  filename=adj.filename,
3756
3806
  lineno=adj.fun_lineno,
3807
+ line_directive="",
3757
3808
  )
3758
3809
 
3759
3810
  if adj_snippet:
@@ -3769,6 +3820,7 @@ def codegen_snippet(adj, name, snippet, adj_snippet, replay_snippet):
3769
3820
  reverse_body=reverse_body,
3770
3821
  filename=adj.filename,
3771
3822
  lineno=adj.fun_lineno,
3823
+ line_directive="",
3772
3824
  )
3773
3825
 
3774
3826
  return s
@@ -3781,6 +3833,13 @@ def codegen_kernel(kernel, device, options):
3781
3833
 
3782
3834
  adj = kernel.adj
3783
3835
 
3836
+ # Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
3837
+ # This is used as a catch-all C-to-Python source line mapping for any code that does not have
3838
+ # a direct mapping to a Python source line.
3839
+ func_line_directive = ""
3840
+ if line_directive := adj.get_line_directive("", adj.fun_def_lineno - 1):
3841
+ func_line_directive = f"{line_directive}\n"
3842
+
3784
3843
  if device == "cpu":
3785
3844
  template_forward = cpu_kernel_template_forward
3786
3845
  template_backward = cpu_kernel_template_backward
@@ -3808,6 +3867,7 @@ def codegen_kernel(kernel, device, options):
3808
3867
  {
3809
3868
  "forward_args": indent(forward_args),
3810
3869
  "forward_body": forward_body,
3870
+ "line_directive": func_line_directive,
3811
3871
  }
3812
3872
  )
3813
3873
  template += template_forward