warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2220 -313
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1497 -226
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +1 -1
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +313 -201
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +14 -14
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +529 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -20,11 +20,11 @@ import functools
20
20
  import math
21
21
  from typing import Any, Callable, Mapping, Sequence
22
22
 
23
- import warp.build
24
- import warp.context
25
- import warp.utils
26
- from warp.codegen import Reference, Var, get_arg_value, strip_reference
27
- from warp.types import *
23
+ import warp._src.build
24
+ import warp._src.context
25
+ import warp._src.utils
26
+ from warp._src.codegen import Reference, Var, get_arg_value, strip_reference
27
+ from warp._src.types import *
28
28
 
29
29
  from .context import add_builtin
30
30
 
@@ -61,11 +61,11 @@ def sametypes_create_value_func(default: TypeVar):
61
61
 
62
62
  def extract_tuple(arg, as_constant=False):
63
63
  if isinstance(arg, Var):
64
- if isinstance(arg.type, warp.types.tuple_t):
64
+ if isinstance(arg.type, warp._src.types.tuple_t):
65
65
  out = arg.type.values
66
66
  else:
67
67
  out = (arg,)
68
- elif isinstance(arg, warp.types.tuple_t):
68
+ elif isinstance(arg, warp._src.types.tuple_t):
69
69
  out = arg.values
70
70
  elif not isinstance(arg, Sequence):
71
71
  out = (arg,)
@@ -82,7 +82,7 @@ def static_len_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str
82
82
  if arg_types is None:
83
83
  return int
84
84
 
85
- length = warp.types.type_length(arg_types["a"])
85
+ length = warp._src.types.type_length(arg_types["a"])
86
86
  return Var(None, type=int, constant=length)
87
87
 
88
88
 
@@ -126,6 +126,7 @@ add_builtin(
126
126
  value_func=sametypes_create_value_func(Scalar),
127
127
  doc="Return -1 if ``x`` < 0, return 1 otherwise.",
128
128
  group="Scalar Math",
129
+ is_differentiable=False,
129
130
  )
130
131
 
131
132
  add_builtin(
@@ -134,6 +135,7 @@ add_builtin(
134
135
  value_func=sametypes_create_value_func(Scalar),
135
136
  doc="Return 1.0 if ``x`` < 0.0, return 0.0 otherwise.",
136
137
  group="Scalar Math",
138
+ is_differentiable=False,
137
139
  )
138
140
  add_builtin(
139
141
  "nonzero",
@@ -141,6 +143,7 @@ add_builtin(
141
143
  value_func=sametypes_create_value_func(Scalar),
142
144
  doc="Return 1.0 if ``x`` is not equal to zero, return 0.0 otherwise.",
143
145
  group="Scalar Math",
146
+ is_differentiable=False,
144
147
  )
145
148
 
146
149
  add_builtin(
@@ -282,7 +285,36 @@ add_builtin(
282
285
  group="Scalar Math",
283
286
  require_original_output_arg=True,
284
287
  )
285
-
288
+ add_builtin(
289
+ "erf",
290
+ input_types={"x": Float},
291
+ value_func=sametypes_create_value_func(Float),
292
+ doc="Return the error function of ``x``.",
293
+ group="Scalar Math",
294
+ )
295
+ add_builtin(
296
+ "erfc",
297
+ input_types={"x": Float},
298
+ value_func=sametypes_create_value_func(Float),
299
+ doc="Return the complementary error function of ``x``.",
300
+ group="Scalar Math",
301
+ )
302
+ add_builtin(
303
+ "erfinv",
304
+ input_types={"x": Float},
305
+ value_func=sametypes_create_value_func(Float),
306
+ doc="Return the inverse error function of ``x``.",
307
+ group="Scalar Math",
308
+ require_original_output_arg=True,
309
+ )
310
+ add_builtin(
311
+ "erfcinv",
312
+ input_types={"x": Float},
313
+ value_func=sametypes_create_value_func(Float),
314
+ doc="Return the inverse complementary error function of ``x``.",
315
+ group="Scalar Math",
316
+ require_original_output_arg=True,
317
+ )
286
318
  add_builtin(
287
319
  "round",
288
320
  input_types={"x": Float},
@@ -292,6 +324,7 @@ add_builtin(
292
324
 
293
325
  This is the most intuitive form of rounding in the colloquial sense, but can be slower than other options like :func:`warp.rint()`.
294
326
  Differs from :func:`numpy.round()`, which behaves the same way as :func:`numpy.rint()`.""",
327
+ is_differentiable=False,
295
328
  )
296
329
 
297
330
  add_builtin(
@@ -302,6 +335,7 @@ add_builtin(
302
335
  doc="""Return the nearest integer value to ``x``, rounding halfway cases to nearest even integer.
303
336
 
304
337
  It is generally faster than :func:`warp.round()`. Equivalent to :func:`numpy.rint()`.""",
338
+ is_differentiable=False,
305
339
  )
306
340
 
307
341
  add_builtin(
@@ -314,6 +348,7 @@ add_builtin(
314
348
  In other words, it discards the fractional part of ``x``.
315
349
  It is similar to casting ``float(int(a))``, but preserves the negative sign when ``x`` is in the range [-0.0, -1.0).
316
350
  Equivalent to :func:`numpy.trunc()` and :func:`numpy.fix()`.""",
351
+ is_differentiable=False,
317
352
  )
318
353
 
319
354
  add_builtin(
@@ -322,6 +357,7 @@ add_builtin(
322
357
  value_func=sametypes_create_value_func(Float),
323
358
  group="Scalar Math",
324
359
  doc="""Return the largest integer that is less than or equal to ``x``.""",
360
+ is_differentiable=False,
325
361
  )
326
362
 
327
363
  add_builtin(
@@ -330,6 +366,7 @@ add_builtin(
330
366
  value_func=sametypes_create_value_func(Float),
331
367
  group="Scalar Math",
332
368
  doc="""Return the smallest integer that is greater than or equal to ``x``.""",
369
+ is_differentiable=False,
333
370
  )
334
371
 
335
372
  add_builtin(
@@ -340,6 +377,7 @@ add_builtin(
340
377
  doc="""Retrieve the fractional part of ``x``.
341
378
 
342
379
  In other words, it discards the integer part of ``x`` and is equivalent to ``x - trunc(x)``.""",
380
+ is_differentiable=False,
343
381
  )
344
382
 
345
383
  add_builtin(
@@ -348,6 +386,7 @@ add_builtin(
348
386
  value_type=builtins.bool,
349
387
  group="Scalar Math",
350
388
  doc="""Return ``True`` if ``a`` is a finite number, otherwise return ``False``.""",
389
+ is_differentiable=False,
351
390
  )
352
391
  add_builtin(
353
392
  "isfinite",
@@ -355,6 +394,7 @@ add_builtin(
355
394
  value_type=builtins.bool,
356
395
  group="Vector Math",
357
396
  doc="Return ``True`` if all elements of the vector ``a`` are finite, otherwise return ``False``.",
397
+ is_differentiable=False,
358
398
  )
359
399
  add_builtin(
360
400
  "isfinite",
@@ -362,6 +402,7 @@ add_builtin(
362
402
  value_type=builtins.bool,
363
403
  group="Vector Math",
364
404
  doc="Return ``True`` if all elements of the quaternion ``a`` are finite, otherwise return ``False``.",
405
+ is_differentiable=False,
365
406
  )
366
407
  add_builtin(
367
408
  "isfinite",
@@ -369,6 +410,7 @@ add_builtin(
369
410
  value_type=builtins.bool,
370
411
  group="Vector Math",
371
412
  doc="Return ``True`` if all elements of the matrix ``a`` are finite, otherwise return ``False``.",
413
+ is_differentiable=False,
372
414
  )
373
415
 
374
416
  add_builtin(
@@ -377,6 +419,7 @@ add_builtin(
377
419
  value_type=builtins.bool,
378
420
  doc="Return ``True`` if ``a`` is NaN, otherwise return ``False``.",
379
421
  group="Scalar Math",
422
+ is_differentiable=False,
380
423
  )
381
424
  add_builtin(
382
425
  "isnan",
@@ -384,6 +427,7 @@ add_builtin(
384
427
  value_type=builtins.bool,
385
428
  group="Vector Math",
386
429
  doc="Return ``True`` if any element of the vector ``a`` is NaN, otherwise return ``False``.",
430
+ is_differentiable=False,
387
431
  )
388
432
  add_builtin(
389
433
  "isnan",
@@ -391,6 +435,7 @@ add_builtin(
391
435
  value_type=builtins.bool,
392
436
  group="Vector Math",
393
437
  doc="Return ``True`` if any element of the quaternion ``a`` is NaN, otherwise return ``False``.",
438
+ is_differentiable=False,
394
439
  )
395
440
  add_builtin(
396
441
  "isnan",
@@ -398,6 +443,7 @@ add_builtin(
398
443
  value_type=builtins.bool,
399
444
  group="Vector Math",
400
445
  doc="Return ``True`` if any element of the matrix ``a`` is NaN, otherwise return ``False``.",
446
+ is_differentiable=False,
401
447
  )
402
448
 
403
449
  add_builtin(
@@ -406,6 +452,7 @@ add_builtin(
406
452
  value_type=builtins.bool,
407
453
  group="Scalar Math",
408
454
  doc="""Return ``True`` if ``a`` is positive or negative infinity, otherwise return ``False``.""",
455
+ is_differentiable=False,
409
456
  )
410
457
  add_builtin(
411
458
  "isinf",
@@ -413,6 +460,7 @@ add_builtin(
413
460
  value_type=builtins.bool,
414
461
  group="Vector Math",
415
462
  doc="Return ``True`` if any element of the vector ``a`` is positive or negative infinity, otherwise return ``False``.",
463
+ is_differentiable=False,
416
464
  )
417
465
  add_builtin(
418
466
  "isinf",
@@ -420,6 +468,7 @@ add_builtin(
420
468
  value_type=builtins.bool,
421
469
  group="Vector Math",
422
470
  doc="Return ``True`` if any element of the quaternion ``a`` is positive or negative infinity, otherwise return ``False``.",
471
+ is_differentiable=False,
423
472
  )
424
473
  add_builtin(
425
474
  "isinf",
@@ -427,6 +476,7 @@ add_builtin(
427
476
  value_type=builtins.bool,
428
477
  group="Vector Math",
429
478
  doc="Return ``True`` if any element of the matrix ``a`` is positive or negative infinity, otherwise return ``False``.",
479
+ is_differentiable=False,
430
480
  )
431
481
 
432
482
 
@@ -534,7 +584,7 @@ add_builtin(
534
584
  value_func=lambda arg_types, arg_values: warp.uint32,
535
585
  doc="Return the index of the minimum element of a vector ``a``.",
536
586
  group="Vector Math",
537
- missing_grad=True,
587
+ is_differentiable=False,
538
588
  )
539
589
  add_builtin(
540
590
  "argmax",
@@ -542,7 +592,7 @@ add_builtin(
542
592
  value_func=lambda arg_types, arg_values: warp.uint32,
543
593
  doc="Return the index of the maximum element of a vector ``a``.",
544
594
  group="Vector Math",
545
- missing_grad=True,
595
+ is_differentiable=False,
546
596
  )
547
597
 
548
598
  add_builtin(
@@ -867,7 +917,7 @@ def vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
867
917
 
868
918
  if dtype is None:
869
919
  dtype = value_type
870
- elif not warp.types.scalars_equal(value_type, dtype):
920
+ elif not warp._src.types.scalars_equal(value_type, dtype):
871
921
  raise RuntimeError(
872
922
  f"the value used to fill this vector is expected to be of the type `{dtype.__name__}`"
873
923
  )
@@ -888,7 +938,7 @@ def vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
888
938
 
889
939
  if dtype is None:
890
940
  dtype = value_type
891
- elif not warp.types.scalars_equal(value_type, dtype):
941
+ elif not warp._src.types.scalars_equal(value_type, dtype):
892
942
  raise RuntimeError(
893
943
  f"all values used to initialize this vector are expected to be of the type `{dtype.__name__}`"
894
944
  )
@@ -971,7 +1021,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
971
1021
 
972
1022
  if dtype is None:
973
1023
  dtype = value_type
974
- elif not warp.types.scalars_equal(value_type, dtype):
1024
+ elif not warp._src.types.scalars_equal(value_type, dtype):
975
1025
  raise RuntimeError(
976
1026
  f"the value used to fill this matrix is expected to be of the type `{dtype.__name__}`"
977
1027
  )
@@ -981,7 +1031,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
981
1031
  raise RuntimeError("the `shape` argument must be specified when initializing a matrix by value")
982
1032
 
983
1033
  if all(type_is_vector(x) for x in variadic_arg_types):
984
- warp.utils.warn(
1034
+ warp._src.utils.warn(
985
1035
  "the built-in `wp.matrix()` won't support taking column vectors as input "
986
1036
  "in the future. Use `wp.matrix_from_rows()` or `wp.matrix_from_cols()` instead.",
987
1037
  DeprecationWarning,
@@ -1010,7 +1060,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
1010
1060
 
1011
1061
  if dtype is None:
1012
1062
  dtype = value_type
1013
- elif not warp.types.scalars_equal(value_type, dtype):
1063
+ elif not warp._src.types.scalars_equal(value_type, dtype):
1014
1064
  raise RuntimeError(
1015
1065
  f"all values used to initialize this matrix are expected to be of the type `{dtype.__name__}`"
1016
1066
  )
@@ -1182,48 +1232,18 @@ add_builtin(
1182
1232
  doc="Create an identity matrix with shape=(n,n) with the type given by ``dtype``.",
1183
1233
  group="Vector Math",
1184
1234
  export=False,
1235
+ is_differentiable=False,
1185
1236
  )
1186
1237
 
1187
1238
 
1188
1239
  def matrix_transform_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1189
- warp.utils.warn(
1190
- "the built-in `wp.matrix()` function to construct a 4x4 matrix from a 3D position, quaternion, "
1191
- "and 3D scale vector will be deprecated in favor of `wp.transform_compose()`.",
1192
- DeprecationWarning,
1193
- )
1194
1240
  if arg_types is None:
1195
1241
  return matrix(shape=(4, 4), dtype=Float)
1196
1242
 
1197
- dtype = arg_values.get("dtype", None)
1198
-
1199
- value_arg_types = tuple(v for k, v in arg_types.items() if k != "dtype")
1200
- try:
1201
- value_type = scalar_infer_type(value_arg_types)
1202
- except RuntimeError:
1203
- raise RuntimeError(
1204
- "all values given when constructing a transformation matrix must have the same type"
1205
- ) from None
1206
-
1207
- if dtype is None:
1208
- dtype = value_type
1209
- elif not warp.types.scalars_equal(value_type, dtype):
1210
- raise RuntimeError(
1211
- f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
1212
- )
1213
-
1214
- return matrix(shape=(4, 4), dtype=dtype)
1215
-
1216
-
1217
- def matrix_transform_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
1218
- # We're in the codegen stage where we emit the code calling the built-in.
1219
- # Further validate the given argument values if needed and map them
1220
- # to the underlying C++ function's runtime and template params.
1221
-
1222
- dtype = return_type._wp_scalar_type_
1223
-
1224
- func_args = tuple(v for k, v in args.items() if k != "dtype")
1225
- template_args = (4, 4, dtype)
1226
- return (func_args, template_args)
1243
+ raise RuntimeError(
1244
+ "the built-in `wp.matrix()` to construct a 4x4 matrix from a 3D position, quaternion, "
1245
+ "and 3D scale vector has been removed in favor of `wp.transform_compose()`."
1246
+ )
1227
1247
 
1228
1248
 
1229
1249
  add_builtin(
@@ -1237,13 +1257,14 @@ add_builtin(
1237
1257
  defaults={"dtype": None},
1238
1258
  value_func=matrix_transform_value_func,
1239
1259
  export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
1240
- dispatch_func=matrix_transform_dispatch_func,
1241
1260
  native_func="mat_t",
1242
1261
  doc="""Construct a 4x4 transformation matrix that applies the transformations as
1243
1262
  Translation(pos)*Rotation(rot)*Scaling(scale) when applied to column vectors, i.e.: y = (TRS)*x
1244
1263
 
1245
- .. warning::
1246
- This function has been deprecated in favor of :func:`warp.math.transform_compose()`.""",
1264
+ .. versionremoved:: 1.10
1265
+ This function has been removed in favor of :func:`warp.math.transform_compose()`.
1266
+
1267
+ .. deprecated:: 1.8""",
1247
1268
  group="Vector Math",
1248
1269
  export=False,
1249
1270
  )
@@ -1438,7 +1459,7 @@ def quaternion_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str
1438
1459
 
1439
1460
  if dtype is None:
1440
1461
  dtype = value_type
1441
- elif not warp.types.scalars_equal(value_type, dtype):
1462
+ elif not warp._src.types.scalars_equal(value_type, dtype):
1442
1463
  raise RuntimeError(
1443
1464
  f"all values used to initialize this quaternion are expected to be of the type `{dtype.__name__}`"
1444
1465
  )
@@ -1546,6 +1567,7 @@ add_builtin(
1546
1567
  group="Quaternion Math",
1547
1568
  doc="Construct an identity quaternion with zero imaginary part and real part of 1.0",
1548
1569
  export=True,
1570
+ is_differentiable=False,
1549
1571
  )
1550
1572
 
1551
1573
  add_builtin(
@@ -1674,7 +1696,7 @@ def transformation_value_func(arg_types: Mapping[str, type], arg_values: Mapping
1674
1696
  value_type = strip_reference(variadic_arg_types[0])
1675
1697
  if dtype is None:
1676
1698
  dtype = value_type
1677
- elif not warp.types.scalars_equal(value_type, dtype):
1699
+ elif not warp._src.types.scalars_equal(value_type, dtype):
1678
1700
  raise RuntimeError(
1679
1701
  f"the value used to fill this transform is expected to be of the type `{dtype.__name__}`"
1680
1702
  )
@@ -1687,7 +1709,7 @@ def transformation_value_func(arg_types: Mapping[str, type], arg_values: Mapping
1687
1709
 
1688
1710
  if dtype is None:
1689
1711
  dtype = value_type
1690
- elif not warp.types.scalars_equal(value_type, dtype):
1712
+ elif not warp._src.types.scalars_equal(value_type, dtype):
1691
1713
  raise RuntimeError(
1692
1714
  f"all values used to initialize this transform are expected to be of the type `{dtype.__name__}`"
1693
1715
  )
@@ -1712,7 +1734,7 @@ def transformation_pq_value_func(arg_types: Mapping[str, type], arg_values: Mapp
1712
1734
  dtype = arg_values.get("dtype", None)
1713
1735
  if dtype is None:
1714
1736
  dtype = value_type
1715
- elif not warp.types.scalars_equal(value_type, dtype):
1737
+ elif not warp._src.types.scalars_equal(value_type, dtype):
1716
1738
  raise RuntimeError(
1717
1739
  f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
1718
1740
  )
@@ -1727,9 +1749,19 @@ def transformation_dispatch_func(input_types: Mapping[str, type], return_type: A
1727
1749
 
1728
1750
  dtype = return_type._wp_scalar_type_
1729
1751
 
1730
- variadic_args = tuple(v for k, v in args.items() if k != "dtype")
1752
+ variadic_args = args.get("args", ())
1753
+ variadic_arg_count = len(variadic_args)
1754
+
1755
+ if variadic_arg_count == 7:
1756
+ func_args = variadic_args
1757
+ else:
1758
+ func_args = tuple(v for k, v in args.items() if k != "dtype")
1759
+ if "p" in args and "q" not in args:
1760
+ quat_ident = warp._src.codegen.Var(
1761
+ label=None, type=quaternion(dtype=dtype), constant=quaternion(dtype=dtype)(0, 0, 0, 1)
1762
+ )
1763
+ func_args += (quat_ident,)
1731
1764
 
1732
- func_args = variadic_args
1733
1765
  template_args = (dtype,)
1734
1766
  return (func_args, template_args)
1735
1767
 
@@ -1737,7 +1769,7 @@ def transformation_dispatch_func(input_types: Mapping[str, type], return_type: A
1737
1769
  add_builtin(
1738
1770
  "transformation",
1739
1771
  input_types={"p": vector(length=3, dtype=Float), "q": quaternion(dtype=Float), "dtype": Float},
1740
- defaults={"dtype": None},
1772
+ defaults={"q": None, "dtype": None},
1741
1773
  value_func=transformation_pq_value_func,
1742
1774
  export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
1743
1775
  dispatch_func=transformation_dispatch_func,
@@ -1795,6 +1827,7 @@ add_builtin(
1795
1827
  group="Transformations",
1796
1828
  doc="Construct an identity transform with zero translation and identity rotation.",
1797
1829
  export=True,
1830
+ is_differentiable=False,
1798
1831
  )
1799
1832
 
1800
1833
  add_builtin(
@@ -1928,7 +1961,7 @@ def spatial_vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping
1928
1961
 
1929
1962
  if dtype is None:
1930
1963
  dtype = value_type
1931
- elif not warp.types.scalars_equal(value_type, dtype):
1964
+ elif not warp._src.types.scalars_equal(value_type, dtype):
1932
1965
  raise RuntimeError(
1933
1966
  f"all values used to initialize this spatial vector are expected to be of the type `{dtype.__name__}`"
1934
1967
  )
@@ -2122,7 +2155,7 @@ add_builtin(
2122
2155
  value_func=tile_zeros_value_func,
2123
2156
  dispatch_func=tile_zeros_dispatch_func,
2124
2157
  variadic=False,
2125
- missing_grad=True,
2158
+ is_differentiable=False,
2126
2159
  doc="""Allocate a tile of zero-initialized items.
2127
2160
 
2128
2161
  :param shape: Shape of the output tile
@@ -2142,7 +2175,7 @@ add_builtin(
2142
2175
  value_func=tile_zeros_value_func,
2143
2176
  dispatch_func=tile_zeros_dispatch_func,
2144
2177
  variadic=False,
2145
- missing_grad=True,
2178
+ is_differentiable=False,
2146
2179
  hidden=True,
2147
2180
  group="Tile Primitives",
2148
2181
  export=False,
@@ -2194,7 +2227,7 @@ add_builtin(
2194
2227
  defaults={"storage": "register"},
2195
2228
  value_func=tile_ones_value_func,
2196
2229
  dispatch_func=tile_ones_dispatch_func,
2197
- missing_grad=True,
2230
+ is_differentiable=False,
2198
2231
  doc="""Allocate a tile of one-initialized items.
2199
2232
 
2200
2233
  :param shape: Shape of the output tile
@@ -2213,7 +2246,86 @@ add_builtin(
2213
2246
  defaults={"storage": "register"},
2214
2247
  value_func=tile_ones_value_func,
2215
2248
  dispatch_func=tile_ones_dispatch_func,
2216
- missing_grad=True,
2249
+ is_differentiable=False,
2250
+ hidden=True,
2251
+ group="Tile Primitives",
2252
+ export=False,
2253
+ )
2254
+
2255
+
2256
+ def tile_full_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
2257
+ # return generic type (for doc builds)
2258
+ if arg_types is None:
2259
+ return tile(dtype=Any, shape=Tuple[int, ...])
2260
+
2261
+ shape = extract_tuple(arg_values["shape"], as_constant=True)
2262
+
2263
+ if None in shape:
2264
+ raise ValueError("Tile functions require shape to be a compile time constant.")
2265
+
2266
+ if "value" not in arg_values:
2267
+ raise TypeError("tile_full() missing required keyword argument 'value'")
2268
+
2269
+ if "dtype" not in arg_values:
2270
+ raise TypeError("tile_full() missing required keyword argument 'dtype'")
2271
+
2272
+ if "storage" not in arg_values:
2273
+ raise TypeError("tile_full() missing required keyword argument 'storage'")
2274
+
2275
+ if arg_values["storage"] not in {"shared", "register"}:
2276
+ raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
2277
+
2278
+ dtype = arg_values["dtype"]
2279
+
2280
+ return tile(dtype=dtype, shape=shape, storage=arg_values["storage"])
2281
+
2282
+
2283
+ def tile_full_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
2284
+ shape = extract_tuple(arg_values["shape"], as_constant=True)
2285
+
2286
+ if None in shape:
2287
+ raise ValueError("Tile functions require shape to be a compile time constant.")
2288
+
2289
+ dtype = arg_values["dtype"]
2290
+ value = arg_values["value"]
2291
+
2292
+ func_args = [value]
2293
+
2294
+ template_args = []
2295
+ template_args.append(dtype)
2296
+ template_args.extend(shape)
2297
+
2298
+ return (func_args, template_args)
2299
+
2300
+
2301
+ add_builtin(
2302
+ "tile_full",
2303
+ input_types={"shape": Tuple[int, ...], "value": Any, "dtype": Any, "storage": str},
2304
+ defaults={"storage": "register"},
2305
+ value_func=tile_full_value_func,
2306
+ dispatch_func=tile_full_dispatch_func,
2307
+ is_differentiable=False,
2308
+ doc="""Allocate a tile filled with the specified value.
2309
+
2310
+ :param shape: Shape of the output tile
2311
+ :param value: Value to fill the tile with
2312
+ :param dtype: Data type of output tile's elements
2313
+ :param storage: The storage location for the tile: ``"register"`` for registers
2314
+ (default) or ``"shared"`` for shared memory.
2315
+ :returns: A tile filled with the specified value""",
2316
+ group="Tile Primitives",
2317
+ export=False,
2318
+ )
2319
+
2320
+
2321
+ # overload for scalar shape
2322
+ add_builtin(
2323
+ "tile_full",
2324
+ input_types={"shape": int, "value": Any, "dtype": Any, "storage": str},
2325
+ defaults={"storage": "register"},
2326
+ value_func=tile_full_value_func,
2327
+ dispatch_func=tile_full_dispatch_func,
2328
+ is_differentiable=False,
2217
2329
  hidden=True,
2218
2330
  group="Tile Primitives",
2219
2331
  export=False,
@@ -2275,13 +2387,13 @@ def tile_arange_dispatch_func(arg_types: Mapping[str, type], return_type: Any, a
2275
2387
  args = arg_values["args"]
2276
2388
 
2277
2389
  if len(args) == 1:
2278
- start = warp.codegen.Var(label=None, type=return_type.dtype, constant=0)
2390
+ start = warp._src.codegen.Var(label=None, type=return_type.dtype, constant=0)
2279
2391
  stop = args[0]
2280
- step = warp.codegen.Var(label=None, type=return_type.dtype, constant=1)
2392
+ step = warp._src.codegen.Var(label=None, type=return_type.dtype, constant=1)
2281
2393
  elif len(args) == 2:
2282
2394
  start = args[0]
2283
2395
  stop = args[1]
2284
- step = warp.codegen.Var(label=None, type=return_type.dtype, constant=1)
2396
+ step = warp._src.codegen.Var(label=None, type=return_type.dtype, constant=1)
2285
2397
  elif len(args) == 3:
2286
2398
  start = args[0]
2287
2399
  stop = args[1]
@@ -2304,7 +2416,7 @@ add_builtin(
2304
2416
  value_func=tile_arange_value_func,
2305
2417
  dispatch_func=tile_arange_dispatch_func,
2306
2418
  variadic=True,
2307
- missing_grad=True,
2419
+ is_differentiable=False,
2308
2420
  doc="""Generate a tile of linearly spaced elements.
2309
2421
 
2310
2422
  :param args: Variable-length positional arguments, interpreted as:
@@ -3099,7 +3211,7 @@ add_builtin(
3099
3211
  :param shape: Shape of the returned slice
3100
3212
  :returns: A tile with dimensions given by the specified shape or the remaining source tile dimensions""",
3101
3213
  group="Tile Primitives",
3102
- missing_grad=True,
3214
+ is_differentiable=False,
3103
3215
  export=False,
3104
3216
  )
3105
3217
 
@@ -3346,7 +3458,32 @@ add_builtin(
3346
3458
 
3347
3459
  add_builtin(
3348
3460
  "assign",
3349
- input_types={"dst": tile(dtype=Any, shape=Tuple[int, int]), "i": int, "j": int, "src": Any},
3461
+ input_types={"dst": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "src": Any},
3462
+ value_func=tile_assign_value_func,
3463
+ group="Tile Primitives",
3464
+ export=False,
3465
+ hidden=True,
3466
+ )
3467
+
3468
+ add_builtin(
3469
+ "assign",
3470
+ input_types={"dst": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "src": Any},
3471
+ value_func=tile_assign_value_func,
3472
+ group="Tile Primitives",
3473
+ export=False,
3474
+ hidden=True,
3475
+ )
3476
+
3477
+ add_builtin(
3478
+ "assign",
3479
+ input_types={
3480
+ "dst": tile(dtype=Any, shape=Tuple[int, ...]),
3481
+ "i": int,
3482
+ "j": int,
3483
+ "k": int,
3484
+ "l": int,
3485
+ "src": Any,
3486
+ },
3350
3487
  value_func=tile_assign_value_func,
3351
3488
  group="Tile Primitives",
3352
3489
  export=False,
@@ -3355,7 +3492,15 @@ add_builtin(
3355
3492
 
3356
3493
  add_builtin(
3357
3494
  "assign",
3358
- input_types={"dst": tile(dtype=Any, shape=Tuple[int, int, int]), "i": int, "j": int, "k": int, "src": Any},
3495
+ input_types={
3496
+ "dst": tile(dtype=Any, shape=Tuple[int, ...]),
3497
+ "i": int,
3498
+ "j": int,
3499
+ "k": int,
3500
+ "l": int,
3501
+ "m": int,
3502
+ "src": Any,
3503
+ },
3359
3504
  value_func=tile_assign_value_func,
3360
3505
  group="Tile Primitives",
3361
3506
  export=False,
@@ -3370,6 +3515,8 @@ add_builtin(
3370
3515
  "j": int,
3371
3516
  "k": int,
3372
3517
  "l": int,
3518
+ "m": int,
3519
+ "n": int,
3373
3520
  "src": Any,
3374
3521
  },
3375
3522
  value_func=tile_assign_value_func,
@@ -3391,7 +3538,7 @@ def tile_value_func(arg_types, arg_values):
3391
3538
 
3392
3539
  if preserve_type:
3393
3540
  dtype = arg_types["x"]
3394
- shape = (warp.codegen.options["block_dim"],)
3541
+ shape = (warp._src.codegen.options["block_dim"],)
3395
3542
 
3396
3543
  return tile(dtype=dtype, shape=shape)
3397
3544
 
@@ -3399,18 +3546,18 @@ def tile_value_func(arg_types, arg_values):
3399
3546
  if type_is_vector(arg_types["x"]):
3400
3547
  dtype = arg_types["x"]._wp_scalar_type_
3401
3548
  length = arg_types["x"]._shape_[0]
3402
- shape = (length, warp.codegen.options["block_dim"])
3549
+ shape = (length, warp._src.codegen.options["block_dim"])
3403
3550
  elif type_is_quaternion(arg_types["x"]):
3404
3551
  dtype = arg_types["x"]._wp_scalar_type_
3405
- shape = (4, warp.codegen.options["block_dim"])
3552
+ shape = (4, warp._src.codegen.options["block_dim"])
3406
3553
  elif type_is_matrix(arg_types["x"]):
3407
3554
  dtype = arg_types["x"]._wp_scalar_type_
3408
3555
  rows = arg_types["x"]._shape_[0]
3409
3556
  cols = arg_types["x"]._shape_[1]
3410
- shape = (rows, cols, warp.codegen.options["block_dim"])
3557
+ shape = (rows, cols, warp._src.codegen.options["block_dim"])
3411
3558
  else:
3412
3559
  dtype = arg_types["x"]
3413
- shape = (warp.codegen.options["block_dim"],)
3560
+ shape = (warp._src.codegen.options["block_dim"],)
3414
3561
 
3415
3562
  return tile(dtype=dtype, shape=shape)
3416
3563
 
@@ -3500,17 +3647,17 @@ def untile_value_func(arg_types, arg_values):
3500
3647
  if not is_tile(t):
3501
3648
  raise TypeError(f"untile() argument must be a tile, got {t!r}")
3502
3649
 
3503
- if t.shape[-1] != warp.codegen.options["block_dim"]:
3650
+ if t.shape[-1] != warp._src.codegen.options["block_dim"]:
3504
3651
  raise ValueError(
3505
- f"untile() argument last dimension {t.shape[-1]} does not match the expected block width {warp.codegen.options['block_dim']}"
3652
+ f"untile() argument last dimension {t.shape[-1]} does not match the expected block width {warp._src.codegen.options['block_dim']}"
3506
3653
  )
3507
3654
 
3508
3655
  if len(t.shape) == 1:
3509
3656
  return t.dtype
3510
3657
  elif len(t.shape) == 2:
3511
- return warp.types.vector(t.shape[0], t.dtype)
3658
+ return warp._src.types.vector(t.shape[0], t.dtype)
3512
3659
  elif len(t.shape) == 3:
3513
- return warp.types.matrix((t.shape[0], t.shape[1]), t.dtype)
3660
+ return warp._src.types.matrix((t.shape[0], t.shape[1]), t.dtype)
3514
3661
  else:
3515
3662
  raise ValueError(f"untile() argument must have a positive size in dimension 0, but got {t.shape[0]}")
3516
3663
 
@@ -3572,7 +3719,36 @@ def tile_extract_value_func(arg_types, arg_values):
3572
3719
  # force the input tile to shared memory
3573
3720
  arg_types["a"].storage = "shared"
3574
3721
 
3575
- return arg_types["a"].dtype
3722
+ # count the number of indices (all parameters except the tile "a")
3723
+ num_indices = len(arg_types) - 1
3724
+ tile_dtype = arg_types["a"].dtype
3725
+ tile_shape = arg_types["a"].shape
3726
+
3727
+ if type_is_vector(tile_dtype):
3728
+ if num_indices == len(tile_shape):
3729
+ return tile_dtype
3730
+ elif num_indices == len(tile_shape) + 1:
3731
+ return tile_dtype._wp_scalar_type_
3732
+ else:
3733
+ raise IndexError(
3734
+ f"tile_extract: incorrect number of indices ({num_indices}) for tile shape {tuple(tile_shape)}"
3735
+ )
3736
+ elif type_is_matrix(tile_dtype):
3737
+ if num_indices == len(tile_shape):
3738
+ return tile_dtype
3739
+ elif num_indices == len(tile_shape) + 2:
3740
+ return tile_dtype._wp_scalar_type_
3741
+ else:
3742
+ raise IndexError(
3743
+ f"tile_extract: incorrect number of indices ({num_indices}) for matrix tile shape {tuple(tile_shape)}"
3744
+ )
3745
+ else:
3746
+ # scalar element: index count must exactly match tile rank
3747
+ if num_indices == len(tile_shape):
3748
+ return tile_dtype
3749
+ raise IndexError(
3750
+ f"tile_extract: incorrect number of indices ({num_indices}) for tile shape {tuple(tile_shape)}"
3751
+ )
3576
3752
 
3577
3753
 
3578
3754
  add_builtin(
@@ -3596,7 +3772,7 @@ add_builtin(
3596
3772
 
3597
3773
  add_builtin(
3598
3774
  "tile_extract",
3599
- input_types={"a": tile(dtype=Any, shape=Tuple[int, int]), "i": int, "j": int},
3775
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int},
3600
3776
  value_func=tile_extract_value_func,
3601
3777
  variadic=False,
3602
3778
  doc="""Extract a single element from the tile.
@@ -3607,7 +3783,28 @@ add_builtin(
3607
3783
 
3608
3784
  :param a: Tile to extract the element from
3609
3785
  :param i: Coordinate of element on first dimension
3610
- :param j: Coordinate of element on the second dimension
3786
+ :param j: Coordinate of element on the second dimension, or vector index
3787
+ :returns: The value of the element at the specified tile location with the same data type as the input tile""",
3788
+ group="Tile Primitives",
3789
+ hidden=True,
3790
+ export=False,
3791
+ )
3792
+
3793
+ add_builtin(
3794
+ "tile_extract",
3795
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int},
3796
+ value_func=tile_extract_value_func,
3797
+ variadic=False,
3798
+ doc="""Extract a single element from the tile.
3799
+
3800
+ This function will extract an element from the tile and broadcast its value to all threads in the block.
3801
+
3802
+ Note that this may incur additional synchronization if the source tile is a register tile.
3803
+
3804
+ :param a: Tile to extract the element from
3805
+ :param i: Coordinate of element on first dimension
3806
+ :param j: Coordinate of element on the second dimension, or first matrix index
3807
+ :param k: Coordinate of element on the third dimension, or vector index, or second matrix index
3611
3808
  :returns: The value of the element at the specified tile location with the same data type as the input tile""",
3612
3809
  group="Tile Primitives",
3613
3810
  hidden=True,
@@ -3616,7 +3813,36 @@ add_builtin(
3616
3813
 
3617
3814
  add_builtin(
3618
3815
  "tile_extract",
3619
- input_types={"a": tile(dtype=Any, shape=Tuple[int, int, int]), "i": int, "j": int, "k": int},
3816
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int},
3817
+ value_func=tile_extract_value_func,
3818
+ variadic=False,
3819
+ doc="""Extract a single element from the tile.
3820
+
3821
+ This function will extract an element from the tile and broadcast its value to all threads in the block.
3822
+
3823
+ Note that this may incur additional synchronization if the source tile is a register tile.
3824
+
3825
+ :param a: Tile to extract the element from
3826
+ :param i: Coordinate of element on first dimension
3827
+ :param j: Coordinate of element on the second dimension
3828
+ :param k: Coordinate of element on the third dimension, or first matrix index
3829
+ :param l: Coordinate of element on the fourth dimension, or vector index, or second matrix index
3830
+ :returns: The value of the element at the specified tile location, with the same data type as the input tile""",
3831
+ group="Tile Primitives",
3832
+ hidden=True,
3833
+ export=False,
3834
+ )
3835
+
3836
+ add_builtin(
3837
+ "tile_extract",
3838
+ input_types={
3839
+ "a": tile(dtype=Any, shape=Tuple[int, ...]),
3840
+ "i": int,
3841
+ "j": int,
3842
+ "k": int,
3843
+ "l": int,
3844
+ "m": int,
3845
+ },
3620
3846
  value_func=tile_extract_value_func,
3621
3847
  variadic=False,
3622
3848
  doc="""Extract a single element from the tile.
@@ -3629,7 +3855,9 @@ add_builtin(
3629
3855
  :param i: Coordinate of element on first dimension
3630
3856
  :param j: Coordinate of element on the second dimension
3631
3857
  :param k: Coordinate of element on the third dimension
3632
- :returns: The value of the element at the specified tile location with the same data type as the input tile""",
3858
+ :param l: Coordinate of element on the fourth dimension, or first matrix index
3859
+ :param m: Vector index, or second matrix index
3860
+ :returns: The value of the element at the specified tile location, with the same data type as the input tile""",
3633
3861
  group="Tile Primitives",
3634
3862
  hidden=True,
3635
3863
  export=False,
@@ -3637,7 +3865,15 @@ add_builtin(
3637
3865
 
3638
3866
  add_builtin(
3639
3867
  "tile_extract",
3640
- input_types={"a": tile(dtype=Any, shape=Tuple[int, int, int, int]), "i": int, "j": int, "k": int, "l": int},
3868
+ input_types={
3869
+ "a": tile(dtype=Any, shape=Tuple[int, int, int, int]),
3870
+ "i": int,
3871
+ "j": int,
3872
+ "k": int,
3873
+ "l": int,
3874
+ "m": int,
3875
+ "n": int,
3876
+ },
3641
3877
  value_func=tile_extract_value_func,
3642
3878
  variadic=False,
3643
3879
  doc="""Extract a single element from the tile.
@@ -3651,6 +3887,8 @@ add_builtin(
3651
3887
  :param j: Coordinate of element on the second dimension
3652
3888
  :param k: Coordinate of element on the third dimension
3653
3889
  :param l: Coordinate of element on the fourth dimension
3890
+ :param m: Vector index, or first matrix index
3891
+ :param n: Second matrix index
3654
3892
  :returns: The value of the element at the specified tile location, with the same data type as the input tile""",
3655
3893
  group="Tile Primitives",
3656
3894
  hidden=True,
@@ -3737,50 +3975,161 @@ add_builtin(
3737
3975
  export=False,
3738
3976
  )
3739
3977
 
3740
-
3741
- def tile_transpose_value_func(arg_types, arg_values):
3742
- # return generic type (for doc builds)
3743
- if arg_types is None:
3744
- return tile(dtype=Any, shape=Tuple[int, int])
3745
-
3746
- if len(arg_types) != 1:
3747
- raise TypeError(f"tile_transpose() takes exactly 1 positional argument but {len(arg_types)} were given")
3748
-
3749
- t = arg_types["a"]
3750
-
3751
- if not is_tile(t):
3752
- raise TypeError(f"tile_transpose() argument must be a tile, got {t!r}")
3753
-
3754
- layout = None
3755
-
3756
- # flip layout
3757
- if t.layout == "rowmajor":
3758
- layout = "colmajor"
3759
- elif t.layout == "colmajor":
3760
- layout = "rowmajor"
3761
-
3762
- # force the input tile to shared memory
3763
- t.storage = "shared"
3764
-
3765
- return tile(
3766
- dtype=t.dtype,
3767
- shape=t.shape[::-1],
3768
- storage=t.storage,
3769
- strides=t.strides[::-1],
3770
- layout=layout,
3771
- owner=False,
3772
- )
3773
-
3774
-
3775
3978
  add_builtin(
3776
- "tile_transpose",
3777
- input_types={"a": tile(dtype=Any, shape=Tuple[int, int])},
3778
- value_func=tile_transpose_value_func,
3779
- variadic=True,
3780
- doc="""Transpose a tile.
3781
-
3782
- For shared memory tiles, this operation will alias the input tile.
3783
- Register tiles will first be transferred to shared memory before transposition.
3979
+ "tile_bit_and_inplace",
3980
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
3981
+ value_func=tile_inplace_value_func,
3982
+ group="Tile Primitives",
3983
+ hidden=True,
3984
+ export=False,
3985
+ is_differentiable=False,
3986
+ )
3987
+ add_builtin(
3988
+ "tile_bit_and_inplace",
3989
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
3990
+ value_func=tile_inplace_value_func,
3991
+ group="Tile Primitives",
3992
+ hidden=True,
3993
+ export=False,
3994
+ is_differentiable=False,
3995
+ )
3996
+ add_builtin(
3997
+ "tile_bit_and_inplace",
3998
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
3999
+ value_func=tile_inplace_value_func,
4000
+ group="Tile Primitives",
4001
+ hidden=True,
4002
+ export=False,
4003
+ is_differentiable=False,
4004
+ )
4005
+ add_builtin(
4006
+ "tile_bit_and_inplace",
4007
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
4008
+ value_func=tile_inplace_value_func,
4009
+ group="Tile Primitives",
4010
+ hidden=True,
4011
+ export=False,
4012
+ is_differentiable=False,
4013
+ )
4014
+
4015
+ add_builtin(
4016
+ "tile_bit_or_inplace",
4017
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
4018
+ value_func=tile_inplace_value_func,
4019
+ group="Tile Primitives",
4020
+ hidden=True,
4021
+ export=False,
4022
+ is_differentiable=False,
4023
+ )
4024
+ add_builtin(
4025
+ "tile_bit_or_inplace",
4026
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
4027
+ value_func=tile_inplace_value_func,
4028
+ group="Tile Primitives",
4029
+ hidden=True,
4030
+ export=False,
4031
+ is_differentiable=False,
4032
+ )
4033
+ add_builtin(
4034
+ "tile_bit_or_inplace",
4035
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
4036
+ value_func=tile_inplace_value_func,
4037
+ group="Tile Primitives",
4038
+ hidden=True,
4039
+ export=False,
4040
+ is_differentiable=False,
4041
+ )
4042
+ add_builtin(
4043
+ "tile_bit_or_inplace",
4044
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
4045
+ value_func=tile_inplace_value_func,
4046
+ group="Tile Primitives",
4047
+ hidden=True,
4048
+ export=False,
4049
+ is_differentiable=False,
4050
+ )
4051
+
4052
+ add_builtin(
4053
+ "tile_bit_xor_inplace",
4054
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
4055
+ value_func=tile_inplace_value_func,
4056
+ group="Tile Primitives",
4057
+ hidden=True,
4058
+ export=False,
4059
+ is_differentiable=False,
4060
+ )
4061
+ add_builtin(
4062
+ "tile_bit_xor_inplace",
4063
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
4064
+ value_func=tile_inplace_value_func,
4065
+ group="Tile Primitives",
4066
+ hidden=True,
4067
+ export=False,
4068
+ is_differentiable=False,
4069
+ )
4070
+ add_builtin(
4071
+ "tile_bit_xor_inplace",
4072
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
4073
+ value_func=tile_inplace_value_func,
4074
+ group="Tile Primitives",
4075
+ hidden=True,
4076
+ export=False,
4077
+ is_differentiable=False,
4078
+ )
4079
+ add_builtin(
4080
+ "tile_bit_xor_inplace",
4081
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
4082
+ value_func=tile_inplace_value_func,
4083
+ group="Tile Primitives",
4084
+ hidden=True,
4085
+ export=False,
4086
+ is_differentiable=False,
4087
+ )
4088
+
4089
+
4090
+ def tile_transpose_value_func(arg_types, arg_values):
4091
+ # return generic type (for doc builds)
4092
+ if arg_types is None:
4093
+ return tile(dtype=Any, shape=Tuple[int, int])
4094
+
4095
+ if len(arg_types) != 1:
4096
+ raise TypeError(f"tile_transpose() takes exactly 1 positional argument but {len(arg_types)} were given")
4097
+
4098
+ t = arg_types["a"]
4099
+
4100
+ if not is_tile(t):
4101
+ raise TypeError(f"tile_transpose() argument must be a tile, got {t!r}")
4102
+
4103
+ layout = None
4104
+
4105
+ # flip layout
4106
+ if t.layout == "rowmajor":
4107
+ layout = "colmajor"
4108
+ elif t.layout == "colmajor":
4109
+ layout = "rowmajor"
4110
+
4111
+ # force the input tile to shared memory
4112
+ t.storage = "shared"
4113
+
4114
+ return tile(
4115
+ dtype=t.dtype,
4116
+ shape=t.shape[::-1],
4117
+ storage=t.storage,
4118
+ strides=t.strides[::-1],
4119
+ layout=layout,
4120
+ owner=False,
4121
+ )
4122
+
4123
+
4124
+ add_builtin(
4125
+ "tile_transpose",
4126
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, int])},
4127
+ value_func=tile_transpose_value_func,
4128
+ variadic=True,
4129
+ doc="""Transpose a tile.
4130
+
4131
+ For shared memory tiles, this operation will alias the input tile.
4132
+ Register tiles will first be transferred to shared memory before transposition.
3784
4133
 
3785
4134
  :param a: Tile to transpose with ``shape=(M,N)``
3786
4135
  :returns: Tile with ``shape=(N,M)``""",
@@ -3910,6 +4259,80 @@ add_builtin(
3910
4259
  )
3911
4260
 
3912
4261
 
4262
+ def tile_sum_axis_value_func(arg_types, arg_values):
4263
+ if arg_types is None:
4264
+ return tile(dtype=Scalar, shape=Tuple[int, ...])
4265
+
4266
+ a = arg_types["a"]
4267
+
4268
+ if not is_tile(a):
4269
+ raise TypeError(f"tile_sum() 'a' argument must be a tile, got {a!r}")
4270
+
4271
+ # force input tile to shared
4272
+ a.storage = "shared"
4273
+
4274
+ axis = arg_values["axis"]
4275
+ shape = a.shape
4276
+
4277
+ if axis < 0 or axis >= len(shape):
4278
+ raise ValueError(f"tile_sum() axis {axis} is out of bounds for tile with {len(shape)} dimensions")
4279
+
4280
+ # shape is identical less the axis reduction is along
4281
+ if len(shape) > 1:
4282
+ new_shape = shape[:axis] + shape[axis + 1 :]
4283
+ else:
4284
+ new_shape = (1,)
4285
+
4286
+ return tile(dtype=a.dtype, shape=new_shape)
4287
+
4288
+
4289
+ def tile_sum_axis_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
4290
+ tile = arg_values["a"]
4291
+ axis_var = arg_values["axis"]
4292
+ if not hasattr(axis_var, "constant") or axis_var.constant is None:
4293
+ raise ValueError("tile_sum() axis must be a compile-time constant")
4294
+ axis = axis_var.constant
4295
+
4296
+ return ((tile,), (axis,))
4297
+
4298
+
4299
+ add_builtin(
4300
+ "tile_sum",
4301
+ input_types={"a": tile(dtype=Scalar, shape=Tuple[int, ...]), "axis": int},
4302
+ value_func=tile_sum_axis_value_func,
4303
+ dispatch_func=tile_sum_axis_dispatch_func,
4304
+ doc="""Cooperatively compute the sum of the tile elements across an axis of the tile using all threads in the block.
4305
+
4306
+ :param a: The input tile. Must reside in shared memory.
4307
+ :param axis: The tile axis to compute the sum across. Must be a compile-time constant.
4308
+ :returns: A tile with the same shape as the input tile less the axis dimension and the same data type as the input tile.
4309
+
4310
+ Example:
4311
+
4312
+ .. code-block:: python
4313
+
4314
+ @wp.kernel
4315
+ def compute():
4316
+
4317
+ t = wp.tile_ones(dtype=float, shape=(8, 8))
4318
+ s = wp.tile_sum(t, axis=0)
4319
+
4320
+ print(s)
4321
+
4322
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
4323
+
4324
+ Prints:
4325
+
4326
+ .. code-block:: text
4327
+
4328
+ [8 8 8 8 8 8 8 8] = tile(shape=(8), storage=register)
4329
+
4330
+ """,
4331
+ group="Tile Primitives",
4332
+ export=False,
4333
+ )
4334
+
4335
+
3913
4336
  def tile_sort_value_func(arg_types, arg_values):
3914
4337
  # return generic type (for doc builds)
3915
4338
  if arg_types is None:
@@ -3986,6 +4409,7 @@ add_builtin(
3986
4409
  """,
3987
4410
  group="Tile Primitives",
3988
4411
  export=False,
4412
+ is_differentiable=False,
3989
4413
  )
3990
4414
 
3991
4415
 
@@ -4039,6 +4463,7 @@ add_builtin(
4039
4463
  """,
4040
4464
  group="Tile Primitives",
4041
4465
  export=False,
4466
+ is_differentiable=False,
4042
4467
  )
4043
4468
 
4044
4469
 
@@ -4092,6 +4517,7 @@ add_builtin(
4092
4517
  """,
4093
4518
  group="Tile Primitives",
4094
4519
  export=False,
4520
+ is_differentiable=False,
4095
4521
  )
4096
4522
 
4097
4523
 
@@ -4144,6 +4570,7 @@ add_builtin(
4144
4570
  """,
4145
4571
  group="Tile Primitives",
4146
4572
  export=False,
4573
+ is_differentiable=False,
4147
4574
  )
4148
4575
 
4149
4576
 
@@ -4196,10 +4623,10 @@ add_builtin(
4196
4623
  """,
4197
4624
  group="Tile Primitives",
4198
4625
  export=False,
4626
+ is_differentiable=False,
4199
4627
  )
4200
4628
 
4201
4629
 
4202
- # does type propagation for load()
4203
4630
  def tile_reduce_value_func(arg_types, arg_values):
4204
4631
  if arg_types is None:
4205
4632
  return tile(dtype=Scalar, shape=(1,))
@@ -4253,6 +4680,88 @@ add_builtin(
4253
4680
  """,
4254
4681
  group="Tile Primitives",
4255
4682
  export=False,
4683
+ is_differentiable=False,
4684
+ )
4685
+
4686
+
4687
+ def tile_reduce_axis_value_func(arg_types, arg_values):
4688
+ if arg_types is None:
4689
+ return tile(dtype=Scalar, shape=Tuple[int, ...])
4690
+
4691
+ a = arg_types["a"]
4692
+
4693
+ if not is_tile(a):
4694
+ raise TypeError(f"tile_reduce() 'a' argument must be a tile, got {a!r}")
4695
+
4696
+ # force input tile to shared memory
4697
+ a.storage = "shared"
4698
+
4699
+ axis = arg_values["axis"]
4700
+ shape = a.shape
4701
+
4702
+ if axis < 0 or axis >= len(shape):
4703
+ raise ValueError(f"tile_reduce() axis {axis} is out of bounds for tile with {len(shape)} dimensions")
4704
+
4705
+ # shape is identical less the axis reduction is along
4706
+ if len(shape) > 1:
4707
+ new_shape = shape[:axis] + shape[axis + 1 :]
4708
+ else:
4709
+ new_shape = (1,)
4710
+
4711
+ return tile(dtype=a.dtype, shape=new_shape)
4712
+
4713
+
4714
+ add_builtin(
4715
+ "tile_reduce",
4716
+ input_types={"op": Callable, "a": tile(dtype=Scalar, shape=Tuple[int, ...]), "axis": int},
4717
+ value_func=tile_reduce_axis_value_func,
4718
+ native_func="tile_reduce_axis",
4719
+ doc="""Apply a custom reduction operator across a tile axis.
4720
+
4721
+ This function cooperatively performs a reduction using the provided operator across an axis of the tile.
4722
+
4723
+ :param op: A callable function that accepts two arguments and returns one argument, may be a user function or builtin
4724
+ :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's data type. Must reside in shared memory.
4725
+ :param axis: The tile axis to perform the reduction across. Must be a compile-time constant.
4726
+ :returns: A tile with the same shape as the input tile less the axis dimension and the same data type as the input tile.
4727
+
4728
+ Example:
4729
+
4730
+ .. code-block:: python
4731
+
4732
+ TILE_M = wp.constant(4)
4733
+ TILE_N = wp.constant(2)
4734
+
4735
+ @wp.kernel
4736
+ def compute(x: wp.array2d(dtype=float), y: wp.array(dtype=float)):
4737
+
4738
+ a = wp.tile_load(x, shape=(TILE_M, TILE_N))
4739
+ b = wp.tile_reduce(wp.add, a, axis=1)
4740
+ wp.tile_store(y, b)
4741
+
4742
+ arr = np.arange(TILE_M * TILE_N).reshape(TILE_M, TILE_N)
4743
+
4744
+ x = wp.array(arr, dtype=float)
4745
+ y = wp.zeros(TILE_M, dtype=float)
4746
+
4747
+ wp.launch_tiled(compute, dim=[1], inputs=[x], outputs=[y], block_dim=32)
4748
+
4749
+ print(x.numpy())
4750
+ print(y.numpy())
4751
+
4752
+ Prints:
4753
+
4754
+ .. code-block:: text
4755
+
4756
+ [[0. 1.]
4757
+ [2. 3.]
4758
+ [4. 5.]
4759
+ [6. 7.]]
4760
+ [ 1. 5. 9. 13.]
4761
+ """,
4762
+ group="Tile Primitives",
4763
+ export=False,
4764
+ is_differentiable=False,
4256
4765
  )
4257
4766
 
4258
4767
 
@@ -4316,6 +4825,7 @@ add_builtin(
4316
4825
  """,
4317
4826
  group="Tile Primitives",
4318
4827
  export=False,
4828
+ is_differentiable=False,
4319
4829
  )
4320
4830
 
4321
4831
 
@@ -4379,6 +4889,7 @@ add_builtin(
4379
4889
  """,
4380
4890
  group="Tile Primitives",
4381
4891
  export=False,
4892
+ is_differentiable=False,
4382
4893
  )
4383
4894
 
4384
4895
 
@@ -4632,6 +5143,7 @@ add_builtin(
4632
5143
  doc="WIP",
4633
5144
  group="Utility",
4634
5145
  hidden=True,
5146
+ is_differentiable=False,
4635
5147
  )
4636
5148
 
4637
5149
  add_builtin(
@@ -4647,6 +5159,7 @@ add_builtin(
4647
5159
  doc="WIP",
4648
5160
  group="Utility",
4649
5161
  hidden=True,
5162
+ is_differentiable=False,
4650
5163
  )
4651
5164
 
4652
5165
  add_builtin(
@@ -4656,6 +5169,7 @@ add_builtin(
4656
5169
  doc="WIP",
4657
5170
  group="Utility",
4658
5171
  hidden=True,
5172
+ is_differentiable=False,
4659
5173
  )
4660
5174
 
4661
5175
  add_builtin(
@@ -4707,6 +5221,7 @@ add_builtin(
4707
5221
  :param low: The lower bound of the bounding box in BVH space
4708
5222
  :param high: The upper bound of the bounding box in BVH space""",
4709
5223
  export=False,
5224
+ is_differentiable=False,
4710
5225
  )
4711
5226
 
4712
5227
  add_builtin(
@@ -4722,6 +5237,7 @@ add_builtin(
4722
5237
  :param start: The start of the ray in BVH space
4723
5238
  :param dir: The direction of the ray in BVH space""",
4724
5239
  export=False,
5240
+ is_differentiable=False,
4725
5241
  )
4726
5242
 
4727
5243
  add_builtin(
@@ -4732,6 +5248,7 @@ add_builtin(
4732
5248
  doc="""Move to the next bound returned by the query.
4733
5249
  The index of the current bound is stored in ``index``, returns ``False`` if there are no more overlapping bound.""",
4734
5250
  export=False,
5251
+ is_differentiable=False,
4735
5252
  )
4736
5253
 
4737
5254
  add_builtin(
@@ -5066,12 +5583,13 @@ add_builtin(
5066
5583
  group="Geometry",
5067
5584
  doc="""Construct an axis-aligned bounding box query against a :class:`Mesh`.
5068
5585
 
5069
- This query can be used to iterate over all triangles inside a volume.
5586
+ This query can be used to iterate over all bounding boxes of the triangles inside a volume.
5070
5587
 
5071
5588
  :param id: The mesh identifier
5072
5589
  :param low: The lower bound of the bounding box in mesh space
5073
5590
  :param high: The upper bound of the bounding box in mesh space""",
5074
5591
  export=False,
5592
+ is_differentiable=False,
5075
5593
  )
5076
5594
 
5077
5595
  add_builtin(
@@ -5079,10 +5597,11 @@ add_builtin(
5079
5597
  input_types={"query": MeshQueryAABB, "index": int},
5080
5598
  value_type=builtins.bool,
5081
5599
  group="Geometry",
5082
- doc="""Move to the next triangle overlapping the query bounding box.
5600
+ doc="""Move to the next triangle whose bounding box overlaps the query bounding box.
5083
5601
 
5084
5602
  The index of the current face is stored in ``index``, returns ``False`` if there are no more overlapping triangles.""",
5085
5603
  export=False,
5604
+ is_differentiable=False,
5086
5605
  )
5087
5606
 
5088
5607
  add_builtin(
@@ -5112,6 +5631,7 @@ add_builtin(
5112
5631
 
5113
5632
  This query can be used to iterate over all neighboring point within a fixed radius from the query point.""",
5114
5633
  export=False,
5634
+ is_differentiable=False,
5115
5635
  )
5116
5636
 
5117
5637
  add_builtin(
@@ -5123,6 +5643,7 @@ add_builtin(
5123
5643
 
5124
5644
  The index of the current neighbor is stored in ``index``, returns ``False`` if there are no more neighbors.""",
5125
5645
  export=False,
5646
+ is_differentiable=False,
5126
5647
  )
5127
5648
 
5128
5649
  add_builtin(
@@ -5136,6 +5657,7 @@ add_builtin(
5136
5657
 
5137
5658
  Returns -1 if the :class:`HashGrid` has not been reserved.""",
5138
5659
  export=False,
5660
+ is_differentiable=False,
5139
5661
  )
5140
5662
 
5141
5663
  add_builtin(
@@ -5145,15 +5667,34 @@ add_builtin(
5145
5667
  group="Geometry",
5146
5668
  doc="""Tests for intersection between two triangles (v0, v1, v2) and (u0, u1, u2) using Moller's method.
5147
5669
 
5670
+ This function works with single precision, may return incorrect results in some case.
5671
+
5148
5672
  Returns > 0 if triangles intersect.""",
5149
5673
  export=False,
5674
+ is_differentiable=False,
5150
5675
  )
5151
5676
 
5677
+
5678
+ add_builtin(
5679
+ "intersect_tri_tri",
5680
+ input_types={"v0": vec3d, "v1": vec3d, "v2": vec3d, "u0": vec3d, "u1": vec3d, "u2": vec3d},
5681
+ value_type=int,
5682
+ group="Geometry",
5683
+ doc="""Tests for intersection between two triangles (v0, v1, v2) and (u0, u1, u2) using Moller's method.
5684
+
5685
+ This function works with double precision, results are more accurate than the single precision version.
5686
+
5687
+ Returns > 0 if triangles intersect.""",
5688
+ export=False,
5689
+ is_differentiable=False,
5690
+ )
5691
+
5692
+
5152
5693
  add_builtin(
5153
5694
  "mesh_get",
5154
5695
  input_types={"id": uint64},
5155
5696
  value_type=Mesh,
5156
- missing_grad=True,
5697
+ is_differentiable=False,
5157
5698
  group="Geometry",
5158
5699
  doc="""Retrieves the mesh given its index.""",
5159
5700
  export=False,
@@ -5166,6 +5707,7 @@ add_builtin(
5166
5707
  group="Geometry",
5167
5708
  doc="""Evaluates the face normal the mesh given a face index.""",
5168
5709
  export=False,
5710
+ is_differentiable=False,
5169
5711
  )
5170
5712
 
5171
5713
  add_builtin(
@@ -5175,6 +5717,7 @@ add_builtin(
5175
5717
  group="Geometry",
5176
5718
  doc="""Returns the point of the mesh given a index.""",
5177
5719
  export=False,
5720
+ is_differentiable=False,
5178
5721
  )
5179
5722
 
5180
5723
  add_builtin(
@@ -5184,6 +5727,7 @@ add_builtin(
5184
5727
  group="Geometry",
5185
5728
  doc="""Returns the velocity of the mesh given a index.""",
5186
5729
  export=False,
5730
+ is_differentiable=False,
5187
5731
  )
5188
5732
 
5189
5733
  add_builtin(
@@ -5193,6 +5737,7 @@ add_builtin(
5193
5737
  group="Geometry",
5194
5738
  doc="""Returns the point-index of the mesh given a face-vertex index.""",
5195
5739
  export=False,
5740
+ is_differentiable=False,
5196
5741
  )
5197
5742
 
5198
5743
 
@@ -5233,12 +5778,32 @@ add_builtin(
5233
5778
  # ---------------------------------
5234
5779
  # Iterators
5235
5780
 
5236
- add_builtin("iter_next", input_types={"range": range_t}, value_type=int, group="Utility", export=False, hidden=True)
5237
5781
  add_builtin(
5238
- "iter_next", input_types={"query": HashGridQuery}, value_type=int, group="Utility", export=False, hidden=True
5782
+ "iter_next",
5783
+ input_types={"range": range_t},
5784
+ value_type=int,
5785
+ group="Utility",
5786
+ export=False,
5787
+ hidden=True,
5788
+ is_differentiable=False,
5789
+ )
5790
+ add_builtin(
5791
+ "iter_next",
5792
+ input_types={"query": HashGridQuery},
5793
+ value_type=int,
5794
+ group="Utility",
5795
+ export=False,
5796
+ hidden=True,
5797
+ is_differentiable=False,
5239
5798
  )
5240
5799
  add_builtin(
5241
- "iter_next", input_types={"query": MeshQueryAABB}, value_type=int, group="Utility", export=False, hidden=True
5800
+ "iter_next",
5801
+ input_types={"query": MeshQueryAABB},
5802
+ value_type=int,
5803
+ group="Utility",
5804
+ export=False,
5805
+ hidden=True,
5806
+ is_differentiable=False,
5242
5807
  )
5243
5808
 
5244
5809
  add_builtin(
@@ -5249,6 +5814,7 @@ add_builtin(
5249
5814
  group="Utility",
5250
5815
  doc="""Returns the range in reversed order.""",
5251
5816
  export=False,
5817
+ is_differentiable=False,
5252
5818
  )
5253
5819
 
5254
5820
  # ---------------------------------
@@ -5268,8 +5834,8 @@ _volume_supported_value_types = {
5268
5834
 
5269
5835
 
5270
5836
  def _is_volume_type_supported(dtype):
5271
- for typ in _volume_supported_value_types:
5272
- if types_equal(typ, dtype):
5837
+ for value_type in _volume_supported_value_types:
5838
+ if types_equal(value_type, dtype):
5273
5839
  return True
5274
5840
  return False
5275
5841
 
@@ -5397,6 +5963,7 @@ add_builtin(
5397
5963
  doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k`` for a volume of type type `dtype`.
5398
5964
 
5399
5965
  If the voxel at this index does not exist, this function returns the background value.""",
5966
+ is_differentiable=False,
5400
5967
  )
5401
5968
 
5402
5969
 
@@ -5417,6 +5984,7 @@ add_builtin(
5417
5984
  export=False,
5418
5985
  group="Volumes",
5419
5986
  doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
5987
+ is_differentiable=False,
5420
5988
  )
5421
5989
 
5422
5990
  add_builtin(
@@ -5447,6 +6015,7 @@ add_builtin(
5447
6015
  doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k``.
5448
6016
 
5449
6017
  If the voxel at this index does not exist, this function returns the background value""",
6018
+ is_differentiable=False,
5450
6019
  )
5451
6020
 
5452
6021
  add_builtin(
@@ -5455,6 +6024,7 @@ add_builtin(
5455
6024
  group="Volumes",
5456
6025
  doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
5457
6026
  export=False,
6027
+ is_differentiable=False,
5458
6028
  )
5459
6029
 
5460
6030
  add_builtin(
@@ -5475,6 +6045,7 @@ add_builtin(
5475
6045
  doc="""Returns the vector value of voxel with coordinates ``i``, ``j``, ``k``.
5476
6046
 
5477
6047
  If the voxel at this index does not exist, this function returns the background value.""",
6048
+ is_differentiable=False,
5478
6049
  )
5479
6050
 
5480
6051
  add_builtin(
@@ -5483,6 +6054,7 @@ add_builtin(
5483
6054
  group="Volumes",
5484
6055
  doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
5485
6056
  export=False,
6057
+ is_differentiable=False,
5486
6058
  )
5487
6059
 
5488
6060
  add_builtin(
@@ -5501,6 +6073,7 @@ add_builtin(
5501
6073
  doc="""Returns the :class:`int32` value of voxel with coordinates ``i``, ``j``, ``k``.
5502
6074
 
5503
6075
  If the voxel at this index does not exist, this function returns the background value.""",
6076
+ is_differentiable=False,
5504
6077
  )
5505
6078
 
5506
6079
  add_builtin(
@@ -5509,6 +6082,7 @@ add_builtin(
5509
6082
  group="Volumes",
5510
6083
  doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
5511
6084
  export=False,
6085
+ is_differentiable=False,
5512
6086
  )
5513
6087
 
5514
6088
 
@@ -5590,6 +6164,7 @@ add_builtin(
5590
6164
  If the voxel at this index does not exist, this function returns -1.
5591
6165
  This function is available for both index grids and classical volumes.
5592
6166
  """,
6167
+ is_differentiable=False,
5593
6168
  )
5594
6169
 
5595
6170
  add_builtin(
@@ -5631,6 +6206,7 @@ add_builtin(
5631
6206
  value_type=uint32,
5632
6207
  group="Random",
5633
6208
  doc="Initialize a new random number generator given a user-defined seed. Returns a 32-bit integer representing the RNG state.",
6209
+ is_differentiable=False,
5634
6210
  )
5635
6211
 
5636
6212
  add_builtin(
@@ -5642,6 +6218,7 @@ add_builtin(
5642
6218
 
5643
6219
  This alternative constructor can be useful in parallel programs, where a kernel as a whole should share a seed,
5644
6220
  but each thread should generate uncorrelated values. In this case usage should be ``r = rand_init(seed, tid)``""",
6221
+ is_differentiable=False,
5645
6222
  )
5646
6223
 
5647
6224
  add_builtin(
@@ -5650,6 +6227,7 @@ add_builtin(
5650
6227
  value_type=int,
5651
6228
  group="Random",
5652
6229
  doc="Return a random integer in the range [-2^31, 2^31).",
6230
+ is_differentiable=False,
5653
6231
  )
5654
6232
  add_builtin(
5655
6233
  "randi",
@@ -5657,6 +6235,7 @@ add_builtin(
5657
6235
  value_type=int,
5658
6236
  group="Random",
5659
6237
  doc="Return a random integer between [low, high).",
6238
+ is_differentiable=False,
5660
6239
  )
5661
6240
  add_builtin(
5662
6241
  "randu",
@@ -5664,6 +6243,7 @@ add_builtin(
5664
6243
  value_type=uint32,
5665
6244
  group="Random",
5666
6245
  doc="Return a random unsigned integer in the range [0, 2^32).",
6246
+ is_differentiable=False,
5667
6247
  )
5668
6248
  add_builtin(
5669
6249
  "randu",
@@ -5671,6 +6251,7 @@ add_builtin(
5671
6251
  value_type=uint32,
5672
6252
  group="Random",
5673
6253
  doc="Return a random unsigned integer between [low, high).",
6254
+ is_differentiable=False,
5674
6255
  )
5675
6256
  add_builtin(
5676
6257
  "randf",
@@ -5678,6 +6259,7 @@ add_builtin(
5678
6259
  value_type=float,
5679
6260
  group="Random",
5680
6261
  doc="Return a random float between [0.0, 1.0).",
6262
+ is_differentiable=False,
5681
6263
  )
5682
6264
  add_builtin(
5683
6265
  "randf",
@@ -5685,6 +6267,7 @@ add_builtin(
5685
6267
  value_type=float,
5686
6268
  group="Random",
5687
6269
  doc="Return a random float between [low, high).",
6270
+ is_differentiable=False,
5688
6271
  )
5689
6272
  add_builtin(
5690
6273
  "randn",
@@ -5692,6 +6275,7 @@ add_builtin(
5692
6275
  value_type=float,
5693
6276
  group="Random",
5694
6277
  doc="Sample a normal (Gaussian) distribution of mean 0 and variance 1. ",
6278
+ is_differentiable=False,
5695
6279
  )
5696
6280
 
5697
6281
  add_builtin(
@@ -5700,6 +6284,7 @@ add_builtin(
5700
6284
  value_type=int,
5701
6285
  group="Random",
5702
6286
  doc="Inverse-transform sample a cumulative distribution function.",
6287
+ is_differentiable=False,
5703
6288
  )
5704
6289
  add_builtin(
5705
6290
  "sample_triangle",
@@ -5707,6 +6292,7 @@ add_builtin(
5707
6292
  value_type=vec2,
5708
6293
  group="Random",
5709
6294
  doc="Uniformly sample a triangle. Returns sample barycentric coordinates.",
6295
+ is_differentiable=False,
5710
6296
  )
5711
6297
  add_builtin(
5712
6298
  "sample_unit_ring",
@@ -5714,6 +6300,7 @@ add_builtin(
5714
6300
  value_type=vec2,
5715
6301
  group="Random",
5716
6302
  doc="Uniformly sample a ring in the xy plane.",
6303
+ is_differentiable=False,
5717
6304
  )
5718
6305
  add_builtin(
5719
6306
  "sample_unit_disk",
@@ -5721,6 +6308,7 @@ add_builtin(
5721
6308
  value_type=vec2,
5722
6309
  group="Random",
5723
6310
  doc="Uniformly sample a disk in the xy plane.",
6311
+ is_differentiable=False,
5724
6312
  )
5725
6313
  add_builtin(
5726
6314
  "sample_unit_sphere_surface",
@@ -5728,6 +6316,7 @@ add_builtin(
5728
6316
  value_type=vec3,
5729
6317
  group="Random",
5730
6318
  doc="Uniformly sample a unit sphere surface.",
6319
+ is_differentiable=False,
5731
6320
  )
5732
6321
  add_builtin(
5733
6322
  "sample_unit_sphere",
@@ -5735,6 +6324,7 @@ add_builtin(
5735
6324
  value_type=vec3,
5736
6325
  group="Random",
5737
6326
  doc="Uniformly sample a unit sphere.",
6327
+ is_differentiable=False,
5738
6328
  )
5739
6329
  add_builtin(
5740
6330
  "sample_unit_hemisphere_surface",
@@ -5742,6 +6332,7 @@ add_builtin(
5742
6332
  value_type=vec3,
5743
6333
  group="Random",
5744
6334
  doc="Uniformly sample a unit hemisphere surface.",
6335
+ is_differentiable=False,
5745
6336
  )
5746
6337
  add_builtin(
5747
6338
  "sample_unit_hemisphere",
@@ -5749,6 +6340,7 @@ add_builtin(
5749
6340
  value_type=vec3,
5750
6341
  group="Random",
5751
6342
  doc="Uniformly sample a unit hemisphere.",
6343
+ is_differentiable=False,
5752
6344
  )
5753
6345
  add_builtin(
5754
6346
  "sample_unit_square",
@@ -5756,6 +6348,7 @@ add_builtin(
5756
6348
  value_type=vec2,
5757
6349
  group="Random",
5758
6350
  doc="Uniformly sample a unit square.",
6351
+ is_differentiable=False,
5759
6352
  )
5760
6353
  add_builtin(
5761
6354
  "sample_unit_cube",
@@ -5763,6 +6356,7 @@ add_builtin(
5763
6356
  value_type=vec3,
5764
6357
  group="Random",
5765
6358
  doc="Uniformly sample a unit cube.",
6359
+ is_differentiable=False,
5766
6360
  )
5767
6361
 
5768
6362
  add_builtin(
@@ -5774,6 +6368,7 @@ add_builtin(
5774
6368
 
5775
6369
  :param state: RNG state
5776
6370
  :param lam: The expected value of the distribution""",
6371
+ is_differentiable=False,
5777
6372
  )
5778
6373
 
5779
6374
  add_builtin(
@@ -5841,7 +6436,7 @@ add_builtin(
5841
6436
  value_type=vec2,
5842
6437
  group="Random",
5843
6438
  doc="Divergence-free vector field based on the gradient of a Perlin noise function.",
5844
- missing_grad=True,
6439
+ is_differentiable=False,
5845
6440
  )
5846
6441
  add_builtin(
5847
6442
  "curlnoise",
@@ -5850,7 +6445,7 @@ add_builtin(
5850
6445
  value_type=vec3,
5851
6446
  group="Random",
5852
6447
  doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
5853
- missing_grad=True,
6448
+ is_differentiable=False,
5854
6449
  )
5855
6450
  add_builtin(
5856
6451
  "curlnoise",
@@ -5859,7 +6454,7 @@ add_builtin(
5859
6454
  value_type=vec3,
5860
6455
  group="Random",
5861
6456
  doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
5862
- missing_grad=True,
6457
+ is_differentiable=False,
5863
6458
  )
5864
6459
 
5865
6460
 
@@ -5891,9 +6486,16 @@ add_builtin(
5891
6486
  dispatch_func=printf_dispatch_func,
5892
6487
  group="Utility",
5893
6488
  doc="Allows printing formatted strings using C-style format specifiers.",
6489
+ is_differentiable=False,
5894
6490
  )
5895
6491
 
5896
- add_builtin("print", input_types={"value": Any}, doc="Print variable to stdout", export=False, group="Utility")
6492
+ add_builtin(
6493
+ "print",
6494
+ input_types={"value": Any},
6495
+ doc="Print variable to stdout",
6496
+ export=False,
6497
+ group="Utility",
6498
+ )
5897
6499
 
5898
6500
  add_builtin(
5899
6501
  "breakpoint",
@@ -5903,6 +6505,7 @@ add_builtin(
5903
6505
  group="Utility",
5904
6506
  namespace="",
5905
6507
  native_func="__debugbreak",
6508
+ is_differentiable=False,
5906
6509
  )
5907
6510
 
5908
6511
  # helpers
@@ -5920,6 +6523,7 @@ add_builtin(
5920
6523
  This function may not be called from user-defined Warp functions.""",
5921
6524
  namespace="",
5922
6525
  native_func="builtin_tid1d",
6526
+ is_differentiable=False,
5923
6527
  )
5924
6528
 
5925
6529
  add_builtin(
@@ -5930,6 +6534,7 @@ add_builtin(
5930
6534
  doc="Returns the number of threads in the current block.",
5931
6535
  namespace="",
5932
6536
  native_func="builtin_block_dim",
6537
+ is_differentiable=False,
5933
6538
  )
5934
6539
 
5935
6540
  add_builtin(
@@ -5944,6 +6549,7 @@ add_builtin(
5944
6549
  This function may not be called from user-defined Warp functions.""",
5945
6550
  namespace="",
5946
6551
  native_func="builtin_tid2d",
6552
+ is_differentiable=False,
5947
6553
  )
5948
6554
 
5949
6555
  add_builtin(
@@ -5958,6 +6564,7 @@ add_builtin(
5958
6564
  This function may not be called from user-defined Warp functions.""",
5959
6565
  namespace="",
5960
6566
  native_func="builtin_tid3d",
6567
+ is_differentiable=False,
5961
6568
  )
5962
6569
 
5963
6570
  add_builtin(
@@ -5972,17 +6579,37 @@ add_builtin(
5972
6579
  This function may not be called from user-defined Warp functions.""",
5973
6580
  namespace="",
5974
6581
  native_func="builtin_tid4d",
6582
+ is_differentiable=False,
5975
6583
  )
5976
6584
 
5977
6585
 
6586
+ def copy_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6587
+ a = arg_types["a"]
6588
+
6589
+ # if the input is a shared tile, we force a copy
6590
+ if is_tile(a) and a.storage == "shared":
6591
+ return tile(
6592
+ dtype=a.dtype,
6593
+ shape=a.shape,
6594
+ storage=a.storage,
6595
+ strides=a.strides,
6596
+ layout=a.layout,
6597
+ owner=True,
6598
+ )
6599
+
6600
+ return a
6601
+
6602
+
5978
6603
  add_builtin(
5979
6604
  "copy",
5980
6605
  input_types={"a": Any},
5981
- value_func=lambda arg_types, arg_values: arg_types["a"],
6606
+ value_func=copy_value_func,
5982
6607
  hidden=True,
5983
6608
  export=False,
5984
6609
  group="Utility",
5985
6610
  )
6611
+
6612
+
5986
6613
  add_builtin(
5987
6614
  "assign",
5988
6615
  input_types={"dest": Any, "src": Any},
@@ -5992,61 +6619,88 @@ add_builtin(
5992
6619
  )
5993
6620
 
5994
6621
 
5995
- def select_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
5996
- warp.utils.warn(
5997
- "wp.select() is deprecated and will be removed in a future\n"
5998
- "version. Use wp.where(cond, value_if_true, value_if_false) instead.",
5999
- category=DeprecationWarning,
6000
- )
6001
-
6002
- func_args = tuple(args.values())
6003
- template_args = ()
6622
+ def select_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6623
+ if arg_types is None:
6624
+ return Any
6004
6625
 
6005
- return (func_args, template_args)
6626
+ raise RuntimeError("wp.select() has been removed. Use wp.where(cond, value_if_true, value_if_false) instead.")
6006
6627
 
6007
6628
 
6008
6629
  add_builtin(
6009
6630
  "select",
6010
6631
  input_types={"cond": builtins.bool, "value_if_false": Any, "value_if_true": Any},
6011
- value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
6012
- dispatch_func=select_dispatch_func,
6632
+ value_func=select_value_func,
6013
6633
  doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
6014
6634
 
6015
- .. deprecated:: 1.7
6635
+ .. versionremoved:: 1.10
6016
6636
  Use :func:`where` instead, which has the more intuitive argument order:
6017
- ``where(cond, value_if_true, value_if_false)``.""",
6637
+ ``where(cond, value_if_true, value_if_false)``.
6638
+
6639
+ .. deprecated:: 1.7""",
6018
6640
  group="Utility",
6019
6641
  )
6020
6642
  for t in int_types:
6021
6643
  add_builtin(
6022
6644
  "select",
6023
6645
  input_types={"cond": t, "value_if_false": Any, "value_if_true": Any},
6024
- value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
6025
- dispatch_func=select_dispatch_func,
6646
+ value_func=select_value_func,
6026
6647
  doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
6027
6648
 
6028
- .. deprecated:: 1.7
6649
+ .. versionremoved:: 1.10
6029
6650
  Use :func:`where` instead, which has the more intuitive argument order:
6030
- ``where(cond, value_if_true, value_if_false)``.""",
6651
+ ``where(cond, value_if_true, value_if_false)``.
6652
+
6653
+ .. deprecated:: 1.7""",
6031
6654
  group="Utility",
6032
6655
  )
6033
6656
  add_builtin(
6034
6657
  "select",
6035
6658
  input_types={"arr": array(dtype=Any), "value_if_false": Any, "value_if_true": Any},
6036
- value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
6037
- dispatch_func=select_dispatch_func,
6659
+ value_func=select_value_func,
6038
6660
  doc="""Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``.
6039
6661
 
6040
- .. deprecated:: 1.7
6662
+ .. versionremoved:: 1.10
6041
6663
  Use :func:`where` instead, which has the more intuitive argument order:
6042
- ``where(arr, value_if_true, value_if_false)``.""",
6664
+ ``where(arr, value_if_true, value_if_false)``.
6665
+
6666
+ .. deprecated:: 1.7""",
6043
6667
  group="Utility",
6044
6668
  )
6045
6669
 
6670
+
6671
+ def where_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6672
+ if arg_types is None:
6673
+ return Any
6674
+
6675
+ v_true = arg_types["value_if_true"]
6676
+ v_false = arg_types["value_if_false"]
6677
+
6678
+ if not types_equal(v_true, v_false):
6679
+ raise RuntimeError(f"where() true value type ({v_true}) must be of the same type as the false type ({v_false})")
6680
+
6681
+ if is_tile(v_false):
6682
+ if v_true.storage == "register":
6683
+ return v_true
6684
+ if v_false.storage == "register":
6685
+ return v_false
6686
+
6687
+ # both v_true and v_false are shared
6688
+ return tile(
6689
+ dtype=v_true.dtype,
6690
+ shape=v_true.shape,
6691
+ storage=v_true.storage,
6692
+ strides=v_true.strides,
6693
+ layout=v_true.layout,
6694
+ owner=True,
6695
+ )
6696
+
6697
+ return v_true
6698
+
6699
+
6046
6700
  add_builtin(
6047
6701
  "where",
6048
6702
  input_types={"cond": builtins.bool, "value_if_true": Any, "value_if_false": Any},
6049
- value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
6703
+ value_func=where_value_func,
6050
6704
  doc="Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``.",
6051
6705
  group="Utility",
6052
6706
  )
@@ -6054,14 +6708,14 @@ for t in int_types:
6054
6708
  add_builtin(
6055
6709
  "where",
6056
6710
  input_types={"cond": t, "value_if_true": Any, "value_if_false": Any},
6057
- value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
6711
+ value_func=where_value_func,
6058
6712
  doc="Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``.",
6059
6713
  group="Utility",
6060
6714
  )
6061
6715
  add_builtin(
6062
6716
  "where",
6063
6717
  input_types={"arr": array(dtype=Any), "value_if_true": Any, "value_if_false": Any},
6064
- value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
6718
+ value_func=where_value_func,
6065
6719
  doc="Select between two arguments, if ``arr`` is not null then return ``value_if_true``, otherwise return ``value_if_false``.",
6066
6720
  group="Utility",
6067
6721
  )
@@ -6099,7 +6753,7 @@ add_builtin(
6099
6753
  group="Utility",
6100
6754
  hidden=True,
6101
6755
  export=False,
6102
- missing_grad=True,
6756
+ is_differentiable=False,
6103
6757
  )
6104
6758
 
6105
6759
 
@@ -6140,7 +6794,7 @@ add_builtin(
6140
6794
  native_func="fixedarray_t",
6141
6795
  group="Utility",
6142
6796
  export=False,
6143
- missing_grad=True,
6797
+ is_differentiable=False,
6144
6798
  hidden=True, # Unhide once we can document both a built-in and a Python scope function sharing the same name.
6145
6799
  )
6146
6800
 
@@ -6183,14 +6837,13 @@ for array_type in array_types:
6183
6837
  # does argument checking and type propagation for view()
6184
6838
  def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6185
6839
  arr_type = arg_types["arr"]
6186
- idx_types = tuple(arg_types[x] for x in "ijk" if arg_types.get(x, None) is not None)
6840
+ idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
6187
6841
 
6188
6842
  if not is_array(arr_type):
6189
6843
  raise RuntimeError("view() first argument must be an array")
6190
6844
 
6191
6845
  idx_count = len(idx_types)
6192
-
6193
- if idx_count >= arr_type.ndim:
6846
+ if idx_count > arr_type.ndim:
6194
6847
  raise RuntimeError(
6195
6848
  f"Trying to create an array view with {idx_count} indices, "
6196
6849
  f"but the array only has {arr_type.ndim} dimension(s). "
@@ -6198,14 +6851,35 @@ def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]
6198
6851
  f"the expected number of dimensions, e.g.: def func(param: wp.array3d(dtype=float): ..."
6199
6852
  )
6200
6853
 
6201
- # check index types
6202
- for t in idx_types:
6203
- if not type_is_int(t):
6204
- raise RuntimeError(f"view() index arguments must be of integer type, got index of type {type_repr(t)}")
6854
+ has_slice = any(is_slice(x) for x in idx_types)
6855
+ if has_slice:
6856
+ # check index types
6857
+ for t in idx_types:
6858
+ if not (type_is_int(t) or is_slice(t)):
6859
+ raise RuntimeError(
6860
+ f"view() index arguments must be of integer or slice types, got index of type {type_repr(t)}"
6861
+ )
6862
+
6863
+ # Each integer index collapses one dimension.
6864
+ int_count = sum(x.step == 0 for x in idx_types)
6865
+ ndim = arr_type.ndim - int_count
6866
+ assert ndim > 0
6867
+ else:
6868
+ if idx_count == arr_type.ndim:
6869
+ raise RuntimeError("Expected to call `address()` instead of `view()`")
6870
+
6871
+ # check index types
6872
+ for t in idx_types:
6873
+ if not type_is_int(t):
6874
+ raise RuntimeError(
6875
+ f"view() index arguments must be of integer or slice types, got index of type {type_repr(t)}"
6876
+ )
6877
+
6878
+ # create an array view with leading dimensions removed
6879
+ ndim = arr_type.ndim - idx_count
6880
+ assert ndim > 0
6205
6881
 
6206
- # create an array view with leading dimensions removed
6207
6882
  dtype = arr_type.dtype
6208
- ndim = arr_type.ndim - idx_count
6209
6883
  if isinstance(arr_type, (fabricarray, indexedfabricarray)):
6210
6884
  # fabric array of arrays: return array attribute as a regular array
6211
6885
  return array(dtype=dtype, ndim=ndim)
@@ -6216,8 +6890,18 @@ def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]
6216
6890
  for array_type in array_types:
6217
6891
  add_builtin(
6218
6892
  "view",
6219
- input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int},
6220
- defaults={"j": None, "k": None},
6893
+ input_types={
6894
+ "arr": array_type(dtype=Any),
6895
+ "i": Any,
6896
+ "j": Any,
6897
+ "k": Any,
6898
+ "l": Any,
6899
+ },
6900
+ defaults={
6901
+ "j": None,
6902
+ "k": None,
6903
+ "l": None,
6904
+ },
6221
6905
  constraint=sametypes,
6222
6906
  hidden=True,
6223
6907
  value_func=view_value_func,
@@ -6321,6 +7005,7 @@ add_builtin(
6321
7005
  hidden=True,
6322
7006
  skip_replay=True,
6323
7007
  group="Utility",
7008
+ is_differentiable=False,
6324
7009
  )
6325
7010
 
6326
7011
 
@@ -6337,6 +7022,7 @@ add_builtin(
6337
7022
  dispatch_func=load_dispatch_func,
6338
7023
  hidden=True,
6339
7024
  group="Utility",
7025
+ is_differentiable=False,
6340
7026
  )
6341
7027
 
6342
7028
 
@@ -6412,6 +7098,13 @@ def create_atomic_op_value_func(op: str):
6412
7098
  f"atomic_{op}() operations only work on arrays with [u]int32, [u]int64, float32, or float64 "
6413
7099
  f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
6414
7100
  )
7101
+ elif op in ("and", "or", "xor"):
7102
+ supported_atomic_types = (warp.int32, warp.int64, warp.uint32, warp.uint64)
7103
+ if not any(types_equal(scalar_type, x, match_generic=True) for x in supported_atomic_types):
7104
+ raise RuntimeError(
7105
+ f"atomic_{op}() operations only work on arrays with [u]int32 or [u]int64 "
7106
+ f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
7107
+ )
6415
7108
  else:
6416
7109
  raise NotImplementedError
6417
7110
 
@@ -6653,6 +7346,7 @@ for array_type in array_types:
6653
7346
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6654
7347
  group="Utility",
6655
7348
  skip_replay=True,
7349
+ is_differentiable=False,
6656
7350
  )
6657
7351
  add_builtin(
6658
7352
  "atomic_cas",
@@ -6666,6 +7360,7 @@ for array_type in array_types:
6666
7360
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6667
7361
  group="Utility",
6668
7362
  skip_replay=True,
7363
+ is_differentiable=False,
6669
7364
  )
6670
7365
  add_builtin(
6671
7366
  "atomic_cas",
@@ -6679,6 +7374,7 @@ for array_type in array_types:
6679
7374
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6680
7375
  group="Utility",
6681
7376
  skip_replay=True,
7377
+ is_differentiable=False,
6682
7378
  )
6683
7379
  add_builtin(
6684
7380
  "atomic_cas",
@@ -6700,6 +7396,7 @@ for array_type in array_types:
6700
7396
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6701
7397
  group="Utility",
6702
7398
  skip_replay=True,
7399
+ is_differentiable=False,
6703
7400
  )
6704
7401
 
6705
7402
  add_builtin(
@@ -6714,6 +7411,7 @@ for array_type in array_types:
6714
7411
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6715
7412
  group="Utility",
6716
7413
  skip_replay=True,
7414
+ is_differentiable=False,
6717
7415
  )
6718
7416
  add_builtin(
6719
7417
  "atomic_exch",
@@ -6727,32 +7425,193 @@ for array_type in array_types:
6727
7425
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6728
7426
  group="Utility",
6729
7427
  skip_replay=True,
7428
+ is_differentiable=False,
7429
+ )
7430
+ add_builtin(
7431
+ "atomic_exch",
7432
+ hidden=hidden,
7433
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
7434
+ constraint=atomic_op_constraint,
7435
+ value_func=create_atomic_op_value_func("exch"),
7436
+ dispatch_func=atomic_op_dispatch_func,
7437
+ doc="""Atomically exchange ``value`` with ``arr[i,j,k]`` and return the old value.
7438
+
7439
+ The operation is only atomic on a per-component basis for vectors and matrices.""",
7440
+ group="Utility",
7441
+ skip_replay=True,
7442
+ is_differentiable=False,
7443
+ )
7444
+ add_builtin(
7445
+ "atomic_exch",
7446
+ hidden=hidden,
7447
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
7448
+ constraint=atomic_op_constraint,
7449
+ value_func=create_atomic_op_value_func("exch"),
7450
+ dispatch_func=atomic_op_dispatch_func,
7451
+ doc="""Atomically exchange ``value`` with ``arr[i,j,k,l]`` and return the old value.
7452
+
7453
+ The operation is only atomic on a per-component basis for vectors and matrices.""",
7454
+ group="Utility",
7455
+ skip_replay=True,
7456
+ )
7457
+
7458
+ add_builtin(
7459
+ "atomic_and",
7460
+ hidden=hidden,
7461
+ input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
7462
+ constraint=atomic_op_constraint,
7463
+ value_func=create_atomic_op_value_func("and"),
7464
+ dispatch_func=atomic_op_dispatch_func,
7465
+ doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i]``, atomically update the array, and return the old value.
7466
+ This function is automatically invoked when using the syntax ``arr[i] &= value``.""",
7467
+ group="Utility",
7468
+ skip_replay=True,
7469
+ is_differentiable=False,
7470
+ )
7471
+ add_builtin(
7472
+ "atomic_and",
7473
+ hidden=hidden,
7474
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
7475
+ constraint=atomic_op_constraint,
7476
+ value_func=create_atomic_op_value_func("and"),
7477
+ dispatch_func=atomic_op_dispatch_func,
7478
+ doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
7479
+ This function is automatically invoked when using the syntax ``arr[i,j] &= value``.""",
7480
+ group="Utility",
7481
+ skip_replay=True,
7482
+ is_differentiable=False,
7483
+ )
7484
+ add_builtin(
7485
+ "atomic_and",
7486
+ hidden=hidden,
7487
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
7488
+ constraint=atomic_op_constraint,
7489
+ value_func=create_atomic_op_value_func("and"),
7490
+ dispatch_func=atomic_op_dispatch_func,
7491
+ doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
7492
+ This function is automatically invoked when using the syntax ``arr[i,j,k] &= value``.""",
7493
+ group="Utility",
7494
+ skip_replay=True,
7495
+ is_differentiable=False,
7496
+ )
7497
+ add_builtin(
7498
+ "atomic_and",
7499
+ hidden=hidden,
7500
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
7501
+ constraint=atomic_op_constraint,
7502
+ value_func=create_atomic_op_value_func("and"),
7503
+ dispatch_func=atomic_op_dispatch_func,
7504
+ doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
7505
+ This function is automatically invoked when using the syntax ``arr[i,j,k,l] &= value``.""",
7506
+ group="Utility",
7507
+ skip_replay=True,
7508
+ is_differentiable=False,
7509
+ )
7510
+
7511
+ add_builtin(
7512
+ "atomic_or",
7513
+ hidden=hidden,
7514
+ input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
7515
+ constraint=atomic_op_constraint,
7516
+ value_func=create_atomic_op_value_func("or"),
7517
+ dispatch_func=atomic_op_dispatch_func,
7518
+ doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i]``, atomically update the array, and return the old value.
7519
+ This function is automatically invoked when using the syntax ``arr[i] |= value``.""",
7520
+ group="Utility",
7521
+ skip_replay=True,
7522
+ is_differentiable=False,
7523
+ )
7524
+ add_builtin(
7525
+ "atomic_or",
7526
+ hidden=hidden,
7527
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
7528
+ constraint=atomic_op_constraint,
7529
+ value_func=create_atomic_op_value_func("or"),
7530
+ dispatch_func=atomic_op_dispatch_func,
7531
+ doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
7532
+ This function is automatically invoked when using the syntax ``arr[i,j] |= value``.""",
7533
+ group="Utility",
7534
+ skip_replay=True,
7535
+ is_differentiable=False,
7536
+ )
7537
+ add_builtin(
7538
+ "atomic_or",
7539
+ hidden=hidden,
7540
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
7541
+ constraint=atomic_op_constraint,
7542
+ value_func=create_atomic_op_value_func("or"),
7543
+ dispatch_func=atomic_op_dispatch_func,
7544
+ doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
7545
+ This function is automatically invoked when using the syntax ``arr[i,j,k] |= value``.""",
7546
+ group="Utility",
7547
+ skip_replay=True,
7548
+ is_differentiable=False,
7549
+ )
7550
+ add_builtin(
7551
+ "atomic_or",
7552
+ hidden=hidden,
7553
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
7554
+ constraint=atomic_op_constraint,
7555
+ value_func=create_atomic_op_value_func("or"),
7556
+ dispatch_func=atomic_op_dispatch_func,
7557
+ doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
7558
+ This function is automatically invoked when using the syntax ``arr[i,j,k,l] |= value``.""",
7559
+ group="Utility",
7560
+ skip_replay=True,
7561
+ is_differentiable=False,
7562
+ )
7563
+
7564
+ add_builtin(
7565
+ "atomic_xor",
7566
+ hidden=hidden,
7567
+ input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
7568
+ constraint=atomic_op_constraint,
7569
+ value_func=create_atomic_op_value_func("xor"),
7570
+ dispatch_func=atomic_op_dispatch_func,
7571
+ doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i]``, atomically update the array, and return the old value.
7572
+ This function is automatically invoked when using the syntax ``arr[i] ^= value``.""",
7573
+ group="Utility",
7574
+ skip_replay=True,
7575
+ is_differentiable=False,
7576
+ )
7577
+ add_builtin(
7578
+ "atomic_xor",
7579
+ hidden=hidden,
7580
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
7581
+ constraint=atomic_op_constraint,
7582
+ value_func=create_atomic_op_value_func("xor"),
7583
+ dispatch_func=atomic_op_dispatch_func,
7584
+ doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
7585
+ This function is automatically invoked when using the syntax ``arr[i,j] ^= value``.""",
7586
+ group="Utility",
7587
+ skip_replay=True,
7588
+ is_differentiable=False,
6730
7589
  )
6731
7590
  add_builtin(
6732
- "atomic_exch",
7591
+ "atomic_xor",
6733
7592
  hidden=hidden,
6734
7593
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
6735
7594
  constraint=atomic_op_constraint,
6736
- value_func=create_atomic_op_value_func("exch"),
7595
+ value_func=create_atomic_op_value_func("xor"),
6737
7596
  dispatch_func=atomic_op_dispatch_func,
6738
- doc="""Atomically exchange ``value`` with ``arr[i,j,k]`` and return the old value.
6739
-
6740
- The operation is only atomic on a per-component basis for vectors and matrices.""",
7597
+ doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
7598
+ This function is automatically invoked when using the syntax ``arr[i,j,k] ^= value``.""",
6741
7599
  group="Utility",
6742
7600
  skip_replay=True,
7601
+ is_differentiable=False,
6743
7602
  )
6744
7603
  add_builtin(
6745
- "atomic_exch",
7604
+ "atomic_xor",
6746
7605
  hidden=hidden,
6747
7606
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
6748
7607
  constraint=atomic_op_constraint,
6749
- value_func=create_atomic_op_value_func("exch"),
7608
+ value_func=create_atomic_op_value_func("xor"),
6750
7609
  dispatch_func=atomic_op_dispatch_func,
6751
- doc="""Atomically exchange ``value`` with ``arr[i,j,k,l]`` and return the old value.
6752
-
6753
- The operation is only atomic on a per-component basis for vectors and matrices.""",
7610
+ doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
7611
+ This function is automatically invoked when using the syntax ``arr[i,j,k,l] ^= value``.""",
6754
7612
  group="Utility",
6755
7613
  skip_replay=True,
7614
+ is_differentiable=False,
6756
7615
  )
6757
7616
 
6758
7617
 
@@ -6903,6 +7762,7 @@ add_builtin(
6903
7762
  hidden=True,
6904
7763
  group="Utility",
6905
7764
  skip_replay=True,
7765
+ is_differentiable=False,
6906
7766
  )
6907
7767
  # implements &quaternion[index]
6908
7768
  add_builtin(
@@ -6913,6 +7773,7 @@ add_builtin(
6913
7773
  hidden=True,
6914
7774
  group="Utility",
6915
7775
  skip_replay=True,
7776
+ is_differentiable=False,
6916
7777
  )
6917
7778
  # implements &transformation[index]
6918
7779
  add_builtin(
@@ -6923,6 +7784,7 @@ add_builtin(
6923
7784
  hidden=True,
6924
7785
  group="Utility",
6925
7786
  skip_replay=True,
7787
+ is_differentiable=False,
6926
7788
  )
6927
7789
  # implements &(*vector)[index]
6928
7790
  add_builtin(
@@ -6933,6 +7795,7 @@ add_builtin(
6933
7795
  hidden=True,
6934
7796
  group="Utility",
6935
7797
  skip_replay=True,
7798
+ is_differentiable=False,
6936
7799
  )
6937
7800
  # implements &(*matrix)[i, j]
6938
7801
  add_builtin(
@@ -6943,6 +7806,7 @@ add_builtin(
6943
7806
  hidden=True,
6944
7807
  group="Utility",
6945
7808
  skip_replay=True,
7809
+ is_differentiable=False,
6946
7810
  )
6947
7811
  # implements &(*quaternion)[index]
6948
7812
  add_builtin(
@@ -6953,6 +7817,7 @@ add_builtin(
6953
7817
  hidden=True,
6954
7818
  group="Utility",
6955
7819
  skip_replay=True,
7820
+ is_differentiable=False,
6956
7821
  )
6957
7822
  # implements &(*transformation)[index]
6958
7823
  add_builtin(
@@ -6963,6 +7828,7 @@ add_builtin(
6963
7828
  hidden=True,
6964
7829
  group="Utility",
6965
7830
  skip_replay=True,
7831
+ is_differentiable=False,
6966
7832
  )
6967
7833
 
6968
7834
 
@@ -7158,6 +8024,43 @@ add_builtin(
7158
8024
  )
7159
8025
 
7160
8026
 
8027
+ # implements vector[idx] &= scalar
8028
+ add_builtin(
8029
+ "bit_and_inplace",
8030
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
8031
+ value_type=None,
8032
+ dispatch_func=vector_assign_dispatch_func,
8033
+ hidden=True,
8034
+ export=False,
8035
+ group="Utility",
8036
+ is_differentiable=False,
8037
+ )
8038
+
8039
+ # implements vector[idx] |= scalar
8040
+ add_builtin(
8041
+ "bit_or_inplace",
8042
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
8043
+ value_type=None,
8044
+ dispatch_func=vector_assign_dispatch_func,
8045
+ hidden=True,
8046
+ export=False,
8047
+ group="Utility",
8048
+ is_differentiable=False,
8049
+ )
8050
+
8051
+ # implements vector[idx] ^= scalar
8052
+ add_builtin(
8053
+ "bit_xor_inplace",
8054
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
8055
+ value_type=None,
8056
+ dispatch_func=vector_assign_dispatch_func,
8057
+ hidden=True,
8058
+ export=False,
8059
+ group="Utility",
8060
+ is_differentiable=False,
8061
+ )
8062
+
8063
+
7161
8064
  def matrix_index_row_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
7162
8065
  mat_type = arg_types["a"]
7163
8066
  row_type = mat_type._wp_row_type_
@@ -7173,6 +8076,7 @@ add_builtin(
7173
8076
  hidden=True,
7174
8077
  group="Utility",
7175
8078
  skip_replay=True,
8079
+ is_differentiable=False,
7176
8080
  )
7177
8081
 
7178
8082
 
@@ -7191,6 +8095,7 @@ add_builtin(
7191
8095
  hidden=True,
7192
8096
  group="Utility",
7193
8097
  skip_replay=True,
8098
+ is_differentiable=False,
7194
8099
  )
7195
8100
 
7196
8101
 
@@ -7390,6 +8295,78 @@ add_builtin(
7390
8295
  )
7391
8296
 
7392
8297
 
8298
+ # implements matrix[i] &= value
8299
+ add_builtin(
8300
+ "bit_and_inplace",
8301
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
8302
+ value_type=None,
8303
+ hidden=True,
8304
+ export=False,
8305
+ group="Utility",
8306
+ is_differentiable=False,
8307
+ )
8308
+
8309
+
8310
+ # implements matrix[i,j] &= value
8311
+ add_builtin(
8312
+ "bit_and_inplace",
8313
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
8314
+ value_type=None,
8315
+ hidden=True,
8316
+ export=False,
8317
+ group="Utility",
8318
+ is_differentiable=False,
8319
+ )
8320
+
8321
+
8322
+ # implements matrix[i] |= value
8323
+ add_builtin(
8324
+ "bit_or_inplace",
8325
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
8326
+ value_type=None,
8327
+ hidden=True,
8328
+ export=False,
8329
+ group="Utility",
8330
+ is_differentiable=False,
8331
+ )
8332
+
8333
+
8334
+ # implements matrix[i,j] |= value
8335
+ add_builtin(
8336
+ "bit_or_inplace",
8337
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
8338
+ value_type=None,
8339
+ hidden=True,
8340
+ export=False,
8341
+ group="Utility",
8342
+ is_differentiable=False,
8343
+ )
8344
+
8345
+
8346
+ # implements matrix[i] ^= value
8347
+ add_builtin(
8348
+ "bit_xor_inplace",
8349
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
8350
+ value_type=None,
8351
+ hidden=True,
8352
+ export=False,
8353
+ group="Utility",
8354
+ is_differentiable=False,
8355
+ )
8356
+
8357
+
8358
+ # implements matrix[i,j] ^= value
8359
+ add_builtin(
8360
+ "bit_xor_inplace",
8361
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
8362
+ value_type=None,
8363
+ hidden=True,
8364
+ export=False,
8365
+ group="Utility",
8366
+ is_differentiable=False,
8367
+ )
8368
+
8369
+
7393
8370
  for t in scalar_types + vector_types + (bool,):
7394
8371
  if "vec" in t.__name__ or "mat" in t.__name__:
7395
8372
  continue
@@ -7401,6 +8378,7 @@ for t in scalar_types + vector_types + (bool,):
7401
8378
  doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
7402
8379
  group="Utility",
7403
8380
  hidden=True,
8381
+ is_differentiable=False,
7404
8382
  )
7405
8383
 
7406
8384
  add_builtin(
@@ -7411,6 +8389,7 @@ for t in scalar_types + vector_types + (bool,):
7411
8389
  group="Utility",
7412
8390
  hidden=True,
7413
8391
  export=False,
8392
+ is_differentiable=False,
7414
8393
  )
7415
8394
 
7416
8395
 
@@ -7429,6 +8408,7 @@ add_builtin(
7429
8408
  doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
7430
8409
  group="Utility",
7431
8410
  hidden=True,
8411
+ is_differentiable=False,
7432
8412
  )
7433
8413
  add_builtin(
7434
8414
  "expect_neq",
@@ -7439,6 +8419,7 @@ add_builtin(
7439
8419
  group="Utility",
7440
8420
  hidden=True,
7441
8421
  export=False,
8422
+ is_differentiable=False,
7442
8423
  )
7443
8424
 
7444
8425
  add_builtin(
@@ -7449,6 +8430,7 @@ add_builtin(
7449
8430
  doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
7450
8431
  group="Utility",
7451
8432
  hidden=True,
8433
+ is_differentiable=False,
7452
8434
  )
7453
8435
  add_builtin(
7454
8436
  "expect_neq",
@@ -7459,6 +8441,7 @@ add_builtin(
7459
8441
  group="Utility",
7460
8442
  hidden=True,
7461
8443
  export=False,
8444
+ is_differentiable=False,
7462
8445
  )
7463
8446
 
7464
8447
  add_builtin(
@@ -7549,6 +8532,7 @@ add_builtin(
7549
8532
  value_type=None,
7550
8533
  doc="Prints an error to stdout if ``a`` and ``b`` are not closer than tolerance in magnitude",
7551
8534
  group="Utility",
8535
+ is_differentiable=False,
7552
8536
  )
7553
8537
  add_builtin(
7554
8538
  "expect_near",
@@ -7558,6 +8542,7 @@ add_builtin(
7558
8542
  value_type=None,
7559
8543
  doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
7560
8544
  group="Utility",
8545
+ is_differentiable=False,
7561
8546
  )
7562
8547
  add_builtin(
7563
8548
  "expect_near",
@@ -7567,6 +8552,7 @@ add_builtin(
7567
8552
  value_type=None,
7568
8553
  doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
7569
8554
  group="Utility",
8555
+ is_differentiable=False,
7570
8556
  )
7571
8557
  add_builtin(
7572
8558
  "expect_near",
@@ -7580,6 +8566,7 @@ add_builtin(
7580
8566
  value_type=None,
7581
8567
  doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
7582
8568
  group="Utility",
8569
+ is_differentiable=False,
7583
8570
  )
7584
8571
 
7585
8572
  # ---------------------------------
@@ -7590,6 +8577,7 @@ add_builtin(
7590
8577
  input_types={"arr": array(dtype=Scalar), "value": Scalar},
7591
8578
  value_type=int,
7592
8579
  doc="Search a sorted array ``arr`` for the closest element greater than or equal to ``value``.",
8580
+ is_differentiable=False,
7593
8581
  )
7594
8582
 
7595
8583
  add_builtin(
@@ -7597,6 +8585,7 @@ add_builtin(
7597
8585
  input_types={"arr": array(dtype=Scalar), "arr_begin": int, "arr_end": int, "value": Scalar},
7598
8586
  value_type=int,
7599
8587
  doc="Search a sorted array ``arr`` in the range [arr_begin, arr_end) for the closest element greater than or equal to ``value``.",
8588
+ is_differentiable=False,
7600
8589
  )
7601
8590
 
7602
8591
  # ---------------------------------
@@ -7672,12 +8661,157 @@ add_builtin(
7672
8661
  )
7673
8662
 
7674
8663
  # bitwise operators
7675
- add_builtin("bit_and", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
7676
- add_builtin("bit_or", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
7677
- add_builtin("bit_xor", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
7678
- add_builtin("lshift", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
7679
- add_builtin("rshift", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
7680
- add_builtin("invert", input_types={"a": Int}, value_func=sametypes_create_value_func(Int))
8664
+ add_builtin(
8665
+ "bit_and",
8666
+ input_types={"a": Int, "b": Int},
8667
+ value_func=sametypes_create_value_func(Int),
8668
+ group="Operators",
8669
+ is_differentiable=False,
8670
+ )
8671
+ add_builtin(
8672
+ "bit_and",
8673
+ input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
8674
+ constraint=sametypes,
8675
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
8676
+ doc="",
8677
+ group="Operators",
8678
+ is_differentiable=False,
8679
+ )
8680
+ add_builtin(
8681
+ "bit_and",
8682
+ input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
8683
+ constraint=sametypes,
8684
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
8685
+ doc="",
8686
+ group="Operators",
8687
+ is_differentiable=False,
8688
+ )
8689
+
8690
+ add_builtin(
8691
+ "bit_or",
8692
+ input_types={"a": Int, "b": Int},
8693
+ value_func=sametypes_create_value_func(Int),
8694
+ group="Operators",
8695
+ is_differentiable=False,
8696
+ )
8697
+ add_builtin(
8698
+ "bit_or",
8699
+ input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
8700
+ constraint=sametypes,
8701
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
8702
+ doc="",
8703
+ group="Operators",
8704
+ is_differentiable=False,
8705
+ )
8706
+ add_builtin(
8707
+ "bit_or",
8708
+ input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
8709
+ constraint=sametypes,
8710
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
8711
+ doc="",
8712
+ group="Operators",
8713
+ is_differentiable=False,
8714
+ )
8715
+
8716
+ add_builtin(
8717
+ "bit_xor",
8718
+ input_types={"a": Int, "b": Int},
8719
+ value_func=sametypes_create_value_func(Int),
8720
+ group="Operators",
8721
+ is_differentiable=False,
8722
+ )
8723
+ add_builtin(
8724
+ "bit_xor",
8725
+ input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
8726
+ constraint=sametypes,
8727
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
8728
+ doc="",
8729
+ group="Operators",
8730
+ is_differentiable=False,
8731
+ )
8732
+ add_builtin(
8733
+ "bit_xor",
8734
+ input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
8735
+ constraint=sametypes,
8736
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
8737
+ doc="",
8738
+ group="Operators",
8739
+ is_differentiable=False,
8740
+ )
8741
+
8742
+ add_builtin(
8743
+ "lshift",
8744
+ input_types={"a": Int, "b": Int},
8745
+ value_func=sametypes_create_value_func(Int),
8746
+ group="Operators",
8747
+ is_differentiable=False,
8748
+ )
8749
+ add_builtin(
8750
+ "lshift",
8751
+ input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
8752
+ constraint=sametypes,
8753
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
8754
+ doc="",
8755
+ group="Operators",
8756
+ is_differentiable=False,
8757
+ )
8758
+ add_builtin(
8759
+ "lshift",
8760
+ input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
8761
+ constraint=sametypes,
8762
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
8763
+ doc="",
8764
+ group="Operators",
8765
+ is_differentiable=False,
8766
+ )
8767
+
8768
+ add_builtin(
8769
+ "rshift",
8770
+ input_types={"a": Int, "b": Int},
8771
+ value_func=sametypes_create_value_func(Int),
8772
+ group="Operators",
8773
+ is_differentiable=False,
8774
+ )
8775
+ add_builtin(
8776
+ "rshift",
8777
+ input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
8778
+ constraint=sametypes,
8779
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
8780
+ doc="",
8781
+ group="Operators",
8782
+ is_differentiable=False,
8783
+ )
8784
+ add_builtin(
8785
+ "rshift",
8786
+ input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
8787
+ constraint=sametypes,
8788
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
8789
+ doc="",
8790
+ group="Operators",
8791
+ is_differentiable=False,
8792
+ )
8793
+
8794
+ add_builtin(
8795
+ "invert",
8796
+ input_types={"a": Int},
8797
+ value_func=sametypes_create_value_func(Int),
8798
+ group="Operators",
8799
+ is_differentiable=False,
8800
+ )
8801
+ add_builtin(
8802
+ "invert",
8803
+ input_types={"a": vector(length=Any, dtype=Int)},
8804
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
8805
+ group="Operators",
8806
+ is_differentiable=False,
8807
+ )
8808
+ add_builtin(
8809
+ "invert",
8810
+ input_types={"a": matrix(shape=(Any, Any), dtype=Int)},
8811
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
8812
+ group="Operators",
8813
+ is_differentiable=False,
8814
+ )
7681
8815
 
7682
8816
 
7683
8817
  add_builtin(
@@ -7878,6 +9012,7 @@ add_builtin(
7878
9012
  value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
7879
9013
  doc="Modulo operation using truncated division.",
7880
9014
  group="Operators",
9015
+ is_differentiable=False,
7881
9016
  )
7882
9017
 
7883
9018
  add_builtin(
@@ -7937,6 +9072,7 @@ add_builtin(
7937
9072
  value_func=sametypes_create_value_func(Scalar),
7938
9073
  doc="",
7939
9074
  group="Operators",
9075
+ is_differentiable=False,
7940
9076
  )
7941
9077
 
7942
9078
  add_builtin("pos", input_types={"x": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators")
@@ -7984,12 +9120,28 @@ add_builtin(
7984
9120
  group="Operators",
7985
9121
  )
7986
9122
 
7987
- add_builtin("unot", input_types={"a": builtins.bool}, value_type=builtins.bool, doc="", group="Operators")
9123
+ add_builtin(
9124
+ "unot",
9125
+ input_types={"a": builtins.bool},
9126
+ value_type=builtins.bool,
9127
+ doc="",
9128
+ group="Operators",
9129
+ is_differentiable=False,
9130
+ )
7988
9131
  for t in int_types:
7989
- add_builtin("unot", input_types={"a": t}, value_type=builtins.bool, doc="", group="Operators")
9132
+ add_builtin(
9133
+ "unot", input_types={"a": t}, value_type=builtins.bool, doc="", group="Operators", is_differentiable=False
9134
+ )
7990
9135
 
7991
9136
 
7992
- add_builtin("unot", input_types={"a": array(dtype=Any)}, value_type=builtins.bool, doc="", group="Operators")
9137
+ add_builtin(
9138
+ "unot",
9139
+ input_types={"a": array(dtype=Any)},
9140
+ value_type=builtins.bool,
9141
+ doc="",
9142
+ group="Operators",
9143
+ is_differentiable=False,
9144
+ )
7993
9145
 
7994
9146
 
7995
9147
  # Tile operators
@@ -8061,6 +9213,45 @@ add_builtin(
8061
9213
  export=False,
8062
9214
  )
8063
9215
 
9216
+ add_builtin(
9217
+ "bit_and",
9218
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
9219
+ value_func=tile_binary_map_value_func,
9220
+ # dispatch_func=tile_map_dispatch_func,
9221
+ # variadic=True,
9222
+ native_func="tile_bit_and",
9223
+ doc="Bitwise AND each element of two tiles together",
9224
+ group="Tile Primitives",
9225
+ export=False,
9226
+ is_differentiable=False,
9227
+ )
9228
+
9229
+ add_builtin(
9230
+ "bit_or",
9231
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
9232
+ value_func=tile_binary_map_value_func,
9233
+ # dispatch_func=tile_map_dispatch_func,
9234
+ # variadic=True,
9235
+ native_func="tile_bit_or",
9236
+ doc="Bitwise OR each element of two tiles together",
9237
+ group="Tile Primitives",
9238
+ export=False,
9239
+ is_differentiable=False,
9240
+ )
9241
+
9242
+ add_builtin(
9243
+ "bit_xor",
9244
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
9245
+ value_func=tile_binary_map_value_func,
9246
+ # dispatch_func=tile_map_dispatch_func,
9247
+ # variadic=True,
9248
+ native_func="tile_bit_xor",
9249
+ doc="Bitwise XOR each element of two tiles together",
9250
+ group="Tile Primitives",
9251
+ export=False,
9252
+ is_differentiable=False,
9253
+ )
9254
+
8064
9255
 
8065
9256
  add_builtin(
8066
9257
  "mul",
@@ -8122,6 +9313,45 @@ add_builtin(
8122
9313
  )
8123
9314
 
8124
9315
 
9316
+ add_builtin(
9317
+ "bit_and_inplace",
9318
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
9319
+ value_type=None,
9320
+ dispatch_func=tile_inplace_dispatch_func,
9321
+ export=False,
9322
+ hidden=True,
9323
+ native_func="tile_bit_and_inplace",
9324
+ group="Operators",
9325
+ is_differentiable=False,
9326
+ )
9327
+
9328
+
9329
+ add_builtin(
9330
+ "bit_or_inplace",
9331
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
9332
+ value_type=None,
9333
+ dispatch_func=tile_inplace_dispatch_func,
9334
+ export=False,
9335
+ hidden=True,
9336
+ native_func="tile_bit_or_inplace",
9337
+ group="Operators",
9338
+ is_differentiable=False,
9339
+ )
9340
+
9341
+
9342
+ add_builtin(
9343
+ "bit_xor_inplace",
9344
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
9345
+ value_type=None,
9346
+ dispatch_func=tile_inplace_dispatch_func,
9347
+ export=False,
9348
+ hidden=True,
9349
+ native_func="tile_bit_xor_inplace",
9350
+ group="Operators",
9351
+ is_differentiable=False,
9352
+ )
9353
+
9354
+
8125
9355
  def tile_diag_add_value_func(arg_types, arg_values):
8126
9356
  if arg_types is None:
8127
9357
  return tile(dtype=Any, shape=Tuple[int, int])
@@ -8163,7 +9393,7 @@ def tile_diag_add_lto_dispatch_func(
8163
9393
  return_values: List[Var],
8164
9394
  arg_values: Mapping[str, Var],
8165
9395
  options: Mapping[str, Any],
8166
- builder: warp.context.ModuleBuilder,
9396
+ builder: warp._src.context.ModuleBuilder,
8167
9397
  ):
8168
9398
  a = arg_values["a"]
8169
9399
  d = arg_values["d"]
@@ -8183,6 +9413,7 @@ add_builtin(
8183
9413
  doc="Add a square matrix and a diagonal matrix 'd' represented as a 1D tile",
8184
9414
  group="Tile Primitives",
8185
9415
  export=False,
9416
+ is_differentiable=False,
8186
9417
  )
8187
9418
 
8188
9419
 
@@ -8239,7 +9470,7 @@ def tile_matmul_lto_dispatch_func(
8239
9470
  return_values: List[Var],
8240
9471
  arg_values: Mapping[str, Var],
8241
9472
  options: Mapping[str, Any],
8242
- builder: warp.context.ModuleBuilder,
9473
+ builder: warp._src.context.ModuleBuilder,
8243
9474
  ):
8244
9475
  a = arg_values["a"]
8245
9476
  b = arg_values["b"]
@@ -8277,7 +9508,7 @@ def tile_matmul_lto_dispatch_func(
8277
9508
  num_threads = options["block_dim"]
8278
9509
  arch = options["output_arch"]
8279
9510
 
8280
- if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
9511
+ if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
8281
9512
  # CPU/no-MathDx dispatch
8282
9513
  return ((0, 0, 0, a, b, out), template_args, [], 0)
8283
9514
  else:
@@ -8290,7 +9521,7 @@ def tile_matmul_lto_dispatch_func(
8290
9521
 
8291
9522
  # generate the LTOs
8292
9523
  # C += A * B
8293
- (fun_forward, lto_forward) = warp.build.build_lto_dot(
9524
+ (fun_forward, lto_forward) = warp._src.build.build_lto_dot(
8294
9525
  M,
8295
9526
  N,
8296
9527
  K,
@@ -8306,7 +9537,7 @@ def tile_matmul_lto_dispatch_func(
8306
9537
  )
8307
9538
  if warp.config.enable_backward:
8308
9539
  # adjA += adjC * B^T - Transpose ~= flipped layout
8309
- (fun_backward_A, lto_backward_A) = warp.build.build_lto_dot(
9540
+ (fun_backward_A, lto_backward_A) = warp._src.build.build_lto_dot(
8310
9541
  M,
8311
9542
  K,
8312
9543
  N,
@@ -8321,7 +9552,7 @@ def tile_matmul_lto_dispatch_func(
8321
9552
  builder,
8322
9553
  )
8323
9554
  # adjB += A^T * adjC - Transpose ~= flipped layout
8324
- (fun_backward_B, lto_backward_B) = warp.build.build_lto_dot(
9555
+ (fun_backward_B, lto_backward_B) = warp._src.build.build_lto_dot(
8325
9556
  K,
8326
9557
  N,
8327
9558
  M,
@@ -8438,7 +9669,7 @@ def tile_fft_generic_lto_dispatch_func(
8438
9669
  return_values: List[Var],
8439
9670
  arg_values: Mapping[str, Var],
8440
9671
  options: Mapping[str, Any],
8441
- builder: warp.context.ModuleBuilder,
9672
+ builder: warp._src.context.ModuleBuilder,
8442
9673
  direction: str | None = None,
8443
9674
  ):
8444
9675
  inout = arg_values["inout"]
@@ -8467,12 +9698,12 @@ def tile_fft_generic_lto_dispatch_func(
8467
9698
  arch = options["output_arch"]
8468
9699
  ept = size // num_threads
8469
9700
 
8470
- if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
9701
+ if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
8471
9702
  # CPU/no-MathDx dispatch
8472
9703
  return ([], [], [], 0)
8473
9704
  else:
8474
9705
  # generate the LTO
8475
- lto_symbol, lto_code_data, shared_memory_bytes = warp.build.build_lto_fft(
9706
+ lto_symbol, lto_code_data, shared_memory_bytes = warp._src.build.build_lto_fft(
8476
9707
  arch, size, ept, direction, dir, precision, builder
8477
9708
  )
8478
9709
 
@@ -8510,6 +9741,7 @@ add_builtin(
8510
9741
  group="Tile Primitives",
8511
9742
  export=False,
8512
9743
  namespace="",
9744
+ is_differentiable=False,
8513
9745
  )
8514
9746
 
8515
9747
  add_builtin(
@@ -8531,6 +9763,7 @@ add_builtin(
8531
9763
  group="Tile Primitives",
8532
9764
  export=False,
8533
9765
  namespace="",
9766
+ is_differentiable=False,
8534
9767
  )
8535
9768
 
8536
9769
 
@@ -8575,7 +9808,7 @@ def tile_cholesky_generic_lto_dispatch_func(
8575
9808
  return_values: List[Var],
8576
9809
  arg_values: Mapping[str, Var],
8577
9810
  options: Mapping[str, Any],
8578
- builder: warp.context.ModuleBuilder,
9811
+ builder: warp._src.context.ModuleBuilder,
8579
9812
  ):
8580
9813
  a = arg_values["A"]
8581
9814
  # force source tile to shared memory
@@ -8595,7 +9828,7 @@ def tile_cholesky_generic_lto_dispatch_func(
8595
9828
 
8596
9829
  arch = options["output_arch"]
8597
9830
 
8598
- if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
9831
+ if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
8599
9832
  # CPU/no-MathDx dispatch
8600
9833
  return ((0, a, out), [], [], 0)
8601
9834
  else:
@@ -8610,7 +9843,7 @@ def tile_cholesky_generic_lto_dispatch_func(
8610
9843
  req_smem_bytes = a.type.size * type_size_in_bytes(a.type.dtype)
8611
9844
 
8612
9845
  # generate the LTO
8613
- lto_symbol, lto_code_data = warp.build.build_lto_solver(
9846
+ lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
8614
9847
  M,
8615
9848
  N,
8616
9849
  1,
@@ -8655,6 +9888,7 @@ add_builtin(
8655
9888
  group="Tile Primitives",
8656
9889
  export=False,
8657
9890
  namespace="",
9891
+ is_differentiable=False,
8658
9892
  )
8659
9893
 
8660
9894
 
@@ -8698,7 +9932,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
8698
9932
  return_values: List[Var],
8699
9933
  arg_values: Mapping[str, Var],
8700
9934
  options: Mapping[str, Any],
8701
- builder: warp.context.ModuleBuilder,
9935
+ builder: warp._src.context.ModuleBuilder,
8702
9936
  ):
8703
9937
  L = arg_values["L"]
8704
9938
  y = arg_values["y"]
@@ -8727,7 +9961,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
8727
9961
 
8728
9962
  arch = options["output_arch"]
8729
9963
 
8730
- if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
9964
+ if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
8731
9965
  # CPU/no-MathDx dispatch
8732
9966
  return ((0, L, y, x), [], [], 0)
8733
9967
  else:
@@ -8743,7 +9977,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
8743
9977
  req_smem_bytes = (x.type.size + y.type.size + L.type.size) * type_size_in_bytes(L.type.dtype)
8744
9978
 
8745
9979
  # generate the LTO
8746
- lto_symbol, lto_code_data = warp.build.build_lto_solver(
9980
+ lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
8747
9981
  M,
8748
9982
  N,
8749
9983
  NRHS,
@@ -8785,6 +10019,7 @@ add_builtin(
8785
10019
  group="Tile Primitives",
8786
10020
  export=False,
8787
10021
  namespace="",
10022
+ is_differentiable=False,
8788
10023
  )
8789
10024
 
8790
10025
 
@@ -8794,7 +10029,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
8794
10029
  return_values: List[Var],
8795
10030
  arg_values: Mapping[str, Var],
8796
10031
  options: Mapping[str, Any],
8797
- builder: warp.context.ModuleBuilder,
10032
+ builder: warp._src.context.ModuleBuilder,
8798
10033
  ):
8799
10034
  L = arg_values["L"]
8800
10035
  y = arg_values["y"]
@@ -8823,7 +10058,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
8823
10058
 
8824
10059
  arch = options["output_arch"]
8825
10060
 
8826
- if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
10061
+ if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
8827
10062
  # CPU/no-MathDx dispatch
8828
10063
  return ((0, L, y, z), [], [], 0)
8829
10064
  else:
@@ -8839,7 +10074,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
8839
10074
  req_smem_bytes = (z.type.size + y.type.size + L.type.size) * type_size_in_bytes(L.type.dtype)
8840
10075
 
8841
10076
  # generate the LTO
8842
- lto_symbol, lto_code_data = warp.build.build_lto_solver(
10077
+ lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
8843
10078
  M,
8844
10079
  N,
8845
10080
  NRHS,
@@ -8917,6 +10152,7 @@ add_builtin(
8917
10152
  group="Tile Primitives",
8918
10153
  export=False,
8919
10154
  namespace="",
10155
+ is_differentiable=False,
8920
10156
  )
8921
10157
 
8922
10158
 
@@ -8926,7 +10162,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
8926
10162
  return_values: List[Var],
8927
10163
  arg_values: Mapping[str, Var],
8928
10164
  options: Mapping[str, Any],
8929
- builder: warp.context.ModuleBuilder,
10165
+ builder: warp._src.context.ModuleBuilder,
8930
10166
  ):
8931
10167
  U = arg_values["U"]
8932
10168
  z = arg_values["z"]
@@ -8955,7 +10191,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
8955
10191
 
8956
10192
  arch = options["output_arch"]
8957
10193
 
8958
- if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
10194
+ if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
8959
10195
  # CPU/no-MathDx dispatch
8960
10196
  return ((0, U, z, x), [], [], 0)
8961
10197
  else:
@@ -8971,7 +10207,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
8971
10207
  req_smem_bytes = (x.type.size + z.type.size + U.type.size) * type_size_in_bytes(U.type.dtype)
8972
10208
 
8973
10209
  # generate the LTO
8974
- lto_symbol, lto_code_data = warp.build.build_lto_solver(
10210
+ lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
8975
10211
  M,
8976
10212
  N,
8977
10213
  NRHS,
@@ -9049,6 +10285,7 @@ add_builtin(
9049
10285
  group="Tile Primitives",
9050
10286
  export=False,
9051
10287
  namespace="",
10288
+ is_differentiable=False,
9052
10289
  )
9053
10290
 
9054
10291
 
@@ -9068,6 +10305,7 @@ add_builtin(
9068
10305
  The return type of the expression must be either a Warp function, a string, or a type that is supported inside Warp kernels and functions
9069
10306
  (excluding Warp arrays since they cannot be created in a Warp kernel at the moment).""",
9070
10307
  group="Code Generation",
10308
+ is_differentiable=False,
9071
10309
  )
9072
10310
 
9073
10311
 
@@ -9092,6 +10330,7 @@ add_builtin(
9092
10330
  doc="Return the number of elements in a vector.",
9093
10331
  group="Utility",
9094
10332
  export=False,
10333
+ is_differentiable=False,
9095
10334
  )
9096
10335
 
9097
10336
  add_builtin(
@@ -9101,6 +10340,7 @@ add_builtin(
9101
10340
  doc="Return the number of elements in a quaternion.",
9102
10341
  group="Utility",
9103
10342
  export=False,
10343
+ is_differentiable=False,
9104
10344
  )
9105
10345
 
9106
10346
  add_builtin(
@@ -9110,6 +10350,7 @@ add_builtin(
9110
10350
  doc="Return the number of rows in a matrix.",
9111
10351
  group="Utility",
9112
10352
  export=False,
10353
+ is_differentiable=False,
9113
10354
  )
9114
10355
 
9115
10356
  add_builtin(
@@ -9119,6 +10360,7 @@ add_builtin(
9119
10360
  doc="Return the number of elements in a transformation.",
9120
10361
  group="Utility",
9121
10362
  export=False,
10363
+ is_differentiable=False,
9122
10364
  )
9123
10365
 
9124
10366
  add_builtin(
@@ -9128,6 +10370,7 @@ add_builtin(
9128
10370
  doc="Return the size of the first dimension in an array.",
9129
10371
  group="Utility",
9130
10372
  export=False,
10373
+ is_differentiable=False,
9131
10374
  )
9132
10375
 
9133
10376
  add_builtin(
@@ -9137,6 +10380,33 @@ add_builtin(
9137
10380
  doc="Return the number of rows in a tile.",
9138
10381
  group="Utility",
9139
10382
  export=False,
10383
+ is_differentiable=False,
10384
+ )
10385
+
10386
+
10387
+ def cast_value_func(arg_types, arg_values):
10388
+ # Return generic type for doc builds.
10389
+ if arg_types is None:
10390
+ return Any
10391
+
10392
+ return arg_values["dtype"]
10393
+
10394
+
10395
+ def cast_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
10396
+ func_args = (args["a"],)
10397
+ template_args = (args["dtype"],)
10398
+ return (func_args, template_args)
10399
+
10400
+
10401
+ add_builtin(
10402
+ "cast",
10403
+ input_types={"a": Any, "dtype": Any},
10404
+ value_func=cast_value_func,
10405
+ dispatch_func=cast_dispatch_func,
10406
+ doc="Reinterpret a value as a different type while preserving its bit pattern.",
10407
+ group="Utility",
10408
+ export=False,
10409
+ is_differentiable=False,
9140
10410
  )
9141
10411
 
9142
10412
 
@@ -9163,7 +10433,7 @@ add_builtin(
9163
10433
  doc="Construct a tuple from a list of values",
9164
10434
  group="Utility",
9165
10435
  hidden=True,
9166
- missing_grad=True,
10436
+ is_differentiable=False,
9167
10437
  export=False,
9168
10438
  )
9169
10439
 
@@ -9200,7 +10470,7 @@ add_builtin(
9200
10470
  dispatch_func=tuple_extract_dispatch_func,
9201
10471
  group="Utility",
9202
10472
  hidden=True,
9203
- missing_grad=True,
10473
+ is_differentiable=False,
9204
10474
  )
9205
10475
 
9206
10476
 
@@ -9211,6 +10481,7 @@ add_builtin(
9211
10481
  doc="Return the number of elements in a tuple.",
9212
10482
  group="Utility",
9213
10483
  export=False,
10484
+ is_differentiable=False,
9214
10485
  )
9215
10486
 
9216
10487
  # ---------------------------------
@@ -9229,5 +10500,5 @@ add_builtin(
9229
10500
  export=False,
9230
10501
  group="Utility",
9231
10502
  hidden=True,
9232
- missing_grad=True,
10503
+ is_differentiable=False,
9233
10504
  )