warp-lang 1.2.1__py3-none-win_amd64.whl → 1.3.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (194) hide show
  1. warp/__init__.py +8 -6
  2. warp/autograd.py +823 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +6 -2
  6. warp/builtins.py +1410 -886
  7. warp/codegen.py +503 -166
  8. warp/config.py +48 -18
  9. warp/context.py +401 -199
  10. warp/dlpack.py +8 -0
  11. warp/examples/assets/bunny.usd +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
  13. warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
  14. warp/examples/benchmarks/benchmark_launches.py +1 -1
  15. warp/examples/core/example_cupy.py +78 -0
  16. warp/examples/fem/example_apic_fluid.py +17 -36
  17. warp/examples/fem/example_burgers.py +9 -18
  18. warp/examples/fem/example_convection_diffusion.py +7 -17
  19. warp/examples/fem/example_convection_diffusion_dg.py +27 -47
  20. warp/examples/fem/example_deformed_geometry.py +11 -22
  21. warp/examples/fem/example_diffusion.py +7 -18
  22. warp/examples/fem/example_diffusion_3d.py +24 -28
  23. warp/examples/fem/example_diffusion_mgpu.py +7 -14
  24. warp/examples/fem/example_magnetostatics.py +190 -0
  25. warp/examples/fem/example_mixed_elasticity.py +111 -80
  26. warp/examples/fem/example_navier_stokes.py +30 -34
  27. warp/examples/fem/example_nonconforming_contact.py +290 -0
  28. warp/examples/fem/example_stokes.py +17 -32
  29. warp/examples/fem/example_stokes_transfer.py +12 -21
  30. warp/examples/fem/example_streamlines.py +350 -0
  31. warp/examples/fem/utils.py +936 -0
  32. warp/fabric.py +5 -2
  33. warp/fem/__init__.py +13 -3
  34. warp/fem/cache.py +161 -11
  35. warp/fem/dirichlet.py +37 -28
  36. warp/fem/domain.py +105 -14
  37. warp/fem/field/__init__.py +14 -3
  38. warp/fem/field/field.py +454 -11
  39. warp/fem/field/nodal_field.py +33 -18
  40. warp/fem/geometry/deformed_geometry.py +50 -15
  41. warp/fem/geometry/hexmesh.py +12 -24
  42. warp/fem/geometry/nanogrid.py +106 -31
  43. warp/fem/geometry/quadmesh_2d.py +6 -11
  44. warp/fem/geometry/tetmesh.py +103 -61
  45. warp/fem/geometry/trimesh_2d.py +98 -47
  46. warp/fem/integrate.py +231 -186
  47. warp/fem/operator.py +14 -9
  48. warp/fem/quadrature/pic_quadrature.py +35 -9
  49. warp/fem/quadrature/quadrature.py +119 -32
  50. warp/fem/space/basis_space.py +98 -22
  51. warp/fem/space/collocated_function_space.py +3 -1
  52. warp/fem/space/function_space.py +7 -2
  53. warp/fem/space/grid_2d_function_space.py +3 -3
  54. warp/fem/space/grid_3d_function_space.py +4 -4
  55. warp/fem/space/hexmesh_function_space.py +3 -2
  56. warp/fem/space/nanogrid_function_space.py +12 -14
  57. warp/fem/space/partition.py +45 -47
  58. warp/fem/space/restriction.py +19 -16
  59. warp/fem/space/shape/cube_shape_function.py +91 -3
  60. warp/fem/space/shape/shape_function.py +7 -0
  61. warp/fem/space/shape/square_shape_function.py +32 -0
  62. warp/fem/space/shape/tet_shape_function.py +11 -7
  63. warp/fem/space/shape/triangle_shape_function.py +10 -1
  64. warp/fem/space/topology.py +116 -42
  65. warp/fem/types.py +8 -1
  66. warp/fem/utils.py +301 -83
  67. warp/native/array.h +16 -0
  68. warp/native/builtin.h +0 -15
  69. warp/native/cuda_util.cpp +14 -6
  70. warp/native/exports.h +1348 -1308
  71. warp/native/quat.h +79 -0
  72. warp/native/rand.h +27 -4
  73. warp/native/sparse.cpp +83 -81
  74. warp/native/sparse.cu +381 -453
  75. warp/native/vec.h +64 -0
  76. warp/native/volume.cpp +40 -49
  77. warp/native/volume_builder.cu +2 -3
  78. warp/native/volume_builder.h +12 -17
  79. warp/native/warp.cu +3 -3
  80. warp/native/warp.h +69 -59
  81. warp/render/render_opengl.py +17 -9
  82. warp/sim/articulation.py +117 -17
  83. warp/sim/collide.py +35 -29
  84. warp/sim/model.py +123 -18
  85. warp/sim/render.py +3 -1
  86. warp/sparse.py +867 -203
  87. warp/stubs.py +312 -541
  88. warp/tape.py +29 -1
  89. warp/tests/disabled_kinematics.py +1 -1
  90. warp/tests/test_adam.py +1 -1
  91. warp/tests/test_arithmetic.py +1 -1
  92. warp/tests/test_array.py +58 -1
  93. warp/tests/test_array_reduce.py +1 -1
  94. warp/tests/test_async.py +1 -1
  95. warp/tests/test_atomic.py +1 -1
  96. warp/tests/test_bool.py +1 -1
  97. warp/tests/test_builtins_resolution.py +1 -1
  98. warp/tests/test_bvh.py +6 -1
  99. warp/tests/test_closest_point_edge_edge.py +1 -1
  100. warp/tests/test_codegen.py +66 -1
  101. warp/tests/test_compile_consts.py +1 -1
  102. warp/tests/test_conditional.py +1 -1
  103. warp/tests/test_copy.py +1 -1
  104. warp/tests/test_ctypes.py +1 -1
  105. warp/tests/test_dense.py +1 -1
  106. warp/tests/test_devices.py +1 -1
  107. warp/tests/test_dlpack.py +1 -1
  108. warp/tests/test_examples.py +33 -4
  109. warp/tests/test_fabricarray.py +5 -2
  110. warp/tests/test_fast_math.py +1 -1
  111. warp/tests/test_fem.py +213 -6
  112. warp/tests/test_fp16.py +1 -1
  113. warp/tests/test_func.py +1 -1
  114. warp/tests/test_future_annotations.py +90 -0
  115. warp/tests/test_generics.py +1 -1
  116. warp/tests/test_grad.py +1 -1
  117. warp/tests/test_grad_customs.py +1 -1
  118. warp/tests/test_grad_debug.py +247 -0
  119. warp/tests/test_hash_grid.py +6 -1
  120. warp/tests/test_implicit_init.py +354 -0
  121. warp/tests/test_import.py +1 -1
  122. warp/tests/test_indexedarray.py +1 -1
  123. warp/tests/test_intersect.py +1 -1
  124. warp/tests/test_jax.py +1 -1
  125. warp/tests/test_large.py +1 -1
  126. warp/tests/test_launch.py +1 -1
  127. warp/tests/test_lerp.py +1 -1
  128. warp/tests/test_linear_solvers.py +1 -1
  129. warp/tests/test_lvalue.py +1 -1
  130. warp/tests/test_marching_cubes.py +5 -2
  131. warp/tests/test_mat.py +34 -35
  132. warp/tests/test_mat_lite.py +2 -1
  133. warp/tests/test_mat_scalar_ops.py +1 -1
  134. warp/tests/test_math.py +1 -1
  135. warp/tests/test_matmul.py +20 -16
  136. warp/tests/test_matmul_lite.py +1 -1
  137. warp/tests/test_mempool.py +1 -1
  138. warp/tests/test_mesh.py +5 -2
  139. warp/tests/test_mesh_query_aabb.py +1 -1
  140. warp/tests/test_mesh_query_point.py +1 -1
  141. warp/tests/test_mesh_query_ray.py +1 -1
  142. warp/tests/test_mlp.py +1 -1
  143. warp/tests/test_model.py +1 -1
  144. warp/tests/test_module_hashing.py +77 -1
  145. warp/tests/test_modules_lite.py +1 -1
  146. warp/tests/test_multigpu.py +1 -1
  147. warp/tests/test_noise.py +1 -1
  148. warp/tests/test_operators.py +1 -1
  149. warp/tests/test_options.py +1 -1
  150. warp/tests/test_overwrite.py +542 -0
  151. warp/tests/test_peer.py +1 -1
  152. warp/tests/test_pinned.py +1 -1
  153. warp/tests/test_print.py +1 -1
  154. warp/tests/test_quat.py +15 -1
  155. warp/tests/test_rand.py +1 -1
  156. warp/tests/test_reload.py +1 -1
  157. warp/tests/test_rounding.py +1 -1
  158. warp/tests/test_runlength_encode.py +1 -1
  159. warp/tests/test_scalar_ops.py +95 -0
  160. warp/tests/test_sim_grad.py +1 -1
  161. warp/tests/test_sim_kinematics.py +1 -1
  162. warp/tests/test_smoothstep.py +1 -1
  163. warp/tests/test_sparse.py +82 -15
  164. warp/tests/test_spatial.py +1 -1
  165. warp/tests/test_special_values.py +2 -11
  166. warp/tests/test_streams.py +11 -1
  167. warp/tests/test_struct.py +1 -1
  168. warp/tests/test_tape.py +1 -1
  169. warp/tests/test_torch.py +194 -1
  170. warp/tests/test_transient_module.py +1 -1
  171. warp/tests/test_types.py +1 -1
  172. warp/tests/test_utils.py +1 -1
  173. warp/tests/test_vec.py +15 -63
  174. warp/tests/test_vec_lite.py +2 -1
  175. warp/tests/test_vec_scalar_ops.py +122 -39
  176. warp/tests/test_verify_fp.py +1 -1
  177. warp/tests/test_volume.py +28 -2
  178. warp/tests/test_volume_write.py +1 -1
  179. warp/tests/unittest_serial.py +1 -1
  180. warp/tests/unittest_suites.py +9 -1
  181. warp/tests/walkthrough_debug.py +1 -1
  182. warp/thirdparty/unittest_parallel.py +2 -5
  183. warp/torch.py +103 -41
  184. warp/types.py +344 -227
  185. warp/utils.py +11 -2
  186. {warp_lang-1.2.1.dist-info → warp_lang-1.3.0.dist-info}/METADATA +99 -46
  187. warp_lang-1.3.0.dist-info/RECORD +368 -0
  188. warp/examples/fem/bsr_utils.py +0 -378
  189. warp/examples/fem/mesh_utils.py +0 -133
  190. warp/examples/fem/plot_utils.py +0 -292
  191. warp_lang-1.2.1.dist-info/RECORD +0 -359
  192. {warp_lang-1.2.1.dist-info → warp_lang-1.3.0.dist-info}/LICENSE.md +0 -0
  193. {warp_lang-1.2.1.dist-info → warp_lang-1.3.0.dist-info}/WHEEL +0 -0
  194. {warp_lang-1.2.1.dist-info → warp_lang-1.3.0.dist-info}/top_level.txt +0 -0
