warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/codegen.py CHANGED
@@ -7,23 +7,40 @@
7
7
 
8
8
  from __future__ import annotations
9
9
 
10
- import re
11
- import sys
12
10
  import ast
13
- import inspect
11
+ import builtins
14
12
  import ctypes
13
+ import inspect
14
+ import math
15
+ import re
16
+ import sys
15
17
  import textwrap
16
18
  import types
19
+ from typing import Any, Callable, Mapping
17
20
 
18
- import numpy as np
21
+ import warp.config
22
+ from warp.types import *
19
23
 
20
- from typing import Any
21
- from typing import Callable
22
- from typing import Mapping
23
- from typing import Union
24
24
 
25
- from warp.types import *
26
- import warp.config
25
+ class WarpCodegenError(RuntimeError):
26
+ def __init__(self, message):
27
+ super().__init__(message)
28
+
29
+
30
+ class WarpCodegenTypeError(TypeError):
31
+ def __init__(self, message):
32
+ super().__init__(message)
33
+
34
+
35
+ class WarpCodegenAttributeError(AttributeError):
36
+ def __init__(self, message):
37
+ super().__init__(message)
38
+
39
+
40
+ class WarpCodegenKeyError(KeyError):
41
+ def __init__(self, message):
42
+ super().__init__(message)
43
+
27
44
 
28
45
  # map operator to function name
29
46
  builtin_operators = {}
@@ -57,6 +74,19 @@ builtin_operators[ast.Invert] = "invert"
57
74
  builtin_operators[ast.LShift] = "lshift"
58
75
  builtin_operators[ast.RShift] = "rshift"
59
76
 
77
+ comparison_chain_strings = [
78
+ builtin_operators[ast.Gt],
79
+ builtin_operators[ast.Lt],
80
+ builtin_operators[ast.LtE],
81
+ builtin_operators[ast.GtE],
82
+ builtin_operators[ast.Eq],
83
+ builtin_operators[ast.NotEq],
84
+ ]
85
+
86
+
87
+ def op_str_is_chainable(op: str) -> builtins.bool:
88
+ return op in comparison_chain_strings
89
+
60
90
 
61
91
  def get_annotations(obj: Any) -> Mapping[str, Any]:
62
92
  """Alternative to `inspect.get_annotations()` for Python 3.9 and older."""
@@ -67,97 +97,156 @@ def get_annotations(obj: Any) -> Mapping[str, Any]:
67
97
  return getattr(obj, "__annotations__", {})
68
98
 
69
99
 
70
- def _get_struct_instance_ctype(
71
- inst: StructInstance,
72
- parent_ctype: Union[StructInstance, None],
73
- parent_field: Union[str, None],
74
- ) -> ctypes.Structure:
75
- if inst._struct_.ctype._fields_ == [("_dummy_", ctypes.c_int)]:
76
- return inst._struct_.ctype()
100
+ def struct_instance_repr_recursive(inst: StructInstance, depth: int) -> str:
101
+ indent = "\t"
102
+
103
+ # handle empty structs
104
+ if len(inst._cls.vars) == 0:
105
+ return f"{inst._cls.key}()"
77
106
 
78
- if parent_ctype is None:
79
- inst_ctype = inst._struct_.ctype()
80
- else:
81
- inst_ctype = getattr(parent_ctype, parent_field)
107
+ lines = []
108
+ lines.append(f"{inst._cls.key}(")
109
+
110
+ for field_name, _ in inst._cls.ctype._fields_:
111
+ field_value = getattr(inst, field_name, None)
112
+
113
+ if isinstance(field_value, StructInstance):
114
+ field_value = struct_instance_repr_recursive(field_value, depth + 1)
115
+
116
+ lines.append(f"{indent * (depth + 1)}{field_name}={field_value},")
82
117
 
83
- for field_name, _ in inst_ctype._fields_:
84
- value = getattr(inst, field_name, None)
118
+ lines.append(f"{indent * depth})")
119
+ return "\n".join(lines)
120
+
121
+
122
+ class StructInstance:
123
+ def __init__(self, cls: Struct, ctype):
124
+ super().__setattr__("_cls", cls)
125
+
126
+ # maintain a c-types object for the top-level instance the struct
127
+ if not ctype:
128
+ super().__setattr__("_ctype", cls.ctype())
129
+ else:
130
+ super().__setattr__("_ctype", ctype)
131
+
132
+ # create Python attributes for each of the struct's variables
133
+ for field, var in cls.vars.items():
134
+ if isinstance(var.type, warp.codegen.Struct):
135
+ self.__dict__[field] = StructInstance(var.type, getattr(self._ctype, field))
136
+ elif isinstance(var.type, warp.types.array):
137
+ self.__dict__[field] = None
138
+ else:
139
+ self.__dict__[field] = var.type()
85
140
 
86
- var_type = inst._struct_.vars[field_name].type
87
- if isinstance(var_type, array):
141
+ def __setattr__(self, name, value):
142
+ if name not in self._cls.vars:
143
+ raise RuntimeError(f"Trying to set Warp struct attribute that does not exist {name}")
144
+
145
+ var = self._cls.vars[name]
146
+
147
+ # update our ctype flat copy
148
+ if isinstance(var.type, array):
88
149
  if value is None:
89
150
  # create array with null pointer
90
- setattr(inst_ctype, field_name, array_t())
151
+ setattr(self._ctype, name, array_t())
91
152
  else:
92
153
  # wp.array
93
154
  assert isinstance(value, array)
94
- assert (
95
- value.dtype == var_type.dtype
96
- ), "assign to struct member variable {} failed, expected type {}, got type {}".format(
97
- field_name, var_type.dtype, value.dtype
155
+ assert types_equal(
156
+ value.dtype, var.type.dtype
157
+ ), f"assign to struct member variable {name} failed, expected type {type_repr(var.type.dtype)}, got type {type_repr(value.dtype)}"
158
+ setattr(self._ctype, name, value.__ctype__())
159
+
160
+ elif isinstance(var.type, Struct):
161
+ # assign structs by-value, otherwise we would have problematic cases transferring ownership
162
+ # of the underlying ctypes data between shared Python struct instances
163
+
164
+ if not isinstance(value, StructInstance):
165
+ raise RuntimeError(
166
+ f"Trying to assign a non-structure value to a struct attribute with type: {self._cls.key}"
98
167
  )
99
- setattr(inst_ctype, field_name, value.__ctype__())
100
- elif isinstance(var_type, Struct):
101
- if value is None:
102
- _get_struct_instance_ctype(StructInstance(var_type), inst_ctype, field_name)
103
- else:
104
- _get_struct_instance_ctype(value, inst_ctype, field_name)
105
- elif issubclass(var_type, ctypes.Array):
168
+
169
+ # destination attribution on self
170
+ dest = getattr(self, name)
171
+
172
+ if dest._cls.key is not value._cls.key:
173
+ raise RuntimeError(
174
+ f"Trying to assign a structure of type {value._cls.key} to an attribute of {self._cls.key}"
175
+ )
176
+
177
+ # update all nested ctype vars by deep copy
178
+ for n in dest._cls.vars:
179
+ setattr(dest, n, getattr(value, n))
180
+
181
+ # early return to avoid updating our Python StructInstance
182
+ return
183
+
184
+ elif issubclass(var.type, ctypes.Array):
106
185
  # vector/matrix type, e.g. vec3
107
186
  if value is None:
108
- setattr(inst_ctype, field_name, var_type())
109
- elif types_equal(type(value), var_type):
110
- setattr(inst_ctype, field_name, value)
187
+ setattr(self._ctype, name, var.type())
188
+ elif types_equal(type(value), var.type):
189
+ setattr(self._ctype, name, value)
111
190
  else:
112
191
  # conversion from list/tuple, ndarray, etc.
113
- setattr(inst_ctype, field_name, var_type(value))
192
+ setattr(self._ctype, name, var.type(value))
193
+
114
194
  else:
115
195
  # primitive type
116
196
  if value is None:
117
- setattr(inst_ctype, field_name, var_type._type_())
197
+ # zero initialize
198
+ setattr(self._ctype, name, var.type._type_())
118
199
  else:
119
- setattr(inst_ctype, field_name, var_type._type_(value))
120
-
121
- return inst_ctype
122
-
123
-
124
- def _fmt_struct_instance_repr(inst: StructInstance, depth: int) -> str:
125
- indent = "\t"
126
-
127
- if inst._struct_.ctype._fields_ == [("_dummy_", ctypes.c_int)]:
128
- return f"{inst._struct_.key}()"
129
-
130
- lines = []
131
- lines.append(f"{inst._struct_.key}(")
132
-
133
- for field_name, _ in inst._struct_.ctype._fields_:
134
- if field_name == "_dummy_":
135
- continue
136
-
137
- field_value = getattr(inst, field_name, None)
138
-
139
- if isinstance(field_value, StructInstance):
140
- field_value = _fmt_struct_instance_repr(field_value, depth + 1)
200
+ if hasattr(value, "_type_"):
201
+ # assigning warp type value (e.g.: wp.float32)
202
+ value = value.value
203
+ # float16 needs conversion to uint16 bits
204
+ if var.type == warp.float16:
205
+ setattr(self._ctype, name, float_to_half_bits(value))
206
+ else:
207
+ setattr(self._ctype, name, value)
141
208
 
142
- lines.append(f"{indent * (depth + 1)}{field_name}={field_value},")
209
+ # update Python instance
210
+ super().__setattr__(name, value)
143
211
 
144
- lines.append(f"{indent * depth})")
145
- return "\n".join(lines)
212
+ def __ctype__(self):
213
+ return self._ctype
146
214
 
215
+ def __repr__(self):
216
+ return struct_instance_repr_recursive(self, 0)
147
217
 
148
- class StructInstance:
149
- def __init__(self, struct: Struct):
150
- self.__dict__["_struct_"] = struct
218
+ # type description used in numpy structured arrays
219
+ def numpy_dtype(self):
220
+ return self._cls.numpy_dtype()
151
221
 
152
- def __setattr__(self, name, value):
153
- assert name in self._struct_.vars, "invalid struct member variable {}".format(name)
154
- super().__setattr__(name, value)
222
+ # value usable in numpy structured arrays of .numpy_dtype(), e.g. (42, 13.37, [1.0, 2.0, 3.0])
223
+ def numpy_value(self):
224
+ npvalue = []
225
+ for name, var in self._cls.vars.items():
226
+ # get the attribute value
227
+ value = getattr(self._ctype, name)
155
228
 
156
- def __ctype__(self):
157
- return _get_struct_instance_ctype(self, None, None)
229
+ if isinstance(var.type, array):
230
+ # array_t
231
+ npvalue.append(value.numpy_value())
232
+ elif isinstance(var.type, Struct):
233
+ # nested struct
234
+ npvalue.append(value.numpy_value())
235
+ elif issubclass(var.type, ctypes.Array):
236
+ if len(var.type._shape_) == 1:
237
+ # vector
238
+ npvalue.append(list(value))
239
+ else:
240
+ # matrix
241
+ npvalue.append([list(row) for row in value])
242
+ else:
243
+ # scalar
244
+ if var.type == warp.float16:
245
+ npvalue.append(half_bits_to_float(value))
246
+ else:
247
+ npvalue.append(value)
158
248
 
159
- def __repr__(self):
160
- return _fmt_struct_instance_repr(self, 0)
249
+ return tuple(npvalue)
161
250
 
162
251
 
163
252
  class Struct:
@@ -184,7 +273,7 @@ class Struct:
184
273
 
185
274
  class StructType(ctypes.Structure):
186
275
  # if struct is empty, add a dummy field to avoid launch errors on CPU device ("ffi_prep_cif failed")
187
- _fields_ = fields or [("_dummy_", ctypes.c_int)]
276
+ _fields_ = fields or [("_dummy_", ctypes.c_byte)]
188
277
 
189
278
  self.ctype = StructType
190
279
 
@@ -235,29 +324,108 @@ class Struct:
235
324
 
236
325
  class NewStructInstance(self.cls, StructInstance):
237
326
  def __init__(inst):
238
- StructInstance.__init__(inst, self)
327
+ StructInstance.__init__(inst, self, None)
239
328
 
240
329
  return NewStructInstance()
241
330
 
242
331
  def initializer(self):
243
332
  return self.default_constructor
244
333
 
334
+ # return structured NumPy dtype, including field names, formats, and offsets
335
+ def numpy_dtype(self):
336
+ names = []
337
+ formats = []
338
+ offsets = []
339
+ for name, var in self.vars.items():
340
+ names.append(name)
341
+ offsets.append(getattr(self.ctype, name).offset)
342
+ if isinstance(var.type, array):
343
+ # array_t
344
+ formats.append(array_t.numpy_dtype())
345
+ elif isinstance(var.type, Struct):
346
+ # nested struct
347
+ formats.append(var.type.numpy_dtype())
348
+ elif issubclass(var.type, ctypes.Array):
349
+ scalar_typestr = type_typestr(var.type._wp_scalar_type_)
350
+ if len(var.type._shape_) == 1:
351
+ # vector
352
+ formats.append(f"{var.type._length_}{scalar_typestr}")
353
+ else:
354
+ # matrix
355
+ formats.append(f"{var.type._shape_}{scalar_typestr}")
356
+ else:
357
+ # scalar
358
+ formats.append(type_typestr(var.type))
359
+
360
+ return {"names": names, "formats": formats, "offsets": offsets, "itemsize": ctypes.sizeof(self.ctype)}
361
+
362
+ # constructs a Warp struct instance from a pointer to the ctype
363
+ def from_ptr(self, ptr):
364
+ if not ptr:
365
+ raise RuntimeError("NULL pointer exception")
366
+
367
+ # create a new struct instance
368
+ instance = self()
369
+
370
+ for name, var in self.vars.items():
371
+ offset = getattr(self.ctype, name).offset
372
+ if isinstance(var.type, array):
373
+ # We could reconstruct wp.array from array_t, but it's problematic.
374
+ # There's no guarantee that the original wp.array is still allocated and
375
+ # no easy way to make a backref.
376
+ # Instead, we just create a stub annotation, which is not a fully usable array object.
377
+ setattr(instance, name, array(dtype=var.type.dtype, ndim=var.type.ndim))
378
+ elif isinstance(var.type, Struct):
379
+ # nested struct
380
+ value = var.type.from_ptr(ptr + offset)
381
+ setattr(instance, name, value)
382
+ elif issubclass(var.type, ctypes.Array):
383
+ # vector/matrix
384
+ value = var.type.from_ptr(ptr + offset)
385
+ setattr(instance, name, value)
386
+ else:
387
+ # scalar
388
+ cvalue = ctypes.cast(ptr + offset, ctypes.POINTER(var.type._type_)).contents
389
+ if var.type == warp.float16:
390
+ setattr(instance, name, half_bits_to_float(cvalue))
391
+ else:
392
+ setattr(instance, name, cvalue.value)
393
+
394
+ return instance
395
+
396
+
397
+ class Reference:
398
+ def __init__(self, value_type):
399
+ self.value_type = value_type
400
+
401
+
402
+ def is_reference(type):
403
+ return isinstance(type, Reference)
404
+
405
+
406
+ def strip_reference(arg):
407
+ if is_reference(arg):
408
+ return arg.value_type
409
+ else:
410
+ return arg
411
+
245
412
 
246
413
  def compute_type_str(base_name, template_params):
247
- if template_params is None or len(template_params) == 0:
414
+ if not template_params:
248
415
  return base_name
249
- else:
250
416
 
251
- def param2str(p):
252
- if isinstance(p, int):
253
- return str(p)
254
- return p.__name__
417
+ def param2str(p):
418
+ if isinstance(p, int):
419
+ return str(p)
420
+ elif hasattr(p, "_type_"):
421
+ return f"wp::{p.__name__}"
422
+ return p.__name__
255
423
 
256
- return f"{base_name}<{','.join(map(param2str, template_params))}>"
424
+ return f"{base_name}<{','.join(map(param2str, template_params))}>"
257
425
 
258
426
 
259
427
  class Var:
260
- def __init__(self, label, type, requires_grad=False, constant=None):
428
+ def __init__(self, label, type, requires_grad=False, constant=None, prefix=True):
261
429
  # convert built-in types to wp types
262
430
  if type == float:
263
431
  type = float32
@@ -268,26 +436,49 @@ class Var:
268
436
  self.type = type
269
437
  self.requires_grad = requires_grad
270
438
  self.constant = constant
439
+ self.prefix = prefix
271
440
 
272
441
  def __str__(self):
273
442
  return self.label
274
443
 
275
- def ctype(self):
276
- if is_array(self.type):
277
- if hasattr(self.type.dtype, "_wp_generic_type_str_"):
278
- dtypestr = compute_type_str(self.type.dtype._wp_generic_type_str_, self.type.dtype._wp_type_params_)
279
- elif isinstance(self.type.dtype, Struct):
280
- dtypestr = make_full_qualified_name(self.type.dtype.cls)
444
+ @staticmethod
445
+ def type_to_ctype(t, value_type=False):
446
+ if is_array(t):
447
+ if hasattr(t.dtype, "_wp_generic_type_str_"):
448
+ dtypestr = compute_type_str(f"wp::{t.dtype._wp_generic_type_str_}", t.dtype._wp_type_params_)
449
+ elif isinstance(t.dtype, Struct):
450
+ dtypestr = make_full_qualified_name(t.dtype.cls)
451
+ elif t.dtype.__name__ in ("bool", "int", "float"):
452
+ dtypestr = t.dtype.__name__
281
453
  else:
282
- dtypestr = str(self.type.dtype.__name__)
283
- classstr = type(self.type).__name__
454
+ dtypestr = f"wp::{t.dtype.__name__}"
455
+ classstr = f"wp::{type(t).__name__}"
284
456
  return f"{classstr}_t<{dtypestr}>"
285
- elif isinstance(self.type, Struct):
286
- return make_full_qualified_name(self.type.cls)
287
- elif hasattr(self.type, "_wp_generic_type_str_"):
288
- return compute_type_str(self.type._wp_generic_type_str_, self.type._wp_type_params_)
457
+ elif isinstance(t, Struct):
458
+ return make_full_qualified_name(t.cls)
459
+ elif is_reference(t):
460
+ if not value_type:
461
+ return Var.type_to_ctype(t.value_type) + "*"
462
+ else:
463
+ return Var.type_to_ctype(t.value_type)
464
+ elif hasattr(t, "_wp_generic_type_str_"):
465
+ return compute_type_str(f"wp::{t._wp_generic_type_str_}", t._wp_type_params_)
466
+ elif t.__name__ in ("bool", "int", "float"):
467
+ return t.__name__
468
+ else:
469
+ return f"wp::{t.__name__}"
470
+
471
+ def ctype(self, value_type=False):
472
+ return Var.type_to_ctype(self.type, value_type)
473
+
474
+ def emit(self, prefix: str = "var"):
475
+ if self.prefix:
476
+ return f"{prefix}_{self.label}"
289
477
  else:
290
- return str(self.type.__name__)
478
+ return self.label
479
+
480
+ def emit_adj(self):
481
+ return self.emit("adj")
291
482
 
292
483
 
293
484
  class Block:
@@ -304,33 +495,65 @@ class Block:
304
495
  self.vars = []
305
496
 
306
497
 
498
+ def is_local_value(value) -> bool:
499
+ """Check whether a variable is defined inside a kernel."""
500
+ return isinstance(value, (warp.context.Function, Var))
501
+
502
+
307
503
  class Adjoint:
308
504
  # Source code transformer, this class takes a Python function and
309
505
  # generates forward and backward SSA forms of the function instructions
310
506
 
311
- def __init__(adj, func, overload_annotations=None):
507
+ def __init__(
508
+ adj,
509
+ func,
510
+ overload_annotations=None,
511
+ is_user_function=False,
512
+ skip_forward_codegen=False,
513
+ skip_reverse_codegen=False,
514
+ custom_reverse_mode=False,
515
+ custom_reverse_num_input_args=-1,
516
+ transformers: List[ast.NodeTransformer] = [],
517
+ ):
312
518
  adj.func = func
313
519
 
314
- # build AST from function object
315
- adj.source = inspect.getsource(func)
520
+ adj.is_user_function = is_user_function
316
521
 
317
- # get source code lines and line number where function starts
318
- adj.raw_source, adj.fun_lineno = inspect.getsourcelines(func)
522
+ # whether the generation of the forward code is skipped for this function
523
+ adj.skip_forward_codegen = skip_forward_codegen
524
+ # whether the generation of the adjoint code is skipped for this function
525
+ adj.skip_reverse_codegen = skip_reverse_codegen
319
526
 
320
- # keep track of line number in function code
321
- adj.lineno = None
527
+ # extract name of source file
528
+ adj.filename = inspect.getsourcefile(func) or "unknown source file"
529
+ # get source file line number where function starts
530
+ _, adj.fun_lineno = inspect.getsourcelines(func)
322
531
 
532
+ # get function source code
533
+ adj.source = inspect.getsource(func)
323
534
  # ensures that indented class methods can be parsed as kernels
324
535
  adj.source = textwrap.dedent(adj.source)
325
536
 
326
- # extract name of source file
327
- adj.filename = inspect.getsourcefile(func) or "unknown source file"
537
+ adj.source_lines = adj.source.splitlines()
328
538
 
329
- # build AST
539
+ # build AST and apply node transformers
330
540
  adj.tree = ast.parse(adj.source)
541
+ adj.transformers = transformers
542
+ for transformer in transformers:
543
+ adj.tree = transformer.visit(adj.tree)
331
544
 
332
545
  adj.fun_name = adj.tree.body[0].name
333
546
 
547
+ # for keeping track of line number in function code
548
+ adj.lineno = None
549
+
550
+ # whether the forward code shall be used for the reverse pass and a custom
551
+ # function signature is applied to the reverse version of the function
552
+ adj.custom_reverse_mode = custom_reverse_mode
553
+ # the number of function arguments that pertain to the forward function
554
+ # input arguments (i.e. the number of arguments that are not adjoint arguments)
555
+ adj.custom_reverse_num_input_args = custom_reverse_num_input_args
556
+
334
557
  # parse argument types
335
558
  argspec = inspect.getfullargspec(func)
336
559
 
@@ -338,16 +561,17 @@ class Adjoint:
338
561
  if overload_annotations is None:
339
562
  # use source-level argument annotations
340
563
  if len(argspec.annotations) < len(argspec.args):
341
- raise RuntimeError(f"Incomplete argument annotations on function {adj.fun_name}")
564
+ raise WarpCodegenError(f"Incomplete argument annotations on function {adj.fun_name}")
342
565
  adj.arg_types = argspec.annotations
343
566
  else:
344
567
  # use overload argument annotations
345
568
  for arg_name in argspec.args:
346
569
  if arg_name not in overload_annotations:
347
- raise RuntimeError(f"Incomplete overload annotations for function {adj.fun_name}")
570
+ raise WarpCodegenError(f"Incomplete overload annotations for function {adj.fun_name}")
348
571
  adj.arg_types = overload_annotations.copy()
349
572
 
350
573
  adj.args = []
574
+ adj.symbols = {}
351
575
 
352
576
  for name, type in adj.arg_types.items():
353
577
  # skip return hint
@@ -358,8 +582,23 @@ class Adjoint:
358
582
  arg = Var(name, type, False)
359
583
  adj.args.append(arg)
360
584
 
585
+ # pre-populate symbol dictionary with function argument names
586
+ # this is to avoid registering false references to overshadowed modules
587
+ adj.symbols[name] = arg
588
+
589
+ # There are cases where a same module might be rebuilt multiple times,
590
+ # for example when kernels are nested inside of functions, or when
591
+ # a kernel's launch raises an exception. Ideally we'd always want to
592
+ # avoid rebuilding kernels but some corner cases seem to depend on it,
593
+ # so we only avoid rebuilding kernels that errored out to give a chance
594
+ # for unit testing errors being spit out from kernels.
595
+ adj.skip_build = False
596
+
361
597
  # generate function ssa form and adjoint
362
598
  def build(adj, builder):
599
+ if adj.skip_build:
600
+ return
601
+
363
602
  adj.builder = builder
364
603
 
365
604
  adj.symbols = {} # map from symbols to adjoint variables
@@ -373,7 +612,7 @@ class Adjoint:
373
612
  adj.loop_blocks = []
374
613
 
375
614
  # holds current indent level
376
- adj.prefix = ""
615
+ adj.indentation = ""
377
616
 
378
617
  # used to generate new label indices
379
618
  adj.label_count = 0
@@ -387,20 +626,25 @@ class Adjoint:
387
626
  adj.eval(adj.tree.body[0])
388
627
  except Exception as e:
389
628
  try:
629
+ if isinstance(e, KeyError) and getattr(e.args[0], "__module__", None) == "ast":
630
+ msg = f'Syntax error: unsupported construct "ast.{e.args[0].__name__}"'
631
+ else:
632
+ msg = "Error"
390
633
  lineno = adj.lineno + adj.fun_lineno
391
- line = adj.source.splitlines()[adj.lineno]
392
- msg = f'Error while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
634
+ line = adj.source_lines[adj.lineno]
635
+ msg += f' while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
393
636
  ex, data, traceback = sys.exc_info()
394
- e = ex("".join([msg] + list(data.args))).with_traceback(traceback)
637
+ e = ex(";".join([msg] + [str(a) for a in data.args])).with_traceback(traceback)
395
638
  finally:
639
+ adj.skip_build = True
396
640
  raise e
397
641
 
398
- for a in adj.args:
399
- if isinstance(a.type, Struct):
400
- builder.build_struct_recursive(a.type)
401
- elif isinstance(a.type, warp.types.array) and isinstance(a.type.dtype, Struct):
402
- builder.build_struct_recursive(a.type.dtype)
403
-
642
+ if builder is not None:
643
+ for a in adj.args:
644
+ if isinstance(a.type, Struct):
645
+ builder.build_struct_recursive(a.type)
646
+ elif isinstance(a.type, warp.types.array) and isinstance(a.type.dtype, Struct):
647
+ builder.build_struct_recursive(a.type.dtype)
404
648
 
405
649
  # code generation methods
406
650
  def format_template(adj, template, input_vars, output_var):
@@ -415,44 +659,56 @@ class Adjoint:
415
659
  arg_strs = []
416
660
 
417
661
  for a in args:
418
- if type(a) == warp.context.Function:
662
+ if isinstance(a, warp.context.Function):
419
663
  # functions don't have a var_ prefix so strip it off here
420
- if prefix == "var_":
664
+ if prefix == "var":
421
665
  arg_strs.append(a.key)
422
666
  else:
423
- arg_strs.append(prefix + a.key)
424
-
667
+ arg_strs.append(f"{prefix}_{a.key}")
668
+ elif is_reference(a.type):
669
+ arg_strs.append(f"{prefix}_{a}")
670
+ elif isinstance(a, Var):
671
+ arg_strs.append(a.emit(prefix))
425
672
  else:
426
- arg_strs.append(prefix + str(a))
673
+ raise WarpCodegenTypeError(f"Arguments must be variables or functions, got {type(a)}")
427
674
 
428
675
  return arg_strs
429
676
 
430
677
  # generates argument string for a forward function call
431
678
  def format_forward_call_args(adj, args, use_initializer_list):
432
- arg_str = ", ".join(adj.format_args("var_", args))
679
+ arg_str = ", ".join(adj.format_args("var", args))
433
680
  if use_initializer_list:
434
- return "{{{}}}".format(arg_str)
681
+ return f"{{{arg_str}}}"
435
682
  return arg_str
436
683
 
437
684
  # generates argument string for a reverse function call
438
- def format_reverse_call_args(adj, args, args_out, non_adjoint_args, non_adjoint_outputs, use_initializer_list):
439
- formatted_var = adj.format_args("var_", args)
685
+ def format_reverse_call_args(
686
+ adj,
687
+ args_var,
688
+ args,
689
+ args_out,
690
+ use_initializer_list,
691
+ has_output_args=True,
692
+ require_original_output_arg=False,
693
+ ):
694
+ formatted_var = adj.format_args("var", args_var)
440
695
  formatted_out = []
441
- if len(args_out) > 1:
442
- formatted_out = adj.format_args("var_", args_out)
696
+ if has_output_args and (require_original_output_arg or len(args_out) > 1):
697
+ formatted_out = adj.format_args("var", args_out)
443
698
  formatted_var_adj = adj.format_args(
444
- "&adj_" if use_initializer_list else "adj_", [a for i, a in enumerate(args) if i not in non_adjoint_args]
699
+ "&adj" if use_initializer_list else "adj",
700
+ args,
445
701
  )
446
- formatted_out_adj = adj.format_args("adj_", [a for i, a in enumerate(args_out) if i not in non_adjoint_outputs])
702
+ formatted_out_adj = adj.format_args("adj", args_out)
447
703
 
448
704
  if len(formatted_var_adj) == 0 and len(formatted_out_adj) == 0:
449
705
  # there are no adjoint arguments, so we don't need to call the reverse function
450
706
  return None
451
707
 
452
708
  if use_initializer_list:
453
- var_str = "{{{}}}".format(", ".join(formatted_var))
454
- out_str = "{{{}}}".format(", ".join(formatted_out))
455
- adj_str = "{{{}}}".format(", ".join(formatted_var_adj))
709
+ var_str = f"{{{', '.join(formatted_var)}}}"
710
+ out_str = f"{{{', '.join(formatted_out)}}}"
711
+ adj_str = f"{{{', '.join(formatted_var_adj)}}}"
456
712
  out_adj_str = ", ".join(formatted_out_adj)
457
713
  if len(args_out) > 1:
458
714
  arg_str = ", ".join([var_str, out_str, adj_str, out_adj_str])
@@ -463,10 +719,10 @@ class Adjoint:
463
719
  return arg_str
464
720
 
465
721
  def indent(adj):
466
- adj.prefix = adj.prefix + "\t"
722
+ adj.indentation = adj.indentation + " "
467
723
 
468
724
  def dedent(adj):
469
- adj.prefix = adj.prefix[0:-1]
725
+ adj.indentation = adj.indentation[:-4]
470
726
 
471
727
  def begin_block(adj):
472
728
  b = Block()
@@ -481,10 +737,9 @@ class Adjoint:
481
737
  def end_block(adj):
482
738
  return adj.blocks.pop()
483
739
 
484
- def add_var(adj, type=None, constant=None, name=None):
485
- if name is None:
486
- index = len(adj.variables)
487
- name = str(index)
740
+ def add_var(adj, type=None, constant=None):
741
+ index = len(adj.variables)
742
+ name = str(index)
488
743
 
489
744
  # allocate new variable
490
745
  v = Var(name, type=type, constant=constant)
@@ -497,30 +752,54 @@ class Adjoint:
497
752
 
498
753
  # append a statement to the forward pass
499
754
  def add_forward(adj, statement, replay=None, skip_replay=False):
500
- adj.blocks[-1].body_forward.append(adj.prefix + statement)
755
+ adj.blocks[-1].body_forward.append(adj.indentation + statement)
501
756
 
502
757
  if not skip_replay:
503
758
  if replay:
504
759
  # if custom replay specified then output it
505
- adj.blocks[-1].body_replay.append(adj.prefix + replay)
760
+ adj.blocks[-1].body_replay.append(adj.indentation + replay)
506
761
  else:
507
762
  # by default just replay the original statement
508
- adj.blocks[-1].body_replay.append(adj.prefix + statement)
763
+ adj.blocks[-1].body_replay.append(adj.indentation + statement)
509
764
 
510
765
  # append a statement to the reverse pass
511
766
  def add_reverse(adj, statement):
512
- adj.blocks[-1].body_reverse.append(adj.prefix + statement)
767
+ adj.blocks[-1].body_reverse.append(adj.indentation + statement)
513
768
 
514
769
  def add_constant(adj, n):
515
770
  output = adj.add_var(type=type(n), constant=n)
516
771
  return output
517
772
 
773
+ def load(adj, var):
774
+ if is_reference(var.type):
775
+ var = adj.add_builtin_call("load", [var])
776
+ return var
777
+
518
778
  def add_comp(adj, op_strings, left, comps):
519
- output = adj.add_var(bool)
779
+ output = adj.add_var(builtins.bool)
780
+
781
+ left = adj.load(left)
782
+ s = output.emit() + " = " + ("(" * len(comps)) + left.emit() + " "
783
+
784
+ prev_comp = None
520
785
 
521
- s = "var_" + str(output) + " = " + ("(" * len(comps)) + "var_" + str(left) + " "
522
786
  for op, comp in zip(op_strings, comps):
523
- s += op + " var_" + str(comp) + ") "
787
+ comp_chainable = op_str_is_chainable(op)
788
+ if comp_chainable and prev_comp:
789
+ # We restrict chaining to operands of the same type
790
+ if prev_comp.type is comp.type:
791
+ prev_comp = adj.load(prev_comp)
792
+ comp = adj.load(comp)
793
+ s += "&& (" + prev_comp.emit() + " " + op + " " + comp.emit() + ")) "
794
+ else:
795
+ raise WarpCodegenTypeError(
796
+ f"Cannot chain comparisons of unequal types: {prev_comp.type} {op} {comp.type}."
797
+ )
798
+ else:
799
+ comp = adj.load(comp)
800
+ s += op + " " + comp.emit() + ") "
801
+
802
+ prev_comp = comp
524
803
 
525
804
  s = s.rstrip() + ";"
526
805
 
@@ -529,109 +808,106 @@ class Adjoint:
529
808
  return output
530
809
 
531
810
  def add_bool_op(adj, op_string, exprs):
532
- output = adj.add_var(bool)
533
- command = (
534
- "var_" + str(output) + " = " + (" " + op_string + " ").join(["var_" + str(expr) for expr in exprs]) + ";"
535
- )
811
+ exprs = [adj.load(expr) for expr in exprs]
812
+ output = adj.add_var(builtins.bool)
813
+ command = output.emit() + " = " + (" " + op_string + " ").join([expr.emit() for expr in exprs]) + ";"
536
814
  adj.add_forward(command)
537
815
 
538
816
  return output
539
817
 
540
- def add_call(adj, func, args, min_outputs=None, templates=[], kwds=None):
541
- # if func is overloaded then perform overload resolution here
542
- # we validate argument types before they go to generated native code
543
- resolved_func = None
818
+ def resolve_func(adj, func, args, min_outputs, templates, kwds):
819
+ arg_types = [strip_reference(a.type) for a in args if not isinstance(a, warp.context.Function)]
544
820
 
545
- if func.is_builtin():
821
+ if not func.is_builtin():
822
+ # user-defined function
823
+ overload = func.get_overload(arg_types)
824
+ if overload is not None:
825
+ return overload
826
+ else:
827
+ # if func is overloaded then perform overload resolution here
828
+ # we validate argument types before they go to generated native code
546
829
  for f in func.overloads:
547
- match = True
548
-
549
830
  # skip type checking for variadic functions
550
831
  if not f.variadic:
551
832
  # check argument counts match are compatible (may be some default args)
552
833
  if len(f.input_types) < len(args):
553
- match = False
554
834
  continue
555
835
 
556
- # check argument types equal
557
- for i, (arg_name, arg_type) in enumerate(f.input_types.items()):
558
- # if arg type registered as Any, treat as
559
- # template allowing any type to match
560
- if arg_type == Any:
561
- continue
562
-
563
- # handle function refs as a special case
564
- if arg_type == Callable and type(args[i]) is warp.context.Function:
565
- continue
566
-
567
- # look for default values for missing args
568
- if i >= len(args):
569
- if arg_name not in f.defaults:
570
- match = False
571
- break
572
- else:
573
- # otherwise check arg type matches input variable type
574
- if not types_equal(arg_type, args[i].type, match_generic=True):
575
- match = False
576
- break
836
+ def match_args(args, f):
837
+ # check argument types equal
838
+ for i, (arg_name, arg_type) in enumerate(f.input_types.items()):
839
+ # if arg type registered as Any, treat as
840
+ # template allowing any type to match
841
+ if arg_type == Any:
842
+ continue
843
+
844
+ # handle function refs as a special case
845
+ if arg_type == Callable and type(args[i]) is warp.context.Function:
846
+ continue
847
+
848
+ if arg_type == Reference and is_reference(args[i].type):
849
+ continue
850
+
851
+ # look for default values for missing args
852
+ if i >= len(args):
853
+ if arg_name not in f.defaults:
854
+ return False
855
+ else:
856
+ # otherwise check arg type matches input variable type
857
+ if not types_equal(arg_type, strip_reference(args[i].type), match_generic=True):
858
+ return False
859
+
860
+ return True
861
+
862
+ if not match_args(args, f):
863
+ continue
577
864
 
578
865
  # check output dimensions match expectations
579
866
  if min_outputs:
580
867
  try:
581
868
  value_type = f.value_func(args, kwds, templates)
582
- if len(value_type) != min_outputs:
583
- match = False
869
+ if not hasattr(value_type, "__len__") or len(value_type) != min_outputs:
584
870
  continue
585
871
  except Exception:
586
872
  # value func may fail if the user has given
587
873
  # incorrect args, so we need to catch this
588
- match = False
589
874
  continue
590
875
 
591
876
  # found a match, use it
592
- if match:
593
- resolved_func = f
594
- break
595
- else:
596
- # user-defined function
597
- arg_types = [a.type for a in args]
598
- resolved_func = func.get_overload(arg_types)
599
-
600
- if resolved_func is None:
601
- arg_types = []
602
-
603
- for x in args:
604
- if isinstance(x, Var):
605
- # shorten Warp primitive type names
606
- if isinstance(x.type, list):
607
- if len(x.type) != 1:
608
- raise Exception("Argument must not be the result from a multi-valued function")
609
- arg_type = x.type[0]
610
- else:
611
- arg_type = x.type
612
- if arg_type.__module__ == "warp.types":
613
- arg_types.append(arg_type.__name__)
614
- else:
615
- arg_types.append(arg_type.__module__ + "." + arg_type.__name__)
616
-
617
- if isinstance(x, warp.context.Function):
618
- arg_types.append("function")
619
-
620
- raise Exception(
621
- f"Couldn't find function overload for '{func.key}' that matched inputs with types: [{', '.join(arg_types)}]"
622
- )
877
+ return f
878
+
879
+ # unresolved function, report error
880
+ arg_types = []
881
+
882
+ for x in args:
883
+ if isinstance(x, Var):
884
+ # shorten Warp primitive type names
885
+ if isinstance(x.type, list):
886
+ if len(x.type) != 1:
887
+ raise WarpCodegenError("Argument must not be the result from a multi-valued function")
888
+ arg_type = x.type[0]
889
+ else:
890
+ arg_type = x.type
623
891
 
624
- else:
625
- func = resolved_func
892
+ arg_types.append(type_repr(arg_type))
893
+
894
+ if isinstance(x, warp.context.Function):
895
+ arg_types.append("function")
896
+
897
+ raise WarpCodegenError(
898
+ f"Couldn't find function overload for '{func.key}' that matched inputs with types: [{', '.join(arg_types)}]"
899
+ )
900
+
901
+ def add_call(adj, func, args, min_outputs=None, templates=[], kwds=None):
902
+ func = adj.resolve_func(func, args, min_outputs, templates, kwds)
626
903
 
627
904
  # push any default values onto args
628
905
  for i, (arg_name, arg_type) in enumerate(func.input_types.items()):
629
906
  if i >= len(args):
630
- if arg_name in f.defaults:
907
+ if arg_name in func.defaults:
631
908
  const = adj.add_constant(func.defaults[arg_name])
632
909
  args.append(const)
633
910
  else:
634
- match = False
635
911
  break
636
912
 
637
913
  # if it is a user-function then build it recursively
@@ -639,93 +915,105 @@ class Adjoint:
639
915
  adj.builder.build_function(func)
640
916
 
641
917
  # evaluate the function type based on inputs
642
- value_type = func.value_func(args, kwds, templates)
918
+ arg_types = [strip_reference(a.type) for a in args if not isinstance(a, warp.context.Function)]
919
+ return_type = func.value_func(arg_types, kwds, templates)
643
920
 
644
921
  func_name = compute_type_str(func.native_func, templates)
922
+ param_types = list(func.input_types.values())
645
923
 
646
924
  use_initializer_list = func.initializer_list_func(args, templates)
647
925
 
648
- if value_type is None:
649
- # handles expression (zero output) functions, e.g.: void do_something();
650
-
651
- forward_call = "{}{}({});".format(
652
- func.namespace, func_name, adj.format_forward_call_args(args, use_initializer_list)
653
- )
654
- if func.skip_replay:
655
- adj.add_forward(forward_call, replay="//" + forward_call)
656
- else:
657
- adj.add_forward(forward_call)
658
-
659
- if not func.missing_grad and len(args):
660
- arg_str = adj.format_reverse_call_args(args, [], {}, {}, use_initializer_list)
661
- if arg_str is not None:
662
- reverse_call = "{}adj_{}({});".format(func.namespace, func.native_func, arg_str)
663
- adj.add_reverse(reverse_call)
926
+ args_var = [
927
+ adj.load(a)
928
+ if not ((param_types[i] == Reference or param_types[i] == Callable) if i < len(param_types) else False)
929
+ else a
930
+ for i, a in enumerate(args)
931
+ ]
664
932
 
665
- return None
933
+ if return_type is None:
934
+ # handles expression (zero output) functions, e.g.: void do_something();
666
935
 
667
- elif not isinstance(value_type, list) or len(value_type) == 1:
668
- # handle simple function (one output)
936
+ output = None
937
+ output_list = []
669
938
 
670
- if isinstance(value_type, list):
671
- value_type = value_type[0]
672
- output = adj.add_var(value_type)
673
- forward_call = "var_{} = {}{}({});".format(
674
- output, func.namespace, func_name, adj.format_forward_call_args(args, use_initializer_list)
939
+ forward_call = (
940
+ f"{func.namespace}{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
675
941
  )
942
+ replay_call = forward_call
943
+ if func.custom_replay_func is not None:
944
+ replay_call = f"{func.namespace}replay_{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
676
945
 
677
- if func.skip_replay:
678
- adj.add_forward(forward_call, replay="//" + forward_call)
679
- else:
680
- adj.add_forward(forward_call)
946
+ elif not isinstance(return_type, list) or len(return_type) == 1:
947
+ # handle simple function (one output)
681
948
 
682
- if not func.missing_grad and len(args):
683
- arg_str = adj.format_reverse_call_args(args, [output], {}, {}, use_initializer_list)
684
- if arg_str is not None:
685
- reverse_call = "{}adj_{}({});".format(func.namespace, func.native_func, arg_str)
686
- adj.add_reverse(reverse_call)
949
+ if isinstance(return_type, list):
950
+ return_type = return_type[0]
951
+ output = adj.add_var(return_type)
952
+ output_list = [output]
687
953
 
688
- return output
954
+ forward_call = f"var_{output} = {func.namespace}{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
955
+ replay_call = forward_call
956
+ if func.custom_replay_func is not None:
957
+ replay_call = f"var_{output} = {func.namespace}replay_{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
689
958
 
690
959
  else:
691
960
  # handle multiple value functions
692
961
 
693
- output = [adj.add_var(v) for v in value_type]
694
- forward_call = "{}{}({});".format(
695
- func.namespace, func_name, adj.format_forward_call_args(args + output, use_initializer_list)
962
+ output = [adj.add_var(v) for v in return_type]
963
+ output_list = output
964
+
965
+ forward_call = (
966
+ f"{func.namespace}{func_name}({adj.format_forward_call_args(args_var + output, use_initializer_list)});"
696
967
  )
697
- adj.add_forward(forward_call)
968
+ replay_call = forward_call
698
969
 
699
- if not func.missing_grad and len(args):
700
- arg_str = adj.format_reverse_call_args(args, output, {}, {}, use_initializer_list)
701
- if arg_str is not None:
702
- reverse_call = "{}adj_{}({});".format(func.namespace, func.native_func, arg_str)
703
- adj.add_reverse(reverse_call)
970
+ if func.skip_replay:
971
+ adj.add_forward(forward_call, replay="// " + replay_call)
972
+ else:
973
+ adj.add_forward(forward_call, replay=replay_call)
974
+
975
+ if not func.missing_grad and len(args):
976
+ reverse_has_output_args = (
977
+ func.require_original_output_arg or len(output_list) > 1
978
+ ) and func.custom_grad_func is None
979
+ arg_str = adj.format_reverse_call_args(
980
+ args_var,
981
+ args,
982
+ output_list,
983
+ use_initializer_list,
984
+ has_output_args=reverse_has_output_args,
985
+ require_original_output_arg=func.require_original_output_arg,
986
+ )
987
+ if arg_str is not None:
988
+ reverse_call = f"{func.namespace}adj_{func.native_func}({arg_str});"
989
+ adj.add_reverse(reverse_call)
704
990
 
705
- if len(output) == 1:
706
- return output[0]
991
+ return output
707
992
 
708
- return output
993
+ def add_builtin_call(adj, func_name, args, min_outputs=None, templates=[], kwds=None):
994
+ func = warp.context.builtin_functions[func_name]
995
+ return adj.add_call(func, args, min_outputs, templates, kwds)
709
996
 
710
997
  def add_return(adj, var):
711
998
  if var is None or len(var) == 0:
712
- adj.add_forward("return;", "goto label{};".format(adj.label_count))
999
+ adj.add_forward("return;", f"goto label{adj.label_count};")
713
1000
  elif len(var) == 1:
714
- adj.add_forward("return var_{};".format(var[0]), "goto label{};".format(adj.label_count))
1001
+ adj.add_forward(f"return {var[0].emit()};", f"goto label{adj.label_count};")
715
1002
  adj.add_reverse("adj_" + str(var[0]) + " += adj_ret;")
716
1003
  else:
717
1004
  for i, v in enumerate(var):
718
- adj.add_forward("ret_{} = var_{};".format(i, v))
719
- adj.add_reverse("adj_{} += adj_ret_{};".format(v, i))
720
- adj.add_forward("return;", "goto label{};".format(adj.label_count))
1005
+ adj.add_forward(f"ret_{i} = {v.emit()};")
1006
+ adj.add_reverse(f"adj_{v} += adj_ret_{i};")
1007
+ adj.add_forward("return;", f"goto label{adj.label_count};")
721
1008
 
722
- adj.add_reverse("label{}:;".format(adj.label_count))
1009
+ adj.add_reverse(f"label{adj.label_count}:;")
723
1010
 
724
1011
  adj.label_count += 1
725
1012
 
726
1013
  # define an if statement
727
1014
  def begin_if(adj, cond):
728
- adj.add_forward("if (var_{}) {{".format(cond))
1015
+ cond = adj.load(cond)
1016
+ adj.add_forward(f"if ({cond.emit()}) {{")
729
1017
  adj.add_reverse("}")
730
1018
 
731
1019
  adj.indent()
@@ -734,10 +1022,12 @@ class Adjoint:
734
1022
  adj.dedent()
735
1023
 
736
1024
  adj.add_forward("}")
737
- adj.add_reverse(f"if (var_{cond}) {{")
1025
+ cond = adj.load(cond)
1026
+ adj.add_reverse(f"if ({cond.emit()}) {{")
738
1027
 
739
1028
  def begin_else(adj, cond):
740
- adj.add_forward(f"if (!var_{cond}) {{")
1029
+ cond = adj.load(cond)
1030
+ adj.add_forward(f"if (!{cond.emit()}) {{")
741
1031
  adj.add_reverse("}")
742
1032
 
743
1033
  adj.indent()
@@ -746,7 +1036,8 @@ class Adjoint:
746
1036
  adj.dedent()
747
1037
 
748
1038
  adj.add_forward("}")
749
- adj.add_reverse(f"if (!var_{cond}) {{")
1039
+ cond = adj.load(cond)
1040
+ adj.add_reverse(f"if (!{cond.emit()}) {{")
750
1041
 
751
1042
  # define a for-loop
752
1043
  def begin_for(adj, iter):
@@ -756,10 +1047,10 @@ class Adjoint:
756
1047
  adj.indent()
757
1048
 
758
1049
  # evaluate cond
759
- adj.add_forward(f"if (iter_cmp(var_{iter}) == 0) goto for_end_{cond_block.label};")
1050
+ adj.add_forward(f"if (iter_cmp({iter.emit()}) == 0) goto for_end_{cond_block.label};")
760
1051
 
761
1052
  # evaluate iter
762
- val = adj.add_call(warp.context.builtin_functions["iter_next"], [iter])
1053
+ val = adj.add_builtin_call("iter_next", [iter])
763
1054
 
764
1055
  adj.begin_block()
765
1056
 
@@ -790,17 +1081,14 @@ class Adjoint:
790
1081
  reverse = []
791
1082
 
792
1083
  # reverse iterator
793
- reverse.append(adj.prefix + f"var_{iter} = wp::iter_reverse(var_{iter});")
1084
+ reverse.append(adj.indentation + f"{iter.emit()} = wp::iter_reverse({iter.emit()});")
794
1085
 
795
1086
  for i in cond_block.body_forward:
796
1087
  reverse.append(i)
797
1088
 
798
1089
  # zero adjoints
799
1090
  for i in body_block.vars:
800
- if isinstance(i.type, Struct):
801
- reverse.append(adj.prefix + f"\tadj_{i} = {i.ctype()}{{}};")
802
- else:
803
- reverse.append(adj.prefix + f"\tadj_{i} = {i.ctype()}(0);")
1091
+ reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
804
1092
 
805
1093
  # replay
806
1094
  for i in body_block.body_replay:
@@ -810,14 +1098,14 @@ class Adjoint:
810
1098
  for i in reversed(body_block.body_reverse):
811
1099
  reverse.append(i)
812
1100
 
813
- reverse.append(adj.prefix + f"\tgoto for_start_{cond_block.label};")
814
- reverse.append(adj.prefix + f"for_end_{cond_block.label}:;")
1101
+ reverse.append(adj.indentation + f"\tgoto for_start_{cond_block.label};")
1102
+ reverse.append(adj.indentation + f"for_end_{cond_block.label}:;")
815
1103
 
816
1104
  adj.blocks[-1].body_reverse.extend(reversed(reverse))
817
1105
 
818
1106
  # define a while loop
819
1107
  def begin_while(adj, cond):
820
- # evaulate condition in its own block
1108
+ # evaluate condition in its own block
821
1109
  # so we can control replay
822
1110
  cond_block = adj.begin_block()
823
1111
  adj.loop_blocks.append(cond_block)
@@ -825,7 +1113,7 @@ class Adjoint:
825
1113
 
826
1114
  c = adj.eval(cond)
827
1115
 
828
- cond_block.body_forward.append(f"if ((var_{c}) == false) goto while_end_{cond_block.label};")
1116
+ cond_block.body_forward.append(f"if (({c.emit()}) == false) goto while_end_{cond_block.label};")
829
1117
 
830
1118
  # being block around loop
831
1119
  adj.begin_block()
@@ -859,10 +1147,7 @@ class Adjoint:
859
1147
 
860
1148
  # zero adjoints of local vars
861
1149
  for i in body_block.vars:
862
- if isinstance(i.type, Struct):
863
- reverse.append(f"adj_{i} = {i.ctype()}{{}};")
864
- else:
865
- reverse.append(f"adj_{i} = {i.ctype()}(0);")
1150
+ reverse.append(f"{i.emit_adj()} = {{}};")
866
1151
 
867
1152
  # replay
868
1153
  for i in body_block.body_replay:
@@ -882,6 +1167,10 @@ class Adjoint:
882
1167
  for f in node.body:
883
1168
  adj.eval(f)
884
1169
 
1170
+ if adj.return_var is not None and len(adj.return_var) == 1:
1171
+ if not isinstance(node.body[-1], ast.Return):
1172
+ adj.add_forward("return {};", skip_replay=True)
1173
+
885
1174
  def emit_If(adj, node):
886
1175
  if len(node.body) == 0:
887
1176
  return None
@@ -909,7 +1198,7 @@ class Adjoint:
909
1198
 
910
1199
  if var1 != var2:
911
1200
  # insert a phi function that selects var1, var2 based on cond
912
- out = adj.add_call(warp.context.builtin_functions["select"], [cond, var1, var2])
1201
+ out = adj.add_builtin_call("select", [cond, var1, var2])
913
1202
  adj.symbols[sym] = out
914
1203
 
915
1204
  symbols_prev = adj.symbols.copy()
@@ -933,7 +1222,7 @@ class Adjoint:
933
1222
  if var1 != var2:
934
1223
  # insert a phi function that selects var1, var2 based on cond
935
1224
  # note the reversed order of vars since we want to use !cond as our select
936
- out = adj.add_call(warp.context.builtin_functions["select"], [cond, var2, var1])
1225
+ out = adj.add_builtin_call("select", [cond, var2, var1])
937
1226
  adj.symbols[sym] = out
938
1227
 
939
1228
  def emit_Compare(adj, node):
@@ -955,7 +1244,7 @@ class Adjoint:
955
1244
  elif isinstance(op, ast.Or):
956
1245
  func = "||"
957
1246
  else:
958
- raise KeyError("Op {} is not supported".format(op))
1247
+ raise WarpCodegenKeyError(f"Op {op} is not supported")
959
1248
 
960
1249
  return adj.add_bool_op(func, [adj.eval(expr) for expr in node.values])
961
1250
 
@@ -975,7 +1264,7 @@ class Adjoint:
975
1264
  obj = capturedvars.get(str(node.id), None)
976
1265
 
977
1266
  if obj is None:
978
- raise KeyError("Referencing undefined symbol: " + str(node.id))
1267
+ raise WarpCodegenKeyError("Referencing undefined symbol: " + str(node.id))
979
1268
 
980
1269
  if warp.types.is_value(obj):
981
1270
  # evaluate constant
@@ -987,26 +1276,96 @@ class Adjoint:
987
1276
  # pass it back to the caller for processing
988
1277
  return obj
989
1278
 
1279
+ @staticmethod
1280
+ def resolve_type_attribute(var_type: type, attr: str):
1281
+ if isinstance(var_type, type) and type_is_value(var_type):
1282
+ if attr == "dtype":
1283
+ return type_scalar_type(var_type)
1284
+ elif attr == "length":
1285
+ return type_length(var_type)
1286
+
1287
+ return getattr(var_type, attr, None)
1288
+
1289
+ def vector_component_index(adj, component, vector_type):
1290
+ if len(component) != 1:
1291
+ raise WarpCodegenAttributeError(f"Vector swizzle must be single character, got .{component}")
1292
+
1293
+ dim = vector_type._shape_[0]
1294
+ swizzles = "xyzw"[0:dim]
1295
+ if component not in swizzles:
1296
+ raise WarpCodegenAttributeError(
1297
+ f"Vector swizzle for {vector_type} must be one of {swizzles}, got {component}"
1298
+ )
1299
+
1300
+ index = swizzles.index(component)
1301
+ index = adj.add_constant(index)
1302
+ return index
1303
+
1304
+ @staticmethod
1305
+ def is_differentiable_value_type(var_type):
1306
+ # checks that the argument type is a value type (i.e, not an array)
1307
+ # possibly holding differentiable values (for which gradients must be accumulated)
1308
+ return type_scalar_type(var_type) in float_types or isinstance(var_type, Struct)
1309
+
990
1310
  def emit_Attribute(adj, node):
991
- try:
992
- val = adj.eval(node.value)
1311
+ if hasattr(node, "is_adjoint"):
1312
+ node.value.is_adjoint = True
1313
+
1314
+ aggregate = adj.eval(node.value)
993
1315
 
994
- if isinstance(val, types.ModuleType) or isinstance(val, type):
995
- out = getattr(val, node.attr)
1316
+ try:
1317
+ if isinstance(aggregate, types.ModuleType) or isinstance(aggregate, type):
1318
+ out = getattr(aggregate, node.attr)
996
1319
 
997
1320
  if warp.types.is_value(out):
998
1321
  return adj.add_constant(out)
999
1322
 
1000
1323
  return out
1001
1324
 
1002
- # create a Var that points to the struct attribute, i.e.: directly generates `struct.attr` when used
1003
- attr_name = val.label + "." + node.attr
1004
- attr_type = val.type.vars[node.attr].type
1325
+ if hasattr(node, "is_adjoint"):
1326
+ # create a Var that points to the struct attribute, i.e.: directly generates `struct.attr` when used
1327
+ attr_name = aggregate.label + "." + node.attr
1328
+ attr_type = aggregate.type.vars[node.attr].type
1329
+
1330
+ return Var(attr_name, attr_type)
1331
+
1332
+ aggregate_type = strip_reference(aggregate.type)
1333
+
1334
+ # reading a vector component
1335
+ if type_is_vector(aggregate_type):
1336
+ index = adj.vector_component_index(node.attr, aggregate_type)
1337
+
1338
+ return adj.add_builtin_call("extract", [aggregate, index])
1339
+
1340
+ else:
1341
+ attr_type = Reference(aggregate_type.vars[node.attr].type)
1342
+ attr = adj.add_var(attr_type)
1343
+
1344
+ if is_reference(aggregate.type):
1345
+ adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}->{node.attr});")
1346
+ else:
1347
+ adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}.{node.attr});")
1348
+
1349
+ if adj.is_differentiable_value_type(strip_reference(attr_type)):
1350
+ adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} += {attr.emit_adj()};")
1351
+ else:
1352
+ adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} = {attr.emit_adj()};")
1353
+
1354
+ return attr
1355
+
1356
+ except (KeyError, AttributeError):
1357
+ # Try resolving as type attribute
1358
+ aggregate_type = strip_reference(aggregate.type) if isinstance(aggregate, Var) else aggregate
1005
1359
 
1006
- return Var(attr_name, attr_type)
1360
+ type_attribute = adj.resolve_type_attribute(aggregate_type, node.attr)
1361
+ if type_attribute is not None:
1362
+ return type_attribute
1007
1363
 
1008
- except KeyError:
1009
- raise RuntimeError(f"Error, `{node.attr}` is not an attribute of '{val.label}' ({val.type})")
1364
+ if isinstance(aggregate, Var):
1365
+ raise WarpCodegenAttributeError(
1366
+ f"Error, `{node.attr}` is not an attribute of '{node.value.id}' ({type_repr(aggregate.type)})"
1367
+ )
1368
+ raise WarpCodegenAttributeError(f"Error, `{node.attr}` is not an attribute of '{aggregate}'")
1010
1369
 
1011
1370
  def emit_String(adj, node):
1012
1371
  # string constant
@@ -1023,19 +1382,25 @@ class Adjoint:
1023
1382
  adj.symbols[key] = out
1024
1383
  return out
1025
1384
 
1385
+ def emit_Ellipsis(adj, node):
1386
+ # stubbed @wp.native_func
1387
+ return
1388
+
1026
1389
  def emit_NameConstant(adj, node):
1027
- if node.value == True:
1390
+ if node.value:
1028
1391
  return adj.add_constant(True)
