warp-lang 1.9.1__py3-none-manylinux_2_34_aarch64.whl → 1.10.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +882 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/{builtins.py → _src/builtins.py} +1435 -379
  8. warp/_src/codegen.py +4361 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/{fem → _src/fem}/domain.py +42 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +32 -15
  22. warp/{fem → _src/fem}/field/restriction.py +3 -1
  23. warp/{fem → _src/fem}/field/virtual.py +55 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +79 -163
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +16 -22
  28. warp/{fem → _src/fem}/geometry/element.py +34 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +50 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +14 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +14 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +42 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +256 -247
  34. warp/{fem → _src/fem}/geometry/partition.py +123 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +28 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +42 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +28 -45
  38. warp/{fem → _src/fem}/integrate.py +166 -158
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +17 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +97 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +22 -11
  46. warp/_src/fem/space/basis_space.py +681 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +5 -3
  48. warp/{fem → _src/fem}/space/function_space.py +16 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +6 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +6 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +6 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +5 -9
  53. warp/{fem → _src/fem}/space/partition.py +119 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +6 -10
  55. warp/{fem → _src/fem}/space/restriction.py +68 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +11 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +10 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +8 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +5 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +5 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +5 -9
  63. warp/_src/fem/space/topology.py +461 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +5 -9
  65. warp/_src/fem/types.py +114 -0
  66. warp/_src/fem/utils.py +488 -0
  67. warp/_src/jax.py +188 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +389 -0
  70. warp/_src/jax_experimental/ffi.py +1286 -0
  71. warp/_src/jax_experimental/xla_ffi.py +658 -0
  72. warp/_src/marching_cubes.py +710 -0
  73. warp/_src/math.py +416 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +165 -0
  76. warp/_src/optim/linear.py +1608 -0
  77. warp/_src/optim/sgd.py +114 -0
  78. warp/_src/paddle.py +408 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +291 -0
  81. warp/_src/render/render_opengl.py +3638 -0
  82. warp/_src/render/render_usd.py +939 -0
  83. warp/_src/render/utils.py +162 -0
  84. warp/_src/sparse.py +2718 -0
  85. warp/_src/tape.py +1208 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +393 -0
  88. warp/_src/types.py +5888 -0
  89. warp/_src/utils.py +1695 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.so +0 -0
  92. warp/bin/warp.so +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +3 -3
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +521 -250
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +18 -17
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +578 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/top_level.txt +0 -0
@@ -20,14 +20,16 @@ import functools
20
20
  import math
21
21
  from typing import Any, Callable, Mapping, Sequence
22
22
 
23
- import warp.build
24
- import warp.context
25
- import warp.utils
26
- from warp.codegen import Reference, Var, get_arg_value, strip_reference
27
- from warp.types import *
23
+ import warp._src.build
24
+ import warp._src.context
25
+ import warp._src.utils
26
+ from warp._src.codegen import Reference, Var, get_arg_value, strip_reference
27
+ from warp._src.types import *
28
28
 
29
29
  from .context import add_builtin
30
30
 
31
+ _wp_module_name_ = "warp.builtins"
32
+
31
33
 
32
34
  def seq_check_equal(seq_1, seq_2):
33
35
  if not isinstance(seq_1, Sequence) or not isinstance(seq_2, Sequence):
@@ -61,11 +63,11 @@ def sametypes_create_value_func(default: TypeVar):
61
63
 
62
64
  def extract_tuple(arg, as_constant=False):
63
65
  if isinstance(arg, Var):
64
- if isinstance(arg.type, warp.types.tuple_t):
66
+ if isinstance(arg.type, warp._src.types.tuple_t):
65
67
  out = arg.type.values
66
68
  else:
67
69
  out = (arg,)
68
- elif isinstance(arg, warp.types.tuple_t):
70
+ elif isinstance(arg, warp._src.types.tuple_t):
69
71
  out = arg.values
70
72
  elif not isinstance(arg, Sequence):
71
73
  out = (arg,)
@@ -82,7 +84,7 @@ def static_len_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str
82
84
  if arg_types is None:
83
85
  return int
84
86
 
85
- length = warp.types.type_length(arg_types["a"])
87
+ length = warp._src.types.type_length(arg_types["a"])
86
88
  return Var(None, type=int, constant=length)
87
89
 
88
90
 
@@ -126,7 +128,7 @@ add_builtin(
126
128
  value_func=sametypes_create_value_func(Scalar),
127
129
  doc="Return -1 if ``x`` < 0, return 1 otherwise.",
128
130
  group="Scalar Math",
129
- missing_grad=True,
131
+ is_differentiable=False,
130
132
  )
131
133
 
132
134
  add_builtin(
@@ -135,7 +137,7 @@ add_builtin(
135
137
  value_func=sametypes_create_value_func(Scalar),
136
138
  doc="Return 1.0 if ``x`` < 0.0, return 0.0 otherwise.",
137
139
  group="Scalar Math",
138
- missing_grad=True,
140
+ is_differentiable=False,
139
141
  )
140
142
  add_builtin(
141
143
  "nonzero",
@@ -143,7 +145,7 @@ add_builtin(
143
145
  value_func=sametypes_create_value_func(Scalar),
144
146
  doc="Return 1.0 if ``x`` is not equal to zero, return 0.0 otherwise.",
145
147
  group="Scalar Math",
146
- missing_grad=True,
148
+ is_differentiable=False,
147
149
  )
148
150
 
149
151
  add_builtin(
@@ -285,7 +287,36 @@ add_builtin(
285
287
  group="Scalar Math",
286
288
  require_original_output_arg=True,
287
289
  )
288
-
290
+ add_builtin(
291
+ "erf",
292
+ input_types={"x": Float},
293
+ value_func=sametypes_create_value_func(Float),
294
+ doc="Return the error function of ``x``.",
295
+ group="Scalar Math",
296
+ )
297
+ add_builtin(
298
+ "erfc",
299
+ input_types={"x": Float},
300
+ value_func=sametypes_create_value_func(Float),
301
+ doc="Return the complementary error function of ``x``.",
302
+ group="Scalar Math",
303
+ )
304
+ add_builtin(
305
+ "erfinv",
306
+ input_types={"x": Float},
307
+ value_func=sametypes_create_value_func(Float),
308
+ doc="Return the inverse error function of ``x``.",
309
+ group="Scalar Math",
310
+ require_original_output_arg=True,
311
+ )
312
+ add_builtin(
313
+ "erfcinv",
314
+ input_types={"x": Float},
315
+ value_func=sametypes_create_value_func(Float),
316
+ doc="Return the inverse complementary error function of ``x``.",
317
+ group="Scalar Math",
318
+ require_original_output_arg=True,
319
+ )
289
320
  add_builtin(
290
321
  "round",
291
322
  input_types={"x": Float},
@@ -295,7 +326,7 @@ add_builtin(
295
326
 
296
327
  This is the most intuitive form of rounding in the colloquial sense, but can be slower than other options like :func:`warp.rint()`.
297
328
  Differs from :func:`numpy.round()`, which behaves the same way as :func:`numpy.rint()`.""",
298
- missing_grad=True,
329
+ is_differentiable=False,
299
330
  )
300
331
 
301
332
  add_builtin(
@@ -306,7 +337,7 @@ add_builtin(
306
337
  doc="""Return the nearest integer value to ``x``, rounding halfway cases to nearest even integer.
307
338
 
308
339
  It is generally faster than :func:`warp.round()`. Equivalent to :func:`numpy.rint()`.""",
309
- missing_grad=True,
340
+ is_differentiable=False,
310
341
  )
311
342
 
312
343
  add_builtin(
@@ -319,7 +350,7 @@ add_builtin(
319
350
  In other words, it discards the fractional part of ``x``.
320
351
  It is similar to casting ``float(int(a))``, but preserves the negative sign when ``x`` is in the range [-0.0, -1.0).
321
352
  Equivalent to :func:`numpy.trunc()` and :func:`numpy.fix()`.""",
322
- missing_grad=True,
353
+ is_differentiable=False,
323
354
  )
324
355
 
325
356
  add_builtin(
@@ -328,7 +359,7 @@ add_builtin(
328
359
  value_func=sametypes_create_value_func(Float),
329
360
  group="Scalar Math",
330
361
  doc="""Return the largest integer that is less than or equal to ``x``.""",
331
- missing_grad=True,
362
+ is_differentiable=False,
332
363
  )
333
364
 
334
365
  add_builtin(
@@ -337,7 +368,7 @@ add_builtin(
337
368
  value_func=sametypes_create_value_func(Float),
338
369
  group="Scalar Math",
339
370
  doc="""Return the smallest integer that is greater than or equal to ``x``.""",
340
- missing_grad=True,
371
+ is_differentiable=False,
341
372
  )
342
373
 
343
374
  add_builtin(
@@ -348,7 +379,7 @@ add_builtin(
348
379
  doc="""Retrieve the fractional part of ``x``.
349
380
 
350
381
  In other words, it discards the integer part of ``x`` and is equivalent to ``x - trunc(x)``.""",
351
- missing_grad=True,
382
+ is_differentiable=False,
352
383
  )
353
384
 
354
385
  add_builtin(
@@ -357,7 +388,7 @@ add_builtin(
357
388
  value_type=builtins.bool,
358
389
  group="Scalar Math",
359
390
  doc="""Return ``True`` if ``a`` is a finite number, otherwise return ``False``.""",
360
- missing_grad=True,
391
+ is_differentiable=False,
361
392
  )
362
393
  add_builtin(
363
394
  "isfinite",
@@ -365,7 +396,7 @@ add_builtin(
365
396
  value_type=builtins.bool,
366
397
  group="Vector Math",
367
398
  doc="Return ``True`` if all elements of the vector ``a`` are finite, otherwise return ``False``.",
368
- missing_grad=True,
399
+ is_differentiable=False,
369
400
  )
370
401
  add_builtin(
371
402
  "isfinite",
@@ -373,7 +404,7 @@ add_builtin(
373
404
  value_type=builtins.bool,
374
405
  group="Vector Math",
375
406
  doc="Return ``True`` if all elements of the quaternion ``a`` are finite, otherwise return ``False``.",
376
- missing_grad=True,
407
+ is_differentiable=False,
377
408
  )
378
409
  add_builtin(
379
410
  "isfinite",
@@ -381,7 +412,7 @@ add_builtin(
381
412
  value_type=builtins.bool,
382
413
  group="Vector Math",
383
414
  doc="Return ``True`` if all elements of the matrix ``a`` are finite, otherwise return ``False``.",
384
- missing_grad=True,
415
+ is_differentiable=False,
385
416
  )
386
417
 
387
418
  add_builtin(
@@ -390,7 +421,7 @@ add_builtin(
390
421
  value_type=builtins.bool,
391
422
  doc="Return ``True`` if ``a`` is NaN, otherwise return ``False``.",
392
423
  group="Scalar Math",
393
- missing_grad=True,
424
+ is_differentiable=False,
394
425
  )
395
426
  add_builtin(
396
427
  "isnan",
@@ -398,7 +429,7 @@ add_builtin(
398
429
  value_type=builtins.bool,
399
430
  group="Vector Math",
400
431
  doc="Return ``True`` if any element of the vector ``a`` is NaN, otherwise return ``False``.",
401
- missing_grad=True,
432
+ is_differentiable=False,
402
433
  )
403
434
  add_builtin(
404
435
  "isnan",
@@ -406,7 +437,7 @@ add_builtin(
406
437
  value_type=builtins.bool,
407
438
  group="Vector Math",
408
439
  doc="Return ``True`` if any element of the quaternion ``a`` is NaN, otherwise return ``False``.",
409
- missing_grad=True,
440
+ is_differentiable=False,
410
441
  )
411
442
  add_builtin(
412
443
  "isnan",
@@ -414,7 +445,7 @@ add_builtin(
414
445
  value_type=builtins.bool,
415
446
  group="Vector Math",
416
447
  doc="Return ``True`` if any element of the matrix ``a`` is NaN, otherwise return ``False``.",
417
- missing_grad=True,
448
+ is_differentiable=False,
418
449
  )
419
450
 
420
451
  add_builtin(
@@ -423,7 +454,7 @@ add_builtin(
423
454
  value_type=builtins.bool,
424
455
  group="Scalar Math",
425
456
  doc="""Return ``True`` if ``a`` is positive or negative infinity, otherwise return ``False``.""",
426
- missing_grad=True,
457
+ is_differentiable=False,
427
458
  )
428
459
  add_builtin(
429
460
  "isinf",
@@ -431,7 +462,7 @@ add_builtin(
431
462
  value_type=builtins.bool,
432
463
  group="Vector Math",
433
464
  doc="Return ``True`` if any element of the vector ``a`` is positive or negative infinity, otherwise return ``False``.",
434
- missing_grad=True,
465
+ is_differentiable=False,
435
466
  )
436
467
  add_builtin(
437
468
  "isinf",
@@ -439,7 +470,7 @@ add_builtin(
439
470
  value_type=builtins.bool,
440
471
  group="Vector Math",
441
472
  doc="Return ``True`` if any element of the quaternion ``a`` is positive or negative infinity, otherwise return ``False``.",
442
- missing_grad=True,
473
+ is_differentiable=False,
443
474
  )
444
475
  add_builtin(
445
476
  "isinf",
@@ -447,7 +478,7 @@ add_builtin(
447
478
  value_type=builtins.bool,
448
479
  group="Vector Math",
449
480
  doc="Return ``True`` if any element of the matrix ``a`` is positive or negative infinity, otherwise return ``False``.",
450
- missing_grad=True,
481
+ is_differentiable=False,
451
482
  )
452
483
 
453
484
 
@@ -555,7 +586,7 @@ add_builtin(
555
586
  value_func=lambda arg_types, arg_values: warp.uint32,
556
587
  doc="Return the index of the minimum element of a vector ``a``.",
557
588
  group="Vector Math",
558
- missing_grad=True,
589
+ is_differentiable=False,
559
590
  )
560
591
  add_builtin(
561
592
  "argmax",
@@ -563,7 +594,7 @@ add_builtin(
563
594
  value_func=lambda arg_types, arg_values: warp.uint32,
564
595
  doc="Return the index of the maximum element of a vector ``a``.",
565
596
  group="Vector Math",
566
- missing_grad=True,
597
+ is_differentiable=False,
567
598
  )
568
599
 
569
600
  add_builtin(
@@ -888,7 +919,7 @@ def vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
888
919
 
889
920
  if dtype is None:
890
921
  dtype = value_type
891
- elif not warp.types.scalars_equal(value_type, dtype):
922
+ elif not warp._src.types.scalars_equal(value_type, dtype):
892
923
  raise RuntimeError(
893
924
  f"the value used to fill this vector is expected to be of the type `{dtype.__name__}`"
894
925
  )
@@ -909,7 +940,7 @@ def vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
909
940
 
910
941
  if dtype is None:
911
942
  dtype = value_type
912
- elif not warp.types.scalars_equal(value_type, dtype):
943
+ elif not warp._src.types.scalars_equal(value_type, dtype):
913
944
  raise RuntimeError(
914
945
  f"all values used to initialize this vector are expected to be of the type `{dtype.__name__}`"
915
946
  )
@@ -992,7 +1023,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
992
1023
 
993
1024
  if dtype is None:
994
1025
  dtype = value_type
995
- elif not warp.types.scalars_equal(value_type, dtype):
1026
+ elif not warp._src.types.scalars_equal(value_type, dtype):
996
1027
  raise RuntimeError(
997
1028
  f"the value used to fill this matrix is expected to be of the type `{dtype.__name__}`"
998
1029
  )
@@ -1002,7 +1033,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
1002
1033
  raise RuntimeError("the `shape` argument must be specified when initializing a matrix by value")
1003
1034
 
1004
1035
  if all(type_is_vector(x) for x in variadic_arg_types):
1005
- warp.utils.warn(
1036
+ warp._src.utils.warn(
1006
1037
  "the built-in `wp.matrix()` won't support taking column vectors as input "
1007
1038
  "in the future. Use `wp.matrix_from_rows()` or `wp.matrix_from_cols()` instead.",
1008
1039
  DeprecationWarning,
@@ -1031,7 +1062,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
1031
1062
 
1032
1063
  if dtype is None:
1033
1064
  dtype = value_type
1034
- elif not warp.types.scalars_equal(value_type, dtype):
1065
+ elif not warp._src.types.scalars_equal(value_type, dtype):
1035
1066
  raise RuntimeError(
1036
1067
  f"all values used to initialize this matrix are expected to be of the type `{dtype.__name__}`"
1037
1068
  )
@@ -1203,49 +1234,18 @@ add_builtin(
1203
1234
  doc="Create an identity matrix with shape=(n,n) with the type given by ``dtype``.",
1204
1235
  group="Vector Math",
1205
1236
  export=False,
1206
- missing_grad=True,
1237
+ is_differentiable=False,
1207
1238
  )
1208
1239
 
1209
1240
 
1210
1241
  def matrix_transform_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1211
- warp.utils.warn(
1212
- "the built-in `wp.matrix()` function to construct a 4x4 matrix from a 3D position, quaternion, "
1213
- "and 3D scale vector will be deprecated in favor of `wp.transform_compose()`.",
1214
- DeprecationWarning,
1215
- )
1216
1242
  if arg_types is None:
1217
1243
  return matrix(shape=(4, 4), dtype=Float)
1218
1244
 
1219
- dtype = arg_values.get("dtype", None)
1220
-
1221
- value_arg_types = tuple(v for k, v in arg_types.items() if k != "dtype")
1222
- try:
1223
- value_type = scalar_infer_type(value_arg_types)
1224
- except RuntimeError:
1225
- raise RuntimeError(
1226
- "all values given when constructing a transformation matrix must have the same type"
1227
- ) from None
1228
-
1229
- if dtype is None:
1230
- dtype = value_type
1231
- elif not warp.types.scalars_equal(value_type, dtype):
1232
- raise RuntimeError(
1233
- f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
1234
- )
1235
-
1236
- return matrix(shape=(4, 4), dtype=dtype)
1237
-
1238
-
1239
- def matrix_transform_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
1240
- # We're in the codegen stage where we emit the code calling the built-in.
1241
- # Further validate the given argument values if needed and map them
1242
- # to the underlying C++ function's runtime and template params.
1243
-
1244
- dtype = return_type._wp_scalar_type_
1245
-
1246
- func_args = tuple(v for k, v in args.items() if k != "dtype")
1247
- template_args = (4, 4, dtype)
1248
- return (func_args, template_args)
1245
+ raise RuntimeError(
1246
+ "the built-in `wp.matrix()` to construct a 4x4 matrix from a 3D position, quaternion, "
1247
+ "and 3D scale vector has been removed in favor of `wp.transform_compose()`."
1248
+ )
1249
1249
 
1250
1250
 
1251
1251
  add_builtin(
@@ -1259,13 +1259,14 @@ add_builtin(
1259
1259
  defaults={"dtype": None},
1260
1260
  value_func=matrix_transform_value_func,
1261
1261
  export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
1262
- dispatch_func=matrix_transform_dispatch_func,
1263
1262
  native_func="mat_t",
1264
1263
  doc="""Construct a 4x4 transformation matrix that applies the transformations as
1265
1264
  Translation(pos)*Rotation(rot)*Scaling(scale) when applied to column vectors, i.e.: y = (TRS)*x
1266
1265
 
1267
- .. warning::
1268
- This function has been deprecated in favor of :func:`warp.math.transform_compose()`.""",
1266
+ .. versionremoved:: 1.10
1267
+ This function has been removed in favor of :func:`warp.math.transform_compose()`.
1268
+
1269
+ .. deprecated:: 1.8""",
1269
1270
  group="Vector Math",
1270
1271
  export=False,
1271
1272
  )
@@ -1460,7 +1461,7 @@ def quaternion_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str
1460
1461
 
1461
1462
  if dtype is None:
1462
1463
  dtype = value_type
1463
- elif not warp.types.scalars_equal(value_type, dtype):
1464
+ elif not warp._src.types.scalars_equal(value_type, dtype):
1464
1465
  raise RuntimeError(
1465
1466
  f"all values used to initialize this quaternion are expected to be of the type `{dtype.__name__}`"
1466
1467
  )
@@ -1568,7 +1569,7 @@ add_builtin(
1568
1569
  group="Quaternion Math",
1569
1570
  doc="Construct an identity quaternion with zero imaginary part and real part of 1.0",
1570
1571
  export=True,
1571
- missing_grad=True,
1572
+ is_differentiable=False,
1572
1573
  )
1573
1574
 
1574
1575
  add_builtin(
@@ -1697,7 +1698,7 @@ def transformation_value_func(arg_types: Mapping[str, type], arg_values: Mapping
1697
1698
  value_type = strip_reference(variadic_arg_types[0])
1698
1699
  if dtype is None:
1699
1700
  dtype = value_type
1700
- elif not warp.types.scalars_equal(value_type, dtype):
1701
+ elif not warp._src.types.scalars_equal(value_type, dtype):
1701
1702
  raise RuntimeError(
1702
1703
  f"the value used to fill this transform is expected to be of the type `{dtype.__name__}`"
1703
1704
  )
@@ -1710,7 +1711,7 @@ def transformation_value_func(arg_types: Mapping[str, type], arg_values: Mapping
1710
1711
 
1711
1712
  if dtype is None:
1712
1713
  dtype = value_type
1713
- elif not warp.types.scalars_equal(value_type, dtype):
1714
+ elif not warp._src.types.scalars_equal(value_type, dtype):
1714
1715
  raise RuntimeError(
1715
1716
  f"all values used to initialize this transform are expected to be of the type `{dtype.__name__}`"
1716
1717
  )
@@ -1735,7 +1736,7 @@ def transformation_pq_value_func(arg_types: Mapping[str, type], arg_values: Mapp
1735
1736
  dtype = arg_values.get("dtype", None)
1736
1737
  if dtype is None:
1737
1738
  dtype = value_type
1738
- elif not warp.types.scalars_equal(value_type, dtype):
1739
+ elif not warp._src.types.scalars_equal(value_type, dtype):
1739
1740
  raise RuntimeError(
1740
1741
  f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
1741
1742
  )
@@ -1750,9 +1751,19 @@ def transformation_dispatch_func(input_types: Mapping[str, type], return_type: A
1750
1751
 
1751
1752
  dtype = return_type._wp_scalar_type_
1752
1753
 
1753
- variadic_args = tuple(v for k, v in args.items() if k != "dtype")
1754
+ variadic_args = args.get("args", ())
1755
+ variadic_arg_count = len(variadic_args)
1756
+
1757
+ if variadic_arg_count == 7:
1758
+ func_args = variadic_args
1759
+ else:
1760
+ func_args = tuple(v for k, v in args.items() if k != "dtype")
1761
+ if "p" in args and "q" not in args:
1762
+ quat_ident = warp._src.codegen.Var(
1763
+ label=None, type=quaternion(dtype=dtype), constant=quaternion(dtype=dtype)(0, 0, 0, 1)
1764
+ )
1765
+ func_args += (quat_ident,)
1754
1766
 
1755
- func_args = variadic_args
1756
1767
  template_args = (dtype,)
1757
1768
  return (func_args, template_args)
1758
1769
 
@@ -1760,7 +1771,7 @@ def transformation_dispatch_func(input_types: Mapping[str, type], return_type: A
1760
1771
  add_builtin(
1761
1772
  "transformation",
1762
1773
  input_types={"p": vector(length=3, dtype=Float), "q": quaternion(dtype=Float), "dtype": Float},
1763
- defaults={"dtype": None},
1774
+ defaults={"q": None, "dtype": None},
1764
1775
  value_func=transformation_pq_value_func,
1765
1776
  export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
1766
1777
  dispatch_func=transformation_dispatch_func,
@@ -1784,7 +1795,6 @@ add_builtin(
1784
1795
  doc="Construct a spatial transform vector of given dtype.",
1785
1796
  group="Spatial Math",
1786
1797
  export=False,
1787
- missing_grad=True,
1788
1798
  )
1789
1799
 
1790
1800
 
@@ -1819,7 +1829,7 @@ add_builtin(
1819
1829
  group="Transformations",
1820
1830
  doc="Construct an identity transform with zero translation and identity rotation.",
1821
1831
  export=True,
1822
- missing_grad=True,
1832
+ is_differentiable=False,
1823
1833
  )
1824
1834
 
1825
1835
  add_builtin(
@@ -1953,7 +1963,7 @@ def spatial_vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping
1953
1963
 
1954
1964
  if dtype is None:
1955
1965
  dtype = value_type
1956
- elif not warp.types.scalars_equal(value_type, dtype):
1966
+ elif not warp._src.types.scalars_equal(value_type, dtype):
1957
1967
  raise RuntimeError(
1958
1968
  f"all values used to initialize this spatial vector are expected to be of the type `{dtype.__name__}`"
1959
1969
  )
@@ -2147,7 +2157,7 @@ add_builtin(
2147
2157
  value_func=tile_zeros_value_func,
2148
2158
  dispatch_func=tile_zeros_dispatch_func,
2149
2159
  variadic=False,
2150
- missing_grad=True,
2160
+ is_differentiable=False,
2151
2161
  doc="""Allocate a tile of zero-initialized items.
2152
2162
 
2153
2163
  :param shape: Shape of the output tile
@@ -2167,7 +2177,7 @@ add_builtin(
2167
2177
  value_func=tile_zeros_value_func,
2168
2178
  dispatch_func=tile_zeros_dispatch_func,
2169
2179
  variadic=False,
2170
- missing_grad=True,
2180
+ is_differentiable=False,
2171
2181
  hidden=True,
2172
2182
  group="Tile Primitives",
2173
2183
  export=False,
@@ -2219,7 +2229,7 @@ add_builtin(
2219
2229
  defaults={"storage": "register"},
2220
2230
  value_func=tile_ones_value_func,
2221
2231
  dispatch_func=tile_ones_dispatch_func,
2222
- missing_grad=True,
2232
+ is_differentiable=False,
2223
2233
  doc="""Allocate a tile of one-initialized items.
2224
2234
 
2225
2235
  :param shape: Shape of the output tile
@@ -2238,7 +2248,86 @@ add_builtin(
2238
2248
  defaults={"storage": "register"},
2239
2249
  value_func=tile_ones_value_func,
2240
2250
  dispatch_func=tile_ones_dispatch_func,
2241
- missing_grad=True,
2251
+ is_differentiable=False,
2252
+ hidden=True,
2253
+ group="Tile Primitives",
2254
+ export=False,
2255
+ )
2256
+
2257
+
2258
+ def tile_full_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
2259
+ # return generic type (for doc builds)
2260
+ if arg_types is None:
2261
+ return tile(dtype=Any, shape=Tuple[int, ...])
2262
+
2263
+ shape = extract_tuple(arg_values["shape"], as_constant=True)
2264
+
2265
+ if None in shape:
2266
+ raise ValueError("Tile functions require shape to be a compile time constant.")
2267
+
2268
+ if "value" not in arg_values:
2269
+ raise TypeError("tile_full() missing required keyword argument 'value'")
2270
+
2271
+ if "dtype" not in arg_values:
2272
+ raise TypeError("tile_full() missing required keyword argument 'dtype'")
2273
+
2274
+ if "storage" not in arg_values:
2275
+ raise TypeError("tile_full() missing required keyword argument 'storage'")
2276
+
2277
+ if arg_values["storage"] not in {"shared", "register"}:
2278
+ raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
2279
+
2280
+ dtype = arg_values["dtype"]
2281
+
2282
+ return tile(dtype=dtype, shape=shape, storage=arg_values["storage"])
2283
+
2284
+
2285
+ def tile_full_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
2286
+ shape = extract_tuple(arg_values["shape"], as_constant=True)
2287
+
2288
+ if None in shape:
2289
+ raise ValueError("Tile functions require shape to be a compile time constant.")
2290
+
2291
+ dtype = arg_values["dtype"]
2292
+ value = arg_values["value"]
2293
+
2294
+ func_args = [value]
2295
+
2296
+ template_args = []
2297
+ template_args.append(dtype)
2298
+ template_args.extend(shape)
2299
+
2300
+ return (func_args, template_args)
2301
+
2302
+
2303
+ add_builtin(
2304
+ "tile_full",
2305
+ input_types={"shape": Tuple[int, ...], "value": Any, "dtype": Any, "storage": str},
2306
+ defaults={"storage": "register"},
2307
+ value_func=tile_full_value_func,
2308
+ dispatch_func=tile_full_dispatch_func,
2309
+ is_differentiable=False,
2310
+ doc="""Allocate a tile filled with the specified value.
2311
+
2312
+ :param shape: Shape of the output tile
2313
+ :param value: Value to fill the tile with
2314
+ :param dtype: Data type of output tile's elements
2315
+ :param storage: The storage location for the tile: ``"register"`` for registers
2316
+ (default) or ``"shared"`` for shared memory.
2317
+ :returns: A tile filled with the specified value""",
2318
+ group="Tile Primitives",
2319
+ export=False,
2320
+ )
2321
+
2322
+
2323
+ # overload for scalar shape
2324
+ add_builtin(
2325
+ "tile_full",
2326
+ input_types={"shape": int, "value": Any, "dtype": Any, "storage": str},
2327
+ defaults={"storage": "register"},
2328
+ value_func=tile_full_value_func,
2329
+ dispatch_func=tile_full_dispatch_func,
2330
+ is_differentiable=False,
2242
2331
  hidden=True,
2243
2332
  group="Tile Primitives",
2244
2333
  export=False,
@@ -2300,13 +2389,13 @@ def tile_arange_dispatch_func(arg_types: Mapping[str, type], return_type: Any, a
2300
2389
  args = arg_values["args"]
2301
2390
 
2302
2391
  if len(args) == 1:
2303
- start = warp.codegen.Var(label=None, type=return_type.dtype, constant=0)
2392
+ start = warp._src.codegen.Var(label=None, type=return_type.dtype, constant=0)
2304
2393
  stop = args[0]
2305
- step = warp.codegen.Var(label=None, type=return_type.dtype, constant=1)
2394
+ step = warp._src.codegen.Var(label=None, type=return_type.dtype, constant=1)
2306
2395
  elif len(args) == 2:
2307
2396
  start = args[0]
2308
2397
  stop = args[1]
2309
- step = warp.codegen.Var(label=None, type=return_type.dtype, constant=1)
2398
+ step = warp._src.codegen.Var(label=None, type=return_type.dtype, constant=1)
2310
2399
  elif len(args) == 3:
2311
2400
  start = args[0]
2312
2401
  stop = args[1]
@@ -2329,7 +2418,7 @@ add_builtin(
2329
2418
  value_func=tile_arange_value_func,
2330
2419
  dispatch_func=tile_arange_dispatch_func,
2331
2420
  variadic=True,
2332
- missing_grad=True,
2421
+ is_differentiable=False,
2333
2422
  doc="""Generate a tile of linearly spaced elements.
2334
2423
 
2335
2424
  :param args: Variable-length positional arguments, interpreted as:
@@ -3124,7 +3213,7 @@ add_builtin(
3124
3213
  :param shape: Shape of the returned slice
3125
3214
  :returns: A tile with dimensions given by the specified shape or the remaining source tile dimensions""",
3126
3215
  group="Tile Primitives",
3127
- missing_grad=True,
3216
+ is_differentiable=False,
3128
3217
  export=False,
3129
3218
  )
3130
3219
 
@@ -3371,7 +3460,32 @@ add_builtin(
3371
3460
 
3372
3461
  add_builtin(
3373
3462
  "assign",
3374
- input_types={"dst": tile(dtype=Any, shape=Tuple[int, int]), "i": int, "j": int, "src": Any},
3463
+ input_types={"dst": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "src": Any},
3464
+ value_func=tile_assign_value_func,
3465
+ group="Tile Primitives",
3466
+ export=False,
3467
+ hidden=True,
3468
+ )
3469
+
3470
+ add_builtin(
3471
+ "assign",
3472
+ input_types={"dst": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "src": Any},
3473
+ value_func=tile_assign_value_func,
3474
+ group="Tile Primitives",
3475
+ export=False,
3476
+ hidden=True,
3477
+ )
3478
+
3479
+ add_builtin(
3480
+ "assign",
3481
+ input_types={
3482
+ "dst": tile(dtype=Any, shape=Tuple[int, ...]),
3483
+ "i": int,
3484
+ "j": int,
3485
+ "k": int,
3486
+ "l": int,
3487
+ "src": Any,
3488
+ },
3375
3489
  value_func=tile_assign_value_func,
3376
3490
  group="Tile Primitives",
3377
3491
  export=False,
@@ -3380,7 +3494,15 @@ add_builtin(
3380
3494
 
3381
3495
  add_builtin(
3382
3496
  "assign",
3383
- input_types={"dst": tile(dtype=Any, shape=Tuple[int, int, int]), "i": int, "j": int, "k": int, "src": Any},
3497
+ input_types={
3498
+ "dst": tile(dtype=Any, shape=Tuple[int, ...]),
3499
+ "i": int,
3500
+ "j": int,
3501
+ "k": int,
3502
+ "l": int,
3503
+ "m": int,
3504
+ "src": Any,
3505
+ },
3384
3506
  value_func=tile_assign_value_func,
3385
3507
  group="Tile Primitives",
3386
3508
  export=False,
@@ -3395,6 +3517,8 @@ add_builtin(
3395
3517
  "j": int,
3396
3518
  "k": int,
3397
3519
  "l": int,
3520
+ "m": int,
3521
+ "n": int,
3398
3522
  "src": Any,
3399
3523
  },
3400
3524
  value_func=tile_assign_value_func,
@@ -3416,7 +3540,7 @@ def tile_value_func(arg_types, arg_values):
3416
3540
 
3417
3541
  if preserve_type:
3418
3542
  dtype = arg_types["x"]
3419
- shape = (warp.codegen.options["block_dim"],)
3543
+ shape = (warp._src.codegen.options["block_dim"],)
3420
3544
 
3421
3545
  return tile(dtype=dtype, shape=shape)
3422
3546
 
@@ -3424,18 +3548,18 @@ def tile_value_func(arg_types, arg_values):
3424
3548
  if type_is_vector(arg_types["x"]):
3425
3549
  dtype = arg_types["x"]._wp_scalar_type_
3426
3550
  length = arg_types["x"]._shape_[0]
3427
- shape = (length, warp.codegen.options["block_dim"])
3551
+ shape = (length, warp._src.codegen.options["block_dim"])
3428
3552
  elif type_is_quaternion(arg_types["x"]):
3429
3553
  dtype = arg_types["x"]._wp_scalar_type_
3430
- shape = (4, warp.codegen.options["block_dim"])
3554
+ shape = (4, warp._src.codegen.options["block_dim"])
3431
3555
  elif type_is_matrix(arg_types["x"]):
3432
3556
  dtype = arg_types["x"]._wp_scalar_type_
3433
3557
  rows = arg_types["x"]._shape_[0]
3434
3558
  cols = arg_types["x"]._shape_[1]
3435
- shape = (rows, cols, warp.codegen.options["block_dim"])
3559
+ shape = (rows, cols, warp._src.codegen.options["block_dim"])
3436
3560
  else:
3437
3561
  dtype = arg_types["x"]
3438
- shape = (warp.codegen.options["block_dim"],)
3562
+ shape = (warp._src.codegen.options["block_dim"],)
3439
3563
 
3440
3564
  return tile(dtype=dtype, shape=shape)
3441
3565
 
@@ -3525,17 +3649,17 @@ def untile_value_func(arg_types, arg_values):
3525
3649
  if not is_tile(t):
3526
3650
  raise TypeError(f"untile() argument must be a tile, got {t!r}")
3527
3651
 
3528
- if t.shape[-1] != warp.codegen.options["block_dim"]:
3652
+ if t.shape[-1] != warp._src.codegen.options["block_dim"]:
3529
3653
  raise ValueError(
3530
- f"untile() argument last dimension {t.shape[-1]} does not match the expected block width {warp.codegen.options['block_dim']}"
3654
+ f"untile() argument last dimension {t.shape[-1]} does not match the expected block width {warp._src.codegen.options['block_dim']}"
3531
3655
  )
3532
3656
 
3533
3657
  if len(t.shape) == 1:
3534
3658
  return t.dtype
3535
3659
  elif len(t.shape) == 2:
3536
- return warp.types.vector(t.shape[0], t.dtype)
3660
+ return warp._src.types.vector(t.shape[0], t.dtype)
3537
3661
  elif len(t.shape) == 3:
3538
- return warp.types.matrix((t.shape[0], t.shape[1]), t.dtype)
3662
+ return warp._src.types.matrix((t.shape[0], t.shape[1]), t.dtype)
3539
3663
  else:
3540
3664
  raise ValueError(f"untile() argument must have a positive size in dimension 0, but got {t.shape[0]}")
3541
3665
 
@@ -3597,7 +3721,36 @@ def tile_extract_value_func(arg_types, arg_values):
3597
3721
  # force the input tile to shared memory
3598
3722
  arg_types["a"].storage = "shared"
3599
3723
 
3600
- return arg_types["a"].dtype
3724
+ # count the number of indices (all parameters except the tile "a")
3725
+ num_indices = len(arg_types) - 1
3726
+ tile_dtype = arg_types["a"].dtype
3727
+ tile_shape = arg_types["a"].shape
3728
+
3729
+ if type_is_vector(tile_dtype):
3730
+ if num_indices == len(tile_shape):
3731
+ return tile_dtype
3732
+ elif num_indices == len(tile_shape) + 1:
3733
+ return tile_dtype._wp_scalar_type_
3734
+ else:
3735
+ raise IndexError(
3736
+ f"tile_extract: incorrect number of indices ({num_indices}) for tile shape {tuple(tile_shape)}"
3737
+ )
3738
+ elif type_is_matrix(tile_dtype):
3739
+ if num_indices == len(tile_shape):
3740
+ return tile_dtype
3741
+ elif num_indices == len(tile_shape) + 2:
3742
+ return tile_dtype._wp_scalar_type_
3743
+ else:
3744
+ raise IndexError(
3745
+ f"tile_extract: incorrect number of indices ({num_indices}) for matrix tile shape {tuple(tile_shape)}"
3746
+ )
3747
+ else:
3748
+ # scalar element: index count must exactly match tile rank
3749
+ if num_indices == len(tile_shape):
3750
+ return tile_dtype
3751
+ raise IndexError(
3752
+ f"tile_extract: incorrect number of indices ({num_indices}) for tile shape {tuple(tile_shape)}"
3753
+ )
3601
3754
 
3602
3755
 
3603
3756
  add_builtin(
@@ -3621,7 +3774,7 @@ add_builtin(
3621
3774
 
3622
3775
  add_builtin(
3623
3776
  "tile_extract",
3624
- input_types={"a": tile(dtype=Any, shape=Tuple[int, int]), "i": int, "j": int},
3777
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int},
3625
3778
  value_func=tile_extract_value_func,
3626
3779
  variadic=False,
3627
3780
  doc="""Extract a single element from the tile.
@@ -3632,7 +3785,7 @@ add_builtin(
3632
3785
 
3633
3786
  :param a: Tile to extract the element from
3634
3787
  :param i: Coordinate of element on first dimension
3635
- :param j: Coordinate of element on the second dimension
3788
+ :param j: Coordinate of element on the second dimension, or vector index
3636
3789
  :returns: The value of the element at the specified tile location with the same data type as the input tile""",
3637
3790
  group="Tile Primitives",
3638
3791
  hidden=True,
@@ -3641,7 +3794,57 @@ add_builtin(
3641
3794
 
3642
3795
  add_builtin(
3643
3796
  "tile_extract",
3644
- input_types={"a": tile(dtype=Any, shape=Tuple[int, int, int]), "i": int, "j": int, "k": int},
3797
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int},
3798
+ value_func=tile_extract_value_func,
3799
+ variadic=False,
3800
+ doc="""Extract a single element from the tile.
3801
+
3802
+ This function will extract an element from the tile and broadcast its value to all threads in the block.
3803
+
3804
+ Note that this may incur additional synchronization if the source tile is a register tile.
3805
+
3806
+ :param a: Tile to extract the element from
3807
+ :param i: Coordinate of element on first dimension
3808
+ :param j: Coordinate of element on the second dimension, or first matrix index
3809
+ :param k: Coordinate of element on the third dimension, or vector index, or second matrix index
3810
+ :returns: The value of the element at the specified tile location with the same data type as the input tile""",
3811
+ group="Tile Primitives",
3812
+ hidden=True,
3813
+ export=False,
3814
+ )
3815
+
3816
+ add_builtin(
3817
+ "tile_extract",
3818
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int},
3819
+ value_func=tile_extract_value_func,
3820
+ variadic=False,
3821
+ doc="""Extract a single element from the tile.
3822
+
3823
+ This function will extract an element from the tile and broadcast its value to all threads in the block.
3824
+
3825
+ Note that this may incur additional synchronization if the source tile is a register tile.
3826
+
3827
+ :param a: Tile to extract the element from
3828
+ :param i: Coordinate of element on first dimension
3829
+ :param j: Coordinate of element on the second dimension
3830
+ :param k: Coordinate of element on the third dimension, or first matrix index
3831
+ :param l: Coordinate of element on the fourth dimension, or vector index, or second matrix index
3832
+ :returns: The value of the element at the specified tile location, with the same data type as the input tile""",
3833
+ group="Tile Primitives",
3834
+ hidden=True,
3835
+ export=False,
3836
+ )
3837
+
3838
+ add_builtin(
3839
+ "tile_extract",
3840
+ input_types={
3841
+ "a": tile(dtype=Any, shape=Tuple[int, ...]),
3842
+ "i": int,
3843
+ "j": int,
3844
+ "k": int,
3845
+ "l": int,
3846
+ "m": int,
3847
+ },
3645
3848
  value_func=tile_extract_value_func,
3646
3849
  variadic=False,
3647
3850
  doc="""Extract a single element from the tile.
@@ -3654,7 +3857,9 @@ add_builtin(
3654
3857
  :param i: Coordinate of element on first dimension
3655
3858
  :param j: Coordinate of element on the second dimension
3656
3859
  :param k: Coordinate of element on the third dimension
3657
- :returns: The value of the element at the specified tile location with the same data type as the input tile""",
3860
+ :param l: Coordinate of element on the fourth dimension, or first matrix index
3861
+ :param m: Vector index, or second matrix index
3862
+ :returns: The value of the element at the specified tile location, with the same data type as the input tile""",
3658
3863
  group="Tile Primitives",
3659
3864
  hidden=True,
3660
3865
  export=False,
@@ -3662,7 +3867,15 @@ add_builtin(
3662
3867
 
3663
3868
  add_builtin(
3664
3869
  "tile_extract",
3665
- input_types={"a": tile(dtype=Any, shape=Tuple[int, int, int, int]), "i": int, "j": int, "k": int, "l": int},
3870
+ input_types={
3871
+ "a": tile(dtype=Any, shape=Tuple[int, int, int, int]),
3872
+ "i": int,
3873
+ "j": int,
3874
+ "k": int,
3875
+ "l": int,
3876
+ "m": int,
3877
+ "n": int,
3878
+ },
3666
3879
  value_func=tile_extract_value_func,
3667
3880
  variadic=False,
3668
3881
  doc="""Extract a single element from the tile.
@@ -3676,6 +3889,8 @@ add_builtin(
3676
3889
  :param j: Coordinate of element on the second dimension
3677
3890
  :param k: Coordinate of element on the third dimension
3678
3891
  :param l: Coordinate of element on the fourth dimension
3892
+ :param m: Vector index, or first matrix index
3893
+ :param n: Second matrix index
3679
3894
  :returns: The value of the element at the specified tile location, with the same data type as the input tile""",
3680
3895
  group="Tile Primitives",
3681
3896
  hidden=True,
@@ -3762,49 +3977,160 @@ add_builtin(
3762
3977
  export=False,
3763
3978
  )
3764
3979
 
3765
-
3766
- def tile_transpose_value_func(arg_types, arg_values):
3767
- # return generic type (for doc builds)
3768
- if arg_types is None:
3769
- return tile(dtype=Any, shape=Tuple[int, int])
3770
-
3771
- if len(arg_types) != 1:
3772
- raise TypeError(f"tile_transpose() takes exactly 1 positional argument but {len(arg_types)} were given")
3773
-
3774
- t = arg_types["a"]
3775
-
3776
- if not is_tile(t):
3777
- raise TypeError(f"tile_transpose() argument must be a tile, got {t!r}")
3778
-
3779
- layout = None
3780
-
3781
- # flip layout
3782
- if t.layout == "rowmajor":
3783
- layout = "colmajor"
3784
- elif t.layout == "colmajor":
3785
- layout = "rowmajor"
3786
-
3787
- # force the input tile to shared memory
3788
- t.storage = "shared"
3789
-
3790
- return tile(
3791
- dtype=t.dtype,
3792
- shape=t.shape[::-1],
3793
- storage=t.storage,
3794
- strides=t.strides[::-1],
3795
- layout=layout,
3796
- owner=False,
3797
- )
3798
-
3799
-
3800
3980
  add_builtin(
3801
- "tile_transpose",
3802
- input_types={"a": tile(dtype=Any, shape=Tuple[int, int])},
3803
- value_func=tile_transpose_value_func,
3804
- variadic=True,
3805
- doc="""Transpose a tile.
3806
-
3807
- For shared memory tiles, this operation will alias the input tile.
3981
+ "tile_bit_and_inplace",
3982
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
3983
+ value_func=tile_inplace_value_func,
3984
+ group="Tile Primitives",
3985
+ hidden=True,
3986
+ export=False,
3987
+ is_differentiable=False,
3988
+ )
3989
+ add_builtin(
3990
+ "tile_bit_and_inplace",
3991
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
3992
+ value_func=tile_inplace_value_func,
3993
+ group="Tile Primitives",
3994
+ hidden=True,
3995
+ export=False,
3996
+ is_differentiable=False,
3997
+ )
3998
+ add_builtin(
3999
+ "tile_bit_and_inplace",
4000
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
4001
+ value_func=tile_inplace_value_func,
4002
+ group="Tile Primitives",
4003
+ hidden=True,
4004
+ export=False,
4005
+ is_differentiable=False,
4006
+ )
4007
+ add_builtin(
4008
+ "tile_bit_and_inplace",
4009
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
4010
+ value_func=tile_inplace_value_func,
4011
+ group="Tile Primitives",
4012
+ hidden=True,
4013
+ export=False,
4014
+ is_differentiable=False,
4015
+ )
4016
+
4017
+ add_builtin(
4018
+ "tile_bit_or_inplace",
4019
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
4020
+ value_func=tile_inplace_value_func,
4021
+ group="Tile Primitives",
4022
+ hidden=True,
4023
+ export=False,
4024
+ is_differentiable=False,
4025
+ )
4026
+ add_builtin(
4027
+ "tile_bit_or_inplace",
4028
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
4029
+ value_func=tile_inplace_value_func,
4030
+ group="Tile Primitives",
4031
+ hidden=True,
4032
+ export=False,
4033
+ is_differentiable=False,
4034
+ )
4035
+ add_builtin(
4036
+ "tile_bit_or_inplace",
4037
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
4038
+ value_func=tile_inplace_value_func,
4039
+ group="Tile Primitives",
4040
+ hidden=True,
4041
+ export=False,
4042
+ is_differentiable=False,
4043
+ )
4044
+ add_builtin(
4045
+ "tile_bit_or_inplace",
4046
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
4047
+ value_func=tile_inplace_value_func,
4048
+ group="Tile Primitives",
4049
+ hidden=True,
4050
+ export=False,
4051
+ is_differentiable=False,
4052
+ )
4053
+
4054
+ add_builtin(
4055
+ "tile_bit_xor_inplace",
4056
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
4057
+ value_func=tile_inplace_value_func,
4058
+ group="Tile Primitives",
4059
+ hidden=True,
4060
+ export=False,
4061
+ is_differentiable=False,
4062
+ )
4063
+ add_builtin(
4064
+ "tile_bit_xor_inplace",
4065
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
4066
+ value_func=tile_inplace_value_func,
4067
+ group="Tile Primitives",
4068
+ hidden=True,
4069
+ export=False,
4070
+ is_differentiable=False,
4071
+ )
4072
+ add_builtin(
4073
+ "tile_bit_xor_inplace",
4074
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
4075
+ value_func=tile_inplace_value_func,
4076
+ group="Tile Primitives",
4077
+ hidden=True,
4078
+ export=False,
4079
+ is_differentiable=False,
4080
+ )
4081
+ add_builtin(
4082
+ "tile_bit_xor_inplace",
4083
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
4084
+ value_func=tile_inplace_value_func,
4085
+ group="Tile Primitives",
4086
+ hidden=True,
4087
+ export=False,
4088
+ is_differentiable=False,
4089
+ )
4090
+
4091
+
4092
+ def tile_transpose_value_func(arg_types, arg_values):
4093
+ # return generic type (for doc builds)
4094
+ if arg_types is None:
4095
+ return tile(dtype=Any, shape=Tuple[int, int])
4096
+
4097
+ if len(arg_types) != 1:
4098
+ raise TypeError(f"tile_transpose() takes exactly 1 positional argument but {len(arg_types)} were given")
4099
+
4100
+ t = arg_types["a"]
4101
+
4102
+ if not is_tile(t):
4103
+ raise TypeError(f"tile_transpose() argument must be a tile, got {t!r}")
4104
+
4105
+ layout = None
4106
+
4107
+ # flip layout
4108
+ if t.layout == "rowmajor":
4109
+ layout = "colmajor"
4110
+ elif t.layout == "colmajor":
4111
+ layout = "rowmajor"
4112
+
4113
+ # force the input tile to shared memory
4114
+ t.storage = "shared"
4115
+
4116
+ return tile(
4117
+ dtype=t.dtype,
4118
+ shape=t.shape[::-1],
4119
+ storage=t.storage,
4120
+ strides=t.strides[::-1],
4121
+ layout=layout,
4122
+ owner=False,
4123
+ )
4124
+
4125
+
4126
+ add_builtin(
4127
+ "tile_transpose",
4128
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, int])},
4129
+ value_func=tile_transpose_value_func,
4130
+ variadic=True,
4131
+ doc="""Transpose a tile.
4132
+
4133
+ For shared memory tiles, this operation will alias the input tile.
3808
4134
  Register tiles will first be transferred to shared memory before transposition.
3809
4135
 
3810
4136
  :param a: Tile to transpose with ``shape=(M,N)``
@@ -3935,6 +4261,80 @@ add_builtin(
3935
4261
  )
3936
4262
 
3937
4263
 
4264
+ def tile_sum_axis_value_func(arg_types, arg_values):
4265
+ if arg_types is None:
4266
+ return tile(dtype=Scalar, shape=Tuple[int, ...])
4267
+
4268
+ a = arg_types["a"]
4269
+
4270
+ if not is_tile(a):
4271
+ raise TypeError(f"tile_sum() 'a' argument must be a tile, got {a!r}")
4272
+
4273
+ # force input tile to shared
4274
+ a.storage = "shared"
4275
+
4276
+ axis = arg_values["axis"]
4277
+ shape = a.shape
4278
+
4279
+ if axis < 0 or axis >= len(shape):
4280
+ raise ValueError(f"tile_sum() axis {axis} is out of bounds for tile with {len(shape)} dimensions")
4281
+
4282
+ # shape is identical less the axis reduction is along
4283
+ if len(shape) > 1:
4284
+ new_shape = shape[:axis] + shape[axis + 1 :]
4285
+ else:
4286
+ new_shape = (1,)
4287
+
4288
+ return tile(dtype=a.dtype, shape=new_shape)
4289
+
4290
+
4291
+ def tile_sum_axis_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
4292
+ tile = arg_values["a"]
4293
+ axis_var = arg_values["axis"]
4294
+ if not hasattr(axis_var, "constant") or axis_var.constant is None:
4295
+ raise ValueError("tile_sum() axis must be a compile-time constant")
4296
+ axis = axis_var.constant
4297
+
4298
+ return ((tile,), (axis,))
4299
+
4300
+
4301
+ add_builtin(
4302
+ "tile_sum",
4303
+ input_types={"a": tile(dtype=Scalar, shape=Tuple[int, ...]), "axis": int},
4304
+ value_func=tile_sum_axis_value_func,
4305
+ dispatch_func=tile_sum_axis_dispatch_func,
4306
+ doc="""Cooperatively compute the sum of the tile elements across an axis of the tile using all threads in the block.
4307
+
4308
+ :param a: The input tile. Must reside in shared memory.
4309
+ :param axis: The tile axis to compute the sum across. Must be a compile-time constant.
4310
+ :returns: A tile with the same shape as the input tile less the axis dimension and the same data type as the input tile.
4311
+
4312
+ Example:
4313
+
4314
+ .. code-block:: python
4315
+
4316
+ @wp.kernel
4317
+ def compute():
4318
+
4319
+ t = wp.tile_ones(dtype=float, shape=(8, 8))
4320
+ s = wp.tile_sum(t, axis=0)
4321
+
4322
+ print(s)
4323
+
4324
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
4325
+
4326
+ Prints:
4327
+
4328
+ .. code-block:: text
4329
+
4330
+ [8 8 8 8 8 8 8 8] = tile(shape=(8), storage=register)
4331
+
4332
+ """,
4333
+ group="Tile Primitives",
4334
+ export=False,
4335
+ )
4336
+
4337
+
3938
4338
  def tile_sort_value_func(arg_types, arg_values):
3939
4339
  # return generic type (for doc builds)
3940
4340
  if arg_types is None:
@@ -4011,7 +4411,7 @@ add_builtin(
4011
4411
  """,
4012
4412
  group="Tile Primitives",
4013
4413
  export=False,
4014
- missing_grad=True,
4414
+ is_differentiable=False,
4015
4415
  )
4016
4416
 
4017
4417
 
@@ -4065,7 +4465,7 @@ add_builtin(
4065
4465
  """,
4066
4466
  group="Tile Primitives",
4067
4467
  export=False,
4068
- missing_grad=True,
4468
+ is_differentiable=False,
4069
4469
  )
4070
4470
 
4071
4471
 
@@ -4119,7 +4519,7 @@ add_builtin(
4119
4519
  """,
4120
4520
  group="Tile Primitives",
4121
4521
  export=False,
4122
- missing_grad=True,
4522
+ is_differentiable=False,
4123
4523
  )
4124
4524
 
4125
4525
 
@@ -4172,7 +4572,7 @@ add_builtin(
4172
4572
  """,
4173
4573
  group="Tile Primitives",
4174
4574
  export=False,
4175
- missing_grad=True,
4575
+ is_differentiable=False,
4176
4576
  )
4177
4577
 
4178
4578
 
@@ -4225,11 +4625,10 @@ add_builtin(
4225
4625
  """,
4226
4626
  group="Tile Primitives",
4227
4627
  export=False,
4228
- missing_grad=True,
4628
+ is_differentiable=False,
4229
4629
  )
4230
4630
 
4231
4631
 
4232
- # does type propagation for load()
4233
4632
  def tile_reduce_value_func(arg_types, arg_values):
4234
4633
  if arg_types is None:
4235
4634
  return tile(dtype=Scalar, shape=(1,))
@@ -4283,7 +4682,88 @@ add_builtin(
4283
4682
  """,
4284
4683
  group="Tile Primitives",
4285
4684
  export=False,
4286
- missing_grad=True,
4685
+ is_differentiable=False,
4686
+ )
4687
+
4688
+
4689
+ def tile_reduce_axis_value_func(arg_types, arg_values):
4690
+ if arg_types is None:
4691
+ return tile(dtype=Scalar, shape=Tuple[int, ...])
4692
+
4693
+ a = arg_types["a"]
4694
+
4695
+ if not is_tile(a):
4696
+ raise TypeError(f"tile_reduce() 'a' argument must be a tile, got {a!r}")
4697
+
4698
+ # force input tile to shared memory
4699
+ a.storage = "shared"
4700
+
4701
+ axis = arg_values["axis"]
4702
+ shape = a.shape
4703
+
4704
+ if axis < 0 or axis >= len(shape):
4705
+ raise ValueError(f"tile_reduce() axis {axis} is out of bounds for tile with {len(shape)} dimensions")
4706
+
4707
+ # shape is identical less the axis reduction is along
4708
+ if len(shape) > 1:
4709
+ new_shape = shape[:axis] + shape[axis + 1 :]
4710
+ else:
4711
+ new_shape = (1,)
4712
+
4713
+ return tile(dtype=a.dtype, shape=new_shape)
4714
+
4715
+
4716
+ add_builtin(
4717
+ "tile_reduce",
4718
+ input_types={"op": Callable, "a": tile(dtype=Scalar, shape=Tuple[int, ...]), "axis": int},
4719
+ value_func=tile_reduce_axis_value_func,
4720
+ native_func="tile_reduce_axis",
4721
+ doc="""Apply a custom reduction operator across a tile axis.
4722
+
4723
+ This function cooperatively performs a reduction using the provided operator across an axis of the tile.
4724
+
4725
+ :param op: A callable function that accepts two arguments and returns one argument, may be a user function or builtin
4726
+ :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's data type. Must reside in shared memory.
4727
+ :param axis: The tile axis to perform the reduction across. Must be a compile-time constant.
4728
+ :returns: A tile with the same shape as the input tile less the axis dimension and the same data type as the input tile.
4729
+
4730
+ Example:
4731
+
4732
+ .. code-block:: python
4733
+
4734
+ TILE_M = wp.constant(4)
4735
+ TILE_N = wp.constant(2)
4736
+
4737
+ @wp.kernel
4738
+ def compute(x: wp.array2d(dtype=float), y: wp.array(dtype=float)):
4739
+
4740
+ a = wp.tile_load(x, shape=(TILE_M, TILE_N))
4741
+ b = wp.tile_reduce(wp.add, a, axis=1)
4742
+ wp.tile_store(y, b)
4743
+
4744
+ arr = np.arange(TILE_M * TILE_N).reshape(TILE_M, TILE_N)
4745
+
4746
+ x = wp.array(arr, dtype=float)
4747
+ y = wp.zeros(TILE_M, dtype=float)
4748
+
4749
+ wp.launch_tiled(compute, dim=[1], inputs=[x], outputs=[y], block_dim=32)
4750
+
4751
+ print(x.numpy())
4752
+ print(y.numpy())
4753
+
4754
+ Prints:
4755
+
4756
+ .. code-block:: text
4757
+
4758
+ [[0. 1.]
4759
+ [2. 3.]
4760
+ [4. 5.]
4761
+ [6. 7.]]
4762
+ [ 1. 5. 9. 13.]
4763
+ """,
4764
+ group="Tile Primitives",
4765
+ export=False,
4766
+ is_differentiable=False,
4287
4767
  )
4288
4768
 
4289
4769
 
@@ -4347,7 +4827,7 @@ add_builtin(
4347
4827
  """,
4348
4828
  group="Tile Primitives",
4349
4829
  export=False,
4350
- missing_grad=True,
4830
+ is_differentiable=False,
4351
4831
  )
4352
4832
 
4353
4833
 
@@ -4411,7 +4891,7 @@ add_builtin(
4411
4891
  """,
4412
4892
  group="Tile Primitives",
4413
4893
  export=False,
4414
- missing_grad=True,
4894
+ is_differentiable=False,
4415
4895
  )
4416
4896
 
4417
4897
 
@@ -4665,7 +5145,7 @@ add_builtin(
4665
5145
  doc="WIP",
4666
5146
  group="Utility",
4667
5147
  hidden=True,
4668
- missing_grad=True,
5148
+ is_differentiable=False,
4669
5149
  )
4670
5150
 
4671
5151
  add_builtin(
@@ -4681,7 +5161,7 @@ add_builtin(
4681
5161
  doc="WIP",
4682
5162
  group="Utility",
4683
5163
  hidden=True,
4684
- missing_grad=True,
5164
+ is_differentiable=False,
4685
5165
  )
4686
5166
 
4687
5167
  add_builtin(
@@ -4691,7 +5171,7 @@ add_builtin(
4691
5171
  doc="WIP",
4692
5172
  group="Utility",
4693
5173
  hidden=True,
4694
- missing_grad=True,
5174
+ is_differentiable=False,
4695
5175
  )
4696
5176
 
4697
5177
  add_builtin(
@@ -4743,7 +5223,7 @@ add_builtin(
4743
5223
  :param low: The lower bound of the bounding box in BVH space
4744
5224
  :param high: The upper bound of the bounding box in BVH space""",
4745
5225
  export=False,
4746
- missing_grad=True,
5226
+ is_differentiable=False,
4747
5227
  )
4748
5228
 
4749
5229
  add_builtin(
@@ -4759,7 +5239,7 @@ add_builtin(
4759
5239
  :param start: The start of the ray in BVH space
4760
5240
  :param dir: The direction of the ray in BVH space""",
4761
5241
  export=False,
4762
- missing_grad=True,
5242
+ is_differentiable=False,
4763
5243
  )
4764
5244
 
4765
5245
  add_builtin(
@@ -4770,7 +5250,7 @@ add_builtin(
4770
5250
  doc="""Move to the next bound returned by the query.
4771
5251
  The index of the current bound is stored in ``index``, returns ``False`` if there are no more overlapping bound.""",
4772
5252
  export=False,
4773
- missing_grad=True,
5253
+ is_differentiable=False,
4774
5254
  )
4775
5255
 
4776
5256
  add_builtin(
@@ -5111,7 +5591,7 @@ add_builtin(
5111
5591
  :param low: The lower bound of the bounding box in mesh space
5112
5592
  :param high: The upper bound of the bounding box in mesh space""",
5113
5593
  export=False,
5114
- missing_grad=True,
5594
+ is_differentiable=False,
5115
5595
  )
5116
5596
 
5117
5597
  add_builtin(
@@ -5123,7 +5603,7 @@ add_builtin(
5123
5603
 
5124
5604
  The index of the current face is stored in ``index``, returns ``False`` if there are no more overlapping triangles.""",
5125
5605
  export=False,
5126
- missing_grad=True,
5606
+ is_differentiable=False,
5127
5607
  )
5128
5608
 
5129
5609
  add_builtin(
@@ -5153,7 +5633,7 @@ add_builtin(
5153
5633
 
5154
5634
  This query can be used to iterate over all neighboring point within a fixed radius from the query point.""",
5155
5635
  export=False,
5156
- missing_grad=True,
5636
+ is_differentiable=False,
5157
5637
  )
5158
5638
 
5159
5639
  add_builtin(
@@ -5165,7 +5645,7 @@ add_builtin(
5165
5645
 
5166
5646
  The index of the current neighbor is stored in ``index``, returns ``False`` if there are no more neighbors.""",
5167
5647
  export=False,
5168
- missing_grad=True,
5648
+ is_differentiable=False,
5169
5649
  )
5170
5650
 
5171
5651
  add_builtin(
@@ -5179,7 +5659,7 @@ add_builtin(
5179
5659
 
5180
5660
  Returns -1 if the :class:`HashGrid` has not been reserved.""",
5181
5661
  export=False,
5182
- missing_grad=True,
5662
+ is_differentiable=False,
5183
5663
  )
5184
5664
 
5185
5665
  add_builtin(
@@ -5189,16 +5669,34 @@ add_builtin(
5189
5669
  group="Geometry",
5190
5670
  doc="""Tests for intersection between two triangles (v0, v1, v2) and (u0, u1, u2) using Moller's method.
5191
5671
 
5672
+ This function works with single precision, may return incorrect results in some case.
5673
+
5674
+ Returns > 0 if triangles intersect.""",
5675
+ export=False,
5676
+ is_differentiable=False,
5677
+ )
5678
+
5679
+
5680
+ add_builtin(
5681
+ "intersect_tri_tri",
5682
+ input_types={"v0": vec3d, "v1": vec3d, "v2": vec3d, "u0": vec3d, "u1": vec3d, "u2": vec3d},
5683
+ value_type=int,
5684
+ group="Geometry",
5685
+ doc="""Tests for intersection between two triangles (v0, v1, v2) and (u0, u1, u2) using Moller's method.
5686
+
5687
+ This function works with double precision, results are more accurate than the single precision version.
5688
+
5192
5689
  Returns > 0 if triangles intersect.""",
5193
5690
  export=False,
5194
- missing_grad=True,
5691
+ is_differentiable=False,
5195
5692
  )
5196
5693
 
5694
+
5197
5695
  add_builtin(
5198
5696
  "mesh_get",
5199
5697
  input_types={"id": uint64},
5200
5698
  value_type=Mesh,
5201
- missing_grad=True,
5699
+ is_differentiable=False,
5202
5700
  group="Geometry",
5203
5701
  doc="""Retrieves the mesh given its index.""",
5204
5702
  export=False,
@@ -5211,7 +5709,7 @@ add_builtin(
5211
5709
  group="Geometry",
5212
5710
  doc="""Evaluates the face normal the mesh given a face index.""",
5213
5711
  export=False,
5214
- missing_grad=True,
5712
+ is_differentiable=False,
5215
5713
  )
5216
5714
 
5217
5715
  add_builtin(
@@ -5221,7 +5719,7 @@ add_builtin(
5221
5719
  group="Geometry",
5222
5720
  doc="""Returns the point of the mesh given a index.""",
5223
5721
  export=False,
5224
- missing_grad=True,
5722
+ is_differentiable=False,
5225
5723
  )
5226
5724
 
5227
5725
  add_builtin(
@@ -5231,7 +5729,7 @@ add_builtin(
5231
5729
  group="Geometry",
5232
5730
  doc="""Returns the velocity of the mesh given a index.""",
5233
5731
  export=False,
5234
- missing_grad=True,
5732
+ is_differentiable=False,
5235
5733
  )
5236
5734
 
5237
5735
  add_builtin(
@@ -5241,7 +5739,7 @@ add_builtin(
5241
5739
  group="Geometry",
5242
5740
  doc="""Returns the point-index of the mesh given a face-vertex index.""",
5243
5741
  export=False,
5244
- missing_grad=True,
5742
+ is_differentiable=False,
5245
5743
  )
5246
5744
 
5247
5745
 
@@ -5289,7 +5787,7 @@ add_builtin(
5289
5787
  group="Utility",
5290
5788
  export=False,
5291
5789
  hidden=True,
5292
- missing_grad=True,
5790
+ is_differentiable=False,
5293
5791
  )
5294
5792
  add_builtin(
5295
5793
  "iter_next",
@@ -5298,7 +5796,7 @@ add_builtin(
5298
5796
  group="Utility",
5299
5797
  export=False,
5300
5798
  hidden=True,
5301
- missing_grad=True,
5799
+ is_differentiable=False,
5302
5800
  )
5303
5801
  add_builtin(
5304
5802
  "iter_next",
@@ -5307,7 +5805,7 @@ add_builtin(
5307
5805
  group="Utility",
5308
5806
  export=False,
5309
5807
  hidden=True,
5310
- missing_grad=True,
5808
+ is_differentiable=False,
5311
5809
  )
5312
5810
 
5313
5811
  add_builtin(
@@ -5318,7 +5816,7 @@ add_builtin(
5318
5816
  group="Utility",
5319
5817
  doc="""Returns the range in reversed order.""",
5320
5818
  export=False,
5321
- missing_grad=True,
5819
+ is_differentiable=False,
5322
5820
  )
5323
5821
 
5324
5822
  # ---------------------------------
@@ -5338,8 +5836,8 @@ _volume_supported_value_types = {
5338
5836
 
5339
5837
 
5340
5838
  def _is_volume_type_supported(dtype):
5341
- for typ in _volume_supported_value_types:
5342
- if types_equal(typ, dtype):
5839
+ for value_type in _volume_supported_value_types:
5840
+ if types_equal(value_type, dtype):
5343
5841
  return True
5344
5842
  return False
5345
5843
 
@@ -5467,7 +5965,7 @@ add_builtin(
5467
5965
  doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k`` for a volume of type type `dtype`.
5468
5966
 
5469
5967
  If the voxel at this index does not exist, this function returns the background value.""",
5470
- missing_grad=True,
5968
+ is_differentiable=False,
5471
5969
  )
5472
5970
 
5473
5971
 
@@ -5488,7 +5986,7 @@ add_builtin(
5488
5986
  export=False,
5489
5987
  group="Volumes",
5490
5988
  doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
5491
- missing_grad=True,
5989
+ is_differentiable=False,
5492
5990
  )
5493
5991
 
5494
5992
  add_builtin(
@@ -5519,7 +6017,7 @@ add_builtin(
5519
6017
  doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k``.
5520
6018
 
5521
6019
  If the voxel at this index does not exist, this function returns the background value""",
5522
- missing_grad=True,
6020
+ is_differentiable=False,
5523
6021
  )
5524
6022
 
5525
6023
  add_builtin(
@@ -5528,7 +6026,7 @@ add_builtin(
5528
6026
  group="Volumes",
5529
6027
  doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
5530
6028
  export=False,
5531
- missing_grad=True,
6029
+ is_differentiable=False,
5532
6030
  )
5533
6031
 
5534
6032
  add_builtin(
@@ -5549,7 +6047,7 @@ add_builtin(
5549
6047
  doc="""Returns the vector value of voxel with coordinates ``i``, ``j``, ``k``.
5550
6048
 
5551
6049
  If the voxel at this index does not exist, this function returns the background value.""",
5552
- missing_grad=True,
6050
+ is_differentiable=False,
5553
6051
  )
5554
6052
 
5555
6053
  add_builtin(
@@ -5558,7 +6056,7 @@ add_builtin(
5558
6056
  group="Volumes",
5559
6057
  doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
5560
6058
  export=False,
5561
- missing_grad=True,
6059
+ is_differentiable=False,
5562
6060
  )
5563
6061
 
5564
6062
  add_builtin(
@@ -5577,7 +6075,7 @@ add_builtin(
5577
6075
  doc="""Returns the :class:`int32` value of voxel with coordinates ``i``, ``j``, ``k``.
5578
6076
 
5579
6077
  If the voxel at this index does not exist, this function returns the background value.""",
5580
- missing_grad=True,
6078
+ is_differentiable=False,
5581
6079
  )
5582
6080
 
5583
6081
  add_builtin(
@@ -5586,7 +6084,7 @@ add_builtin(
5586
6084
  group="Volumes",
5587
6085
  doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
5588
6086
  export=False,
5589
- missing_grad=True,
6087
+ is_differentiable=False,
5590
6088
  )
5591
6089
 
5592
6090
 
@@ -5668,7 +6166,7 @@ add_builtin(
5668
6166
  If the voxel at this index does not exist, this function returns -1.
5669
6167
  This function is available for both index grids and classical volumes.
5670
6168
  """,
5671
- missing_grad=True,
6169
+ is_differentiable=False,
5672
6170
  )
5673
6171
 
5674
6172
  add_builtin(
@@ -5710,7 +6208,7 @@ add_builtin(
5710
6208
  value_type=uint32,
5711
6209
  group="Random",
5712
6210
  doc="Initialize a new random number generator given a user-defined seed. Returns a 32-bit integer representing the RNG state.",
5713
- missing_grad=True,
6211
+ is_differentiable=False,
5714
6212
  )
5715
6213
 
5716
6214
  add_builtin(
@@ -5722,7 +6220,7 @@ add_builtin(
5722
6220
 
5723
6221
  This alternative constructor can be useful in parallel programs, where a kernel as a whole should share a seed,
5724
6222
  but each thread should generate uncorrelated values. In this case usage should be ``r = rand_init(seed, tid)``""",
5725
- missing_grad=True,
6223
+ is_differentiable=False,
5726
6224
  )
5727
6225
 
5728
6226
  add_builtin(
@@ -5731,7 +6229,7 @@ add_builtin(
5731
6229
  value_type=int,
5732
6230
  group="Random",
5733
6231
  doc="Return a random integer in the range [-2^31, 2^31).",
5734
- missing_grad=True,
6232
+ is_differentiable=False,
5735
6233
  )
5736
6234
  add_builtin(
5737
6235
  "randi",
@@ -5739,7 +6237,7 @@ add_builtin(
5739
6237
  value_type=int,
5740
6238
  group="Random",
5741
6239
  doc="Return a random integer between [low, high).",
5742
- missing_grad=True,
6240
+ is_differentiable=False,
5743
6241
  )
5744
6242
  add_builtin(
5745
6243
  "randu",
@@ -5747,7 +6245,7 @@ add_builtin(
5747
6245
  value_type=uint32,
5748
6246
  group="Random",
5749
6247
  doc="Return a random unsigned integer in the range [0, 2^32).",
5750
- missing_grad=True,
6248
+ is_differentiable=False,
5751
6249
  )
5752
6250
  add_builtin(
5753
6251
  "randu",
@@ -5755,7 +6253,7 @@ add_builtin(
5755
6253
  value_type=uint32,
5756
6254
  group="Random",
5757
6255
  doc="Return a random unsigned integer between [low, high).",
5758
- missing_grad=True,
6256
+ is_differentiable=False,
5759
6257
  )
5760
6258
  add_builtin(
5761
6259
  "randf",
@@ -5763,7 +6261,7 @@ add_builtin(
5763
6261
  value_type=float,
5764
6262
  group="Random",
5765
6263
  doc="Return a random float between [0.0, 1.0).",
5766
- missing_grad=True,
6264
+ is_differentiable=False,
5767
6265
  )
5768
6266
  add_builtin(
5769
6267
  "randf",
@@ -5771,7 +6269,7 @@ add_builtin(
5771
6269
  value_type=float,
5772
6270
  group="Random",
5773
6271
  doc="Return a random float between [low, high).",
5774
- missing_grad=True,
6272
+ is_differentiable=False,
5775
6273
  )
5776
6274
  add_builtin(
5777
6275
  "randn",
@@ -5779,7 +6277,7 @@ add_builtin(
5779
6277
  value_type=float,
5780
6278
  group="Random",
5781
6279
  doc="Sample a normal (Gaussian) distribution of mean 0 and variance 1. ",
5782
- missing_grad=True,
6280
+ is_differentiable=False,
5783
6281
  )
5784
6282
 
5785
6283
  add_builtin(
@@ -5788,7 +6286,7 @@ add_builtin(
5788
6286
  value_type=int,
5789
6287
  group="Random",
5790
6288
  doc="Inverse-transform sample a cumulative distribution function.",
5791
- missing_grad=True,
6289
+ is_differentiable=False,
5792
6290
  )
5793
6291
  add_builtin(
5794
6292
  "sample_triangle",
@@ -5796,7 +6294,7 @@ add_builtin(
5796
6294
  value_type=vec2,
5797
6295
  group="Random",
5798
6296
  doc="Uniformly sample a triangle. Returns sample barycentric coordinates.",
5799
- missing_grad=True,
6297
+ is_differentiable=False,
5800
6298
  )
5801
6299
  add_builtin(
5802
6300
  "sample_unit_ring",
@@ -5804,7 +6302,7 @@ add_builtin(
5804
6302
  value_type=vec2,
5805
6303
  group="Random",
5806
6304
  doc="Uniformly sample a ring in the xy plane.",
5807
- missing_grad=True,
6305
+ is_differentiable=False,
5808
6306
  )
5809
6307
  add_builtin(
5810
6308
  "sample_unit_disk",
@@ -5812,7 +6310,7 @@ add_builtin(
5812
6310
  value_type=vec2,
5813
6311
  group="Random",
5814
6312
  doc="Uniformly sample a disk in the xy plane.",
5815
- missing_grad=True,
6313
+ is_differentiable=False,
5816
6314
  )
5817
6315
  add_builtin(
5818
6316
  "sample_unit_sphere_surface",
@@ -5820,7 +6318,7 @@ add_builtin(
5820
6318
  value_type=vec3,
5821
6319
  group="Random",
5822
6320
  doc="Uniformly sample a unit sphere surface.",
5823
- missing_grad=True,
6321
+ is_differentiable=False,
5824
6322
  )
5825
6323
  add_builtin(
5826
6324
  "sample_unit_sphere",
@@ -5828,7 +6326,7 @@ add_builtin(
5828
6326
  value_type=vec3,
5829
6327
  group="Random",
5830
6328
  doc="Uniformly sample a unit sphere.",
5831
- missing_grad=True,
6329
+ is_differentiable=False,
5832
6330
  )
5833
6331
  add_builtin(
5834
6332
  "sample_unit_hemisphere_surface",
@@ -5836,7 +6334,7 @@ add_builtin(
5836
6334
  value_type=vec3,
5837
6335
  group="Random",
5838
6336
  doc="Uniformly sample a unit hemisphere surface.",
5839
- missing_grad=True,
6337
+ is_differentiable=False,
5840
6338
  )
5841
6339
  add_builtin(
5842
6340
  "sample_unit_hemisphere",
@@ -5844,7 +6342,7 @@ add_builtin(
5844
6342
  value_type=vec3,
5845
6343
  group="Random",
5846
6344
  doc="Uniformly sample a unit hemisphere.",
5847
- missing_grad=True,
6345
+ is_differentiable=False,
5848
6346
  )
5849
6347
  add_builtin(
5850
6348
  "sample_unit_square",
@@ -5852,7 +6350,7 @@ add_builtin(
5852
6350
  value_type=vec2,
5853
6351
  group="Random",
5854
6352
  doc="Uniformly sample a unit square.",
5855
- missing_grad=True,
6353
+ is_differentiable=False,
5856
6354
  )
5857
6355
  add_builtin(
5858
6356
  "sample_unit_cube",
@@ -5860,7 +6358,7 @@ add_builtin(
5860
6358
  value_type=vec3,
5861
6359
  group="Random",
5862
6360
  doc="Uniformly sample a unit cube.",
5863
- missing_grad=True,
6361
+ is_differentiable=False,
5864
6362
  )
5865
6363
 
5866
6364
  add_builtin(
@@ -5872,7 +6370,7 @@ add_builtin(
5872
6370
 
5873
6371
  :param state: RNG state
5874
6372
  :param lam: The expected value of the distribution""",
5875
- missing_grad=True,
6373
+ is_differentiable=False,
5876
6374
  )
5877
6375
 
5878
6376
  add_builtin(
@@ -5940,7 +6438,7 @@ add_builtin(
5940
6438
  value_type=vec2,
5941
6439
  group="Random",
5942
6440
  doc="Divergence-free vector field based on the gradient of a Perlin noise function.",
5943
- missing_grad=True,
6441
+ is_differentiable=False,
5944
6442
  )
5945
6443
  add_builtin(
5946
6444
  "curlnoise",
@@ -5949,7 +6447,7 @@ add_builtin(
5949
6447
  value_type=vec3,
5950
6448
  group="Random",
5951
6449
  doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
5952
- missing_grad=True,
6450
+ is_differentiable=False,
5953
6451
  )
5954
6452
  add_builtin(
5955
6453
  "curlnoise",
@@ -5958,7 +6456,7 @@ add_builtin(
5958
6456
  value_type=vec3,
5959
6457
  group="Random",
5960
6458
  doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
5961
- missing_grad=True,
6459
+ is_differentiable=False,
5962
6460
  )
5963
6461
 
5964
6462
 
@@ -5990,7 +6488,7 @@ add_builtin(
5990
6488
  dispatch_func=printf_dispatch_func,
5991
6489
  group="Utility",
5992
6490
  doc="Allows printing formatted strings using C-style format specifiers.",
5993
- missing_grad=True,
6491
+ is_differentiable=False,
5994
6492
  )
5995
6493
 
5996
6494
  add_builtin(
@@ -6009,7 +6507,7 @@ add_builtin(
6009
6507
  group="Utility",
6010
6508
  namespace="",
6011
6509
  native_func="__debugbreak",
6012
- missing_grad=True,
6510
+ is_differentiable=False,
6013
6511
  )
6014
6512
 
6015
6513
  # helpers
@@ -6027,7 +6525,7 @@ add_builtin(
6027
6525
  This function may not be called from user-defined Warp functions.""",
6028
6526
  namespace="",
6029
6527
  native_func="builtin_tid1d",
6030
- missing_grad=True,
6528
+ is_differentiable=False,
6031
6529
  )
6032
6530
 
6033
6531
  add_builtin(
@@ -6038,7 +6536,7 @@ add_builtin(
6038
6536
  doc="Returns the number of threads in the current block.",
6039
6537
  namespace="",
6040
6538
  native_func="builtin_block_dim",
6041
- missing_grad=True,
6539
+ is_differentiable=False,
6042
6540
  )
6043
6541
 
6044
6542
  add_builtin(
@@ -6053,7 +6551,7 @@ add_builtin(
6053
6551
  This function may not be called from user-defined Warp functions.""",
6054
6552
  namespace="",
6055
6553
  native_func="builtin_tid2d",
6056
- missing_grad=True,
6554
+ is_differentiable=False,
6057
6555
  )
6058
6556
 
6059
6557
  add_builtin(
@@ -6068,7 +6566,7 @@ add_builtin(
6068
6566
  This function may not be called from user-defined Warp functions.""",
6069
6567
  namespace="",
6070
6568
  native_func="builtin_tid3d",
6071
- missing_grad=True,
6569
+ is_differentiable=False,
6072
6570
  )
6073
6571
 
6074
6572
  add_builtin(
@@ -6083,7 +6581,7 @@ add_builtin(
6083
6581
  This function may not be called from user-defined Warp functions.""",
6084
6582
  namespace="",
6085
6583
  native_func="builtin_tid4d",
6086
- missing_grad=True,
6584
+ is_differentiable=False,
6087
6585
  )
6088
6586
 
6089
6587
 
@@ -6127,56 +6625,20 @@ def select_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
6127
6625
  if arg_types is None:
6128
6626
  return Any
6129
6627
 
6130
- v_true = arg_types["value_if_true"]
6131
- v_false = arg_types["value_if_false"]
6628
+ raise RuntimeError("wp.select() has been removed. Use wp.where(cond, value_if_true, value_if_false) instead.")
6132
6629
 
6133
- if not types_equal(v_true, v_false):
6134
- raise RuntimeError(
6135
- f"select() true value type ({v_true}) must be of the same type as the false type ({v_false})"
6136
- )
6137
6630
 
6138
- if is_tile(v_false):
6139
- if v_true.storage == "register":
6140
- return v_true
6141
- if v_false.storage == "register":
6142
- return v_false
6631
+ add_builtin(
6632
+ "select",
6633
+ input_types={"cond": builtins.bool, "value_if_false": Any, "value_if_true": Any},
6634
+ value_func=select_value_func,
6635
+ doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
6143
6636
 
6144
- # both v_true and v_false are shared
6145
- return tile(
6146
- dtype=v_true.dtype,
6147
- shape=v_true.shape,
6148
- storage=v_true.storage,
6149
- strides=v_true.strides,
6150
- layout=v_true.layout,
6151
- owner=True,
6152
- )
6153
-
6154
- return v_true
6155
-
6156
-
6157
- def select_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
6158
- warp.utils.warn(
6159
- "wp.select() is deprecated and will be removed in a future\n"
6160
- "version. Use wp.where(cond, value_if_true, value_if_false) instead.",
6161
- category=DeprecationWarning,
6162
- )
6163
-
6164
- func_args = tuple(args.values())
6165
- template_args = ()
6166
-
6167
- return (func_args, template_args)
6168
-
6169
-
6170
- add_builtin(
6171
- "select",
6172
- input_types={"cond": builtins.bool, "value_if_false": Any, "value_if_true": Any},
6173
- value_func=select_value_func,
6174
- dispatch_func=select_dispatch_func,
6175
- doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
6176
-
6177
- .. deprecated:: 1.7
6637
+ .. versionremoved:: 1.10
6178
6638
  Use :func:`where` instead, which has the more intuitive argument order:
6179
- ``where(cond, value_if_true, value_if_false)``.""",
6639
+ ``where(cond, value_if_true, value_if_false)``.
6640
+
6641
+ .. deprecated:: 1.7""",
6180
6642
  group="Utility",
6181
6643
  )
6182
6644
  for t in int_types:
@@ -6184,24 +6646,26 @@ for t in int_types:
6184
6646
  "select",
6185
6647
  input_types={"cond": t, "value_if_false": Any, "value_if_true": Any},
6186
6648
  value_func=select_value_func,
6187
- dispatch_func=select_dispatch_func,
6188
6649
  doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
6189
6650
 
6190
- .. deprecated:: 1.7
6651
+ .. versionremoved:: 1.10
6191
6652
  Use :func:`where` instead, which has the more intuitive argument order:
6192
- ``where(cond, value_if_true, value_if_false)``.""",
6653
+ ``where(cond, value_if_true, value_if_false)``.
6654
+
6655
+ .. deprecated:: 1.7""",
6193
6656
  group="Utility",
6194
6657
  )
6195
6658
  add_builtin(
6196
6659
  "select",
6197
6660
  input_types={"arr": array(dtype=Any), "value_if_false": Any, "value_if_true": Any},
6198
6661
  value_func=select_value_func,
6199
- dispatch_func=select_dispatch_func,
6200
6662
  doc="""Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``.
6201
6663
 
6202
- .. deprecated:: 1.7
6664
+ .. versionremoved:: 1.10
6203
6665
  Use :func:`where` instead, which has the more intuitive argument order:
6204
- ``where(arr, value_if_true, value_if_false)``.""",
6666
+ ``where(arr, value_if_true, value_if_false)``.
6667
+
6668
+ .. deprecated:: 1.7""",
6205
6669
  group="Utility",
6206
6670
  )
6207
6671
 
@@ -6291,7 +6755,7 @@ add_builtin(
6291
6755
  group="Utility",
6292
6756
  hidden=True,
6293
6757
  export=False,
6294
- missing_grad=True,
6758
+ is_differentiable=False,
6295
6759
  )
6296
6760
 
6297
6761
 
@@ -6332,7 +6796,7 @@ add_builtin(
6332
6796
  native_func="fixedarray_t",
6333
6797
  group="Utility",
6334
6798
  export=False,
6335
- missing_grad=True,
6799
+ is_differentiable=False,
6336
6800
  hidden=True, # Unhide once we can document both a built-in and a Python scope function sharing the same name.
6337
6801
  )
6338
6802
 
@@ -6375,14 +6839,13 @@ for array_type in array_types:
6375
6839
  # does argument checking and type propagation for view()
6376
6840
  def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6377
6841
  arr_type = arg_types["arr"]
6378
- idx_types = tuple(arg_types[x] for x in "ijk" if arg_types.get(x, None) is not None)
6842
+ idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
6379
6843
 
6380
6844
  if not is_array(arr_type):
6381
6845
  raise RuntimeError("view() first argument must be an array")
6382
6846
 
6383
6847
  idx_count = len(idx_types)
6384
-
6385
- if idx_count >= arr_type.ndim:
6848
+ if idx_count > arr_type.ndim:
6386
6849
  raise RuntimeError(
6387
6850
  f"Trying to create an array view with {idx_count} indices, "
6388
6851
  f"but the array only has {arr_type.ndim} dimension(s). "
@@ -6390,14 +6853,35 @@ def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]
6390
6853
  f"the expected number of dimensions, e.g.: def func(param: wp.array3d(dtype=float): ..."
6391
6854
  )
6392
6855
 
6393
- # check index types
6394
- for t in idx_types:
6395
- if not type_is_int(t):
6396
- raise RuntimeError(f"view() index arguments must be of integer type, got index of type {type_repr(t)}")
6856
+ has_slice = any(is_slice(x) for x in idx_types)
6857
+ if has_slice:
6858
+ # check index types
6859
+ for t in idx_types:
6860
+ if not (type_is_int(t) or is_slice(t)):
6861
+ raise RuntimeError(
6862
+ f"view() index arguments must be of integer or slice types, got index of type {type_repr(t)}"
6863
+ )
6864
+
6865
+ # Each integer index collapses one dimension.
6866
+ int_count = sum(x.step == 0 for x in idx_types)
6867
+ ndim = arr_type.ndim - int_count
6868
+ assert ndim > 0
6869
+ else:
6870
+ if idx_count == arr_type.ndim:
6871
+ raise RuntimeError("Expected to call `address()` instead of `view()`")
6872
+
6873
+ # check index types
6874
+ for t in idx_types:
6875
+ if not type_is_int(t):
6876
+ raise RuntimeError(
6877
+ f"view() index arguments must be of integer or slice types, got index of type {type_repr(t)}"
6878
+ )
6879
+
6880
+ # create an array view with leading dimensions removed
6881
+ ndim = arr_type.ndim - idx_count
6882
+ assert ndim > 0
6397
6883
 
6398
- # create an array view with leading dimensions removed
6399
6884
  dtype = arr_type.dtype
6400
- ndim = arr_type.ndim - idx_count
6401
6885
  if isinstance(arr_type, (fabricarray, indexedfabricarray)):
6402
6886
  # fabric array of arrays: return array attribute as a regular array
6403
6887
  return array(dtype=dtype, ndim=ndim)
@@ -6408,8 +6892,18 @@ def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]
6408
6892
  for array_type in array_types:
6409
6893
  add_builtin(
6410
6894
  "view",
6411
- input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int},
6412
- defaults={"j": None, "k": None},
6895
+ input_types={
6896
+ "arr": array_type(dtype=Any),
6897
+ "i": Any,
6898
+ "j": Any,
6899
+ "k": Any,
6900
+ "l": Any,
6901
+ },
6902
+ defaults={
6903
+ "j": None,
6904
+ "k": None,
6905
+ "l": None,
6906
+ },
6413
6907
  constraint=sametypes,
6414
6908
  hidden=True,
6415
6909
  value_func=view_value_func,
@@ -6513,7 +7007,7 @@ add_builtin(
6513
7007
  hidden=True,
6514
7008
  skip_replay=True,
6515
7009
  group="Utility",
6516
- missing_grad=True,
7010
+ is_differentiable=False,
6517
7011
  )
6518
7012
 
6519
7013
 
@@ -6530,7 +7024,7 @@ add_builtin(
6530
7024
  dispatch_func=load_dispatch_func,
6531
7025
  hidden=True,
6532
7026
  group="Utility",
6533
- missing_grad=True,
7027
+ is_differentiable=False,
6534
7028
  )
6535
7029
 
6536
7030
 
@@ -6606,6 +7100,13 @@ def create_atomic_op_value_func(op: str):
6606
7100
  f"atomic_{op}() operations only work on arrays with [u]int32, [u]int64, float32, or float64 "
6607
7101
  f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
6608
7102
  )
7103
+ elif op in ("and", "or", "xor"):
7104
+ supported_atomic_types = (warp.int32, warp.int64, warp.uint32, warp.uint64)
7105
+ if not any(types_equal(scalar_type, x, match_generic=True) for x in supported_atomic_types):
7106
+ raise RuntimeError(
7107
+ f"atomic_{op}() operations only work on arrays with [u]int32 or [u]int64 "
7108
+ f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
7109
+ )
6609
7110
  else:
6610
7111
  raise NotImplementedError
6611
7112
 
@@ -6639,7 +7140,8 @@ for array_type in array_types:
6639
7140
  value_func=create_atomic_op_value_func("add"),
6640
7141
  dispatch_func=atomic_op_dispatch_func,
6641
7142
  doc="""Atomically adds ``value`` onto ``arr[i]`` and returns the original value of ``arr[i]``.
6642
- This function is automatically invoked when using the syntax ``arr[i] += value``.""",
7143
+
7144
+ This function is automatically invoked when using the syntax ``arr[i] += value``.""",
6643
7145
  group="Utility",
6644
7146
  skip_replay=True,
6645
7147
  )
@@ -6651,7 +7153,8 @@ for array_type in array_types:
6651
7153
  value_func=create_atomic_op_value_func("add"),
6652
7154
  dispatch_func=atomic_op_dispatch_func,
6653
7155
  doc="""Atomically adds ``value`` onto ``arr[i,j]`` and returns the original value of ``arr[i,j]``.
6654
- This function is automatically invoked when using the syntax ``arr[i,j] += value``.""",
7156
+
7157
+ This function is automatically invoked when using the syntax ``arr[i,j] += value``.""",
6655
7158
  group="Utility",
6656
7159
  skip_replay=True,
6657
7160
  )
@@ -6663,7 +7166,8 @@ for array_type in array_types:
6663
7166
  value_func=create_atomic_op_value_func("add"),
6664
7167
  dispatch_func=atomic_op_dispatch_func,
6665
7168
  doc="""Atomically adds ``value`` onto ``arr[i,j,k]`` and returns the original value of ``arr[i,j,k]``.
6666
- This function is automatically invoked when using the syntax ``arr[i,j,k] += value``.""",
7169
+
7170
+ This function is automatically invoked when using the syntax ``arr[i,j,k] += value``.""",
6667
7171
  group="Utility",
6668
7172
  skip_replay=True,
6669
7173
  )
@@ -6675,7 +7179,8 @@ for array_type in array_types:
6675
7179
  value_func=create_atomic_op_value_func("add"),
6676
7180
  dispatch_func=atomic_op_dispatch_func,
6677
7181
  doc="""Atomically adds ``value`` onto ``arr[i,j,k,l]`` and returns the original value of ``arr[i,j,k,l]``.
6678
- This function is automatically invoked when using the syntax ``arr[i,j,k,l] += value``.""",
7182
+
7183
+ This function is automatically invoked when using the syntax ``arr[i,j,k,l] += value``.""",
6679
7184
  group="Utility",
6680
7185
  skip_replay=True,
6681
7186
  )
@@ -6688,7 +7193,8 @@ for array_type in array_types:
6688
7193
  value_func=create_atomic_op_value_func("sub"),
6689
7194
  dispatch_func=atomic_op_dispatch_func,
6690
7195
  doc="""Atomically subtracts ``value`` onto ``arr[i]`` and returns the original value of ``arr[i]``.
6691
- This function is automatically invoked when using the syntax ``arr[i] -= value``.""",
7196
+
7197
+ This function is automatically invoked when using the syntax ``arr[i] -= value``.""",
6692
7198
  group="Utility",
6693
7199
  skip_replay=True,
6694
7200
  )
@@ -6700,7 +7206,8 @@ for array_type in array_types:
6700
7206
  value_func=create_atomic_op_value_func("sub"),
6701
7207
  dispatch_func=atomic_op_dispatch_func,
6702
7208
  doc="""Atomically subtracts ``value`` onto ``arr[i,j]`` and returns the original value of ``arr[i,j]``.
6703
- This function is automatically invoked when using the syntax ``arr[i,j] -= value``.""",
7209
+
7210
+ This function is automatically invoked when using the syntax ``arr[i,j] -= value``.""",
6704
7211
  group="Utility",
6705
7212
  skip_replay=True,
6706
7213
  )
@@ -6712,7 +7219,8 @@ for array_type in array_types:
6712
7219
  value_func=create_atomic_op_value_func("sub"),
6713
7220
  dispatch_func=atomic_op_dispatch_func,
6714
7221
  doc="""Atomically subtracts ``value`` onto ``arr[i,j,k]`` and returns the original value of ``arr[i,j,k]``.
6715
- This function is automatically invoked when using the syntax ``arr[i,j,k] -= value``.""",
7222
+
7223
+ This function is automatically invoked when using the syntax ``arr[i,j,k] -= value``.""",
6716
7224
  group="Utility",
6717
7225
  skip_replay=True,
6718
7226
  )
@@ -6724,7 +7232,8 @@ for array_type in array_types:
6724
7232
  value_func=create_atomic_op_value_func("sub"),
6725
7233
  dispatch_func=atomic_op_dispatch_func,
6726
7234
  doc="""Atomically subtracts ``value`` onto ``arr[i,j,k,l]`` and returns the original value of ``arr[i,j,k,l]``.
6727
- This function is automatically invoked when using the syntax ``arr[i,j,k,l] -= value``.""",
7235
+
7236
+ This function is automatically invoked when using the syntax ``arr[i,j,k,l] -= value``.""",
6728
7237
  group="Utility",
6729
7238
  skip_replay=True,
6730
7239
  )
@@ -6847,7 +7356,7 @@ for array_type in array_types:
6847
7356
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6848
7357
  group="Utility",
6849
7358
  skip_replay=True,
6850
- missing_grad=True,
7359
+ is_differentiable=False,
6851
7360
  )
6852
7361
  add_builtin(
6853
7362
  "atomic_cas",
@@ -6861,7 +7370,7 @@ for array_type in array_types:
6861
7370
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6862
7371
  group="Utility",
6863
7372
  skip_replay=True,
6864
- missing_grad=True,
7373
+ is_differentiable=False,
6865
7374
  )
6866
7375
  add_builtin(
6867
7376
  "atomic_cas",
@@ -6875,7 +7384,7 @@ for array_type in array_types:
6875
7384
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6876
7385
  group="Utility",
6877
7386
  skip_replay=True,
6878
- missing_grad=True,
7387
+ is_differentiable=False,
6879
7388
  )
6880
7389
  add_builtin(
6881
7390
  "atomic_cas",
@@ -6897,7 +7406,7 @@ for array_type in array_types:
6897
7406
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6898
7407
  group="Utility",
6899
7408
  skip_replay=True,
6900
- missing_grad=True,
7409
+ is_differentiable=False,
6901
7410
  )
6902
7411
 
6903
7412
  add_builtin(
@@ -6912,7 +7421,7 @@ for array_type in array_types:
6912
7421
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6913
7422
  group="Utility",
6914
7423
  skip_replay=True,
6915
- missing_grad=True,
7424
+ is_differentiable=False,
6916
7425
  )
6917
7426
  add_builtin(
6918
7427
  "atomic_exch",
@@ -6926,7 +7435,7 @@ for array_type in array_types:
6926
7435
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6927
7436
  group="Utility",
6928
7437
  skip_replay=True,
6929
- missing_grad=True,
7438
+ is_differentiable=False,
6930
7439
  )
6931
7440
  add_builtin(
6932
7441
  "atomic_exch",
@@ -6940,7 +7449,7 @@ for array_type in array_types:
6940
7449
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6941
7450
  group="Utility",
6942
7451
  skip_replay=True,
6943
- missing_grad=True,
7452
+ is_differentiable=False,
6944
7453
  )
6945
7454
  add_builtin(
6946
7455
  "atomic_exch",
@@ -6956,6 +7465,177 @@ for array_type in array_types:
6956
7465
  skip_replay=True,
6957
7466
  )
6958
7467
 
7468
+ add_builtin(
7469
+ "atomic_and",
7470
+ hidden=hidden,
7471
+ input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
7472
+ constraint=atomic_op_constraint,
7473
+ value_func=create_atomic_op_value_func("and"),
7474
+ dispatch_func=atomic_op_dispatch_func,
7475
+ doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i]``, atomically update the array, and return the old value.
7476
+
7477
+ This function is automatically invoked when using the syntax ``arr[i] &= value``.""",
7478
+ group="Utility",
7479
+ skip_replay=True,
7480
+ is_differentiable=False,
7481
+ )
7482
+ add_builtin(
7483
+ "atomic_and",
7484
+ hidden=hidden,
7485
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
7486
+ constraint=atomic_op_constraint,
7487
+ value_func=create_atomic_op_value_func("and"),
7488
+ dispatch_func=atomic_op_dispatch_func,
7489
+ doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
7490
+
7491
+ This function is automatically invoked when using the syntax ``arr[i,j] &= value``.""",
7492
+ group="Utility",
7493
+ skip_replay=True,
7494
+ is_differentiable=False,
7495
+ )
7496
+ add_builtin(
7497
+ "atomic_and",
7498
+ hidden=hidden,
7499
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
7500
+ constraint=atomic_op_constraint,
7501
+ value_func=create_atomic_op_value_func("and"),
7502
+ dispatch_func=atomic_op_dispatch_func,
7503
+ doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
7504
+
7505
+ This function is automatically invoked when using the syntax ``arr[i,j,k] &= value``.""",
7506
+ group="Utility",
7507
+ skip_replay=True,
7508
+ is_differentiable=False,
7509
+ )
7510
+ add_builtin(
7511
+ "atomic_and",
7512
+ hidden=hidden,
7513
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
7514
+ constraint=atomic_op_constraint,
7515
+ value_func=create_atomic_op_value_func("and"),
7516
+ dispatch_func=atomic_op_dispatch_func,
7517
+ doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
7518
+
7519
+ This function is automatically invoked when using the syntax ``arr[i,j,k,l] &= value``.""",
7520
+ group="Utility",
7521
+ skip_replay=True,
7522
+ is_differentiable=False,
7523
+ )
7524
+
7525
+ add_builtin(
7526
+ "atomic_or",
7527
+ hidden=hidden,
7528
+ input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
7529
+ constraint=atomic_op_constraint,
7530
+ value_func=create_atomic_op_value_func("or"),
7531
+ dispatch_func=atomic_op_dispatch_func,
7532
+ doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i]``, atomically update the array, and return the old value.
7533
+
7534
+ This function is automatically invoked when using the syntax ``arr[i] |= value``.""",
7535
+ group="Utility",
7536
+ skip_replay=True,
7537
+ is_differentiable=False,
7538
+ )
7539
+ add_builtin(
7540
+ "atomic_or",
7541
+ hidden=hidden,
7542
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
7543
+ constraint=atomic_op_constraint,
7544
+ value_func=create_atomic_op_value_func("or"),
7545
+ dispatch_func=atomic_op_dispatch_func,
7546
+ doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
7547
+
7548
+ This function is automatically invoked when using the syntax ``arr[i,j] |= value``.""",
7549
+ group="Utility",
7550
+ skip_replay=True,
7551
+ is_differentiable=False,
7552
+ )
7553
+ add_builtin(
7554
+ "atomic_or",
7555
+ hidden=hidden,
7556
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
7557
+ constraint=atomic_op_constraint,
7558
+ value_func=create_atomic_op_value_func("or"),
7559
+ dispatch_func=atomic_op_dispatch_func,
7560
+ doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
7561
+
7562
+ This function is automatically invoked when using the syntax ``arr[i,j,k] |= value``.""",
7563
+ group="Utility",
7564
+ skip_replay=True,
7565
+ is_differentiable=False,
7566
+ )
7567
+ add_builtin(
7568
+ "atomic_or",
7569
+ hidden=hidden,
7570
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
7571
+ constraint=atomic_op_constraint,
7572
+ value_func=create_atomic_op_value_func("or"),
7573
+ dispatch_func=atomic_op_dispatch_func,
7574
+ doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
7575
+
7576
+ This function is automatically invoked when using the syntax ``arr[i,j,k,l] |= value``.""",
7577
+ group="Utility",
7578
+ skip_replay=True,
7579
+ is_differentiable=False,
7580
+ )
7581
+
7582
+ add_builtin(
7583
+ "atomic_xor",
7584
+ hidden=hidden,
7585
+ input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
7586
+ constraint=atomic_op_constraint,
7587
+ value_func=create_atomic_op_value_func("xor"),
7588
+ dispatch_func=atomic_op_dispatch_func,
7589
+ doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i]``, atomically update the array, and return the old value.
7590
+
7591
+ This function is automatically invoked when using the syntax ``arr[i] ^= value``.""",
7592
+ group="Utility",
7593
+ skip_replay=True,
7594
+ is_differentiable=False,
7595
+ )
7596
+ add_builtin(
7597
+ "atomic_xor",
7598
+ hidden=hidden,
7599
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
7600
+ constraint=atomic_op_constraint,
7601
+ value_func=create_atomic_op_value_func("xor"),
7602
+ dispatch_func=atomic_op_dispatch_func,
7603
+ doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
7604
+
7605
+ This function is automatically invoked when using the syntax ``arr[i,j] ^= value``.""",
7606
+ group="Utility",
7607
+ skip_replay=True,
7608
+ is_differentiable=False,
7609
+ )
7610
+ add_builtin(
7611
+ "atomic_xor",
7612
+ hidden=hidden,
7613
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
7614
+ constraint=atomic_op_constraint,
7615
+ value_func=create_atomic_op_value_func("xor"),
7616
+ dispatch_func=atomic_op_dispatch_func,
7617
+ doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
7618
+
7619
+ This function is automatically invoked when using the syntax ``arr[i,j,k] ^= value``.""",
7620
+ group="Utility",
7621
+ skip_replay=True,
7622
+ is_differentiable=False,
7623
+ )
7624
+ add_builtin(
7625
+ "atomic_xor",
7626
+ hidden=hidden,
7627
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
7628
+ constraint=atomic_op_constraint,
7629
+ value_func=create_atomic_op_value_func("xor"),
7630
+ dispatch_func=atomic_op_dispatch_func,
7631
+ doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
7632
+
7633
+ This function is automatically invoked when using the syntax ``arr[i,j,k,l] ^= value``.""",
7634
+ group="Utility",
7635
+ skip_replay=True,
7636
+ is_differentiable=False,
7637
+ )
7638
+
6959
7639
 
6960
7640
  # used to index into builtin types, i.e.: y = vec3[1]
6961
7641
  def vector_extract_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
@@ -7104,7 +7784,7 @@ add_builtin(
7104
7784
  hidden=True,
7105
7785
  group="Utility",
7106
7786
  skip_replay=True,
7107
- missing_grad=True,
7787
+ is_differentiable=False,
7108
7788
  )
7109
7789
  # implements &quaternion[index]
7110
7790
  add_builtin(
@@ -7115,7 +7795,7 @@ add_builtin(
7115
7795
  hidden=True,
7116
7796
  group="Utility",
7117
7797
  skip_replay=True,
7118
- missing_grad=True,
7798
+ is_differentiable=False,
7119
7799
  )
7120
7800
  # implements &transformation[index]
7121
7801
  add_builtin(
@@ -7126,7 +7806,7 @@ add_builtin(
7126
7806
  hidden=True,
7127
7807
  group="Utility",
7128
7808
  skip_replay=True,
7129
- missing_grad=True,
7809
+ is_differentiable=False,
7130
7810
  )
7131
7811
  # implements &(*vector)[index]
7132
7812
  add_builtin(
@@ -7137,7 +7817,7 @@ add_builtin(
7137
7817
  hidden=True,
7138
7818
  group="Utility",
7139
7819
  skip_replay=True,
7140
- missing_grad=True,
7820
+ is_differentiable=False,
7141
7821
  )
7142
7822
  # implements &(*matrix)[i, j]
7143
7823
  add_builtin(
@@ -7148,7 +7828,7 @@ add_builtin(
7148
7828
  hidden=True,
7149
7829
  group="Utility",
7150
7830
  skip_replay=True,
7151
- missing_grad=True,
7831
+ is_differentiable=False,
7152
7832
  )
7153
7833
  # implements &(*quaternion)[index]
7154
7834
  add_builtin(
@@ -7159,7 +7839,7 @@ add_builtin(
7159
7839
  hidden=True,
7160
7840
  group="Utility",
7161
7841
  skip_replay=True,
7162
- missing_grad=True,
7842
+ is_differentiable=False,
7163
7843
  )
7164
7844
  # implements &(*transformation)[index]
7165
7845
  add_builtin(
@@ -7170,7 +7850,7 @@ add_builtin(
7170
7850
  hidden=True,
7171
7851
  group="Utility",
7172
7852
  skip_replay=True,
7173
- missing_grad=True,
7853
+ is_differentiable=False,
7174
7854
  )
7175
7855
 
7176
7856
 
@@ -7366,6 +8046,43 @@ add_builtin(
7366
8046
  )
7367
8047
 
7368
8048
 
8049
+ # implements vector[idx] &= scalar
8050
+ add_builtin(
8051
+ "bit_and_inplace",
8052
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
8053
+ value_type=None,
8054
+ dispatch_func=vector_assign_dispatch_func,
8055
+ hidden=True,
8056
+ export=False,
8057
+ group="Utility",
8058
+ is_differentiable=False,
8059
+ )
8060
+
8061
+ # implements vector[idx] |= scalar
8062
+ add_builtin(
8063
+ "bit_or_inplace",
8064
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
8065
+ value_type=None,
8066
+ dispatch_func=vector_assign_dispatch_func,
8067
+ hidden=True,
8068
+ export=False,
8069
+ group="Utility",
8070
+ is_differentiable=False,
8071
+ )
8072
+
8073
+ # implements vector[idx] ^= scalar
8074
+ add_builtin(
8075
+ "bit_xor_inplace",
8076
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
8077
+ value_type=None,
8078
+ dispatch_func=vector_assign_dispatch_func,
8079
+ hidden=True,
8080
+ export=False,
8081
+ group="Utility",
8082
+ is_differentiable=False,
8083
+ )
8084
+
8085
+
7369
8086
  def matrix_index_row_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
7370
8087
  mat_type = arg_types["a"]
7371
8088
  row_type = mat_type._wp_row_type_
@@ -7381,7 +8098,7 @@ add_builtin(
7381
8098
  hidden=True,
7382
8099
  group="Utility",
7383
8100
  skip_replay=True,
7384
- missing_grad=True,
8101
+ is_differentiable=False,
7385
8102
  )
7386
8103
 
7387
8104
 
@@ -7400,7 +8117,7 @@ add_builtin(
7400
8117
  hidden=True,
7401
8118
  group="Utility",
7402
8119
  skip_replay=True,
7403
- missing_grad=True,
8120
+ is_differentiable=False,
7404
8121
  )
7405
8122
 
7406
8123
 
@@ -7600,6 +8317,78 @@ add_builtin(
7600
8317
  )
7601
8318
 
7602
8319
 
8320
+ # implements matrix[i] &= value
8321
+ add_builtin(
8322
+ "bit_and_inplace",
8323
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
8324
+ value_type=None,
8325
+ hidden=True,
8326
+ export=False,
8327
+ group="Utility",
8328
+ is_differentiable=False,
8329
+ )
8330
+
8331
+
8332
+ # implements matrix[i,j] &= value
8333
+ add_builtin(
8334
+ "bit_and_inplace",
8335
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
8336
+ value_type=None,
8337
+ hidden=True,
8338
+ export=False,
8339
+ group="Utility",
8340
+ is_differentiable=False,
8341
+ )
8342
+
8343
+
8344
+ # implements matrix[i] |= value
8345
+ add_builtin(
8346
+ "bit_or_inplace",
8347
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
8348
+ value_type=None,
8349
+ hidden=True,
8350
+ export=False,
8351
+ group="Utility",
8352
+ is_differentiable=False,
8353
+ )
8354
+
8355
+
8356
+ # implements matrix[i,j] |= value
8357
+ add_builtin(
8358
+ "bit_or_inplace",
8359
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
8360
+ value_type=None,
8361
+ hidden=True,
8362
+ export=False,
8363
+ group="Utility",
8364
+ is_differentiable=False,
8365
+ )
8366
+
8367
+
8368
+ # implements matrix[i] ^= value
8369
+ add_builtin(
8370
+ "bit_xor_inplace",
8371
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
8372
+ value_type=None,
8373
+ hidden=True,
8374
+ export=False,
8375
+ group="Utility",
8376
+ is_differentiable=False,
8377
+ )
8378
+
8379
+
8380
+ # implements matrix[i,j] ^= value
8381
+ add_builtin(
8382
+ "bit_xor_inplace",
8383
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
8384
+ value_type=None,
8385
+ hidden=True,
8386
+ export=False,
8387
+ group="Utility",
8388
+ is_differentiable=False,
8389
+ )
8390
+
8391
+
7603
8392
  for t in scalar_types + vector_types + (bool,):
7604
8393
  if "vec" in t.__name__ or "mat" in t.__name__:
7605
8394
  continue
@@ -7611,7 +8400,7 @@ for t in scalar_types + vector_types + (bool,):
7611
8400
  doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
7612
8401
  group="Utility",
7613
8402
  hidden=True,
7614
- missing_grad=True,
8403
+ is_differentiable=False,
7615
8404
  )
7616
8405
 
7617
8406
  add_builtin(
@@ -7622,7 +8411,7 @@ for t in scalar_types + vector_types + (bool,):
7622
8411
  group="Utility",
7623
8412
  hidden=True,
7624
8413
  export=False,
7625
- missing_grad=True,
8414
+ is_differentiable=False,
7626
8415
  )
7627
8416
 
7628
8417
 
@@ -7641,7 +8430,7 @@ add_builtin(
7641
8430
  doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
7642
8431
  group="Utility",
7643
8432
  hidden=True,
7644
- missing_grad=True,
8433
+ is_differentiable=False,
7645
8434
  )
7646
8435
  add_builtin(
7647
8436
  "expect_neq",
@@ -7652,7 +8441,7 @@ add_builtin(
7652
8441
  group="Utility",
7653
8442
  hidden=True,
7654
8443
  export=False,
7655
- missing_grad=True,
8444
+ is_differentiable=False,
7656
8445
  )
7657
8446
 
7658
8447
  add_builtin(
@@ -7663,7 +8452,7 @@ add_builtin(
7663
8452
  doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
7664
8453
  group="Utility",
7665
8454
  hidden=True,
7666
- missing_grad=True,
8455
+ is_differentiable=False,
7667
8456
  )
7668
8457
  add_builtin(
7669
8458
  "expect_neq",
@@ -7674,7 +8463,7 @@ add_builtin(
7674
8463
  group="Utility",
7675
8464
  hidden=True,
7676
8465
  export=False,
7677
- missing_grad=True,
8466
+ is_differentiable=False,
7678
8467
  )
7679
8468
 
7680
8469
  add_builtin(
@@ -7765,7 +8554,7 @@ add_builtin(
7765
8554
  value_type=None,
7766
8555
  doc="Prints an error to stdout if ``a`` and ``b`` are not closer than tolerance in magnitude",
7767
8556
  group="Utility",
7768
- missing_grad=True,
8557
+ is_differentiable=False,
7769
8558
  )
7770
8559
  add_builtin(
7771
8560
  "expect_near",
@@ -7775,7 +8564,7 @@ add_builtin(
7775
8564
  value_type=None,
7776
8565
  doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
7777
8566
  group="Utility",
7778
- missing_grad=True,
8567
+ is_differentiable=False,
7779
8568
  )
7780
8569
  add_builtin(
7781
8570
  "expect_near",
@@ -7785,7 +8574,7 @@ add_builtin(
7785
8574
  value_type=None,
7786
8575
  doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
7787
8576
  group="Utility",
7788
- missing_grad=True,
8577
+ is_differentiable=False,
7789
8578
  )
7790
8579
  add_builtin(
7791
8580
  "expect_near",
@@ -7799,7 +8588,7 @@ add_builtin(
7799
8588
  value_type=None,
7800
8589
  doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
7801
8590
  group="Utility",
7802
- missing_grad=True,
8591
+ is_differentiable=False,
7803
8592
  )
7804
8593
 
7805
8594
  # ---------------------------------
@@ -7810,7 +8599,7 @@ add_builtin(
7810
8599
  input_types={"arr": array(dtype=Scalar), "value": Scalar},
7811
8600
  value_type=int,
7812
8601
  doc="Search a sorted array ``arr`` for the closest element greater than or equal to ``value``.",
7813
- missing_grad=True,
8602
+ is_differentiable=False,
7814
8603
  )
7815
8604
 
7816
8605
  add_builtin(
@@ -7818,7 +8607,7 @@ add_builtin(
7818
8607
  input_types={"arr": array(dtype=Scalar), "arr_begin": int, "arr_end": int, "value": Scalar},
7819
8608
  value_type=int,
7820
8609
  doc="Search a sorted array ``arr`` in the range [arr_begin, arr_end) for the closest element greater than or equal to ``value``.",
7821
- missing_grad=True,
8610
+ is_differentiable=False,
7822
8611
  )
7823
8612
 
7824
8613
  # ---------------------------------
@@ -7899,31 +8688,153 @@ add_builtin(
7899
8688
  input_types={"a": Int, "b": Int},
7900
8689
  value_func=sametypes_create_value_func(Int),
7901
8690
  group="Operators",
7902
- missing_grad=True,
8691
+ is_differentiable=False,
7903
8692
  )
8693
+ add_builtin(
8694
+ "bit_and",
8695
+ input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
8696
+ constraint=sametypes,
8697
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
8698
+ doc="",
8699
+ group="Operators",
8700
+ is_differentiable=False,
8701
+ )
8702
+ add_builtin(
8703
+ "bit_and",
8704
+ input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
8705
+ constraint=sametypes,
8706
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
8707
+ doc="",
8708
+ group="Operators",
8709
+ is_differentiable=False,
8710
+ )
8711
+
7904
8712
  add_builtin(
7905
8713
  "bit_or",
7906
8714
  input_types={"a": Int, "b": Int},
7907
8715
  value_func=sametypes_create_value_func(Int),
7908
8716
  group="Operators",
7909
- missing_grad=True,
8717
+ is_differentiable=False,
7910
8718
  )
8719
+ add_builtin(
8720
+ "bit_or",
8721
+ input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
8722
+ constraint=sametypes,
8723
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
8724
+ doc="",
8725
+ group="Operators",
8726
+ is_differentiable=False,
8727
+ )
8728
+ add_builtin(
8729
+ "bit_or",
8730
+ input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
8731
+ constraint=sametypes,
8732
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
8733
+ doc="",
8734
+ group="Operators",
8735
+ is_differentiable=False,
8736
+ )
8737
+
7911
8738
  add_builtin(
7912
8739
  "bit_xor",
7913
8740
  input_types={"a": Int, "b": Int},
7914
8741
  value_func=sametypes_create_value_func(Int),
7915
8742
  group="Operators",
7916
- missing_grad=True,
8743
+ is_differentiable=False,
8744
+ )
8745
+ add_builtin(
8746
+ "bit_xor",
8747
+ input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
8748
+ constraint=sametypes,
8749
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
8750
+ doc="",
8751
+ group="Operators",
8752
+ is_differentiable=False,
8753
+ )
8754
+ add_builtin(
8755
+ "bit_xor",
8756
+ input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
8757
+ constraint=sametypes,
8758
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
8759
+ doc="",
8760
+ group="Operators",
8761
+ is_differentiable=False,
8762
+ )
8763
+
8764
+ add_builtin(
8765
+ "lshift",
8766
+ input_types={"a": Int, "b": Int},
8767
+ value_func=sametypes_create_value_func(Int),
8768
+ group="Operators",
8769
+ is_differentiable=False,
8770
+ )
8771
+ add_builtin(
8772
+ "lshift",
8773
+ input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
8774
+ constraint=sametypes,
8775
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
8776
+ doc="",
8777
+ group="Operators",
8778
+ is_differentiable=False,
7917
8779
  )
7918
- add_builtin("lshift", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int), group="Operators")
8780
+ add_builtin(
8781
+ "lshift",
8782
+ input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
8783
+ constraint=sametypes,
8784
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
8785
+ doc="",
8786
+ group="Operators",
8787
+ is_differentiable=False,
8788
+ )
8789
+
7919
8790
  add_builtin(
7920
8791
  "rshift",
7921
8792
  input_types={"a": Int, "b": Int},
7922
8793
  value_func=sametypes_create_value_func(Int),
7923
8794
  group="Operators",
7924
- missing_grad=True,
8795
+ is_differentiable=False,
8796
+ )
8797
+ add_builtin(
8798
+ "rshift",
8799
+ input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
8800
+ constraint=sametypes,
8801
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
8802
+ doc="",
8803
+ group="Operators",
8804
+ is_differentiable=False,
8805
+ )
8806
+ add_builtin(
8807
+ "rshift",
8808
+ input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
8809
+ constraint=sametypes,
8810
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
8811
+ doc="",
8812
+ group="Operators",
8813
+ is_differentiable=False,
8814
+ )
8815
+
8816
+ add_builtin(
8817
+ "invert",
8818
+ input_types={"a": Int},
8819
+ value_func=sametypes_create_value_func(Int),
8820
+ group="Operators",
8821
+ is_differentiable=False,
8822
+ )
8823
+ add_builtin(
8824
+ "invert",
8825
+ input_types={"a": vector(length=Any, dtype=Int)},
8826
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
8827
+ group="Operators",
8828
+ is_differentiable=False,
7925
8829
  )
7926
- add_builtin("invert", input_types={"a": Int}, value_func=sametypes_create_value_func(Int), group="Operators")
8830
+ add_builtin(
8831
+ "invert",
8832
+ input_types={"a": matrix(shape=(Any, Any), dtype=Int)},
8833
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
8834
+ group="Operators",
8835
+ is_differentiable=False,
8836
+ )
8837
+
7927
8838
 
7928
8839
  add_builtin(
7929
8840
  "mul", input_types={"a": Scalar, "b": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators"
@@ -8123,7 +9034,7 @@ add_builtin(
8123
9034
  value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
8124
9035
  doc="Modulo operation using truncated division.",
8125
9036
  group="Operators",
8126
- missing_grad=True,
9037
+ is_differentiable=False,
8127
9038
  )
8128
9039
 
8129
9040
  add_builtin(
@@ -8183,7 +9094,7 @@ add_builtin(
8183
9094
  value_func=sametypes_create_value_func(Scalar),
8184
9095
  doc="",
8185
9096
  group="Operators",
8186
- missing_grad=True,
9097
+ is_differentiable=False,
8187
9098
  )
8188
9099
 
8189
9100
  add_builtin("pos", input_types={"x": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators")
@@ -8232,14 +9143,26 @@ add_builtin(
8232
9143
  )
8233
9144
 
8234
9145
  add_builtin(
8235
- "unot", input_types={"a": builtins.bool}, value_type=builtins.bool, doc="", group="Operators", missing_grad=True
9146
+ "unot",
9147
+ input_types={"a": builtins.bool},
9148
+ value_type=builtins.bool,
9149
+ doc="",
9150
+ group="Operators",
9151
+ is_differentiable=False,
8236
9152
  )
8237
9153
  for t in int_types:
8238
- add_builtin("unot", input_types={"a": t}, value_type=builtins.bool, doc="", group="Operators", missing_grad=True)
9154
+ add_builtin(
9155
+ "unot", input_types={"a": t}, value_type=builtins.bool, doc="", group="Operators", is_differentiable=False
9156
+ )
8239
9157
 
8240
9158
 
8241
9159
  add_builtin(
8242
- "unot", input_types={"a": array(dtype=Any)}, value_type=builtins.bool, doc="", group="Operators", missing_grad=True
9160
+ "unot",
9161
+ input_types={"a": array(dtype=Any)},
9162
+ value_type=builtins.bool,
9163
+ doc="",
9164
+ group="Operators",
9165
+ is_differentiable=False,
8243
9166
  )
8244
9167
 
8245
9168
 
@@ -8312,6 +9235,45 @@ add_builtin(
8312
9235
  export=False,
8313
9236
  )
8314
9237
 
9238
+ add_builtin(
9239
+ "bit_and",
9240
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
9241
+ value_func=tile_binary_map_value_func,
9242
+ # dispatch_func=tile_map_dispatch_func,
9243
+ # variadic=True,
9244
+ native_func="tile_bit_and",
9245
+ doc="Bitwise AND each element of two tiles together",
9246
+ group="Tile Primitives",
9247
+ export=False,
9248
+ is_differentiable=False,
9249
+ )
9250
+
9251
+ add_builtin(
9252
+ "bit_or",
9253
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
9254
+ value_func=tile_binary_map_value_func,
9255
+ # dispatch_func=tile_map_dispatch_func,
9256
+ # variadic=True,
9257
+ native_func="tile_bit_or",
9258
+ doc="Bitwise OR each element of two tiles together",
9259
+ group="Tile Primitives",
9260
+ export=False,
9261
+ is_differentiable=False,
9262
+ )
9263
+
9264
+ add_builtin(
9265
+ "bit_xor",
9266
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
9267
+ value_func=tile_binary_map_value_func,
9268
+ # dispatch_func=tile_map_dispatch_func,
9269
+ # variadic=True,
9270
+ native_func="tile_bit_xor",
9271
+ doc="Bitwise XOR each element of two tiles together",
9272
+ group="Tile Primitives",
9273
+ export=False,
9274
+ is_differentiable=False,
9275
+ )
9276
+
8315
9277
 
8316
9278
  add_builtin(
8317
9279
  "mul",
@@ -8373,6 +9335,45 @@ add_builtin(
8373
9335
  )
8374
9336
 
8375
9337
 
9338
+ add_builtin(
9339
+ "bit_and_inplace",
9340
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
9341
+ value_type=None,
9342
+ dispatch_func=tile_inplace_dispatch_func,
9343
+ export=False,
9344
+ hidden=True,
9345
+ native_func="tile_bit_and_inplace",
9346
+ group="Operators",
9347
+ is_differentiable=False,
9348
+ )
9349
+
9350
+
9351
+ add_builtin(
9352
+ "bit_or_inplace",
9353
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
9354
+ value_type=None,
9355
+ dispatch_func=tile_inplace_dispatch_func,
9356
+ export=False,
9357
+ hidden=True,
9358
+ native_func="tile_bit_or_inplace",
9359
+ group="Operators",
9360
+ is_differentiable=False,
9361
+ )
9362
+
9363
+
9364
+ add_builtin(
9365
+ "bit_xor_inplace",
9366
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
9367
+ value_type=None,
9368
+ dispatch_func=tile_inplace_dispatch_func,
9369
+ export=False,
9370
+ hidden=True,
9371
+ native_func="tile_bit_xor_inplace",
9372
+ group="Operators",
9373
+ is_differentiable=False,
9374
+ )
9375
+
9376
+
8376
9377
  def tile_diag_add_value_func(arg_types, arg_values):
8377
9378
  if arg_types is None:
8378
9379
  return tile(dtype=Any, shape=Tuple[int, int])
@@ -8414,7 +9415,7 @@ def tile_diag_add_lto_dispatch_func(
8414
9415
  return_values: List[Var],
8415
9416
  arg_values: Mapping[str, Var],
8416
9417
  options: Mapping[str, Any],
8417
- builder: warp.context.ModuleBuilder,
9418
+ builder: warp._src.context.ModuleBuilder,
8418
9419
  ):
8419
9420
  a = arg_values["a"]
8420
9421
  d = arg_values["d"]
@@ -8434,7 +9435,7 @@ add_builtin(
8434
9435
  doc="Add a square matrix and a diagonal matrix 'd' represented as a 1D tile",
8435
9436
  group="Tile Primitives",
8436
9437
  export=False,
8437
- missing_grad=True,
9438
+ is_differentiable=False,
8438
9439
  )
8439
9440
 
8440
9441
 
@@ -8491,7 +9492,7 @@ def tile_matmul_lto_dispatch_func(
8491
9492
  return_values: List[Var],
8492
9493
  arg_values: Mapping[str, Var],
8493
9494
  options: Mapping[str, Any],
8494
- builder: warp.context.ModuleBuilder,
9495
+ builder: warp._src.context.ModuleBuilder,
8495
9496
  ):
8496
9497
  a = arg_values["a"]
8497
9498
  b = arg_values["b"]
@@ -8529,7 +9530,7 @@ def tile_matmul_lto_dispatch_func(
8529
9530
  num_threads = options["block_dim"]
8530
9531
  arch = options["output_arch"]
8531
9532
 
8532
- if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
9533
+ if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
8533
9534
  # CPU/no-MathDx dispatch
8534
9535
  return ((0, 0, 0, a, b, out), template_args, [], 0)
8535
9536
  else:
@@ -8542,7 +9543,7 @@ def tile_matmul_lto_dispatch_func(
8542
9543
 
8543
9544
  # generate the LTOs
8544
9545
  # C += A * B
8545
- (fun_forward, lto_forward) = warp.build.build_lto_dot(
9546
+ (fun_forward, lto_forward) = warp._src.build.build_lto_dot(
8546
9547
  M,
8547
9548
  N,
8548
9549
  K,
@@ -8558,7 +9559,7 @@ def tile_matmul_lto_dispatch_func(
8558
9559
  )
8559
9560
  if warp.config.enable_backward:
8560
9561
  # adjA += adjC * B^T - Transpose ~= flipped layout
8561
- (fun_backward_A, lto_backward_A) = warp.build.build_lto_dot(
9562
+ (fun_backward_A, lto_backward_A) = warp._src.build.build_lto_dot(
8562
9563
  M,
8563
9564
  K,
8564
9565
  N,
@@ -8573,7 +9574,7 @@ def tile_matmul_lto_dispatch_func(
8573
9574
  builder,
8574
9575
  )
8575
9576
  # adjB += A^T * adjC - Transpose ~= flipped layout
8576
- (fun_backward_B, lto_backward_B) = warp.build.build_lto_dot(
9577
+ (fun_backward_B, lto_backward_B) = warp._src.build.build_lto_dot(
8577
9578
  K,
8578
9579
  N,
8579
9580
  M,
@@ -8690,7 +9691,7 @@ def tile_fft_generic_lto_dispatch_func(
8690
9691
  return_values: List[Var],
8691
9692
  arg_values: Mapping[str, Var],
8692
9693
  options: Mapping[str, Any],
8693
- builder: warp.context.ModuleBuilder,
9694
+ builder: warp._src.context.ModuleBuilder,
8694
9695
  direction: str | None = None,
8695
9696
  ):
8696
9697
  inout = arg_values["inout"]
@@ -8719,12 +9720,12 @@ def tile_fft_generic_lto_dispatch_func(
8719
9720
  arch = options["output_arch"]
8720
9721
  ept = size // num_threads
8721
9722
 
8722
- if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
9723
+ if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
8723
9724
  # CPU/no-MathDx dispatch
8724
9725
  return ([], [], [], 0)
8725
9726
  else:
8726
9727
  # generate the LTO
8727
- lto_symbol, lto_code_data, shared_memory_bytes = warp.build.build_lto_fft(
9728
+ lto_symbol, lto_code_data, shared_memory_bytes = warp._src.build.build_lto_fft(
8728
9729
  arch, size, ept, direction, dir, precision, builder
8729
9730
  )
8730
9731
 
@@ -8762,7 +9763,7 @@ add_builtin(
8762
9763
  group="Tile Primitives",
8763
9764
  export=False,
8764
9765
  namespace="",
8765
- missing_grad=True,
9766
+ is_differentiable=False,
8766
9767
  )
8767
9768
 
8768
9769
  add_builtin(
@@ -8784,7 +9785,7 @@ add_builtin(
8784
9785
  group="Tile Primitives",
8785
9786
  export=False,
8786
9787
  namespace="",
8787
- missing_grad=True,
9788
+ is_differentiable=False,
8788
9789
  )
8789
9790
 
8790
9791
 
@@ -8829,7 +9830,7 @@ def tile_cholesky_generic_lto_dispatch_func(
8829
9830
  return_values: List[Var],
8830
9831
  arg_values: Mapping[str, Var],
8831
9832
  options: Mapping[str, Any],
8832
- builder: warp.context.ModuleBuilder,
9833
+ builder: warp._src.context.ModuleBuilder,
8833
9834
  ):
8834
9835
  a = arg_values["A"]
8835
9836
  # force source tile to shared memory
@@ -8849,7 +9850,7 @@ def tile_cholesky_generic_lto_dispatch_func(
8849
9850
 
8850
9851
  arch = options["output_arch"]
8851
9852
 
8852
- if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
9853
+ if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
8853
9854
  # CPU/no-MathDx dispatch
8854
9855
  return ((0, a, out), [], [], 0)
8855
9856
  else:
@@ -8864,7 +9865,7 @@ def tile_cholesky_generic_lto_dispatch_func(
8864
9865
  req_smem_bytes = a.type.size * type_size_in_bytes(a.type.dtype)
8865
9866
 
8866
9867
  # generate the LTO
8867
- lto_symbol, lto_code_data = warp.build.build_lto_solver(
9868
+ lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
8868
9869
  M,
8869
9870
  N,
8870
9871
  1,
@@ -8909,7 +9910,7 @@ add_builtin(
8909
9910
  group="Tile Primitives",
8910
9911
  export=False,
8911
9912
  namespace="",
8912
- missing_grad=True,
9913
+ is_differentiable=False,
8913
9914
  )
8914
9915
 
8915
9916
 
@@ -8953,7 +9954,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
8953
9954
  return_values: List[Var],
8954
9955
  arg_values: Mapping[str, Var],
8955
9956
  options: Mapping[str, Any],
8956
- builder: warp.context.ModuleBuilder,
9957
+ builder: warp._src.context.ModuleBuilder,
8957
9958
  ):
8958
9959
  L = arg_values["L"]
8959
9960
  y = arg_values["y"]
@@ -8982,7 +9983,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
8982
9983
 
8983
9984
  arch = options["output_arch"]
8984
9985
 
8985
- if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
9986
+ if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
8986
9987
  # CPU/no-MathDx dispatch
8987
9988
  return ((0, L, y, x), [], [], 0)
8988
9989
  else:
@@ -8998,7 +9999,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
8998
9999
  req_smem_bytes = (x.type.size + y.type.size + L.type.size) * type_size_in_bytes(L.type.dtype)
8999
10000
 
9000
10001
  # generate the LTO
9001
- lto_symbol, lto_code_data = warp.build.build_lto_solver(
10002
+ lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
9002
10003
  M,
9003
10004
  N,
9004
10005
  NRHS,
@@ -9040,7 +10041,7 @@ add_builtin(
9040
10041
  group="Tile Primitives",
9041
10042
  export=False,
9042
10043
  namespace="",
9043
- missing_grad=True,
10044
+ is_differentiable=False,
9044
10045
  )
9045
10046
 
9046
10047
 
@@ -9050,7 +10051,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
9050
10051
  return_values: List[Var],
9051
10052
  arg_values: Mapping[str, Var],
9052
10053
  options: Mapping[str, Any],
9053
- builder: warp.context.ModuleBuilder,
10054
+ builder: warp._src.context.ModuleBuilder,
9054
10055
  ):
9055
10056
  L = arg_values["L"]
9056
10057
  y = arg_values["y"]
@@ -9079,7 +10080,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
9079
10080
 
9080
10081
  arch = options["output_arch"]
9081
10082
 
9082
- if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
10083
+ if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
9083
10084
  # CPU/no-MathDx dispatch
9084
10085
  return ((0, L, y, z), [], [], 0)
9085
10086
  else:
@@ -9095,7 +10096,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
9095
10096
  req_smem_bytes = (z.type.size + y.type.size + L.type.size) * type_size_in_bytes(L.type.dtype)
9096
10097
 
9097
10098
  # generate the LTO
9098
- lto_symbol, lto_code_data = warp.build.build_lto_solver(
10099
+ lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
9099
10100
  M,
9100
10101
  N,
9101
10102
  NRHS,
@@ -9173,7 +10174,7 @@ add_builtin(
9173
10174
  group="Tile Primitives",
9174
10175
  export=False,
9175
10176
  namespace="",
9176
- missing_grad=True,
10177
+ is_differentiable=False,
9177
10178
  )
9178
10179
 
9179
10180
 
@@ -9183,7 +10184,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
9183
10184
  return_values: List[Var],
9184
10185
  arg_values: Mapping[str, Var],
9185
10186
  options: Mapping[str, Any],
9186
- builder: warp.context.ModuleBuilder,
10187
+ builder: warp._src.context.ModuleBuilder,
9187
10188
  ):
9188
10189
  U = arg_values["U"]
9189
10190
  z = arg_values["z"]
@@ -9212,7 +10213,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
9212
10213
 
9213
10214
  arch = options["output_arch"]
9214
10215
 
9215
- if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
10216
+ if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
9216
10217
  # CPU/no-MathDx dispatch
9217
10218
  return ((0, U, z, x), [], [], 0)
9218
10219
  else:
@@ -9228,7 +10229,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
9228
10229
  req_smem_bytes = (x.type.size + z.type.size + U.type.size) * type_size_in_bytes(U.type.dtype)
9229
10230
 
9230
10231
  # generate the LTO
9231
- lto_symbol, lto_code_data = warp.build.build_lto_solver(
10232
+ lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
9232
10233
  M,
9233
10234
  N,
9234
10235
  NRHS,
@@ -9306,7 +10307,7 @@ add_builtin(
9306
10307
  group="Tile Primitives",
9307
10308
  export=False,
9308
10309
  namespace="",
9309
- missing_grad=True,
10310
+ is_differentiable=False,
9310
10311
  )
9311
10312
 
9312
10313
 
@@ -9326,7 +10327,7 @@ add_builtin(
9326
10327
  The return type of the expression must be either a Warp function, a string, or a type that is supported inside Warp kernels and functions
9327
10328
  (excluding Warp arrays since they cannot be created in a Warp kernel at the moment).""",
9328
10329
  group="Code Generation",
9329
- missing_grad=True,
10330
+ is_differentiable=False,
9330
10331
  )
9331
10332
 
9332
10333
 
@@ -9351,7 +10352,7 @@ add_builtin(
9351
10352
  doc="Return the number of elements in a vector.",
9352
10353
  group="Utility",
9353
10354
  export=False,
9354
- missing_grad=True,
10355
+ is_differentiable=False,
9355
10356
  )
9356
10357
 
9357
10358
  add_builtin(
@@ -9361,7 +10362,7 @@ add_builtin(
9361
10362
  doc="Return the number of elements in a quaternion.",
9362
10363
  group="Utility",
9363
10364
  export=False,
9364
- missing_grad=True,
10365
+ is_differentiable=False,
9365
10366
  )
9366
10367
 
9367
10368
  add_builtin(
@@ -9371,7 +10372,7 @@ add_builtin(
9371
10372
  doc="Return the number of rows in a matrix.",
9372
10373
  group="Utility",
9373
10374
  export=False,
9374
- missing_grad=True,
10375
+ is_differentiable=False,
9375
10376
  )
9376
10377
 
9377
10378
  add_builtin(
@@ -9381,7 +10382,7 @@ add_builtin(
9381
10382
  doc="Return the number of elements in a transformation.",
9382
10383
  group="Utility",
9383
10384
  export=False,
9384
- missing_grad=True,
10385
+ is_differentiable=False,
9385
10386
  )
9386
10387
 
9387
10388
  add_builtin(
@@ -9391,7 +10392,7 @@ add_builtin(
9391
10392
  doc="Return the size of the first dimension in an array.",
9392
10393
  group="Utility",
9393
10394
  export=False,
9394
- missing_grad=True,
10395
+ is_differentiable=False,
9395
10396
  )
9396
10397
 
9397
10398
  add_builtin(
@@ -9401,7 +10402,62 @@ add_builtin(
9401
10402
  doc="Return the number of rows in a tile.",
9402
10403
  group="Utility",
9403
10404
  export=False,
9404
- missing_grad=True,
10405
+ is_differentiable=False,
10406
+ )
10407
+
10408
+
10409
+ def cast_value_func(arg_types, arg_values):
10410
+ # Return generic type for doc builds.
10411
+ if arg_types is None:
10412
+ return Any
10413
+
10414
+ return arg_values["dtype"]
10415
+
10416
+
10417
+ def cast_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
10418
+ func_args = (args["a"],)
10419
+ template_args = (args["dtype"],)
10420
+ return (func_args, template_args)
10421
+
10422
+
10423
+ add_builtin(
10424
+ "cast",
10425
+ input_types={"a": Any, "dtype": Any},
10426
+ value_func=cast_value_func,
10427
+ dispatch_func=cast_dispatch_func,
10428
+ group="Utility",
10429
+ export=False,
10430
+ is_differentiable=False,
10431
+ doc="""Reinterpret a value as a different type while preserving its bit pattern.
10432
+
10433
+ :param a: The value to cast
10434
+ :param dtype: The target type
10435
+
10436
+ Example:
10437
+
10438
+ .. code-block:: python
10439
+
10440
+ @wp.struct
10441
+ class MyStruct:
10442
+ f: wp.float16
10443
+ i: wp.int16
10444
+
10445
+
10446
+ @wp.kernel
10447
+ def compute():
10448
+ x = wp.int32(0x40000000)
10449
+ x_casted = wp.cast(x, wp.float32)
10450
+ wp.expect_eq(x_casted, 2.0) # 0x40000000
10451
+
10452
+ s = MyStruct()
10453
+ s.f = wp.float16(2.0) # 0x4000
10454
+ s.i = wp.int16(4096) # 0x1000
10455
+ s_casted = wp.cast(s, wp.int32)
10456
+ wp.expect_eq(s_casted, 0x10004000)
10457
+
10458
+
10459
+ wp.launch(compute, dim=1)
10460
+ """,
9405
10461
  )
9406
10462
 
9407
10463
 
@@ -9428,7 +10484,7 @@ add_builtin(
9428
10484
  doc="Construct a tuple from a list of values",
9429
10485
  group="Utility",
9430
10486
  hidden=True,
9431
- missing_grad=True,
10487
+ is_differentiable=False,
9432
10488
  export=False,
9433
10489
  )
9434
10490
 
@@ -9465,7 +10521,7 @@ add_builtin(
9465
10521
  dispatch_func=tuple_extract_dispatch_func,
9466
10522
  group="Utility",
9467
10523
  hidden=True,
9468
- missing_grad=True,
10524
+ is_differentiable=False,
9469
10525
  )
9470
10526
 
9471
10527
 
@@ -9476,7 +10532,7 @@ add_builtin(
9476
10532
  doc="Return the number of elements in a tuple.",
9477
10533
  group="Utility",
9478
10534
  export=False,
9479
- missing_grad=True,
10535
+ is_differentiable=False,
9480
10536
  )
9481
10537
 
9482
10538
  # ---------------------------------
@@ -9495,5 +10551,5 @@ add_builtin(
9495
10551
  export=False,
9496
10552
  group="Utility",
9497
10553
  hidden=True,
9498
- missing_grad=True,
10554
+ is_differentiable=False,
9499
10555
  )