warp-lang 1.0.0b2__py3-none-win_amd64.whl → 1.0.0b6__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (271) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.dll +0 -0
  57. warp/bin/warp.dll +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/bin/warp-clang.so +0 -0
  257. warp/bin/warp.so +0 -0
  258. warp/fem/field/discrete_field.py +0 -80
  259. warp/fem/space/nodal_function_space.py +0 -233
  260. warp/tests/test_all.py +0 -223
  261. warp/tests/test_array_scan.py +0 -60
  262. warp/tests/test_base.py +0 -208
  263. warp/tests/test_unresolved_func.py +0 -7
  264. warp/tests/test_unresolved_symbol.py +0 -7
  265. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  266. warp_lang-1.0.0b2.dist-info/RECORD +0 -380
  267. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  268. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  269. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  270. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  271. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/codegen.py CHANGED
@@ -20,6 +20,27 @@ from typing import Any, Callable, Mapping
20
20
  import warp.config
21
21
  from warp.types import *
22
22
 
23
+
24
+ class WarpCodegenError(RuntimeError):
25
+ def __init__(self, message):
26
+ super().__init__(message)
27
+
28
+
29
+ class WarpCodegenTypeError(TypeError):
30
+ def __init__(self, message):
31
+ super().__init__(message)
32
+
33
+
34
+ class WarpCodegenAttributeError(AttributeError):
35
+ def __init__(self, message):
36
+ super().__init__(message)
37
+
38
+
39
+ class WarpCodegenKeyError(KeyError):
40
+ def __init__(self, message):
41
+ super().__init__(message)
42
+
43
+
23
44
  # map operator to function name
24
45
  builtin_operators = {}
25
46
 
@@ -52,6 +73,19 @@ builtin_operators[ast.Invert] = "invert"
52
73
  builtin_operators[ast.LShift] = "lshift"
53
74
  builtin_operators[ast.RShift] = "rshift"
54
75
 
76
+ comparison_chain_strings = [
77
+ builtin_operators[ast.Gt],
78
+ builtin_operators[ast.Lt],
79
+ builtin_operators[ast.LtE],
80
+ builtin_operators[ast.GtE],
81
+ builtin_operators[ast.Eq],
82
+ builtin_operators[ast.NotEq],
83
+ ]
84
+
85
+
86
+ def op_str_is_chainable(op: str) -> builtins.bool:
87
+ return op in comparison_chain_strings
88
+
55
89
 
56
90
  def get_annotations(obj: Any) -> Mapping[str, Any]:
57
91
  """Alternative to `inspect.get_annotations()` for Python 3.9 and older."""
@@ -65,16 +99,14 @@ def get_annotations(obj: Any) -> Mapping[str, Any]:
65
99
  def struct_instance_repr_recursive(inst: StructInstance, depth: int) -> str:
66
100
  indent = "\t"
67
101
 
68
- if inst._cls.ctype._fields_ == [("_dummy_", ctypes.c_int)]:
102
+ # handle empty structs
103
+ if len(inst._cls.vars) == 0:
69
104
  return f"{inst._cls.key}()"
70
105
 
71
106
  lines = []
72
107
  lines.append(f"{inst._cls.key}(")
73
108
 
74
109
  for field_name, _ in inst._cls.ctype._fields_:
75
- if field_name == "_dummy_":
76
- continue
77
-
78
110
  field_value = getattr(inst, field_name, None)
79
111
 
80
112
  if isinstance(field_value, StructInstance):
@@ -121,9 +153,7 @@ class StructInstance:
121
153
  assert isinstance(value, array)
122
154
  assert types_equal(
123
155
  value.dtype, var.type.dtype
124
- ), "assign to struct member variable {} failed, expected type {}, got type {}".format(
125
- name, type_repr(var.type.dtype), type_repr(value.dtype)
126
- )
156
+ ), f"assign to struct member variable {name} failed, expected type {type_repr(var.type.dtype)}, got type {type_repr(value.dtype)}"
127
157
  setattr(self._ctype, name, value.__ctype__())
128
158
 
129
159
  elif isinstance(var.type, Struct):
@@ -242,7 +272,7 @@ class Struct:
242
272
 
243
273
  class StructType(ctypes.Structure):
244
274
  # if struct is empty, add a dummy field to avoid launch errors on CPU device ("ffi_prep_cif failed")
245
- _fields_ = fields or [("_dummy_", ctypes.c_int)]
275
+ _fields_ = fields or [("_dummy_", ctypes.c_byte)]
246
276
 
247
277
  self.ctype = StructType
248
278
 
@@ -363,21 +393,38 @@ class Struct:
363
393
  return instance
364
394
 
365
395
 
396
+ class Reference:
397
+ def __init__(self, value_type):
398
+ self.value_type = value_type
399
+
400
+
401
+ def is_reference(type):
402
+ return isinstance(type, Reference)
403
+
404
+
405
+ def strip_reference(arg):
406
+ if is_reference(arg):
407
+ return arg.value_type
408
+ else:
409
+ return arg
410
+
411
+
366
412
  def compute_type_str(base_name, template_params):
367
- if template_params is None or len(template_params) == 0:
413
+ if not template_params:
368
414
  return base_name
369
- else:
370
415
 
371
- def param2str(p):
372
- if isinstance(p, int):
373
- return str(p)
374
- return p.__name__
416
+ def param2str(p):
417
+ if isinstance(p, int):
418
+ return str(p)
419
+ elif hasattr(p, "_type_"):
420
+ return f"wp::{p.__name__}"
421
+ return p.__name__
375
422
 
376
- return f"{base_name}<{','.join(map(param2str, template_params))}>"
423
+ return f"{base_name}<{','.join(map(param2str, template_params))}>"
377
424
 
378
425
 
379
426
  class Var:
380
- def __init__(self, label, type, requires_grad=False, constant=None, prefix=True, is_adjoint=False):
427
+ def __init__(self, label, type, requires_grad=False, constant=None, prefix=True):
381
428
  # convert built-in types to wp types
382
429
  if type == float:
383
430
  type = float32
@@ -389,27 +436,39 @@ class Var:
389
436
  self.requires_grad = requires_grad
390
437
  self.constant = constant
391
438
  self.prefix = prefix
392
- self.is_adjoint = is_adjoint
393
439
 
394
440
  def __str__(self):
395
441
  return self.label
396
442
 
397
- def ctype(self):
398
- if is_array(self.type):
399
- if hasattr(self.type.dtype, "_wp_generic_type_str_"):
400
- dtypestr = compute_type_str(self.type.dtype._wp_generic_type_str_, self.type.dtype._wp_type_params_)
401
- elif isinstance(self.type.dtype, Struct):
402
- dtypestr = make_full_qualified_name(self.type.dtype.cls)
443
+ @staticmethod
444
+ def type_to_ctype(t, value_type=False):
445
+ if is_array(t):
446
+ if hasattr(t.dtype, "_wp_generic_type_str_"):
447
+ dtypestr = compute_type_str(f"wp::{t.dtype._wp_generic_type_str_}", t.dtype._wp_type_params_)
448
+ elif isinstance(t.dtype, Struct):
449
+ dtypestr = make_full_qualified_name(t.dtype.cls)
450
+ elif t.dtype.__name__ in ("bool", "int", "float"):
451
+ dtypestr = t.dtype.__name__
403
452
  else:
404
- dtypestr = str(self.type.dtype.__name__)
405
- classstr = type(self.type).__name__
453
+ dtypestr = f"wp::{t.dtype.__name__}"
454
+ classstr = f"wp::{type(t).__name__}"
406
455
  return f"{classstr}_t<{dtypestr}>"
407
- elif isinstance(self.type, Struct):
408
- return make_full_qualified_name(self.type.cls)
409
- elif hasattr(self.type, "_wp_generic_type_str_"):
410
- return compute_type_str(self.type._wp_generic_type_str_, self.type._wp_type_params_)
456
+ elif isinstance(t, Struct):
457
+ return make_full_qualified_name(t.cls)
458
+ elif is_reference(t):
459
+ if not value_type:
460
+ return Var.type_to_ctype(t.value_type) + "*"
461
+ else:
462
+ return Var.type_to_ctype(t.value_type)
463
+ elif hasattr(t, "_wp_generic_type_str_"):
464
+ return compute_type_str(f"wp::{t._wp_generic_type_str_}", t._wp_type_params_)
465
+ elif t.__name__ in ("bool", "int", "float"):
466
+ return t.__name__
411
467
  else:
412
- return str(self.type.__name__)
468
+ return f"wp::{t.__name__}"
469
+
470
+ def ctype(self, value_type=False):
471
+ return Var.type_to_ctype(self.type, value_type)
413
472
 
414
473
  def emit(self, prefix: str = "var"):
415
474
  if self.prefix:
@@ -417,6 +476,9 @@ class Var:
417
476
  else:
418
477
  return self.label
419
478
 
479
+ def emit_adj(self):
480
+ return self.emit("adj")
481
+
420
482
 
421
483
  class Block:
422
484
  # Represents a basic block of instructions, e.g.: list
@@ -456,20 +518,17 @@ class Adjoint:
456
518
  # whether the generation of the adjoint code is skipped for this function
457
519
  adj.skip_reverse_codegen = skip_reverse_codegen
458
520
 
459
- # build AST from function object
460
- adj.source = inspect.getsource(func)
461
-
462
- # get source code lines and line number where function starts
463
- adj.raw_source, adj.fun_lineno = inspect.getsourcelines(func)
464
-
465
- # keep track of line number in function code
466
- adj.lineno = None
521
+ # extract name of source file
522
+ adj.filename = inspect.getsourcefile(func) or "unknown source file"
523
+ # get source file line number where function starts
524
+ _, adj.fun_lineno = inspect.getsourcelines(func)
467
525
 
526
+ # get function source code
527
+ adj.source = inspect.getsource(func)
468
528
  # ensures that indented class methods can be parsed as kernels
469
529
  adj.source = textwrap.dedent(adj.source)
470
530
 
471
- # extract name of source file
472
- adj.filename = inspect.getsourcefile(func) or "unknown source file"
531
+ adj.source_lines = adj.source.splitlines()
473
532
 
474
533
  # build AST and apply node transformers
475
534
  adj.tree = ast.parse(adj.source)
@@ -479,6 +538,9 @@ class Adjoint:
479
538
 
480
539
  adj.fun_name = adj.tree.body[0].name
481
540
 
541
+ # for keeping track of line number in function code
542
+ adj.lineno = None
543
+
482
544
  # whether the forward code shall be used for the reverse pass and a custom
483
545
  # function signature is applied to the reverse version of the function
484
546
  adj.custom_reverse_mode = custom_reverse_mode
@@ -493,16 +555,17 @@ class Adjoint:
493
555
  if overload_annotations is None:
494
556
  # use source-level argument annotations
495
557
  if len(argspec.annotations) < len(argspec.args):
496
- raise RuntimeError(f"Incomplete argument annotations on function {adj.fun_name}")
558
+ raise WarpCodegenError(f"Incomplete argument annotations on function {adj.fun_name}")
497
559
  adj.arg_types = argspec.annotations
498
560
  else:
499
561
  # use overload argument annotations
500
562
  for arg_name in argspec.args:
501
563
  if arg_name not in overload_annotations:
502
- raise RuntimeError(f"Incomplete overload annotations for function {adj.fun_name}")
564
+ raise WarpCodegenError(f"Incomplete overload annotations for function {adj.fun_name}")
503
565
  adj.arg_types = overload_annotations.copy()
504
566
 
505
567
  adj.args = []
568
+ adj.symbols = {}
506
569
 
507
570
  for name, type in adj.arg_types.items():
508
571
  # skip return hint
@@ -513,8 +576,23 @@ class Adjoint:
513
576
  arg = Var(name, type, False)
514
577
  adj.args.append(arg)
515
578
 
579
+ # pre-populate symbol dictionary with function argument names
580
+ # this is to avoid registering false references to overshadowed modules
581
+ adj.symbols[name] = arg
582
+
583
+ # There are cases where a same module might be rebuilt multiple times,
584
+ # for example when kernels are nested inside of functions, or when
585
+ # a kernel's launch raises an exception. Ideally we'd always want to
586
+ # avoid rebuilding kernels but some corner cases seem to depend on it,
587
+ # so we only avoid rebuilding kernels that errored out to give a chance
588
+ # for unit testing errors being spit out from kernels.
589
+ adj.skip_build = False
590
+
516
591
  # generate function ssa form and adjoint
517
592
  def build(adj, builder):
593
+ if adj.skip_build:
594
+ return
595
+
518
596
  adj.builder = builder
519
597
 