1029
- elif node.value == False:
1030
- return adj.add_constant(False)
1031
1392
  elif node.value is None:
1032
- raise TypeError("None type unsupported")
1393
+ raise WarpCodegenTypeError("None type unsupported")
1394
+ else:
1395
+ return adj.add_constant(False)
1033
1396
 
1034
1397
  def emit_Constant(adj, node):
1035
1398
  if isinstance(node, ast.Str):
1036
1399
  return adj.emit_String(node)
1037
1400
  elif isinstance(node, ast.Num):
1038
1401
  return adj.emit_Num(node)
1402
+ elif isinstance(node, ast.Ellipsis):
1403
+ return adj.emit_Ellipsis(node)
1039
1404
  else:
1040
1405
  assert isinstance(node, ast.NameConstant)
1041
1406
  return adj.emit_NameConstant(node)
@@ -1046,18 +1411,16 @@ class Adjoint:
1046
1411
  right = adj.eval(node.right)
1047
1412
 
1048
1413
  name = builtin_operators[type(node.op)]
1049
- func = warp.context.builtin_functions[name]
1050
1414
 
1051
- return adj.add_call(func, [left, right])
1415
+ return adj.add_builtin_call(name, [left, right])
1052
1416
 
1053
1417
  def emit_UnaryOp(adj, node):
1054
1418
  # evaluate unary op arguments
1055
1419
  arg = adj.eval(node.operand)
1056
1420
 
1057
1421
  name = builtin_operators[type(node.op)]
1058
- func = warp.context.builtin_functions[name]
1059
1422
 
1060
- return adj.add_call(func, [arg])
1423
+ return adj.add_builtin_call(name, [arg])
1061
1424
 
1062
1425
  def materialize_redefinitions(adj, symbols):
1063
1426
  # detect symbols with conflicting definitions (assigned inside the for loop)
@@ -1067,21 +1430,19 @@ class Adjoint:
1067
1430
  var2 = adj.symbols[sym]
1068
1431
 
1069
1432
  if var1 != var2:
1070
- if warp.config.verbose:
1433
+ if warp.config.verbose and not adj.custom_reverse_mode:
1071
1434
  lineno = adj.lineno + adj.fun_lineno
1072
- line = adj.source.splitlines()[adj.lineno]
1073
- msg = f'Warning: detected mutated variable {sym} during a dynamic for-loop in function "{adj.fun_name}" at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n'
1435
+ line = adj.source_lines[adj.lineno]
1436
+ msg = f'Warning: detected mutated variable {sym} during a dynamic for-loop in function "{adj.fun_name}" at {adj.filename}:{lineno}: this may not be a differentiable operation.\n{line}\n'
1074
1437
  print(msg)
1075
1438
 
1076
1439
  if var1.constant is not None:
1077
- raise Exception(
1078
- "Error mutating a constant {} inside a dynamic loop, use the following syntax: pi = float(3.141) to declare a dynamic variable".format(
1079
- sym
1080
- )
1440
+ raise WarpCodegenError(
1441
+ f"Error mutating a constant {sym} inside a dynamic loop, use the following syntax: pi = float(3.141) to declare a dynamic variable"
1081
1442
  )
1082
1443
 
1083
1444
  # overwrite the old variable value (violates SSA)
1084
- adj.add_call(warp.context.builtin_functions["copy"], [var1, var2])
1445
+ adj.add_builtin_call("assign", [var1, var2])
1085
1446
 
1086
1447
  # reset the symbol to point to the original variable
1087
1448
  adj.symbols[sym] = var1
@@ -1100,95 +1461,132 @@ class Adjoint:
1100
1461
 
1101
1462
  adj.end_while()
1102
1463
 
1103
- def is_num(adj, a):
1104
- # simple constant
1464
+ def eval_num(adj, a):
1105
1465
  if isinstance(a, ast.Num):
1106
- return True
1107
- # expression of form -constant
1108
- elif isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Num):
1109
- return True
1110
- else:
1111
- # try and resolve the expression to an object
1112
- # e.g.: wp.constant in the globals scope
1113
- obj, path = adj.resolve_path(a)
1114
- if warp.types.is_int(obj):
1466
+ return True, a.n
1467
+ if isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Num):
1468
+ return True, -a.operand.n
1469
+
1470
+ # try and resolve the expression to an object
1471
+ # e.g.: wp.constant in the globals scope
1472
+ obj, _ = adj.resolve_static_expression(a)
1473
+
1474
+ if isinstance(obj, Var) and obj.constant is not None:
1475
+ obj = obj.constant
1476
+
1477
+ return warp.types.is_int(obj), obj
1478
+
1479
+ # detects whether a loop contains a break (or continue) statement
1480
+ def contains_break(adj, body):
1481
+ for s in body:
1482
+ if isinstance(s, ast.Break):
1115
1483
  return True
1484
+ elif isinstance(s, ast.Continue):
1485
+ return True
1486
+ elif isinstance(s, ast.If):
1487
+ if adj.contains_break(s.body):
1488
+ return True
1489
+ if adj.contains_break(s.orelse):
1490
+ return True
1116
1491
  else:
1117
- return False
1492
+ # note that nested for or while loops containing a break statement
1493
+ # do not affect the current loop
1494
+ pass
1495
+
1496
+ return False
1497
+
1498
+ # returns a constant range() if unrollable, otherwise None
1499
+ def get_unroll_range(adj, loop):
1500
+ if (
1501
+ not isinstance(loop.iter, ast.Call)
1502
+ or not isinstance(loop.iter.func, ast.Name)
1503
+ or loop.iter.func.id != "range"
1504
+ or len(loop.iter.args) == 0
1505
+ or len(loop.iter.args) > 3
1506
+ ):
1507
+ return None
1118
1508
 
1119
- def eval_num(adj, a):
1120
- if isinstance(a, ast.Num):
1121
- return a.n
1122
- elif isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Num):
1123
- return -a.operand.n
1124
- else:
1125
- # try and resolve the expression to an object
1126
- # e.g.: wp.constant in the globals scope
1127
- obj, path = adj.resolve_path(a)
1128
- if warp.types.is_int(obj):
1129
- return obj
1130
- else:
1131
- return False
1509
+ # if all range() arguments are numeric constants we will unroll
1510
+ # note that this only handles trivial constants, it will not unroll
1511
+ # constant compile-time expressions e.g.: range(0, 3*2)
1512
+
1513
+ # Evaluate the arguments and check that they are numeric constants
1514
+ # It is important to do that in one pass, so that if evaluating these arguments have side effects
1515
+ # the code does not get generated more than once
1516
+ range_args = [adj.eval_num(arg) for arg in loop.iter.args]
1517
+ arg_is_numeric, arg_values = zip(*range_args)
1518
+
1519
+ if all(arg_is_numeric):
1520
+ # All argument are numeric constants
1521
+
1522
+ # range(end)
1523
+ if len(loop.iter.args) == 1:
1524
+ start = 0
1525
+ end = arg_values[0]
1526
+ step = 1
1527
+
1528
+ # range(start, end)
1529
+ elif len(loop.iter.args) == 2:
1530
+ start = arg_values[0]
1531
+ end = arg_values[1]
1532
+ step = 1
1533
+
1534
+ # range(start, end, step)
1535
+ elif len(loop.iter.args) == 3:
1536
+ start = arg_values[0]
1537
+ end = arg_values[1]
1538
+ step = arg_values[2]
1539
+
1540
+ # test if we're above max unroll count
1541
+ max_iters = abs(end - start) // abs(step)
1542
+ max_unroll = adj.builder.options["max_unroll"]
1543
+
1544
+ ok_to_unroll = True
1545
+
1546
+ if max_iters > max_unroll:
1547
+ if warp.config.verbose:
1548
+ print(
1549
+ f"Warning: fixed-size loop count of {max_iters} is larger than the module 'max_unroll' limit of {max_unroll}, will generate dynamic loop."
1550
+ )
1551
+ ok_to_unroll = False
1552
+
1553
+ elif adj.contains_break(loop.body):
1554
+ if warp.config.verbose:
1555
+ print("Warning: 'break' or 'continue' found in loop body, will generate dynamic loop.")
1556
+ ok_to_unroll = False
1557
+
1558
+ if ok_to_unroll:
1559
+ return range(start, end, step)
1560
+
1561
+ # Unroll is not possible, range needs to be valuated dynamically
1562
+ range_call = adj.add_builtin_call(
1563
+ "range",
1564
+ [adj.add_constant(val) if is_numeric else val for is_numeric, val in range_args],
1565
+ )
1566
+ return range_call
1132
1567
 
1133
1568
  def emit_For(adj, node):
1134
1569
  # try and unroll simple range() statements that use constant args
1135
- unrolled = False
1136
-
1137
- if isinstance(node.iter, ast.Call) and node.iter.func.id == "range":
1138
- is_constant = True
1139
- for a in node.iter.args:
1140
- # if all range() arguments are numeric constants we will unroll
1141
- # note that this only handles trivial constants, it will not unroll
1142
- # constant compile-time expressions e.g.: range(0, 3*2)
1143
- if not adj.is_num(a):
1144
- is_constant = False
1145
- break
1146
-
1147
- if is_constant:
1148
- # range(end)
1149
- if len(node.iter.args) == 1:
1150
- start = 0
1151
- end = adj.eval_num(node.iter.args[0])
1152
- step = 1
1153
-
1154
- # range(start, end)
1155
- elif len(node.iter.args) == 2:
1156
- start = adj.eval_num(node.iter.args[0])
1157
- end = adj.eval_num(node.iter.args[1])
1158
- step = 1
1159
-
1160
- # range(start, end, step)
1161
- elif len(node.iter.args) == 3:
1162
- start = adj.eval_num(node.iter.args[0])
1163
- end = adj.eval_num(node.iter.args[1])
1164
- step = adj.eval_num(node.iter.args[2])
1165
-
1166
- # test if we're above max unroll count
1167
- max_iters = abs(end - start) // abs(step)
1168
- max_unroll = adj.builder.options["max_unroll"]
1169
-
1170
- if max_iters > max_unroll:
1171
- if warp.config.verbose:
1172
- print(
1173
- f"Warning: fixed-size loop count of {max_iters} is larger than the module 'max_unroll' limit of {max_unroll}, will generate dynamic loop."
1174
- )
1175
- else:
1176
- # unroll
1177
- for i in range(start, end, step):
1178
- const_iter = adj.add_constant(i)
1179
- var_iter = adj.add_call(warp.context.builtin_functions["int"], [const_iter])
1180
- adj.symbols[node.target.id] = var_iter
1570
+ unroll_range = adj.get_unroll_range(node)
1181
1571
 
1182
- # eval body
1183
- for s in node.body:
1184
- adj.eval(s)
1572
+ if isinstance(unroll_range, range):
1573
+ for i in unroll_range:
1574
+ const_iter = adj.add_constant(i)
1575
+ var_iter = adj.add_builtin_call("int", [const_iter])
1576
+ adj.symbols[node.target.id] = var_iter
1185
1577
 
1186
- unrolled = True
1578
+ # eval body
1579
+ for s in node.body:
1580
+ adj.eval(s)
1187
1581
 
1188
- # couldn't unroll so generate a dynamic loop
1189
- if not unrolled:
1190
- # evaluate the Iterable
1191
- iter = adj.eval(node.iter)
1582
+ # otherwise generate a dynamic loop
1583
+ else:
1584
+ # evaluate the Iterable -- only if not previously evaluated when trying to unroll
1585
+ if unroll_range is not None:
1586
+ # Range has already been evaluated when trying to unroll, do not re-evaluate
1587
+ iter = unroll_range
1588
+ else:
1589
+ iter = adj.eval(node.iter)
1192
1590
 
1193
1591
  adj.symbols[node.target.id] = adj.begin_for(iter)
1194
1592
 
@@ -1217,15 +1615,28 @@ class Adjoint:
1217
1615
  def emit_Expr(adj, node):
1218
1616
  return adj.eval(node.value)
1219
1617
 
1618
+ def check_tid_in_func_error(adj, node):
1619
+ if adj.is_user_function:
1620
+ if hasattr(node.func, "attr") and node.func.attr == "tid":
1621
+ lineno = adj.lineno + adj.fun_lineno
1622
+ line = adj.source_lines[adj.lineno]
1623
+ raise WarpCodegenError(
1624
+ "tid() may only be called from a Warp kernel, not a Warp function. "
1625
+ "Instead, obtain the indices from a @wp.kernel and pass them as "
1626
+ f"arguments to the function {adj.fun_name}, {adj.filename}:{lineno}:\n{line}\n"
1627
+ )
1628
+
1220
1629
  def emit_Call(adj, node):
1630
+ adj.check_tid_in_func_error(node)
1631
+
1221
1632
  # try and lookup function in globals by
1222
1633
  # resolving path (e.g.: module.submodule.attr)
1223
- func, path = adj.resolve_path(node.func)
1634
+ func, path = adj.resolve_static_expression(node.func)
1224
1635
  templates = []
1225
1636
 
1226
- if isinstance(func, warp.context.Function) == False:
1637
+ if not isinstance(func, warp.context.Function):
1227
1638
  if len(path) == 0:
1228
- raise RuntimeError(f"Unrecognized syntax for function call, path not valid: '{node.func}'")
1639
+ raise WarpCodegenError(f"Unknown function or operator: '{node.func.func.id}'")
1229
1640
 
1230
1641
  attr = path[-1]
1231
1642
  caller = func
@@ -1250,7 +1661,7 @@ class Adjoint:
1250
1661
  func = caller.initializer()
1251
1662
 
1252
1663
  if func is None:
1253
- raise RuntimeError(
1664
+ raise WarpCodegenError(
1254
1665
  f"Could not find function {'.'.join(path)} as a built-in or user-defined function. Note that user functions must be annotated with a @wp.func decorator to be called from a kernel."
1255
1666
  )
1256
1667
 
@@ -1259,16 +1670,25 @@ class Adjoint:
1259
1670
  # eval all arguments
1260
1671
  for arg in node.args:
1261
1672
  var = adj.eval(arg)
1673
+ if not is_local_value(var):
1674
+ raise RuntimeError(
1675
+ "Cannot reference a global variable from a kernel unless `wp.constant()` is being used"
1676
+ )
1262
1677
  args.append(var)
1263
1678
 
1264
- # eval all keyword ags
1679
+ # eval all keyword args
1265
1680
  def kwval(kw):
1266
1681
  if isinstance(kw.value, ast.Num):
1267
1682
  return kw.value.n
1268
1683
  elif isinstance(kw.value, ast.Tuple):
1269
- return tuple(adj.eval_num(e) for e in kw.value.elts)
1684
+ arg_is_numeric, arg_values = zip(*(adj.eval_num(e) for e in kw.value.elts))
1685
+ if not all(arg_is_numeric):
1686
+ raise WarpCodegenError(
1687
+ f"All elements of the tuple keyword argument '{kw.name}' must be numeric constants, got '{arg_values}'"
1688
+ )
1689
+ return arg_values
1270
1690
  else:
1271
- return adj.resolve_path(kw.value)[0]
1691
+ return adj.resolve_static_expression(kw.value)[0]
1272
1692
 
1273
1693
  kwds = {kw.arg: kwval(kw) for kw in node.keywords}
1274
1694
 
@@ -1285,10 +1705,26 @@ class Adjoint:
1285
1705
  # the ast.Index node appears in 3.7 versions
1286
1706
  # when performing array slices, e.g.: x = arr[i]
1287
1707
  # but in version 3.8 and higher it does not appear
1708
+
1709
+ if hasattr(node, "is_adjoint"):
1710
+ node.value.is_adjoint = True
1711
+
1288
1712
  return adj.eval(node.value)
1289
1713
 
1290
1714
  def emit_Subscript(adj, node):
1715
+ if hasattr(node.value, "attr") and node.value.attr == "adjoint":
1716
+ # handle adjoint of a variable, i.e. wp.adjoint[var]
1717
+ node.slice.is_adjoint = True
1718
+ var = adj.eval(node.slice)
1719
+ var_name = var.label
1720
+ var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False)
1721
+ return var
1722
+
1291
1723
  target = adj.eval(node.value)
1724
+ if not is_local_value(target):
1725
+ raise RuntimeError(
1726
+ "Cannot reference a global variable from a kernel unless `wp.constant()` is being used"
1727
+ )
1292
1728
 
1293
1729
  indices = []
1294
1730
 
@@ -1308,28 +1744,34 @@ class Adjoint:
1308
1744
  var = adj.eval(node.slice)
1309
1745
  indices.append(var)
1310
1746
 
1311
- if is_array(target.type):
1312
- if len(indices) == target.type.ndim:
1747
+ target_type = strip_reference(target.type)
1748
+ if is_array(target_type):
1749
+ if len(indices) == target_type.ndim:
1313
1750
  # handles array loads (where each dimension has an index specified)
1314
- out = adj.add_call(warp.context.builtin_functions["load"], [target, *indices])
1751
+ out = adj.add_builtin_call("address", [target, *indices])
1315
1752
  else:
1316
1753
  # handles array views (fewer indices than dimensions)
1317
- out = adj.add_call(warp.context.builtin_functions["view"], [target, *indices])
1754
+ out = adj.add_builtin_call("view", [target, *indices])
1318
1755
 
1319
1756
  else:
1320
1757
  # handles non-array type indexing, e.g: vec3, mat33, etc
1321
- out = adj.add_call(warp.context.builtin_functions["index"], [target, *indices])
1758
+ out = adj.add_builtin_call("extract", [target, *indices])
1322
1759
 
1323
1760
  return out
1324
1761
 
1325
1762
  def emit_Assign(adj, node):
1763
+ if len(node.targets) != 1:
1764
+ raise WarpCodegenError("Assigning the same value to multiple variables is not supported")
1765
+
1766
+ lhs = node.targets[0]
1767
+
1326
1768
  # handle the case where we are assigning multiple output variables
1327
- if isinstance(node.targets[0], ast.Tuple):
1769
+ if isinstance(lhs, ast.Tuple):
1328
1770
  # record the expected number of outputs on the node
1329
1771
  # we do this so we can decide which function to
1330
1772
  # call based on the number of expected outputs
1331
1773
  if isinstance(node.value, ast.Call):
1332
- node.value.expects = len(node.targets[0].elts)
1774
+ node.value.expects = len(lhs.elts)
1333
1775
 
1334
1776
  # evaluate values
1335
1777
  if isinstance(node.value, ast.Tuple):
@@ -1338,40 +1780,47 @@ class Adjoint:
1338
1780
  out = adj.eval(node.value)
1339
1781
 
1340
1782
  names = []
1341
- for v in node.targets[0].elts:
1783
+ for v in lhs.elts:
1342
1784
  if isinstance(v, ast.Name):
1343
1785
  names.append(v.id)
1344
1786
  else:
1345
- raise RuntimeError(
1787
+ raise WarpCodegenError(
1346
1788
  "Multiple return functions can only assign to simple variables, e.g.: x, y = func()"
1347
1789
  )
1348
1790
 
1349
1791
  if len(names) != len(out):
1350
- raise RuntimeError(
1351
- "Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {}, got {})".format(
1352
- len(out), len(names)
1353
- )
1792
+ raise WarpCodegenError(
1793
+ f"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {len(out)}, got {len(names)})"
1354
1794
  )
1355
1795
 
1356
1796
  for name, rhs in zip(names, out):
1357
1797
  if name in adj.symbols:
1358
1798
  if not types_equal(rhs.type, adj.symbols[name].type):
1359
- raise TypeError(
1360
- "Error, assigning to existing symbol {} ({}) with different type ({})".format(
1361
- name, adj.symbols[name].type, rhs.type
1362
- )
1799
+ raise WarpCodegenTypeError(
1800
+ f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
1363
1801
  )
1364
1802
 
1365
1803
  adj.symbols[name] = rhs
1366
1804
 
1367
- return out
1368
-
1369
1805
  # handles the case where we are assigning to an array index (e.g.: arr[i] = 2.0)
1370
- elif isinstance(node.targets[0], ast.Subscript):
1371
- target = adj.eval(node.targets[0].value)
1806
+ elif isinstance(lhs, ast.Subscript):
1807
+ if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
1808
+ # handle adjoint of a variable, i.e. wp.adjoint[var]
1809
+ lhs.slice.is_adjoint = True
1810
+ src_var = adj.eval(lhs.slice)
1811
+ var = Var(f"adj_{src_var.label}", type=src_var.type, constant=None, prefix=False)
1812
+ value = adj.eval(node.value)
1813
+ adj.add_forward(f"{var.emit()} = {value.emit()};")
1814
+ return
1815
+
1816
+ target = adj.eval(lhs.value)
1372
1817
  value = adj.eval(node.value)
1818
+ if not is_local_value(value):
1819
+ raise RuntimeError(
1820
+ "Cannot reference a global variable from a kernel unless `wp.constant()` is being used"
1821
+ )
1373
1822
 
1374
- slice = node.targets[0].slice
1823
+ slice = lhs.slice
1375
1824
  indices = []
1376
1825
 
1377
1826
  if isinstance(slice, ast.Tuple):
@@ -1379,7 +1828,6 @@ class Adjoint:
1379
1828
  for arg in slice.elts:
1380
1829
  var = adj.eval(arg)
1381
1830
  indices.append(var)
1382
-
1383
1831
  elif isinstance(slice, ast.Index) and isinstance(slice.value, ast.Tuple):
1384
1832
  # handles the x[i, j] case (Python 3.7.x)
1385
1833
  for arg in slice.value.elts:
@@ -1390,64 +1838,84 @@ class Adjoint:
1390
1838
  var = adj.eval(slice)
1391
1839
  indices.append(var)
1392
1840
 
1393
- if is_array(target.type):
1394
- adj.add_call(warp.context.builtin_functions["store"], [target, *indices, value])
1841
+ target_type = strip_reference(target.type)
1395
1842
 
1396
- elif type_is_vector(target.type) or type_is_matrix(target.type):
1397
- adj.add_call(warp.context.builtin_functions["indexset"], [target, *indices, value])
1843
+ if is_array(target_type):
1844
+ adj.add_builtin_call("array_store", [target, *indices, value])
1398
1845
 
1399
- if warp.config.verbose:
1846
+ elif type_is_vector(target_type) or type_is_matrix(target_type):
1847
+ if is_reference(target.type):
1848
+ attr = adj.add_builtin_call("indexref", [target, *indices])
1849
+ else:
1850
+ attr = adj.add_builtin_call("index", [target, *indices])
1851
+
1852
+ adj.add_builtin_call("store", [attr, value])
1853
+
1854
+ if warp.config.verbose and not adj.custom_reverse_mode:
1400
1855
  lineno = adj.lineno + adj.fun_lineno
1401
- line = adj.source.splitlines()[adj.lineno]
1856
+ line = adj.source_lines[adj.lineno]
1857
+ node_source = adj.get_node_source(lhs.value)
1402
1858
  print(
1403
- f"Warning: mutating {node.targets[0].value.id} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
1859
+ f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
1404
1860
  )
1405
1861
 
1406
1862
  else:
1407
- raise RuntimeError("Can only subscript assign array, vector, and matrix types")
1863
+ raise WarpCodegenError("Can only subscript assign array, vector, and matrix types")
1408
1864
 
1409
- return var
1410
-
1411
- elif isinstance(node.targets[0], ast.Name):
1865
+ elif isinstance(lhs, ast.Name):
1412
1866
  # symbol name
1413
- name = node.targets[0].id
1867
+ name = lhs.id
1414
1868
 
1415
1869
  # evaluate rhs
1416
1870
  rhs = adj.eval(node.value)
1417
1871
 
1418
1872
  # check type matches if symbol already defined
1419
1873
  if name in adj.symbols:
1420
- if not types_equal(rhs.type, adj.symbols[name].type):
1421
- raise TypeError(
1422
- "Error, assigning to existing symbol {} ({}) with different type ({})".format(
1423
- name, adj.symbols[name].type, rhs.type
1424
- )
1874
+ if not types_equal(strip_reference(rhs.type), adj.symbols[name].type):
1875
+ raise WarpCodegenTypeError(
1876
+ f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
1425
1877
  )
1426
1878
 
1427
1879
  # handle simple assignment case (a = b), where we generate a value copy rather than reference
1428
- if isinstance(node.value, ast.Name):
1429
- out = adj.add_var(rhs.type)
1430
- adj.add_call(warp.context.builtin_functions["copy"], [out, rhs])
1880
+ if isinstance(node.value, ast.Name) or is_reference(rhs.type):
1881
+ out = adj.add_builtin_call("copy", [rhs])
1431
1882
  else:
1432
1883
  out = rhs
1433
1884
 
1434
1885
  # update symbol map (assumes lhs is a Name node)
1435
1886
  adj.symbols[name] = out
1436
- return out
1437
1887
 
1438
- elif isinstance(node.targets[0], ast.Attribute):
1888
+ elif isinstance(lhs, ast.Attribute):
1439
1889
  rhs = adj.eval(node.value)
1440
- attr = adj.emit_Attribute(node.targets[0])
1441
- adj.add_call(warp.context.builtin_functions["copy"], [attr, rhs])
1890
+ aggregate = adj.eval(lhs.value)
1891
+ aggregate_type = strip_reference(aggregate.type)
1442
1892
 
1443
- if warp.config.verbose:
1444
- lineno = adj.lineno + adj.fun_lineno
1445
- line = adj.source.splitlines()[adj.lineno]
1446
- msg = f'Warning: detected mutated struct {attr.label} during function "{adj.fun_name}" at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n'
1447
- print(msg)
1893
+ # assigning to a vector component
1894
+ if type_is_vector(aggregate_type):
1895
+ index = adj.vector_component_index(lhs.attr, aggregate_type)
1896
+
1897
+ if is_reference(aggregate.type):
1898
+ attr = adj.add_builtin_call("indexref", [aggregate, index])
1899
+ else:
1900
+ attr = adj.add_builtin_call("index", [aggregate, index])
1901
+
1902
+ adj.add_builtin_call("store", [attr, rhs])
1903
+
1904
+ else:
1905
+ attr = adj.emit_Attribute(lhs)
1906
+ if is_reference(attr.type):
1907
+ adj.add_builtin_call("store", [attr, rhs])
1908
+ else:
1909
+ adj.add_builtin_call("assign", [attr, rhs])
1910
+
1911
+ if warp.config.verbose and not adj.custom_reverse_mode:
1912
+ lineno = adj.lineno + adj.fun_lineno
1913
+ line = adj.source_lines[adj.lineno]
1914
+ msg = f'Warning: detected mutated struct {attr.label} during function "{adj.fun_name}" at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n'
1915
+ print(msg)
1448
1916
 
1449
1917
  else:
1450
- raise RuntimeError("Error, unsupported assignment statement.")
1918
+ raise WarpCodegenError("Error, unsupported assignment statement.")
1451
1919
 
1452
1920
  def emit_Return(adj, node):
1453
1921
  if node.value is None:
@@ -1458,30 +1926,26 @@ class Adjoint:
1458
1926
  var = (adj.eval(node.value),)
1459
1927
 
1460
1928
  if adj.return_var is not None:
1461
- old_ctypes = tuple(v.ctype() for v in adj.return_var)
1462
- new_ctypes = tuple(v.ctype() for v in var)
1929
+ old_ctypes = tuple(v.ctype(value_type=True) for v in adj.return_var)
1930
+ new_ctypes = tuple(v.ctype(value_type=True) for v in var)
1463
1931
  if old_ctypes != new_ctypes:
1464
- raise TypeError(
1932
+ raise WarpCodegenTypeError(
1465
1933
  f"Error, function returned different types, previous: [{', '.join(old_ctypes)}], new [{', '.join(new_ctypes)}]"
1466
1934
  )
1467
- else:
1468
- adj.return_var = var
1469
-
1470
- adj.add_return(var)
1471
-
1472
- def emit_AugAssign(adj, node):
1473
- # convert inplace operations (+=, -=, etc) to ssa form, e.g.: c = a + b
1474
- left = adj.eval(node.target)
1475
- right = adj.eval(node.value)
1476
1935
 
1477
- # lookup
1478
- name = builtin_operators[type(node.op)]
1479
- func = warp.context.builtin_functions[name]
1936
+ if var is not None:
1937
+ adj.return_var = tuple()
1938
+ for ret in var:
1939
+ if is_reference(ret.type):
1940
+ ret = adj.add_builtin_call("copy", [ret])
1941
+ adj.return_var += (ret,)
1480
1942
 
1481
- out = adj.add_call(func, [left, right])
1943
+ adj.add_return(adj.return_var)
1482
1944
 
1483
- # update symbol map
1484
- adj.symbols[node.target.id] = out
1945
+ def emit_AugAssign(adj, node):
1946
+ # replace augmented assignment with assignment statement + binary op
1947
+ new_node = ast.Assign(targets=[node.target], value=ast.BinOp(node.target, node.op, node.value))
1948
+ adj.eval(new_node)
1485
1949
 
1486
1950
  def emit_Tuple(adj, node):
1487
1951
  # LHS for expressions, such as i, j, k = 1, 2, 3
@@ -1491,122 +1955,167 @@ class Adjoint:
1491
1955
  def emit_Pass(adj, node):
1492
1956
  pass
1493
1957
 
1958
+ node_visitors = {
1959
+ ast.FunctionDef: emit_FunctionDef,
1960
+ ast.If: emit_If,
1961
+ ast.Compare: emit_Compare,
1962
+ ast.BoolOp: emit_BoolOp,
1963
+ ast.Name: emit_Name,
1964
+ ast.Attribute: emit_Attribute,
1965
+ ast.Str: emit_String, # Deprecated in 3.8; use Constant
1966
+ ast.Num: emit_Num, # Deprecated in 3.8; use Constant
1967
+ ast.NameConstant: emit_NameConstant, # Deprecated in 3.8; use Constant
1968
+ ast.Constant: emit_Constant,
1969
+ ast.BinOp: emit_BinOp,
1970
+ ast.UnaryOp: emit_UnaryOp,
1971
+ ast.While: emit_While,
1972
+ ast.For: emit_For,
1973
+ ast.Break: emit_Break,
1974
+ ast.Continue: emit_Continue,
1975
+ ast.Expr: emit_Expr,
1976
+ ast.Call: emit_Call,
1977
+ ast.Index: emit_Index, # Deprecated in 3.8; Use the index value directly instead.
1978
+ ast.Subscript: emit_Subscript,
1979
+ ast.Assign: emit_Assign,
1980
+ ast.Return: emit_Return,
1981
+ ast.AugAssign: emit_AugAssign,
1982
+ ast.Tuple: emit_Tuple,
1983
+ ast.Pass: emit_Pass,
1984
+ ast.Ellipsis: emit_Ellipsis,
1985
+ }
1986
+
1494
1987
  def eval(adj, node):
1495
1988
  if hasattr(node, "lineno"):
1496
1989
  adj.set_lineno(node.lineno - 1)
1497
1990
 
1498
- node_visitors = {
1499
- ast.FunctionDef: Adjoint.emit_FunctionDef,
1500
- ast.If: Adjoint.emit_If,
1501
- ast.Compare: Adjoint.emit_Compare,
1502
- ast.BoolOp: Adjoint.emit_BoolOp,
1503
- ast.Name: Adjoint.emit_Name,
1504
- ast.Attribute: Adjoint.emit_Attribute,
1505
- ast.Str: Adjoint.emit_String, # Deprecated in 3.8; use Constant
1506
- ast.Num: Adjoint.emit_Num, # Deprecated in 3.8; use Constant
1507
- ast.NameConstant: Adjoint.emit_NameConstant, # Deprecated in 3.8; use Constant
1508
- ast.Constant: Adjoint.emit_Constant,
1509
- ast.BinOp: Adjoint.emit_BinOp,
1510
- ast.UnaryOp: Adjoint.emit_UnaryOp,
1511
- ast.While: Adjoint.emit_While,
1512
- ast.For: Adjoint.emit_For,
1513
- ast.Break: Adjoint.emit_Break,
1514
- ast.Continue: Adjoint.emit_Continue,
1515
- ast.Expr: Adjoint.emit_Expr,
1516
- ast.Call: Adjoint.emit_Call,
1517
- ast.Index: Adjoint.emit_Index, # Deprecated in 3.8; Use the index value directly instead.
1518
- ast.Subscript: Adjoint.emit_Subscript,
1519
- ast.Assign: Adjoint.emit_Assign,
1520
- ast.Return: Adjoint.emit_Return,
1521
- ast.AugAssign: Adjoint.emit_AugAssign,
1522
- ast.Tuple: Adjoint.emit_Tuple,
1523
- ast.Pass: Adjoint.emit_Pass,
1524
- }
1525
-
1526
- emit_node = node_visitors.get(type(node))
1527
-
1528
- if emit_node is not None:
1529
- return emit_node(adj, node)
1530
- else:
1531
- raise Exception("Error, ast node of type {} not supported".format(type(node)))
1991
+ emit_node = adj.node_visitors[type(node)]
1992
+
1993
+ return emit_node(adj, node)
1532
1994
 
1533
1995
  # helper to evaluate expressions of the form
1534
1996
  # obj1.obj2.obj3.attr in the function's global scope
1535
- def resolve_path(adj, node):
1536
- modules = []
1997
+ def resolve_path(adj, path):
1998
+ if len(path) == 0:
1999
+ return None
1537
2000
 
1538
- while isinstance(node, ast.Attribute):
1539
- modules.append(node.attr)
1540
- node = node.value
2001
+ # if root is overshadowed by local symbols, bail out
2002
+ if path[0] in adj.symbols:
2003
+ return None
1541
2004
 
1542
- if isinstance(node, ast.Name):
1543
- modules.append(node.id)
2005
+ if path[0] in __builtins__:
2006
+ return __builtins__[path[0]]
1544
2007
 
1545
- # reverse list since ast presents it backward order
1546
- path = [*reversed(modules)]
2008
+ # Look up the closure info and append it to adj.func.__globals__
2009
+ # in case you want to define a kernel inside a function and refer
2010
+ # to variables you've declared inside that function:
2011
+ extract_contents = (
2012
+ lambda contents: contents
2013
+ if isinstance(contents, warp.context.Function) or not callable(contents)
2014
+ else contents
2015
+ )
2016
+ capturedvars = dict(
2017
+ zip(
2018
+ adj.func.__code__.co_freevars,
2019
+ [extract_contents(c.cell_contents) for c in (adj.func.__closure__ or [])],
2020
+ )
2021
+ )
2022
+ vars_dict = {**adj.func.__globals__, **capturedvars}
1547
2023
 
1548
- if len(path) == 0:
1549
- return None, path
2024
+ if path[0] in vars_dict:
2025
+ func = vars_dict[path[0]]
1550
2026
 
1551
- # try and evaluate object path
1552
- try:
1553
- # Look up the closure info and append it to adj.func.__globals__
1554
- # in case you want to define a kernel inside a function and refer
1555
- # to variables you've declared inside that function:
1556
- extract_contents = (
1557
- lambda contents: contents
1558
- if isinstance(contents, warp.context.Function) or not callable(contents)
1559
- else contents
1560
- )
1561
- capturedvars = dict(
1562
- zip(
1563
- adj.func.__code__.co_freevars,
1564
- [extract_contents(c.cell_contents) for c in (adj.func.__closure__ or [])],
1565
- )
1566
- )
2027
+ # Support Warp types in kernels without the module suffix (e.g. v = vec3(0.0,0.2,0.4)):
2028
+ else:
2029
+ func = getattr(warp, path[0], None)
1567
2030
 
1568
- vars_dict = {**adj.func.__globals__, **capturedvars}
1569
- func = eval(".".join(path), vars_dict)
1570
- return func, path
1571
- except:
1572
- pass
2031
+ if func:
2032
+ for i in range(1, len(path)):
2033
+ if hasattr(func, path[i]):
2034
+ func = getattr(func, path[i])
1573
2035
 
1574
- # I added this so people can eg do this kind of thing
1575
- # in a kernel:
2036
+ return func
1576
2037
 
1577
- # v = vec3(0.0,0.2,0.4)
2038
+ # Evaluates a static expression that does not depend on runtime values
2039
+ # if eval_types is True, try resolving the path using evaluated type information as well
2040
+ def resolve_static_expression(adj, root_node, eval_types=True):
2041
+ attributes = []
1578
2042
 
1579
- # vec3 is now an alias and is not in warp.context.builtin_functions.
1580
- # This means it can't be directly looked up in Adjoint.add_call, and
1581
- # needs to be looked up by digging some information out of the
1582
- # python object it actually came from.
2043
+ node = root_node
2044
+ while isinstance(node, ast.Attribute):
2045
+ attributes.append(node.attr)
2046
+ node = node.value
1583
2047
 
1584
- # Before this fix, resolve_path was returning None, as the
1585
- # "vec3" symbol is not available. In this situation I'm assuming
1586
- # it's a member of the warp module and trying to look it up:
1587
- try:
1588
- evalstr = ".".join(["warp"] + path)
1589
- func = eval(evalstr, {"warp": warp})
1590
- return func, path
1591
- except:
1592
- return None, path
2048
+ if eval_types and isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
2049
+ # support for operators returning modules
2050
+ # i.e. operator_name(*operator_args).x.y.z
2051
+ operator_args = node.args
2052
+ operator_name = node.func.id
2053
+
2054
+ if operator_name == "type":
2055
+ if len(operator_args) != 1:
2056
+ raise WarpCodegenError(f"type() operator expects exactly one argument, got {len(operator_args)}")
2057
+
2058
+ # type() operator
2059
+ var = adj.eval(operator_args[0])
2060
+
2061
+ if isinstance(var, Var):
2062
+ var_type = strip_reference(var.type)
2063
+ # Allow accessing type attributes, for instance array.dtype
2064
+ while attributes:
2065
+ attr_name = attributes.pop()
2066
+ var_type, prev_type = adj.resolve_type_attribute(var_type, attr_name), var_type
2067
+
2068
+ if var_type is None:
2069
+ raise WarpCodegenAttributeError(
2070
+ f"{attr_name} is not an attribute of {type_repr(prev_type)}"
2071
+ )
2072
+
2073
+ return var_type, [type_repr(var_type)]
2074
+ else:
2075
+ raise WarpCodegenError(f"Cannot deduce the type of {var}")
2076
+
2077
+ # reverse list since ast presents it backward order
2078
+ path = [*reversed(attributes)]
2079
+ if isinstance(node, ast.Name):
2080
+ path.insert(0, node.id)
2081
+
2082
+ # Try resolving path from captured context
2083
+ captured_obj = adj.resolve_path(path)
2084
+ if captured_obj is not None:
2085
+ return captured_obj, path
2086
+
2087
+ # Still nothing found, maybe this is a predefined type attribute like `dtype`
2088
+ if eval_types:
2089
+ try:
2090
+ val = adj.eval(root_node)
2091
+ if val:
2092
+ return [val, type_repr(val)]
2093
+
2094
+ except Exception:
2095
+ pass
2096
+
2097
+ return None, path
1593
2098
 
1594
2099
  # annotate generated code with the original source code line
1595
2100
  def set_lineno(adj, lineno):
1596
2101
  if adj.lineno is None or adj.lineno != lineno:
1597
2102
  line = lineno + adj.fun_lineno
1598
- source = adj.raw_source[lineno].strip().ljust(70)
2103
+ source = adj.source_lines[lineno].strip().ljust(80 - len(adj.indentation), " ")
1599
2104
  adj.add_forward(f"// {source} <L {line}>")
1600
2105
  adj.add_reverse(f"// adj: {source} <L {line}>")
1601
2106
  adj.lineno = lineno
1602
2107
 
2108
+ def get_node_source(adj, node):
2109
+ # return the Python code corresponding to the given AST node
2110
+ return ast.get_source_segment(adj.source, node)
2111
+
1603
2112
 
1604
2113
  # ----------------
1605
2114
  # code generation
1606
2115
 
1607
2116
  cpu_module_header = """
1608
2117
  #define WP_NO_CRT
1609
- #include "../native/builtin.h"
2118
+ #include "builtin.h"
1610
2119
 
1611
2120
  // avoid namespacing of float type for casting to float type, this is to avoid wp::float(x), which is not valid in C++
1612
2121
  #define float(x) cast_float(x)
@@ -1615,13 +2124,16 @@ cpu_module_header = """
1615
2124
  #define int(x) cast_int(x)
1616
2125
  #define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
1617
2126
 
1618
- using namespace wp;
2127
+ #define builtin_tid1d() wp::tid(wp::s_threadIdx)
2128
+ #define builtin_tid2d(x, y) wp::tid(x, y, wp::s_threadIdx, dim)
2129
+ #define builtin_tid3d(x, y, z) wp::tid(x, y, z, wp::s_threadIdx, dim)
2130
+ #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, wp::s_threadIdx, dim)
1619
2131
 
1620
2132
  """
1621
2133
 
1622
2134
  cuda_module_header = """
1623
2135
  #define WP_NO_CRT
1624
- #include "../native/builtin.h"
2136
+ #include "builtin.h"
1625
2137
 
1626
2138
  // avoid namespacing of float type for casting to float type, this is to avoid wp::float(x), which is not valid in C++
1627
2139
  #define float(x) cast_float(x)
@@ -1630,8 +2142,10 @@ cuda_module_header = """
1630
2142
  #define int(x) cast_int(x)
1631
2143
  #define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
1632
2144
 
1633
-
1634
- using namespace wp;
2145
+ #define builtin_tid1d() wp::tid(_idx)
2146
+ #define builtin_tid2d(x, y) wp::tid(x, y, _idx, dim)
2147
+ #define builtin_tid3d(x, y, z) wp::tid(x, y, z, _idx, dim)
2148
+ #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, _idx, dim)
1635
2149
 
1636
2150
  """
1637
2151
 
@@ -1645,54 +2159,56 @@ struct {name}
1645
2159
  {{
1646
2160
  }}
1647
2161
 
1648
- {name}& operator += (const {name}&) {{ return *this; }}
2162
+ CUDA_CALLABLE {name}& operator += (const {name}& rhs)
2163
+ {{{prefix_add_body}
2164
+ return *this;}}
1649
2165
 
1650
2166
  }};
1651
2167
 
1652
2168
  static CUDA_CALLABLE void adj_{name}({reverse_args})
1653
2169
  {{
1654
- {reverse_body}
1655
- }}
2170
+ {reverse_body}}}
1656
2171
 
1657
- CUDA_CALLABLE void atomic_add({name}* p, {name} t)
2172
+ CUDA_CALLABLE void adj_atomic_add({name}* p, {name} t)
1658
2173
  {{
1659
- {atomic_add_body}
1660
- }}
2174
+ {atomic_add_body}}}
1661
2175
 
1662
2176
 
1663
2177
  """
1664
2178
 
1665
- cpu_function_template = """
2179
+ cpu_forward_function_template = """
1666
2180
  // {filename}:{lineno}
1667
2181
  static {return_type} {name}(
1668
2182
  {forward_args})
1669
2183
  {{
1670
- {forward_body}
1671
- }}
2184
+ {forward_body}}}
1672
2185
 
2186
+ """
2187
+
2188
+ cpu_reverse_function_template = """
1673
2189
  // {filename}:{lineno}
1674
2190
  static void adj_{name}(
1675
2191
  {reverse_args})
1676
2192
  {{
1677
- {reverse_body}
1678
- }}
2193
+ {reverse_body}}}
1679
2194
 
1680
2195
  """
1681
2196
 
1682
- cuda_function_template = """
2197
+ cuda_forward_function_template = """
1683
2198
  // {filename}:{lineno}
1684
2199
  static CUDA_CALLABLE {return_type} {name}(
1685
2200
  {forward_args})
1686
2201
  {{
1687
- {forward_body}
1688
- }}
2202
+ {forward_body}}}
1689
2203
 
2204
+ """
2205
+
2206
+ cuda_reverse_function_template = """
1690
2207
  // {filename}:{lineno}
1691
2208
  static CUDA_CALLABLE void adj_{name}(
1692
2209
  {reverse_args})
1693
2210
  {{
1694
- {reverse_body}
1695
- }}
2211
+ {reverse_body}}}
1696
2212
 
1697
2213
  """
1698
2214
 
@@ -1701,25 +2217,21 @@ cuda_kernel_template = """
1701
2217
  extern "C" __global__ void {name}_cuda_kernel_forward(
1702
2218
  {forward_args})
1703
2219
  {{
1704
- size_t _idx = grid_index();
1705
- if (_idx >= dim.size)
1706
- return;
1707
-
1708
- set_launch_bounds(dim);
1709
-
1710
- {forward_body}
2220
+ for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
2221
+ _idx < dim.size;
2222
+ _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
2223
+ {{
2224
+ {forward_body} }}
1711
2225
  }}
1712
2226
 
1713
2227
  extern "C" __global__ void {name}_cuda_kernel_backward(
1714
2228
  {reverse_args})
1715
2229
  {{
1716
- size_t _idx = grid_index();
1717
- if (_idx >= dim.size)
1718
- return;
1719
-
1720
- set_launch_bounds(dim);
1721
-
1722
- {reverse_body}
2230
+ for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
2231
+ _idx < dim.size;
2232
+ _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
2233
+ {{
2234
+ {reverse_body} }}
1723
2235
  }}
1724
2236
 
1725
2237
  """
@@ -1729,39 +2241,12 @@ cpu_kernel_template = """
1729
2241
  void {name}_cpu_kernel_forward(
1730
2242
  {forward_args})
1731
2243
  {{
1732
- {forward_body}
1733
- }}
2244
+ {forward_body}}}
1734
2245
 
1735
2246
  void {name}_cpu_kernel_backward(
1736
2247
  {reverse_args})
1737
2248
  {{
1738
- {reverse_body}
1739
- }}
1740
-
1741
- """
1742
-
1743
- cuda_module_template = """
1744
-
1745
- extern "C" {{
1746
-
1747
- // Python entry points
1748
- WP_API void {name}_cuda_forward(
1749
- void* stream,
1750
- {forward_args})
1751
- {{
1752
- {name}_cuda_kernel_forward<<<(dim.size + 256 - 1) / 256, 256, 0, (cudaStream_t)stream>>>(
1753
- {forward_params});
1754
- }}
1755
-
1756
- WP_API void {name}_cuda_backward(
1757
- void* stream,
1758
- {reverse_args})
1759
- {{
1760
- {name}_cuda_kernel_backward<<<(dim.size + 256 - 1) / 256, 256, 0, (cudaStream_t)stream>>>(
1761
- {reverse_params});
1762
- }}
1763
-
1764
- }} // extern C
2249
+ {reverse_body}}}
1765
2250
 
1766
2251
  """
1767
2252
 
@@ -1773,11 +2258,9 @@ extern "C" {{
1773
2258
  WP_API void {name}_cpu_forward(
1774
2259
  {forward_args})
1775
2260
  {{
1776
- set_launch_bounds(dim);
1777
-
1778
2261
  for (size_t i=0; i < dim.size; ++i)
1779
2262
  {{
1780
- s_threadIdx = i;
2263
+ wp::s_threadIdx = i;
1781
2264
 
1782
2265
  {name}_cpu_kernel_forward(
1783
2266
  {forward_params});
@@ -1787,11 +2270,9 @@ WP_API void {name}_cpu_forward(
1787
2270
  WP_API void {name}_cpu_backward(
1788
2271
  {reverse_args})
1789
2272
  {{
1790
- set_launch_bounds(dim);
1791
-
1792
2273
  for (size_t i=0; i < dim.size; ++i)
1793
2274
  {{
1794
- s_threadIdx = i;
2275
+ wp::s_threadIdx = i;
1795
2276
 
1796
2277
  {name}_cpu_kernel_backward(
1797
2278
  {reverse_params});
@@ -1837,7 +2318,7 @@ WP_API void {name}_cpu_backward(
1837
2318
  def constant_str(value):
1838
2319
  value_type = type(value)
1839
2320
 
1840
- if value_type == bool:
2321
+ if value_type == bool or value_type == builtins.bool:
1841
2322
  if value:
1842
2323
  return "true"
1843
2324
  else:
@@ -1854,7 +2335,9 @@ def constant_str(value):
1854
2335
 
1855
2336
  scalar_value = runtime.core.half_bits_to_float
1856
2337
  else:
1857
- scalar_value = lambda x: x
2338
+
2339
+ def scalar_value(x):
2340
+ return x
1858
2341
 
1859
2342
  # list of scalar initializer values
1860
2343
  initlist = []
@@ -1871,6 +2354,9 @@ def constant_str(value):
1871
2354
  # make sure we emit the value of objects, e.g. uint32
1872
2355
  return str(value.value)
1873
2356
 
2357
+ elif value == math.inf:
2358
+ return "INFINITY"
2359
+
1874
2360
  else:
1875
2361
  # otherwise just convert constant to string
1876
2362
  return str(value)
@@ -1879,7 +2365,7 @@ def constant_str(value):
1879
2365
  def indent(args, stops=1):
1880
2366
  sep = ",\n"
1881
2367
  for i in range(stops):
1882
- sep += "\t"
2368
+ sep += " "
1883
2369
 
1884
2370
  # return sep + args.replace(", ", "," + sep)
1885
2371
  return sep.join(args)
@@ -1887,7 +2373,9 @@ def indent(args, stops=1):
1887
2373
 
1888
2374
  # generates a C function name based on the python function name
1889
2375
  def make_full_qualified_name(func):
1890
- return re.sub("[^0-9a-zA-Z_]+", "", func.__qualname__.replace(".", "__"))
2376
+ if not isinstance(func, str):
2377
+ func = func.__qualname__
2378
+ return re.sub("[^0-9a-zA-Z_]+", "", func.replace(".", "__"))
1891
2379
 
1892
2380
 
1893
2381
  def codegen_struct(struct, device="cpu", indent_size=4):
@@ -1895,8 +2383,13 @@ def codegen_struct(struct, device="cpu", indent_size=4):
1895
2383
 
1896
2384
  body = []
1897
2385
  indent_block = " " * indent_size
1898
- for label, var in struct.vars.items():
1899
- body.append(var.ctype() + " " + label + ";\n")
2386
+
2387
+ if len(struct.vars) > 0:
2388
+ for label, var in struct.vars.items():
2389
+ body.append(var.ctype() + " " + label + ";\n")
2390
+ else:
2391
+ # for empty structs, emit the dummy attribute to avoid any compiler-specific alignment issues
2392
+ body.append("char _dummy_;\n")
1900
2393
 
1901
2394
  forward_args = []
1902
2395
  reverse_args = []
@@ -1904,21 +2397,32 @@ def codegen_struct(struct, device="cpu", indent_size=4):
1904
2397
  forward_initializers = []
1905
2398
  reverse_body = []
1906
2399
  atomic_add_body = []
2400
+ prefix_add_body = []
1907
2401
 
1908
2402
  # forward args
1909
2403
  for label, var in struct.vars.items():
1910
- forward_args.append(f"{var.ctype()} const& {label} = {{}}")
1911
- reverse_args.append(f"{var.ctype()} const&")
2404
+ var_ctype = var.ctype()
2405
+ forward_args.append(f"{var_ctype} const& {label} = {{}}")
2406
+ reverse_args.append(f"{var_ctype} const&")
1912
2407
 
1913
- atomic_add_body.append(f"{indent_block}atomic_add(&p->{label}, t.{label});\n")
2408
+ namespace = "wp::" if var_ctype.startswith("wp::") or var_ctype == "bool" else ""
2409
+ atomic_add_body.append(f"{indent_block}{namespace}adj_atomic_add(&p->{label}, t.{label});\n")
1914
2410
 
1915
2411
  prefix = f"{indent_block}," if forward_initializers else ":"
1916
2412
  forward_initializers.append(f"{indent_block}{prefix} {label}{{{label}}}\n")
1917
2413
 
2414
+ # prefix-add operator
2415
+ for label, var in struct.vars.items():
2416
+ if not is_array(var.type):
2417
+ prefix_add_body.append(f"{indent_block}{label} += rhs.{label};\n")
2418
+
1918
2419
  # reverse args
1919
2420
  for label, var in struct.vars.items():
1920
- reverse_args.append(var.ctype() + " const& adj_" + label)
1921
- reverse_body.append(f"{indent_block}adj_ret.{label} = adj_{label};\n")
2421
+ reverse_args.append(var.ctype() + " & adj_" + label)
2422
+ if is_array(var.type):
2423
+ reverse_body.append(f"{indent_block}adj_{label} = adj_ret.{label};\n")
2424
+ else:
2425
+ reverse_body.append(f"{indent_block}adj_{label} += adj_ret.{label};\n")
1922
2426
 
1923
2427
  reverse_args.append(name + " & adj_ret")
1924
2428
 
@@ -1929,109 +2433,101 @@ def codegen_struct(struct, device="cpu", indent_size=4):
1929
2433
  forward_initializers="".join(forward_initializers),
1930
2434
  reverse_args=indent(reverse_args),
1931
2435
  reverse_body="".join(reverse_body),
2436
+ prefix_add_body="".join(prefix_add_body),
1932
2437
  atomic_add_body="".join(atomic_add_body),
1933
2438
  )
1934
2439
 
1935
2440
 
1936
- def codegen_func_forward_body(adj, device="cpu", indent=4):
1937
- body = []
1938
- indent_block = " " * indent
1939
-
1940
- for f in adj.blocks[0].body_forward:
1941
- body += [f + "\n"]
1942
-
1943
- return "".join([indent_block + l for l in body])
1944
-
1945
-
1946
2441
  def codegen_func_forward(adj, func_type="kernel", device="cpu"):
1947
- s = ""
2442
+ if device == "cpu":
2443
+ indent = 4
2444
+ elif device == "cuda":
2445
+ if func_type == "kernel":
2446
+ indent = 8
2447
+ else:
2448
+ indent = 4
2449
+ else:
2450
+ raise ValueError(f"Device {device} not supported for codegen")
2451
+
2452
+ indent_block = " " * indent
1948
2453
 
1949
2454
  # primal vars
1950
- s += " //---------\n"
1951
- s += " // primal vars\n"
2455
+ lines = []
2456
+ lines += ["//---------\n"]
2457
+ lines += ["// primal vars\n"]
1952
2458
 
1953
2459
  for var in adj.variables:
1954
2460
  if var.constant is None:
1955
- s += " " + var.ctype() + " var_" + str(var.label) + ";\n"
2461
+ lines += [f"{var.ctype()} {var.emit()};\n"]
1956
2462
  else:
1957
- s += " const " + var.ctype() + " var_" + str(var.label) + " = " + constant_str(var.constant) + ";\n"
2463
+ lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
1958
2464
 
1959
2465
  # forward pass
1960
- s += " //---------\n"
1961
- s += " // forward\n"
2466
+ lines += ["//---------\n"]
2467
+ lines += ["// forward\n"]
1962
2468
 
1963
- if device == "cpu":
1964
- s += codegen_func_forward_body(adj, device=device, indent=4)
2469
+ for f in adj.blocks[0].body_forward:
2470
+ lines += [f + "\n"]
2471
+
2472
+ return "".join([indent_block + l for l in lines])
1965
2473
 
2474
+
2475
+ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
2476
+ if device == "cpu":
2477
+ indent = 4
1966
2478
  elif device == "cuda":
1967
2479
  if func_type == "kernel":
1968
- s += codegen_func_forward_body(adj, device=device, indent=8)
2480
+ indent = 8
1969
2481
  else:
1970
- s += codegen_func_forward_body(adj, device=device, indent=4)
1971
-
1972
- return s
1973
-
2482
+ indent = 4
2483
+ else:
2484
+ raise ValueError(f"Device {device} not supported for codegen")
1974
2485
 
1975
- def codegen_func_reverse_body(adj, device="cpu", indent=4):
1976
- body = []
1977
2486
  indent_block = " " * indent
1978
2487
 
1979
- # forward pass
1980
- body += ["//---------\n"]
1981
- body += ["// forward\n"]
1982
-
1983
- for f in adj.blocks[0].body_replay:
1984
- body += [f + "\n"]
1985
-
1986
- # reverse pass
1987
- body += ["//---------\n"]
1988
- body += ["// reverse\n"]
1989
-
1990
- for l in reversed(adj.blocks[0].body_reverse):
1991
- body += [l + "\n"]
1992
-
1993
- body += ["return;\n"]
1994
-
1995
- return "".join([indent_block + l for l in body])
1996
-
1997
-
1998
- def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
1999
- s = ""
2488
+ lines = []
2000
2489
 
2001
2490
  # primal vars
2002
- s += " //---------\n"
2003
- s += " // primal vars\n"
2491
+ lines += ["//---------\n"]
2492
+ lines += ["// primal vars\n"]
2004
2493
 
2005
2494
  for var in adj.variables:
2006
2495
  if var.constant is None:
2007
- s += " " + var.ctype() + " var_" + str(var.label) + ";\n"
2496
+ lines += [f"{var.ctype()} {var.emit()};\n"]
2008
2497
  else:
2009
- s += " const " + var.ctype() + " var_" + str(var.label) + " = " + constant_str(var.constant) + ";\n"
2498
+ lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
2010
2499
 
2011
2500
  # dual vars
2012
- s += " //---------\n"
2013
- s += " // dual vars\n"
2501
+ lines += ["//---------\n"]
2502
+ lines += ["// dual vars\n"]
2014
2503
 
2015
2504
  for var in adj.variables:
2016
- if isinstance(var.type, Struct):
2017
- s += " " + var.ctype() + " adj_" + str(var.label) + ";\n"
2018
- else:
2019
- s += " " + var.ctype() + " adj_" + str(var.label) + "(0);\n"
2505
+ lines += [f"{var.ctype(value_type=True)} {var.emit_adj()} = {{}};\n"]
2020
2506
 
2021
- if device == "cpu":
2022
- s += codegen_func_reverse_body(adj, device=device, indent=4)
2023
- elif device == "cuda":
2024
- if func_type == "kernel":
2025
- s += codegen_func_reverse_body(adj, device=device, indent=8)
2026
- else:
2027
- s += codegen_func_reverse_body(adj, device=device, indent=4)
2507
+ # forward pass
2508
+ lines += ["//---------\n"]
2509
+ lines += ["// forward\n"]
2510
+
2511
+ for f in adj.blocks[0].body_replay:
2512
+ lines += [f + "\n"]
2513
+
2514
+ # reverse pass
2515
+ lines += ["//---------\n"]
2516
+ lines += ["// reverse\n"]
2517
+
2518
+ for l in reversed(adj.blocks[0].body_reverse):
2519
+ lines += [l + "\n"]
2520
+
2521
+ # In grid-stride kernels the reverse body is in a for loop
2522
+ if device == "cuda" and func_type == "kernel":
2523
+ lines += ["continue;\n"]
2028
2524
  else:
2029
- raise ValueError("Device {} not supported for codegen".format(device))
2525
+ lines += ["return;\n"]
2030
2526
 
2031
- return s
2527
+ return "".join([indent_block + l for l in lines])
2032
2528
 
2033
2529
 
2034
- def codegen_func(adj, device="cpu"):
2530
+ def codegen_func(adj, c_func_name: str, device="cpu", options={}):
2035
2531
  # forward header
2036
2532
  if adj.return_var is not None and len(adj.return_var) == 1:
2037
2533
  return_type = adj.return_var[0].ctype()
@@ -2044,16 +2540,20 @@ def codegen_func(adj, device="cpu"):
2044
2540
  reverse_args = []
2045
2541
 
2046
2542
  # forward args
2047
- for arg in adj.args:
2048
- forward_args.append(arg.ctype() + " var_" + arg.label)
2049
- reverse_args.append(arg.ctype() + " var_" + arg.label)
2543
+ for i, arg in enumerate(adj.args):
2544
+ s = f"{arg.ctype()} {arg.emit()}"
2545
+ forward_args.append(s)
2546
+ if not adj.custom_reverse_mode or i < adj.custom_reverse_num_input_args:
2547
+ reverse_args.append(s)
2050
2548
  if has_multiple_outputs:
2051
2549
  for i, arg in enumerate(adj.return_var):
2052
2550
  forward_args.append(arg.ctype() + " & ret_" + str(i))
2053
2551
  reverse_args.append(arg.ctype() + " & ret_" + str(i))
2054
2552
 
2055
2553
  # reverse args
2056
- for arg in adj.args:
2554
+ for i, arg in enumerate(adj.args):
2555
+ if adj.custom_reverse_mode and i >= adj.custom_reverse_num_input_args:
2556
+ break
2057
2557
  # indexed array gradients are regular arrays
2058
2558
  if isinstance(arg.type, indexedarray):
2059
2559
  _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
@@ -2065,24 +2565,96 @@ def codegen_func(adj, device="cpu"):
2065
2565
  reverse_args.append(arg.ctype() + " & adj_ret_" + str(i))
2066
2566
  elif return_type != "void":
2067
2567
  reverse_args.append(return_type + " & adj_ret")
2068
-
2069
- # codegen body
2070
- forward_body = codegen_func_forward(adj, func_type="function", device=device)
2071
- reverse_body = codegen_func_reverse(adj, func_type="function", device=device)
2568
+ # custom output reverse args (user-declared)
2569
+ if adj.custom_reverse_mode:
2570
+ for arg in adj.args[adj.custom_reverse_num_input_args :]:
2571
+ reverse_args.append(f"{arg.ctype()} & {arg.emit()}")
2072
2572
 
2073
2573
  if device == "cpu":
2074
- template = cpu_function_template
2574
+ forward_template = cpu_forward_function_template
2575
+ reverse_template = cpu_reverse_function_template
2075
2576
  elif device == "cuda":
2076
- template = cuda_function_template
2577
+ forward_template = cuda_forward_function_template
2578
+ reverse_template = cuda_reverse_function_template
2077
2579
  else:
2078
- raise ValueError("Device {} is not supported".format(device))
2580
+ raise ValueError(f"Device {device} is not supported")
2079
2581
 
2080
- s = template.format(
2081
- name=make_full_qualified_name(adj.func),
2082
- return_type=return_type,
2582
+ # codegen body
2583
+ forward_body = codegen_func_forward(adj, func_type="function", device=device)
2584
+
2585
+ s = ""
2586
+ if not adj.skip_forward_codegen:
2587
+ s += forward_template.format(
2588
+ name=c_func_name,
2589
+ return_type=return_type,
2590
+ forward_args=indent(forward_args),
2591
+ forward_body=forward_body,
2592
+ filename=adj.filename,
2593
+ lineno=adj.fun_lineno,
2594
+ )
2595
+
2596
+ if not adj.skip_reverse_codegen:
2597
+ if adj.custom_reverse_mode:
2598
+ reverse_body = "\t// user-defined adjoint code\n" + forward_body
2599
+ else:
2600
+ if options.get("enable_backward", True):
2601
+ reverse_body = codegen_func_reverse(adj, func_type="function", device=device)
2602
+ else:
2603
+ reverse_body = '\t// reverse mode disabled (module option "enable_backward" is False)\n'
2604
+ s += reverse_template.format(
2605
+ name=c_func_name,
2606
+ return_type=return_type,
2607
+ reverse_args=indent(reverse_args),
2608
+ forward_body=forward_body,
2609
+ reverse_body=reverse_body,
2610
+ filename=adj.filename,
2611
+ lineno=adj.fun_lineno,
2612
+ )
2613
+
2614
+ return s
2615
+
2616
+
2617
+ def codegen_snippet(adj, name, snippet, adj_snippet):
2618
+ forward_args = []
2619
+ reverse_args = []
2620
+
2621
+ # forward args
2622
+ for i, arg in enumerate(adj.args):
2623
+ s = f"{arg.ctype()} {arg.emit().replace('var_', '')}"
2624
+ forward_args.append(s)
2625
+ reverse_args.append(s)
2626
+
2627
+ # reverse args
2628
+ for i, arg in enumerate(adj.args):
2629
+ if isinstance(arg.type, indexedarray):
2630
+ _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
2631
+ reverse_args.append(_arg.ctype() + " & adj_" + arg.label)
2632
+ else:
2633
+ reverse_args.append(arg.ctype() + " & adj_" + arg.label)
2634
+
2635
+ forward_template = cuda_forward_function_template
2636
+ reverse_template = cuda_reverse_function_template
2637
+
2638
+ s = ""
2639
+ s += forward_template.format(
2640
+ name=name,
2641
+ return_type="void",
2083
2642
  forward_args=indent(forward_args),
2643
+ forward_body=snippet,
2644
+ filename=adj.filename,
2645
+ lineno=adj.fun_lineno,
2646
+ )
2647
+
2648
+ if adj_snippet:
2649
+ reverse_body = adj_snippet
2650
+ else:
2651
+ reverse_body = ""
2652
+
2653
+ s += reverse_template.format(
2654
+ name=name,
2655
+ return_type="void",
2084
2656
  reverse_args=indent(reverse_args),
2085
- forward_body=forward_body,
2657
+ forward_body=snippet,
2086
2658
  reverse_body=reverse_body,
2087
2659
  filename=adj.filename,
2088
2660
  lineno=adj.fun_lineno,
@@ -2098,8 +2670,8 @@ def codegen_kernel(kernel, device, options):
2098
2670
 
2099
2671
  adj = kernel.adj
2100
2672
 
2101
- forward_args = ["launch_bounds_t dim"]
2102
- reverse_args = ["launch_bounds_t dim"]
2673
+ forward_args = ["wp::launch_bounds_t dim"]
2674
+ reverse_args = ["wp::launch_bounds_t dim"]
2103
2675
 
2104
2676
  # forward args
2105
2677
  for arg in adj.args:
@@ -2128,7 +2700,7 @@ def codegen_kernel(kernel, device, options):
2128
2700
  elif device == "cuda":
2129
2701
  template = cuda_kernel_template
2130
2702
  else:
2131
- raise ValueError("Device {} is not supported".format(device))
2703
+ raise ValueError(f"Device {device} is not supported")
2132
2704
 
2133
2705
  s = template.format(
2134
2706
  name=kernel.get_mangled_name(),
@@ -2142,10 +2714,13 @@ def codegen_kernel(kernel, device, options):
2142
2714
 
2143
2715
 
2144
2716
  def codegen_module(kernel, device="cpu"):
2717
+ if device != "cpu":
2718
+ return ""
2719
+
2145
2720
  adj = kernel.adj
2146
2721
 
2147
2722
  # build forward signature
2148
- forward_args = ["launch_bounds_t dim"]
2723
+ forward_args = ["wp::launch_bounds_t dim"]
2149
2724
  forward_params = ["dim"]
2150
2725
 
2151
2726
  for arg in adj.args:
@@ -2175,14 +2750,7 @@ def codegen_module(kernel, device="cpu"):
2175
2750
  reverse_args.append(f"{arg.ctype()} adj_{arg.label}")
2176
2751
  reverse_params.append(f"adj_{arg.label}")
2177
2752
 
2178
- if device == "cpu":
2179
- template = cpu_module_template
2180
- elif device == "cuda":
2181
- template = cuda_module_template
2182
- else:
2183
- raise ValueError("Device {} is not supported".format(device))
2184
-
2185
- s = template.format(
2753
+ s = cpu_module_template.format(
2186
2754
  name=kernel.get_mangled_name(),
2187
2755
  forward_args=indent(forward_args),
2188
2756
  reverse_args=indent(reverse_args),