warp/codegen.py CHANGED
@@ -10,13 +10,14 @@ from __future__ import annotations
10
10
  import ast
11
11
  import builtins
12
12
  import ctypes
13
+ import functools
13
14
  import inspect
14
15
  import math
15
16
  import re
16
17
  import sys
17
18
  import textwrap
18
19
  import types
19
- from typing import Any, Callable, Dict, Mapping
20
+ from typing import Any, Callable, Dict, Mapping, Optional, Sequence
20
21
 
21
22
  import warp.config
22
23
  from warp.types import *
@@ -84,17 +85,108 @@ comparison_chain_strings = [
84
85
  ]
85
86
 
86
87
 
88
+ def values_check_equal(a, b):
89
+ if isinstance(a, Sequence) and isinstance(b, Sequence):
90
+ if len(a) != len(b):
91
+ return False
92
+
93
+ return all(x == y for x, y in zip(a, b))
94
+
95
+ return a == b
96
+
97
+
87
98
  def op_str_is_chainable(op: str) -> builtins.bool:
88
99
  return op in comparison_chain_strings
89
100
 
90
101
 
102
+ def get_closure_cell_contents(obj):
103
+ """Retrieve a closure's cell contents or `None` if it's empty."""
104
+ try:
105
+ return obj.cell_contents
106
+ except ValueError:
107
+ pass
108
+
109
+ return None
110
+
111
+
112
+ def eval_annotations(annotations: Mapping[str, Any], obj: Any) -> Mapping[str, Any]:
113
+ """Un-stringize annotations caused by `from __future__ import annotations` of PEP 563."""
114
+ # Implementation backported from `inspect.get_annotations()` for Python 3.9 and older.
115
+ if not annotations:
116
+ return {}
117
+
118
+ if not any(isinstance(x, str) for x in annotations.values()):
119
+ # No annotation to un-stringize.
120
+ return annotations
121
+
122
+ if isinstance(obj, type):
123
+ # class
124
+ globals = {}
125
+ module_name = getattr(obj, "__module__", None)
126
+ if module_name:
127
+ module = sys.modules.get(module_name, None)
128
+ if module:
129
+ globals = getattr(module, "__dict__", {})
130
+ locals = dict(vars(obj))
131
+ unwrap = obj
132
+ elif isinstance(obj, types.ModuleType):
133
+ # module
134
+ globals = obj.__dict__
135
+ locals = {}
136
+ unwrap = None
137
+ elif callable(obj):
138
+ # function
139
+ globals = getattr(obj, "__globals__", {})
140
+ # Capture the variables from the surrounding scope.
141
+ closure_vars = zip(
142
+ obj.__code__.co_freevars, tuple(get_closure_cell_contents(x) for x in (obj.__closure__ or ()))
143
+ )
144
+ locals = {k: v for k, v in closure_vars if v is not None}
145
+ unwrap = obj
146
+ else:
147
+ raise TypeError(f"{obj!r} is not a module, class, or callable.")
148
+
149
+ if unwrap is not None:
150
+ while True:
151
+ if hasattr(unwrap, "__wrapped__"):
152
+ unwrap = unwrap.__wrapped__
153
+ continue
154
+ if isinstance(unwrap, functools.partial):
155
+ unwrap = unwrap.func
156
+ continue
157
+ break
158
+ if hasattr(unwrap, "__globals__"):
159
+ globals = unwrap.__globals__
160
+
161
+ # "Inject" type parameters into the local namespace
162
+ # (unless they are shadowed by assignments *in* the local namespace),
163
+ # as a way of emulating annotation scopes when calling `eval()`
164
+ type_params = getattr(obj, "__type_params__", ())
165
+ if type_params:
166
+ locals = {param.__name__: param for param in type_params} | locals
167
+
168
+ return {k: v if not isinstance(v, str) else eval(v, globals, locals) for k, v in annotations.items()}
169
+
170
+
91
171
  def get_annotations(obj: Any) -> Mapping[str, Any]:
92
- """Alternative to `inspect.get_annotations()` for Python 3.9 and older."""
172
+ """Same as `inspect.get_annotations()` but always returning un-stringized annotations."""
173
+ # This backports `inspect.get_annotations()` for Python 3.9 and older.
93
174
  # See https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
94
175
  if isinstance(obj, type):
95
- return obj.__dict__.get("__annotations__", {})
176
+ annotations = obj.__dict__.get("__annotations__", {})
177
+ else:
178
+ annotations = getattr(obj, "__annotations__", {})
179
+
180
+ # Evaluating annotations can be done using the `eval_str` parameter with
181
+ # the official function from the `inspect` module.
182
+ return eval_annotations(annotations, obj)
96
183
 
97
- return getattr(obj, "__annotations__", {})
184
+
185
+ def get_full_arg_spec(func: Callable) -> inspect.FullArgSpec:
186
+ """Same as `inspect.getfullargspec()` but always returning un-stringized annotations."""
187
+ # See https://docs.python.org/3/howto/annotations.html#manually-un-stringizing-stringized-annotations
188
+ spec = inspect.getfullargspec(func)
189
+ return spec._replace(annotations=eval_annotations(spec.annotations, func))
98
190
 
99
191
 
100
192
  def struct_instance_repr_recursive(inst: StructInstance, depth: int) -> str:
@@ -490,6 +582,14 @@ class Var:
490
582
  self.constant = constant
491
583
  self.prefix = prefix
492
584
 
585
+ # records whether this Var has been read from in a kernel function (array only)
586
+ self.is_read = False
587
+ # records whether this Var has been written to in a kernel function (array only)
588
+ self.is_write = False
589
+
590
+ # used to associate a view array Var with its parent array Var
591
+ self.parent = None
592
+
493
593
  def __str__(self):
494
594
  return self.label
495
595
 
@@ -532,6 +632,42 @@ class Var:
532
632
  def emit_adj(self):
533
633
  return self.emit("adj")
534
634
 
635
+ def mark_read(self):
636
+ """Marks this Var as having been read from in a kernel (array only)."""
637
+ if not is_array(self.type):
638
+ return
639
+
640
+ self.is_read = True
641
+
642
+ # recursively update all parent states
643
+ parent = self.parent
644
+ while parent is not None:
645
+ parent.is_read = True
646
+ parent = parent.parent
647
+
648
+ def mark_write(self, **kwargs):
649
+ """Marks this Var has having been written to in a kernel (array only)."""
650
+ if not is_array(self.type):
651
+ return
652
+
653
+ # detect if we are writing to an array after reading from it within the same kernel
654
+ if self.is_read and warp.config.verify_autograd_array_access:
655
+ if "kernel_name" and "filename" and "lineno" in kwargs:
656
+ print(
657
+ f"Warning: Array passed to argument {self.label} in kernel {kwargs['kernel_name']} at {kwargs['filename']}:{kwargs['lineno']} is being written to after it has been read from within the same kernel. This may corrupt gradient computation in the backward pass."
658
+ )
659
+ else:
660
+ print(
661
+ f"Warning: Array {self} is being written to after it has been read from within the same kernel. This may corrupt gradient computation in the backward pass."
662
+ )
663
+ self.is_write = True
664
+
665
+ # recursively update all parent states
666
+ parent = self.parent
667
+ while parent is not None:
668
+ parent.is_write = True
669
+ parent = parent.parent
670
+
535
671
 
536
672
  class Block:
537
673
  # Represents a basic block of instructions, e.g.: list
@@ -547,6 +683,91 @@ class Block:
547
683
  self.vars = []
548
684
 
549
685
 
686
+ def apply_defaults(
687
+ bound_args: inspect.BoundArguments,
688
+ values: Mapping[str, Any],
689
+ ):
690
+ # Similar to Python's `inspect.BoundArguments.apply_defaults()`
691
+ # but with the possibility to pass an augmented set of default values.
692
+ arguments = bound_args.arguments
693
+ new_arguments = []
694
+ for name in bound_args._signature.parameters.keys():
695
+ try:
696
+ new_arguments.append((name, arguments[name]))
697
+ except KeyError:
698
+ if name in values:
699
+ new_arguments.append((name, values[name]))
700
+
701
+ bound_args.arguments = dict(new_arguments)
702
+
703
+
704
+ def func_match_args(func, arg_types, kwarg_types):
705
+ try:
706
+ # Try to bind the given arguments to the function's signature.
707
+ # This is not checking whether the argument types are matching,
708
+ # rather it's just assigning each argument to the corresponding
709
+ # function parameter.
710
+ bound_arg_types = func.signature.bind(*arg_types, **kwarg_types)
711
+ except TypeError:
712
+ return False
713
+
714
+ # Populate the bound arguments with any default values.
715
+ default_arg_types = {
716
+ k: None if v is None else get_arg_type(v)
717
+ for k, v in func.defaults.items()
718
+ if k not in bound_arg_types.arguments
719
+ }
720
+ apply_defaults(bound_arg_types, default_arg_types)
721
+ bound_arg_types = tuple(bound_arg_types.arguments.values())
722
+
723
+ # Check the given argument types against the ones defined on the function.
724
+ for bound_arg_type, func_arg_type in zip(bound_arg_types, func.input_types.values()):
725
+ # Let the `value_func` callback infer the type.
726
+ if bound_arg_type is None:
727
+ continue
728
+
729
+ # if arg type registered as Any, treat as
730
+ # template allowing any type to match
731
+ if func_arg_type == Any:
732
+ continue
733
+
734
+ # handle function refs as a special case
735
+ if func_arg_type == Callable and isinstance(bound_arg_type, warp.context.Function):
736
+ continue
737
+
738
+ # check arg type matches input variable type
739
+ if not types_equal(func_arg_type, strip_reference(bound_arg_type), match_generic=True):
740
+ return False
741
+
742
+ return True
743
+
744
+
745
+ def get_arg_type(arg: Union[Var, Any]):
746
+ if isinstance(arg, Sequence):
747
+ return tuple(get_arg_type(x) for x in arg)
748
+
749
+ if isinstance(arg, (type, warp.context.Function)):
750
+ return arg
751
+
752
+ if isinstance(arg, Var):
753
+ return arg.type
754
+
755
+ return type(arg)
756
+
757
+
758
+ def get_arg_value(arg: Union[Var, Any]):
759
+ if isinstance(arg, Sequence):
760
+ return tuple(get_arg_value(x) for x in arg)
761
+
762
+ if isinstance(arg, (type, warp.context.Function)):
763
+ return arg
764
+
765
+ if isinstance(arg, Var):
766
+ return arg.constant
767
+
768
+ return arg
769
+
770
+
550
771
  class Adjoint:
551
772
  # Source code transformer, this class takes a Python function and
552
773
  # generates forward and backward SSA forms of the function instructions
@@ -605,7 +826,7 @@ class Adjoint:
605
826
  adj.custom_reverse_num_input_args = custom_reverse_num_input_args
606
827
 
607
828
  # parse argument types
608
- argspec = inspect.getfullargspec(func)
829
+ argspec = get_full_arg_spec(func)
609
830
 
610
831
  # ensure all arguments are annotated
611
832
  if overload_annotations is None:
@@ -646,6 +867,11 @@ class Adjoint:
646
867
 
647
868
  # generate function ssa form and adjoint
648
869
  def build(adj, builder, default_builder_options=None):
870
+ # arg Var read/write flags are held during module rebuilds, so we reset here even when skipping a build
871
+ for arg in adj.args:
872
+ arg.is_read = False
873
+ arg.is_write = False
874
+
649
875
  if adj.skip_build:
650
876
  return
651
877
 
@@ -682,15 +908,11 @@ class Adjoint:
682
908
  # recursively evaluate function body
683
909
  try:
684
910
  adj.eval(adj.tree.body[0])
685
- except Exception as e:
911
+ except Exception:
686
912
  try:
687
- if isinstance(e, KeyError) and getattr(e.args[0], "__module__", None) == "ast":
688
- msg = f'Syntax error: unsupported construct "ast.{e.args[0].__name__}"'
689
- else:
690
- msg = "Error"
691
913
  lineno = adj.lineno + adj.fun_lineno
692
914
  line = adj.source_lines[adj.lineno]
693
- msg += f' while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
915
+ msg = f'Error while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
694
916
  ex, data, traceback = sys.exc_info()
695
917
  e = ex(";".join([msg] + [str(a) for a in data.args])).with_traceback(traceback)
696
918
  finally:
@@ -808,6 +1030,20 @@ class Adjoint:
808
1030
 
809
1031
  return v
810
1032
 
1033
+ def register_var(adj, var):
1034
+ # We sometimes initialize `Var` instances that might be thrown away
1035
+ # afterwards, so this method allows to defer their registration among
1036
+ # the list of primal vars until later on, instead of registering them
1037
+ # immediately if we were to use `adj.add_var()` or `adj.add_constant()`.
1038
+
1039
+ if isinstance(var, (Reference, warp.context.Function)):
1040
+ return var
1041
+
1042
+ if var.label is None:
1043
+ return adj.add_var(var.type, var.constant)
1044
+
1045
+ return var
1046
+
811
1047
  # append a statement to the forward pass
812
1048
  def add_forward(adj, statement, replay=None, skip_replay=False):
813
1049
  adj.blocks[-1].body_forward.append(adj.indentation + statement)
@@ -873,12 +1109,10 @@ class Adjoint:
873
1109
 
874
1110
  return output
875
1111
 
876
- def resolve_func(adj, func, args, min_outputs, templates, kwds):
877
- arg_types = [strip_reference(a.type) for a in args if not isinstance(a, warp.context.Function)]
878
-
1112
+ def resolve_func(adj, func, arg_types, kwarg_types, min_outputs):
879
1113
  if not func.is_builtin():
880
1114
  # user-defined function
881
- overload = func.get_overload(arg_types)
1115
+ overload = func.get_overload(arg_types, kwarg_types)
882
1116
  if overload is not None:
883
1117
  return overload
884
1118
  else:
@@ -888,88 +1122,89 @@ class Adjoint:
888
1122
  # skip type checking for variadic functions
889
1123
  if not f.variadic:
890
1124
  # check argument counts match are compatible (may be some default args)
891
- if len(f.input_types) < len(args):
1125
+ if len(f.input_types) < len(arg_types) + len(kwarg_types):
892
1126
  continue
893
1127
 
894
- def match_args(args, f):
895
- # check argument types equal
896
- for i, (arg_name, arg_type) in enumerate(f.input_types.items()):
897
- # if arg type registered as Any, treat as
898
- # template allowing any type to match
899
- if arg_type == Any:
900
- continue
901
-
902
- # handle function refs as a special case
903
- if arg_type == Callable and type(args[i]) is warp.context.Function:
904
- continue
905
-
906
- if arg_type == Reference and is_reference(args[i].type):
907
- continue
908
-
909
- # look for default values for missing args
910
- if i >= len(args):
911
- if arg_name not in f.defaults:
912
- return False
913
- else:
914
- # otherwise check arg type matches input variable type
915
- if not types_equal(arg_type, strip_reference(args[i].type), match_generic=True):
916
- return False
917
-
918
- return True
919
-
920
- if not match_args(args, f):
1128
+ if not func_match_args(f, arg_types, kwarg_types):
921
1129
  continue
922
1130
 
923
1131
  # check output dimensions match expectations
924
1132
  if min_outputs:
925
- try:
926
- value_type = f.value_func(args, kwds, templates)
927
- if not hasattr(value_type, "__len__") or len(value_type) != min_outputs:
928
- continue
929
- except Exception:
930
- # value func may fail if the user has given
931
- # incorrect args, so we need to catch this
1133
+ if not isinstance(f.value_type, Sequence) or len(f.value_type) != min_outputs:
932
1134
  continue
933
1135
 
934
1136
  # found a match, use it
935
1137
  return f
936
1138
 
937
1139
  # unresolved function, report error
938
- arg_types = []
1140
+ arg_type_reprs = []
939
1141
 
940
- for x in args:
941
- if isinstance(x, Var):
1142
+ for x in arg_types:
1143
+ if isinstance(x, warp.context.Function):
1144
+ arg_type_reprs.append("function")
1145
+ else:
942
1146
  # shorten Warp primitive type names
943
- if isinstance(x.type, list):
944
- if len(x.type) != 1:
1147
+ if isinstance(x, Sequence):
1148
+ if len(x) != 1:
945
1149
  raise WarpCodegenError("Argument must not be the result from a multi-valued function")
946
- arg_type = x.type[0]
1150
+ arg_type = x[0]
947
1151
  else:
948
- arg_type = x.type
1152
+ arg_type = x
949
1153
 
950
- arg_types.append(type_repr(arg_type))
951
-
952
- if isinstance(x, warp.context.Function):
953
- arg_types.append("function")
1154
+ arg_type_reprs.append(type_repr(arg_type))
954
1155
 
955
1156
  raise WarpCodegenError(
956
- f"Couldn't find function overload for '{func.key}' that matched inputs with types: [{', '.join(arg_types)}]"
1157
+ f"Couldn't find function overload for '{func.key}' that matched inputs with types: [{', '.join(arg_type_reprs)}]"
957
1158
  )
958
1159
 
959
- def add_call(adj, func, args, min_outputs=None, templates=None, kwds=None):
960
- if templates is None:
961
- templates = []
1160
+ def add_call(adj, func, args, kwargs, type_args, min_outputs=None):
1161
+ # Extract the types and values passed as arguments to the function call.
1162
+ arg_types = tuple(strip_reference(get_arg_type(x)) for x in args)
1163
+ kwarg_types = {k: strip_reference(get_arg_type(v)) for k, v in kwargs.items()}
1164
+
1165
+ # Resolve the exact function signature among any existing overload.
1166
+ func = adj.resolve_func(func, arg_types, kwarg_types, min_outputs)
1167
+
1168
+ # Bind the positional and keyword arguments to the function's signature
1169
+ # in order to process them as Python does it.
1170
+ bound_args = func.signature.bind(*args, **kwargs)
1171
+
1172
+ # Type args are the “compile time” argument values we get from codegen.
1173
+ # For example, when calling `wp.vec3f(...)` from within a kernel,
1174
+ # this translates in fact to calling the `vector()` built-in augmented
1175
+ # with the type args `length=3, dtype=float`.
1176
+ # Eventually, these need to be passed to the underlying C++ function,
1177
+ # so we update the arguments with the type args here.
1178
+ if type_args:
1179
+ for arg in type_args:
1180
+ if arg in bound_args.arguments:
1181
+ # In case of conflict, ideally we'd throw an error since
1182
+ # what comes from codegen should be the source of truth
1183
+ # and users also passing the same value as an argument
1184
+ # is redundant (e.g.: `wp.mat22(shape=(2, 2))`).
1185
+ # However, for backward compatibility, we allow that form
1186
+ # as long as the values are equal.
1187
+ if values_check_equal(get_arg_value(bound_args.arguments[arg]), type_args[arg]):
1188
+ continue
962
1189
 
963
- func = adj.resolve_func(func, args, min_outputs, templates, kwds)
1190
+ raise RuntimeError(
1191
+ f"Remove the extraneous `{arg}` parameter "
1192
+ f"when calling the templated version of "
1193
+ f"`wp.{func.native_func}()`"
1194
+ )
964
1195
 
965
- # push any default values onto args
966
- for i, (arg_name, _arg_type) in enumerate(func.input_types.items()):
967
- if i >= len(args):
968
- if arg_name in func.defaults:
969
- const = adj.add_constant(func.defaults[arg_name])
970
- args.append(const)
971
- else:
972
- break
1196
+ type_vars = {k: Var(None, type=type(v), constant=v) for k, v in type_args.items()}
1197
+ apply_defaults(bound_args, type_vars)
1198
+
1199
+ if func.defaults:
1200
+ default_vars = {
1201
+ k: Var(None, type=type(v), constant=v)
1202
+ for k, v in func.defaults.items()
1203
+ if k not in bound_args.arguments and v is not None
1204
+ }
1205
+ apply_defaults(bound_args, default_vars)
1206
+
1207
+ bound_args = bound_args.arguments
973
1208
 
974
1209
  # if it is a user-function then build it recursively
975
1210
  if not func.is_builtin() and func not in adj.builder.functions:
@@ -983,23 +1218,38 @@ class Adjoint:
983
1218
  if func.custom_replay_func:
984
1219
  adj.builder.deferred_functions.append(func.custom_replay_func)
985
1220
 
986
- # evaluate the function type based on inputs
987
- arg_types = [strip_reference(a.type) for a in args if not isinstance(a, warp.context.Function)]
988
- return_type = func.value_func(arg_types, kwds, templates)
1221
+ # Resolve the return value based on the types and values of the given arguments.
1222
+ bound_arg_types = {k: get_arg_type(v) for k, v in bound_args.items()}
1223
+ bound_arg_values = {k: get_arg_value(v) for k, v in bound_args.items()}
1224
+ return_type = func.value_func(
1225
+ {k: strip_reference(v) for k, v in bound_arg_types.items()},
1226
+ bound_arg_values,
1227
+ )
1228
+
1229
+ if func.dispatch_func is not None:
1230
+ # If we have a built-in that requires special handling to dispatch
1231
+ # the arguments to the underlying C++ function, then we can resolve
1232
+ # these using the `dispatch_func`. Since this is only called from
1233
+ # within codegen, we pass it directly `codegen.Var` objects,
1234
+ # which allows for some more advanced resolution to be performed,
1235
+ # for example by checking whether an argument corresponds to
1236
+ # a literal value or references a variable.
989
1237
 
990
- func_name = compute_type_str(func.native_func, templates)
991
- param_types = list(func.input_types.values())
1238
+ func_args, template_args = func.dispatch_func(func.input_types, return_type, bound_args)
1239
+ else:
1240
+ func_args = tuple(bound_args.values())
1241
+ template_args = ()
992
1242
 
993
- use_initializer_list = func.initializer_list_func(args, templates)
1243
+ func_args = tuple(adj.register_var(x) for x in func_args)
1244
+ func_name = compute_type_str(func.native_func, template_args)
1245
+ use_initializer_list = func.initializer_list_func(bound_args, return_type)
994
1246
 
995
- args_var = [
996
- (
997
- adj.load(a)
998
- if not ((param_types[i] == Reference or param_types[i] == Callable) if i < len(param_types) else False)
999
- else a
1000
- )
1001
- for i, a in enumerate(args)
1002
- ]
1247
+ fwd_args = []
1248
+ for func_arg in func_args:
1249
+ if not isinstance(func_arg, (Reference, warp.context.Function)):
1250
+ func_arg = adj.load(func_arg)
1251
+
1252
+ fwd_args.append(strip_reference(func_arg))
1003
1253
 
1004
1254
  if return_type is None:
1005
1255
  # handles expression (zero output) functions, e.g.: void do_something();
@@ -1008,24 +1258,24 @@ class Adjoint:
1008
1258
  output_list = []
1009
1259
 
1010
1260
  forward_call = (
1011
- f"{func.namespace}{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
1261
+ f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1012
1262
  )
1013
1263
  replay_call = forward_call
1014
1264
  if func.custom_replay_func is not None or func.replay_snippet is not None:
1015
- replay_call = f"{func.namespace}replay_{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
1265
+ replay_call = f"{func.namespace}replay_{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1016
1266
 
1017
- elif not isinstance(return_type, list) or len(return_type) == 1:
1267
+ elif not isinstance(return_type, Sequence) or len(return_type) == 1:
1018
1268
  # handle simple function (one output)