520
598
  adj.symbols = {} # map from symbols to adjoint variables
@@ -528,7 +606,7 @@ class Adjoint:
528
606
  adj.loop_blocks = []
529
607
 
530
608
  # holds current indent level
531
- adj.prefix = ""
609
+ adj.indentation = ""
532
610
 
533
611
  # used to generate new label indices
534
612
  adj.label_count = 0
@@ -542,12 +620,17 @@ class Adjoint:
542
620
  adj.eval(adj.tree.body[0])
543
621
  except Exception as e:
544
622
  try:
623
+ if isinstance(e, KeyError) and getattr(e.args[0], "__module__", None) == "ast":
624
+ msg = f'Syntax error: unsupported construct "ast.{e.args[0].__name__}"'
625
+ else:
626
+ msg = "Error"
545
627
  lineno = adj.lineno + adj.fun_lineno
546
- line = adj.source.splitlines()[adj.lineno]
547
- msg = f'Error while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
628
+ line = adj.source_lines[adj.lineno]
629
+ msg += f' while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
548
630
  ex, data, traceback = sys.exc_info()
549
- e = ex("".join([msg] + list(data.args))).with_traceback(traceback)
631
+ e = ex(";".join([msg] + [str(a) for a in data.args])).with_traceback(traceback)
550
632
  finally:
633
+ adj.skip_build = True
551
634
  raise e
552
635
 
553
636
  if builder is not None:
@@ -570,16 +653,18 @@ class Adjoint:
570
653
  arg_strs = []
571
654
 
572
655
  for a in args:
573
- if type(a) == warp.context.Function:
656
+ if isinstance(a, warp.context.Function):
574
657
  # functions don't have a var_ prefix so strip it off here
575
658
  if prefix == "var":
576
659
  arg_strs.append(a.key)
577
660
  else:
578
661
  arg_strs.append(f"{prefix}_{a.key}")
662
+ elif is_reference(a.type):
663
+ arg_strs.append(f"{prefix}_{a}")
579
664
  elif isinstance(a, Var):
580
665
  arg_strs.append(a.emit(prefix))
581
666
  else:
582
- arg_strs.append(f"{prefix}_{a}")
667
+ raise WarpCodegenTypeError(f"Arguments must be variables or functions, got {type(a)}")
583
668
 
584
669
  return arg_strs
585
670
 
@@ -587,30 +672,37 @@ class Adjoint:
587
672
  def format_forward_call_args(adj, args, use_initializer_list):
588
673
  arg_str = ", ".join(adj.format_args("var", args))
589
674
  if use_initializer_list:
590
- return "{{{}}}".format(arg_str)
675
+ return f"{{{arg_str}}}"
591
676
  return arg_str
592
677
 
593
678
  # generates argument string for a reverse function call
594
679
  def format_reverse_call_args(
595
- adj, args, args_out, non_adjoint_args, non_adjoint_outputs, use_initializer_list, has_output_args=True
680
+ adj,
681
+ args_var,
682
+ args,
683
+ args_out,
684
+ use_initializer_list,
685
+ has_output_args=True,
686
+ require_original_output_arg=False,
596
687
  ):
597
- formatted_var = adj.format_args("var", args)
688
+ formatted_var = adj.format_args("var", args_var)
598
689
  formatted_out = []
599
- if has_output_args and len(args_out) > 1:
690
+ if has_output_args and (require_original_output_arg or len(args_out) > 1):
600
691
  formatted_out = adj.format_args("var", args_out)
601
692
  formatted_var_adj = adj.format_args(
602
- "&adj" if use_initializer_list else "adj", [a for i, a in enumerate(args) if i not in non_adjoint_args]
693
+ "&adj" if use_initializer_list else "adj",
694
+ args,
603
695
  )
604
- formatted_out_adj = adj.format_args("adj", [a for i, a in enumerate(args_out) if i not in non_adjoint_outputs])
696
+ formatted_out_adj = adj.format_args("adj", args_out)
605
697
 
606
698
  if len(formatted_var_adj) == 0 and len(formatted_out_adj) == 0:
607
699
  # there are no adjoint arguments, so we don't need to call the reverse function
608
700
  return None
609
701
 
610
702
  if use_initializer_list:
611
- var_str = "{{{}}}".format(", ".join(formatted_var))
612
- out_str = "{{{}}}".format(", ".join(formatted_out))
613
- adj_str = "{{{}}}".format(", ".join(formatted_var_adj))
703
+ var_str = f"{{{', '.join(formatted_var)}}}"
704
+ out_str = f"{{{', '.join(formatted_out)}}}"
705
+ adj_str = f"{{{', '.join(formatted_var_adj)}}}"
614
706
  out_adj_str = ", ".join(formatted_out_adj)
615
707
  if len(args_out) > 1:
616
708
  arg_str = ", ".join([var_str, out_str, adj_str, out_adj_str])
@@ -621,10 +713,10 @@ class Adjoint:
621
713
  return arg_str
622
714
 
623
715
  def indent(adj):
624
- adj.prefix = adj.prefix + " "
716
+ adj.indentation = adj.indentation + " "
625
717
 
626
718
  def dedent(adj):
627
- adj.prefix = adj.prefix[:-4]
719
+ adj.indentation = adj.indentation[:-4]
628
720
 
629
721
  def begin_block(adj):
630
722
  b = Block()
@@ -639,10 +731,9 @@ class Adjoint:
639
731
  def end_block(adj):
640
732
  return adj.blocks.pop()
641
733
 
642
- def add_var(adj, type=None, constant=None, name=None):
643
- if name is None:
644
- index = len(adj.variables)
645
- name = str(index)
734
+ def add_var(adj, type=None, constant=None):
735
+ index = len(adj.variables)
736
+ name = str(index)
646
737
 
647
738
  # allocate new variable
648
739
  v = Var(name, type=type, constant=constant)
@@ -655,30 +746,54 @@ class Adjoint:
655
746
 
656
747
  # append a statement to the forward pass
657
748
  def add_forward(adj, statement, replay=None, skip_replay=False):
658
- adj.blocks[-1].body_forward.append(adj.prefix + statement)
749
+ adj.blocks[-1].body_forward.append(adj.indentation + statement)
659
750
 
660
751
  if not skip_replay:
661
752
  if replay:
662
753
  # if custom replay specified then output it
663
- adj.blocks[-1].body_replay.append(adj.prefix + replay)
754
+ adj.blocks[-1].body_replay.append(adj.indentation + replay)
664
755
  else:
665
756
  # by default just replay the original statement
666
- adj.blocks[-1].body_replay.append(adj.prefix + statement)
757
+ adj.blocks[-1].body_replay.append(adj.indentation + statement)
667
758
 
668
759
  # append a statement to the reverse pass
669
760
  def add_reverse(adj, statement):
670
- adj.blocks[-1].body_reverse.append(adj.prefix + statement)
761
+ adj.blocks[-1].body_reverse.append(adj.indentation + statement)
671
762
 
672
763
  def add_constant(adj, n):
673
764
  output = adj.add_var(type=type(n), constant=n)
674
765
  return output
675
766
 
767
+ def load(adj, var):
768
+ if is_reference(var.type):
769
+ var = adj.add_builtin_call("load", [var])
770
+ return var
771
+
676
772
  def add_comp(adj, op_strings, left, comps):
677
773
  output = adj.add_var(builtins.bool)
678
774
 
679
- s = "var_" + str(output) + " = " + ("(" * len(comps)) + "var_" + str(left) + " "
775
+ left = adj.load(left)
776
+ s = output.emit() + " = " + ("(" * len(comps)) + left.emit() + " "
777
+
778
+ prev_comp = None
779
+
680
780
  for op, comp in zip(op_strings, comps):
681
- s += op + " var_" + str(comp) + ") "
781
+ comp_chainable = op_str_is_chainable(op)
782
+ if comp_chainable and prev_comp:
783
+ # We restrict chaining to operands of the same type
784
+ if prev_comp.type is comp.type:
785
+ prev_comp = adj.load(prev_comp)
786
+ comp = adj.load(comp)
787
+ s += "&& (" + prev_comp.emit() + " " + op + " " + comp.emit() + ")) "
788
+ else:
789
+ raise WarpCodegenTypeError(
790
+ f"Cannot chain comparisons of unequal types: {prev_comp.type} {op} {comp.type}."
791
+ )
792
+ else:
793
+ comp = adj.load(comp)
794
+ s += op + " " + comp.emit() + ") "
795
+
796
+ prev_comp = comp
682
797
 
683
798
  s = s.rstrip() + ";"
684
799
 
@@ -687,110 +802,106 @@ class Adjoint:
687
802
  return output
688
803
 
689
804
  def add_bool_op(adj, op_string, exprs):
805
+ exprs = [adj.load(expr) for expr in exprs]
690
806
  output = adj.add_var(builtins.bool)
691
- command = (
692
- "var_" + str(output) + " = " + (" " + op_string + " ").join(["var_" + str(expr) for expr in exprs]) + ";"
693
- )
807
+ command = output.emit() + " = " + (" " + op_string + " ").join([expr.emit() for expr in exprs]) + ";"
694
808
  adj.add_forward(command)
695
809
 
696
810
  return output
697
811
 
698
- def add_call(adj, func, args, min_outputs=None, templates=[], kwds=None):
699
- # if func is overloaded then perform overload resolution here
700
- # we validate argument types before they go to generated native code
701
- resolved_func = None
812
+ def resolve_func(adj, func, args, min_outputs, templates, kwds):
813
+ arg_types = [strip_reference(a.type) for a in args if not isinstance(a, warp.context.Function)]
702
814
 
703
- if func.is_builtin():
815
+ if not func.is_builtin():
816
+ # user-defined function
817
+ overload = func.get_overload(arg_types)
818
+ if overload is not None:
819
+ return overload
820
+ else:
821
+ # if func is overloaded then perform overload resolution here
822
+ # we validate argument types before they go to generated native code
704
823
  for f in func.overloads:
705
- match = True
706
-
707
824
  # skip type checking for variadic functions
708
825
  if not f.variadic:
709
826
  # check argument counts match are compatible (may be some default args)
710
827
  if len(f.input_types) < len(args):
711
- match = False
712
828
  continue
713
829
 
714
- # check argument types equal
715
- for i, (arg_name, arg_type) in enumerate(f.input_types.items()):
716
- # if arg type registered as Any, treat as
717
- # template allowing any type to match
718
- if arg_type == Any:
719
- continue
720
-
721
- # handle function refs as a special case
722
- if arg_type == Callable and type(args[i]) is warp.context.Function:
723
- continue
724
-
725
- # look for default values for missing args
726
- if i >= len(args):
727
- if arg_name not in f.defaults:
728
- match = False
729
- break
730
- else:
731
- # otherwise check arg type matches input variable type
732
- if not types_equal(arg_type, args[i].type, match_generic=True):
733
- match = False
734
- break
830
+ def match_args(args, f):
831
+ # check argument types equal
832
+ for i, (arg_name, arg_type) in enumerate(f.input_types.items()):
833
+ # if arg type registered as Any, treat as
834
+ # template allowing any type to match
835
+ if arg_type == Any:
836
+ continue
837
+
838
+ # handle function refs as a special case
839
+ if arg_type == Callable and type(args[i]) is warp.context.Function:
840
+ continue
841
+
842
+ if arg_type == Reference and is_reference(args[i].type):
843
+ continue
844
+
845
+ # look for default values for missing args
846
+ if i >= len(args):
847
+ if arg_name not in f.defaults:
848
+ return False
849
+ else:
850
+ # otherwise check arg type matches input variable type
851
+ if not types_equal(arg_type, strip_reference(args[i].type), match_generic=True):
852
+ return False
853
+
854
+ return True
855
+
856
+ if not match_args(args, f):
857
+ continue
735
858
 
736
859
  # check output dimensions match expectations
737
860
  if min_outputs:
738
861
  try:
739
862
  value_type = f.value_func(args, kwds, templates)
740
- if len(value_type) != min_outputs:
741
- match = False
863
+ if not hasattr(value_type, "__len__") or len(value_type) != min_outputs:
742
864
  continue
743
865
  except Exception:
744
866
  # value func may fail if the user has given
745
867
  # incorrect args, so we need to catch this
746
- match = False
747
868
  continue
748
869
 
749
870
  # found a match, use it
