warp-lang 1.9.1__py3-none-manylinux_2_34_aarch64.whl → 1.10.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +882 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/{builtins.py → _src/builtins.py} +1435 -379
  8. warp/_src/codegen.py +4361 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/{fem → _src/fem}/domain.py +42 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +32 -15
  22. warp/{fem → _src/fem}/field/restriction.py +3 -1
  23. warp/{fem → _src/fem}/field/virtual.py +55 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +79 -163
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +16 -22
  28. warp/{fem → _src/fem}/geometry/element.py +34 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +50 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +14 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +14 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +42 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +256 -247
  34. warp/{fem → _src/fem}/geometry/partition.py +123 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +28 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +42 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +28 -45
  38. warp/{fem → _src/fem}/integrate.py +166 -158
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +17 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +97 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +22 -11
  46. warp/_src/fem/space/basis_space.py +681 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +5 -3
  48. warp/{fem → _src/fem}/space/function_space.py +16 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +6 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +6 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +6 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +5 -9
  53. warp/{fem → _src/fem}/space/partition.py +119 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +6 -10
  55. warp/{fem → _src/fem}/space/restriction.py +68 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +11 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +10 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +8 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +5 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +5 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +5 -9
  63. warp/_src/fem/space/topology.py +461 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +5 -9
  65. warp/_src/fem/types.py +114 -0
  66. warp/_src/fem/utils.py +488 -0
  67. warp/_src/jax.py +188 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +389 -0
  70. warp/_src/jax_experimental/ffi.py +1286 -0
  71. warp/_src/jax_experimental/xla_ffi.py +658 -0
  72. warp/_src/marching_cubes.py +710 -0
  73. warp/_src/math.py +416 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +165 -0
  76. warp/_src/optim/linear.py +1608 -0
  77. warp/_src/optim/sgd.py +114 -0
  78. warp/_src/paddle.py +408 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +291 -0
  81. warp/_src/render/render_opengl.py +3638 -0
  82. warp/_src/render/render_usd.py +939 -0
  83. warp/_src/render/utils.py +162 -0
  84. warp/_src/sparse.py +2718 -0
  85. warp/_src/tape.py +1208 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +393 -0
  88. warp/_src/types.py +5888 -0
  89. warp/_src/utils.py +1695 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.so +0 -0
  92. warp/bin/warp.so +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +3 -3
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +521 -250
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +18 -17
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +578 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/top_level.txt +0 -0
warp/_src/codegen.py ADDED
@@ -0,0 +1,4361 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import ast
19
+ import builtins
20
+ import ctypes
21
+ import enum
22
+ import functools
23
+ import hashlib
24
+ import inspect
25
+ import itertools
26
+ import math
27
+ import re
28
+ import sys
29
+ import textwrap
30
+ import types
31
+ from typing import Any, Callable, ClassVar, Mapping, Sequence, get_args, get_origin
32
+
33
+ import warp._src.config
34
+ from warp._src.types import *
35
+
36
+ _wp_module_name_ = "warp.codegen"
37
+
38
+ # used as a globally accessible copy
39
+ # of current compile options (block_dim) etc
40
+ options = {}
41
+
42
+
43
+ class WarpCodegenError(RuntimeError):
44
+ def __init__(self, message):
45
+ super().__init__(message)
46
+
47
+
48
+ class WarpCodegenTypeError(TypeError):
49
+ def __init__(self, message):
50
+ super().__init__(message)
51
+
52
+
53
+ class WarpCodegenAttributeError(AttributeError):
54
+ def __init__(self, message):
55
+ super().__init__(message)
56
+
57
+
58
+ def get_node_name_safe(node):
59
+ """Safely get a string representation of an AST node for error messages.
60
+
61
+ This handles different AST node types (Name, Subscript, etc.) without
62
+ raising AttributeError when accessing attributes that may not exist.
63
+ """
64
+ if hasattr(node, "id"):
65
+ return node.id
66
+ elif hasattr(node, "value") and hasattr(node, "slice"):
67
+ # Subscript node like inputs[tid]
68
+ base_name = get_node_name_safe(node.value)
69
+ return f"{base_name}[...]"
70
+ else:
71
+ return f"<{type(node).__name__}>"
72
+
73
+
74
+ class WarpCodegenKeyError(KeyError):
75
+ def __init__(self, message):
76
+ super().__init__(message)
77
+
78
+
79
+ # map operator to function name
80
+ builtin_operators: dict[type[ast.AST], str] = {}
81
+
82
+ # see https://www.ics.uci.edu/~pattis/ICS-31/lectures/opexp.pdf for a
83
+ # nice overview of python operators
84
+
85
+ builtin_operators[ast.Add] = "add"
86
+ builtin_operators[ast.Sub] = "sub"
87
+ builtin_operators[ast.Mult] = "mul"
88
+ builtin_operators[ast.MatMult] = "mul"
89
+ builtin_operators[ast.Div] = "div"
90
+ builtin_operators[ast.FloorDiv] = "floordiv"
91
+ builtin_operators[ast.Pow] = "pow"
92
+ builtin_operators[ast.Mod] = "mod"
93
+ builtin_operators[ast.UAdd] = "pos"
94
+ builtin_operators[ast.USub] = "neg"
95
+ builtin_operators[ast.Not] = "unot"
96
+
97
+ builtin_operators[ast.Gt] = ">"
98
+ builtin_operators[ast.Lt] = "<"
99
+ builtin_operators[ast.GtE] = ">="
100
+ builtin_operators[ast.LtE] = "<="
101
+ builtin_operators[ast.Eq] = "=="
102
+ builtin_operators[ast.NotEq] = "!="
103
+
104
+ builtin_operators[ast.BitAnd] = "bit_and"
105
+ builtin_operators[ast.BitOr] = "bit_or"
106
+ builtin_operators[ast.BitXor] = "bit_xor"
107
+ builtin_operators[ast.Invert] = "invert"
108
+ builtin_operators[ast.LShift] = "lshift"
109
+ builtin_operators[ast.RShift] = "rshift"
110
+
111
+ comparison_chain_strings = [
112
+ builtin_operators[ast.Gt],
113
+ builtin_operators[ast.Lt],
114
+ builtin_operators[ast.LtE],
115
+ builtin_operators[ast.GtE],
116
+ builtin_operators[ast.Eq],
117
+ builtin_operators[ast.NotEq],
118
+ ]
119
+
120
+
121
+ def values_check_equal(a, b):
122
+ if isinstance(a, Sequence) and isinstance(b, Sequence):
123
+ if len(a) != len(b):
124
+ return False
125
+
126
+ return all(x == y for x, y in zip(a, b))
127
+
128
+ return a == b
129
+
130
+
131
+ def op_str_is_chainable(op: str) -> builtins.bool:
132
+ return op in comparison_chain_strings
133
+
134
+
135
+ def get_closure_cell_contents(obj):
136
+ """Retrieve a closure's cell contents or `None` if it's empty."""
137
+ try:
138
+ return obj.cell_contents
139
+ except ValueError:
140
+ pass
141
+
142
+ return None
143
+
144
+
145
+ def eval_annotations(annotations: Mapping[str, Any], obj: Any) -> Mapping[str, Any]:
146
+ """Un-stringize annotations caused by `from __future__ import annotations` of PEP 563."""
147
+ # Implementation backported from `inspect.get_annotations()` for Python 3.9 and older.
148
+ if not annotations:
149
+ return {}
150
+
151
+ if not any(isinstance(x, str) for x in annotations.values()):
152
+ # No annotation to un-stringize.
153
+ return annotations
154
+
155
+ if isinstance(obj, type):
156
+ # class
157
+ globals = {}
158
+ module_name = getattr(obj, "__module__", None)
159
+ if module_name:
160
+ module = sys.modules.get(module_name, None)
161
+ if module:
162
+ globals = getattr(module, "__dict__", {})
163
+ locals = dict(vars(obj))
164
+ unwrap = obj
165
+ elif isinstance(obj, types.ModuleType):
166
+ # module
167
+ globals = obj.__dict__
168
+ locals = {}
169
+ unwrap = None
170
+ elif callable(obj):
171
+ # function
172
+ globals = getattr(obj, "__globals__", {})
173
+ # Capture the variables from the surrounding scope.
174
+ closure_vars = zip(
175
+ obj.__code__.co_freevars, tuple(get_closure_cell_contents(x) for x in (obj.__closure__ or ()))
176
+ )
177
+ locals = {k: v for k, v in closure_vars if v is not None}
178
+ unwrap = obj
179
+ else:
180
+ raise TypeError(f"{obj!r} is not a module, class, or callable.")
181
+
182
+ if unwrap is not None:
183
+ while True:
184
+ if hasattr(unwrap, "__wrapped__"):
185
+ unwrap = unwrap.__wrapped__
186
+ continue
187
+ if isinstance(unwrap, functools.partial):
188
+ unwrap = unwrap.func
189
+ continue
190
+ break
191
+ if hasattr(unwrap, "__globals__"):
192
+ globals = unwrap.__globals__
193
+
194
+ # "Inject" type parameters into the local namespace
195
+ # (unless they are shadowed by assignments *in* the local namespace),
196
+ # as a way of emulating annotation scopes when calling `eval()`
197
+ type_params = getattr(obj, "__type_params__", ())
198
+ if type_params:
199
+ locals = {param.__name__: param for param in type_params} | locals
200
+
201
+ return {k: v if not isinstance(v, str) else eval(v, globals, locals) for k, v in annotations.items()}
202
+
203
+
204
+ def get_annotations(obj: Any) -> Mapping[str, Any]:
205
+ """Same as `inspect.get_annotations()` but always returning un-stringized annotations."""
206
+ # Python 3.10+: Use the built-in inspect.get_annotations() which handles
207
+ # PEP 649 (deferred annotation evaluation) in Python 3.14+
208
+ if hasattr(inspect, "get_annotations"):
209
+ # eval_str=True ensures stringized annotations from PEP 563 are evaluated
210
+ return inspect.get_annotations(obj, eval_str=True)
211
+ else:
212
+ # Python 3.9 and older: Manual backport of inspect.get_annotations()
213
+ # See https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
214
+ if isinstance(obj, type):
215
+ annotations = obj.__dict__.get("__annotations__", {})
216
+ else:
217
+ annotations = getattr(obj, "__annotations__", {})
218
+
219
+ return eval_annotations(annotations, obj)
220
+
221
+
222
+ def get_full_arg_spec(func: Callable) -> inspect.FullArgSpec:
223
+ """Same as `inspect.getfullargspec()` but always returning un-stringized annotations."""
224
+ spec = inspect.getfullargspec(func)
225
+
226
+ # Python 3.10+: Use inspect.get_annotations()
227
+ if hasattr(inspect, "get_annotations"):
228
+ # Capture closure variables to handle cases like `foo.Data` where `foo` is a closure variable
229
+ closure_vars = dict(
230
+ zip(func.__code__.co_freevars, (get_closure_cell_contents(x) for x in (func.__closure__ or ())))
231
+ )
232
+ # Filter out None values from empty cells
233
+ closure_vars = {k: v for k, v in closure_vars.items() if v is not None}
234
+ return spec._replace(annotations=inspect.get_annotations(func, eval_str=True, locals=closure_vars))
235
+ else:
236
+ # Python 3.9 and older: Manually un-stringize annotations
237
+ # See https://docs.python.org/3/howto/annotations.html#manually-un-stringizing-stringized-annotations
238
+ return spec._replace(annotations=eval_annotations(spec.annotations, func))
239
+
240
+
241
+ def struct_instance_repr_recursive(inst: StructInstance, depth: int, use_repr: bool) -> str:
242
+ indent = "\t"
243
+
244
+ # handle empty structs
245
+ if len(inst._cls.vars) == 0:
246
+ return f"{inst._cls.key}()"
247
+
248
+ lines = []
249
+ lines.append(f"{inst._cls.key}(")
250
+
251
+ for field_name, _ in inst._cls.ctype._fields_:
252
+ field_value = getattr(inst, field_name, None)
253
+
254
+ if isinstance(field_value, StructInstance):
255
+ field_value = struct_instance_repr_recursive(field_value, depth + 1, use_repr)
256
+
257
+ if use_repr:
258
+ lines.append(f"{indent * (depth + 1)}{field_name}={field_value!r},")
259
+ else:
260
+ lines.append(f"{indent * (depth + 1)}{field_name}={field_value!s},")
261
+
262
+ lines.append(f"{indent * depth})")
263
+ return "\n".join(lines)
264
+
265
+
266
+ class StructInstance:
267
+ def __init__(self, ctype):
268
+ # maintain a c-types object for the top-level instance the struct
269
+ super().__setattr__("_ctype", ctype)
270
+
271
+ # create Python attributes for each of the struct's variables
272
+ for k, cst in type(self)._constructors:
273
+ self.__dict__[k] = cst(ctype)
274
+
275
+ def __setattr__(self, name, value):
276
+ try:
277
+ self._setters[name](self, value)
278
+ except KeyError as err:
279
+ raise RuntimeError(f"Trying to set Warp struct attribute that does not exist {name}") from err
280
+
281
+ def __ctype__(self):
282
+ return self._ctype
283
+
284
+ def __repr__(self):
285
+ return struct_instance_repr_recursive(self, 0, use_repr=True)
286
+
287
+ def __str__(self):
288
+ return struct_instance_repr_recursive(self, 0, use_repr=False)
289
+
290
+ def assign(self, value):
291
+ """Assigns the values of another struct instance to this one."""
292
+ if not isinstance(value, StructInstance):
293
+ raise RuntimeError(
294
+ f"Trying to assign a non-structure value to a struct attribute with type: {self._cls.key}"
295
+ )
296
+
297
+ if self._cls.key is not value._cls.key:
298
+ raise RuntimeError(
299
+ f"Trying to assign a structure of type {value._cls.key} to an attribute of {self._cls.key}"
300
+ )
301
+
302
+ # update all nested ctype vars by deep copy
303
+ for n in self._cls.vars:
304
+ setattr(self, n, getattr(value, n))
305
+
306
+ def to(self, device):
307
+ """Copies this struct with all array members moved onto the given device.
308
+
309
+ Arrays already living on the desired device are referenced as-is, while
310
+ arrays being moved are copied.
311
+ """
312
+ out = self._cls()
313
+ stack = [(self, out, k, v) for k, v in self._cls.vars.items()]
314
+ while stack:
315
+ src, dst, name, var = stack.pop()
316
+ value = getattr(src, name)
317
+ if isinstance(var.type, array):
318
+ # array_t
319
+ setattr(dst, name, value.to(device))
320
+ elif isinstance(var.type, Struct):
321
+ # nested struct
322
+ new_struct = var.type()
323
+ setattr(dst, name, new_struct)
324
+ # The call to `setattr()` just above makes a copy of `new_struct`
325
+ # so we need to reference that new instance of the struct.
326
+ new_struct = getattr(dst, name)
327
+ stack.extend((value, new_struct, k, v) for k, v in var.type.vars.items())
328
+ else:
329
+ setattr(dst, name, value)
330
+
331
+ return out
332
+
333
+ # type description used in numpy structured arrays
334
+ def numpy_dtype(self):
335
+ return self._cls.numpy_dtype()
336
+
337
+ # value usable in numpy structured arrays of .numpy_dtype(), e.g. (42, 13.37, [1.0, 2.0, 3.0])
338
+ def numpy_value(self):
339
+ npvalue = []
340
+ for name, var in self._cls.vars.items():
341
+ # get the attribute value
342
+ value = getattr(self._ctype, name)
343
+
344
+ if isinstance(var.type, array):
345
+ # array_t
346
+ npvalue.append(value.numpy_value())
347
+ elif isinstance(var.type, Struct):
348
+ # nested struct
349
+ npvalue.append(value.numpy_value())
350
+ elif issubclass(var.type, ctypes.Array):
351
+ if len(var.type._shape_) == 1:
352
+ # vector
353
+ npvalue.append(list(value))
354
+ else:
355
+ # matrix
356
+ npvalue.append([list(row) for row in value])
357
+ else:
358
+ # scalar
359
+ if var.type == warp.float16:
360
+ npvalue.append(half_bits_to_float(value))
361
+ else:
362
+ npvalue.append(value)
363
+
364
+ return tuple(npvalue)
365
+
366
+
367
+ def _make_struct_field_constructor(field: str, var_type: type):
368
+ if isinstance(var_type, Struct):
369
+ return lambda ctype: var_type.instance_type(ctype=getattr(ctype, field))
370
+ elif isinstance(var_type, warp._src.types.array):
371
+ return lambda ctype: None
372
+ elif issubclass(var_type, ctypes.Array):
373
+ # for vector/matrices, the Python attribute aliases the ctype one
374
+ return lambda ctype: getattr(ctype, field)
375
+ else:
376
+ return lambda ctype: var_type()
377
+
378
+
379
+ def _make_struct_field_setter(cls, field: str, var_type: type):
380
+ def set_array_value(inst, value):
381
+ if value is None:
382
+ # create array with null pointer
383
+ setattr(inst._ctype, field, array_t())
384
+ else:
385
+ # wp.array
386
+ assert isinstance(value, array)
387
+ assert types_equal(value.dtype, var_type.dtype), (
388
+ f"assign to struct member variable {field} failed, expected type {type_repr(var_type.dtype)}, got type {type_repr(value.dtype)}"
389
+ )
390
+ setattr(inst._ctype, field, value.__ctype__())
391
+
392
+ # workaround to prevent gradient buffers being garbage collected
393
+ # since users can do struct.array.requires_grad = False the gradient array
394
+ # would be collected while the struct ctype still holds a reference to it
395
+ if value.requires_grad:
396
+ cls.__setattr__(inst, "_" + field + "_grad", value.grad)
397
+
398
+ cls.__setattr__(inst, field, value)
399
+
400
+ def set_struct_value(inst, value):
401
+ getattr(inst, field).assign(value)
402
+
403
+ def set_vector_value(inst, value):
404
+ # vector/matrix type, e.g. vec3
405
+ if value is None:
406
+ setattr(inst._ctype, field, var_type())
407
+ elif type(value) is var_type:
408
+ setattr(inst._ctype, field, value)
409
+ else:
410
+ # conversion from list/tuple, ndarray, etc.
411
+ setattr(inst._ctype, field, var_type(value))
412
+
413
+ # no need to update the Python attribute,
414
+ # it's already aliasing the ctype one
415
+
416
+ def set_primitive_value(inst, value):
417
+ # primitive type
418
+ if value is None:
419
+ # zero initialize
420
+ setattr(inst._ctype, field, var_type._type_())
421
+ else:
422
+ if hasattr(value, "_type_"):
423
+ # assigning warp type value (e.g.: wp.float32)
424
+ value = value.value
425
+ # float16 needs conversion to uint16 bits
426
+ if var_type == warp.float16:
427
+ setattr(inst._ctype, field, float_to_half_bits(value))
428
+ else:
429
+ setattr(inst._ctype, field, value)
430
+
431
+ cls.__setattr__(inst, field, value)
432
+
433
+ if isinstance(var_type, array):
434
+ return set_array_value
435
+ elif isinstance(var_type, Struct):
436
+ return set_struct_value
437
+ elif issubclass(var_type, ctypes.Array):
438
+ return set_vector_value
439
+ else:
440
+ return set_primitive_value
441
+
442
+
443
+ class Struct:
444
+ hash: bytes
445
+
446
+ def __init__(self, key: str, cls: type, module: warp._src.context.Module):
447
+ self.key = key
448
+ self.cls = cls
449
+ self.module = module
450
+ self.vars: dict[str, Var] = {}
451
+
452
+ if isinstance(self.cls, Sequence):
453
+ raise RuntimeError("Warp structs must be defined as base classes")
454
+
455
+ annotations = get_annotations(self.cls)
456
+ for label, type_ in annotations.items():
457
+ self.vars[label] = Var(label, type_)
458
+
459
+ fields = []
460
+ for label, var in self.vars.items():
461
+ if isinstance(var.type, array):
462
+ fields.append((label, array_t))
463
+ elif isinstance(var.type, Struct):
464
+ fields.append((label, var.type.ctype))
465
+ elif issubclass(var.type, ctypes.Array):
466
+ fields.append((label, var.type))
467
+ else:
468
+ # HACK: fp16 requires conversion functions from warp.so
469
+ if var.type is warp.float16:
470
+ warp.init()
471
+ fields.append((label, var.type._type_))
472
+
473
+ class StructType(ctypes.Structure):
474
+ # if struct is empty, add a dummy field to avoid launch errors on CPU device ("ffi_prep_cif failed")
475
+ _fields_ = fields or [("_dummy_", ctypes.c_byte)]
476
+
477
+ self.ctype = StructType
478
+
479
+ # Compute the hash. We can cache the hash because it's static, even with nested structs.
480
+ # All field types are specified in the annotations, so they're resolved at declaration time.
481
+ ch = hashlib.sha256()
482
+
483
+ ch.update(bytes(self.key, "utf-8"))
484
+
485
+ for name, type_hint in annotations.items():
486
+ s = f"{name}:{warp._src.types.get_type_code(type_hint)}"
487
+ ch.update(bytes(s, "utf-8"))
488
+
489
+ # recurse on nested structs
490
+ if isinstance(type_hint, Struct):
491
+ ch.update(type_hint.hash)
492
+
493
+ self.hash = ch.digest()
494
+
495
+ # generate unique identifier for structs in native code
496
+ hash_suffix = f"{self.hash.hex()[:8]}"
497
+ self.native_name = f"{self.key}_{hash_suffix}"
498
+
499
+ # create default constructor (zero-initialize)
500
+ self.default_constructor = warp._src.context.Function(
501
+ func=None,
502
+ key=self.native_name,
503
+ namespace="",
504
+ value_func=lambda *_: self,
505
+ input_types={},
506
+ initializer_list_func=lambda *_: False,
507
+ native_func=self.native_name,
508
+ )
509
+
510
+ # build a constructor that takes each param as a value
511
+ input_types = {label: var.type for label, var in self.vars.items()}
512
+
513
+ self.value_constructor = warp._src.context.Function(
514
+ func=None,
515
+ key=self.native_name,
516
+ namespace="",
517
+ value_func=lambda *_: self,
518
+ input_types=input_types,
519
+ initializer_list_func=lambda *_: False,
520
+ native_func=self.native_name,
521
+ )
522
+
523
+ self.default_constructor.add_overload(self.value_constructor)
524
+
525
+ if isinstance(module, warp._src.context.Module):
526
+ module.register_struct(self)
527
+
528
+ # Define class for instances of this struct
529
+ # To enable autocomplete on s, we inherit from self.cls.
530
+ # For example,
531
+
532
+ # @wp.struct
533
+ # class A:
534
+ # # annotations
535
+ # ...
536
+
537
+ # The type annotations are inherited in A(), allowing autocomplete in kernels
538
+ class NewStructInstance(self.cls, StructInstance):
539
+ cls: ClassVar[type] = self.cls
540
+ native_name: ClassVar[str] = self.native_name
541
+
542
+ _cls: ClassVar[type] = self
543
+ _constructors: ClassVar[list[tuple[str, Callable]]] = [
544
+ (field, _make_struct_field_constructor(field, var.type)) for field, var in self.vars.items()
545
+ ]
546
+ _setters: ClassVar[dict[str, Callable]] = {
547
+ field: _make_struct_field_setter(self.cls, field, var.type) for field, var in self.vars.items()
548
+ }
549
+
550
+ def __init__(inst, ctype=None):
551
+ StructInstance.__init__(inst, ctype or self.ctype())
552
+
553
+ self.instance_type = NewStructInstance
554
+
555
+ def __call__(self):
556
+ """
557
+ This function returns s = StructInstance(self)
558
+ s uses self.cls as template.
559
+ """
560
+ return self.instance_type()
561
+
562
+ def initializer(self):
563
+ return self.default_constructor
564
+
565
+ # return structured NumPy dtype, including field names, formats, and offsets
566
+ def numpy_dtype(self):
567
+ names = []
568
+ formats = []
569
+ offsets = []
570
+ for name, var in self.vars.items():
571
+ names.append(name)
572
+ offsets.append(getattr(self.ctype, name).offset)
573
+ if isinstance(var.type, array):
574
+ # array_t
575
+ formats.append(array_t.numpy_dtype())
576
+ elif isinstance(var.type, Struct):
577
+ # nested struct
578
+ formats.append(var.type.numpy_dtype())
579
+ elif issubclass(var.type, ctypes.Array):
580
+ scalar_typestr = type_typestr(var.type._wp_scalar_type_)
581
+ if len(var.type._shape_) == 1:
582
+ # vector
583
+ formats.append(f"{var.type._length_}{scalar_typestr}")
584
+ else:
585
+ # matrix
586
+ formats.append(f"{var.type._shape_}{scalar_typestr}")
587
+ else:
588
+ # scalar
589
+ formats.append(type_typestr(var.type))
590
+
591
+ return {"names": names, "formats": formats, "offsets": offsets, "itemsize": ctypes.sizeof(self.ctype)}
592
+
593
+ # constructs a Warp struct instance from a pointer to the ctype
594
+ def from_ptr(self, ptr):
595
+ if not ptr:
596
+ raise RuntimeError("NULL pointer exception")
597
+
598
+ # create a new struct instance
599
+ instance = self()
600
+
601
+ for name, var in self.vars.items():
602
+ offset = getattr(self.ctype, name).offset
603
+ if isinstance(var.type, array):
604
+ # We could reconstruct wp.array from array_t, but it's problematic.
605
+ # There's no guarantee that the original wp.array is still allocated and
606
+ # no easy way to make a backref.
607
+ # Instead, we just create a stub annotation, which is not a fully usable array object.
608
+ setattr(instance, name, array(dtype=var.type.dtype, ndim=var.type.ndim))
609
+ elif isinstance(var.type, Struct):
610
+ # nested struct
611
+ value = var.type.from_ptr(ptr + offset)
612
+ setattr(instance, name, value)
613
+ elif issubclass(var.type, ctypes.Array):
614
+ # vector/matrix
615
+ value = var.type.from_ptr(ptr + offset)
616
+ setattr(instance, name, value)
617
+ else:
618
+ # scalar
619
+ cvalue = ctypes.cast(ptr + offset, ctypes.POINTER(var.type._type_)).contents
620
+ if var.type == warp.float16:
621
+ setattr(instance, name, half_bits_to_float(cvalue))
622
+ else:
623
+ setattr(instance, name, cvalue.value)
624
+
625
+ return instance
626
+
627
+
628
+ class Reference:
629
+ def __init__(self, value_type):
630
+ self.value_type = value_type
631
+
632
+
633
+ def is_reference(type: Any) -> builtins.bool:
634
+ return isinstance(type, Reference)
635
+
636
+
637
+ def strip_reference(arg: Any) -> Any:
638
+ if is_reference(arg):
639
+ return arg.value_type
640
+ else:
641
+ return arg
642
+
643
+
644
+ def compute_type_str(base_name, template_params):
645
+ if not template_params:
646
+ return base_name
647
+
648
+ def param2str(p):
649
+ if isinstance(p, builtins.bool):
650
+ return "true" if p else "false"
651
+ if isinstance(p, int):
652
+ return str(p)
653
+ elif hasattr(p, "_wp_generic_type_str_"):
654
+ return compute_type_str(f"wp::{p._wp_generic_type_str_}", p._wp_type_params_)
655
+ elif hasattr(p, "_type_"):
656
+ if p.__name__ == "bool":
657
+ return "bool"
658
+ else:
659
+ return f"wp::{p.__name__}"
660
+ elif is_tile(p):
661
+ return p.ctype()
662
+ elif isinstance(p, Struct):
663
+ return p.native_name
664
+
665
+ return p.__name__
666
+
667
+ return f"{base_name}<{', '.join(map(param2str, template_params))}>"
668
+
669
+
670
+ class Var:
671
+ def __init__(
672
+ self,
673
+ label: str,
674
+ type: type,
675
+ requires_grad: builtins.bool = False,
676
+ constant: builtins.bool | None = None,
677
+ prefix: builtins.bool = True,
678
+ relative_lineno: int | None = None,
679
+ ):
680
+ # convert built-in types to wp types
681
+ if type == float:
682
+ type = float32
683
+ elif type == int:
684
+ type = int32
685
+ elif type == builtins.bool:
686
+ type = bool
687
+
688
+ self.label = label
689
+ self.type = type
690
+ self.requires_grad = requires_grad
691
+ self.constant = constant
692
+ self.prefix = prefix
693
+
694
+ # records whether this Var has been read from in a kernel function (array only)
695
+ self.is_read = False
696
+ # records whether this Var has been written to in a kernel function (array only)
697
+ self.is_write = False
698
+
699
+ # used to associate a view array Var with its parent array Var
700
+ self.parent = None
701
+
702
+ # Used to associate the variable with the Python statement that resulted in it being created.
703
+ self.relative_lineno = relative_lineno
704
+
705
+ def __str__(self):
706
+ return self.label
707
+
708
+ @staticmethod
709
+ def dtype_to_ctype(t: type) -> str:
710
+ if hasattr(t, "_wp_generic_type_str_"):
711
+ return compute_type_str(f"wp::{t._wp_generic_type_str_}", t._wp_type_params_)
712
+ elif isinstance(t, Struct):
713
+ return t.native_name
714
+ elif hasattr(t, "_wp_native_name_"):
715
+ return f"wp::{t._wp_native_name_}"
716
+ elif t.__name__ in ("bool", "int", "float"):
717
+ return t.__name__
718
+
719
+ return f"wp::{t.__name__}"
720
+
721
+ @staticmethod
722
+ def type_to_ctype(t: type, value_type: builtins.bool = False) -> str:
723
+ if isinstance(t, fixedarray):
724
+ template_args = (str(t.size), Var.dtype_to_ctype(t.dtype))
725
+ dtypestr = ", ".join(template_args)
726
+ classstr = f"wp::{type(t).__name__}"
727
+ return f"{classstr}_t<{dtypestr}>"
728
+ elif is_array(t):
729
+ dtypestr = Var.dtype_to_ctype(t.dtype)
730
+ classstr = f"wp::{type(t).__name__}"
731
+ return f"{classstr}_t<{dtypestr}>"
732
+ elif get_origin(t) is tuple:
733
+ dtypestr = ", ".join(Var.dtype_to_ctype(x) for x in get_args(t))
734
+ return f"wp::tuple_t<{dtypestr}>"
735
+ elif is_tuple(t):
736
+ dtypestr = ", ".join(Var.dtype_to_ctype(x) for x in t.types)
737
+ classstr = f"wp::{type(t).__name__}"
738
+ return f"{classstr}<{dtypestr}>"
739
+ elif is_tile(t):
740
+ return t.ctype()
741
+ elif isinstance(t, type) and issubclass(t, StructInstance):
742
+ # ensure the actual Struct name is used instead of "NewStructInstance"
743
+ return t.native_name
744
+ elif is_reference(t):
745
+ if not value_type:
746
+ return Var.type_to_ctype(t.value_type) + "*"
747
+
748
+ return Var.type_to_ctype(t.value_type)
749
+
750
+ return Var.dtype_to_ctype(t)
751
+
752
+ def ctype(self, value_type: builtins.bool = False) -> str:
753
+ return Var.type_to_ctype(self.type, value_type)
754
+
755
+ def emit(self, prefix: str = "var"):
756
+ if self.prefix:
757
+ return f"{prefix}_{self.label}"
758
+ else:
759
+ return self.label
760
+
761
+ def emit_adj(self):
762
+ return self.emit("adj")
763
+
764
+ def mark_read(self):
765
+ """Marks this Var as having been read from in a kernel (array only)."""
766
+ if not is_array(self.type):
767
+ return
768
+
769
+ self.is_read = True
770
+
771
+ # recursively update all parent states
772
+ parent = self.parent
773
+ while parent is not None:
774
+ parent.is_read = True
775
+ parent = parent.parent
776
+
777
+ def mark_write(self, **kwargs):
778
+ """Marks this Var has having been written to in a kernel (array only)."""
779
+ if not is_array(self.type):
780
+ return
781
+
782
+ # detect if we are writing to an array after reading from it within the same kernel
783
+ if self.is_read and warp._src.config.verify_autograd_array_access:
784
+ if "kernel_name" and "filename" and "lineno" in kwargs:
785
+ print(
786
+ f"Warning: Array passed to argument {self.label} in kernel {kwargs['kernel_name']} at {kwargs['filename']}:{kwargs['lineno']} is being written to after it has been read from within the same kernel. This may corrupt gradient computation in the backward pass."
787
+ )
788
+ else:
789
+ print(
790
+ f"Warning: Array {self} is being written to after it has been read from within the same kernel. This may corrupt gradient computation in the backward pass."
791
+ )
792
+ self.is_write = True
793
+
794
+ # recursively update all parent states
795
+ parent = self.parent
796
+ while parent is not None:
797
+ parent.is_write = True
798
+ parent = parent.parent
799
+
800
+
801
+ class Block:
802
+ # Represents a basic block of instructions, e.g.: list
803
+ # of straight line instructions inside a for-loop or conditional
804
+
805
+ def __init__(self):
806
+ # list of statements inside this block
807
+ self.body_forward = []
808
+ self.body_replay = []
809
+ self.body_reverse = []
810
+
811
+ # list of vars declared in this block
812
+ self.vars = []
813
+
814
+
815
+ def apply_defaults(
816
+ bound_args: inspect.BoundArguments,
817
+ values: Mapping[str, Any],
818
+ ):
819
+ # Similar to Python's `inspect.BoundArguments.apply_defaults()`
820
+ # but with the possibility to pass an augmented set of default values.
821
+ arguments = bound_args.arguments
822
+ new_arguments = []
823
+ for name in bound_args._signature.parameters.keys():
824
+ if name in arguments:
825
+ new_arguments.append((name, arguments[name]))
826
+ elif name in values:
827
+ new_arguments.append((name, values[name]))
828
+
829
+ bound_args.arguments = dict(new_arguments)
830
+
831
+
832
+ def func_match_args(func, arg_types, kwarg_types):
833
+ try:
834
+ # Try to bind the given arguments to the function's signature.
835
+ # This is not checking whether the argument types are matching,
836
+ # rather it's just assigning each argument to the corresponding
837
+ # function parameter.
838
+ bound_arg_types = func.signature.bind(*arg_types, **kwarg_types)
839
+ except TypeError:
840
+ return False
841
+
842
+ # Populate the bound arguments with any default values.
843
+ default_arg_types = {
844
+ k: None if v is None else get_arg_type(v)
845
+ for k, v in func.defaults.items()
846
+ if k not in bound_arg_types.arguments
847
+ }
848
+ apply_defaults(bound_arg_types, default_arg_types)
849
+ bound_arg_types = tuple(bound_arg_types.arguments.values())
850
+
851
+ # Check the given argument types against the ones defined on the function.
852
+ for bound_arg_type, func_arg_type in zip(bound_arg_types, func.input_types.values()):
853
+ # Let the `value_func` callback infer the type.
854
+ if bound_arg_type is None:
855
+ continue
856
+
857
+ # if arg type registered as Any, treat as
858
+ # template allowing any type to match
859
+ if func_arg_type == Any:
860
+ continue
861
+
862
+ # handle function refs as a special case
863
+ if func_arg_type == Callable and isinstance(bound_arg_type, warp._src.context.Function):
864
+ continue
865
+
866
+ # check arg type matches input variable type
867
+ if not types_equal(func_arg_type, strip_reference(bound_arg_type), match_generic=True):
868
+ return False
869
+
870
+ return True
871
+
872
+
873
+ def get_arg_type(arg: Var | Any) -> type:
874
+ if isinstance(arg, str):
875
+ return str
876
+
877
+ if isinstance(arg, Sequence):
878
+ return tuple(get_arg_type(x) for x in arg)
879
+
880
+ if is_array(arg):
881
+ return arg
882
+
883
+ if get_origin(arg) is tuple:
884
+ return tuple(get_arg_type(x) for x in get_args(arg))
885
+
886
+ if is_tuple(arg):
887
+ return arg
888
+
889
+ if isinstance(arg, (type, warp._src.context.Function)):
890
+ return arg
891
+
892
+ if isinstance(arg, Var):
893
+ if get_origin(arg.type) is tuple:
894
+ return get_args(arg.type)
895
+
896
+ return arg.type
897
+
898
+ return type(arg)
899
+
900
+
901
+ def get_arg_value(arg: Any) -> Any:
902
+ if isinstance(arg, Sequence):
903
+ return tuple(get_arg_value(x) for x in arg)
904
+
905
+ if isinstance(arg, (type, warp._src.context.Function)):
906
+ return arg
907
+
908
+ if isinstance(arg, Var):
909
+ if is_tuple(arg.type):
910
+ return tuple(get_arg_value(x) for x in arg.type.values)
911
+
912
+ if arg.constant is not None:
913
+ return arg.constant
914
+
915
+ return arg
916
+
917
+
918
+ class Adjoint:
919
+ # Source code transformer, this class takes a Python function and
920
+ # generates forward and backward SSA forms of the function instructions
921
+
922
+ def __init__(
923
+ adj,
924
+ func: Callable[..., Any],
925
+ overload_annotations=None,
926
+ is_user_function=False,
927
+ skip_forward_codegen=False,
928
+ skip_reverse_codegen=False,
929
+ custom_reverse_mode=False,
930
+ custom_reverse_num_input_args=-1,
931
+ transformers: list[ast.NodeTransformer] | None = None,
932
+ source: str | None = None,
933
+ ):
934
+ adj.func = func
935
+
936
+ adj.is_user_function = is_user_function
937
+
938
+ # whether the generation of the forward code is skipped for this function
939
+ adj.skip_forward_codegen = skip_forward_codegen
940
+ # whether the generation of the adjoint code is skipped for this function
941
+ adj.skip_reverse_codegen = skip_reverse_codegen
942
+ # Whether this function is used by a kernel that has has the backward pass enabled.
943
+ adj.used_by_backward_kernel = False
944
+
945
+ # extract name of source file
946
+ adj.filename = inspect.getsourcefile(func) or "unknown source file"
947
+ # get source file line number where function starts
948
+ adj.fun_lineno = 0
949
+ adj.source = source
950
+ if adj.source is None:
951
+ adj.source, adj.fun_lineno = adj.extract_function_source(func)
952
+
953
+ assert adj.source is not None, f"Failed to extract source code for function {func.__name__}"
954
+
955
+ # Indicates where the function definition starts (excludes decorators)
956
+ adj.fun_def_lineno = None
957
+
958
+ # get function source code
959
+ # ensures that indented class methods can be parsed as kernels
960
+ adj.source = textwrap.dedent(adj.source)
961
+
962
+ adj.source_lines = adj.source.splitlines()
963
+
964
+ if transformers is None:
965
+ transformers = []
966
+
967
+ # build AST and apply node transformers
968
+ adj.tree = ast.parse(adj.source)
969
+ adj.transformers = transformers
970
+ for transformer in transformers:
971
+ adj.tree = transformer.visit(adj.tree)
972
+
973
+ adj.fun_name = adj.tree.body[0].name
974
+
975
+ # for keeping track of line number in function code
976
+ adj.lineno = None
977
+
978
+ # whether the forward code shall be used for the reverse pass and a custom
979
+ # function signature is applied to the reverse version of the function
980
+ adj.custom_reverse_mode = custom_reverse_mode
981
+ # the number of function arguments that pertain to the forward function
982
+ # input arguments (i.e. the number of arguments that are not adjoint arguments)
983
+ adj.custom_reverse_num_input_args = custom_reverse_num_input_args
984
+
985
+ # parse argument types
986
+ argspec = get_full_arg_spec(func)
987
+
988
+ # ensure all arguments are annotated
989
+ if overload_annotations is None:
990
+ # use source-level argument annotations
991
+ if len(argspec.annotations) < len(argspec.args):
992
+ raise WarpCodegenError(f"Incomplete argument annotations on function {adj.fun_name}")
993
+ adj.arg_types = {k: v for k, v in argspec.annotations.items() if not (k == "return" and v is None)}
994
+ else:
995
+ # use overload argument annotations
996
+ for arg_name in argspec.args:
997
+ if arg_name not in overload_annotations:
998
+ raise WarpCodegenError(f"Incomplete overload annotations for function {adj.fun_name}")
999
+ adj.arg_types = overload_annotations.copy()
1000
+
1001
+ adj.args = []
1002
+ adj.symbols = {}
1003
+
1004
+ for name, type in adj.arg_types.items():
1005
+ # skip return hint
1006
+ if name == "return":
1007
+ continue
1008
+
1009
+ # add variable for argument
1010
+ arg = Var(name, type, requires_grad=False)
1011
+ adj.args.append(arg)
1012
+
1013
+ # pre-populate symbol dictionary with function argument names
1014
+ # this is to avoid registering false references to overshadowed modules
1015
+ adj.symbols[name] = arg
1016
+
1017
+ # Indicates whether there are unresolved static expressions in the function.
1018
+ # These stem from wp.static() expressions that could not be evaluated at declaration time.
1019
+ # This will signal to the module builder that this module needs to be rebuilt even if the module hash is unchanged.
1020
+ adj.has_unresolved_static_expressions = False
1021
+
1022
+ # try to replace static expressions by their constant result if the
1023
+ # expression can be evaluated at declaration time
1024
+ adj.static_expressions: dict[str, Any] = {}
1025
+ if "static" in adj.source:
1026
+ adj.replace_static_expressions()
1027
+
1028
+ # There are cases where a same module might be rebuilt multiple times,
1029
+ # for example when kernels are nested inside of functions, or when
1030
+ # a kernel's launch raises an exception. Ideally we'd always want to
1031
+ # avoid rebuilding kernels but some corner cases seem to depend on it,
1032
+ # so we only avoid rebuilding kernels that errored out to give a chance
1033
+ # for unit testing errors being spit out from kernels.
1034
+ adj.skip_build = False
1035
+
1036
+ # allocate extra space for a function call that requires its
1037
+ # own shared memory space, we treat shared memory as a stack
1038
+ # where each function pushes and pops space off, the extra
1039
+ # quantity is the 'roofline' amount required for the entire kernel
1040
+ def alloc_shared_extra(adj, num_bytes):
1041
+ adj.max_required_extra_shared_memory = max(adj.max_required_extra_shared_memory, num_bytes)
1042
+
1043
+ # returns the total number of bytes for a function
1044
+ # based on it's own requirements + worst case
1045
+ # requirements of any dependent functions
1046
+ def get_total_required_shared(adj):
1047
+ total_shared = 0
1048
+
1049
+ for var in adj.variables:
1050
+ if is_tile(var.type) and var.type.storage == "shared" and var.type.owner:
1051
+ total_shared += var.type.size_in_bytes()
1052
+
1053
+ return total_shared + adj.max_required_extra_shared_memory
1054
+
1055
+ @staticmethod
1056
+ def extract_function_source(func: Callable) -> tuple[str, int]:
1057
+ try:
1058
+ _, fun_lineno = inspect.getsourcelines(func)
1059
+ source = inspect.getsource(func)
1060
+ except OSError as e:
1061
+ raise RuntimeError(
1062
+ "Directly evaluating Warp code defined as a string using `exec()` is not supported, "
1063
+ "please save it to a file and use `importlib` if needed."
1064
+ ) from e
1065
+ return source, fun_lineno
1066
+
1067
+ # generate function ssa form and adjoint
1068
+ def build(adj, builder, default_builder_options=None):
1069
+ # arg Var read/write flags are held during module rebuilds, so we reset here even when skipping a build
1070
+ for arg in adj.args:
1071
+ arg.is_read = False
1072
+ arg.is_write = False
1073
+
1074
+ if adj.skip_build:
1075
+ return
1076
+
1077
+ adj.builder = builder
1078
+
1079
+ if default_builder_options is None:
1080
+ default_builder_options = {}
1081
+
1082
+ if adj.builder:
1083
+ adj.builder_options = adj.builder.options
1084
+ else:
1085
+ adj.builder_options = default_builder_options
1086
+
1087
+ global options
1088
+ options = adj.builder_options
1089
+
1090
+ adj.symbols = {} # map from symbols to adjoint variables
1091
+ adj.variables = [] # list of local variables (in order)
1092
+
1093
+ adj.return_var = None # return type for function or kernel
1094
+ adj.loop_symbols = [] # symbols at the start of each loop
1095
+ adj.loop_const_iter_symbols = (
1096
+ set()
1097
+ ) # constant iteration variables for static loops (mutating them does not raise an error)
1098
+
1099
+ # blocks
1100
+ adj.blocks = [Block()]
1101
+ adj.loop_blocks = []
1102
+
1103
+ # holds current indent level
1104
+ adj.indentation = ""
1105
+
1106
+ # used to generate new label indices
1107
+ adj.label_count = 0
1108
+
1109
+ # tracks how much additional shared memory is required by any dependent function calls
1110
+ adj.max_required_extra_shared_memory = 0
1111
+
1112
+ # update symbol map for each argument
1113
+ for a in adj.args:
1114
+ adj.symbols[a.label] = a
1115
+
1116
+ # recursively evaluate function body
1117
+ try:
1118
+ adj.eval(adj.tree.body[0])
1119
+ except Exception as original_exc:
1120
+ try:
1121
+ lineno = adj.lineno + adj.fun_lineno
1122
+ line = adj.source_lines[adj.lineno]
1123
+ msg = f'Error while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
1124
+
1125
+ # Combine the new message with the original exception's arguments
1126
+ new_args = (";".join([msg] + [str(a) for a in original_exc.args]),)
1127
+
1128
+ # Enhance the original exception with parser context before re-raising.
1129
+ # 'from None' is used to suppress Python's chained exceptions for a cleaner error output.
1130
+ raise type(original_exc)(*new_args).with_traceback(original_exc.__traceback__) from None
1131
+ finally:
1132
+ adj.skip_build = True
1133
+ adj.builder = None
1134
+
1135
+ if builder is not None:
1136
+ for a in adj.args:
1137
+ if isinstance(a.type, Struct):
1138
+ builder.build_struct_recursive(a.type)
1139
+ elif isinstance(a.type, warp._src.types.array) and isinstance(a.type.dtype, Struct):
1140
+ builder.build_struct_recursive(a.type.dtype)
1141
+
1142
+ # release builder reference for GC
1143
+ adj.builder = None
1144
+
1145
+ # code generation methods
1146
+ def format_template(adj, template, input_vars, output_var):
1147
+ # output var is always the 0th index
1148
+ args = [output_var, *input_vars]
1149
+ s = template.format(*args)
1150
+
1151
+ return s
1152
+
1153
+ # generates a list of formatted args
1154
+ def format_args(adj, prefix, args):
1155
+ arg_strs = []
1156
+
1157
+ for a in args:
1158
+ if isinstance(a, warp._src.context.Function):
1159
+ # functions don't have a var_ prefix so strip it off here
1160
+ if prefix == "var":
1161
+ arg_strs.append(f"{a.namespace}{a.native_func}")
1162
+ else:
1163
+ arg_strs.append(f"{a.namespace}{prefix}_{a.native_func}")
1164
+ elif is_reference(a.type):
1165
+ arg_strs.append(f"{prefix}_{a}")
1166
+ elif isinstance(a, Var):
1167
+ arg_strs.append(a.emit(prefix))
1168
+ else:
1169
+ raise WarpCodegenTypeError(f"Arguments must be variables or functions, got {type(a)}")
1170
+
1171
+ return arg_strs
1172
+
1173
+ # generates argument string for a forward function call
1174
+ def format_forward_call_args(adj, args, use_initializer_list):
1175
+ arg_str = ", ".join(adj.format_args("var", args))
1176
+ if use_initializer_list:
1177
+ return f"{{{arg_str}}}"
1178
+ return arg_str
1179
+
1180
+ # generates argument string for a reverse function call
1181
+ def format_reverse_call_args(
1182
+ adj,
1183
+ args_var,
1184
+ args,
1185
+ args_out,
1186
+ use_initializer_list,
1187
+ has_output_args=True,
1188
+ require_original_output_arg=False,
1189
+ ):
1190
+ formatted_var = adj.format_args("var", args_var)
1191
+ formatted_out = []
1192
+ if has_output_args and (require_original_output_arg or len(args_out) > 1):
1193
+ formatted_out = adj.format_args("var", args_out)
1194
+ formatted_var_adj = adj.format_args(
1195
+ "&adj" if use_initializer_list else "adj",
1196
+ args,
1197
+ )
1198
+ formatted_out_adj = adj.format_args("adj", args_out)
1199
+
1200
+ if len(formatted_var_adj) == 0 and len(formatted_out_adj) == 0:
1201
+ # there are no adjoint arguments, so we don't need to call the reverse function
1202
+ return None
1203
+
1204
+ if use_initializer_list:
1205
+ var_str = f"{{{', '.join(formatted_var)}}}"
1206
+ out_str = f"{{{', '.join(formatted_out)}}}"
1207
+ adj_str = f"{{{', '.join(formatted_var_adj)}}}"
1208
+ out_adj_str = ", ".join(formatted_out_adj)
1209
+ if len(args_out) > 1:
1210
+ arg_str = ", ".join([var_str, out_str, adj_str, out_adj_str])
1211
+ else:
1212
+ arg_str = ", ".join([var_str, adj_str, out_adj_str])
1213
+ else:
1214
+ arg_str = ", ".join(formatted_var + formatted_out + formatted_var_adj + formatted_out_adj)
1215
+ return arg_str
1216
+
1217
+ def indent(adj):
1218
+ adj.indentation = adj.indentation + " "
1219
+
1220
+ def dedent(adj):
1221
+ adj.indentation = adj.indentation[:-4]
1222
+
1223
+ def begin_block(adj, name="block"):
1224
+ b = Block()
1225
+
1226
+ # give block a unique id
1227
+ b.label = name + "_" + str(adj.label_count)
1228
+ adj.label_count += 1
1229
+
1230
+ adj.blocks.append(b)
1231
+ return b
1232
+
1233
+ def end_block(adj):
1234
+ return adj.blocks.pop()
1235
+
1236
+ def add_var(adj, type=None, constant=None):
1237
+ index = len(adj.variables)
1238
+ name = str(index)
1239
+
1240
+ # allocate new variable
1241
+ v = Var(name, type=type, constant=constant, relative_lineno=adj.lineno)
1242
+
1243
+ adj.variables.append(v)
1244
+
1245
+ adj.blocks[-1].vars.append(v)
1246
+
1247
+ return v
1248
+
1249
+ def register_var(adj, var):
1250
+ # We sometimes initialize `Var` instances that might be thrown away
1251
+ # afterwards, so this method allows to defer their registration among
1252
+ # the list of primal vars until later on, instead of registering them
1253
+ # immediately if we were to use `adj.add_var()` or `adj.add_constant()`.
1254
+
1255
+ if isinstance(var, (Reference, warp._src.context.Function)):
1256
+ return var
1257
+
1258
+ if isinstance(var, int):
1259
+ return adj.add_constant(var)
1260
+
1261
+ if var.label is None:
1262
+ return adj.add_var(var.type, var.constant)
1263
+
1264
+ return var
1265
+
1266
+ def get_line_directive(adj, statement: str, relative_lineno: int | None = None) -> str | None:
1267
+ """Get a line directive for the given statement.
1268
+
1269
+ Args:
1270
+ statement: The statement to get the line directive for.
1271
+ relative_lineno: The line number of the statement relative to the function.
1272
+
1273
+ Returns:
1274
+ A line directive for the given statement, or None if no line directive is needed.
1275
+ """
1276
+
1277
+ if adj.filename == "unknown source file" or adj.fun_lineno == 0:
1278
+ # Early return if function is not associated with a source file or is otherwise invalid
1279
+ # TODO: Get line directives working with wp.map() functions
1280
+ return None
1281
+
1282
+ # lineinfo is enabled by default in debug mode regardless of the builder option, don't want to unnecessarily
1283
+ # emit line directives in generated code if it's not being compiled with line information
1284
+ build_mode = val if (val := adj.builder_options.get("mode")) is not None else warp._src.config.mode
1285
+
1286
+ lineinfo_enabled = adj.builder_options.get("lineinfo", False) or build_mode == "debug"
1287
+
1288
+ if relative_lineno is not None and lineinfo_enabled and warp._src.config.line_directives:
1289
+ is_comment = statement.strip().startswith("//")
1290
+ if not is_comment:
1291
+ line = relative_lineno + adj.fun_lineno
1292
+ # Convert backslashes to forward slashes for CUDA compatibility
1293
+ normalized_path = adj.filename.replace("\\", "/")
1294
+ return f'#line {line} "{normalized_path}"'
1295
+ return None
1296
+
1297
+ def add_forward(adj, statement: str, replay: str | None = None, skip_replay: builtins.bool = False) -> None:
1298
+ """Append a statement to the forward pass."""
1299
+
1300
+ if line_directive := adj.get_line_directive(statement, adj.lineno):
1301
+ adj.blocks[-1].body_forward.append(line_directive)
1302
+
1303
+ adj.blocks[-1].body_forward.append(adj.indentation + statement)
1304
+
1305
+ if not skip_replay:
1306
+ if line_directive:
1307
+ adj.blocks[-1].body_replay.append(line_directive)
1308
+
1309
+ if replay:
1310
+ # if custom replay specified then output it
1311
+ adj.blocks[-1].body_replay.append(adj.indentation + replay)
1312
+ else:
1313
+ # by default just replay the original statement
1314
+ adj.blocks[-1].body_replay.append(adj.indentation + statement)
1315
+
1316
+ # append a statement to the reverse pass
1317
+ def add_reverse(adj, statement: str) -> None:
1318
+ """Append a statement to the reverse pass."""
1319
+
1320
+ adj.blocks[-1].body_reverse.append(adj.indentation + statement)
1321
+
1322
+ if line_directive := adj.get_line_directive(statement, adj.lineno):
1323
+ adj.blocks[-1].body_reverse.append(line_directive)
1324
+
1325
+ def add_constant(adj, n):
1326
+ output = adj.add_var(type=type(n), constant=n)
1327
+ return output
1328
+
1329
+ def load(adj, var):
1330
+ if is_reference(var.type):
1331
+ var = adj.add_builtin_call("load", [var])
1332
+ return var
1333
+
1334
+ def add_comp(adj, op_strings, left, comps):
1335
+ output = adj.add_var(builtins.bool)
1336
+
1337
+ left = adj.load(left)
1338
+ s = output.emit() + " = " + ("(" * len(comps)) + left.emit() + " "
1339
+
1340
+ prev_comp_var = None
1341
+
1342
+ for op, comp in zip(op_strings, comps):
1343
+ comp_chainable = op_str_is_chainable(op)
1344
+ if comp_chainable and prev_comp_var:
1345
+ # We restrict chaining to operands of the same type
1346
+ if prev_comp_var.type is comp.type:
1347
+ prev_comp_var = adj.load(prev_comp_var)
1348
+ comp_var = adj.load(comp)
1349
+ s += "&& (" + prev_comp_var.emit() + " " + op + " " + comp_var.emit() + ")) "
1350
+ else:
1351
+ raise WarpCodegenTypeError(
1352
+ f"Cannot chain comparisons of unequal types: {prev_comp_var.type} {op} {comp.type}."
1353
+ )
1354
+ else:
1355
+ comp_var = adj.load(comp)
1356
+ s += op + " " + comp_var.emit() + ") "
1357
+
1358
+ prev_comp_var = comp_var
1359
+
1360
+ s = s.rstrip() + ";"
1361
+
1362
+ adj.add_forward(s)
1363
+
1364
+ return output
1365
+
1366
+ def add_bool_op(adj, op_string, exprs):
1367
+ exprs = [adj.load(expr) for expr in exprs]
1368
+ output = adj.add_var(builtins.bool)
1369
+ command = output.emit() + " = " + (" " + op_string + " ").join([expr.emit() for expr in exprs]) + ";"
1370
+ adj.add_forward(command)
1371
+
1372
+ return output
1373
+
1374
+ def resolve_func(adj, func, arg_types, kwarg_types, min_outputs):
1375
+ if not func.is_builtin():
1376
+ # user-defined function
1377
+ overload = func.get_overload(arg_types, kwarg_types)
1378
+ if overload is not None:
1379
+ return overload
1380
+ else:
1381
+ # if func is overloaded then perform overload resolution here
1382
+ # we validate argument types before they go to generated native code
1383
+ for f in func.overloads:
1384
+ # skip type checking for variadic functions
1385
+ if not f.variadic:
1386
+ # check argument counts match are compatible (may be some default args)
1387
+ if len(f.input_types) < len(arg_types) + len(kwarg_types):
1388
+ continue
1389
+
1390
+ if not func_match_args(f, arg_types, kwarg_types):
1391
+ continue
1392
+
1393
+ # check output dimensions match expectations
1394
+ if min_outputs:
1395
+ value_type = f.value_func(None, None)
1396
+ if not isinstance(value_type, Sequence) or len(value_type) != min_outputs:
1397
+ continue
1398
+
1399
+ # found a match, use it
1400
+ return f
1401
+
1402
+ # unresolved function, report error
1403
+ arg_type_reprs = []
1404
+
1405
+ for x in itertools.chain(arg_types, kwarg_types.values()):
1406
+ if isinstance(x, warp._src.context.Function):
1407
+ arg_type_reprs.append("function")
1408
+ else:
1409
+ # shorten Warp primitive type names
1410
+ if isinstance(x, Sequence):
1411
+ if len(x) != 1:
1412
+ raise WarpCodegenError("Argument must not be the result from a multi-valued function")
1413
+ arg_type = x[0]
1414
+ else:
1415
+ arg_type = x
1416
+
1417
+ arg_type_reprs.append(type_repr(arg_type))
1418
+
1419
+ raise WarpCodegenError(
1420
+ f"Couldn't find function overload for '{func.key}' that matched inputs with types: [{', '.join(arg_type_reprs)}]"
1421
+ )
1422
+
1423
+ def add_call(adj, func, args, kwargs, type_args, min_outputs=None):
1424
+ # Extract the types and values passed as arguments to the function call.
1425
+ arg_types = tuple(strip_reference(get_arg_type(x)) for x in args)
1426
+ kwarg_types = {k: strip_reference(get_arg_type(v)) for k, v in kwargs.items()}
1427
+
1428
+ # Resolve the exact function signature among any existing overload.
1429
+ func = adj.resolve_func(func, arg_types, kwarg_types, min_outputs)
1430
+
1431
+ # Bind the positional and keyword arguments to the function's signature
1432
+ # in order to process them as Python does it.
1433
+ bound_args: inspect.BoundArguments = func.signature.bind(*args, **kwargs)
1434
+
1435
+ # Type args are the "compile time" argument values we get from codegen.
1436
+ # For example, when calling `wp.vec3f(...)` from within a kernel,
1437
+ # this translates in fact to calling the `vector()` built-in augmented
1438
+ # with the type args `length=3, dtype=float`.
1439
+ # Eventually, these need to be passed to the underlying C++ function,
1440
+ # so we update the arguments with the type args here.
1441
+ if type_args:
1442
+ for arg in type_args:
1443
+ if arg in bound_args.arguments:
1444
+ # In case of conflict, ideally we'd throw an error since
1445
+ # what comes from codegen should be the source of truth
1446
+ # and users also passing the same value as an argument
1447
+ # is redundant (e.g.: `wp.mat22(shape=(2, 2))`).
1448
+ # However, for backward compatibility, we allow that form
1449
+ # as long as the values are equal.
1450
+ if values_check_equal(get_arg_value(bound_args.arguments[arg]), type_args[arg]):
1451
+ continue
1452
+
1453
+ raise RuntimeError(
1454
+ f"Remove the extraneous `{arg}` parameter "
1455
+ f"when calling the templated version of "
1456
+ f"`wp.{func.native_func}()`"
1457
+ )
1458
+
1459
+ type_vars = {k: Var(None, type=type(v), constant=v) for k, v in type_args.items()}
1460
+ apply_defaults(bound_args, type_vars)
1461
+
1462
+ if func.defaults:
1463
+ default_vars = {
1464
+ k: Var(None, type=type(v), constant=v)
1465
+ for k, v in func.defaults.items()
1466
+ if k not in bound_args.arguments and v is not None
1467
+ }
1468
+ apply_defaults(bound_args, default_vars)
1469
+
1470
+ bound_args = bound_args.arguments
1471
+
1472
+ # if it is a user-function then build it recursively
1473
+ if not func.is_builtin():
1474
+ # If the function called is a user function,
1475
+ # we need to ensure its adjoint is also being generated.
1476
+ if adj.used_by_backward_kernel:
1477
+ func.adj.used_by_backward_kernel = True
1478
+
1479
+ if adj.builder is None:
1480
+ func.build(None)
1481
+
1482
+ elif func not in adj.builder.functions:
1483
+ adj.builder.build_function(func)
1484
+ # add custom grad, replay functions to the list of functions
1485
+ # to be built later (invalid code could be generated if we built them now)
1486
+ # so that they are not missed when only the forward function is imported
1487
+ # from another module
1488
+ if func.custom_grad_func:
1489
+ adj.builder.deferred_functions.append(func.custom_grad_func)
1490
+ if func.custom_replay_func:
1491
+ adj.builder.deferred_functions.append(func.custom_replay_func)
1492
+
1493
+ # Resolve the return value based on the types and values of the given arguments.
1494
+ bound_arg_types = {k: get_arg_type(v) for k, v in bound_args.items()}
1495
+ bound_arg_values = {k: get_arg_value(v) for k, v in bound_args.items()}
1496
+
1497
+ return_type = func.value_func(
1498
+ {k: strip_reference(v) for k, v in bound_arg_types.items()},
1499
+ bound_arg_values,
1500
+ )
1501
+
1502
+ # Handle the special case where a Var instance is returned from the `value_func`
1503
+ # callback, in which case we replace the call with a reference to that variable.
1504
+ if isinstance(return_type, Var):
1505
+ return adj.register_var(return_type)
1506
+ elif isinstance(return_type, Sequence) and all(isinstance(x, Var) for x in return_type):
1507
+ return tuple(adj.register_var(x) for x in return_type)
1508
+
1509
+ if get_origin(return_type) is tuple:
1510
+ types = get_args(return_type)
1511
+ return_type = warp._src.types.tuple_t(types=types, values=(None,) * len(types))
1512
+
1513
+ # immediately allocate output variables so we can pass them into the dispatch method
1514
+ if return_type is None:
1515
+ # void function
1516
+ output = None
1517
+ output_list = []
1518
+ elif not isinstance(return_type, Sequence) or len(return_type) == 1:
1519
+ # single return value function
1520
+ if isinstance(return_type, Sequence):
1521
+ return_type = return_type[0]
1522
+ output = adj.add_var(return_type)
1523
+ output_list = [output]
1524
+ else:
1525
+ # multiple return value function
1526
+ output = [adj.add_var(v) for v in return_type]
1527
+ output_list = output
1528
+
1529
+ # If we have a built-in that requires special handling to dispatch
1530
+ # the arguments to the underlying C++ function, then we can resolve
1531
+ # these using the `dispatch_func`. Since this is only called from
1532
+ # within codegen, we pass it directly `codegen.Var` objects,
1533
+ # which allows for some more advanced resolution to be performed,
1534
+ # for example by checking whether an argument corresponds to
1535
+ # a literal value or references a variable.
1536
+ extra_shared_memory = 0
1537
+ if func.lto_dispatch_func is not None:
1538
+ func_args, template_args, ltoirs, extra_shared_memory = func.lto_dispatch_func(
1539
+ func.input_types, return_type, output_list, bound_args, options=adj.builder_options, builder=adj.builder
1540
+ )
1541
+ elif func.dispatch_func is not None:
1542
+ func_args, template_args = func.dispatch_func(func.input_types, return_type, bound_args)
1543
+ else:
1544
+ func_args = tuple(bound_args.values())
1545
+ template_args = ()
1546
+
1547
+ func_args = tuple(adj.register_var(x) for x in func_args)
1548
+ func_name = compute_type_str(func.native_func, template_args)
1549
+ use_initializer_list = func.initializer_list_func(bound_args, return_type)
1550
+
1551
+ fwd_args = []
1552
+ for func_arg in func_args:
1553
+ if not isinstance(func_arg, (Reference, warp._src.context.Function)):
1554
+ func_arg_var = adj.load(func_arg)
1555
+ else:
1556
+ func_arg_var = func_arg
1557
+
1558
+ # if the argument is a function (and not a builtin), then build it recursively
1559
+ if isinstance(func_arg_var, warp._src.context.Function) and not func_arg_var.is_builtin():
1560
+ if adj.used_by_backward_kernel:
1561
+ func_arg_var.adj.used_by_backward_kernel = True
1562
+
1563
+ adj.builder.build_function(func_arg_var)
1564
+
1565
+ fwd_args.append(strip_reference(func_arg_var))
1566
+
1567
+ if return_type is None:
1568
+ # handles expression (zero output) functions, e.g.: void do_something();
1569
+ forward_call = (
1570
+ f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1571
+ )
1572
+ replay_call = forward_call
1573
+ if func.custom_replay_func is not None or func.replay_snippet is not None:
1574
+ replay_call = f"{func.namespace}replay_{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1575
+
1576
+ elif not isinstance(return_type, Sequence) or len(return_type) == 1:
1577
+ # handle simple function (one output)
1578
+ forward_call = f"var_{output} = {func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1579
+ replay_call = forward_call
1580
+ if func.custom_replay_func is not None:
1581
+ replay_call = f"var_{output} = {func.namespace}replay_{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1582
+
1583
+ else:
1584
+ # handle multiple value functions
1585
+ forward_call = (
1586
+ f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args + output, use_initializer_list)});"
1587
+ )
1588
+ replay_call = forward_call
1589
+
1590
+ if func.skip_replay:
1591
+ adj.add_forward(forward_call, replay="// " + replay_call)
1592
+ else:
1593
+ adj.add_forward(forward_call, replay=replay_call)
1594
+
1595
+ if func.is_differentiable and len(func_args):
1596
+ adj_args = tuple(strip_reference(x) for x in func_args)
1597
+ reverse_has_output_args = (
1598
+ func.require_original_output_arg or len(output_list) > 1
1599
+ ) and func.custom_grad_func is None
1600
+ arg_str = adj.format_reverse_call_args(
1601
+ fwd_args,
1602
+ adj_args,
1603
+ output_list,
1604
+ use_initializer_list,
1605
+ has_output_args=reverse_has_output_args,
1606
+ require_original_output_arg=func.require_original_output_arg,
1607
+ )
1608
+ if arg_str is not None:
1609
+ reverse_call = f"{func.namespace}adj_{func.native_func}({arg_str});"
1610
+ adj.add_reverse(reverse_call)
1611
+
1612
+ # update our smem roofline requirements based on any
1613
+ # shared memory required by the dependent function call
1614
+ if not func.is_builtin():
1615
+ adj.alloc_shared_extra(func.adj.get_total_required_shared() + extra_shared_memory)
1616
+ else:
1617
+ adj.alloc_shared_extra(extra_shared_memory)
1618
+
1619
+ return output
1620
+
1621
+ def add_builtin_call(adj, func_name, args, min_outputs=None):
1622
+ func = warp._src.context.builtin_functions[func_name]
1623
+ return adj.add_call(func, args, {}, {}, min_outputs=min_outputs)
1624
+
1625
+ def add_return(adj, var):
1626
+ if var is None or len(var) == 0:
1627
+ # NOTE: If this kernel gets compiled for a CUDA device, then we need
1628
+ # to convert the return; into a continue; in codegen_func_forward()
1629
+ adj.add_forward("return;", f"goto label{adj.label_count};")
1630
+ elif len(var) == 1:
1631
+ adj.add_forward(f"return {var[0].emit()};", f"goto label{adj.label_count};")
1632
+ adj.add_reverse("adj_" + str(var[0]) + " += adj_ret;")
1633
+ else:
1634
+ for i, v in enumerate(var):
1635
+ adj.add_forward(f"ret_{i} = {v.emit()};")
1636
+ adj.add_reverse(f"adj_{v} += adj_ret_{i};")
1637
+ adj.add_forward("return;", f"goto label{adj.label_count};")
1638
+
1639
+ adj.add_reverse(f"label{adj.label_count}:;")
1640
+
1641
+ adj.label_count += 1
1642
+
1643
+ # define an if statement
1644
+ def begin_if(adj, cond):
1645
+ cond = adj.load(cond)
1646
+ adj.add_forward(f"if ({cond.emit()}) {{")
1647
+ adj.add_reverse("}")
1648
+
1649
+ adj.indent()
1650
+
1651
+ def end_if(adj, cond):
1652
+ adj.dedent()
1653
+
1654
+ adj.add_forward("}")
1655
+ cond = adj.load(cond)
1656
+ adj.add_reverse(f"if ({cond.emit()}) {{")
1657
+
1658
+ def begin_else(adj, cond):
1659
+ cond = adj.load(cond)
1660
+ adj.add_forward(f"if (!{cond.emit()}) {{")
1661
+ adj.add_reverse("}")
1662
+
1663
+ adj.indent()
1664
+
1665
+ def end_else(adj, cond):
1666
+ adj.dedent()
1667
+
1668
+ adj.add_forward("}")
1669
+ cond = adj.load(cond)
1670
+ adj.add_reverse(f"if (!{cond.emit()}) {{")
1671
+
1672
+ # define a for-loop
1673
+ def begin_for(adj, iter):
1674
+ cond_block = adj.begin_block("for")
1675
+ adj.loop_blocks.append(cond_block)
1676
+ adj.add_forward(f"start_{cond_block.label}:;")
1677
+ adj.indent()
1678
+
1679
+ # evaluate cond
1680
+ adj.add_forward(f"if (iter_cmp({iter.emit()}) == 0) goto end_{cond_block.label};")
1681
+
1682
+ # evaluate iter
1683
+ val = adj.add_builtin_call("iter_next", [iter])
1684
+
1685
+ adj.begin_block()
1686
+
1687
+ return val
1688
+
1689
+ def end_for(adj, iter):
1690
+ body_block = adj.end_block()
1691
+ cond_block = adj.end_block()
1692
+ adj.loop_blocks.pop()
1693
+
1694
+ ####################
1695
+ # forward pass
1696
+
1697
+ for i in cond_block.body_forward:
1698
+ adj.blocks[-1].body_forward.append(i)
1699
+
1700
+ for i in body_block.body_forward:
1701
+ adj.blocks[-1].body_forward.append(i)
1702
+
1703
+ adj.add_forward(f"goto start_{cond_block.label};", skip_replay=True)
1704
+
1705
+ adj.dedent()
1706
+ adj.add_forward(f"end_{cond_block.label}:;", skip_replay=True)
1707
+
1708
+ ####################
1709
+ # reverse pass
1710
+
1711
+ reverse = []
1712
+
1713
+ # reverse iterator
1714
+ reverse.append(adj.indentation + f"{iter.emit()} = wp::iter_reverse({iter.emit()});")
1715
+
1716
+ for i in cond_block.body_forward:
1717
+ reverse.append(i)
1718
+
1719
+ # zero adjoints
1720
+ for i in body_block.vars:
1721
+ if is_tile(i.type):
1722
+ if i.type.owner:
1723
+ reverse.append(adj.indentation + f"\t{i.emit_adj()}.grad_zero();")
1724
+ else:
1725
+ reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
1726
+
1727
+ # replay
1728
+ for i in body_block.body_replay:
1729
+ reverse.append(i)
1730
+
1731
+ # reverse
1732
+ for i in reversed(body_block.body_reverse):
1733
+ reverse.append(i)
1734
+
1735
+ reverse.append(adj.indentation + f"\tgoto start_{cond_block.label};")
1736
+ reverse.append(adj.indentation + f"end_{cond_block.label}:;")
1737
+
1738
+ adj.blocks[-1].body_reverse.extend(reversed(reverse))
1739
+
1740
+ # define a while loop
1741
+ def begin_while(adj, cond):
1742
+ # evaluate condition in its own block
1743
+ # so we can control replay
1744
+ cond_block = adj.begin_block("while")
1745
+ adj.loop_blocks.append(cond_block)
1746
+ cond_block.body_forward.append(f"start_{cond_block.label}:;")
1747
+
1748
+ c = adj.eval(cond)
1749
+ c = adj.load(c)
1750
+
1751
+ cond_block.body_forward.append(f"if (({c.emit()}) == false) goto end_{cond_block.label};")
1752
+
1753
+ # being block around loop
1754
+ adj.begin_block()
1755
+ adj.indent()
1756
+
1757
+ def end_while(adj):
1758
+ adj.dedent()
1759
+ body_block = adj.end_block()
1760
+ cond_block = adj.end_block()
1761
+ adj.loop_blocks.pop()
1762
+
1763
+ ####################
1764
+ # forward pass
1765
+
1766
+ for i in cond_block.body_forward:
1767
+ adj.blocks[-1].body_forward.append(i)
1768
+
1769
+ for i in body_block.body_forward:
1770
+ adj.blocks[-1].body_forward.append(i)
1771
+
1772
+ adj.blocks[-1].body_forward.append(f"goto start_{cond_block.label};")
1773
+ adj.blocks[-1].body_forward.append(f"end_{cond_block.label}:;")
1774
+
1775
+ ####################
1776
+ # reverse pass
1777
+ reverse = []
1778
+
1779
+ # cond
1780
+ for i in cond_block.body_forward:
1781
+ reverse.append(i)
1782
+
1783
+ # zero adjoints of local vars
1784
+ for i in body_block.vars:
1785
+ reverse.append(f"{i.emit_adj()} = {{}};")
1786
+
1787
+ # replay
1788
+ for i in body_block.body_replay:
1789
+ reverse.append(i)
1790
+
1791
+ # reverse
1792
+ for i in reversed(body_block.body_reverse):
1793
+ reverse.append(i)
1794
+
1795
+ reverse.append(f"goto start_{cond_block.label};")
1796
+ reverse.append(f"end_{cond_block.label}:;")
1797
+
1798
+ # output
1799
+ adj.blocks[-1].body_reverse.extend(reversed(reverse))
1800
+
1801
+ def emit_FunctionDef(adj, node):
1802
+ adj.fun_def_lineno = node.lineno
1803
+
1804
+ for f in node.body:
1805
+ # Skip variable creation for standalone constants, including docstrings
1806
+ if isinstance(f, ast.Expr) and isinstance(f.value, ast.Constant):
1807
+ continue
1808
+ adj.eval(f)
1809
+
1810
+ if adj.return_var is not None and len(adj.return_var) == 1:
1811
+ if not isinstance(node.body[-1], ast.Return):
1812
+ adj.add_forward("return {};", skip_replay=True)
1813
+
1814
+ # native function case: return type is specified, eg -> int or -> wp.float32
1815
+ is_func_native = False
1816
+ if node.decorator_list is not None and len(node.decorator_list) == 1:
1817
+ obj = node.decorator_list[0]
1818
+ if isinstance(obj, ast.Call):
1819
+ if isinstance(obj.func, ast.Attribute):
1820
+ if obj.func.attr == "func_native":
1821
+ is_func_native = True
1822
+ if is_func_native and node.returns is not None:
1823
+ if isinstance(node.returns, ast.Name): # python built-in type
1824
+ var = Var(label="return_type", type=eval(node.returns.id))
1825
+ elif isinstance(node.returns, ast.Attribute): # warp type
1826
+ var = Var(label="return_type", type=eval(node.returns.attr))
1827
+ else:
1828
+ raise WarpCodegenTypeError("Native function return type not recognized")
1829
+ adj.return_var = (var,)
1830
+
1831
+ def emit_If(adj, node):
1832
+ if len(node.body) == 0:
1833
+ return None
1834
+
1835
+ # eval condition
1836
+ cond = adj.eval(node.test)
1837
+
1838
+ if cond.constant is not None:
1839
+ # resolve constant condition
1840
+ if cond.constant:
1841
+ for stmt in node.body:
1842
+ adj.eval(stmt)
1843
+ else:
1844
+ for stmt in node.orelse:
1845
+ adj.eval(stmt)
1846
+ return None
1847
+
1848
+ # save symbol map
1849
+ symbols_prev = adj.symbols.copy()
1850
+
1851
+ # eval body
1852
+ adj.begin_if(cond)
1853
+
1854
+ for stmt in node.body:
1855
+ adj.eval(stmt)
1856
+
1857
+ adj.end_if(cond)
1858
+
1859
+ # detect existing symbols with conflicting definitions (variables assigned inside the branch)
1860
+ # and resolve with a phi (select) function
1861
+ for items in symbols_prev.items():
1862
+ sym = items[0]
1863
+ var1 = items[1]
1864
+ var2 = adj.symbols[sym]
1865
+
1866
+ if var1 != var2:
1867
+ # insert a phi function that selects var1, var2 based on cond
1868
+ out = adj.add_builtin_call("where", [cond, var2, var1])
1869
+ adj.symbols[sym] = out
1870
+
1871
+ symbols_prev = adj.symbols.copy()
1872
+
1873
+ # evaluate 'else' statement as if (!cond)
1874
+ if len(node.orelse) > 0:
1875
+ adj.begin_else(cond)
1876
+
1877
+ for stmt in node.orelse:
1878
+ adj.eval(stmt)
1879
+
1880
+ adj.end_else(cond)
1881
+
1882
+ # detect existing symbols with conflicting definitions (variables assigned inside the else)
1883
+ # and resolve with a phi (select) function
1884
+ for items in symbols_prev.items():
1885
+ sym = items[0]
1886
+ var1 = items[1]
1887
+ var2 = adj.symbols[sym]
1888
+
1889
+ if var1 != var2:
1890
+ # insert a phi function that selects var1, var2 based on cond
1891
+ # note the reversed order of vars since we want to use !cond as our select
1892
+ out = adj.add_builtin_call("where", [cond, var1, var2])
1893
+ adj.symbols[sym] = out
1894
+
1895
+ def emit_IfExp(adj, node):
1896
+ cond = adj.eval(node.test)
1897
+
1898
+ if cond.constant is not None:
1899
+ return adj.eval(node.body) if cond.constant else adj.eval(node.orelse)
1900
+
1901
+ adj.begin_if(cond)
1902
+ body = adj.eval(node.body)
1903
+ adj.end_if(cond)
1904
+
1905
+ adj.begin_else(cond)
1906
+ orelse = adj.eval(node.orelse)
1907
+ adj.end_else(cond)
1908
+
1909
+ return adj.add_builtin_call("where", [cond, body, orelse])
1910
+
1911
+ def emit_Compare(adj, node):
1912
+ # node.left, node.ops (list of ops), node.comparators (things to compare to)
1913
+ # e.g. (left ops[0] node.comparators[0]) ops[1] node.comparators[1]
1914
+
1915
+ left = adj.eval(node.left)
1916
+ comps = [adj.eval(comp) for comp in node.comparators]
1917
+ op_strings = [builtin_operators[type(op)] for op in node.ops]
1918
+
1919
+ return adj.add_comp(op_strings, left, comps)
1920
+
1921
+ def emit_BoolOp(adj, node):
1922
+ # op, expr list values
1923
+
1924
+ op = node.op
1925
+ if isinstance(op, ast.And):
1926
+ func = "&&"
1927
+ elif isinstance(op, ast.Or):
1928
+ func = "||"
1929
+ else:
1930
+ raise WarpCodegenKeyError(f"Op {op} is not supported")
1931
+
1932
+ return adj.add_bool_op(func, [adj.eval(expr) for expr in node.values])
1933
+
1934
+ def emit_Name(adj, node):
1935
+ # lookup symbol, if it has already been assigned to a variable then return the existing mapping
1936
+ if node.id in adj.symbols:
1937
+ return adj.symbols[node.id]
1938
+
1939
+ obj = adj.resolve_external_reference(node.id)
1940
+
1941
+ if obj is None:
1942
+ raise WarpCodegenKeyError("Referencing undefined symbol: " + str(node.id))
1943
+
1944
+ if warp._src.types.is_value(obj):
1945
+ # evaluate constant
1946
+ out = adj.add_constant(obj)
1947
+ adj.symbols[node.id] = out
1948
+ return out
1949
+
1950
+ # the named object is either a function, class name, or module
1951
+ # pass it back to the caller for processing
1952
+ if isinstance(obj, warp._src.context.Function):
1953
+ return obj
1954
+ if isinstance(obj, type):
1955
+ return obj
1956
+ if isinstance(obj, Struct):
1957
+ adj.builder.build_struct_recursive(obj)
1958
+ return obj
1959
+ if isinstance(obj, types.ModuleType):
1960
+ return obj
1961
+
1962
+ raise TypeError(f"Invalid external reference type: {type(obj)}")
1963
+
1964
+ @staticmethod
1965
+ def resolve_type_attribute(var_type: type, attr: str):
1966
+ if isinstance(var_type, type) and type_is_value(var_type):
1967
+ if attr == "dtype":
1968
+ return type_scalar_type(var_type)
1969
+ elif attr == "length":
1970
+ return type_size(var_type)
1971
+
1972
+ return getattr(var_type, attr, None)
1973
+
1974
+ def vector_component_index(adj, component, vector_type):
1975
+ if len(component) != 1:
1976
+ raise WarpCodegenAttributeError(f"Vector swizzle must be single character, got .{component}")
1977
+
1978
+ dim = vector_type._shape_[0]
1979
+ swizzles = "xyzw"[0:dim]
1980
+ if component not in swizzles:
1981
+ raise WarpCodegenAttributeError(
1982
+ f"Vector swizzle for {vector_type} must be one of {swizzles}, got {component}"
1983
+ )
1984
+
1985
+ index = swizzles.index(component)
1986
+ index = adj.add_constant(index)
1987
+ return index
1988
+
1989
+ def transform_component(adj, component):
1990
+ if len(component) != 1:
1991
+ raise WarpCodegenAttributeError(f"Transform attribute must be single character, got .{component}")
1992
+
1993
+ if component not in ("p", "q"):
1994
+ raise WarpCodegenAttributeError(f"Attribute for transformation must be either 'p' or 'q', got {component}")
1995
+
1996
+ return component
1997
+
1998
+ @staticmethod
1999
+ def is_differentiable_value_type(var_type):
2000
+ # checks that the argument type is a value type (i.e, not an array)
2001
+ # possibly holding differentiable values (for which gradients must be accumulated)
2002
+ return type_scalar_type(var_type) in float_types or isinstance(var_type, Struct)
2003
+
2004
+ def emit_Attribute(adj, node):
2005
+ if hasattr(node, "is_adjoint"):
2006
+ node.value.is_adjoint = True
2007
+
2008
+ aggregate = adj.eval(node.value)
2009
+
2010
+ try:
2011
+ if isinstance(aggregate, Var) and aggregate.constant is not None:
2012
+ # this case may occur when the attribute is a constant, e.g.: `IntEnum.A.value`
2013
+ return aggregate
2014
+
2015
+ if isinstance(aggregate, types.ModuleType) or isinstance(aggregate, type):
2016
+ out = getattr(aggregate, node.attr)
2017
+
2018
+ if warp._src.types.is_value(out):
2019
+ return adj.add_constant(out)
2020
+ if isinstance(out, (enum.IntEnum, enum.IntFlag)):
2021
+ return adj.add_constant(int(out))
2022
+
2023
+ return out
2024
+
2025
+ if hasattr(node, "is_adjoint"):
2026
+ # create a Var that points to the struct attribute, i.e.: directly generates `struct.attr` when used
2027
+ attr_name = aggregate.label + "." + node.attr
2028
+ attr_type = aggregate.type.vars[node.attr].type
2029
+
2030
+ return Var(attr_name, attr_type)
2031
+
2032
+ aggregate_type = strip_reference(aggregate.type)
2033
+
2034
+ # reading a vector or quaternion component
2035
+ if type_is_vector(aggregate_type) or type_is_quaternion(aggregate_type):
2036
+ index = adj.vector_component_index(node.attr, aggregate_type)
2037
+
2038
+ return adj.add_builtin_call("extract", [aggregate, index])
2039
+
2040
+ elif type_is_transformation(aggregate_type):
2041
+ component = adj.transform_component(node.attr)
2042
+
2043
+ if component == "p":
2044
+ return adj.add_builtin_call("transform_get_translation", [aggregate])
2045
+ else:
2046
+ return adj.add_builtin_call("transform_get_rotation", [aggregate])
2047
+
2048
+ else:
2049
+ attr_var = aggregate_type.vars[node.attr]
2050
+
2051
+ # represent pointer types as uint64
2052
+ if isinstance(attr_var.type, pointer_t):
2053
+ cast = f"({Var.dtype_to_ctype(uint64)}*)"
2054
+ adj_cast = f"({Var.dtype_to_ctype(attr_var.type.dtype)}*)"
2055
+ attr_type = Reference(uint64)
2056
+ else:
2057
+ cast = ""
2058
+ adj_cast = ""
2059
+ attr_type = Reference(attr_var.type)
2060
+
2061
+ attr = adj.add_var(attr_type)
2062
+
2063
+ if is_reference(aggregate.type):
2064
+ adj.add_forward(f"{attr.emit()} = {cast}&({aggregate.emit()}->{attr_var.label});")
2065
+ else:
2066
+ adj.add_forward(f"{attr.emit()} = {cast}&({aggregate.emit()}.{attr_var.label});")
2067
+
2068
+ if adj.is_differentiable_value_type(strip_reference(attr_type)):
2069
+ adj.add_reverse(f"{aggregate.emit_adj()}.{attr_var.label} += {adj_cast}{attr.emit_adj()};")
2070
+ else:
2071
+ adj.add_reverse(f"{aggregate.emit_adj()}.{attr_var.label} = {adj_cast}{attr.emit_adj()};")
2072
+
2073
+ return attr
2074
+
2075
+ except (KeyError, AttributeError) as e:
2076
+ # Try resolving as type attribute
2077
+ aggregate_type = strip_reference(aggregate.type) if isinstance(aggregate, Var) else aggregate
2078
+
2079
+ type_attribute = adj.resolve_type_attribute(aggregate_type, node.attr)
2080
+ if type_attribute is not None:
2081
+ return type_attribute
2082
+
2083
+ if isinstance(aggregate, Var):
2084
+ node_name = get_node_name_safe(node.value)
2085
+ raise WarpCodegenAttributeError(
2086
+ f"Error, `{node.attr}` is not an attribute of '{node_name}' ({type_repr(aggregate.type)})"
2087
+ ) from e
2088
+ raise WarpCodegenAttributeError(f"Error, `{node.attr}` is not an attribute of '{aggregate}'") from e
2089
+
2090
+ def emit_Assert(adj, node):
2091
+ # eval condition
2092
+ cond = adj.eval(node.test)
2093
+ cond = adj.load(cond)
2094
+
2095
+ source_segment = ast.get_source_segment(adj.source, node)
2096
+ # If a message was provided with the assert, " marks can interfere with the generated code
2097
+ escaped_segment = source_segment.replace('"', '\\"')
2098
+
2099
+ adj.add_forward(f'assert(("{escaped_segment}",{cond.emit()}));')
2100
+
2101
+ def emit_Constant(adj, node):
2102
+ if node.value is None:
2103
+ raise WarpCodegenTypeError("None type unsupported")
2104
+ else:
2105
+ return adj.add_constant(node.value)
2106
+
2107
+ def emit_BinOp(adj, node):
2108
+ # evaluate binary operator arguments
2109
+
2110
+ if warp._src.config.verify_autograd_array_access:
2111
+ # array overwrite tracking: in-place operators are a special case
2112
+ # x[tid] = x[tid] + 1 is a read followed by a write, but we only want to record the write
2113
+ # so we save the current arg read flags and restore them after lhs eval
2114
+ is_read_states = []
2115
+ for arg in adj.args:
2116
+ is_read_states.append(arg.is_read)
2117
+
2118
+ # evaluate lhs binary operator argument
2119
+ left = adj.eval(node.left)
2120
+
2121
+ if warp._src.config.verify_autograd_array_access:
2122
+ # restore arg read flags
2123
+ for i, arg in enumerate(adj.args):
2124
+ arg.is_read = is_read_states[i]
2125
+
2126
+ # evaluate rhs binary operator argument
2127
+ right = adj.eval(node.right)
2128
+
2129
+ name = builtin_operators[type(node.op)]
2130
+
2131
+ try:
2132
+ # Check if there is any user-defined overload for this operator
2133
+ user_func = adj.resolve_external_reference(name)
2134
+ if isinstance(user_func, warp._src.context.Function):
2135
+ return adj.add_call(user_func, (left, right), {}, {})
2136
+ except WarpCodegenError:
2137
+ pass
2138
+
2139
+ return adj.add_builtin_call(name, [left, right])
2140
+
2141
+ def emit_UnaryOp(adj, node):
2142
+ # evaluate unary op arguments
2143
+ arg = adj.eval(node.operand)
2144
+
2145
+ # evaluate expression to a compile-time constant if arg is a constant
2146
+ if arg.constant is not None and math.isfinite(arg.constant):
2147
+ if isinstance(node.op, ast.USub):
2148
+ return adj.add_constant(-arg.constant)
2149
+
2150
+ name = builtin_operators[type(node.op)]
2151
+
2152
+ return adj.add_builtin_call(name, [arg])
2153
+
2154
+ def materialize_redefinitions(adj, symbols):
2155
+ # detect symbols with conflicting definitions (assigned inside the for loop)
2156
+ for items in symbols.items():
2157
+ sym = items[0]
2158
+ if adj.is_constant_iter_symbol(sym):
2159
+ # ignore constant overwriting in for-loops if it is a loop iterator
2160
+ # (it is no problem to unroll static loops multiple times in sequence)
2161
+ continue
2162
+
2163
+ var1 = items[1]
2164
+ var2 = adj.symbols[sym]
2165
+
2166
+ if var1 != var2:
2167
+ if warp._src.config.verbose and not adj.custom_reverse_mode:
2168
+ lineno = adj.lineno + adj.fun_lineno
2169
+ line = adj.source_lines[adj.lineno]
2170
+ msg = f'Warning: detected mutated variable {sym} during a dynamic for-loop in function "{adj.fun_name}" at {adj.filename}:{lineno}: this may not be a differentiable operation.\n{line}\n'
2171
+ print(msg)
2172
+
2173
+ if var1.constant is not None:
2174
+ raise WarpCodegenError(
2175
+ f"Error mutating a constant {sym} inside a dynamic loop, use the following syntax: pi = float(3.141) to declare a dynamic variable"
2176
+ )
2177
+
2178
+ # overwrite the old variable value (violates SSA)
2179
+ adj.add_builtin_call("assign", [var1, var2])
2180
+
2181
+ # reset the symbol to point to the original variable
2182
+ adj.symbols[sym] = var1
2183
+
2184
+ def emit_While(adj, node):
2185
+ adj.begin_while(node.test)
2186
+
2187
+ adj.loop_symbols.append(adj.symbols.copy())
2188
+
2189
+ # eval body
2190
+ for s in node.body:
2191
+ adj.eval(s)
2192
+
2193
+ adj.materialize_redefinitions(adj.loop_symbols[-1])
2194
+ adj.loop_symbols.pop()
2195
+
2196
+ adj.end_while()
2197
+
2198
+ def eval_num(adj, a):
2199
+ if isinstance(a, ast.Constant):
2200
+ return True, a.value
2201
+ if isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Constant):
2202
+ # Negative constant
2203
+ return True, -a.operand.value
2204
+
2205
+ # try and resolve the expression to an object
2206
+ # e.g.: wp.constant in the globals scope
2207
+ obj, _ = adj.resolve_static_expression(a)
2208
+
2209
+ if obj is None:
2210
+ obj = adj.eval(a)
2211
+
2212
+ if isinstance(obj, Var) and obj.constant is not None:
2213
+ obj = obj.constant
2214
+
2215
+ return warp._src.types.is_int(obj), obj
2216
+
2217
+ # detects whether a loop contains a break (or continue) statement
2218
+ def contains_break(adj, body):
2219
+ for s in body:
2220
+ if isinstance(s, ast.Break):
2221
+ return True
2222
+ elif isinstance(s, ast.Continue):
2223
+ return True
2224
+ elif isinstance(s, ast.If):
2225
+ if adj.contains_break(s.body):
2226
+ return True
2227
+ if adj.contains_break(s.orelse):
2228
+ return True
2229
+ else:
2230
+ # note that nested for or while loops containing a break statement
2231
+ # do not affect the current loop
2232
+ pass
2233
+
2234
+ return False
2235
+
2236
+ # returns a constant range() if unrollable, otherwise None
2237
+ def get_unroll_range(adj, loop):
2238
+ if (
2239
+ not isinstance(loop.iter, ast.Call)
2240
+ or not isinstance(loop.iter.func, ast.Name)
2241
+ or loop.iter.func.id != "range"
2242
+ or len(loop.iter.args) == 0
2243
+ or len(loop.iter.args) > 3
2244
+ ):
2245
+ return None
2246
+
2247
+ # if all range() arguments are numeric constants we will unroll
2248
+ # note that this only handles trivial constants, it will not unroll
2249
+ # constant compile-time expressions e.g.: range(0, 3*2)
2250
+
2251
+ # Evaluate the arguments and check that they are numeric constants
2252
+ # It is important to do that in one pass, so that if evaluating these arguments have side effects
2253
+ # the code does not get generated more than once
2254
+ range_args = [adj.eval_num(arg) for arg in loop.iter.args]
2255
+ arg_is_numeric, arg_values = zip(*range_args)
2256
+
2257
+ if all(arg_is_numeric):
2258
+ # All argument are numeric constants
2259
+
2260
+ # range(end)
2261
+ if len(loop.iter.args) == 1:
2262
+ start = 0
2263
+ end = arg_values[0]
2264
+ step = 1
2265
+
2266
+ # range(start, end)
2267
+ elif len(loop.iter.args) == 2:
2268
+ start = arg_values[0]
2269
+ end = arg_values[1]
2270
+ step = 1
2271
+
2272
+ # range(start, end, step)
2273
+ elif len(loop.iter.args) == 3:
2274
+ start = arg_values[0]
2275
+ end = arg_values[1]
2276
+ step = arg_values[2]
2277
+
2278
+ # test if we're above max unroll count
2279
+ max_iters = abs(end - start) // abs(step)
2280
+
2281
+ if "max_unroll" in adj.builder_options:
2282
+ max_unroll = adj.builder_options["max_unroll"]
2283
+ else:
2284
+ max_unroll = warp._src.config.max_unroll
2285
+
2286
+ ok_to_unroll = True
2287
+
2288
+ if max_iters > max_unroll:
2289
+ if warp._src.config.verbose:
2290
+ print(
2291
+ f"Warning: fixed-size loop count of {max_iters} is larger than the module 'max_unroll' limit of {max_unroll}, will generate dynamic loop."
2292
+ )
2293
+ ok_to_unroll = False
2294
+
2295
+ elif adj.contains_break(loop.body):
2296
+ if warp._src.config.verbose:
2297
+ print("Warning: 'break' or 'continue' found in loop body, will generate dynamic loop.")
2298
+ ok_to_unroll = False
2299
+
2300
+ if ok_to_unroll:
2301
+ return range(start, end, step)
2302
+
2303
+ # Unroll is not possible, range needs to be valuated dynamically
2304
+ range_call = adj.add_builtin_call(
2305
+ "range",
2306
+ [adj.add_constant(val) if is_numeric else val for is_numeric, val in range_args],
2307
+ )
2308
+ return range_call
2309
+
2310
+ def record_constant_iter_symbol(adj, sym):
2311
+ adj.loop_const_iter_symbols.add(sym)
2312
+
2313
+ def is_constant_iter_symbol(adj, sym):
2314
+ return sym in adj.loop_const_iter_symbols
2315
+
2316
+ def emit_For(adj, node):
2317
+ # try and unroll simple range() statements that use constant args
2318
+ unroll_range = adj.get_unroll_range(node)
2319
+
2320
+ if isinstance(unroll_range, range):
2321
+ const_iter_sym = node.target.id
2322
+ # prevent constant conflicts in `materialize_redefinitions()`
2323
+ adj.record_constant_iter_symbol(const_iter_sym)
2324
+
2325
+ # unroll static for-loop
2326
+ for i in unroll_range:
2327
+ const_iter = adj.add_constant(i)
2328
+ adj.symbols[const_iter_sym] = const_iter
2329
+
2330
+ # eval body
2331
+ for s in node.body:
2332
+ adj.eval(s)
2333
+
2334
+ # otherwise generate a dynamic loop
2335
+ else:
2336
+ # evaluate the Iterable -- only if not previously evaluated when trying to unroll
2337
+ if unroll_range is not None:
2338
+ # Range has already been evaluated when trying to unroll, do not re-evaluate
2339
+ iter = unroll_range
2340
+ else:
2341
+ iter = adj.eval(node.iter)
2342
+
2343
+ adj.symbols[node.target.id] = adj.begin_for(iter)
2344
+
2345
+ # for loops should be side-effect free, here we store a copy
2346
+ adj.loop_symbols.append(adj.symbols.copy())
2347
+
2348
+ # eval body
2349
+ for s in node.body:
2350
+ adj.eval(s)
2351
+
2352
+ adj.materialize_redefinitions(adj.loop_symbols[-1])
2353
+ adj.loop_symbols.pop()
2354
+
2355
+ adj.end_for(iter)
2356
+
2357
+ def emit_Break(adj, node):
2358
+ adj.materialize_redefinitions(adj.loop_symbols[-1])
2359
+
2360
+ adj.add_forward(f"goto end_{adj.loop_blocks[-1].label};")
2361
+
2362
+ def emit_Continue(adj, node):
2363
+ adj.materialize_redefinitions(adj.loop_symbols[-1])
2364
+
2365
+ adj.add_forward(f"goto start_{adj.loop_blocks[-1].label};")
2366
+
2367
+ def emit_Expr(adj, node):
2368
+ return adj.eval(node.value)
2369
+
2370
+ def check_tid_in_func_error(adj, node):
2371
+ if adj.is_user_function:
2372
+ if hasattr(node.func, "attr") and node.func.attr == "tid":
2373
+ lineno = adj.lineno + adj.fun_lineno
2374
+ line = adj.source_lines[adj.lineno]
2375
+ raise WarpCodegenError(
2376
+ "tid() may only be called from a Warp kernel, not a Warp function. "
2377
+ "Instead, obtain the indices from a @wp.kernel and pass them as "
2378
+ f"arguments to the function {adj.fun_name}, {adj.filename}:{lineno}:\n{line}\n"
2379
+ )
2380
+
2381
+ def resolve_arg(adj, arg):
2382
+ # Always try to start with evaluating the argument since it can help
2383
+ # detecting some issues such as global variables being accessed.
2384
+ try:
2385
+ var = adj.eval(arg)
2386
+ except (WarpCodegenError, WarpCodegenKeyError) as e:
2387
+ error = e
2388
+ else:
2389
+ error = None
2390
+
2391
+ # Check if we can resolve the argument as a static expression.
2392
+ # If not, return the variable resulting from evaluating the argument.
2393
+ expr, _ = adj.resolve_static_expression(arg)
2394
+ if expr is None:
2395
+ if error is not None:
2396
+ raise error
2397
+
2398
+ return var
2399
+
2400
+ if isinstance(expr, (type, Struct, Var, warp._src.context.Function)):
2401
+ return expr
2402
+
2403
+ if isinstance(expr, (enum.IntEnum, enum.IntFlag)):
2404
+ return adj.add_constant(int(expr))
2405
+
2406
+ return adj.add_constant(expr)
2407
+
2408
+ def emit_Call(adj, node):
2409
+ adj.check_tid_in_func_error(node)
2410
+
2411
+ # try and lookup function in globals by
2412
+ # resolving path (e.g.: module.submodule.attr)
2413
+ if hasattr(node.func, "warp_func"):
2414
+ func = node.func.warp_func
2415
+ path = []
2416
+ else:
2417
+ func, path = adj.resolve_static_expression(node.func)
2418
+ if func is None:
2419
+ func = adj.eval(node.func)
2420
+
2421
+ if adj.is_static_expression(func):
2422
+ # try to evaluate wp.static() expressions
2423
+ obj, code = adj.evaluate_static_expression(node)
2424
+ if obj is not None:
2425
+ adj.static_expressions[code] = obj
2426
+ if isinstance(obj, warp._src.context.Function):
2427
+ # special handling for wp.static() evaluating to a function
2428
+ return obj
2429
+ else:
2430
+ out = adj.add_constant(obj)
2431
+ return out
2432
+
2433
+ type_args = {}
2434
+
2435
+ if len(path) > 0 and not isinstance(func, warp._src.context.Function):
2436
+ attr = path[-1]
2437
+ caller = func
2438
+ func = None
2439
+
2440
+ # try and lookup function name in builtins (e.g.: using `dot` directly without wp prefix)
2441
+ if attr in warp._src.context.builtin_functions:
2442
+ func = warp._src.context.builtin_functions[attr]
2443
+
2444
+ # vector class type e.g.: wp.vec3f constructor
2445
+ if func is None and hasattr(caller, "_wp_generic_type_str_"):
2446
+ func = warp._src.context.builtin_functions.get(caller._wp_constructor_)
2447
+
2448
+ # scalar class type e.g.: wp.int8 constructor
2449
+ if func is None and hasattr(caller, "__name__") and caller.__name__ in warp._src.context.builtin_functions:
2450
+ func = warp._src.context.builtin_functions.get(caller.__name__)
2451
+
2452
+ # struct constructor
2453
+ if func is None and isinstance(caller, Struct):
2454
+ if adj.builder is not None:
2455
+ adj.builder.build_struct_recursive(caller)
2456
+ if node.args or node.keywords:
2457
+ func = caller.value_constructor
2458
+ else:
2459
+ func = caller.default_constructor
2460
+
2461
+ # lambda function
2462
+ if func is None and getattr(caller, "__name__", None) == "<lambda>":
2463
+ raise NotImplementedError("Lambda expressions are not yet supported")
2464
+
2465
+ if hasattr(caller, "_wp_type_args_"):
2466
+ type_args = caller._wp_type_args_
2467
+
2468
+ if func is None:
2469
+ raise WarpCodegenError(
2470
+ f"Could not find function {'.'.join(path)} as a built-in or user-defined function. Note that user functions must be annotated with a @wp.func decorator to be called from a kernel."
2471
+ )
2472
+
2473
+ # get expected return count, e.g.: for multi-assignment
2474
+ min_outputs = None
2475
+ if hasattr(node, "expects"):
2476
+ min_outputs = node.expects
2477
+
2478
+ # Evaluate all positional and keywords arguments.
2479
+ args = tuple(adj.resolve_arg(x) for x in node.args)
2480
+ kwargs = {x.arg: adj.resolve_arg(x.value) for x in node.keywords}
2481
+
2482
+ out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs)
2483
+
2484
+ if warp._src.config.verify_autograd_array_access:
2485
+ # Extract the types and values passed as arguments to the function call.
2486
+ arg_types = tuple(strip_reference(get_arg_type(x)) for x in args)
2487
+ kwarg_types = {k: strip_reference(get_arg_type(v)) for k, v in kwargs.items()}
2488
+
2489
+ # Resolve the exact function signature among any existing overload.
2490
+ resolved_func = adj.resolve_func(func, arg_types, kwarg_types, min_outputs)
2491
+
2492
+ # update arg read/write states according to what happens to that arg in the called function
2493
+ if hasattr(resolved_func, "adj"):
2494
+ for i, arg in enumerate(args):
2495
+ if resolved_func.adj.args[i].is_write:
2496
+ kernel_name = adj.fun_name
2497
+ filename = adj.filename
2498
+ lineno = adj.lineno + adj.fun_lineno
2499
+ arg.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2500
+ if resolved_func.adj.args[i].is_read:
2501
+ arg.mark_read()
2502
+
2503
+ return out
2504
+
2505
+ def emit_Index(adj, node):
2506
+ # the ast.Index node appears in 3.7 versions
2507
+ # when performing array slices, e.g.: x = arr[i]
2508
+ # but in version 3.8 and higher it does not appear
2509
+
2510
+ if hasattr(node, "is_adjoint"):
2511
+ node.value.is_adjoint = True
2512
+
2513
+ return adj.eval(node.value)
2514
+
2515
+ def eval_indices(adj, target_type, indices):
2516
+ nodes = indices
2517
+ if hasattr(target_type, "_wp_generic_type_hint_"):
2518
+ indices = []
2519
+ for dim, node in enumerate(nodes):
2520
+ if isinstance(node, ast.Slice):
2521
+ # In the context of slicing a vec/mat type, indices are expected
2522
+ # to be compile-time constants, hence we can infer the actual slice
2523
+ # bounds also at compile-time.
2524
+ length = target_type._shape_[dim]
2525
+ step = 1 if node.step is None else adj.eval(node.step).constant
2526
+
2527
+ if node.lower is None:
2528
+ start = length - 1 if step < 0 else 0
2529
+ else:
2530
+ start = adj.eval(node.lower).constant
2531
+ start = min(max(start, -length), length)
2532
+ start = start + length if start < 0 else start
2533
+
2534
+ if node.upper is None:
2535
+ stop = -1 if step < 0 else length
2536
+ else:
2537
+ stop = adj.eval(node.upper).constant
2538
+ stop = min(max(stop, -length), length)
2539
+ stop = stop + length if stop < 0 else stop
2540
+
2541
+ slice = adj.add_builtin_call("slice", (start, stop, step))
2542
+ indices.append(slice)
2543
+ else:
2544
+ indices.append(adj.eval(node))
2545
+
2546
+ return tuple(indices)
2547
+ else:
2548
+ return tuple(adj.eval(x) for x in nodes)
2549
+
2550
+ def emit_indexing(adj, target, indices):
2551
+ target_type = strip_reference(target.type)
2552
+ indices = adj.eval_indices(target_type, indices)
2553
+
2554
+ if is_array(target_type):
2555
+ if len(indices) == target_type.ndim and all(
2556
+ warp._src.types.type_is_int(strip_reference(x.type)) for x in indices
2557
+ ):
2558
+ # handles array loads (where each dimension has an index specified)
2559
+ out = adj.add_builtin_call("address", [target, *indices])
2560
+
2561
+ if warp._src.config.verify_autograd_array_access:
2562
+ target.mark_read()
2563
+
2564
+ else:
2565
+ if isinstance(target_type, warp._src.types.array):
2566
+ # In order to reduce the number of overloads needed in the C
2567
+ # implementation to support combinations of int/slice indices,
2568
+ # we convert all integer indices into slices, and set their
2569
+ # step to 0 if they are representing an integer index.
2570
+ new_indices = []
2571
+ for idx in indices:
2572
+ if not warp._src.types.is_slice(strip_reference(idx.type)):
2573
+ new_idx = adj.add_builtin_call("slice", (idx, idx, 0))
2574
+ new_indices.append(new_idx)
2575
+ else:
2576
+ new_indices.append(idx)
2577
+
2578
+ indices = new_indices
2579
+
2580
+ # handles array views (fewer indices than dimensions)
2581
+ out = adj.add_builtin_call("view", [target, *indices])
2582
+
2583
+ if warp._src.config.verify_autograd_array_access:
2584
+ # store reference to target Var to propagate downstream read/write state back to root arg Var
2585
+ out.parent = target
2586
+
2587
+ # view arg inherits target Var's read/write states
2588
+ out.is_read = target.is_read
2589
+ out.is_write = target.is_write
2590
+
2591
+ elif is_tile(target_type):
2592
+ if len(indices) >= len(target_type.shape): # equality for scalars, inequality for composite types
2593
+ # handles extracting a single element from a tile
2594
+ out = adj.add_builtin_call("tile_extract", [target, *indices])
2595
+ elif len(indices) < len(target_type.shape):
2596
+ # handles tile views
2597
+ out = adj.add_builtin_call("tile_view", [target, indices])
2598
+ else:
2599
+ raise RuntimeError(
2600
+ f"Incorrect number of indices specified for a tile view/extract, got {len(indices)} indices for a {len(target_type.shape)} dimensional tile."
2601
+ )
2602
+
2603
+ else:
2604
+ # handles non-array type indexing, e.g: vec3, mat33, etc
2605
+ out = adj.add_builtin_call("extract", [target, *indices])
2606
+
2607
+ return out
2608
+
2609
+ # from a list of lists of indices, strip the first `count` indices
2610
+ @staticmethod
2611
+ def strip_indices(indices, count):
2612
+ dim = count
2613
+ while count > 0:
2614
+ ij = indices[0]
2615
+ indices = indices[1:]
2616
+ count -= len(ij)
2617
+
2618
+ # report straddling like in `arr2d[0][1,2]` as a syntax error
2619
+ if count < 0:
2620
+ raise WarpCodegenError(
2621
+ f"Incorrect number of indices specified for array indexing, got {dim - count} indices for a {dim} dimensional array."
2622
+ )
2623
+
2624
+ return indices
2625
+
2626
+ def recurse_subscript(adj, node, indices):
2627
+ if isinstance(node, ast.Name):
2628
+ target = adj.eval(node)
2629
+ return target, indices
2630
+
2631
+ if isinstance(node, ast.Subscript):
2632
+ if hasattr(node.value, "attr") and node.value.attr == "adjoint":
2633
+ return adj.eval(node), indices
2634
+
2635
+ if isinstance(node.slice, ast.Tuple):
2636
+ ij = node.slice.elts
2637
+ elif isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Tuple):
2638
+ # The node `ast.Index` is deprecated in Python 3.9.
2639
+ ij = node.slice.value.elts
2640
+ elif isinstance(node.slice, ast.ExtSlice):
2641
+ # The node `ast.ExtSlice` is deprecated in Python 3.9.
2642
+ ij = node.slice.dims
2643
+ else:
2644
+ ij = [node.slice]
2645
+
2646
+ indices = [ij, *indices] # prepend
2647
+
2648
+ target, indices = adj.recurse_subscript(node.value, indices)
2649
+
2650
+ target_type = strip_reference(target.type)
2651
+ if is_array(target_type):
2652
+ flat_indices = [i for ij in indices for i in ij]
2653
+ if len(flat_indices) > target_type.ndim:
2654
+ target = adj.emit_indexing(target, flat_indices[: target_type.ndim])
2655
+ indices = adj.strip_indices(indices, target_type.ndim)
2656
+
2657
+ return target, indices
2658
+
2659
+ target = adj.eval(node)
2660
+ return target, indices
2661
+
2662
+ # returns the object being indexed, and the list of indices
2663
+ def eval_subscript(adj, node):
2664
+ target, indices = adj.recurse_subscript(node, [])
2665
+ flat_indices = [i for ij in indices for i in ij]
2666
+ return target, flat_indices
2667
+
2668
+ def emit_Subscript(adj, node):
2669
+ if hasattr(node.value, "attr") and node.value.attr == "adjoint":
2670
+ # handle adjoint of a variable, i.e. wp.adjoint[var]
2671
+ node.slice.is_adjoint = True
2672
+ var = adj.eval(node.slice)
2673
+ var_name = var.label
2674
+ var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False)
2675
+ return var
2676
+
2677
+ target, indices = adj.eval_subscript(node)
2678
+
2679
+ return adj.emit_indexing(target, indices)
2680
+
2681
+ def emit_Slice(adj, node):
2682
+ start = SLICE_BEGIN if node.lower is None else adj.eval(node.lower)
2683
+ stop = SLICE_END if node.upper is None else adj.eval(node.upper)
2684
+ step = 1 if node.step is None else adj.eval(node.step)
2685
+ return adj.add_builtin_call("slice", (start, stop, step))
2686
+
2687
+ def emit_Assign(adj, node):
2688
+ if len(node.targets) != 1:
2689
+ raise WarpCodegenError("Assigning the same value to multiple variables is not supported")
2690
+
2691
+ # Check if the rhs corresponds to an unsupported construct.
2692
+ # Tuples are supported in the context of assigning multiple variables
2693
+ # at once, but not for simple assignments like `x = (1, 2, 3)`.
2694
+ # Therefore, we need to catch this specific case here instead of
2695
+ # more generally in `adj.eval()`.
2696
+ if isinstance(node.value, ast.List):
2697
+ raise WarpCodegenError(
2698
+ "List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
2699
+ )
2700
+
2701
+ lhs = node.targets[0]
2702
+
2703
+ if isinstance(lhs, ast.Tuple) and isinstance(node.value, ast.Call):
2704
+ # record the expected number of outputs on the node
2705
+ # we do this so we can decide which function to
2706
+ # call based on the number of expected outputs
2707
+ node.value.expects = len(lhs.elts)
2708
+
2709
+ # evaluate rhs
2710
+ if isinstance(lhs, ast.Tuple) and isinstance(node.value, ast.Tuple):
2711
+ rhs = [adj.eval(v) for v in node.value.elts]
2712
+ else:
2713
+ rhs = adj.eval(node.value)
2714
+
2715
+ # handle the case where we are assigning multiple output variables
2716
+ if isinstance(lhs, ast.Tuple):
2717
+ subtype = getattr(rhs, "type", None)
2718
+
2719
+ if isinstance(subtype, warp._src.types.tuple_t):
2720
+ if len(rhs.type.types) != len(lhs.elts):
2721
+ raise WarpCodegenError(
2722
+ f"Invalid number of values to unpack (expected {len(lhs.elts)}, got {len(rhs.type.types)})."
2723
+ )
2724
+ rhs = tuple(adj.add_builtin_call("extract", (rhs, adj.add_constant(i))) for i in range(len(lhs.elts)))
2725
+
2726
+ names = []
2727
+ for v in lhs.elts:
2728
+ if isinstance(v, ast.Name):
2729
+ names.append(v.id)
2730
+ else:
2731
+ raise WarpCodegenError(
2732
+ "Multiple return functions can only assign to simple variables, e.g.: x, y = func()"
2733
+ )
2734
+
2735
+ if len(names) != len(rhs):
2736
+ raise WarpCodegenError(
2737
+ f"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {len(rhs)}, got {len(names)})"
2738
+ )
2739
+
2740
+ out = rhs
2741
+ for name, rhs in zip(names, out):
2742
+ if name in adj.symbols:
2743
+ if not types_equal(rhs.type, adj.symbols[name].type):
2744
+ raise WarpCodegenTypeError(
2745
+ f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
2746
+ )
2747
+
2748
+ adj.symbols[name] = rhs
2749
+
2750
+ # handles the case where we are assigning to an array index (e.g.: arr[i] = 2.0)
2751
+ elif isinstance(lhs, ast.Subscript):
2752
+ if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
2753
+ # handle adjoint of a variable, i.e. wp.adjoint[var]
2754
+ lhs.slice.is_adjoint = True
2755
+ src_var = adj.eval(lhs.slice)
2756
+ var = Var(f"adj_{src_var.label}", type=src_var.type, constant=None, prefix=False)
2757
+ adj.add_forward(f"{var.emit()} = {rhs.emit()};")
2758
+ return
2759
+
2760
+ target, indices = adj.eval_subscript(lhs)
2761
+
2762
+ target_type = strip_reference(target.type)
2763
+ indices = adj.eval_indices(target_type, indices)
2764
+
2765
+ if is_array(target_type):
2766
+ adj.add_builtin_call("array_store", [target, *indices, rhs])
2767
+
2768
+ if warp._src.config.verify_autograd_array_access:
2769
+ kernel_name = adj.fun_name
2770
+ filename = adj.filename
2771
+ lineno = adj.lineno + adj.fun_lineno
2772
+
2773
+ target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2774
+
2775
+ elif is_tile(target_type):
2776
+ adj.add_builtin_call("assign", [target, *indices, rhs])
2777
+
2778
+ elif (
2779
+ type_is_vector(target_type)
2780
+ or type_is_quaternion(target_type)
2781
+ or type_is_matrix(target_type)
2782
+ or type_is_transformation(target_type)
2783
+ ):
2784
+ # recursively unwind AST, stopping at penultimate node
2785
+ root = lhs
2786
+ while hasattr(root.value, "value"):
2787
+ root = root.value
2788
+ # lhs is updating a variable adjoint (i.e. wp.adjoint[var])
2789
+ if hasattr(root, "attr") and root.attr == "adjoint":
2790
+ attr = adj.add_builtin_call("index", [target, *indices])
2791
+ adj.add_builtin_call("store", [attr, rhs])
2792
+ return
2793
+
2794
+ # TODO: array vec component case
2795
+ if is_reference(target.type):
2796
+ attr = adj.add_builtin_call("indexref", [target, *indices])
2797
+ adj.add_builtin_call("store", [attr, rhs])
2798
+
2799
+ if warp._src.config.verbose and not adj.custom_reverse_mode:
2800
+ lineno = adj.lineno + adj.fun_lineno
2801
+ line = adj.source_lines[adj.lineno]
2802
+ node_source = adj.get_node_source(lhs.value)
2803
+ print(
2804
+ f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
2805
+ )
2806
+ else:
2807
+ if warp._src.config.enable_vector_component_overwrites:
2808
+ out = adj.add_builtin_call("assign_copy", [target, *indices, rhs])
2809
+
2810
+ # re-point target symbol to out var
2811
+ for id in adj.symbols:
2812
+ if adj.symbols[id] == target:
2813
+ adj.symbols[id] = out
2814
+ break
2815
+ else:
2816
+ adj.add_builtin_call("assign_inplace", [target, *indices, rhs])
2817
+
2818
+ else:
2819
+ raise WarpCodegenError(
2820
+ f"Can only subscript assign array, vector, quaternion, transformation, and matrix types, got {target_type}"
2821
+ )
2822
+
2823
+ elif isinstance(lhs, ast.Name):
2824
+ # symbol name
2825
+ name = lhs.id
2826
+
2827
+ # check type matches if symbol already defined
2828
+ if name in adj.symbols:
2829
+ if not types_equal(strip_reference(rhs.type), adj.symbols[name].type):
2830
+ raise WarpCodegenTypeError(
2831
+ f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
2832
+ )
2833
+
2834
+ if isinstance(node.value, ast.Tuple):
2835
+ out = rhs
2836
+ elif isinstance(rhs, Sequence):
2837
+ out = adj.add_builtin_call("tuple", rhs)
2838
+ elif isinstance(node.value, ast.Name) or is_reference(rhs.type):
2839
+ out = adj.add_builtin_call("copy", [rhs])
2840
+ else:
2841
+ out = rhs
2842
+
2843
+ # update symbol map (assumes lhs is a Name node)
2844
+ adj.symbols[name] = out
2845
+
2846
+ elif isinstance(lhs, ast.Attribute):
2847
+ aggregate = adj.eval(lhs.value)
2848
+ aggregate_type = strip_reference(aggregate.type)
2849
+
2850
+ # assigning to a vector or quaternion component
2851
+ if type_is_vector(aggregate_type) or type_is_quaternion(aggregate_type):
2852
+ index = adj.vector_component_index(lhs.attr, aggregate_type)
2853
+
2854
+ if is_reference(aggregate.type):
2855
+ attr = adj.add_builtin_call("indexref", [aggregate, index])
2856
+ adj.add_builtin_call("store", [attr, rhs])
2857
+ else:
2858
+ if warp._src.config.enable_vector_component_overwrites:
2859
+ out = adj.add_builtin_call("assign_copy", [aggregate, index, rhs])
2860
+
2861
+ # re-point target symbol to out var
2862
+ for id in adj.symbols:
2863
+ if adj.symbols[id] == aggregate:
2864
+ adj.symbols[id] = out
2865
+ break
2866
+ else:
2867
+ adj.add_builtin_call("assign_inplace", [aggregate, index, rhs])
2868
+
2869
+ elif type_is_transformation(aggregate_type):
2870
+ component = adj.transform_component(lhs.attr)
2871
+
2872
+ # TODO: x[i,j].p = rhs case
2873
+ if is_reference(aggregate.type):
2874
+ raise WarpCodegenError(f"Error, assigning transform attribute {component} to an array element")
2875
+
2876
+ if component == "p":
2877
+ return adj.add_builtin_call("transform_set_translation", [aggregate, rhs])
2878
+ else:
2879
+ return adj.add_builtin_call("transform_set_rotation", [aggregate, rhs])
2880
+
2881
+ else:
2882
+ attr = adj.emit_Attribute(lhs)
2883
+ if is_reference(attr.type):
2884
+ adj.add_builtin_call("store", [attr, rhs])
2885
+ else:
2886
+ adj.add_builtin_call("assign", [attr, rhs])
2887
+
2888
+ if warp._src.config.verbose and not adj.custom_reverse_mode:
2889
+ lineno = adj.lineno + adj.fun_lineno
2890
+ line = adj.source_lines[adj.lineno]
2891
+ msg = f'Warning: detected mutated struct {attr.label} during function "{adj.fun_name}" at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n'
2892
+ print(msg)
2893
+
2894
+ else:
2895
+ raise WarpCodegenError("Error, unsupported assignment statement.")
2896
+
2897
+ def emit_Return(adj, node):
2898
+ if node.value is None:
2899
+ var = None
2900
+ elif isinstance(node.value, ast.Tuple):
2901
+ var = tuple(adj.eval(arg) for arg in node.value.elts)
2902
+ else:
2903
+ var = adj.eval(node.value)
2904
+ if not isinstance(var, list) and not isinstance(var, tuple):
2905
+ var = (var,)
2906
+
2907
+ if adj.return_var is not None:
2908
+ old_ctypes = tuple(v.ctype(value_type=True) for v in adj.return_var)
2909
+ new_ctypes = tuple(v.ctype(value_type=True) for v in var)
2910
+ if old_ctypes != new_ctypes:
2911
+ raise WarpCodegenTypeError(
2912
+ f"Error, function returned different types, previous: [{', '.join(old_ctypes)}], new [{', '.join(new_ctypes)}]"
2913
+ )
2914
+
2915
+ if var is not None:
2916
+ adj.return_var = ()
2917
+ for ret in var:
2918
+ if is_reference(ret.type):
2919
+ ret_var = adj.add_builtin_call("copy", [ret])
2920
+ else:
2921
+ ret_var = ret
2922
+ adj.return_var += (ret_var,)
2923
+
2924
+ adj.add_return(adj.return_var)
2925
+
2926
+ def emit_AugAssign(adj, node):
2927
+ lhs = node.target
2928
+
2929
+ # replace augmented assignment with assignment statement + binary op (default behaviour)
2930
+ def make_new_assign_statement():
2931
+ new_node = ast.Assign(targets=[lhs], value=ast.BinOp(lhs, node.op, node.value))
2932
+ adj.eval(new_node)
2933
+
2934
+ rhs = adj.eval(node.value)
2935
+
2936
+ if isinstance(lhs, ast.Subscript):
2937
+ # wp.adjoint[var] appears in custom grad functions, and does not require
2938
+ # special consideration in the AugAssign case
2939
+ if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
2940
+ make_new_assign_statement()
2941
+ return
2942
+
2943
+ target, indices = adj.eval_subscript(lhs)
2944
+
2945
+ target_type = strip_reference(target.type)
2946
+ indices = adj.eval_indices(target_type, indices)
2947
+
2948
+ if is_array(target_type):
2949
+ # target_types int8, uint8, int16, uint16 are not suitable for atomic array accumulation
2950
+ if target_type.dtype in warp._src.types.non_atomic_types:
2951
+ make_new_assign_statement()
2952
+ return
2953
+
2954
+ # the same holds true for vecs/mats/quats that are composed of these types
2955
+ if (
2956
+ type_is_vector(target_type.dtype)
2957
+ or type_is_quaternion(target_type.dtype)
2958
+ or type_is_matrix(target_type.dtype)
2959
+ or type_is_transformation(target_type.dtype)
2960
+ ):
2961
+ dtype = getattr(target_type.dtype, "_wp_scalar_type_", None)
2962
+ if dtype in warp._src.types.non_atomic_types:
2963
+ make_new_assign_statement()
2964
+ return
2965
+
2966
+ kernel_name = adj.fun_name
2967
+ filename = adj.filename
2968
+ lineno = adj.lineno + adj.fun_lineno
2969
+
2970
+ if isinstance(node.op, ast.Add):
2971
+ adj.add_builtin_call("atomic_add", [target, *indices, rhs])
2972
+
2973
+ if warp._src.config.verify_autograd_array_access:
2974
+ target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2975
+
2976
+ elif isinstance(node.op, ast.Sub):
2977
+ adj.add_builtin_call("atomic_sub", [target, *indices, rhs])
2978
+
2979
+ if warp._src.config.verify_autograd_array_access:
2980
+ target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2981
+
2982
+ elif isinstance(node.op, ast.BitAnd):
2983
+ adj.add_builtin_call("atomic_and", [target, *indices, rhs])
2984
+
2985
+ if warp._src.config.verify_autograd_array_access:
2986
+ target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2987
+
2988
+ elif isinstance(node.op, ast.BitOr):
2989
+ adj.add_builtin_call("atomic_or", [target, *indices, rhs])
2990
+
2991
+ if warp._src.config.verify_autograd_array_access:
2992
+ target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2993
+
2994
+ elif isinstance(node.op, ast.BitXor):
2995
+ adj.add_builtin_call("atomic_xor", [target, *indices, rhs])
2996
+
2997
+ if warp._src.config.verify_autograd_array_access:
2998
+ target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2999
+ else:
3000
+ if warp._src.config.verbose:
3001
+ print(f"Warning: in-place op {node.op} is not differentiable")
3002
+ make_new_assign_statement()
3003
+ return
3004
+
3005
+ elif (
3006
+ type_is_vector(target_type)
3007
+ or type_is_quaternion(target_type)
3008
+ or type_is_matrix(target_type)
3009
+ or type_is_transformation(target_type)
3010
+ ):
3011
+ if isinstance(node.op, ast.Add):
3012
+ adj.add_builtin_call("add_inplace", [target, *indices, rhs])
3013
+ elif isinstance(node.op, ast.Sub):
3014
+ adj.add_builtin_call("sub_inplace", [target, *indices, rhs])
3015
+ elif isinstance(node.op, ast.BitAnd):
3016
+ adj.add_builtin_call("bit_and_inplace", [target, *indices, rhs])
3017
+ elif isinstance(node.op, ast.BitOr):
3018
+ adj.add_builtin_call("bit_or_inplace", [target, *indices, rhs])
3019
+ elif isinstance(node.op, ast.BitXor):
3020
+ adj.add_builtin_call("bit_xor_inplace", [target, *indices, rhs])
3021
+ else:
3022
+ if warp._src.config.verbose:
3023
+ print(f"Warning: in-place op {node.op} is not differentiable")
3024
+ make_new_assign_statement()
3025
+ return
3026
+
3027
+ elif is_tile(target.type):
3028
+ if isinstance(node.op, ast.Add):
3029
+ adj.add_builtin_call("tile_add_inplace", [target, *indices, rhs])
3030
+ elif isinstance(node.op, ast.Sub):
3031
+ adj.add_builtin_call("tile_sub_inplace", [target, *indices, rhs])
3032
+ elif isinstance(node.op, ast.BitAnd):
3033
+ adj.add_builtin_call("tile_bit_and_inplace", [target, *indices, rhs])
3034
+ elif isinstance(node.op, ast.BitOr):
3035
+ adj.add_builtin_call("tile_bit_or_inplace", [target, *indices, rhs])
3036
+ elif isinstance(node.op, ast.BitXor):
3037
+ adj.add_builtin_call("tile_bit_xor_inplace", [target, *indices, rhs])
3038
+ else:
3039
+ if warp._src.config.verbose:
3040
+ print(f"Warning: in-place op {node.op} is not differentiable")
3041
+ make_new_assign_statement()
3042
+ return
3043
+
3044
+ else:
3045
+ raise WarpCodegenError("Can only subscript in-place assign array, vector, quaternion, and matrix types")
3046
+
3047
+ elif isinstance(lhs, ast.Name):
3048
+ target = adj.eval(node.target)
3049
+
3050
+ if is_tile(target.type) and is_tile(rhs.type):
3051
+ if isinstance(node.op, ast.Add):
3052
+ adj.add_builtin_call("add_inplace", [target, rhs])
3053
+ elif isinstance(node.op, ast.Sub):
3054
+ adj.add_builtin_call("sub_inplace", [target, rhs])
3055
+ elif isinstance(node.op, ast.BitAnd):
3056
+ adj.add_builtin_call("bit_and_inplace", [target, rhs])
3057
+ elif isinstance(node.op, ast.BitOr):
3058
+ adj.add_builtin_call("bit_or_inplace", [target, rhs])
3059
+ elif isinstance(node.op, ast.BitXor):
3060
+ adj.add_builtin_call("bit_xor_inplace", [target, rhs])
3061
+ else:
3062
+ make_new_assign_statement()
3063
+ return
3064
+ else:
3065
+ make_new_assign_statement()
3066
+ return
3067
+
3068
+ # TODO
3069
+ elif isinstance(lhs, ast.Attribute):
3070
+ make_new_assign_statement()
3071
+ return
3072
+
3073
+ else:
3074
+ make_new_assign_statement()
3075
+ return
3076
+
3077
+ def emit_Tuple(adj, node):
3078
+ elements = tuple(adj.eval(x) for x in node.elts)
3079
+ return adj.add_builtin_call("tuple", elements)
3080
+
3081
+ def emit_Pass(adj, node):
3082
+ pass
3083
+
3084
+ node_visitors: ClassVar[dict[type[ast.AST], Callable]] = {
3085
+ ast.FunctionDef: emit_FunctionDef,
3086
+ ast.If: emit_If,
3087
+ ast.IfExp: emit_IfExp,
3088
+ ast.Compare: emit_Compare,
3089
+ ast.BoolOp: emit_BoolOp,
3090
+ ast.Name: emit_Name,
3091
+ ast.Attribute: emit_Attribute,
3092
+ ast.Constant: emit_Constant,
3093
+ ast.BinOp: emit_BinOp,
3094
+ ast.UnaryOp: emit_UnaryOp,
3095
+ ast.While: emit_While,
3096
+ ast.For: emit_For,
3097
+ ast.Break: emit_Break,
3098
+ ast.Continue: emit_Continue,
3099
+ ast.Expr: emit_Expr,
3100
+ ast.Call: emit_Call,
3101
+ ast.Index: emit_Index, # Deprecated in 3.9
3102
+ ast.Subscript: emit_Subscript,
3103
+ ast.Slice: emit_Slice,
3104
+ ast.Assign: emit_Assign,
3105
+ ast.Return: emit_Return,
3106
+ ast.AugAssign: emit_AugAssign,
3107
+ ast.Tuple: emit_Tuple,
3108
+ ast.Pass: emit_Pass,
3109
+ ast.Assert: emit_Assert,
3110
+ }
3111
+
3112
+ def eval(adj, node):
3113
+ if hasattr(node, "lineno"):
3114
+ adj.set_lineno(node.lineno - 1)
3115
+
3116
+ try:
3117
+ emit_node = adj.node_visitors[type(node)]
3118
+ except KeyError as e:
3119
+ type_name = type(node).__name__
3120
+ namespace = "ast." if isinstance(node, ast.AST) else ""
3121
+ raise WarpCodegenError(f"Construct `{namespace}{type_name}` not supported in kernels.") from e
3122
+
3123
+ return emit_node(adj, node)
3124
+
3125
+ # helper to evaluate expressions of the form
3126
+ # obj1.obj2.obj3.attr in the function's global scope
3127
+ def resolve_path(adj, path):
3128
+ if len(path) == 0:
3129
+ return None
3130
+
3131
+ # if root is overshadowed by local symbols, bail out
3132
+ if path[0] in adj.symbols:
3133
+ return None
3134
+
3135
+ # look up in closure/global variables
3136
+ expr = adj.resolve_external_reference(path[0])
3137
+
3138
+ # Support Warp types in kernels without the module suffix (e.g. v = vec3(0.0,0.2,0.4)):
3139
+ if expr is None:
3140
+ expr = getattr(warp, path[0], None)
3141
+
3142
+ # look up in builtins
3143
+ if expr is None:
3144
+ expr = __builtins__.get(path[0])
3145
+
3146
+ if expr is not None:
3147
+ for i in range(1, len(path)):
3148
+ if hasattr(expr, path[i]):
3149
+ expr = getattr(expr, path[i])
3150
+
3151
+ return expr
3152
+
3153
+ # retrieves a dictionary of all closure and global variables and their values
3154
+ # to be used in the evaluation context of wp.static() expressions
3155
+ def get_static_evaluation_context(adj):
3156
+ closure_vars = dict(
3157
+ zip(
3158
+ adj.func.__code__.co_freevars,
3159
+ [c.cell_contents for c in (adj.func.__closure__ or [])],
3160
+ )
3161
+ )
3162
+
3163
+ vars_dict = {}
3164
+ vars_dict.update(adj.func.__globals__)
3165
+ # variables captured in closure have precedence over global vars
3166
+ vars_dict.update(closure_vars)
3167
+
3168
+ return vars_dict
3169
+
3170
+ def is_static_expression(adj, func):
3171
+ return (
3172
+ isinstance(func, types.FunctionType)
3173
+ and func.__module__ == "warp._src.builtins"
3174
+ and func.__qualname__ == "static"
3175
+ )
3176
+
3177
+ # verify the return type of a wp.static() expression is supported inside a Warp kernel
3178
+ def verify_static_return_value(adj, value):
3179
+ if value is None:
3180
+ raise ValueError("None is returned")
3181
+ if warp._src.types.is_value(value):
3182
+ return True
3183
+ if warp._src.types.is_array(value):
3184
+ # more useful explanation for the common case of creating a Warp array
3185
+ raise ValueError("a Warp array cannot be created inside Warp kernels")
3186
+ if isinstance(value, str):
3187
+ # we want to support cases such as `print(wp.static("test"))`
3188
+ return True
3189
+ if isinstance(value, warp._src.context.Function):
3190
+ return True
3191
+
3192
+ def verify_struct(s: StructInstance, attr_path: list[str]):
3193
+ for key in s._cls.vars.keys():
3194
+ v = getattr(s, key)
3195
+ if issubclass(type(v), StructInstance):
3196
+ verify_struct(v, [*attr_path, key])
3197
+ else:
3198
+ try:
3199
+ adj.verify_static_return_value(v)
3200
+ except ValueError as e:
3201
+ raise ValueError(
3202
+ f"the returned Warp struct contains a data type that cannot be constructed inside Warp kernels: {e} at {value._cls.key}.{'.'.join(attr_path)}"
3203
+ ) from e
3204
+
3205
+ if issubclass(type(value), StructInstance):
3206
+ return verify_struct(value, [])
3207
+
3208
+ raise ValueError(f"value of type {type(value)} cannot be constructed inside Warp kernels")
3209
+
3210
+ # find the source code string of an AST node
3211
+ @staticmethod
3212
+ def extract_node_source_from_lines(source_lines, node) -> str | None:
3213
+ if not hasattr(node, "lineno") or not hasattr(node, "col_offset"):
3214
+ return None
3215
+
3216
+ start_line = node.lineno - 1 # line numbers start at 1
3217
+ start_col = node.col_offset
3218
+
3219
+ if hasattr(node, "end_lineno") and hasattr(node, "end_col_offset"):
3220
+ end_line = node.end_lineno - 1
3221
+ end_col = node.end_col_offset
3222
+ else:
3223
+ # fallback for Python versions before 3.8
3224
+ # we have to find the end line and column manually
3225
+ end_line = start_line
3226
+ end_col = start_col
3227
+ parenthesis_count = 1
3228
+ for lineno in range(start_line, len(source_lines)):
3229
+ if lineno == start_line:
3230
+ c_start = start_col
3231
+ else:
3232
+ c_start = 0
3233
+ line = source_lines[lineno]
3234
+ for i in range(c_start, len(line)):
3235
+ c = line[i]
3236
+ if c == "(":
3237
+ parenthesis_count += 1
3238
+ elif c == ")":
3239
+ parenthesis_count -= 1
3240
+ if parenthesis_count == 0:
3241
+ end_col = i
3242
+ end_line = lineno
3243
+ break
3244
+ if parenthesis_count == 0:
3245
+ break
3246
+
3247
+ if start_line == end_line:
3248
+ # single-line expression
3249
+ return source_lines[start_line][start_col:end_col]
3250
+ else:
3251
+ # multi-line expression
3252
+ lines = []
3253
+ # first line (from start_col to the end)
3254
+ lines.append(source_lines[start_line][start_col:])
3255
+ # middle lines (entire lines)
3256
+ lines.extend(source_lines[start_line + 1 : end_line])
3257
+ # last line (from the start to end_col)
3258
+ lines.append(source_lines[end_line][:end_col])
3259
+ return "".join(lines).strip()
3260
+
3261
+ @staticmethod
3262
+ def extract_lambda_source(func, only_body=False) -> str | None:
3263
+ try:
3264
+ source_lines = inspect.getsourcelines(func)[0]
3265
+ source_lines[0] = source_lines[0][source_lines[0].index("lambda") :]
3266
+ except OSError as e:
3267
+ raise WarpCodegenError(
3268
+ "Could not access lambda function source code. Please use a named function instead."
3269
+ ) from e
3270
+ source = "".join(source_lines)
3271
+ source = source[source.index("lambda") :].rstrip()
3272
+ # Remove trailing unbalanced parentheses
3273
+ while source.count("(") < source.count(")"):
3274
+ source = source[:-1]
3275
+ # extract lambda expression up until a comma, e.g. in the case of
3276
+ # "map(lambda a: (a + 2.0, a + 3.0), a, return_kernel=True)"
3277
+ si = max(source.rfind(")"), source.find(":"))
3278
+ ci = source.find(",", si)
3279
+ if ci != -1:
3280
+ source = source[:ci]
3281
+ tree = ast.parse(source)
3282
+ lambda_source = None
3283
+ for node in ast.walk(tree):
3284
+ if isinstance(node, ast.Lambda):
3285
+ if only_body:
3286
+ # extract the body of the lambda function
3287
+ lambda_source = Adjoint.extract_node_source_from_lines(source_lines, node.body)
3288
+ else:
3289
+ # extract the entire lambda function
3290
+ lambda_source = Adjoint.extract_node_source_from_lines(source_lines, node)
3291
+ break
3292
+ return lambda_source
3293
+
3294
+ def extract_node_source(adj, node) -> str | None:
3295
+ return adj.extract_node_source_from_lines(adj.source_lines, node)
3296
+
3297
+ # handles a wp.static() expression and returns the resulting object and a string representing the code
3298
+ # of the static expression
3299
+ def evaluate_static_expression(adj, node) -> tuple[Any, str]:
3300
+ if len(node.args) == 1:
3301
+ static_code = adj.extract_node_source(node.args[0])
3302
+ elif len(node.keywords) == 1:
3303
+ static_code = adj.extract_node_source(node.keywords[0])
3304
+ else:
3305
+ raise WarpCodegenError("warp.static() requires a single argument or keyword")
3306
+ if static_code is None:
3307
+ raise WarpCodegenError("Error extracting source code from wp.static() expression")
3308
+
3309
+ # Since this is an expression, we can enforce it to be defined on a single line.
3310
+ static_code = static_code.replace("\n", "")
3311
+ code_to_eval = static_code # code to be evaluated
3312
+
3313
+ vars_dict = adj.get_static_evaluation_context()
3314
+ # add constant variables to the static call context
3315
+ constant_vars = {k: v.constant for k, v in adj.symbols.items() if isinstance(v, Var) and v.constant is not None}
3316
+ vars_dict.update(constant_vars)
3317
+
3318
+ # Replace all constant `len()` expressions with their value.
3319
+ if "len" in static_code:
3320
+ len_expr_ctx = vars_dict.copy()
3321
+ constant_types = {k: v.type for k, v in adj.symbols.items() if isinstance(v, Var) and v.type is not None}
3322
+ len_expr_ctx.update(constant_types)
3323
+ len_expr_ctx.update({"len": warp._src.types.type_length})
3324
+
3325
+ # We want to replace the expression code in-place,
3326
+ # so reparse it to get the correct column info.
3327
+ len_value_locs: list[tuple[int, int, int]] = []
3328
+ expr_tree = ast.parse(static_code)
3329
+ assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr)
3330
+ expr_root = expr_tree.body[0].value
3331
+ for expr_node in ast.walk(expr_root):
3332
+ if (
3333
+ isinstance(expr_node, ast.Call)
3334
+ and getattr(expr_node.func, "id", None) == "len"
3335
+ and len(expr_node.args) == 1
3336
+ ):
3337
+ len_expr = static_code[expr_node.col_offset : expr_node.end_col_offset]
3338
+ try:
3339
+ len_value = eval(len_expr, len_expr_ctx)
3340
+ except Exception:
3341
+ pass
3342
+ else:
3343
+ len_value_locs.append((len_value, expr_node.col_offset, expr_node.end_col_offset))
3344
+
3345
+ if len_value_locs:
3346
+ new_static_code = ""
3347
+ loc = 0
3348
+ for value, start, end in len_value_locs:
3349
+ new_static_code += f"{static_code[loc:start]}{value}"
3350
+ loc = end
3351
+
3352
+ new_static_code += static_code[len_value_locs[-1][2] :]
3353
+ code_to_eval = new_static_code
3354
+
3355
+ try:
3356
+ value = eval(code_to_eval, vars_dict)
3357
+ if isinstance(value, (enum.IntEnum, enum.IntFlag)):
3358
+ value = int(value)
3359
+ if warp._src.config.verbose:
3360
+ print(f"Evaluated static command: {static_code} = {value}")
3361
+ except NameError as e:
3362
+ raise WarpCodegenError(
3363
+ f"Error evaluating static expression: {e}. Make sure all variables used in the static expression are constant."
3364
+ ) from e
3365
+ except Exception as e:
3366
+ raise WarpCodegenError(
3367
+ f"Error evaluating static expression: {e} while evaluating the following code generated from the static expression:\n{static_code}"
3368
+ ) from e
3369
+
3370
+ try:
3371
+ adj.verify_static_return_value(value)
3372
+ except ValueError as e:
3373
+ raise WarpCodegenError(
3374
+ f"Static expression returns an unsupported value: {e} while evaluating the following code generated from the static expression:\n{static_code}"
3375
+ ) from e
3376
+
3377
+ return value, static_code
3378
+
3379
+ # try to replace wp.static() expressions by their evaluated value if the
3380
+ # expression can be evaluated
3381
+ def replace_static_expressions(adj):
3382
+ class StaticExpressionReplacer(ast.NodeTransformer):
3383
+ def visit_Call(self, node):
3384
+ func, _ = adj.resolve_static_expression(node.func, eval_types=False)
3385
+ if adj.is_static_expression(func):
3386
+ try:
3387
+ # the static expression will execute as long as the static expression is valid and
3388
+ # only depends on global or captured variables
3389
+ obj, code = adj.evaluate_static_expression(node)
3390
+ if code is not None:
3391
+ adj.static_expressions[code] = obj
3392
+ if isinstance(obj, warp._src.context.Function):
3393
+ name_node = ast.Name("__warp_func__")
3394
+ # we add a pointer to the Warp function here so that we can refer to it later at
3395
+ # codegen time (note that the function key itself is not sufficient to uniquely
3396
+ # identify the function, as the function may be redefined between the current time
3397
+ # of wp.static() declaration and the time of codegen during module building)
3398
+ name_node.warp_func = obj
3399
+ return ast.copy_location(name_node, node)
3400
+ else:
3401
+ return ast.copy_location(ast.Constant(value=obj), node)
3402
+ except Exception:
3403
+ # Ignoring failing static expressions should generally not be an issue because only
3404
+ # one of these cases should be possible:
3405
+ # 1) the static expression itself is invalid code, in which case the module cannot be
3406
+ # built all,
3407
+ # 2) the static expression contains a reference to a local (even if constant) variable
3408
+ # (and is therefore not executable and raises this exception), in which
3409
+ # case changing the constant, or the code affecting this constant, would lead to
3410
+ # a different module hash anyway.
3411
+ # In any case, we mark this Adjoint to have unresolvable static expressions.
3412
+ # This will trigger a code generation step even if the module hash is unchanged.
3413
+ adj.has_unresolved_static_expressions = True
3414
+ pass
3415
+
3416
+ return self.generic_visit(node)
3417
+
3418
+ adj.tree = StaticExpressionReplacer().visit(adj.tree)
3419
+
3420
+ # Evaluates a static expression that does not depend on runtime values
3421
+ # if eval_types is True, try resolving the path using evaluated type information as well
3422
+ def resolve_static_expression(adj, root_node, eval_types=True):
3423
+ attributes = []
3424
+
3425
+ node = root_node
3426
+ while isinstance(node, ast.Attribute):
3427
+ attributes.append(node.attr)
3428
+ node = node.value
3429
+
3430
+ if eval_types and isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
3431
+ # support for operators returning modules
3432
+ # i.e. operator_name(*operator_args).x.y.z
3433
+ operator_args = node.args
3434
+ operator_name = node.func.id
3435
+
3436
+ if operator_name == "type":
3437
+ if len(operator_args) != 1:
3438
+ raise WarpCodegenError(f"type() operator expects exactly one argument, got {len(operator_args)}")
3439
+
3440
+ # type() operator
3441
+ var = adj.eval(operator_args[0])
3442
+
3443
+ if isinstance(var, Var):
3444
+ var_type = strip_reference(var.type)
3445
+ # Allow accessing type attributes, for instance array.dtype
3446
+ while attributes:
3447
+ attr_name = attributes.pop()
3448
+ var_type, prev_type = adj.resolve_type_attribute(var_type, attr_name), var_type
3449
+
3450
+ if var_type is None:
3451
+ raise WarpCodegenAttributeError(
3452
+ f"{attr_name} is not an attribute of {type_repr(prev_type)}"
3453
+ )
3454
+
3455
+ return var_type, [str(var_type)]
3456
+ else:
3457
+ raise WarpCodegenError(f"Cannot deduce the type of {var}")
3458
+
3459
+ # reverse list since ast presents it in backward order
3460
+ path = [*reversed(attributes)]
3461
+ if isinstance(node, ast.Name):
3462
+ path.insert(0, node.id)
3463
+
3464
+ # Try resolving path from captured context
3465
+ captured_obj = adj.resolve_path(path)
3466
+ if captured_obj is not None:
3467
+ return captured_obj, path
3468
+
3469
+ return None, path
3470
+
3471
+ def resolve_external_reference(adj, name: str):
3472
+ try:
3473
+ # look up in closure variables
3474
+ idx = adj.func.__code__.co_freevars.index(name)
3475
+ obj = adj.func.__closure__[idx].cell_contents
3476
+ except ValueError:
3477
+ # look up in global variables
3478
+ obj = adj.func.__globals__.get(name)
3479
+ return obj
3480
+
3481
+ # annotate generated code with the original source code line
3482
+ def set_lineno(adj, lineno):
3483
+ if adj.lineno is None or adj.lineno != lineno:
3484
+ line = lineno + adj.fun_lineno
3485
+ source = adj.source_lines[lineno].strip().ljust(80 - len(adj.indentation), " ")
3486
+ adj.add_forward(f"// {source} <L {line}>")
3487
+ adj.add_reverse(f"// adj: {source} <L {line}>")
3488
+ adj.lineno = lineno
3489
+
3490
+ def get_node_source(adj, node):
3491
+ # return the Python code corresponding to the given AST node
3492
+ return ast.get_source_segment(adj.source, node)
3493
+
3494
+ def get_references(adj) -> tuple[dict[str, Any], dict[Any, Any], dict[warp._src.context.Function, Any]]:
3495
+ """Traverses ``adj.tree`` and returns referenced constants, types, and user-defined functions."""
3496
+
3497
+ local_variables = set() # Track local variables appearing on the LHS so we know when variables are shadowed
3498
+
3499
+ constants: dict[str, Any] = {}
3500
+ types: dict[Struct | type, Any] = {}
3501
+ functions: dict[warp._src.context.Function, Any] = {}
3502
+
3503
+ for node in ast.walk(adj.tree):
3504
+ if isinstance(node, ast.Name) and node.id not in local_variables:
3505
+ # look up in closure/global variables
3506
+ obj = adj.resolve_external_reference(node.id)
3507
+ if warp._src.types.is_value(obj):
3508
+ constants[node.id] = obj
3509
+
3510
+ elif isinstance(node, ast.Attribute):
3511
+ obj, path = adj.resolve_static_expression(node, eval_types=False)
3512
+ if warp._src.types.is_value(obj):
3513
+ constants[".".join(path)] = obj
3514
+
3515
+ elif isinstance(node, ast.Call):
3516
+ func, _ = adj.resolve_static_expression(node.func, eval_types=False)
3517
+ if isinstance(func, warp._src.context.Function) and not func.is_builtin():
3518
+ # calling user-defined function
3519
+ functions[func] = None
3520
+ elif isinstance(func, Struct):
3521
+ # calling struct constructor
3522
+ types[func] = None
3523
+ elif isinstance(func, type) and warp._src.types.type_is_value(func):
3524
+ # calling value type constructor
3525
+ types[func] = None
3526
+
3527
+ elif isinstance(node, ast.Assign):
3528
+ # Add the LHS names to the local_variables so we know any subsequent uses are shadowed
3529
+ lhs = node.targets[0]
3530
+ if isinstance(lhs, ast.Tuple):
3531
+ for v in lhs.elts:
3532
+ if isinstance(v, ast.Name):
3533
+ local_variables.add(v.id)
3534
+ elif isinstance(lhs, ast.Name):
3535
+ local_variables.add(lhs.id)
3536
+
3537
+ return constants, types, functions
3538
+
3539
+
3540
+ # ----------------
3541
+ # code generation
3542
+
3543
+ cpu_module_header = """
3544
+ #define WP_TILE_BLOCK_DIM {block_dim}
3545
+ #define WP_NO_CRT
3546
+ #include "builtin.h"
3547
+
3548
+ // avoid namespacing of float type for casting to float type, this is to avoid wp::float(x), which is not valid in C++
3549
+ #define float(x) cast_float(x)
3550
+ #define adj_float(x, adj_x, adj_ret) adj_cast_float(x, adj_x, adj_ret)
3551
+
3552
+ #define int(x) cast_int(x)
3553
+ #define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
3554
+
3555
+ #define builtin_tid1d() wp::tid(task_index, dim)
3556
+ #define builtin_tid2d(x, y) wp::tid(x, y, task_index, dim)
3557
+ #define builtin_tid3d(x, y, z) wp::tid(x, y, z, task_index, dim)
3558
+ #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, task_index, dim)
3559
+
3560
+ #define builtin_block_dim() wp::block_dim()
3561
+
3562
+ """
3563
+
3564
+ cuda_module_header = """
3565
+ #define WP_TILE_BLOCK_DIM {block_dim}
3566
+ #define WP_NO_CRT
3567
+ #include "builtin.h"
3568
+
3569
+ // Map wp.breakpoint() to a device brkpt at the call site so cuda-gdb attributes the stop to the generated .cu line
3570
+ #if defined(__CUDACC__) && !defined(_MSC_VER)
3571
+ #define __debugbreak() __brkpt()
3572
+ #endif
3573
+
3574
+ // avoid namespacing of float type for casting to float type, this is to avoid wp::float(x), which is not valid in C++
3575
+ #define float(x) cast_float(x)
3576
+ #define adj_float(x, adj_x, adj_ret) adj_cast_float(x, adj_x, adj_ret)
3577
+
3578
+ #define int(x) cast_int(x)
3579
+ #define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
3580
+
3581
+ #define builtin_tid1d() wp::tid(_idx, dim)
3582
+ #define builtin_tid2d(x, y) wp::tid(x, y, _idx, dim)
3583
+ #define builtin_tid3d(x, y, z) wp::tid(x, y, z, _idx, dim)
3584
+ #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, _idx, dim)
3585
+
3586
+ #define builtin_block_dim() wp::block_dim()
3587
+
3588
+ """
3589
+
3590
+ struct_template = """
3591
+ struct {name}
3592
+ {{
3593
+ {struct_body}
3594
+
3595
+ {defaulted_constructor_def}
3596
+ CUDA_CALLABLE {name}({forward_args})
3597
+ {forward_initializers}
3598
+ {{
3599
+ }}
3600
+
3601
+ CUDA_CALLABLE {name}& operator += (const {name}& rhs)
3602
+ {{{prefix_add_body}
3603
+ return *this;}}
3604
+
3605
+ }};
3606
+
3607
+ static CUDA_CALLABLE void adj_{name}({reverse_args})
3608
+ {{
3609
+ {reverse_body}}}
3610
+
3611
+ // Required when compiling adjoints.
3612
+ CUDA_CALLABLE {name} add(const {name}& a, const {name}& b)
3613
+ {{
3614
+ return {name}();
3615
+ }}
3616
+
3617
+ CUDA_CALLABLE void adj_atomic_add({name}* p, {name} t)
3618
+ {{
3619
+ {atomic_add_body}}}
3620
+
3621
+
3622
+ """
3623
+
3624
+ cpu_forward_function_template = """
3625
+ // {filename}:{lineno}
3626
+ static {return_type} {name}(
3627
+ {forward_args})
3628
+ {{
3629
+ {forward_body}}}
3630
+
3631
+ """
3632
+
3633
+ cpu_reverse_function_template = """
3634
+ // {filename}:{lineno}
3635
+ static void adj_{name}(
3636
+ {reverse_args})
3637
+ {{
3638
+ {reverse_body}}}
3639
+
3640
+ """
3641
+
3642
+ cuda_forward_function_template = """
3643
+ // {filename}:{lineno}
3644
+ {line_directive}static CUDA_CALLABLE {return_type} {name}(
3645
+ {forward_args})
3646
+ {{
3647
+ {forward_body}{line_directive}}}
3648
+
3649
+ """
3650
+
3651
+ cuda_reverse_function_template = """
3652
+ // {filename}:{lineno}
3653
+ {line_directive}static CUDA_CALLABLE void adj_{name}(
3654
+ {reverse_args})
3655
+ {{
3656
+ {reverse_body}{line_directive}}}
3657
+
3658
+ """
3659
+
3660
+ cuda_kernel_template_forward = """
3661
+
3662
+ {line_directive}extern "C" __global__ void {name}_cuda_kernel_forward(
3663
+ {forward_args})
3664
+ {{
3665
+ {line_directive} wp::tile_shared_storage_t tile_mem;
3666
+
3667
+ {line_directive} for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3668
+ {line_directive} _idx < dim.size;
3669
+ {line_directive} _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3670
+ {{
3671
+ // reset shared memory allocator
3672
+ {line_directive} wp::tile_shared_storage_t::init();
3673
+
3674
+ {forward_body}{line_directive} }}
3675
+ {line_directive}}}
3676
+
3677
+ """
3678
+
3679
+ cuda_kernel_template_backward = """
3680
+
3681
+ {line_directive}extern "C" __global__ void {name}_cuda_kernel_backward(
3682
+ {reverse_args})
3683
+ {{
3684
+ {line_directive} wp::tile_shared_storage_t tile_mem;
3685
+
3686
+ {line_directive} for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3687
+ {line_directive} _idx < dim.size;
3688
+ {line_directive} _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3689
+ {{
3690
+ // reset shared memory allocator
3691
+ {line_directive} wp::tile_shared_storage_t::init();
3692
+
3693
+ {reverse_body}{line_directive} }}
3694
+ {line_directive}}}
3695
+
3696
+ """
3697
+
3698
+ cpu_kernel_template_forward = """
3699
+
3700
+ void {name}_cpu_kernel_forward(
3701
+ {forward_args},
3702
+ wp_args_{name} *_wp_args)
3703
+ {{
3704
+ {forward_body}}}
3705
+
3706
+ """
3707
+
3708
+ cpu_kernel_template_backward = """
3709
+
3710
+ void {name}_cpu_kernel_backward(
3711
+ {reverse_args},
3712
+ wp_args_{name} *_wp_args,
3713
+ wp_args_{name} *_wp_adj_args)
3714
+ {{
3715
+ {reverse_body}}}
3716
+
3717
+ """
3718
+
3719
+ cpu_module_template_forward = """
3720
+
3721
+ extern "C" {{
3722
+
3723
+ // Python CPU entry points
3724
+ WP_API void {name}_cpu_forward(
3725
+ wp::launch_bounds_t dim,
3726
+ wp_args_{name} *_wp_args)
3727
+ {{
3728
+ wp::tile_shared_storage_t tile_mem;
3729
+ #if defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
3730
+ wp::shared_tile_storage = &tile_mem;
3731
+ #endif
3732
+
3733
+ for (size_t task_index = 0; task_index < dim.size; ++task_index)
3734
+ {{
3735
+ {name}_cpu_kernel_forward(dim, task_index, _wp_args);
3736
+ }}
3737
+ }}
3738
+
3739
+ }} // extern C
3740
+
3741
+ """
3742
+
3743
+ cpu_module_template_backward = """
3744
+
3745
+ extern "C" {{
3746
+
3747
+ WP_API void {name}_cpu_backward(
3748
+ wp::launch_bounds_t dim,
3749
+ wp_args_{name} *_wp_args,
3750
+ wp_args_{name} *_wp_adj_args)
3751
+ {{
3752
+ wp::tile_shared_storage_t tile_mem;
3753
+ #if defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
3754
+ wp::shared_tile_storage = &tile_mem;
3755
+ #endif
3756
+
3757
+ for (size_t task_index = 0; task_index < dim.size; ++task_index)
3758
+ {{
3759
+ {name}_cpu_kernel_backward(dim, task_index, _wp_args, _wp_adj_args);
3760
+ }}
3761
+ }}
3762
+
3763
+ }} // extern C
3764
+
3765
+ """
3766
+
3767
+
3768
+ # converts a constant Python value to equivalent C-repr
3769
+ def constant_str(value):
3770
+ value_type = type(value)
3771
+
3772
+ if value_type == bool or value_type == builtins.bool:
3773
+ if value:
3774
+ return "true"
3775
+ else:
3776
+ return "false"
3777
+
3778
+ elif value_type == str:
3779
+ # ensure constant strings are correctly escaped
3780
+ return '"' + str(value.encode("unicode-escape").decode()) + '"'
3781
+
3782
+ elif isinstance(value, ctypes.Array):
3783
+ if value_type._wp_scalar_type_ == float16:
3784
+ # special case for float16, which is stored as uint16 in the ctypes.Array
3785
+ from warp._src.context import runtime
3786
+
3787
+ scalar_value = runtime.core.wp_half_bits_to_float
3788
+ else:
3789
+
3790
+ def scalar_value(x):
3791
+ return x
3792
+
3793
+ # list of scalar initializer values
3794
+ initlist = []
3795
+ for i in range(value._length_):
3796
+ x = ctypes.Array.__getitem__(value, i)
3797
+ initlist.append(str(scalar_value(x)).lower())
3798
+
3799
+ if value._wp_scalar_type_ is bool:
3800
+ dtypestr = f"wp::initializer_array<{value._length_},{value._wp_scalar_type_.__name__}>"
3801
+ else:
3802
+ dtypestr = f"wp::initializer_array<{value._length_},wp::{value._wp_scalar_type_.__name__}>"
3803
+
3804
+ # construct value from initializer array, e.g. wp::initializer_array<4,wp::float32>{1.0, 2.0, 3.0, 4.0}
3805
+ return f"{dtypestr}{{{', '.join(initlist)}}}"
3806
+
3807
+ elif value_type in warp._src.types.scalar_types:
3808
+ # make sure we emit the value of objects, e.g. uint32
3809
+ return str(value.value)
3810
+
3811
+ elif issubclass(value_type, StructInstance):
3812
+ # constant struct instance
3813
+ arg_strs = []
3814
+ for key, var in value._cls.vars.items():
3815
+ attr = getattr(value, key)
3816
+ arg_strs.append(f"{Var.type_to_ctype(var.type)}({constant_str(attr)})")
3817
+ arg_str = ", ".join(arg_strs)
3818
+ return f"{value.native_name}({arg_str})"
3819
+
3820
+ elif value == math.inf:
3821
+ return "INFINITY"
3822
+
3823
+ elif math.isnan(value):
3824
+ return "NAN"
3825
+
3826
+ else:
3827
+ # otherwise just convert constant to string
3828
+ return str(value)
3829
+
3830
+
3831
+ def indent(args, stops=1):
3832
+ sep = ",\n"
3833
+ for _i in range(stops):
3834
+ sep += " "
3835
+
3836
+ # return sep + args.replace(", ", "," + sep)
3837
+ return sep.join(args)
3838
+
3839
+
3840
+ # generates a C function name based on the python function name
3841
+ def make_full_qualified_name(func: Union[str, Callable]) -> str:
3842
+ if not isinstance(func, str):
3843
+ func = func.__qualname__
3844
+ return re.sub("[^0-9a-zA-Z_]+", "", func.replace(".", "__"))
3845
+
3846
+
3847
+ def codegen_struct(struct, device="cpu", indent_size=4):
3848
+ name = struct.native_name
3849
+
3850
+ body = []
3851
+ indent_block = " " * indent_size
3852
+
3853
+ if len(struct.vars) > 0:
3854
+ for label, var in struct.vars.items():
3855
+ body.append(var.ctype() + " " + label + ";\n")
3856
+ else:
3857
+ # for empty structs, emit the dummy attribute to avoid any compiler-specific alignment issues
3858
+ body.append("char _dummy_;\n")
3859
+
3860
+ forward_args = []
3861
+ reverse_args = []
3862
+
3863
+ forward_initializers = []
3864
+ reverse_body = []
3865
+ atomic_add_body = []
3866
+ prefix_add_body = []
3867
+
3868
+ # forward args
3869
+ for label, var in struct.vars.items():
3870
+ var_ctype = var.ctype()
3871
+ default_arg_def = " = {}" if forward_args else ""
3872
+ forward_args.append(f"{var_ctype} const& {label}{default_arg_def}")
3873
+ reverse_args.append(f"{var_ctype} const&")
3874
+
3875
+ namespace = "wp::" if var_ctype.startswith("wp::") or var_ctype == "bool" else ""
3876
+ atomic_add_body.append(f"{indent_block}{namespace}adj_atomic_add(&p->{label}, t.{label});\n")
3877
+
3878
+ prefix = f"{indent_block}," if forward_initializers else ":"
3879
+ forward_initializers.append(f"{indent_block}{prefix} {label}{{{label}}}\n")
3880
+
3881
+ # prefix-add operator
3882
+ for label, var in struct.vars.items():
3883
+ if not is_array(var.type):
3884
+ prefix_add_body.append(f"{indent_block}{label} += rhs.{label};\n")
3885
+
3886
+ # reverse args
3887
+ for label, var in struct.vars.items():
3888
+ reverse_args.append(var.ctype() + " & adj_" + label)
3889
+ if is_array(var.type):
3890
+ reverse_body.append(f"{indent_block}adj_{label} = adj_ret.{label};\n")
3891
+ else:
3892
+ reverse_body.append(f"{indent_block}adj_{label} += adj_ret.{label};\n")
3893
+
3894
+ reverse_args.append(name + " & adj_ret")
3895
+
3896
+ # explicitly defaulted default constructor if no default constructor has been defined
3897
+ defaulted_constructor_def = f"{name}() = default;" if forward_args else ""
3898
+
3899
+ return struct_template.format(
3900
+ name=name,
3901
+ struct_body="".join([indent_block + l for l in body]),
3902
+ forward_args=indent(forward_args),
3903
+ forward_initializers="".join(forward_initializers),
3904
+ reverse_args=indent(reverse_args),
3905
+ reverse_body="".join(reverse_body),
3906
+ prefix_add_body="".join(prefix_add_body),
3907
+ atomic_add_body="".join(atomic_add_body),
3908
+ defaulted_constructor_def=defaulted_constructor_def,
3909
+ )
3910
+
3911
+
3912
+ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
3913
+ if device == "cpu":
3914
+ indent = 4
3915
+ elif device == "cuda":
3916
+ if func_type == "kernel":
3917
+ indent = 8
3918
+ else:
3919
+ indent = 4
3920
+ else:
3921
+ raise ValueError(f"Device {device} not supported for codegen")
3922
+
3923
+ indent_block = " " * indent
3924
+
3925
+ lines = []
3926
+
3927
+ # argument vars
3928
+ if device == "cpu" and func_type == "kernel":
3929
+ lines += ["//---------\n"]
3930
+ lines += ["// argument vars\n"]
3931
+
3932
+ for var in adj.args:
3933
+ lines += [f"{var.ctype()} {var.emit()} = _wp_args->{var.label};\n"]
3934
+
3935
+ # primal vars
3936
+ lines += ["//---------\n"]
3937
+ lines += ["// primal vars\n"]
3938
+
3939
+ for var in adj.variables:
3940
+ if is_tile(var.type):
3941
+ lines += [f"{var.ctype()} {var.emit()} = {var.type.cinit(requires_grad=False)};\n"]
3942
+ elif var.constant is None:
3943
+ lines += [f"{var.ctype()} {var.emit()};\n"]
3944
+ else:
3945
+ lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
3946
+
3947
+ if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
3948
+ lines.insert(-1, f"{line_directive}\n")
3949
+
3950
+ # forward pass
3951
+ lines += ["//---------\n"]
3952
+ lines += ["// forward\n"]
3953
+
3954
+ for f in adj.blocks[0].body_forward:
3955
+ if func_type == "kernel" and device == "cuda" and f.lstrip().startswith("return;"):
3956
+ # Use of grid-stride loops in CUDA kernels requires that we convert return; to continue;
3957
+ lines += [f.replace("return;", "continue;") + "\n"]
3958
+ else:
3959
+ lines += [f + "\n"]
3960
+
3961
+ return "".join(l.lstrip() if l.lstrip().startswith("#line") else indent_block + l for l in lines)
3962
+
3963
+
3964
+ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
3965
+ if device == "cpu":
3966
+ indent = 4
3967
+ elif device == "cuda":
3968
+ if func_type == "kernel":
3969
+ indent = 8
3970
+ else:
3971
+ indent = 4
3972
+ else:
3973
+ raise ValueError(f"Device {device} not supported for codegen")
3974
+
3975
+ indent_block = " " * indent
3976
+
3977
+ lines = []
3978
+
3979
+ # argument vars
3980
+ if device == "cpu" and func_type == "kernel":
3981
+ lines += ["//---------\n"]
3982
+ lines += ["// argument vars\n"]
3983
+
3984
+ for var in adj.args:
3985
+ lines += [f"{var.ctype()} {var.emit()} = _wp_args->{var.label};\n"]
3986
+
3987
+ for var in adj.args:
3988
+ lines += [f"{var.ctype()} {var.emit_adj()} = _wp_adj_args->{var.label};\n"]
3989
+
3990
+ # primal vars
3991
+ lines += ["//---------\n"]
3992
+ lines += ["// primal vars\n"]
3993
+
3994
+ for var in adj.variables:
3995
+ if is_tile(var.type):
3996
+ lines += [f"{var.ctype()} {var.emit()} = {var.type.cinit(requires_grad=True)};\n"]
3997
+ elif var.constant is None:
3998
+ lines += [f"{var.ctype()} {var.emit()};\n"]
3999
+ else:
4000
+ lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
4001
+
4002
+ if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
4003
+ lines.insert(-1, f"{line_directive}\n")
4004
+
4005
+ # dual vars
4006
+ lines += ["//---------\n"]
4007
+ lines += ["// dual vars\n"]
4008
+
4009
+ for var in adj.variables:
4010
+ name = var.emit_adj()
4011
+ ctype = var.ctype(value_type=True)
4012
+
4013
+ if is_tile(var.type):
4014
+ if var.type.storage == "register":
4015
+ lines += [
4016
+ f"{var.type.ctype()} {name}(0.0);\n"
4017
+ ] # reverse mode tiles alias the forward vars since shared tiles store both primal/dual vars together
4018
+ elif var.type.storage == "shared":
4019
+ lines += [
4020
+ f"{var.type.ctype()}& {name} = {var.emit()};\n"
4021
+ ] # reverse mode tiles alias the forward vars since shared tiles store both primal/dual vars together
4022
+ else:
4023
+ lines += [f"{ctype} {name} = {{}};\n"]
4024
+
4025
+ if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
4026
+ lines.insert(-1, f"{line_directive}\n")
4027
+
4028
+ # forward pass
4029
+ lines += ["//---------\n"]
4030
+ lines += ["// forward\n"]
4031
+
4032
+ for f in adj.blocks[0].body_replay:
4033
+ lines += [f + "\n"]
4034
+
4035
+ # reverse pass
4036
+ lines += ["//---------\n"]
4037
+ lines += ["// reverse\n"]
4038
+
4039
+ for l in reversed(adj.blocks[0].body_reverse):
4040
+ lines += [l + "\n"]
4041
+
4042
+ # In grid-stride kernels the reverse body is in a for loop
4043
+ if device == "cuda" and func_type == "kernel":
4044
+ lines += ["continue;\n"]
4045
+ else:
4046
+ lines += ["return;\n"]
4047
+
4048
+ return "".join(l.lstrip() if l.lstrip().startswith("#line") else indent_block + l for l in lines)
4049
+
4050
+
4051
+ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
4052
+ if options is None:
4053
+ options = {}
4054
+
4055
+ if adj.return_var is not None and "return" in adj.arg_types:
4056
+ if get_origin(adj.arg_types["return"]) is tuple:
4057
+ if len(get_args(adj.arg_types["return"])) != len(adj.return_var):
4058
+ raise WarpCodegenError(
4059
+ f"The function `{adj.fun_name}` has its return type "
4060
+ f"annotated as a tuple of {len(get_args(adj.arg_types['return']))} elements "
4061
+ f"but the code returns {len(adj.return_var)} values."
4062
+ )
4063
+ elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var), match_generic=True):
4064
+ raise WarpCodegenError(
4065
+ f"The function `{adj.fun_name}` has its return type "
4066
+ f"annotated as `{warp._src.context.type_str(adj.arg_types['return'])}` "
4067
+ f"but the code returns a tuple with types `({', '.join(warp._src.context.type_str(x.type) for x in adj.return_var)})`."
4068
+ )
4069
+ elif len(adj.return_var) > 1 and get_origin(adj.arg_types["return"]) is not tuple:
4070
+ raise WarpCodegenError(
4071
+ f"The function `{adj.fun_name}` has its return type "
4072
+ f"annotated as `{warp._src.context.type_str(adj.arg_types['return'])}` "
4073
+ f"but the code returns {len(adj.return_var)} values."
4074
+ )
4075
+ elif not types_equal(adj.arg_types["return"], adj.return_var[0].type):
4076
+ raise WarpCodegenError(
4077
+ f"The function `{adj.fun_name}` has its return type "
4078
+ f"annotated as `{warp._src.context.type_str(adj.arg_types['return'])}` "
4079
+ f"but the code returns a value of type `{warp._src.context.type_str(adj.return_var[0].type)}`."
4080
+ )
4081
+ elif (
4082
+ isinstance(adj.return_var[0].type, warp._src.types.fixedarray)
4083
+ and type(adj.arg_types["return"]) is warp._src.types.array
4084
+ ):
4085
+ # If the return statement yields a `fixedarray` while the function is annotated
4086
+ # to return a standard `array`, then raise an error since the `fixedarray` storage
4087
+ # allocated on the stack will be freed once the function exits, meaning that the
4088
+ # resulting `array` instance will point to an invalid data.
4089
+ raise WarpCodegenError(
4090
+ f"The function `{adj.fun_name}` returns a fixed-size array "
4091
+ f"whereas it has its return type annotated as "
4092
+ f"`{warp._src.context.type_str(adj.arg_types['return'])}`."
4093
+ )
4094
+
4095
+ # Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
4096
+ # This is used as a catch-all C-to-Python source line mapping for any code that does not have
4097
+ # a direct mapping to a Python source line.
4098
+ func_line_directive = ""
4099
+ if line_directive := adj.get_line_directive("", adj.fun_def_lineno - 1):
4100
+ func_line_directive = f"{line_directive}\n"
4101
+
4102
+ # forward header
4103
+ if adj.return_var is not None and len(adj.return_var) == 1:
4104
+ return_type = adj.return_var[0].ctype()
4105
+ else:
4106
+ return_type = "void"
4107
+
4108
+ has_multiple_outputs = adj.return_var is not None and len(adj.return_var) != 1
4109
+
4110
+ forward_args = []
4111
+ reverse_args = []
4112
+
4113
+ # forward args
4114
+ for i, arg in enumerate(adj.args):
4115
+ s = f"{arg.ctype()} {arg.emit()}"
4116
+ forward_args.append(s)
4117
+ if not adj.custom_reverse_mode or i < adj.custom_reverse_num_input_args:
4118
+ reverse_args.append(s)
4119
+ if has_multiple_outputs:
4120
+ for i, arg in enumerate(adj.return_var):
4121
+ forward_args.append(arg.ctype() + " & ret_" + str(i))
4122
+ reverse_args.append(arg.ctype() + " & ret_" + str(i))
4123
+
4124
+ # reverse args
4125
+ for i, arg in enumerate(adj.args):
4126
+ if adj.custom_reverse_mode and i >= adj.custom_reverse_num_input_args:
4127
+ break
4128
+ # indexed array gradients are regular arrays
4129
+ if isinstance(arg.type, indexedarray):
4130
+ _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
4131
+ reverse_args.append(_arg.ctype() + " & adj_" + arg.label)
4132
+ else:
4133
+ reverse_args.append(arg.ctype() + " & adj_" + arg.label)
4134
+ if has_multiple_outputs:
4135
+ for i, arg in enumerate(adj.return_var):
4136
+ reverse_args.append(arg.ctype() + " & adj_ret_" + str(i))
4137
+ elif return_type != "void":
4138
+ reverse_args.append(return_type + " & adj_ret")
4139
+ # custom output reverse args (user-declared)
4140
+ if adj.custom_reverse_mode:
4141
+ for arg in adj.args[adj.custom_reverse_num_input_args :]:
4142
+ reverse_args.append(f"{arg.ctype()} & {arg.emit()}")
4143
+
4144
+ if device == "cpu":
4145
+ forward_template = cpu_forward_function_template
4146
+ reverse_template = cpu_reverse_function_template
4147
+ elif device == "cuda":
4148
+ forward_template = cuda_forward_function_template
4149
+ reverse_template = cuda_reverse_function_template
4150
+ else:
4151
+ raise ValueError(f"Device {device} is not supported")
4152
+
4153
+ # codegen body
4154
+ forward_body = codegen_func_forward(adj, func_type="function", device=device)
4155
+
4156
+ s = ""
4157
+ if not adj.skip_forward_codegen:
4158
+ s += forward_template.format(
4159
+ name=c_func_name,
4160
+ return_type=return_type,
4161
+ forward_args=indent(forward_args),
4162
+ forward_body=forward_body,
4163
+ filename=adj.filename,
4164
+ lineno=adj.fun_lineno,
4165
+ line_directive=func_line_directive,
4166
+ )
4167
+
4168
+ if not adj.skip_reverse_codegen:
4169
+ if adj.custom_reverse_mode:
4170
+ reverse_body = "\t// user-defined adjoint code\n" + forward_body
4171
+ else:
4172
+ if options.get("enable_backward", True) and adj.used_by_backward_kernel:
4173
+ reverse_body = codegen_func_reverse(adj, func_type="function", device=device)
4174
+ else:
4175
+ reverse_body = '\t// reverse mode disabled (module option "enable_backward" is False or no dependent kernel found with "enable_backward")\n'
4176
+ s += reverse_template.format(
4177
+ name=c_func_name,
4178
+ return_type=return_type,
4179
+ reverse_args=indent(reverse_args),
4180
+ forward_body=forward_body,
4181
+ reverse_body=reverse_body,
4182
+ filename=adj.filename,
4183
+ lineno=adj.fun_lineno,
4184
+ line_directive=func_line_directive,
4185
+ )
4186
+
4187
+ return s
4188
+
4189
+
4190
+ def codegen_snippet(adj, name, snippet, adj_snippet, replay_snippet):
4191
+ if adj.return_var is not None and len(adj.return_var) == 1:
4192
+ return_type = adj.return_var[0].ctype()
4193
+ else:
4194
+ return_type = "void"
4195
+
4196
+ forward_args = []
4197
+ reverse_args = []
4198
+
4199
+ # forward args
4200
+ for _i, arg in enumerate(adj.args):
4201
+ s = f"{arg.ctype()} {arg.emit().replace('var_', '')}"
4202
+ forward_args.append(s)
4203
+ reverse_args.append(s)
4204
+
4205
+ # reverse args
4206
+ for _i, arg in enumerate(adj.args):
4207
+ if isinstance(arg.type, indexedarray):
4208
+ _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
4209
+ reverse_args.append(_arg.ctype() + " & adj_" + arg.label)
4210
+ else:
4211
+ reverse_args.append(arg.ctype() + " & adj_" + arg.label)
4212
+ if return_type != "void":
4213
+ reverse_args.append(return_type + " & adj_ret")
4214
+
4215
+ forward_template = cuda_forward_function_template
4216
+ replay_template = cuda_forward_function_template
4217
+ reverse_template = cuda_reverse_function_template
4218
+
4219
+ s = ""
4220
+ s += forward_template.format(
4221
+ name=name,
4222
+ return_type=return_type,
4223
+ forward_args=indent(forward_args),
4224
+ forward_body=snippet,
4225
+ filename=adj.filename,
4226
+ lineno=adj.fun_lineno,
4227
+ line_directive="",
4228
+ )
4229
+
4230
+ if replay_snippet is not None:
4231
+ s += replay_template.format(
4232
+ name="replay_" + name,
4233
+ return_type=return_type,
4234
+ forward_args=indent(forward_args),
4235
+ forward_body=replay_snippet,
4236
+ filename=adj.filename,
4237
+ lineno=adj.fun_lineno,
4238
+ line_directive="",
4239
+ )
4240
+
4241
+ if adj_snippet:
4242
+ reverse_body = adj_snippet
4243
+ else:
4244
+ reverse_body = ""
4245
+
4246
+ s += reverse_template.format(
4247
+ name=name,
4248
+ return_type=return_type,
4249
+ reverse_args=indent(reverse_args),
4250
+ forward_body=snippet,
4251
+ reverse_body=reverse_body,
4252
+ filename=adj.filename,
4253
+ lineno=adj.fun_lineno,
4254
+ line_directive="",
4255
+ )
4256
+
4257
+ return s
4258
+
4259
+
4260
+ def codegen_kernel(kernel, device, options):
4261
+ # Update the module's options with the ones defined on the kernel, if any.
4262
+ options = dict(options)
4263
+ options.update(kernel.options)
4264
+
4265
+ adj = kernel.adj
4266
+
4267
+ args_struct = ""
4268
+ if device == "cpu":
4269
+ args_struct = f"struct wp_args_{kernel.get_mangled_name()} {{\n"
4270
+ for i in adj.args:
4271
+ args_struct += f" {i.ctype()} {i.label};\n"
4272
+ args_struct += "};\n"
4273
+
4274
+ # Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
4275
+ # This is used as a catch-all C-to-Python source line mapping for any code that does not have
4276
+ # a direct mapping to a Python source line.
4277
+ func_line_directive = ""
4278
+ if line_directive := adj.get_line_directive("", adj.fun_def_lineno - 1):
4279
+ func_line_directive = f"{line_directive}\n"
4280
+
4281
+ if device == "cpu":
4282
+ template_forward = cpu_kernel_template_forward
4283
+ template_backward = cpu_kernel_template_backward
4284
+ elif device == "cuda":
4285
+ template_forward = cuda_kernel_template_forward
4286
+ template_backward = cuda_kernel_template_backward
4287
+ else:
4288
+ raise ValueError(f"Device {device} is not supported")
4289
+
4290
+ template = ""
4291
+ template_fmt_args = {
4292
+ "name": kernel.get_mangled_name(),
4293
+ }
4294
+
4295
+ # build forward signature
4296
+ forward_args = ["wp::launch_bounds_t dim"]
4297
+ if device == "cpu":
4298
+ forward_args.append("size_t task_index")
4299
+ else:
4300
+ for arg in adj.args:
4301
+ forward_args.append(arg.ctype() + " var_" + arg.label)
4302
+
4303
+ forward_body = codegen_func_forward(adj, func_type="kernel", device=device)
4304
+ template_fmt_args.update(
4305
+ {
4306
+ "forward_args": indent(forward_args),
4307
+ "forward_body": forward_body,
4308
+ "line_directive": func_line_directive,
4309
+ }
4310
+ )
4311
+ template += template_forward
4312
+
4313
+ if options["enable_backward"]:
4314
+ # build reverse signature
4315
+ reverse_args = ["wp::launch_bounds_t dim"]
4316
+ if device == "cpu":
4317
+ reverse_args.append("size_t task_index")
4318
+ else:
4319
+ for arg in adj.args:
4320
+ reverse_args.append(arg.ctype() + " var_" + arg.label)
4321
+ for arg in adj.args:
4322
+ # indexed array gradients are regular arrays
4323
+ if isinstance(arg.type, indexedarray):
4324
+ _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
4325
+ reverse_args.append(_arg.ctype() + " adj_" + arg.label)
4326
+ else:
4327
+ reverse_args.append(arg.ctype() + " adj_" + arg.label)
4328
+
4329
+ reverse_body = codegen_func_reverse(adj, func_type="kernel", device=device)
4330
+ template_fmt_args.update(
4331
+ {
4332
+ "reverse_args": indent(reverse_args),
4333
+ "reverse_body": reverse_body,
4334
+ }
4335
+ )
4336
+ template += template_backward
4337
+
4338
+ s = template.format(**template_fmt_args)
4339
+ return args_struct + s
4340
+
4341
+
4342
+ def codegen_module(kernel, device, options):
4343
+ if device != "cpu":
4344
+ return ""
4345
+
4346
+ # Update the module's options with the ones defined on the kernel, if any.
4347
+ options = dict(options)
4348
+ options.update(kernel.options)
4349
+
4350
+ template = ""
4351
+ template_fmt_args = {
4352
+ "name": kernel.get_mangled_name(),
4353
+ }
4354
+
4355
+ template += cpu_module_template_forward
4356
+
4357
+ if options["enable_backward"]:
4358
+ template += cpu_module_template_backward
4359
+
4360
+ s = template.format(**template_fmt_args)
4361
+ return s