warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2220 -313
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1497 -226
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +1 -1
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +313 -201
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +14 -14
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +529 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/codegen.py CHANGED
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,4252 +13,12 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from __future__ import annotations
16
+ # TODO: Remove after cleaning up the public API.
17
17
 
18
- import ast
19
- import builtins
20
- import ctypes
21
- import enum
22
- import functools
23
- import hashlib
24
- import inspect
25
- import itertools
26
- import math
27
- import re
28
- import sys
29
- import textwrap
30
- import types
31
- from typing import Any, Callable, ClassVar, Mapping, Sequence, get_args, get_origin
18
+ from warp._src import codegen as _codegen
32
19
 
33
- import warp.config
34
- from warp.types import *
35
20
 
36
- # used as a globally accessible copy
37
- # of current compile options (block_dim) etc
38
- options = {}
21
+ def __getattr__(name):
22
+ from warp._src.utils import get_deprecated_api
39
23
 
40
-
41
- class WarpCodegenError(RuntimeError):
42
- def __init__(self, message):
43
- super().__init__(message)
44
-
45
-
46
- class WarpCodegenTypeError(TypeError):
47
- def __init__(self, message):
48
- super().__init__(message)
49
-
50
-
51
- class WarpCodegenAttributeError(AttributeError):
52
- def __init__(self, message):
53
- super().__init__(message)
54
-
55
-
56
- class WarpCodegenKeyError(KeyError):
57
- def __init__(self, message):
58
- super().__init__(message)
59
-
60
-
61
- # map operator to function name
62
- builtin_operators: dict[type[ast.AST], str] = {}
63
-
64
- # see https://www.ics.uci.edu/~pattis/ICS-31/lectures/opexp.pdf for a
65
- # nice overview of python operators
66
-
67
- builtin_operators[ast.Add] = "add"
68
- builtin_operators[ast.Sub] = "sub"
69
- builtin_operators[ast.Mult] = "mul"
70
- builtin_operators[ast.MatMult] = "mul"
71
- builtin_operators[ast.Div] = "div"
72
- builtin_operators[ast.FloorDiv] = "floordiv"
73
- builtin_operators[ast.Pow] = "pow"
74
- builtin_operators[ast.Mod] = "mod"
75
- builtin_operators[ast.UAdd] = "pos"
76
- builtin_operators[ast.USub] = "neg"
77
- builtin_operators[ast.Not] = "unot"
78
-
79
- builtin_operators[ast.Gt] = ">"
80
- builtin_operators[ast.Lt] = "<"
81
- builtin_operators[ast.GtE] = ">="
82
- builtin_operators[ast.LtE] = "<="
83
- builtin_operators[ast.Eq] = "=="
84
- builtin_operators[ast.NotEq] = "!="
85
-
86
- builtin_operators[ast.BitAnd] = "bit_and"
87
- builtin_operators[ast.BitOr] = "bit_or"
88
- builtin_operators[ast.BitXor] = "bit_xor"
89
- builtin_operators[ast.Invert] = "invert"
90
- builtin_operators[ast.LShift] = "lshift"
91
- builtin_operators[ast.RShift] = "rshift"
92
-
93
- comparison_chain_strings = [
94
- builtin_operators[ast.Gt],
95
- builtin_operators[ast.Lt],
96
- builtin_operators[ast.LtE],
97
- builtin_operators[ast.GtE],
98
- builtin_operators[ast.Eq],
99
- builtin_operators[ast.NotEq],
100
- ]
101
-
102
-
103
- def values_check_equal(a, b):
104
- if isinstance(a, Sequence) and isinstance(b, Sequence):
105
- if len(a) != len(b):
106
- return False
107
-
108
- return all(x == y for x, y in zip(a, b))
109
-
110
- return a == b
111
-
112
-
113
- def op_str_is_chainable(op: str) -> builtins.bool:
114
- return op in comparison_chain_strings
115
-
116
-
117
- def get_closure_cell_contents(obj):
118
- """Retrieve a closure's cell contents or `None` if it's empty."""
119
- try:
120
- return obj.cell_contents
121
- except ValueError:
122
- pass
123
-
124
- return None
125
-
126
-
127
- def eval_annotations(annotations: Mapping[str, Any], obj: Any) -> Mapping[str, Any]:
128
- """Un-stringize annotations caused by `from __future__ import annotations` of PEP 563."""
129
- # Implementation backported from `inspect.get_annotations()` for Python 3.9 and older.
130
- if not annotations:
131
- return {}
132
-
133
- if not any(isinstance(x, str) for x in annotations.values()):
134
- # No annotation to un-stringize.
135
- return annotations
136
-
137
- if isinstance(obj, type):
138
- # class
139
- globals = {}
140
- module_name = getattr(obj, "__module__", None)
141
- if module_name:
142
- module = sys.modules.get(module_name, None)
143
- if module:
144
- globals = getattr(module, "__dict__", {})
145
- locals = dict(vars(obj))
146
- unwrap = obj
147
- elif isinstance(obj, types.ModuleType):
148
- # module
149
- globals = obj.__dict__
150
- locals = {}
151
- unwrap = None
152
- elif callable(obj):
153
- # function
154
- globals = getattr(obj, "__globals__", {})
155
- # Capture the variables from the surrounding scope.
156
- closure_vars = zip(
157
- obj.__code__.co_freevars, tuple(get_closure_cell_contents(x) for x in (obj.__closure__ or ()))
158
- )
159
- locals = {k: v for k, v in closure_vars if v is not None}
160
- unwrap = obj
161
- else:
162
- raise TypeError(f"{obj!r} is not a module, class, or callable.")
163
-
164
- if unwrap is not None:
165
- while True:
166
- if hasattr(unwrap, "__wrapped__"):
167
- unwrap = unwrap.__wrapped__
168
- continue
169
- if isinstance(unwrap, functools.partial):
170
- unwrap = unwrap.func
171
- continue
172
- break
173
- if hasattr(unwrap, "__globals__"):
174
- globals = unwrap.__globals__
175
-
176
- # "Inject" type parameters into the local namespace
177
- # (unless they are shadowed by assignments *in* the local namespace),
178
- # as a way of emulating annotation scopes when calling `eval()`
179
- type_params = getattr(obj, "__type_params__", ())
180
- if type_params:
181
- locals = {param.__name__: param for param in type_params} | locals
182
-
183
- return {k: v if not isinstance(v, str) else eval(v, globals, locals) for k, v in annotations.items()}
184
-
185
-
186
- def get_annotations(obj: Any) -> Mapping[str, Any]:
187
- """Same as `inspect.get_annotations()` but always returning un-stringized annotations."""
188
- # This backports `inspect.get_annotations()` for Python 3.9 and older.
189
- # See https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
190
- if isinstance(obj, type):
191
- annotations = obj.__dict__.get("__annotations__", {})
192
- else:
193
- annotations = getattr(obj, "__annotations__", {})
194
-
195
- # Evaluating annotations can be done using the `eval_str` parameter with
196
- # the official function from the `inspect` module.
197
- return eval_annotations(annotations, obj)
198
-
199
-
200
- def get_full_arg_spec(func: Callable) -> inspect.FullArgSpec:
201
- """Same as `inspect.getfullargspec()` but always returning un-stringized annotations."""
202
- # See https://docs.python.org/3/howto/annotations.html#manually-un-stringizing-stringized-annotations
203
- spec = inspect.getfullargspec(func)
204
- return spec._replace(annotations=eval_annotations(spec.annotations, func))
205
-
206
-
207
- def struct_instance_repr_recursive(inst: StructInstance, depth: int, use_repr: bool) -> str:
208
- indent = "\t"
209
-
210
- # handle empty structs
211
- if len(inst._cls.vars) == 0:
212
- return f"{inst._cls.key}()"
213
-
214
- lines = []
215
- lines.append(f"{inst._cls.key}(")
216
-
217
- for field_name, _ in inst._cls.ctype._fields_:
218
- field_value = getattr(inst, field_name, None)
219
-
220
- if isinstance(field_value, StructInstance):
221
- field_value = struct_instance_repr_recursive(field_value, depth + 1, use_repr)
222
-
223
- if use_repr:
224
- lines.append(f"{indent * (depth + 1)}{field_name}={field_value!r},")
225
- else:
226
- lines.append(f"{indent * (depth + 1)}{field_name}={field_value!s},")
227
-
228
- lines.append(f"{indent * depth})")
229
- return "\n".join(lines)
230
-
231
-
232
- class StructInstance:
233
- def __init__(self, cls: Struct, ctype):
234
- super().__setattr__("_cls", cls)
235
-
236
- # maintain a c-types object for the top-level instance the struct
237
- if not ctype:
238
- super().__setattr__("_ctype", cls.ctype())
239
- else:
240
- super().__setattr__("_ctype", ctype)
241
-
242
- # create Python attributes for each of the struct's variables
243
- for field, var in cls.vars.items():
244
- if isinstance(var.type, warp.codegen.Struct):
245
- self.__dict__[field] = var.type.instance_type(ctype=getattr(self._ctype, field))
246
- elif isinstance(var.type, warp.types.array):
247
- self.__dict__[field] = None
248
- else:
249
- self.__dict__[field] = var.type()
250
-
251
- def __getattribute__(self, name):
252
- cls = super().__getattribute__("_cls")
253
- if name == "native_name":
254
- return cls.native_name
255
-
256
- var = cls.vars.get(name)
257
- if var is not None:
258
- if isinstance(var.type, type) and issubclass(var.type, ctypes.Array):
259
- # Each field stored in a `StructInstance` is exposed as
260
- # a standard Python attribute but also has a `ctypes`
261
- # equivalent that is being updated in `__setattr__`.
262
- # However, when assigning in place an object such as a vec/mat
263
- # (e.g.: `my_struct.my_vec[0] = 1.23`), the `__setattr__` method
264
- # from `StructInstance` isn't called, and the synchronization
265
- # mechanism has no chance of updating the underlying ctype data.
266
- # As a workaround, we catch here all attempts at accessing such
267
- # objects and directly return their underlying ctype since
268
- # the Python-facing Warp vectors and matrices are implemented
269
- # using `ctypes.Array` anyways.
270
- return getattr(self._ctype, name)
271
-
272
- return super().__getattribute__(name)
273
-
274
- def __setattr__(self, name, value):
275
- if name not in self._cls.vars:
276
- raise RuntimeError(f"Trying to set Warp struct attribute that does not exist {name}")
277
-
278
- var = self._cls.vars[name]
279
-
280
- # update our ctype flat copy
281
- if isinstance(var.type, array):
282
- if value is None:
283
- # create array with null pointer
284
- setattr(self._ctype, name, array_t())
285
- else:
286
- # wp.array
287
- assert isinstance(value, array)
288
- assert types_equal(value.dtype, var.type.dtype), (
289
- f"assign to struct member variable {name} failed, expected type {type_repr(var.type.dtype)}, got type {type_repr(value.dtype)}"
290
- )
291
- setattr(self._ctype, name, value.__ctype__())
292
-
293
- # workaround to prevent gradient buffers being garbage collected
294
- # since users can do struct.array.requires_grad = False the gradient array
295
- # would be collected while the struct ctype still holds a reference to it
296
- super().__setattr__("_" + name + "_grad", value.grad)
297
-
298
- elif isinstance(var.type, Struct):
299
- # assign structs by-value, otherwise we would have problematic cases transferring ownership
300
- # of the underlying ctypes data between shared Python struct instances
301
-
302
- if not isinstance(value, StructInstance):
303
- raise RuntimeError(
304
- f"Trying to assign a non-structure value to a struct attribute with type: {self._cls.key}"
305
- )
306
-
307
- # destination attribution on self
308
- dest = getattr(self, name)
309
-
310
- if dest._cls.key is not value._cls.key:
311
- raise RuntimeError(
312
- f"Trying to assign a structure of type {value._cls.key} to an attribute of {self._cls.key}"
313
- )
314
-
315
- # update all nested ctype vars by deep copy
316
- for n in dest._cls.vars:
317
- setattr(dest, n, getattr(value, n))
318
-
319
- # early return to avoid updating our Python StructInstance
320
- return
321
-
322
- elif issubclass(var.type, ctypes.Array):
323
- # vector/matrix type, e.g. vec3
324
- if value is None:
325
- setattr(self._ctype, name, var.type())
326
- elif type(value) == var.type:
327
- setattr(self._ctype, name, value)
328
- else:
329
- # conversion from list/tuple, ndarray, etc.
330
- setattr(self._ctype, name, var.type(value))
331
-
332
- else:
333
- # primitive type
334
- if value is None:
335
- # zero initialize
336
- setattr(self._ctype, name, var.type._type_())
337
- else:
338
- if hasattr(value, "_type_"):
339
- # assigning warp type value (e.g.: wp.float32)
340
- value = value.value
341
- # float16 needs conversion to uint16 bits
342
- if var.type == warp.float16:
343
- setattr(self._ctype, name, float_to_half_bits(value))
344
- else:
345
- setattr(self._ctype, name, value)
346
-
347
- # update Python instance
348
- super().__setattr__(name, value)
349
-
350
- def __ctype__(self):
351
- return self._ctype
352
-
353
- def __repr__(self):
354
- return struct_instance_repr_recursive(self, 0, use_repr=True)
355
-
356
- def __str__(self):
357
- return struct_instance_repr_recursive(self, 0, use_repr=False)
358
-
359
- def to(self, device):
360
- """Copies this struct with all array members moved onto the given device.
361
-
362
- Arrays already living on the desired device are referenced as-is, while
363
- arrays being moved are copied.
364
- """
365
- out = self._cls()
366
- stack = [(self, out, k, v) for k, v in self._cls.vars.items()]
367
- while stack:
368
- src, dst, name, var = stack.pop()
369
- value = getattr(src, name)
370
- if isinstance(var.type, array):
371
- # array_t
372
- setattr(dst, name, value.to(device))
373
- elif isinstance(var.type, Struct):
374
- # nested struct
375
- new_struct = value._cls()
376
- setattr(dst, name, new_struct)
377
- # The call to `setattr()` just above makes a copy of `new_struct`
378
- # so we need to reference that new instance of the struct.
379
- new_struct = getattr(dst, name)
380
- stack.extend((value, new_struct, k, v) for k, v in value._cls.vars.items())
381
- else:
382
- setattr(dst, name, value)
383
-
384
- return out
385
-
386
- # type description used in numpy structured arrays
387
- def numpy_dtype(self):
388
- return self._cls.numpy_dtype()
389
-
390
- # value usable in numpy structured arrays of .numpy_dtype(), e.g. (42, 13.37, [1.0, 2.0, 3.0])
391
- def numpy_value(self):
392
- npvalue = []
393
- for name, var in self._cls.vars.items():
394
- # get the attribute value
395
- value = getattr(self._ctype, name)
396
-
397
- if isinstance(var.type, array):
398
- # array_t
399
- npvalue.append(value.numpy_value())
400
- elif isinstance(var.type, Struct):
401
- # nested struct
402
- npvalue.append(value.numpy_value())
403
- elif issubclass(var.type, ctypes.Array):
404
- if len(var.type._shape_) == 1:
405
- # vector
406
- npvalue.append(list(value))
407
- else:
408
- # matrix
409
- npvalue.append([list(row) for row in value])
410
- else:
411
- # scalar
412
- if var.type == warp.float16:
413
- npvalue.append(half_bits_to_float(value))
414
- else:
415
- npvalue.append(value)
416
-
417
- return tuple(npvalue)
418
-
419
-
420
- class Struct:
421
- hash: bytes
422
-
423
- def __init__(self, key: str, cls: type, module: warp.context.Module):
424
- self.key = key
425
- self.cls = cls
426
- self.module = module
427
- self.vars: dict[str, Var] = {}
428
-
429
- if isinstance(self.cls, Sequence):
430
- raise RuntimeError("Warp structs must be defined as base classes")
431
-
432
- annotations = get_annotations(self.cls)
433
- for label, type in annotations.items():
434
- self.vars[label] = Var(label, type)
435
-
436
- fields = []
437
- for label, var in self.vars.items():
438
- if isinstance(var.type, array):
439
- fields.append((label, array_t))
440
- elif isinstance(var.type, Struct):
441
- fields.append((label, var.type.ctype))
442
- elif issubclass(var.type, ctypes.Array):
443
- fields.append((label, var.type))
444
- else:
445
- # HACK: fp16 requires conversion functions from warp.so
446
- if var.type is warp.float16:
447
- warp.init()
448
- fields.append((label, var.type._type_))
449
-
450
- class StructType(ctypes.Structure):
451
- # if struct is empty, add a dummy field to avoid launch errors on CPU device ("ffi_prep_cif failed")
452
- _fields_ = fields or [("_dummy_", ctypes.c_byte)]
453
-
454
- self.ctype = StructType
455
-
456
- # Compute the hash. We can cache the hash because it's static, even with nested structs.
457
- # All field types are specified in the annotations, so they're resolved at declaration time.
458
- ch = hashlib.sha256()
459
-
460
- ch.update(bytes(self.key, "utf-8"))
461
-
462
- for name, type_hint in annotations.items():
463
- s = f"{name}:{warp.types.get_type_code(type_hint)}"
464
- ch.update(bytes(s, "utf-8"))
465
-
466
- # recurse on nested structs
467
- if isinstance(type_hint, Struct):
468
- ch.update(type_hint.hash)
469
-
470
- self.hash = ch.digest()
471
-
472
- # generate unique identifier for structs in native code
473
- hash_suffix = f"{self.hash.hex()[:8]}"
474
- self.native_name = f"{self.key}_{hash_suffix}"
475
-
476
- # create default constructor (zero-initialize)
477
- self.default_constructor = warp.context.Function(
478
- func=None,
479
- key=self.native_name,
480
- namespace="",
481
- value_func=lambda *_: self,
482
- input_types={},
483
- initializer_list_func=lambda *_: False,
484
- native_func=self.native_name,
485
- )
486
-
487
- # build a constructor that takes each param as a value
488
- input_types = {label: var.type for label, var in self.vars.items()}
489
-
490
- self.value_constructor = warp.context.Function(
491
- func=None,
492
- key=self.native_name,
493
- namespace="",
494
- value_func=lambda *_: self,
495
- input_types=input_types,
496
- initializer_list_func=lambda *_: False,
497
- native_func=self.native_name,
498
- )
499
-
500
- self.default_constructor.add_overload(self.value_constructor)
501
-
502
- if isinstance(module, warp.context.Module):
503
- module.register_struct(self)
504
-
505
- # Define class for instances of this struct
506
- # To enable autocomplete on s, we inherit from self.cls.
507
- # For example,
508
-
509
- # @wp.struct
510
- # class A:
511
- # # annotations
512
- # ...
513
-
514
- # The type annotations are inherited in A(), allowing autocomplete in kernels
515
- class NewStructInstance(self.cls, StructInstance):
516
- def __init__(inst, ctype=None):
517
- StructInstance.__init__(inst, self, ctype)
518
-
519
- # make sure warp.types.get_type_code works with this StructInstance
520
- NewStructInstance.cls = self.cls
521
- NewStructInstance.native_name = self.native_name
522
-
523
- self.instance_type = NewStructInstance
524
-
525
- def __call__(self):
526
- """
527
- This function returns s = StructInstance(self)
528
- s uses self.cls as template.
529
- """
530
- return self.instance_type()
531
-
532
- def initializer(self):
533
- return self.default_constructor
534
-
535
- # return structured NumPy dtype, including field names, formats, and offsets
536
- def numpy_dtype(self):
537
- names = []
538
- formats = []
539
- offsets = []
540
- for name, var in self.vars.items():
541
- names.append(name)
542
- offsets.append(getattr(self.ctype, name).offset)
543
- if isinstance(var.type, array):
544
- # array_t
545
- formats.append(array_t.numpy_dtype())
546
- elif isinstance(var.type, Struct):
547
- # nested struct
548
- formats.append(var.type.numpy_dtype())
549
- elif issubclass(var.type, ctypes.Array):
550
- scalar_typestr = type_typestr(var.type._wp_scalar_type_)
551
- if len(var.type._shape_) == 1:
552
- # vector
553
- formats.append(f"{var.type._length_}{scalar_typestr}")
554
- else:
555
- # matrix
556
- formats.append(f"{var.type._shape_}{scalar_typestr}")
557
- else:
558
- # scalar
559
- formats.append(type_typestr(var.type))
560
-
561
- return {"names": names, "formats": formats, "offsets": offsets, "itemsize": ctypes.sizeof(self.ctype)}
562
-
563
- # constructs a Warp struct instance from a pointer to the ctype
564
- def from_ptr(self, ptr):
565
- if not ptr:
566
- raise RuntimeError("NULL pointer exception")
567
-
568
- # create a new struct instance
569
- instance = self()
570
-
571
- for name, var in self.vars.items():
572
- offset = getattr(self.ctype, name).offset
573
- if isinstance(var.type, array):
574
- # We could reconstruct wp.array from array_t, but it's problematic.
575
- # There's no guarantee that the original wp.array is still allocated and
576
- # no easy way to make a backref.
577
- # Instead, we just create a stub annotation, which is not a fully usable array object.
578
- setattr(instance, name, array(dtype=var.type.dtype, ndim=var.type.ndim))
579
- elif isinstance(var.type, Struct):
580
- # nested struct
581
- value = var.type.from_ptr(ptr + offset)
582
- setattr(instance, name, value)
583
- elif issubclass(var.type, ctypes.Array):
584
- # vector/matrix
585
- value = var.type.from_ptr(ptr + offset)
586
- setattr(instance, name, value)
587
- else:
588
- # scalar
589
- cvalue = ctypes.cast(ptr + offset, ctypes.POINTER(var.type._type_)).contents
590
- if var.type == warp.float16:
591
- setattr(instance, name, half_bits_to_float(cvalue))
592
- else:
593
- setattr(instance, name, cvalue.value)
594
-
595
- return instance
596
-
597
-
598
- class Reference:
599
- def __init__(self, value_type):
600
- self.value_type = value_type
601
-
602
-
603
- def is_reference(type: Any) -> builtins.bool:
604
- return isinstance(type, Reference)
605
-
606
-
607
- def strip_reference(arg: Any) -> Any:
608
- if is_reference(arg):
609
- return arg.value_type
610
- else:
611
- return arg
612
-
613
-
614
- def compute_type_str(base_name, template_params):
615
- if not template_params:
616
- return base_name
617
-
618
- def param2str(p):
619
- if isinstance(p, builtins.bool):
620
- return "true" if p else "false"
621
- if isinstance(p, int):
622
- return str(p)
623
- elif hasattr(p, "_wp_generic_type_str_"):
624
- return compute_type_str(f"wp::{p._wp_generic_type_str_}", p._wp_type_params_)
625
- elif hasattr(p, "_type_"):
626
- if p.__name__ == "bool":
627
- return "bool"
628
- else:
629
- return f"wp::{p.__name__}"
630
- elif is_tile(p):
631
- return p.ctype()
632
- elif isinstance(p, Struct):
633
- return p.native_name
634
-
635
- return p.__name__
636
-
637
- return f"{base_name}<{', '.join(map(param2str, template_params))}>"
638
-
639
-
640
- class Var:
641
- def __init__(
642
- self,
643
- label: str,
644
- type: type,
645
- requires_grad: builtins.bool = False,
646
- constant: builtins.bool | None = None,
647
- prefix: builtins.bool = True,
648
- relative_lineno: int | None = None,
649
- ):
650
- # convert built-in types to wp types
651
- if type == float:
652
- type = float32
653
- elif type == int:
654
- type = int32
655
- elif type == builtins.bool:
656
- type = bool
657
-
658
- self.label = label
659
- self.type = type
660
- self.requires_grad = requires_grad
661
- self.constant = constant
662
- self.prefix = prefix
663
-
664
- # records whether this Var has been read from in a kernel function (array only)
665
- self.is_read = False
666
- # records whether this Var has been written to in a kernel function (array only)
667
- self.is_write = False
668
-
669
- # used to associate a view array Var with its parent array Var
670
- self.parent = None
671
-
672
- # Used to associate the variable with the Python statement that resulted in it being created.
673
- self.relative_lineno = relative_lineno
674
-
675
- def __str__(self):
676
- return self.label
677
-
678
- @staticmethod
679
- def dtype_to_ctype(t: type) -> str:
680
- if hasattr(t, "_wp_generic_type_str_"):
681
- return compute_type_str(f"wp::{t._wp_generic_type_str_}", t._wp_type_params_)
682
- elif isinstance(t, Struct):
683
- return t.native_name
684
- elif hasattr(t, "_wp_native_name_"):
685
- return f"wp::{t._wp_native_name_}"
686
- elif t.__name__ in ("bool", "int", "float"):
687
- return t.__name__
688
-
689
- return f"wp::{t.__name__}"
690
-
691
- @staticmethod
692
- def type_to_ctype(t: type, value_type: builtins.bool = False) -> str:
693
- if isinstance(t, fixedarray):
694
- template_args = (str(t.size), Var.dtype_to_ctype(t.dtype))
695
- dtypestr = ", ".join(template_args)
696
- classstr = f"wp::{type(t).__name__}"
697
- return f"{classstr}_t<{dtypestr}>"
698
- elif is_array(t):
699
- dtypestr = Var.dtype_to_ctype(t.dtype)
700
- classstr = f"wp::{type(t).__name__}"
701
- return f"{classstr}_t<{dtypestr}>"
702
- elif get_origin(t) is tuple:
703
- dtypestr = ", ".join(Var.dtype_to_ctype(x) for x in get_args(t))
704
- return f"wp::tuple_t<{dtypestr}>"
705
- elif is_tuple(t):
706
- dtypestr = ", ".join(Var.dtype_to_ctype(x) for x in t.types)
707
- classstr = f"wp::{type(t).__name__}"
708
- return f"{classstr}<{dtypestr}>"
709
- elif is_tile(t):
710
- return t.ctype()
711
- elif isinstance(t, type) and issubclass(t, StructInstance):
712
- # ensure the actual Struct name is used instead of "NewStructInstance"
713
- return t.native_name
714
- elif is_reference(t):
715
- if not value_type:
716
- return Var.type_to_ctype(t.value_type) + "*"
717
-
718
- return Var.type_to_ctype(t.value_type)
719
-
720
- return Var.dtype_to_ctype(t)
721
-
722
- def ctype(self, value_type: builtins.bool = False) -> str:
723
- return Var.type_to_ctype(self.type, value_type)
724
-
725
- def emit(self, prefix: str = "var"):
726
- if self.prefix:
727
- return f"{prefix}_{self.label}"
728
- else:
729
- return self.label
730
-
731
- def emit_adj(self):
732
- return self.emit("adj")
733
-
734
- def mark_read(self):
735
- """Marks this Var as having been read from in a kernel (array only)."""
736
- if not is_array(self.type):
737
- return
738
-
739
- self.is_read = True
740
-
741
- # recursively update all parent states
742
- parent = self.parent
743
- while parent is not None:
744
- parent.is_read = True
745
- parent = parent.parent
746
-
747
- def mark_write(self, **kwargs):
748
- """Marks this Var has having been written to in a kernel (array only)."""
749
- if not is_array(self.type):
750
- return
751
-
752
- # detect if we are writing to an array after reading from it within the same kernel
753
- if self.is_read and warp.config.verify_autograd_array_access:
754
- if "kernel_name" and "filename" and "lineno" in kwargs:
755
- print(
756
- f"Warning: Array passed to argument {self.label} in kernel {kwargs['kernel_name']} at {kwargs['filename']}:{kwargs['lineno']} is being written to after it has been read from within the same kernel. This may corrupt gradient computation in the backward pass."
757
- )
758
- else:
759
- print(
760
- f"Warning: Array {self} is being written to after it has been read from within the same kernel. This may corrupt gradient computation in the backward pass."
761
- )
762
- self.is_write = True
763
-
764
- # recursively update all parent states
765
- parent = self.parent
766
- while parent is not None:
767
- parent.is_write = True
768
- parent = parent.parent
769
-
770
-
771
- class Block:
772
- # Represents a basic block of instructions, e.g.: list
773
- # of straight line instructions inside a for-loop or conditional
774
-
775
- def __init__(self):
776
- # list of statements inside this block
777
- self.body_forward = []
778
- self.body_replay = []
779
- self.body_reverse = []
780
-
781
- # list of vars declared in this block
782
- self.vars = []
783
-
784
-
785
- def apply_defaults(
786
- bound_args: inspect.BoundArguments,
787
- values: Mapping[str, Any],
788
- ):
789
- # Similar to Python's `inspect.BoundArguments.apply_defaults()`
790
- # but with the possibility to pass an augmented set of default values.
791
- arguments = bound_args.arguments
792
- new_arguments = []
793
- for name in bound_args._signature.parameters.keys():
794
- if name in arguments:
795
- new_arguments.append((name, arguments[name]))
796
- elif name in values:
797
- new_arguments.append((name, values[name]))
798
-
799
- bound_args.arguments = dict(new_arguments)
800
-
801
-
802
- def func_match_args(func, arg_types, kwarg_types):
803
- try:
804
- # Try to bind the given arguments to the function's signature.
805
- # This is not checking whether the argument types are matching,
806
- # rather it's just assigning each argument to the corresponding
807
- # function parameter.
808
- bound_arg_types = func.signature.bind(*arg_types, **kwarg_types)
809
- except TypeError:
810
- return False
811
-
812
- # Populate the bound arguments with any default values.
813
- default_arg_types = {
814
- k: None if v is None else get_arg_type(v)
815
- for k, v in func.defaults.items()
816
- if k not in bound_arg_types.arguments
817
- }
818
- apply_defaults(bound_arg_types, default_arg_types)
819
- bound_arg_types = tuple(bound_arg_types.arguments.values())
820
-
821
- # Check the given argument types against the ones defined on the function.
822
- for bound_arg_type, func_arg_type in zip(bound_arg_types, func.input_types.values()):
823
- # Let the `value_func` callback infer the type.
824
- if bound_arg_type is None:
825
- continue
826
-
827
- # if arg type registered as Any, treat as
828
- # template allowing any type to match
829
- if func_arg_type == Any:
830
- continue
831
-
832
- # handle function refs as a special case
833
- if func_arg_type == Callable and isinstance(bound_arg_type, warp.context.Function):
834
- continue
835
-
836
- # check arg type matches input variable type
837
- if not types_equal(func_arg_type, strip_reference(bound_arg_type), match_generic=True):
838
- return False
839
-
840
- return True
841
-
842
-
843
- def get_arg_type(arg: Var | Any) -> type:
844
- if isinstance(arg, str):
845
- return str
846
-
847
- if isinstance(arg, Sequence):
848
- return tuple(get_arg_type(x) for x in arg)
849
-
850
- if is_array(arg):
851
- return arg
852
-
853
- if get_origin(arg) is tuple:
854
- return tuple(get_arg_type(x) for x in get_args(arg))
855
-
856
- if is_tuple(arg):
857
- return arg
858
-
859
- if isinstance(arg, (type, warp.context.Function)):
860
- return arg
861
-
862
- if isinstance(arg, Var):
863
- if get_origin(arg.type) is tuple:
864
- return get_args(arg.type)
865
-
866
- return arg.type
867
-
868
- return type(arg)
869
-
870
-
871
- def get_arg_value(arg: Any) -> Any:
872
- if isinstance(arg, Sequence):
873
- return tuple(get_arg_value(x) for x in arg)
874
-
875
- if isinstance(arg, (type, warp.context.Function)):
876
- return arg
877
-
878
- if isinstance(arg, Var):
879
- if is_tuple(arg.type):
880
- return tuple(get_arg_value(x) for x in arg.type.values)
881
-
882
- if arg.constant is not None:
883
- return arg.constant
884
-
885
- return arg
886
-
887
-
888
- class Adjoint:
889
- # Source code transformer, this class takes a Python function and
890
- # generates forward and backward SSA forms of the function instructions
891
-
892
- def __init__(
893
- adj,
894
- func: Callable[..., Any],
895
- overload_annotations=None,
896
- is_user_function=False,
897
- skip_forward_codegen=False,
898
- skip_reverse_codegen=False,
899
- custom_reverse_mode=False,
900
- custom_reverse_num_input_args=-1,
901
- transformers: list[ast.NodeTransformer] | None = None,
902
- source: str | None = None,
903
- ):
904
- adj.func = func
905
-
906
- adj.is_user_function = is_user_function
907
-
908
- # whether the generation of the forward code is skipped for this function
909
- adj.skip_forward_codegen = skip_forward_codegen
910
- # whether the generation of the adjoint code is skipped for this function
911
- adj.skip_reverse_codegen = skip_reverse_codegen
912
- # Whether this function is used by a kernel that has has the backward pass enabled.
913
- adj.used_by_backward_kernel = False
914
-
915
- # extract name of source file
916
- adj.filename = inspect.getsourcefile(func) or "unknown source file"
917
- # get source file line number where function starts
918
- adj.fun_lineno = 0
919
- adj.source = source
920
- if adj.source is None:
921
- adj.source, adj.fun_lineno = adj.extract_function_source(func)
922
-
923
- assert adj.source is not None, f"Failed to extract source code for function {func.__name__}"
924
-
925
- # Indicates where the function definition starts (excludes decorators)
926
- adj.fun_def_lineno = None
927
-
928
- # get function source code
929
- # ensures that indented class methods can be parsed as kernels
930
- adj.source = textwrap.dedent(adj.source)
931
-
932
- adj.source_lines = adj.source.splitlines()
933
-
934
- if transformers is None:
935
- transformers = []
936
-
937
- # build AST and apply node transformers
938
- adj.tree = ast.parse(adj.source)
939
- adj.transformers = transformers
940
- for transformer in transformers:
941
- adj.tree = transformer.visit(adj.tree)
942
-
943
- adj.fun_name = adj.tree.body[0].name
944
-
945
- # for keeping track of line number in function code
946
- adj.lineno = None
947
-
948
- # whether the forward code shall be used for the reverse pass and a custom
949
- # function signature is applied to the reverse version of the function
950
- adj.custom_reverse_mode = custom_reverse_mode
951
- # the number of function arguments that pertain to the forward function
952
- # input arguments (i.e. the number of arguments that are not adjoint arguments)
953
- adj.custom_reverse_num_input_args = custom_reverse_num_input_args
954
-
955
- # parse argument types
956
- argspec = get_full_arg_spec(func)
957
-
958
- # ensure all arguments are annotated
959
- if overload_annotations is None:
960
- # use source-level argument annotations
961
- if len(argspec.annotations) < len(argspec.args):
962
- raise WarpCodegenError(f"Incomplete argument annotations on function {adj.fun_name}")
963
- adj.arg_types = {k: v for k, v in argspec.annotations.items() if not (k == "return" and v is None)}
964
- else:
965
- # use overload argument annotations
966
- for arg_name in argspec.args:
967
- if arg_name not in overload_annotations:
968
- raise WarpCodegenError(f"Incomplete overload annotations for function {adj.fun_name}")
969
- adj.arg_types = overload_annotations.copy()
970
-
971
- adj.args = []
972
- adj.symbols = {}
973
-
974
- for name, type in adj.arg_types.items():
975
- # skip return hint
976
- if name == "return":
977
- continue
978
-
979
- # add variable for argument
980
- arg = Var(name, type, requires_grad=False)
981
- adj.args.append(arg)
982
-
983
- # pre-populate symbol dictionary with function argument names
984
- # this is to avoid registering false references to overshadowed modules
985
- adj.symbols[name] = arg
986
-
987
- # Indicates whether there are unresolved static expressions in the function.
988
- # These stem from wp.static() expressions that could not be evaluated at declaration time.
989
- # This will signal to the module builder that this module needs to be rebuilt even if the module hash is unchanged.
990
- adj.has_unresolved_static_expressions = False
991
-
992
- # try to replace static expressions by their constant result if the
993
- # expression can be evaluated at declaration time
994
- adj.static_expressions: dict[str, Any] = {}
995
- if "static" in adj.source:
996
- adj.replace_static_expressions()
997
-
998
- # There are cases where a same module might be rebuilt multiple times,
999
- # for example when kernels are nested inside of functions, or when
1000
- # a kernel's launch raises an exception. Ideally we'd always want to
1001
- # avoid rebuilding kernels but some corner cases seem to depend on it,
1002
- # so we only avoid rebuilding kernels that errored out to give a chance
1003
- # for unit testing errors being spit out from kernels.
1004
- adj.skip_build = False
1005
-
1006
- # allocate extra space for a function call that requires its
1007
- # own shared memory space, we treat shared memory as a stack
1008
- # where each function pushes and pops space off, the extra
1009
- # quantity is the 'roofline' amount required for the entire kernel
1010
- def alloc_shared_extra(adj, num_bytes):
1011
- adj.max_required_extra_shared_memory = max(adj.max_required_extra_shared_memory, num_bytes)
1012
-
1013
- # returns the total number of bytes for a function
1014
- # based on it's own requirements + worst case
1015
- # requirements of any dependent functions
1016
- def get_total_required_shared(adj):
1017
- total_shared = 0
1018
-
1019
- for var in adj.variables:
1020
- if is_tile(var.type) and var.type.storage == "shared" and var.type.owner:
1021
- total_shared += var.type.size_in_bytes()
1022
-
1023
- return total_shared + adj.max_required_extra_shared_memory
1024
-
1025
- @staticmethod
1026
- def extract_function_source(func: Callable) -> tuple[str, int]:
1027
- try:
1028
- _, fun_lineno = inspect.getsourcelines(func)
1029
- source = inspect.getsource(func)
1030
- except OSError as e:
1031
- raise RuntimeError(
1032
- "Directly evaluating Warp code defined as a string using `exec()` is not supported, "
1033
- "please save it to a file and use `importlib` if needed."
1034
- ) from e
1035
- return source, fun_lineno
1036
-
1037
- # generate function ssa form and adjoint
1038
- def build(adj, builder, default_builder_options=None):
1039
- # arg Var read/write flags are held during module rebuilds, so we reset here even when skipping a build
1040
- for arg in adj.args:
1041
- arg.is_read = False
1042
- arg.is_write = False
1043
-
1044
- if adj.skip_build:
1045
- return
1046
-
1047
- adj.builder = builder
1048
-
1049
- if default_builder_options is None:
1050
- default_builder_options = {}
1051
-
1052
- if adj.builder:
1053
- adj.builder_options = adj.builder.options
1054
- else:
1055
- adj.builder_options = default_builder_options
1056
-
1057
- global options
1058
- options = adj.builder_options
1059
-
1060
- adj.symbols = {} # map from symbols to adjoint variables
1061
- adj.variables = [] # list of local variables (in order)
1062
-
1063
- adj.return_var = None # return type for function or kernel
1064
- adj.loop_symbols = [] # symbols at the start of each loop
1065
- adj.loop_const_iter_symbols = (
1066
- set()
1067
- ) # constant iteration variables for static loops (mutating them does not raise an error)
1068
-
1069
- # blocks
1070
- adj.blocks = [Block()]
1071
- adj.loop_blocks = []
1072
-
1073
- # holds current indent level
1074
- adj.indentation = ""
1075
-
1076
- # used to generate new label indices
1077
- adj.label_count = 0
1078
-
1079
- # tracks how much additional shared memory is required by any dependent function calls
1080
- adj.max_required_extra_shared_memory = 0
1081
-
1082
- # update symbol map for each argument
1083
- for a in adj.args:
1084
- adj.symbols[a.label] = a
1085
-
1086
- # recursively evaluate function body
1087
- try:
1088
- adj.eval(adj.tree.body[0])
1089
- except Exception as original_exc:
1090
- try:
1091
- lineno = adj.lineno + adj.fun_lineno
1092
- line = adj.source_lines[adj.lineno]
1093
- msg = f'Error while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
1094
-
1095
- # Combine the new message with the original exception's arguments
1096
- new_args = (";".join([msg] + [str(a) for a in original_exc.args]),)
1097
-
1098
- # Enhance the original exception with parser context before re-raising.
1099
- # 'from None' is used to suppress Python's chained exceptions for a cleaner error output.
1100
- raise type(original_exc)(*new_args).with_traceback(original_exc.__traceback__) from None
1101
- finally:
1102
- adj.skip_build = True
1103
- adj.builder = None
1104
-
1105
- if builder is not None:
1106
- for a in adj.args:
1107
- if isinstance(a.type, Struct):
1108
- builder.build_struct_recursive(a.type)
1109
- elif isinstance(a.type, warp.types.array) and isinstance(a.type.dtype, Struct):
1110
- builder.build_struct_recursive(a.type.dtype)
1111
-
1112
- # release builder reference for GC
1113
- adj.builder = None
1114
-
1115
- # code generation methods
1116
- def format_template(adj, template, input_vars, output_var):
1117
- # output var is always the 0th index
1118
- args = [output_var, *input_vars]
1119
- s = template.format(*args)
1120
-
1121
- return s
1122
-
1123
- # generates a list of formatted args
1124
- def format_args(adj, prefix, args):
1125
- arg_strs = []
1126
-
1127
- for a in args:
1128
- if isinstance(a, warp.context.Function):
1129
- # functions don't have a var_ prefix so strip it off here
1130
- if prefix == "var":
1131
- arg_strs.append(f"{a.namespace}{a.native_func}")
1132
- else:
1133
- arg_strs.append(f"{a.namespace}{prefix}_{a.native_func}")
1134
- elif is_reference(a.type):
1135
- arg_strs.append(f"{prefix}_{a}")
1136
- elif isinstance(a, Var):
1137
- arg_strs.append(a.emit(prefix))
1138
- else:
1139
- raise WarpCodegenTypeError(f"Arguments must be variables or functions, got {type(a)}")
1140
-
1141
- return arg_strs
1142
-
1143
- # generates argument string for a forward function call
1144
- def format_forward_call_args(adj, args, use_initializer_list):
1145
- arg_str = ", ".join(adj.format_args("var", args))
1146
- if use_initializer_list:
1147
- return f"{{{arg_str}}}"
1148
- return arg_str
1149
-
1150
- # generates argument string for a reverse function call
1151
- def format_reverse_call_args(
1152
- adj,
1153
- args_var,
1154
- args,
1155
- args_out,
1156
- use_initializer_list,
1157
- has_output_args=True,
1158
- require_original_output_arg=False,
1159
- ):
1160
- formatted_var = adj.format_args("var", args_var)
1161
- formatted_out = []
1162
- if has_output_args and (require_original_output_arg or len(args_out) > 1):
1163
- formatted_out = adj.format_args("var", args_out)
1164
- formatted_var_adj = adj.format_args(
1165
- "&adj" if use_initializer_list else "adj",
1166
- args,
1167
- )
1168
- formatted_out_adj = adj.format_args("adj", args_out)
1169
-
1170
- if len(formatted_var_adj) == 0 and len(formatted_out_adj) == 0:
1171
- # there are no adjoint arguments, so we don't need to call the reverse function
1172
- return None
1173
-
1174
- if use_initializer_list:
1175
- var_str = f"{{{', '.join(formatted_var)}}}"
1176
- out_str = f"{{{', '.join(formatted_out)}}}"
1177
- adj_str = f"{{{', '.join(formatted_var_adj)}}}"
1178
- out_adj_str = ", ".join(formatted_out_adj)
1179
- if len(args_out) > 1:
1180
- arg_str = ", ".join([var_str, out_str, adj_str, out_adj_str])
1181
- else:
1182
- arg_str = ", ".join([var_str, adj_str, out_adj_str])
1183
- else:
1184
- arg_str = ", ".join(formatted_var + formatted_out + formatted_var_adj + formatted_out_adj)
1185
- return arg_str
1186
-
1187
- def indent(adj):
1188
- adj.indentation = adj.indentation + " "
1189
-
1190
- def dedent(adj):
1191
- adj.indentation = adj.indentation[:-4]
1192
-
1193
- def begin_block(adj, name="block"):
1194
- b = Block()
1195
-
1196
- # give block a unique id
1197
- b.label = name + "_" + str(adj.label_count)
1198
- adj.label_count += 1
1199
-
1200
- adj.blocks.append(b)
1201
- return b
1202
-
1203
- def end_block(adj):
1204
- return adj.blocks.pop()
1205
-
1206
- def add_var(adj, type=None, constant=None):
1207
- index = len(adj.variables)
1208
- name = str(index)
1209
-
1210
- # allocate new variable
1211
- v = Var(name, type=type, constant=constant, relative_lineno=adj.lineno)
1212
-
1213
- adj.variables.append(v)
1214
-
1215
- adj.blocks[-1].vars.append(v)
1216
-
1217
- return v
1218
-
1219
- def register_var(adj, var):
1220
- # We sometimes initialize `Var` instances that might be thrown away
1221
- # afterwards, so this method allows to defer their registration among
1222
- # the list of primal vars until later on, instead of registering them
1223
- # immediately if we were to use `adj.add_var()` or `adj.add_constant()`.
1224
-
1225
- if isinstance(var, (Reference, warp.context.Function)):
1226
- return var
1227
-
1228
- if isinstance(var, int):
1229
- return adj.add_constant(var)
1230
-
1231
- if var.label is None:
1232
- return adj.add_var(var.type, var.constant)
1233
-
1234
- return var
1235
-
1236
- def get_line_directive(adj, statement: str, relative_lineno: int | None = None) -> str | None:
1237
- """Get a line directive for the given statement.
1238
-
1239
- Args:
1240
- statement: The statement to get the line directive for.
1241
- relative_lineno: The line number of the statement relative to the function.
1242
-
1243
- Returns:
1244
- A line directive for the given statement, or None if no line directive is needed.
1245
- """
1246
-
1247
- # lineinfo is enabled by default in debug mode regardless of the builder option, don't want to unnecessarily
1248
- # emit line directives in generated code if it's not being compiled with line information
1249
- build_mode = val if (val := adj.builder_options.get("mode")) is not None else warp.config.mode
1250
-
1251
- lineinfo_enabled = adj.builder_options.get("lineinfo", False) or build_mode == "debug"
1252
-
1253
- if relative_lineno is not None and lineinfo_enabled and warp.config.line_directives:
1254
- is_comment = statement.strip().startswith("//")
1255
- if not is_comment:
1256
- line = relative_lineno + adj.fun_lineno
1257
- # Convert backslashes to forward slashes for CUDA compatibility
1258
- normalized_path = adj.filename.replace("\\", "/")
1259
- return f'#line {line} "{normalized_path}"'
1260
- return None
1261
-
1262
- def add_forward(adj, statement: str, replay: str | None = None, skip_replay: builtins.bool = False) -> None:
1263
- """Append a statement to the forward pass."""
1264
-
1265
- if line_directive := adj.get_line_directive(statement, adj.lineno):
1266
- adj.blocks[-1].body_forward.append(line_directive)
1267
-
1268
- adj.blocks[-1].body_forward.append(adj.indentation + statement)
1269
-
1270
- if not skip_replay:
1271
- if line_directive:
1272
- adj.blocks[-1].body_replay.append(line_directive)
1273
-
1274
- if replay:
1275
- # if custom replay specified then output it
1276
- adj.blocks[-1].body_replay.append(adj.indentation + replay)
1277
- else:
1278
- # by default just replay the original statement
1279
- adj.blocks[-1].body_replay.append(adj.indentation + statement)
1280
-
1281
- # append a statement to the reverse pass
1282
- def add_reverse(adj, statement: str) -> None:
1283
- """Append a statement to the reverse pass."""
1284
-
1285
- adj.blocks[-1].body_reverse.append(adj.indentation + statement)
1286
-
1287
- if line_directive := adj.get_line_directive(statement, adj.lineno):
1288
- adj.blocks[-1].body_reverse.append(line_directive)
1289
-
1290
- def add_constant(adj, n):
1291
- output = adj.add_var(type=type(n), constant=n)
1292
- return output
1293
-
1294
- def load(adj, var):
1295
- if is_reference(var.type):
1296
- var = adj.add_builtin_call("load", [var])
1297
- return var
1298
-
1299
- def add_comp(adj, op_strings, left, comps):
1300
- output = adj.add_var(builtins.bool)
1301
-
1302
- left = adj.load(left)
1303
- s = output.emit() + " = " + ("(" * len(comps)) + left.emit() + " "
1304
-
1305
- prev_comp_var = None
1306
-
1307
- for op, comp in zip(op_strings, comps):
1308
- comp_chainable = op_str_is_chainable(op)
1309
- if comp_chainable and prev_comp_var:
1310
- # We restrict chaining to operands of the same type
1311
- if prev_comp_var.type is comp.type:
1312
- prev_comp_var = adj.load(prev_comp_var)
1313
- comp_var = adj.load(comp)
1314
- s += "&& (" + prev_comp_var.emit() + " " + op + " " + comp_var.emit() + ")) "
1315
- else:
1316
- raise WarpCodegenTypeError(
1317
- f"Cannot chain comparisons of unequal types: {prev_comp_var.type} {op} {comp.type}."
1318
- )
1319
- else:
1320
- comp_var = adj.load(comp)
1321
- s += op + " " + comp_var.emit() + ") "
1322
-
1323
- prev_comp_var = comp_var
1324
-
1325
- s = s.rstrip() + ";"
1326
-
1327
- adj.add_forward(s)
1328
-
1329
- return output
1330
-
1331
- def add_bool_op(adj, op_string, exprs):
1332
- exprs = [adj.load(expr) for expr in exprs]
1333
- output = adj.add_var(builtins.bool)
1334
- command = output.emit() + " = " + (" " + op_string + " ").join([expr.emit() for expr in exprs]) + ";"
1335
- adj.add_forward(command)
1336
-
1337
- return output
1338
-
1339
- def resolve_func(adj, func, arg_types, kwarg_types, min_outputs):
1340
- if not func.is_builtin():
1341
- # user-defined function
1342
- overload = func.get_overload(arg_types, kwarg_types)
1343
- if overload is not None:
1344
- return overload
1345
- else:
1346
- # if func is overloaded then perform overload resolution here
1347
- # we validate argument types before they go to generated native code
1348
- for f in func.overloads:
1349
- # skip type checking for variadic functions
1350
- if not f.variadic:
1351
- # check argument counts match are compatible (may be some default args)
1352
- if len(f.input_types) < len(arg_types) + len(kwarg_types):
1353
- continue
1354
-
1355
- if not func_match_args(f, arg_types, kwarg_types):
1356
- continue
1357
-
1358
- # check output dimensions match expectations
1359
- if min_outputs:
1360
- value_type = f.value_func(None, None)
1361
- if not isinstance(value_type, Sequence) or len(value_type) != min_outputs:
1362
- continue
1363
-
1364
- # found a match, use it
1365
- return f
1366
-
1367
- # unresolved function, report error
1368
- arg_type_reprs = []
1369
-
1370
- for x in itertools.chain(arg_types, kwarg_types.values()):
1371
- if isinstance(x, warp.context.Function):
1372
- arg_type_reprs.append("function")
1373
- else:
1374
- # shorten Warp primitive type names
1375
- if isinstance(x, Sequence):
1376
- if len(x) != 1:
1377
- raise WarpCodegenError("Argument must not be the result from a multi-valued function")
1378
- arg_type = x[0]
1379
- else:
1380
- arg_type = x
1381
-
1382
- arg_type_reprs.append(type_repr(arg_type))
1383
-
1384
- raise WarpCodegenError(
1385
- f"Couldn't find function overload for '{func.key}' that matched inputs with types: [{', '.join(arg_type_reprs)}]"
1386
- )
1387
-
1388
- def add_call(adj, func, args, kwargs, type_args, min_outputs=None):
1389
- # Extract the types and values passed as arguments to the function call.
1390
- arg_types = tuple(strip_reference(get_arg_type(x)) for x in args)
1391
- kwarg_types = {k: strip_reference(get_arg_type(v)) for k, v in kwargs.items()}
1392
-
1393
- # Resolve the exact function signature among any existing overload.
1394
- func = adj.resolve_func(func, arg_types, kwarg_types, min_outputs)
1395
-
1396
- # Bind the positional and keyword arguments to the function's signature
1397
- # in order to process them as Python does it.
1398
- bound_args: inspect.BoundArguments = func.signature.bind(*args, **kwargs)
1399
-
1400
- # Type args are the "compile time" argument values we get from codegen.
1401
- # For example, when calling `wp.vec3f(...)` from within a kernel,
1402
- # this translates in fact to calling the `vector()` built-in augmented
1403
- # with the type args `length=3, dtype=float`.
1404
- # Eventually, these need to be passed to the underlying C++ function,
1405
- # so we update the arguments with the type args here.
1406
- if type_args:
1407
- for arg in type_args:
1408
- if arg in bound_args.arguments:
1409
- # In case of conflict, ideally we'd throw an error since
1410
- # what comes from codegen should be the source of truth
1411
- # and users also passing the same value as an argument
1412
- # is redundant (e.g.: `wp.mat22(shape=(2, 2))`).
1413
- # However, for backward compatibility, we allow that form
1414
- # as long as the values are equal.
1415
- if values_check_equal(get_arg_value(bound_args.arguments[arg]), type_args[arg]):
1416
- continue
1417
-
1418
- raise RuntimeError(
1419
- f"Remove the extraneous `{arg}` parameter "
1420
- f"when calling the templated version of "
1421
- f"`wp.{func.native_func}()`"
1422
- )
1423
-
1424
- type_vars = {k: Var(None, type=type(v), constant=v) for k, v in type_args.items()}
1425
- apply_defaults(bound_args, type_vars)
1426
-
1427
- if func.defaults:
1428
- default_vars = {
1429
- k: Var(None, type=type(v), constant=v)
1430
- for k, v in func.defaults.items()
1431
- if k not in bound_args.arguments and v is not None
1432
- }
1433
- apply_defaults(bound_args, default_vars)
1434
-
1435
- bound_args = bound_args.arguments
1436
-
1437
- # if it is a user-function then build it recursively
1438
- if not func.is_builtin():
1439
- # If the function called is a user function,
1440
- # we need to ensure its adjoint is also being generated.
1441
- if adj.used_by_backward_kernel:
1442
- func.adj.used_by_backward_kernel = True
1443
-
1444
- if adj.builder is None:
1445
- func.build(None)
1446
-
1447
- elif func not in adj.builder.functions:
1448
- adj.builder.build_function(func)
1449
- # add custom grad, replay functions to the list of functions
1450
- # to be built later (invalid code could be generated if we built them now)
1451
- # so that they are not missed when only the forward function is imported
1452
- # from another module
1453
- if func.custom_grad_func:
1454
- adj.builder.deferred_functions.append(func.custom_grad_func)
1455
- if func.custom_replay_func:
1456
- adj.builder.deferred_functions.append(func.custom_replay_func)
1457
-
1458
- # Resolve the return value based on the types and values of the given arguments.
1459
- bound_arg_types = {k: get_arg_type(v) for k, v in bound_args.items()}
1460
- bound_arg_values = {k: get_arg_value(v) for k, v in bound_args.items()}
1461
-
1462
- return_type = func.value_func(
1463
- {k: strip_reference(v) for k, v in bound_arg_types.items()},
1464
- bound_arg_values,
1465
- )
1466
-
1467
- # Handle the special case where a Var instance is returned from the `value_func`
1468
- # callback, in which case we replace the call with a reference to that variable.
1469
- if isinstance(return_type, Var):
1470
- return adj.register_var(return_type)
1471
- elif isinstance(return_type, Sequence) and all(isinstance(x, Var) for x in return_type):
1472
- return tuple(adj.register_var(x) for x in return_type)
1473
-
1474
- if get_origin(return_type) is tuple:
1475
- types = get_args(return_type)
1476
- return_type = warp.types.tuple_t(types=types, values=(None,) * len(types))
1477
-
1478
- # immediately allocate output variables so we can pass them into the dispatch method
1479
- if return_type is None:
1480
- # void function
1481
- output = None
1482
- output_list = []
1483
- elif not isinstance(return_type, Sequence) or len(return_type) == 1:
1484
- # single return value function
1485
- if isinstance(return_type, Sequence):
1486
- return_type = return_type[0]
1487
- output = adj.add_var(return_type)
1488
- output_list = [output]
1489
- else:
1490
- # multiple return value function
1491
- output = [adj.add_var(v) for v in return_type]
1492
- output_list = output
1493
-
1494
- # If we have a built-in that requires special handling to dispatch
1495
- # the arguments to the underlying C++ function, then we can resolve
1496
- # these using the `dispatch_func`. Since this is only called from
1497
- # within codegen, we pass it directly `codegen.Var` objects,
1498
- # which allows for some more advanced resolution to be performed,
1499
- # for example by checking whether an argument corresponds to
1500
- # a literal value or references a variable.
1501
- extra_shared_memory = 0
1502
- if func.lto_dispatch_func is not None:
1503
- func_args, template_args, ltoirs, extra_shared_memory = func.lto_dispatch_func(
1504
- func.input_types, return_type, output_list, bound_args, options=adj.builder_options, builder=adj.builder
1505
- )
1506
- elif func.dispatch_func is not None:
1507
- func_args, template_args = func.dispatch_func(func.input_types, return_type, bound_args)
1508
- else:
1509
- func_args = tuple(bound_args.values())
1510
- template_args = ()
1511
-
1512
- func_args = tuple(adj.register_var(x) for x in func_args)
1513
- func_name = compute_type_str(func.native_func, template_args)
1514
- use_initializer_list = func.initializer_list_func(bound_args, return_type)
1515
-
1516
- fwd_args = []
1517
- for func_arg in func_args:
1518
- if not isinstance(func_arg, (Reference, warp.context.Function)):
1519
- func_arg_var = adj.load(func_arg)
1520
- else:
1521
- func_arg_var = func_arg
1522
-
1523
- # if the argument is a function (and not a builtin), then build it recursively
1524
- if isinstance(func_arg_var, warp.context.Function) and not func_arg_var.is_builtin():
1525
- if adj.used_by_backward_kernel:
1526
- func_arg_var.adj.used_by_backward_kernel = True
1527
-
1528
- adj.builder.build_function(func_arg_var)
1529
-
1530
- fwd_args.append(strip_reference(func_arg_var))
1531
-
1532
- if return_type is None:
1533
- # handles expression (zero output) functions, e.g.: void do_something();
1534
- forward_call = (
1535
- f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1536
- )
1537
- replay_call = forward_call
1538
- if func.custom_replay_func is not None or func.replay_snippet is not None:
1539
- replay_call = f"{func.namespace}replay_{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1540
-
1541
- elif not isinstance(return_type, Sequence) or len(return_type) == 1:
1542
- # handle simple function (one output)
1543
- forward_call = f"var_{output} = {func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1544
- replay_call = forward_call
1545
- if func.custom_replay_func is not None:
1546
- replay_call = f"var_{output} = {func.namespace}replay_{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1547
-
1548
- else:
1549
- # handle multiple value functions
1550
- forward_call = (
1551
- f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args + output, use_initializer_list)});"
1552
- )
1553
- replay_call = forward_call
1554
-
1555
- if func.skip_replay:
1556
- adj.add_forward(forward_call, replay="// " + replay_call)
1557
- else:
1558
- adj.add_forward(forward_call, replay=replay_call)
1559
-
1560
- if not func.missing_grad and len(func_args):
1561
- adj_args = tuple(strip_reference(x) for x in func_args)
1562
- reverse_has_output_args = (
1563
- func.require_original_output_arg or len(output_list) > 1
1564
- ) and func.custom_grad_func is None
1565
- arg_str = adj.format_reverse_call_args(
1566
- fwd_args,
1567
- adj_args,
1568
- output_list,
1569
- use_initializer_list,
1570
- has_output_args=reverse_has_output_args,
1571
- require_original_output_arg=func.require_original_output_arg,
1572
- )
1573
- if arg_str is not None:
1574
- reverse_call = f"{func.namespace}adj_{func.native_func}({arg_str});"
1575
- adj.add_reverse(reverse_call)
1576
-
1577
- # update our smem roofline requirements based on any
1578
- # shared memory required by the dependent function call
1579
- if not func.is_builtin():
1580
- adj.alloc_shared_extra(func.adj.get_total_required_shared() + extra_shared_memory)
1581
- else:
1582
- adj.alloc_shared_extra(extra_shared_memory)
1583
-
1584
- return output
1585
-
1586
- def add_builtin_call(adj, func_name, args, min_outputs=None):
1587
- func = warp.context.builtin_functions[func_name]
1588
- return adj.add_call(func, args, {}, {}, min_outputs=min_outputs)
1589
-
1590
- def add_return(adj, var):
1591
- if var is None or len(var) == 0:
1592
- # NOTE: If this kernel gets compiled for a CUDA device, then we need
1593
- # to convert the return; into a continue; in codegen_func_forward()
1594
- adj.add_forward("return;", f"goto label{adj.label_count};")
1595
- elif len(var) == 1:
1596
- adj.add_forward(f"return {var[0].emit()};", f"goto label{adj.label_count};")
1597
- adj.add_reverse("adj_" + str(var[0]) + " += adj_ret;")
1598
- else:
1599
- for i, v in enumerate(var):
1600
- adj.add_forward(f"ret_{i} = {v.emit()};")
1601
- adj.add_reverse(f"adj_{v} += adj_ret_{i};")
1602
- adj.add_forward("return;", f"goto label{adj.label_count};")
1603
-
1604
- adj.add_reverse(f"label{adj.label_count}:;")
1605
-
1606
- adj.label_count += 1
1607
-
1608
- # define an if statement
1609
- def begin_if(adj, cond):
1610
- cond = adj.load(cond)
1611
- adj.add_forward(f"if ({cond.emit()}) {{")
1612
- adj.add_reverse("}")
1613
-
1614
- adj.indent()
1615
-
1616
- def end_if(adj, cond):
1617
- adj.dedent()
1618
-
1619
- adj.add_forward("}")
1620
- cond = adj.load(cond)
1621
- adj.add_reverse(f"if ({cond.emit()}) {{")
1622
-
1623
- def begin_else(adj, cond):
1624
- cond = adj.load(cond)
1625
- adj.add_forward(f"if (!{cond.emit()}) {{")
1626
- adj.add_reverse("}")
1627
-
1628
- adj.indent()
1629
-
1630
- def end_else(adj, cond):
1631
- adj.dedent()
1632
-
1633
- adj.add_forward("}")
1634
- cond = adj.load(cond)
1635
- adj.add_reverse(f"if (!{cond.emit()}) {{")
1636
-
1637
- # define a for-loop
1638
- def begin_for(adj, iter):
1639
- cond_block = adj.begin_block("for")
1640
- adj.loop_blocks.append(cond_block)
1641
- adj.add_forward(f"start_{cond_block.label}:;")
1642
- adj.indent()
1643
-
1644
- # evaluate cond
1645
- adj.add_forward(f"if (iter_cmp({iter.emit()}) == 0) goto end_{cond_block.label};")
1646
-
1647
- # evaluate iter
1648
- val = adj.add_builtin_call("iter_next", [iter])
1649
-
1650
- adj.begin_block()
1651
-
1652
- return val
1653
-
1654
- def end_for(adj, iter):
1655
- body_block = adj.end_block()
1656
- cond_block = adj.end_block()
1657
- adj.loop_blocks.pop()
1658
-
1659
- ####################
1660
- # forward pass
1661
-
1662
- for i in cond_block.body_forward:
1663
- adj.blocks[-1].body_forward.append(i)
1664
-
1665
- for i in body_block.body_forward:
1666
- adj.blocks[-1].body_forward.append(i)
1667
-
1668
- adj.add_forward(f"goto start_{cond_block.label};", skip_replay=True)
1669
-
1670
- adj.dedent()
1671
- adj.add_forward(f"end_{cond_block.label}:;", skip_replay=True)
1672
-
1673
- ####################
1674
- # reverse pass
1675
-
1676
- reverse = []
1677
-
1678
- # reverse iterator
1679
- reverse.append(adj.indentation + f"{iter.emit()} = wp::iter_reverse({iter.emit()});")
1680
-
1681
- for i in cond_block.body_forward:
1682
- reverse.append(i)
1683
-
1684
- # zero adjoints
1685
- for i in body_block.vars:
1686
- if is_tile(i.type):
1687
- if i.type.owner:
1688
- reverse.append(adj.indentation + f"\t{i.emit_adj()}.grad_zero();")
1689
- else:
1690
- reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
1691
-
1692
- # replay
1693
- for i in body_block.body_replay:
1694
- reverse.append(i)
1695
-
1696
- # reverse
1697
- for i in reversed(body_block.body_reverse):
1698
- reverse.append(i)
1699
-
1700
- reverse.append(adj.indentation + f"\tgoto start_{cond_block.label};")
1701
- reverse.append(adj.indentation + f"end_{cond_block.label}:;")
1702
-
1703
- adj.blocks[-1].body_reverse.extend(reversed(reverse))
1704
-
1705
- # define a while loop
1706
- def begin_while(adj, cond):
1707
- # evaluate condition in its own block
1708
- # so we can control replay
1709
- cond_block = adj.begin_block("while")
1710
- adj.loop_blocks.append(cond_block)
1711
- cond_block.body_forward.append(f"start_{cond_block.label}:;")
1712
-
1713
- c = adj.eval(cond)
1714
- c = adj.load(c)
1715
-
1716
- cond_block.body_forward.append(f"if (({c.emit()}) == false) goto end_{cond_block.label};")
1717
-
1718
- # being block around loop
1719
- adj.begin_block()
1720
- adj.indent()
1721
-
1722
- def end_while(adj):
1723
- adj.dedent()
1724
- body_block = adj.end_block()
1725
- cond_block = adj.end_block()
1726
- adj.loop_blocks.pop()
1727
-
1728
- ####################
1729
- # forward pass
1730
-
1731
- for i in cond_block.body_forward:
1732
- adj.blocks[-1].body_forward.append(i)
1733
-
1734
- for i in body_block.body_forward:
1735
- adj.blocks[-1].body_forward.append(i)
1736
-
1737
- adj.blocks[-1].body_forward.append(f"goto start_{cond_block.label};")
1738
- adj.blocks[-1].body_forward.append(f"end_{cond_block.label}:;")
1739
-
1740
- ####################
1741
- # reverse pass
1742
- reverse = []
1743
-
1744
- # cond
1745
- for i in cond_block.body_forward:
1746
- reverse.append(i)
1747
-
1748
- # zero adjoints of local vars
1749
- for i in body_block.vars:
1750
- reverse.append(f"{i.emit_adj()} = {{}};")
1751
-
1752
- # replay
1753
- for i in body_block.body_replay:
1754
- reverse.append(i)
1755
-
1756
- # reverse
1757
- for i in reversed(body_block.body_reverse):
1758
- reverse.append(i)
1759
-
1760
- reverse.append(f"goto start_{cond_block.label};")
1761
- reverse.append(f"end_{cond_block.label}:;")
1762
-
1763
- # output
1764
- adj.blocks[-1].body_reverse.extend(reversed(reverse))
1765
-
1766
- def emit_FunctionDef(adj, node):
1767
- adj.fun_def_lineno = node.lineno
1768
-
1769
- for f in node.body:
1770
- # Skip variable creation for standalone constants, including docstrings
1771
- if isinstance(f, ast.Expr) and isinstance(f.value, ast.Constant):
1772
- continue
1773
- adj.eval(f)
1774
-
1775
- if adj.return_var is not None and len(adj.return_var) == 1:
1776
- if not isinstance(node.body[-1], ast.Return):
1777
- adj.add_forward("return {};", skip_replay=True)
1778
-
1779
- # native function case: return type is specified, eg -> int or -> wp.float32
1780
- is_func_native = False
1781
- if node.decorator_list is not None and len(node.decorator_list) == 1:
1782
- obj = node.decorator_list[0]
1783
- if isinstance(obj, ast.Call):
1784
- if isinstance(obj.func, ast.Attribute):
1785
- if obj.func.attr == "func_native":
1786
- is_func_native = True
1787
- if is_func_native and node.returns is not None:
1788
- if isinstance(node.returns, ast.Name): # python built-in type
1789
- var = Var(label="return_type", type=eval(node.returns.id))
1790
- elif isinstance(node.returns, ast.Attribute): # warp type
1791
- var = Var(label="return_type", type=eval(node.returns.attr))
1792
- else:
1793
- raise WarpCodegenTypeError("Native function return type not recognized")
1794
- adj.return_var = (var,)
1795
-
1796
- def emit_If(adj, node):
1797
- if len(node.body) == 0:
1798
- return None
1799
-
1800
- # eval condition
1801
- cond = adj.eval(node.test)
1802
-
1803
- if cond.constant is not None:
1804
- # resolve constant condition
1805
- if cond.constant:
1806
- for stmt in node.body:
1807
- adj.eval(stmt)
1808
- else:
1809
- for stmt in node.orelse:
1810
- adj.eval(stmt)
1811
- return None
1812
-
1813
- # save symbol map
1814
- symbols_prev = adj.symbols.copy()
1815
-
1816
- # eval body
1817
- adj.begin_if(cond)
1818
-
1819
- for stmt in node.body:
1820
- adj.eval(stmt)
1821
-
1822
- adj.end_if(cond)
1823
-
1824
- # detect existing symbols with conflicting definitions (variables assigned inside the branch)
1825
- # and resolve with a phi (select) function
1826
- for items in symbols_prev.items():
1827
- sym = items[0]
1828
- var1 = items[1]
1829
- var2 = adj.symbols[sym]
1830
-
1831
- if var1 != var2:
1832
- # insert a phi function that selects var1, var2 based on cond
1833
- out = adj.add_builtin_call("where", [cond, var2, var1])
1834
- adj.symbols[sym] = out
1835
-
1836
- symbols_prev = adj.symbols.copy()
1837
-
1838
- # evaluate 'else' statement as if (!cond)
1839
- if len(node.orelse) > 0:
1840
- adj.begin_else(cond)
1841
-
1842
- for stmt in node.orelse:
1843
- adj.eval(stmt)
1844
-
1845
- adj.end_else(cond)
1846
-
1847
- # detect existing symbols with conflicting definitions (variables assigned inside the else)
1848
- # and resolve with a phi (select) function
1849
- for items in symbols_prev.items():
1850
- sym = items[0]
1851
- var1 = items[1]
1852
- var2 = adj.symbols[sym]
1853
-
1854
- if var1 != var2:
1855
- # insert a phi function that selects var1, var2 based on cond
1856
- # note the reversed order of vars since we want to use !cond as our select
1857
- out = adj.add_builtin_call("where", [cond, var1, var2])
1858
- adj.symbols[sym] = out
1859
-
1860
- def emit_IfExp(adj, node):
1861
- cond = adj.eval(node.test)
1862
-
1863
- if cond.constant is not None:
1864
- return adj.eval(node.body) if cond.constant else adj.eval(node.orelse)
1865
-
1866
- adj.begin_if(cond)
1867
- body = adj.eval(node.body)
1868
- adj.end_if(cond)
1869
-
1870
- adj.begin_else(cond)
1871
- orelse = adj.eval(node.orelse)
1872
- adj.end_else(cond)
1873
-
1874
- return adj.add_builtin_call("where", [cond, body, orelse])
1875
-
1876
- def emit_Compare(adj, node):
1877
- # node.left, node.ops (list of ops), node.comparators (things to compare to)
1878
- # e.g. (left ops[0] node.comparators[0]) ops[1] node.comparators[1]
1879
-
1880
- left = adj.eval(node.left)
1881
- comps = [adj.eval(comp) for comp in node.comparators]
1882
- op_strings = [builtin_operators[type(op)] for op in node.ops]
1883
-
1884
- return adj.add_comp(op_strings, left, comps)
1885
-
1886
- def emit_BoolOp(adj, node):
1887
- # op, expr list values
1888
-
1889
- op = node.op
1890
- if isinstance(op, ast.And):
1891
- func = "&&"
1892
- elif isinstance(op, ast.Or):
1893
- func = "||"
1894
- else:
1895
- raise WarpCodegenKeyError(f"Op {op} is not supported")
1896
-
1897
- return adj.add_bool_op(func, [adj.eval(expr) for expr in node.values])
1898
-
1899
- def emit_Name(adj, node):
1900
- # lookup symbol, if it has already been assigned to a variable then return the existing mapping
1901
- if node.id in adj.symbols:
1902
- return adj.symbols[node.id]
1903
-
1904
- obj = adj.resolve_external_reference(node.id)
1905
-
1906
- if obj is None:
1907
- raise WarpCodegenKeyError("Referencing undefined symbol: " + str(node.id))
1908
-
1909
- if warp.types.is_value(obj):
1910
- # evaluate constant
1911
- out = adj.add_constant(obj)
1912
- adj.symbols[node.id] = out
1913
- return out
1914
-
1915
- # the named object is either a function, class name, or module
1916
- # pass it back to the caller for processing
1917
- if isinstance(obj, warp.context.Function):
1918
- return obj
1919
- if isinstance(obj, type):
1920
- return obj
1921
- if isinstance(obj, Struct):
1922
- adj.builder.build_struct_recursive(obj)
1923
- return obj
1924
- if isinstance(obj, types.ModuleType):
1925
- return obj
1926
-
1927
- raise TypeError(f"Invalid external reference type: {type(obj)}")
1928
-
1929
- @staticmethod
1930
- def resolve_type_attribute(var_type: type, attr: str):
1931
- if isinstance(var_type, type) and type_is_value(var_type):
1932
- if attr == "dtype":
1933
- return type_scalar_type(var_type)
1934
- elif attr == "length":
1935
- return type_size(var_type)
1936
-
1937
- return getattr(var_type, attr, None)
1938
-
1939
- def vector_component_index(adj, component, vector_type):
1940
- if len(component) != 1:
1941
- raise WarpCodegenAttributeError(f"Vector swizzle must be single character, got .{component}")
1942
-
1943
- dim = vector_type._shape_[0]
1944
- swizzles = "xyzw"[0:dim]
1945
- if component not in swizzles:
1946
- raise WarpCodegenAttributeError(
1947
- f"Vector swizzle for {vector_type} must be one of {swizzles}, got {component}"
1948
- )
1949
-
1950
- index = swizzles.index(component)
1951
- index = adj.add_constant(index)
1952
- return index
1953
-
1954
- def transform_component(adj, component):
1955
- if len(component) != 1:
1956
- raise WarpCodegenAttributeError(f"Transform attribute must be single character, got .{component}")
1957
-
1958
- if component not in ("p", "q"):
1959
- raise WarpCodegenAttributeError(f"Attribute for transformation must be either 'p' or 'q', got {component}")
1960
-
1961
- return component
1962
-
1963
- @staticmethod
1964
- def is_differentiable_value_type(var_type):
1965
- # checks that the argument type is a value type (i.e, not an array)
1966
- # possibly holding differentiable values (for which gradients must be accumulated)
1967
- return type_scalar_type(var_type) in float_types or isinstance(var_type, Struct)
1968
-
1969
- def emit_Attribute(adj, node):
1970
- if hasattr(node, "is_adjoint"):
1971
- node.value.is_adjoint = True
1972
-
1973
- aggregate = adj.eval(node.value)
1974
-
1975
- try:
1976
- if isinstance(aggregate, Var) and aggregate.constant is not None:
1977
- # this case may occur when the attribute is a constant, e.g.: `IntEnum.A.value`
1978
- return aggregate
1979
-
1980
- if isinstance(aggregate, types.ModuleType) or isinstance(aggregate, type):
1981
- out = getattr(aggregate, node.attr)
1982
-
1983
- if warp.types.is_value(out):
1984
- return adj.add_constant(out)
1985
- if isinstance(out, (enum.IntEnum, enum.IntFlag)):
1986
- return adj.add_constant(int(out))
1987
-
1988
- return out
1989
-
1990
- if hasattr(node, "is_adjoint"):
1991
- # create a Var that points to the struct attribute, i.e.: directly generates `struct.attr` when used
1992
- attr_name = aggregate.label + "." + node.attr
1993
- attr_type = aggregate.type.vars[node.attr].type
1994
-
1995
- return Var(attr_name, attr_type)
1996
-
1997
- aggregate_type = strip_reference(aggregate.type)
1998
-
1999
- # reading a vector or quaternion component
2000
- if type_is_vector(aggregate_type) or type_is_quaternion(aggregate_type):
2001
- index = adj.vector_component_index(node.attr, aggregate_type)
2002
-
2003
- return adj.add_builtin_call("extract", [aggregate, index])
2004
-
2005
- elif type_is_transformation(aggregate_type):
2006
- component = adj.transform_component(node.attr)
2007
-
2008
- if component == "p":
2009
- return adj.add_builtin_call("transform_get_translation", [aggregate])
2010
- else:
2011
- return adj.add_builtin_call("transform_get_rotation", [aggregate])
2012
-
2013
- else:
2014
- attr_var = aggregate_type.vars[node.attr]
2015
-
2016
- # represent pointer types as uint64
2017
- if isinstance(attr_var.type, pointer_t):
2018
- cast = f"({Var.dtype_to_ctype(uint64)}*)"
2019
- adj_cast = f"({Var.dtype_to_ctype(attr_var.type.dtype)}*)"
2020
- attr_type = Reference(uint64)
2021
- else:
2022
- cast = ""
2023
- adj_cast = ""
2024
- attr_type = Reference(attr_var.type)
2025
-
2026
- attr = adj.add_var(attr_type)
2027
-
2028
- if is_reference(aggregate.type):
2029
- adj.add_forward(f"{attr.emit()} = {cast}&({aggregate.emit()}->{attr_var.label});")
2030
- else:
2031
- adj.add_forward(f"{attr.emit()} = {cast}&({aggregate.emit()}.{attr_var.label});")
2032
-
2033
- if adj.is_differentiable_value_type(strip_reference(attr_type)):
2034
- adj.add_reverse(f"{aggregate.emit_adj()}.{attr_var.label} += {adj_cast}{attr.emit_adj()};")
2035
- else:
2036
- adj.add_reverse(f"{aggregate.emit_adj()}.{attr_var.label} = {adj_cast}{attr.emit_adj()};")
2037
-
2038
- return attr
2039
-
2040
- except (KeyError, AttributeError) as e:
2041
- # Try resolving as type attribute
2042
- aggregate_type = strip_reference(aggregate.type) if isinstance(aggregate, Var) else aggregate
2043
-
2044
- type_attribute = adj.resolve_type_attribute(aggregate_type, node.attr)
2045
- if type_attribute is not None:
2046
- return type_attribute
2047
-
2048
- if isinstance(aggregate, Var):
2049
- raise WarpCodegenAttributeError(
2050
- f"Error, `{node.attr}` is not an attribute of '{node.value.id}' ({type_repr(aggregate.type)})"
2051
- ) from e
2052
- raise WarpCodegenAttributeError(f"Error, `{node.attr}` is not an attribute of '{aggregate}'") from e
2053
-
2054
- def emit_Assert(adj, node):
2055
- # eval condition
2056
- cond = adj.eval(node.test)
2057
- cond = adj.load(cond)
2058
-
2059
- source_segment = ast.get_source_segment(adj.source, node)
2060
- # If a message was provided with the assert, " marks can interfere with the generated code
2061
- escaped_segment = source_segment.replace('"', '\\"')
2062
-
2063
- adj.add_forward(f'assert(("{escaped_segment}",{cond.emit()}));')
2064
-
2065
- def emit_Constant(adj, node):
2066
- if node.value is None:
2067
- raise WarpCodegenTypeError("None type unsupported")
2068
- else:
2069
- return adj.add_constant(node.value)
2070
-
2071
- def emit_BinOp(adj, node):
2072
- # evaluate binary operator arguments
2073
-
2074
- if warp.config.verify_autograd_array_access:
2075
- # array overwrite tracking: in-place operators are a special case
2076
- # x[tid] = x[tid] + 1 is a read followed by a write, but we only want to record the write
2077
- # so we save the current arg read flags and restore them after lhs eval
2078
- is_read_states = []
2079
- for arg in adj.args:
2080
- is_read_states.append(arg.is_read)
2081
-
2082
- # evaluate lhs binary operator argument
2083
- left = adj.eval(node.left)
2084
-
2085
- if warp.config.verify_autograd_array_access:
2086
- # restore arg read flags
2087
- for i, arg in enumerate(adj.args):
2088
- arg.is_read = is_read_states[i]
2089
-
2090
- # evaluate rhs binary operator argument
2091
- right = adj.eval(node.right)
2092
-
2093
- name = builtin_operators[type(node.op)]
2094
-
2095
- try:
2096
- # Check if there is any user-defined overload for this operator
2097
- user_func = adj.resolve_external_reference(name)
2098
- if isinstance(user_func, warp.context.Function):
2099
- return adj.add_call(user_func, (left, right), {}, {})
2100
- except WarpCodegenError:
2101
- pass
2102
-
2103
- return adj.add_builtin_call(name, [left, right])
2104
-
2105
- def emit_UnaryOp(adj, node):
2106
- # evaluate unary op arguments
2107
- arg = adj.eval(node.operand)
2108
-
2109
- # evaluate expression to a compile-time constant if arg is a constant
2110
- if arg.constant is not None and math.isfinite(arg.constant):
2111
- if isinstance(node.op, ast.USub):
2112
- return adj.add_constant(-arg.constant)
2113
-
2114
- name = builtin_operators[type(node.op)]
2115
-
2116
- return adj.add_builtin_call(name, [arg])
2117
-
2118
- def materialize_redefinitions(adj, symbols):
2119
- # detect symbols with conflicting definitions (assigned inside the for loop)
2120
- for items in symbols.items():
2121
- sym = items[0]
2122
- if adj.is_constant_iter_symbol(sym):
2123
- # ignore constant overwriting in for-loops if it is a loop iterator
2124
- # (it is no problem to unroll static loops multiple times in sequence)
2125
- continue
2126
-
2127
- var1 = items[1]
2128
- var2 = adj.symbols[sym]
2129
-
2130
- if var1 != var2:
2131
- if warp.config.verbose and not adj.custom_reverse_mode:
2132
- lineno = adj.lineno + adj.fun_lineno
2133
- line = adj.source_lines[adj.lineno]
2134
- msg = f'Warning: detected mutated variable {sym} during a dynamic for-loop in function "{adj.fun_name}" at {adj.filename}:{lineno}: this may not be a differentiable operation.\n{line}\n'
2135
- print(msg)
2136
-
2137
- if var1.constant is not None:
2138
- raise WarpCodegenError(
2139
- f"Error mutating a constant {sym} inside a dynamic loop, use the following syntax: pi = float(3.141) to declare a dynamic variable"
2140
- )
2141
-
2142
- # overwrite the old variable value (violates SSA)
2143
- adj.add_builtin_call("assign", [var1, var2])
2144
-
2145
- # reset the symbol to point to the original variable
2146
- adj.symbols[sym] = var1
2147
-
2148
- def emit_While(adj, node):
2149
- adj.begin_while(node.test)
2150
-
2151
- adj.loop_symbols.append(adj.symbols.copy())
2152
-
2153
- # eval body
2154
- for s in node.body:
2155
- adj.eval(s)
2156
-
2157
- adj.materialize_redefinitions(adj.loop_symbols[-1])
2158
- adj.loop_symbols.pop()
2159
-
2160
- adj.end_while()
2161
-
2162
- def eval_num(adj, a):
2163
- if isinstance(a, ast.Constant):
2164
- return True, a.value
2165
- if isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Constant):
2166
- # Negative constant
2167
- return True, -a.operand.value
2168
-
2169
- # try and resolve the expression to an object
2170
- # e.g.: wp.constant in the globals scope
2171
- obj, _ = adj.resolve_static_expression(a)
2172
-
2173
- if obj is None:
2174
- obj = adj.eval(a)
2175
-
2176
- if isinstance(obj, Var) and obj.constant is not None:
2177
- obj = obj.constant
2178
-
2179
- return warp.types.is_int(obj), obj
2180
-
2181
- # detects whether a loop contains a break (or continue) statement
2182
- def contains_break(adj, body):
2183
- for s in body:
2184
- if isinstance(s, ast.Break):
2185
- return True
2186
- elif isinstance(s, ast.Continue):
2187
- return True
2188
- elif isinstance(s, ast.If):
2189
- if adj.contains_break(s.body):
2190
- return True
2191
- if adj.contains_break(s.orelse):
2192
- return True
2193
- else:
2194
- # note that nested for or while loops containing a break statement
2195
- # do not affect the current loop
2196
- pass
2197
-
2198
- return False
2199
-
2200
- # returns a constant range() if unrollable, otherwise None
2201
- def get_unroll_range(adj, loop):
2202
- if (
2203
- not isinstance(loop.iter, ast.Call)
2204
- or not isinstance(loop.iter.func, ast.Name)
2205
- or loop.iter.func.id != "range"
2206
- or len(loop.iter.args) == 0
2207
- or len(loop.iter.args) > 3
2208
- ):
2209
- return None
2210
-
2211
- # if all range() arguments are numeric constants we will unroll
2212
- # note that this only handles trivial constants, it will not unroll
2213
- # constant compile-time expressions e.g.: range(0, 3*2)
2214
-
2215
- # Evaluate the arguments and check that they are numeric constants
2216
- # It is important to do that in one pass, so that if evaluating these arguments have side effects
2217
- # the code does not get generated more than once
2218
- range_args = [adj.eval_num(arg) for arg in loop.iter.args]
2219
- arg_is_numeric, arg_values = zip(*range_args)
2220
-
2221
- if all(arg_is_numeric):
2222
- # All argument are numeric constants
2223
-
2224
- # range(end)
2225
- if len(loop.iter.args) == 1:
2226
- start = 0
2227
- end = arg_values[0]
2228
- step = 1
2229
-
2230
- # range(start, end)
2231
- elif len(loop.iter.args) == 2:
2232
- start = arg_values[0]
2233
- end = arg_values[1]
2234
- step = 1
2235
-
2236
- # range(start, end, step)
2237
- elif len(loop.iter.args) == 3:
2238
- start = arg_values[0]
2239
- end = arg_values[1]
2240
- step = arg_values[2]
2241
-
2242
- # test if we're above max unroll count
2243
- max_iters = abs(end - start) // abs(step)
2244
-
2245
- if "max_unroll" in adj.builder_options:
2246
- max_unroll = adj.builder_options["max_unroll"]
2247
- else:
2248
- max_unroll = warp.config.max_unroll
2249
-
2250
- ok_to_unroll = True
2251
-
2252
- if max_iters > max_unroll:
2253
- if warp.config.verbose:
2254
- print(
2255
- f"Warning: fixed-size loop count of {max_iters} is larger than the module 'max_unroll' limit of {max_unroll}, will generate dynamic loop."
2256
- )
2257
- ok_to_unroll = False
2258
-
2259
- elif adj.contains_break(loop.body):
2260
- if warp.config.verbose:
2261
- print("Warning: 'break' or 'continue' found in loop body, will generate dynamic loop.")
2262
- ok_to_unroll = False
2263
-
2264
- if ok_to_unroll:
2265
- return range(start, end, step)
2266
-
2267
- # Unroll is not possible, range needs to be valuated dynamically
2268
- range_call = adj.add_builtin_call(
2269
- "range",
2270
- [adj.add_constant(val) if is_numeric else val for is_numeric, val in range_args],
2271
- )
2272
- return range_call
2273
-
2274
- def record_constant_iter_symbol(adj, sym):
2275
- adj.loop_const_iter_symbols.add(sym)
2276
-
2277
- def is_constant_iter_symbol(adj, sym):
2278
- return sym in adj.loop_const_iter_symbols
2279
-
2280
- def emit_For(adj, node):
2281
- # try and unroll simple range() statements that use constant args
2282
- unroll_range = adj.get_unroll_range(node)
2283
-
2284
- if isinstance(unroll_range, range):
2285
- const_iter_sym = node.target.id
2286
- # prevent constant conflicts in `materialize_redefinitions()`
2287
- adj.record_constant_iter_symbol(const_iter_sym)
2288
-
2289
- # unroll static for-loop
2290
- for i in unroll_range:
2291
- const_iter = adj.add_constant(i)
2292
- adj.symbols[const_iter_sym] = const_iter
2293
-
2294
- # eval body
2295
- for s in node.body:
2296
- adj.eval(s)
2297
-
2298
- # otherwise generate a dynamic loop
2299
- else:
2300
- # evaluate the Iterable -- only if not previously evaluated when trying to unroll
2301
- if unroll_range is not None:
2302
- # Range has already been evaluated when trying to unroll, do not re-evaluate
2303
- iter = unroll_range
2304
- else:
2305
- iter = adj.eval(node.iter)
2306
-
2307
- adj.symbols[node.target.id] = adj.begin_for(iter)
2308
-
2309
- # for loops should be side-effect free, here we store a copy
2310
- adj.loop_symbols.append(adj.symbols.copy())
2311
-
2312
- # eval body
2313
- for s in node.body:
2314
- adj.eval(s)
2315
-
2316
- adj.materialize_redefinitions(adj.loop_symbols[-1])
2317
- adj.loop_symbols.pop()
2318
-
2319
- adj.end_for(iter)
2320
-
2321
- def emit_Break(adj, node):
2322
- adj.materialize_redefinitions(adj.loop_symbols[-1])
2323
-
2324
- adj.add_forward(f"goto end_{adj.loop_blocks[-1].label};")
2325
-
2326
- def emit_Continue(adj, node):
2327
- adj.materialize_redefinitions(adj.loop_symbols[-1])
2328
-
2329
- adj.add_forward(f"goto start_{adj.loop_blocks[-1].label};")
2330
-
2331
- def emit_Expr(adj, node):
2332
- return adj.eval(node.value)
2333
-
2334
- def check_tid_in_func_error(adj, node):
2335
- if adj.is_user_function:
2336
- if hasattr(node.func, "attr") and node.func.attr == "tid":
2337
- lineno = adj.lineno + adj.fun_lineno
2338
- line = adj.source_lines[adj.lineno]
2339
- raise WarpCodegenError(
2340
- "tid() may only be called from a Warp kernel, not a Warp function. "
2341
- "Instead, obtain the indices from a @wp.kernel and pass them as "
2342
- f"arguments to the function {adj.fun_name}, {adj.filename}:{lineno}:\n{line}\n"
2343
- )
2344
-
2345
- def resolve_arg(adj, arg):
2346
- # Always try to start with evaluating the argument since it can help
2347
- # detecting some issues such as global variables being accessed.
2348
- try:
2349
- var = adj.eval(arg)
2350
- except (WarpCodegenError, WarpCodegenKeyError) as e:
2351
- error = e
2352
- else:
2353
- error = None
2354
-
2355
- # Check if we can resolve the argument as a static expression.
2356
- # If not, return the variable resulting from evaluating the argument.
2357
- expr, _ = adj.resolve_static_expression(arg)
2358
- if expr is None:
2359
- if error is not None:
2360
- raise error
2361
-
2362
- return var
2363
-
2364
- if isinstance(expr, (type, Struct, Var, warp.context.Function)):
2365
- return expr
2366
-
2367
- if isinstance(expr, (enum.IntEnum, enum.IntFlag)):
2368
- return adj.add_constant(int(expr))
2369
-
2370
- return adj.add_constant(expr)
2371
-
2372
- def emit_Call(adj, node):
2373
- adj.check_tid_in_func_error(node)
2374
-
2375
- # try and lookup function in globals by
2376
- # resolving path (e.g.: module.submodule.attr)
2377
- if hasattr(node.func, "warp_func"):
2378
- func = node.func.warp_func
2379
- path = []
2380
- else:
2381
- func, path = adj.resolve_static_expression(node.func)
2382
- if func is None:
2383
- func = adj.eval(node.func)
2384
-
2385
- if adj.is_static_expression(func):
2386
- # try to evaluate wp.static() expressions
2387
- obj, code = adj.evaluate_static_expression(node)
2388
- if obj is not None:
2389
- adj.static_expressions[code] = obj
2390
- if isinstance(obj, warp.context.Function):
2391
- # special handling for wp.static() evaluating to a function
2392
- return obj
2393
- else:
2394
- out = adj.add_constant(obj)
2395
- return out
2396
-
2397
- type_args = {}
2398
-
2399
- if len(path) > 0 and not isinstance(func, warp.context.Function):
2400
- attr = path[-1]
2401
- caller = func
2402
- func = None
2403
-
2404
- # try and lookup function name in builtins (e.g.: using `dot` directly without wp prefix)
2405
- if attr in warp.context.builtin_functions:
2406
- func = warp.context.builtin_functions[attr]
2407
-
2408
- # vector class type e.g.: wp.vec3f constructor
2409
- if func is None and hasattr(caller, "_wp_generic_type_str_"):
2410
- func = warp.context.builtin_functions.get(caller._wp_constructor_)
2411
-
2412
- # scalar class type e.g.: wp.int8 constructor
2413
- if func is None and hasattr(caller, "__name__") and caller.__name__ in warp.context.builtin_functions:
2414
- func = warp.context.builtin_functions.get(caller.__name__)
2415
-
2416
- # struct constructor
2417
- if func is None and isinstance(caller, Struct):
2418
- if adj.builder is not None:
2419
- adj.builder.build_struct_recursive(caller)
2420
- if node.args or node.keywords:
2421
- func = caller.value_constructor
2422
- else:
2423
- func = caller.default_constructor
2424
-
2425
- # lambda function
2426
- if func is None and getattr(caller, "__name__", None) == "<lambda>":
2427
- raise NotImplementedError("Lambda expressions are not yet supported")
2428
-
2429
- if hasattr(caller, "_wp_type_args_"):
2430
- type_args = caller._wp_type_args_
2431
-
2432
- if func is None:
2433
- raise WarpCodegenError(
2434
- f"Could not find function {'.'.join(path)} as a built-in or user-defined function. Note that user functions must be annotated with a @wp.func decorator to be called from a kernel."
2435
- )
2436
-
2437
- # get expected return count, e.g.: for multi-assignment
2438
- min_outputs = None
2439
- if hasattr(node, "expects"):
2440
- min_outputs = node.expects
2441
-
2442
- # Evaluate all positional and keywords arguments.
2443
- args = tuple(adj.resolve_arg(x) for x in node.args)
2444
- kwargs = {x.arg: adj.resolve_arg(x.value) for x in node.keywords}
2445
-
2446
- out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs)
2447
-
2448
- if warp.config.verify_autograd_array_access:
2449
- # Extract the types and values passed as arguments to the function call.
2450
- arg_types = tuple(strip_reference(get_arg_type(x)) for x in args)
2451
- kwarg_types = {k: strip_reference(get_arg_type(v)) for k, v in kwargs.items()}
2452
-
2453
- # Resolve the exact function signature among any existing overload.
2454
- resolved_func = adj.resolve_func(func, arg_types, kwarg_types, min_outputs)
2455
-
2456
- # update arg read/write states according to what happens to that arg in the called function
2457
- if hasattr(resolved_func, "adj"):
2458
- for i, arg in enumerate(args):
2459
- if resolved_func.adj.args[i].is_write:
2460
- kernel_name = adj.fun_name
2461
- filename = adj.filename
2462
- lineno = adj.lineno + adj.fun_lineno
2463
- arg.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2464
- if resolved_func.adj.args[i].is_read:
2465
- arg.mark_read()
2466
-
2467
- return out
2468
-
2469
- def emit_Index(adj, node):
2470
- # the ast.Index node appears in 3.7 versions
2471
- # when performing array slices, e.g.: x = arr[i]
2472
- # but in version 3.8 and higher it does not appear
2473
-
2474
- if hasattr(node, "is_adjoint"):
2475
- node.value.is_adjoint = True
2476
-
2477
- return adj.eval(node.value)
2478
-
2479
- def eval_indices(adj, target_type, indices):
2480
- nodes = indices
2481
- if hasattr(target_type, "_wp_generic_type_hint_"):
2482
- indices = []
2483
- for dim, node in enumerate(nodes):
2484
- if isinstance(node, ast.Slice):
2485
- # In the context of slicing a vec/mat type, indices are expected
2486
- # to be compile-time constants, hence we can infer the actual slice
2487
- # bounds also at compile-time.
2488
- length = target_type._shape_[dim]
2489
- step = 1 if node.step is None else adj.eval(node.step).constant
2490
-
2491
- if node.lower is None:
2492
- start = length - 1 if step < 0 else 0
2493
- else:
2494
- start = adj.eval(node.lower).constant
2495
- start = min(max(start, -length), length)
2496
- start = start + length if start < 0 else start
2497
-
2498
- if node.upper is None:
2499
- stop = -1 if step < 0 else length
2500
- else:
2501
- stop = adj.eval(node.upper).constant
2502
- stop = min(max(stop, -length), length)
2503
- stop = stop + length if stop < 0 else stop
2504
-
2505
- slice = adj.add_builtin_call("slice", (start, stop, step))
2506
- indices.append(slice)
2507
- else:
2508
- indices.append(adj.eval(node))
2509
-
2510
- return tuple(indices)
2511
- else:
2512
- return tuple(adj.eval(x) for x in nodes)
2513
-
2514
- def emit_indexing(adj, target, indices):
2515
- target_type = strip_reference(target.type)
2516
- indices = adj.eval_indices(target_type, indices)
2517
-
2518
- if is_array(target_type):
2519
- if len(indices) == target_type.ndim:
2520
- # handles array loads (where each dimension has an index specified)
2521
- out = adj.add_builtin_call("address", [target, *indices])
2522
-
2523
- if warp.config.verify_autograd_array_access:
2524
- target.mark_read()
2525
-
2526
- else:
2527
- # handles array views (fewer indices than dimensions)
2528
- out = adj.add_builtin_call("view", [target, *indices])
2529
-
2530
- if warp.config.verify_autograd_array_access:
2531
- # store reference to target Var to propagate downstream read/write state back to root arg Var
2532
- out.parent = target
2533
-
2534
- # view arg inherits target Var's read/write states
2535
- out.is_read = target.is_read
2536
- out.is_write = target.is_write
2537
-
2538
- elif is_tile(target_type):
2539
- if len(indices) == len(target_type.shape):
2540
- # handles extracting a single element from a tile
2541
- out = adj.add_builtin_call("tile_extract", [target, *indices])
2542
- elif len(indices) < len(target_type.shape):
2543
- # handles tile views
2544
- out = adj.add_builtin_call("tile_view", [target, indices])
2545
- else:
2546
- raise RuntimeError(
2547
- f"Incorrect number of indices specified for a tile view/extract, got {len(indices)} indices for a {len(target_type.shape)} dimensional tile."
2548
- )
2549
-
2550
- else:
2551
- # handles non-array type indexing, e.g: vec3, mat33, etc
2552
- out = adj.add_builtin_call("extract", [target, *indices])
2553
-
2554
- return out
2555
-
2556
- # from a list of lists of indices, strip the first `count` indices
2557
- @staticmethod
2558
- def strip_indices(indices, count):
2559
- dim = count
2560
- while count > 0:
2561
- ij = indices[0]
2562
- indices = indices[1:]
2563
- count -= len(ij)
2564
-
2565
- # report straddling like in `arr2d[0][1,2]` as a syntax error
2566
- if count < 0:
2567
- raise WarpCodegenError(
2568
- f"Incorrect number of indices specified for array indexing, got {dim - count} indices for a {dim} dimensional array."
2569
- )
2570
-
2571
- return indices
2572
-
2573
- def recurse_subscript(adj, node, indices):
2574
- if isinstance(node, ast.Name):
2575
- target = adj.eval(node)
2576
- return target, indices
2577
-
2578
- if isinstance(node, ast.Subscript):
2579
- if hasattr(node.value, "attr") and node.value.attr == "adjoint":
2580
- return adj.eval(node), indices
2581
-
2582
- if isinstance(node.slice, ast.Tuple):
2583
- ij = node.slice.elts
2584
- elif isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Tuple):
2585
- # The node `ast.Index` is deprecated in Python 3.9.
2586
- ij = node.slice.value.elts
2587
- elif isinstance(node.slice, ast.ExtSlice):
2588
- # The node `ast.ExtSlice` is deprecated in Python 3.9.
2589
- ij = node.slice.dims
2590
- else:
2591
- ij = [node.slice]
2592
-
2593
- indices = [ij, *indices] # prepend
2594
-
2595
- target, indices = adj.recurse_subscript(node.value, indices)
2596
-
2597
- target_type = strip_reference(target.type)
2598
- if is_array(target_type):
2599
- flat_indices = [i for ij in indices for i in ij]
2600
- if len(flat_indices) > target_type.ndim:
2601
- target = adj.emit_indexing(target, flat_indices[: target_type.ndim])
2602
- indices = adj.strip_indices(indices, target_type.ndim)
2603
-
2604
- return target, indices
2605
-
2606
- target = adj.eval(node)
2607
- return target, indices
2608
-
2609
- # returns the object being indexed, and the list of indices
2610
- def eval_subscript(adj, node):
2611
- target, indices = adj.recurse_subscript(node, [])
2612
- flat_indices = [i for ij in indices for i in ij]
2613
- return target, flat_indices
2614
-
2615
- def emit_Subscript(adj, node):
2616
- if hasattr(node.value, "attr") and node.value.attr == "adjoint":
2617
- # handle adjoint of a variable, i.e. wp.adjoint[var]
2618
- node.slice.is_adjoint = True
2619
- var = adj.eval(node.slice)
2620
- var_name = var.label
2621
- var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False)
2622
- return var
2623
-
2624
- target, indices = adj.eval_subscript(node)
2625
-
2626
- return adj.emit_indexing(target, indices)
2627
-
2628
- def emit_Assign(adj, node):
2629
- if len(node.targets) != 1:
2630
- raise WarpCodegenError("Assigning the same value to multiple variables is not supported")
2631
-
2632
- # Check if the rhs corresponds to an unsupported construct.
2633
- # Tuples are supported in the context of assigning multiple variables
2634
- # at once, but not for simple assignments like `x = (1, 2, 3)`.
2635
- # Therefore, we need to catch this specific case here instead of
2636
- # more generally in `adj.eval()`.
2637
- if isinstance(node.value, ast.List):
2638
- raise WarpCodegenError(
2639
- "List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
2640
- )
2641
-
2642
- lhs = node.targets[0]
2643
-
2644
- if isinstance(lhs, ast.Tuple) and isinstance(node.value, ast.Call):
2645
- # record the expected number of outputs on the node
2646
- # we do this so we can decide which function to
2647
- # call based on the number of expected outputs
2648
- node.value.expects = len(lhs.elts)
2649
-
2650
- # evaluate rhs
2651
- if isinstance(lhs, ast.Tuple) and isinstance(node.value, ast.Tuple):
2652
- rhs = [adj.eval(v) for v in node.value.elts]
2653
- else:
2654
- rhs = adj.eval(node.value)
2655
-
2656
- # handle the case where we are assigning multiple output variables
2657
- if isinstance(lhs, ast.Tuple):
2658
- subtype = getattr(rhs, "type", None)
2659
-
2660
- if isinstance(subtype, warp.types.tuple_t):
2661
- if len(rhs.type.types) != len(lhs.elts):
2662
- raise WarpCodegenError(
2663
- f"Invalid number of values to unpack (expected {len(lhs.elts)}, got {len(rhs.type.types)})."
2664
- )
2665
- rhs = tuple(adj.add_builtin_call("extract", (rhs, adj.add_constant(i))) for i in range(len(lhs.elts)))
2666
-
2667
- names = []
2668
- for v in lhs.elts:
2669
- if isinstance(v, ast.Name):
2670
- names.append(v.id)
2671
- else:
2672
- raise WarpCodegenError(
2673
- "Multiple return functions can only assign to simple variables, e.g.: x, y = func()"
2674
- )
2675
-
2676
- if len(names) != len(rhs):
2677
- raise WarpCodegenError(
2678
- f"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {len(rhs)}, got {len(names)})"
2679
- )
2680
-
2681
- out = rhs
2682
- for name, rhs in zip(names, out):
2683
- if name in adj.symbols:
2684
- if not types_equal(rhs.type, adj.symbols[name].type):
2685
- raise WarpCodegenTypeError(
2686
- f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
2687
- )
2688
-
2689
- adj.symbols[name] = rhs
2690
-
2691
- # handles the case where we are assigning to an array index (e.g.: arr[i] = 2.0)
2692
- elif isinstance(lhs, ast.Subscript):
2693
- if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
2694
- # handle adjoint of a variable, i.e. wp.adjoint[var]
2695
- lhs.slice.is_adjoint = True
2696
- src_var = adj.eval(lhs.slice)
2697
- var = Var(f"adj_{src_var.label}", type=src_var.type, constant=None, prefix=False)
2698
- adj.add_forward(f"{var.emit()} = {rhs.emit()};")
2699
- return
2700
-
2701
- target, indices = adj.eval_subscript(lhs)
2702
-
2703
- target_type = strip_reference(target.type)
2704
- indices = adj.eval_indices(target_type, indices)
2705
-
2706
- if is_array(target_type):
2707
- adj.add_builtin_call("array_store", [target, *indices, rhs])
2708
-
2709
- if warp.config.verify_autograd_array_access:
2710
- kernel_name = adj.fun_name
2711
- filename = adj.filename
2712
- lineno = adj.lineno + adj.fun_lineno
2713
-
2714
- target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2715
-
2716
- elif is_tile(target_type):
2717
- adj.add_builtin_call("assign", [target, *indices, rhs])
2718
-
2719
- elif (
2720
- type_is_vector(target_type)
2721
- or type_is_quaternion(target_type)
2722
- or type_is_matrix(target_type)
2723
- or type_is_transformation(target_type)
2724
- ):
2725
- # recursively unwind AST, stopping at penultimate node
2726
- root = lhs
2727
- while hasattr(root.value, "value"):
2728
- root = root.value
2729
- # lhs is updating a variable adjoint (i.e. wp.adjoint[var])
2730
- if hasattr(root, "attr") and root.attr == "adjoint":
2731
- attr = adj.add_builtin_call("index", [target, *indices])
2732
- adj.add_builtin_call("store", [attr, rhs])
2733
- return
2734
-
2735
- # TODO: array vec component case
2736
- if is_reference(target.type):
2737
- attr = adj.add_builtin_call("indexref", [target, *indices])
2738
- adj.add_builtin_call("store", [attr, rhs])
2739
-
2740
- if warp.config.verbose and not adj.custom_reverse_mode:
2741
- lineno = adj.lineno + adj.fun_lineno
2742
- line = adj.source_lines[adj.lineno]
2743
- node_source = adj.get_node_source(lhs.value)
2744
- print(
2745
- f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
2746
- )
2747
- else:
2748
- if warp.config.enable_vector_component_overwrites:
2749
- out = adj.add_builtin_call("assign_copy", [target, *indices, rhs])
2750
-
2751
- # re-point target symbol to out var
2752
- for id in adj.symbols:
2753
- if adj.symbols[id] == target:
2754
- adj.symbols[id] = out
2755
- break
2756
- else:
2757
- adj.add_builtin_call("assign_inplace", [target, *indices, rhs])
2758
-
2759
- else:
2760
- raise WarpCodegenError(
2761
- f"Can only subscript assign array, vector, quaternion, transformation, and matrix types, got {target_type}"
2762
- )
2763
-
2764
- elif isinstance(lhs, ast.Name):
2765
- # symbol name
2766
- name = lhs.id
2767
-
2768
- # check type matches if symbol already defined
2769
- if name in adj.symbols:
2770
- if not types_equal(strip_reference(rhs.type), adj.symbols[name].type):
2771
- raise WarpCodegenTypeError(
2772
- f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
2773
- )
2774
-
2775
- if isinstance(node.value, ast.Tuple):
2776
- out = rhs
2777
- elif isinstance(rhs, Sequence):
2778
- out = adj.add_builtin_call("tuple", rhs)
2779
- elif isinstance(node.value, ast.Name) or is_reference(rhs.type):
2780
- out = adj.add_builtin_call("copy", [rhs])
2781
- else:
2782
- out = rhs
2783
-
2784
- # update symbol map (assumes lhs is a Name node)
2785
- adj.symbols[name] = out
2786
-
2787
- elif isinstance(lhs, ast.Attribute):
2788
- aggregate = adj.eval(lhs.value)
2789
- aggregate_type = strip_reference(aggregate.type)
2790
-
2791
- # assigning to a vector or quaternion component
2792
- if type_is_vector(aggregate_type) or type_is_quaternion(aggregate_type):
2793
- index = adj.vector_component_index(lhs.attr, aggregate_type)
2794
-
2795
- if is_reference(aggregate.type):
2796
- attr = adj.add_builtin_call("indexref", [aggregate, index])
2797
- adj.add_builtin_call("store", [attr, rhs])
2798
- else:
2799
- if warp.config.enable_vector_component_overwrites:
2800
- out = adj.add_builtin_call("assign_copy", [aggregate, index, rhs])
2801
-
2802
- # re-point target symbol to out var
2803
- for id in adj.symbols:
2804
- if adj.symbols[id] == aggregate:
2805
- adj.symbols[id] = out
2806
- break
2807
- else:
2808
- adj.add_builtin_call("assign_inplace", [aggregate, index, rhs])
2809
-
2810
- elif type_is_transformation(aggregate_type):
2811
- component = adj.transform_component(lhs.attr)
2812
-
2813
- # TODO: x[i,j].p = rhs case
2814
- if is_reference(aggregate.type):
2815
- raise WarpCodegenError(f"Error, assigning transform attribute {component} to an array element")
2816
-
2817
- if component == "p":
2818
- return adj.add_builtin_call("transform_set_translation", [aggregate, rhs])
2819
- else:
2820
- return adj.add_builtin_call("transform_set_rotation", [aggregate, rhs])
2821
-
2822
- else:
2823
- attr = adj.emit_Attribute(lhs)
2824
- if is_reference(attr.type):
2825
- adj.add_builtin_call("store", [attr, rhs])
2826
- else:
2827
- adj.add_builtin_call("assign", [attr, rhs])
2828
-
2829
- if warp.config.verbose and not adj.custom_reverse_mode:
2830
- lineno = adj.lineno + adj.fun_lineno
2831
- line = adj.source_lines[adj.lineno]
2832
- msg = f'Warning: detected mutated struct {attr.label} during function "{adj.fun_name}" at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n'
2833
- print(msg)
2834
-
2835
- else:
2836
- raise WarpCodegenError("Error, unsupported assignment statement.")
2837
-
2838
- def emit_Return(adj, node):
2839
- if node.value is None:
2840
- var = None
2841
- elif isinstance(node.value, ast.Tuple):
2842
- var = tuple(adj.eval(arg) for arg in node.value.elts)
2843
- else:
2844
- var = adj.eval(node.value)
2845
- if not isinstance(var, list) and not isinstance(var, tuple):
2846
- var = (var,)
2847
-
2848
- if adj.return_var is not None:
2849
- old_ctypes = tuple(v.ctype(value_type=True) for v in adj.return_var)
2850
- new_ctypes = tuple(v.ctype(value_type=True) for v in var)
2851
- if old_ctypes != new_ctypes:
2852
- raise WarpCodegenTypeError(
2853
- f"Error, function returned different types, previous: [{', '.join(old_ctypes)}], new [{', '.join(new_ctypes)}]"
2854
- )
2855
-
2856
- if var is not None:
2857
- adj.return_var = ()
2858
- for ret in var:
2859
- if is_reference(ret.type):
2860
- ret_var = adj.add_builtin_call("copy", [ret])
2861
- else:
2862
- ret_var = ret
2863
- adj.return_var += (ret_var,)
2864
-
2865
- adj.add_return(adj.return_var)
2866
-
2867
- def emit_AugAssign(adj, node):
2868
- lhs = node.target
2869
-
2870
- # replace augmented assignment with assignment statement + binary op (default behaviour)
2871
- def make_new_assign_statement():
2872
- new_node = ast.Assign(targets=[lhs], value=ast.BinOp(lhs, node.op, node.value))
2873
- adj.eval(new_node)
2874
-
2875
- rhs = adj.eval(node.value)
2876
-
2877
- if isinstance(lhs, ast.Subscript):
2878
- # wp.adjoint[var] appears in custom grad functions, and does not require
2879
- # special consideration in the AugAssign case
2880
- if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
2881
- make_new_assign_statement()
2882
- return
2883
-
2884
- target, indices = adj.eval_subscript(lhs)
2885
-
2886
- target_type = strip_reference(target.type)
2887
- indices = adj.eval_indices(target_type, indices)
2888
-
2889
- if is_array(target_type):
2890
- # target_types int8, uint8, int16, uint16 are not suitable for atomic array accumulation
2891
- if target_type.dtype in warp.types.non_atomic_types:
2892
- make_new_assign_statement()
2893
- return
2894
-
2895
- # the same holds true for vecs/mats/quats that are composed of these types
2896
- if (
2897
- type_is_vector(target_type.dtype)
2898
- or type_is_quaternion(target_type.dtype)
2899
- or type_is_matrix(target_type.dtype)
2900
- or type_is_transformation(target_type.dtype)
2901
- ):
2902
- dtype = getattr(target_type.dtype, "_wp_scalar_type_", None)
2903
- if dtype in warp.types.non_atomic_types:
2904
- make_new_assign_statement()
2905
- return
2906
-
2907
- kernel_name = adj.fun_name
2908
- filename = adj.filename
2909
- lineno = adj.lineno + adj.fun_lineno
2910
-
2911
- if isinstance(node.op, ast.Add):
2912
- adj.add_builtin_call("atomic_add", [target, *indices, rhs])
2913
-
2914
- if warp.config.verify_autograd_array_access:
2915
- target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2916
-
2917
- elif isinstance(node.op, ast.Sub):
2918
- adj.add_builtin_call("atomic_sub", [target, *indices, rhs])
2919
-
2920
- if warp.config.verify_autograd_array_access:
2921
- target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2922
- else:
2923
- if warp.config.verbose:
2924
- print(f"Warning: in-place op {node.op} is not differentiable")
2925
- make_new_assign_statement()
2926
- return
2927
-
2928
- elif (
2929
- type_is_vector(target_type)
2930
- or type_is_quaternion(target_type)
2931
- or type_is_matrix(target_type)
2932
- or type_is_transformation(target_type)
2933
- ):
2934
- if isinstance(node.op, ast.Add):
2935
- adj.add_builtin_call("add_inplace", [target, *indices, rhs])
2936
- elif isinstance(node.op, ast.Sub):
2937
- adj.add_builtin_call("sub_inplace", [target, *indices, rhs])
2938
- else:
2939
- if warp.config.verbose:
2940
- print(f"Warning: in-place op {node.op} is not differentiable")
2941
- make_new_assign_statement()
2942
- return
2943
-
2944
- elif is_tile(target.type):
2945
- if isinstance(node.op, ast.Add):
2946
- adj.add_builtin_call("tile_add_inplace", [target, *indices, rhs])
2947
- elif isinstance(node.op, ast.Sub):
2948
- adj.add_builtin_call("tile_sub_inplace", [target, *indices, rhs])
2949
- else:
2950
- if warp.config.verbose:
2951
- print(f"Warning: in-place op {node.op} is not differentiable")
2952
- make_new_assign_statement()
2953
- return
2954
-
2955
- else:
2956
- raise WarpCodegenError("Can only subscript in-place assign array, vector, quaternion, and matrix types")
2957
-
2958
- elif isinstance(lhs, ast.Name):
2959
- target = adj.eval(node.target)
2960
-
2961
- if is_tile(target.type) and is_tile(rhs.type):
2962
- if isinstance(node.op, ast.Add):
2963
- adj.add_builtin_call("add_inplace", [target, rhs])
2964
- elif isinstance(node.op, ast.Sub):
2965
- adj.add_builtin_call("sub_inplace", [target, rhs])
2966
- else:
2967
- make_new_assign_statement()
2968
- return
2969
- else:
2970
- make_new_assign_statement()
2971
- return
2972
-
2973
- # TODO
2974
- elif isinstance(lhs, ast.Attribute):
2975
- make_new_assign_statement()
2976
- return
2977
-
2978
- else:
2979
- make_new_assign_statement()
2980
- return
2981
-
2982
- def emit_Tuple(adj, node):
2983
- elements = tuple(adj.eval(x) for x in node.elts)
2984
- return adj.add_builtin_call("tuple", elements)
2985
-
2986
- def emit_Pass(adj, node):
2987
- pass
2988
-
2989
- node_visitors: ClassVar[dict[type[ast.AST], Callable]] = {
2990
- ast.FunctionDef: emit_FunctionDef,
2991
- ast.If: emit_If,
2992
- ast.IfExp: emit_IfExp,
2993
- ast.Compare: emit_Compare,
2994
- ast.BoolOp: emit_BoolOp,
2995
- ast.Name: emit_Name,
2996
- ast.Attribute: emit_Attribute,
2997
- ast.Constant: emit_Constant,
2998
- ast.BinOp: emit_BinOp,
2999
- ast.UnaryOp: emit_UnaryOp,
3000
- ast.While: emit_While,
3001
- ast.For: emit_For,
3002
- ast.Break: emit_Break,
3003
- ast.Continue: emit_Continue,
3004
- ast.Expr: emit_Expr,
3005
- ast.Call: emit_Call,
3006
- ast.Index: emit_Index, # Deprecated in 3.9
3007
- ast.Subscript: emit_Subscript,
3008
- ast.Assign: emit_Assign,
3009
- ast.Return: emit_Return,
3010
- ast.AugAssign: emit_AugAssign,
3011
- ast.Tuple: emit_Tuple,
3012
- ast.Pass: emit_Pass,
3013
- ast.Assert: emit_Assert,
3014
- }
3015
-
3016
- def eval(adj, node):
3017
- if hasattr(node, "lineno"):
3018
- adj.set_lineno(node.lineno - 1)
3019
-
3020
- try:
3021
- emit_node = adj.node_visitors[type(node)]
3022
- except KeyError as e:
3023
- type_name = type(node).__name__
3024
- namespace = "ast." if isinstance(node, ast.AST) else ""
3025
- raise WarpCodegenError(f"Construct `{namespace}{type_name}` not supported in kernels.") from e
3026
-
3027
- return emit_node(adj, node)
3028
-
3029
- # helper to evaluate expressions of the form
3030
- # obj1.obj2.obj3.attr in the function's global scope
3031
- def resolve_path(adj, path):
3032
- if len(path) == 0:
3033
- return None
3034
-
3035
- # if root is overshadowed by local symbols, bail out
3036
- if path[0] in adj.symbols:
3037
- return None
3038
-
3039
- # look up in closure/global variables
3040
- expr = adj.resolve_external_reference(path[0])
3041
-
3042
- # Support Warp types in kernels without the module suffix (e.g. v = vec3(0.0,0.2,0.4)):
3043
- if expr is None:
3044
- expr = getattr(warp, path[0], None)
3045
-
3046
- # look up in builtins
3047
- if expr is None:
3048
- expr = __builtins__.get(path[0])
3049
-
3050
- if expr is not None:
3051
- for i in range(1, len(path)):
3052
- if hasattr(expr, path[i]):
3053
- expr = getattr(expr, path[i])
3054
-
3055
- return expr
3056
-
3057
- # retrieves a dictionary of all closure and global variables and their values
3058
- # to be used in the evaluation context of wp.static() expressions
3059
- def get_static_evaluation_context(adj):
3060
- closure_vars = dict(
3061
- zip(
3062
- adj.func.__code__.co_freevars,
3063
- [c.cell_contents for c in (adj.func.__closure__ or [])],
3064
- )
3065
- )
3066
-
3067
- vars_dict = {}
3068
- vars_dict.update(adj.func.__globals__)
3069
- # variables captured in closure have precedence over global vars
3070
- vars_dict.update(closure_vars)
3071
-
3072
- return vars_dict
3073
-
3074
- def is_static_expression(adj, func):
3075
- return (
3076
- isinstance(func, types.FunctionType)
3077
- and func.__module__ == "warp.builtins"
3078
- and func.__qualname__ == "static"
3079
- )
3080
-
3081
- # verify the return type of a wp.static() expression is supported inside a Warp kernel
3082
- def verify_static_return_value(adj, value):
3083
- if value is None:
3084
- raise ValueError("None is returned")
3085
- if warp.types.is_value(value):
3086
- return True
3087
- if warp.types.is_array(value):
3088
- # more useful explanation for the common case of creating a Warp array
3089
- raise ValueError("a Warp array cannot be created inside Warp kernels")
3090
- if isinstance(value, str):
3091
- # we want to support cases such as `print(wp.static("test"))`
3092
- return True
3093
- if isinstance(value, warp.context.Function):
3094
- return True
3095
-
3096
- def verify_struct(s: StructInstance, attr_path: list[str]):
3097
- for key in s._cls.vars.keys():
3098
- v = getattr(s, key)
3099
- if issubclass(type(v), StructInstance):
3100
- verify_struct(v, [*attr_path, key])
3101
- else:
3102
- try:
3103
- adj.verify_static_return_value(v)
3104
- except ValueError as e:
3105
- raise ValueError(
3106
- f"the returned Warp struct contains a data type that cannot be constructed inside Warp kernels: {e} at {value._cls.key}.{'.'.join(attr_path)}"
3107
- ) from e
3108
-
3109
- if issubclass(type(value), StructInstance):
3110
- return verify_struct(value, [])
3111
-
3112
- raise ValueError(f"value of type {type(value)} cannot be constructed inside Warp kernels")
3113
-
3114
- # find the source code string of an AST node
3115
- @staticmethod
3116
- def extract_node_source_from_lines(source_lines, node) -> str | None:
3117
- if not hasattr(node, "lineno") or not hasattr(node, "col_offset"):
3118
- return None
3119
-
3120
- start_line = node.lineno - 1 # line numbers start at 1
3121
- start_col = node.col_offset
3122
-
3123
- if hasattr(node, "end_lineno") and hasattr(node, "end_col_offset"):
3124
- end_line = node.end_lineno - 1
3125
- end_col = node.end_col_offset
3126
- else:
3127
- # fallback for Python versions before 3.8
3128
- # we have to find the end line and column manually
3129
- end_line = start_line
3130
- end_col = start_col
3131
- parenthesis_count = 1
3132
- for lineno in range(start_line, len(source_lines)):
3133
- if lineno == start_line:
3134
- c_start = start_col
3135
- else:
3136
- c_start = 0
3137
- line = source_lines[lineno]
3138
- for i in range(c_start, len(line)):
3139
- c = line[i]
3140
- if c == "(":
3141
- parenthesis_count += 1
3142
- elif c == ")":
3143
- parenthesis_count -= 1
3144
- if parenthesis_count == 0:
3145
- end_col = i
3146
- end_line = lineno
3147
- break
3148
- if parenthesis_count == 0:
3149
- break
3150
-
3151
- if start_line == end_line:
3152
- # single-line expression
3153
- return source_lines[start_line][start_col:end_col]
3154
- else:
3155
- # multi-line expression
3156
- lines = []
3157
- # first line (from start_col to the end)
3158
- lines.append(source_lines[start_line][start_col:])
3159
- # middle lines (entire lines)
3160
- lines.extend(source_lines[start_line + 1 : end_line])
3161
- # last line (from the start to end_col)
3162
- lines.append(source_lines[end_line][:end_col])
3163
- return "\n".join(lines).strip()
3164
-
3165
- @staticmethod
3166
- def extract_lambda_source(func, only_body=False) -> str | None:
3167
- try:
3168
- source_lines = inspect.getsourcelines(func)[0]
3169
- source_lines[0] = source_lines[0][source_lines[0].index("lambda") :]
3170
- except OSError as e:
3171
- raise WarpCodegenError(
3172
- "Could not access lambda function source code. Please use a named function instead."
3173
- ) from e
3174
- source = "".join(source_lines)
3175
- source = source[source.index("lambda") :].rstrip()
3176
- # Remove trailing unbalanced parentheses
3177
- while source.count("(") < source.count(")"):
3178
- source = source[:-1]
3179
- # extract lambda expression up until a comma, e.g. in the case of
3180
- # "map(lambda a: (a + 2.0, a + 3.0), a, return_kernel=True)"
3181
- si = max(source.find(")"), source.find(":"))
3182
- ci = source.find(",", si)
3183
- if ci != -1:
3184
- source = source[:ci]
3185
- tree = ast.parse(source)
3186
- lambda_source = None
3187
- for node in ast.walk(tree):
3188
- if isinstance(node, ast.Lambda):
3189
- if only_body:
3190
- # extract the body of the lambda function
3191
- lambda_source = Adjoint.extract_node_source_from_lines(source_lines, node.body)
3192
- else:
3193
- # extract the entire lambda function
3194
- lambda_source = Adjoint.extract_node_source_from_lines(source_lines, node)
3195
- break
3196
- return lambda_source
3197
-
3198
- def extract_node_source(adj, node) -> str | None:
3199
- return adj.extract_node_source_from_lines(adj.source_lines, node)
3200
-
3201
- # handles a wp.static() expression and returns the resulting object and a string representing the code
3202
- # of the static expression
3203
- def evaluate_static_expression(adj, node) -> tuple[Any, str]:
3204
- if len(node.args) == 1:
3205
- static_code = adj.extract_node_source(node.args[0])
3206
- elif len(node.keywords) == 1:
3207
- static_code = adj.extract_node_source(node.keywords[0])
3208
- else:
3209
- raise WarpCodegenError("warp.static() requires a single argument or keyword")
3210
- if static_code is None:
3211
- raise WarpCodegenError("Error extracting source code from wp.static() expression")
3212
-
3213
- # Since this is an expression, we can enforce it to be defined on a single line.
3214
- static_code = static_code.replace("\n", "")
3215
- code_to_eval = static_code # code to be evaluated
3216
-
3217
- vars_dict = adj.get_static_evaluation_context()
3218
- # add constant variables to the static call context
3219
- constant_vars = {k: v.constant for k, v in adj.symbols.items() if isinstance(v, Var) and v.constant is not None}
3220
- vars_dict.update(constant_vars)
3221
-
3222
- # Replace all constant `len()` expressions with their value.
3223
- if "len" in static_code:
3224
- len_expr_ctx = vars_dict.copy()
3225
- constant_types = {k: v.type for k, v in adj.symbols.items() if isinstance(v, Var) and v.type is not None}
3226
- len_expr_ctx.update(constant_types)
3227
- len_expr_ctx.update({"len": warp.types.type_length})
3228
-
3229
- # We want to replace the expression code in-place,
3230
- # so reparse it to get the correct column info.
3231
- len_value_locs: list[tuple[int, int, int]] = []
3232
- expr_tree = ast.parse(static_code)
3233
- assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr)
3234
- expr_root = expr_tree.body[0].value
3235
- for expr_node in ast.walk(expr_root):
3236
- if (
3237
- isinstance(expr_node, ast.Call)
3238
- and getattr(expr_node.func, "id", None) == "len"
3239
- and len(expr_node.args) == 1
3240
- ):
3241
- len_expr = static_code[expr_node.col_offset : expr_node.end_col_offset]
3242
- try:
3243
- len_value = eval(len_expr, len_expr_ctx)
3244
- except Exception:
3245
- pass
3246
- else:
3247
- len_value_locs.append((len_value, expr_node.col_offset, expr_node.end_col_offset))
3248
-
3249
- if len_value_locs:
3250
- new_static_code = ""
3251
- loc = 0
3252
- for value, start, end in len_value_locs:
3253
- new_static_code += f"{static_code[loc:start]}{value}"
3254
- loc = end
3255
-
3256
- new_static_code += static_code[len_value_locs[-1][2] :]
3257
- code_to_eval = new_static_code
3258
-
3259
- try:
3260
- value = eval(code_to_eval, vars_dict)
3261
- if isinstance(value, (enum.IntEnum, enum.IntFlag)):
3262
- value = int(value)
3263
- if warp.config.verbose:
3264
- print(f"Evaluated static command: {static_code} = {value}")
3265
- except NameError as e:
3266
- raise WarpCodegenError(
3267
- f"Error evaluating static expression: {e}. Make sure all variables used in the static expression are constant."
3268
- ) from e
3269
- except Exception as e:
3270
- raise WarpCodegenError(
3271
- f"Error evaluating static expression: {e} while evaluating the following code generated from the static expression:\n{static_code}"
3272
- ) from e
3273
-
3274
- try:
3275
- adj.verify_static_return_value(value)
3276
- except ValueError as e:
3277
- raise WarpCodegenError(
3278
- f"Static expression returns an unsupported value: {e} while evaluating the following code generated from the static expression:\n{static_code}"
3279
- ) from e
3280
-
3281
- return value, static_code
3282
-
3283
- # try to replace wp.static() expressions by their evaluated value if the
3284
- # expression can be evaluated
3285
- def replace_static_expressions(adj):
3286
- class StaticExpressionReplacer(ast.NodeTransformer):
3287
- def visit_Call(self, node):
3288
- func, _ = adj.resolve_static_expression(node.func, eval_types=False)
3289
- if adj.is_static_expression(func):
3290
- try:
3291
- # the static expression will execute as long as the static expression is valid and
3292
- # only depends on global or captured variables
3293
- obj, code = adj.evaluate_static_expression(node)
3294
- if code is not None:
3295
- adj.static_expressions[code] = obj
3296
- if isinstance(obj, warp.context.Function):
3297
- name_node = ast.Name("__warp_func__")
3298
- # we add a pointer to the Warp function here so that we can refer to it later at
3299
- # codegen time (note that the function key itself is not sufficient to uniquely
3300
- # identify the function, as the function may be redefined between the current time
3301
- # of wp.static() declaration and the time of codegen during module building)
3302
- name_node.warp_func = obj
3303
- return ast.copy_location(name_node, node)
3304
- else:
3305
- return ast.copy_location(ast.Constant(value=obj), node)
3306
- except Exception:
3307
- # Ignoring failing static expressions should generally not be an issue because only
3308
- # one of these cases should be possible:
3309
- # 1) the static expression itself is invalid code, in which case the module cannot be
3310
- # built all,
3311
- # 2) the static expression contains a reference to a local (even if constant) variable
3312
- # (and is therefore not executable and raises this exception), in which
3313
- # case changing the constant, or the code affecting this constant, would lead to
3314
- # a different module hash anyway.
3315
- # In any case, we mark this Adjoint to have unresolvable static expressions.
3316
- # This will trigger a code generation step even if the module hash is unchanged.
3317
- adj.has_unresolved_static_expressions = True
3318
- pass
3319
-
3320
- return self.generic_visit(node)
3321
-
3322
- adj.tree = StaticExpressionReplacer().visit(adj.tree)
3323
-
3324
- # Evaluates a static expression that does not depend on runtime values
3325
- # if eval_types is True, try resolving the path using evaluated type information as well
3326
- def resolve_static_expression(adj, root_node, eval_types=True):
3327
- attributes = []
3328
-
3329
- node = root_node
3330
- while isinstance(node, ast.Attribute):
3331
- attributes.append(node.attr)
3332
- node = node.value
3333
-
3334
- if eval_types and isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
3335
- # support for operators returning modules
3336
- # i.e. operator_name(*operator_args).x.y.z
3337
- operator_args = node.args
3338
- operator_name = node.func.id
3339
-
3340
- if operator_name == "type":
3341
- if len(operator_args) != 1:
3342
- raise WarpCodegenError(f"type() operator expects exactly one argument, got {len(operator_args)}")
3343
-
3344
- # type() operator
3345
- var = adj.eval(operator_args[0])
3346
-
3347
- if isinstance(var, Var):
3348
- var_type = strip_reference(var.type)
3349
- # Allow accessing type attributes, for instance array.dtype
3350
- while attributes:
3351
- attr_name = attributes.pop()
3352
- var_type, prev_type = adj.resolve_type_attribute(var_type, attr_name), var_type
3353
-
3354
- if var_type is None:
3355
- raise WarpCodegenAttributeError(
3356
- f"{attr_name} is not an attribute of {type_repr(prev_type)}"
3357
- )
3358
-
3359
- return var_type, [str(var_type)]
3360
- else:
3361
- raise WarpCodegenError(f"Cannot deduce the type of {var}")
3362
-
3363
- # reverse list since ast presents it in backward order
3364
- path = [*reversed(attributes)]
3365
- if isinstance(node, ast.Name):
3366
- path.insert(0, node.id)
3367
-
3368
- # Try resolving path from captured context
3369
- captured_obj = adj.resolve_path(path)
3370
- if captured_obj is not None:
3371
- return captured_obj, path
3372
-
3373
- return None, path
3374
-
3375
- def resolve_external_reference(adj, name: str):
3376
- try:
3377
- # look up in closure variables
3378
- idx = adj.func.__code__.co_freevars.index(name)
3379
- obj = adj.func.__closure__[idx].cell_contents
3380
- except ValueError:
3381
- # look up in global variables
3382
- obj = adj.func.__globals__.get(name)
3383
- return obj
3384
-
3385
- # annotate generated code with the original source code line
3386
- def set_lineno(adj, lineno):
3387
- if adj.lineno is None or adj.lineno != lineno:
3388
- line = lineno + adj.fun_lineno
3389
- source = adj.source_lines[lineno].strip().ljust(80 - len(adj.indentation), " ")
3390
- adj.add_forward(f"// {source} <L {line}>")
3391
- adj.add_reverse(f"// adj: {source} <L {line}>")
3392
- adj.lineno = lineno
3393
-
3394
- def get_node_source(adj, node):
3395
- # return the Python code corresponding to the given AST node
3396
- return ast.get_source_segment(adj.source, node)
3397
-
3398
- def get_references(adj) -> tuple[dict[str, Any], dict[Any, Any], dict[warp.context.Function, Any]]:
3399
- """Traverses ``adj.tree`` and returns referenced constants, types, and user-defined functions."""
3400
-
3401
- local_variables = set() # Track local variables appearing on the LHS so we know when variables are shadowed
3402
-
3403
- constants: dict[str, Any] = {}
3404
- types: dict[Struct | type, Any] = {}
3405
- functions: dict[warp.context.Function, Any] = {}
3406
-
3407
- for node in ast.walk(adj.tree):
3408
- if isinstance(node, ast.Name) and node.id not in local_variables:
3409
- # look up in closure/global variables
3410
- obj = adj.resolve_external_reference(node.id)
3411
- if warp.types.is_value(obj):
3412
- constants[node.id] = obj
3413
-
3414
- elif isinstance(node, ast.Attribute):
3415
- obj, path = adj.resolve_static_expression(node, eval_types=False)
3416
- if warp.types.is_value(obj):
3417
- constants[".".join(path)] = obj
3418
-
3419
- elif isinstance(node, ast.Call):
3420
- func, _ = adj.resolve_static_expression(node.func, eval_types=False)
3421
- if isinstance(func, warp.context.Function) and not func.is_builtin():
3422
- # calling user-defined function
3423
- functions[func] = None
3424
- elif isinstance(func, Struct):
3425
- # calling struct constructor
3426
- types[func] = None
3427
- elif isinstance(func, type) and warp.types.type_is_value(func):
3428
- # calling value type constructor
3429
- types[func] = None
3430
-
3431
- elif isinstance(node, ast.Assign):
3432
- # Add the LHS names to the local_variables so we know any subsequent uses are shadowed
3433
- lhs = node.targets[0]
3434
- if isinstance(lhs, ast.Tuple):
3435
- for v in lhs.elts:
3436
- if isinstance(v, ast.Name):
3437
- local_variables.add(v.id)
3438
- elif isinstance(lhs, ast.Name):
3439
- local_variables.add(lhs.id)
3440
-
3441
- return constants, types, functions
3442
-
3443
-
3444
- # ----------------
3445
- # code generation
3446
-
3447
- cpu_module_header = """
3448
- #define WP_TILE_BLOCK_DIM {block_dim}
3449
- #define WP_NO_CRT
3450
- #include "builtin.h"
3451
-
3452
- // avoid namespacing of float type for casting to float type, this is to avoid wp::float(x), which is not valid in C++
3453
- #define float(x) cast_float(x)
3454
- #define adj_float(x, adj_x, adj_ret) adj_cast_float(x, adj_x, adj_ret)
3455
-
3456
- #define int(x) cast_int(x)
3457
- #define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
3458
-
3459
- #define builtin_tid1d() wp::tid(task_index, dim)
3460
- #define builtin_tid2d(x, y) wp::tid(x, y, task_index, dim)
3461
- #define builtin_tid3d(x, y, z) wp::tid(x, y, z, task_index, dim)
3462
- #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, task_index, dim)
3463
-
3464
- #define builtin_block_dim() wp::block_dim()
3465
-
3466
- """
3467
-
3468
- cuda_module_header = """
3469
- #define WP_TILE_BLOCK_DIM {block_dim}
3470
- #define WP_NO_CRT
3471
- #include "builtin.h"
3472
-
3473
- // Map wp.breakpoint() to a device brkpt at the call site so cuda-gdb attributes the stop to the generated .cu line
3474
- #if defined(__CUDACC__) && !defined(_MSC_VER)
3475
- #define __debugbreak() __brkpt()
3476
- #endif
3477
-
3478
- // avoid namespacing of float type for casting to float type, this is to avoid wp::float(x), which is not valid in C++
3479
- #define float(x) cast_float(x)
3480
- #define adj_float(x, adj_x, adj_ret) adj_cast_float(x, adj_x, adj_ret)
3481
-
3482
- #define int(x) cast_int(x)
3483
- #define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
3484
-
3485
- #define builtin_tid1d() wp::tid(_idx, dim)
3486
- #define builtin_tid2d(x, y) wp::tid(x, y, _idx, dim)
3487
- #define builtin_tid3d(x, y, z) wp::tid(x, y, z, _idx, dim)
3488
- #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, _idx, dim)
3489
-
3490
- #define builtin_block_dim() wp::block_dim()
3491
-
3492
- """
3493
-
3494
- struct_template = """
3495
- struct {name}
3496
- {{
3497
- {struct_body}
3498
-
3499
- {defaulted_constructor_def}
3500
- CUDA_CALLABLE {name}({forward_args})
3501
- {forward_initializers}
3502
- {{
3503
- }}
3504
-
3505
- CUDA_CALLABLE {name}& operator += (const {name}& rhs)
3506
- {{{prefix_add_body}
3507
- return *this;}}
3508
-
3509
- }};
3510
-
3511
- static CUDA_CALLABLE void adj_{name}({reverse_args})
3512
- {{
3513
- {reverse_body}}}
3514
-
3515
- // Required when compiling adjoints.
3516
- CUDA_CALLABLE {name} add(const {name}& a, const {name}& b)
3517
- {{
3518
- return {name}();
3519
- }}
3520
-
3521
- CUDA_CALLABLE void adj_atomic_add({name}* p, {name} t)
3522
- {{
3523
- {atomic_add_body}}}
3524
-
3525
-
3526
- """
3527
-
3528
- cpu_forward_function_template = """
3529
- // {filename}:{lineno}
3530
- static {return_type} {name}(
3531
- {forward_args})
3532
- {{
3533
- {forward_body}}}
3534
-
3535
- """
3536
-
3537
- cpu_reverse_function_template = """
3538
- // {filename}:{lineno}
3539
- static void adj_{name}(
3540
- {reverse_args})
3541
- {{
3542
- {reverse_body}}}
3543
-
3544
- """
3545
-
3546
- cuda_forward_function_template = """
3547
- // {filename}:{lineno}
3548
- {line_directive}static CUDA_CALLABLE {return_type} {name}(
3549
- {forward_args})
3550
- {{
3551
- {forward_body}{line_directive}}}
3552
-
3553
- """
3554
-
3555
- cuda_reverse_function_template = """
3556
- // {filename}:{lineno}
3557
- {line_directive}static CUDA_CALLABLE void adj_{name}(
3558
- {reverse_args})
3559
- {{
3560
- {reverse_body}{line_directive}}}
3561
-
3562
- """
3563
-
3564
- cuda_kernel_template_forward = """
3565
-
3566
- {line_directive}extern "C" __global__ void {name}_cuda_kernel_forward(
3567
- {forward_args})
3568
- {{
3569
- {line_directive} for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3570
- {line_directive} _idx < dim.size;
3571
- {line_directive} _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3572
- {{
3573
- // reset shared memory allocator
3574
- {line_directive} wp::tile_alloc_shared(0, true);
3575
-
3576
- {forward_body}{line_directive} }}
3577
- {line_directive}}}
3578
-
3579
- """
3580
-
3581
- cuda_kernel_template_backward = """
3582
-
3583
- {line_directive}extern "C" __global__ void {name}_cuda_kernel_backward(
3584
- {reverse_args})
3585
- {{
3586
- {line_directive} for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3587
- {line_directive} _idx < dim.size;
3588
- {line_directive} _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3589
- {{
3590
- // reset shared memory allocator
3591
- {line_directive} wp::tile_alloc_shared(0, true);
3592
-
3593
- {reverse_body}{line_directive} }}
3594
- {line_directive}}}
3595
-
3596
- """
3597
-
3598
- cpu_kernel_template_forward = """
3599
-
3600
- void {name}_cpu_kernel_forward(
3601
- {forward_args},
3602
- wp_args_{name} *_wp_args)
3603
- {{
3604
- {forward_body}}}
3605
-
3606
- """
3607
-
3608
- cpu_kernel_template_backward = """
3609
-
3610
- void {name}_cpu_kernel_backward(
3611
- {reverse_args},
3612
- wp_args_{name} *_wp_args,
3613
- wp_args_{name} *_wp_adj_args)
3614
- {{
3615
- {reverse_body}}}
3616
-
3617
- """
3618
-
3619
- cpu_module_template_forward = """
3620
-
3621
- extern "C" {{
3622
-
3623
- // Python CPU entry points
3624
- WP_API void {name}_cpu_forward(
3625
- wp::launch_bounds_t dim,
3626
- wp_args_{name} *_wp_args)
3627
- {{
3628
- for (size_t task_index = 0; task_index < dim.size; ++task_index)
3629
- {{
3630
- // init shared memory allocator
3631
- wp::tile_alloc_shared(0, true);
3632
-
3633
- {name}_cpu_kernel_forward(dim, task_index, _wp_args);
3634
-
3635
- // check shared memory allocator
3636
- wp::tile_alloc_shared(0, false, true);
3637
-
3638
- }}
3639
- }}
3640
-
3641
- }} // extern C
3642
-
3643
- """
3644
-
3645
- cpu_module_template_backward = """
3646
-
3647
- extern "C" {{
3648
-
3649
- WP_API void {name}_cpu_backward(
3650
- wp::launch_bounds_t dim,
3651
- wp_args_{name} *_wp_args,
3652
- wp_args_{name} *_wp_adj_args)
3653
- {{
3654
- for (size_t task_index = 0; task_index < dim.size; ++task_index)
3655
- {{
3656
- // initialize shared memory allocator
3657
- wp::tile_alloc_shared(0, true);
3658
-
3659
- {name}_cpu_kernel_backward(dim, task_index, _wp_args, _wp_adj_args);
3660
-
3661
- // check shared memory allocator
3662
- wp::tile_alloc_shared(0, false, true);
3663
- }}
3664
- }}
3665
-
3666
- }} // extern C
3667
-
3668
- """
3669
-
3670
-
3671
- # converts a constant Python value to equivalent C-repr
3672
- def constant_str(value):
3673
- value_type = type(value)
3674
-
3675
- if value_type == bool or value_type == builtins.bool:
3676
- if value:
3677
- return "true"
3678
- else:
3679
- return "false"
3680
-
3681
- elif value_type == str:
3682
- # ensure constant strings are correctly escaped
3683
- return '"' + str(value.encode("unicode-escape").decode()) + '"'
3684
-
3685
- elif isinstance(value, ctypes.Array):
3686
- if value_type._wp_scalar_type_ == float16:
3687
- # special case for float16, which is stored as uint16 in the ctypes.Array
3688
- from warp.context import runtime
3689
-
3690
- scalar_value = runtime.core.wp_half_bits_to_float
3691
- else:
3692
-
3693
- def scalar_value(x):
3694
- return x
3695
-
3696
- # list of scalar initializer values
3697
- initlist = []
3698
- for i in range(value._length_):
3699
- x = ctypes.Array.__getitem__(value, i)
3700
- initlist.append(str(scalar_value(x)).lower())
3701
-
3702
- if value._wp_scalar_type_ is bool:
3703
- dtypestr = f"wp::initializer_array<{value._length_},{value._wp_scalar_type_.__name__}>"
3704
- else:
3705
- dtypestr = f"wp::initializer_array<{value._length_},wp::{value._wp_scalar_type_.__name__}>"
3706
-
3707
- # construct value from initializer array, e.g. wp::initializer_array<4,wp::float32>{1.0, 2.0, 3.0, 4.0}
3708
- return f"{dtypestr}{{{', '.join(initlist)}}}"
3709
-
3710
- elif value_type in warp.types.scalar_types:
3711
- # make sure we emit the value of objects, e.g. uint32
3712
- return str(value.value)
3713
-
3714
- elif issubclass(value_type, warp.codegen.StructInstance):
3715
- # constant struct instance
3716
- arg_strs = []
3717
- for key, var in value._cls.vars.items():
3718
- attr = getattr(value, key)
3719
- arg_strs.append(f"{Var.type_to_ctype(var.type)}({constant_str(attr)})")
3720
- arg_str = ", ".join(arg_strs)
3721
- return f"{value.native_name}({arg_str})"
3722
-
3723
- elif value == math.inf:
3724
- return "INFINITY"
3725
-
3726
- elif math.isnan(value):
3727
- return "NAN"
3728
-
3729
- else:
3730
- # otherwise just convert constant to string
3731
- return str(value)
3732
-
3733
-
3734
- def indent(args, stops=1):
3735
- sep = ",\n"
3736
- for _i in range(stops):
3737
- sep += " "
3738
-
3739
- # return sep + args.replace(", ", "," + sep)
3740
- return sep.join(args)
3741
-
3742
-
3743
- # generates a C function name based on the python function name
3744
- def make_full_qualified_name(func: Union[str, Callable]) -> str:
3745
- if not isinstance(func, str):
3746
- func = func.__qualname__
3747
- return re.sub("[^0-9a-zA-Z_]+", "", func.replace(".", "__"))
3748
-
3749
-
3750
- def codegen_struct(struct, device="cpu", indent_size=4):
3751
- name = struct.native_name
3752
-
3753
- body = []
3754
- indent_block = " " * indent_size
3755
-
3756
- if len(struct.vars) > 0:
3757
- for label, var in struct.vars.items():
3758
- body.append(var.ctype() + " " + label + ";\n")
3759
- else:
3760
- # for empty structs, emit the dummy attribute to avoid any compiler-specific alignment issues
3761
- body.append("char _dummy_;\n")
3762
-
3763
- forward_args = []
3764
- reverse_args = []
3765
-
3766
- forward_initializers = []
3767
- reverse_body = []
3768
- atomic_add_body = []
3769
- prefix_add_body = []
3770
-
3771
- # forward args
3772
- for label, var in struct.vars.items():
3773
- var_ctype = var.ctype()
3774
- default_arg_def = " = {}" if forward_args else ""
3775
- forward_args.append(f"{var_ctype} const& {label}{default_arg_def}")
3776
- reverse_args.append(f"{var_ctype} const&")
3777
-
3778
- namespace = "wp::" if var_ctype.startswith("wp::") or var_ctype == "bool" else ""
3779
- atomic_add_body.append(f"{indent_block}{namespace}adj_atomic_add(&p->{label}, t.{label});\n")
3780
-
3781
- prefix = f"{indent_block}," if forward_initializers else ":"
3782
- forward_initializers.append(f"{indent_block}{prefix} {label}{{{label}}}\n")
3783
-
3784
- # prefix-add operator
3785
- for label, var in struct.vars.items():
3786
- if not is_array(var.type):
3787
- prefix_add_body.append(f"{indent_block}{label} += rhs.{label};\n")
3788
-
3789
- # reverse args
3790
- for label, var in struct.vars.items():
3791
- reverse_args.append(var.ctype() + " & adj_" + label)
3792
- if is_array(var.type):
3793
- reverse_body.append(f"{indent_block}adj_{label} = adj_ret.{label};\n")
3794
- else:
3795
- reverse_body.append(f"{indent_block}adj_{label} += adj_ret.{label};\n")
3796
-
3797
- reverse_args.append(name + " & adj_ret")
3798
-
3799
- # explicitly defaulted default constructor if no default constructor has been defined
3800
- defaulted_constructor_def = f"{name}() = default;" if forward_args else ""
3801
-
3802
- return struct_template.format(
3803
- name=name,
3804
- struct_body="".join([indent_block + l for l in body]),
3805
- forward_args=indent(forward_args),
3806
- forward_initializers="".join(forward_initializers),
3807
- reverse_args=indent(reverse_args),
3808
- reverse_body="".join(reverse_body),
3809
- prefix_add_body="".join(prefix_add_body),
3810
- atomic_add_body="".join(atomic_add_body),
3811
- defaulted_constructor_def=defaulted_constructor_def,
3812
- )
3813
-
3814
-
3815
- def codegen_func_forward(adj, func_type="kernel", device="cpu"):
3816
- if device == "cpu":
3817
- indent = 4
3818
- elif device == "cuda":
3819
- if func_type == "kernel":
3820
- indent = 8
3821
- else:
3822
- indent = 4
3823
- else:
3824
- raise ValueError(f"Device {device} not supported for codegen")
3825
-
3826
- indent_block = " " * indent
3827
-
3828
- lines = []
3829
-
3830
- # argument vars
3831
- if device == "cpu" and func_type == "kernel":
3832
- lines += ["//---------\n"]
3833
- lines += ["// argument vars\n"]
3834
-
3835
- for var in adj.args:
3836
- lines += [f"{var.ctype()} {var.emit()} = _wp_args->{var.label};\n"]
3837
-
3838
- # primal vars
3839
- lines += ["//---------\n"]
3840
- lines += ["// primal vars\n"]
3841
-
3842
- for var in adj.variables:
3843
- if is_tile(var.type):
3844
- lines += [f"{var.ctype()} {var.emit()} = {var.type.cinit(requires_grad=False)};\n"]
3845
- elif var.constant is None:
3846
- lines += [f"{var.ctype()} {var.emit()};\n"]
3847
- else:
3848
- lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
3849
-
3850
- if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
3851
- lines.insert(-1, f"{line_directive}\n")
3852
-
3853
- # forward pass
3854
- lines += ["//---------\n"]
3855
- lines += ["// forward\n"]
3856
-
3857
- for f in adj.blocks[0].body_forward:
3858
- if func_type == "kernel" and device == "cuda" and f.lstrip().startswith("return;"):
3859
- # Use of grid-stride loops in CUDA kernels requires that we convert return; to continue;
3860
- lines += [f.replace("return;", "continue;") + "\n"]
3861
- else:
3862
- lines += [f + "\n"]
3863
-
3864
- return "".join(l.lstrip() if l.lstrip().startswith("#line") else indent_block + l for l in lines)
3865
-
3866
-
3867
- def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
3868
- if device == "cpu":
3869
- indent = 4
3870
- elif device == "cuda":
3871
- if func_type == "kernel":
3872
- indent = 8
3873
- else:
3874
- indent = 4
3875
- else:
3876
- raise ValueError(f"Device {device} not supported for codegen")
3877
-
3878
- indent_block = " " * indent
3879
-
3880
- lines = []
3881
-
3882
- # argument vars
3883
- if device == "cpu" and func_type == "kernel":
3884
- lines += ["//---------\n"]
3885
- lines += ["// argument vars\n"]
3886
-
3887
- for var in adj.args:
3888
- lines += [f"{var.ctype()} {var.emit()} = _wp_args->{var.label};\n"]
3889
-
3890
- for var in adj.args:
3891
- lines += [f"{var.ctype()} {var.emit_adj()} = _wp_adj_args->{var.label};\n"]
3892
-
3893
- # primal vars
3894
- lines += ["//---------\n"]
3895
- lines += ["// primal vars\n"]
3896
-
3897
- for var in adj.variables:
3898
- if is_tile(var.type):
3899
- lines += [f"{var.ctype()} {var.emit()} = {var.type.cinit(requires_grad=True)};\n"]
3900
- elif var.constant is None:
3901
- lines += [f"{var.ctype()} {var.emit()};\n"]
3902
- else:
3903
- lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
3904
-
3905
- if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
3906
- lines.insert(-1, f"{line_directive}\n")
3907
-
3908
- # dual vars
3909
- lines += ["//---------\n"]
3910
- lines += ["// dual vars\n"]
3911
-
3912
- for var in adj.variables:
3913
- name = var.emit_adj()
3914
- ctype = var.ctype(value_type=True)
3915
-
3916
- if is_tile(var.type):
3917
- if var.type.storage == "register":
3918
- lines += [
3919
- f"{var.type.ctype()} {name}(0.0);\n"
3920
- ] # reverse mode tiles alias the forward vars since shared tiles store both primal/dual vars together
3921
- elif var.type.storage == "shared":
3922
- lines += [
3923
- f"{var.type.ctype()}& {name} = {var.emit()};\n"
3924
- ] # reverse mode tiles alias the forward vars since shared tiles store both primal/dual vars together
3925
- else:
3926
- lines += [f"{ctype} {name} = {{}};\n"]
3927
-
3928
- if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
3929
- lines.insert(-1, f"{line_directive}\n")
3930
-
3931
- # forward pass
3932
- lines += ["//---------\n"]
3933
- lines += ["// forward\n"]
3934
-
3935
- for f in adj.blocks[0].body_replay:
3936
- lines += [f + "\n"]
3937
-
3938
- # reverse pass
3939
- lines += ["//---------\n"]
3940
- lines += ["// reverse\n"]
3941
-
3942
- for l in reversed(adj.blocks[0].body_reverse):
3943
- lines += [l + "\n"]
3944
-
3945
- # In grid-stride kernels the reverse body is in a for loop
3946
- if device == "cuda" and func_type == "kernel":
3947
- lines += ["continue;\n"]
3948
- else:
3949
- lines += ["return;\n"]
3950
-
3951
- return "".join(l.lstrip() if l.lstrip().startswith("#line") else indent_block + l for l in lines)
3952
-
3953
-
3954
- def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3955
- if options is None:
3956
- options = {}
3957
-
3958
- if adj.return_var is not None and "return" in adj.arg_types:
3959
- if get_origin(adj.arg_types["return"]) is tuple:
3960
- if len(get_args(adj.arg_types["return"])) != len(adj.return_var):
3961
- raise WarpCodegenError(
3962
- f"The function `{adj.fun_name}` has its return type "
3963
- f"annotated as a tuple of {len(get_args(adj.arg_types['return']))} elements "
3964
- f"but the code returns {len(adj.return_var)} values."
3965
- )
3966
- elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var), match_generic=True):
3967
- raise WarpCodegenError(
3968
- f"The function `{adj.fun_name}` has its return type "
3969
- f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
3970
- f"but the code returns a tuple with types `({', '.join(warp.context.type_str(x.type) for x in adj.return_var)})`."
3971
- )
3972
- elif len(adj.return_var) > 1 and get_origin(adj.arg_types["return"]) is not tuple:
3973
- raise WarpCodegenError(
3974
- f"The function `{adj.fun_name}` has its return type "
3975
- f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
3976
- f"but the code returns {len(adj.return_var)} values."
3977
- )
3978
- elif not types_equal(adj.arg_types["return"], adj.return_var[0].type):
3979
- raise WarpCodegenError(
3980
- f"The function `{adj.fun_name}` has its return type "
3981
- f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
3982
- f"but the code returns a value of type `{warp.context.type_str(adj.return_var[0].type)}`."
3983
- )
3984
- elif (
3985
- isinstance(adj.return_var[0].type, warp.types.fixedarray)
3986
- and type(adj.arg_types["return"]) is warp.types.array
3987
- ):
3988
- # If the return statement yields a `fixedarray` while the function is annotated
3989
- # to return a standard `array`, then raise an error since the `fixedarray` storage
3990
- # allocated on the stack will be freed once the function exits, meaning that the
3991
- # resulting `array` instance will point to an invalid data.
3992
- raise WarpCodegenError(
3993
- f"The function `{adj.fun_name}` returns a fixed-size array "
3994
- f"whereas it has its return type annotated as "
3995
- f"`{warp.context.type_str(adj.arg_types['return'])}`."
3996
- )
3997
-
3998
- # Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
3999
- # This is used as a catch-all C-to-Python source line mapping for any code that does not have
4000
- # a direct mapping to a Python source line.
4001
- func_line_directive = ""
4002
- if line_directive := adj.get_line_directive("", adj.fun_def_lineno - 1):
4003
- func_line_directive = f"{line_directive}\n"
4004
-
4005
- # forward header
4006
- if adj.return_var is not None and len(adj.return_var) == 1:
4007
- return_type = adj.return_var[0].ctype()
4008
- else:
4009
- return_type = "void"
4010
-
4011
- has_multiple_outputs = adj.return_var is not None and len(adj.return_var) != 1
4012
-
4013
- forward_args = []
4014
- reverse_args = []
4015
-
4016
- # forward args
4017
- for i, arg in enumerate(adj.args):
4018
- s = f"{arg.ctype()} {arg.emit()}"
4019
- forward_args.append(s)
4020
- if not adj.custom_reverse_mode or i < adj.custom_reverse_num_input_args:
4021
- reverse_args.append(s)
4022
- if has_multiple_outputs:
4023
- for i, arg in enumerate(adj.return_var):
4024
- forward_args.append(arg.ctype() + " & ret_" + str(i))
4025
- reverse_args.append(arg.ctype() + " & ret_" + str(i))
4026
-
4027
- # reverse args
4028
- for i, arg in enumerate(adj.args):
4029
- if adj.custom_reverse_mode and i >= adj.custom_reverse_num_input_args:
4030
- break
4031
- # indexed array gradients are regular arrays
4032
- if isinstance(arg.type, indexedarray):
4033
- _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
4034
- reverse_args.append(_arg.ctype() + " & adj_" + arg.label)
4035
- else:
4036
- reverse_args.append(arg.ctype() + " & adj_" + arg.label)
4037
- if has_multiple_outputs:
4038
- for i, arg in enumerate(adj.return_var):
4039
- reverse_args.append(arg.ctype() + " & adj_ret_" + str(i))
4040
- elif return_type != "void":
4041
- reverse_args.append(return_type + " & adj_ret")
4042
- # custom output reverse args (user-declared)
4043
- if adj.custom_reverse_mode:
4044
- for arg in adj.args[adj.custom_reverse_num_input_args :]:
4045
- reverse_args.append(f"{arg.ctype()} & {arg.emit()}")
4046
-
4047
- if device == "cpu":
4048
- forward_template = cpu_forward_function_template
4049
- reverse_template = cpu_reverse_function_template
4050
- elif device == "cuda":
4051
- forward_template = cuda_forward_function_template
4052
- reverse_template = cuda_reverse_function_template
4053
- else:
4054
- raise ValueError(f"Device {device} is not supported")
4055
-
4056
- # codegen body
4057
- forward_body = codegen_func_forward(adj, func_type="function", device=device)
4058
-
4059
- s = ""
4060
- if not adj.skip_forward_codegen:
4061
- s += forward_template.format(
4062
- name=c_func_name,
4063
- return_type=return_type,
4064
- forward_args=indent(forward_args),
4065
- forward_body=forward_body,
4066
- filename=adj.filename,
4067
- lineno=adj.fun_lineno,
4068
- line_directive=func_line_directive,
4069
- )
4070
-
4071
- if not adj.skip_reverse_codegen:
4072
- if adj.custom_reverse_mode:
4073
- reverse_body = "\t// user-defined adjoint code\n" + forward_body
4074
- else:
4075
- if options.get("enable_backward", True) and adj.used_by_backward_kernel:
4076
- reverse_body = codegen_func_reverse(adj, func_type="function", device=device)
4077
- else:
4078
- reverse_body = '\t// reverse mode disabled (module option "enable_backward" is False or no dependent kernel found with "enable_backward")\n'
4079
- s += reverse_template.format(
4080
- name=c_func_name,
4081
- return_type=return_type,
4082
- reverse_args=indent(reverse_args),
4083
- forward_body=forward_body,
4084
- reverse_body=reverse_body,
4085
- filename=adj.filename,
4086
- lineno=adj.fun_lineno,
4087
- line_directive=func_line_directive,
4088
- )
4089
-
4090
- return s
4091
-
4092
-
4093
- def codegen_snippet(adj, name, snippet, adj_snippet, replay_snippet):
4094
- if adj.return_var is not None and len(adj.return_var) == 1:
4095
- return_type = adj.return_var[0].ctype()
4096
- else:
4097
- return_type = "void"
4098
-
4099
- forward_args = []
4100
- reverse_args = []
4101
-
4102
- # forward args
4103
- for _i, arg in enumerate(adj.args):
4104
- s = f"{arg.ctype()} {arg.emit().replace('var_', '')}"
4105
- forward_args.append(s)
4106
- reverse_args.append(s)
4107
-
4108
- # reverse args
4109
- for _i, arg in enumerate(adj.args):
4110
- if isinstance(arg.type, indexedarray):
4111
- _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
4112
- reverse_args.append(_arg.ctype() + " & adj_" + arg.label)
4113
- else:
4114
- reverse_args.append(arg.ctype() + " & adj_" + arg.label)
4115
- if return_type != "void":
4116
- reverse_args.append(return_type + " & adj_ret")
4117
-
4118
- forward_template = cuda_forward_function_template
4119
- replay_template = cuda_forward_function_template
4120
- reverse_template = cuda_reverse_function_template
4121
-
4122
- s = ""
4123
- s += forward_template.format(
4124
- name=name,
4125
- return_type=return_type,
4126
- forward_args=indent(forward_args),
4127
- forward_body=snippet,
4128
- filename=adj.filename,
4129
- lineno=adj.fun_lineno,
4130
- line_directive="",
4131
- )
4132
-
4133
- if replay_snippet is not None:
4134
- s += replay_template.format(
4135
- name="replay_" + name,
4136
- return_type=return_type,
4137
- forward_args=indent(forward_args),
4138
- forward_body=replay_snippet,
4139
- filename=adj.filename,
4140
- lineno=adj.fun_lineno,
4141
- line_directive="",
4142
- )
4143
-
4144
- if adj_snippet:
4145
- reverse_body = adj_snippet
4146
- else:
4147
- reverse_body = ""
4148
-
4149
- s += reverse_template.format(
4150
- name=name,
4151
- return_type=return_type,
4152
- reverse_args=indent(reverse_args),
4153
- forward_body=snippet,
4154
- reverse_body=reverse_body,
4155
- filename=adj.filename,
4156
- lineno=adj.fun_lineno,
4157
- line_directive="",
4158
- )
4159
-
4160
- return s
4161
-
4162
-
4163
- def codegen_kernel(kernel, device, options):
4164
- # Update the module's options with the ones defined on the kernel, if any.
4165
- options = dict(options)
4166
- options.update(kernel.options)
4167
-
4168
- adj = kernel.adj
4169
-
4170
- args_struct = ""
4171
- if device == "cpu":
4172
- args_struct = f"struct wp_args_{kernel.get_mangled_name()} {{\n"
4173
- for i in adj.args:
4174
- args_struct += f" {i.ctype()} {i.label};\n"
4175
- args_struct += "};\n"
4176
-
4177
- # Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
4178
- # This is used as a catch-all C-to-Python source line mapping for any code that does not have
4179
- # a direct mapping to a Python source line.
4180
- func_line_directive = ""
4181
- if line_directive := adj.get_line_directive("", adj.fun_def_lineno - 1):
4182
- func_line_directive = f"{line_directive}\n"
4183
-
4184
- if device == "cpu":
4185
- template_forward = cpu_kernel_template_forward
4186
- template_backward = cpu_kernel_template_backward
4187
- elif device == "cuda":
4188
- template_forward = cuda_kernel_template_forward
4189
- template_backward = cuda_kernel_template_backward
4190
- else:
4191
- raise ValueError(f"Device {device} is not supported")
4192
-
4193
- template = ""
4194
- template_fmt_args = {
4195
- "name": kernel.get_mangled_name(),
4196
- }
4197
-
4198
- # build forward signature
4199
- forward_args = ["wp::launch_bounds_t dim"]
4200
- if device == "cpu":
4201
- forward_args.append("size_t task_index")
4202
- else:
4203
- for arg in adj.args:
4204
- forward_args.append(arg.ctype() + " var_" + arg.label)
4205
-
4206
- forward_body = codegen_func_forward(adj, func_type="kernel", device=device)
4207
- template_fmt_args.update(
4208
- {
4209
- "forward_args": indent(forward_args),
4210
- "forward_body": forward_body,
4211
- "line_directive": func_line_directive,
4212
- }
4213
- )
4214
- template += template_forward
4215
-
4216
- if options["enable_backward"]:
4217
- # build reverse signature
4218
- reverse_args = ["wp::launch_bounds_t dim"]
4219
- if device == "cpu":
4220
- reverse_args.append("size_t task_index")
4221
- else:
4222
- for arg in adj.args:
4223
- reverse_args.append(arg.ctype() + " var_" + arg.label)
4224
- for arg in adj.args:
4225
- # indexed array gradients are regular arrays
4226
- if isinstance(arg.type, indexedarray):
4227
- _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
4228
- reverse_args.append(_arg.ctype() + " adj_" + arg.label)
4229
- else:
4230
- reverse_args.append(arg.ctype() + " adj_" + arg.label)
4231
-
4232
- reverse_body = codegen_func_reverse(adj, func_type="kernel", device=device)
4233
- template_fmt_args.update(
4234
- {
4235
- "reverse_args": indent(reverse_args),
4236
- "reverse_body": reverse_body,
4237
- }
4238
- )
4239
- template += template_backward
4240
-
4241
- s = template.format(**template_fmt_args)
4242
- return args_struct + s
4243
-
4244
-
4245
- def codegen_module(kernel, device, options):
4246
- if device != "cpu":
4247
- return ""
4248
-
4249
- # Update the module's options with the ones defined on the kernel, if any.
4250
- options = dict(options)
4251
- options.update(kernel.options)
4252
-
4253
- template = ""
4254
- template_fmt_args = {
4255
- "name": kernel.get_mangled_name(),
4256
- }
4257
-
4258
- template += cpu_module_template_forward
4259
-
4260
- if options["enable_backward"]:
4261
- template += cpu_module_template_backward
4262
-
4263
- s = template.format(**template_fmt_args)
4264
- return s
24
+ return get_deprecated_api(_codegen, "wp", name)