750
- if match:
751
- resolved_func = f
752
- break
753
- else:
754
- # user-defined function
755
- arg_types = [a.type for a in args]
756
-
757
- resolved_func = func.get_overload(arg_types)
758
-
759
- if resolved_func is None:
760
- arg_types = []
761
-
762
- for x in args:
763
- if isinstance(x, Var):
764
- # shorten Warp primitive type names
765
- if isinstance(x.type, list):
766
- if len(x.type) != 1:
767
- raise Exception("Argument must not be the result from a multi-valued function")
768
- arg_type = x.type[0]
769
- else:
770
- arg_type = x.type
771
- if arg_type.__module__ == "warp.types":
772
- arg_types.append(arg_type.__name__)
773
- else:
774
- arg_types.append(arg_type.__module__ + "." + arg_type.__name__)
775
-
776
- if isinstance(x, warp.context.Function):
777
- arg_types.append("function")
778
-
779
- raise Exception(
780
- f"Couldn't find function overload for '{func.key}' that matched inputs with types: [{', '.join(arg_types)}]"
781
- )
871
+ return f
872
+
873
+ # unresolved function, report error
874
+ arg_types = []
875
+
876
+ for x in args:
877
+ if isinstance(x, Var):
878
+ # shorten Warp primitive type names
879
+ if isinstance(x.type, list):
880
+ if len(x.type) != 1:
881
+ raise WarpCodegenError("Argument must not be the result from a multi-valued function")
882
+ arg_type = x.type[0]
883
+ else:
884
+ arg_type = x.type
782
885
 
783
- else:
784
- func = resolved_func
886
+ arg_types.append(type_repr(arg_type))
887
+
888
+ if isinstance(x, warp.context.Function):
889
+ arg_types.append("function")
890
+
891
+ raise WarpCodegenError(
892
+ f"Couldn't find function overload for '{func.key}' that matched inputs with types: [{', '.join(arg_types)}]"
893
+ )
894
+
895
+ def add_call(adj, func, args, min_outputs=None, templates=[], kwds=None):
896
+ func = adj.resolve_func(func, args, min_outputs, templates, kwds)
785
897
 
786
898
  # push any default values onto args
787
899
  for i, (arg_name, arg_type) in enumerate(func.input_types.items()):
788
900
  if i >= len(args):
789
- if arg_name in f.defaults:
901
+ if arg_name in func.defaults:
790
902
  const = adj.add_constant(func.defaults[arg_name])
791
903
  args.append(const)
792
904
  else:
793
- match = False
794
905
  break
795
906
 
796
907
  # if it is a user-function then build it recursively
@@ -798,105 +909,105 @@ class Adjoint:
798
909
  adj.builder.build_function(func)
799
910
 
800
911
  # evaluate the function type based on inputs
801
- value_type = func.value_func(args, kwds, templates)
912
+ arg_types = [strip_reference(a.type) for a in args if not isinstance(a, warp.context.Function)]
913
+ return_type = func.value_func(arg_types, kwds, templates)
802
914
 
803
915
  func_name = compute_type_str(func.native_func, templates)
916
+ param_types = list(func.input_types.values())
804
917
 
805
918
  use_initializer_list = func.initializer_list_func(args, templates)
806
919
 
807
- if value_type is None:
920
+ args_var = [
921
+ adj.load(a)
922
+ if not ((param_types[i] == Reference or param_types[i] == Callable) if i < len(param_types) else False)
923
+ else a
924
+ for i, a in enumerate(args)
925
+ ]
926
+
927
+ if return_type is None:
808
928
  # handles expression (zero output) functions, e.g.: void do_something();
809
929
 
810
- forward_call = "{}{}({});".format(
811
- func.namespace, func_name, adj.format_forward_call_args(args, use_initializer_list)
930
+ output = None
931
+ output_list = []
932
+
933
+ forward_call = (
934
+ f"{func.namespace}{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
812
935
  )
813
936
  replay_call = forward_call
814
937
  if func.custom_replay_func is not None:
815
- replay_call = "{}replay_{}({});".format(
816
- func.namespace, func_name, adj.format_forward_call_args(args, use_initializer_list)
817
- )
818
- if func.skip_replay:
819
- adj.add_forward(forward_call, replay="// " + replay_call)
820
- else:
821
- adj.add_forward(forward_call, replay=replay_call)
822
-
823
- if not func.missing_grad and len(args):
824
- arg_str = adj.format_reverse_call_args(args, [], {}, {}, use_initializer_list)
825
- if arg_str is not None:
826
- reverse_call = "{}adj_{}({});".format(func.namespace, func.native_func, arg_str)
827
- adj.add_reverse(reverse_call)
828
-
829
- return None
938
+ replay_call = f"{func.namespace}replay_{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
830
939
 
831
- elif not isinstance(value_type, list) or len(value_type) == 1:
940
+ elif not isinstance(return_type, list) or len(return_type) == 1:
832
941
  # handle simple function (one output)
833
942
 
834
- if isinstance(value_type, list):
835
- value_type = value_type[0]
836
- output = adj.add_var(value_type)
837
- forward_call = "var_{} = {}{}({});".format(
838
- output, func.namespace, func_name, adj.format_forward_call_args(args, use_initializer_list)
839
- )
943
+ if isinstance(return_type, list):
944
+ return_type = return_type[0]
945
+ output = adj.add_var(return_type)
946
+ output_list = [output]
947
+
948
+ forward_call = f"var_{output} = {func.namespace}{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
840
949
  replay_call = forward_call
841
950
  if func.custom_replay_func is not None:
842
- replay_call = "var_{} = {}replay_{}({});".format(
843
- output, func.namespace, func_name, adj.format_forward_call_args(args, use_initializer_list)
844
- )
845
-
846
- if func.skip_replay:
847
- adj.add_forward(forward_call, replay="// " + replay_call)
848
- else:
849
- adj.add_forward(forward_call, replay=replay_call)
850
-
851
- if not func.missing_grad and len(args):
852
- arg_str = adj.format_reverse_call_args(args, [output], {}, {}, use_initializer_list)
853
- if arg_str is not None:
854
- reverse_call = "{}adj_{}({});".format(func.namespace, func.native_func, arg_str)
855
- adj.add_reverse(reverse_call)
856
-
857
- return output
951
+ replay_call = f"var_{output} = {func.namespace}replay_{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
858
952
 
859
953
  else:
860
954
  # handle multiple value functions
861
955
 
862
- output = [adj.add_var(v) for v in value_type]
863
- forward_call = "{}{}({});".format(
864
- func.namespace, func_name, adj.format_forward_call_args(args + output, use_initializer_list)
956
+ output = [adj.add_var(v) for v in return_type]
957
+ output_list = output
958
+
959
+ forward_call = (
960
+ f"{func.namespace}{func_name}({adj.format_forward_call_args(args_var + output, use_initializer_list)});"
865
961
  )
866
- adj.add_forward(forward_call)
962
+ replay_call = forward_call
867
963
 
868
- if not func.missing_grad and len(args):
869
- arg_str = adj.format_reverse_call_args(
870
- args, output, {}, {}, use_initializer_list, has_output_args=func.custom_grad_func is None
871
- )
872
- if arg_str is not None:
873
- reverse_call = "{}adj_{}({});".format(func.namespace, func.native_func, arg_str)
874
- adj.add_reverse(reverse_call)
964
+ if func.skip_replay:
965
+ adj.add_forward(forward_call, replay="// " + replay_call)
966
+ else:
967
+ adj.add_forward(forward_call, replay=replay_call)
968
+
969
+ if not func.missing_grad and len(args):
970
+ reverse_has_output_args = (
971
+ func.require_original_output_arg or len(output_list) > 1
972
+ ) and func.custom_grad_func is None
973
+ arg_str = adj.format_reverse_call_args(
974
+ args_var,
975
+ args,
976
+ output_list,
977
+ use_initializer_list,
978
+ has_output_args=reverse_has_output_args,
979
+ require_original_output_arg=func.require_original_output_arg,
980
+ )
981
+ if arg_str is not None:
982
+ reverse_call = f"{func.namespace}adj_{func.native_func}({arg_str});"
983
+ adj.add_reverse(reverse_call)
875
984
 
876
- if len(output) == 1:
877
- return output[0]
985
+ return output
878
986
 
879
- return output
987
+ def add_builtin_call(adj, func_name, args, min_outputs=None, templates=[], kwds=None):
988
+ func = warp.context.builtin_functions[func_name]
989
+ return adj.add_call(func, args, min_outputs, templates, kwds)
880
990
 
881
991
  def add_return(adj, var):
882
992
  if var is None or len(var) == 0:
883
- adj.add_forward("return;", "goto label{};".format(adj.label_count))
993
+ adj.add_forward("return;", f"goto label{adj.label_count};")
884
994
  elif len(var) == 1:
885
- adj.add_forward("return var_{};".format(var[0]), "goto label{};".format(adj.label_count))
995
+ adj.add_forward(f"return {var[0].emit()};", f"goto label{adj.label_count};")
886
996
  adj.add_reverse("adj_" + str(var[0]) + " += adj_ret;")
887
997
  else:
888
998
  for i, v in enumerate(var):
889
- adj.add_forward("ret_{} = var_{};".format(i, v))
890
- adj.add_reverse("adj_{} += adj_ret_{};".format(v, i))
891
- adj.add_forward("return;", "goto label{};".format(adj.label_count))
999
+ adj.add_forward(f"ret_{i} = {v.emit()};")
1000
+ adj.add_reverse(f"adj_{v} += adj_ret_{i};")
1001
+ adj.add_forward("return;", f"goto label{adj.label_count};")
892
1002
 
893
- adj.add_reverse("label{}:;".format(adj.label_count))
1003
+ adj.add_reverse(f"label{adj.label_count}:;")
894
1004
 
895
1005
  adj.label_count += 1
896
1006
 
897
1007
  # define an if statement
898
1008
  def begin_if(adj, cond):
899
- adj.add_forward("if (var_{}) {{".format(cond))
1009
+ cond = adj.load(cond)
1010
+ adj.add_forward(f"if ({cond.emit()}) {{")
900
1011
  adj.add_reverse("}")
901
1012
 
902
1013
  adj.indent()
@@ -905,10 +1016,12 @@ class Adjoint:
905
1016
  adj.dedent()
906
1017
 
907
1018
  adj.add_forward("}")
908
- adj.add_reverse(f"if (var_{cond}) {{")
1019
+ cond = adj.load(cond)
1020
+ adj.add_reverse(f"if ({cond.emit()}) {{")
909
1021
 
910
1022
  def begin_else(adj, cond):
911
- adj.add_forward(f"if (!var_{cond}) {{")
1023
+ cond = adj.load(cond)
1024
+ adj.add_forward(f"if (!{cond.emit()}) {{")
912
1025
  adj.add_reverse("}")
913
1026
 
914
1027
  adj.indent()
@@ -917,7 +1030,8 @@ class Adjoint:
917
1030
  adj.dedent()
918
1031
 
919
1032
  adj.add_forward("}")
920
- adj.add_reverse(f"if (!var_{cond}) {{")
1033
+ cond = adj.load(cond)
1034
+ adj.add_reverse(f"if (!{cond.emit()}) {{")
921
1035
 
922
1036
  # define a for-loop
923
1037
  def begin_for(adj, iter):
@@ -927,10 +1041,10 @@ class Adjoint:
927
1041
  adj.indent()
928
1042
 
929
1043
  # evaluate cond
930
- adj.add_forward(f"if (iter_cmp(var_{iter}) == 0) goto for_end_{cond_block.label};")
1044
+ adj.add_forward(f"if (iter_cmp({iter.emit()}) == 0) goto for_end_{cond_block.label};")
931
1045
 
932
1046
  # evaluate iter
933
- val = adj.add_call(warp.context.builtin_functions["iter_next"], [iter])
1047
+ val = adj.add_builtin_call("iter_next", [iter])
934
1048
 
935
1049
  adj.begin_block()
936
1050
 
@@ -961,17 +1075,14 @@ class Adjoint:
961
1075
  reverse = []
962
1076
 
963
1077
  # reverse iterator
964
- reverse.append(adj.prefix + f"var_{iter} = wp::iter_reverse(var_{iter});")
1078
+ reverse.append(adj.indentation + f"{iter.emit()} = wp::iter_reverse({iter.emit()});")
965
1079
 
966
1080
  for i in cond_block.body_forward:
967
1081
  reverse.append(i)
968
1082
 
969
1083
  # zero adjoints
970
1084
  for i in body_block.vars:
971
- if isinstance(i.type, Struct):
972
- reverse.append(adj.prefix + f"\tadj_{i} = {i.ctype()}{{}};")
973
- else:
974
- reverse.append(adj.prefix + f"\tadj_{i} = {i.ctype()}(0);")
1085
+ reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
975
1086
 
976
1087
  # replay
977
1088
  for i in body_block.body_replay:
@@ -981,14 +1092,14 @@ class Adjoint:
981
1092
  for i in reversed(body_block.body_reverse):
982
1093
  reverse.append(i)
983
1094
 
984
- reverse.append(adj.prefix + f"\tgoto for_start_{cond_block.label};")
985
- reverse.append(adj.prefix + f"for_end_{cond_block.label}:;")
1095
+ reverse.append(adj.indentation + f"\tgoto for_start_{cond_block.label};")
1096
+ reverse.append(adj.indentation + f"for_end_{cond_block.label}:;")
986
1097
 
987
1098
  adj.blocks[-1].body_reverse.extend(reversed(reverse))
988
1099
 
989
1100
  # define a while loop
990
1101
  def begin_while(adj, cond):
991
- # evaulate condition in its own block
1102
+ # evaluate condition in its own block
992
1103
  # so we can control replay
993
1104
  cond_block = adj.begin_block()
994
1105
  adj.loop_blocks.append(cond_block)
@@ -996,7 +1107,7 @@ class Adjoint:
996
1107
 
997
1108
  c = adj.eval(cond)
998
1109
 
999
- cond_block.body_forward.append(f"if ((var_{c}) == false) goto while_end_{cond_block.label};")
1110
+ cond_block.body_forward.append(f"if (({c.emit()}) == false) goto while_end_{cond_block.label};")
1000
1111
 
1001
1112
  # being block around loop
1002
1113
  adj.begin_block()
@@ -1030,10 +1141,7 @@ class Adjoint:
1030
1141
 
1031
1142
  # zero adjoints of local vars
1032
1143
  for i in body_block.vars:
1033
- if isinstance(i.type, Struct):
1034
- reverse.append(f"adj_{i} = {i.ctype()}{{}};")
1035
- else:
1036
- reverse.append(f"adj_{i} = {i.ctype()}(0);")
1144
+ reverse.append(f"{i.emit_adj()} = {{}};")
1037
1145
 
1038
1146
  # replay
1039
1147
  for i in body_block.body_replay:
@@ -1053,6 +1161,10 @@ class Adjoint:
1053
1161
  for f in node.body:
1054
1162
  adj.eval(f)
1055
1163
 
1164
+ if adj.return_var is not None and len(adj.return_var) == 1:
1165
+ if not isinstance(node.body[-1], ast.Return):
1166
+ adj.add_forward("return {};", skip_replay=True)
1167
+
1056
1168
  def emit_If(adj, node):
1057
1169
  if len(node.body) == 0:
1058
1170
  return None
@@ -1080,7 +1192,7 @@ class Adjoint:
1080
1192
 
1081
1193
  if var1 != var2:
1082
1194
  # insert a phi function that selects var1, var2 based on cond
1083
- out = adj.add_call(warp.context.builtin_functions["select"], [cond, var1, var2])
1195
+ out = adj.add_builtin_call("select", [cond, var1, var2])
1084
1196
  adj.symbols[sym] = out
1085
1197
 
1086
1198
  symbols_prev = adj.symbols.copy()
@@ -1104,7 +1216,7 @@ class Adjoint:
1104
1216
  if var1 != var2:
1105
1217
  # insert a phi function that selects var1, var2 based on cond
1106
1218
  # note the reversed order of vars since we want to use !cond as our select
1107
- out = adj.add_call(warp.context.builtin_functions["select"], [cond, var2, var1])
1219
+ out = adj.add_builtin_call("select", [cond, var2, var1])
1108
1220
  adj.symbols[sym] = out
1109
1221
 
1110
1222
  def emit_Compare(adj, node):
@@ -1126,7 +1238,7 @@ class Adjoint:
1126
1238
  elif isinstance(op, ast.Or):
1127
1239
  func = "||"
1128
1240
  else:
1129
- raise KeyError("Op {} is not supported".format(op))
1241
+ raise WarpCodegenKeyError(f"Op {op} is not supported")
1130
1242
 
1131
1243
  return adj.add_bool_op(func, [adj.eval(expr) for expr in node.values])
1132
1244
 
@@ -1146,7 +1258,7 @@ class Adjoint:
1146
1258
  obj = capturedvars.get(str(node.id), None)
1147
1259
 
1148
1260
  if obj is None:
1149
- raise KeyError("Referencing undefined symbol: " + str(node.id))
1261
+ raise WarpCodegenKeyError("Referencing undefined symbol: " + str(node.id))
1150
1262
 
1151
1263
  if warp.types.is_value(obj):
1152
1264
  # evaluate constant
@@ -1158,26 +1270,96 @@ class Adjoint:
1158
1270
  # pass it back to the caller for processing
1159
1271
  return obj
1160
1272
 
1273
+ @staticmethod
1274
+ def resolve_type_attribute(var_type: type, attr: str):
1275
+ if isinstance(var_type, type) and type_is_value(var_type):
1276
+ if attr == "dtype":
1277
+ return type_scalar_type(var_type)
1278
+ elif attr == "length":
1279
+ return type_length(var_type)
1280
+
1281
+ return getattr(var_type, attr, None)
1282
+
1283
+ def vector_component_index(adj, component, vector_type):
1284
+ if len(component) != 1:
1285
+ raise WarpCodegenAttributeError(f"Vector swizzle must be single character, got .{component}")
1286
+
1287
+ dim = vector_type._shape_[0]
1288
+ swizzles = "xyzw"[0:dim]
1289
+ if component not in swizzles:
1290
+ raise WarpCodegenAttributeError(
1291
+ f"Vector swizzle for {vector_type} must be one of {swizzles}, got {component}"
1292
+ )
1293
+
1294
+ index = swizzles.index(component)
1295
+ index = adj.add_constant(index)
1296
+ return index
1297
+
1298
+ @staticmethod
1299
+ def is_differentiable_value_type(var_type):
1300
+ # checks that the argument type is a value type (i.e, not an array)
1301
+ # possibly holding differentiable values (for which gradients must be accumulated)
1302
+ return type_scalar_type(var_type) in float_types or isinstance(var_type, Struct)
1303
+
1161
1304
  def emit_Attribute(adj, node):
1162
- try:
1163
- val = adj.eval(node.value)
1305
+ if hasattr(node, "is_adjoint"):
1306
+ node.value.is_adjoint = True
1307
+
1308
+ aggregate = adj.eval(node.value)
1164
1309
 
1165
- if isinstance(val, types.ModuleType) or isinstance(val, type):
1166
- out = getattr(val, node.attr)
1310
+ try:
1311
+ if isinstance(aggregate, types.ModuleType) or isinstance(aggregate, type):
1312
+ out = getattr(aggregate, node.attr)
1167
1313
 
1168
1314
  if warp.types.is_value(out):
1169
1315
  return adj.add_constant(out)
1170
1316
 
1171
1317
  return out
1172
1318
 
1173
- # create a Var that points to the struct attribute, i.e.: directly generates `struct.attr` when used
1174
- attr_name = val.label + "." + node.attr
1175
- attr_type = val.type.vars[node.attr].type
1319
+ if hasattr(node, "is_adjoint"):
1320
+ # create a Var that points to the struct attribute, i.e.: directly generates `struct.attr` when used
1321
+ attr_name = aggregate.label + "." + node.attr
1322
+ attr_type = aggregate.type.vars[node.attr].type
1323
+
1324
+ return Var(attr_name, attr_type)
1325
+
1326
+ aggregate_type = strip_reference(aggregate.type)
1176
1327
 
1177
- return Var(attr_name, attr_type)
1328
+ # reading a vector component
1329
+ if type_is_vector(aggregate_type):
1330
+ index = adj.vector_component_index(node.attr, aggregate_type)
1178
1331
 
1179
- except KeyError:
1180
- raise RuntimeError(f"Error, `{node.attr}` is not an attribute of '{val.label}' ({val.type})")
1332
+ return adj.add_builtin_call("extract", [aggregate, index])
1333
+
1334
+ else:
1335
+ attr_type = Reference(aggregate_type.vars[node.attr].type)
1336
+ attr = adj.add_var(attr_type)
1337
+
1338
+ if is_reference(aggregate.type):
1339
+ adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}->{node.attr});")
1340
+ else:
1341
+ adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}.{node.attr});")
1342
+
1343
+ if adj.is_differentiable_value_type(strip_reference(attr_type)):
1344
+ adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} += {attr.emit_adj()};")
1345
+ else:
1346
+ adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} = {attr.emit_adj()};")
1347
+
1348
+ return attr
1349
+
1350
+ except (KeyError, AttributeError):
1351
+ # Try resolving as type attribute
1352
+ aggregate_type = strip_reference(aggregate.type) if isinstance(aggregate, Var) else aggregate
1353
+
1354
+ type_attribute = adj.resolve_type_attribute(aggregate_type, node.attr)
1355
+ if type_attribute is not None:
1356
+ return type_attribute
1357
+
1358
+ if isinstance(aggregate, Var):
1359
+ raise WarpCodegenAttributeError(
1360
+ f"Error, `{node.attr}` is not an attribute of '{node.value.id}' ({type_repr(aggregate.type)})"
1361
+ )
1362
+ raise WarpCodegenAttributeError(f"Error, `{node.attr}` is not an attribute of '{aggregate}'")
1181
1363
 
