warp-lang 1.9.0__py3-none-manylinux_2_34_aarch64.whl → 1.10.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2302 -307
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/{builtins.py → _src/builtins.py} +1546 -224
  8. warp/_src/codegen.py +4361 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/{fem → _src/fem}/domain.py +42 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +32 -15
  22. warp/{fem → _src/fem}/field/restriction.py +3 -1
  23. warp/{fem → _src/fem}/field/virtual.py +55 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +79 -163
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +16 -22
  28. warp/{fem → _src/fem}/geometry/element.py +34 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +50 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +14 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +14 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +42 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +256 -247
  34. warp/{fem → _src/fem}/geometry/partition.py +123 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +28 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +42 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +28 -45
  38. warp/{fem → _src/fem}/integrate.py +166 -158
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +17 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +97 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +22 -11
  46. warp/_src/fem/space/basis_space.py +681 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +5 -3
  48. warp/{fem → _src/fem}/space/function_space.py +16 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +6 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +6 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +6 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +5 -9
  53. warp/{fem → _src/fem}/space/partition.py +119 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +6 -10
  55. warp/{fem → _src/fem}/space/restriction.py +68 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +11 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +10 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +8 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +5 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +5 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +5 -9
  63. warp/_src/fem/space/topology.py +461 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +5 -9
  65. warp/_src/fem/types.py +114 -0
  66. warp/_src/fem/utils.py +488 -0
  67. warp/_src/jax.py +188 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +389 -0
  70. warp/_src/jax_experimental/ffi.py +1286 -0
  71. warp/_src/jax_experimental/xla_ffi.py +658 -0
  72. warp/_src/marching_cubes.py +710 -0
  73. warp/_src/math.py +416 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +165 -0
  76. warp/_src/optim/linear.py +1608 -0
  77. warp/_src/optim/sgd.py +114 -0
  78. warp/_src/paddle.py +408 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +291 -0
  81. warp/_src/render/render_opengl.py +3638 -0
  82. warp/_src/render/render_usd.py +939 -0
  83. warp/_src/render/utils.py +162 -0
  84. warp/_src/sparse.py +2718 -0
  85. warp/_src/tape.py +1208 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +393 -0
  88. warp/_src/types.py +5888 -0
  89. warp/_src/utils.py +1695 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.so +0 -0
  92. warp/bin/warp.so +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +3 -3
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +581 -280
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +18 -17
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +580 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/top_level.txt +0 -0
@@ -20,14 +20,16 @@ import functools
20
20
  import math
21
21
  from typing import Any, Callable, Mapping, Sequence
22
22
 
23
- import warp.build
24
- import warp.context
25
- import warp.utils
26
- from warp.codegen import Reference, Var, get_arg_value, strip_reference
27
- from warp.types import *
23
+ import warp._src.build
24
+ import warp._src.context
25
+ import warp._src.utils
26
+ from warp._src.codegen import Reference, Var, get_arg_value, strip_reference
27
+ from warp._src.types import *
28
28
 
29
29
  from .context import add_builtin
30
30
 
31
+ _wp_module_name_ = "warp.builtins"
32
+
31
33
 
32
34
  def seq_check_equal(seq_1, seq_2):
33
35
  if not isinstance(seq_1, Sequence) or not isinstance(seq_2, Sequence):
@@ -61,11 +63,11 @@ def sametypes_create_value_func(default: TypeVar):
61
63
 
62
64
  def extract_tuple(arg, as_constant=False):
63
65
  if isinstance(arg, Var):
64
- if isinstance(arg.type, warp.types.tuple_t):
66
+ if isinstance(arg.type, warp._src.types.tuple_t):
65
67
  out = arg.type.values
66
68
  else:
67
69
  out = (arg,)
68
- elif isinstance(arg, warp.types.tuple_t):
70
+ elif isinstance(arg, warp._src.types.tuple_t):
69
71
  out = arg.values
70
72
  elif not isinstance(arg, Sequence):
71
73
  out = (arg,)
@@ -82,7 +84,7 @@ def static_len_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str
82
84
  if arg_types is None:
83
85
  return int
84
86
 
85
- length = warp.types.type_length(arg_types["a"])
87
+ length = warp._src.types.type_length(arg_types["a"])
86
88
  return Var(None, type=int, constant=length)
87
89
 
88
90
 
@@ -126,6 +128,7 @@ add_builtin(
126
128
  value_func=sametypes_create_value_func(Scalar),
127
129
  doc="Return -1 if ``x`` < 0, return 1 otherwise.",
128
130
  group="Scalar Math",
131
+ is_differentiable=False,
129
132
  )
130
133
 
131
134
  add_builtin(
@@ -134,6 +137,7 @@ add_builtin(
134
137
  value_func=sametypes_create_value_func(Scalar),
135
138
  doc="Return 1.0 if ``x`` < 0.0, return 0.0 otherwise.",
136
139
  group="Scalar Math",
140
+ is_differentiable=False,
137
141
  )
138
142
  add_builtin(
139
143
  "nonzero",
@@ -141,6 +145,7 @@ add_builtin(
141
145
  value_func=sametypes_create_value_func(Scalar),
142
146
  doc="Return 1.0 if ``x`` is not equal to zero, return 0.0 otherwise.",
143
147
  group="Scalar Math",
148
+ is_differentiable=False,
144
149
  )
145
150
 
146
151
  add_builtin(
@@ -282,7 +287,36 @@ add_builtin(
282
287
  group="Scalar Math",
283
288
  require_original_output_arg=True,
284
289
  )
285
-
290
+ add_builtin(
291
+ "erf",
292
+ input_types={"x": Float},
293
+ value_func=sametypes_create_value_func(Float),
294
+ doc="Return the error function of ``x``.",
295
+ group="Scalar Math",
296
+ )
297
+ add_builtin(
298
+ "erfc",
299
+ input_types={"x": Float},
300
+ value_func=sametypes_create_value_func(Float),
301
+ doc="Return the complementary error function of ``x``.",
302
+ group="Scalar Math",
303
+ )
304
+ add_builtin(
305
+ "erfinv",
306
+ input_types={"x": Float},
307
+ value_func=sametypes_create_value_func(Float),
308
+ doc="Return the inverse error function of ``x``.",
309
+ group="Scalar Math",
310
+ require_original_output_arg=True,
311
+ )
312
+ add_builtin(
313
+ "erfcinv",
314
+ input_types={"x": Float},
315
+ value_func=sametypes_create_value_func(Float),
316
+ doc="Return the inverse complementary error function of ``x``.",
317
+ group="Scalar Math",
318
+ require_original_output_arg=True,
319
+ )
286
320
  add_builtin(
287
321
  "round",
288
322
  input_types={"x": Float},
@@ -292,6 +326,7 @@ add_builtin(
292
326
 
293
327
  This is the most intuitive form of rounding in the colloquial sense, but can be slower than other options like :func:`warp.rint()`.
294
328
  Differs from :func:`numpy.round()`, which behaves the same way as :func:`numpy.rint()`.""",
329
+ is_differentiable=False,
295
330
  )
296
331
 
297
332
  add_builtin(
@@ -302,6 +337,7 @@ add_builtin(
302
337
  doc="""Return the nearest integer value to ``x``, rounding halfway cases to nearest even integer.
303
338
 
304
339
  It is generally faster than :func:`warp.round()`. Equivalent to :func:`numpy.rint()`.""",
340
+ is_differentiable=False,
305
341
  )
306
342
 
307
343
  add_builtin(
@@ -314,6 +350,7 @@ add_builtin(
314
350
  In other words, it discards the fractional part of ``x``.
315
351
  It is similar to casting ``float(int(a))``, but preserves the negative sign when ``x`` is in the range [-0.0, -1.0).
316
352
  Equivalent to :func:`numpy.trunc()` and :func:`numpy.fix()`.""",
353
+ is_differentiable=False,
317
354
  )
318
355
 
319
356
  add_builtin(
@@ -322,6 +359,7 @@ add_builtin(
322
359
  value_func=sametypes_create_value_func(Float),
323
360
  group="Scalar Math",
324
361
  doc="""Return the largest integer that is less than or equal to ``x``.""",
362
+ is_differentiable=False,
325
363
  )
326
364
 
327
365
  add_builtin(
@@ -330,6 +368,7 @@ add_builtin(
330
368
  value_func=sametypes_create_value_func(Float),
331
369
  group="Scalar Math",
332
370
  doc="""Return the smallest integer that is greater than or equal to ``x``.""",
371
+ is_differentiable=False,
333
372
  )
334
373
 
335
374
  add_builtin(
@@ -340,6 +379,7 @@ add_builtin(
340
379
  doc="""Retrieve the fractional part of ``x``.
341
380
 
342
381
  In other words, it discards the integer part of ``x`` and is equivalent to ``x - trunc(x)``.""",
382
+ is_differentiable=False,
343
383
  )
344
384
 
345
385
  add_builtin(
@@ -348,6 +388,7 @@ add_builtin(
348
388
  value_type=builtins.bool,
349
389
  group="Scalar Math",
350
390
  doc="""Return ``True`` if ``a`` is a finite number, otherwise return ``False``.""",
391
+ is_differentiable=False,
351
392
  )
352
393
  add_builtin(
353
394
  "isfinite",
@@ -355,6 +396,7 @@ add_builtin(
355
396
  value_type=builtins.bool,
356
397
  group="Vector Math",
357
398
  doc="Return ``True`` if all elements of the vector ``a`` are finite, otherwise return ``False``.",
399
+ is_differentiable=False,
358
400
  )
359
401
  add_builtin(
360
402
  "isfinite",
@@ -362,6 +404,7 @@ add_builtin(
362
404
  value_type=builtins.bool,
363
405
  group="Vector Math",
364
406
  doc="Return ``True`` if all elements of the quaternion ``a`` are finite, otherwise return ``False``.",
407
+ is_differentiable=False,
365
408
  )
366
409
  add_builtin(
367
410
  "isfinite",
@@ -369,6 +412,7 @@ add_builtin(
369
412
  value_type=builtins.bool,
370
413
  group="Vector Math",
371
414
  doc="Return ``True`` if all elements of the matrix ``a`` are finite, otherwise return ``False``.",
415
+ is_differentiable=False,
372
416
  )
373
417
 
374
418
  add_builtin(
@@ -377,6 +421,7 @@ add_builtin(
377
421
  value_type=builtins.bool,
378
422
  doc="Return ``True`` if ``a`` is NaN, otherwise return ``False``.",
379
423
  group="Scalar Math",
424
+ is_differentiable=False,
380
425
  )
381
426
  add_builtin(
382
427
  "isnan",
@@ -384,6 +429,7 @@ add_builtin(
384
429
  value_type=builtins.bool,
385
430
  group="Vector Math",
386
431
  doc="Return ``True`` if any element of the vector ``a`` is NaN, otherwise return ``False``.",
432
+ is_differentiable=False,
387
433
  )
388
434
  add_builtin(
389
435
  "isnan",
@@ -391,6 +437,7 @@ add_builtin(
391
437
  value_type=builtins.bool,
392
438
  group="Vector Math",
393
439
  doc="Return ``True`` if any element of the quaternion ``a`` is NaN, otherwise return ``False``.",
440
+ is_differentiable=False,
394
441
  )
395
442
  add_builtin(
396
443
  "isnan",
@@ -398,6 +445,7 @@ add_builtin(
398
445
  value_type=builtins.bool,
399
446
  group="Vector Math",
400
447
  doc="Return ``True`` if any element of the matrix ``a`` is NaN, otherwise return ``False``.",
448
+ is_differentiable=False,
401
449
  )
402
450
 
403
451
  add_builtin(
@@ -406,6 +454,7 @@ add_builtin(
406
454
  value_type=builtins.bool,
407
455
  group="Scalar Math",
408
456
  doc="""Return ``True`` if ``a`` is positive or negative infinity, otherwise return ``False``.""",
457
+ is_differentiable=False,
409
458
  )
410
459
  add_builtin(
411
460
  "isinf",
@@ -413,6 +462,7 @@ add_builtin(
413
462
  value_type=builtins.bool,
414
463
  group="Vector Math",
415
464
  doc="Return ``True`` if any element of the vector ``a`` is positive or negative infinity, otherwise return ``False``.",
465
+ is_differentiable=False,
416
466
  )
417
467
  add_builtin(
418
468
  "isinf",
@@ -420,6 +470,7 @@ add_builtin(
420
470
  value_type=builtins.bool,
421
471
  group="Vector Math",
422
472
  doc="Return ``True`` if any element of the quaternion ``a`` is positive or negative infinity, otherwise return ``False``.",
473
+ is_differentiable=False,
423
474
  )
424
475
  add_builtin(
425
476
  "isinf",
@@ -427,6 +478,7 @@ add_builtin(
427
478
  value_type=builtins.bool,
428
479
  group="Vector Math",
429
480
  doc="Return ``True`` if any element of the matrix ``a`` is positive or negative infinity, otherwise return ``False``.",
481
+ is_differentiable=False,
430
482
  )
431
483
 
432
484
 
@@ -534,7 +586,7 @@ add_builtin(
534
586
  value_func=lambda arg_types, arg_values: warp.uint32,
535
587
  doc="Return the index of the minimum element of a vector ``a``.",
536
588
  group="Vector Math",
537
- missing_grad=True,
589
+ is_differentiable=False,
538
590
  )
539
591
  add_builtin(
540
592
  "argmax",
@@ -542,7 +594,7 @@ add_builtin(
542
594
  value_func=lambda arg_types, arg_values: warp.uint32,
543
595
  doc="Return the index of the maximum element of a vector ``a``.",
544
596
  group="Vector Math",
545
- missing_grad=True,
597
+ is_differentiable=False,
546
598
  )
547
599
 
548
600
  add_builtin(
@@ -867,7 +919,7 @@ def vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
867
919
 
868
920
  if dtype is None:
869
921
  dtype = value_type
870
- elif not warp.types.scalars_equal(value_type, dtype):
922
+ elif not warp._src.types.scalars_equal(value_type, dtype):
871
923
  raise RuntimeError(
872
924
  f"the value used to fill this vector is expected to be of the type `{dtype.__name__}`"
873
925
  )
@@ -888,7 +940,7 @@ def vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
888
940
 
889
941
  if dtype is None:
890
942
  dtype = value_type
891
- elif not warp.types.scalars_equal(value_type, dtype):
943
+ elif not warp._src.types.scalars_equal(value_type, dtype):
892
944
  raise RuntimeError(
893
945
  f"all values used to initialize this vector are expected to be of the type `{dtype.__name__}`"
894
946
  )
@@ -971,7 +1023,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
971
1023
 
972
1024
  if dtype is None:
973
1025
  dtype = value_type
974
- elif not warp.types.scalars_equal(value_type, dtype):
1026
+ elif not warp._src.types.scalars_equal(value_type, dtype):
975
1027
  raise RuntimeError(
976
1028
  f"the value used to fill this matrix is expected to be of the type `{dtype.__name__}`"
977
1029
  )
@@ -981,7 +1033,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
981
1033
  raise RuntimeError("the `shape` argument must be specified when initializing a matrix by value")
982
1034
 
983
1035
  if all(type_is_vector(x) for x in variadic_arg_types):
984
- warp.utils.warn(
1036
+ warp._src.utils.warn(
985
1037
  "the built-in `wp.matrix()` won't support taking column vectors as input "
986
1038
  "in the future. Use `wp.matrix_from_rows()` or `wp.matrix_from_cols()` instead.",
987
1039
  DeprecationWarning,
@@ -1010,7 +1062,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
1010
1062
 
1011
1063
  if dtype is None:
1012
1064
  dtype = value_type
1013
- elif not warp.types.scalars_equal(value_type, dtype):
1065
+ elif not warp._src.types.scalars_equal(value_type, dtype):
1014
1066
  raise RuntimeError(
1015
1067
  f"all values used to initialize this matrix are expected to be of the type `{dtype.__name__}`"
1016
1068
  )
@@ -1182,48 +1234,18 @@ add_builtin(
1182
1234
  doc="Create an identity matrix with shape=(n,n) with the type given by ``dtype``.",
1183
1235
  group="Vector Math",
1184
1236
  export=False,
1237
+ is_differentiable=False,
1185
1238
  )
1186
1239
 
1187
1240
 
1188
1241
  def matrix_transform_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1189
- warp.utils.warn(
1190
- "the built-in `wp.matrix()` function to construct a 4x4 matrix from a 3D position, quaternion, "
1191
- "and 3D scale vector will be deprecated in favor of `wp.transform_compose()`.",
1192
- DeprecationWarning,
1193
- )
1194
1242
  if arg_types is None:
1195
1243
  return matrix(shape=(4, 4), dtype=Float)
1196
1244
 
1197
- dtype = arg_values.get("dtype", None)
1198
-
1199
- value_arg_types = tuple(v for k, v in arg_types.items() if k != "dtype")
1200
- try:
1201
- value_type = scalar_infer_type(value_arg_types)
1202
- except RuntimeError:
1203
- raise RuntimeError(
1204
- "all values given when constructing a transformation matrix must have the same type"
1205
- ) from None
1206
-
1207
- if dtype is None:
1208
- dtype = value_type
1209
- elif not warp.types.scalars_equal(value_type, dtype):
1210
- raise RuntimeError(
1211
- f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
1212
- )
1213
-
1214
- return matrix(shape=(4, 4), dtype=dtype)
1215
-
1216
-
1217
- def matrix_transform_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
1218
- # We're in the codegen stage where we emit the code calling the built-in.
1219
- # Further validate the given argument values if needed and map them
1220
- # to the underlying C++ function's runtime and template params.
1221
-
1222
- dtype = return_type._wp_scalar_type_
1223
-
1224
- func_args = tuple(v for k, v in args.items() if k != "dtype")
1225
- template_args = (4, 4, dtype)
1226
- return (func_args, template_args)
1245
+ raise RuntimeError(
1246
+ "the built-in `wp.matrix()` to construct a 4x4 matrix from a 3D position, quaternion, "
1247
+ "and 3D scale vector has been removed in favor of `wp.transform_compose()`."
1248
+ )
1227
1249
 
1228
1250
 
1229
1251
  add_builtin(
@@ -1237,13 +1259,14 @@ add_builtin(
1237
1259
  defaults={"dtype": None},
1238
1260
  value_func=matrix_transform_value_func,
1239
1261
  export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
1240
- dispatch_func=matrix_transform_dispatch_func,
1241
1262
  native_func="mat_t",
1242
1263
  doc="""Construct a 4x4 transformation matrix that applies the transformations as
1243
1264
  Translation(pos)*Rotation(rot)*Scaling(scale) when applied to column vectors, i.e.: y = (TRS)*x
1244
1265
 
1245
- .. warning::
1246
- This function has been deprecated in favor of :func:`warp.math.transform_compose()`.""",
1266
+ .. versionremoved:: 1.10
1267
+ This function has been removed in favor of :func:`warp.math.transform_compose()`.
1268
+
1269
+ .. deprecated:: 1.8""",
1247
1270
  group="Vector Math",
1248
1271
  export=False,
1249
1272
  )
@@ -1438,7 +1461,7 @@ def quaternion_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str
1438
1461
 
1439
1462
  if dtype is None:
1440
1463
  dtype = value_type
1441
- elif not warp.types.scalars_equal(value_type, dtype):
1464
+ elif not warp._src.types.scalars_equal(value_type, dtype):
1442
1465
  raise RuntimeError(
1443
1466
  f"all values used to initialize this quaternion are expected to be of the type `{dtype.__name__}`"
1444
1467
  )
@@ -1546,6 +1569,7 @@ add_builtin(
1546
1569
  group="Quaternion Math",
1547
1570
  doc="Construct an identity quaternion with zero imaginary part and real part of 1.0",
1548
1571
  export=True,
1572
+ is_differentiable=False,
1549
1573
  )
1550
1574
 
1551
1575
  add_builtin(
@@ -1674,7 +1698,7 @@ def transformation_value_func(arg_types: Mapping[str, type], arg_values: Mapping
1674
1698
  value_type = strip_reference(variadic_arg_types[0])
1675
1699
  if dtype is None:
1676
1700
  dtype = value_type
1677
- elif not warp.types.scalars_equal(value_type, dtype):
1701
+ elif not warp._src.types.scalars_equal(value_type, dtype):
1678
1702
  raise RuntimeError(
1679
1703
  f"the value used to fill this transform is expected to be of the type `{dtype.__name__}`"
1680
1704
  )
@@ -1687,7 +1711,7 @@ def transformation_value_func(arg_types: Mapping[str, type], arg_values: Mapping
1687
1711
 
1688
1712
  if dtype is None:
1689
1713
  dtype = value_type
1690
- elif not warp.types.scalars_equal(value_type, dtype):
1714
+ elif not warp._src.types.scalars_equal(value_type, dtype):
1691
1715
  raise RuntimeError(
1692
1716
  f"all values used to initialize this transform are expected to be of the type `{dtype.__name__}`"
1693
1717
  )
@@ -1712,7 +1736,7 @@ def transformation_pq_value_func(arg_types: Mapping[str, type], arg_values: Mapp
1712
1736
  dtype = arg_values.get("dtype", None)
1713
1737
  if dtype is None:
1714
1738
  dtype = value_type
1715
- elif not warp.types.scalars_equal(value_type, dtype):
1739
+ elif not warp._src.types.scalars_equal(value_type, dtype):
1716
1740
  raise RuntimeError(
1717
1741
  f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
1718
1742
  )
@@ -1727,9 +1751,19 @@ def transformation_dispatch_func(input_types: Mapping[str, type], return_type: A
1727
1751
 
1728
1752
  dtype = return_type._wp_scalar_type_
1729
1753
 
1730
- variadic_args = tuple(v for k, v in args.items() if k != "dtype")
1754
+ variadic_args = args.get("args", ())
1755
+ variadic_arg_count = len(variadic_args)
1756
+
1757
+ if variadic_arg_count == 7:
1758
+ func_args = variadic_args
1759
+ else:
1760
+ func_args = tuple(v for k, v in args.items() if k != "dtype")
1761
+ if "p" in args and "q" not in args:
1762
+ quat_ident = warp._src.codegen.Var(
1763
+ label=None, type=quaternion(dtype=dtype), constant=quaternion(dtype=dtype)(0, 0, 0, 1)
1764
+ )
1765
+ func_args += (quat_ident,)
1731
1766
 
1732
- func_args = variadic_args
1733
1767
  template_args = (dtype,)
1734
1768
  return (func_args, template_args)
1735
1769
 
@@ -1737,7 +1771,7 @@ def transformation_dispatch_func(input_types: Mapping[str, type], return_type: A
1737
1771
  add_builtin(
1738
1772
  "transformation",
1739
1773
  input_types={"p": vector(length=3, dtype=Float), "q": quaternion(dtype=Float), "dtype": Float},
1740
- defaults={"dtype": None},
1774
+ defaults={"q": None, "dtype": None},
1741
1775
  value_func=transformation_pq_value_func,
1742
1776
  export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
1743
1777
  dispatch_func=transformation_dispatch_func,
@@ -1795,6 +1829,7 @@ add_builtin(
1795
1829
  group="Transformations",
1796
1830
  doc="Construct an identity transform with zero translation and identity rotation.",
1797
1831
  export=True,
1832
+ is_differentiable=False,
1798
1833
  )
1799
1834
 
1800
1835
  add_builtin(
@@ -1928,7 +1963,7 @@ def spatial_vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping
1928
1963
 
1929
1964
  if dtype is None:
1930
1965
  dtype = value_type
1931
- elif not warp.types.scalars_equal(value_type, dtype):
1966
+ elif not warp._src.types.scalars_equal(value_type, dtype):
1932
1967
  raise RuntimeError(
1933
1968
  f"all values used to initialize this spatial vector are expected to be of the type `{dtype.__name__}`"
1934
1969
  )
@@ -2122,7 +2157,7 @@ add_builtin(
2122
2157
  value_func=tile_zeros_value_func,
2123
2158
  dispatch_func=tile_zeros_dispatch_func,
2124
2159
  variadic=False,
2125
- missing_grad=True,
2160
+ is_differentiable=False,
2126
2161
  doc="""Allocate a tile of zero-initialized items.
2127
2162
 
2128
2163
  :param shape: Shape of the output tile
@@ -2142,7 +2177,7 @@ add_builtin(
2142
2177
  value_func=tile_zeros_value_func,
2143
2178
  dispatch_func=tile_zeros_dispatch_func,
2144
2179
  variadic=False,
2145
- missing_grad=True,
2180
+ is_differentiable=False,
2146
2181
  hidden=True,
2147
2182
  group="Tile Primitives",
2148
2183
  export=False,
@@ -2194,7 +2229,7 @@ add_builtin(
2194
2229
  defaults={"storage": "register"},
2195
2230
  value_func=tile_ones_value_func,
2196
2231
  dispatch_func=tile_ones_dispatch_func,
2197
- missing_grad=True,
2232
+ is_differentiable=False,
2198
2233
  doc="""Allocate a tile of one-initialized items.
2199
2234
 
2200
2235
  :param shape: Shape of the output tile
@@ -2213,7 +2248,86 @@ add_builtin(
2213
2248
  defaults={"storage": "register"},
2214
2249
  value_func=tile_ones_value_func,
2215
2250
  dispatch_func=tile_ones_dispatch_func,
2216
- missing_grad=True,
2251
+ is_differentiable=False,
2252
+ hidden=True,
2253
+ group="Tile Primitives",
2254
+ export=False,
2255
+ )
2256
+
2257
+
2258
+ def tile_full_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
2259
+ # return generic type (for doc builds)
2260
+ if arg_types is None:
2261
+ return tile(dtype=Any, shape=Tuple[int, ...])
2262
+
2263
+ shape = extract_tuple(arg_values["shape"], as_constant=True)
2264
+
2265
+ if None in shape:
2266
+ raise ValueError("Tile functions require shape to be a compile time constant.")
2267
+
2268
+ if "value" not in arg_values:
2269
+ raise TypeError("tile_full() missing required keyword argument 'value'")
2270
+
2271
+ if "dtype" not in arg_values:
2272
+ raise TypeError("tile_full() missing required keyword argument 'dtype'")
2273
+
2274
+ if "storage" not in arg_values:
2275
+ raise TypeError("tile_full() missing required keyword argument 'storage'")
2276
+
2277
+ if arg_values["storage"] not in {"shared", "register"}:
2278
+ raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
2279
+
2280
+ dtype = arg_values["dtype"]
2281
+
2282
+ return tile(dtype=dtype, shape=shape, storage=arg_values["storage"])
2283
+
2284
+
2285
+ def tile_full_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
2286
+ shape = extract_tuple(arg_values["shape"], as_constant=True)
2287
+
2288
+ if None in shape:
2289
+ raise ValueError("Tile functions require shape to be a compile time constant.")
2290
+
2291
+ dtype = arg_values["dtype"]
2292
+ value = arg_values["value"]
2293
+
2294
+ func_args = [value]
2295
+
2296
+ template_args = []
2297
+ template_args.append(dtype)
2298
+ template_args.extend(shape)
2299
+
2300
+ return (func_args, template_args)
2301
+
2302
+
2303
+ add_builtin(
2304
+ "tile_full",
2305
+ input_types={"shape": Tuple[int, ...], "value": Any, "dtype": Any, "storage": str},
2306
+ defaults={"storage": "register"},
2307
+ value_func=tile_full_value_func,
2308
+ dispatch_func=tile_full_dispatch_func,
2309
+ is_differentiable=False,
2310
+ doc="""Allocate a tile filled with the specified value.
2311
+
2312
+ :param shape: Shape of the output tile
2313
+ :param value: Value to fill the tile with
2314
+ :param dtype: Data type of output tile's elements
2315
+ :param storage: The storage location for the tile: ``"register"`` for registers
2316
+ (default) or ``"shared"`` for shared memory.
2317
+ :returns: A tile filled with the specified value""",
2318
+ group="Tile Primitives",
2319
+ export=False,
2320
+ )
2321
+
2322
+
2323
+ # overload for scalar shape
2324
+ add_builtin(
2325
+ "tile_full",
2326
+ input_types={"shape": int, "value": Any, "dtype": Any, "storage": str},
2327
+ defaults={"storage": "register"},
2328
+ value_func=tile_full_value_func,
2329
+ dispatch_func=tile_full_dispatch_func,
2330
+ is_differentiable=False,
2217
2331
  hidden=True,
2218
2332
  group="Tile Primitives",
2219
2333
  export=False,
@@ -2275,13 +2389,13 @@ def tile_arange_dispatch_func(arg_types: Mapping[str, type], return_type: Any, a
2275
2389
  args = arg_values["args"]
2276
2390
 
2277
2391
  if len(args) == 1:
2278
- start = warp.codegen.Var(label=None, type=return_type.dtype, constant=0)
2392
+ start = warp._src.codegen.Var(label=None, type=return_type.dtype, constant=0)
2279
2393
  stop = args[0]
2280
- step = warp.codegen.Var(label=None, type=return_type.dtype, constant=1)
2394
+ step = warp._src.codegen.Var(label=None, type=return_type.dtype, constant=1)
2281
2395
  elif len(args) == 2:
2282
2396
  start = args[0]
2283
2397
  stop = args[1]
2284
- step = warp.codegen.Var(label=None, type=return_type.dtype, constant=1)
2398
+ step = warp._src.codegen.Var(label=None, type=return_type.dtype, constant=1)
2285
2399
  elif len(args) == 3:
2286
2400
  start = args[0]
2287
2401
  stop = args[1]
@@ -2304,7 +2418,7 @@ add_builtin(
2304
2418
  value_func=tile_arange_value_func,
2305
2419
  dispatch_func=tile_arange_dispatch_func,
2306
2420
  variadic=True,
2307
- missing_grad=True,
2421
+ is_differentiable=False,
2308
2422
  doc="""Generate a tile of linearly spaced elements.
2309
2423
 
2310
2424
  :param args: Variable-length positional arguments, interpreted as:
@@ -3099,7 +3213,7 @@ add_builtin(
3099
3213
  :param shape: Shape of the returned slice
3100
3214
  :returns: A tile with dimensions given by the specified shape or the remaining source tile dimensions""",
3101
3215
  group="Tile Primitives",
3102
- missing_grad=True,
3216
+ is_differentiable=False,
3103
3217
  export=False,
3104
3218
  )
3105
3219
 
@@ -3346,7 +3460,32 @@ add_builtin(
3346
3460
 
3347
3461
  add_builtin(
3348
3462
  "assign",
3349
- input_types={"dst": tile(dtype=Any, shape=Tuple[int, int]), "i": int, "j": int, "src": Any},
3463
+ input_types={"dst": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "src": Any},
3464
+ value_func=tile_assign_value_func,
3465
+ group="Tile Primitives",
3466
+ export=False,
3467
+ hidden=True,
3468
+ )
3469
+
3470
+ add_builtin(
3471
+ "assign",
3472
+ input_types={"dst": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "src": Any},
3473
+ value_func=tile_assign_value_func,
3474
+ group="Tile Primitives",
3475
+ export=False,
3476
+ hidden=True,
3477
+ )
3478
+
3479
+ add_builtin(
3480
+ "assign",
3481
+ input_types={
3482
+ "dst": tile(dtype=Any, shape=Tuple[int, ...]),
3483
+ "i": int,
3484
+ "j": int,
3485
+ "k": int,
3486
+ "l": int,
3487
+ "src": Any,
3488
+ },
3350
3489
  value_func=tile_assign_value_func,
3351
3490
  group="Tile Primitives",
3352
3491
  export=False,
@@ -3355,7 +3494,15 @@ add_builtin(
3355
3494
 
3356
3495
  add_builtin(
3357
3496
  "assign",
3358
- input_types={"dst": tile(dtype=Any, shape=Tuple[int, int, int]), "i": int, "j": int, "k": int, "src": Any},
3497
+ input_types={
3498
+ "dst": tile(dtype=Any, shape=Tuple[int, ...]),
3499
+ "i": int,
3500
+ "j": int,
3501
+ "k": int,
3502
+ "l": int,
3503
+ "m": int,
3504
+ "src": Any,
3505
+ },
3359
3506
  value_func=tile_assign_value_func,
3360
3507
  group="Tile Primitives",
3361
3508
  export=False,
@@ -3370,6 +3517,8 @@ add_builtin(
3370
3517
  "j": int,
3371
3518
  "k": int,
3372
3519
  "l": int,
3520
+ "m": int,
3521
+ "n": int,
3373
3522
  "src": Any,
3374
3523
  },
3375
3524
  value_func=tile_assign_value_func,
@@ -3391,7 +3540,7 @@ def tile_value_func(arg_types, arg_values):
3391
3540
 
3392
3541
  if preserve_type:
3393
3542
  dtype = arg_types["x"]
3394
- shape = (warp.codegen.options["block_dim"],)
3543
+ shape = (warp._src.codegen.options["block_dim"],)
3395
3544
 
3396
3545
  return tile(dtype=dtype, shape=shape)
3397
3546
 
@@ -3399,18 +3548,18 @@ def tile_value_func(arg_types, arg_values):
3399
3548
  if type_is_vector(arg_types["x"]):
3400
3549
  dtype = arg_types["x"]._wp_scalar_type_
3401
3550
  length = arg_types["x"]._shape_[0]
3402
- shape = (length, warp.codegen.options["block_dim"])
3551
+ shape = (length, warp._src.codegen.options["block_dim"])
3403
3552
  elif type_is_quaternion(arg_types["x"]):
3404
3553
  dtype = arg_types["x"]._wp_scalar_type_
3405
- shape = (4, warp.codegen.options["block_dim"])
3554
+ shape = (4, warp._src.codegen.options["block_dim"])
3406
3555
  elif type_is_matrix(arg_types["x"]):
3407
3556
  dtype = arg_types["x"]._wp_scalar_type_
3408
3557
  rows = arg_types["x"]._shape_[0]
3409
3558
  cols = arg_types["x"]._shape_[1]
3410
- shape = (rows, cols, warp.codegen.options["block_dim"])
3559
+ shape = (rows, cols, warp._src.codegen.options["block_dim"])
3411
3560
  else:
3412
3561
  dtype = arg_types["x"]
3413
- shape = (warp.codegen.options["block_dim"],)
3562
+ shape = (warp._src.codegen.options["block_dim"],)
3414
3563
 
3415
3564
  return tile(dtype=dtype, shape=shape)
3416
3565
 
@@ -3500,17 +3649,17 @@ def untile_value_func(arg_types, arg_values):
3500
3649
  if not is_tile(t):
3501
3650
  raise TypeError(f"untile() argument must be a tile, got {t!r}")
3502
3651
 
3503
- if t.shape[-1] != warp.codegen.options["block_dim"]:
3652
+ if t.shape[-1] != warp._src.codegen.options["block_dim"]:
3504
3653
  raise ValueError(
3505
- f"untile() argument last dimension {t.shape[-1]} does not match the expected block width {warp.codegen.options['block_dim']}"
3654
+ f"untile() argument last dimension {t.shape[-1]} does not match the expected block width {warp._src.codegen.options['block_dim']}"
3506
3655
  )
3507
3656
 
3508
3657
  if len(t.shape) == 1:
3509
3658
  return t.dtype
3510
3659
  elif len(t.shape) == 2:
3511
- return warp.types.vector(t.shape[0], t.dtype)
3660
+ return warp._src.types.vector(t.shape[0], t.dtype)
3512
3661
  elif len(t.shape) == 3:
3513
- return warp.types.matrix((t.shape[0], t.shape[1]), t.dtype)
3662
+ return warp._src.types.matrix((t.shape[0], t.shape[1]), t.dtype)
3514
3663
  else:
3515
3664
  raise ValueError(f"untile() argument must have a positive size in dimension 0, but got {t.shape[0]}")
3516
3665
 
@@ -3572,7 +3721,36 @@ def tile_extract_value_func(arg_types, arg_values):
3572
3721
  # force the input tile to shared memory
3573
3722
  arg_types["a"].storage = "shared"
3574
3723
 
3575
- return arg_types["a"].dtype
3724
+ # count the number of indices (all parameters except the tile "a")
3725
+ num_indices = len(arg_types) - 1
3726
+ tile_dtype = arg_types["a"].dtype
3727
+ tile_shape = arg_types["a"].shape
3728
+
3729
+ if type_is_vector(tile_dtype):
3730
+ if num_indices == len(tile_shape):
3731
+ return tile_dtype
3732
+ elif num_indices == len(tile_shape) + 1:
3733
+ return tile_dtype._wp_scalar_type_
3734
+ else:
3735
+ raise IndexError(
3736
+ f"tile_extract: incorrect number of indices ({num_indices}) for tile shape {tuple(tile_shape)}"
3737
+ )
3738
+ elif type_is_matrix(tile_dtype):
3739
+ if num_indices == len(tile_shape):
3740
+ return tile_dtype
3741
+ elif num_indices == len(tile_shape) + 2:
3742
+ return tile_dtype._wp_scalar_type_
3743
+ else:
3744
+ raise IndexError(
3745
+ f"tile_extract: incorrect number of indices ({num_indices}) for matrix tile shape {tuple(tile_shape)}"
3746
+ )
3747
+ else:
3748
+ # scalar element: index count must exactly match tile rank
3749
+ if num_indices == len(tile_shape):
3750
+ return tile_dtype
3751
+ raise IndexError(
3752
+ f"tile_extract: incorrect number of indices ({num_indices}) for tile shape {tuple(tile_shape)}"
3753
+ )
3576
3754
 
3577
3755
 
3578
3756
  add_builtin(
@@ -3596,7 +3774,7 @@ add_builtin(
3596
3774
 
3597
3775
  add_builtin(
3598
3776
  "tile_extract",
3599
- input_types={"a": tile(dtype=Any, shape=Tuple[int, int]), "i": int, "j": int},
3777
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int},
3600
3778
  value_func=tile_extract_value_func,
3601
3779
  variadic=False,
3602
3780
  doc="""Extract a single element from the tile.
@@ -3607,7 +3785,7 @@ add_builtin(
3607
3785
 
3608
3786
  :param a: Tile to extract the element from
3609
3787
  :param i: Coordinate of element on first dimension
3610
- :param j: Coordinate of element on the second dimension
3788
+ :param j: Coordinate of element on the second dimension, or vector index
3611
3789
  :returns: The value of the element at the specified tile location with the same data type as the input tile""",
3612
3790
  group="Tile Primitives",
3613
3791
  hidden=True,
@@ -3616,7 +3794,57 @@ add_builtin(
3616
3794
 
3617
3795
  add_builtin(
3618
3796
  "tile_extract",
3619
- input_types={"a": tile(dtype=Any, shape=Tuple[int, int, int]), "i": int, "j": int, "k": int},
3797
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int},
3798
+ value_func=tile_extract_value_func,
3799
+ variadic=False,
3800
+ doc="""Extract a single element from the tile.
3801
+
3802
+ This function will extract an element from the tile and broadcast its value to all threads in the block.
3803
+
3804
+ Note that this may incur additional synchronization if the source tile is a register tile.
3805
+
3806
+ :param a: Tile to extract the element from
3807
+ :param i: Coordinate of element on first dimension
3808
+ :param j: Coordinate of element on the second dimension, or first matrix index
3809
+ :param k: Coordinate of element on the third dimension, or vector index, or second matrix index
3810
+ :returns: The value of the element at the specified tile location with the same data type as the input tile""",
3811
+ group="Tile Primitives",
3812
+ hidden=True,
3813
+ export=False,
3814
+ )
3815
+
3816
+ add_builtin(
3817
+ "tile_extract",
3818
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int},
3819
+ value_func=tile_extract_value_func,
3820
+ variadic=False,
3821
+ doc="""Extract a single element from the tile.
3822
+
3823
+ This function will extract an element from the tile and broadcast its value to all threads in the block.
3824
+
3825
+ Note that this may incur additional synchronization if the source tile is a register tile.
3826
+
3827
+ :param a: Tile to extract the element from
3828
+ :param i: Coordinate of element on first dimension
3829
+ :param j: Coordinate of element on the second dimension
3830
+ :param k: Coordinate of element on the third dimension, or first matrix index
3831
+ :param l: Coordinate of element on the fourth dimension, or vector index, or second matrix index
3832
+ :returns: The value of the element at the specified tile location, with the same data type as the input tile""",
3833
+ group="Tile Primitives",
3834
+ hidden=True,
3835
+ export=False,
3836
+ )
3837
+
3838
+ add_builtin(
3839
+ "tile_extract",
3840
+ input_types={
3841
+ "a": tile(dtype=Any, shape=Tuple[int, ...]),
3842
+ "i": int,
3843
+ "j": int,
3844
+ "k": int,
3845
+ "l": int,
3846
+ "m": int,
3847
+ },
3620
3848
  value_func=tile_extract_value_func,
3621
3849
  variadic=False,
3622
3850
  doc="""Extract a single element from the tile.
@@ -3629,7 +3857,9 @@ add_builtin(
3629
3857
  :param i: Coordinate of element on first dimension
3630
3858
  :param j: Coordinate of element on the second dimension
3631
3859
  :param k: Coordinate of element on the third dimension
3632
- :returns: The value of the element at the specified tile location with the same data type as the input tile""",
3860
+ :param l: Coordinate of element on the fourth dimension, or first matrix index
3861
+ :param m: Vector index, or second matrix index
3862
+ :returns: The value of the element at the specified tile location, with the same data type as the input tile""",
3633
3863
  group="Tile Primitives",
3634
3864
  hidden=True,
3635
3865
  export=False,
@@ -3637,7 +3867,15 @@ add_builtin(
3637
3867
 
3638
3868
  add_builtin(
3639
3869
  "tile_extract",
3640
- input_types={"a": tile(dtype=Any, shape=Tuple[int, int, int, int]), "i": int, "j": int, "k": int, "l": int},
3870
+ input_types={
3871
+ "a": tile(dtype=Any, shape=Tuple[int, int, int, int]),
3872
+ "i": int,
3873
+ "j": int,
3874
+ "k": int,
3875
+ "l": int,
3876
+ "m": int,
3877
+ "n": int,
3878
+ },
3641
3879
  value_func=tile_extract_value_func,
3642
3880
  variadic=False,
3643
3881
  doc="""Extract a single element from the tile.
@@ -3651,6 +3889,8 @@ add_builtin(
3651
3889
  :param j: Coordinate of element on the second dimension
3652
3890
  :param k: Coordinate of element on the third dimension
3653
3891
  :param l: Coordinate of element on the fourth dimension
3892
+ :param m: Vector index, or first matrix index
3893
+ :param n: Second matrix index
3654
3894
  :returns: The value of the element at the specified tile location, with the same data type as the input tile""",
3655
3895
  group="Tile Primitives",
3656
3896
  hidden=True,
@@ -3737,49 +3977,160 @@ add_builtin(
3737
3977
  export=False,
3738
3978
  )
3739
3979
 
3740
-
3741
- def tile_transpose_value_func(arg_types, arg_values):
3742
- # return generic type (for doc builds)
3743
- if arg_types is None:
3744
- return tile(dtype=Any, shape=Tuple[int, int])
3745
-
3746
- if len(arg_types) != 1:
3747
- raise TypeError(f"tile_transpose() takes exactly 1 positional argument but {len(arg_types)} were given")
3748
-
3749
- t = arg_types["a"]
3750
-
3751
- if not is_tile(t):
3752
- raise TypeError(f"tile_transpose() argument must be a tile, got {t!r}")
3753
-
3754
- layout = None
3755
-
3756
- # flip layout
3757
- if t.layout == "rowmajor":
3758
- layout = "colmajor"
3759
- elif t.layout == "colmajor":
3760
- layout = "rowmajor"
3761
-
3762
- # force the input tile to shared memory
3763
- t.storage = "shared"
3764
-
3765
- return tile(
3766
- dtype=t.dtype,
3767
- shape=t.shape[::-1],
3768
- storage=t.storage,
3769
- strides=t.strides[::-1],
3770
- layout=layout,
3771
- owner=False,
3772
- )
3773
-
3774
-
3775
3980
  add_builtin(
3776
- "tile_transpose",
3777
- input_types={"a": tile(dtype=Any, shape=Tuple[int, int])},
3778
- value_func=tile_transpose_value_func,
3779
- variadic=True,
3780
- doc="""Transpose a tile.
3781
-
3782
- For shared memory tiles, this operation will alias the input tile.
3981
+ "tile_bit_and_inplace",
3982
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
3983
+ value_func=tile_inplace_value_func,
3984
+ group="Tile Primitives",
3985
+ hidden=True,
3986
+ export=False,
3987
+ is_differentiable=False,
3988
+ )
3989
+ add_builtin(
3990
+ "tile_bit_and_inplace",
3991
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
3992
+ value_func=tile_inplace_value_func,
3993
+ group="Tile Primitives",
3994
+ hidden=True,
3995
+ export=False,
3996
+ is_differentiable=False,
3997
+ )
3998
+ add_builtin(
3999
+ "tile_bit_and_inplace",
4000
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
4001
+ value_func=tile_inplace_value_func,
4002
+ group="Tile Primitives",
4003
+ hidden=True,
4004
+ export=False,
4005
+ is_differentiable=False,
4006
+ )
4007
+ add_builtin(
4008
+ "tile_bit_and_inplace",
4009
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
4010
+ value_func=tile_inplace_value_func,
4011
+ group="Tile Primitives",
4012
+ hidden=True,
4013
+ export=False,
4014
+ is_differentiable=False,
4015
+ )
4016
+
4017
+ add_builtin(
4018
+ "tile_bit_or_inplace",
4019
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
4020
+ value_func=tile_inplace_value_func,
4021
+ group="Tile Primitives",
4022
+ hidden=True,
4023
+ export=False,
4024
+ is_differentiable=False,
4025
+ )
4026
+ add_builtin(
4027
+ "tile_bit_or_inplace",
4028
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
4029
+ value_func=tile_inplace_value_func,
4030
+ group="Tile Primitives",
4031
+ hidden=True,
4032
+ export=False,
4033
+ is_differentiable=False,
4034
+ )
4035
+ add_builtin(
4036
+ "tile_bit_or_inplace",
4037
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
4038
+ value_func=tile_inplace_value_func,
4039
+ group="Tile Primitives",
4040
+ hidden=True,
4041
+ export=False,
4042
+ is_differentiable=False,
4043
+ )
4044
+ add_builtin(
4045
+ "tile_bit_or_inplace",
4046
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
4047
+ value_func=tile_inplace_value_func,
4048
+ group="Tile Primitives",
4049
+ hidden=True,
4050
+ export=False,
4051
+ is_differentiable=False,
4052
+ )
4053
+
4054
+ add_builtin(
4055
+ "tile_bit_xor_inplace",
4056
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
4057
+ value_func=tile_inplace_value_func,
4058
+ group="Tile Primitives",
4059
+ hidden=True,
4060
+ export=False,
4061
+ is_differentiable=False,
4062
+ )
4063
+ add_builtin(
4064
+ "tile_bit_xor_inplace",
4065
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
4066
+ value_func=tile_inplace_value_func,
4067
+ group="Tile Primitives",
4068
+ hidden=True,
4069
+ export=False,
4070
+ is_differentiable=False,
4071
+ )
4072
+ add_builtin(
4073
+ "tile_bit_xor_inplace",
4074
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
4075
+ value_func=tile_inplace_value_func,
4076
+ group="Tile Primitives",
4077
+ hidden=True,
4078
+ export=False,
4079
+ is_differentiable=False,
4080
+ )
4081
+ add_builtin(
4082
+ "tile_bit_xor_inplace",
4083
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
4084
+ value_func=tile_inplace_value_func,
4085
+ group="Tile Primitives",
4086
+ hidden=True,
4087
+ export=False,
4088
+ is_differentiable=False,
4089
+ )
4090
+
4091
+
4092
+ def tile_transpose_value_func(arg_types, arg_values):
4093
+ # return generic type (for doc builds)
4094
+ if arg_types is None:
4095
+ return tile(dtype=Any, shape=Tuple[int, int])
4096
+
4097
+ if len(arg_types) != 1:
4098
+ raise TypeError(f"tile_transpose() takes exactly 1 positional argument but {len(arg_types)} were given")
4099
+
4100
+ t = arg_types["a"]
4101
+
4102
+ if not is_tile(t):
4103
+ raise TypeError(f"tile_transpose() argument must be a tile, got {t!r}")
4104
+
4105
+ layout = None
4106
+
4107
+ # flip layout
4108
+ if t.layout == "rowmajor":
4109
+ layout = "colmajor"
4110
+ elif t.layout == "colmajor":
4111
+ layout = "rowmajor"
4112
+
4113
+ # force the input tile to shared memory
4114
+ t.storage = "shared"
4115
+
4116
+ return tile(
4117
+ dtype=t.dtype,
4118
+ shape=t.shape[::-1],
4119
+ storage=t.storage,
4120
+ strides=t.strides[::-1],
4121
+ layout=layout,
4122
+ owner=False,
4123
+ )
4124
+
4125
+
4126
+ add_builtin(
4127
+ "tile_transpose",
4128
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, int])},
4129
+ value_func=tile_transpose_value_func,
4130
+ variadic=True,
4131
+ doc="""Transpose a tile.
4132
+
4133
+ For shared memory tiles, this operation will alias the input tile.
3783
4134
  Register tiles will first be transferred to shared memory before transposition.
3784
4135
 
3785
4136
  :param a: Tile to transpose with ``shape=(M,N)``
@@ -3910,6 +4261,80 @@ add_builtin(
3910
4261
  )
3911
4262
 
3912
4263
 
4264
+ def tile_sum_axis_value_func(arg_types, arg_values):
4265
+ if arg_types is None:
4266
+ return tile(dtype=Scalar, shape=Tuple[int, ...])
4267
+
4268
+ a = arg_types["a"]
4269
+
4270
+ if not is_tile(a):
4271
+ raise TypeError(f"tile_sum() 'a' argument must be a tile, got {a!r}")
4272
+
4273
+ # force input tile to shared
4274
+ a.storage = "shared"
4275
+
4276
+ axis = arg_values["axis"]
4277
+ shape = a.shape
4278
+
4279
+ if axis < 0 or axis >= len(shape):
4280
+ raise ValueError(f"tile_sum() axis {axis} is out of bounds for tile with {len(shape)} dimensions")
4281
+
4282
+ # shape is identical less the axis reduction is along
4283
+ if len(shape) > 1:
4284
+ new_shape = shape[:axis] + shape[axis + 1 :]
4285
+ else:
4286
+ new_shape = (1,)
4287
+
4288
+ return tile(dtype=a.dtype, shape=new_shape)
4289
+
4290
+
4291
+ def tile_sum_axis_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
4292
+ tile = arg_values["a"]
4293
+ axis_var = arg_values["axis"]
4294
+ if not hasattr(axis_var, "constant") or axis_var.constant is None:
4295
+ raise ValueError("tile_sum() axis must be a compile-time constant")
4296
+ axis = axis_var.constant
4297
+
4298
+ return ((tile,), (axis,))
4299
+
4300
+
4301
+ add_builtin(
4302
+ "tile_sum",
4303
+ input_types={"a": tile(dtype=Scalar, shape=Tuple[int, ...]), "axis": int},
4304
+ value_func=tile_sum_axis_value_func,
4305
+ dispatch_func=tile_sum_axis_dispatch_func,
4306
+ doc="""Cooperatively compute the sum of the tile elements across an axis of the tile using all threads in the block.
4307
+
4308
+ :param a: The input tile. Must reside in shared memory.
4309
+ :param axis: The tile axis to compute the sum across. Must be a compile-time constant.
4310
+ :returns: A tile with the same shape as the input tile less the axis dimension and the same data type as the input tile.
4311
+
4312
+ Example:
4313
+
4314
+ .. code-block:: python
4315
+
4316
+ @wp.kernel
4317
+ def compute():
4318
+
4319
+ t = wp.tile_ones(dtype=float, shape=(8, 8))
4320
+ s = wp.tile_sum(t, axis=0)
4321
+
4322
+ print(s)
4323
+
4324
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
4325
+
4326
+ Prints:
4327
+
4328
+ .. code-block:: text
4329
+
4330
+ [8 8 8 8 8 8 8 8] = tile(shape=(8), storage=register)
4331
+
4332
+ """,
4333
+ group="Tile Primitives",
4334
+ export=False,
4335
+ )
4336
+
4337
+
3913
4338
  def tile_sort_value_func(arg_types, arg_values):
3914
4339
  # return generic type (for doc builds)
3915
4340
  if arg_types is None:
@@ -3986,6 +4411,7 @@ add_builtin(
3986
4411
  """,
3987
4412
  group="Tile Primitives",
3988
4413
  export=False,
4414
+ is_differentiable=False,
3989
4415
  )
3990
4416
 
3991
4417
 
@@ -4039,6 +4465,7 @@ add_builtin(
4039
4465
  """,
4040
4466
  group="Tile Primitives",
4041
4467
  export=False,
4468
+ is_differentiable=False,
4042
4469
  )
4043
4470
 
4044
4471
 
@@ -4092,6 +4519,7 @@ add_builtin(
4092
4519
  """,
4093
4520
  group="Tile Primitives",
4094
4521
  export=False,
4522
+ is_differentiable=False,
4095
4523
  )
4096
4524
 
4097
4525
 
@@ -4144,6 +4572,7 @@ add_builtin(
4144
4572
  """,
4145
4573
  group="Tile Primitives",
4146
4574
  export=False,
4575
+ is_differentiable=False,
4147
4576
  )
4148
4577
 
4149
4578
 
@@ -4196,10 +4625,10 @@ add_builtin(
4196
4625
  """,
4197
4626
  group="Tile Primitives",
4198
4627
  export=False,
4628
+ is_differentiable=False,
4199
4629
  )
4200
4630
 
4201
4631
 
4202
- # does type propagation for load()
4203
4632
  def tile_reduce_value_func(arg_types, arg_values):
4204
4633
  if arg_types is None:
4205
4634
  return tile(dtype=Scalar, shape=(1,))
@@ -4253,6 +4682,88 @@ add_builtin(
4253
4682
  """,
4254
4683
  group="Tile Primitives",
4255
4684
  export=False,
4685
+ is_differentiable=False,
4686
+ )
4687
+
4688
+
4689
+ def tile_reduce_axis_value_func(arg_types, arg_values):
4690
+ if arg_types is None:
4691
+ return tile(dtype=Scalar, shape=Tuple[int, ...])
4692
+
4693
+ a = arg_types["a"]
4694
+
4695
+ if not is_tile(a):
4696
+ raise TypeError(f"tile_reduce() 'a' argument must be a tile, got {a!r}")
4697
+
4698
+ # force input tile to shared memory
4699
+ a.storage = "shared"
4700
+
4701
+ axis = arg_values["axis"]
4702
+ shape = a.shape
4703
+
4704
+ if axis < 0 or axis >= len(shape):
4705
+ raise ValueError(f"tile_reduce() axis {axis} is out of bounds for tile with {len(shape)} dimensions")
4706
+
4707
+ # shape is identical less the axis reduction is along
4708
+ if len(shape) > 1:
4709
+ new_shape = shape[:axis] + shape[axis + 1 :]
4710
+ else:
4711
+ new_shape = (1,)
4712
+
4713
+ return tile(dtype=a.dtype, shape=new_shape)
4714
+
4715
+
4716
+ add_builtin(
4717
+ "tile_reduce",
4718
+ input_types={"op": Callable, "a": tile(dtype=Scalar, shape=Tuple[int, ...]), "axis": int},
4719
+ value_func=tile_reduce_axis_value_func,
4720
+ native_func="tile_reduce_axis",
4721
+ doc="""Apply a custom reduction operator across a tile axis.
4722
+
4723
+ This function cooperatively performs a reduction using the provided operator across an axis of the tile.
4724
+
4725
+ :param op: A callable function that accepts two arguments and returns one argument, may be a user function or builtin
4726
+ :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's data type. Must reside in shared memory.
4727
+ :param axis: The tile axis to perform the reduction across. Must be a compile-time constant.
4728
+ :returns: A tile with the same shape as the input tile less the axis dimension and the same data type as the input tile.
4729
+
4730
+ Example:
4731
+
4732
+ .. code-block:: python
4733
+
4734
+ TILE_M = wp.constant(4)
4735
+ TILE_N = wp.constant(2)
4736
+
4737
+ @wp.kernel
4738
+ def compute(x: wp.array2d(dtype=float), y: wp.array(dtype=float)):
4739
+
4740
+ a = wp.tile_load(x, shape=(TILE_M, TILE_N))
4741
+ b = wp.tile_reduce(wp.add, a, axis=1)
4742
+ wp.tile_store(y, b)
4743
+
4744
+ arr = np.arange(TILE_M * TILE_N).reshape(TILE_M, TILE_N)
4745
+
4746
+ x = wp.array(arr, dtype=float)
4747
+ y = wp.zeros(TILE_M, dtype=float)
4748
+
4749
+ wp.launch_tiled(compute, dim=[1], inputs=[x], outputs=[y], block_dim=32)
4750
+
4751
+ print(x.numpy())
4752
+ print(y.numpy())
4753
+
4754
+ Prints:
4755
+
4756
+ .. code-block:: text
4757
+
4758
+ [[0. 1.]
4759
+ [2. 3.]
4760
+ [4. 5.]
4761
+ [6. 7.]]
4762
+ [ 1. 5. 9. 13.]
4763
+ """,
4764
+ group="Tile Primitives",
4765
+ export=False,
4766
+ is_differentiable=False,
4256
4767
  )
4257
4768
 
4258
4769
 
@@ -4316,6 +4827,7 @@ add_builtin(
4316
4827
  """,
4317
4828
  group="Tile Primitives",
4318
4829
  export=False,
4830
+ is_differentiable=False,
4319
4831
  )
4320
4832
 
4321
4833
 
@@ -4379,6 +4891,7 @@ add_builtin(
4379
4891
  """,
4380
4892
  group="Tile Primitives",
4381
4893
  export=False,
4894
+ is_differentiable=False,
4382
4895
  )
4383
4896
 
4384
4897
 
@@ -4632,6 +5145,7 @@ add_builtin(
4632
5145
  doc="WIP",
4633
5146
  group="Utility",
4634
5147
  hidden=True,
5148
+ is_differentiable=False,
4635
5149
  )
4636
5150
 
4637
5151
  add_builtin(
@@ -4647,6 +5161,7 @@ add_builtin(
4647
5161
  doc="WIP",
4648
5162
  group="Utility",
4649
5163
  hidden=True,
5164
+ is_differentiable=False,
4650
5165
  )
4651
5166
 
4652
5167
  add_builtin(
@@ -4656,6 +5171,7 @@ add_builtin(
4656
5171
  doc="WIP",
4657
5172
  group="Utility",
4658
5173
  hidden=True,
5174
+ is_differentiable=False,
4659
5175
  )
4660
5176
 
4661
5177
  add_builtin(
@@ -4707,6 +5223,7 @@ add_builtin(
4707
5223
  :param low: The lower bound of the bounding box in BVH space
4708
5224
  :param high: The upper bound of the bounding box in BVH space""",
4709
5225
  export=False,
5226
+ is_differentiable=False,
4710
5227
  )
4711
5228
 
4712
5229
  add_builtin(
@@ -4722,6 +5239,7 @@ add_builtin(
4722
5239
  :param start: The start of the ray in BVH space
4723
5240
  :param dir: The direction of the ray in BVH space""",
4724
5241
  export=False,
5242
+ is_differentiable=False,
4725
5243
  )
4726
5244
 
4727
5245
  add_builtin(
@@ -4732,6 +5250,7 @@ add_builtin(
4732
5250
  doc="""Move to the next bound returned by the query.
4733
5251
  The index of the current bound is stored in ``index``, returns ``False`` if there are no more overlapping bound.""",
4734
5252
  export=False,
5253
+ is_differentiable=False,
4735
5254
  )
4736
5255
 
4737
5256
  add_builtin(
@@ -5066,12 +5585,13 @@ add_builtin(
5066
5585
  group="Geometry",
5067
5586
  doc="""Construct an axis-aligned bounding box query against a :class:`Mesh`.
5068
5587
 
5069
- This query can be used to iterate over all triangles inside a volume.
5588
+ This query can be used to iterate over all bounding boxes of the triangles inside a volume.
5070
5589
 
5071
5590
  :param id: The mesh identifier
5072
5591
  :param low: The lower bound of the bounding box in mesh space
5073
5592
  :param high: The upper bound of the bounding box in mesh space""",
5074
5593
  export=False,
5594
+ is_differentiable=False,
5075
5595
  )
5076
5596
 
5077
5597
  add_builtin(
@@ -5079,10 +5599,11 @@ add_builtin(
5079
5599
  input_types={"query": MeshQueryAABB, "index": int},
5080
5600
  value_type=builtins.bool,
5081
5601
  group="Geometry",
5082
- doc="""Move to the next triangle overlapping the query bounding box.
5602
+ doc="""Move to the next triangle whose bounding box overlaps the query bounding box.
5083
5603
 
5084
5604
  The index of the current face is stored in ``index``, returns ``False`` if there are no more overlapping triangles.""",
5085
5605
  export=False,
5606
+ is_differentiable=False,
5086
5607
  )
5087
5608
 
5088
5609
  add_builtin(
@@ -5112,6 +5633,7 @@ add_builtin(
5112
5633
 
5113
5634
  This query can be used to iterate over all neighboring point within a fixed radius from the query point.""",
5114
5635
  export=False,
5636
+ is_differentiable=False,
5115
5637
  )
5116
5638
 
5117
5639
  add_builtin(
@@ -5123,6 +5645,7 @@ add_builtin(
5123
5645
 
5124
5646
  The index of the current neighbor is stored in ``index``, returns ``False`` if there are no more neighbors.""",
5125
5647
  export=False,
5648
+ is_differentiable=False,
5126
5649
  )
5127
5650
 
5128
5651
  add_builtin(
@@ -5136,6 +5659,7 @@ add_builtin(
5136
5659
 
5137
5660
  Returns -1 if the :class:`HashGrid` has not been reserved.""",
5138
5661
  export=False,
5662
+ is_differentiable=False,
5139
5663
  )
5140
5664
 
5141
5665
  add_builtin(
@@ -5145,15 +5669,34 @@ add_builtin(
5145
5669
  group="Geometry",
5146
5670
  doc="""Tests for intersection between two triangles (v0, v1, v2) and (u0, u1, u2) using Moller's method.
5147
5671
 
5672
+ This function works with single precision, may return incorrect results in some case.
5673
+
5674
+ Returns > 0 if triangles intersect.""",
5675
+ export=False,
5676
+ is_differentiable=False,
5677
+ )
5678
+
5679
+
5680
+ add_builtin(
5681
+ "intersect_tri_tri",
5682
+ input_types={"v0": vec3d, "v1": vec3d, "v2": vec3d, "u0": vec3d, "u1": vec3d, "u2": vec3d},
5683
+ value_type=int,
5684
+ group="Geometry",
5685
+ doc="""Tests for intersection between two triangles (v0, v1, v2) and (u0, u1, u2) using Moller's method.
5686
+
5687
+ This function works with double precision, results are more accurate than the single precision version.
5688
+
5148
5689
  Returns > 0 if triangles intersect.""",
5149
5690
  export=False,
5691
+ is_differentiable=False,
5150
5692
  )
5151
5693
 
5694
+
5152
5695
  add_builtin(
5153
5696
  "mesh_get",
5154
5697
  input_types={"id": uint64},
5155
5698
  value_type=Mesh,
5156
- missing_grad=True,
5699
+ is_differentiable=False,
5157
5700
  group="Geometry",
5158
5701
  doc="""Retrieves the mesh given its index.""",
5159
5702
  export=False,
@@ -5166,6 +5709,7 @@ add_builtin(
5166
5709
  group="Geometry",
5167
5710
  doc="""Evaluates the face normal the mesh given a face index.""",
5168
5711
  export=False,
5712
+ is_differentiable=False,
5169
5713
  )
5170
5714
 
5171
5715
  add_builtin(
@@ -5175,6 +5719,7 @@ add_builtin(
5175
5719
  group="Geometry",
5176
5720
  doc="""Returns the point of the mesh given a index.""",
5177
5721
  export=False,
5722
+ is_differentiable=False,
5178
5723
  )
5179
5724
 
5180
5725
  add_builtin(
@@ -5184,6 +5729,7 @@ add_builtin(
5184
5729
  group="Geometry",
5185
5730
  doc="""Returns the velocity of the mesh given a index.""",
5186
5731
  export=False,
5732
+ is_differentiable=False,
5187
5733
  )
5188
5734
 
5189
5735
  add_builtin(
@@ -5193,6 +5739,7 @@ add_builtin(
5193
5739
  group="Geometry",
5194
5740
  doc="""Returns the point-index of the mesh given a face-vertex index.""",
5195
5741
  export=False,
5742
+ is_differentiable=False,
5196
5743
  )
5197
5744
 
5198
5745
 
@@ -5233,12 +5780,32 @@ add_builtin(
5233
5780
  # ---------------------------------
5234
5781
  # Iterators
5235
5782
 
5236
- add_builtin("iter_next", input_types={"range": range_t}, value_type=int, group="Utility", export=False, hidden=True)
5237
5783
  add_builtin(
5238
- "iter_next", input_types={"query": HashGridQuery}, value_type=int, group="Utility", export=False, hidden=True
5784
+ "iter_next",
5785
+ input_types={"range": range_t},
5786
+ value_type=int,
5787
+ group="Utility",
5788
+ export=False,
5789
+ hidden=True,
5790
+ is_differentiable=False,
5791
+ )
5792
+ add_builtin(
5793
+ "iter_next",
5794
+ input_types={"query": HashGridQuery},
5795
+ value_type=int,
5796
+ group="Utility",
5797
+ export=False,
5798
+ hidden=True,
5799
+ is_differentiable=False,
5239
5800
  )
5240
5801
  add_builtin(
5241
- "iter_next", input_types={"query": MeshQueryAABB}, value_type=int, group="Utility", export=False, hidden=True
5802
+ "iter_next",
5803
+ input_types={"query": MeshQueryAABB},
5804
+ value_type=int,
5805
+ group="Utility",
5806
+ export=False,
5807
+ hidden=True,
5808
+ is_differentiable=False,
5242
5809
  )
5243
5810
 
5244
5811
  add_builtin(
@@ -5249,6 +5816,7 @@ add_builtin(
5249
5816
  group="Utility",
5250
5817
  doc="""Returns the range in reversed order.""",
5251
5818
  export=False,
5819
+ is_differentiable=False,
5252
5820
  )
5253
5821
 
5254
5822
  # ---------------------------------
@@ -5268,8 +5836,8 @@ _volume_supported_value_types = {
5268
5836
 
5269
5837
 
5270
5838
  def _is_volume_type_supported(dtype):
5271
- for typ in _volume_supported_value_types:
5272
- if types_equal(typ, dtype):
5839
+ for value_type in _volume_supported_value_types:
5840
+ if types_equal(value_type, dtype):
5273
5841
  return True
5274
5842
  return False
5275
5843
 
@@ -5397,6 +5965,7 @@ add_builtin(
5397
5965
  doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k`` for a volume of type type `dtype`.
5398
5966
 
5399
5967
  If the voxel at this index does not exist, this function returns the background value.""",
5968
+ is_differentiable=False,
5400
5969
  )
5401
5970
 
5402
5971
 
@@ -5417,6 +5986,7 @@ add_builtin(
5417
5986
  export=False,
5418
5987
  group="Volumes",
5419
5988
  doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
5989
+ is_differentiable=False,
5420
5990
  )
5421
5991
 
5422
5992
  add_builtin(
@@ -5447,6 +6017,7 @@ add_builtin(
5447
6017
  doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k``.
5448
6018
 
5449
6019
  If the voxel at this index does not exist, this function returns the background value""",
6020
+ is_differentiable=False,
5450
6021
  )
5451
6022
 
5452
6023
  add_builtin(
@@ -5455,6 +6026,7 @@ add_builtin(
5455
6026
  group="Volumes",
5456
6027
  doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
5457
6028
  export=False,
6029
+ is_differentiable=False,
5458
6030
  )
5459
6031
 
5460
6032
  add_builtin(
@@ -5475,6 +6047,7 @@ add_builtin(
5475
6047
  doc="""Returns the vector value of voxel with coordinates ``i``, ``j``, ``k``.
5476
6048
 
5477
6049
  If the voxel at this index does not exist, this function returns the background value.""",
6050
+ is_differentiable=False,
5478
6051
  )
5479
6052
 
5480
6053
  add_builtin(
@@ -5483,6 +6056,7 @@ add_builtin(
5483
6056
  group="Volumes",
5484
6057
  doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
5485
6058
  export=False,
6059
+ is_differentiable=False,
5486
6060
  )
5487
6061
 
5488
6062
  add_builtin(
@@ -5501,6 +6075,7 @@ add_builtin(
5501
6075
  doc="""Returns the :class:`int32` value of voxel with coordinates ``i``, ``j``, ``k``.
5502
6076
 
5503
6077
  If the voxel at this index does not exist, this function returns the background value.""",
6078
+ is_differentiable=False,
5504
6079
  )
5505
6080
 
5506
6081
  add_builtin(
@@ -5509,6 +6084,7 @@ add_builtin(
5509
6084
  group="Volumes",
5510
6085
  doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
5511
6086
  export=False,
6087
+ is_differentiable=False,
5512
6088
  )
5513
6089
 
5514
6090
 
@@ -5590,6 +6166,7 @@ add_builtin(
5590
6166
  If the voxel at this index does not exist, this function returns -1.
5591
6167
  This function is available for both index grids and classical volumes.
5592
6168
  """,
6169
+ is_differentiable=False,
5593
6170
  )
5594
6171
 
5595
6172
  add_builtin(
@@ -5631,6 +6208,7 @@ add_builtin(
5631
6208
  value_type=uint32,
5632
6209
  group="Random",
5633
6210
  doc="Initialize a new random number generator given a user-defined seed. Returns a 32-bit integer representing the RNG state.",
6211
+ is_differentiable=False,
5634
6212
  )
5635
6213
 
5636
6214
  add_builtin(
@@ -5642,6 +6220,7 @@ add_builtin(
5642
6220
 
5643
6221
  This alternative constructor can be useful in parallel programs, where a kernel as a whole should share a seed,
5644
6222
  but each thread should generate uncorrelated values. In this case usage should be ``r = rand_init(seed, tid)``""",
6223
+ is_differentiable=False,
5645
6224
  )
5646
6225
 
5647
6226
  add_builtin(
@@ -5650,6 +6229,7 @@ add_builtin(
5650
6229
  value_type=int,
5651
6230
  group="Random",
5652
6231
  doc="Return a random integer in the range [-2^31, 2^31).",
6232
+ is_differentiable=False,
5653
6233
  )
5654
6234
  add_builtin(
5655
6235
  "randi",
@@ -5657,6 +6237,7 @@ add_builtin(
5657
6237
  value_type=int,
5658
6238
  group="Random",
5659
6239
  doc="Return a random integer between [low, high).",
6240
+ is_differentiable=False,
5660
6241
  )
5661
6242
  add_builtin(
5662
6243
  "randu",
@@ -5664,6 +6245,7 @@ add_builtin(
5664
6245
  value_type=uint32,
5665
6246
  group="Random",
5666
6247
  doc="Return a random unsigned integer in the range [0, 2^32).",
6248
+ is_differentiable=False,
5667
6249
  )
5668
6250
  add_builtin(
5669
6251
  "randu",
@@ -5671,6 +6253,7 @@ add_builtin(
5671
6253
  value_type=uint32,
5672
6254
  group="Random",
5673
6255
  doc="Return a random unsigned integer between [low, high).",
6256
+ is_differentiable=False,
5674
6257
  )
5675
6258
  add_builtin(
5676
6259
  "randf",
@@ -5678,6 +6261,7 @@ add_builtin(
5678
6261
  value_type=float,
5679
6262
  group="Random",
5680
6263
  doc="Return a random float between [0.0, 1.0).",
6264
+ is_differentiable=False,
5681
6265
  )
5682
6266
  add_builtin(
5683
6267
  "randf",
@@ -5685,6 +6269,7 @@ add_builtin(
5685
6269
  value_type=float,
5686
6270
  group="Random",
5687
6271
  doc="Return a random float between [low, high).",
6272
+ is_differentiable=False,
5688
6273
  )
5689
6274
  add_builtin(
5690
6275
  "randn",
@@ -5692,6 +6277,7 @@ add_builtin(
5692
6277
  value_type=float,
5693
6278
  group="Random",
5694
6279
  doc="Sample a normal (Gaussian) distribution of mean 0 and variance 1. ",
6280
+ is_differentiable=False,
5695
6281
  )
5696
6282
 
5697
6283
  add_builtin(
@@ -5700,6 +6286,7 @@ add_builtin(
5700
6286
  value_type=int,
5701
6287
  group="Random",
5702
6288
  doc="Inverse-transform sample a cumulative distribution function.",
6289
+ is_differentiable=False,
5703
6290
  )
5704
6291
  add_builtin(
5705
6292
  "sample_triangle",
@@ -5707,6 +6294,7 @@ add_builtin(
5707
6294
  value_type=vec2,
5708
6295
  group="Random",
5709
6296
  doc="Uniformly sample a triangle. Returns sample barycentric coordinates.",
6297
+ is_differentiable=False,
5710
6298
  )
5711
6299
  add_builtin(
5712
6300
  "sample_unit_ring",
@@ -5714,6 +6302,7 @@ add_builtin(
5714
6302
  value_type=vec2,
5715
6303
  group="Random",
5716
6304
  doc="Uniformly sample a ring in the xy plane.",
6305
+ is_differentiable=False,
5717
6306
  )
5718
6307
  add_builtin(
5719
6308
  "sample_unit_disk",
@@ -5721,6 +6310,7 @@ add_builtin(
5721
6310
  value_type=vec2,
5722
6311
  group="Random",
5723
6312
  doc="Uniformly sample a disk in the xy plane.",
6313
+ is_differentiable=False,
5724
6314
  )
5725
6315
  add_builtin(
5726
6316
  "sample_unit_sphere_surface",
@@ -5728,6 +6318,7 @@ add_builtin(
5728
6318
  value_type=vec3,
5729
6319
  group="Random",
5730
6320
  doc="Uniformly sample a unit sphere surface.",
6321
+ is_differentiable=False,
5731
6322
  )
5732
6323
  add_builtin(
5733
6324
  "sample_unit_sphere",
@@ -5735,6 +6326,7 @@ add_builtin(
5735
6326
  value_type=vec3,
5736
6327
  group="Random",
5737
6328
  doc="Uniformly sample a unit sphere.",
6329
+ is_differentiable=False,
5738
6330
  )
5739
6331
  add_builtin(
5740
6332
  "sample_unit_hemisphere_surface",
@@ -5742,6 +6334,7 @@ add_builtin(
5742
6334
  value_type=vec3,
5743
6335
  group="Random",
5744
6336
  doc="Uniformly sample a unit hemisphere surface.",
6337
+ is_differentiable=False,
5745
6338
  )
5746
6339
  add_builtin(
5747
6340
  "sample_unit_hemisphere",
@@ -5749,6 +6342,7 @@ add_builtin(
5749
6342
  value_type=vec3,
5750
6343
  group="Random",
5751
6344
  doc="Uniformly sample a unit hemisphere.",
6345
+ is_differentiable=False,
5752
6346
  )
5753
6347
  add_builtin(
5754
6348
  "sample_unit_square",
@@ -5756,6 +6350,7 @@ add_builtin(
5756
6350
  value_type=vec2,
5757
6351
  group="Random",
5758
6352
  doc="Uniformly sample a unit square.",
6353
+ is_differentiable=False,
5759
6354
  )
5760
6355
  add_builtin(
5761
6356
  "sample_unit_cube",
@@ -5763,6 +6358,7 @@ add_builtin(
5763
6358
  value_type=vec3,
5764
6359
  group="Random",
5765
6360
  doc="Uniformly sample a unit cube.",
6361
+ is_differentiable=False,
5766
6362
  )
5767
6363
 
5768
6364
  add_builtin(
@@ -5774,6 +6370,7 @@ add_builtin(
5774
6370
 
5775
6371
  :param state: RNG state
5776
6372
  :param lam: The expected value of the distribution""",
6373
+ is_differentiable=False,
5777
6374
  )
5778
6375
 
5779
6376
  add_builtin(
@@ -5841,7 +6438,7 @@ add_builtin(
5841
6438
  value_type=vec2,
5842
6439
  group="Random",
5843
6440
  doc="Divergence-free vector field based on the gradient of a Perlin noise function.",
5844
- missing_grad=True,
6441
+ is_differentiable=False,
5845
6442
  )
5846
6443
  add_builtin(
5847
6444
  "curlnoise",
@@ -5850,7 +6447,7 @@ add_builtin(
5850
6447
  value_type=vec3,
5851
6448
  group="Random",
5852
6449
  doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
5853
- missing_grad=True,
6450
+ is_differentiable=False,
5854
6451
  )
5855
6452
  add_builtin(
5856
6453
  "curlnoise",
@@ -5859,7 +6456,7 @@ add_builtin(
5859
6456
  value_type=vec3,
5860
6457
  group="Random",
5861
6458
  doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
5862
- missing_grad=True,
6459
+ is_differentiable=False,
5863
6460
  )
5864
6461
 
5865
6462
 
@@ -5891,10 +6488,17 @@ add_builtin(
5891
6488
  dispatch_func=printf_dispatch_func,
5892
6489
  group="Utility",
5893
6490
  doc="Allows printing formatted strings using C-style format specifiers.",
6491
+ is_differentiable=False,
6492
+ )
6493
+
6494
+ add_builtin(
6495
+ "print",
6496
+ input_types={"value": Any},
6497
+ doc="Print variable to stdout",
6498
+ export=False,
6499
+ group="Utility",
5894
6500
  )
5895
6501
 
5896
- add_builtin("print", input_types={"value": Any}, doc="Print variable to stdout", export=False, group="Utility")
5897
-
5898
6502
  add_builtin(
5899
6503
  "breakpoint",
5900
6504
  input_types={},
@@ -5903,6 +6507,7 @@ add_builtin(
5903
6507
  group="Utility",
5904
6508
  namespace="",
5905
6509
  native_func="__debugbreak",
6510
+ is_differentiable=False,
5906
6511
  )
5907
6512
 
5908
6513
  # helpers
@@ -5920,6 +6525,7 @@ add_builtin(
5920
6525
  This function may not be called from user-defined Warp functions.""",
5921
6526
  namespace="",
5922
6527
  native_func="builtin_tid1d",
6528
+ is_differentiable=False,
5923
6529
  )
5924
6530
 
5925
6531
  add_builtin(
@@ -5930,6 +6536,7 @@ add_builtin(
5930
6536
  doc="Returns the number of threads in the current block.",
5931
6537
  namespace="",
5932
6538
  native_func="builtin_block_dim",
6539
+ is_differentiable=False,
5933
6540
  )
5934
6541
 
5935
6542
  add_builtin(
@@ -5944,6 +6551,7 @@ add_builtin(
5944
6551
  This function may not be called from user-defined Warp functions.""",
5945
6552
  namespace="",
5946
6553
  native_func="builtin_tid2d",
6554
+ is_differentiable=False,
5947
6555
  )
5948
6556
 
5949
6557
  add_builtin(
@@ -5958,6 +6566,7 @@ add_builtin(
5958
6566
  This function may not be called from user-defined Warp functions.""",
5959
6567
  namespace="",
5960
6568
  native_func="builtin_tid3d",
6569
+ is_differentiable=False,
5961
6570
  )
5962
6571
 
5963
6572
  add_builtin(
@@ -5972,17 +6581,37 @@ add_builtin(
5972
6581
  This function may not be called from user-defined Warp functions.""",
5973
6582
  namespace="",
5974
6583
  native_func="builtin_tid4d",
6584
+ is_differentiable=False,
5975
6585
  )
5976
6586
 
5977
6587
 
6588
+ def copy_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6589
+ a = arg_types["a"]
6590
+
6591
+ # if the input is a shared tile, we force a copy
6592
+ if is_tile(a) and a.storage == "shared":
6593
+ return tile(
6594
+ dtype=a.dtype,
6595
+ shape=a.shape,
6596
+ storage=a.storage,
6597
+ strides=a.strides,
6598
+ layout=a.layout,
6599
+ owner=True,
6600
+ )
6601
+
6602
+ return a
6603
+
6604
+
5978
6605
  add_builtin(
5979
6606
  "copy",
5980
6607
  input_types={"a": Any},
5981
- value_func=lambda arg_types, arg_values: arg_types["a"],
6608
+ value_func=copy_value_func,
5982
6609
  hidden=True,
5983
6610
  export=False,
5984
6611
  group="Utility",
5985
6612
  )
6613
+
6614
+
5986
6615
  add_builtin(
5987
6616
  "assign",
5988
6617
  input_types={"dest": Any, "src": Any},
@@ -5992,61 +6621,88 @@ add_builtin(
5992
6621
  )
5993
6622
 
5994
6623
 
5995
- def select_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
5996
- warp.utils.warn(
5997
- "wp.select() is deprecated and will be removed in a future\n"
5998
- "version. Use wp.where(cond, value_if_true, value_if_false) instead.",
5999
- category=DeprecationWarning,
6000
- )
6001
-
6002
- func_args = tuple(args.values())
6003
- template_args = ()
6624
+ def select_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6625
+ if arg_types is None:
6626
+ return Any
6004
6627
 
6005
- return (func_args, template_args)
6628
+ raise RuntimeError("wp.select() has been removed. Use wp.where(cond, value_if_true, value_if_false) instead.")
6006
6629
 
6007
6630
 
6008
6631
  add_builtin(
6009
6632
  "select",
6010
6633
  input_types={"cond": builtins.bool, "value_if_false": Any, "value_if_true": Any},
6011
- value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
6012
- dispatch_func=select_dispatch_func,
6634
+ value_func=select_value_func,
6013
6635
  doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
6014
6636
 
6015
- .. deprecated:: 1.7
6637
+ .. versionremoved:: 1.10
6016
6638
  Use :func:`where` instead, which has the more intuitive argument order:
6017
- ``where(cond, value_if_true, value_if_false)``.""",
6639
+ ``where(cond, value_if_true, value_if_false)``.
6640
+
6641
+ .. deprecated:: 1.7""",
6018
6642
  group="Utility",
6019
6643
  )
6020
6644
  for t in int_types:
6021
6645
  add_builtin(
6022
6646
  "select",
6023
6647
  input_types={"cond": t, "value_if_false": Any, "value_if_true": Any},
6024
- value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
6025
- dispatch_func=select_dispatch_func,
6648
+ value_func=select_value_func,
6026
6649
  doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
6027
6650
 
6028
- .. deprecated:: 1.7
6651
+ .. versionremoved:: 1.10
6029
6652
  Use :func:`where` instead, which has the more intuitive argument order:
6030
- ``where(cond, value_if_true, value_if_false)``.""",
6653
+ ``where(cond, value_if_true, value_if_false)``.
6654
+
6655
+ .. deprecated:: 1.7""",
6031
6656
  group="Utility",
6032
6657
  )
6033
6658
  add_builtin(
6034
6659
  "select",
6035
6660
  input_types={"arr": array(dtype=Any), "value_if_false": Any, "value_if_true": Any},
6036
- value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
6037
- dispatch_func=select_dispatch_func,
6661
+ value_func=select_value_func,
6038
6662
  doc="""Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``.
6039
6663
 
6040
- .. deprecated:: 1.7
6664
+ .. versionremoved:: 1.10
6041
6665
  Use :func:`where` instead, which has the more intuitive argument order:
6042
- ``where(arr, value_if_true, value_if_false)``.""",
6666
+ ``where(arr, value_if_true, value_if_false)``.
6667
+
6668
+ .. deprecated:: 1.7""",
6043
6669
  group="Utility",
6044
6670
  )
6045
6671
 
6672
+
6673
+ def where_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6674
+ if arg_types is None:
6675
+ return Any
6676
+
6677
+ v_true = arg_types["value_if_true"]
6678
+ v_false = arg_types["value_if_false"]
6679
+
6680
+ if not types_equal(v_true, v_false):
6681
+ raise RuntimeError(f"where() true value type ({v_true}) must be of the same type as the false type ({v_false})")
6682
+
6683
+ if is_tile(v_false):
6684
+ if v_true.storage == "register":
6685
+ return v_true
6686
+ if v_false.storage == "register":
6687
+ return v_false
6688
+
6689
+ # both v_true and v_false are shared
6690
+ return tile(
6691
+ dtype=v_true.dtype,
6692
+ shape=v_true.shape,
6693
+ storage=v_true.storage,
6694
+ strides=v_true.strides,
6695
+ layout=v_true.layout,
6696
+ owner=True,
6697
+ )
6698
+
6699
+ return v_true
6700
+
6701
+
6046
6702
  add_builtin(
6047
6703
  "where",
6048
6704
  input_types={"cond": builtins.bool, "value_if_true": Any, "value_if_false": Any},
6049
- value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
6705
+ value_func=where_value_func,
6050
6706
  doc="Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``.",
6051
6707
  group="Utility",
6052
6708
  )
@@ -6054,14 +6710,14 @@ for t in int_types:
6054
6710
  add_builtin(
6055
6711
  "where",
6056
6712
  input_types={"cond": t, "value_if_true": Any, "value_if_false": Any},
6057
- value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
6713
+ value_func=where_value_func,
6058
6714
  doc="Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``.",
6059
6715
  group="Utility",
6060
6716
  )
6061
6717
  add_builtin(
6062
6718
  "where",
6063
6719
  input_types={"arr": array(dtype=Any), "value_if_true": Any, "value_if_false": Any},
6064
- value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
6720
+ value_func=where_value_func,
6065
6721
  doc="Select between two arguments, if ``arr`` is not null then return ``value_if_true``, otherwise return ``value_if_false``.",
6066
6722
  group="Utility",
6067
6723
  )
@@ -6099,7 +6755,7 @@ add_builtin(
6099
6755
  group="Utility",
6100
6756
  hidden=True,
6101
6757
  export=False,
6102
- missing_grad=True,
6758
+ is_differentiable=False,
6103
6759
  )
6104
6760
 
6105
6761
 
@@ -6140,7 +6796,7 @@ add_builtin(
6140
6796
  native_func="fixedarray_t",
6141
6797
  group="Utility",
6142
6798
  export=False,
6143
- missing_grad=True,
6799
+ is_differentiable=False,
6144
6800
  hidden=True, # Unhide once we can document both a built-in and a Python scope function sharing the same name.
6145
6801
  )
6146
6802
 
@@ -6183,14 +6839,13 @@ for array_type in array_types:
6183
6839
  # does argument checking and type propagation for view()
6184
6840
  def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6185
6841
  arr_type = arg_types["arr"]
6186
- idx_types = tuple(arg_types[x] for x in "ijk" if arg_types.get(x, None) is not None)
6842
+ idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
6187
6843
 
6188
6844
  if not is_array(arr_type):
6189
6845
  raise RuntimeError("view() first argument must be an array")
6190
6846
 
6191
6847
  idx_count = len(idx_types)
6192
-
6193
- if idx_count >= arr_type.ndim:
6848
+ if idx_count > arr_type.ndim:
6194
6849
  raise RuntimeError(
6195
6850
  f"Trying to create an array view with {idx_count} indices, "
6196
6851
  f"but the array only has {arr_type.ndim} dimension(s). "
@@ -6198,14 +6853,35 @@ def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]
6198
6853
  f"the expected number of dimensions, e.g.: def func(param: wp.array3d(dtype=float): ..."
6199
6854
  )
6200
6855
 
6201
- # check index types
6202
- for t in idx_types:
6203
- if not type_is_int(t):
6204
- raise RuntimeError(f"view() index arguments must be of integer type, got index of type {type_repr(t)}")
6856
+ has_slice = any(is_slice(x) for x in idx_types)
6857
+ if has_slice:
6858
+ # check index types
6859
+ for t in idx_types:
6860
+ if not (type_is_int(t) or is_slice(t)):
6861
+ raise RuntimeError(
6862
+ f"view() index arguments must be of integer or slice types, got index of type {type_repr(t)}"
6863
+ )
6864
+
6865
+ # Each integer index collapses one dimension.
6866
+ int_count = sum(x.step == 0 for x in idx_types)
6867
+ ndim = arr_type.ndim - int_count
6868
+ assert ndim > 0
6869
+ else:
6870
+ if idx_count == arr_type.ndim:
6871
+ raise RuntimeError("Expected to call `address()` instead of `view()`")
6872
+
6873
+ # check index types
6874
+ for t in idx_types:
6875
+ if not type_is_int(t):
6876
+ raise RuntimeError(
6877
+ f"view() index arguments must be of integer or slice types, got index of type {type_repr(t)}"
6878
+ )
6879
+
6880
+ # create an array view with leading dimensions removed
6881
+ ndim = arr_type.ndim - idx_count
6882
+ assert ndim > 0
6205
6883
 
6206
- # create an array view with leading dimensions removed
6207
6884
  dtype = arr_type.dtype
6208
- ndim = arr_type.ndim - idx_count
6209
6885
  if isinstance(arr_type, (fabricarray, indexedfabricarray)):
6210
6886
  # fabric array of arrays: return array attribute as a regular array
6211
6887
  return array(dtype=dtype, ndim=ndim)
@@ -6216,8 +6892,18 @@ def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]
6216
6892
  for array_type in array_types:
6217
6893
  add_builtin(
6218
6894
  "view",
6219
- input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int},
6220
- defaults={"j": None, "k": None},
6895
+ input_types={
6896
+ "arr": array_type(dtype=Any),
6897
+ "i": Any,
6898
+ "j": Any,
6899
+ "k": Any,
6900
+ "l": Any,
6901
+ },
6902
+ defaults={
6903
+ "j": None,
6904
+ "k": None,
6905
+ "l": None,
6906
+ },
6221
6907
  constraint=sametypes,
6222
6908
  hidden=True,
6223
6909
  value_func=view_value_func,
@@ -6321,6 +7007,7 @@ add_builtin(
6321
7007
  hidden=True,
6322
7008
  skip_replay=True,
6323
7009
  group="Utility",
7010
+ is_differentiable=False,
6324
7011
  )
6325
7012
 
6326
7013
 
@@ -6337,6 +7024,7 @@ add_builtin(
6337
7024
  dispatch_func=load_dispatch_func,
6338
7025
  hidden=True,
6339
7026
  group="Utility",
7027
+ is_differentiable=False,
6340
7028
  )
6341
7029
 
6342
7030
 
@@ -6412,6 +7100,13 @@ def create_atomic_op_value_func(op: str):
6412
7100
  f"atomic_{op}() operations only work on arrays with [u]int32, [u]int64, float32, or float64 "
6413
7101
  f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
6414
7102
  )
7103
+ elif op in ("and", "or", "xor"):
7104
+ supported_atomic_types = (warp.int32, warp.int64, warp.uint32, warp.uint64)
7105
+ if not any(types_equal(scalar_type, x, match_generic=True) for x in supported_atomic_types):
7106
+ raise RuntimeError(
7107
+ f"atomic_{op}() operations only work on arrays with [u]int32 or [u]int64 "
7108
+ f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
7109
+ )
6415
7110
  else:
6416
7111
  raise NotImplementedError
6417
7112
 
@@ -6445,7 +7140,8 @@ for array_type in array_types:
6445
7140
  value_func=create_atomic_op_value_func("add"),
6446
7141
  dispatch_func=atomic_op_dispatch_func,
6447
7142
  doc="""Atomically adds ``value`` onto ``arr[i]`` and returns the original value of ``arr[i]``.
6448
- This function is automatically invoked when using the syntax ``arr[i] += value``.""",
7143
+
7144
+ This function is automatically invoked when using the syntax ``arr[i] += value``.""",
6449
7145
  group="Utility",
6450
7146
  skip_replay=True,
6451
7147
  )
@@ -6457,7 +7153,8 @@ for array_type in array_types:
6457
7153
  value_func=create_atomic_op_value_func("add"),
6458
7154
  dispatch_func=atomic_op_dispatch_func,
6459
7155
  doc="""Atomically adds ``value`` onto ``arr[i,j]`` and returns the original value of ``arr[i,j]``.
6460
- This function is automatically invoked when using the syntax ``arr[i,j] += value``.""",
7156
+
7157
+ This function is automatically invoked when using the syntax ``arr[i,j] += value``.""",
6461
7158
  group="Utility",
6462
7159
  skip_replay=True,
6463
7160
  )
@@ -6469,7 +7166,8 @@ for array_type in array_types:
6469
7166
  value_func=create_atomic_op_value_func("add"),
6470
7167
  dispatch_func=atomic_op_dispatch_func,
6471
7168
  doc="""Atomically adds ``value`` onto ``arr[i,j,k]`` and returns the original value of ``arr[i,j,k]``.
6472
- This function is automatically invoked when using the syntax ``arr[i,j,k] += value``.""",
7169
+
7170
+ This function is automatically invoked when using the syntax ``arr[i,j,k] += value``.""",
6473
7171
  group="Utility",
6474
7172
  skip_replay=True,
6475
7173
  )
@@ -6481,7 +7179,8 @@ for array_type in array_types:
6481
7179
  value_func=create_atomic_op_value_func("add"),
6482
7180
  dispatch_func=atomic_op_dispatch_func,
6483
7181
  doc="""Atomically adds ``value`` onto ``arr[i,j,k,l]`` and returns the original value of ``arr[i,j,k,l]``.
6484
- This function is automatically invoked when using the syntax ``arr[i,j,k,l] += value``.""",
7182
+
7183
+ This function is automatically invoked when using the syntax ``arr[i,j,k,l] += value``.""",
6485
7184
  group="Utility",
6486
7185
  skip_replay=True,
6487
7186
  )
@@ -6494,7 +7193,8 @@ for array_type in array_types:
6494
7193
  value_func=create_atomic_op_value_func("sub"),
6495
7194
  dispatch_func=atomic_op_dispatch_func,
6496
7195
  doc="""Atomically subtracts ``value`` onto ``arr[i]`` and returns the original value of ``arr[i]``.
6497
- This function is automatically invoked when using the syntax ``arr[i] -= value``.""",
7196
+
7197
+ This function is automatically invoked when using the syntax ``arr[i] -= value``.""",
6498
7198
  group="Utility",
6499
7199
  skip_replay=True,
6500
7200
  )
@@ -6506,7 +7206,8 @@ for array_type in array_types:
6506
7206
  value_func=create_atomic_op_value_func("sub"),
6507
7207
  dispatch_func=atomic_op_dispatch_func,
6508
7208
  doc="""Atomically subtracts ``value`` onto ``arr[i,j]`` and returns the original value of ``arr[i,j]``.
6509
- This function is automatically invoked when using the syntax ``arr[i,j] -= value``.""",
7209
+
7210
+ This function is automatically invoked when using the syntax ``arr[i,j] -= value``.""",
6510
7211
  group="Utility",
6511
7212
  skip_replay=True,
6512
7213
  )
@@ -6518,7 +7219,8 @@ for array_type in array_types:
6518
7219
  value_func=create_atomic_op_value_func("sub"),
6519
7220
  dispatch_func=atomic_op_dispatch_func,
6520
7221
  doc="""Atomically subtracts ``value`` onto ``arr[i,j,k]`` and returns the original value of ``arr[i,j,k]``.
6521
- This function is automatically invoked when using the syntax ``arr[i,j,k] -= value``.""",
7222
+
7223
+ This function is automatically invoked when using the syntax ``arr[i,j,k] -= value``.""",
6522
7224
  group="Utility",
6523
7225
  skip_replay=True,
6524
7226
  )
@@ -6530,7 +7232,8 @@ for array_type in array_types:
6530
7232
  value_func=create_atomic_op_value_func("sub"),
6531
7233
  dispatch_func=atomic_op_dispatch_func,
6532
7234
  doc="""Atomically subtracts ``value`` onto ``arr[i,j,k,l]`` and returns the original value of ``arr[i,j,k,l]``.
6533
- This function is automatically invoked when using the syntax ``arr[i,j,k,l] -= value``.""",
7235
+
7236
+ This function is automatically invoked when using the syntax ``arr[i,j,k,l] -= value``.""",
6534
7237
  group="Utility",
6535
7238
  skip_replay=True,
6536
7239
  )
@@ -6653,6 +7356,7 @@ for array_type in array_types:
6653
7356
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6654
7357
  group="Utility",
6655
7358
  skip_replay=True,
7359
+ is_differentiable=False,
6656
7360
  )
6657
7361
  add_builtin(
6658
7362
  "atomic_cas",
@@ -6666,6 +7370,7 @@ for array_type in array_types:
6666
7370
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6667
7371
  group="Utility",
6668
7372
  skip_replay=True,
7373
+ is_differentiable=False,
6669
7374
  )
6670
7375
  add_builtin(
6671
7376
  "atomic_cas",
@@ -6679,6 +7384,7 @@ for array_type in array_types:
6679
7384
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6680
7385
  group="Utility",
6681
7386
  skip_replay=True,
7387
+ is_differentiable=False,
6682
7388
  )
6683
7389
  add_builtin(
6684
7390
  "atomic_cas",
@@ -6700,6 +7406,7 @@ for array_type in array_types:
6700
7406
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6701
7407
  group="Utility",
6702
7408
  skip_replay=True,
7409
+ is_differentiable=False,
6703
7410
  )
6704
7411
 
6705
7412
  add_builtin(
@@ -6714,6 +7421,7 @@ for array_type in array_types:
6714
7421
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6715
7422
  group="Utility",
6716
7423
  skip_replay=True,
7424
+ is_differentiable=False,
6717
7425
  )
6718
7426
  add_builtin(
6719
7427
  "atomic_exch",
@@ -6727,6 +7435,7 @@ for array_type in array_types:
6727
7435
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6728
7436
  group="Utility",
6729
7437
  skip_replay=True,
7438
+ is_differentiable=False,
6730
7439
  )
6731
7440
  add_builtin(
6732
7441
  "atomic_exch",
@@ -6740,6 +7449,7 @@ for array_type in array_types:
6740
7449
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6741
7450
  group="Utility",
6742
7451
  skip_replay=True,
7452
+ is_differentiable=False,
6743
7453
  )
6744
7454
  add_builtin(
6745
7455
  "atomic_exch",
@@ -6755,6 +7465,177 @@ for array_type in array_types:
6755
7465
  skip_replay=True,
6756
7466
  )
6757
7467
 
7468
+ add_builtin(
7469
+ "atomic_and",
7470
+ hidden=hidden,
7471
+ input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
7472
+ constraint=atomic_op_constraint,
7473
+ value_func=create_atomic_op_value_func("and"),
7474
+ dispatch_func=atomic_op_dispatch_func,
7475
+ doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i]``, atomically update the array, and return the old value.
7476
+
7477
+ This function is automatically invoked when using the syntax ``arr[i] &= value``.""",
7478
+ group="Utility",
7479
+ skip_replay=True,
7480
+ is_differentiable=False,
7481
+ )
7482
+ add_builtin(
7483
+ "atomic_and",
7484
+ hidden=hidden,
7485
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
7486
+ constraint=atomic_op_constraint,
7487
+ value_func=create_atomic_op_value_func("and"),
7488
+ dispatch_func=atomic_op_dispatch_func,
7489
+ doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
7490
+
7491
+ This function is automatically invoked when using the syntax ``arr[i,j] &= value``.""",
7492
+ group="Utility",
7493
+ skip_replay=True,
7494
+ is_differentiable=False,
7495
+ )
7496
+ add_builtin(
7497
+ "atomic_and",
7498
+ hidden=hidden,
7499
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
7500
+ constraint=atomic_op_constraint,
7501
+ value_func=create_atomic_op_value_func("and"),
7502
+ dispatch_func=atomic_op_dispatch_func,
7503
+ doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
7504
+
7505
+ This function is automatically invoked when using the syntax ``arr[i,j,k] &= value``.""",
7506
+ group="Utility",
7507
+ skip_replay=True,
7508
+ is_differentiable=False,
7509
+ )
7510
+ add_builtin(
7511
+ "atomic_and",
7512
+ hidden=hidden,
7513
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
7514
+ constraint=atomic_op_constraint,
7515
+ value_func=create_atomic_op_value_func("and"),
7516
+ dispatch_func=atomic_op_dispatch_func,
7517
+ doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
7518
+
7519
+ This function is automatically invoked when using the syntax ``arr[i,j,k,l] &= value``.""",
7520
+ group="Utility",
7521
+ skip_replay=True,
7522
+ is_differentiable=False,
7523
+ )
7524
+
7525
+ add_builtin(
7526
+ "atomic_or",
7527
+ hidden=hidden,
7528
+ input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
7529
+ constraint=atomic_op_constraint,
7530
+ value_func=create_atomic_op_value_func("or"),
7531
+ dispatch_func=atomic_op_dispatch_func,
7532
+ doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i]``, atomically update the array, and return the old value.
7533
+
7534
+ This function is automatically invoked when using the syntax ``arr[i] |= value``.""",
7535
+ group="Utility",
7536
+ skip_replay=True,
7537
+ is_differentiable=False,
7538
+ )
7539
+ add_builtin(
7540
+ "atomic_or",
7541
+ hidden=hidden,
7542
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
7543
+ constraint=atomic_op_constraint,
7544
+ value_func=create_atomic_op_value_func("or"),
7545
+ dispatch_func=atomic_op_dispatch_func,
7546
+ doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
7547
+
7548
+ This function is automatically invoked when using the syntax ``arr[i,j] |= value``.""",
7549
+ group="Utility",
7550
+ skip_replay=True,
7551
+ is_differentiable=False,
7552
+ )
7553
+ add_builtin(
7554
+ "atomic_or",
7555
+ hidden=hidden,
7556
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
7557
+ constraint=atomic_op_constraint,
7558
+ value_func=create_atomic_op_value_func("or"),
7559
+ dispatch_func=atomic_op_dispatch_func,
7560
+ doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
7561
+
7562
+ This function is automatically invoked when using the syntax ``arr[i,j,k] |= value``.""",
7563
+ group="Utility",
7564
+ skip_replay=True,
7565
+ is_differentiable=False,
7566
+ )
7567
+ add_builtin(
7568
+ "atomic_or",
7569
+ hidden=hidden,
7570
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
7571
+ constraint=atomic_op_constraint,
7572
+ value_func=create_atomic_op_value_func("or"),
7573
+ dispatch_func=atomic_op_dispatch_func,
7574
+ doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
7575
+
7576
+ This function is automatically invoked when using the syntax ``arr[i,j,k,l] |= value``.""",
7577
+ group="Utility",
7578
+ skip_replay=True,
7579
+ is_differentiable=False,
7580
+ )
7581
+
7582
+ add_builtin(
7583
+ "atomic_xor",
7584
+ hidden=hidden,
7585
+ input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
7586
+ constraint=atomic_op_constraint,
7587
+ value_func=create_atomic_op_value_func("xor"),
7588
+ dispatch_func=atomic_op_dispatch_func,
7589
+ doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i]``, atomically update the array, and return the old value.
7590
+
7591
+ This function is automatically invoked when using the syntax ``arr[i] ^= value``.""",
7592
+ group="Utility",
7593
+ skip_replay=True,
7594
+ is_differentiable=False,
7595
+ )
7596
+ add_builtin(
7597
+ "atomic_xor",
7598
+ hidden=hidden,
7599
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
7600
+ constraint=atomic_op_constraint,
7601
+ value_func=create_atomic_op_value_func("xor"),
7602
+ dispatch_func=atomic_op_dispatch_func,
7603
+ doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
7604
+
7605
+ This function is automatically invoked when using the syntax ``arr[i,j] ^= value``.""",
7606
+ group="Utility",
7607
+ skip_replay=True,
7608
+ is_differentiable=False,
7609
+ )
7610
+ add_builtin(
7611
+ "atomic_xor",
7612
+ hidden=hidden,
7613
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
7614
+ constraint=atomic_op_constraint,
7615
+ value_func=create_atomic_op_value_func("xor"),
7616
+ dispatch_func=atomic_op_dispatch_func,
7617
+ doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
7618
+
7619
+ This function is automatically invoked when using the syntax ``arr[i,j,k] ^= value``.""",
7620
+ group="Utility",
7621
+ skip_replay=True,
7622
+ is_differentiable=False,
7623
+ )
7624
+ add_builtin(
7625
+ "atomic_xor",
7626
+ hidden=hidden,
7627
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
7628
+ constraint=atomic_op_constraint,
7629
+ value_func=create_atomic_op_value_func("xor"),
7630
+ dispatch_func=atomic_op_dispatch_func,
7631
+ doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
7632
+
7633
+ This function is automatically invoked when using the syntax ``arr[i,j,k,l] ^= value``.""",
7634
+ group="Utility",
7635
+ skip_replay=True,
7636
+ is_differentiable=False,
7637
+ )
7638
+
6758
7639
 
6759
7640
  # used to index into builtin types, i.e.: y = vec3[1]
6760
7641
  def vector_extract_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
@@ -6903,6 +7784,7 @@ add_builtin(
6903
7784
  hidden=True,
6904
7785
  group="Utility",
6905
7786
  skip_replay=True,
7787
+ is_differentiable=False,
6906
7788
  )
6907
7789
  # implements &quaternion[index]
6908
7790
  add_builtin(
@@ -6913,6 +7795,7 @@ add_builtin(
6913
7795
  hidden=True,
6914
7796
  group="Utility",
6915
7797
  skip_replay=True,
7798
+ is_differentiable=False,
6916
7799
  )
6917
7800
  # implements &transformation[index]
6918
7801
  add_builtin(
@@ -6923,6 +7806,7 @@ add_builtin(
6923
7806
  hidden=True,
6924
7807
  group="Utility",
6925
7808
  skip_replay=True,
7809
+ is_differentiable=False,
6926
7810
  )
6927
7811
  # implements &(*vector)[index]
6928
7812
  add_builtin(
@@ -6933,6 +7817,7 @@ add_builtin(
6933
7817
  hidden=True,
6934
7818
  group="Utility",
6935
7819
  skip_replay=True,
7820
+ is_differentiable=False,
6936
7821
  )
6937
7822
  # implements &(*matrix)[i, j]
6938
7823
  add_builtin(
@@ -6943,6 +7828,7 @@ add_builtin(
6943
7828
  hidden=True,
6944
7829
  group="Utility",
6945
7830
  skip_replay=True,
7831
+ is_differentiable=False,
6946
7832
  )
6947
7833
  # implements &(*quaternion)[index]
6948
7834
  add_builtin(
@@ -6953,6 +7839,7 @@ add_builtin(
6953
7839
  hidden=True,
6954
7840
  group="Utility",
6955
7841
  skip_replay=True,
7842
+ is_differentiable=False,
6956
7843
  )
6957
7844
  # implements &(*transformation)[index]
6958
7845
  add_builtin(
@@ -6963,6 +7850,7 @@ add_builtin(
6963
7850
  hidden=True,
6964
7851
  group="Utility",
6965
7852
  skip_replay=True,
7853
+ is_differentiable=False,
6966
7854
  )
6967
7855
 
6968
7856
 
@@ -7158,6 +8046,43 @@ add_builtin(
7158
8046
  )
7159
8047
 
7160
8048
 
8049
+ # implements vector[idx] &= scalar
8050
+ add_builtin(
8051
+ "bit_and_inplace",
8052
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
8053
+ value_type=None,
8054
+ dispatch_func=vector_assign_dispatch_func,
8055
+ hidden=True,
8056
+ export=False,
8057
+ group="Utility",
8058
+ is_differentiable=False,
8059
+ )
8060
+
8061
+ # implements vector[idx] |= scalar
8062
+ add_builtin(
8063
+ "bit_or_inplace",
8064
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
8065
+ value_type=None,
8066
+ dispatch_func=vector_assign_dispatch_func,
8067
+ hidden=True,
8068
+ export=False,
8069
+ group="Utility",
8070
+ is_differentiable=False,
8071
+ )
8072
+
8073
+ # implements vector[idx] ^= scalar
8074
+ add_builtin(
8075
+ "bit_xor_inplace",
8076
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
8077
+ value_type=None,
8078
+ dispatch_func=vector_assign_dispatch_func,
8079
+ hidden=True,
8080
+ export=False,
8081
+ group="Utility",
8082
+ is_differentiable=False,
8083
+ )
8084
+
8085
+
7161
8086
  def matrix_index_row_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
7162
8087
  mat_type = arg_types["a"]
7163
8088
  row_type = mat_type._wp_row_type_
@@ -7173,6 +8098,7 @@ add_builtin(
7173
8098
  hidden=True,
7174
8099
  group="Utility",
7175
8100
  skip_replay=True,
8101
+ is_differentiable=False,
7176
8102
  )
7177
8103
 
7178
8104
 
@@ -7191,6 +8117,7 @@ add_builtin(
7191
8117
  hidden=True,
7192
8118
  group="Utility",
7193
8119
  skip_replay=True,
8120
+ is_differentiable=False,
7194
8121
  )
7195
8122
 
7196
8123
 
@@ -7390,6 +8317,78 @@ add_builtin(
7390
8317
  )
7391
8318
 
7392
8319
 
8320
+ # implements matrix[i] &= value
8321
+ add_builtin(
8322
+ "bit_and_inplace",
8323
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
8324
+ value_type=None,
8325
+ hidden=True,
8326
+ export=False,
8327
+ group="Utility",
8328
+ is_differentiable=False,
8329
+ )
8330
+
8331
+
8332
+ # implements matrix[i,j] &= value
8333
+ add_builtin(
8334
+ "bit_and_inplace",
8335
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
8336
+ value_type=None,
8337
+ hidden=True,
8338
+ export=False,
8339
+ group="Utility",
8340
+ is_differentiable=False,
8341
+ )
8342
+
8343
+
8344
+ # implements matrix[i] |= value
8345
+ add_builtin(
8346
+ "bit_or_inplace",
8347
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
8348
+ value_type=None,
8349
+ hidden=True,
8350
+ export=False,
8351
+ group="Utility",
8352
+ is_differentiable=False,
8353
+ )
8354
+
8355
+
8356
+ # implements matrix[i,j] |= value
8357
+ add_builtin(
8358
+ "bit_or_inplace",
8359
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
8360
+ value_type=None,
8361
+ hidden=True,
8362
+ export=False,
8363
+ group="Utility",
8364
+ is_differentiable=False,
8365
+ )
8366
+
8367
+
8368
+ # implements matrix[i] ^= value
8369
+ add_builtin(
8370
+ "bit_xor_inplace",
8371
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
8372
+ value_type=None,
8373
+ hidden=True,
8374
+ export=False,
8375
+ group="Utility",
8376
+ is_differentiable=False,
8377
+ )
8378
+
8379
+
8380
+ # implements matrix[i,j] ^= value
8381
+ add_builtin(
8382
+ "bit_xor_inplace",
8383
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
8384
+ value_type=None,
8385
+ hidden=True,
8386
+ export=False,
8387
+ group="Utility",
8388
+ is_differentiable=False,
8389
+ )
8390
+
8391
+
7393
8392
  for t in scalar_types + vector_types + (bool,):
7394
8393
  if "vec" in t.__name__ or "mat" in t.__name__:
7395
8394
  continue
@@ -7401,6 +8400,7 @@ for t in scalar_types + vector_types + (bool,):
7401
8400
  doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
7402
8401
  group="Utility",
7403
8402
  hidden=True,
8403
+ is_differentiable=False,
7404
8404
  )
7405
8405
 
7406
8406
  add_builtin(
@@ -7411,6 +8411,7 @@ for t in scalar_types + vector_types + (bool,):
7411
8411
  group="Utility",
7412
8412
  hidden=True,
7413
8413
  export=False,
8414
+ is_differentiable=False,
7414
8415
  )
7415
8416
 
7416
8417
 
@@ -7429,6 +8430,7 @@ add_builtin(
7429
8430
  doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
7430
8431
  group="Utility",
7431
8432
  hidden=True,
8433
+ is_differentiable=False,
7432
8434
  )
7433
8435
  add_builtin(
7434
8436
  "expect_neq",
@@ -7439,6 +8441,7 @@ add_builtin(
7439
8441
  group="Utility",
7440
8442
  hidden=True,
7441
8443
  export=False,
8444
+ is_differentiable=False,
7442
8445
  )
7443
8446
 
7444
8447
  add_builtin(
@@ -7449,6 +8452,7 @@ add_builtin(
7449
8452
  doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
7450
8453
  group="Utility",
7451
8454
  hidden=True,
8455
+ is_differentiable=False,
7452
8456
  )
7453
8457
  add_builtin(
7454
8458
  "expect_neq",
@@ -7459,6 +8463,7 @@ add_builtin(
7459
8463
  group="Utility",
7460
8464
  hidden=True,
7461
8465
  export=False,
8466
+ is_differentiable=False,
7462
8467
  )
7463
8468
 
7464
8469
  add_builtin(
@@ -7549,6 +8554,7 @@ add_builtin(
7549
8554
  value_type=None,
7550
8555
  doc="Prints an error to stdout if ``a`` and ``b`` are not closer than tolerance in magnitude",
7551
8556
  group="Utility",
8557
+ is_differentiable=False,
7552
8558
  )
7553
8559
  add_builtin(
7554
8560
  "expect_near",
@@ -7558,6 +8564,7 @@ add_builtin(
7558
8564
  value_type=None,
7559
8565
  doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
7560
8566
  group="Utility",
8567
+ is_differentiable=False,
7561
8568
  )
7562
8569
  add_builtin(
7563
8570
  "expect_near",
@@ -7567,6 +8574,7 @@ add_builtin(
7567
8574
  value_type=None,
7568
8575
  doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
7569
8576
  group="Utility",
8577
+ is_differentiable=False,
7570
8578
  )
7571
8579
  add_builtin(
7572
8580
  "expect_near",
@@ -7580,6 +8588,7 @@ add_builtin(
7580
8588
  value_type=None,
7581
8589
  doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
7582
8590
  group="Utility",
8591
+ is_differentiable=False,
7583
8592
  )
7584
8593
 
7585
8594
  # ---------------------------------
@@ -7590,6 +8599,7 @@ add_builtin(
7590
8599
  input_types={"arr": array(dtype=Scalar), "value": Scalar},
7591
8600
  value_type=int,
7592
8601
  doc="Search a sorted array ``arr`` for the closest element greater than or equal to ``value``.",
8602
+ is_differentiable=False,
7593
8603
  )
7594
8604
 
7595
8605
  add_builtin(
@@ -7597,6 +8607,7 @@ add_builtin(
7597
8607
  input_types={"arr": array(dtype=Scalar), "arr_begin": int, "arr_end": int, "value": Scalar},
7598
8608
  value_type=int,
7599
8609
  doc="Search a sorted array ``arr`` in the range [arr_begin, arr_end) for the closest element greater than or equal to ``value``.",
8610
+ is_differentiable=False,
7600
8611
  )
7601
8612
 
7602
8613
  # ---------------------------------
@@ -7672,12 +8683,157 @@ add_builtin(
7672
8683
  )
7673
8684
 
7674
8685
  # bitwise operators
7675
- add_builtin("bit_and", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
7676
- add_builtin("bit_or", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
7677
- add_builtin("bit_xor", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
7678
- add_builtin("lshift", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
7679
- add_builtin("rshift", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
7680
- add_builtin("invert", input_types={"a": Int}, value_func=sametypes_create_value_func(Int))
8686
+ add_builtin(
8687
+ "bit_and",
8688
+ input_types={"a": Int, "b": Int},
8689
+ value_func=sametypes_create_value_func(Int),
8690
+ group="Operators",
8691
+ is_differentiable=False,
8692
+ )
8693
+ add_builtin(
8694
+ "bit_and",
8695
+ input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
8696
+ constraint=sametypes,
8697
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
8698
+ doc="",
8699
+ group="Operators",
8700
+ is_differentiable=False,
8701
+ )
8702
+ add_builtin(
8703
+ "bit_and",
8704
+ input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
8705
+ constraint=sametypes,
8706
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
8707
+ doc="",
8708
+ group="Operators",
8709
+ is_differentiable=False,
8710
+ )
8711
+
8712
+ add_builtin(
8713
+ "bit_or",
8714
+ input_types={"a": Int, "b": Int},
8715
+ value_func=sametypes_create_value_func(Int),
8716
+ group="Operators",
8717
+ is_differentiable=False,
8718
+ )
8719
+ add_builtin(
8720
+ "bit_or",
8721
+ input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
8722
+ constraint=sametypes,
8723
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
8724
+ doc="",
8725
+ group="Operators",
8726
+ is_differentiable=False,
8727
+ )
8728
+ add_builtin(
8729
+ "bit_or",
8730
+ input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
8731
+ constraint=sametypes,
8732
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
8733
+ doc="",
8734
+ group="Operators",
8735
+ is_differentiable=False,
8736
+ )
8737
+
8738
+ add_builtin(
8739
+ "bit_xor",
8740
+ input_types={"a": Int, "b": Int},
8741
+ value_func=sametypes_create_value_func(Int),
8742
+ group="Operators",
8743
+ is_differentiable=False,
8744
+ )
8745
+ add_builtin(
8746
+ "bit_xor",
8747
+ input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
8748
+ constraint=sametypes,
8749
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
8750
+ doc="",
8751
+ group="Operators",
8752
+ is_differentiable=False,
8753
+ )
8754
+ add_builtin(
8755
+ "bit_xor",
8756
+ input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
8757
+ constraint=sametypes,
8758
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
8759
+ doc="",
8760
+ group="Operators",
8761
+ is_differentiable=False,
8762
+ )
8763
+
8764
+ add_builtin(
8765
+ "lshift",
8766
+ input_types={"a": Int, "b": Int},
8767
+ value_func=sametypes_create_value_func(Int),
8768
+ group="Operators",
8769
+ is_differentiable=False,
8770
+ )
8771
+ add_builtin(
8772
+ "lshift",
8773
+ input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
8774
+ constraint=sametypes,
8775
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
8776
+ doc="",
8777
+ group="Operators",
8778
+ is_differentiable=False,
8779
+ )
8780
+ add_builtin(
8781
+ "lshift",
8782
+ input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
8783
+ constraint=sametypes,
8784
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
8785
+ doc="",
8786
+ group="Operators",
8787
+ is_differentiable=False,
8788
+ )
8789
+
8790
+ add_builtin(
8791
+ "rshift",
8792
+ input_types={"a": Int, "b": Int},
8793
+ value_func=sametypes_create_value_func(Int),
8794
+ group="Operators",
8795
+ is_differentiable=False,
8796
+ )
8797
+ add_builtin(
8798
+ "rshift",
8799
+ input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
8800
+ constraint=sametypes,
8801
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
8802
+ doc="",
8803
+ group="Operators",
8804
+ is_differentiable=False,
8805
+ )
8806
+ add_builtin(
8807
+ "rshift",
8808
+ input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
8809
+ constraint=sametypes,
8810
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
8811
+ doc="",
8812
+ group="Operators",
8813
+ is_differentiable=False,
8814
+ )
8815
+
8816
+ add_builtin(
8817
+ "invert",
8818
+ input_types={"a": Int},
8819
+ value_func=sametypes_create_value_func(Int),
8820
+ group="Operators",
8821
+ is_differentiable=False,
8822
+ )
8823
+ add_builtin(
8824
+ "invert",
8825
+ input_types={"a": vector(length=Any, dtype=Int)},
8826
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
8827
+ group="Operators",
8828
+ is_differentiable=False,
8829
+ )
8830
+ add_builtin(
8831
+ "invert",
8832
+ input_types={"a": matrix(shape=(Any, Any), dtype=Int)},
8833
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
8834
+ group="Operators",
8835
+ is_differentiable=False,
8836
+ )
7681
8837
 
7682
8838
 
7683
8839
  add_builtin(
@@ -7878,6 +9034,7 @@ add_builtin(
7878
9034
  value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
7879
9035
  doc="Modulo operation using truncated division.",
7880
9036
  group="Operators",
9037
+ is_differentiable=False,
7881
9038
  )
7882
9039
 
7883
9040
  add_builtin(
@@ -7937,6 +9094,7 @@ add_builtin(
7937
9094
  value_func=sametypes_create_value_func(Scalar),
7938
9095
  doc="",
7939
9096
  group="Operators",
9097
+ is_differentiable=False,
7940
9098
  )
7941
9099
 
7942
9100
  add_builtin("pos", input_types={"x": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators")
@@ -7984,12 +9142,28 @@ add_builtin(
7984
9142
  group="Operators",
7985
9143
  )
7986
9144
 
7987
- add_builtin("unot", input_types={"a": builtins.bool}, value_type=builtins.bool, doc="", group="Operators")
9145
+ add_builtin(
9146
+ "unot",
9147
+ input_types={"a": builtins.bool},
9148
+ value_type=builtins.bool,
9149
+ doc="",
9150
+ group="Operators",
9151
+ is_differentiable=False,
9152
+ )
7988
9153
  for t in int_types:
7989
- add_builtin("unot", input_types={"a": t}, value_type=builtins.bool, doc="", group="Operators")
9154
+ add_builtin(
9155
+ "unot", input_types={"a": t}, value_type=builtins.bool, doc="", group="Operators", is_differentiable=False
9156
+ )
7990
9157
 
7991
9158
 
7992
- add_builtin("unot", input_types={"a": array(dtype=Any)}, value_type=builtins.bool, doc="", group="Operators")
9159
+ add_builtin(
9160
+ "unot",
9161
+ input_types={"a": array(dtype=Any)},
9162
+ value_type=builtins.bool,
9163
+ doc="",
9164
+ group="Operators",
9165
+ is_differentiable=False,
9166
+ )
7993
9167
 
7994
9168
 
7995
9169
  # Tile operators
@@ -8061,6 +9235,45 @@ add_builtin(
8061
9235
  export=False,
8062
9236
  )
8063
9237
 
9238
+ add_builtin(
9239
+ "bit_and",
9240
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
9241
+ value_func=tile_binary_map_value_func,
9242
+ # dispatch_func=tile_map_dispatch_func,
9243
+ # variadic=True,
9244
+ native_func="tile_bit_and",
9245
+ doc="Bitwise AND each element of two tiles together",
9246
+ group="Tile Primitives",
9247
+ export=False,
9248
+ is_differentiable=False,
9249
+ )
9250
+
9251
+ add_builtin(
9252
+ "bit_or",
9253
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
9254
+ value_func=tile_binary_map_value_func,
9255
+ # dispatch_func=tile_map_dispatch_func,
9256
+ # variadic=True,
9257
+ native_func="tile_bit_or",
9258
+ doc="Bitwise OR each element of two tiles together",
9259
+ group="Tile Primitives",
9260
+ export=False,
9261
+ is_differentiable=False,
9262
+ )
9263
+
9264
+ add_builtin(
9265
+ "bit_xor",
9266
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
9267
+ value_func=tile_binary_map_value_func,
9268
+ # dispatch_func=tile_map_dispatch_func,
9269
+ # variadic=True,
9270
+ native_func="tile_bit_xor",
9271
+ doc="Bitwise XOR each element of two tiles together",
9272
+ group="Tile Primitives",
9273
+ export=False,
9274
+ is_differentiable=False,
9275
+ )
9276
+
8064
9277
 
8065
9278
  add_builtin(
8066
9279
  "mul",
@@ -8122,6 +9335,45 @@ add_builtin(
8122
9335
  )
8123
9336
 
8124
9337
 
9338
+ add_builtin(
9339
+ "bit_and_inplace",
9340
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
9341
+ value_type=None,
9342
+ dispatch_func=tile_inplace_dispatch_func,
9343
+ export=False,
9344
+ hidden=True,
9345
+ native_func="tile_bit_and_inplace",
9346
+ group="Operators",
9347
+ is_differentiable=False,
9348
+ )
9349
+
9350
+
9351
+ add_builtin(
9352
+ "bit_or_inplace",
9353
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
9354
+ value_type=None,
9355
+ dispatch_func=tile_inplace_dispatch_func,
9356
+ export=False,
9357
+ hidden=True,
9358
+ native_func="tile_bit_or_inplace",
9359
+ group="Operators",
9360
+ is_differentiable=False,
9361
+ )
9362
+
9363
+
9364
+ add_builtin(
9365
+ "bit_xor_inplace",
9366
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
9367
+ value_type=None,
9368
+ dispatch_func=tile_inplace_dispatch_func,
9369
+ export=False,
9370
+ hidden=True,
9371
+ native_func="tile_bit_xor_inplace",
9372
+ group="Operators",
9373
+ is_differentiable=False,
9374
+ )
9375
+
9376
+
8125
9377
  def tile_diag_add_value_func(arg_types, arg_values):
8126
9378
  if arg_types is None:
8127
9379
  return tile(dtype=Any, shape=Tuple[int, int])
@@ -8163,7 +9415,7 @@ def tile_diag_add_lto_dispatch_func(
8163
9415
  return_values: List[Var],
8164
9416
  arg_values: Mapping[str, Var],
8165
9417
  options: Mapping[str, Any],
8166
- builder: warp.context.ModuleBuilder,
9418
+ builder: warp._src.context.ModuleBuilder,
8167
9419
  ):
8168
9420
  a = arg_values["a"]
8169
9421
  d = arg_values["d"]
@@ -8183,6 +9435,7 @@ add_builtin(
8183
9435
  doc="Add a square matrix and a diagonal matrix 'd' represented as a 1D tile",
8184
9436
  group="Tile Primitives",
8185
9437
  export=False,
9438
+ is_differentiable=False,
8186
9439
  )
8187
9440
 
8188
9441
 
@@ -8239,7 +9492,7 @@ def tile_matmul_lto_dispatch_func(
8239
9492
  return_values: List[Var],
8240
9493
  arg_values: Mapping[str, Var],
8241
9494
  options: Mapping[str, Any],
8242
- builder: warp.context.ModuleBuilder,
9495
+ builder: warp._src.context.ModuleBuilder,
8243
9496
  ):
8244
9497
  a = arg_values["a"]
8245
9498
  b = arg_values["b"]
@@ -8277,7 +9530,7 @@ def tile_matmul_lto_dispatch_func(
8277
9530
  num_threads = options["block_dim"]
8278
9531
  arch = options["output_arch"]
8279
9532
 
8280
- if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
9533
+ if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
8281
9534
  # CPU/no-MathDx dispatch
8282
9535
  return ((0, 0, 0, a, b, out), template_args, [], 0)
8283
9536
  else:
@@ -8290,7 +9543,7 @@ def tile_matmul_lto_dispatch_func(
8290
9543
 
8291
9544
  # generate the LTOs
8292
9545
  # C += A * B
8293
- (fun_forward, lto_forward) = warp.build.build_lto_dot(
9546
+ (fun_forward, lto_forward) = warp._src.build.build_lto_dot(
8294
9547
  M,
8295
9548
  N,
8296
9549
  K,
@@ -8306,7 +9559,7 @@ def tile_matmul_lto_dispatch_func(
8306
9559
  )
8307
9560
  if warp.config.enable_backward:
8308
9561
  # adjA += adjC * B^T - Transpose ~= flipped layout
8309
- (fun_backward_A, lto_backward_A) = warp.build.build_lto_dot(
9562
+ (fun_backward_A, lto_backward_A) = warp._src.build.build_lto_dot(
8310
9563
  M,
8311
9564
  K,
8312
9565
  N,
@@ -8321,7 +9574,7 @@ def tile_matmul_lto_dispatch_func(
8321
9574
  builder,
8322
9575
  )
8323
9576
  # adjB += A^T * adjC - Transpose ~= flipped layout
8324
- (fun_backward_B, lto_backward_B) = warp.build.build_lto_dot(
9577
+ (fun_backward_B, lto_backward_B) = warp._src.build.build_lto_dot(
8325
9578
  K,
8326
9579
  N,
8327
9580
  M,
@@ -8438,7 +9691,7 @@ def tile_fft_generic_lto_dispatch_func(
8438
9691
  return_values: List[Var],
8439
9692
  arg_values: Mapping[str, Var],
8440
9693
  options: Mapping[str, Any],
8441
- builder: warp.context.ModuleBuilder,
9694
+ builder: warp._src.context.ModuleBuilder,
8442
9695
  direction: str | None = None,
8443
9696
  ):
8444
9697
  inout = arg_values["inout"]
@@ -8467,12 +9720,12 @@ def tile_fft_generic_lto_dispatch_func(
8467
9720
  arch = options["output_arch"]
8468
9721
  ept = size // num_threads
8469
9722
 
8470
- if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
9723
+ if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
8471
9724
  # CPU/no-MathDx dispatch
8472
9725
  return ([], [], [], 0)
8473
9726
  else:
8474
9727
  # generate the LTO
8475
- lto_symbol, lto_code_data, shared_memory_bytes = warp.build.build_lto_fft(
9728
+ lto_symbol, lto_code_data, shared_memory_bytes = warp._src.build.build_lto_fft(
8476
9729
  arch, size, ept, direction, dir, precision, builder
8477
9730
  )
8478
9731
 
@@ -8510,6 +9763,7 @@ add_builtin(
8510
9763
  group="Tile Primitives",
8511
9764
  export=False,
8512
9765
  namespace="",
9766
+ is_differentiable=False,
8513
9767
  )
8514
9768
 
8515
9769
  add_builtin(
@@ -8531,6 +9785,7 @@ add_builtin(
8531
9785
  group="Tile Primitives",
8532
9786
  export=False,
8533
9787
  namespace="",
9788
+ is_differentiable=False,
8534
9789
  )
8535
9790
 
8536
9791
 
@@ -8575,7 +9830,7 @@ def tile_cholesky_generic_lto_dispatch_func(
8575
9830
  return_values: List[Var],
8576
9831
  arg_values: Mapping[str, Var],
8577
9832
  options: Mapping[str, Any],
8578
- builder: warp.context.ModuleBuilder,
9833
+ builder: warp._src.context.ModuleBuilder,
8579
9834
  ):
8580
9835
  a = arg_values["A"]
8581
9836
  # force source tile to shared memory
@@ -8595,7 +9850,7 @@ def tile_cholesky_generic_lto_dispatch_func(
8595
9850
 
8596
9851
  arch = options["output_arch"]
8597
9852
 
8598
- if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
9853
+ if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
8599
9854
  # CPU/no-MathDx dispatch
8600
9855
  return ((0, a, out), [], [], 0)
8601
9856
  else:
@@ -8610,7 +9865,7 @@ def tile_cholesky_generic_lto_dispatch_func(
8610
9865
  req_smem_bytes = a.type.size * type_size_in_bytes(a.type.dtype)
8611
9866
 
8612
9867
  # generate the LTO
8613
- lto_symbol, lto_code_data = warp.build.build_lto_solver(
9868
+ lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
8614
9869
  M,
8615
9870
  N,
8616
9871
  1,
@@ -8655,6 +9910,7 @@ add_builtin(
8655
9910
  group="Tile Primitives",
8656
9911
  export=False,
8657
9912
  namespace="",
9913
+ is_differentiable=False,
8658
9914
  )
8659
9915
 
8660
9916
 
@@ -8698,7 +9954,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
8698
9954
  return_values: List[Var],
8699
9955
  arg_values: Mapping[str, Var],
8700
9956
  options: Mapping[str, Any],
8701
- builder: warp.context.ModuleBuilder,
9957
+ builder: warp._src.context.ModuleBuilder,
8702
9958
  ):
8703
9959
  L = arg_values["L"]
8704
9960
  y = arg_values["y"]
@@ -8727,7 +9983,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
8727
9983
 
8728
9984
  arch = options["output_arch"]
8729
9985
 
8730
- if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
9986
+ if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
8731
9987
  # CPU/no-MathDx dispatch
8732
9988
  return ((0, L, y, x), [], [], 0)
8733
9989
  else:
@@ -8743,7 +9999,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
8743
9999
  req_smem_bytes = (x.type.size + y.type.size + L.type.size) * type_size_in_bytes(L.type.dtype)
8744
10000
 
8745
10001
  # generate the LTO
8746
- lto_symbol, lto_code_data = warp.build.build_lto_solver(
10002
+ lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
8747
10003
  M,
8748
10004
  N,
8749
10005
  NRHS,
@@ -8785,6 +10041,7 @@ add_builtin(
8785
10041
  group="Tile Primitives",
8786
10042
  export=False,
8787
10043
  namespace="",
10044
+ is_differentiable=False,
8788
10045
  )
8789
10046
 
8790
10047
 
@@ -8794,7 +10051,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
8794
10051
  return_values: List[Var],
8795
10052
  arg_values: Mapping[str, Var],
8796
10053
  options: Mapping[str, Any],
8797
- builder: warp.context.ModuleBuilder,
10054
+ builder: warp._src.context.ModuleBuilder,
8798
10055
  ):
8799
10056
  L = arg_values["L"]
8800
10057
  y = arg_values["y"]
@@ -8823,7 +10080,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
8823
10080
 
8824
10081
  arch = options["output_arch"]
8825
10082
 
8826
- if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
10083
+ if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
8827
10084
  # CPU/no-MathDx dispatch
8828
10085
  return ((0, L, y, z), [], [], 0)
8829
10086
  else:
@@ -8839,7 +10096,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
8839
10096
  req_smem_bytes = (z.type.size + y.type.size + L.type.size) * type_size_in_bytes(L.type.dtype)
8840
10097
 
8841
10098
  # generate the LTO
8842
- lto_symbol, lto_code_data = warp.build.build_lto_solver(
10099
+ lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
8843
10100
  M,
8844
10101
  N,
8845
10102
  NRHS,
@@ -8917,6 +10174,7 @@ add_builtin(
8917
10174
  group="Tile Primitives",
8918
10175
  export=False,
8919
10176
  namespace="",
10177
+ is_differentiable=False,
8920
10178
  )
8921
10179
 
8922
10180
 
@@ -8926,7 +10184,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
8926
10184
  return_values: List[Var],
8927
10185
  arg_values: Mapping[str, Var],
8928
10186
  options: Mapping[str, Any],
8929
- builder: warp.context.ModuleBuilder,
10187
+ builder: warp._src.context.ModuleBuilder,
8930
10188
  ):
8931
10189
  U = arg_values["U"]
8932
10190
  z = arg_values["z"]
@@ -8955,7 +10213,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
8955
10213
 
8956
10214
  arch = options["output_arch"]
8957
10215
 
8958
- if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
10216
+ if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
8959
10217
  # CPU/no-MathDx dispatch
8960
10218
  return ((0, U, z, x), [], [], 0)
8961
10219
  else:
@@ -8971,7 +10229,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
8971
10229
  req_smem_bytes = (x.type.size + z.type.size + U.type.size) * type_size_in_bytes(U.type.dtype)
8972
10230
 
8973
10231
  # generate the LTO
8974
- lto_symbol, lto_code_data = warp.build.build_lto_solver(
10232
+ lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
8975
10233
  M,
8976
10234
  N,
8977
10235
  NRHS,
@@ -9049,6 +10307,7 @@ add_builtin(
9049
10307
  group="Tile Primitives",
9050
10308
  export=False,
9051
10309
  namespace="",
10310
+ is_differentiable=False,
9052
10311
  )
9053
10312
 
9054
10313
 
@@ -9068,6 +10327,7 @@ add_builtin(
9068
10327
  The return type of the expression must be either a Warp function, a string, or a type that is supported inside Warp kernels and functions
9069
10328
  (excluding Warp arrays since they cannot be created in a Warp kernel at the moment).""",
9070
10329
  group="Code Generation",
10330
+ is_differentiable=False,
9071
10331
  )
9072
10332
 
9073
10333
 
@@ -9092,6 +10352,7 @@ add_builtin(
9092
10352
  doc="Return the number of elements in a vector.",
9093
10353
  group="Utility",
9094
10354
  export=False,
10355
+ is_differentiable=False,
9095
10356
  )
9096
10357
 
9097
10358
  add_builtin(
@@ -9101,6 +10362,7 @@ add_builtin(
9101
10362
  doc="Return the number of elements in a quaternion.",
9102
10363
  group="Utility",
9103
10364
  export=False,
10365
+ is_differentiable=False,
9104
10366
  )
9105
10367
 
9106
10368
  add_builtin(
@@ -9110,6 +10372,7 @@ add_builtin(
9110
10372
  doc="Return the number of rows in a matrix.",
9111
10373
  group="Utility",
9112
10374
  export=False,
10375
+ is_differentiable=False,
9113
10376
  )
9114
10377
 
9115
10378
  add_builtin(
@@ -9119,6 +10382,7 @@ add_builtin(
9119
10382
  doc="Return the number of elements in a transformation.",
9120
10383
  group="Utility",
9121
10384
  export=False,
10385
+ is_differentiable=False,
9122
10386
  )
9123
10387
 
9124
10388
  add_builtin(
@@ -9128,6 +10392,7 @@ add_builtin(
9128
10392
  doc="Return the size of the first dimension in an array.",
9129
10393
  group="Utility",
9130
10394
  export=False,
10395
+ is_differentiable=False,
9131
10396
  )
9132
10397
 
9133
10398
  add_builtin(
@@ -9137,6 +10402,62 @@ add_builtin(
9137
10402
  doc="Return the number of rows in a tile.",
9138
10403
  group="Utility",
9139
10404
  export=False,
10405
+ is_differentiable=False,
10406
+ )
10407
+
10408
+
10409
+ def cast_value_func(arg_types, arg_values):
10410
+ # Return generic type for doc builds.
10411
+ if arg_types is None:
10412
+ return Any
10413
+
10414
+ return arg_values["dtype"]
10415
+
10416
+
10417
+ def cast_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
10418
+ func_args = (args["a"],)
10419
+ template_args = (args["dtype"],)
10420
+ return (func_args, template_args)
10421
+
10422
+
10423
+ add_builtin(
10424
+ "cast",
10425
+ input_types={"a": Any, "dtype": Any},
10426
+ value_func=cast_value_func,
10427
+ dispatch_func=cast_dispatch_func,
10428
+ group="Utility",
10429
+ export=False,
10430
+ is_differentiable=False,
10431
+ doc="""Reinterpret a value as a different type while preserving its bit pattern.
10432
+
10433
+ :param a: The value to cast
10434
+ :param dtype: The target type
10435
+
10436
+ Example:
10437
+
10438
+ .. code-block:: python
10439
+
10440
+ @wp.struct
10441
+ class MyStruct:
10442
+ f: wp.float16
10443
+ i: wp.int16
10444
+
10445
+
10446
+ @wp.kernel
10447
+ def compute():
10448
+ x = wp.int32(0x40000000)
10449
+ x_casted = wp.cast(x, wp.float32)
10450
+ wp.expect_eq(x_casted, 2.0) # 0x40000000
10451
+
10452
+ s = MyStruct()
10453
+ s.f = wp.float16(2.0) # 0x4000
10454
+ s.i = wp.int16(4096) # 0x1000
10455
+ s_casted = wp.cast(s, wp.int32)
10456
+ wp.expect_eq(s_casted, 0x10004000)
10457
+
10458
+
10459
+ wp.launch(compute, dim=1)
10460
+ """,
9140
10461
  )
9141
10462
 
9142
10463
 
@@ -9163,7 +10484,7 @@ add_builtin(
9163
10484
  doc="Construct a tuple from a list of values",
9164
10485
  group="Utility",
9165
10486
  hidden=True,
9166
- missing_grad=True,
10487
+ is_differentiable=False,
9167
10488
  export=False,
9168
10489
  )
9169
10490
 
@@ -9200,7 +10521,7 @@ add_builtin(
9200
10521
  dispatch_func=tuple_extract_dispatch_func,
9201
10522
  group="Utility",
9202
10523
  hidden=True,
9203
- missing_grad=True,
10524
+ is_differentiable=False,
9204
10525
  )
9205
10526
 
9206
10527
 
@@ -9211,6 +10532,7 @@ add_builtin(
9211
10532
  doc="Return the number of elements in a tuple.",
9212
10533
  group="Utility",
9213
10534
  export=False,
10535
+ is_differentiable=False,
9214
10536
  )
9215
10537
 
9216
10538
  # ---------------------------------
@@ -9229,5 +10551,5 @@ add_builtin(
9229
10551
  export=False,
9230
10552
  group="Utility",
9231
10553
  hidden=True,
9232
- missing_grad=True,
10554
+ is_differentiable=False,
9233
10555
  )