warp-lang 1.9.1__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +794 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1382 -377
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +1 -1
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +253 -171
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +14 -14
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +527 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/_src/codegen.py ADDED
@@ -0,0 +1,4359 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import ast
19
+ import builtins
20
+ import ctypes
21
+ import enum
22
+ import functools
23
+ import hashlib
24
+ import inspect
25
+ import itertools
26
+ import math
27
+ import re
28
+ import sys
29
+ import textwrap
30
+ import types
31
+ from typing import Any, Callable, ClassVar, Mapping, Sequence, get_args, get_origin
32
+
33
+ import warp._src.config
34
+ from warp._src.types import *
35
+
36
+ # used as a globally accessible copy
37
+ # of current compile options (block_dim) etc
38
+ options = {}
39
+
40
+
41
+ class WarpCodegenError(RuntimeError):
42
+ def __init__(self, message):
43
+ super().__init__(message)
44
+
45
+
46
+ class WarpCodegenTypeError(TypeError):
47
+ def __init__(self, message):
48
+ super().__init__(message)
49
+
50
+
51
+ class WarpCodegenAttributeError(AttributeError):
52
+ def __init__(self, message):
53
+ super().__init__(message)
54
+
55
+
56
+ def get_node_name_safe(node):
57
+ """Safely get a string representation of an AST node for error messages.
58
+
59
+ This handles different AST node types (Name, Subscript, etc.) without
60
+ raising AttributeError when accessing attributes that may not exist.
61
+ """
62
+ if hasattr(node, "id"):
63
+ return node.id
64
+ elif hasattr(node, "value") and hasattr(node, "slice"):
65
+ # Subscript node like inputs[tid]
66
+ base_name = get_node_name_safe(node.value)
67
+ return f"{base_name}[...]"
68
+ else:
69
+ return f"<{type(node).__name__}>"
70
+
71
+
72
+ class WarpCodegenKeyError(KeyError):
73
+ def __init__(self, message):
74
+ super().__init__(message)
75
+
76
+
77
+ # map operator to function name
78
+ builtin_operators: dict[type[ast.AST], str] = {}
79
+
80
+ # see https://www.ics.uci.edu/~pattis/ICS-31/lectures/opexp.pdf for a
81
+ # nice overview of python operators
82
+
83
+ builtin_operators[ast.Add] = "add"
84
+ builtin_operators[ast.Sub] = "sub"
85
+ builtin_operators[ast.Mult] = "mul"
86
+ builtin_operators[ast.MatMult] = "mul"
87
+ builtin_operators[ast.Div] = "div"
88
+ builtin_operators[ast.FloorDiv] = "floordiv"
89
+ builtin_operators[ast.Pow] = "pow"
90
+ builtin_operators[ast.Mod] = "mod"
91
+ builtin_operators[ast.UAdd] = "pos"
92
+ builtin_operators[ast.USub] = "neg"
93
+ builtin_operators[ast.Not] = "unot"
94
+
95
+ builtin_operators[ast.Gt] = ">"
96
+ builtin_operators[ast.Lt] = "<"
97
+ builtin_operators[ast.GtE] = ">="
98
+ builtin_operators[ast.LtE] = "<="
99
+ builtin_operators[ast.Eq] = "=="
100
+ builtin_operators[ast.NotEq] = "!="
101
+
102
+ builtin_operators[ast.BitAnd] = "bit_and"
103
+ builtin_operators[ast.BitOr] = "bit_or"
104
+ builtin_operators[ast.BitXor] = "bit_xor"
105
+ builtin_operators[ast.Invert] = "invert"
106
+ builtin_operators[ast.LShift] = "lshift"
107
+ builtin_operators[ast.RShift] = "rshift"
108
+
109
+ comparison_chain_strings = [
110
+ builtin_operators[ast.Gt],
111
+ builtin_operators[ast.Lt],
112
+ builtin_operators[ast.LtE],
113
+ builtin_operators[ast.GtE],
114
+ builtin_operators[ast.Eq],
115
+ builtin_operators[ast.NotEq],
116
+ ]
117
+
118
+
119
+ def values_check_equal(a, b):
120
+ if isinstance(a, Sequence) and isinstance(b, Sequence):
121
+ if len(a) != len(b):
122
+ return False
123
+
124
+ return all(x == y for x, y in zip(a, b))
125
+
126
+ return a == b
127
+
128
+
129
+ def op_str_is_chainable(op: str) -> builtins.bool:
130
+ return op in comparison_chain_strings
131
+
132
+
133
+ def get_closure_cell_contents(obj):
134
+ """Retrieve a closure's cell contents or `None` if it's empty."""
135
+ try:
136
+ return obj.cell_contents
137
+ except ValueError:
138
+ pass
139
+
140
+ return None
141
+
142
+
143
+ def eval_annotations(annotations: Mapping[str, Any], obj: Any) -> Mapping[str, Any]:
144
+ """Un-stringize annotations caused by `from __future__ import annotations` of PEP 563."""
145
+ # Implementation backported from `inspect.get_annotations()` for Python 3.9 and older.
146
+ if not annotations:
147
+ return {}
148
+
149
+ if not any(isinstance(x, str) for x in annotations.values()):
150
+ # No annotation to un-stringize.
151
+ return annotations
152
+
153
+ if isinstance(obj, type):
154
+ # class
155
+ globals = {}
156
+ module_name = getattr(obj, "__module__", None)
157
+ if module_name:
158
+ module = sys.modules.get(module_name, None)
159
+ if module:
160
+ globals = getattr(module, "__dict__", {})
161
+ locals = dict(vars(obj))
162
+ unwrap = obj
163
+ elif isinstance(obj, types.ModuleType):
164
+ # module
165
+ globals = obj.__dict__
166
+ locals = {}
167
+ unwrap = None
168
+ elif callable(obj):
169
+ # function
170
+ globals = getattr(obj, "__globals__", {})
171
+ # Capture the variables from the surrounding scope.
172
+ closure_vars = zip(
173
+ obj.__code__.co_freevars, tuple(get_closure_cell_contents(x) for x in (obj.__closure__ or ()))
174
+ )
175
+ locals = {k: v for k, v in closure_vars if v is not None}
176
+ unwrap = obj
177
+ else:
178
+ raise TypeError(f"{obj!r} is not a module, class, or callable.")
179
+
180
+ if unwrap is not None:
181
+ while True:
182
+ if hasattr(unwrap, "__wrapped__"):
183
+ unwrap = unwrap.__wrapped__
184
+ continue
185
+ if isinstance(unwrap, functools.partial):
186
+ unwrap = unwrap.func
187
+ continue
188
+ break
189
+ if hasattr(unwrap, "__globals__"):
190
+ globals = unwrap.__globals__
191
+
192
+ # "Inject" type parameters into the local namespace
193
+ # (unless they are shadowed by assignments *in* the local namespace),
194
+ # as a way of emulating annotation scopes when calling `eval()`
195
+ type_params = getattr(obj, "__type_params__", ())
196
+ if type_params:
197
+ locals = {param.__name__: param for param in type_params} | locals
198
+
199
+ return {k: v if not isinstance(v, str) else eval(v, globals, locals) for k, v in annotations.items()}
200
+
201
+
202
+ def get_annotations(obj: Any) -> Mapping[str, Any]:
203
+ """Same as `inspect.get_annotations()` but always returning un-stringized annotations."""
204
+ # Python 3.10+: Use the built-in inspect.get_annotations() which handles
205
+ # PEP 649 (deferred annotation evaluation) in Python 3.14+
206
+ if hasattr(inspect, "get_annotations"):
207
+ # eval_str=True ensures stringized annotations from PEP 563 are evaluated
208
+ return inspect.get_annotations(obj, eval_str=True)
209
+ else:
210
+ # Python 3.9 and older: Manual backport of inspect.get_annotations()
211
+ # See https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
212
+ if isinstance(obj, type):
213
+ annotations = obj.__dict__.get("__annotations__", {})
214
+ else:
215
+ annotations = getattr(obj, "__annotations__", {})
216
+
217
+ return eval_annotations(annotations, obj)
218
+
219
+
220
+ def get_full_arg_spec(func: Callable) -> inspect.FullArgSpec:
221
+ """Same as `inspect.getfullargspec()` but always returning un-stringized annotations."""
222
+ spec = inspect.getfullargspec(func)
223
+
224
+ # Python 3.10+: Use inspect.get_annotations()
225
+ if hasattr(inspect, "get_annotations"):
226
+ # Capture closure variables to handle cases like `foo.Data` where `foo` is a closure variable
227
+ closure_vars = dict(
228
+ zip(func.__code__.co_freevars, (get_closure_cell_contents(x) for x in (func.__closure__ or ())))
229
+ )
230
+ # Filter out None values from empty cells
231
+ closure_vars = {k: v for k, v in closure_vars.items() if v is not None}
232
+ return spec._replace(annotations=inspect.get_annotations(func, eval_str=True, locals=closure_vars))
233
+ else:
234
+ # Python 3.9 and older: Manually un-stringize annotations
235
+ # See https://docs.python.org/3/howto/annotations.html#manually-un-stringizing-stringized-annotations
236
+ return spec._replace(annotations=eval_annotations(spec.annotations, func))
237
+
238
+
239
+ def struct_instance_repr_recursive(inst: StructInstance, depth: int, use_repr: bool) -> str:
240
+ indent = "\t"
241
+
242
+ # handle empty structs
243
+ if len(inst._cls.vars) == 0:
244
+ return f"{inst._cls.key}()"
245
+
246
+ lines = []
247
+ lines.append(f"{inst._cls.key}(")
248
+
249
+ for field_name, _ in inst._cls.ctype._fields_:
250
+ field_value = getattr(inst, field_name, None)
251
+
252
+ if isinstance(field_value, StructInstance):
253
+ field_value = struct_instance_repr_recursive(field_value, depth + 1, use_repr)
254
+
255
+ if use_repr:
256
+ lines.append(f"{indent * (depth + 1)}{field_name}={field_value!r},")
257
+ else:
258
+ lines.append(f"{indent * (depth + 1)}{field_name}={field_value!s},")
259
+
260
+ lines.append(f"{indent * depth})")
261
+ return "\n".join(lines)
262
+
263
+
264
+ class StructInstance:
265
+ def __init__(self, ctype):
266
+ # maintain a c-types object for the top-level instance the struct
267
+ super().__setattr__("_ctype", ctype)
268
+
269
+ # create Python attributes for each of the struct's variables
270
+ for k, cst in type(self)._constructors:
271
+ self.__dict__[k] = cst(ctype)
272
+
273
+ def __setattr__(self, name, value):
274
+ try:
275
+ self._setters[name](self, value)
276
+ except KeyError as err:
277
+ raise RuntimeError(f"Trying to set Warp struct attribute that does not exist {name}") from err
278
+
279
+ def __ctype__(self):
280
+ return self._ctype
281
+
282
+ def __repr__(self):
283
+ return struct_instance_repr_recursive(self, 0, use_repr=True)
284
+
285
+ def __str__(self):
286
+ return struct_instance_repr_recursive(self, 0, use_repr=False)
287
+
288
+ def assign(self, value):
289
+ """Assigns the values of another struct instance to this one."""
290
+ if not isinstance(value, StructInstance):
291
+ raise RuntimeError(
292
+ f"Trying to assign a non-structure value to a struct attribute with type: {self._cls.key}"
293
+ )
294
+
295
+ if self._cls.key is not value._cls.key:
296
+ raise RuntimeError(
297
+ f"Trying to assign a structure of type {value._cls.key} to an attribute of {self._cls.key}"
298
+ )
299
+
300
+ # update all nested ctype vars by deep copy
301
+ for n in self._cls.vars:
302
+ setattr(self, n, getattr(value, n))
303
+
304
+ def to(self, device):
305
+ """Copies this struct with all array members moved onto the given device.
306
+
307
+ Arrays already living on the desired device are referenced as-is, while
308
+ arrays being moved are copied.
309
+ """
310
+ out = self._cls()
311
+ stack = [(self, out, k, v) for k, v in self._cls.vars.items()]
312
+ while stack:
313
+ src, dst, name, var = stack.pop()
314
+ value = getattr(src, name)
315
+ if isinstance(var.type, array):
316
+ # array_t
317
+ setattr(dst, name, value.to(device))
318
+ elif isinstance(var.type, Struct):
319
+ # nested struct
320
+ new_struct = var.type()
321
+ setattr(dst, name, new_struct)
322
+ # The call to `setattr()` just above makes a copy of `new_struct`
323
+ # so we need to reference that new instance of the struct.
324
+ new_struct = getattr(dst, name)
325
+ stack.extend((value, new_struct, k, v) for k, v in var.type.vars.items())
326
+ else:
327
+ setattr(dst, name, value)
328
+
329
+ return out
330
+
331
+ # type description used in numpy structured arrays
332
+ def numpy_dtype(self):
333
+ return self._cls.numpy_dtype()
334
+
335
+ # value usable in numpy structured arrays of .numpy_dtype(), e.g. (42, 13.37, [1.0, 2.0, 3.0])
336
+ def numpy_value(self):
337
+ npvalue = []
338
+ for name, var in self._cls.vars.items():
339
+ # get the attribute value
340
+ value = getattr(self._ctype, name)
341
+
342
+ if isinstance(var.type, array):
343
+ # array_t
344
+ npvalue.append(value.numpy_value())
345
+ elif isinstance(var.type, Struct):
346
+ # nested struct
347
+ npvalue.append(value.numpy_value())
348
+ elif issubclass(var.type, ctypes.Array):
349
+ if len(var.type._shape_) == 1:
350
+ # vector
351
+ npvalue.append(list(value))
352
+ else:
353
+ # matrix
354
+ npvalue.append([list(row) for row in value])
355
+ else:
356
+ # scalar
357
+ if var.type == warp.float16:
358
+ npvalue.append(half_bits_to_float(value))
359
+ else:
360
+ npvalue.append(value)
361
+
362
+ return tuple(npvalue)
363
+
364
+
365
+ def _make_struct_field_constructor(field: str, var_type: type):
366
+ if isinstance(var_type, Struct):
367
+ return lambda ctype: var_type.instance_type(ctype=getattr(ctype, field))
368
+ elif isinstance(var_type, warp._src.types.array):
369
+ return lambda ctype: None
370
+ elif issubclass(var_type, ctypes.Array):
371
+ # for vector/matrices, the Python attribute aliases the ctype one
372
+ return lambda ctype: getattr(ctype, field)
373
+ else:
374
+ return lambda ctype: var_type()
375
+
376
+
377
+ def _make_struct_field_setter(cls, field: str, var_type: type):
378
+ def set_array_value(inst, value):
379
+ if value is None:
380
+ # create array with null pointer
381
+ setattr(inst._ctype, field, array_t())
382
+ else:
383
+ # wp.array
384
+ assert isinstance(value, array)
385
+ assert types_equal(value.dtype, var_type.dtype), (
386
+ f"assign to struct member variable {field} failed, expected type {type_repr(var_type.dtype)}, got type {type_repr(value.dtype)}"
387
+ )
388
+ setattr(inst._ctype, field, value.__ctype__())
389
+
390
+ # workaround to prevent gradient buffers being garbage collected
391
+ # since users can do struct.array.requires_grad = False the gradient array
392
+ # would be collected while the struct ctype still holds a reference to it
393
+ if value.requires_grad:
394
+ cls.__setattr__(inst, "_" + field + "_grad", value.grad)
395
+
396
+ cls.__setattr__(inst, field, value)
397
+
398
+ def set_struct_value(inst, value):
399
+ getattr(inst, field).assign(value)
400
+
401
+ def set_vector_value(inst, value):
402
+ # vector/matrix type, e.g. vec3
403
+ if value is None:
404
+ setattr(inst._ctype, field, var_type())
405
+ elif type(value) is var_type:
406
+ setattr(inst._ctype, field, value)
407
+ else:
408
+ # conversion from list/tuple, ndarray, etc.
409
+ setattr(inst._ctype, field, var_type(value))
410
+
411
+ # no need to update the Python attribute,
412
+ # it's already aliasing the ctype one
413
+
414
+ def set_primitive_value(inst, value):
415
+ # primitive type
416
+ if value is None:
417
+ # zero initialize
418
+ setattr(inst._ctype, field, var_type._type_())
419
+ else:
420
+ if hasattr(value, "_type_"):
421
+ # assigning warp type value (e.g.: wp.float32)
422
+ value = value.value
423
+ # float16 needs conversion to uint16 bits
424
+ if var_type == warp.float16:
425
+ setattr(inst._ctype, field, float_to_half_bits(value))
426
+ else:
427
+ setattr(inst._ctype, field, value)
428
+
429
+ cls.__setattr__(inst, field, value)
430
+
431
+ if isinstance(var_type, array):
432
+ return set_array_value
433
+ elif isinstance(var_type, Struct):
434
+ return set_struct_value
435
+ elif issubclass(var_type, ctypes.Array):
436
+ return set_vector_value
437
+ else:
438
+ return set_primitive_value
439
+
440
+
441
+ class Struct:
442
+ hash: bytes
443
+
444
+ def __init__(self, key: str, cls: type, module: warp._src.context.Module):
445
+ self.key = key
446
+ self.cls = cls
447
+ self.module = module
448
+ self.vars: dict[str, Var] = {}
449
+
450
+ if isinstance(self.cls, Sequence):
451
+ raise RuntimeError("Warp structs must be defined as base classes")
452
+
453
+ annotations = get_annotations(self.cls)
454
+ for label, type_ in annotations.items():
455
+ self.vars[label] = Var(label, type_)
456
+
457
+ fields = []
458
+ for label, var in self.vars.items():
459
+ if isinstance(var.type, array):
460
+ fields.append((label, array_t))
461
+ elif isinstance(var.type, Struct):
462
+ fields.append((label, var.type.ctype))
463
+ elif issubclass(var.type, ctypes.Array):
464
+ fields.append((label, var.type))
465
+ else:
466
+ # HACK: fp16 requires conversion functions from warp.so
467
+ if var.type is warp.float16:
468
+ warp.init()
469
+ fields.append((label, var.type._type_))
470
+
471
+ class StructType(ctypes.Structure):
472
+ # if struct is empty, add a dummy field to avoid launch errors on CPU device ("ffi_prep_cif failed")
473
+ _fields_ = fields or [("_dummy_", ctypes.c_byte)]
474
+
475
+ self.ctype = StructType
476
+
477
+ # Compute the hash. We can cache the hash because it's static, even with nested structs.
478
+ # All field types are specified in the annotations, so they're resolved at declaration time.
479
+ ch = hashlib.sha256()
480
+
481
+ ch.update(bytes(self.key, "utf-8"))
482
+
483
+ for name, type_hint in annotations.items():
484
+ s = f"{name}:{warp._src.types.get_type_code(type_hint)}"
485
+ ch.update(bytes(s, "utf-8"))
486
+
487
+ # recurse on nested structs
488
+ if isinstance(type_hint, Struct):
489
+ ch.update(type_hint.hash)
490
+
491
+ self.hash = ch.digest()
492
+
493
+ # generate unique identifier for structs in native code
494
+ hash_suffix = f"{self.hash.hex()[:8]}"
495
+ self.native_name = f"{self.key}_{hash_suffix}"
496
+
497
+ # create default constructor (zero-initialize)
498
+ self.default_constructor = warp._src.context.Function(
499
+ func=None,
500
+ key=self.native_name,
501
+ namespace="",
502
+ value_func=lambda *_: self,
503
+ input_types={},
504
+ initializer_list_func=lambda *_: False,
505
+ native_func=self.native_name,
506
+ )
507
+
508
+ # build a constructor that takes each param as a value
509
+ input_types = {label: var.type for label, var in self.vars.items()}
510
+
511
+ self.value_constructor = warp._src.context.Function(
512
+ func=None,
513
+ key=self.native_name,
514
+ namespace="",
515
+ value_func=lambda *_: self,
516
+ input_types=input_types,
517
+ initializer_list_func=lambda *_: False,
518
+ native_func=self.native_name,
519
+ )
520
+
521
+ self.default_constructor.add_overload(self.value_constructor)
522
+
523
+ if isinstance(module, warp._src.context.Module):
524
+ module.register_struct(self)
525
+
526
+ # Define class for instances of this struct
527
+ # To enable autocomplete on s, we inherit from self.cls.
528
+ # For example,
529
+
530
+ # @wp.struct
531
+ # class A:
532
+ # # annotations
533
+ # ...
534
+
535
+ # The type annotations are inherited in A(), allowing autocomplete in kernels
536
+ class NewStructInstance(self.cls, StructInstance):
537
+ cls: ClassVar[type] = self.cls
538
+ native_name: ClassVar[str] = self.native_name
539
+
540
+ _cls: ClassVar[type] = self
541
+ _constructors: ClassVar[list[tuple[str, Callable]]] = [
542
+ (field, _make_struct_field_constructor(field, var.type)) for field, var in self.vars.items()
543
+ ]
544
+ _setters: ClassVar[dict[str, Callable]] = {
545
+ field: _make_struct_field_setter(self.cls, field, var.type) for field, var in self.vars.items()
546
+ }
547
+
548
+ def __init__(inst, ctype=None):
549
+ StructInstance.__init__(inst, ctype or self.ctype())
550
+
551
+ self.instance_type = NewStructInstance
552
+
553
+ def __call__(self):
554
+ """
555
+ This function returns s = StructInstance(self)
556
+ s uses self.cls as template.
557
+ """
558
+ return self.instance_type()
559
+
560
+ def initializer(self):
561
+ return self.default_constructor
562
+
563
+ # return structured NumPy dtype, including field names, formats, and offsets
564
+ def numpy_dtype(self):
565
+ names = []
566
+ formats = []
567
+ offsets = []
568
+ for name, var in self.vars.items():
569
+ names.append(name)
570
+ offsets.append(getattr(self.ctype, name).offset)
571
+ if isinstance(var.type, array):
572
+ # array_t
573
+ formats.append(array_t.numpy_dtype())
574
+ elif isinstance(var.type, Struct):
575
+ # nested struct
576
+ formats.append(var.type.numpy_dtype())
577
+ elif issubclass(var.type, ctypes.Array):
578
+ scalar_typestr = type_typestr(var.type._wp_scalar_type_)
579
+ if len(var.type._shape_) == 1:
580
+ # vector
581
+ formats.append(f"{var.type._length_}{scalar_typestr}")
582
+ else:
583
+ # matrix
584
+ formats.append(f"{var.type._shape_}{scalar_typestr}")
585
+ else:
586
+ # scalar
587
+ formats.append(type_typestr(var.type))
588
+
589
+ return {"names": names, "formats": formats, "offsets": offsets, "itemsize": ctypes.sizeof(self.ctype)}
590
+
591
+ # constructs a Warp struct instance from a pointer to the ctype
592
+ def from_ptr(self, ptr):
593
+ if not ptr:
594
+ raise RuntimeError("NULL pointer exception")
595
+
596
+ # create a new struct instance
597
+ instance = self()
598
+
599
+ for name, var in self.vars.items():
600
+ offset = getattr(self.ctype, name).offset
601
+ if isinstance(var.type, array):
602
+ # We could reconstruct wp.array from array_t, but it's problematic.
603
+ # There's no guarantee that the original wp.array is still allocated and
604
+ # no easy way to make a backref.
605
+ # Instead, we just create a stub annotation, which is not a fully usable array object.
606
+ setattr(instance, name, array(dtype=var.type.dtype, ndim=var.type.ndim))
607
+ elif isinstance(var.type, Struct):
608
+ # nested struct
609
+ value = var.type.from_ptr(ptr + offset)
610
+ setattr(instance, name, value)
611
+ elif issubclass(var.type, ctypes.Array):
612
+ # vector/matrix
613
+ value = var.type.from_ptr(ptr + offset)
614
+ setattr(instance, name, value)
615
+ else:
616
+ # scalar
617
+ cvalue = ctypes.cast(ptr + offset, ctypes.POINTER(var.type._type_)).contents
618
+ if var.type == warp.float16:
619
+ setattr(instance, name, half_bits_to_float(cvalue))
620
+ else:
621
+ setattr(instance, name, cvalue.value)
622
+
623
+ return instance
624
+
625
+
626
+ class Reference:
627
+ def __init__(self, value_type):
628
+ self.value_type = value_type
629
+
630
+
631
+ def is_reference(type: Any) -> builtins.bool:
632
+ return isinstance(type, Reference)
633
+
634
+
635
+ def strip_reference(arg: Any) -> Any:
636
+ if is_reference(arg):
637
+ return arg.value_type
638
+ else:
639
+ return arg
640
+
641
+
642
+ def compute_type_str(base_name, template_params):
643
+ if not template_params:
644
+ return base_name
645
+
646
+ def param2str(p):
647
+ if isinstance(p, builtins.bool):
648
+ return "true" if p else "false"
649
+ if isinstance(p, int):
650
+ return str(p)
651
+ elif hasattr(p, "_wp_generic_type_str_"):
652
+ return compute_type_str(f"wp::{p._wp_generic_type_str_}", p._wp_type_params_)
653
+ elif hasattr(p, "_type_"):
654
+ if p.__name__ == "bool":
655
+ return "bool"
656
+ else:
657
+ return f"wp::{p.__name__}"
658
+ elif is_tile(p):
659
+ return p.ctype()
660
+ elif isinstance(p, Struct):
661
+ return p.native_name
662
+
663
+ return p.__name__
664
+
665
+ return f"{base_name}<{', '.join(map(param2str, template_params))}>"
666
+
667
+
668
+ class Var:
669
+ def __init__(
670
+ self,
671
+ label: str,
672
+ type: type,
673
+ requires_grad: builtins.bool = False,
674
+ constant: builtins.bool | None = None,
675
+ prefix: builtins.bool = True,
676
+ relative_lineno: int | None = None,
677
+ ):
678
+ # convert built-in types to wp types
679
+ if type == float:
680
+ type = float32
681
+ elif type == int:
682
+ type = int32
683
+ elif type == builtins.bool:
684
+ type = bool
685
+
686
+ self.label = label
687
+ self.type = type
688
+ self.requires_grad = requires_grad
689
+ self.constant = constant
690
+ self.prefix = prefix
691
+
692
+ # records whether this Var has been read from in a kernel function (array only)
693
+ self.is_read = False
694
+ # records whether this Var has been written to in a kernel function (array only)
695
+ self.is_write = False
696
+
697
+ # used to associate a view array Var with its parent array Var
698
+ self.parent = None
699
+
700
+ # Used to associate the variable with the Python statement that resulted in it being created.
701
+ self.relative_lineno = relative_lineno
702
+
703
+ def __str__(self):
704
+ return self.label
705
+
706
+ @staticmethod
707
+ def dtype_to_ctype(t: type) -> str:
708
+ if hasattr(t, "_wp_generic_type_str_"):
709
+ return compute_type_str(f"wp::{t._wp_generic_type_str_}", t._wp_type_params_)
710
+ elif isinstance(t, Struct):
711
+ return t.native_name
712
+ elif hasattr(t, "_wp_native_name_"):
713
+ return f"wp::{t._wp_native_name_}"
714
+ elif t.__name__ in ("bool", "int", "float"):
715
+ return t.__name__
716
+
717
+ return f"wp::{t.__name__}"
718
+
719
+ @staticmethod
720
+ def type_to_ctype(t: type, value_type: builtins.bool = False) -> str:
721
+ if isinstance(t, fixedarray):
722
+ template_args = (str(t.size), Var.dtype_to_ctype(t.dtype))
723
+ dtypestr = ", ".join(template_args)
724
+ classstr = f"wp::{type(t).__name__}"
725
+ return f"{classstr}_t<{dtypestr}>"
726
+ elif is_array(t):
727
+ dtypestr = Var.dtype_to_ctype(t.dtype)
728
+ classstr = f"wp::{type(t).__name__}"
729
+ return f"{classstr}_t<{dtypestr}>"
730
+ elif get_origin(t) is tuple:
731
+ dtypestr = ", ".join(Var.dtype_to_ctype(x) for x in get_args(t))
732
+ return f"wp::tuple_t<{dtypestr}>"
733
+ elif is_tuple(t):
734
+ dtypestr = ", ".join(Var.dtype_to_ctype(x) for x in t.types)
735
+ classstr = f"wp::{type(t).__name__}"
736
+ return f"{classstr}<{dtypestr}>"
737
+ elif is_tile(t):
738
+ return t.ctype()
739
+ elif isinstance(t, type) and issubclass(t, StructInstance):
740
+ # ensure the actual Struct name is used instead of "NewStructInstance"
741
+ return t.native_name
742
+ elif is_reference(t):
743
+ if not value_type:
744
+ return Var.type_to_ctype(t.value_type) + "*"
745
+
746
+ return Var.type_to_ctype(t.value_type)
747
+
748
+ return Var.dtype_to_ctype(t)
749
+
750
+ def ctype(self, value_type: builtins.bool = False) -> str:
751
+ return Var.type_to_ctype(self.type, value_type)
752
+
753
+ def emit(self, prefix: str = "var"):
754
+ if self.prefix:
755
+ return f"{prefix}_{self.label}"
756
+ else:
757
+ return self.label
758
+
759
+ def emit_adj(self):
760
+ return self.emit("adj")
761
+
762
+ def mark_read(self):
763
+ """Marks this Var as having been read from in a kernel (array only)."""
764
+ if not is_array(self.type):
765
+ return
766
+
767
+ self.is_read = True
768
+
769
+ # recursively update all parent states
770
+ parent = self.parent
771
+ while parent is not None:
772
+ parent.is_read = True
773
+ parent = parent.parent
774
+
775
+ def mark_write(self, **kwargs):
776
+ """Marks this Var has having been written to in a kernel (array only)."""
777
+ if not is_array(self.type):
778
+ return
779
+
780
+ # detect if we are writing to an array after reading from it within the same kernel
781
+ if self.is_read and warp._src.config.verify_autograd_array_access:
782
+ if "kernel_name" and "filename" and "lineno" in kwargs:
783
+ print(
784
+ f"Warning: Array passed to argument {self.label} in kernel {kwargs['kernel_name']} at {kwargs['filename']}:{kwargs['lineno']} is being written to after it has been read from within the same kernel. This may corrupt gradient computation in the backward pass."
785
+ )
786
+ else:
787
+ print(
788
+ f"Warning: Array {self} is being written to after it has been read from within the same kernel. This may corrupt gradient computation in the backward pass."
789
+ )
790
+ self.is_write = True
791
+
792
+ # recursively update all parent states
793
+ parent = self.parent
794
+ while parent is not None:
795
+ parent.is_write = True
796
+ parent = parent.parent
797
+
798
+
799
+ class Block:
800
+ # Represents a basic block of instructions, e.g.: list
801
+ # of straight line instructions inside a for-loop or conditional
802
+
803
+ def __init__(self):
804
+ # list of statements inside this block
805
+ self.body_forward = []
806
+ self.body_replay = []
807
+ self.body_reverse = []
808
+
809
+ # list of vars declared in this block
810
+ self.vars = []
811
+
812
+
813
+ def apply_defaults(
814
+ bound_args: inspect.BoundArguments,
815
+ values: Mapping[str, Any],
816
+ ):
817
+ # Similar to Python's `inspect.BoundArguments.apply_defaults()`
818
+ # but with the possibility to pass an augmented set of default values.
819
+ arguments = bound_args.arguments
820
+ new_arguments = []
821
+ for name in bound_args._signature.parameters.keys():
822
+ if name in arguments:
823
+ new_arguments.append((name, arguments[name]))
824
+ elif name in values:
825
+ new_arguments.append((name, values[name]))
826
+
827
+ bound_args.arguments = dict(new_arguments)
828
+
829
+
830
+ def func_match_args(func, arg_types, kwarg_types):
831
+ try:
832
+ # Try to bind the given arguments to the function's signature.
833
+ # This is not checking whether the argument types are matching,
834
+ # rather it's just assigning each argument to the corresponding
835
+ # function parameter.
836
+ bound_arg_types = func.signature.bind(*arg_types, **kwarg_types)
837
+ except TypeError:
838
+ return False
839
+
840
+ # Populate the bound arguments with any default values.
841
+ default_arg_types = {
842
+ k: None if v is None else get_arg_type(v)
843
+ for k, v in func.defaults.items()
844
+ if k not in bound_arg_types.arguments
845
+ }
846
+ apply_defaults(bound_arg_types, default_arg_types)
847
+ bound_arg_types = tuple(bound_arg_types.arguments.values())
848
+
849
+ # Check the given argument types against the ones defined on the function.
850
+ for bound_arg_type, func_arg_type in zip(bound_arg_types, func.input_types.values()):
851
+ # Let the `value_func` callback infer the type.
852
+ if bound_arg_type is None:
853
+ continue
854
+
855
+ # if arg type registered as Any, treat as
856
+ # template allowing any type to match
857
+ if func_arg_type == Any:
858
+ continue
859
+
860
+ # handle function refs as a special case
861
+ if func_arg_type == Callable and isinstance(bound_arg_type, warp._src.context.Function):
862
+ continue
863
+
864
+ # check arg type matches input variable type
865
+ if not types_equal(func_arg_type, strip_reference(bound_arg_type), match_generic=True):
866
+ return False
867
+
868
+ return True
869
+
870
+
871
+ def get_arg_type(arg: Var | Any) -> type:
872
+ if isinstance(arg, str):
873
+ return str
874
+
875
+ if isinstance(arg, Sequence):
876
+ return tuple(get_arg_type(x) for x in arg)
877
+
878
+ if is_array(arg):
879
+ return arg
880
+
881
+ if get_origin(arg) is tuple:
882
+ return tuple(get_arg_type(x) for x in get_args(arg))
883
+
884
+ if is_tuple(arg):
885
+ return arg
886
+
887
+ if isinstance(arg, (type, warp._src.context.Function)):
888
+ return arg
889
+
890
+ if isinstance(arg, Var):
891
+ if get_origin(arg.type) is tuple:
892
+ return get_args(arg.type)
893
+
894
+ return arg.type
895
+
896
+ return type(arg)
897
+
898
+
899
+ def get_arg_value(arg: Any) -> Any:
900
+ if isinstance(arg, Sequence):
901
+ return tuple(get_arg_value(x) for x in arg)
902
+
903
+ if isinstance(arg, (type, warp._src.context.Function)):
904
+ return arg
905
+
906
+ if isinstance(arg, Var):
907
+ if is_tuple(arg.type):
908
+ return tuple(get_arg_value(x) for x in arg.type.values)
909
+
910
+ if arg.constant is not None:
911
+ return arg.constant
912
+
913
+ return arg
914
+
915
+
916
+ class Adjoint:
917
+ # Source code transformer, this class takes a Python function and
918
+ # generates forward and backward SSA forms of the function instructions
919
+
920
+ def __init__(
921
+ adj,
922
+ func: Callable[..., Any],
923
+ overload_annotations=None,
924
+ is_user_function=False,
925
+ skip_forward_codegen=False,
926
+ skip_reverse_codegen=False,
927
+ custom_reverse_mode=False,
928
+ custom_reverse_num_input_args=-1,
929
+ transformers: list[ast.NodeTransformer] | None = None,
930
+ source: str | None = None,
931
+ ):
932
+ adj.func = func
933
+
934
+ adj.is_user_function = is_user_function
935
+
936
+ # whether the generation of the forward code is skipped for this function
937
+ adj.skip_forward_codegen = skip_forward_codegen
938
+ # whether the generation of the adjoint code is skipped for this function
939
+ adj.skip_reverse_codegen = skip_reverse_codegen
940
+ # Whether this function is used by a kernel that has has the backward pass enabled.
941
+ adj.used_by_backward_kernel = False
942
+
943
+ # extract name of source file
944
+ adj.filename = inspect.getsourcefile(func) or "unknown source file"
945
+ # get source file line number where function starts
946
+ adj.fun_lineno = 0
947
+ adj.source = source
948
+ if adj.source is None:
949
+ adj.source, adj.fun_lineno = adj.extract_function_source(func)
950
+
951
+ assert adj.source is not None, f"Failed to extract source code for function {func.__name__}"
952
+
953
+ # Indicates where the function definition starts (excludes decorators)
954
+ adj.fun_def_lineno = None
955
+
956
+ # get function source code
957
+ # ensures that indented class methods can be parsed as kernels
958
+ adj.source = textwrap.dedent(adj.source)
959
+
960
+ adj.source_lines = adj.source.splitlines()
961
+
962
+ if transformers is None:
963
+ transformers = []
964
+
965
+ # build AST and apply node transformers
966
+ adj.tree = ast.parse(adj.source)
967
+ adj.transformers = transformers
968
+ for transformer in transformers:
969
+ adj.tree = transformer.visit(adj.tree)
970
+
971
+ adj.fun_name = adj.tree.body[0].name
972
+
973
+ # for keeping track of line number in function code
974
+ adj.lineno = None
975
+
976
+ # whether the forward code shall be used for the reverse pass and a custom
977
+ # function signature is applied to the reverse version of the function
978
+ adj.custom_reverse_mode = custom_reverse_mode
979
+ # the number of function arguments that pertain to the forward function
980
+ # input arguments (i.e. the number of arguments that are not adjoint arguments)
981
+ adj.custom_reverse_num_input_args = custom_reverse_num_input_args
982
+
983
+ # parse argument types
984
+ argspec = get_full_arg_spec(func)
985
+
986
+ # ensure all arguments are annotated
987
+ if overload_annotations is None:
988
+ # use source-level argument annotations
989
+ if len(argspec.annotations) < len(argspec.args):
990
+ raise WarpCodegenError(f"Incomplete argument annotations on function {adj.fun_name}")
991
+ adj.arg_types = {k: v for k, v in argspec.annotations.items() if not (k == "return" and v is None)}
992
+ else:
993
+ # use overload argument annotations
994
+ for arg_name in argspec.args:
995
+ if arg_name not in overload_annotations:
996
+ raise WarpCodegenError(f"Incomplete overload annotations for function {adj.fun_name}")
997
+ adj.arg_types = overload_annotations.copy()
998
+
999
+ adj.args = []
1000
+ adj.symbols = {}
1001
+
1002
+ for name, type in adj.arg_types.items():
1003
+ # skip return hint
1004
+ if name == "return":
1005
+ continue
1006
+
1007
+ # add variable for argument
1008
+ arg = Var(name, type, requires_grad=False)
1009
+ adj.args.append(arg)
1010
+
1011
+ # pre-populate symbol dictionary with function argument names
1012
+ # this is to avoid registering false references to overshadowed modules
1013
+ adj.symbols[name] = arg
1014
+
1015
+ # Indicates whether there are unresolved static expressions in the function.
1016
+ # These stem from wp.static() expressions that could not be evaluated at declaration time.
1017
+ # This will signal to the module builder that this module needs to be rebuilt even if the module hash is unchanged.
1018
+ adj.has_unresolved_static_expressions = False
1019
+
1020
+ # try to replace static expressions by their constant result if the
1021
+ # expression can be evaluated at declaration time
1022
+ adj.static_expressions: dict[str, Any] = {}
1023
+ if "static" in adj.source:
1024
+ adj.replace_static_expressions()
1025
+
1026
+ # There are cases where a same module might be rebuilt multiple times,
1027
+ # for example when kernels are nested inside of functions, or when
1028
+ # a kernel's launch raises an exception. Ideally we'd always want to
1029
+ # avoid rebuilding kernels but some corner cases seem to depend on it,
1030
+ # so we only avoid rebuilding kernels that errored out to give a chance
1031
+ # for unit testing errors being spit out from kernels.
1032
+ adj.skip_build = False
1033
+
1034
+ # allocate extra space for a function call that requires its
1035
+ # own shared memory space, we treat shared memory as a stack
1036
+ # where each function pushes and pops space off, the extra
1037
+ # quantity is the 'roofline' amount required for the entire kernel
1038
+ def alloc_shared_extra(adj, num_bytes):
1039
+ adj.max_required_extra_shared_memory = max(adj.max_required_extra_shared_memory, num_bytes)
1040
+
1041
+ # returns the total number of bytes for a function
1042
+ # based on it's own requirements + worst case
1043
+ # requirements of any dependent functions
1044
+ def get_total_required_shared(adj):
1045
+ total_shared = 0
1046
+
1047
+ for var in adj.variables:
1048
+ if is_tile(var.type) and var.type.storage == "shared" and var.type.owner:
1049
+ total_shared += var.type.size_in_bytes()
1050
+
1051
+ return total_shared + adj.max_required_extra_shared_memory
1052
+
1053
+ @staticmethod
1054
+ def extract_function_source(func: Callable) -> tuple[str, int]:
1055
+ try:
1056
+ _, fun_lineno = inspect.getsourcelines(func)
1057
+ source = inspect.getsource(func)
1058
+ except OSError as e:
1059
+ raise RuntimeError(
1060
+ "Directly evaluating Warp code defined as a string using `exec()` is not supported, "
1061
+ "please save it to a file and use `importlib` if needed."
1062
+ ) from e
1063
+ return source, fun_lineno
1064
+
1065
+ # generate function ssa form and adjoint
1066
+ def build(adj, builder, default_builder_options=None):
1067
+ # arg Var read/write flags are held during module rebuilds, so we reset here even when skipping a build
1068
+ for arg in adj.args:
1069
+ arg.is_read = False
1070
+ arg.is_write = False
1071
+
1072
+ if adj.skip_build:
1073
+ return
1074
+
1075
+ adj.builder = builder
1076
+
1077
+ if default_builder_options is None:
1078
+ default_builder_options = {}
1079
+
1080
+ if adj.builder:
1081
+ adj.builder_options = adj.builder.options
1082
+ else:
1083
+ adj.builder_options = default_builder_options
1084
+
1085
+ global options
1086
+ options = adj.builder_options
1087
+
1088
+ adj.symbols = {} # map from symbols to adjoint variables
1089
+ adj.variables = [] # list of local variables (in order)
1090
+
1091
+ adj.return_var = None # return type for function or kernel
1092
+ adj.loop_symbols = [] # symbols at the start of each loop
1093
+ adj.loop_const_iter_symbols = (
1094
+ set()
1095
+ ) # constant iteration variables for static loops (mutating them does not raise an error)
1096
+
1097
+ # blocks
1098
+ adj.blocks = [Block()]
1099
+ adj.loop_blocks = []
1100
+
1101
+ # holds current indent level
1102
+ adj.indentation = ""
1103
+
1104
+ # used to generate new label indices
1105
+ adj.label_count = 0
1106
+
1107
+ # tracks how much additional shared memory is required by any dependent function calls
1108
+ adj.max_required_extra_shared_memory = 0
1109
+
1110
+ # update symbol map for each argument
1111
+ for a in adj.args:
1112
+ adj.symbols[a.label] = a
1113
+
1114
+ # recursively evaluate function body
1115
+ try:
1116
+ adj.eval(adj.tree.body[0])
1117
+ except Exception as original_exc:
1118
+ try:
1119
+ lineno = adj.lineno + adj.fun_lineno
1120
+ line = adj.source_lines[adj.lineno]
1121
+ msg = f'Error while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
1122
+
1123
+ # Combine the new message with the original exception's arguments
1124
+ new_args = (";".join([msg] + [str(a) for a in original_exc.args]),)
1125
+
1126
+ # Enhance the original exception with parser context before re-raising.
1127
+ # 'from None' is used to suppress Python's chained exceptions for a cleaner error output.
1128
+ raise type(original_exc)(*new_args).with_traceback(original_exc.__traceback__) from None
1129
+ finally:
1130
+ adj.skip_build = True
1131
+ adj.builder = None
1132
+
1133
+ if builder is not None:
1134
+ for a in adj.args:
1135
+ if isinstance(a.type, Struct):
1136
+ builder.build_struct_recursive(a.type)
1137
+ elif isinstance(a.type, warp._src.types.array) and isinstance(a.type.dtype, Struct):
1138
+ builder.build_struct_recursive(a.type.dtype)
1139
+
1140
+ # release builder reference for GC
1141
+ adj.builder = None
1142
+
1143
+ # code generation methods
1144
+ def format_template(adj, template, input_vars, output_var):
1145
+ # output var is always the 0th index
1146
+ args = [output_var, *input_vars]
1147
+ s = template.format(*args)
1148
+
1149
+ return s
1150
+
1151
+ # generates a list of formatted args
1152
+ def format_args(adj, prefix, args):
1153
+ arg_strs = []
1154
+
1155
+ for a in args:
1156
+ if isinstance(a, warp._src.context.Function):
1157
+ # functions don't have a var_ prefix so strip it off here
1158
+ if prefix == "var":
1159
+ arg_strs.append(f"{a.namespace}{a.native_func}")
1160
+ else:
1161
+ arg_strs.append(f"{a.namespace}{prefix}_{a.native_func}")
1162
+ elif is_reference(a.type):
1163
+ arg_strs.append(f"{prefix}_{a}")
1164
+ elif isinstance(a, Var):
1165
+ arg_strs.append(a.emit(prefix))
1166
+ else:
1167
+ raise WarpCodegenTypeError(f"Arguments must be variables or functions, got {type(a)}")
1168
+
1169
+ return arg_strs
1170
+
1171
+ # generates argument string for a forward function call
1172
+ def format_forward_call_args(adj, args, use_initializer_list):
1173
+ arg_str = ", ".join(adj.format_args("var", args))
1174
+ if use_initializer_list:
1175
+ return f"{{{arg_str}}}"
1176
+ return arg_str
1177
+
1178
+ # generates argument string for a reverse function call
1179
+ def format_reverse_call_args(
1180
+ adj,
1181
+ args_var,
1182
+ args,
1183
+ args_out,
1184
+ use_initializer_list,
1185
+ has_output_args=True,
1186
+ require_original_output_arg=False,
1187
+ ):
1188
+ formatted_var = adj.format_args("var", args_var)
1189
+ formatted_out = []
1190
+ if has_output_args and (require_original_output_arg or len(args_out) > 1):
1191
+ formatted_out = adj.format_args("var", args_out)
1192
+ formatted_var_adj = adj.format_args(
1193
+ "&adj" if use_initializer_list else "adj",
1194
+ args,
1195
+ )
1196
+ formatted_out_adj = adj.format_args("adj", args_out)
1197
+
1198
+ if len(formatted_var_adj) == 0 and len(formatted_out_adj) == 0:
1199
+ # there are no adjoint arguments, so we don't need to call the reverse function
1200
+ return None
1201
+
1202
+ if use_initializer_list:
1203
+ var_str = f"{{{', '.join(formatted_var)}}}"
1204
+ out_str = f"{{{', '.join(formatted_out)}}}"
1205
+ adj_str = f"{{{', '.join(formatted_var_adj)}}}"
1206
+ out_adj_str = ", ".join(formatted_out_adj)
1207
+ if len(args_out) > 1:
1208
+ arg_str = ", ".join([var_str, out_str, adj_str, out_adj_str])
1209
+ else:
1210
+ arg_str = ", ".join([var_str, adj_str, out_adj_str])
1211
+ else:
1212
+ arg_str = ", ".join(formatted_var + formatted_out + formatted_var_adj + formatted_out_adj)
1213
+ return arg_str
1214
+
1215
+ def indent(adj):
1216
+ adj.indentation = adj.indentation + " "
1217
+
1218
+ def dedent(adj):
1219
+ adj.indentation = adj.indentation[:-4]
1220
+
1221
+ def begin_block(adj, name="block"):
1222
+ b = Block()
1223
+
1224
+ # give block a unique id
1225
+ b.label = name + "_" + str(adj.label_count)
1226
+ adj.label_count += 1
1227
+
1228
+ adj.blocks.append(b)
1229
+ return b
1230
+
1231
+ def end_block(adj):
1232
+ return adj.blocks.pop()
1233
+
1234
+ def add_var(adj, type=None, constant=None):
1235
+ index = len(adj.variables)
1236
+ name = str(index)
1237
+
1238
+ # allocate new variable
1239
+ v = Var(name, type=type, constant=constant, relative_lineno=adj.lineno)
1240
+
1241
+ adj.variables.append(v)
1242
+
1243
+ adj.blocks[-1].vars.append(v)
1244
+
1245
+ return v
1246
+
1247
+ def register_var(adj, var):
1248
+ # We sometimes initialize `Var` instances that might be thrown away
1249
+ # afterwards, so this method allows to defer their registration among
1250
+ # the list of primal vars until later on, instead of registering them
1251
+ # immediately if we were to use `adj.add_var()` or `adj.add_constant()`.
1252
+
1253
+ if isinstance(var, (Reference, warp._src.context.Function)):
1254
+ return var
1255
+
1256
+ if isinstance(var, int):
1257
+ return adj.add_constant(var)
1258
+
1259
+ if var.label is None:
1260
+ return adj.add_var(var.type, var.constant)
1261
+
1262
+ return var
1263
+
1264
+ def get_line_directive(adj, statement: str, relative_lineno: int | None = None) -> str | None:
1265
+ """Get a line directive for the given statement.
1266
+
1267
+ Args:
1268
+ statement: The statement to get the line directive for.
1269
+ relative_lineno: The line number of the statement relative to the function.
1270
+
1271
+ Returns:
1272
+ A line directive for the given statement, or None if no line directive is needed.
1273
+ """
1274
+
1275
+ if adj.filename == "unknown source file" or adj.fun_lineno == 0:
1276
+ # Early return if function is not associated with a source file or is otherwise invalid
1277
+ # TODO: Get line directives working with wp.map() functions
1278
+ return None
1279
+
1280
+ # lineinfo is enabled by default in debug mode regardless of the builder option, don't want to unnecessarily
1281
+ # emit line directives in generated code if it's not being compiled with line information
1282
+ build_mode = val if (val := adj.builder_options.get("mode")) is not None else warp._src.config.mode
1283
+
1284
+ lineinfo_enabled = adj.builder_options.get("lineinfo", False) or build_mode == "debug"
1285
+
1286
+ if relative_lineno is not None and lineinfo_enabled and warp._src.config.line_directives:
1287
+ is_comment = statement.strip().startswith("//")
1288
+ if not is_comment:
1289
+ line = relative_lineno + adj.fun_lineno
1290
+ # Convert backslashes to forward slashes for CUDA compatibility
1291
+ normalized_path = adj.filename.replace("\\", "/")
1292
+ return f'#line {line} "{normalized_path}"'
1293
+ return None
1294
+
1295
+ def add_forward(adj, statement: str, replay: str | None = None, skip_replay: builtins.bool = False) -> None:
1296
+ """Append a statement to the forward pass."""
1297
+
1298
+ if line_directive := adj.get_line_directive(statement, adj.lineno):
1299
+ adj.blocks[-1].body_forward.append(line_directive)
1300
+
1301
+ adj.blocks[-1].body_forward.append(adj.indentation + statement)
1302
+
1303
+ if not skip_replay:
1304
+ if line_directive:
1305
+ adj.blocks[-1].body_replay.append(line_directive)
1306
+
1307
+ if replay:
1308
+ # if custom replay specified then output it
1309
+ adj.blocks[-1].body_replay.append(adj.indentation + replay)
1310
+ else:
1311
+ # by default just replay the original statement
1312
+ adj.blocks[-1].body_replay.append(adj.indentation + statement)
1313
+
1314
+ # append a statement to the reverse pass
1315
+ def add_reverse(adj, statement: str) -> None:
1316
+ """Append a statement to the reverse pass."""
1317
+
1318
+ adj.blocks[-1].body_reverse.append(adj.indentation + statement)
1319
+
1320
+ if line_directive := adj.get_line_directive(statement, adj.lineno):
1321
+ adj.blocks[-1].body_reverse.append(line_directive)
1322
+
1323
+ def add_constant(adj, n):
1324
+ output = adj.add_var(type=type(n), constant=n)
1325
+ return output
1326
+
1327
+ def load(adj, var):
1328
+ if is_reference(var.type):
1329
+ var = adj.add_builtin_call("load", [var])
1330
+ return var
1331
+
1332
+ def add_comp(adj, op_strings, left, comps):
1333
+ output = adj.add_var(builtins.bool)
1334
+
1335
+ left = adj.load(left)
1336
+ s = output.emit() + " = " + ("(" * len(comps)) + left.emit() + " "
1337
+
1338
+ prev_comp_var = None
1339
+
1340
+ for op, comp in zip(op_strings, comps):
1341
+ comp_chainable = op_str_is_chainable(op)
1342
+ if comp_chainable and prev_comp_var:
1343
+ # We restrict chaining to operands of the same type
1344
+ if prev_comp_var.type is comp.type:
1345
+ prev_comp_var = adj.load(prev_comp_var)
1346
+ comp_var = adj.load(comp)
1347
+ s += "&& (" + prev_comp_var.emit() + " " + op + " " + comp_var.emit() + ")) "
1348
+ else:
1349
+ raise WarpCodegenTypeError(
1350
+ f"Cannot chain comparisons of unequal types: {prev_comp_var.type} {op} {comp.type}."
1351
+ )
1352
+ else:
1353
+ comp_var = adj.load(comp)
1354
+ s += op + " " + comp_var.emit() + ") "
1355
+
1356
+ prev_comp_var = comp_var
1357
+
1358
+ s = s.rstrip() + ";"
1359
+
1360
+ adj.add_forward(s)
1361
+
1362
+ return output
1363
+
1364
+ def add_bool_op(adj, op_string, exprs):
1365
+ exprs = [adj.load(expr) for expr in exprs]
1366
+ output = adj.add_var(builtins.bool)
1367
+ command = output.emit() + " = " + (" " + op_string + " ").join([expr.emit() for expr in exprs]) + ";"
1368
+ adj.add_forward(command)
1369
+
1370
+ return output
1371
+
1372
+ def resolve_func(adj, func, arg_types, kwarg_types, min_outputs):
1373
+ if not func.is_builtin():
1374
+ # user-defined function
1375
+ overload = func.get_overload(arg_types, kwarg_types)
1376
+ if overload is not None:
1377
+ return overload
1378
+ else:
1379
+ # if func is overloaded then perform overload resolution here
1380
+ # we validate argument types before they go to generated native code
1381
+ for f in func.overloads:
1382
+ # skip type checking for variadic functions
1383
+ if not f.variadic:
1384
+ # check argument counts match are compatible (may be some default args)
1385
+ if len(f.input_types) < len(arg_types) + len(kwarg_types):
1386
+ continue
1387
+
1388
+ if not func_match_args(f, arg_types, kwarg_types):
1389
+ continue
1390
+
1391
+ # check output dimensions match expectations
1392
+ if min_outputs:
1393
+ value_type = f.value_func(None, None)
1394
+ if not isinstance(value_type, Sequence) or len(value_type) != min_outputs:
1395
+ continue
1396
+
1397
+ # found a match, use it
1398
+ return f
1399
+
1400
+ # unresolved function, report error
1401
+ arg_type_reprs = []
1402
+
1403
+ for x in itertools.chain(arg_types, kwarg_types.values()):
1404
+ if isinstance(x, warp._src.context.Function):
1405
+ arg_type_reprs.append("function")
1406
+ else:
1407
+ # shorten Warp primitive type names
1408
+ if isinstance(x, Sequence):
1409
+ if len(x) != 1:
1410
+ raise WarpCodegenError("Argument must not be the result from a multi-valued function")
1411
+ arg_type = x[0]
1412
+ else:
1413
+ arg_type = x
1414
+
1415
+ arg_type_reprs.append(type_repr(arg_type))
1416
+
1417
+ raise WarpCodegenError(
1418
+ f"Couldn't find function overload for '{func.key}' that matched inputs with types: [{', '.join(arg_type_reprs)}]"
1419
+ )
1420
+
1421
+ def add_call(adj, func, args, kwargs, type_args, min_outputs=None):
1422
+ # Extract the types and values passed as arguments to the function call.
1423
+ arg_types = tuple(strip_reference(get_arg_type(x)) for x in args)
1424
+ kwarg_types = {k: strip_reference(get_arg_type(v)) for k, v in kwargs.items()}
1425
+
1426
+ # Resolve the exact function signature among any existing overload.
1427
+ func = adj.resolve_func(func, arg_types, kwarg_types, min_outputs)
1428
+
1429
+ # Bind the positional and keyword arguments to the function's signature
1430
+ # in order to process them as Python does it.
1431
+ bound_args: inspect.BoundArguments = func.signature.bind(*args, **kwargs)
1432
+
1433
+ # Type args are the "compile time" argument values we get from codegen.
1434
+ # For example, when calling `wp.vec3f(...)` from within a kernel,
1435
+ # this translates in fact to calling the `vector()` built-in augmented
1436
+ # with the type args `length=3, dtype=float`.
1437
+ # Eventually, these need to be passed to the underlying C++ function,
1438
+ # so we update the arguments with the type args here.
1439
+ if type_args:
1440
+ for arg in type_args:
1441
+ if arg in bound_args.arguments:
1442
+ # In case of conflict, ideally we'd throw an error since
1443
+ # what comes from codegen should be the source of truth
1444
+ # and users also passing the same value as an argument
1445
+ # is redundant (e.g.: `wp.mat22(shape=(2, 2))`).
1446
+ # However, for backward compatibility, we allow that form
1447
+ # as long as the values are equal.
1448
+ if values_check_equal(get_arg_value(bound_args.arguments[arg]), type_args[arg]):
1449
+ continue
1450
+
1451
+ raise RuntimeError(
1452
+ f"Remove the extraneous `{arg}` parameter "
1453
+ f"when calling the templated version of "
1454
+ f"`wp.{func.native_func}()`"
1455
+ )
1456
+
1457
+ type_vars = {k: Var(None, type=type(v), constant=v) for k, v in type_args.items()}
1458
+ apply_defaults(bound_args, type_vars)
1459
+
1460
+ if func.defaults:
1461
+ default_vars = {
1462
+ k: Var(None, type=type(v), constant=v)
1463
+ for k, v in func.defaults.items()
1464
+ if k not in bound_args.arguments and v is not None
1465
+ }
1466
+ apply_defaults(bound_args, default_vars)
1467
+
1468
+ bound_args = bound_args.arguments
1469
+
1470
+ # if it is a user-function then build it recursively
1471
+ if not func.is_builtin():
1472
+ # If the function called is a user function,
1473
+ # we need to ensure its adjoint is also being generated.
1474
+ if adj.used_by_backward_kernel:
1475
+ func.adj.used_by_backward_kernel = True
1476
+
1477
+ if adj.builder is None:
1478
+ func.build(None)
1479
+
1480
+ elif func not in adj.builder.functions:
1481
+ adj.builder.build_function(func)
1482
+ # add custom grad, replay functions to the list of functions
1483
+ # to be built later (invalid code could be generated if we built them now)
1484
+ # so that they are not missed when only the forward function is imported
1485
+ # from another module
1486
+ if func.custom_grad_func:
1487
+ adj.builder.deferred_functions.append(func.custom_grad_func)
1488
+ if func.custom_replay_func:
1489
+ adj.builder.deferred_functions.append(func.custom_replay_func)
1490
+
1491
+ # Resolve the return value based on the types and values of the given arguments.
1492
+ bound_arg_types = {k: get_arg_type(v) for k, v in bound_args.items()}
1493
+ bound_arg_values = {k: get_arg_value(v) for k, v in bound_args.items()}
1494
+
1495
+ return_type = func.value_func(
1496
+ {k: strip_reference(v) for k, v in bound_arg_types.items()},
1497
+ bound_arg_values,
1498
+ )
1499
+
1500
+ # Handle the special case where a Var instance is returned from the `value_func`
1501
+ # callback, in which case we replace the call with a reference to that variable.
1502
+ if isinstance(return_type, Var):
1503
+ return adj.register_var(return_type)
1504
+ elif isinstance(return_type, Sequence) and all(isinstance(x, Var) for x in return_type):
1505
+ return tuple(adj.register_var(x) for x in return_type)
1506
+
1507
+ if get_origin(return_type) is tuple:
1508
+ types = get_args(return_type)
1509
+ return_type = warp._src.types.tuple_t(types=types, values=(None,) * len(types))
1510
+
1511
+ # immediately allocate output variables so we can pass them into the dispatch method
1512
+ if return_type is None:
1513
+ # void function
1514
+ output = None
1515
+ output_list = []
1516
+ elif not isinstance(return_type, Sequence) or len(return_type) == 1:
1517
+ # single return value function
1518
+ if isinstance(return_type, Sequence):
1519
+ return_type = return_type[0]
1520
+ output = adj.add_var(return_type)
1521
+ output_list = [output]
1522
+ else:
1523
+ # multiple return value function
1524
+ output = [adj.add_var(v) for v in return_type]
1525
+ output_list = output
1526
+
1527
+ # If we have a built-in that requires special handling to dispatch
1528
+ # the arguments to the underlying C++ function, then we can resolve
1529
+ # these using the `dispatch_func`. Since this is only called from
1530
+ # within codegen, we pass it directly `codegen.Var` objects,
1531
+ # which allows for some more advanced resolution to be performed,
1532
+ # for example by checking whether an argument corresponds to
1533
+ # a literal value or references a variable.
1534
+ extra_shared_memory = 0
1535
+ if func.lto_dispatch_func is not None:
1536
+ func_args, template_args, ltoirs, extra_shared_memory = func.lto_dispatch_func(
1537
+ func.input_types, return_type, output_list, bound_args, options=adj.builder_options, builder=adj.builder
1538
+ )
1539
+ elif func.dispatch_func is not None:
1540
+ func_args, template_args = func.dispatch_func(func.input_types, return_type, bound_args)
1541
+ else:
1542
+ func_args = tuple(bound_args.values())
1543
+ template_args = ()
1544
+
1545
+ func_args = tuple(adj.register_var(x) for x in func_args)
1546
+ func_name = compute_type_str(func.native_func, template_args)
1547
+ use_initializer_list = func.initializer_list_func(bound_args, return_type)
1548
+
1549
+ fwd_args = []
1550
+ for func_arg in func_args:
1551
+ if not isinstance(func_arg, (Reference, warp._src.context.Function)):
1552
+ func_arg_var = adj.load(func_arg)
1553
+ else:
1554
+ func_arg_var = func_arg
1555
+
1556
+ # if the argument is a function (and not a builtin), then build it recursively
1557
+ if isinstance(func_arg_var, warp._src.context.Function) and not func_arg_var.is_builtin():
1558
+ if adj.used_by_backward_kernel:
1559
+ func_arg_var.adj.used_by_backward_kernel = True
1560
+
1561
+ adj.builder.build_function(func_arg_var)
1562
+
1563
+ fwd_args.append(strip_reference(func_arg_var))
1564
+
1565
+ if return_type is None:
1566
+ # handles expression (zero output) functions, e.g.: void do_something();
1567
+ forward_call = (
1568
+ f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1569
+ )
1570
+ replay_call = forward_call
1571
+ if func.custom_replay_func is not None or func.replay_snippet is not None:
1572
+ replay_call = f"{func.namespace}replay_{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1573
+
1574
+ elif not isinstance(return_type, Sequence) or len(return_type) == 1:
1575
+ # handle simple function (one output)
1576
+ forward_call = f"var_{output} = {func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1577
+ replay_call = forward_call
1578
+ if func.custom_replay_func is not None:
1579
+ replay_call = f"var_{output} = {func.namespace}replay_{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1580
+
1581
+ else:
1582
+ # handle multiple value functions
1583
+ forward_call = (
1584
+ f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args + output, use_initializer_list)});"
1585
+ )
1586
+ replay_call = forward_call
1587
+
1588
+ if func.skip_replay:
1589
+ adj.add_forward(forward_call, replay="// " + replay_call)
1590
+ else:
1591
+ adj.add_forward(forward_call, replay=replay_call)
1592
+
1593
+ if func.is_differentiable and len(func_args):
1594
+ adj_args = tuple(strip_reference(x) for x in func_args)
1595
+ reverse_has_output_args = (
1596
+ func.require_original_output_arg or len(output_list) > 1
1597
+ ) and func.custom_grad_func is None
1598
+ arg_str = adj.format_reverse_call_args(
1599
+ fwd_args,
1600
+ adj_args,
1601
+ output_list,
1602
+ use_initializer_list,
1603
+ has_output_args=reverse_has_output_args,
1604
+ require_original_output_arg=func.require_original_output_arg,
1605
+ )
1606
+ if arg_str is not None:
1607
+ reverse_call = f"{func.namespace}adj_{func.native_func}({arg_str});"
1608
+ adj.add_reverse(reverse_call)
1609
+
1610
+ # update our smem roofline requirements based on any
1611
+ # shared memory required by the dependent function call
1612
+ if not func.is_builtin():
1613
+ adj.alloc_shared_extra(func.adj.get_total_required_shared() + extra_shared_memory)
1614
+ else:
1615
+ adj.alloc_shared_extra(extra_shared_memory)
1616
+
1617
+ return output
1618
+
1619
+ def add_builtin_call(adj, func_name, args, min_outputs=None):
1620
+ func = warp._src.context.builtin_functions[func_name]
1621
+ return adj.add_call(func, args, {}, {}, min_outputs=min_outputs)
1622
+
1623
+ def add_return(adj, var):
1624
+ if var is None or len(var) == 0:
1625
+ # NOTE: If this kernel gets compiled for a CUDA device, then we need
1626
+ # to convert the return; into a continue; in codegen_func_forward()
1627
+ adj.add_forward("return;", f"goto label{adj.label_count};")
1628
+ elif len(var) == 1:
1629
+ adj.add_forward(f"return {var[0].emit()};", f"goto label{adj.label_count};")
1630
+ adj.add_reverse("adj_" + str(var[0]) + " += adj_ret;")
1631
+ else:
1632
+ for i, v in enumerate(var):
1633
+ adj.add_forward(f"ret_{i} = {v.emit()};")
1634
+ adj.add_reverse(f"adj_{v} += adj_ret_{i};")
1635
+ adj.add_forward("return;", f"goto label{adj.label_count};")
1636
+
1637
+ adj.add_reverse(f"label{adj.label_count}:;")
1638
+
1639
+ adj.label_count += 1
1640
+
1641
+ # define an if statement
1642
+ def begin_if(adj, cond):
1643
+ cond = adj.load(cond)
1644
+ adj.add_forward(f"if ({cond.emit()}) {{")
1645
+ adj.add_reverse("}")
1646
+
1647
+ adj.indent()
1648
+
1649
+ def end_if(adj, cond):
1650
+ adj.dedent()
1651
+
1652
+ adj.add_forward("}")
1653
+ cond = adj.load(cond)
1654
+ adj.add_reverse(f"if ({cond.emit()}) {{")
1655
+
1656
+ def begin_else(adj, cond):
1657
+ cond = adj.load(cond)
1658
+ adj.add_forward(f"if (!{cond.emit()}) {{")
1659
+ adj.add_reverse("}")
1660
+
1661
+ adj.indent()
1662
+
1663
+ def end_else(adj, cond):
1664
+ adj.dedent()
1665
+
1666
+ adj.add_forward("}")
1667
+ cond = adj.load(cond)
1668
+ adj.add_reverse(f"if (!{cond.emit()}) {{")
1669
+
1670
+ # define a for-loop
1671
+ def begin_for(adj, iter):
1672
+ cond_block = adj.begin_block("for")
1673
+ adj.loop_blocks.append(cond_block)
1674
+ adj.add_forward(f"start_{cond_block.label}:;")
1675
+ adj.indent()
1676
+
1677
+ # evaluate cond
1678
+ adj.add_forward(f"if (iter_cmp({iter.emit()}) == 0) goto end_{cond_block.label};")
1679
+
1680
+ # evaluate iter
1681
+ val = adj.add_builtin_call("iter_next", [iter])
1682
+
1683
+ adj.begin_block()
1684
+
1685
+ return val
1686
+
1687
+ def end_for(adj, iter):
1688
+ body_block = adj.end_block()
1689
+ cond_block = adj.end_block()
1690
+ adj.loop_blocks.pop()
1691
+
1692
+ ####################
1693
+ # forward pass
1694
+
1695
+ for i in cond_block.body_forward:
1696
+ adj.blocks[-1].body_forward.append(i)
1697
+
1698
+ for i in body_block.body_forward:
1699
+ adj.blocks[-1].body_forward.append(i)
1700
+
1701
+ adj.add_forward(f"goto start_{cond_block.label};", skip_replay=True)
1702
+
1703
+ adj.dedent()
1704
+ adj.add_forward(f"end_{cond_block.label}:;", skip_replay=True)
1705
+
1706
+ ####################
1707
+ # reverse pass
1708
+
1709
+ reverse = []
1710
+
1711
+ # reverse iterator
1712
+ reverse.append(adj.indentation + f"{iter.emit()} = wp::iter_reverse({iter.emit()});")
1713
+
1714
+ for i in cond_block.body_forward:
1715
+ reverse.append(i)
1716
+
1717
+ # zero adjoints
1718
+ for i in body_block.vars:
1719
+ if is_tile(i.type):
1720
+ if i.type.owner:
1721
+ reverse.append(adj.indentation + f"\t{i.emit_adj()}.grad_zero();")
1722
+ else:
1723
+ reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
1724
+
1725
+ # replay
1726
+ for i in body_block.body_replay:
1727
+ reverse.append(i)
1728
+
1729
+ # reverse
1730
+ for i in reversed(body_block.body_reverse):
1731
+ reverse.append(i)
1732
+
1733
+ reverse.append(adj.indentation + f"\tgoto start_{cond_block.label};")
1734
+ reverse.append(adj.indentation + f"end_{cond_block.label}:;")
1735
+
1736
+ adj.blocks[-1].body_reverse.extend(reversed(reverse))
1737
+
1738
+ # define a while loop
1739
+ def begin_while(adj, cond):
1740
+ # evaluate condition in its own block
1741
+ # so we can control replay
1742
+ cond_block = adj.begin_block("while")
1743
+ adj.loop_blocks.append(cond_block)
1744
+ cond_block.body_forward.append(f"start_{cond_block.label}:;")
1745
+
1746
+ c = adj.eval(cond)
1747
+ c = adj.load(c)
1748
+
1749
+ cond_block.body_forward.append(f"if (({c.emit()}) == false) goto end_{cond_block.label};")
1750
+
1751
+ # being block around loop
1752
+ adj.begin_block()
1753
+ adj.indent()
1754
+
1755
+ def end_while(adj):
1756
+ adj.dedent()
1757
+ body_block = adj.end_block()
1758
+ cond_block = adj.end_block()
1759
+ adj.loop_blocks.pop()
1760
+
1761
+ ####################
1762
+ # forward pass
1763
+
1764
+ for i in cond_block.body_forward:
1765
+ adj.blocks[-1].body_forward.append(i)
1766
+
1767
+ for i in body_block.body_forward:
1768
+ adj.blocks[-1].body_forward.append(i)
1769
+
1770
+ adj.blocks[-1].body_forward.append(f"goto start_{cond_block.label};")
1771
+ adj.blocks[-1].body_forward.append(f"end_{cond_block.label}:;")
1772
+
1773
+ ####################
1774
+ # reverse pass
1775
+ reverse = []
1776
+
1777
+ # cond
1778
+ for i in cond_block.body_forward:
1779
+ reverse.append(i)
1780
+
1781
+ # zero adjoints of local vars
1782
+ for i in body_block.vars:
1783
+ reverse.append(f"{i.emit_adj()} = {{}};")
1784
+
1785
+ # replay
1786
+ for i in body_block.body_replay:
1787
+ reverse.append(i)
1788
+
1789
+ # reverse
1790
+ for i in reversed(body_block.body_reverse):
1791
+ reverse.append(i)
1792
+
1793
+ reverse.append(f"goto start_{cond_block.label};")
1794
+ reverse.append(f"end_{cond_block.label}:;")
1795
+
1796
+ # output
1797
+ adj.blocks[-1].body_reverse.extend(reversed(reverse))
1798
+
1799
+ def emit_FunctionDef(adj, node):
1800
+ adj.fun_def_lineno = node.lineno
1801
+
1802
+ for f in node.body:
1803
+ # Skip variable creation for standalone constants, including docstrings
1804
+ if isinstance(f, ast.Expr) and isinstance(f.value, ast.Constant):
1805
+ continue
1806
+ adj.eval(f)
1807
+
1808
+ if adj.return_var is not None and len(adj.return_var) == 1:
1809
+ if not isinstance(node.body[-1], ast.Return):
1810
+ adj.add_forward("return {};", skip_replay=True)
1811
+
1812
+ # native function case: return type is specified, eg -> int or -> wp.float32
1813
+ is_func_native = False
1814
+ if node.decorator_list is not None and len(node.decorator_list) == 1:
1815
+ obj = node.decorator_list[0]
1816
+ if isinstance(obj, ast.Call):
1817
+ if isinstance(obj.func, ast.Attribute):
1818
+ if obj.func.attr == "func_native":
1819
+ is_func_native = True
1820
+ if is_func_native and node.returns is not None:
1821
+ if isinstance(node.returns, ast.Name): # python built-in type
1822
+ var = Var(label="return_type", type=eval(node.returns.id))
1823
+ elif isinstance(node.returns, ast.Attribute): # warp type
1824
+ var = Var(label="return_type", type=eval(node.returns.attr))
1825
+ else:
1826
+ raise WarpCodegenTypeError("Native function return type not recognized")
1827
+ adj.return_var = (var,)
1828
+
1829
+ def emit_If(adj, node):
1830
+ if len(node.body) == 0:
1831
+ return None
1832
+
1833
+ # eval condition
1834
+ cond = adj.eval(node.test)
1835
+
1836
+ if cond.constant is not None:
1837
+ # resolve constant condition
1838
+ if cond.constant:
1839
+ for stmt in node.body:
1840
+ adj.eval(stmt)
1841
+ else:
1842
+ for stmt in node.orelse:
1843
+ adj.eval(stmt)
1844
+ return None
1845
+
1846
+ # save symbol map
1847
+ symbols_prev = adj.symbols.copy()
1848
+
1849
+ # eval body
1850
+ adj.begin_if(cond)
1851
+
1852
+ for stmt in node.body:
1853
+ adj.eval(stmt)
1854
+
1855
+ adj.end_if(cond)
1856
+
1857
+ # detect existing symbols with conflicting definitions (variables assigned inside the branch)
1858
+ # and resolve with a phi (select) function
1859
+ for items in symbols_prev.items():
1860
+ sym = items[0]
1861
+ var1 = items[1]
1862
+ var2 = adj.symbols[sym]
1863
+
1864
+ if var1 != var2:
1865
+ # insert a phi function that selects var1, var2 based on cond
1866
+ out = adj.add_builtin_call("where", [cond, var2, var1])
1867
+ adj.symbols[sym] = out
1868
+
1869
+ symbols_prev = adj.symbols.copy()
1870
+
1871
+ # evaluate 'else' statement as if (!cond)
1872
+ if len(node.orelse) > 0:
1873
+ adj.begin_else(cond)
1874
+
1875
+ for stmt in node.orelse:
1876
+ adj.eval(stmt)
1877
+
1878
+ adj.end_else(cond)
1879
+
1880
+ # detect existing symbols with conflicting definitions (variables assigned inside the else)
1881
+ # and resolve with a phi (select) function
1882
+ for items in symbols_prev.items():
1883
+ sym = items[0]
1884
+ var1 = items[1]
1885
+ var2 = adj.symbols[sym]
1886
+
1887
+ if var1 != var2:
1888
+ # insert a phi function that selects var1, var2 based on cond
1889
+ # note the reversed order of vars since we want to use !cond as our select
1890
+ out = adj.add_builtin_call("where", [cond, var1, var2])
1891
+ adj.symbols[sym] = out
1892
+
1893
+ def emit_IfExp(adj, node):
1894
+ cond = adj.eval(node.test)
1895
+
1896
+ if cond.constant is not None:
1897
+ return adj.eval(node.body) if cond.constant else adj.eval(node.orelse)
1898
+
1899
+ adj.begin_if(cond)
1900
+ body = adj.eval(node.body)
1901
+ adj.end_if(cond)
1902
+
1903
+ adj.begin_else(cond)
1904
+ orelse = adj.eval(node.orelse)
1905
+ adj.end_else(cond)
1906
+
1907
+ return adj.add_builtin_call("where", [cond, body, orelse])
1908
+
1909
+ def emit_Compare(adj, node):
1910
+ # node.left, node.ops (list of ops), node.comparators (things to compare to)
1911
+ # e.g. (left ops[0] node.comparators[0]) ops[1] node.comparators[1]
1912
+
1913
+ left = adj.eval(node.left)
1914
+ comps = [adj.eval(comp) for comp in node.comparators]
1915
+ op_strings = [builtin_operators[type(op)] for op in node.ops]
1916
+
1917
+ return adj.add_comp(op_strings, left, comps)
1918
+
1919
+ def emit_BoolOp(adj, node):
1920
+ # op, expr list values
1921
+
1922
+ op = node.op
1923
+ if isinstance(op, ast.And):
1924
+ func = "&&"
1925
+ elif isinstance(op, ast.Or):
1926
+ func = "||"
1927
+ else:
1928
+ raise WarpCodegenKeyError(f"Op {op} is not supported")
1929
+
1930
+ return adj.add_bool_op(func, [adj.eval(expr) for expr in node.values])
1931
+
1932
+ def emit_Name(adj, node):
1933
+ # lookup symbol, if it has already been assigned to a variable then return the existing mapping
1934
+ if node.id in adj.symbols:
1935
+ return adj.symbols[node.id]
1936
+
1937
+ obj = adj.resolve_external_reference(node.id)
1938
+
1939
+ if obj is None:
1940
+ raise WarpCodegenKeyError("Referencing undefined symbol: " + str(node.id))
1941
+
1942
+ if warp._src.types.is_value(obj):
1943
+ # evaluate constant
1944
+ out = adj.add_constant(obj)
1945
+ adj.symbols[node.id] = out
1946
+ return out
1947
+
1948
+ # the named object is either a function, class name, or module
1949
+ # pass it back to the caller for processing
1950
+ if isinstance(obj, warp._src.context.Function):
1951
+ return obj
1952
+ if isinstance(obj, type):
1953
+ return obj
1954
+ if isinstance(obj, Struct):
1955
+ adj.builder.build_struct_recursive(obj)
1956
+ return obj
1957
+ if isinstance(obj, types.ModuleType):
1958
+ return obj
1959
+
1960
+ raise TypeError(f"Invalid external reference type: {type(obj)}")
1961
+
1962
+ @staticmethod
1963
+ def resolve_type_attribute(var_type: type, attr: str):
1964
+ if isinstance(var_type, type) and type_is_value(var_type):
1965
+ if attr == "dtype":
1966
+ return type_scalar_type(var_type)
1967
+ elif attr == "length":
1968
+ return type_size(var_type)
1969
+
1970
+ return getattr(var_type, attr, None)
1971
+
1972
+ def vector_component_index(adj, component, vector_type):
1973
+ if len(component) != 1:
1974
+ raise WarpCodegenAttributeError(f"Vector swizzle must be single character, got .{component}")
1975
+
1976
+ dim = vector_type._shape_[0]
1977
+ swizzles = "xyzw"[0:dim]
1978
+ if component not in swizzles:
1979
+ raise WarpCodegenAttributeError(
1980
+ f"Vector swizzle for {vector_type} must be one of {swizzles}, got {component}"
1981
+ )
1982
+
1983
+ index = swizzles.index(component)
1984
+ index = adj.add_constant(index)
1985
+ return index
1986
+
1987
+ def transform_component(adj, component):
1988
+ if len(component) != 1:
1989
+ raise WarpCodegenAttributeError(f"Transform attribute must be single character, got .{component}")
1990
+
1991
+ if component not in ("p", "q"):
1992
+ raise WarpCodegenAttributeError(f"Attribute for transformation must be either 'p' or 'q', got {component}")
1993
+
1994
+ return component
1995
+
1996
+ @staticmethod
1997
+ def is_differentiable_value_type(var_type):
1998
+ # checks that the argument type is a value type (i.e, not an array)
1999
+ # possibly holding differentiable values (for which gradients must be accumulated)
2000
+ return type_scalar_type(var_type) in float_types or isinstance(var_type, Struct)
2001
+
2002
+ def emit_Attribute(adj, node):
2003
+ if hasattr(node, "is_adjoint"):
2004
+ node.value.is_adjoint = True
2005
+
2006
+ aggregate = adj.eval(node.value)
2007
+
2008
+ try:
2009
+ if isinstance(aggregate, Var) and aggregate.constant is not None:
2010
+ # this case may occur when the attribute is a constant, e.g.: `IntEnum.A.value`
2011
+ return aggregate
2012
+
2013
+ if isinstance(aggregate, types.ModuleType) or isinstance(aggregate, type):
2014
+ out = getattr(aggregate, node.attr)
2015
+
2016
+ if warp._src.types.is_value(out):
2017
+ return adj.add_constant(out)
2018
+ if isinstance(out, (enum.IntEnum, enum.IntFlag)):
2019
+ return adj.add_constant(int(out))
2020
+
2021
+ return out
2022
+
2023
+ if hasattr(node, "is_adjoint"):
2024
+ # create a Var that points to the struct attribute, i.e.: directly generates `struct.attr` when used
2025
+ attr_name = aggregate.label + "." + node.attr
2026
+ attr_type = aggregate.type.vars[node.attr].type
2027
+
2028
+ return Var(attr_name, attr_type)
2029
+
2030
+ aggregate_type = strip_reference(aggregate.type)
2031
+
2032
+ # reading a vector or quaternion component
2033
+ if type_is_vector(aggregate_type) or type_is_quaternion(aggregate_type):
2034
+ index = adj.vector_component_index(node.attr, aggregate_type)
2035
+
2036
+ return adj.add_builtin_call("extract", [aggregate, index])
2037
+
2038
+ elif type_is_transformation(aggregate_type):
2039
+ component = adj.transform_component(node.attr)
2040
+
2041
+ if component == "p":
2042
+ return adj.add_builtin_call("transform_get_translation", [aggregate])
2043
+ else:
2044
+ return adj.add_builtin_call("transform_get_rotation", [aggregate])
2045
+
2046
+ else:
2047
+ attr_var = aggregate_type.vars[node.attr]
2048
+
2049
+ # represent pointer types as uint64
2050
+ if isinstance(attr_var.type, pointer_t):
2051
+ cast = f"({Var.dtype_to_ctype(uint64)}*)"
2052
+ adj_cast = f"({Var.dtype_to_ctype(attr_var.type.dtype)}*)"
2053
+ attr_type = Reference(uint64)
2054
+ else:
2055
+ cast = ""
2056
+ adj_cast = ""
2057
+ attr_type = Reference(attr_var.type)
2058
+
2059
+ attr = adj.add_var(attr_type)
2060
+
2061
+ if is_reference(aggregate.type):
2062
+ adj.add_forward(f"{attr.emit()} = {cast}&({aggregate.emit()}->{attr_var.label});")
2063
+ else:
2064
+ adj.add_forward(f"{attr.emit()} = {cast}&({aggregate.emit()}.{attr_var.label});")
2065
+
2066
+ if adj.is_differentiable_value_type(strip_reference(attr_type)):
2067
+ adj.add_reverse(f"{aggregate.emit_adj()}.{attr_var.label} += {adj_cast}{attr.emit_adj()};")
2068
+ else:
2069
+ adj.add_reverse(f"{aggregate.emit_adj()}.{attr_var.label} = {adj_cast}{attr.emit_adj()};")
2070
+
2071
+ return attr
2072
+
2073
+ except (KeyError, AttributeError) as e:
2074
+ # Try resolving as type attribute
2075
+ aggregate_type = strip_reference(aggregate.type) if isinstance(aggregate, Var) else aggregate
2076
+
2077
+ type_attribute = adj.resolve_type_attribute(aggregate_type, node.attr)
2078
+ if type_attribute is not None:
2079
+ return type_attribute
2080
+
2081
+ if isinstance(aggregate, Var):
2082
+ node_name = get_node_name_safe(node.value)
2083
+ raise WarpCodegenAttributeError(
2084
+ f"Error, `{node.attr}` is not an attribute of '{node_name}' ({type_repr(aggregate.type)})"
2085
+ ) from e
2086
+ raise WarpCodegenAttributeError(f"Error, `{node.attr}` is not an attribute of '{aggregate}'") from e
2087
+
2088
+ def emit_Assert(adj, node):
2089
+ # eval condition
2090
+ cond = adj.eval(node.test)
2091
+ cond = adj.load(cond)
2092
+
2093
+ source_segment = ast.get_source_segment(adj.source, node)
2094
+ # If a message was provided with the assert, " marks can interfere with the generated code
2095
+ escaped_segment = source_segment.replace('"', '\\"')
2096
+
2097
+ adj.add_forward(f'assert(("{escaped_segment}",{cond.emit()}));')
2098
+
2099
+ def emit_Constant(adj, node):
2100
+ if node.value is None:
2101
+ raise WarpCodegenTypeError("None type unsupported")
2102
+ else:
2103
+ return adj.add_constant(node.value)
2104
+
2105
+ def emit_BinOp(adj, node):
2106
+ # evaluate binary operator arguments
2107
+
2108
+ if warp._src.config.verify_autograd_array_access:
2109
+ # array overwrite tracking: in-place operators are a special case
2110
+ # x[tid] = x[tid] + 1 is a read followed by a write, but we only want to record the write
2111
+ # so we save the current arg read flags and restore them after lhs eval
2112
+ is_read_states = []
2113
+ for arg in adj.args:
2114
+ is_read_states.append(arg.is_read)
2115
+
2116
+ # evaluate lhs binary operator argument
2117
+ left = adj.eval(node.left)
2118
+
2119
+ if warp._src.config.verify_autograd_array_access:
2120
+ # restore arg read flags
2121
+ for i, arg in enumerate(adj.args):
2122
+ arg.is_read = is_read_states[i]
2123
+
2124
+ # evaluate rhs binary operator argument
2125
+ right = adj.eval(node.right)
2126
+
2127
+ name = builtin_operators[type(node.op)]
2128
+
2129
+ try:
2130
+ # Check if there is any user-defined overload for this operator
2131
+ user_func = adj.resolve_external_reference(name)
2132
+ if isinstance(user_func, warp._src.context.Function):
2133
+ return adj.add_call(user_func, (left, right), {}, {})
2134
+ except WarpCodegenError:
2135
+ pass
2136
+
2137
+ return adj.add_builtin_call(name, [left, right])
2138
+
2139
+ def emit_UnaryOp(adj, node):
2140
+ # evaluate unary op arguments
2141
+ arg = adj.eval(node.operand)
2142
+
2143
+ # evaluate expression to a compile-time constant if arg is a constant
2144
+ if arg.constant is not None and math.isfinite(arg.constant):
2145
+ if isinstance(node.op, ast.USub):
2146
+ return adj.add_constant(-arg.constant)
2147
+
2148
+ name = builtin_operators[type(node.op)]
2149
+
2150
+ return adj.add_builtin_call(name, [arg])
2151
+
2152
+ def materialize_redefinitions(adj, symbols):
2153
+ # detect symbols with conflicting definitions (assigned inside the for loop)
2154
+ for items in symbols.items():
2155
+ sym = items[0]
2156
+ if adj.is_constant_iter_symbol(sym):
2157
+ # ignore constant overwriting in for-loops if it is a loop iterator
2158
+ # (it is no problem to unroll static loops multiple times in sequence)
2159
+ continue
2160
+
2161
+ var1 = items[1]
2162
+ var2 = adj.symbols[sym]
2163
+
2164
+ if var1 != var2:
2165
+ if warp._src.config.verbose and not adj.custom_reverse_mode:
2166
+ lineno = adj.lineno + adj.fun_lineno
2167
+ line = adj.source_lines[adj.lineno]
2168
+ msg = f'Warning: detected mutated variable {sym} during a dynamic for-loop in function "{adj.fun_name}" at {adj.filename}:{lineno}: this may not be a differentiable operation.\n{line}\n'
2169
+ print(msg)
2170
+
2171
+ if var1.constant is not None:
2172
+ raise WarpCodegenError(
2173
+ f"Error mutating a constant {sym} inside a dynamic loop, use the following syntax: pi = float(3.141) to declare a dynamic variable"
2174
+ )
2175
+
2176
+ # overwrite the old variable value (violates SSA)
2177
+ adj.add_builtin_call("assign", [var1, var2])
2178
+
2179
+ # reset the symbol to point to the original variable
2180
+ adj.symbols[sym] = var1
2181
+
2182
+ def emit_While(adj, node):
2183
+ adj.begin_while(node.test)
2184
+
2185
+ adj.loop_symbols.append(adj.symbols.copy())
2186
+
2187
+ # eval body
2188
+ for s in node.body:
2189
+ adj.eval(s)
2190
+
2191
+ adj.materialize_redefinitions(adj.loop_symbols[-1])
2192
+ adj.loop_symbols.pop()
2193
+
2194
+ adj.end_while()
2195
+
2196
+ def eval_num(adj, a):
2197
+ if isinstance(a, ast.Constant):
2198
+ return True, a.value
2199
+ if isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Constant):
2200
+ # Negative constant
2201
+ return True, -a.operand.value
2202
+
2203
+ # try and resolve the expression to an object
2204
+ # e.g.: wp.constant in the globals scope
2205
+ obj, _ = adj.resolve_static_expression(a)
2206
+
2207
+ if obj is None:
2208
+ obj = adj.eval(a)
2209
+
2210
+ if isinstance(obj, Var) and obj.constant is not None:
2211
+ obj = obj.constant
2212
+
2213
+ return warp._src.types.is_int(obj), obj
2214
+
2215
+ # detects whether a loop contains a break (or continue) statement
2216
+ def contains_break(adj, body):
2217
+ for s in body:
2218
+ if isinstance(s, ast.Break):
2219
+ return True
2220
+ elif isinstance(s, ast.Continue):
2221
+ return True
2222
+ elif isinstance(s, ast.If):
2223
+ if adj.contains_break(s.body):
2224
+ return True
2225
+ if adj.contains_break(s.orelse):
2226
+ return True
2227
+ else:
2228
+ # note that nested for or while loops containing a break statement
2229
+ # do not affect the current loop
2230
+ pass
2231
+
2232
+ return False
2233
+
2234
+ # returns a constant range() if unrollable, otherwise None
2235
+ def get_unroll_range(adj, loop):
2236
+ if (
2237
+ not isinstance(loop.iter, ast.Call)
2238
+ or not isinstance(loop.iter.func, ast.Name)
2239
+ or loop.iter.func.id != "range"
2240
+ or len(loop.iter.args) == 0
2241
+ or len(loop.iter.args) > 3
2242
+ ):
2243
+ return None
2244
+
2245
+ # if all range() arguments are numeric constants we will unroll
2246
+ # note that this only handles trivial constants, it will not unroll
2247
+ # constant compile-time expressions e.g.: range(0, 3*2)
2248
+
2249
+ # Evaluate the arguments and check that they are numeric constants
2250
+ # It is important to do that in one pass, so that if evaluating these arguments have side effects
2251
+ # the code does not get generated more than once
2252
+ range_args = [adj.eval_num(arg) for arg in loop.iter.args]
2253
+ arg_is_numeric, arg_values = zip(*range_args)
2254
+
2255
+ if all(arg_is_numeric):
2256
+ # All argument are numeric constants
2257
+
2258
+ # range(end)
2259
+ if len(loop.iter.args) == 1:
2260
+ start = 0
2261
+ end = arg_values[0]
2262
+ step = 1
2263
+
2264
+ # range(start, end)
2265
+ elif len(loop.iter.args) == 2:
2266
+ start = arg_values[0]
2267
+ end = arg_values[1]
2268
+ step = 1
2269
+
2270
+ # range(start, end, step)
2271
+ elif len(loop.iter.args) == 3:
2272
+ start = arg_values[0]
2273
+ end = arg_values[1]
2274
+ step = arg_values[2]
2275
+
2276
+ # test if we're above max unroll count
2277
+ max_iters = abs(end - start) // abs(step)
2278
+
2279
+ if "max_unroll" in adj.builder_options:
2280
+ max_unroll = adj.builder_options["max_unroll"]
2281
+ else:
2282
+ max_unroll = warp._src.config.max_unroll
2283
+
2284
+ ok_to_unroll = True
2285
+
2286
+ if max_iters > max_unroll:
2287
+ if warp._src.config.verbose:
2288
+ print(
2289
+ f"Warning: fixed-size loop count of {max_iters} is larger than the module 'max_unroll' limit of {max_unroll}, will generate dynamic loop."
2290
+ )
2291
+ ok_to_unroll = False
2292
+
2293
+ elif adj.contains_break(loop.body):
2294
+ if warp._src.config.verbose:
2295
+ print("Warning: 'break' or 'continue' found in loop body, will generate dynamic loop.")
2296
+ ok_to_unroll = False
2297
+
2298
+ if ok_to_unroll:
2299
+ return range(start, end, step)
2300
+
2301
+ # Unroll is not possible, range needs to be valuated dynamically
2302
+ range_call = adj.add_builtin_call(
2303
+ "range",
2304
+ [adj.add_constant(val) if is_numeric else val for is_numeric, val in range_args],
2305
+ )
2306
+ return range_call
2307
+
2308
+ def record_constant_iter_symbol(adj, sym):
2309
+ adj.loop_const_iter_symbols.add(sym)
2310
+
2311
+ def is_constant_iter_symbol(adj, sym):
2312
+ return sym in adj.loop_const_iter_symbols
2313
+
2314
+ def emit_For(adj, node):
2315
+ # try and unroll simple range() statements that use constant args
2316
+ unroll_range = adj.get_unroll_range(node)
2317
+
2318
+ if isinstance(unroll_range, range):
2319
+ const_iter_sym = node.target.id
2320
+ # prevent constant conflicts in `materialize_redefinitions()`
2321
+ adj.record_constant_iter_symbol(const_iter_sym)
2322
+
2323
+ # unroll static for-loop
2324
+ for i in unroll_range:
2325
+ const_iter = adj.add_constant(i)
2326
+ adj.symbols[const_iter_sym] = const_iter
2327
+
2328
+ # eval body
2329
+ for s in node.body:
2330
+ adj.eval(s)
2331
+
2332
+ # otherwise generate a dynamic loop
2333
+ else:
2334
+ # evaluate the Iterable -- only if not previously evaluated when trying to unroll
2335
+ if unroll_range is not None:
2336
+ # Range has already been evaluated when trying to unroll, do not re-evaluate
2337
+ iter = unroll_range
2338
+ else:
2339
+ iter = adj.eval(node.iter)
2340
+
2341
+ adj.symbols[node.target.id] = adj.begin_for(iter)
2342
+
2343
+ # for loops should be side-effect free, here we store a copy
2344
+ adj.loop_symbols.append(adj.symbols.copy())
2345
+
2346
+ # eval body
2347
+ for s in node.body:
2348
+ adj.eval(s)
2349
+
2350
+ adj.materialize_redefinitions(adj.loop_symbols[-1])
2351
+ adj.loop_symbols.pop()
2352
+
2353
+ adj.end_for(iter)
2354
+
2355
+ def emit_Break(adj, node):
2356
+ adj.materialize_redefinitions(adj.loop_symbols[-1])
2357
+
2358
+ adj.add_forward(f"goto end_{adj.loop_blocks[-1].label};")
2359
+
2360
+ def emit_Continue(adj, node):
2361
+ adj.materialize_redefinitions(adj.loop_symbols[-1])
2362
+
2363
+ adj.add_forward(f"goto start_{adj.loop_blocks[-1].label};")
2364
+
2365
+ def emit_Expr(adj, node):
2366
+ return adj.eval(node.value)
2367
+
2368
+ def check_tid_in_func_error(adj, node):
2369
+ if adj.is_user_function:
2370
+ if hasattr(node.func, "attr") and node.func.attr == "tid":
2371
+ lineno = adj.lineno + adj.fun_lineno
2372
+ line = adj.source_lines[adj.lineno]
2373
+ raise WarpCodegenError(
2374
+ "tid() may only be called from a Warp kernel, not a Warp function. "
2375
+ "Instead, obtain the indices from a @wp.kernel and pass them as "
2376
+ f"arguments to the function {adj.fun_name}, {adj.filename}:{lineno}:\n{line}\n"
2377
+ )
2378
+
2379
+ def resolve_arg(adj, arg):
2380
+ # Always try to start with evaluating the argument since it can help
2381
+ # detecting some issues such as global variables being accessed.
2382
+ try:
2383
+ var = adj.eval(arg)
2384
+ except (WarpCodegenError, WarpCodegenKeyError) as e:
2385
+ error = e
2386
+ else:
2387
+ error = None
2388
+
2389
+ # Check if we can resolve the argument as a static expression.
2390
+ # If not, return the variable resulting from evaluating the argument.
2391
+ expr, _ = adj.resolve_static_expression(arg)
2392
+ if expr is None:
2393
+ if error is not None:
2394
+ raise error
2395
+
2396
+ return var
2397
+
2398
+ if isinstance(expr, (type, Struct, Var, warp._src.context.Function)):
2399
+ return expr
2400
+
2401
+ if isinstance(expr, (enum.IntEnum, enum.IntFlag)):
2402
+ return adj.add_constant(int(expr))
2403
+
2404
+ return adj.add_constant(expr)
2405
+
2406
+ def emit_Call(adj, node):
2407
+ adj.check_tid_in_func_error(node)
2408
+
2409
+ # try and lookup function in globals by
2410
+ # resolving path (e.g.: module.submodule.attr)
2411
+ if hasattr(node.func, "warp_func"):
2412
+ func = node.func.warp_func
2413
+ path = []
2414
+ else:
2415
+ func, path = adj.resolve_static_expression(node.func)
2416
+ if func is None:
2417
+ func = adj.eval(node.func)
2418
+
2419
+ if adj.is_static_expression(func):
2420
+ # try to evaluate wp.static() expressions
2421
+ obj, code = adj.evaluate_static_expression(node)
2422
+ if obj is not None:
2423
+ adj.static_expressions[code] = obj
2424
+ if isinstance(obj, warp._src.context.Function):
2425
+ # special handling for wp.static() evaluating to a function
2426
+ return obj
2427
+ else:
2428
+ out = adj.add_constant(obj)
2429
+ return out
2430
+
2431
+ type_args = {}
2432
+
2433
+ if len(path) > 0 and not isinstance(func, warp._src.context.Function):
2434
+ attr = path[-1]
2435
+ caller = func
2436
+ func = None
2437
+
2438
+ # try and lookup function name in builtins (e.g.: using `dot` directly without wp prefix)
2439
+ if attr in warp._src.context.builtin_functions:
2440
+ func = warp._src.context.builtin_functions[attr]
2441
+
2442
+ # vector class type e.g.: wp.vec3f constructor
2443
+ if func is None and hasattr(caller, "_wp_generic_type_str_"):
2444
+ func = warp._src.context.builtin_functions.get(caller._wp_constructor_)
2445
+
2446
+ # scalar class type e.g.: wp.int8 constructor
2447
+ if func is None and hasattr(caller, "__name__") and caller.__name__ in warp._src.context.builtin_functions:
2448
+ func = warp._src.context.builtin_functions.get(caller.__name__)
2449
+
2450
+ # struct constructor
2451
+ if func is None and isinstance(caller, Struct):
2452
+ if adj.builder is not None:
2453
+ adj.builder.build_struct_recursive(caller)
2454
+ if node.args or node.keywords:
2455
+ func = caller.value_constructor
2456
+ else:
2457
+ func = caller.default_constructor
2458
+
2459
+ # lambda function
2460
+ if func is None and getattr(caller, "__name__", None) == "<lambda>":
2461
+ raise NotImplementedError("Lambda expressions are not yet supported")
2462
+
2463
+ if hasattr(caller, "_wp_type_args_"):
2464
+ type_args = caller._wp_type_args_
2465
+
2466
+ if func is None:
2467
+ raise WarpCodegenError(
2468
+ f"Could not find function {'.'.join(path)} as a built-in or user-defined function. Note that user functions must be annotated with a @wp.func decorator to be called from a kernel."
2469
+ )
2470
+
2471
+ # get expected return count, e.g.: for multi-assignment
2472
+ min_outputs = None
2473
+ if hasattr(node, "expects"):
2474
+ min_outputs = node.expects
2475
+
2476
+ # Evaluate all positional and keywords arguments.
2477
+ args = tuple(adj.resolve_arg(x) for x in node.args)
2478
+ kwargs = {x.arg: adj.resolve_arg(x.value) for x in node.keywords}
2479
+
2480
+ out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs)
2481
+
2482
+ if warp._src.config.verify_autograd_array_access:
2483
+ # Extract the types and values passed as arguments to the function call.
2484
+ arg_types = tuple(strip_reference(get_arg_type(x)) for x in args)
2485
+ kwarg_types = {k: strip_reference(get_arg_type(v)) for k, v in kwargs.items()}
2486
+
2487
+ # Resolve the exact function signature among any existing overload.
2488
+ resolved_func = adj.resolve_func(func, arg_types, kwarg_types, min_outputs)
2489
+
2490
+ # update arg read/write states according to what happens to that arg in the called function
2491
+ if hasattr(resolved_func, "adj"):
2492
+ for i, arg in enumerate(args):
2493
+ if resolved_func.adj.args[i].is_write:
2494
+ kernel_name = adj.fun_name
2495
+ filename = adj.filename
2496
+ lineno = adj.lineno + adj.fun_lineno
2497
+ arg.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2498
+ if resolved_func.adj.args[i].is_read:
2499
+ arg.mark_read()
2500
+
2501
+ return out
2502
+
2503
+ def emit_Index(adj, node):
2504
+ # the ast.Index node appears in 3.7 versions
2505
+ # when performing array slices, e.g.: x = arr[i]
2506
+ # but in version 3.8 and higher it does not appear
2507
+
2508
+ if hasattr(node, "is_adjoint"):
2509
+ node.value.is_adjoint = True
2510
+
2511
+ return adj.eval(node.value)
2512
+
2513
+ def eval_indices(adj, target_type, indices):
2514
+ nodes = indices
2515
+ if hasattr(target_type, "_wp_generic_type_hint_"):
2516
+ indices = []
2517
+ for dim, node in enumerate(nodes):
2518
+ if isinstance(node, ast.Slice):
2519
+ # In the context of slicing a vec/mat type, indices are expected
2520
+ # to be compile-time constants, hence we can infer the actual slice
2521
+ # bounds also at compile-time.
2522
+ length = target_type._shape_[dim]
2523
+ step = 1 if node.step is None else adj.eval(node.step).constant
2524
+
2525
+ if node.lower is None:
2526
+ start = length - 1 if step < 0 else 0
2527
+ else:
2528
+ start = adj.eval(node.lower).constant
2529
+ start = min(max(start, -length), length)
2530
+ start = start + length if start < 0 else start
2531
+
2532
+ if node.upper is None:
2533
+ stop = -1 if step < 0 else length
2534
+ else:
2535
+ stop = adj.eval(node.upper).constant
2536
+ stop = min(max(stop, -length), length)
2537
+ stop = stop + length if stop < 0 else stop
2538
+
2539
+ slice = adj.add_builtin_call("slice", (start, stop, step))
2540
+ indices.append(slice)
2541
+ else:
2542
+ indices.append(adj.eval(node))
2543
+
2544
+ return tuple(indices)
2545
+ else:
2546
+ return tuple(adj.eval(x) for x in nodes)
2547
+
2548
+ def emit_indexing(adj, target, indices):
2549
+ target_type = strip_reference(target.type)
2550
+ indices = adj.eval_indices(target_type, indices)
2551
+
2552
+ if is_array(target_type):
2553
+ if len(indices) == target_type.ndim and all(
2554
+ warp._src.types.type_is_int(strip_reference(x.type)) for x in indices
2555
+ ):
2556
+ # handles array loads (where each dimension has an index specified)
2557
+ out = adj.add_builtin_call("address", [target, *indices])
2558
+
2559
+ if warp._src.config.verify_autograd_array_access:
2560
+ target.mark_read()
2561
+
2562
+ else:
2563
+ if isinstance(target_type, warp._src.types.array):
2564
+ # In order to reduce the number of overloads needed in the C
2565
+ # implementation to support combinations of int/slice indices,
2566
+ # we convert all integer indices into slices, and set their
2567
+ # step to 0 if they are representing an integer index.
2568
+ new_indices = []
2569
+ for idx in indices:
2570
+ if not warp._src.types.is_slice(strip_reference(idx.type)):
2571
+ new_idx = adj.add_builtin_call("slice", (idx, idx, 0))
2572
+ new_indices.append(new_idx)
2573
+ else:
2574
+ new_indices.append(idx)
2575
+
2576
+ indices = new_indices
2577
+
2578
+ # handles array views (fewer indices than dimensions)
2579
+ out = adj.add_builtin_call("view", [target, *indices])
2580
+
2581
+ if warp._src.config.verify_autograd_array_access:
2582
+ # store reference to target Var to propagate downstream read/write state back to root arg Var
2583
+ out.parent = target
2584
+
2585
+ # view arg inherits target Var's read/write states
2586
+ out.is_read = target.is_read
2587
+ out.is_write = target.is_write
2588
+
2589
+ elif is_tile(target_type):
2590
+ if len(indices) >= len(target_type.shape): # equality for scalars, inequality for composite types
2591
+ # handles extracting a single element from a tile
2592
+ out = adj.add_builtin_call("tile_extract", [target, *indices])
2593
+ elif len(indices) < len(target_type.shape):
2594
+ # handles tile views
2595
+ out = adj.add_builtin_call("tile_view", [target, indices])
2596
+ else:
2597
+ raise RuntimeError(
2598
+ f"Incorrect number of indices specified for a tile view/extract, got {len(indices)} indices for a {len(target_type.shape)} dimensional tile."
2599
+ )
2600
+
2601
+ else:
2602
+ # handles non-array type indexing, e.g: vec3, mat33, etc
2603
+ out = adj.add_builtin_call("extract", [target, *indices])
2604
+
2605
+ return out
2606
+
2607
+ # from a list of lists of indices, strip the first `count` indices
2608
+ @staticmethod
2609
+ def strip_indices(indices, count):
2610
+ dim = count
2611
+ while count > 0:
2612
+ ij = indices[0]
2613
+ indices = indices[1:]
2614
+ count -= len(ij)
2615
+
2616
+ # report straddling like in `arr2d[0][1,2]` as a syntax error
2617
+ if count < 0:
2618
+ raise WarpCodegenError(
2619
+ f"Incorrect number of indices specified for array indexing, got {dim - count} indices for a {dim} dimensional array."
2620
+ )
2621
+
2622
+ return indices
2623
+
2624
+ def recurse_subscript(adj, node, indices):
2625
+ if isinstance(node, ast.Name):
2626
+ target = adj.eval(node)
2627
+ return target, indices
2628
+
2629
+ if isinstance(node, ast.Subscript):
2630
+ if hasattr(node.value, "attr") and node.value.attr == "adjoint":
2631
+ return adj.eval(node), indices
2632
+
2633
+ if isinstance(node.slice, ast.Tuple):
2634
+ ij = node.slice.elts
2635
+ elif isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Tuple):
2636
+ # The node `ast.Index` is deprecated in Python 3.9.
2637
+ ij = node.slice.value.elts
2638
+ elif isinstance(node.slice, ast.ExtSlice):
2639
+ # The node `ast.ExtSlice` is deprecated in Python 3.9.
2640
+ ij = node.slice.dims
2641
+ else:
2642
+ ij = [node.slice]
2643
+
2644
+ indices = [ij, *indices] # prepend
2645
+
2646
+ target, indices = adj.recurse_subscript(node.value, indices)
2647
+
2648
+ target_type = strip_reference(target.type)
2649
+ if is_array(target_type):
2650
+ flat_indices = [i for ij in indices for i in ij]
2651
+ if len(flat_indices) > target_type.ndim:
2652
+ target = adj.emit_indexing(target, flat_indices[: target_type.ndim])
2653
+ indices = adj.strip_indices(indices, target_type.ndim)
2654
+
2655
+ return target, indices
2656
+
2657
+ target = adj.eval(node)
2658
+ return target, indices
2659
+
2660
+ # returns the object being indexed, and the list of indices
2661
+ def eval_subscript(adj, node):
2662
+ target, indices = adj.recurse_subscript(node, [])
2663
+ flat_indices = [i for ij in indices for i in ij]
2664
+ return target, flat_indices
2665
+
2666
+ def emit_Subscript(adj, node):
2667
+ if hasattr(node.value, "attr") and node.value.attr == "adjoint":
2668
+ # handle adjoint of a variable, i.e. wp.adjoint[var]
2669
+ node.slice.is_adjoint = True
2670
+ var = adj.eval(node.slice)
2671
+ var_name = var.label
2672
+ var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False)
2673
+ return var
2674
+
2675
+ target, indices = adj.eval_subscript(node)
2676
+
2677
+ return adj.emit_indexing(target, indices)
2678
+
2679
+ def emit_Slice(adj, node):
2680
+ start = SLICE_BEGIN if node.lower is None else adj.eval(node.lower)
2681
+ stop = SLICE_END if node.upper is None else adj.eval(node.upper)
2682
+ step = 1 if node.step is None else adj.eval(node.step)
2683
+ return adj.add_builtin_call("slice", (start, stop, step))
2684
+
2685
+ def emit_Assign(adj, node):
2686
+ if len(node.targets) != 1:
2687
+ raise WarpCodegenError("Assigning the same value to multiple variables is not supported")
2688
+
2689
+ # Check if the rhs corresponds to an unsupported construct.
2690
+ # Tuples are supported in the context of assigning multiple variables
2691
+ # at once, but not for simple assignments like `x = (1, 2, 3)`.
2692
+ # Therefore, we need to catch this specific case here instead of
2693
+ # more generally in `adj.eval()`.
2694
+ if isinstance(node.value, ast.List):
2695
+ raise WarpCodegenError(
2696
+ "List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
2697
+ )
2698
+
2699
+ lhs = node.targets[0]
2700
+
2701
+ if isinstance(lhs, ast.Tuple) and isinstance(node.value, ast.Call):
2702
+ # record the expected number of outputs on the node
2703
+ # we do this so we can decide which function to
2704
+ # call based on the number of expected outputs
2705
+ node.value.expects = len(lhs.elts)
2706
+
2707
+ # evaluate rhs
2708
+ if isinstance(lhs, ast.Tuple) and isinstance(node.value, ast.Tuple):
2709
+ rhs = [adj.eval(v) for v in node.value.elts]
2710
+ else:
2711
+ rhs = adj.eval(node.value)
2712
+
2713
+ # handle the case where we are assigning multiple output variables
2714
+ if isinstance(lhs, ast.Tuple):
2715
+ subtype = getattr(rhs, "type", None)
2716
+
2717
+ if isinstance(subtype, warp._src.types.tuple_t):
2718
+ if len(rhs.type.types) != len(lhs.elts):
2719
+ raise WarpCodegenError(
2720
+ f"Invalid number of values to unpack (expected {len(lhs.elts)}, got {len(rhs.type.types)})."
2721
+ )
2722
+ rhs = tuple(adj.add_builtin_call("extract", (rhs, adj.add_constant(i))) for i in range(len(lhs.elts)))
2723
+
2724
+ names = []
2725
+ for v in lhs.elts:
2726
+ if isinstance(v, ast.Name):
2727
+ names.append(v.id)
2728
+ else:
2729
+ raise WarpCodegenError(
2730
+ "Multiple return functions can only assign to simple variables, e.g.: x, y = func()"
2731
+ )
2732
+
2733
+ if len(names) != len(rhs):
2734
+ raise WarpCodegenError(
2735
+ f"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {len(rhs)}, got {len(names)})"
2736
+ )
2737
+
2738
+ out = rhs
2739
+ for name, rhs in zip(names, out):
2740
+ if name in adj.symbols:
2741
+ if not types_equal(rhs.type, adj.symbols[name].type):
2742
+ raise WarpCodegenTypeError(
2743
+ f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
2744
+ )
2745
+
2746
+ adj.symbols[name] = rhs
2747
+
2748
+ # handles the case where we are assigning to an array index (e.g.: arr[i] = 2.0)
2749
+ elif isinstance(lhs, ast.Subscript):
2750
+ if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
2751
+ # handle adjoint of a variable, i.e. wp.adjoint[var]
2752
+ lhs.slice.is_adjoint = True
2753
+ src_var = adj.eval(lhs.slice)
2754
+ var = Var(f"adj_{src_var.label}", type=src_var.type, constant=None, prefix=False)
2755
+ adj.add_forward(f"{var.emit()} = {rhs.emit()};")
2756
+ return
2757
+
2758
+ target, indices = adj.eval_subscript(lhs)
2759
+
2760
+ target_type = strip_reference(target.type)
2761
+ indices = adj.eval_indices(target_type, indices)
2762
+
2763
+ if is_array(target_type):
2764
+ adj.add_builtin_call("array_store", [target, *indices, rhs])
2765
+
2766
+ if warp._src.config.verify_autograd_array_access:
2767
+ kernel_name = adj.fun_name
2768
+ filename = adj.filename
2769
+ lineno = adj.lineno + adj.fun_lineno
2770
+
2771
+ target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2772
+
2773
+ elif is_tile(target_type):
2774
+ adj.add_builtin_call("assign", [target, *indices, rhs])
2775
+
2776
+ elif (
2777
+ type_is_vector(target_type)
2778
+ or type_is_quaternion(target_type)
2779
+ or type_is_matrix(target_type)
2780
+ or type_is_transformation(target_type)
2781
+ ):
2782
+ # recursively unwind AST, stopping at penultimate node
2783
+ root = lhs
2784
+ while hasattr(root.value, "value"):
2785
+ root = root.value
2786
+ # lhs is updating a variable adjoint (i.e. wp.adjoint[var])
2787
+ if hasattr(root, "attr") and root.attr == "adjoint":
2788
+ attr = adj.add_builtin_call("index", [target, *indices])
2789
+ adj.add_builtin_call("store", [attr, rhs])
2790
+ return
2791
+
2792
+ # TODO: array vec component case
2793
+ if is_reference(target.type):
2794
+ attr = adj.add_builtin_call("indexref", [target, *indices])
2795
+ adj.add_builtin_call("store", [attr, rhs])
2796
+
2797
+ if warp._src.config.verbose and not adj.custom_reverse_mode:
2798
+ lineno = adj.lineno + adj.fun_lineno
2799
+ line = adj.source_lines[adj.lineno]
2800
+ node_source = adj.get_node_source(lhs.value)
2801
+ print(
2802
+ f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
2803
+ )
2804
+ else:
2805
+ if warp._src.config.enable_vector_component_overwrites:
2806
+ out = adj.add_builtin_call("assign_copy", [target, *indices, rhs])
2807
+
2808
+ # re-point target symbol to out var
2809
+ for id in adj.symbols:
2810
+ if adj.symbols[id] == target:
2811
+ adj.symbols[id] = out
2812
+ break
2813
+ else:
2814
+ adj.add_builtin_call("assign_inplace", [target, *indices, rhs])
2815
+
2816
+ else:
2817
+ raise WarpCodegenError(
2818
+ f"Can only subscript assign array, vector, quaternion, transformation, and matrix types, got {target_type}"
2819
+ )
2820
+
2821
+ elif isinstance(lhs, ast.Name):
2822
+ # symbol name
2823
+ name = lhs.id
2824
+
2825
+ # check type matches if symbol already defined
2826
+ if name in adj.symbols:
2827
+ if not types_equal(strip_reference(rhs.type), adj.symbols[name].type):
2828
+ raise WarpCodegenTypeError(
2829
+ f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
2830
+ )
2831
+
2832
+ if isinstance(node.value, ast.Tuple):
2833
+ out = rhs
2834
+ elif isinstance(rhs, Sequence):
2835
+ out = adj.add_builtin_call("tuple", rhs)
2836
+ elif isinstance(node.value, ast.Name) or is_reference(rhs.type):
2837
+ out = adj.add_builtin_call("copy", [rhs])
2838
+ else:
2839
+ out = rhs
2840
+
2841
+ # update symbol map (assumes lhs is a Name node)
2842
+ adj.symbols[name] = out
2843
+
2844
+ elif isinstance(lhs, ast.Attribute):
2845
+ aggregate = adj.eval(lhs.value)
2846
+ aggregate_type = strip_reference(aggregate.type)
2847
+
2848
+ # assigning to a vector or quaternion component
2849
+ if type_is_vector(aggregate_type) or type_is_quaternion(aggregate_type):
2850
+ index = adj.vector_component_index(lhs.attr, aggregate_type)
2851
+
2852
+ if is_reference(aggregate.type):
2853
+ attr = adj.add_builtin_call("indexref", [aggregate, index])
2854
+ adj.add_builtin_call("store", [attr, rhs])
2855
+ else:
2856
+ if warp._src.config.enable_vector_component_overwrites:
2857
+ out = adj.add_builtin_call("assign_copy", [aggregate, index, rhs])
2858
+
2859
+ # re-point target symbol to out var
2860
+ for id in adj.symbols:
2861
+ if adj.symbols[id] == aggregate:
2862
+ adj.symbols[id] = out
2863
+ break
2864
+ else:
2865
+ adj.add_builtin_call("assign_inplace", [aggregate, index, rhs])
2866
+
2867
+ elif type_is_transformation(aggregate_type):
2868
+ component = adj.transform_component(lhs.attr)
2869
+
2870
+ # TODO: x[i,j].p = rhs case
2871
+ if is_reference(aggregate.type):
2872
+ raise WarpCodegenError(f"Error, assigning transform attribute {component} to an array element")
2873
+
2874
+ if component == "p":
2875
+ return adj.add_builtin_call("transform_set_translation", [aggregate, rhs])
2876
+ else:
2877
+ return adj.add_builtin_call("transform_set_rotation", [aggregate, rhs])
2878
+
2879
+ else:
2880
+ attr = adj.emit_Attribute(lhs)
2881
+ if is_reference(attr.type):
2882
+ adj.add_builtin_call("store", [attr, rhs])
2883
+ else:
2884
+ adj.add_builtin_call("assign", [attr, rhs])
2885
+
2886
+ if warp._src.config.verbose and not adj.custom_reverse_mode:
2887
+ lineno = adj.lineno + adj.fun_lineno
2888
+ line = adj.source_lines[adj.lineno]
2889
+ msg = f'Warning: detected mutated struct {attr.label} during function "{adj.fun_name}" at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n'
2890
+ print(msg)
2891
+
2892
+ else:
2893
+ raise WarpCodegenError("Error, unsupported assignment statement.")
2894
+
2895
+ def emit_Return(adj, node):
2896
+ if node.value is None:
2897
+ var = None
2898
+ elif isinstance(node.value, ast.Tuple):
2899
+ var = tuple(adj.eval(arg) for arg in node.value.elts)
2900
+ else:
2901
+ var = adj.eval(node.value)
2902
+ if not isinstance(var, list) and not isinstance(var, tuple):
2903
+ var = (var,)
2904
+
2905
+ if adj.return_var is not None:
2906
+ old_ctypes = tuple(v.ctype(value_type=True) for v in adj.return_var)
2907
+ new_ctypes = tuple(v.ctype(value_type=True) for v in var)
2908
+ if old_ctypes != new_ctypes:
2909
+ raise WarpCodegenTypeError(
2910
+ f"Error, function returned different types, previous: [{', '.join(old_ctypes)}], new [{', '.join(new_ctypes)}]"
2911
+ )
2912
+
2913
+ if var is not None:
2914
+ adj.return_var = ()
2915
+ for ret in var:
2916
+ if is_reference(ret.type):
2917
+ ret_var = adj.add_builtin_call("copy", [ret])
2918
+ else:
2919
+ ret_var = ret
2920
+ adj.return_var += (ret_var,)
2921
+
2922
+ adj.add_return(adj.return_var)
2923
+
2924
+ def emit_AugAssign(adj, node):
2925
+ lhs = node.target
2926
+
2927
+ # replace augmented assignment with assignment statement + binary op (default behaviour)
2928
+ def make_new_assign_statement():
2929
+ new_node = ast.Assign(targets=[lhs], value=ast.BinOp(lhs, node.op, node.value))
2930
+ adj.eval(new_node)
2931
+
2932
+ rhs = adj.eval(node.value)
2933
+
2934
+ if isinstance(lhs, ast.Subscript):
2935
+ # wp.adjoint[var] appears in custom grad functions, and does not require
2936
+ # special consideration in the AugAssign case
2937
+ if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
2938
+ make_new_assign_statement()
2939
+ return
2940
+
2941
+ target, indices = adj.eval_subscript(lhs)
2942
+
2943
+ target_type = strip_reference(target.type)
2944
+ indices = adj.eval_indices(target_type, indices)
2945
+
2946
+ if is_array(target_type):
2947
+ # target_types int8, uint8, int16, uint16 are not suitable for atomic array accumulation
2948
+ if target_type.dtype in warp._src.types.non_atomic_types:
2949
+ make_new_assign_statement()
2950
+ return
2951
+
2952
+ # the same holds true for vecs/mats/quats that are composed of these types
2953
+ if (
2954
+ type_is_vector(target_type.dtype)
2955
+ or type_is_quaternion(target_type.dtype)
2956
+ or type_is_matrix(target_type.dtype)
2957
+ or type_is_transformation(target_type.dtype)
2958
+ ):
2959
+ dtype = getattr(target_type.dtype, "_wp_scalar_type_", None)
2960
+ if dtype in warp._src.types.non_atomic_types:
2961
+ make_new_assign_statement()
2962
+ return
2963
+
2964
+ kernel_name = adj.fun_name
2965
+ filename = adj.filename
2966
+ lineno = adj.lineno + adj.fun_lineno
2967
+
2968
+ if isinstance(node.op, ast.Add):
2969
+ adj.add_builtin_call("atomic_add", [target, *indices, rhs])
2970
+
2971
+ if warp._src.config.verify_autograd_array_access:
2972
+ target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2973
+
2974
+ elif isinstance(node.op, ast.Sub):
2975
+ adj.add_builtin_call("atomic_sub", [target, *indices, rhs])
2976
+
2977
+ if warp._src.config.verify_autograd_array_access:
2978
+ target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2979
+
2980
+ elif isinstance(node.op, ast.BitAnd):
2981
+ adj.add_builtin_call("atomic_and", [target, *indices, rhs])
2982
+
2983
+ if warp._src.config.verify_autograd_array_access:
2984
+ target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2985
+
2986
+ elif isinstance(node.op, ast.BitOr):
2987
+ adj.add_builtin_call("atomic_or", [target, *indices, rhs])
2988
+
2989
+ if warp._src.config.verify_autograd_array_access:
2990
+ target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2991
+
2992
+ elif isinstance(node.op, ast.BitXor):
2993
+ adj.add_builtin_call("atomic_xor", [target, *indices, rhs])
2994
+
2995
+ if warp._src.config.verify_autograd_array_access:
2996
+ target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2997
+ else:
2998
+ if warp._src.config.verbose:
2999
+ print(f"Warning: in-place op {node.op} is not differentiable")
3000
+ make_new_assign_statement()
3001
+ return
3002
+
3003
+ elif (
3004
+ type_is_vector(target_type)
3005
+ or type_is_quaternion(target_type)
3006
+ or type_is_matrix(target_type)
3007
+ or type_is_transformation(target_type)
3008
+ ):
3009
+ if isinstance(node.op, ast.Add):
3010
+ adj.add_builtin_call("add_inplace", [target, *indices, rhs])
3011
+ elif isinstance(node.op, ast.Sub):
3012
+ adj.add_builtin_call("sub_inplace", [target, *indices, rhs])
3013
+ elif isinstance(node.op, ast.BitAnd):
3014
+ adj.add_builtin_call("bit_and_inplace", [target, *indices, rhs])
3015
+ elif isinstance(node.op, ast.BitOr):
3016
+ adj.add_builtin_call("bit_or_inplace", [target, *indices, rhs])
3017
+ elif isinstance(node.op, ast.BitXor):
3018
+ adj.add_builtin_call("bit_xor_inplace", [target, *indices, rhs])
3019
+ else:
3020
+ if warp._src.config.verbose:
3021
+ print(f"Warning: in-place op {node.op} is not differentiable")
3022
+ make_new_assign_statement()
3023
+ return
3024
+
3025
+ elif is_tile(target.type):
3026
+ if isinstance(node.op, ast.Add):
3027
+ adj.add_builtin_call("tile_add_inplace", [target, *indices, rhs])
3028
+ elif isinstance(node.op, ast.Sub):
3029
+ adj.add_builtin_call("tile_sub_inplace", [target, *indices, rhs])
3030
+ elif isinstance(node.op, ast.BitAnd):
3031
+ adj.add_builtin_call("tile_bit_and_inplace", [target, *indices, rhs])
3032
+ elif isinstance(node.op, ast.BitOr):
3033
+ adj.add_builtin_call("tile_bit_or_inplace", [target, *indices, rhs])
3034
+ elif isinstance(node.op, ast.BitXor):
3035
+ adj.add_builtin_call("tile_bit_xor_inplace", [target, *indices, rhs])
3036
+ else:
3037
+ if warp._src.config.verbose:
3038
+ print(f"Warning: in-place op {node.op} is not differentiable")
3039
+ make_new_assign_statement()
3040
+ return
3041
+
3042
+ else:
3043
+ raise WarpCodegenError("Can only subscript in-place assign array, vector, quaternion, and matrix types")
3044
+
3045
+ elif isinstance(lhs, ast.Name):
3046
+ target = adj.eval(node.target)
3047
+
3048
+ if is_tile(target.type) and is_tile(rhs.type):
3049
+ if isinstance(node.op, ast.Add):
3050
+ adj.add_builtin_call("add_inplace", [target, rhs])
3051
+ elif isinstance(node.op, ast.Sub):
3052
+ adj.add_builtin_call("sub_inplace", [target, rhs])
3053
+ elif isinstance(node.op, ast.BitAnd):
3054
+ adj.add_builtin_call("bit_and_inplace", [target, rhs])
3055
+ elif isinstance(node.op, ast.BitOr):
3056
+ adj.add_builtin_call("bit_or_inplace", [target, rhs])
3057
+ elif isinstance(node.op, ast.BitXor):
3058
+ adj.add_builtin_call("bit_xor_inplace", [target, rhs])
3059
+ else:
3060
+ make_new_assign_statement()
3061
+ return
3062
+ else:
3063
+ make_new_assign_statement()
3064
+ return
3065
+
3066
+ # TODO
3067
+ elif isinstance(lhs, ast.Attribute):
3068
+ make_new_assign_statement()
3069
+ return
3070
+
3071
+ else:
3072
+ make_new_assign_statement()
3073
+ return
3074
+
3075
+ def emit_Tuple(adj, node):
3076
+ elements = tuple(adj.eval(x) for x in node.elts)
3077
+ return adj.add_builtin_call("tuple", elements)
3078
+
3079
+ def emit_Pass(adj, node):
3080
+ pass
3081
+
3082
+ node_visitors: ClassVar[dict[type[ast.AST], Callable]] = {
3083
+ ast.FunctionDef: emit_FunctionDef,
3084
+ ast.If: emit_If,
3085
+ ast.IfExp: emit_IfExp,
3086
+ ast.Compare: emit_Compare,
3087
+ ast.BoolOp: emit_BoolOp,
3088
+ ast.Name: emit_Name,
3089
+ ast.Attribute: emit_Attribute,
3090
+ ast.Constant: emit_Constant,
3091
+ ast.BinOp: emit_BinOp,
3092
+ ast.UnaryOp: emit_UnaryOp,
3093
+ ast.While: emit_While,
3094
+ ast.For: emit_For,
3095
+ ast.Break: emit_Break,
3096
+ ast.Continue: emit_Continue,
3097
+ ast.Expr: emit_Expr,
3098
+ ast.Call: emit_Call,
3099
+ ast.Index: emit_Index, # Deprecated in 3.9
3100
+ ast.Subscript: emit_Subscript,
3101
+ ast.Slice: emit_Slice,
3102
+ ast.Assign: emit_Assign,
3103
+ ast.Return: emit_Return,
3104
+ ast.AugAssign: emit_AugAssign,
3105
+ ast.Tuple: emit_Tuple,
3106
+ ast.Pass: emit_Pass,
3107
+ ast.Assert: emit_Assert,
3108
+ }
3109
+
3110
+ def eval(adj, node):
3111
+ if hasattr(node, "lineno"):
3112
+ adj.set_lineno(node.lineno - 1)
3113
+
3114
+ try:
3115
+ emit_node = adj.node_visitors[type(node)]
3116
+ except KeyError as e:
3117
+ type_name = type(node).__name__
3118
+ namespace = "ast." if isinstance(node, ast.AST) else ""
3119
+ raise WarpCodegenError(f"Construct `{namespace}{type_name}` not supported in kernels.") from e
3120
+
3121
+ return emit_node(adj, node)
3122
+
3123
+ # helper to evaluate expressions of the form
3124
+ # obj1.obj2.obj3.attr in the function's global scope
3125
+ def resolve_path(adj, path):
3126
+ if len(path) == 0:
3127
+ return None
3128
+
3129
+ # if root is overshadowed by local symbols, bail out
3130
+ if path[0] in adj.symbols:
3131
+ return None
3132
+
3133
+ # look up in closure/global variables
3134
+ expr = adj.resolve_external_reference(path[0])
3135
+
3136
+ # Support Warp types in kernels without the module suffix (e.g. v = vec3(0.0,0.2,0.4)):
3137
+ if expr is None:
3138
+ expr = getattr(warp, path[0], None)
3139
+
3140
+ # look up in builtins
3141
+ if expr is None:
3142
+ expr = __builtins__.get(path[0])
3143
+
3144
+ if expr is not None:
3145
+ for i in range(1, len(path)):
3146
+ if hasattr(expr, path[i]):
3147
+ expr = getattr(expr, path[i])
3148
+
3149
+ return expr
3150
+
3151
+ # retrieves a dictionary of all closure and global variables and their values
3152
+ # to be used in the evaluation context of wp.static() expressions
3153
+ def get_static_evaluation_context(adj):
3154
+ closure_vars = dict(
3155
+ zip(
3156
+ adj.func.__code__.co_freevars,
3157
+ [c.cell_contents for c in (adj.func.__closure__ or [])],
3158
+ )
3159
+ )
3160
+
3161
+ vars_dict = {}
3162
+ vars_dict.update(adj.func.__globals__)
3163
+ # variables captured in closure have precedence over global vars
3164
+ vars_dict.update(closure_vars)
3165
+
3166
+ return vars_dict
3167
+
3168
+ def is_static_expression(adj, func):
3169
+ return (
3170
+ isinstance(func, types.FunctionType)
3171
+ and func.__module__ == "warp._src.builtins"
3172
+ and func.__qualname__ == "static"
3173
+ )
3174
+
3175
+ # verify the return type of a wp.static() expression is supported inside a Warp kernel
3176
+ def verify_static_return_value(adj, value):
3177
+ if value is None:
3178
+ raise ValueError("None is returned")
3179
+ if warp._src.types.is_value(value):
3180
+ return True
3181
+ if warp._src.types.is_array(value):
3182
+ # more useful explanation for the common case of creating a Warp array
3183
+ raise ValueError("a Warp array cannot be created inside Warp kernels")
3184
+ if isinstance(value, str):
3185
+ # we want to support cases such as `print(wp.static("test"))`
3186
+ return True
3187
+ if isinstance(value, warp._src.context.Function):
3188
+ return True
3189
+
3190
+ def verify_struct(s: StructInstance, attr_path: list[str]):
3191
+ for key in s._cls.vars.keys():
3192
+ v = getattr(s, key)
3193
+ if issubclass(type(v), StructInstance):
3194
+ verify_struct(v, [*attr_path, key])
3195
+ else:
3196
+ try:
3197
+ adj.verify_static_return_value(v)
3198
+ except ValueError as e:
3199
+ raise ValueError(
3200
+ f"the returned Warp struct contains a data type that cannot be constructed inside Warp kernels: {e} at {value._cls.key}.{'.'.join(attr_path)}"
3201
+ ) from e
3202
+
3203
+ if issubclass(type(value), StructInstance):
3204
+ return verify_struct(value, [])
3205
+
3206
+ raise ValueError(f"value of type {type(value)} cannot be constructed inside Warp kernels")
3207
+
3208
+ # find the source code string of an AST node
3209
+ @staticmethod
3210
+ def extract_node_source_from_lines(source_lines, node) -> str | None:
3211
+ if not hasattr(node, "lineno") or not hasattr(node, "col_offset"):
3212
+ return None
3213
+
3214
+ start_line = node.lineno - 1 # line numbers start at 1
3215
+ start_col = node.col_offset
3216
+
3217
+ if hasattr(node, "end_lineno") and hasattr(node, "end_col_offset"):
3218
+ end_line = node.end_lineno - 1
3219
+ end_col = node.end_col_offset
3220
+ else:
3221
+ # fallback for Python versions before 3.8
3222
+ # we have to find the end line and column manually
3223
+ end_line = start_line
3224
+ end_col = start_col
3225
+ parenthesis_count = 1
3226
+ for lineno in range(start_line, len(source_lines)):
3227
+ if lineno == start_line:
3228
+ c_start = start_col
3229
+ else:
3230
+ c_start = 0
3231
+ line = source_lines[lineno]
3232
+ for i in range(c_start, len(line)):
3233
+ c = line[i]
3234
+ if c == "(":
3235
+ parenthesis_count += 1
3236
+ elif c == ")":
3237
+ parenthesis_count -= 1
3238
+ if parenthesis_count == 0:
3239
+ end_col = i
3240
+ end_line = lineno
3241
+ break
3242
+ if parenthesis_count == 0:
3243
+ break
3244
+
3245
+ if start_line == end_line:
3246
+ # single-line expression
3247
+ return source_lines[start_line][start_col:end_col]
3248
+ else:
3249
+ # multi-line expression
3250
+ lines = []
3251
+ # first line (from start_col to the end)
3252
+ lines.append(source_lines[start_line][start_col:])
3253
+ # middle lines (entire lines)
3254
+ lines.extend(source_lines[start_line + 1 : end_line])
3255
+ # last line (from the start to end_col)
3256
+ lines.append(source_lines[end_line][:end_col])
3257
+ return "".join(lines).strip()
3258
+
3259
+ @staticmethod
3260
+ def extract_lambda_source(func, only_body=False) -> str | None:
3261
+ try:
3262
+ source_lines = inspect.getsourcelines(func)[0]
3263
+ source_lines[0] = source_lines[0][source_lines[0].index("lambda") :]
3264
+ except OSError as e:
3265
+ raise WarpCodegenError(
3266
+ "Could not access lambda function source code. Please use a named function instead."
3267
+ ) from e
3268
+ source = "".join(source_lines)
3269
+ source = source[source.index("lambda") :].rstrip()
3270
+ # Remove trailing unbalanced parentheses
3271
+ while source.count("(") < source.count(")"):
3272
+ source = source[:-1]
3273
+ # extract lambda expression up until a comma, e.g. in the case of
3274
+ # "map(lambda a: (a + 2.0, a + 3.0), a, return_kernel=True)"
3275
+ si = max(source.rfind(")"), source.find(":"))
3276
+ ci = source.find(",", si)
3277
+ if ci != -1:
3278
+ source = source[:ci]
3279
+ tree = ast.parse(source)
3280
+ lambda_source = None
3281
+ for node in ast.walk(tree):
3282
+ if isinstance(node, ast.Lambda):
3283
+ if only_body:
3284
+ # extract the body of the lambda function
3285
+ lambda_source = Adjoint.extract_node_source_from_lines(source_lines, node.body)
3286
+ else:
3287
+ # extract the entire lambda function
3288
+ lambda_source = Adjoint.extract_node_source_from_lines(source_lines, node)
3289
+ break
3290
+ return lambda_source
3291
+
3292
+ def extract_node_source(adj, node) -> str | None:
3293
+ return adj.extract_node_source_from_lines(adj.source_lines, node)
3294
+
3295
+ # handles a wp.static() expression and returns the resulting object and a string representing the code
3296
+ # of the static expression
3297
+ def evaluate_static_expression(adj, node) -> tuple[Any, str]:
3298
+ if len(node.args) == 1:
3299
+ static_code = adj.extract_node_source(node.args[0])
3300
+ elif len(node.keywords) == 1:
3301
+ static_code = adj.extract_node_source(node.keywords[0])
3302
+ else:
3303
+ raise WarpCodegenError("warp.static() requires a single argument or keyword")
3304
+ if static_code is None:
3305
+ raise WarpCodegenError("Error extracting source code from wp.static() expression")
3306
+
3307
+ # Since this is an expression, we can enforce it to be defined on a single line.
3308
+ static_code = static_code.replace("\n", "")
3309
+ code_to_eval = static_code # code to be evaluated
3310
+
3311
+ vars_dict = adj.get_static_evaluation_context()
3312
+ # add constant variables to the static call context
3313
+ constant_vars = {k: v.constant for k, v in adj.symbols.items() if isinstance(v, Var) and v.constant is not None}
3314
+ vars_dict.update(constant_vars)
3315
+
3316
+ # Replace all constant `len()` expressions with their value.
3317
+ if "len" in static_code:
3318
+ len_expr_ctx = vars_dict.copy()
3319
+ constant_types = {k: v.type for k, v in adj.symbols.items() if isinstance(v, Var) and v.type is not None}
3320
+ len_expr_ctx.update(constant_types)
3321
+ len_expr_ctx.update({"len": warp._src.types.type_length})
3322
+
3323
+ # We want to replace the expression code in-place,
3324
+ # so reparse it to get the correct column info.
3325
+ len_value_locs: list[tuple[int, int, int]] = []
3326
+ expr_tree = ast.parse(static_code)
3327
+ assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr)
3328
+ expr_root = expr_tree.body[0].value
3329
+ for expr_node in ast.walk(expr_root):
3330
+ if (
3331
+ isinstance(expr_node, ast.Call)
3332
+ and getattr(expr_node.func, "id", None) == "len"
3333
+ and len(expr_node.args) == 1
3334
+ ):
3335
+ len_expr = static_code[expr_node.col_offset : expr_node.end_col_offset]
3336
+ try:
3337
+ len_value = eval(len_expr, len_expr_ctx)
3338
+ except Exception:
3339
+ pass
3340
+ else:
3341
+ len_value_locs.append((len_value, expr_node.col_offset, expr_node.end_col_offset))
3342
+
3343
+ if len_value_locs:
3344
+ new_static_code = ""
3345
+ loc = 0
3346
+ for value, start, end in len_value_locs:
3347
+ new_static_code += f"{static_code[loc:start]}{value}"
3348
+ loc = end
3349
+
3350
+ new_static_code += static_code[len_value_locs[-1][2] :]
3351
+ code_to_eval = new_static_code
3352
+
3353
+ try:
3354
+ value = eval(code_to_eval, vars_dict)
3355
+ if isinstance(value, (enum.IntEnum, enum.IntFlag)):
3356
+ value = int(value)
3357
+ if warp._src.config.verbose:
3358
+ print(f"Evaluated static command: {static_code} = {value}")
3359
+ except NameError as e:
3360
+ raise WarpCodegenError(
3361
+ f"Error evaluating static expression: {e}. Make sure all variables used in the static expression are constant."
3362
+ ) from e
3363
+ except Exception as e:
3364
+ raise WarpCodegenError(
3365
+ f"Error evaluating static expression: {e} while evaluating the following code generated from the static expression:\n{static_code}"
3366
+ ) from e
3367
+
3368
+ try:
3369
+ adj.verify_static_return_value(value)
3370
+ except ValueError as e:
3371
+ raise WarpCodegenError(
3372
+ f"Static expression returns an unsupported value: {e} while evaluating the following code generated from the static expression:\n{static_code}"
3373
+ ) from e
3374
+
3375
+ return value, static_code
3376
+
3377
+ # try to replace wp.static() expressions by their evaluated value if the
3378
+ # expression can be evaluated
3379
+ def replace_static_expressions(adj):
3380
+ class StaticExpressionReplacer(ast.NodeTransformer):
3381
+ def visit_Call(self, node):
3382
+ func, _ = adj.resolve_static_expression(node.func, eval_types=False)
3383
+ if adj.is_static_expression(func):
3384
+ try:
3385
+ # the static expression will execute as long as the static expression is valid and
3386
+ # only depends on global or captured variables
3387
+ obj, code = adj.evaluate_static_expression(node)
3388
+ if code is not None:
3389
+ adj.static_expressions[code] = obj
3390
+ if isinstance(obj, warp._src.context.Function):
3391
+ name_node = ast.Name("__warp_func__")
3392
+ # we add a pointer to the Warp function here so that we can refer to it later at
3393
+ # codegen time (note that the function key itself is not sufficient to uniquely
3394
+ # identify the function, as the function may be redefined between the current time
3395
+ # of wp.static() declaration and the time of codegen during module building)
3396
+ name_node.warp_func = obj
3397
+ return ast.copy_location(name_node, node)
3398
+ else:
3399
+ return ast.copy_location(ast.Constant(value=obj), node)
3400
+ except Exception:
3401
+ # Ignoring failing static expressions should generally not be an issue because only
3402
+ # one of these cases should be possible:
3403
+ # 1) the static expression itself is invalid code, in which case the module cannot be
3404
+ # built all,
3405
+ # 2) the static expression contains a reference to a local (even if constant) variable
3406
+ # (and is therefore not executable and raises this exception), in which
3407
+ # case changing the constant, or the code affecting this constant, would lead to
3408
+ # a different module hash anyway.
3409
+ # In any case, we mark this Adjoint to have unresolvable static expressions.
3410
+ # This will trigger a code generation step even if the module hash is unchanged.
3411
+ adj.has_unresolved_static_expressions = True
3412
+ pass
3413
+
3414
+ return self.generic_visit(node)
3415
+
3416
+ adj.tree = StaticExpressionReplacer().visit(adj.tree)
3417
+
3418
+ # Evaluates a static expression that does not depend on runtime values
3419
+ # if eval_types is True, try resolving the path using evaluated type information as well
3420
+ def resolve_static_expression(adj, root_node, eval_types=True):
3421
+ attributes = []
3422
+
3423
+ node = root_node
3424
+ while isinstance(node, ast.Attribute):
3425
+ attributes.append(node.attr)
3426
+ node = node.value
3427
+
3428
+ if eval_types and isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
3429
+ # support for operators returning modules
3430
+ # i.e. operator_name(*operator_args).x.y.z
3431
+ operator_args = node.args
3432
+ operator_name = node.func.id
3433
+
3434
+ if operator_name == "type":
3435
+ if len(operator_args) != 1:
3436
+ raise WarpCodegenError(f"type() operator expects exactly one argument, got {len(operator_args)}")
3437
+
3438
+ # type() operator
3439
+ var = adj.eval(operator_args[0])
3440
+
3441
+ if isinstance(var, Var):
3442
+ var_type = strip_reference(var.type)
3443
+ # Allow accessing type attributes, for instance array.dtype
3444
+ while attributes:
3445
+ attr_name = attributes.pop()
3446
+ var_type, prev_type = adj.resolve_type_attribute(var_type, attr_name), var_type
3447
+
3448
+ if var_type is None:
3449
+ raise WarpCodegenAttributeError(
3450
+ f"{attr_name} is not an attribute of {type_repr(prev_type)}"
3451
+ )
3452
+
3453
+ return var_type, [str(var_type)]
3454
+ else:
3455
+ raise WarpCodegenError(f"Cannot deduce the type of {var}")
3456
+
3457
+ # reverse list since ast presents it in backward order
3458
+ path = [*reversed(attributes)]
3459
+ if isinstance(node, ast.Name):
3460
+ path.insert(0, node.id)
3461
+
3462
+ # Try resolving path from captured context
3463
+ captured_obj = adj.resolve_path(path)
3464
+ if captured_obj is not None:
3465
+ return captured_obj, path
3466
+
3467
+ return None, path
3468
+
3469
+ def resolve_external_reference(adj, name: str):
3470
+ try:
3471
+ # look up in closure variables
3472
+ idx = adj.func.__code__.co_freevars.index(name)
3473
+ obj = adj.func.__closure__[idx].cell_contents
3474
+ except ValueError:
3475
+ # look up in global variables
3476
+ obj = adj.func.__globals__.get(name)
3477
+ return obj
3478
+
3479
+ # annotate generated code with the original source code line
3480
+ def set_lineno(adj, lineno):
3481
+ if adj.lineno is None or adj.lineno != lineno:
3482
+ line = lineno + adj.fun_lineno
3483
+ source = adj.source_lines[lineno].strip().ljust(80 - len(adj.indentation), " ")
3484
+ adj.add_forward(f"// {source} <L {line}>")
3485
+ adj.add_reverse(f"// adj: {source} <L {line}>")
3486
+ adj.lineno = lineno
3487
+
3488
+ def get_node_source(adj, node):
3489
+ # return the Python code corresponding to the given AST node
3490
+ return ast.get_source_segment(adj.source, node)
3491
+
3492
+ def get_references(adj) -> tuple[dict[str, Any], dict[Any, Any], dict[warp._src.context.Function, Any]]:
3493
+ """Traverses ``adj.tree`` and returns referenced constants, types, and user-defined functions."""
3494
+
3495
+ local_variables = set() # Track local variables appearing on the LHS so we know when variables are shadowed
3496
+
3497
+ constants: dict[str, Any] = {}
3498
+ types: dict[Struct | type, Any] = {}
3499
+ functions: dict[warp._src.context.Function, Any] = {}
3500
+
3501
+ for node in ast.walk(adj.tree):
3502
+ if isinstance(node, ast.Name) and node.id not in local_variables:
3503
+ # look up in closure/global variables
3504
+ obj = adj.resolve_external_reference(node.id)
3505
+ if warp._src.types.is_value(obj):
3506
+ constants[node.id] = obj
3507
+
3508
+ elif isinstance(node, ast.Attribute):
3509
+ obj, path = adj.resolve_static_expression(node, eval_types=False)
3510
+ if warp._src.types.is_value(obj):
3511
+ constants[".".join(path)] = obj
3512
+
3513
+ elif isinstance(node, ast.Call):
3514
+ func, _ = adj.resolve_static_expression(node.func, eval_types=False)
3515
+ if isinstance(func, warp._src.context.Function) and not func.is_builtin():
3516
+ # calling user-defined function
3517
+ functions[func] = None
3518
+ elif isinstance(func, Struct):
3519
+ # calling struct constructor
3520
+ types[func] = None
3521
+ elif isinstance(func, type) and warp._src.types.type_is_value(func):
3522
+ # calling value type constructor
3523
+ types[func] = None
3524
+
3525
+ elif isinstance(node, ast.Assign):
3526
+ # Add the LHS names to the local_variables so we know any subsequent uses are shadowed
3527
+ lhs = node.targets[0]
3528
+ if isinstance(lhs, ast.Tuple):
3529
+ for v in lhs.elts:
3530
+ if isinstance(v, ast.Name):
3531
+ local_variables.add(v.id)
3532
+ elif isinstance(lhs, ast.Name):
3533
+ local_variables.add(lhs.id)
3534
+
3535
+ return constants, types, functions
3536
+
3537
+
3538
+ # ----------------
3539
+ # code generation
3540
+
3541
+ cpu_module_header = """
3542
+ #define WP_TILE_BLOCK_DIM {block_dim}
3543
+ #define WP_NO_CRT
3544
+ #include "builtin.h"
3545
+
3546
+ // avoid namespacing of float type for casting to float type, this is to avoid wp::float(x), which is not valid in C++
3547
+ #define float(x) cast_float(x)
3548
+ #define adj_float(x, adj_x, adj_ret) adj_cast_float(x, adj_x, adj_ret)
3549
+
3550
+ #define int(x) cast_int(x)
3551
+ #define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
3552
+
3553
+ #define builtin_tid1d() wp::tid(task_index, dim)
3554
+ #define builtin_tid2d(x, y) wp::tid(x, y, task_index, dim)
3555
+ #define builtin_tid3d(x, y, z) wp::tid(x, y, z, task_index, dim)
3556
+ #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, task_index, dim)
3557
+
3558
+ #define builtin_block_dim() wp::block_dim()
3559
+
3560
+ """
3561
+
3562
+ cuda_module_header = """
3563
+ #define WP_TILE_BLOCK_DIM {block_dim}
3564
+ #define WP_NO_CRT
3565
+ #include "builtin.h"
3566
+
3567
+ // Map wp.breakpoint() to a device brkpt at the call site so cuda-gdb attributes the stop to the generated .cu line
3568
+ #if defined(__CUDACC__) && !defined(_MSC_VER)
3569
+ #define __debugbreak() __brkpt()
3570
+ #endif
3571
+
3572
+ // avoid namespacing of float type for casting to float type, this is to avoid wp::float(x), which is not valid in C++
3573
+ #define float(x) cast_float(x)
3574
+ #define adj_float(x, adj_x, adj_ret) adj_cast_float(x, adj_x, adj_ret)
3575
+
3576
+ #define int(x) cast_int(x)
3577
+ #define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
3578
+
3579
+ #define builtin_tid1d() wp::tid(_idx, dim)
3580
+ #define builtin_tid2d(x, y) wp::tid(x, y, _idx, dim)
3581
+ #define builtin_tid3d(x, y, z) wp::tid(x, y, z, _idx, dim)
3582
+ #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, _idx, dim)
3583
+
3584
+ #define builtin_block_dim() wp::block_dim()
3585
+
3586
+ """
3587
+
3588
+ struct_template = """
3589
+ struct {name}
3590
+ {{
3591
+ {struct_body}
3592
+
3593
+ {defaulted_constructor_def}
3594
+ CUDA_CALLABLE {name}({forward_args})
3595
+ {forward_initializers}
3596
+ {{
3597
+ }}
3598
+
3599
+ CUDA_CALLABLE {name}& operator += (const {name}& rhs)
3600
+ {{{prefix_add_body}
3601
+ return *this;}}
3602
+
3603
+ }};
3604
+
3605
+ static CUDA_CALLABLE void adj_{name}({reverse_args})
3606
+ {{
3607
+ {reverse_body}}}
3608
+
3609
+ // Required when compiling adjoints.
3610
+ CUDA_CALLABLE {name} add(const {name}& a, const {name}& b)
3611
+ {{
3612
+ return {name}();
3613
+ }}
3614
+
3615
+ CUDA_CALLABLE void adj_atomic_add({name}* p, {name} t)
3616
+ {{
3617
+ {atomic_add_body}}}
3618
+
3619
+
3620
+ """
3621
+
3622
+ cpu_forward_function_template = """
3623
+ // {filename}:{lineno}
3624
+ static {return_type} {name}(
3625
+ {forward_args})
3626
+ {{
3627
+ {forward_body}}}
3628
+
3629
+ """
3630
+
3631
+ cpu_reverse_function_template = """
3632
+ // {filename}:{lineno}
3633
+ static void adj_{name}(
3634
+ {reverse_args})
3635
+ {{
3636
+ {reverse_body}}}
3637
+
3638
+ """
3639
+
3640
+ cuda_forward_function_template = """
3641
+ // {filename}:{lineno}
3642
+ {line_directive}static CUDA_CALLABLE {return_type} {name}(
3643
+ {forward_args})
3644
+ {{
3645
+ {forward_body}{line_directive}}}
3646
+
3647
+ """
3648
+
3649
+ cuda_reverse_function_template = """
3650
+ // {filename}:{lineno}
3651
+ {line_directive}static CUDA_CALLABLE void adj_{name}(
3652
+ {reverse_args})
3653
+ {{
3654
+ {reverse_body}{line_directive}}}
3655
+
3656
+ """
3657
+
3658
+ cuda_kernel_template_forward = """
3659
+
3660
+ {line_directive}extern "C" __global__ void {name}_cuda_kernel_forward(
3661
+ {forward_args})
3662
+ {{
3663
+ {line_directive} wp::tile_shared_storage_t tile_mem;
3664
+
3665
+ {line_directive} for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3666
+ {line_directive} _idx < dim.size;
3667
+ {line_directive} _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3668
+ {{
3669
+ // reset shared memory allocator
3670
+ {line_directive} wp::tile_shared_storage_t::init();
3671
+
3672
+ {forward_body}{line_directive} }}
3673
+ {line_directive}}}
3674
+
3675
+ """
3676
+
3677
+ cuda_kernel_template_backward = """
3678
+
3679
+ {line_directive}extern "C" __global__ void {name}_cuda_kernel_backward(
3680
+ {reverse_args})
3681
+ {{
3682
+ {line_directive} wp::tile_shared_storage_t tile_mem;
3683
+
3684
+ {line_directive} for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3685
+ {line_directive} _idx < dim.size;
3686
+ {line_directive} _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3687
+ {{
3688
+ // reset shared memory allocator
3689
+ {line_directive} wp::tile_shared_storage_t::init();
3690
+
3691
+ {reverse_body}{line_directive} }}
3692
+ {line_directive}}}
3693
+
3694
+ """
3695
+
3696
+ cpu_kernel_template_forward = """
3697
+
3698
+ void {name}_cpu_kernel_forward(
3699
+ {forward_args},
3700
+ wp_args_{name} *_wp_args)
3701
+ {{
3702
+ {forward_body}}}
3703
+
3704
+ """
3705
+
3706
+ cpu_kernel_template_backward = """
3707
+
3708
+ void {name}_cpu_kernel_backward(
3709
+ {reverse_args},
3710
+ wp_args_{name} *_wp_args,
3711
+ wp_args_{name} *_wp_adj_args)
3712
+ {{
3713
+ {reverse_body}}}
3714
+
3715
+ """
3716
+
3717
+ cpu_module_template_forward = """
3718
+
3719
+ extern "C" {{
3720
+
3721
+ // Python CPU entry points
3722
+ WP_API void {name}_cpu_forward(
3723
+ wp::launch_bounds_t dim,
3724
+ wp_args_{name} *_wp_args)
3725
+ {{
3726
+ wp::tile_shared_storage_t tile_mem;
3727
+ #if defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
3728
+ wp::shared_tile_storage = &tile_mem;
3729
+ #endif
3730
+
3731
+ for (size_t task_index = 0; task_index < dim.size; ++task_index)
3732
+ {{
3733
+ {name}_cpu_kernel_forward(dim, task_index, _wp_args);
3734
+ }}
3735
+ }}
3736
+
3737
+ }} // extern C
3738
+
3739
+ """
3740
+
3741
+ cpu_module_template_backward = """
3742
+
3743
+ extern "C" {{
3744
+
3745
+ WP_API void {name}_cpu_backward(
3746
+ wp::launch_bounds_t dim,
3747
+ wp_args_{name} *_wp_args,
3748
+ wp_args_{name} *_wp_adj_args)
3749
+ {{
3750
+ wp::tile_shared_storage_t tile_mem;
3751
+ #if defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
3752
+ wp::shared_tile_storage = &tile_mem;
3753
+ #endif
3754
+
3755
+ for (size_t task_index = 0; task_index < dim.size; ++task_index)
3756
+ {{
3757
+ {name}_cpu_kernel_backward(dim, task_index, _wp_args, _wp_adj_args);
3758
+ }}
3759
+ }}
3760
+
3761
+ }} // extern C
3762
+
3763
+ """
3764
+
3765
+
3766
+ # converts a constant Python value to equivalent C-repr
3767
+ def constant_str(value):
3768
+ value_type = type(value)
3769
+
3770
+ if value_type == bool or value_type == builtins.bool:
3771
+ if value:
3772
+ return "true"
3773
+ else:
3774
+ return "false"
3775
+
3776
+ elif value_type == str:
3777
+ # ensure constant strings are correctly escaped
3778
+ return '"' + str(value.encode("unicode-escape").decode()) + '"'
3779
+
3780
+ elif isinstance(value, ctypes.Array):
3781
+ if value_type._wp_scalar_type_ == float16:
3782
+ # special case for float16, which is stored as uint16 in the ctypes.Array
3783
+ from warp._src.context import runtime
3784
+
3785
+ scalar_value = runtime.core.wp_half_bits_to_float
3786
+ else:
3787
+
3788
+ def scalar_value(x):
3789
+ return x
3790
+
3791
+ # list of scalar initializer values
3792
+ initlist = []
3793
+ for i in range(value._length_):
3794
+ x = ctypes.Array.__getitem__(value, i)
3795
+ initlist.append(str(scalar_value(x)).lower())
3796
+
3797
+ if value._wp_scalar_type_ is bool:
3798
+ dtypestr = f"wp::initializer_array<{value._length_},{value._wp_scalar_type_.__name__}>"
3799
+ else:
3800
+ dtypestr = f"wp::initializer_array<{value._length_},wp::{value._wp_scalar_type_.__name__}>"
3801
+
3802
+ # construct value from initializer array, e.g. wp::initializer_array<4,wp::float32>{1.0, 2.0, 3.0, 4.0}
3803
+ return f"{dtypestr}{{{', '.join(initlist)}}}"
3804
+
3805
+ elif value_type in warp._src.types.scalar_types:
3806
+ # make sure we emit the value of objects, e.g. uint32
3807
+ return str(value.value)
3808
+
3809
+ elif issubclass(value_type, StructInstance):
3810
+ # constant struct instance
3811
+ arg_strs = []
3812
+ for key, var in value._cls.vars.items():
3813
+ attr = getattr(value, key)
3814
+ arg_strs.append(f"{Var.type_to_ctype(var.type)}({constant_str(attr)})")
3815
+ arg_str = ", ".join(arg_strs)
3816
+ return f"{value.native_name}({arg_str})"
3817
+
3818
+ elif value == math.inf:
3819
+ return "INFINITY"
3820
+
3821
+ elif math.isnan(value):
3822
+ return "NAN"
3823
+
3824
+ else:
3825
+ # otherwise just convert constant to string
3826
+ return str(value)
3827
+
3828
+
3829
+ def indent(args, stops=1):
3830
+ sep = ",\n"
3831
+ for _i in range(stops):
3832
+ sep += " "
3833
+
3834
+ # return sep + args.replace(", ", "," + sep)
3835
+ return sep.join(args)
3836
+
3837
+
3838
+ # generates a C function name based on the python function name
3839
+ def make_full_qualified_name(func: Union[str, Callable]) -> str:
3840
+ if not isinstance(func, str):
3841
+ func = func.__qualname__
3842
+ return re.sub("[^0-9a-zA-Z_]+", "", func.replace(".", "__"))
3843
+
3844
+
3845
+ def codegen_struct(struct, device="cpu", indent_size=4):
3846
+ name = struct.native_name
3847
+
3848
+ body = []
3849
+ indent_block = " " * indent_size
3850
+
3851
+ if len(struct.vars) > 0:
3852
+ for label, var in struct.vars.items():
3853
+ body.append(var.ctype() + " " + label + ";\n")
3854
+ else:
3855
+ # for empty structs, emit the dummy attribute to avoid any compiler-specific alignment issues
3856
+ body.append("char _dummy_;\n")
3857
+
3858
+ forward_args = []
3859
+ reverse_args = []
3860
+
3861
+ forward_initializers = []
3862
+ reverse_body = []
3863
+ atomic_add_body = []
3864
+ prefix_add_body = []
3865
+
3866
+ # forward args
3867
+ for label, var in struct.vars.items():
3868
+ var_ctype = var.ctype()
3869
+ default_arg_def = " = {}" if forward_args else ""
3870
+ forward_args.append(f"{var_ctype} const& {label}{default_arg_def}")
3871
+ reverse_args.append(f"{var_ctype} const&")
3872
+
3873
+ namespace = "wp::" if var_ctype.startswith("wp::") or var_ctype == "bool" else ""
3874
+ atomic_add_body.append(f"{indent_block}{namespace}adj_atomic_add(&p->{label}, t.{label});\n")
3875
+
3876
+ prefix = f"{indent_block}," if forward_initializers else ":"
3877
+ forward_initializers.append(f"{indent_block}{prefix} {label}{{{label}}}\n")
3878
+
3879
+ # prefix-add operator
3880
+ for label, var in struct.vars.items():
3881
+ if not is_array(var.type):
3882
+ prefix_add_body.append(f"{indent_block}{label} += rhs.{label};\n")
3883
+
3884
+ # reverse args
3885
+ for label, var in struct.vars.items():
3886
+ reverse_args.append(var.ctype() + " & adj_" + label)
3887
+ if is_array(var.type):
3888
+ reverse_body.append(f"{indent_block}adj_{label} = adj_ret.{label};\n")
3889
+ else:
3890
+ reverse_body.append(f"{indent_block}adj_{label} += adj_ret.{label};\n")
3891
+
3892
+ reverse_args.append(name + " & adj_ret")
3893
+
3894
+ # explicitly defaulted default constructor if no default constructor has been defined
3895
+ defaulted_constructor_def = f"{name}() = default;" if forward_args else ""
3896
+
3897
+ return struct_template.format(
3898
+ name=name,
3899
+ struct_body="".join([indent_block + l for l in body]),
3900
+ forward_args=indent(forward_args),
3901
+ forward_initializers="".join(forward_initializers),
3902
+ reverse_args=indent(reverse_args),
3903
+ reverse_body="".join(reverse_body),
3904
+ prefix_add_body="".join(prefix_add_body),
3905
+ atomic_add_body="".join(atomic_add_body),
3906
+ defaulted_constructor_def=defaulted_constructor_def,
3907
+ )
3908
+
3909
+
3910
+ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
3911
+ if device == "cpu":
3912
+ indent = 4
3913
+ elif device == "cuda":
3914
+ if func_type == "kernel":
3915
+ indent = 8
3916
+ else:
3917
+ indent = 4
3918
+ else:
3919
+ raise ValueError(f"Device {device} not supported for codegen")
3920
+
3921
+ indent_block = " " * indent
3922
+
3923
+ lines = []
3924
+
3925
+ # argument vars
3926
+ if device == "cpu" and func_type == "kernel":
3927
+ lines += ["//---------\n"]
3928
+ lines += ["// argument vars\n"]
3929
+
3930
+ for var in adj.args:
3931
+ lines += [f"{var.ctype()} {var.emit()} = _wp_args->{var.label};\n"]
3932
+
3933
+ # primal vars
3934
+ lines += ["//---------\n"]
3935
+ lines += ["// primal vars\n"]
3936
+
3937
+ for var in adj.variables:
3938
+ if is_tile(var.type):
3939
+ lines += [f"{var.ctype()} {var.emit()} = {var.type.cinit(requires_grad=False)};\n"]
3940
+ elif var.constant is None:
3941
+ lines += [f"{var.ctype()} {var.emit()};\n"]
3942
+ else:
3943
+ lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
3944
+
3945
+ if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
3946
+ lines.insert(-1, f"{line_directive}\n")
3947
+
3948
+ # forward pass
3949
+ lines += ["//---------\n"]
3950
+ lines += ["// forward\n"]
3951
+
3952
+ for f in adj.blocks[0].body_forward:
3953
+ if func_type == "kernel" and device == "cuda" and f.lstrip().startswith("return;"):
3954
+ # Use of grid-stride loops in CUDA kernels requires that we convert return; to continue;
3955
+ lines += [f.replace("return;", "continue;") + "\n"]
3956
+ else:
3957
+ lines += [f + "\n"]
3958
+
3959
+ return "".join(l.lstrip() if l.lstrip().startswith("#line") else indent_block + l for l in lines)
3960
+
3961
+
3962
+ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
3963
+ if device == "cpu":
3964
+ indent = 4
3965
+ elif device == "cuda":
3966
+ if func_type == "kernel":
3967
+ indent = 8
3968
+ else:
3969
+ indent = 4
3970
+ else:
3971
+ raise ValueError(f"Device {device} not supported for codegen")
3972
+
3973
+ indent_block = " " * indent
3974
+
3975
+ lines = []
3976
+
3977
+ # argument vars
3978
+ if device == "cpu" and func_type == "kernel":
3979
+ lines += ["//---------\n"]
3980
+ lines += ["// argument vars\n"]
3981
+
3982
+ for var in adj.args:
3983
+ lines += [f"{var.ctype()} {var.emit()} = _wp_args->{var.label};\n"]
3984
+
3985
+ for var in adj.args:
3986
+ lines += [f"{var.ctype()} {var.emit_adj()} = _wp_adj_args->{var.label};\n"]
3987
+
3988
+ # primal vars
3989
+ lines += ["//---------\n"]
3990
+ lines += ["// primal vars\n"]
3991
+
3992
+ for var in adj.variables:
3993
+ if is_tile(var.type):
3994
+ lines += [f"{var.ctype()} {var.emit()} = {var.type.cinit(requires_grad=True)};\n"]
3995
+ elif var.constant is None:
3996
+ lines += [f"{var.ctype()} {var.emit()};\n"]
3997
+ else:
3998
+ lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
3999
+
4000
+ if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
4001
+ lines.insert(-1, f"{line_directive}\n")
4002
+
4003
+ # dual vars
4004
+ lines += ["//---------\n"]
4005
+ lines += ["// dual vars\n"]
4006
+
4007
+ for var in adj.variables:
4008
+ name = var.emit_adj()
4009
+ ctype = var.ctype(value_type=True)
4010
+
4011
+ if is_tile(var.type):
4012
+ if var.type.storage == "register":
4013
+ lines += [
4014
+ f"{var.type.ctype()} {name}(0.0);\n"
4015
+ ] # reverse mode tiles alias the forward vars since shared tiles store both primal/dual vars together
4016
+ elif var.type.storage == "shared":
4017
+ lines += [
4018
+ f"{var.type.ctype()}& {name} = {var.emit()};\n"
4019
+ ] # reverse mode tiles alias the forward vars since shared tiles store both primal/dual vars together
4020
+ else:
4021
+ lines += [f"{ctype} {name} = {{}};\n"]
4022
+
4023
+ if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
4024
+ lines.insert(-1, f"{line_directive}\n")
4025
+
4026
+ # forward pass
4027
+ lines += ["//---------\n"]
4028
+ lines += ["// forward\n"]
4029
+
4030
+ for f in adj.blocks[0].body_replay:
4031
+ lines += [f + "\n"]
4032
+
4033
+ # reverse pass
4034
+ lines += ["//---------\n"]
4035
+ lines += ["// reverse\n"]
4036
+
4037
+ for l in reversed(adj.blocks[0].body_reverse):
4038
+ lines += [l + "\n"]
4039
+
4040
+ # In grid-stride kernels the reverse body is in a for loop
4041
+ if device == "cuda" and func_type == "kernel":
4042
+ lines += ["continue;\n"]
4043
+ else:
4044
+ lines += ["return;\n"]
4045
+
4046
+ return "".join(l.lstrip() if l.lstrip().startswith("#line") else indent_block + l for l in lines)
4047
+
4048
+
4049
+ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
4050
+ if options is None:
4051
+ options = {}
4052
+
4053
+ if adj.return_var is not None and "return" in adj.arg_types:
4054
+ if get_origin(adj.arg_types["return"]) is tuple:
4055
+ if len(get_args(adj.arg_types["return"])) != len(adj.return_var):
4056
+ raise WarpCodegenError(
4057
+ f"The function `{adj.fun_name}` has its return type "
4058
+ f"annotated as a tuple of {len(get_args(adj.arg_types['return']))} elements "
4059
+ f"but the code returns {len(adj.return_var)} values."
4060
+ )
4061
+ elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var), match_generic=True):
4062
+ raise WarpCodegenError(
4063
+ f"The function `{adj.fun_name}` has its return type "
4064
+ f"annotated as `{warp._src.context.type_str(adj.arg_types['return'])}` "
4065
+ f"but the code returns a tuple with types `({', '.join(warp._src.context.type_str(x.type) for x in adj.return_var)})`."
4066
+ )
4067
+ elif len(adj.return_var) > 1 and get_origin(adj.arg_types["return"]) is not tuple:
4068
+ raise WarpCodegenError(
4069
+ f"The function `{adj.fun_name}` has its return type "
4070
+ f"annotated as `{warp._src.context.type_str(adj.arg_types['return'])}` "
4071
+ f"but the code returns {len(adj.return_var)} values."
4072
+ )
4073
+ elif not types_equal(adj.arg_types["return"], adj.return_var[0].type):
4074
+ raise WarpCodegenError(
4075
+ f"The function `{adj.fun_name}` has its return type "
4076
+ f"annotated as `{warp._src.context.type_str(adj.arg_types['return'])}` "
4077
+ f"but the code returns a value of type `{warp._src.context.type_str(adj.return_var[0].type)}`."
4078
+ )
4079
+ elif (
4080
+ isinstance(adj.return_var[0].type, warp._src.types.fixedarray)
4081
+ and type(adj.arg_types["return"]) is warp._src.types.array
4082
+ ):
4083
+ # If the return statement yields a `fixedarray` while the function is annotated
4084
+ # to return a standard `array`, then raise an error since the `fixedarray` storage
4085
+ # allocated on the stack will be freed once the function exits, meaning that the
4086
+ # resulting `array` instance will point to an invalid data.
4087
+ raise WarpCodegenError(
4088
+ f"The function `{adj.fun_name}` returns a fixed-size array "
4089
+ f"whereas it has its return type annotated as "
4090
+ f"`{warp._src.context.type_str(adj.arg_types['return'])}`."
4091
+ )
4092
+
4093
+ # Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
4094
+ # This is used as a catch-all C-to-Python source line mapping for any code that does not have
4095
+ # a direct mapping to a Python source line.
4096
+ func_line_directive = ""
4097
+ if line_directive := adj.get_line_directive("", adj.fun_def_lineno - 1):
4098
+ func_line_directive = f"{line_directive}\n"
4099
+
4100
+ # forward header
4101
+ if adj.return_var is not None and len(adj.return_var) == 1:
4102
+ return_type = adj.return_var[0].ctype()
4103
+ else:
4104
+ return_type = "void"
4105
+
4106
+ has_multiple_outputs = adj.return_var is not None and len(adj.return_var) != 1
4107
+
4108
+ forward_args = []
4109
+ reverse_args = []
4110
+
4111
+ # forward args
4112
+ for i, arg in enumerate(adj.args):
4113
+ s = f"{arg.ctype()} {arg.emit()}"
4114
+ forward_args.append(s)
4115
+ if not adj.custom_reverse_mode or i < adj.custom_reverse_num_input_args:
4116
+ reverse_args.append(s)
4117
+ if has_multiple_outputs:
4118
+ for i, arg in enumerate(adj.return_var):
4119
+ forward_args.append(arg.ctype() + " & ret_" + str(i))
4120
+ reverse_args.append(arg.ctype() + " & ret_" + str(i))
4121
+
4122
+ # reverse args
4123
+ for i, arg in enumerate(adj.args):
4124
+ if adj.custom_reverse_mode and i >= adj.custom_reverse_num_input_args:
4125
+ break
4126
+ # indexed array gradients are regular arrays
4127
+ if isinstance(arg.type, indexedarray):
4128
+ _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
4129
+ reverse_args.append(_arg.ctype() + " & adj_" + arg.label)
4130
+ else:
4131
+ reverse_args.append(arg.ctype() + " & adj_" + arg.label)
4132
+ if has_multiple_outputs:
4133
+ for i, arg in enumerate(adj.return_var):
4134
+ reverse_args.append(arg.ctype() + " & adj_ret_" + str(i))
4135
+ elif return_type != "void":
4136
+ reverse_args.append(return_type + " & adj_ret")
4137
+ # custom output reverse args (user-declared)
4138
+ if adj.custom_reverse_mode:
4139
+ for arg in adj.args[adj.custom_reverse_num_input_args :]:
4140
+ reverse_args.append(f"{arg.ctype()} & {arg.emit()}")
4141
+
4142
+ if device == "cpu":
4143
+ forward_template = cpu_forward_function_template
4144
+ reverse_template = cpu_reverse_function_template
4145
+ elif device == "cuda":
4146
+ forward_template = cuda_forward_function_template
4147
+ reverse_template = cuda_reverse_function_template
4148
+ else:
4149
+ raise ValueError(f"Device {device} is not supported")
4150
+
4151
+ # codegen body
4152
+ forward_body = codegen_func_forward(adj, func_type="function", device=device)
4153
+
4154
+ s = ""
4155
+ if not adj.skip_forward_codegen:
4156
+ s += forward_template.format(
4157
+ name=c_func_name,
4158
+ return_type=return_type,
4159
+ forward_args=indent(forward_args),
4160
+ forward_body=forward_body,
4161
+ filename=adj.filename,
4162
+ lineno=adj.fun_lineno,
4163
+ line_directive=func_line_directive,
4164
+ )
4165
+
4166
+ if not adj.skip_reverse_codegen:
4167
+ if adj.custom_reverse_mode:
4168
+ reverse_body = "\t// user-defined adjoint code\n" + forward_body
4169
+ else:
4170
+ if options.get("enable_backward", True) and adj.used_by_backward_kernel:
4171
+ reverse_body = codegen_func_reverse(adj, func_type="function", device=device)
4172
+ else:
4173
+ reverse_body = '\t// reverse mode disabled (module option "enable_backward" is False or no dependent kernel found with "enable_backward")\n'
4174
+ s += reverse_template.format(
4175
+ name=c_func_name,
4176
+ return_type=return_type,
4177
+ reverse_args=indent(reverse_args),
4178
+ forward_body=forward_body,
4179
+ reverse_body=reverse_body,
4180
+ filename=adj.filename,
4181
+ lineno=adj.fun_lineno,
4182
+ line_directive=func_line_directive,
4183
+ )
4184
+
4185
+ return s
4186
+
4187
+
4188
+ def codegen_snippet(adj, name, snippet, adj_snippet, replay_snippet):
4189
+ if adj.return_var is not None and len(adj.return_var) == 1:
4190
+ return_type = adj.return_var[0].ctype()
4191
+ else:
4192
+ return_type = "void"
4193
+
4194
+ forward_args = []
4195
+ reverse_args = []
4196
+
4197
+ # forward args
4198
+ for _i, arg in enumerate(adj.args):
4199
+ s = f"{arg.ctype()} {arg.emit().replace('var_', '')}"
4200
+ forward_args.append(s)
4201
+ reverse_args.append(s)
4202
+
4203
+ # reverse args
4204
+ for _i, arg in enumerate(adj.args):
4205
+ if isinstance(arg.type, indexedarray):
4206
+ _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
4207
+ reverse_args.append(_arg.ctype() + " & adj_" + arg.label)
4208
+ else:
4209
+ reverse_args.append(arg.ctype() + " & adj_" + arg.label)
4210
+ if return_type != "void":
4211
+ reverse_args.append(return_type + " & adj_ret")
4212
+
4213
+ forward_template = cuda_forward_function_template
4214
+ replay_template = cuda_forward_function_template
4215
+ reverse_template = cuda_reverse_function_template
4216
+
4217
+ s = ""
4218
+ s += forward_template.format(
4219
+ name=name,
4220
+ return_type=return_type,
4221
+ forward_args=indent(forward_args),
4222
+ forward_body=snippet,
4223
+ filename=adj.filename,
4224
+ lineno=adj.fun_lineno,
4225
+ line_directive="",
4226
+ )
4227
+
4228
+ if replay_snippet is not None:
4229
+ s += replay_template.format(
4230
+ name="replay_" + name,
4231
+ return_type=return_type,
4232
+ forward_args=indent(forward_args),
4233
+ forward_body=replay_snippet,
4234
+ filename=adj.filename,
4235
+ lineno=adj.fun_lineno,
4236
+ line_directive="",
4237
+ )
4238
+
4239
+ if adj_snippet:
4240
+ reverse_body = adj_snippet
4241
+ else:
4242
+ reverse_body = ""
4243
+
4244
+ s += reverse_template.format(
4245
+ name=name,
4246
+ return_type=return_type,
4247
+ reverse_args=indent(reverse_args),
4248
+ forward_body=snippet,
4249
+ reverse_body=reverse_body,
4250
+ filename=adj.filename,
4251
+ lineno=adj.fun_lineno,
4252
+ line_directive="",
4253
+ )
4254
+
4255
+ return s
4256
+
4257
+
4258
+ def codegen_kernel(kernel, device, options):
4259
+ # Update the module's options with the ones defined on the kernel, if any.
4260
+ options = dict(options)
4261
+ options.update(kernel.options)
4262
+
4263
+ adj = kernel.adj
4264
+
4265
+ args_struct = ""
4266
+ if device == "cpu":
4267
+ args_struct = f"struct wp_args_{kernel.get_mangled_name()} {{\n"
4268
+ for i in adj.args:
4269
+ args_struct += f" {i.ctype()} {i.label};\n"
4270
+ args_struct += "};\n"
4271
+
4272
+ # Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
4273
+ # This is used as a catch-all C-to-Python source line mapping for any code that does not have
4274
+ # a direct mapping to a Python source line.
4275
+ func_line_directive = ""
4276
+ if line_directive := adj.get_line_directive("", adj.fun_def_lineno - 1):
4277
+ func_line_directive = f"{line_directive}\n"
4278
+
4279
+ if device == "cpu":
4280
+ template_forward = cpu_kernel_template_forward
4281
+ template_backward = cpu_kernel_template_backward
4282
+ elif device == "cuda":
4283
+ template_forward = cuda_kernel_template_forward
4284
+ template_backward = cuda_kernel_template_backward
4285
+ else:
4286
+ raise ValueError(f"Device {device} is not supported")
4287
+
4288
+ template = ""
4289
+ template_fmt_args = {
4290
+ "name": kernel.get_mangled_name(),
4291
+ }
4292
+
4293
+ # build forward signature
4294
+ forward_args = ["wp::launch_bounds_t dim"]
4295
+ if device == "cpu":
4296
+ forward_args.append("size_t task_index")
4297
+ else:
4298
+ for arg in adj.args:
4299
+ forward_args.append(arg.ctype() + " var_" + arg.label)
4300
+
4301
+ forward_body = codegen_func_forward(adj, func_type="kernel", device=device)
4302
+ template_fmt_args.update(
4303
+ {
4304
+ "forward_args": indent(forward_args),
4305
+ "forward_body": forward_body,
4306
+ "line_directive": func_line_directive,
4307
+ }
4308
+ )
4309
+ template += template_forward
4310
+
4311
+ if options["enable_backward"]:
4312
+ # build reverse signature
4313
+ reverse_args = ["wp::launch_bounds_t dim"]
4314
+ if device == "cpu":
4315
+ reverse_args.append("size_t task_index")
4316
+ else:
4317
+ for arg in adj.args:
4318
+ reverse_args.append(arg.ctype() + " var_" + arg.label)
4319
+ for arg in adj.args:
4320
+ # indexed array gradients are regular arrays
4321
+ if isinstance(arg.type, indexedarray):
4322
+ _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
4323
+ reverse_args.append(_arg.ctype() + " adj_" + arg.label)
4324
+ else:
4325
+ reverse_args.append(arg.ctype() + " adj_" + arg.label)
4326
+
4327
+ reverse_body = codegen_func_reverse(adj, func_type="kernel", device=device)
4328
+ template_fmt_args.update(
4329
+ {
4330
+ "reverse_args": indent(reverse_args),
4331
+ "reverse_body": reverse_body,
4332
+ }
4333
+ )
4334
+ template += template_backward
4335
+
4336
+ s = template.format(**template_fmt_args)
4337
+ return args_struct + s
4338
+
4339
+
4340
+ def codegen_module(kernel, device, options):
4341
+ if device != "cpu":
4342
+ return ""
4343
+
4344
+ # Update the module's options with the ones defined on the kernel, if any.
4345
+ options = dict(options)
4346
+ options.update(kernel.options)
4347
+
4348
+ template = ""
4349
+ template_fmt_args = {
4350
+ "name": kernel.get_mangled_name(),
4351
+ }
4352
+
4353
+ template += cpu_module_template_forward
4354
+
4355
+ if options["enable_backward"]:
4356
+ template += cpu_module_template_backward
4357
+
4358
+ s = template.format(**template_fmt_args)
4359
+ return s