1182
1364
  def emit_String(adj, node):
1183
1365
  # string constant
@@ -1194,19 +1376,25 @@ class Adjoint:
1194
1376
  adj.symbols[key] = out
1195
1377
  return out
1196
1378
 
1379
+ def emit_Ellipsis(adj, node):
1380
+ # stubbed @wp.native_func
1381
+ return
1382
+
1197
1383
  def emit_NameConstant(adj, node):
1198
- if node.value is True:
1384
+ if node.value:
1199
1385
  return adj.add_constant(True)
1200
- elif node.value is False:
1201
- return adj.add_constant(False)
1202
1386
  elif node.value is None:
1203
- raise TypeError("None type unsupported")
1387
+ raise WarpCodegenTypeError("None type unsupported")
1388
+ else:
1389
+ return adj.add_constant(False)
1204
1390
 
1205
1391
  def emit_Constant(adj, node):
1206
1392
  if isinstance(node, ast.Str):
1207
1393
  return adj.emit_String(node)
1208
1394
  elif isinstance(node, ast.Num):
1209
1395
  return adj.emit_Num(node)
1396
+ elif isinstance(node, ast.Ellipsis):
1397
+ return adj.emit_Ellipsis(node)
1210
1398
  else:
1211
1399
  assert isinstance(node, ast.NameConstant)
1212
1400
  return adj.emit_NameConstant(node)
@@ -1217,18 +1405,16 @@ class Adjoint:
1217
1405
  right = adj.eval(node.right)
1218
1406
 
1219
1407
  name = builtin_operators[type(node.op)]
1220
- func = warp.context.builtin_functions[name]
1221
1408
 
1222
- return adj.add_call(func, [left, right])
1409
+ return adj.add_builtin_call(name, [left, right])
1223
1410
 
1224
1411
  def emit_UnaryOp(adj, node):
1225
1412
  # evaluate unary op arguments
1226
1413
  arg = adj.eval(node.operand)
1227
1414
 
1228
1415
  name = builtin_operators[type(node.op)]
1229
- func = warp.context.builtin_functions[name]
1230
1416
 
1231
- return adj.add_call(func, [arg])
1417
+ return adj.add_builtin_call(name, [arg])
1232
1418
 
1233
1419
  def materialize_redefinitions(adj, symbols):
1234
1420
  # detect symbols with conflicting definitions (assigned inside the for loop)
@@ -1240,19 +1426,17 @@ class Adjoint:
1240
1426
  if var1 != var2:
1241
1427
  if warp.config.verbose and not adj.custom_reverse_mode:
1242
1428
  lineno = adj.lineno + adj.fun_lineno
1243
- line = adj.source.splitlines()[adj.lineno]
1244
- msg = f'Warning: detected mutated variable {sym} during a dynamic for-loop in function "{adj.fun_name}" at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n'
1429
+ line = adj.source_lines[adj.lineno]
1430
+ msg = f'Warning: detected mutated variable {sym} during a dynamic for-loop in function "{adj.fun_name}" at {adj.filename}:{lineno}: this may not be a differentiable operation.\n{line}\n'
1245
1431
  print(msg)
1246
1432
 
1247
1433
  if var1.constant is not None:
1248
- raise Exception(
1249
- "Error mutating a constant {} inside a dynamic loop, use the following syntax: pi = float(3.141) to declare a dynamic variable".format(
1250
- sym
1251
- )
1434
+ raise WarpCodegenError(
1435
+ f"Error mutating a constant {sym} inside a dynamic loop, use the following syntax: pi = float(3.141) to declare a dynamic variable"
1252
1436
  )
1253
1437
 
1254
1438
  # overwrite the old variable value (violates SSA)
1255
- adj.add_call(warp.context.builtin_functions["copy"], [var1, var2])
1439
+ adj.add_builtin_call("assign", [var1, var2])
1256
1440
 
1257
1441
  # reset the symbol to point to the original variable
1258
1442
  adj.symbols[sym] = var1
@@ -1271,35 +1455,20 @@ class Adjoint:
1271
1455
 
1272
1456
  adj.end_while()
1273
1457
 
1274
- def is_num(adj, a):
1275
- # simple constant
1276
- if isinstance(a, ast.Num):
1277
- return True
1278
- # expression of form -constant
1279
- elif isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Num):
1280
- return True
1281
- else:
1282
- # try and resolve the expression to an object
1283
- # e.g.: wp.constant in the globals scope
1284
- obj, path = adj.resolve_path(a)
1285
- if warp.types.is_int(obj):
1286
- return True
1287
- else:
1288
- return False
1289
-
1290
1458
  def eval_num(adj, a):
1291
1459
  if isinstance(a, ast.Num):
1292
- return a.n
1293
- elif isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Num):
1294
- return -a.operand.n
1295
- else:
1296
- # try and resolve the expression to an object
1297
- # e.g.: wp.constant in the globals scope
1298
- obj, path = adj.resolve_path(a)
1299
- if warp.types.is_int(obj):
1300
- return obj
1301
- else:
1302
- return False
1460
+ return True, a.n
1461
+ if isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Num):
1462
+ return True, -a.operand.n
1463
+
1464
+ # try and resolve the expression to an object
1465
+ # e.g.: wp.constant in the globals scope
1466
+ obj, _ = adj.resolve_static_expression(a)
1467
+
1468
+ if isinstance(obj, Var) and obj.constant is not None:
1469
+ obj = obj.constant
1470
+
1471
+ return warp.types.is_int(obj), obj
1303
1472
 
1304
1473
  # detects whether a loop contains a break (or continue) statement
1305
1474
  def contains_break(adj, body):
@@ -1322,61 +1491,82 @@ class Adjoint:
1322
1491
 
1323
1492
  # returns a constant range() if unrollable, otherwise None
1324
1493
  def get_unroll_range(adj, loop):
1325
- if not isinstance(loop.iter, ast.Call) or not isinstance(loop.iter.func, ast.Name) or loop.iter.func.id != "range":
1494
+ if (
1495
+ not isinstance(loop.iter, ast.Call)
1496
+ or not isinstance(loop.iter.func, ast.Name)
1497
+ or loop.iter.func.id != "range"
1498
+ or len(loop.iter.args) == 0
1499
+ or len(loop.iter.args) > 3
1500
+ ):
1326
1501
  return None
1327
1502
 
1328
- for a in loop.iter.args:
1329
- # if all range() arguments are numeric constants we will unroll
1330
- # note that this only handles trivial constants, it will not unroll
1331
- # constant compile-time expressions e.g.: range(0, 3*2)
1332
- if not adj.is_num(a):
1333
- return None
1334
-
1335
- # range(end)
1336
- if len(loop.iter.args) == 1:
1337
- start = 0
1338
- end = adj.eval_num(loop.iter.args[0])
1339
- step = 1
1340
-
1341
- # range(start, end)
1342
- elif len(loop.iter.args) == 2:
1343
- start = adj.eval_num(loop.iter.args[0])
1344
- end = adj.eval_num(loop.iter.args[1])
1345
- step = 1
1346
-
1347
- # range(start, end, step)
1348
- elif len(loop.iter.args) == 3:
1349
- start = adj.eval_num(loop.iter.args[0])
1350
- end = adj.eval_num(loop.iter.args[1])
1351
- step = adj.eval_num(loop.iter.args[2])
1352
-
1353
- # test if we're above max unroll count
1354
- max_iters = abs(end - start) // abs(step)
1355
- max_unroll = adj.builder.options["max_unroll"]
1356
-
1357
- if max_iters > max_unroll:
1358
- if warp.config.verbose:
1359
- print(
1360
- f"Warning: fixed-size loop count of {max_iters} is larger than the module 'max_unroll' limit of {max_unroll}, will generate dynamic loop."
1361
- )
1362
- return None
1503
+ # if all range() arguments are numeric constants we will unroll
1504
+ # note that this only handles trivial constants, it will not unroll
1505
+ # constant compile-time expressions e.g.: range(0, 3*2)
1506
+
1507
+ # Evaluate the arguments and check that they are numeric constants
1508
+ # It is important to do that in one pass, so that if evaluating these arguments have side effects
1509
+ # the code does not get generated more than once
1510
+ range_args = [adj.eval_num(arg) for arg in loop.iter.args]
1511
+ arg_is_numeric, arg_values = zip(*range_args)
1512
+
1513
+ if all(arg_is_numeric):
1514
+ # All argument are numeric constants
1515
+
1516
+ # range(end)
1517
+ if len(loop.iter.args) == 1:
1518
+ start = 0
1519
+ end = arg_values[0]
1520
+ step = 1
1521
+
1522
+ # range(start, end)
1523
+ elif len(loop.iter.args) == 2:
1524
+ start = arg_values[0]
1525
+ end = arg_values[1]
1526
+ step = 1
1527
+
1528
+ # range(start, end, step)
1529
+ elif len(loop.iter.args) == 3:
1530
+ start = arg_values[0]
1531
+ end = arg_values[1]
1532
+ step = arg_values[2]
1533
+
1534
+ # test if we're above max unroll count
1535
+ max_iters = abs(end - start) // abs(step)
1536
+ max_unroll = adj.builder.options["max_unroll"]
1537
+
1538
+ ok_to_unroll = True
1539
+
1540
+ if max_iters > max_unroll:
1541
+ if warp.config.verbose:
1542
+ print(
1543
+ f"Warning: fixed-size loop count of {max_iters} is larger than the module 'max_unroll' limit of {max_unroll}, will generate dynamic loop."
1544
+ )
1545
+ ok_to_unroll = False
1363
1546
 
1364
- if adj.contains_break(loop.body):
1365
- if warp.config.verbose:
1366
- print("Warning: 'break' or 'continue' found in loop body, will generate dynamic loop.")
1367
- return None
1547
+ elif adj.contains_break(loop.body):
1548
+ if warp.config.verbose:
1549
+ print("Warning: 'break' or 'continue' found in loop body, will generate dynamic loop.")
1550
+ ok_to_unroll = False
1368
1551
 
1369
- # unroll
1370
- return range(start, end, step)
1552
+ if ok_to_unroll:
1553
+ return range(start, end, step)
1554
+
1555
+ # Unroll is not possible, range needs to be valuated dynamically
1556
+ range_call = adj.add_builtin_call(
1557
+ "range",
1558
+ [adj.add_constant(val) if is_numeric else val for is_numeric, val in range_args],
1559
+ )
1560
+ return range_call
1371
1561
 
1372
1562
  def emit_For(adj, node):
1373
1563
  # try and unroll simple range() statements that use constant args
1374
1564
  unroll_range = adj.get_unroll_range(node)
1375
1565
 
1376
- if unroll_range:
1566
+ if isinstance(unroll_range, range):
1377
1567
  for i in unroll_range:
1378
1568
  const_iter = adj.add_constant(i)
1379
- var_iter = adj.add_call(warp.context.builtin_functions["int"], [const_iter])
1569
+ var_iter = adj.add_builtin_call("int", [const_iter])
1380
1570
  adj.symbols[node.target.id] = var_iter
1381
1571
 
1382
1572
  # eval body
@@ -1385,8 +1575,12 @@ class Adjoint:
1385
1575
 
1386
1576
  # otherwise generate a dynamic loop
1387
1577
  else:
1388
- # evaluate the Iterable
1389
- iter = adj.eval(node.iter)
1578
+ # evaluate the Iterable -- only if not previously evaluated when trying to unroll
1579
+ if unroll_range is not None:
1580
+ # Range has already been evaluated when trying to unroll, do not re-evaluate
1581
+ iter = unroll_range
1582
+ else:
1583
+ iter = adj.eval(node.iter)
1390
1584
 
1391
1585
  adj.symbols[node.target.id] = adj.begin_for(iter)
1392
1586
 
@@ -1415,15 +1609,28 @@ class Adjoint:
1415
1609
  def emit_Expr(adj, node):
1416
1610
  return adj.eval(node.value)
1417
1611
 
1612
+ def check_tid_in_func_error(adj, node):
1613
+ if adj.is_user_function:
1614
+ if hasattr(node.func, "attr") and node.func.attr == "tid":
1615
+ lineno = adj.lineno + adj.fun_lineno
1616
+ line = adj.source_lines[adj.lineno]
1617
+ raise WarpCodegenError(
1618
+ "tid() may only be called from a Warp kernel, not a Warp function. "
1619
+ "Instead, obtain the indices from a @wp.kernel and pass them as "
1620
+ f"arguments to the function {adj.fun_name}, {adj.filename}:{lineno}:\n{line}\n"
1621
+ )
1622
+
1418
1623
  def emit_Call(adj, node):
1624
+ adj.check_tid_in_func_error(node)
1625
+
1419
1626
  # try and lookup function in globals by
1420
1627
  # resolving path (e.g.: module.submodule.attr)
1421
- func, path = adj.resolve_path(node.func)
1628
+ func, path = adj.resolve_static_expression(node.func)
1422
1629
  templates = []
1423
1630
 
1424
- if isinstance(func, warp.context.Function) is False:
1631
+ if not isinstance(func, warp.context.Function):
1425
1632
  if len(path) == 0:
1426
- raise RuntimeError(f"Unrecognized syntax for function call, path not valid: '{node.func}'")
1633
+ raise WarpCodegenError(f"Unknown function or operator: '{node.func.func.id}'")
1427
1634
 
1428
1635
  attr = path[-1]
1429
1636
  caller = func
@@ -1448,7 +1655,7 @@ class Adjoint:
1448
1655
  func = caller.initializer()
1449
1656
 
1450
1657
  if func is None:
1451
- raise RuntimeError(
1658
+ raise WarpCodegenError(
1452
1659
  f"Could not find function {'.'.join(path)} as a built-in or user-defined function. Note that user functions must be annotated with a @wp.func decorator to be called from a kernel."
1453
1660
  )
1454
1661
 
@@ -1464,9 +1671,14 @@ class Adjoint:
1464
1671
  if isinstance(kw.value, ast.Num):
1465
1672
  return kw.value.n
1466
1673
  elif isinstance(kw.value, ast.Tuple):
1467
- return tuple(adj.eval_num(e) for e in kw.value.elts)
1674
+ arg_is_numeric, arg_values = zip(*(adj.eval_num(e) for e in kw.value.elts))
1675
+ if not all(arg_is_numeric):
1676
+ raise WarpCodegenError(
1677
+ f"All elements of the tuple keyword argument '{kw.name}' must be numeric constants, got '{arg_values}'"
1678
+ )
1679
+ return arg_values
1468
1680
  else:
1469
- return adj.resolve_path(kw.value)[0]
1681
+ return adj.resolve_static_expression(kw.value)[0]
1470
1682
 
1471
1683
  kwds = {kw.arg: kwval(kw) for kw in node.keywords}
1472
1684
 
@@ -1483,15 +1695,19 @@ class Adjoint:
1483
1695
  # the ast.Index node appears in 3.7 versions
1484
1696
  # when performing array slices, e.g.: x = arr[i]
1485
1697
  # but in version 3.8 and higher it does not appear
1698
+
1699
+ if hasattr(node, "is_adjoint"):
1700
+ node.value.is_adjoint = True
1701
+
1486
1702
  return adj.eval(node.value)
1487
1703
 
1488
1704
  def emit_Subscript(adj, node):
1489
1705
  if hasattr(node.value, "attr") and node.value.attr == "adjoint":
1490
1706
  # handle adjoint of a variable, i.e. wp.adjoint[var]
1707
+ node.slice.is_adjoint = True
1491
1708
  var = adj.eval(node.slice)
1492
1709
  var_name = var.label
1493
- var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False, is_adjoint=True)
1494
- adj.symbols[var.label] = var
1710
+ var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False)
1495
1711
  return var
1496
1712
 
1497
1713
  target = adj.eval(node.value)
@@ -1514,29 +1730,34 @@ class Adjoint:
1514
1730
  var = adj.eval(node.slice)
1515
1731
  indices.append(var)
1516
1732
 
1517
- if is_array(target.type):
1518
- if len(indices) == target.type.ndim:
1733
+ target_type = strip_reference(target.type)
1734
+ if is_array(target_type):
1735
+ if len(indices) == target_type.ndim:
1519
1736
  # handles array loads (where each dimension has an index specified)
1520
- out = adj.add_call(warp.context.builtin_functions["load"], [target, *indices])
1737
+ out = adj.add_builtin_call("address", [target, *indices])
1521
1738
  else:
1522
1739
  # handles array views (fewer indices than dimensions)
1523
- out = adj.add_call(warp.context.builtin_functions["view"], [target, *indices])
1740
+ out = adj.add_builtin_call("view", [target, *indices])
1524
1741
 
1525
1742
  else:
1526
1743
  # handles non-array type indexing, e.g: vec3, mat33, etc
1527
- out = adj.add_call(warp.context.builtin_functions["index"], [target, *indices])
1744
+ out = adj.add_builtin_call("extract", [target, *indices])
1528
1745
 
1529
- out.is_adjoint = target.is_adjoint
1530
1746
  return out
1531
1747
 
1532
1748
  def emit_Assign(adj, node):
1749
+ if len(node.targets) != 1:
1750
+ raise WarpCodegenError("Assigning the same value to multiple variables is not supported")
1751
+
1752
+ lhs = node.targets[0]
1753
+
1533
1754
  # handle the case where we are assigning multiple output variables
1534
- if isinstance(node.targets[0], ast.Tuple):
1755
+ if isinstance(lhs, ast.Tuple):
1535
1756
  # record the expected number of outputs on the node
1536
1757
  # we do this so we can decide which function to
1537
1758
  # call based on the number of expected outputs
1538
1759
  if isinstance(node.value, ast.Call):
1539
- node.value.expects = len(node.targets[0].elts)
1760
+ node.value.expects = len(lhs.elts)
1540
1761
 
1541
1762
  # evaluate values
1542
1763
  if isinstance(node.value, ast.Tuple):
@@ -1545,49 +1766,43 @@ class Adjoint:
1545
1766
  out = adj.eval(node.value)
1546
1767
 
1547
1768
  names = []
1548
- for v in node.targets[0].elts:
1769
+ for v in lhs.elts:
1549
1770
  if isinstance(v, ast.Name):
1550
1771
  names.append(v.id)
1551
1772
  else:
1552
- raise RuntimeError(
1773
+ raise WarpCodegenError(
1553
1774
  "Multiple return functions can only assign to simple variables, e.g.: x, y = func()"
1554
1775
  )
1555
1776
 
1556
1777
  if len(names) != len(out):
1557
- raise RuntimeError(
1558
- "Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {}, got {})".format(
1559
- len(out), len(names)
1560
- )
1778
+ raise WarpCodegenError(
1779
+ f"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {len(out)}, got {len(names)})"
1561
1780
  )
1562
1781
 
1563
1782
  for name, rhs in zip(names, out):
1564
1783
  if name in adj.symbols:
1565
1784
  if not types_equal(rhs.type, adj.symbols[name].type):
1566
- raise TypeError(
1567
- "Error, assigning to existing symbol {} ({}) with different type ({})".format(
1568
- name, adj.symbols[name].type, rhs.type
1569
- )
1785
+ raise WarpCodegenTypeError(
1786
+ f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
1570
1787
  )
1571
1788
 
1572
1789
  adj.symbols[name] = rhs
1573
1790
 
1574
- return out
1575
-
1576
1791
  # handles the case where we are assigning to an array index (e.g.: arr[i] = 2.0)
1577
- elif isinstance(node.targets[0], ast.Subscript):
1578
- if hasattr(node.targets[0].value, "attr") and node.targets[0].value.attr == "adjoint":
1792
+ elif isinstance(lhs, ast.Subscript):
1793
+ if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
1579
1794
  # handle adjoint of a variable, i.e. wp.adjoint[var]
1580
- src_var = adj.eval(node.targets[0].slice)
1795
+ lhs.slice.is_adjoint = True
1796
+ src_var = adj.eval(lhs.slice)
1581
1797
  var = Var(f"adj_{src_var.label}", type=src_var.type, constant=None, prefix=False)
1582
- adj.symbols[var.label] = var
1583
1798
  value = adj.eval(node.value)
1584
1799
  adj.add_forward(f"{var.emit()} = {value.emit()};")
1585
- return var
1800
+ return
1586
1801
 
1587
- target = adj.eval(node.targets[0].value)
1802
+ target = adj.eval(lhs.value)
1588
1803
  value = adj.eval(node.value)
1589
1804
 
1590
- slice = node.targets[0].slice
1805
+ slice = lhs.slice
1591
1806
  indices = []
1592
1807
 
1593
1808
  if isinstance(slice, ast.Tuple):
@@ -1595,7 +1810,6 @@ class Adjoint:
1595
1810
  for arg in slice.elts:
1596
1811
  var = adj.eval(arg)
1597
1812
  indices.append(var)
1598
-
1599
1813
  elif isinstance(slice, ast.Index) and isinstance(slice.value, ast.Tuple):
1600
1814
  # handles the x[i, j] case (Python 3.7.x)
1601
1815
  for arg in slice.value.elts:
@@ -1606,65 +1820,84 @@ class Adjoint:
1606
1820
  var = adj.eval(slice)
1607
1821
  indices.append(var)
1608
1822
 
1609
- if is_array(target.type):
1610
- adj.add_call(warp.context.builtin_functions["store"], [target, *indices, value])
1823
+ target_type = strip_reference(target.type)
1611
1824
 
1612
- elif type_is_vector(target.type) or type_is_matrix(target.type):
1613
- adj.add_call(warp.context.builtin_functions["indexset"], [target, *indices, value])
1825
+ if is_array(target_type):
1826
+ adj.add_builtin_call("array_store", [target, *indices, value])
1827
+
1828
+ elif type_is_vector(target_type) or type_is_matrix(target_type):
1829
+ if is_reference(target.type):
1830
+ attr = adj.add_builtin_call("indexref", [target, *indices])
1831
+ else:
1832
+ attr = adj.add_builtin_call("index", [target, *indices])
1833
+
1834
+ adj.add_builtin_call("store", [attr, value])
1614
1835
 
1615
1836
  if warp.config.verbose and not adj.custom_reverse_mode:
1616
1837
  lineno = adj.lineno + adj.fun_lineno
1617
- line = adj.source.splitlines()[adj.lineno]
1618
- node_source = adj.get_node_source(node.targets[0].value)
1838
+ line = adj.source_lines[adj.lineno]
1839
+ node_source = adj.get_node_source(lhs.value)
1619
1840
  print(
1620
1841
  f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
1621
1842
  )
1622
1843
 
1623
1844
  else:
1624
- raise RuntimeError("Can only subscript assign array, vector, and matrix types")
1845
+ raise WarpCodegenError("Can only subscript assign array, vector, and matrix types")
1625
1846
 
1626
- return var
1627
-
1628
- elif isinstance(node.targets[0], ast.Name):
1847
+ elif isinstance(lhs, ast.Name):
1629
1848
  # symbol name
1630
- name = node.targets[0].id
1849
+ name = lhs.id
1631
1850
 
1632
1851
  # evaluate rhs
1633
1852
  rhs = adj.eval(node.value)
1634
1853
 
1635
1854
  # check type matches if symbol already defined
1636
1855
  if name in adj.symbols:
1637
- if not types_equal(rhs.type, adj.symbols[name].type):
1638
- raise TypeError(
1639
- "Error, assigning to existing symbol {} ({}) with different type ({})".format(
1640
- name, adj.symbols[name].type, rhs.type
1641
- )
1856
+ if not types_equal(strip_reference(rhs.type), adj.symbols[name].type):
1857
+ raise WarpCodegenTypeError(
1858
+ f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
1642
1859
  )
1643
1860
 
1644
1861
  # handle simple assignment case (a = b), where we generate a value copy rather than reference
1645
- if isinstance(node.value, ast.Name):
1646
- out = adj.add_var(rhs.type)
1647
- adj.add_call(warp.context.builtin_functions["copy"], [out, rhs])
1862
+ if isinstance(node.value, ast.Name) or is_reference(rhs.type):
1863
+ out = adj.add_builtin_call("copy", [rhs])
1648
1864
  else:
1649
1865
  out = rhs
1650
1866
 
1651
1867
  # update symbol map (assumes lhs is a Name node)
1652
1868
  adj.symbols[name] = out
1653
- return out
1654
1869
 
1655
- elif isinstance(node.targets[0], ast.Attribute):
1870
+ elif isinstance(lhs, ast.Attribute):
1656
1871
  rhs = adj.eval(node.value)
1657
- attr = adj.emit_Attribute(node.targets[0])
1658
- adj.add_call(warp.context.builtin_functions["copy"], [attr, rhs])
1872
+ aggregate = adj.eval(lhs.value)
1873
+ aggregate_type = strip_reference(aggregate.type)
1659
1874
 
1660
- if warp.config.verbose and not adj.custom_reverse_mode:
1661
- lineno = adj.lineno + adj.fun_lineno
1662
- line = adj.source.splitlines()[adj.lineno]
1663
- msg = f'Warning: detected mutated struct {attr.label} during function "{adj.fun_name}" at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n'
1664
- print(msg)
1875
+ # assigning to a vector component
1876
+ if type_is_vector(aggregate_type):
1877
+ index = adj.vector_component_index(lhs.attr, aggregate_type)
1878
+
1879
+ if is_reference(aggregate.type):
1880
+ attr = adj.add_builtin_call("indexref", [aggregate, index])
1881
+ else:
1882
+ attr = adj.add_builtin_call("index", [aggregate, index])
1883
+
1884
+ adj.add_builtin_call("store", [attr, rhs])
1885
+
1886
+ else:
1887
+ attr = adj.emit_Attribute(lhs)
1888
+ if is_reference(attr.type):
1889
+ adj.add_builtin_call("store", [attr, rhs])
1890
+ else:
1891
+ adj.add_builtin_call("assign", [attr, rhs])
1892
+
1893
+ if warp.config.verbose and not adj.custom_reverse_mode:
1894
+ lineno = adj.lineno + adj.fun_lineno
1895
+ line = adj.source_lines[adj.lineno]
1896
+ msg = f'Warning: detected mutated struct {attr.label} during function "{adj.fun_name}" at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n'
1897
+ print(msg)
1665
1898
 
1666
1899
  else:
1667
- raise RuntimeError("Error, unsupported assignment statement.")
1900
+ raise WarpCodegenError("Error, unsupported assignment statement.")
1668
1901
 
1669
1902
  def emit_Return(adj, node):
1670
1903
  if node.value is None:
@@ -1675,37 +1908,26 @@ class Adjoint:
1675
1908
  var = (adj.eval(node.value),)
1676
1909
 
1677
1910
  if adj.return_var is not None:
1678
- old_ctypes = tuple(v.ctype() for v in adj.return_var)
1679
- new_ctypes = tuple(v.ctype() for v in var)
1911
+ old_ctypes = tuple(v.ctype(value_type=True) for v in adj.return_var)
1912
+ new_ctypes = tuple(v.ctype(value_type=True) for v in var)
1680
1913
  if old_ctypes != new_ctypes:
1681
- raise TypeError(
1914
+ raise WarpCodegenTypeError(
1682
1915
  f"Error, function returned different types, previous: [{', '.join(old_ctypes)}], new [{', '.join(new_ctypes)}]"
1683
1916
  )
1684
- else:
1685
- adj.return_var = var
1686
1917
 
1687
- adj.add_return(var)
1918
+ if var is not None:
1919
+ adj.return_var = tuple()
1920
+ for ret in var:
1921
+ if is_reference(ret.type):
1922
+ ret = adj.add_builtin_call("copy", [ret])
1923
+ adj.return_var += (ret,)
1688
1924
 
1689
- def emit_AugAssign(adj, node):
1690
- # convert inplace operations (+=, -=, etc) to ssa form, e.g.: c = a + b
1691
- left = adj.eval(node.target)
1925
+ adj.add_return(adj.return_var)
1692
1926
 
1693
- if left.is_adjoint:
1694
- # replace augassign with assignment statement + binary op
1695
- new_node = ast.Assign(targets=[node.target], value=ast.BinOp(node.target, node.op, node.value))
1696
- adj.eval(new_node)
1697
- return
1698
-
1699
- right = adj.eval(node.value)
1700
-
1701
- # lookup
1702
- name = builtin_operators[type(node.op)]
1703
- func = warp.context.builtin_functions[name]
1704
-
1705
- out = adj.add_call(func, [left, right])
1706
-
1707
- # update symbol map
1708
- adj.symbols[node.target.id] = out
1927
+ def emit_AugAssign(adj, node):
1928
+ # replace augmented assignment with assignment statement + binary op
1929
+ new_node = ast.Assign(targets=[node.target], value=ast.BinOp(node.target, node.op, node.value))
1930
+ adj.eval(new_node)
1709
1931
 
1710
1932
  def emit_Tuple(adj, node):
1711
1933
  # LHS for expressions, such as i, j, k = 1, 2, 3
@@ -1715,131 +1937,159 @@ class Adjoint:
1715
1937
  def emit_Pass(adj, node):
1716
1938
  pass
1717
1939
 
1940
+ node_visitors = {
1941
+ ast.FunctionDef: emit_FunctionDef,
1942
+ ast.If: emit_If,
1943
+ ast.Compare: emit_Compare,
1944
+ ast.BoolOp: emit_BoolOp,
1945
+ ast.Name: emit_Name,
1946
+ ast.Attribute: emit_Attribute,
1947
+ ast.Str: emit_String, # Deprecated in 3.8; use Constant
1948
+ ast.Num: emit_Num, # Deprecated in 3.8; use Constant
1949
+ ast.NameConstant: emit_NameConstant, # Deprecated in 3.8; use Constant
1950
+ ast.Constant: emit_Constant,
1951
+ ast.BinOp: emit_BinOp,
1952
+ ast.UnaryOp: emit_UnaryOp,
1953
+ ast.While: emit_While,
1954
+ ast.For: emit_For,
1955
+ ast.Break: emit_Break,
1956
+ ast.Continue: emit_Continue,
1957
+ ast.Expr: emit_Expr,
1958
+ ast.Call: emit_Call,
1959
+ ast.Index: emit_Index, # Deprecated in 3.8; Use the index value directly instead.
1960
+ ast.Subscript: emit_Subscript,
1961
+ ast.Assign: emit_Assign,
1962
+ ast.Return: emit_Return,
1963
+ ast.AugAssign: emit_AugAssign,
1964
+ ast.Tuple: emit_Tuple,
1965
+ ast.Pass: emit_Pass,
1966
+ ast.Ellipsis: emit_Ellipsis,
1967
+ }
1968
+
1718
1969
  def eval(adj, node):
1719
1970
  if hasattr(node, "lineno"):
1720
1971
  adj.set_lineno(node.lineno - 1)
1721
1972
 
1722
- node_visitors = {
1723
- ast.FunctionDef: Adjoint.emit_FunctionDef,
1724
- ast.If: Adjoint.emit_If,
1725
- ast.Compare: Adjoint.emit_Compare,
1726
- ast.BoolOp: Adjoint.emit_BoolOp,
1727
- ast.Name: Adjoint.emit_Name,
1728
- ast.Attribute: Adjoint.emit_Attribute,
1729
- ast.Str: Adjoint.emit_String, # Deprecated in 3.8; use Constant
1730
- ast.Num: Adjoint.emit_Num, # Deprecated in 3.8; use Constant
1731
- ast.NameConstant: Adjoint.emit_NameConstant, # Deprecated in 3.8; use Constant
1732
- ast.Constant: Adjoint.emit_Constant,
1733
- ast.BinOp: Adjoint.emit_BinOp,
1734
- ast.UnaryOp: Adjoint.emit_UnaryOp,
1735
- ast.While: Adjoint.emit_While,
1736
- ast.For: Adjoint.emit_For,
1737
- ast.Break: Adjoint.emit_Break,
1738
- ast.Continue: Adjoint.emit_Continue,
1739
- ast.Expr: Adjoint.emit_Expr,
1740
- ast.Call: Adjoint.emit_Call,
1741
- ast.Index: Adjoint.emit_Index, # Deprecated in 3.8; Use the index value directly instead.
1742
- ast.Subscript: Adjoint.emit_Subscript,
1743
- ast.Assign: Adjoint.emit_Assign,
1744
- ast.Return: Adjoint.emit_Return,
1745
- ast.AugAssign: Adjoint.emit_AugAssign,
1746
- ast.Tuple: Adjoint.emit_Tuple,
1747
- ast.Pass: Adjoint.emit_Pass,
1748
- }
1749
-
1750
- emit_node = node_visitors.get(type(node))
1751
-
1752
- if emit_node is not None:
1753
- if adj.is_user_function:
1754
- if hasattr(node, "value") and hasattr(node.value, "func") and hasattr(node.value.func, "attr"):
1755
- if node.value.func.attr == "tid":
1756
- lineno = adj.lineno + adj.fun_lineno
1757
- line = adj.source.splitlines()[adj.lineno]
1758
-
1759
- warp.utils.warn(
1760
- "Calling wp.tid() from a @wp.func is deprecated and will be removed in a future Warp "
1761
- "version. Instead, obtain the indices from a @wp.kernel and pass them as "
1762
- f"arguments to this function {adj.fun_name}, {adj.filename}:{lineno}:\n{line}\n",
1763
- PendingDeprecationWarning,
1764
- stacklevel=2,
1765
- )
1766
- return emit_node(adj, node)
1767
- else:
1768
- raise Exception("Error, ast node of type {} not supported".format(type(node)))
1973
+ emit_node = adj.node_visitors[type(node)]
1974
+
1975
+ return emit_node(adj, node)
1769
1976
 
1770
1977
  # helper to evaluate expressions of the form
1771
1978
  # obj1.obj2.obj3.attr in the function's global scope
1772
- def resolve_path(adj, node):
1773
- modules = []
1979
+ def resolve_path(adj, path):
1980
+ if len(path) == 0:
1981
+ return None
1774
1982
 
1775
- while isinstance(node, ast.Attribute):
1776
- modules.append(node.attr)
1777
- node = node.value
1983
+ # if root is overshadowed by local symbols, bail out
1984
+ if path[0] in adj.symbols:
1985
+ return None
1778
1986
 
1779
- if isinstance(node, ast.Name):
1780
- modules.append(node.id)
1987
+ if path[0] in __builtins__:
1988
+ return __builtins__[path[0]]
1781
1989
 
1782
- # reverse list since ast presents it backward order
1783
- path = [*reversed(modules)]
1990
+ # Look up the closure info and append it to adj.func.__globals__
1991
+ # in case you want to define a kernel inside a function and refer
1992
+ # to variables you've declared inside that function:
1993
+ extract_contents = (
1994
+ lambda contents: contents
1995
+ if isinstance(contents, warp.context.Function) or not callable(contents)
1996
+ else contents
1997
+ )
1998
+ capturedvars = dict(
1999
+ zip(
2000
+ adj.func.__code__.co_freevars,
2001
+ [extract_contents(c.cell_contents) for c in (adj.func.__closure__ or [])],
2002
+ )
2003
+ )
2004
+ vars_dict = {**adj.func.__globals__, **capturedvars}
1784
2005
 
1785
- if len(path) == 0:
1786
- return None, path
2006
+ if path[0] in vars_dict:
2007
+ func = vars_dict[path[0]]
1787
2008
 
1788
- # try and evaluate object path
1789
- try:
1790
- # Look up the closure info and append it to adj.func.__globals__
1791
- # in case you want to define a kernel inside a function and refer
1792
- # to variables you've declared inside that function:
1793
- extract_contents = (
1794
- lambda contents: contents
1795
- if isinstance(contents, warp.context.Function) or not callable(contents)
1796
- else contents
1797
- )
1798
- capturedvars = dict(
1799
- zip(
1800
- adj.func.__code__.co_freevars,
1801
- [extract_contents(c.cell_contents) for c in (adj.func.__closure__ or [])],
1802
- )
1803
- )
2009
+ # Support Warp types in kernels without the module suffix (e.g. v = vec3(0.0,0.2,0.4)):
2010
+ else:
2011
+ func = getattr(warp, path[0], None)
1804
2012
 
1805
- vars_dict = {**adj.func.__globals__, **capturedvars}
1806
- func = eval(".".join(path), vars_dict)
1807
- return func, path
1808
- except Exception:
1809
- pass
2013
+ if func:
2014
+ for i in range(1, len(path)):
2015
+ if hasattr(func, path[i]):
2016
+ func = getattr(func, path[i])
1810
2017
 
1811
- # I added this so people can eg do this kind of thing
1812
- # in a kernel:
2018
+ return func
1813
2019
 
1814
- # v = vec3(0.0,0.2,0.4)
2020
+ # Evaluates a static expression that does not depend on runtime values
2021
+ # if eval_types is True, try resolving the path using evaluated type information as well
2022
+ def resolve_static_expression(adj, root_node, eval_types=True):
2023
+ attributes = []
1815
2024
 
1816
- # vec3 is now an alias and is not in warp.context.builtin_functions.
1817
- # This means it can't be directly looked up in Adjoint.add_call, and
1818
- # needs to be looked up by digging some information out of the
1819
- # python object it actually came from.
2025
+ node = root_node
2026
+ while isinstance(node, ast.Attribute):
2027
+ attributes.append(node.attr)
2028
+ node = node.value
1820
2029
 
1821
- # Before this fix, resolve_path was returning None, as the
1822
- # "vec3" symbol is not available. In this situation I'm assuming
1823
- # it's a member of the warp module and trying to look it up:
1824
- try:
1825
- evalstr = ".".join(["warp"] + path)
1826
- func = eval(evalstr, {"warp": warp})
1827
- return func, path
1828
- except Exception:
1829
- return None, path
2030
+ if eval_types and isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
2031
+ # support for operators returning modules
2032
+ # i.e. operator_name(*operator_args).x.y.z
2033
+ operator_args = node.args
2034
+ operator_name = node.func.id
2035
+
2036
+ if operator_name == "type":
2037
+ if len(operator_args) != 1:
2038
+ raise WarpCodegenError(f"type() operator expects exactly one argument, got {len(operator_args)}")
2039
+
2040
+ # type() operator
2041
+ var = adj.eval(operator_args[0])
2042
+
2043
+ if isinstance(var, Var):
2044
+ var_type = strip_reference(var.type)
2045
+ # Allow accessing type attributes, for instance array.dtype
2046
+ while attributes:
2047
+ attr_name = attributes.pop()
2048
+ var_type, prev_type = adj.resolve_type_attribute(var_type, attr_name), var_type
2049
+
2050
+ if var_type is None:
2051
+ raise WarpCodegenAttributeError(
2052
+ f"{attr_name} is not an attribute of {type_repr(prev_type)}"
2053
+ )
2054
+
2055
+ return var_type, [type_repr(var_type)]
2056
+ else:
2057
+ raise WarpCodegenError(f"Cannot deduce the type of {var}")
2058
+
2059
+ # reverse list since ast presents it backward order
2060
+ path = [*reversed(attributes)]
2061
+ if isinstance(node, ast.Name):
2062
+ path.insert(0, node.id)
2063
+
2064
+ # Try resolving path from captured context
2065
+ captured_obj = adj.resolve_path(path)
2066
+ if captured_obj is not None:
2067
+ return captured_obj, path
2068
+
2069
+ # Still nothing found, maybe this is a predefined type attribute like `dtype`
2070
+ if eval_types:
2071
+ try:
2072
+ val = adj.eval(root_node)
2073
+ if val:
2074
+ return [val, type_repr(val)]
2075
+
2076
+ except Exception:
2077
+ pass
2078
+
2079
+ return None, path
1830
2080
 
1831
2081
  # annotate generated code with the original source code line
1832
2082
  def set_lineno(adj, lineno):
1833
2083
  if adj.lineno is None or adj.lineno != lineno:
1834
2084
  line = lineno + adj.fun_lineno
1835
- source = adj.raw_source[lineno].strip().ljust(80 - len(adj.prefix), " ")
2085
+ source = adj.source_lines[lineno].strip().ljust(80 - len(adj.indentation), " ")
1836
2086
  adj.add_forward(f"// {source} <L {line}>")
1837
2087
  adj.add_reverse(f"// adj: {source} <L {line}>")
1838
2088
  adj.lineno = lineno
1839
2089
 
1840
2090
  def get_node_source(adj, node):
1841
2091
  # return the Python code corresponding to the given AST node
1842
- return ast.get_source_segment("".join(adj.raw_source), node)
2092
+ return ast.get_source_segment(adj.source, node)
1843
2093
 
1844
2094
 
1845
2095
  # ----------------
@@ -1856,7 +2106,10 @@ cpu_module_header = """
1856
2106
  #define int(x) cast_int(x)
1857
2107
  #define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
1858
2108
 
1859
- using namespace wp;
2109
+ #define builtin_tid1d() wp::tid(wp::s_threadIdx)
2110
+ #define builtin_tid2d(x, y) wp::tid(x, y, wp::s_threadIdx, dim)
2111
+ #define builtin_tid3d(x, y, z) wp::tid(x, y, z, wp::s_threadIdx, dim)
2112
+ #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, wp::s_threadIdx, dim)
1860
2113
 
1861
2114
  """
1862
2115
 
@@ -1871,8 +2124,10 @@ cuda_module_header = """
1871
2124
  #define int(x) cast_int(x)
1872
2125
  #define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
1873
2126
 
1874
-
1875
- using namespace wp;
2127
+ #define builtin_tid1d() wp::tid(_idx)
2128
+ #define builtin_tid2d(x, y) wp::tid(x, y, _idx, dim)
2129
+ #define builtin_tid3d(x, y, z) wp::tid(x, y, z, _idx, dim)
2130
+ #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, _idx, dim)
1876
2131
 
1877
2132
  """
1878
2133
 
@@ -1886,7 +2141,9 @@ struct {name}
1886
2141
  {{
1887
2142
  }}
1888
2143
 
1889
- CUDA_CALLABLE {name}& operator += (const {name}&) {{ return *this; }}
2144
+ CUDA_CALLABLE {name}& operator += (const {name}& rhs)
2145
+ {{{prefix_add_body}
2146
+ return *this;}}
1890
2147
 
1891
2148
  }};
1892
2149
 
@@ -1942,24 +2199,18 @@ cuda_kernel_template = """
1942
2199
  extern "C" __global__ void {name}_cuda_kernel_forward(
1943
2200
  {forward_args})
1944
2201
  {{
1945
- size_t _idx = grid_index();
1946
- if (_idx >= dim.size)
1947
- return;
1948
-
1949
- set_launch_bounds(dim);
1950
-
1951
- {forward_body}}}
2202
+ for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
2203
+ _idx < dim.size;
2204
+ _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x)) {{
2205
+ {forward_body}}}}}
1952
2206
 
1953
2207
  extern "C" __global__ void {name}_cuda_kernel_backward(
1954
2208
  {reverse_args})
1955
2209
  {{
1956
- size_t _idx = grid_index();
1957
- if (_idx >= dim.size)
1958
- return;
1959
-
1960
- set_launch_bounds(dim);
1961
-
1962
- {reverse_body}}}
2210
+ for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
2211
+ _idx < dim.size;
2212
+ _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x)) {{
2213
+ {reverse_body}}}}}
1963
2214
 
1964
2215
  """
1965
2216
 
@@ -1985,11 +2236,9 @@ extern "C" {{
1985
2236
  WP_API void {name}_cpu_forward(
1986
2237
  {forward_args})
1987
2238
  {{
1988
- set_launch_bounds(dim);
1989
-
1990
2239
  for (size_t i=0; i < dim.size; ++i)
1991
2240
  {{
1992
- s_threadIdx = i;
2241
+ wp::s_threadIdx = i;
1993
2242
 
1994
2243
  {name}_cpu_kernel_forward(
1995
2244
  {forward_params});
@@ -1999,11 +2248,9 @@ WP_API void {name}_cpu_forward(
1999
2248
  WP_API void {name}_cpu_backward(
2000
2249
  {reverse_args})
2001
2250
  {{
2002
- set_launch_bounds(dim);
2003
-
2004
2251
  for (size_t i=0; i < dim.size; ++i)
2005
2252
  {{
2006
- s_threadIdx = i;
2253
+ wp::s_threadIdx = i;
2007
2254
 
2008
2255
  {name}_cpu_kernel_backward(
2009
2256
  {reverse_params});
@@ -2109,8 +2356,13 @@ def codegen_struct(struct, device="cpu", indent_size=4):
2109
2356
 
2110
2357
  body = []
2111
2358
  indent_block = " " * indent_size
2112
- for label, var in struct.vars.items():
2113
- body.append(var.ctype() + " " + label + ";\n")
2359
+
2360
+ if len(struct.vars) > 0:
2361
+ for label, var in struct.vars.items():
2362
+ body.append(var.ctype() + " " + label + ";\n")
2363
+ else:
2364
+ # for empty structs, emit the dummy attribute to avoid any compiler-specific alignment issues
2365
+ body.append("char _dummy_;\n")
2114
2366
 
2115
2367
  forward_args = []
2116
2368
  reverse_args = []
@@ -2118,17 +2370,25 @@ def codegen_struct(struct, device="cpu", indent_size=4):
2118
2370
  forward_initializers = []
2119
2371
  reverse_body = []
2120
2372
  atomic_add_body = []
2373
+ prefix_add_body = []
2121
2374
 
2122
2375
  # forward args
2123
2376
  for label, var in struct.vars.items():
2124
- forward_args.append(f"{var.ctype()} const& {label} = {{}}")
2125
- reverse_args.append(f"{var.ctype()} const&")
2377
+ var_ctype = var.ctype()
2378
+ forward_args.append(f"{var_ctype} const& {label} = {{}}")
2379
+ reverse_args.append(f"{var_ctype} const&")
2126
2380
 
2127
- atomic_add_body.append(f"{indent_block}adj_atomic_add(&p->{label}, t.{label});\n")
2381
+ namespace = "wp::" if var_ctype.startswith("wp::") or var_ctype == "bool" else ""
2382
+ atomic_add_body.append(f"{indent_block}{namespace}adj_atomic_add(&p->{label}, t.{label});\n")
2128
2383
 
2129
2384
  prefix = f"{indent_block}," if forward_initializers else ":"
2130
2385
  forward_initializers.append(f"{indent_block}{prefix} {label}{{{label}}}\n")
2131
2386
 
2387
+ # prefix-add operator
2388
+ for label, var in struct.vars.items():
2389
+ if not is_array(var.type):
2390
+ prefix_add_body.append(f"{indent_block}{label} += rhs.{label};\n")
2391
+
2132
2392
  # reverse args
2133
2393
  for label, var in struct.vars.items():
2134
2394
  reverse_args.append(var.ctype() + " & adj_" + label)
@@ -2146,6 +2406,7 @@ def codegen_struct(struct, device="cpu", indent_size=4):
2146
2406
  forward_initializers="".join(forward_initializers),
2147
2407
  reverse_args=indent(reverse_args),
2148
2408
  reverse_body="".join(reverse_body),
2409
+ prefix_add_body="".join(prefix_add_body),
2149
2410
  atomic_add_body="".join(atomic_add_body),
2150
2411
  )
2151
2412
 
@@ -2189,7 +2450,7 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
2189
2450
  return s
2190
2451
 
2191
2452
 
2192
- def codegen_func_reverse_body(adj, device="cpu", indent=4):
2453
+ def codegen_func_reverse_body(adj, device="cpu", indent=4, func_type="kernel"):
2193
2454
  body = []
2194
2455
  indent_block = " " * indent
2195
2456
 
@@ -2207,7 +2468,11 @@ def codegen_func_reverse_body(adj, device="cpu", indent=4):
2207
2468
  for l in reversed(adj.blocks[0].body_reverse):
2208
2469
  body += [l + "\n"]
2209
2470
 
2210
- body += ["return;\n"]
2471
+ # In grid-stride kernels the reverse body is in a for loop
2472
+ if device == "cuda" and func_type == "kernel":
2473
+ body += ["continue;\n"]
2474
+ else:
2475
+ body += ["return;\n"]
2211
2476
 
2212
2477
  return "".join([indent_block + l for l in body])
2213
2478
 
@@ -2230,20 +2495,17 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
2230
2495
  s += " // dual vars\n"
2231
2496
 
2232
2497
  for var in adj.variables:
2233
- if isinstance(var.type, Struct):
2234
- s += f" {var.ctype()} {var.emit('adj')};\n"
2235
- else:
2236
- s += f" {var.ctype()} {var.emit('adj')}(0);\n"
2498
+ s += f" {var.ctype(value_type=True)} {var.emit_adj()} = {{}};\n"
2237
2499
 
2238
2500
  if device == "cpu":
2239
2501
  s += codegen_func_reverse_body(adj, device=device, indent=4)
2240
2502
  elif device == "cuda":
2241
2503
  if func_type == "kernel":
2242
- s += codegen_func_reverse_body(adj, device=device, indent=8)
2504
+ s += codegen_func_reverse_body(adj, device=device, indent=8, func_type=func_type)
2243
2505
  else:
2244
- s += codegen_func_reverse_body(adj, device=device, indent=4)
2506
+ s += codegen_func_reverse_body(adj, device=device, indent=4, func_type=func_type)
2245
2507
  else:
2246
- raise ValueError("Device {} not supported for codegen".format(device))
2508
+ raise ValueError(f"Device {device} not supported for codegen")
2247
2509
 
2248
2510
  return s
2249
2511
 
@@ -2298,7 +2560,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options={}):
2298
2560
  forward_template = cuda_forward_function_template
2299
2561
  reverse_template = cuda_reverse_function_template
2300
2562
  else:
2301
- raise ValueError("Device {} is not supported".format(device))
2563
+ raise ValueError(f"Device {device} is not supported")
2302
2564
 
2303
2565
  # codegen body
2304
2566
  forward_body = codegen_func_forward(adj, func_type="function", device=device)
@@ -2335,6 +2597,55 @@ def codegen_func(adj, c_func_name: str, device="cpu", options={}):
2335
2597
  return s
2336
2598
 
2337
2599
 
2600
+ def codegen_snippet(adj, name, snippet, adj_snippet):
2601
+ forward_args = []
2602
+ reverse_args = []
2603
+
2604
+ # forward args
2605
+ for i, arg in enumerate(adj.args):
2606
+ s = f"{arg.ctype()} {arg.emit().replace('var_', '')}"
2607
+ forward_args.append(s)
2608
+ reverse_args.append(s)
2609
+
2610
+ # reverse args
2611
+ for i, arg in enumerate(adj.args):
2612
+ if isinstance(arg.type, indexedarray):
2613
+ _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
2614
+ reverse_args.append(_arg.ctype() + " & adj_" + arg.label)
2615
+ else:
2616
+ reverse_args.append(arg.ctype() + " & adj_" + arg.label)
2617
+
2618
+ forward_template = cuda_forward_function_template
2619
+ reverse_template = cuda_reverse_function_template
2620
+
2621
+ s = ""
2622
+ s += forward_template.format(
2623
+ name=name,
2624
+ return_type="void",
2625
+ forward_args=indent(forward_args),
2626
+ forward_body=snippet,
2627
+ filename=adj.filename,
2628
+ lineno=adj.fun_lineno,
2629
+ )
2630
+
2631
+ if adj_snippet:
2632
+ reverse_body = adj_snippet
2633
+ else:
2634
+ reverse_body = ""
2635
+
2636
+ s += reverse_template.format(
2637
+ name=name,
2638
+ return_type="void",
2639
+ reverse_args=indent(reverse_args),
2640
+ forward_body=snippet,
2641
+ reverse_body=reverse_body,
2642
+ filename=adj.filename,
2643
+ lineno=adj.fun_lineno,
2644
+ )
2645
+
2646
+ return s
2647
+
2648
+
2338
2649
  def codegen_kernel(kernel, device, options):
2339
2650
  # Update the module's options with the ones defined on the kernel, if any.
2340
2651
  options = dict(options)
@@ -2342,8 +2653,8 @@ def codegen_kernel(kernel, device, options):
2342
2653
 
2343
2654
  adj = kernel.adj
2344
2655
 
2345
- forward_args = ["launch_bounds_t dim"]
2346
- reverse_args = ["launch_bounds_t dim"]
2656
+ forward_args = ["wp::launch_bounds_t dim"]
2657
+ reverse_args = ["wp::launch_bounds_t dim"]
2347
2658
 
2348
2659
  # forward args
2349
2660
  for arg in adj.args:
@@ -2372,7 +2683,7 @@ def codegen_kernel(kernel, device, options):
2372
2683
  elif device == "cuda":
2373
2684
  template = cuda_kernel_template
2374
2685
  else:
2375
- raise ValueError("Device {} is not supported".format(device))
2686
+ raise ValueError(f"Device {device} is not supported")
2376
2687
 
2377
2688
  s = template.format(
2378
2689
  name=kernel.get_mangled_name(),
@@ -2392,7 +2703,7 @@ def codegen_module(kernel, device="cpu"):
2392
2703
  adj = kernel.adj
2393
2704
 
2394
2705
  # build forward signature
2395
- forward_args = ["launch_bounds_t dim"]
2706
+ forward_args = ["wp::launch_bounds_t dim"]
2396
2707
  forward_params = ["dim"]
2397
2708
 
2398
2709
  for arg in adj.args: