warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/codegen.py CHANGED
@@ -7,23 +7,40 @@
7
7
 
8
8
  from __future__ import annotations
9
9
 
10
- import re
11
- import sys
12
10
  import ast
13
- import inspect
11
+ import builtins
14
12
  import ctypes
13
+ import inspect
14
+ import math
15
+ import re
16
+ import sys
15
17
  import textwrap
16
18
  import types
19
+ from typing import Any, Callable, Mapping
17
20
 
18
- import numpy as np
21
+ import warp.config
22
+ from warp.types import *
19
23
 
20
- from typing import Any
21
- from typing import Callable
22
- from typing import Mapping
23
- from typing import Union
24
24
 
25
- from warp.types import *
26
- import warp.config
25
+ class WarpCodegenError(RuntimeError):
26
+ def __init__(self, message):
27
+ super().__init__(message)
28
+
29
+
30
+ class WarpCodegenTypeError(TypeError):
31
+ def __init__(self, message):
32
+ super().__init__(message)
33
+
34
+
35
+ class WarpCodegenAttributeError(AttributeError):
36
+ def __init__(self, message):
37
+ super().__init__(message)
38
+
39
+
40
+ class WarpCodegenKeyError(KeyError):
41
+ def __init__(self, message):
42
+ super().__init__(message)
43
+
27
44
 
28
45
  # map operator to function name
29
46
  builtin_operators = {}
@@ -57,6 +74,19 @@ builtin_operators[ast.Invert] = "invert"
57
74
  builtin_operators[ast.LShift] = "lshift"
58
75
  builtin_operators[ast.RShift] = "rshift"
59
76
 
77
+ comparison_chain_strings = [
78
+ builtin_operators[ast.Gt],
79
+ builtin_operators[ast.Lt],
80
+ builtin_operators[ast.LtE],
81
+ builtin_operators[ast.GtE],
82
+ builtin_operators[ast.Eq],
83
+ builtin_operators[ast.NotEq],
84
+ ]
85
+
86
+
87
+ def op_str_is_chainable(op: str) -> builtins.bool:
88
+ return op in comparison_chain_strings
89
+
60
90
 
61
91
  def get_annotations(obj: Any) -> Mapping[str, Any]:
62
92
  """Alternative to `inspect.get_annotations()` for Python 3.9 and older."""
@@ -70,16 +100,14 @@ def get_annotations(obj: Any) -> Mapping[str, Any]:
70
100
  def struct_instance_repr_recursive(inst: StructInstance, depth: int) -> str:
71
101
  indent = "\t"
72
102
 
73
- if inst._cls.ctype._fields_ == [("_dummy_", ctypes.c_int)]:
103
+ # handle empty structs
104
+ if len(inst._cls.vars) == 0:
74
105
  return f"{inst._cls.key}()"
75
106
 
76
107
  lines = []
77
108
  lines.append(f"{inst._cls.key}(")
78
109
 
79
110
  for field_name, _ in inst._cls.ctype._fields_:
80
- if field_name == "_dummy_":
81
- continue
82
-
83
111
  field_value = getattr(inst, field_name, None)
84
112
 
85
113
  if isinstance(field_value, StructInstance):
@@ -126,9 +154,7 @@ class StructInstance:
126
154
  assert isinstance(value, array)
127
155
  assert types_equal(
128
156
  value.dtype, var.type.dtype
129
- ), "assign to struct member variable {} failed, expected type {}, got type {}".format(
130
- name, type_repr(var.type.dtype), type_repr(value.dtype)
131
- )
157
+ ), f"assign to struct member variable {name} failed, expected type {type_repr(var.type.dtype)}, got type {type_repr(value.dtype)}"
132
158
  setattr(self._ctype, name, value.__ctype__())
133
159
 
134
160
  elif isinstance(var.type, Struct):
@@ -247,7 +273,7 @@ class Struct:
247
273
 
248
274
  class StructType(ctypes.Structure):
249
275
  # if struct is empty, add a dummy field to avoid launch errors on CPU device ("ffi_prep_cif failed")
250
- _fields_ = fields or [("_dummy_", ctypes.c_int)]
276
+ _fields_ = fields or [("_dummy_", ctypes.c_byte)]
251
277
 
252
278
  self.ctype = StructType
253
279
 
@@ -368,21 +394,38 @@ class Struct:
368
394
  return instance
369
395
 
370
396
 
397
+ class Reference:
398
+ def __init__(self, value_type):
399
+ self.value_type = value_type
400
+
401
+
402
+ def is_reference(type):
403
+ return isinstance(type, Reference)
404
+
405
+
406
+ def strip_reference(arg):
407
+ if is_reference(arg):
408
+ return arg.value_type
409
+ else:
410
+ return arg
411
+
412
+
371
413
  def compute_type_str(base_name, template_params):
372
- if template_params is None or len(template_params) == 0:
414
+ if not template_params:
373
415
  return base_name
374
- else:
375
416
 
376
- def param2str(p):
377
- if isinstance(p, int):
378
- return str(p)
379
- return p.__name__
417
+ def param2str(p):
418
+ if isinstance(p, int):
419
+ return str(p)
420
+ elif hasattr(p, "_type_"):
421
+ return f"wp::{p.__name__}"
422
+ return p.__name__
380
423
 
381
- return f"{base_name}<{','.join(map(param2str, template_params))}>"
424
+ return f"{base_name}<{','.join(map(param2str, template_params))}>"
382
425
 
383
426
 
384
427
  class Var:
385
- def __init__(self, label, type, requires_grad=False, constant=None):
428
+ def __init__(self, label, type, requires_grad=False, constant=None, prefix=True):
386
429
  # convert built-in types to wp types
387
430
  if type == float:
388
431
  type = float32
@@ -393,26 +436,49 @@ class Var:
393
436
  self.type = type
394
437
  self.requires_grad = requires_grad
395
438
  self.constant = constant
439
+ self.prefix = prefix
396
440
 
397
441
  def __str__(self):
398
442
  return self.label
399
443
 
400
- def ctype(self):
401
- if is_array(self.type):
402
- if hasattr(self.type.dtype, "_wp_generic_type_str_"):
403
- dtypestr = compute_type_str(self.type.dtype._wp_generic_type_str_, self.type.dtype._wp_type_params_)
404
- elif isinstance(self.type.dtype, Struct):
405
- dtypestr = make_full_qualified_name(self.type.dtype.cls)
444
+ @staticmethod
445
+ def type_to_ctype(t, value_type=False):
446
+ if is_array(t):
447
+ if hasattr(t.dtype, "_wp_generic_type_str_"):
448
+ dtypestr = compute_type_str(f"wp::{t.dtype._wp_generic_type_str_}", t.dtype._wp_type_params_)
449
+ elif isinstance(t.dtype, Struct):
450
+ dtypestr = make_full_qualified_name(t.dtype.cls)
451
+ elif t.dtype.__name__ in ("bool", "int", "float"):
452
+ dtypestr = t.dtype.__name__
406
453
  else:
407
- dtypestr = str(self.type.dtype.__name__)
408
- classstr = type(self.type).__name__
454
+ dtypestr = f"wp::{t.dtype.__name__}"
455
+ classstr = f"wp::{type(t).__name__}"
409
456
  return f"{classstr}_t<{dtypestr}>"
410
- elif isinstance(self.type, Struct):
411
- return make_full_qualified_name(self.type.cls)
412
- elif hasattr(self.type, "_wp_generic_type_str_"):
413
- return compute_type_str(self.type._wp_generic_type_str_, self.type._wp_type_params_)
457
+ elif isinstance(t, Struct):
458
+ return make_full_qualified_name(t.cls)
459
+ elif is_reference(t):
460
+ if not value_type:
461
+ return Var.type_to_ctype(t.value_type) + "*"
462
+ else:
463
+ return Var.type_to_ctype(t.value_type)
464
+ elif hasattr(t, "_wp_generic_type_str_"):
465
+ return compute_type_str(f"wp::{t._wp_generic_type_str_}", t._wp_type_params_)
466
+ elif t.__name__ in ("bool", "int", "float"):
467
+ return t.__name__
468
+ else:
469
+ return f"wp::{t.__name__}"
470
+
471
+ def ctype(self, value_type=False):
472
+ return Var.type_to_ctype(self.type, value_type)
473
+
474
+ def emit(self, prefix: str = "var"):
475
+ if self.prefix:
476
+ return f"{prefix}_{self.label}"
414
477
  else:
415
- return str(self.type.__name__)
478
+ return self.label
479
+
480
+ def emit_adj(self):
481
+ return self.emit("adj")
416
482
 
417
483
 
418
484
  class Block:
@@ -429,35 +495,65 @@ class Block:
429
495
  self.vars = []
430
496
 
431
497
 
498
+ def is_local_value(value) -> bool:
499
+ """Check whether a variable is defined inside a kernel."""
500
+ return isinstance(value, (warp.context.Function, Var))
501
+
502
+
432
503
  class Adjoint:
433
504
  # Source code transformer, this class takes a Python function and
434
505
  # generates forward and backward SSA forms of the function instructions
435
506
 
436
- def __init__(adj, func, overload_annotations=None, transformers: List[ast.NodeTransformer] = []):
507
+ def __init__(
508
+ adj,
509
+ func,
510
+ overload_annotations=None,
511
+ is_user_function=False,
512
+ skip_forward_codegen=False,
513
+ skip_reverse_codegen=False,
514
+ custom_reverse_mode=False,
515
+ custom_reverse_num_input_args=-1,
516
+ transformers: List[ast.NodeTransformer] = [],
517
+ ):
437
518
  adj.func = func
438
519
 
439
- # build AST from function object
440
- adj.source = inspect.getsource(func)
520
+ adj.is_user_function = is_user_function
441
521
 
442
- # get source code lines and line number where function starts
443
- adj.raw_source, adj.fun_lineno = inspect.getsourcelines(func)
522
+ # whether the generation of the forward code is skipped for this function
523
+ adj.skip_forward_codegen = skip_forward_codegen
524
+ # whether the generation of the adjoint code is skipped for this function
525
+ adj.skip_reverse_codegen = skip_reverse_codegen
444
526
 
445
- # keep track of line number in function code
446
- adj.lineno = None
527
+ # extract name of source file
528
+ adj.filename = inspect.getsourcefile(func) or "unknown source file"
529
+ # get source file line number where function starts
530
+ _, adj.fun_lineno = inspect.getsourcelines(func)
447
531
 
532
+ # get function source code
533
+ adj.source = inspect.getsource(func)
448
534
  # ensures that indented class methods can be parsed as kernels
449
535
  adj.source = textwrap.dedent(adj.source)
450
536
 
451
- # extract name of source file
452
- adj.filename = inspect.getsourcefile(func) or "unknown source file"
537
+ adj.source_lines = adj.source.splitlines()
453
538
 
454
539
  # build AST and apply node transformers
455
540
  adj.tree = ast.parse(adj.source)
541
+ adj.transformers = transformers
456
542
  for transformer in transformers:
457
543
  adj.tree = transformer.visit(adj.tree)
458
544
 
459
545
  adj.fun_name = adj.tree.body[0].name
460
546
 
547
+ # for keeping track of line number in function code
548
+ adj.lineno = None
549
+
550
+ # whether the forward code shall be used for the reverse pass and a custom
551
+ # function signature is applied to the reverse version of the function
552
+ adj.custom_reverse_mode = custom_reverse_mode
553
+ # the number of function arguments that pertain to the forward function
554
+ # input arguments (i.e. the number of arguments that are not adjoint arguments)
555
+ adj.custom_reverse_num_input_args = custom_reverse_num_input_args
556
+
461
557
  # parse argument types
462
558
  argspec = inspect.getfullargspec(func)
463
559
 
@@ -465,16 +561,17 @@ class Adjoint:
465
561
  if overload_annotations is None:
466
562
  # use source-level argument annotations
467
563
  if len(argspec.annotations) < len(argspec.args):
468
- raise RuntimeError(f"Incomplete argument annotations on function {adj.fun_name}")
564
+ raise WarpCodegenError(f"Incomplete argument annotations on function {adj.fun_name}")
469
565
  adj.arg_types = argspec.annotations
470
566
  else:
471
567
  # use overload argument annotations
472
568
  for arg_name in argspec.args:
473
569
  if arg_name not in overload_annotations:
474
- raise RuntimeError(f"Incomplete overload annotations for function {adj.fun_name}")
570
+ raise WarpCodegenError(f"Incomplete overload annotations for function {adj.fun_name}")
475
571
  adj.arg_types = overload_annotations.copy()
476
572
 
477
573
  adj.args = []
574
+ adj.symbols = {}
478
575
 
479
576
  for name, type in adj.arg_types.items():
480
577
  # skip return hint
@@ -485,8 +582,23 @@ class Adjoint:
485
582
  arg = Var(name, type, False)
486
583
  adj.args.append(arg)
487
584
 
585
+ # pre-populate symbol dictionary with function argument names
586
+ # this is to avoid registering false references to overshadowed modules
587
+ adj.symbols[name] = arg
588
+
589
+ # There are cases where a same module might be rebuilt multiple times,
590
+ # for example when kernels are nested inside of functions, or when
591
+ # a kernel's launch raises an exception. Ideally we'd always want to
592
+ # avoid rebuilding kernels but some corner cases seem to depend on it,
593
+ # so we only avoid rebuilding kernels that errored out to give a chance
594
+ # for unit testing errors being spit out from kernels.
595
+ adj.skip_build = False
596
+
488
597
  # generate function ssa form and adjoint
489
598
  def build(adj, builder):
599
+ if adj.skip_build:
600
+ return
601
+
490
602
  adj.builder = builder
491
603
 
492
604
  adj.symbols = {} # map from symbols to adjoint variables
@@ -500,7 +612,7 @@ class Adjoint:
500
612
  adj.loop_blocks = []
501
613
 
502
614
  # holds current indent level
503
- adj.prefix = ""
615
+ adj.indentation = ""
504
616
 
505
617
  # used to generate new label indices
506
618
  adj.label_count = 0
@@ -514,19 +626,25 @@ class Adjoint:
514
626
  adj.eval(adj.tree.body[0])
515
627
  except Exception as e:
516
628
  try:
629
+ if isinstance(e, KeyError) and getattr(e.args[0], "__module__", None) == "ast":
630
+ msg = f'Syntax error: unsupported construct "ast.{e.args[0].__name__}"'
631
+ else:
632
+ msg = "Error"
517
633
  lineno = adj.lineno + adj.fun_lineno
518
- line = adj.source.splitlines()[adj.lineno]
519
- msg = f'Error while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
634
+ line = adj.source_lines[adj.lineno]
635
+ msg += f' while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
520
636
  ex, data, traceback = sys.exc_info()
521
- e = ex("".join([msg] + list(data.args))).with_traceback(traceback)
637
+ e = ex(";".join([msg] + [str(a) for a in data.args])).with_traceback(traceback)
522
638
  finally:
639
+ adj.skip_build = True
523
640
  raise e
524
641
 
525
- for a in adj.args:
526
- if isinstance(a.type, Struct):
527
- builder.build_struct_recursive(a.type)
528
- elif isinstance(a.type, warp.types.array) and isinstance(a.type.dtype, Struct):
529
- builder.build_struct_recursive(a.type.dtype)
642
+ if builder is not None:
643
+ for a in adj.args:
644
+ if isinstance(a.type, Struct):
645
+ builder.build_struct_recursive(a.type)
646
+ elif isinstance(a.type, warp.types.array) and isinstance(a.type.dtype, Struct):
647
+ builder.build_struct_recursive(a.type.dtype)
530
648
 
531
649
  # code generation methods
532
650
  def format_template(adj, template, input_vars, output_var):
@@ -541,44 +659,56 @@ class Adjoint:
541
659
  arg_strs = []
542
660
 
543
661
  for a in args:
544
- if type(a) == warp.context.Function:
662
+ if isinstance(a, warp.context.Function):
545
663
  # functions don't have a var_ prefix so strip it off here
546
- if prefix == "var_":
664
+ if prefix == "var":
547
665
  arg_strs.append(a.key)
548
666
  else:
549
- arg_strs.append(prefix + a.key)
550
-
667
+ arg_strs.append(f"{prefix}_{a.key}")
668
+ elif is_reference(a.type):
669
+ arg_strs.append(f"{prefix}_{a}")
670
+ elif isinstance(a, Var):
671
+ arg_strs.append(a.emit(prefix))
551
672
  else:
552
- arg_strs.append(prefix + str(a))
673
+ raise WarpCodegenTypeError(f"Arguments must be variables or functions, got {type(a)}")
553
674
 
554
675
  return arg_strs
555
676
 
556
677
  # generates argument string for a forward function call
557
678
  def format_forward_call_args(adj, args, use_initializer_list):
558
- arg_str = ", ".join(adj.format_args("var_", args))
679
+ arg_str = ", ".join(adj.format_args("var", args))
559
680
  if use_initializer_list:
560
- return "{{{}}}".format(arg_str)
681
+ return f"{{{arg_str}}}"
561
682
  return arg_str
562
683
 
563
684
  # generates argument string for a reverse function call
564
- def format_reverse_call_args(adj, args, args_out, non_adjoint_args, non_adjoint_outputs, use_initializer_list):
565
- formatted_var = adj.format_args("var_", args)
685
+ def format_reverse_call_args(
686
+ adj,
687
+ args_var,
688
+ args,
689
+ args_out,
690
+ use_initializer_list,
691
+ has_output_args=True,
692
+ require_original_output_arg=False,
693
+ ):
694
+ formatted_var = adj.format_args("var", args_var)
566
695
  formatted_out = []
567
- if len(args_out) > 1:
568
- formatted_out = adj.format_args("var_", args_out)
696
+ if has_output_args and (require_original_output_arg or len(args_out) > 1):
697
+ formatted_out = adj.format_args("var", args_out)
569
698
  formatted_var_adj = adj.format_args(
570
- "&adj_" if use_initializer_list else "adj_", [a for i, a in enumerate(args) if i not in non_adjoint_args]
699
+ "&adj" if use_initializer_list else "adj",
700
+ args,
571
701
  )
572
- formatted_out_adj = adj.format_args("adj_", [a for i, a in enumerate(args_out) if i not in non_adjoint_outputs])
702
+ formatted_out_adj = adj.format_args("adj", args_out)
573
703
 
574
704
  if len(formatted_var_adj) == 0 and len(formatted_out_adj) == 0:
575
705
  # there are no adjoint arguments, so we don't need to call the reverse function
576
706
  return None
577
707
 
578
708
  if use_initializer_list:
579
- var_str = "{{{}}}".format(", ".join(formatted_var))
580
- out_str = "{{{}}}".format(", ".join(formatted_out))
581
- adj_str = "{{{}}}".format(", ".join(formatted_var_adj))
709
+ var_str = f"{{{', '.join(formatted_var)}}}"
710
+ out_str = f"{{{', '.join(formatted_out)}}}"
711
+ adj_str = f"{{{', '.join(formatted_var_adj)}}}"
582
712
  out_adj_str = ", ".join(formatted_out_adj)
583
713
  if len(args_out) > 1:
584
714
  arg_str = ", ".join([var_str, out_str, adj_str, out_adj_str])
@@ -589,10 +719,10 @@ class Adjoint:
589
719
  return arg_str
590
720
 
591
721
  def indent(adj):
592
- adj.prefix = adj.prefix + "\t"
722
+ adj.indentation = adj.indentation + " "
593
723
 
594
724
  def dedent(adj):
595
- adj.prefix = adj.prefix[0:-1]
725
+ adj.indentation = adj.indentation[:-4]
596
726
 
597
727
  def begin_block(adj):
598
728
  b = Block()
@@ -607,10 +737,9 @@ class Adjoint:
607
737
  def end_block(adj):
608
738
  return adj.blocks.pop()
609
739
 
610
- def add_var(adj, type=None, constant=None, name=None):
611
- if name is None:
612
- index = len(adj.variables)
613
- name = str(index)
740
+ def add_var(adj, type=None, constant=None):
741
+ index = len(adj.variables)
742
+ name = str(index)
614
743
 
615
744
  # allocate new variable
616
745
  v = Var(name, type=type, constant=constant)
@@ -623,30 +752,54 @@ class Adjoint:
623
752
 
624
753
  # append a statement to the forward pass
625
754
  def add_forward(adj, statement, replay=None, skip_replay=False):
626
- adj.blocks[-1].body_forward.append(adj.prefix + statement)
755
+ adj.blocks[-1].body_forward.append(adj.indentation + statement)
627
756
 
628
757
  if not skip_replay:
629
758
  if replay:
630
759
  # if custom replay specified then output it
631
- adj.blocks[-1].body_replay.append(adj.prefix + replay)
760
+ adj.blocks[-1].body_replay.append(adj.indentation + replay)
632
761
  else:
633
762
  # by default just replay the original statement
634
- adj.blocks[-1].body_replay.append(adj.prefix + statement)
763
+ adj.blocks[-1].body_replay.append(adj.indentation + statement)
635
764
 
636
765
  # append a statement to the reverse pass
637
766
  def add_reverse(adj, statement):
638
- adj.blocks[-1].body_reverse.append(adj.prefix + statement)
767
+ adj.blocks[-1].body_reverse.append(adj.indentation + statement)
639
768
 
640
769
  def add_constant(adj, n):
641
770
  output = adj.add_var(type=type(n), constant=n)
642
771
  return output
643
772
 
773
+ def load(adj, var):
774
+ if is_reference(var.type):
775
+ var = adj.add_builtin_call("load", [var])
776
+ return var
777
+
644
778
  def add_comp(adj, op_strings, left, comps):
645
- output = adj.add_var(bool)
779
+ output = adj.add_var(builtins.bool)
780
+
781
+ left = adj.load(left)
782
+ s = output.emit() + " = " + ("(" * len(comps)) + left.emit() + " "
783
+
784
+ prev_comp = None
646
785
 
647
- s = "var_" + str(output) + " = " + ("(" * len(comps)) + "var_" + str(left) + " "
648
786
  for op, comp in zip(op_strings, comps):
649
- s += op + " var_" + str(comp) + ") "
787
+ comp_chainable = op_str_is_chainable(op)
788
+ if comp_chainable and prev_comp:
789
+ # We restrict chaining to operands of the same type
790
+ if prev_comp.type is comp.type:
791
+ prev_comp = adj.load(prev_comp)
792
+ comp = adj.load(comp)
793
+ s += "&& (" + prev_comp.emit() + " " + op + " " + comp.emit() + ")) "
794
+ else:
795
+ raise WarpCodegenTypeError(
796
+ f"Cannot chain comparisons of unequal types: {prev_comp.type} {op} {comp.type}."
797
+ )
798
+ else:
799
+ comp = adj.load(comp)
800
+ s += op + " " + comp.emit() + ") "
801
+
802
+ prev_comp = comp
650
803
 
651
804
  s = s.rstrip() + ";"
652
805
 
@@ -655,110 +808,106 @@ class Adjoint:
655
808
  return output
656
809
 
657
810
  def add_bool_op(adj, op_string, exprs):
658
- output = adj.add_var(bool)
659
- command = (
660
- "var_" + str(output) + " = " + (" " + op_string + " ").join(["var_" + str(expr) for expr in exprs]) + ";"
661
- )
811
+ exprs = [adj.load(expr) for expr in exprs]
812
+ output = adj.add_var(builtins.bool)
813
+ command = output.emit() + " = " + (" " + op_string + " ").join([expr.emit() for expr in exprs]) + ";"
662
814
  adj.add_forward(command)
663
815
 
664
816
  return output
665
817
 
666
- def add_call(adj, func, args, min_outputs=None, templates=[], kwds=None):
667
- # if func is overloaded then perform overload resolution here
668
- # we validate argument types before they go to generated native code
669
- resolved_func = None
818
+ def resolve_func(adj, func, args, min_outputs, templates, kwds):
819
+ arg_types = [strip_reference(a.type) for a in args if not isinstance(a, warp.context.Function)]
670
820
 
671
- if func.is_builtin():
821
+ if not func.is_builtin():
822
+ # user-defined function
823
+ overload = func.get_overload(arg_types)
824
+ if overload is not None:
825
+ return overload
826
+ else:
827
+ # if func is overloaded then perform overload resolution here
828
+ # we validate argument types before they go to generated native code
672
829
  for f in func.overloads:
673
- match = True
674
-
675
830
  # skip type checking for variadic functions
676
831
  if not f.variadic:
677
832
  # check argument counts match are compatible (may be some default args)
678
833
  if len(f.input_types) < len(args):
679
- match = False
680
834
  continue
681
835
 
682
- # check argument types equal
683
- for i, (arg_name, arg_type) in enumerate(f.input_types.items()):
684
- # if arg type registered as Any, treat as
685
- # template allowing any type to match
686
- if arg_type == Any:
687
- continue
688
-
689
- # handle function refs as a special case
690
- if arg_type == Callable and type(args[i]) is warp.context.Function:
691
- continue
692
-
693
- # look for default values for missing args
694
- if i >= len(args):
695
- if arg_name not in f.defaults:
696
- match = False
697
- break
698
- else:
699
- # otherwise check arg type matches input variable type
700
- if not types_equal(arg_type, args[i].type, match_generic=True):
701
- match = False
702
- break
836
+ def match_args(args, f):
837
+ # check argument types equal
838
+ for i, (arg_name, arg_type) in enumerate(f.input_types.items()):
839
+ # if arg type registered as Any, treat as
840
+ # template allowing any type to match
841
+ if arg_type == Any:
842
+ continue
843
+
844
+ # handle function refs as a special case
845
+ if arg_type == Callable and type(args[i]) is warp.context.Function:
846
+ continue
847
+
848
+ if arg_type == Reference and is_reference(args[i].type):
849
+ continue
850
+
851
+ # look for default values for missing args
852
+ if i >= len(args):
853
+ if arg_name not in f.defaults:
854
+ return False
855
+ else:
856
+ # otherwise check arg type matches input variable type
857
+ if not types_equal(arg_type, strip_reference(args[i].type), match_generic=True):
858
+ return False
859
+
860
+ return True
861
+
862
+ if not match_args(args, f):
863
+ continue
703
864
 
704
865
  # check output dimensions match expectations
705
866
  if min_outputs:
706
867
  try:
707
868
  value_type = f.value_func(args, kwds, templates)
708
- if len(value_type) != min_outputs:
709
- match = False
869
+ if not hasattr(value_type, "__len__") or len(value_type) != min_outputs:
710
870
  continue
711
871
  except Exception:
712
872
  # value func may fail if the user has given
713
873
  # incorrect args, so we need to catch this
714
- match = False
715
874
  continue
716
875
 
717
876
  # found a match, use it
718
- if match:
719
- resolved_func = f
720
- break
721
- else:
722
- # user-defined function
723
- arg_types = [a.type for a in args]
724
-
725
- resolved_func = func.get_overload(arg_types)
726
-
727
- if resolved_func is None:
728
- arg_types = []
729
-
730
- for x in args:
731
- if isinstance(x, Var):
732
- # shorten Warp primitive type names
733
- if isinstance(x.type, list):
734
- if len(x.type) != 1:
735
- raise Exception("Argument must not be the result from a multi-valued function")
736
- arg_type = x.type[0]
737
- else:
738
- arg_type = x.type
739
- if arg_type.__module__ == "warp.types":
740
- arg_types.append(arg_type.__name__)
741
- else:
742
- arg_types.append(arg_type.__module__ + "." + arg_type.__name__)
743
-
744
- if isinstance(x, warp.context.Function):
745
- arg_types.append("function")
746
-
747
- raise Exception(
748
- f"Couldn't find function overload for '{func.key}' that matched inputs with types: [{', '.join(arg_types)}]"
749
- )
877
+ return f
878
+
879
+ # unresolved function, report error
880
+ arg_types = []
881
+
882
+ for x in args:
883
+ if isinstance(x, Var):
884
+ # shorten Warp primitive type names
885
+ if isinstance(x.type, list):
886
+ if len(x.type) != 1:
887
+ raise WarpCodegenError("Argument must not be the result from a multi-valued function")
888
+ arg_type = x.type[0]
889
+ else:
890
+ arg_type = x.type
750
891
 
751
- else:
752
- func = resolved_func
892
+ arg_types.append(type_repr(arg_type))
893
+
894
+ if isinstance(x, warp.context.Function):
895
+ arg_types.append("function")
896
+
897
+ raise WarpCodegenError(
898
+ f"Couldn't find function overload for '{func.key}' that matched inputs with types: [{', '.join(arg_types)}]"
899
+ )
900
+
901
+ def add_call(adj, func, args, min_outputs=None, templates=[], kwds=None):
902
+ func = adj.resolve_func(func, args, min_outputs, templates, kwds)
753
903
 
754
904
  # push any default values onto args
755
905
  for i, (arg_name, arg_type) in enumerate(func.input_types.items()):
756
906
  if i >= len(args):
757
- if arg_name in f.defaults:
907
+ if arg_name in func.defaults:
758
908
  const = adj.add_constant(func.defaults[arg_name])
759
909
  args.append(const)
760
910
  else:
761
- match = False
762
911
  break
763
912
 
764
913
  # if it is a user-function then build it recursively
@@ -766,93 +915,105 @@ class Adjoint:
766
915
  adj.builder.build_function(func)
767
916
 
768
917
  # evaluate the function type based on inputs
769
- value_type = func.value_func(args, kwds, templates)
918
+ arg_types = [strip_reference(a.type) for a in args if not isinstance(a, warp.context.Function)]
919
+ return_type = func.value_func(arg_types, kwds, templates)
770
920
 
771
921
  func_name = compute_type_str(func.native_func, templates)
922
+ param_types = list(func.input_types.values())
772
923
 
773
924
  use_initializer_list = func.initializer_list_func(args, templates)
774
925
 
775
- if value_type is None:
776
- # handles expression (zero output) functions, e.g.: void do_something();
777
-
778
- forward_call = "{}{}({});".format(
779
- func.namespace, func_name, adj.format_forward_call_args(args, use_initializer_list)
780
- )
781
- if func.skip_replay:
782
- adj.add_forward(forward_call, replay="//" + forward_call)
783
- else:
784
- adj.add_forward(forward_call)
926
+ args_var = [
927
+ adj.load(a)
928
+ if not ((param_types[i] == Reference or param_types[i] == Callable) if i < len(param_types) else False)
929
+ else a
930
+ for i, a in enumerate(args)
931
+ ]
785
932
 
786
- if not func.missing_grad and len(args):
787
- arg_str = adj.format_reverse_call_args(args, [], {}, {}, use_initializer_list)
788
- if arg_str is not None:
789
- reverse_call = "{}adj_{}({});".format(func.namespace, func.native_func, arg_str)
790
- adj.add_reverse(reverse_call)
933
+ if return_type is None:
934
+ # handles expression (zero output) functions, e.g.: void do_something();
791
935
 
792
- return None
936
+ output = None
937
+ output_list = []
793
938
 
794
- elif not isinstance(value_type, list) or len(value_type) == 1:
795
- # handle simple function (one output)
796
-
797
- if isinstance(value_type, list):
798
- value_type = value_type[0]
799
- output = adj.add_var(value_type)
800
- forward_call = "var_{} = {}{}({});".format(
801
- output, func.namespace, func_name, adj.format_forward_call_args(args, use_initializer_list)
939
+ forward_call = (
940
+ f"{func.namespace}{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
802
941
  )
942
+ replay_call = forward_call
943
+ if func.custom_replay_func is not None:
944
+ replay_call = f"{func.namespace}replay_{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
803
945
 
804
- if func.skip_replay:
805
- adj.add_forward(forward_call, replay="//" + forward_call)
806
- else:
807
- adj.add_forward(forward_call)
946
+ elif not isinstance(return_type, list) or len(return_type) == 1:
947
+ # handle simple function (one output)
808
948
 
809
- if not func.missing_grad and len(args):
810
- arg_str = adj.format_reverse_call_args(args, [output], {}, {}, use_initializer_list)
811
- if arg_str is not None:
812
- reverse_call = "{}adj_{}({});".format(func.namespace, func.native_func, arg_str)
813
- adj.add_reverse(reverse_call)
949
+ if isinstance(return_type, list):
950
+ return_type = return_type[0]
951
+ output = adj.add_var(return_type)
952
+ output_list = [output]
814
953
 
815
- return output
954
+ forward_call = f"var_{output} = {func.namespace}{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
955
+ replay_call = forward_call
956
+ if func.custom_replay_func is not None:
957
+ replay_call = f"var_{output} = {func.namespace}replay_{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
816
958
 
817
959
  else:
818
960
  # handle multiple value functions
819
961
 
820
- output = [adj.add_var(v) for v in value_type]
821
- forward_call = "{}{}({});".format(
822
- func.namespace, func_name, adj.format_forward_call_args(args + output, use_initializer_list)
962
+ output = [adj.add_var(v) for v in return_type]
963
+ output_list = output
964
+
965
+ forward_call = (
966
+ f"{func.namespace}{func_name}({adj.format_forward_call_args(args_var + output, use_initializer_list)});"
823
967
  )
824
- adj.add_forward(forward_call)
968
+ replay_call = forward_call
825
969
 
826
- if not func.missing_grad and len(args):
827
- arg_str = adj.format_reverse_call_args(args, output, {}, {}, use_initializer_list)
828
- if arg_str is not None:
829
- reverse_call = "{}adj_{}({});".format(func.namespace, func.native_func, arg_str)
830
- adj.add_reverse(reverse_call)
970
+ if func.skip_replay:
971
+ adj.add_forward(forward_call, replay="// " + replay_call)
972
+ else:
973
+ adj.add_forward(forward_call, replay=replay_call)
974
+
975
+ if not func.missing_grad and len(args):
976
+ reverse_has_output_args = (
977
+ func.require_original_output_arg or len(output_list) > 1
978
+ ) and func.custom_grad_func is None
979
+ arg_str = adj.format_reverse_call_args(
980
+ args_var,
981
+ args,
982
+ output_list,
983
+ use_initializer_list,
984
+ has_output_args=reverse_has_output_args,
985
+ require_original_output_arg=func.require_original_output_arg,
986
+ )
987
+ if arg_str is not None:
988
+ reverse_call = f"{func.namespace}adj_{func.native_func}({arg_str});"
989
+ adj.add_reverse(reverse_call)
831
990
 
832
- if len(output) == 1:
833
- return output[0]
991
+ return output
834
992
 
835
- return output
993
+ def add_builtin_call(adj, func_name, args, min_outputs=None, templates=[], kwds=None):
994
+ func = warp.context.builtin_functions[func_name]
995
+ return adj.add_call(func, args, min_outputs, templates, kwds)
836
996
 
837
997
  def add_return(adj, var):
838
998
  if var is None or len(var) == 0:
839
- adj.add_forward("return;", "goto label{};".format(adj.label_count))
999
+ adj.add_forward("return;", f"goto label{adj.label_count};")
840
1000
  elif len(var) == 1:
841
- adj.add_forward("return var_{};".format(var[0]), "goto label{};".format(adj.label_count))
1001
+ adj.add_forward(f"return {var[0].emit()};", f"goto label{adj.label_count};")
842
1002
  adj.add_reverse("adj_" + str(var[0]) + " += adj_ret;")
843
1003
  else:
844
1004
  for i, v in enumerate(var):
845
- adj.add_forward("ret_{} = var_{};".format(i, v))
846
- adj.add_reverse("adj_{} += adj_ret_{};".format(v, i))
847
- adj.add_forward("return;", "goto label{};".format(adj.label_count))
1005
+ adj.add_forward(f"ret_{i} = {v.emit()};")
1006
+ adj.add_reverse(f"adj_{v} += adj_ret_{i};")
1007
+ adj.add_forward("return;", f"goto label{adj.label_count};")
848
1008
 
849
- adj.add_reverse("label{}:;".format(adj.label_count))
1009
+ adj.add_reverse(f"label{adj.label_count}:;")
850
1010
 
851
1011
  adj.label_count += 1
852
1012
 
853
1013
  # define an if statement
854
1014
  def begin_if(adj, cond):
855
- adj.add_forward("if (var_{}) {{".format(cond))
1015
+ cond = adj.load(cond)
1016
+ adj.add_forward(f"if ({cond.emit()}) {{")
856
1017
  adj.add_reverse("}")
857
1018
 
858
1019
  adj.indent()
@@ -861,10 +1022,12 @@ class Adjoint:
861
1022
  adj.dedent()
862
1023
 
863
1024
  adj.add_forward("}")
864
- adj.add_reverse(f"if (var_{cond}) {{")
1025
+ cond = adj.load(cond)
1026
+ adj.add_reverse(f"if ({cond.emit()}) {{")
865
1027
 
866
1028
  def begin_else(adj, cond):
867
- adj.add_forward(f"if (!var_{cond}) {{")
1029
+ cond = adj.load(cond)
1030
+ adj.add_forward(f"if (!{cond.emit()}) {{")
868
1031
  adj.add_reverse("}")
869
1032
 
870
1033
  adj.indent()
@@ -873,7 +1036,8 @@ class Adjoint:
873
1036
  adj.dedent()
874
1037
 
875
1038
  adj.add_forward("}")
876
- adj.add_reverse(f"if (!var_{cond}) {{")
1039
+ cond = adj.load(cond)
1040
+ adj.add_reverse(f"if (!{cond.emit()}) {{")
877
1041
 
878
1042
  # define a for-loop
879
1043
  def begin_for(adj, iter):
@@ -883,10 +1047,10 @@ class Adjoint:
883
1047
  adj.indent()
884
1048
 
885
1049
  # evaluate cond
886
- adj.add_forward(f"if (iter_cmp(var_{iter}) == 0) goto for_end_{cond_block.label};")
1050
+ adj.add_forward(f"if (iter_cmp({iter.emit()}) == 0) goto for_end_{cond_block.label};")
887
1051
 
888
1052
  # evaluate iter
889
- val = adj.add_call(warp.context.builtin_functions["iter_next"], [iter])
1053
+ val = adj.add_builtin_call("iter_next", [iter])
890
1054
 
891
1055
  adj.begin_block()
892
1056
 
@@ -917,17 +1081,14 @@ class Adjoint:
917
1081
  reverse = []
918
1082
 
919
1083
  # reverse iterator
920
- reverse.append(adj.prefix + f"var_{iter} = wp::iter_reverse(var_{iter});")
1084
+ reverse.append(adj.indentation + f"{iter.emit()} = wp::iter_reverse({iter.emit()});")
921
1085
 
922
1086
  for i in cond_block.body_forward:
923
1087
  reverse.append(i)
924
1088
 
925
1089
  # zero adjoints
926
1090
  for i in body_block.vars:
927
- if isinstance(i.type, Struct):
928
- reverse.append(adj.prefix + f"\tadj_{i} = {i.ctype()}{{}};")
929
- else:
930
- reverse.append(adj.prefix + f"\tadj_{i} = {i.ctype()}(0);")
1091
+ reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
931
1092
 
932
1093
  # replay
933
1094
  for i in body_block.body_replay:
@@ -937,14 +1098,14 @@ class Adjoint:
937
1098
  for i in reversed(body_block.body_reverse):
938
1099
  reverse.append(i)
939
1100
 
940
- reverse.append(adj.prefix + f"\tgoto for_start_{cond_block.label};")
941
- reverse.append(adj.prefix + f"for_end_{cond_block.label}:;")
1101
+ reverse.append(adj.indentation + f"\tgoto for_start_{cond_block.label};")
1102
+ reverse.append(adj.indentation + f"for_end_{cond_block.label}:;")
942
1103
 
943
1104
  adj.blocks[-1].body_reverse.extend(reversed(reverse))
944
1105
 
945
1106
  # define a while loop
946
1107
  def begin_while(adj, cond):
947
- # evaulate condition in its own block
1108
+ # evaluate condition in its own block
948
1109
  # so we can control replay
949
1110
  cond_block = adj.begin_block()
950
1111
  adj.loop_blocks.append(cond_block)
@@ -952,7 +1113,7 @@ class Adjoint:
952
1113
 
953
1114
  c = adj.eval(cond)
954
1115
 
955
- cond_block.body_forward.append(f"if ((var_{c}) == false) goto while_end_{cond_block.label};")
1116
+ cond_block.body_forward.append(f"if (({c.emit()}) == false) goto while_end_{cond_block.label};")
956
1117
 
957
1118
  # being block around loop
958
1119
  adj.begin_block()
@@ -986,10 +1147,7 @@ class Adjoint:
986
1147
 
987
1148
  # zero adjoints of local vars
988
1149
  for i in body_block.vars:
989
- if isinstance(i.type, Struct):
990
- reverse.append(f"adj_{i} = {i.ctype()}{{}};")
991
- else:
992
- reverse.append(f"adj_{i} = {i.ctype()}(0);")
1150
+ reverse.append(f"{i.emit_adj()} = {{}};")
993
1151
 
994
1152
  # replay
995
1153
  for i in body_block.body_replay:
@@ -1009,6 +1167,10 @@ class Adjoint:
1009
1167
  for f in node.body:
1010
1168
  adj.eval(f)
1011
1169
 
1170
+ if adj.return_var is not None and len(adj.return_var) == 1:
1171
+ if not isinstance(node.body[-1], ast.Return):
1172
+ adj.add_forward("return {};", skip_replay=True)
1173
+
1012
1174
  def emit_If(adj, node):
1013
1175
  if len(node.body) == 0:
1014
1176
  return None
@@ -1036,7 +1198,7 @@ class Adjoint:
1036
1198
 
1037
1199
  if var1 != var2:
1038
1200
  # insert a phi function that selects var1, var2 based on cond
1039
- out = adj.add_call(warp.context.builtin_functions["select"], [cond, var1, var2])
1201
+ out = adj.add_builtin_call("select", [cond, var1, var2])
1040
1202
  adj.symbols[sym] = out
1041
1203
 
1042
1204
  symbols_prev = adj.symbols.copy()
@@ -1060,7 +1222,7 @@ class Adjoint:
1060
1222
  if var1 != var2:
1061
1223
  # insert a phi function that selects var1, var2 based on cond
1062
1224
  # note the reversed order of vars since we want to use !cond as our select
1063
- out = adj.add_call(warp.context.builtin_functions["select"], [cond, var2, var1])
1225
+ out = adj.add_builtin_call("select", [cond, var2, var1])
1064
1226
  adj.symbols[sym] = out
1065
1227
 
1066
1228
  def emit_Compare(adj, node):
@@ -1082,7 +1244,7 @@ class Adjoint:
1082
1244
  elif isinstance(op, ast.Or):
1083
1245
  func = "||"
1084
1246
  else:
1085
- raise KeyError("Op {} is not supported".format(op))
1247
+ raise WarpCodegenKeyError(f"Op {op} is not supported")
1086
1248
 
1087
1249
  return adj.add_bool_op(func, [adj.eval(expr) for expr in node.values])
1088
1250
 
@@ -1102,7 +1264,7 @@ class Adjoint:
1102
1264
  obj = capturedvars.get(str(node.id), None)
1103
1265
 
1104
1266
  if obj is None:
1105
- raise KeyError("Referencing undefined symbol: " + str(node.id))
1267
+ raise WarpCodegenKeyError("Referencing undefined symbol: " + str(node.id))
1106
1268
 
1107
1269
  if warp.types.is_value(obj):
1108
1270
  # evaluate constant
@@ -1114,26 +1276,96 @@ class Adjoint:
1114
1276
  # pass it back to the caller for processing
1115
1277
  return obj
1116
1278
 
1279
+ @staticmethod
1280
+ def resolve_type_attribute(var_type: type, attr: str):
1281
+ if isinstance(var_type, type) and type_is_value(var_type):
1282
+ if attr == "dtype":
1283
+ return type_scalar_type(var_type)
1284
+ elif attr == "length":
1285
+ return type_length(var_type)
1286
+
1287
+ return getattr(var_type, attr, None)
1288
+
1289
+ def vector_component_index(adj, component, vector_type):
1290
+ if len(component) != 1:
1291
+ raise WarpCodegenAttributeError(f"Vector swizzle must be single character, got .{component}")
1292
+
1293
+ dim = vector_type._shape_[0]
1294
+ swizzles = "xyzw"[0:dim]
1295
+ if component not in swizzles:
1296
+ raise WarpCodegenAttributeError(
1297
+ f"Vector swizzle for {vector_type} must be one of {swizzles}, got {component}"
1298
+ )
1299
+
1300
+ index = swizzles.index(component)
1301
+ index = adj.add_constant(index)
1302
+ return index
1303
+
1304
+ @staticmethod
1305
+ def is_differentiable_value_type(var_type):
1306
+ # checks that the argument type is a value type (i.e, not an array)
1307
+ # possibly holding differentiable values (for which gradients must be accumulated)
1308
+ return type_scalar_type(var_type) in float_types or isinstance(var_type, Struct)
1309
+
1117
1310
  def emit_Attribute(adj, node):
1118
- try:
1119
- val = adj.eval(node.value)
1311
+ if hasattr(node, "is_adjoint"):
1312
+ node.value.is_adjoint = True
1313
+
1314
+ aggregate = adj.eval(node.value)
1120
1315
 
1121
- if isinstance(val, types.ModuleType) or isinstance(val, type):
1122
- out = getattr(val, node.attr)
1316
+ try:
1317
+ if isinstance(aggregate, types.ModuleType) or isinstance(aggregate, type):
1318
+ out = getattr(aggregate, node.attr)
1123
1319
 
1124
1320
  if warp.types.is_value(out):
1125
1321
  return adj.add_constant(out)
1126
1322
 
1127
1323
  return out
1128
1324
 
1129
- # create a Var that points to the struct attribute, i.e.: directly generates `struct.attr` when used
1130
- attr_name = val.label + "." + node.attr
1131
- attr_type = val.type.vars[node.attr].type
1325
+ if hasattr(node, "is_adjoint"):
1326
+ # create a Var that points to the struct attribute, i.e.: directly generates `struct.attr` when used
1327
+ attr_name = aggregate.label + "." + node.attr
1328
+ attr_type = aggregate.type.vars[node.attr].type
1329
+
1330
+ return Var(attr_name, attr_type)
1331
+
1332
+ aggregate_type = strip_reference(aggregate.type)
1333
+
1334
+ # reading a vector component
1335
+ if type_is_vector(aggregate_type):
1336
+ index = adj.vector_component_index(node.attr, aggregate_type)
1337
+
1338
+ return adj.add_builtin_call("extract", [aggregate, index])
1339
+
1340
+ else:
1341
+ attr_type = Reference(aggregate_type.vars[node.attr].type)
1342
+ attr = adj.add_var(attr_type)
1343
+
1344
+ if is_reference(aggregate.type):
1345
+ adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}->{node.attr});")
1346
+ else:
1347
+ adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}.{node.attr});")
1348
+
1349
+ if adj.is_differentiable_value_type(strip_reference(attr_type)):
1350
+ adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} += {attr.emit_adj()};")
1351
+ else:
1352
+ adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} = {attr.emit_adj()};")
1353
+
1354
+ return attr
1132
1355
 