1019
1269
 
1020
- if isinstance(return_type, list):
1270
+ if isinstance(return_type, Sequence):
1021
1271
  return_type = return_type[0]
1022
1272
  output = adj.add_var(return_type)
1023
1273
  output_list = [output]
1024
1274
 
1025
- forward_call = f"var_{output} = {func.namespace}{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
1275
+ forward_call = f"var_{output} = {func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1026
1276
  replay_call = forward_call
1027
1277
  if func.custom_replay_func is not None:
1028
- replay_call = f"var_{output} = {func.namespace}replay_{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
1278
+ replay_call = f"var_{output} = {func.namespace}replay_{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1029
1279
 
1030
1280
  else:
1031
1281
  # handle multiple value functions
@@ -1034,7 +1284,7 @@ class Adjoint:
1034
1284
  output_list = output
1035
1285
 
1036
1286
  forward_call = (
1037
- f"{func.namespace}{func_name}({adj.format_forward_call_args(args_var + output, use_initializer_list)});"
1287
+ f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args + output, use_initializer_list)});"
1038
1288
  )
1039
1289
  replay_call = forward_call
1040
1290
 
@@ -1043,13 +1293,14 @@ class Adjoint:
1043
1293
  else:
1044
1294
  adj.add_forward(forward_call, replay=replay_call)
1045
1295
 
1046
- if not func.missing_grad and len(args):
1296
+ if not func.missing_grad and len(func_args):
1297
+ adj_args = tuple(strip_reference(x) for x in func_args)
1047
1298
  reverse_has_output_args = (
1048
1299
  func.require_original_output_arg or len(output_list) > 1
1049
1300
  ) and func.custom_grad_func is None
1050
1301
  arg_str = adj.format_reverse_call_args(
1051
- args_var,
1052
- args,
1302
+ fwd_args,
1303
+ adj_args,
1053
1304
  output_list,
1054
1305
  use_initializer_list,
1055
1306
  has_output_args=reverse_has_output_args,
@@ -1061,12 +1312,9 @@ class Adjoint:
1061
1312
 
1062
1313
  return output
1063
1314
 
1064
- def add_builtin_call(adj, func_name, args, min_outputs=None, templates=None, kwds=None):
1065
- if templates is None:
1066
- templates = []
1067
-
1315
+ def add_builtin_call(adj, func_name, args, min_outputs=None):
1068
1316
  func = warp.context.builtin_functions[func_name]
1069
- return adj.add_call(func, args, min_outputs, templates, kwds)
1317
+ return adj.add_call(func, args, {}, {}, min_outputs=min_outputs)
1070
1318
 
1071
1319
  def add_return(adj, var):
1072
1320
  if var is None or len(var) == 0:
@@ -1505,7 +1753,24 @@ class Adjoint:
1505
1753
 
1506
1754
  def emit_BinOp(adj, node):
1507
1755
  # evaluate binary operator arguments
1756
+
1757
+ if warp.config.verify_autograd_array_access:
1758
+ # array overwrite tracking: in-place operators are a special case
1759
+ # x[tid] = x[tid] + 1 is a read followed by a write, but we only want to record the write
1760
+ # so we save the current arg read flags and restore them after lhs eval
1761
+ is_read_states = []
1762
+ for arg in adj.args:
1763
+ is_read_states.append(arg.is_read)
1764
+
1765
+ # evaluate lhs binary operator argument
1508
1766
  left = adj.eval(node.left)
1767
+
1768
+ if warp.config.verify_autograd_array_access:
1769
+ # restore arg read flags
1770
+ for i, arg in enumerate(adj.args):
1771
+ arg.is_read = is_read_states[i]
1772
+
1773
+ # evaluate rhs binary operator argument
1509
1774
  right = adj.eval(node.right)
1510
1775
 
1511
1776
  name = builtin_operators[type(node.op)]
@@ -1569,6 +1834,9 @@ class Adjoint:
1569
1834
  # e.g.: wp.constant in the globals scope
1570
1835
  obj, _ = adj.resolve_static_expression(a)
1571
1836
 
1837
+ if obj is None:
1838
+ obj = adj.eval(a)
1839
+
1572
1840
  if isinstance(obj, Var) and obj.constant is not None:
1573
1841
  obj = obj.constant
1574
1842
 
@@ -1728,13 +1996,40 @@ class Adjoint:
1728
1996
  f"arguments to the function {adj.fun_name}, {adj.filename}:{lineno}:\n{line}\n"
1729
1997
  )
1730
1998
 
1999
+ def resolve_arg(adj, arg):
2000
+ # Always try to start with evaluating the argument since it can help
2001
+ # detecting some issues such as global variables being accessed.
2002
+ try:
2003
+ var = adj.eval(arg)
2004
+ except (WarpCodegenError, WarpCodegenKeyError) as e:
2005
+ error = e
2006
+ else:
2007
+ error = None
2008
+
2009
+ # Check if we can resolve the argument as a static expression.
2010
+ # If not, return the variable resulting from evaluating the argument.
2011
+ expr, _ = adj.resolve_static_expression(arg)
2012
+ if expr is None:
2013
+ if error is not None:
2014
+ raise error
2015
+
2016
+ return var
2017
+
2018
+ if isinstance(expr, (type, Var, warp.context.Function)):
2019
+ return expr
2020
+
2021
+ return adj.add_constant(expr)
2022
+
1731
2023
  def emit_Call(adj, node):
1732
2024
  adj.check_tid_in_func_error(node)
1733
2025
 
1734
2026
  # try and lookup function in globals by
1735
2027
  # resolving path (e.g.: module.submodule.attr)
1736
2028
  func, path = adj.resolve_static_expression(node.func)
1737
- templates = []
2029
+ if func is None:
2030
+ func = adj.eval(node.func)
2031
+
2032
+ type_args = {}
1738
2033
 
1739
2034
  if not isinstance(func, warp.context.Function):
1740
2035
  attr = path[-1]
@@ -1747,7 +2042,6 @@ class Adjoint:
1747
2042
 
1748
2043
  # vector class type e.g.: wp.vec3f constructor
1749
2044
  if func is None and hasattr(caller, "_wp_generic_type_str_"):
1750
- templates = caller._wp_type_params_
1751
2045
  func = warp.context.builtin_functions.get(caller._wp_constructor_)
1752
2046
 
1753
2047
  # scalar class type e.g.: wp.int8 constructor
@@ -1757,43 +2051,53 @@ class Adjoint:
1757
2051
  # struct constructor
1758
2052
  if func is None and isinstance(caller, Struct):
1759
2053
  adj.builder.build_struct_recursive(caller)
1760
- func = caller.initializer()
2054
+ if node.args or node.keywords:
2055
+ func = caller.value_constructor
2056
+ else:
2057
+ func = caller.default_constructor
2058
+
2059
+ if hasattr(caller, "_wp_type_args_"):
2060
+ type_args = caller._wp_type_args_
1761
2061
 
1762
2062
  if func is None:
1763
2063
  raise WarpCodegenError(
1764
2064
  f"Could not find function {'.'.join(path)} as a built-in or user-defined function. Note that user functions must be annotated with a @wp.func decorator to be called from a kernel."
1765
2065
  )
1766
2066
 
1767
- args = []
1768
-
1769
- # eval all arguments
2067
+ # Check if any argument correspond to an unsupported construct.
2068
+ # Tuples are supported in the context of assigning multiple variables
2069
+ # at once, but not in place of vectors when calling built-ins like
2070
+ # `wp.length((1, 2, 3))`.
2071
+ # Therefore, we need to catch this specific case here instead of
2072
+ # more generally in `adj.eval()`.
1770
2073
  for arg in node.args:
1771
- var = adj.eval(arg)
1772
- args.append(var)
1773
-
1774
- # eval all keyword args
1775
- def kwval(kw):
1776
- if isinstance(kw.value, ast.Num):
1777
- return kw.value.n
1778
- elif isinstance(kw.value, ast.Tuple):
1779
- arg_is_numeric, arg_values = zip(*(adj.eval_num(e) for e in kw.value.elts))
1780
- if not all(arg_is_numeric):
1781
- raise WarpCodegenError(
1782
- f"All elements of the tuple keyword argument '{kw.name}' must be numeric constants, got '{arg_values}'"
1783
- )
1784
- return arg_values
1785
- else:
1786
- return adj.resolve_static_expression(kw.value)[0]
1787
-
1788
- kwds = {kw.arg: kwval(kw) for kw in node.keywords}
2074
+ if isinstance(arg, ast.Tuple):
2075
+ raise WarpCodegenError(
2076
+ "Tuple constructs are not supported in kernels. Use vectors like `wp.vec3()` instead."
2077
+ )
1789
2078
 
1790
2079
  # get expected return count, e.g.: for multi-assignment
1791
2080
  min_outputs = None
1792
2081
  if hasattr(node, "expects"):
1793
2082
  min_outputs = node.expects
1794
2083
 
1795
- # add var with value type from the function
1796
- out = adj.add_call(func=func, args=args, kwds=kwds, templates=templates, min_outputs=min_outputs)
2084
+ # Evaluate all positional and keywords arguments.
2085
+ args = tuple(adj.resolve_arg(x) for x in node.args)
2086
+ kwargs = {x.arg: adj.resolve_arg(x.value) for x in node.keywords}
2087
+
2088
+ if warp.config.verify_autograd_array_access:
2089
+ # update arg read/write states according to what happens to that arg in the called function
2090
+ if hasattr(func, "adj"):
2091
+ for i, arg in enumerate(args):
2092
+ if func.adj.args[i].is_write:
2093
+ kernel_name = adj.fun_name
2094
+ filename = adj.filename
2095
+ lineno = adj.lineno + adj.fun_lineno
2096
+ arg.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2097
+ if func.adj.args[i].is_read:
2098
+ arg.mark_read()
2099
+
2100
+ out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs)
1797
2101
  return out
1798
2102
 
1799
2103
  def emit_Index(adj, node):
@@ -1872,10 +2176,22 @@ class Adjoint:
1872
2176
  if len(indices) == target_type.ndim:
1873
2177
  # handles array loads (where each dimension has an index specified)
1874
2178
  out = adj.add_builtin_call("address", [target, *indices])
2179
+
2180
+ if warp.config.verify_autograd_array_access:
2181
+ target.mark_read()
2182
+
1875
2183
  else:
1876
2184
  # handles array views (fewer indices than dimensions)
1877
2185
  out = adj.add_builtin_call("view", [target, *indices])
1878
2186
 
2187
+ if warp.config.verify_autograd_array_access:
2188
+ # store reference to target Var to propagate downstream read/write state back to root arg Var
2189
+ out.parent = target
2190
+
2191
+ # view arg inherits target Var's read/write states
2192
+ out.is_read = target.is_read
2193
+ out.is_write = target.is_write
2194
+
1879
2195
  else:
1880
2196
  # handles non-array type indexing, e.g: vec3, mat33, etc
1881
2197
  out = adj.add_builtin_call("extract", [target, *indices])
@@ -1888,6 +2204,21 @@ class Adjoint:
1888
2204
 
1889
2205
  lhs = node.targets[0]
1890
2206
 
2207
+ if not isinstance(lhs, ast.Tuple):
2208
+ # Check if the rhs corresponds to an unsupported construct.
2209
+ # Tuples are supported in the context of assigning multiple variables
2210
+ # at once, but not for simple assignments like `x = (1, 2, 3)`.
2211
+ # Therefore, we need to catch this specific case here instead of
2212
+ # more generally in `adj.eval()`.
2213
+ if isinstance(node.value, ast.List):
2214
+ raise WarpCodegenError(
2215
+ "List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
2216
+ )
2217
+ elif isinstance(node.value, ast.Tuple):
2218
+ raise WarpCodegenError(
2219
+ "Tuple constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
2220
+ )
2221
+
1891
2222
  # handle the case where we are assigning multiple output variables
1892
2223
  if isinstance(lhs, ast.Tuple):
1893
2224
  # record the expected number of outputs on the node
@@ -1944,7 +2275,14 @@ class Adjoint:
1944
2275
  if is_array(target_type):
1945
2276
  adj.add_builtin_call("array_store", [target, *indices, rhs])
1946
2277
 
1947
- elif type_is_vector(target_type) or type_is_matrix(target_type):
2278
+ if warp.config.verify_autograd_array_access:
2279
+ kernel_name = adj.fun_name
2280
+ filename = adj.filename
2281
+ lineno = adj.lineno + adj.fun_lineno
2282
+
2283
+ target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2284
+
2285
+ elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
1948
2286
  if is_reference(target.type):
1949
2287
  attr = adj.add_builtin_call("indexref", [target, *indices])
1950
2288
  else:
@@ -1961,7 +2299,7 @@ class Adjoint:
1961
2299
  )
1962
2300
 
1963
2301
  else:
1964
- raise WarpCodegenError("Can only subscript assign array, vector, and matrix types")
2302
+ raise WarpCodegenError("Can only subscript assign array, vector, quaternion, and matrix types")
1965
2303
 
1966
2304
  elif isinstance(lhs, ast.Name):
1967
2305
  # symbol name
@@ -2050,8 +2388,7 @@ class Adjoint:
2050
2388
 
2051
2389
  def emit_Tuple(adj, node):
2052
2390
  # LHS for expressions, such as i, j, k = 1, 2, 3
2053
- for elem in node.elts:
2054
- adj.eval(elem)
2391
+ return tuple(adj.eval(x) for x in node.elts)
2055
2392
 
2056
2393
  def emit_Pass(adj, node):
2057
2394
  pass
@@ -2089,7 +2426,12 @@ class Adjoint:
2089
2426
  if hasattr(node, "lineno"):
2090
2427
  adj.set_lineno(node.lineno - 1)
2091
2428
 
2092
- emit_node = adj.node_visitors[type(node)]
2429
+ try:
2430
+ emit_node = adj.node_visitors[type(node)]
2431
+ except KeyError as e:
2432
+ type_name = type(node).__name__
2433
+ namespace = "ast." if isinstance(node, ast.AST) else ""
2434
+ raise WarpCodegenError(f"Construct `{namespace}{type_name}` not supported in kernels.") from e
2093
2435
 
2094
2436
  return emit_node(adj, node)
2095
2437
 
@@ -2120,18 +2462,18 @@ class Adjoint:
2120
2462
  vars_dict = {**adj.func.__globals__, **capturedvars}
2121
2463
 
2122
2464
  if path[0] in vars_dict:
2123
- func = vars_dict[path[0]]
2465
+ expr = vars_dict[path[0]]
2124
2466
 
2125
2467
  # Support Warp types in kernels without the module suffix (e.g. v = vec3(0.0,0.2,0.4)):
2126
2468
  else:
2127
- func = getattr(warp, path[0], None)
2469
+ expr = getattr(warp, path[0], None)
2128
2470
 
2129
- if func:
2471
+ if expr:
2130
2472
  for i in range(1, len(path)):
2131
- if hasattr(func, path[i]):
2132
- func = getattr(func, path[i])
2473
+ if hasattr(expr, path[i]):
2474
+ expr = getattr(expr, path[i])
2133
2475
 
2134
- return func
2476
+ return expr
2135
2477
 
2136
2478
  # Evaluates a static expression that does not depend on runtime values
2137
2479
  # if eval_types is True, try resolving the path using evaluated type information as well
@@ -2182,11 +2524,6 @@ class Adjoint:
2182
2524
  if captured_obj is not None:
2183
2525
  return captured_obj, path
2184
2526
 
2185
- # Still nothing found, maybe this is a predefined type attribute like `dtype`
2186
- if eval_types:
2187
- val = adj.eval(root_node)
2188
- return [val, path]
2189
-
2190
2527
  return None, path
2191
2528
 
2192
2529
  # annotate generated code with the original source code line
@@ -2262,10 +2599,10 @@ cpu_module_header = """
2262
2599
  #define int(x) cast_int(x)
2263
2600
  #define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
2264
2601
 
2265
- #define builtin_tid1d() wp::tid(wp::s_threadIdx)
2266
- #define builtin_tid2d(x, y) wp::tid(x, y, wp::s_threadIdx, dim)
2267
- #define builtin_tid3d(x, y, z) wp::tid(x, y, z, wp::s_threadIdx, dim)
2268
- #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, wp::s_threadIdx, dim)
2602
+ #define builtin_tid1d() wp::tid(task_index)
2603
+ #define builtin_tid2d(x, y) wp::tid(x, y, task_index, dim)
2604
+ #define builtin_tid3d(x, y, z) wp::tid(x, y, z, task_index, dim)
2605
+ #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, task_index, dim)
2269
2606
 
2270
2607
  """
2271
2608
 
@@ -2280,10 +2617,10 @@ cuda_module_header = """
2280
2617
  #define int(x) cast_int(x)
2281
2618
  #define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
2282
2619
 
2283
- #define builtin_tid1d() wp::tid(_idx)
2284
- #define builtin_tid2d(x, y) wp::tid(x, y, _idx, dim)
2285
- #define builtin_tid3d(x, y, z) wp::tid(x, y, z, _idx, dim)
2286
- #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, _idx, dim)
2620
+ #define builtin_tid1d() wp::tid(task_index)
2621
+ #define builtin_tid2d(x, y) wp::tid(x, y, task_index, dim)
2622
+ #define builtin_tid3d(x, y, z) wp::tid(x, y, z, task_index, dim)
2623
+ #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, task_index, dim)
2287
2624
 
2288
2625
  """
2289
2626
 
@@ -2355,9 +2692,9 @@ cuda_kernel_template = """
2355
2692
  extern "C" __global__ void {name}_cuda_kernel_forward(
2356
2693
  {forward_args})
2357
2694
  {{
2358
- for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
2359
- _idx < dim.size;
2360
- _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
2695
+ for (size_t task_index = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
2696
+ task_index < dim.size;
2697
+ task_index += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
2361
2698
  {{
2362
2699
  {forward_body} }}
2363
2700
  }}
@@ -2365,9 +2702,9 @@ extern "C" __global__ void {name}_cuda_kernel_forward(
2365
2702
  extern "C" __global__ void {name}_cuda_kernel_backward(
2366
2703
  {reverse_args})
2367
2704
  {{
2368
- for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
2369
- _idx < dim.size;
2370
- _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
2705
+ for (size_t task_index = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
2706
+ task_index < dim.size;
2707
+ task_index += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
2371
2708
  {{
2372
2709
  {reverse_body} }}
2373
2710
  }}
@@ -2396,10 +2733,8 @@ extern "C" {{
2396
2733
  WP_API void {name}_cpu_forward(
2397
2734
  {forward_args})
2398
2735
  {{
2399
- for (size_t i=0; i < dim.size; ++i)
2736
+ for (size_t task_index = 0; task_index < dim.size; ++task_index)
2400
2737
  {{
2401
- wp::s_threadIdx = i;
2402
-
2403
2738
  {name}_cpu_kernel_forward(
2404
2739
  {forward_params});
2405
2740
  }}
@@ -2408,10 +2743,8 @@ WP_API void {name}_cpu_forward(
2408
2743
  WP_API void {name}_cpu_backward(
2409
2744
  {reverse_args})
2410
2745
  {{
2411
- for (size_t i=0; i < dim.size; ++i)
2746
+ for (size_t task_index = 0; task_index < dim.size; ++task_index)
2412
2747
  {{
2413
- wp::s_threadIdx = i;
2414
-
2415
2748
  {name}_cpu_kernel_backward(
2416
2749
  {reverse_params});
2417
2750
  }}
@@ -2838,6 +3171,10 @@ def codegen_kernel(kernel, device, options):
2838
3171
  forward_args = ["wp::launch_bounds_t dim"]
2839
3172
  reverse_args = ["wp::launch_bounds_t dim"]
2840
3173
 
3174
+ if device == "cpu":
3175
+ forward_args.append("size_t task_index")
3176
+ reverse_args.append("size_t task_index")
3177
+
2841
3178
  # forward args
2842
3179
  for arg in adj.args:
2843
3180
  forward_args.append(arg.ctype() + " var_" + arg.label)
@@ -2886,7 +3223,7 @@ def codegen_module(kernel, device="cpu"):
2886
3223
 
2887
3224
  # build forward signature
2888
3225
  forward_args = ["wp::launch_bounds_t dim"]
2889
- forward_params = ["dim"]
3226
+ forward_params = ["dim", "task_index"]
2890
3227
 
2891
3228
  for arg in adj.args:
2892
3229
  if hasattr(arg.type, "_wp_generic_type_str_"):