1133
- return Var(attr_name, attr_type)
1356
+ except (KeyError, AttributeError):
1357
+ # Try resolving as type attribute
1358
+ aggregate_type = strip_reference(aggregate.type) if isinstance(aggregate, Var) else aggregate
1134
1359
 
1135
- except KeyError:
1136
- raise RuntimeError(f"Error, `{node.attr}` is not an attribute of '{val.label}' ({val.type})")
1360
+ type_attribute = adj.resolve_type_attribute(aggregate_type, node.attr)
1361
+ if type_attribute is not None:
1362
+ return type_attribute
1363
+
1364
+ if isinstance(aggregate, Var):
1365
+ raise WarpCodegenAttributeError(
1366
+ f"Error, `{node.attr}` is not an attribute of '{node.value.id}' ({type_repr(aggregate.type)})"
1367
+ )
1368
+ raise WarpCodegenAttributeError(f"Error, `{node.attr}` is not an attribute of '{aggregate}'")
1137
1369
 
1138
1370
  def emit_String(adj, node):
1139
1371
  # string constant
@@ -1150,19 +1382,25 @@ class Adjoint:
1150
1382
  adj.symbols[key] = out
1151
1383
  return out
1152
1384
 
1385
+ def emit_Ellipsis(adj, node):
1386
+ # stubbed @wp.native_func
1387
+ return
1388
+
1153
1389
  def emit_NameConstant(adj, node):
1154
- if node.value == True:
1390
+ if node.value:
1155
1391
  return adj.add_constant(True)
1156
- elif node.value == False:
1157
- return adj.add_constant(False)
1158
1392
  elif node.value is None:
1159
- raise TypeError("None type unsupported")
1393
+ raise WarpCodegenTypeError("None type unsupported")
1394
+ else:
1395
+ return adj.add_constant(False)
1160
1396
 
1161
1397
  def emit_Constant(adj, node):
1162
1398
  if isinstance(node, ast.Str):
1163
1399
  return adj.emit_String(node)
1164
1400
  elif isinstance(node, ast.Num):
1165
1401
  return adj.emit_Num(node)
1402
+ elif isinstance(node, ast.Ellipsis):
1403
+ return adj.emit_Ellipsis(node)
1166
1404
  else:
1167
1405
  assert isinstance(node, ast.NameConstant)
1168
1406
  return adj.emit_NameConstant(node)
@@ -1173,18 +1411,16 @@ class Adjoint:
1173
1411
  right = adj.eval(node.right)
1174
1412
 
1175
1413
  name = builtin_operators[type(node.op)]
1176
- func = warp.context.builtin_functions[name]
1177
1414
 
1178
- return adj.add_call(func, [left, right])
1415
+ return adj.add_builtin_call(name, [left, right])
1179
1416
 
1180
1417
  def emit_UnaryOp(adj, node):
1181
1418
  # evaluate unary op arguments
1182
1419
  arg = adj.eval(node.operand)
1183
1420
 
1184
1421
  name = builtin_operators[type(node.op)]
1185
- func = warp.context.builtin_functions[name]
1186
1422
 
1187
- return adj.add_call(func, [arg])
1423
+ return adj.add_builtin_call(name, [arg])
1188
1424
 
1189
1425
  def materialize_redefinitions(adj, symbols):
1190
1426
  # detect symbols with conflicting definitions (assigned inside the for loop)
@@ -1194,21 +1430,19 @@ class Adjoint:
1194
1430
  var2 = adj.symbols[sym]
1195
1431
 
1196
1432
  if var1 != var2:
1197
- if warp.config.verbose:
1433
+ if warp.config.verbose and not adj.custom_reverse_mode:
1198
1434
  lineno = adj.lineno + adj.fun_lineno
1199
- line = adj.source.splitlines()[adj.lineno]
1200
- msg = f'Warning: detected mutated variable {sym} during a dynamic for-loop in function "{adj.fun_name}" at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n'
1435
+ line = adj.source_lines[adj.lineno]
1436
+ msg = f'Warning: detected mutated variable {sym} during a dynamic for-loop in function "{adj.fun_name}" at {adj.filename}:{lineno}: this may not be a differentiable operation.\n{line}\n'
1201
1437
  print(msg)
1202
1438
 
1203
1439
  if var1.constant is not None:
1204
- raise Exception(
1205
- "Error mutating a constant {} inside a dynamic loop, use the following syntax: pi = float(3.141) to declare a dynamic variable".format(
1206
- sym
1207
- )
1440
+ raise WarpCodegenError(
1441
+ f"Error mutating a constant {sym} inside a dynamic loop, use the following syntax: pi = float(3.141) to declare a dynamic variable"
1208
1442
  )
1209
1443
 
1210
1444
  # overwrite the old variable value (violates SSA)
1211
- adj.add_call(warp.context.builtin_functions["copy"], [var1, var2])
1445
+ adj.add_builtin_call("assign", [var1, var2])
1212
1446
 
1213
1447
  # reset the symbol to point to the original variable
1214
1448
  adj.symbols[sym] = var1
@@ -1227,35 +1461,20 @@ class Adjoint:
1227
1461
 
1228
1462
  adj.end_while()
1229
1463
 
1230
- def is_num(adj, a):
1231
- # simple constant
1232
- if isinstance(a, ast.Num):
1233
- return True
1234
- # expression of form -constant
1235
- elif isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Num):
1236
- return True
1237
- else:
1238
- # try and resolve the expression to an object
1239
- # e.g.: wp.constant in the globals scope
1240
- obj, path = adj.resolve_path(a)
1241
- if warp.types.is_int(obj):
1242
- return True
1243
- else:
1244
- return False
1245
-
1246
1464
  def eval_num(adj, a):
1247
1465
  if isinstance(a, ast.Num):
1248
- return a.n
1249
- elif isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Num):
1250
- return -a.operand.n
1251
- else:
1252
- # try and resolve the expression to an object
1253
- # e.g.: wp.constant in the globals scope
1254
- obj, path = adj.resolve_path(a)
1255
- if warp.types.is_int(obj):
1256
- return obj
1257
- else:
1258
- return False
1466
+ return True, a.n
1467
+ if isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Num):
1468
+ return True, -a.operand.n
1469
+
1470
+ # try and resolve the expression to an object
1471
+ # e.g.: wp.constant in the globals scope
1472
+ obj, _ = adj.resolve_static_expression(a)
1473
+
1474
+ if isinstance(obj, Var) and obj.constant is not None:
1475
+ obj = obj.constant
1476
+
1477
+ return warp.types.is_int(obj), obj
1259
1478
 
1260
1479
  # detects whether a loop contains a break (or continue) statement
1261
1480
  def contains_break(adj, body):
@@ -1278,61 +1497,82 @@ class Adjoint:
1278
1497
 
1279
1498
  # returns a constant range() if unrollable, otherwise None
1280
1499
  def get_unroll_range(adj, loop):
1281
- if not isinstance(loop.iter, ast.Call) or loop.iter.func.id != "range":
1500
+ if (
1501
+ not isinstance(loop.iter, ast.Call)
1502
+ or not isinstance(loop.iter.func, ast.Name)
1503
+ or loop.iter.func.id != "range"
1504
+ or len(loop.iter.args) == 0
1505
+ or len(loop.iter.args) > 3
1506
+ ):
1282
1507
  return None
1283
1508
 
1284
- for a in loop.iter.args:
1285
- # if all range() arguments are numeric constants we will unroll
1286
- # note that this only handles trivial constants, it will not unroll
1287
- # constant compile-time expressions e.g.: range(0, 3*2)
1288
- if not adj.is_num(a):
1289
- return None
1290
-
1291
- # range(end)
1292
- if len(loop.iter.args) == 1:
1293
- start = 0
1294
- end = adj.eval_num(loop.iter.args[0])
1295
- step = 1
1296
-
1297
- # range(start, end)
1298
- elif len(loop.iter.args) == 2:
1299
- start = adj.eval_num(loop.iter.args[0])
1300
- end = adj.eval_num(loop.iter.args[1])
1301
- step = 1
1302
-
1303
- # range(start, end, step)
1304
- elif len(loop.iter.args) == 3:
1305
- start = adj.eval_num(loop.iter.args[0])
1306
- end = adj.eval_num(loop.iter.args[1])
1307
- step = adj.eval_num(loop.iter.args[2])
1308
-
1309
- # test if we're above max unroll count
1310
- max_iters = abs(end - start) // abs(step)
1311
- max_unroll = adj.builder.options["max_unroll"]
1312
-
1313
- if max_iters > max_unroll:
1314
- if warp.config.verbose:
1315
- print(
1316
- f"Warning: fixed-size loop count of {max_iters} is larger than the module 'max_unroll' limit of {max_unroll}, will generate dynamic loop."
1317
- )
1318
- return None
1509
+ # if all range() arguments are numeric constants we will unroll
1510
+ # note that this only handles trivial constants, it will not unroll
1511
+ # constant compile-time expressions e.g.: range(0, 3*2)
1319
1512
 
1320
- if adj.contains_break(loop.body):
1321
- if warp.config.verbose:
1322
- print("Warning: 'break' or 'continue' found in loop body, will generate dynamic loop.")
1323
- return None
1513
+ # Evaluate the arguments and check that they are numeric constants
1514
+ # It is important to do that in one pass, so that if evaluating these arguments have side effects
1515
+ # the code does not get generated more than once
1516
+ range_args = [adj.eval_num(arg) for arg in loop.iter.args]
1517
+ arg_is_numeric, arg_values = zip(*range_args)
1518
+
1519
+ if all(arg_is_numeric):
1520
+ # All argument are numeric constants
1521
+
1522
+ # range(end)
1523
+ if len(loop.iter.args) == 1:
1524
+ start = 0
1525
+ end = arg_values[0]
1526
+ step = 1
1527
+
1528
+ # range(start, end)
1529
+ elif len(loop.iter.args) == 2:
1530
+ start = arg_values[0]
1531
+ end = arg_values[1]
1532
+ step = 1
1533
+
1534
+ # range(start, end, step)
1535
+ elif len(loop.iter.args) == 3:
1536
+ start = arg_values[0]
1537
+ end = arg_values[1]
1538
+ step = arg_values[2]
1539
+
1540
+ # test if we're above max unroll count
1541
+ max_iters = abs(end - start) // abs(step)
1542
+ max_unroll = adj.builder.options["max_unroll"]
1324
1543
 
1325
- # unroll
1326
- return range(start, end, step)
1544
+ ok_to_unroll = True
1545
+
1546
+ if max_iters > max_unroll:
1547
+ if warp.config.verbose:
1548
+ print(
1549
+ f"Warning: fixed-size loop count of {max_iters} is larger than the module 'max_unroll' limit of {max_unroll}, will generate dynamic loop."
1550
+ )
1551
+ ok_to_unroll = False
1552
+
1553
+ elif adj.contains_break(loop.body):
1554
+ if warp.config.verbose:
1555
+ print("Warning: 'break' or 'continue' found in loop body, will generate dynamic loop.")
1556
+ ok_to_unroll = False
1557
+
1558
+ if ok_to_unroll:
1559
+ return range(start, end, step)
1560
+
1561
+ # Unroll is not possible, range needs to be valuated dynamically
1562
+ range_call = adj.add_builtin_call(
1563
+ "range",
1564
+ [adj.add_constant(val) if is_numeric else val for is_numeric, val in range_args],
1565
+ )
1566
+ return range_call
1327
1567
 
1328
1568
  def emit_For(adj, node):
1329
1569
  # try and unroll simple range() statements that use constant args
1330
1570
  unroll_range = adj.get_unroll_range(node)
1331
1571
 
1332
- if unroll_range:
1572
+ if isinstance(unroll_range, range):
1333
1573
  for i in unroll_range:
1334
1574
  const_iter = adj.add_constant(i)
1335
- var_iter = adj.add_call(warp.context.builtin_functions["int"], [const_iter])
1575
+ var_iter = adj.add_builtin_call("int", [const_iter])
1336
1576
  adj.symbols[node.target.id] = var_iter
1337
1577
 
1338
1578
  # eval body
@@ -1341,8 +1581,12 @@ class Adjoint:
1341
1581
 
1342
1582
  # otherwise generate a dynamic loop
1343
1583
  else:
1344
- # evaluate the Iterable
1345
- iter = adj.eval(node.iter)
1584
+ # evaluate the Iterable -- only if not previously evaluated when trying to unroll
1585
+ if unroll_range is not None:
1586
+ # Range has already been evaluated when trying to unroll, do not re-evaluate
1587
+ iter = unroll_range
1588
+ else:
1589
+ iter = adj.eval(node.iter)
1346
1590
 
1347
1591
  adj.symbols[node.target.id] = adj.begin_for(iter)
1348
1592
 
@@ -1371,15 +1615,28 @@ class Adjoint:
1371
1615
  def emit_Expr(adj, node):
1372
1616
  return adj.eval(node.value)
1373
1617
 
1618
+ def check_tid_in_func_error(adj, node):
1619
+ if adj.is_user_function:
1620
+ if hasattr(node.func, "attr") and node.func.attr == "tid":
1621
+ lineno = adj.lineno + adj.fun_lineno
1622
+ line = adj.source_lines[adj.lineno]
1623
+ raise WarpCodegenError(
1624
+ "tid() may only be called from a Warp kernel, not a Warp function. "
1625
+ "Instead, obtain the indices from a @wp.kernel and pass them as "
1626
+ f"arguments to the function {adj.fun_name}, {adj.filename}:{lineno}:\n{line}\n"
1627
+ )
1628
+
1374
1629
  def emit_Call(adj, node):
1630
+ adj.check_tid_in_func_error(node)
1631
+
1375
1632
  # try and lookup function in globals by
1376
1633
  # resolving path (e.g.: module.submodule.attr)
1377
- func, path = adj.resolve_path(node.func)
1634
+ func, path = adj.resolve_static_expression(node.func)
1378
1635
  templates = []
1379
1636
 
1380
- if isinstance(func, warp.context.Function) == False:
1637
+ if not isinstance(func, warp.context.Function):
1381
1638
  if len(path) == 0:
1382
- raise RuntimeError(f"Unrecognized syntax for function call, path not valid: '{node.func}'")
1639
+ raise WarpCodegenError(f"Unknown function or operator: '{node.func.func.id}'")
1383
1640
 
1384
1641
  attr = path[-1]
1385
1642
  caller = func
@@ -1404,7 +1661,7 @@ class Adjoint:
1404
1661
  func = caller.initializer()
1405
1662
 
1406
1663
  if func is None:
1407
- raise RuntimeError(
1664
+ raise WarpCodegenError(
1408
1665
  f"Could not find function {'.'.join(path)} as a built-in or user-defined function. Note that user functions must be annotated with a @wp.func decorator to be called from a kernel."
1409
1666
  )
1410
1667
 
@@ -1413,16 +1670,25 @@ class Adjoint:
1413
1670
  # eval all arguments
1414
1671
  for arg in node.args:
1415
1672
  var = adj.eval(arg)
1673
+ if not is_local_value(var):
1674
+ raise RuntimeError(
1675
+ "Cannot reference a global variable from a kernel unless `wp.constant()` is being used"
1676
+ )
1416
1677
  args.append(var)
1417
1678
 
1418
- # eval all keyword ags
1679
+ # eval all keyword args
1419
1680
  def kwval(kw):
1420
1681
  if isinstance(kw.value, ast.Num):
1421
1682
  return kw.value.n
1422
1683
  elif isinstance(kw.value, ast.Tuple):
1423
- return tuple(adj.eval_num(e) for e in kw.value.elts)
1684
+ arg_is_numeric, arg_values = zip(*(adj.eval_num(e) for e in kw.value.elts))
1685
+ if not all(arg_is_numeric):
1686
+ raise WarpCodegenError(
1687
+ f"All elements of the tuple keyword argument '{kw.name}' must be numeric constants, got '{arg_values}'"
1688
+ )
1689
+ return arg_values
1424
1690
  else:
1425
- return adj.resolve_path(kw.value)[0]
1691
+ return adj.resolve_static_expression(kw.value)[0]
1426
1692
 
1427
1693
  kwds = {kw.arg: kwval(kw) for kw in node.keywords}
1428
1694
 
@@ -1439,10 +1705,26 @@ class Adjoint:
1439
1705
  # the ast.Index node appears in 3.7 versions
1440
1706
  # when performing array slices, e.g.: x = arr[i]
1441
1707
  # but in version 3.8 and higher it does not appear
1708
+
1709
+ if hasattr(node, "is_adjoint"):
1710
+ node.value.is_adjoint = True
1711
+
1442
1712
  return adj.eval(node.value)
1443
1713
 
1444
1714
  def emit_Subscript(adj, node):
1715
+ if hasattr(node.value, "attr") and node.value.attr == "adjoint":
1716
+ # handle adjoint of a variable, i.e. wp.adjoint[var]
1717
+ node.slice.is_adjoint = True
1718
+ var = adj.eval(node.slice)
1719
+ var_name = var.label
1720
+ var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False)
1721
+ return var
1722
+
1445
1723
  target = adj.eval(node.value)
1724
+ if not is_local_value(target):
1725
+ raise RuntimeError(
1726
+ "Cannot reference a global variable from a kernel unless `wp.constant()` is being used"
1727
+ )
1446
1728
 
1447
1729
  indices = []
1448
1730
 
@@ -1462,28 +1744,34 @@ class Adjoint:
1462
1744
  var = adj.eval(node.slice)
1463
1745
  indices.append(var)
1464
1746
 
1465
- if is_array(target.type):
1466
- if len(indices) == target.type.ndim:
1747
+ target_type = strip_reference(target.type)
1748
+ if is_array(target_type):
1749
+ if len(indices) == target_type.ndim:
1467
1750
  # handles array loads (where each dimension has an index specified)
1468
- out = adj.add_call(warp.context.builtin_functions["load"], [target, *indices])
1751
+ out = adj.add_builtin_call("address", [target, *indices])
1469
1752
  else:
1470
1753
  # handles array views (fewer indices than dimensions)
1471
- out = adj.add_call(warp.context.builtin_functions["view"], [target, *indices])
1754
+ out = adj.add_builtin_call("view", [target, *indices])
1472
1755
 
1473
1756
  else:
1474
1757
  # handles non-array type indexing, e.g: vec3, mat33, etc
1475
- out = adj.add_call(warp.context.builtin_functions["index"], [target, *indices])
1758
+ out = adj.add_builtin_call("extract", [target, *indices])
1476
1759
 
1477
1760
  return out
1478
1761
 
1479
1762
  def emit_Assign(adj, node):
1763
+ if len(node.targets) != 1:
1764
+ raise WarpCodegenError("Assigning the same value to multiple variables is not supported")
1765
+
1766
+ lhs = node.targets[0]
1767
+
1480
1768
  # handle the case where we are assigning multiple output variables
1481
- if isinstance(node.targets[0], ast.Tuple):
1769
+ if isinstance(lhs, ast.Tuple):
1482
1770
  # record the expected number of outputs on the node
1483
1771
  # we do this so we can decide which function to
1484
1772
  # call based on the number of expected outputs
1485
1773
  if isinstance(node.value, ast.Call):
1486
- node.value.expects = len(node.targets[0].elts)
1774
+ node.value.expects = len(lhs.elts)
1487
1775
 
1488
1776
  # evaluate values
1489
1777
  if isinstance(node.value, ast.Tuple):
@@ -1492,40 +1780,47 @@ class Adjoint:
1492
1780
  out = adj.eval(node.value)
1493
1781
 
1494
1782
  names = []
1495
- for v in node.targets[0].elts:
1783
+ for v in lhs.elts:
1496
1784
  if isinstance(v, ast.Name):
1497
1785
  names.append(v.id)
1498
1786
  else:
1499
- raise RuntimeError(
1787
+ raise WarpCodegenError(
1500
1788
  "Multiple return functions can only assign to simple variables, e.g.: x, y = func()"
1501
1789
  )
1502
1790
 
1503
1791
  if len(names) != len(out):
1504
- raise RuntimeError(
1505
- "Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {}, got {})".format(
1506
- len(out), len(names)
1507
- )
1792
+ raise WarpCodegenError(
1793
+ f"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {len(out)}, got {len(names)})"
1508
1794
  )
1509
1795
 
1510
1796
  for name, rhs in zip(names, out):
1511
1797
  if name in adj.symbols:
1512
1798
  if not types_equal(rhs.type, adj.symbols[name].type):
1513
- raise TypeError(
1514
- "Error, assigning to existing symbol {} ({}) with different type ({})".format(
1515
- name, adj.symbols[name].type, rhs.type
1516
- )
1799
+ raise WarpCodegenTypeError(
1800
+ f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
1517
1801
  )
1518
1802
 
1519
1803
  adj.symbols[name] = rhs
1520
1804
 
1521
- return out
1522
-
1523
1805
  # handles the case where we are assigning to an array index (e.g.: arr[i] = 2.0)
1524
- elif isinstance(node.targets[0], ast.Subscript):
1525
- target = adj.eval(node.targets[0].value)
1806
+ elif isinstance(lhs, ast.Subscript):
1807
+ if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
1808
+ # handle adjoint of a variable, i.e. wp.adjoint[var]
1809
+ lhs.slice.is_adjoint = True
1810
+ src_var = adj.eval(lhs.slice)
1811
+ var = Var(f"adj_{src_var.label}", type=src_var.type, constant=None, prefix=False)
1812
+ value = adj.eval(node.value)
1813
+ adj.add_forward(f"{var.emit()} = {value.emit()};")
1814
+ return
1815
+
1816
+ target = adj.eval(lhs.value)
1526
1817
  value = adj.eval(node.value)
1818
+ if not is_local_value(value):
1819
+ raise RuntimeError(
1820
+ "Cannot reference a global variable from a kernel unless `wp.constant()` is being used"
1821
+ )
1527
1822
 
1528
- slice = node.targets[0].slice
1823
+ slice = lhs.slice
1529
1824
  indices = []
1530
1825
 
1531
1826
  if isinstance(slice, ast.Tuple):
@@ -1533,7 +1828,6 @@ class Adjoint:
1533
1828
  for arg in slice.elts:
1534
1829
  var = adj.eval(arg)
1535
1830
  indices.append(var)
1536
-
1537
1831
  elif isinstance(slice, ast.Index) and isinstance(slice.value, ast.Tuple):
1538
1832
  # handles the x[i, j] case (Python 3.7.x)
1539
1833
  for arg in slice.value.elts:
@@ -1544,64 +1838,84 @@ class Adjoint:
1544
1838
  var = adj.eval(slice)
1545
1839
  indices.append(var)
1546
1840
 
1547
- if is_array(target.type):
1548
- adj.add_call(warp.context.builtin_functions["store"], [target, *indices, value])
1841
+ target_type = strip_reference(target.type)
1549
1842
 
1550
- elif type_is_vector(target.type) or type_is_matrix(target.type):
1551
- adj.add_call(warp.context.builtin_functions["indexset"], [target, *indices, value])
1843
+ if is_array(target_type):
1844
+ adj.add_builtin_call("array_store", [target, *indices, value])
1552
1845
 
1553
- if warp.config.verbose:
1846
+ elif type_is_vector(target_type) or type_is_matrix(target_type):
1847
+ if is_reference(target.type):
1848
+ attr = adj.add_builtin_call("indexref", [target, *indices])
1849
+ else:
1850
+ attr = adj.add_builtin_call("index", [target, *indices])
1851
+
1852
+ adj.add_builtin_call("store", [attr, value])
1853
+
1854
+ if warp.config.verbose and not adj.custom_reverse_mode:
1554
1855
  lineno = adj.lineno + adj.fun_lineno
1555
- line = adj.source.splitlines()[adj.lineno]
1856
+ line = adj.source_lines[adj.lineno]
1857
+ node_source = adj.get_node_source(lhs.value)
1556
1858
  print(
1557
- f"Warning: mutating {node.targets[0].value.id} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
1859
+ f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
1558
1860
  )
1559
1861
 
1560
1862
  else:
1561
- raise RuntimeError("Can only subscript assign array, vector, and matrix types")
1562
-
1563
- return var
1863
+ raise WarpCodegenError("Can only subscript assign array, vector, and matrix types")
1564
1864
 
1565
- elif isinstance(node.targets[0], ast.Name):
1865
+ elif isinstance(lhs, ast.Name):
1566
1866
  # symbol name
1567
- name = node.targets[0].id
1867
+ name = lhs.id
1568
1868
 
1569
1869
  # evaluate rhs
1570
1870
  rhs = adj.eval(node.value)
1571
1871
 
1572
1872
  # check type matches if symbol already defined
1573
1873
  if name in adj.symbols:
1574
- if not types_equal(rhs.type, adj.symbols[name].type):
1575
- raise TypeError(
1576
- "Error, assigning to existing symbol {} ({}) with different type ({})".format(
1577
- name, adj.symbols[name].type, rhs.type
1578
- )
1874
+ if not types_equal(strip_reference(rhs.type), adj.symbols[name].type):
1875
+ raise WarpCodegenTypeError(
1876
+ f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
1579
1877
  )
1580
1878
 
1581
1879
  # handle simple assignment case (a = b), where we generate a value copy rather than reference
1582
- if isinstance(node.value, ast.Name):
1583
- out = adj.add_var(rhs.type)
1584
- adj.add_call(warp.context.builtin_functions["copy"], [out, rhs])
1880
+ if isinstance(node.value, ast.Name) or is_reference(rhs.type):
1881
+ out = adj.add_builtin_call("copy", [rhs])
1585
1882
  else:
1586
1883
  out = rhs
1587
1884
 
1588
1885
  # update symbol map (assumes lhs is a Name node)
1589
1886
  adj.symbols[name] = out
1590
- return out
1591
1887
 
1592
- elif isinstance(node.targets[0], ast.Attribute):
1888
+ elif isinstance(lhs, ast.Attribute):
1593
1889
  rhs = adj.eval(node.value)
1594
- attr = adj.emit_Attribute(node.targets[0])
1595
- adj.add_call(warp.context.builtin_functions["copy"], [attr, rhs])
1890
+ aggregate = adj.eval(lhs.value)
1891
+ aggregate_type = strip_reference(aggregate.type)
1596
1892
 
1597
- if warp.config.verbose:
1598
- lineno = adj.lineno + adj.fun_lineno
1599
- line = adj.source.splitlines()[adj.lineno]
1600
- msg = f'Warning: detected mutated struct {attr.label} during function "{adj.fun_name}" at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n'
1601
- print(msg)
1893
+ # assigning to a vector component
1894
+ if type_is_vector(aggregate_type):
1895
+ index = adj.vector_component_index(lhs.attr, aggregate_type)
1896
+
1897
+ if is_reference(aggregate.type):
1898
+ attr = adj.add_builtin_call("indexref", [aggregate, index])
1899
+ else:
1900
+ attr = adj.add_builtin_call("index", [aggregate, index])
1901
+
1902
+ adj.add_builtin_call("store", [attr, rhs])
1903
+
1904
+ else:
1905
+ attr = adj.emit_Attribute(lhs)
1906
+ if is_reference(attr.type):
1907
+ adj.add_builtin_call("store", [attr, rhs])
1908
+ else:
1909
+ adj.add_builtin_call("assign", [attr, rhs])
1910
+
1911
+ if warp.config.verbose and not adj.custom_reverse_mode:
1912
+ lineno = adj.lineno + adj.fun_lineno
1913
+ line = adj.source_lines[adj.lineno]
1914
+ msg = f'Warning: detected mutated struct {attr.label} during function "{adj.fun_name}" at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n'
1915
+ print(msg)
1602
1916
 
1603
1917
  else:
1604
- raise RuntimeError("Error, unsupported assignment statement.")
1918
+ raise WarpCodegenError("Error, unsupported assignment statement.")
1605
1919
 
1606
1920
  def emit_Return(adj, node):
1607
1921
  if node.value is None:
@@ -1612,30 +1926,26 @@ class Adjoint:
1612
1926
  var = (adj.eval(node.value),)
1613
1927
 
1614
1928
  if adj.return_var is not None:
1615
- old_ctypes = tuple(v.ctype() for v in adj.return_var)
1616
- new_ctypes = tuple(v.ctype() for v in var)
1929
+ old_ctypes = tuple(v.ctype(value_type=True) for v in adj.return_var)
1930
+ new_ctypes = tuple(v.ctype(value_type=True) for v in var)
1617
1931
  if old_ctypes != new_ctypes:
1618
- raise TypeError(
1932
+ raise WarpCodegenTypeError(
1619
1933
  f"Error, function returned different types, previous: [{', '.join(old_ctypes)}], new [{', '.join(new_ctypes)}]"
1620
1934
  )
1621
- else:
1622
- adj.return_var = var
1623
-
1624
- adj.add_return(var)
1625
1935
 
1626
- def emit_AugAssign(adj, node):
1627
- # convert inplace operations (+=, -=, etc) to ssa form, e.g.: c = a + b
1628
- left = adj.eval(node.target)
1629
- right = adj.eval(node.value)
1630
-
1631
- # lookup
1632
- name = builtin_operators[type(node.op)]
1633
- func = warp.context.builtin_functions[name]
1936
+ if var is not None:
1937
+ adj.return_var = tuple()
1938
+ for ret in var:
1939
+ if is_reference(ret.type):
1940
+ ret = adj.add_builtin_call("copy", [ret])
1941
+ adj.return_var += (ret,)
1634
1942
 
1635
- out = adj.add_call(func, [left, right])
1943
+ adj.add_return(adj.return_var)
1636
1944
 
1637
- # update symbol map
1638
- adj.symbols[node.target.id] = out
1945
+ def emit_AugAssign(adj, node):
1946
+ # replace augmented assignment with assignment statement + binary op
1947
+ new_node = ast.Assign(targets=[node.target], value=ast.BinOp(node.target, node.op, node.value))
1948
+ adj.eval(new_node)
1639
1949
 
1640
1950
  def emit_Tuple(adj, node):
1641
1951
  # LHS for expressions, such as i, j, k = 1, 2, 3
@@ -1645,115 +1955,160 @@ class Adjoint:
1645
1955
  def emit_Pass(adj, node):
1646
1956
  pass
1647
1957
 
1958
+ node_visitors = {
1959
+ ast.FunctionDef: emit_FunctionDef,
1960
+ ast.If: emit_If,
1961
+ ast.Compare: emit_Compare,
1962
+ ast.BoolOp: emit_BoolOp,
1963
+ ast.Name: emit_Name,
1964
+ ast.Attribute: emit_Attribute,
1965
+ ast.Str: emit_String, # Deprecated in 3.8; use Constant
1966
+ ast.Num: emit_Num, # Deprecated in 3.8; use Constant
1967
+ ast.NameConstant: emit_NameConstant, # Deprecated in 3.8; use Constant
1968
+ ast.Constant: emit_Constant,
1969
+ ast.BinOp: emit_BinOp,
1970
+ ast.UnaryOp: emit_UnaryOp,
1971
+ ast.While: emit_While,
1972
+ ast.For: emit_For,
1973
+ ast.Break: emit_Break,
1974
+ ast.Continue: emit_Continue,
1975
+ ast.Expr: emit_Expr,
1976
+ ast.Call: emit_Call,
1977
+ ast.Index: emit_Index, # Deprecated in 3.8; Use the index value directly instead.
1978
+ ast.Subscript: emit_Subscript,
1979
+ ast.Assign: emit_Assign,
1980
+ ast.Return: emit_Return,
1981
+ ast.AugAssign: emit_AugAssign,
1982
+ ast.Tuple: emit_Tuple,
1983
+ ast.Pass: emit_Pass,
1984
+ ast.Ellipsis: emit_Ellipsis,
1985
+ }
1986
+
1648
1987
  def eval(adj, node):
1649
1988
  if hasattr(node, "lineno"):
1650
1989
  adj.set_lineno(node.lineno - 1)
1651
1990
 
1652
- node_visitors = {
1653
- ast.FunctionDef: Adjoint.emit_FunctionDef,
1654
- ast.If: Adjoint.emit_If,
1655
- ast.Compare: Adjoint.emit_Compare,
1656
- ast.BoolOp: Adjoint.emit_BoolOp,
1657
- ast.Name: Adjoint.emit_Name,
1658
- ast.Attribute: Adjoint.emit_Attribute,
1659
- ast.Str: Adjoint.emit_String, # Deprecated in 3.8; use Constant
1660
- ast.Num: Adjoint.emit_Num, # Deprecated in 3.8; use Constant
1661
- ast.NameConstant: Adjoint.emit_NameConstant, # Deprecated in 3.8; use Constant
1662
- ast.Constant: Adjoint.emit_Constant,
1663
- ast.BinOp: Adjoint.emit_BinOp,
1664
- ast.UnaryOp: Adjoint.emit_UnaryOp,
1665
- ast.While: Adjoint.emit_While,
1666
- ast.For: Adjoint.emit_For,
1667
- ast.Break: Adjoint.emit_Break,
1668
- ast.Continue: Adjoint.emit_Continue,
1669
- ast.Expr: Adjoint.emit_Expr,
1670
- ast.Call: Adjoint.emit_Call,
1671
- ast.Index: Adjoint.emit_Index, # Deprecated in 3.8; Use the index value directly instead.
1672
- ast.Subscript: Adjoint.emit_Subscript,
1673
- ast.Assign: Adjoint.emit_Assign,
1674
- ast.Return: Adjoint.emit_Return,
1675
- ast.AugAssign: Adjoint.emit_AugAssign,
1676
- ast.Tuple: Adjoint.emit_Tuple,
1677
- ast.Pass: Adjoint.emit_Pass,
1678
- }
1679
-
1680
- emit_node = node_visitors.get(type(node))
1681
-
1682
- if emit_node is not None:
1683
- return emit_node(adj, node)
1684
- else:
1685
- raise Exception("Error, ast node of type {} not supported".format(type(node)))
1991
+ emit_node = adj.node_visitors[type(node)]
1992
+
1993
+ return emit_node(adj, node)
1686
1994
 
1687
1995
  # helper to evaluate expressions of the form
1688
1996
  # obj1.obj2.obj3.attr in the function's global scope
1689
- def resolve_path(adj, node):
1690
- modules = []
1997
+ def resolve_path(adj, path):
1998
+ if len(path) == 0:
1999
+ return None
1691
2000
 
1692
- while isinstance(node, ast.Attribute):
1693
- modules.append(node.attr)
1694
- node = node.value
2001
+ # if root is overshadowed by local symbols, bail out
2002
+ if path[0] in adj.symbols:
2003
+ return None
1695
2004
 
1696
- if isinstance(node, ast.Name):
1697
- modules.append(node.id)
2005
+ if path[0] in __builtins__:
2006
+ return __builtins__[path[0]]
1698
2007
 
1699
- # reverse list since ast presents it backward order
1700
- path = [*reversed(modules)]
2008
+ # Look up the closure info and append it to adj.func.__globals__
2009
+ # in case you want to define a kernel inside a function and refer
2010
+ # to variables you've declared inside that function:
2011
+ extract_contents = (
2012
+ lambda contents: contents
2013
+ if isinstance(contents, warp.context.Function) or not callable(contents)
2014
+ else contents
2015
+ )
2016
+ capturedvars = dict(
2017
+ zip(
2018
+ adj.func.__code__.co_freevars,
2019
+ [extract_contents(c.cell_contents) for c in (adj.func.__closure__ or [])],
2020
+ )
2021
+ )
2022
+ vars_dict = {**adj.func.__globals__, **capturedvars}
1701
2023
 
1702
- if len(path) == 0:
1703
- return None, path
2024
+ if path[0] in vars_dict:
2025
+ func = vars_dict[path[0]]
1704
2026
 
1705
- # try and evaluate object path
1706
- try:
1707
- # Look up the closure info and append it to adj.func.__globals__
1708
- # in case you want to define a kernel inside a function and refer
1709
- # to variables you've declared inside that function:
1710
- extract_contents = (
1711
- lambda contents: contents
1712
- if isinstance(contents, warp.context.Function) or not callable(contents)
1713
- else contents
1714
- )
1715
- capturedvars = dict(
1716
- zip(
1717
- adj.func.__code__.co_freevars,
1718
- [extract_contents(c.cell_contents) for c in (adj.func.__closure__ or [])],
1719
- )
1720
- )
2027
+ # Support Warp types in kernels without the module suffix (e.g. v = vec3(0.0,0.2,0.4)):
2028
+ else:
2029
+ func = getattr(warp, path[0], None)
1721
2030
 
1722
- vars_dict = {**adj.func.__globals__, **capturedvars}
1723
- func = eval(".".join(path), vars_dict)
1724
- return func, path
1725
- except:
1726
- pass
2031
+ if func:
2032
+ for i in range(1, len(path)):
2033
+ if hasattr(func, path[i]):
2034
+ func = getattr(func, path[i])
1727
2035
 
1728
- # I added this so people can eg do this kind of thing
1729
- # in a kernel:
2036
+ return func
1730
2037
 
1731
- # v = vec3(0.0,0.2,0.4)
2038
+ # Evaluates a static expression that does not depend on runtime values
2039
+ # if eval_types is True, try resolving the path using evaluated type information as well
2040
+ def resolve_static_expression(adj, root_node, eval_types=True):
2041
+ attributes = []
1732
2042
 
1733
- # vec3 is now an alias and is not in warp.context.builtin_functions.
1734
- # This means it can't be directly looked up in Adjoint.add_call, and
1735
- # needs to be looked up by digging some information out of the
1736
- # python object it actually came from.
2043
+ node = root_node
2044
+ while isinstance(node, ast.Attribute):
2045
+ attributes.append(node.attr)
2046
+ node = node.value
1737
2047
 
1738
- # Before this fix, resolve_path was returning None, as the
1739
- # "vec3" symbol is not available. In this situation I'm assuming
1740
- # it's a member of the warp module and trying to look it up:
1741
- try:
1742
- evalstr = ".".join(["warp"] + path)
1743
- func = eval(evalstr, {"warp": warp})
1744
- return func, path
1745
- except:
1746
- return None, path
2048
+ if eval_types and isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
2049
+ # support for operators returning modules
2050
+ # i.e. operator_name(*operator_args).x.y.z
2051
+ operator_args = node.args
2052
+ operator_name = node.func.id
2053
+
2054
+ if operator_name == "type":
2055
+ if len(operator_args) != 1:
2056
+ raise WarpCodegenError(f"type() operator expects exactly one argument, got {len(operator_args)}")
2057
+
2058
+ # type() operator
2059
+ var = adj.eval(operator_args[0])
2060
+
2061
+ if isinstance(var, Var):
2062
+ var_type = strip_reference(var.type)
2063
+ # Allow accessing type attributes, for instance array.dtype
2064
+ while attributes:
2065
+ attr_name = attributes.pop()
2066
+ var_type, prev_type = adj.resolve_type_attribute(var_type, attr_name), var_type
2067
+
2068
+ if var_type is None:
2069
+ raise WarpCodegenAttributeError(
2070
+ f"{attr_name} is not an attribute of {type_repr(prev_type)}"
2071
+ )
2072
+
2073
+ return var_type, [type_repr(var_type)]
2074
+ else:
2075
+ raise WarpCodegenError(f"Cannot deduce the type of {var}")
2076
+
2077
+ # reverse list since ast presents it backward order
2078
+ path = [*reversed(attributes)]
2079
+ if isinstance(node, ast.Name):
2080
+ path.insert(0, node.id)
2081
+
2082
+ # Try resolving path from captured context
2083
+ captured_obj = adj.resolve_path(path)
2084
+ if captured_obj is not None:
2085
+ return captured_obj, path
2086
+
2087
+ # Still nothing found, maybe this is a predefined type attribute like `dtype`
2088
+ if eval_types:
2089
+ try:
2090
+ val = adj.eval(root_node)
2091
+ if val:
2092
+ return [val, type_repr(val)]
2093
+
2094
+ except Exception:
2095
+ pass
2096
+
2097
+ return None, path
1747
2098
 
1748
2099
  # annotate generated code with the original source code line
1749
2100
  def set_lineno(adj, lineno):
1750
2101
  if adj.lineno is None or adj.lineno != lineno:
1751
2102
  line = lineno + adj.fun_lineno
1752
- source = adj.raw_source[lineno].strip().ljust(70)
2103
+ source = adj.source_lines[lineno].strip().ljust(80 - len(adj.indentation), " ")
1753
2104
  adj.add_forward(f"// {source} <L {line}>")
1754
2105
  adj.add_reverse(f"// adj: {source} <L {line}>")
1755
2106
  adj.lineno = lineno
1756
2107
 
2108
+ def get_node_source(adj, node):
2109
+ # return the Python code corresponding to the given AST node
2110
+ return ast.get_source_segment(adj.source, node)
2111
+
1757
2112
 
1758
2113
  # ----------------
1759
2114
  # code generation
@@ -1769,7 +2124,10 @@ cpu_module_header = """
1769
2124
  #define int(x) cast_int(x)
1770
2125
  #define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
1771
2126
 
1772
- using namespace wp;
2127
+ #define builtin_tid1d() wp::tid(wp::s_threadIdx)
2128
+ #define builtin_tid2d(x, y) wp::tid(x, y, wp::s_threadIdx, dim)
2129
+ #define builtin_tid3d(x, y, z) wp::tid(x, y, z, wp::s_threadIdx, dim)
2130
+ #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, wp::s_threadIdx, dim)
1773
2131
 
1774
2132
  """
1775
2133
 
@@ -1784,8 +2142,10 @@ cuda_module_header = """
1784
2142
  #define int(x) cast_int(x)
1785
2143
  #define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
1786
2144
 
1787
-
1788
- using namespace wp;
2145
+ #define builtin_tid1d() wp::tid(_idx)
2146
+ #define builtin_tid2d(x, y) wp::tid(x, y, _idx, dim)
2147
+ #define builtin_tid3d(x, y, z) wp::tid(x, y, z, _idx, dim)
2148
+ #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, _idx, dim)
1789
2149
 
1790
2150
  """
1791
2151
 
@@ -1799,54 +2159,56 @@ struct {name}
1799
2159
  {{
1800
2160
  }}
1801
2161
 
1802
- CUDA_CALLABLE {name}& operator += (const {name}&) {{ return *this; }}
2162
+ CUDA_CALLABLE {name}& operator += (const {name}& rhs)
2163
+ {{{prefix_add_body}
2164
+ return *this;}}
1803
2165
 
1804
2166
  }};
1805
2167
 
1806
2168
  static CUDA_CALLABLE void adj_{name}({reverse_args})
1807
2169
  {{
1808
- {reverse_body}
1809
- }}
2170
+ {reverse_body}}}
1810
2171
 
1811
- CUDA_CALLABLE void atomic_add({name}* p, {name} t)
2172
+ CUDA_CALLABLE void adj_atomic_add({name}* p, {name} t)
1812
2173
  {{
1813
- {atomic_add_body}
1814
- }}
2174
+ {atomic_add_body}}}
1815
2175
 
1816
2176
 
1817
2177
  """
1818
2178
 
1819
- cpu_function_template = """
2179
+ cpu_forward_function_template = """
1820
2180
  // {filename}:{lineno}
1821
2181
  static {return_type} {name}(
1822
2182
  {forward_args})
1823
2183
  {{
1824
- {forward_body}
1825
- }}
2184
+ {forward_body}}}
2185
+
2186
+ """
1826
2187
 
2188
+ cpu_reverse_function_template = """
1827
2189
  // {filename}:{lineno}
1828
2190
  static void adj_{name}(
1829
2191
  {reverse_args})
1830
2192
  {{
1831
- {reverse_body}
1832
- }}
2193
+ {reverse_body}}}
1833
2194
 
1834
2195
  """
1835
2196
 
1836
- cuda_function_template = """
2197
+ cuda_forward_function_template = """
1837
2198
  // {filename}:{lineno}
1838
2199
  static CUDA_CALLABLE {return_type} {name}(
1839
2200
  {forward_args})
1840
2201
  {{
1841
- {forward_body}
1842
- }}
2202
+ {forward_body}}}
1843
2203
 
2204
+ """
2205
+
2206
+ cuda_reverse_function_template = """
1844
2207
  // {filename}:{lineno}
1845
2208
  static CUDA_CALLABLE void adj_{name}(
1846
2209
  {reverse_args})
1847
2210
  {{
1848
- {reverse_body}
1849
- }}
2211
+ {reverse_body}}}
1850
2212
 
1851
2213
  """
1852
2214
 
@@ -1855,25 +2217,21 @@ cuda_kernel_template = """
1855
2217
  extern "C" __global__ void {name}_cuda_kernel_forward(
1856
2218
  {forward_args})
1857
2219
  {{
1858
- size_t _idx = grid_index();
1859
- if (_idx >= dim.size)
1860
- return;
1861
-
1862
- set_launch_bounds(dim);
1863
-
1864
- {forward_body}
2220
+ for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
2221
+ _idx < dim.size;
2222
+ _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
2223
+ {{
2224
+ {forward_body} }}
1865
2225
  }}
1866
2226
 
1867
2227
  extern "C" __global__ void {name}_cuda_kernel_backward(
1868
2228
  {reverse_args})
1869
2229
  {{
1870
- size_t _idx = grid_index();
1871
- if (_idx >= dim.size)
1872
- return;
1873
-
1874
- set_launch_bounds(dim);
1875
-
1876
- {reverse_body}
2230
+ for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
2231
+ _idx < dim.size;
2232
+ _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
2233
+ {{
2234
+ {reverse_body} }}
1877
2235
  }}
1878
2236
 
1879
2237
  """
@@ -1883,14 +2241,12 @@ cpu_kernel_template = """
1883
2241
  void {name}_cpu_kernel_forward(
1884
2242
  {forward_args})
1885
2243
  {{
1886
- {forward_body}
1887
- }}
2244
+ {forward_body}}}
1888
2245
 
1889
2246
  void {name}_cpu_kernel_backward(
1890
2247
  {reverse_args})
1891
2248
  {{
1892
- {reverse_body}
1893
- }}
2249
+ {reverse_body}}}
1894
2250
 
1895
2251
  """
1896
2252
 
@@ -1902,11 +2258,9 @@ extern "C" {{
1902
2258
  WP_API void {name}_cpu_forward(
1903
2259
  {forward_args})
1904
2260
  {{
1905
- set_launch_bounds(dim);
1906
-
1907
2261
  for (size_t i=0; i < dim.size; ++i)
1908
2262
  {{
1909
- s_threadIdx = i;
2263
+ wp::s_threadIdx = i;
1910
2264
 
1911
2265
  {name}_cpu_kernel_forward(
1912
2266
  {forward_params});
@@ -1916,11 +2270,9 @@ WP_API void {name}_cpu_forward(
1916
2270
  WP_API void {name}_cpu_backward(
1917
2271
  {reverse_args})
1918
2272
  {{
1919
- set_launch_bounds(dim);
1920
-
1921
2273
  for (size_t i=0; i < dim.size; ++i)
1922
2274
  {{
1923
- s_threadIdx = i;
2275
+ wp::s_threadIdx = i;
1924
2276
 
1925
2277
  {name}_cpu_kernel_backward(
1926
2278
  {reverse_params});
@@ -1966,7 +2318,7 @@ WP_API void {name}_cpu_backward(
1966
2318
  def constant_str(value):
1967
2319
  value_type = type(value)
1968
2320
 
1969
- if value_type == bool:
2321
+ if value_type == bool or value_type == builtins.bool:
1970
2322
  if value:
1971
2323
  return "true"
1972
2324
  else:
@@ -1983,7 +2335,9 @@ def constant_str(value):
1983
2335
 
1984
2336
  scalar_value = runtime.core.half_bits_to_float
1985
2337
  else:
1986
- scalar_value = lambda x: x
2338
+
2339
+ def scalar_value(x):
2340
+ return x
1987
2341
 
1988
2342
  # list of scalar initializer values
1989
2343
  initlist = []
@@ -2000,6 +2354,9 @@ def constant_str(value):
2000
2354
  # make sure we emit the value of objects, e.g. uint32
2001
2355
  return str(value.value)
2002
2356
 
2357
+ elif value == math.inf:
2358
+ return "INFINITY"
2359
+
2003
2360
  else:
2004
2361
  # otherwise just convert constant to string
2005
2362
  return str(value)
@@ -2008,7 +2365,7 @@ def constant_str(value):
2008
2365
  def indent(args, stops=1):
2009
2366
  sep = ",\n"
2010
2367
  for i in range(stops):
2011
- sep += "\t"
2368
+ sep += " "
2012
2369
 
2013
2370
  # return sep + args.replace(", ", "," + sep)
2014
2371
  return sep.join(args)
@@ -2016,7 +2373,9 @@ def indent(args, stops=1):
2016
2373
 
2017
2374
  # generates a C function name based on the python function name
2018
2375
  def make_full_qualified_name(func):
2019
- return re.sub("[^0-9a-zA-Z_]+", "", func.__qualname__.replace(".", "__"))
2376
+ if not isinstance(func, str):
2377
+ func = func.__qualname__
2378
+ return re.sub("[^0-9a-zA-Z_]+", "", func.replace(".", "__"))
2020
2379
 
2021
2380
 
2022
2381
  def codegen_struct(struct, device="cpu", indent_size=4):
@@ -2024,8 +2383,13 @@ def codegen_struct(struct, device="cpu", indent_size=4):
2024
2383
 
2025
2384
  body = []
2026
2385
  indent_block = " " * indent_size
2027
- for label, var in struct.vars.items():
2028
- body.append(var.ctype() + " " + label + ";\n")
2386
+
2387
+ if len(struct.vars) > 0:
2388
+ for label, var in struct.vars.items():
2389
+ body.append(var.ctype() + " " + label + ";\n")
2390
+ else:
2391
+ # for empty structs, emit the dummy attribute to avoid any compiler-specific alignment issues
2392
+ body.append("char _dummy_;\n")
2029
2393
 
2030
2394
  forward_args = []
2031
2395
  reverse_args = []
@@ -2033,24 +2397,32 @@ def codegen_struct(struct, device="cpu", indent_size=4):
2033
2397
  forward_initializers = []
2034
2398
  reverse_body = []
2035
2399
  atomic_add_body = []
2400
+ prefix_add_body = []
2036
2401
 
2037
2402
  # forward args
2038
2403
  for label, var in struct.vars.items():
2039
- forward_args.append(f"{var.ctype()} const& {label} = {{}}")
2040
- reverse_args.append(f"{var.ctype()} const&")
2404
+ var_ctype = var.ctype()
2405
+ forward_args.append(f"{var_ctype} const& {label} = {{}}")
2406
+ reverse_args.append(f"{var_ctype} const&")
2041
2407
 
2042
- atomic_add_body.append(f"{indent_block}atomic_add(&p->{label}, t.{label});\n")
2408
+ namespace = "wp::" if var_ctype.startswith("wp::") or var_ctype == "bool" else ""
2409
+ atomic_add_body.append(f"{indent_block}{namespace}adj_atomic_add(&p->{label}, t.{label});\n")
2043
2410
 
2044
2411
  prefix = f"{indent_block}," if forward_initializers else ":"
2045
2412
  forward_initializers.append(f"{indent_block}{prefix} {label}{{{label}}}\n")
2046
2413
 
2414
+ # prefix-add operator
2415
+ for label, var in struct.vars.items():
2416
+ if not is_array(var.type):
2417
+ prefix_add_body.append(f"{indent_block}{label} += rhs.{label};\n")
2418
+
2047
2419
  # reverse args
2048
2420
  for label, var in struct.vars.items():
2049
2421
  reverse_args.append(var.ctype() + " & adj_" + label)
2050
2422
  if is_array(var.type):
2051
- reverse_body.append(f"adj_{label} = {indent_block}adj_ret.{label};\n")
2423
+ reverse_body.append(f"{indent_block}adj_{label} = adj_ret.{label};\n")
2052
2424
  else:
2053
- reverse_body.append(f"adj_{label} += {indent_block}adj_ret.{label};\n")
2425
+ reverse_body.append(f"{indent_block}adj_{label} += adj_ret.{label};\n")
2054
2426
 
2055
2427
  reverse_args.append(name + " & adj_ret")
2056
2428
 
@@ -2061,109 +2433,101 @@ def codegen_struct(struct, device="cpu", indent_size=4):
2061
2433
  forward_initializers="".join(forward_initializers),
2062
2434
  reverse_args=indent(reverse_args),
2063
2435
  reverse_body="".join(reverse_body),
2436
+ prefix_add_body="".join(prefix_add_body),
2064
2437
  atomic_add_body="".join(atomic_add_body),
2065
2438
  )
2066
2439
 
2067
2440
 
2068
- def codegen_func_forward_body(adj, device="cpu", indent=4):
2069
- body = []
2070
- indent_block = " " * indent
2071
-
2072
- for f in adj.blocks[0].body_forward:
2073
- body += [f + "\n"]
2074
-
2075
- return "".join([indent_block + l for l in body])
2076
-
2077
-
2078
2441
  def codegen_func_forward(adj, func_type="kernel", device="cpu"):
2079
- s = ""
2442
+ if device == "cpu":
2443
+ indent = 4
2444
+ elif device == "cuda":
2445
+ if func_type == "kernel":
2446
+ indent = 8
2447
+ else:
2448
+ indent = 4
2449
+ else:
2450
+ raise ValueError(f"Device {device} not supported for codegen")
2451
+
2452
+ indent_block = " " * indent
2080
2453
 
2081
2454
  # primal vars
2082
- s += " //---------\n"
2083
- s += " // primal vars\n"
2455
+ lines = []
2456
+ lines += ["//---------\n"]
2457
+ lines += ["// primal vars\n"]
2084
2458
 
2085
2459
  for var in adj.variables:
2086
2460
  if var.constant is None:
2087
- s += " " + var.ctype() + " var_" + str(var.label) + ";\n"
2461
+ lines += [f"{var.ctype()} {var.emit()};\n"]
2088
2462
  else:
2089
- s += " const " + var.ctype() + " var_" + str(var.label) + " = " + constant_str(var.constant) + ";\n"
2463
+ lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
2090
2464
 
2091
2465
  # forward pass
2092
- s += " //---------\n"
2093
- s += " // forward\n"
2466
+ lines += ["//---------\n"]
2467
+ lines += ["// forward\n"]
2094
2468
 
2095
- if device == "cpu":
2096
- s += codegen_func_forward_body(adj, device=device, indent=4)
2469
+ for f in adj.blocks[0].body_forward:
2470
+ lines += [f + "\n"]
2097
2471
 
2472
+ return "".join([indent_block + l for l in lines])
2473
+
2474
+
2475
+ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
2476
+ if device == "cpu":
2477
+ indent = 4
2098
2478
  elif device == "cuda":
2099
2479
  if func_type == "kernel":
2100
- s += codegen_func_forward_body(adj, device=device, indent=8)
2480
+ indent = 8
2101
2481
  else:
2102
- s += codegen_func_forward_body(adj, device=device, indent=4)
2103
-
2104
- return s
2105
-
2482
+ indent = 4
2483
+ else:
2484
+ raise ValueError(f"Device {device} not supported for codegen")
2106
2485
 
2107
- def codegen_func_reverse_body(adj, device="cpu", indent=4):
2108
- body = []
2109
2486
  indent_block = " " * indent
2110
2487
 
2111
- # forward pass
2112
- body += ["//---------\n"]
2113
- body += ["// forward\n"]
2114
-
2115
- for f in adj.blocks[0].body_replay:
2116
- body += [f + "\n"]
2117
-
2118
- # reverse pass
2119
- body += ["//---------\n"]
2120
- body += ["// reverse\n"]
2121
-
2122
- for l in reversed(adj.blocks[0].body_reverse):
2123
- body += [l + "\n"]
2124
-
2125
- body += ["return;\n"]
2126
-
2127
- return "".join([indent_block + l for l in body])
2128
-
2129
-
2130
- def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
2131
- s = ""
2488
+ lines = []
2132
2489
 
2133
2490
  # primal vars
2134
- s += " //---------\n"
2135
- s += " // primal vars\n"
2491
+ lines += ["//---------\n"]
2492
+ lines += ["// primal vars\n"]
2136
2493
 
2137
2494
  for var in adj.variables:
2138
2495
  if var.constant is None:
2139
- s += " " + var.ctype() + " var_" + str(var.label) + ";\n"
2496
+ lines += [f"{var.ctype()} {var.emit()};\n"]
2140
2497
  else:
2141
- s += " const " + var.ctype() + " var_" + str(var.label) + " = " + constant_str(var.constant) + ";\n"
2498
+ lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
2142
2499
 
2143
2500
  # dual vars
2144
- s += " //---------\n"
2145
- s += " // dual vars\n"
2501
+ lines += ["//---------\n"]
2502
+ lines += ["// dual vars\n"]
2146
2503
 
2147
2504
  for var in adj.variables:
2148
- if isinstance(var.type, Struct):
2149
- s += " " + var.ctype() + " adj_" + str(var.label) + ";\n"
2150
- else:
2151
- s += " " + var.ctype() + " adj_" + str(var.label) + "(0);\n"
2505
+ lines += [f"{var.ctype(value_type=True)} {var.emit_adj()} = {{}};\n"]
2152
2506
 
2153
- if device == "cpu":
2154
- s += codegen_func_reverse_body(adj, device=device, indent=4)
2155
- elif device == "cuda":
2156
- if func_type == "kernel":
2157
- s += codegen_func_reverse_body(adj, device=device, indent=8)
2158
- else:
2159
- s += codegen_func_reverse_body(adj, device=device, indent=4)
2507
+ # forward pass
2508
+ lines += ["//---------\n"]
2509
+ lines += ["// forward\n"]
2510
+
2511
+ for f in adj.blocks[0].body_replay:
2512
+ lines += [f + "\n"]
2513
+
2514
+ # reverse pass
2515
+ lines += ["//---------\n"]
2516
+ lines += ["// reverse\n"]
2517
+
2518
+ for l in reversed(adj.blocks[0].body_reverse):
2519
+ lines += [l + "\n"]
2520
+
2521
+ # In grid-stride kernels the reverse body is in a for loop
2522
+ if device == "cuda" and func_type == "kernel":
2523
+ lines += ["continue;\n"]
2160
2524
  else:
2161
- raise ValueError("Device {} not supported for codegen".format(device))
2525
+ lines += ["return;\n"]
2162
2526
 
2163
- return s
2527
+ return "".join([indent_block + l for l in lines])
2164
2528
 
2165
2529
 
2166
- def codegen_func(adj, name, device="cpu", options={}):
2530
+ def codegen_func(adj, c_func_name: str, device="cpu", options={}):
2167
2531
  # forward header
2168
2532
  if adj.return_var is not None and len(adj.return_var) == 1:
2169
2533
  return_type = adj.return_var[0].ctype()
@@ -2176,16 +2540,20 @@ def codegen_func(adj, name, device="cpu", options={}):
2176
2540
  reverse_args = []
2177
2541
 
2178
2542
  # forward args
2179
- for arg in adj.args:
2180
- forward_args.append(arg.ctype() + " var_" + arg.label)
2181
- reverse_args.append(arg.ctype() + " var_" + arg.label)
2543
+ for i, arg in enumerate(adj.args):
2544
+ s = f"{arg.ctype()} {arg.emit()}"
2545
+ forward_args.append(s)
2546
+ if not adj.custom_reverse_mode or i < adj.custom_reverse_num_input_args:
2547
+ reverse_args.append(s)
2182
2548
  if has_multiple_outputs:
2183
2549
  for i, arg in enumerate(adj.return_var):
2184
2550
  forward_args.append(arg.ctype() + " & ret_" + str(i))
2185
2551
  reverse_args.append(arg.ctype() + " & ret_" + str(i))
2186
2552
 
2187
2553
  # reverse args
2188
- for arg in adj.args:
2554
+ for i, arg in enumerate(adj.args):
2555
+ if adj.custom_reverse_mode and i >= adj.custom_reverse_num_input_args:
2556
+ break
2189
2557
  # indexed array gradients are regular arrays
2190
2558
  if isinstance(arg.type, indexedarray):
2191
2559
  _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
@@ -2197,28 +2565,96 @@ def codegen_func(adj, name, device="cpu", options={}):
2197
2565
  reverse_args.append(arg.ctype() + " & adj_ret_" + str(i))
2198
2566
  elif return_type != "void":
2199
2567
  reverse_args.append(return_type + " & adj_ret")
2568
+ # custom output reverse args (user-declared)
2569
+ if adj.custom_reverse_mode:
2570
+ for arg in adj.args[adj.custom_reverse_num_input_args :]:
2571
+ reverse_args.append(f"{arg.ctype()} & {arg.emit()}")
2572
+
2573
+ if device == "cpu":
2574
+ forward_template = cpu_forward_function_template
2575
+ reverse_template = cpu_reverse_function_template
2576
+ elif device == "cuda":
2577
+ forward_template = cuda_forward_function_template
2578
+ reverse_template = cuda_reverse_function_template
2579
+ else:
2580
+ raise ValueError(f"Device {device} is not supported")
2200
2581
 
2201
2582
  # codegen body
2202
2583
  forward_body = codegen_func_forward(adj, func_type="function", device=device)
2203
2584
 
2204
- if options.get("enable_backward", True):
2205
- reverse_body = codegen_func_reverse(adj, func_type="function", device=device)
2206
- else:
2207
- reverse_body = ""
2585
+ s = ""
2586
+ if not adj.skip_forward_codegen:
2587
+ s += forward_template.format(
2588
+ name=c_func_name,
2589
+ return_type=return_type,
2590
+ forward_args=indent(forward_args),
2591
+ forward_body=forward_body,
2592
+ filename=adj.filename,
2593
+ lineno=adj.fun_lineno,
2594
+ )
2208
2595
 
2209
- if device == "cpu":
2210
- template = cpu_function_template
2211
- elif device == "cuda":
2212
- template = cuda_function_template
2213
- else:
2214
- raise ValueError("Device {} is not supported".format(device))
2596
+ if not adj.skip_reverse_codegen:
2597
+ if adj.custom_reverse_mode:
2598
+ reverse_body = "\t// user-defined adjoint code\n" + forward_body
2599
+ else:
2600
+ if options.get("enable_backward", True):
2601
+ reverse_body = codegen_func_reverse(adj, func_type="function", device=device)
2602
+ else:
2603
+ reverse_body = '\t// reverse mode disabled (module option "enable_backward" is False)\n'
2604
+ s += reverse_template.format(
2605
+ name=c_func_name,
2606
+ return_type=return_type,
2607
+ reverse_args=indent(reverse_args),
2608
+ forward_body=forward_body,
2609
+ reverse_body=reverse_body,
2610
+ filename=adj.filename,
2611
+ lineno=adj.fun_lineno,
2612
+ )
2215
2613
 
2216
- s = template.format(
2614
+ return s
2615
+
2616
+
2617
+ def codegen_snippet(adj, name, snippet, adj_snippet):
2618
+ forward_args = []
2619
+ reverse_args = []
2620
+
2621
+ # forward args
2622
+ for i, arg in enumerate(adj.args):
2623
+ s = f"{arg.ctype()} {arg.emit().replace('var_', '')}"
2624
+ forward_args.append(s)
2625
+ reverse_args.append(s)
2626
+
2627
+ # reverse args
2628
+ for i, arg in enumerate(adj.args):
2629
+ if isinstance(arg.type, indexedarray):
2630
+ _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
2631
+ reverse_args.append(_arg.ctype() + " & adj_" + arg.label)
2632
+ else:
2633
+ reverse_args.append(arg.ctype() + " & adj_" + arg.label)
2634
+
2635
+ forward_template = cuda_forward_function_template
2636
+ reverse_template = cuda_reverse_function_template
2637
+
2638
+ s = ""
2639
+ s += forward_template.format(
2217
2640
  name=name,
2218
- return_type=return_type,
2641
+ return_type="void",
2219
2642
  forward_args=indent(forward_args),
2643
+ forward_body=snippet,
2644
+ filename=adj.filename,
2645
+ lineno=adj.fun_lineno,
2646
+ )
2647
+
2648
+ if adj_snippet:
2649
+ reverse_body = adj_snippet
2650
+ else:
2651
+ reverse_body = ""
2652
+
2653
+ s += reverse_template.format(
2654
+ name=name,
2655
+ return_type="void",
2220
2656
  reverse_args=indent(reverse_args),
2221
- forward_body=forward_body,
2657
+ forward_body=snippet,
2222
2658
  reverse_body=reverse_body,
2223
2659
  filename=adj.filename,
2224
2660
  lineno=adj.fun_lineno,
@@ -2234,8 +2670,8 @@ def codegen_kernel(kernel, device, options):
2234
2670
 
2235
2671
  adj = kernel.adj
2236
2672
 
2237
- forward_args = ["launch_bounds_t dim"]
2238
- reverse_args = ["launch_bounds_t dim"]
2673
+ forward_args = ["wp::launch_bounds_t dim"]
2674
+ reverse_args = ["wp::launch_bounds_t dim"]
2239
2675
 
2240
2676
  # forward args
2241
2677
  for arg in adj.args:
@@ -2264,7 +2700,7 @@ def codegen_kernel(kernel, device, options):
2264
2700
  elif device == "cuda":
2265
2701
  template = cuda_kernel_template
2266
2702
  else:
2267
- raise ValueError("Device {} is not supported".format(device))
2703
+ raise ValueError(f"Device {device} is not supported")
2268
2704
 
2269
2705
  s = template.format(
2270
2706
  name=kernel.get_mangled_name(),
@@ -2284,7 +2720,7 @@ def codegen_module(kernel, device="cpu"):
2284
2720
  adj = kernel.adj
2285
2721
 
2286
2722
  # build forward signature
2287
- forward_args = ["launch_bounds_t dim"]
2723
+ forward_args = ["wp::launch_bounds_t dim"]
2288
2724
  forward_params = ["dim"]
2289
2725
 
2290
2726
  for arg in adj.args: