warp-lang 1.9.1__py3-none-manylinux_2_34_aarch64.whl → 1.10.0rc2__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +794 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1382 -377
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.so +0 -0
  92. warp/bin/warp.so +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +1 -1
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +253 -171
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +14 -14
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +527 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -20,11 +20,11 @@ import functools
20
20
  import math
21
21
  from typing import Any, Callable, Mapping, Sequence
22
22
 
23
- import warp.build
24
- import warp.context
25
- import warp.utils
26
- from warp.codegen import Reference, Var, get_arg_value, strip_reference
27
- from warp.types import *
23
+ import warp._src.build
24
+ import warp._src.context
25
+ import warp._src.utils
26
+ from warp._src.codegen import Reference, Var, get_arg_value, strip_reference
27
+ from warp._src.types import *
28
28
 
29
29
  from .context import add_builtin
30
30
 
@@ -61,11 +61,11 @@ def sametypes_create_value_func(default: TypeVar):
61
61
 
62
62
  def extract_tuple(arg, as_constant=False):
63
63
  if isinstance(arg, Var):
64
- if isinstance(arg.type, warp.types.tuple_t):
64
+ if isinstance(arg.type, warp._src.types.tuple_t):
65
65
  out = arg.type.values
66
66
  else:
67
67
  out = (arg,)
68
- elif isinstance(arg, warp.types.tuple_t):
68
+ elif isinstance(arg, warp._src.types.tuple_t):
69
69
  out = arg.values
70
70
  elif not isinstance(arg, Sequence):
71
71
  out = (arg,)
@@ -82,7 +82,7 @@ def static_len_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str
82
82
  if arg_types is None:
83
83
  return int
84
84
 
85
- length = warp.types.type_length(arg_types["a"])
85
+ length = warp._src.types.type_length(arg_types["a"])
86
86
  return Var(None, type=int, constant=length)
87
87
 
88
88
 
@@ -126,7 +126,7 @@ add_builtin(
126
126
  value_func=sametypes_create_value_func(Scalar),
127
127
  doc="Return -1 if ``x`` < 0, return 1 otherwise.",
128
128
  group="Scalar Math",
129
- missing_grad=True,
129
+ is_differentiable=False,
130
130
  )
131
131
 
132
132
  add_builtin(
@@ -135,7 +135,7 @@ add_builtin(
135
135
  value_func=sametypes_create_value_func(Scalar),
136
136
  doc="Return 1.0 if ``x`` < 0.0, return 0.0 otherwise.",
137
137
  group="Scalar Math",
138
- missing_grad=True,
138
+ is_differentiable=False,
139
139
  )
140
140
  add_builtin(
141
141
  "nonzero",
@@ -143,7 +143,7 @@ add_builtin(
143
143
  value_func=sametypes_create_value_func(Scalar),
144
144
  doc="Return 1.0 if ``x`` is not equal to zero, return 0.0 otherwise.",
145
145
  group="Scalar Math",
146
- missing_grad=True,
146
+ is_differentiable=False,
147
147
  )
148
148
 
149
149
  add_builtin(
@@ -285,7 +285,36 @@ add_builtin(
285
285
  group="Scalar Math",
286
286
  require_original_output_arg=True,
287
287
  )
288
-
288
+ add_builtin(
289
+ "erf",
290
+ input_types={"x": Float},
291
+ value_func=sametypes_create_value_func(Float),
292
+ doc="Return the error function of ``x``.",
293
+ group="Scalar Math",
294
+ )
295
+ add_builtin(
296
+ "erfc",
297
+ input_types={"x": Float},
298
+ value_func=sametypes_create_value_func(Float),
299
+ doc="Return the complementary error function of ``x``.",
300
+ group="Scalar Math",
301
+ )
302
+ add_builtin(
303
+ "erfinv",
304
+ input_types={"x": Float},
305
+ value_func=sametypes_create_value_func(Float),
306
+ doc="Return the inverse error function of ``x``.",
307
+ group="Scalar Math",
308
+ require_original_output_arg=True,
309
+ )
310
+ add_builtin(
311
+ "erfcinv",
312
+ input_types={"x": Float},
313
+ value_func=sametypes_create_value_func(Float),
314
+ doc="Return the inverse complementary error function of ``x``.",
315
+ group="Scalar Math",
316
+ require_original_output_arg=True,
317
+ )
289
318
  add_builtin(
290
319
  "round",
291
320
  input_types={"x": Float},
@@ -295,7 +324,7 @@ add_builtin(
295
324
 
296
325
  This is the most intuitive form of rounding in the colloquial sense, but can be slower than other options like :func:`warp.rint()`.
297
326
  Differs from :func:`numpy.round()`, which behaves the same way as :func:`numpy.rint()`.""",
298
- missing_grad=True,
327
+ is_differentiable=False,
299
328
  )
300
329
 
301
330
  add_builtin(
@@ -306,7 +335,7 @@ add_builtin(
306
335
  doc="""Return the nearest integer value to ``x``, rounding halfway cases to nearest even integer.
307
336
 
308
337
  It is generally faster than :func:`warp.round()`. Equivalent to :func:`numpy.rint()`.""",
309
- missing_grad=True,
338
+ is_differentiable=False,
310
339
  )
311
340
 
312
341
  add_builtin(
@@ -319,7 +348,7 @@ add_builtin(
319
348
  In other words, it discards the fractional part of ``x``.
320
349
  It is similar to casting ``float(int(a))``, but preserves the negative sign when ``x`` is in the range [-0.0, -1.0).
321
350
  Equivalent to :func:`numpy.trunc()` and :func:`numpy.fix()`.""",
322
- missing_grad=True,
351
+ is_differentiable=False,
323
352
  )
324
353
 
325
354
  add_builtin(
@@ -328,7 +357,7 @@ add_builtin(
328
357
  value_func=sametypes_create_value_func(Float),
329
358
  group="Scalar Math",
330
359
  doc="""Return the largest integer that is less than or equal to ``x``.""",
331
- missing_grad=True,
360
+ is_differentiable=False,
332
361
  )
333
362
 
334
363
  add_builtin(
@@ -337,7 +366,7 @@ add_builtin(
337
366
  value_func=sametypes_create_value_func(Float),
338
367
  group="Scalar Math",
339
368
  doc="""Return the smallest integer that is greater than or equal to ``x``.""",
340
- missing_grad=True,
369
+ is_differentiable=False,
341
370
  )
342
371
 
343
372
  add_builtin(
@@ -348,7 +377,7 @@ add_builtin(
348
377
  doc="""Retrieve the fractional part of ``x``.
349
378
 
350
379
  In other words, it discards the integer part of ``x`` and is equivalent to ``x - trunc(x)``.""",
351
- missing_grad=True,
380
+ is_differentiable=False,
352
381
  )
353
382
 
354
383
  add_builtin(
@@ -357,7 +386,7 @@ add_builtin(
357
386
  value_type=builtins.bool,
358
387
  group="Scalar Math",
359
388
  doc="""Return ``True`` if ``a`` is a finite number, otherwise return ``False``.""",
360
- missing_grad=True,
389
+ is_differentiable=False,
361
390
  )
362
391
  add_builtin(
363
392
  "isfinite",
@@ -365,7 +394,7 @@ add_builtin(
365
394
  value_type=builtins.bool,
366
395
  group="Vector Math",
367
396
  doc="Return ``True`` if all elements of the vector ``a`` are finite, otherwise return ``False``.",
368
- missing_grad=True,
397
+ is_differentiable=False,
369
398
  )
370
399
  add_builtin(
371
400
  "isfinite",
@@ -373,7 +402,7 @@ add_builtin(
373
402
  value_type=builtins.bool,
374
403
  group="Vector Math",
375
404
  doc="Return ``True`` if all elements of the quaternion ``a`` are finite, otherwise return ``False``.",
376
- missing_grad=True,
405
+ is_differentiable=False,
377
406
  )
378
407
  add_builtin(
379
408
  "isfinite",
@@ -381,7 +410,7 @@ add_builtin(
381
410
  value_type=builtins.bool,
382
411
  group="Vector Math",
383
412
  doc="Return ``True`` if all elements of the matrix ``a`` are finite, otherwise return ``False``.",
384
- missing_grad=True,
413
+ is_differentiable=False,
385
414
  )
386
415
 
387
416
  add_builtin(
@@ -390,7 +419,7 @@ add_builtin(
390
419
  value_type=builtins.bool,
391
420
  doc="Return ``True`` if ``a`` is NaN, otherwise return ``False``.",
392
421
  group="Scalar Math",
393
- missing_grad=True,
422
+ is_differentiable=False,
394
423
  )
395
424
  add_builtin(
396
425
  "isnan",
@@ -398,7 +427,7 @@ add_builtin(
398
427
  value_type=builtins.bool,
399
428
  group="Vector Math",
400
429
  doc="Return ``True`` if any element of the vector ``a`` is NaN, otherwise return ``False``.",
401
- missing_grad=True,
430
+ is_differentiable=False,
402
431
  )
403
432
  add_builtin(
404
433
  "isnan",
@@ -406,7 +435,7 @@ add_builtin(
406
435
  value_type=builtins.bool,
407
436
  group="Vector Math",
408
437
  doc="Return ``True`` if any element of the quaternion ``a`` is NaN, otherwise return ``False``.",
409
- missing_grad=True,
438
+ is_differentiable=False,
410
439
  )
411
440
  add_builtin(
412
441
  "isnan",
@@ -414,7 +443,7 @@ add_builtin(
414
443
  value_type=builtins.bool,
415
444
  group="Vector Math",
416
445
  doc="Return ``True`` if any element of the matrix ``a`` is NaN, otherwise return ``False``.",
417
- missing_grad=True,
446
+ is_differentiable=False,
418
447
  )
419
448
 
420
449
  add_builtin(
@@ -423,7 +452,7 @@ add_builtin(
423
452
  value_type=builtins.bool,
424
453
  group="Scalar Math",
425
454
  doc="""Return ``True`` if ``a`` is positive or negative infinity, otherwise return ``False``.""",
426
- missing_grad=True,
455
+ is_differentiable=False,
427
456
  )
428
457
  add_builtin(
429
458
  "isinf",
@@ -431,7 +460,7 @@ add_builtin(
431
460
  value_type=builtins.bool,
432
461
  group="Vector Math",
433
462
  doc="Return ``True`` if any element of the vector ``a`` is positive or negative infinity, otherwise return ``False``.",
434
- missing_grad=True,
463
+ is_differentiable=False,
435
464
  )
436
465
  add_builtin(
437
466
  "isinf",
@@ -439,7 +468,7 @@ add_builtin(
439
468
  value_type=builtins.bool,
440
469
  group="Vector Math",
441
470
  doc="Return ``True`` if any element of the quaternion ``a`` is positive or negative infinity, otherwise return ``False``.",
442
- missing_grad=True,
471
+ is_differentiable=False,
443
472
  )
444
473
  add_builtin(
445
474
  "isinf",
@@ -447,7 +476,7 @@ add_builtin(
447
476
  value_type=builtins.bool,
448
477
  group="Vector Math",
449
478
  doc="Return ``True`` if any element of the matrix ``a`` is positive or negative infinity, otherwise return ``False``.",
450
- missing_grad=True,
479
+ is_differentiable=False,
451
480
  )
452
481
 
453
482
 
@@ -555,7 +584,7 @@ add_builtin(
555
584
  value_func=lambda arg_types, arg_values: warp.uint32,
556
585
  doc="Return the index of the minimum element of a vector ``a``.",
557
586
  group="Vector Math",
558
- missing_grad=True,
587
+ is_differentiable=False,
559
588
  )
560
589
  add_builtin(
561
590
  "argmax",
@@ -563,7 +592,7 @@ add_builtin(
563
592
  value_func=lambda arg_types, arg_values: warp.uint32,
564
593
  doc="Return the index of the maximum element of a vector ``a``.",
565
594
  group="Vector Math",
566
- missing_grad=True,
595
+ is_differentiable=False,
567
596
  )
568
597
 
569
598
  add_builtin(
@@ -888,7 +917,7 @@ def vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
888
917
 
889
918
  if dtype is None:
890
919
  dtype = value_type
891
- elif not warp.types.scalars_equal(value_type, dtype):
920
+ elif not warp._src.types.scalars_equal(value_type, dtype):
892
921
  raise RuntimeError(
893
922
  f"the value used to fill this vector is expected to be of the type `{dtype.__name__}`"
894
923
  )
@@ -909,7 +938,7 @@ def vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
909
938
 
910
939
  if dtype is None:
911
940
  dtype = value_type
912
- elif not warp.types.scalars_equal(value_type, dtype):
941
+ elif not warp._src.types.scalars_equal(value_type, dtype):
913
942
  raise RuntimeError(
914
943
  f"all values used to initialize this vector are expected to be of the type `{dtype.__name__}`"
915
944
  )
@@ -992,7 +1021,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
992
1021
 
993
1022
  if dtype is None:
994
1023
  dtype = value_type
995
- elif not warp.types.scalars_equal(value_type, dtype):
1024
+ elif not warp._src.types.scalars_equal(value_type, dtype):
996
1025
  raise RuntimeError(
997
1026
  f"the value used to fill this matrix is expected to be of the type `{dtype.__name__}`"
998
1027
  )
@@ -1002,7 +1031,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
1002
1031
  raise RuntimeError("the `shape` argument must be specified when initializing a matrix by value")
1003
1032
 
1004
1033
  if all(type_is_vector(x) for x in variadic_arg_types):
1005
- warp.utils.warn(
1034
+ warp._src.utils.warn(
1006
1035
  "the built-in `wp.matrix()` won't support taking column vectors as input "
1007
1036
  "in the future. Use `wp.matrix_from_rows()` or `wp.matrix_from_cols()` instead.",
1008
1037
  DeprecationWarning,
@@ -1031,7 +1060,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
1031
1060
 
1032
1061
  if dtype is None:
1033
1062
  dtype = value_type
1034
- elif not warp.types.scalars_equal(value_type, dtype):
1063
+ elif not warp._src.types.scalars_equal(value_type, dtype):
1035
1064
  raise RuntimeError(
1036
1065
  f"all values used to initialize this matrix are expected to be of the type `{dtype.__name__}`"
1037
1066
  )
@@ -1203,49 +1232,18 @@ add_builtin(
1203
1232
  doc="Create an identity matrix with shape=(n,n) with the type given by ``dtype``.",
1204
1233
  group="Vector Math",
1205
1234
  export=False,
1206
- missing_grad=True,
1235
+ is_differentiable=False,
1207
1236
  )
1208
1237
 
1209
1238
 
1210
1239
  def matrix_transform_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1211
- warp.utils.warn(
1212
- "the built-in `wp.matrix()` function to construct a 4x4 matrix from a 3D position, quaternion, "
1213
- "and 3D scale vector will be deprecated in favor of `wp.transform_compose()`.",
1214
- DeprecationWarning,
1215
- )
1216
1240
  if arg_types is None:
1217
1241
  return matrix(shape=(4, 4), dtype=Float)
1218
1242
 
1219
- dtype = arg_values.get("dtype", None)
1220
-
1221
- value_arg_types = tuple(v for k, v in arg_types.items() if k != "dtype")
1222
- try:
1223
- value_type = scalar_infer_type(value_arg_types)
1224
- except RuntimeError:
1225
- raise RuntimeError(
1226
- "all values given when constructing a transformation matrix must have the same type"
1227
- ) from None
1228
-
1229
- if dtype is None:
1230
- dtype = value_type
1231
- elif not warp.types.scalars_equal(value_type, dtype):
1232
- raise RuntimeError(
1233
- f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
1234
- )
1235
-
1236
- return matrix(shape=(4, 4), dtype=dtype)
1237
-
1238
-
1239
- def matrix_transform_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
1240
- # We're in the codegen stage where we emit the code calling the built-in.
1241
- # Further validate the given argument values if needed and map them
1242
- # to the underlying C++ function's runtime and template params.
1243
-
1244
- dtype = return_type._wp_scalar_type_
1245
-
1246
- func_args = tuple(v for k, v in args.items() if k != "dtype")
1247
- template_args = (4, 4, dtype)
1248
- return (func_args, template_args)
1243
+ raise RuntimeError(
1244
+ "the built-in `wp.matrix()` to construct a 4x4 matrix from a 3D position, quaternion, "
1245
+ "and 3D scale vector has been removed in favor of `wp.transform_compose()`."
1246
+ )
1249
1247
 
1250
1248
 
1251
1249
  add_builtin(
@@ -1259,13 +1257,14 @@ add_builtin(
1259
1257
  defaults={"dtype": None},
1260
1258
  value_func=matrix_transform_value_func,
1261
1259
  export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
1262
- dispatch_func=matrix_transform_dispatch_func,
1263
1260
  native_func="mat_t",
1264
1261
  doc="""Construct a 4x4 transformation matrix that applies the transformations as
1265
1262
  Translation(pos)*Rotation(rot)*Scaling(scale) when applied to column vectors, i.e.: y = (TRS)*x
1266
1263
 
1267
- .. warning::
1268
- This function has been deprecated in favor of :func:`warp.math.transform_compose()`.""",
1264
+ .. versionremoved:: 1.10
1265
+ This function has been removed in favor of :func:`warp.math.transform_compose()`.
1266
+
1267
+ .. deprecated:: 1.8""",
1269
1268
  group="Vector Math",
1270
1269
  export=False,
1271
1270
  )
@@ -1460,7 +1459,7 @@ def quaternion_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str
1460
1459
 
1461
1460
  if dtype is None:
1462
1461
  dtype = value_type
1463
- elif not warp.types.scalars_equal(value_type, dtype):
1462
+ elif not warp._src.types.scalars_equal(value_type, dtype):
1464
1463
  raise RuntimeError(
1465
1464
  f"all values used to initialize this quaternion are expected to be of the type `{dtype.__name__}`"
1466
1465
  )
@@ -1568,7 +1567,7 @@ add_builtin(
1568
1567
  group="Quaternion Math",
1569
1568
  doc="Construct an identity quaternion with zero imaginary part and real part of 1.0",
1570
1569
  export=True,
1571
- missing_grad=True,
1570
+ is_differentiable=False,
1572
1571
  )
1573
1572
 
1574
1573
  add_builtin(
@@ -1697,7 +1696,7 @@ def transformation_value_func(arg_types: Mapping[str, type], arg_values: Mapping
1697
1696
  value_type = strip_reference(variadic_arg_types[0])
1698
1697
  if dtype is None:
1699
1698
  dtype = value_type
1700
- elif not warp.types.scalars_equal(value_type, dtype):
1699
+ elif not warp._src.types.scalars_equal(value_type, dtype):
1701
1700
  raise RuntimeError(
1702
1701
  f"the value used to fill this transform is expected to be of the type `{dtype.__name__}`"
1703
1702
  )
@@ -1710,7 +1709,7 @@ def transformation_value_func(arg_types: Mapping[str, type], arg_values: Mapping
1710
1709
 
1711
1710
  if dtype is None:
1712
1711
  dtype = value_type
1713
- elif not warp.types.scalars_equal(value_type, dtype):
1712
+ elif not warp._src.types.scalars_equal(value_type, dtype):
1714
1713
  raise RuntimeError(
1715
1714
  f"all values used to initialize this transform are expected to be of the type `{dtype.__name__}`"
1716
1715
  )
@@ -1735,7 +1734,7 @@ def transformation_pq_value_func(arg_types: Mapping[str, type], arg_values: Mapp
1735
1734
  dtype = arg_values.get("dtype", None)
1736
1735
  if dtype is None:
1737
1736
  dtype = value_type
1738
- elif not warp.types.scalars_equal(value_type, dtype):
1737
+ elif not warp._src.types.scalars_equal(value_type, dtype):
1739
1738
  raise RuntimeError(
1740
1739
  f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
1741
1740
  )
@@ -1750,9 +1749,19 @@ def transformation_dispatch_func(input_types: Mapping[str, type], return_type: A
1750
1749
 
1751
1750
  dtype = return_type._wp_scalar_type_
1752
1751
 
1753
- variadic_args = tuple(v for k, v in args.items() if k != "dtype")
1752
+ variadic_args = args.get("args", ())
1753
+ variadic_arg_count = len(variadic_args)
1754
+
1755
+ if variadic_arg_count == 7:
1756
+ func_args = variadic_args
1757
+ else:
1758
+ func_args = tuple(v for k, v in args.items() if k != "dtype")
1759
+ if "p" in args and "q" not in args:
1760
+ quat_ident = warp._src.codegen.Var(
1761
+ label=None, type=quaternion(dtype=dtype), constant=quaternion(dtype=dtype)(0, 0, 0, 1)
1762
+ )
1763
+ func_args += (quat_ident,)
1754
1764
 
1755
- func_args = variadic_args
1756
1765
  template_args = (dtype,)
1757
1766
  return (func_args, template_args)
1758
1767
 
@@ -1760,7 +1769,7 @@ def transformation_dispatch_func(input_types: Mapping[str, type], return_type: A
1760
1769
  add_builtin(
1761
1770
  "transformation",
1762
1771
  input_types={"p": vector(length=3, dtype=Float), "q": quaternion(dtype=Float), "dtype": Float},
1763
- defaults={"dtype": None},
1772
+ defaults={"q": None, "dtype": None},
1764
1773
  value_func=transformation_pq_value_func,
1765
1774
  export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
1766
1775
  dispatch_func=transformation_dispatch_func,
@@ -1784,7 +1793,6 @@ add_builtin(
1784
1793
  doc="Construct a spatial transform vector of given dtype.",
1785
1794
  group="Spatial Math",
1786
1795
  export=False,
1787
- missing_grad=True,
1788
1796
  )
1789
1797
 
1790
1798
 
@@ -1819,7 +1827,7 @@ add_builtin(
1819
1827
  group="Transformations",
1820
1828
  doc="Construct an identity transform with zero translation and identity rotation.",
1821
1829
  export=True,
1822
- missing_grad=True,
1830
+ is_differentiable=False,
1823
1831
  )
1824
1832
 
1825
1833
  add_builtin(
@@ -1953,7 +1961,7 @@ def spatial_vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping
1953
1961
 
1954
1962
  if dtype is None:
1955
1963
  dtype = value_type
1956
- elif not warp.types.scalars_equal(value_type, dtype):
1964
+ elif not warp._src.types.scalars_equal(value_type, dtype):
1957
1965
  raise RuntimeError(
1958
1966
  f"all values used to initialize this spatial vector are expected to be of the type `{dtype.__name__}`"
1959
1967
  )
@@ -2147,7 +2155,7 @@ add_builtin(
2147
2155
  value_func=tile_zeros_value_func,
2148
2156
  dispatch_func=tile_zeros_dispatch_func,
2149
2157
  variadic=False,
2150
- missing_grad=True,
2158
+ is_differentiable=False,
2151
2159
  doc="""Allocate a tile of zero-initialized items.
2152
2160
 
2153
2161
  :param shape: Shape of the output tile
@@ -2167,7 +2175,7 @@ add_builtin(
2167
2175
  value_func=tile_zeros_value_func,
2168
2176
  dispatch_func=tile_zeros_dispatch_func,
2169
2177
  variadic=False,
2170
- missing_grad=True,
2178
+ is_differentiable=False,
2171
2179
  hidden=True,
2172
2180
  group="Tile Primitives",
2173
2181
  export=False,
@@ -2219,7 +2227,7 @@ add_builtin(
2219
2227
  defaults={"storage": "register"},
2220
2228
  value_func=tile_ones_value_func,
2221
2229
  dispatch_func=tile_ones_dispatch_func,
2222
- missing_grad=True,
2230
+ is_differentiable=False,
2223
2231
  doc="""Allocate a tile of one-initialized items.
2224
2232
 
2225
2233
  :param shape: Shape of the output tile
@@ -2238,7 +2246,86 @@ add_builtin(
2238
2246
  defaults={"storage": "register"},
2239
2247
  value_func=tile_ones_value_func,
2240
2248
  dispatch_func=tile_ones_dispatch_func,
2241
- missing_grad=True,
2249
+ is_differentiable=False,
2250
+ hidden=True,
2251
+ group="Tile Primitives",
2252
+ export=False,
2253
+ )
2254
+
2255
+
2256
+ def tile_full_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
2257
+ # return generic type (for doc builds)
2258
+ if arg_types is None:
2259
+ return tile(dtype=Any, shape=Tuple[int, ...])
2260
+
2261
+ shape = extract_tuple(arg_values["shape"], as_constant=True)
2262
+
2263
+ if None in shape:
2264
+ raise ValueError("Tile functions require shape to be a compile time constant.")
2265
+
2266
+ if "value" not in arg_values:
2267
+ raise TypeError("tile_full() missing required keyword argument 'value'")
2268
+
2269
+ if "dtype" not in arg_values:
2270
+ raise TypeError("tile_full() missing required keyword argument 'dtype'")
2271
+
2272
+ if "storage" not in arg_values:
2273
+ raise TypeError("tile_full() missing required keyword argument 'storage'")
2274
+
2275
+ if arg_values["storage"] not in {"shared", "register"}:
2276
+ raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
2277
+
2278
+ dtype = arg_values["dtype"]
2279
+
2280
+ return tile(dtype=dtype, shape=shape, storage=arg_values["storage"])
2281
+
2282
+
2283
+ def tile_full_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
2284
+ shape = extract_tuple(arg_values["shape"], as_constant=True)
2285
+
2286
+ if None in shape:
2287
+ raise ValueError("Tile functions require shape to be a compile time constant.")
2288
+
2289
+ dtype = arg_values["dtype"]
2290
+ value = arg_values["value"]
2291
+
2292
+ func_args = [value]
2293
+
2294
+ template_args = []
2295
+ template_args.append(dtype)
2296
+ template_args.extend(shape)
2297
+
2298
+ return (func_args, template_args)
2299
+
2300
+
2301
+ add_builtin(
2302
+ "tile_full",
2303
+ input_types={"shape": Tuple[int, ...], "value": Any, "dtype": Any, "storage": str},
2304
+ defaults={"storage": "register"},
2305
+ value_func=tile_full_value_func,
2306
+ dispatch_func=tile_full_dispatch_func,
2307
+ is_differentiable=False,
2308
+ doc="""Allocate a tile filled with the specified value.
2309
+
2310
+ :param shape: Shape of the output tile
2311
+ :param value: Value to fill the tile with
2312
+ :param dtype: Data type of output tile's elements
2313
+ :param storage: The storage location for the tile: ``"register"`` for registers
2314
+ (default) or ``"shared"`` for shared memory.
2315
+ :returns: A tile filled with the specified value""",
2316
+ group="Tile Primitives",
2317
+ export=False,
2318
+ )
2319
+
2320
+
2321
+ # overload for scalar shape
2322
+ add_builtin(
2323
+ "tile_full",
2324
+ input_types={"shape": int, "value": Any, "dtype": Any, "storage": str},
2325
+ defaults={"storage": "register"},
2326
+ value_func=tile_full_value_func,
2327
+ dispatch_func=tile_full_dispatch_func,
2328
+ is_differentiable=False,
2242
2329
  hidden=True,
2243
2330
  group="Tile Primitives",
2244
2331
  export=False,
@@ -2300,13 +2387,13 @@ def tile_arange_dispatch_func(arg_types: Mapping[str, type], return_type: Any, a
2300
2387
  args = arg_values["args"]
2301
2388
 
2302
2389
  if len(args) == 1:
2303
- start = warp.codegen.Var(label=None, type=return_type.dtype, constant=0)
2390
+ start = warp._src.codegen.Var(label=None, type=return_type.dtype, constant=0)
2304
2391
  stop = args[0]
2305
- step = warp.codegen.Var(label=None, type=return_type.dtype, constant=1)
2392
+ step = warp._src.codegen.Var(label=None, type=return_type.dtype, constant=1)
2306
2393
  elif len(args) == 2:
2307
2394
  start = args[0]
2308
2395
  stop = args[1]
2309
- step = warp.codegen.Var(label=None, type=return_type.dtype, constant=1)
2396
+ step = warp._src.codegen.Var(label=None, type=return_type.dtype, constant=1)
2310
2397
  elif len(args) == 3:
2311
2398
  start = args[0]
2312
2399
  stop = args[1]
@@ -2329,7 +2416,7 @@ add_builtin(
2329
2416
  value_func=tile_arange_value_func,
2330
2417
  dispatch_func=tile_arange_dispatch_func,
2331
2418
  variadic=True,
2332
- missing_grad=True,
2419
+ is_differentiable=False,
2333
2420
  doc="""Generate a tile of linearly spaced elements.
2334
2421
 
2335
2422
  :param args: Variable-length positional arguments, interpreted as:
@@ -3124,7 +3211,7 @@ add_builtin(
3124
3211
  :param shape: Shape of the returned slice
3125
3212
  :returns: A tile with dimensions given by the specified shape or the remaining source tile dimensions""",
3126
3213
  group="Tile Primitives",
3127
- missing_grad=True,
3214
+ is_differentiable=False,
3128
3215
  export=False,
3129
3216
  )
3130
3217
 
@@ -3371,7 +3458,32 @@ add_builtin(
3371
3458
 
3372
3459
  add_builtin(
3373
3460
  "assign",
3374
- input_types={"dst": tile(dtype=Any, shape=Tuple[int, int]), "i": int, "j": int, "src": Any},
3461
+ input_types={"dst": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "src": Any},
3462
+ value_func=tile_assign_value_func,
3463
+ group="Tile Primitives",
3464
+ export=False,
3465
+ hidden=True,
3466
+ )
3467
+
3468
+ add_builtin(
3469
+ "assign",
3470
+ input_types={"dst": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "src": Any},
3471
+ value_func=tile_assign_value_func,
3472
+ group="Tile Primitives",
3473
+ export=False,
3474
+ hidden=True,
3475
+ )
3476
+
3477
+ add_builtin(
3478
+ "assign",
3479
+ input_types={
3480
+ "dst": tile(dtype=Any, shape=Tuple[int, ...]),
3481
+ "i": int,
3482
+ "j": int,
3483
+ "k": int,
3484
+ "l": int,
3485
+ "src": Any,
3486
+ },
3375
3487
  value_func=tile_assign_value_func,
3376
3488
  group="Tile Primitives",
3377
3489
  export=False,
@@ -3380,7 +3492,15 @@ add_builtin(
3380
3492
 
3381
3493
  add_builtin(
3382
3494
  "assign",
3383
- input_types={"dst": tile(dtype=Any, shape=Tuple[int, int, int]), "i": int, "j": int, "k": int, "src": Any},
3495
+ input_types={
3496
+ "dst": tile(dtype=Any, shape=Tuple[int, ...]),
3497
+ "i": int,
3498
+ "j": int,
3499
+ "k": int,
3500
+ "l": int,
3501
+ "m": int,
3502
+ "src": Any,
3503
+ },
3384
3504
  value_func=tile_assign_value_func,
3385
3505
  group="Tile Primitives",
3386
3506
  export=False,
@@ -3395,6 +3515,8 @@ add_builtin(
3395
3515
  "j": int,
3396
3516
  "k": int,
3397
3517
  "l": int,
3518
+ "m": int,
3519
+ "n": int,
3398
3520
  "src": Any,
3399
3521
  },
3400
3522
  value_func=tile_assign_value_func,
@@ -3416,7 +3538,7 @@ def tile_value_func(arg_types, arg_values):
3416
3538
 
3417
3539
  if preserve_type:
3418
3540
  dtype = arg_types["x"]
3419
- shape = (warp.codegen.options["block_dim"],)
3541
+ shape = (warp._src.codegen.options["block_dim"],)
3420
3542
 
3421
3543
  return tile(dtype=dtype, shape=shape)
3422
3544
 
@@ -3424,18 +3546,18 @@ def tile_value_func(arg_types, arg_values):
3424
3546
  if type_is_vector(arg_types["x"]):
3425
3547
  dtype = arg_types["x"]._wp_scalar_type_
3426
3548
  length = arg_types["x"]._shape_[0]
3427
- shape = (length, warp.codegen.options["block_dim"])
3549
+ shape = (length, warp._src.codegen.options["block_dim"])
3428
3550
  elif type_is_quaternion(arg_types["x"]):
3429
3551
  dtype = arg_types["x"]._wp_scalar_type_
3430
- shape = (4, warp.codegen.options["block_dim"])
3552
+ shape = (4, warp._src.codegen.options["block_dim"])
3431
3553
  elif type_is_matrix(arg_types["x"]):
3432
3554
  dtype = arg_types["x"]._wp_scalar_type_
3433
3555
  rows = arg_types["x"]._shape_[0]
3434
3556
  cols = arg_types["x"]._shape_[1]
3435
- shape = (rows, cols, warp.codegen.options["block_dim"])
3557
+ shape = (rows, cols, warp._src.codegen.options["block_dim"])
3436
3558
  else:
3437
3559
  dtype = arg_types["x"]
3438
- shape = (warp.codegen.options["block_dim"],)
3560
+ shape = (warp._src.codegen.options["block_dim"],)
3439
3561
 
3440
3562
  return tile(dtype=dtype, shape=shape)
3441
3563
 
@@ -3525,17 +3647,17 @@ def untile_value_func(arg_types, arg_values):
3525
3647
  if not is_tile(t):
3526
3648
  raise TypeError(f"untile() argument must be a tile, got {t!r}")
3527
3649
 
3528
- if t.shape[-1] != warp.codegen.options["block_dim"]:
3650
+ if t.shape[-1] != warp._src.codegen.options["block_dim"]:
3529
3651
  raise ValueError(
3530
- f"untile() argument last dimension {t.shape[-1]} does not match the expected block width {warp.codegen.options['block_dim']}"
3652
+ f"untile() argument last dimension {t.shape[-1]} does not match the expected block width {warp._src.codegen.options['block_dim']}"
3531
3653
  )
3532
3654
 
3533
3655
  if len(t.shape) == 1:
3534
3656
  return t.dtype
3535
3657
  elif len(t.shape) == 2:
3536
- return warp.types.vector(t.shape[0], t.dtype)
3658
+ return warp._src.types.vector(t.shape[0], t.dtype)
3537
3659
  elif len(t.shape) == 3:
3538
- return warp.types.matrix((t.shape[0], t.shape[1]), t.dtype)
3660
+ return warp._src.types.matrix((t.shape[0], t.shape[1]), t.dtype)
3539
3661
  else:
3540
3662
  raise ValueError(f"untile() argument must have a positive size in dimension 0, but got {t.shape[0]}")
3541
3663
 
@@ -3597,7 +3719,36 @@ def tile_extract_value_func(arg_types, arg_values):
3597
3719
  # force the input tile to shared memory
3598
3720
  arg_types["a"].storage = "shared"
3599
3721
 
3600
- return arg_types["a"].dtype
3722
+ # count the number of indices (all parameters except the tile "a")
3723
+ num_indices = len(arg_types) - 1
3724
+ tile_dtype = arg_types["a"].dtype
3725
+ tile_shape = arg_types["a"].shape
3726
+
3727
+ if type_is_vector(tile_dtype):
3728
+ if num_indices == len(tile_shape):
3729
+ return tile_dtype
3730
+ elif num_indices == len(tile_shape) + 1:
3731
+ return tile_dtype._wp_scalar_type_
3732
+ else:
3733
+ raise IndexError(
3734
+ f"tile_extract: incorrect number of indices ({num_indices}) for tile shape {tuple(tile_shape)}"
3735
+ )
3736
+ elif type_is_matrix(tile_dtype):
3737
+ if num_indices == len(tile_shape):
3738
+ return tile_dtype
3739
+ elif num_indices == len(tile_shape) + 2:
3740
+ return tile_dtype._wp_scalar_type_
3741
+ else:
3742
+ raise IndexError(
3743
+ f"tile_extract: incorrect number of indices ({num_indices}) for matrix tile shape {tuple(tile_shape)}"
3744
+ )
3745
+ else:
3746
+ # scalar element: index count must exactly match tile rank
3747
+ if num_indices == len(tile_shape):
3748
+ return tile_dtype
3749
+ raise IndexError(
3750
+ f"tile_extract: incorrect number of indices ({num_indices}) for tile shape {tuple(tile_shape)}"
3751
+ )
3601
3752
 
3602
3753
 
3603
3754
  add_builtin(
@@ -3621,7 +3772,7 @@ add_builtin(
3621
3772
 
3622
3773
  add_builtin(
3623
3774
  "tile_extract",
3624
- input_types={"a": tile(dtype=Any, shape=Tuple[int, int]), "i": int, "j": int},
3775
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int},
3625
3776
  value_func=tile_extract_value_func,
3626
3777
  variadic=False,
3627
3778
  doc="""Extract a single element from the tile.
@@ -3632,7 +3783,28 @@ add_builtin(
3632
3783
 
3633
3784
  :param a: Tile to extract the element from
3634
3785
  :param i: Coordinate of element on first dimension
3635
- :param j: Coordinate of element on the second dimension
3786
+ :param j: Coordinate of element on the second dimension, or vector index
3787
+ :returns: The value of the element at the specified tile location with the same data type as the input tile""",
3788
+ group="Tile Primitives",
3789
+ hidden=True,
3790
+ export=False,
3791
+ )
3792
+
3793
+ add_builtin(
3794
+ "tile_extract",
3795
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int},
3796
+ value_func=tile_extract_value_func,
3797
+ variadic=False,
3798
+ doc="""Extract a single element from the tile.
3799
+
3800
+ This function will extract an element from the tile and broadcast its value to all threads in the block.
3801
+
3802
+ Note that this may incur additional synchronization if the source tile is a register tile.
3803
+
3804
+ :param a: Tile to extract the element from
3805
+ :param i: Coordinate of element on first dimension
3806
+ :param j: Coordinate of element on the second dimension, or first matrix index
3807
+ :param k: Coordinate of element on the third dimension, or vector index, or second matrix index
3636
3808
  :returns: The value of the element at the specified tile location with the same data type as the input tile""",
3637
3809
  group="Tile Primitives",
3638
3810
  hidden=True,
@@ -3641,7 +3813,36 @@ add_builtin(
3641
3813
 
3642
3814
  add_builtin(
3643
3815
  "tile_extract",
3644
- input_types={"a": tile(dtype=Any, shape=Tuple[int, int, int]), "i": int, "j": int, "k": int},
3816
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int},
3817
+ value_func=tile_extract_value_func,
3818
+ variadic=False,
3819
+ doc="""Extract a single element from the tile.
3820
+
3821
+ This function will extract an element from the tile and broadcast its value to all threads in the block.
3822
+
3823
+ Note that this may incur additional synchronization if the source tile is a register tile.
3824
+
3825
+ :param a: Tile to extract the element from
3826
+ :param i: Coordinate of element on first dimension
3827
+ :param j: Coordinate of element on the second dimension
3828
+ :param k: Coordinate of element on the third dimension, or first matrix index
3829
+ :param l: Coordinate of element on the fourth dimension, or vector index, or second matrix index
3830
+ :returns: The value of the element at the specified tile location, with the same data type as the input tile""",
3831
+ group="Tile Primitives",
3832
+ hidden=True,
3833
+ export=False,
3834
+ )
3835
+
3836
+ add_builtin(
3837
+ "tile_extract",
3838
+ input_types={
3839
+ "a": tile(dtype=Any, shape=Tuple[int, ...]),
3840
+ "i": int,
3841
+ "j": int,
3842
+ "k": int,
3843
+ "l": int,
3844
+ "m": int,
3845
+ },
3645
3846
  value_func=tile_extract_value_func,
3646
3847
  variadic=False,
3647
3848
  doc="""Extract a single element from the tile.
@@ -3654,7 +3855,9 @@ add_builtin(
3654
3855
  :param i: Coordinate of element on first dimension
3655
3856
  :param j: Coordinate of element on the second dimension
3656
3857
  :param k: Coordinate of element on the third dimension
3657
- :returns: The value of the element at the specified tile location with the same data type as the input tile""",
3858
+ :param l: Coordinate of element on the fourth dimension, or first matrix index
3859
+ :param m: Vector index, or second matrix index
3860
+ :returns: The value of the element at the specified tile location, with the same data type as the input tile""",
3658
3861
  group="Tile Primitives",
3659
3862
  hidden=True,
3660
3863
  export=False,
@@ -3662,7 +3865,15 @@ add_builtin(
3662
3865
 
3663
3866
  add_builtin(
3664
3867
  "tile_extract",
3665
- input_types={"a": tile(dtype=Any, shape=Tuple[int, int, int, int]), "i": int, "j": int, "k": int, "l": int},
3868
+ input_types={
3869
+ "a": tile(dtype=Any, shape=Tuple[int, int, int, int]),
3870
+ "i": int,
3871
+ "j": int,
3872
+ "k": int,
3873
+ "l": int,
3874
+ "m": int,
3875
+ "n": int,
3876
+ },
3666
3877
  value_func=tile_extract_value_func,
3667
3878
  variadic=False,
3668
3879
  doc="""Extract a single element from the tile.
@@ -3676,6 +3887,8 @@ add_builtin(
3676
3887
  :param j: Coordinate of element on the second dimension
3677
3888
  :param k: Coordinate of element on the third dimension
3678
3889
  :param l: Coordinate of element on the fourth dimension
3890
+ :param m: Vector index, or first matrix index
3891
+ :param n: Second matrix index
3679
3892
  :returns: The value of the element at the specified tile location, with the same data type as the input tile""",
3680
3893
  group="Tile Primitives",
3681
3894
  hidden=True,
@@ -3762,50 +3975,161 @@ add_builtin(
3762
3975
  export=False,
3763
3976
  )
3764
3977
 
3765
-
3766
- def tile_transpose_value_func(arg_types, arg_values):
3767
- # return generic type (for doc builds)
3768
- if arg_types is None:
3769
- return tile(dtype=Any, shape=Tuple[int, int])
3770
-
3771
- if len(arg_types) != 1:
3772
- raise TypeError(f"tile_transpose() takes exactly 1 positional argument but {len(arg_types)} were given")
3773
-
3774
- t = arg_types["a"]
3775
-
3776
- if not is_tile(t):
3777
- raise TypeError(f"tile_transpose() argument must be a tile, got {t!r}")
3778
-
3779
- layout = None
3780
-
3781
- # flip layout
3782
- if t.layout == "rowmajor":
3783
- layout = "colmajor"
3784
- elif t.layout == "colmajor":
3785
- layout = "rowmajor"
3786
-
3787
- # force the input tile to shared memory
3788
- t.storage = "shared"
3789
-
3790
- return tile(
3791
- dtype=t.dtype,
3792
- shape=t.shape[::-1],
3793
- storage=t.storage,
3794
- strides=t.strides[::-1],
3795
- layout=layout,
3796
- owner=False,
3797
- )
3798
-
3799
-
3800
3978
  add_builtin(
3801
- "tile_transpose",
3802
- input_types={"a": tile(dtype=Any, shape=Tuple[int, int])},
3803
- value_func=tile_transpose_value_func,
3804
- variadic=True,
3805
- doc="""Transpose a tile.
3806
-
3807
- For shared memory tiles, this operation will alias the input tile.
3808
- Register tiles will first be transferred to shared memory before transposition.
3979
+ "tile_bit_and_inplace",
3980
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
3981
+ value_func=tile_inplace_value_func,
3982
+ group="Tile Primitives",
3983
+ hidden=True,
3984
+ export=False,
3985
+ is_differentiable=False,
3986
+ )
3987
+ add_builtin(
3988
+ "tile_bit_and_inplace",
3989
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
3990
+ value_func=tile_inplace_value_func,
3991
+ group="Tile Primitives",
3992
+ hidden=True,
3993
+ export=False,
3994
+ is_differentiable=False,
3995
+ )
3996
+ add_builtin(
3997
+ "tile_bit_and_inplace",
3998
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
3999
+ value_func=tile_inplace_value_func,
4000
+ group="Tile Primitives",
4001
+ hidden=True,
4002
+ export=False,
4003
+ is_differentiable=False,
4004
+ )
4005
+ add_builtin(
4006
+ "tile_bit_and_inplace",
4007
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
4008
+ value_func=tile_inplace_value_func,
4009
+ group="Tile Primitives",
4010
+ hidden=True,
4011
+ export=False,
4012
+ is_differentiable=False,
4013
+ )
4014
+
4015
+ add_builtin(
4016
+ "tile_bit_or_inplace",
4017
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
4018
+ value_func=tile_inplace_value_func,
4019
+ group="Tile Primitives",
4020
+ hidden=True,
4021
+ export=False,
4022
+ is_differentiable=False,
4023
+ )
4024
+ add_builtin(
4025
+ "tile_bit_or_inplace",
4026
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
4027
+ value_func=tile_inplace_value_func,
4028
+ group="Tile Primitives",
4029
+ hidden=True,
4030
+ export=False,
4031
+ is_differentiable=False,
4032
+ )
4033
+ add_builtin(
4034
+ "tile_bit_or_inplace",
4035
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
4036
+ value_func=tile_inplace_value_func,
4037
+ group="Tile Primitives",
4038
+ hidden=True,
4039
+ export=False,
4040
+ is_differentiable=False,
4041
+ )
4042
+ add_builtin(
4043
+ "tile_bit_or_inplace",
4044
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
4045
+ value_func=tile_inplace_value_func,
4046
+ group="Tile Primitives",
4047
+ hidden=True,
4048
+ export=False,
4049
+ is_differentiable=False,
4050
+ )
4051
+
4052
+ add_builtin(
4053
+ "tile_bit_xor_inplace",
4054
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
4055
+ value_func=tile_inplace_value_func,
4056
+ group="Tile Primitives",
4057
+ hidden=True,
4058
+ export=False,
4059
+ is_differentiable=False,
4060
+ )
4061
+ add_builtin(
4062
+ "tile_bit_xor_inplace",
4063
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
4064
+ value_func=tile_inplace_value_func,
4065
+ group="Tile Primitives",
4066
+ hidden=True,
4067
+ export=False,
4068
+ is_differentiable=False,
4069
+ )
4070
+ add_builtin(
4071
+ "tile_bit_xor_inplace",
4072
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
4073
+ value_func=tile_inplace_value_func,
4074
+ group="Tile Primitives",
4075
+ hidden=True,
4076
+ export=False,
4077
+ is_differentiable=False,
4078
+ )
4079
+ add_builtin(
4080
+ "tile_bit_xor_inplace",
4081
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
4082
+ value_func=tile_inplace_value_func,
4083
+ group="Tile Primitives",
4084
+ hidden=True,
4085
+ export=False,
4086
+ is_differentiable=False,
4087
+ )
4088
+
4089
+
4090
+ def tile_transpose_value_func(arg_types, arg_values):
4091
+ # return generic type (for doc builds)
4092
+ if arg_types is None:
4093
+ return tile(dtype=Any, shape=Tuple[int, int])
4094
+
4095
+ if len(arg_types) != 1:
4096
+ raise TypeError(f"tile_transpose() takes exactly 1 positional argument but {len(arg_types)} were given")
4097
+
4098
+ t = arg_types["a"]
4099
+
4100
+ if not is_tile(t):
4101
+ raise TypeError(f"tile_transpose() argument must be a tile, got {t!r}")
4102
+
4103
+ layout = None
4104
+
4105
+ # flip layout
4106
+ if t.layout == "rowmajor":
4107
+ layout = "colmajor"
4108
+ elif t.layout == "colmajor":
4109
+ layout = "rowmajor"
4110
+
4111
+ # force the input tile to shared memory
4112
+ t.storage = "shared"
4113
+
4114
+ return tile(
4115
+ dtype=t.dtype,
4116
+ shape=t.shape[::-1],
4117
+ storage=t.storage,
4118
+ strides=t.strides[::-1],
4119
+ layout=layout,
4120
+ owner=False,
4121
+ )
4122
+
4123
+
4124
+ add_builtin(
4125
+ "tile_transpose",
4126
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, int])},
4127
+ value_func=tile_transpose_value_func,
4128
+ variadic=True,
4129
+ doc="""Transpose a tile.
4130
+
4131
+ For shared memory tiles, this operation will alias the input tile.
4132
+ Register tiles will first be transferred to shared memory before transposition.
3809
4133
 
3810
4134
  :param a: Tile to transpose with ``shape=(M,N)``
3811
4135
  :returns: Tile with ``shape=(N,M)``""",
@@ -3935,6 +4259,80 @@ add_builtin(
3935
4259
  )
3936
4260
 
3937
4261
 
4262
+ def tile_sum_axis_value_func(arg_types, arg_values):
4263
+ if arg_types is None:
4264
+ return tile(dtype=Scalar, shape=Tuple[int, ...])
4265
+
4266
+ a = arg_types["a"]
4267
+
4268
+ if not is_tile(a):
4269
+ raise TypeError(f"tile_sum() 'a' argument must be a tile, got {a!r}")
4270
+
4271
+ # force input tile to shared
4272
+ a.storage = "shared"
4273
+
4274
+ axis = arg_values["axis"]
4275
+ shape = a.shape
4276
+
4277
+ if axis < 0 or axis >= len(shape):
4278
+ raise ValueError(f"tile_sum() axis {axis} is out of bounds for tile with {len(shape)} dimensions")
4279
+
4280
+ # shape is identical less the axis reduction is along
4281
+ if len(shape) > 1:
4282
+ new_shape = shape[:axis] + shape[axis + 1 :]
4283
+ else:
4284
+ new_shape = (1,)
4285
+
4286
+ return tile(dtype=a.dtype, shape=new_shape)
4287
+
4288
+
4289
+ def tile_sum_axis_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
4290
+ tile = arg_values["a"]
4291
+ axis_var = arg_values["axis"]
4292
+ if not hasattr(axis_var, "constant") or axis_var.constant is None:
4293
+ raise ValueError("tile_sum() axis must be a compile-time constant")
4294
+ axis = axis_var.constant
4295
+
4296
+ return ((tile,), (axis,))
4297
+
4298
+
4299
+ add_builtin(
4300
+ "tile_sum",
4301
+ input_types={"a": tile(dtype=Scalar, shape=Tuple[int, ...]), "axis": int},
4302
+ value_func=tile_sum_axis_value_func,
4303
+ dispatch_func=tile_sum_axis_dispatch_func,
4304
+ doc="""Cooperatively compute the sum of the tile elements across an axis of the tile using all threads in the block.
4305
+
4306
+ :param a: The input tile. Must reside in shared memory.
4307
+ :param axis: The tile axis to compute the sum across. Must be a compile-time constant.
4308
+ :returns: A tile with the same shape as the input tile less the axis dimension and the same data type as the input tile.
4309
+
4310
+ Example:
4311
+
4312
+ .. code-block:: python
4313
+
4314
+ @wp.kernel
4315
+ def compute():
4316
+
4317
+ t = wp.tile_ones(dtype=float, shape=(8, 8))
4318
+ s = wp.tile_sum(t, axis=0)
4319
+
4320
+ print(s)
4321
+
4322
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
4323
+
4324
+ Prints:
4325
+
4326
+ .. code-block:: text
4327
+
4328
+ [8 8 8 8 8 8 8 8] = tile(shape=(8), storage=register)
4329
+
4330
+ """,
4331
+ group="Tile Primitives",
4332
+ export=False,
4333
+ )
4334
+
4335
+
3938
4336
  def tile_sort_value_func(arg_types, arg_values):
3939
4337
  # return generic type (for doc builds)
3940
4338
  if arg_types is None:
@@ -4011,7 +4409,7 @@ add_builtin(
4011
4409
  """,
4012
4410
  group="Tile Primitives",
4013
4411
  export=False,
4014
- missing_grad=True,
4412
+ is_differentiable=False,
4015
4413
  )
4016
4414
 
4017
4415
 
@@ -4065,7 +4463,7 @@ add_builtin(
4065
4463
  """,
4066
4464
  group="Tile Primitives",
4067
4465
  export=False,
4068
- missing_grad=True,
4466
+ is_differentiable=False,
4069
4467
  )
4070
4468
 
4071
4469
 
@@ -4119,7 +4517,7 @@ add_builtin(
4119
4517
  """,
4120
4518
  group="Tile Primitives",
4121
4519
  export=False,
4122
- missing_grad=True,
4520
+ is_differentiable=False,
4123
4521
  )
4124
4522
 
4125
4523
 
@@ -4172,7 +4570,7 @@ add_builtin(
4172
4570
  """,
4173
4571
  group="Tile Primitives",
4174
4572
  export=False,
4175
- missing_grad=True,
4573
+ is_differentiable=False,
4176
4574
  )
4177
4575
 
4178
4576
 
@@ -4225,11 +4623,10 @@ add_builtin(
4225
4623
  """,
4226
4624
  group="Tile Primitives",
4227
4625
  export=False,
4228
- missing_grad=True,
4626
+ is_differentiable=False,
4229
4627
  )
4230
4628
 
4231
4629
 
4232
- # does type propagation for load()
4233
4630
  def tile_reduce_value_func(arg_types, arg_values):
4234
4631
  if arg_types is None:
4235
4632
  return tile(dtype=Scalar, shape=(1,))
@@ -4283,7 +4680,88 @@ add_builtin(
4283
4680
  """,
4284
4681
  group="Tile Primitives",
4285
4682
  export=False,
4286
- missing_grad=True,
4683
+ is_differentiable=False,
4684
+ )
4685
+
4686
+
4687
+ def tile_reduce_axis_value_func(arg_types, arg_values):
4688
+ if arg_types is None:
4689
+ return tile(dtype=Scalar, shape=Tuple[int, ...])
4690
+
4691
+ a = arg_types["a"]
4692
+
4693
+ if not is_tile(a):
4694
+ raise TypeError(f"tile_reduce() 'a' argument must be a tile, got {a!r}")
4695
+
4696
+ # force input tile to shared memory
4697
+ a.storage = "shared"
4698
+
4699
+ axis = arg_values["axis"]
4700
+ shape = a.shape
4701
+
4702
+ if axis < 0 or axis >= len(shape):
4703
+ raise ValueError(f"tile_reduce() axis {axis} is out of bounds for tile with {len(shape)} dimensions")
4704
+
4705
+ # shape is identical less the axis reduction is along
4706
+ if len(shape) > 1:
4707
+ new_shape = shape[:axis] + shape[axis + 1 :]
4708
+ else:
4709
+ new_shape = (1,)
4710
+
4711
+ return tile(dtype=a.dtype, shape=new_shape)
4712
+
4713
+
4714
+ add_builtin(
4715
+ "tile_reduce",
4716
+ input_types={"op": Callable, "a": tile(dtype=Scalar, shape=Tuple[int, ...]), "axis": int},
4717
+ value_func=tile_reduce_axis_value_func,
4718
+ native_func="tile_reduce_axis",
4719
+ doc="""Apply a custom reduction operator across a tile axis.
4720
+
4721
+ This function cooperatively performs a reduction using the provided operator across an axis of the tile.
4722
+
4723
+ :param op: A callable function that accepts two arguments and returns one argument, may be a user function or builtin
4724
+ :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's data type. Must reside in shared memory.
4725
+ :param axis: The tile axis to perform the reduction across. Must be a compile-time constant.
4726
+ :returns: A tile with the same shape as the input tile less the axis dimension and the same data type as the input tile.
4727
+
4728
+ Example:
4729
+
4730
+ .. code-block:: python
4731
+
4732
+ TILE_M = wp.constant(4)
4733
+ TILE_N = wp.constant(2)
4734
+
4735
+ @wp.kernel
4736
+ def compute(x: wp.array2d(dtype=float), y: wp.array(dtype=float)):
4737
+
4738
+ a = wp.tile_load(x, shape=(TILE_M, TILE_N))
4739
+ b = wp.tile_reduce(wp.add, a, axis=1)
4740
+ wp.tile_store(y, b)
4741
+
4742
+ arr = np.arange(TILE_M * TILE_N).reshape(TILE_M, TILE_N)
4743
+
4744
+ x = wp.array(arr, dtype=float)
4745
+ y = wp.zeros(TILE_M, dtype=float)
4746
+
4747
+ wp.launch_tiled(compute, dim=[1], inputs=[x], outputs=[y], block_dim=32)
4748
+
4749
+ print(x.numpy())
4750
+ print(y.numpy())
4751
+
4752
+ Prints:
4753
+
4754
+ .. code-block:: text
4755
+
4756
+ [[0. 1.]
4757
+ [2. 3.]
4758
+ [4. 5.]
4759
+ [6. 7.]]
4760
+ [ 1. 5. 9. 13.]
4761
+ """,
4762
+ group="Tile Primitives",
4763
+ export=False,
4764
+ is_differentiable=False,
4287
4765
  )
4288
4766
 
4289
4767
 
@@ -4347,7 +4825,7 @@ add_builtin(
4347
4825
  """,
4348
4826
  group="Tile Primitives",
4349
4827
  export=False,
4350
- missing_grad=True,
4828
+ is_differentiable=False,
4351
4829
  )
4352
4830
 
4353
4831
 
@@ -4411,7 +4889,7 @@ add_builtin(
4411
4889
  """,
4412
4890
  group="Tile Primitives",
4413
4891
  export=False,
4414
- missing_grad=True,
4892
+ is_differentiable=False,
4415
4893
  )
4416
4894
 
4417
4895
 
@@ -4665,7 +5143,7 @@ add_builtin(
4665
5143
  doc="WIP",
4666
5144
  group="Utility",
4667
5145
  hidden=True,
4668
- missing_grad=True,
5146
+ is_differentiable=False,
4669
5147
  )
4670
5148
 
4671
5149
  add_builtin(
@@ -4681,7 +5159,7 @@ add_builtin(
4681
5159
  doc="WIP",
4682
5160
  group="Utility",
4683
5161
  hidden=True,
4684
- missing_grad=True,
5162
+ is_differentiable=False,
4685
5163
  )
4686
5164
 
4687
5165
  add_builtin(
@@ -4691,7 +5169,7 @@ add_builtin(
4691
5169
  doc="WIP",
4692
5170
  group="Utility",
4693
5171
  hidden=True,
4694
- missing_grad=True,
5172
+ is_differentiable=False,
4695
5173
  )
4696
5174
 
4697
5175
  add_builtin(
@@ -4743,7 +5221,7 @@ add_builtin(
4743
5221
  :param low: The lower bound of the bounding box in BVH space
4744
5222
  :param high: The upper bound of the bounding box in BVH space""",
4745
5223
  export=False,
4746
- missing_grad=True,
5224
+ is_differentiable=False,
4747
5225
  )
4748
5226
 
4749
5227
  add_builtin(
@@ -4759,7 +5237,7 @@ add_builtin(
4759
5237
  :param start: The start of the ray in BVH space
4760
5238
  :param dir: The direction of the ray in BVH space""",
4761
5239
  export=False,
4762
- missing_grad=True,
5240
+ is_differentiable=False,
4763
5241
  )
4764
5242
 
4765
5243
  add_builtin(
@@ -4770,7 +5248,7 @@ add_builtin(
4770
5248
  doc="""Move to the next bound returned by the query.
4771
5249
  The index of the current bound is stored in ``index``, returns ``False`` if there are no more overlapping bound.""",
4772
5250
  export=False,
4773
- missing_grad=True,
5251
+ is_differentiable=False,
4774
5252
  )
4775
5253
 
4776
5254
  add_builtin(
@@ -5111,7 +5589,7 @@ add_builtin(
5111
5589
  :param low: The lower bound of the bounding box in mesh space
5112
5590
  :param high: The upper bound of the bounding box in mesh space""",
5113
5591
  export=False,
5114
- missing_grad=True,
5592
+ is_differentiable=False,
5115
5593
  )
5116
5594
 
5117
5595
  add_builtin(
@@ -5123,7 +5601,7 @@ add_builtin(
5123
5601
 
5124
5602
  The index of the current face is stored in ``index``, returns ``False`` if there are no more overlapping triangles.""",
5125
5603
  export=False,
5126
- missing_grad=True,
5604
+ is_differentiable=False,
5127
5605
  )
5128
5606
 
5129
5607
  add_builtin(
@@ -5153,7 +5631,7 @@ add_builtin(
5153
5631
 
5154
5632
  This query can be used to iterate over all neighboring point within a fixed radius from the query point.""",
5155
5633
  export=False,
5156
- missing_grad=True,
5634
+ is_differentiable=False,
5157
5635
  )
5158
5636
 
5159
5637
  add_builtin(
@@ -5165,7 +5643,7 @@ add_builtin(
5165
5643
 
5166
5644
  The index of the current neighbor is stored in ``index``, returns ``False`` if there are no more neighbors.""",
5167
5645
  export=False,
5168
- missing_grad=True,
5646
+ is_differentiable=False,
5169
5647
  )
5170
5648
 
5171
5649
  add_builtin(
@@ -5179,7 +5657,7 @@ add_builtin(
5179
5657
 
5180
5658
  Returns -1 if the :class:`HashGrid` has not been reserved.""",
5181
5659
  export=False,
5182
- missing_grad=True,
5660
+ is_differentiable=False,
5183
5661
  )
5184
5662
 
5185
5663
  add_builtin(
@@ -5189,16 +5667,34 @@ add_builtin(
5189
5667
  group="Geometry",
5190
5668
  doc="""Tests for intersection between two triangles (v0, v1, v2) and (u0, u1, u2) using Moller's method.
5191
5669
 
5670
+ This function works with single precision, may return incorrect results in some case.
5671
+
5672
+ Returns > 0 if triangles intersect.""",
5673
+ export=False,
5674
+ is_differentiable=False,
5675
+ )
5676
+
5677
+
5678
+ add_builtin(
5679
+ "intersect_tri_tri",
5680
+ input_types={"v0": vec3d, "v1": vec3d, "v2": vec3d, "u0": vec3d, "u1": vec3d, "u2": vec3d},
5681
+ value_type=int,
5682
+ group="Geometry",
5683
+ doc="""Tests for intersection between two triangles (v0, v1, v2) and (u0, u1, u2) using Moller's method.
5684
+
5685
+ This function works with double precision, results are more accurate than the single precision version.
5686
+
5192
5687
  Returns > 0 if triangles intersect.""",
5193
5688
  export=False,
5194
- missing_grad=True,
5689
+ is_differentiable=False,
5195
5690
  )
5196
5691
 
5692
+
5197
5693
  add_builtin(
5198
5694
  "mesh_get",
5199
5695
  input_types={"id": uint64},
5200
5696
  value_type=Mesh,
5201
- missing_grad=True,
5697
+ is_differentiable=False,
5202
5698
  group="Geometry",
5203
5699
  doc="""Retrieves the mesh given its index.""",
5204
5700
  export=False,
@@ -5211,7 +5707,7 @@ add_builtin(
5211
5707
  group="Geometry",
5212
5708
  doc="""Evaluates the face normal the mesh given a face index.""",
5213
5709
  export=False,
5214
- missing_grad=True,
5710
+ is_differentiable=False,
5215
5711
  )
5216
5712
 
5217
5713
  add_builtin(
@@ -5221,7 +5717,7 @@ add_builtin(
5221
5717
  group="Geometry",
5222
5718
  doc="""Returns the point of the mesh given a index.""",
5223
5719
  export=False,
5224
- missing_grad=True,
5720
+ is_differentiable=False,
5225
5721
  )
5226
5722
 
5227
5723
  add_builtin(
@@ -5231,7 +5727,7 @@ add_builtin(
5231
5727
  group="Geometry",
5232
5728
  doc="""Returns the velocity of the mesh given a index.""",
5233
5729
  export=False,
5234
- missing_grad=True,
5730
+ is_differentiable=False,
5235
5731
  )
5236
5732
 
5237
5733
  add_builtin(
@@ -5241,7 +5737,7 @@ add_builtin(
5241
5737
  group="Geometry",
5242
5738
  doc="""Returns the point-index of the mesh given a face-vertex index.""",
5243
5739
  export=False,
5244
- missing_grad=True,
5740
+ is_differentiable=False,
5245
5741
  )
5246
5742
 
5247
5743
 
@@ -5289,7 +5785,7 @@ add_builtin(
5289
5785
  group="Utility",
5290
5786
  export=False,
5291
5787
  hidden=True,
5292
- missing_grad=True,
5788
+ is_differentiable=False,
5293
5789
  )
5294
5790
  add_builtin(
5295
5791
  "iter_next",
@@ -5298,7 +5794,7 @@ add_builtin(
5298
5794
  group="Utility",
5299
5795
  export=False,
5300
5796
  hidden=True,
5301
- missing_grad=True,
5797
+ is_differentiable=False,
5302
5798
  )
5303
5799
  add_builtin(
5304
5800
  "iter_next",
@@ -5307,7 +5803,7 @@ add_builtin(
5307
5803
  group="Utility",
5308
5804
  export=False,
5309
5805
  hidden=True,
5310
- missing_grad=True,
5806
+ is_differentiable=False,
5311
5807
  )
5312
5808
 
5313
5809
  add_builtin(
@@ -5318,7 +5814,7 @@ add_builtin(
5318
5814
  group="Utility",
5319
5815
  doc="""Returns the range in reversed order.""",
5320
5816
  export=False,
5321
- missing_grad=True,
5817
+ is_differentiable=False,
5322
5818
  )
5323
5819
 
5324
5820
  # ---------------------------------
@@ -5338,8 +5834,8 @@ _volume_supported_value_types = {
5338
5834
 
5339
5835
 
5340
5836
  def _is_volume_type_supported(dtype):
5341
- for typ in _volume_supported_value_types:
5342
- if types_equal(typ, dtype):
5837
+ for value_type in _volume_supported_value_types:
5838
+ if types_equal(value_type, dtype):
5343
5839
  return True
5344
5840
  return False
5345
5841
 
@@ -5467,7 +5963,7 @@ add_builtin(
5467
5963
  doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k`` for a volume of type type `dtype`.
5468
5964
 
5469
5965
  If the voxel at this index does not exist, this function returns the background value.""",
5470
- missing_grad=True,
5966
+ is_differentiable=False,
5471
5967
  )
5472
5968
 
5473
5969
 
@@ -5488,7 +5984,7 @@ add_builtin(
5488
5984
  export=False,
5489
5985
  group="Volumes",
5490
5986
  doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
5491
- missing_grad=True,
5987
+ is_differentiable=False,
5492
5988
  )
5493
5989
 
5494
5990
  add_builtin(
@@ -5519,7 +6015,7 @@ add_builtin(
5519
6015
  doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k``.
5520
6016
 
5521
6017
  If the voxel at this index does not exist, this function returns the background value""",
5522
- missing_grad=True,
6018
+ is_differentiable=False,
5523
6019
  )
5524
6020
 
5525
6021
  add_builtin(
@@ -5528,7 +6024,7 @@ add_builtin(
5528
6024
  group="Volumes",
5529
6025
  doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
5530
6026
  export=False,
5531
- missing_grad=True,
6027
+ is_differentiable=False,
5532
6028
  )
5533
6029
 
5534
6030
  add_builtin(
@@ -5549,7 +6045,7 @@ add_builtin(
5549
6045
  doc="""Returns the vector value of voxel with coordinates ``i``, ``j``, ``k``.
5550
6046
 
5551
6047
  If the voxel at this index does not exist, this function returns the background value.""",
5552
- missing_grad=True,
6048
+ is_differentiable=False,
5553
6049
  )
5554
6050
 
5555
6051
  add_builtin(
@@ -5558,7 +6054,7 @@ add_builtin(
5558
6054
  group="Volumes",
5559
6055
  doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
5560
6056
  export=False,
5561
- missing_grad=True,
6057
+ is_differentiable=False,
5562
6058
  )
5563
6059
 
5564
6060
  add_builtin(
@@ -5577,7 +6073,7 @@ add_builtin(
5577
6073
  doc="""Returns the :class:`int32` value of voxel with coordinates ``i``, ``j``, ``k``.
5578
6074
 
5579
6075
  If the voxel at this index does not exist, this function returns the background value.""",
5580
- missing_grad=True,
6076
+ is_differentiable=False,
5581
6077
  )
5582
6078
 
5583
6079
  add_builtin(
@@ -5586,7 +6082,7 @@ add_builtin(
5586
6082
  group="Volumes",
5587
6083
  doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
5588
6084
  export=False,
5589
- missing_grad=True,
6085
+ is_differentiable=False,
5590
6086
  )
5591
6087
 
5592
6088
 
@@ -5668,7 +6164,7 @@ add_builtin(
5668
6164
  If the voxel at this index does not exist, this function returns -1.
5669
6165
  This function is available for both index grids and classical volumes.
5670
6166
  """,
5671
- missing_grad=True,
6167
+ is_differentiable=False,
5672
6168
  )
5673
6169
 
5674
6170
  add_builtin(
@@ -5710,7 +6206,7 @@ add_builtin(
5710
6206
  value_type=uint32,
5711
6207
  group="Random",
5712
6208
  doc="Initialize a new random number generator given a user-defined seed. Returns a 32-bit integer representing the RNG state.",
5713
- missing_grad=True,
6209
+ is_differentiable=False,
5714
6210
  )
5715
6211
 
5716
6212
  add_builtin(
@@ -5722,7 +6218,7 @@ add_builtin(
5722
6218
 
5723
6219
  This alternative constructor can be useful in parallel programs, where a kernel as a whole should share a seed,
5724
6220
  but each thread should generate uncorrelated values. In this case usage should be ``r = rand_init(seed, tid)``""",
5725
- missing_grad=True,
6221
+ is_differentiable=False,
5726
6222
  )
5727
6223
 
5728
6224
  add_builtin(
@@ -5731,7 +6227,7 @@ add_builtin(
5731
6227
  value_type=int,
5732
6228
  group="Random",
5733
6229
  doc="Return a random integer in the range [-2^31, 2^31).",
5734
- missing_grad=True,
6230
+ is_differentiable=False,
5735
6231
  )
5736
6232
  add_builtin(
5737
6233
  "randi",
@@ -5739,7 +6235,7 @@ add_builtin(
5739
6235
  value_type=int,
5740
6236
  group="Random",
5741
6237
  doc="Return a random integer between [low, high).",
5742
- missing_grad=True,
6238
+ is_differentiable=False,
5743
6239
  )
5744
6240
  add_builtin(
5745
6241
  "randu",
@@ -5747,7 +6243,7 @@ add_builtin(
5747
6243
  value_type=uint32,
5748
6244
  group="Random",
5749
6245
  doc="Return a random unsigned integer in the range [0, 2^32).",
5750
- missing_grad=True,
6246
+ is_differentiable=False,
5751
6247
  )
5752
6248
  add_builtin(
5753
6249
  "randu",
@@ -5755,7 +6251,7 @@ add_builtin(
5755
6251
  value_type=uint32,
5756
6252
  group="Random",
5757
6253
  doc="Return a random unsigned integer between [low, high).",
5758
- missing_grad=True,
6254
+ is_differentiable=False,
5759
6255
  )
5760
6256
  add_builtin(
5761
6257
  "randf",
@@ -5763,7 +6259,7 @@ add_builtin(
5763
6259
  value_type=float,
5764
6260
  group="Random",
5765
6261
  doc="Return a random float between [0.0, 1.0).",
5766
- missing_grad=True,
6262
+ is_differentiable=False,
5767
6263
  )
5768
6264
  add_builtin(
5769
6265
  "randf",
@@ -5771,7 +6267,7 @@ add_builtin(
5771
6267
  value_type=float,
5772
6268
  group="Random",
5773
6269
  doc="Return a random float between [low, high).",
5774
- missing_grad=True,
6270
+ is_differentiable=False,
5775
6271
  )
5776
6272
  add_builtin(
5777
6273
  "randn",
@@ -5779,7 +6275,7 @@ add_builtin(
5779
6275
  value_type=float,
5780
6276
  group="Random",
5781
6277
  doc="Sample a normal (Gaussian) distribution of mean 0 and variance 1. ",
5782
- missing_grad=True,
6278
+ is_differentiable=False,
5783
6279
  )
5784
6280
 
5785
6281
  add_builtin(
@@ -5788,7 +6284,7 @@ add_builtin(
5788
6284
  value_type=int,
5789
6285
  group="Random",
5790
6286
  doc="Inverse-transform sample a cumulative distribution function.",
5791
- missing_grad=True,
6287
+ is_differentiable=False,
5792
6288
  )
5793
6289
  add_builtin(
5794
6290
  "sample_triangle",
@@ -5796,7 +6292,7 @@ add_builtin(
5796
6292
  value_type=vec2,
5797
6293
  group="Random",
5798
6294
  doc="Uniformly sample a triangle. Returns sample barycentric coordinates.",
5799
- missing_grad=True,
6295
+ is_differentiable=False,
5800
6296
  )
5801
6297
  add_builtin(
5802
6298
  "sample_unit_ring",
@@ -5804,7 +6300,7 @@ add_builtin(
5804
6300
  value_type=vec2,
5805
6301
  group="Random",
5806
6302
  doc="Uniformly sample a ring in the xy plane.",
5807
- missing_grad=True,
6303
+ is_differentiable=False,
5808
6304
  )
5809
6305
  add_builtin(
5810
6306
  "sample_unit_disk",
@@ -5812,7 +6308,7 @@ add_builtin(
5812
6308
  value_type=vec2,
5813
6309
  group="Random",
5814
6310
  doc="Uniformly sample a disk in the xy plane.",
5815
- missing_grad=True,
6311
+ is_differentiable=False,
5816
6312
  )
5817
6313
  add_builtin(
5818
6314
  "sample_unit_sphere_surface",
@@ -5820,7 +6316,7 @@ add_builtin(
5820
6316
  value_type=vec3,
5821
6317
  group="Random",
5822
6318
  doc="Uniformly sample a unit sphere surface.",
5823
- missing_grad=True,
6319
+ is_differentiable=False,
5824
6320
  )
5825
6321
  add_builtin(
5826
6322
  "sample_unit_sphere",
@@ -5828,7 +6324,7 @@ add_builtin(
5828
6324
  value_type=vec3,
5829
6325
  group="Random",
5830
6326
  doc="Uniformly sample a unit sphere.",
5831
- missing_grad=True,
6327
+ is_differentiable=False,
5832
6328
  )
5833
6329
  add_builtin(
5834
6330
  "sample_unit_hemisphere_surface",
@@ -5836,7 +6332,7 @@ add_builtin(
5836
6332
  value_type=vec3,
5837
6333
  group="Random",
5838
6334
  doc="Uniformly sample a unit hemisphere surface.",
5839
- missing_grad=True,
6335
+ is_differentiable=False,
5840
6336
  )
5841
6337
  add_builtin(
5842
6338
  "sample_unit_hemisphere",
@@ -5844,7 +6340,7 @@ add_builtin(
5844
6340
  value_type=vec3,
5845
6341
  group="Random",
5846
6342
  doc="Uniformly sample a unit hemisphere.",
5847
- missing_grad=True,
6343
+ is_differentiable=False,
5848
6344
  )
5849
6345
  add_builtin(
5850
6346
  "sample_unit_square",
@@ -5852,7 +6348,7 @@ add_builtin(
5852
6348
  value_type=vec2,
5853
6349
  group="Random",
5854
6350
  doc="Uniformly sample a unit square.",
5855
- missing_grad=True,
6351
+ is_differentiable=False,
5856
6352
  )
5857
6353
  add_builtin(
5858
6354
  "sample_unit_cube",
@@ -5860,7 +6356,7 @@ add_builtin(
5860
6356
  value_type=vec3,
5861
6357
  group="Random",
5862
6358
  doc="Uniformly sample a unit cube.",
5863
- missing_grad=True,
6359
+ is_differentiable=False,
5864
6360
  )
5865
6361
 
5866
6362
  add_builtin(
@@ -5872,7 +6368,7 @@ add_builtin(
5872
6368
 
5873
6369
  :param state: RNG state
5874
6370
  :param lam: The expected value of the distribution""",
5875
- missing_grad=True,
6371
+ is_differentiable=False,
5876
6372
  )
5877
6373
 
5878
6374
  add_builtin(
@@ -5940,7 +6436,7 @@ add_builtin(
5940
6436
  value_type=vec2,
5941
6437
  group="Random",
5942
6438
  doc="Divergence-free vector field based on the gradient of a Perlin noise function.",
5943
- missing_grad=True,
6439
+ is_differentiable=False,
5944
6440
  )
5945
6441
  add_builtin(
5946
6442
  "curlnoise",
@@ -5949,7 +6445,7 @@ add_builtin(
5949
6445
  value_type=vec3,
5950
6446
  group="Random",
5951
6447
  doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
5952
- missing_grad=True,
6448
+ is_differentiable=False,
5953
6449
  )
5954
6450
  add_builtin(
5955
6451
  "curlnoise",
@@ -5958,7 +6454,7 @@ add_builtin(
5958
6454
  value_type=vec3,
5959
6455
  group="Random",
5960
6456
  doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
5961
- missing_grad=True,
6457
+ is_differentiable=False,
5962
6458
  )
5963
6459
 
5964
6460
 
@@ -5990,7 +6486,7 @@ add_builtin(
5990
6486
  dispatch_func=printf_dispatch_func,
5991
6487
  group="Utility",
5992
6488
  doc="Allows printing formatted strings using C-style format specifiers.",
5993
- missing_grad=True,
6489
+ is_differentiable=False,
5994
6490
  )
5995
6491
 
5996
6492
  add_builtin(
@@ -6009,7 +6505,7 @@ add_builtin(
6009
6505
  group="Utility",
6010
6506
  namespace="",
6011
6507
  native_func="__debugbreak",
6012
- missing_grad=True,
6508
+ is_differentiable=False,
6013
6509
  )
6014
6510
 
6015
6511
  # helpers
@@ -6027,7 +6523,7 @@ add_builtin(
6027
6523
  This function may not be called from user-defined Warp functions.""",
6028
6524
  namespace="",
6029
6525
  native_func="builtin_tid1d",
6030
- missing_grad=True,
6526
+ is_differentiable=False,
6031
6527
  )
6032
6528
 
6033
6529
  add_builtin(
@@ -6038,7 +6534,7 @@ add_builtin(
6038
6534
  doc="Returns the number of threads in the current block.",
6039
6535
  namespace="",
6040
6536
  native_func="builtin_block_dim",
6041
- missing_grad=True,
6537
+ is_differentiable=False,
6042
6538
  )
6043
6539
 
6044
6540
  add_builtin(
@@ -6053,7 +6549,7 @@ add_builtin(
6053
6549
  This function may not be called from user-defined Warp functions.""",
6054
6550
  namespace="",
6055
6551
  native_func="builtin_tid2d",
6056
- missing_grad=True,
6552
+ is_differentiable=False,
6057
6553
  )
6058
6554
 
6059
6555
  add_builtin(
@@ -6068,7 +6564,7 @@ add_builtin(
6068
6564
  This function may not be called from user-defined Warp functions.""",
6069
6565
  namespace="",
6070
6566
  native_func="builtin_tid3d",
6071
- missing_grad=True,
6567
+ is_differentiable=False,
6072
6568
  )
6073
6569
 
6074
6570
  add_builtin(
@@ -6083,7 +6579,7 @@ add_builtin(
6083
6579
  This function may not be called from user-defined Warp functions.""",
6084
6580
  namespace="",
6085
6581
  native_func="builtin_tid4d",
6086
- missing_grad=True,
6582
+ is_differentiable=False,
6087
6583
  )
6088
6584
 
6089
6585
 
@@ -6127,56 +6623,20 @@ def select_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
6127
6623
  if arg_types is None:
6128
6624
  return Any
6129
6625
 
6130
- v_true = arg_types["value_if_true"]
6131
- v_false = arg_types["value_if_false"]
6132
-
6133
- if not types_equal(v_true, v_false):
6134
- raise RuntimeError(
6135
- f"select() true value type ({v_true}) must be of the same type as the false type ({v_false})"
6136
- )
6137
-
6138
- if is_tile(v_false):
6139
- if v_true.storage == "register":
6140
- return v_true
6141
- if v_false.storage == "register":
6142
- return v_false
6143
-
6144
- # both v_true and v_false are shared
6145
- return tile(
6146
- dtype=v_true.dtype,
6147
- shape=v_true.shape,
6148
- storage=v_true.storage,
6149
- strides=v_true.strides,
6150
- layout=v_true.layout,
6151
- owner=True,
6152
- )
6153
-
6154
- return v_true
6155
-
6156
-
6157
- def select_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
6158
- warp.utils.warn(
6159
- "wp.select() is deprecated and will be removed in a future\n"
6160
- "version. Use wp.where(cond, value_if_true, value_if_false) instead.",
6161
- category=DeprecationWarning,
6162
- )
6163
-
6164
- func_args = tuple(args.values())
6165
- template_args = ()
6166
-
6167
- return (func_args, template_args)
6626
+ raise RuntimeError("wp.select() has been removed. Use wp.where(cond, value_if_true, value_if_false) instead.")
6168
6627
 
6169
6628
 
6170
6629
  add_builtin(
6171
6630
  "select",
6172
6631
  input_types={"cond": builtins.bool, "value_if_false": Any, "value_if_true": Any},
6173
6632
  value_func=select_value_func,
6174
- dispatch_func=select_dispatch_func,
6175
6633
  doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
6176
6634
 
6177
- .. deprecated:: 1.7
6635
+ .. versionremoved:: 1.10
6178
6636
  Use :func:`where` instead, which has the more intuitive argument order:
6179
- ``where(cond, value_if_true, value_if_false)``.""",
6637
+ ``where(cond, value_if_true, value_if_false)``.
6638
+
6639
+ .. deprecated:: 1.7""",
6180
6640
  group="Utility",
6181
6641
  )
6182
6642
  for t in int_types:
@@ -6184,24 +6644,26 @@ for t in int_types:
6184
6644
  "select",
6185
6645
  input_types={"cond": t, "value_if_false": Any, "value_if_true": Any},
6186
6646
  value_func=select_value_func,
6187
- dispatch_func=select_dispatch_func,
6188
6647
  doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
6189
6648
 
6190
- .. deprecated:: 1.7
6649
+ .. versionremoved:: 1.10
6191
6650
  Use :func:`where` instead, which has the more intuitive argument order:
6192
- ``where(cond, value_if_true, value_if_false)``.""",
6651
+ ``where(cond, value_if_true, value_if_false)``.
6652
+
6653
+ .. deprecated:: 1.7""",
6193
6654
  group="Utility",
6194
6655
  )
6195
6656
  add_builtin(
6196
6657
  "select",
6197
6658
  input_types={"arr": array(dtype=Any), "value_if_false": Any, "value_if_true": Any},
6198
6659
  value_func=select_value_func,
6199
- dispatch_func=select_dispatch_func,
6200
6660
  doc="""Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``.
6201
6661
 
6202
- .. deprecated:: 1.7
6662
+ .. versionremoved:: 1.10
6203
6663
  Use :func:`where` instead, which has the more intuitive argument order:
6204
- ``where(arr, value_if_true, value_if_false)``.""",
6664
+ ``where(arr, value_if_true, value_if_false)``.
6665
+
6666
+ .. deprecated:: 1.7""",
6205
6667
  group="Utility",
6206
6668
  )
6207
6669
 
@@ -6291,7 +6753,7 @@ add_builtin(
6291
6753
  group="Utility",
6292
6754
  hidden=True,
6293
6755
  export=False,
6294
- missing_grad=True,
6756
+ is_differentiable=False,
6295
6757
  )
6296
6758
 
6297
6759
 
@@ -6332,7 +6794,7 @@ add_builtin(
6332
6794
  native_func="fixedarray_t",
6333
6795
  group="Utility",
6334
6796
  export=False,
6335
- missing_grad=True,
6797
+ is_differentiable=False,
6336
6798
  hidden=True, # Unhide once we can document both a built-in and a Python scope function sharing the same name.
6337
6799
  )
6338
6800
 
@@ -6375,14 +6837,13 @@ for array_type in array_types:
6375
6837
  # does argument checking and type propagation for view()
6376
6838
  def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
6377
6839
  arr_type = arg_types["arr"]
6378
- idx_types = tuple(arg_types[x] for x in "ijk" if arg_types.get(x, None) is not None)
6840
+ idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
6379
6841
 
6380
6842
  if not is_array(arr_type):
6381
6843
  raise RuntimeError("view() first argument must be an array")
6382
6844
 
6383
6845
  idx_count = len(idx_types)
6384
-
6385
- if idx_count >= arr_type.ndim:
6846
+ if idx_count > arr_type.ndim:
6386
6847
  raise RuntimeError(
6387
6848
  f"Trying to create an array view with {idx_count} indices, "
6388
6849
  f"but the array only has {arr_type.ndim} dimension(s). "
@@ -6390,14 +6851,35 @@ def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]
6390
6851
  f"the expected number of dimensions, e.g.: def func(param: wp.array3d(dtype=float): ..."
6391
6852
  )
6392
6853
 
6393
- # check index types
6394
- for t in idx_types:
6395
- if not type_is_int(t):
6396
- raise RuntimeError(f"view() index arguments must be of integer type, got index of type {type_repr(t)}")
6854
+ has_slice = any(is_slice(x) for x in idx_types)
6855
+ if has_slice:
6856
+ # check index types
6857
+ for t in idx_types:
6858
+ if not (type_is_int(t) or is_slice(t)):
6859
+ raise RuntimeError(
6860
+ f"view() index arguments must be of integer or slice types, got index of type {type_repr(t)}"
6861
+ )
6862
+
6863
+ # Each integer index collapses one dimension.
6864
+ int_count = sum(x.step == 0 for x in idx_types)
6865
+ ndim = arr_type.ndim - int_count
6866
+ assert ndim > 0
6867
+ else:
6868
+ if idx_count == arr_type.ndim:
6869
+ raise RuntimeError("Expected to call `address()` instead of `view()`")
6870
+
6871
+ # check index types
6872
+ for t in idx_types:
6873
+ if not type_is_int(t):
6874
+ raise RuntimeError(
6875
+ f"view() index arguments must be of integer or slice types, got index of type {type_repr(t)}"
6876
+ )
6877
+
6878
+ # create an array view with leading dimensions removed
6879
+ ndim = arr_type.ndim - idx_count
6880
+ assert ndim > 0
6397
6881
 
6398
- # create an array view with leading dimensions removed
6399
6882
  dtype = arr_type.dtype
6400
- ndim = arr_type.ndim - idx_count
6401
6883
  if isinstance(arr_type, (fabricarray, indexedfabricarray)):
6402
6884
  # fabric array of arrays: return array attribute as a regular array
6403
6885
  return array(dtype=dtype, ndim=ndim)
@@ -6408,8 +6890,18 @@ def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]
6408
6890
  for array_type in array_types:
6409
6891
  add_builtin(
6410
6892
  "view",
6411
- input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int},
6412
- defaults={"j": None, "k": None},
6893
+ input_types={
6894
+ "arr": array_type(dtype=Any),
6895
+ "i": Any,
6896
+ "j": Any,
6897
+ "k": Any,
6898
+ "l": Any,
6899
+ },
6900
+ defaults={
6901
+ "j": None,
6902
+ "k": None,
6903
+ "l": None,
6904
+ },
6413
6905
  constraint=sametypes,
6414
6906
  hidden=True,
6415
6907
  value_func=view_value_func,
@@ -6513,7 +7005,7 @@ add_builtin(
6513
7005
  hidden=True,
6514
7006
  skip_replay=True,
6515
7007
  group="Utility",
6516
- missing_grad=True,
7008
+ is_differentiable=False,
6517
7009
  )
6518
7010
 
6519
7011
 
@@ -6530,7 +7022,7 @@ add_builtin(
6530
7022
  dispatch_func=load_dispatch_func,
6531
7023
  hidden=True,
6532
7024
  group="Utility",
6533
- missing_grad=True,
7025
+ is_differentiable=False,
6534
7026
  )
6535
7027
 
6536
7028
 
@@ -6606,6 +7098,13 @@ def create_atomic_op_value_func(op: str):
6606
7098
  f"atomic_{op}() operations only work on arrays with [u]int32, [u]int64, float32, or float64 "
6607
7099
  f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
6608
7100
  )
7101
+ elif op in ("and", "or", "xor"):
7102
+ supported_atomic_types = (warp.int32, warp.int64, warp.uint32, warp.uint64)
7103
+ if not any(types_equal(scalar_type, x, match_generic=True) for x in supported_atomic_types):
7104
+ raise RuntimeError(
7105
+ f"atomic_{op}() operations only work on arrays with [u]int32 or [u]int64 "
7106
+ f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
7107
+ )
6609
7108
  else:
6610
7109
  raise NotImplementedError
6611
7110
 
@@ -6847,7 +7346,7 @@ for array_type in array_types:
6847
7346
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6848
7347
  group="Utility",
6849
7348
  skip_replay=True,
6850
- missing_grad=True,
7349
+ is_differentiable=False,
6851
7350
  )
6852
7351
  add_builtin(
6853
7352
  "atomic_cas",
@@ -6861,7 +7360,7 @@ for array_type in array_types:
6861
7360
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6862
7361
  group="Utility",
6863
7362
  skip_replay=True,
6864
- missing_grad=True,
7363
+ is_differentiable=False,
6865
7364
  )
6866
7365
  add_builtin(
6867
7366
  "atomic_cas",
@@ -6875,7 +7374,7 @@ for array_type in array_types:
6875
7374
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6876
7375
  group="Utility",
6877
7376
  skip_replay=True,
6878
- missing_grad=True,
7377
+ is_differentiable=False,
6879
7378
  )
6880
7379
  add_builtin(
6881
7380
  "atomic_cas",
@@ -6897,7 +7396,7 @@ for array_type in array_types:
6897
7396
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6898
7397
  group="Utility",
6899
7398
  skip_replay=True,
6900
- missing_grad=True,
7399
+ is_differentiable=False,
6901
7400
  )
6902
7401
 
6903
7402
  add_builtin(
@@ -6912,7 +7411,7 @@ for array_type in array_types:
6912
7411
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6913
7412
  group="Utility",
6914
7413
  skip_replay=True,
6915
- missing_grad=True,
7414
+ is_differentiable=False,
6916
7415
  )
6917
7416
  add_builtin(
6918
7417
  "atomic_exch",
@@ -6926,34 +7425,193 @@ for array_type in array_types:
6926
7425
  The operation is only atomic on a per-component basis for vectors and matrices.""",
6927
7426
  group="Utility",
6928
7427
  skip_replay=True,
6929
- missing_grad=True,
7428
+ is_differentiable=False,
7429
+ )
7430
+ add_builtin(
7431
+ "atomic_exch",
7432
+ hidden=hidden,
7433
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
7434
+ constraint=atomic_op_constraint,
7435
+ value_func=create_atomic_op_value_func("exch"),
7436
+ dispatch_func=atomic_op_dispatch_func,
7437
+ doc="""Atomically exchange ``value`` with ``arr[i,j,k]`` and return the old value.
7438
+
7439
+ The operation is only atomic on a per-component basis for vectors and matrices.""",
7440
+ group="Utility",
7441
+ skip_replay=True,
7442
+ is_differentiable=False,
7443
+ )
7444
+ add_builtin(
7445
+ "atomic_exch",
7446
+ hidden=hidden,
7447
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
7448
+ constraint=atomic_op_constraint,
7449
+ value_func=create_atomic_op_value_func("exch"),
7450
+ dispatch_func=atomic_op_dispatch_func,
7451
+ doc="""Atomically exchange ``value`` with ``arr[i,j,k,l]`` and return the old value.
7452
+
7453
+ The operation is only atomic on a per-component basis for vectors and matrices.""",
7454
+ group="Utility",
7455
+ skip_replay=True,
7456
+ )
7457
+
7458
+ add_builtin(
7459
+ "atomic_and",
7460
+ hidden=hidden,
7461
+ input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
7462
+ constraint=atomic_op_constraint,
7463
+ value_func=create_atomic_op_value_func("and"),
7464
+ dispatch_func=atomic_op_dispatch_func,
7465
+ doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i]``, atomically update the array, and return the old value.
7466
+ This function is automatically invoked when using the syntax ``arr[i] &= value``.""",
7467
+ group="Utility",
7468
+ skip_replay=True,
7469
+ is_differentiable=False,
7470
+ )
7471
+ add_builtin(
7472
+ "atomic_and",
7473
+ hidden=hidden,
7474
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
7475
+ constraint=atomic_op_constraint,
7476
+ value_func=create_atomic_op_value_func("and"),
7477
+ dispatch_func=atomic_op_dispatch_func,
7478
+ doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
7479
+ This function is automatically invoked when using the syntax ``arr[i,j] &= value``.""",
7480
+ group="Utility",
7481
+ skip_replay=True,
7482
+ is_differentiable=False,
7483
+ )
7484
+ add_builtin(
7485
+ "atomic_and",
7486
+ hidden=hidden,
7487
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
7488
+ constraint=atomic_op_constraint,
7489
+ value_func=create_atomic_op_value_func("and"),
7490
+ dispatch_func=atomic_op_dispatch_func,
7491
+ doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
7492
+ This function is automatically invoked when using the syntax ``arr[i,j,k] &= value``.""",
7493
+ group="Utility",
7494
+ skip_replay=True,
7495
+ is_differentiable=False,
7496
+ )
7497
+ add_builtin(
7498
+ "atomic_and",
7499
+ hidden=hidden,
7500
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
7501
+ constraint=atomic_op_constraint,
7502
+ value_func=create_atomic_op_value_func("and"),
7503
+ dispatch_func=atomic_op_dispatch_func,
7504
+ doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
7505
+ This function is automatically invoked when using the syntax ``arr[i,j,k,l] &= value``.""",
7506
+ group="Utility",
7507
+ skip_replay=True,
7508
+ is_differentiable=False,
7509
+ )
7510
+
7511
+ add_builtin(
7512
+ "atomic_or",
7513
+ hidden=hidden,
7514
+ input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
7515
+ constraint=atomic_op_constraint,
7516
+ value_func=create_atomic_op_value_func("or"),
7517
+ dispatch_func=atomic_op_dispatch_func,
7518
+ doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i]``, atomically update the array, and return the old value.
7519
+ This function is automatically invoked when using the syntax ``arr[i] |= value``.""",
7520
+ group="Utility",
7521
+ skip_replay=True,
7522
+ is_differentiable=False,
7523
+ )
7524
+ add_builtin(
7525
+ "atomic_or",
7526
+ hidden=hidden,
7527
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
7528
+ constraint=atomic_op_constraint,
7529
+ value_func=create_atomic_op_value_func("or"),
7530
+ dispatch_func=atomic_op_dispatch_func,
7531
+ doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
7532
+ This function is automatically invoked when using the syntax ``arr[i,j] |= value``.""",
7533
+ group="Utility",
7534
+ skip_replay=True,
7535
+ is_differentiable=False,
7536
+ )
7537
+ add_builtin(
7538
+ "atomic_or",
7539
+ hidden=hidden,
7540
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
7541
+ constraint=atomic_op_constraint,
7542
+ value_func=create_atomic_op_value_func("or"),
7543
+ dispatch_func=atomic_op_dispatch_func,
7544
+ doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
7545
+ This function is automatically invoked when using the syntax ``arr[i,j,k] |= value``.""",
7546
+ group="Utility",
7547
+ skip_replay=True,
7548
+ is_differentiable=False,
7549
+ )
7550
+ add_builtin(
7551
+ "atomic_or",
7552
+ hidden=hidden,
7553
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
7554
+ constraint=atomic_op_constraint,
7555
+ value_func=create_atomic_op_value_func("or"),
7556
+ dispatch_func=atomic_op_dispatch_func,
7557
+ doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
7558
+ This function is automatically invoked when using the syntax ``arr[i,j,k,l] |= value``.""",
7559
+ group="Utility",
7560
+ skip_replay=True,
7561
+ is_differentiable=False,
7562
+ )
7563
+
7564
+ add_builtin(
7565
+ "atomic_xor",
7566
+ hidden=hidden,
7567
+ input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
7568
+ constraint=atomic_op_constraint,
7569
+ value_func=create_atomic_op_value_func("xor"),
7570
+ dispatch_func=atomic_op_dispatch_func,
7571
+ doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i]``, atomically update the array, and return the old value.
7572
+ This function is automatically invoked when using the syntax ``arr[i] ^= value``.""",
7573
+ group="Utility",
7574
+ skip_replay=True,
7575
+ is_differentiable=False,
7576
+ )
7577
+ add_builtin(
7578
+ "atomic_xor",
7579
+ hidden=hidden,
7580
+ input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
7581
+ constraint=atomic_op_constraint,
7582
+ value_func=create_atomic_op_value_func("xor"),
7583
+ dispatch_func=atomic_op_dispatch_func,
7584
+ doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
7585
+ This function is automatically invoked when using the syntax ``arr[i,j] ^= value``.""",
7586
+ group="Utility",
7587
+ skip_replay=True,
7588
+ is_differentiable=False,
6930
7589
  )
6931
7590
  add_builtin(
6932
- "atomic_exch",
7591
+ "atomic_xor",
6933
7592
  hidden=hidden,
6934
7593
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
6935
7594
  constraint=atomic_op_constraint,
6936
- value_func=create_atomic_op_value_func("exch"),
7595
+ value_func=create_atomic_op_value_func("xor"),
6937
7596
  dispatch_func=atomic_op_dispatch_func,
6938
- doc="""Atomically exchange ``value`` with ``arr[i,j,k]`` and return the old value.
6939
-
6940
- The operation is only atomic on a per-component basis for vectors and matrices.""",
7597
+ doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
7598
+ This function is automatically invoked when using the syntax ``arr[i,j,k] ^= value``.""",
6941
7599
  group="Utility",
6942
7600
  skip_replay=True,
6943
- missing_grad=True,
7601
+ is_differentiable=False,
6944
7602
  )
6945
7603
  add_builtin(
6946
- "atomic_exch",
7604
+ "atomic_xor",
6947
7605
  hidden=hidden,
6948
7606
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
6949
7607
  constraint=atomic_op_constraint,
6950
- value_func=create_atomic_op_value_func("exch"),
7608
+ value_func=create_atomic_op_value_func("xor"),
6951
7609
  dispatch_func=atomic_op_dispatch_func,
6952
- doc="""Atomically exchange ``value`` with ``arr[i,j,k,l]`` and return the old value.
6953
-
6954
- The operation is only atomic on a per-component basis for vectors and matrices.""",
7610
+ doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
7611
+ This function is automatically invoked when using the syntax ``arr[i,j,k,l] ^= value``.""",
6955
7612
  group="Utility",
6956
7613
  skip_replay=True,
7614
+ is_differentiable=False,
6957
7615
  )
6958
7616
 
6959
7617
 
@@ -7104,7 +7762,7 @@ add_builtin(
7104
7762
  hidden=True,
7105
7763
  group="Utility",
7106
7764
  skip_replay=True,
7107
- missing_grad=True,
7765
+ is_differentiable=False,
7108
7766
  )
7109
7767
  # implements &quaternion[index]
7110
7768
  add_builtin(
@@ -7115,7 +7773,7 @@ add_builtin(
7115
7773
  hidden=True,
7116
7774
  group="Utility",
7117
7775
  skip_replay=True,
7118
- missing_grad=True,
7776
+ is_differentiable=False,
7119
7777
  )
7120
7778
  # implements &transformation[index]
7121
7779
  add_builtin(
@@ -7126,7 +7784,7 @@ add_builtin(
7126
7784
  hidden=True,
7127
7785
  group="Utility",
7128
7786
  skip_replay=True,
7129
- missing_grad=True,
7787
+ is_differentiable=False,
7130
7788
  )
7131
7789
  # implements &(*vector)[index]
7132
7790
  add_builtin(
@@ -7137,7 +7795,7 @@ add_builtin(
7137
7795
  hidden=True,
7138
7796
  group="Utility",
7139
7797
  skip_replay=True,
7140
- missing_grad=True,
7798
+ is_differentiable=False,
7141
7799
  )
7142
7800
  # implements &(*matrix)[i, j]
7143
7801
  add_builtin(
@@ -7148,7 +7806,7 @@ add_builtin(
7148
7806
  hidden=True,
7149
7807
  group="Utility",
7150
7808
  skip_replay=True,
7151
- missing_grad=True,
7809
+ is_differentiable=False,
7152
7810
  )
7153
7811
  # implements &(*quaternion)[index]
7154
7812
  add_builtin(
@@ -7159,7 +7817,7 @@ add_builtin(
7159
7817
  hidden=True,
7160
7818
  group="Utility",
7161
7819
  skip_replay=True,
7162
- missing_grad=True,
7820
+ is_differentiable=False,
7163
7821
  )
7164
7822
  # implements &(*transformation)[index]
7165
7823
  add_builtin(
@@ -7170,7 +7828,7 @@ add_builtin(
7170
7828
  hidden=True,
7171
7829
  group="Utility",
7172
7830
  skip_replay=True,
7173
- missing_grad=True,
7831
+ is_differentiable=False,
7174
7832
  )
7175
7833
 
7176
7834
 
@@ -7366,6 +8024,43 @@ add_builtin(
7366
8024
  )
7367
8025
 
7368
8026
 
8027
+ # implements vector[idx] &= scalar
8028
+ add_builtin(
8029
+ "bit_and_inplace",
8030
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
8031
+ value_type=None,
8032
+ dispatch_func=vector_assign_dispatch_func,
8033
+ hidden=True,
8034
+ export=False,
8035
+ group="Utility",
8036
+ is_differentiable=False,
8037
+ )
8038
+
8039
+ # implements vector[idx] |= scalar
8040
+ add_builtin(
8041
+ "bit_or_inplace",
8042
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
8043
+ value_type=None,
8044
+ dispatch_func=vector_assign_dispatch_func,
8045
+ hidden=True,
8046
+ export=False,
8047
+ group="Utility",
8048
+ is_differentiable=False,
8049
+ )
8050
+
8051
+ # implements vector[idx] ^= scalar
8052
+ add_builtin(
8053
+ "bit_xor_inplace",
8054
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
8055
+ value_type=None,
8056
+ dispatch_func=vector_assign_dispatch_func,
8057
+ hidden=True,
8058
+ export=False,
8059
+ group="Utility",
8060
+ is_differentiable=False,
8061
+ )
8062
+
8063
+
7369
8064
  def matrix_index_row_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
7370
8065
  mat_type = arg_types["a"]
7371
8066
  row_type = mat_type._wp_row_type_
@@ -7381,7 +8076,7 @@ add_builtin(
7381
8076
  hidden=True,
7382
8077
  group="Utility",
7383
8078
  skip_replay=True,
7384
- missing_grad=True,
8079
+ is_differentiable=False,
7385
8080
  )
7386
8081
 
7387
8082
 
@@ -7400,7 +8095,7 @@ add_builtin(
7400
8095
  hidden=True,
7401
8096
  group="Utility",
7402
8097
  skip_replay=True,
7403
- missing_grad=True,
8098
+ is_differentiable=False,
7404
8099
  )
7405
8100
 
7406
8101
 
@@ -7600,6 +8295,78 @@ add_builtin(
7600
8295
  )
7601
8296
 
7602
8297
 
8298
+ # implements matrix[i] &= value
8299
+ add_builtin(
8300
+ "bit_and_inplace",
8301
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
8302
+ value_type=None,
8303
+ hidden=True,
8304
+ export=False,
8305
+ group="Utility",
8306
+ is_differentiable=False,
8307
+ )
8308
+
8309
+
8310
+ # implements matrix[i,j] &= value
8311
+ add_builtin(
8312
+ "bit_and_inplace",
8313
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
8314
+ value_type=None,
8315
+ hidden=True,
8316
+ export=False,
8317
+ group="Utility",
8318
+ is_differentiable=False,
8319
+ )
8320
+
8321
+
8322
+ # implements matrix[i] |= value
8323
+ add_builtin(
8324
+ "bit_or_inplace",
8325
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
8326
+ value_type=None,
8327
+ hidden=True,
8328
+ export=False,
8329
+ group="Utility",
8330
+ is_differentiable=False,
8331
+ )
8332
+
8333
+
8334
+ # implements matrix[i,j] |= value
8335
+ add_builtin(
8336
+ "bit_or_inplace",
8337
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
8338
+ value_type=None,
8339
+ hidden=True,
8340
+ export=False,
8341
+ group="Utility",
8342
+ is_differentiable=False,
8343
+ )
8344
+
8345
+
8346
+ # implements matrix[i] ^= value
8347
+ add_builtin(
8348
+ "bit_xor_inplace",
8349
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
8350
+ value_type=None,
8351
+ hidden=True,
8352
+ export=False,
8353
+ group="Utility",
8354
+ is_differentiable=False,
8355
+ )
8356
+
8357
+
8358
+ # implements matrix[i,j] ^= value
8359
+ add_builtin(
8360
+ "bit_xor_inplace",
8361
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
8362
+ value_type=None,
8363
+ hidden=True,
8364
+ export=False,
8365
+ group="Utility",
8366
+ is_differentiable=False,
8367
+ )
8368
+
8369
+
7603
8370
  for t in scalar_types + vector_types + (bool,):
7604
8371
  if "vec" in t.__name__ or "mat" in t.__name__:
7605
8372
  continue
@@ -7611,7 +8378,7 @@ for t in scalar_types + vector_types + (bool,):
7611
8378
  doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
7612
8379
  group="Utility",
7613
8380
  hidden=True,
7614
- missing_grad=True,
8381
+ is_differentiable=False,
7615
8382
  )
7616
8383
 
7617
8384
  add_builtin(
@@ -7622,7 +8389,7 @@ for t in scalar_types + vector_types + (bool,):
7622
8389
  group="Utility",
7623
8390
  hidden=True,
7624
8391
  export=False,
7625
- missing_grad=True,
8392
+ is_differentiable=False,
7626
8393
  )
7627
8394
 
7628
8395
 
@@ -7641,7 +8408,7 @@ add_builtin(
7641
8408
  doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
7642
8409
  group="Utility",
7643
8410
  hidden=True,
7644
- missing_grad=True,
8411
+ is_differentiable=False,
7645
8412
  )
7646
8413
  add_builtin(
7647
8414
  "expect_neq",
@@ -7652,7 +8419,7 @@ add_builtin(
7652
8419
  group="Utility",
7653
8420
  hidden=True,
7654
8421
  export=False,
7655
- missing_grad=True,
8422
+ is_differentiable=False,
7656
8423
  )
7657
8424
 
7658
8425
  add_builtin(
@@ -7663,7 +8430,7 @@ add_builtin(
7663
8430
  doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
7664
8431
  group="Utility",
7665
8432
  hidden=True,
7666
- missing_grad=True,
8433
+ is_differentiable=False,
7667
8434
  )
7668
8435
  add_builtin(
7669
8436
  "expect_neq",
@@ -7674,7 +8441,7 @@ add_builtin(
7674
8441
  group="Utility",
7675
8442
  hidden=True,
7676
8443
  export=False,
7677
- missing_grad=True,
8444
+ is_differentiable=False,
7678
8445
  )
7679
8446
 
7680
8447
  add_builtin(
@@ -7765,7 +8532,7 @@ add_builtin(
7765
8532
  value_type=None,
7766
8533
  doc="Prints an error to stdout if ``a`` and ``b`` are not closer than tolerance in magnitude",
7767
8534
  group="Utility",
7768
- missing_grad=True,
8535
+ is_differentiable=False,
7769
8536
  )
7770
8537
  add_builtin(
7771
8538
  "expect_near",
@@ -7775,7 +8542,7 @@ add_builtin(
7775
8542
  value_type=None,
7776
8543
  doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
7777
8544
  group="Utility",
7778
- missing_grad=True,
8545
+ is_differentiable=False,
7779
8546
  )
7780
8547
  add_builtin(
7781
8548
  "expect_near",
@@ -7785,7 +8552,7 @@ add_builtin(
7785
8552
  value_type=None,
7786
8553
  doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
7787
8554
  group="Utility",
7788
- missing_grad=True,
8555
+ is_differentiable=False,
7789
8556
  )
7790
8557
  add_builtin(
7791
8558
  "expect_near",
@@ -7799,7 +8566,7 @@ add_builtin(
7799
8566
  value_type=None,
7800
8567
  doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
7801
8568
  group="Utility",
7802
- missing_grad=True,
8569
+ is_differentiable=False,
7803
8570
  )
7804
8571
 
7805
8572
  # ---------------------------------
@@ -7810,7 +8577,7 @@ add_builtin(
7810
8577
  input_types={"arr": array(dtype=Scalar), "value": Scalar},
7811
8578
  value_type=int,
7812
8579
  doc="Search a sorted array ``arr`` for the closest element greater than or equal to ``value``.",
7813
- missing_grad=True,
8580
+ is_differentiable=False,
7814
8581
  )
7815
8582
 
7816
8583
  add_builtin(
@@ -7818,7 +8585,7 @@ add_builtin(
7818
8585
  input_types={"arr": array(dtype=Scalar), "arr_begin": int, "arr_end": int, "value": Scalar},
7819
8586
  value_type=int,
7820
8587
  doc="Search a sorted array ``arr`` in the range [arr_begin, arr_end) for the closest element greater than or equal to ``value``.",
7821
- missing_grad=True,
8588
+ is_differentiable=False,
7822
8589
  )
7823
8590
 
7824
8591
  # ---------------------------------
@@ -7899,31 +8666,153 @@ add_builtin(
7899
8666
  input_types={"a": Int, "b": Int},
7900
8667
  value_func=sametypes_create_value_func(Int),
7901
8668
  group="Operators",
7902
- missing_grad=True,
8669
+ is_differentiable=False,
8670
+ )
8671
+ add_builtin(
8672
+ "bit_and",
8673
+ input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
8674
+ constraint=sametypes,
8675
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
8676
+ doc="",
8677
+ group="Operators",
8678
+ is_differentiable=False,
8679
+ )
8680
+ add_builtin(
8681
+ "bit_and",
8682
+ input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
8683
+ constraint=sametypes,
8684
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
8685
+ doc="",
8686
+ group="Operators",
8687
+ is_differentiable=False,
8688
+ )
8689
+
8690
+ add_builtin(
8691
+ "bit_or",
8692
+ input_types={"a": Int, "b": Int},
8693
+ value_func=sametypes_create_value_func(Int),
8694
+ group="Operators",
8695
+ is_differentiable=False,
7903
8696
  )
7904
8697
  add_builtin(
7905
8698
  "bit_or",
8699
+ input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
8700
+ constraint=sametypes,
8701
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
8702
+ doc="",
8703
+ group="Operators",
8704
+ is_differentiable=False,
8705
+ )
8706
+ add_builtin(
8707
+ "bit_or",
8708
+ input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
8709
+ constraint=sametypes,
8710
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
8711
+ doc="",
8712
+ group="Operators",
8713
+ is_differentiable=False,
8714
+ )
8715
+
8716
+ add_builtin(
8717
+ "bit_xor",
7906
8718
  input_types={"a": Int, "b": Int},
7907
8719
  value_func=sametypes_create_value_func(Int),
7908
8720
  group="Operators",
7909
- missing_grad=True,
8721
+ is_differentiable=False,
8722
+ )
8723
+ add_builtin(
8724
+ "bit_xor",
8725
+ input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
8726
+ constraint=sametypes,
8727
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
8728
+ doc="",
8729
+ group="Operators",
8730
+ is_differentiable=False,
7910
8731
  )
7911
8732
  add_builtin(
7912
8733
  "bit_xor",
8734
+ input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
8735
+ constraint=sametypes,
8736
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
8737
+ doc="",
8738
+ group="Operators",
8739
+ is_differentiable=False,
8740
+ )
8741
+
8742
+ add_builtin(
8743
+ "lshift",
7913
8744
  input_types={"a": Int, "b": Int},
7914
8745
  value_func=sametypes_create_value_func(Int),
7915
8746
  group="Operators",
7916
- missing_grad=True,
8747
+ is_differentiable=False,
8748
+ )
8749
+ add_builtin(
8750
+ "lshift",
8751
+ input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
8752
+ constraint=sametypes,
8753
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
8754
+ doc="",
8755
+ group="Operators",
8756
+ is_differentiable=False,
8757
+ )
8758
+ add_builtin(
8759
+ "lshift",
8760
+ input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
8761
+ constraint=sametypes,
8762
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
8763
+ doc="",
8764
+ group="Operators",
8765
+ is_differentiable=False,
7917
8766
  )
7918
- add_builtin("lshift", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int), group="Operators")
8767
+
7919
8768
  add_builtin(
7920
8769
  "rshift",
7921
8770
  input_types={"a": Int, "b": Int},
7922
8771
  value_func=sametypes_create_value_func(Int),
7923
8772
  group="Operators",
7924
- missing_grad=True,
8773
+ is_differentiable=False,
8774
+ )
8775
+ add_builtin(
8776
+ "rshift",
8777
+ input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
8778
+ constraint=sametypes,
8779
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
8780
+ doc="",
8781
+ group="Operators",
8782
+ is_differentiable=False,
8783
+ )
8784
+ add_builtin(
8785
+ "rshift",
8786
+ input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
8787
+ constraint=sametypes,
8788
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
8789
+ doc="",
8790
+ group="Operators",
8791
+ is_differentiable=False,
8792
+ )
8793
+
8794
+ add_builtin(
8795
+ "invert",
8796
+ input_types={"a": Int},
8797
+ value_func=sametypes_create_value_func(Int),
8798
+ group="Operators",
8799
+ is_differentiable=False,
8800
+ )
8801
+ add_builtin(
8802
+ "invert",
8803
+ input_types={"a": vector(length=Any, dtype=Int)},
8804
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
8805
+ group="Operators",
8806
+ is_differentiable=False,
7925
8807
  )
7926
- add_builtin("invert", input_types={"a": Int}, value_func=sametypes_create_value_func(Int), group="Operators")
8808
+ add_builtin(
8809
+ "invert",
8810
+ input_types={"a": matrix(shape=(Any, Any), dtype=Int)},
8811
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
8812
+ group="Operators",
8813
+ is_differentiable=False,
8814
+ )
8815
+
7927
8816
 
7928
8817
  add_builtin(
7929
8818
  "mul", input_types={"a": Scalar, "b": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators"
@@ -8123,7 +9012,7 @@ add_builtin(
8123
9012
  value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
8124
9013
  doc="Modulo operation using truncated division.",
8125
9014
  group="Operators",
8126
- missing_grad=True,
9015
+ is_differentiable=False,
8127
9016
  )
8128
9017
 
8129
9018
  add_builtin(
@@ -8183,7 +9072,7 @@ add_builtin(
8183
9072
  value_func=sametypes_create_value_func(Scalar),
8184
9073
  doc="",
8185
9074
  group="Operators",
8186
- missing_grad=True,
9075
+ is_differentiable=False,
8187
9076
  )
8188
9077
 
8189
9078
  add_builtin("pos", input_types={"x": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators")
@@ -8232,14 +9121,26 @@ add_builtin(
8232
9121
  )
8233
9122
 
8234
9123
  add_builtin(
8235
- "unot", input_types={"a": builtins.bool}, value_type=builtins.bool, doc="", group="Operators", missing_grad=True
9124
+ "unot",
9125
+ input_types={"a": builtins.bool},
9126
+ value_type=builtins.bool,
9127
+ doc="",
9128
+ group="Operators",
9129
+ is_differentiable=False,
8236
9130
  )
8237
9131
  for t in int_types:
8238
- add_builtin("unot", input_types={"a": t}, value_type=builtins.bool, doc="", group="Operators", missing_grad=True)
9132
+ add_builtin(
9133
+ "unot", input_types={"a": t}, value_type=builtins.bool, doc="", group="Operators", is_differentiable=False
9134
+ )
8239
9135
 
8240
9136
 
8241
9137
  add_builtin(
8242
- "unot", input_types={"a": array(dtype=Any)}, value_type=builtins.bool, doc="", group="Operators", missing_grad=True
9138
+ "unot",
9139
+ input_types={"a": array(dtype=Any)},
9140
+ value_type=builtins.bool,
9141
+ doc="",
9142
+ group="Operators",
9143
+ is_differentiable=False,
8243
9144
  )
8244
9145
 
8245
9146
 
@@ -8312,6 +9213,45 @@ add_builtin(
8312
9213
  export=False,
8313
9214
  )
8314
9215
 
9216
+ add_builtin(
9217
+ "bit_and",
9218
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
9219
+ value_func=tile_binary_map_value_func,
9220
+ # dispatch_func=tile_map_dispatch_func,
9221
+ # variadic=True,
9222
+ native_func="tile_bit_and",
9223
+ doc="Bitwise AND each element of two tiles together",
9224
+ group="Tile Primitives",
9225
+ export=False,
9226
+ is_differentiable=False,
9227
+ )
9228
+
9229
+ add_builtin(
9230
+ "bit_or",
9231
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
9232
+ value_func=tile_binary_map_value_func,
9233
+ # dispatch_func=tile_map_dispatch_func,
9234
+ # variadic=True,
9235
+ native_func="tile_bit_or",
9236
+ doc="Bitwise OR each element of two tiles together",
9237
+ group="Tile Primitives",
9238
+ export=False,
9239
+ is_differentiable=False,
9240
+ )
9241
+
9242
+ add_builtin(
9243
+ "bit_xor",
9244
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
9245
+ value_func=tile_binary_map_value_func,
9246
+ # dispatch_func=tile_map_dispatch_func,
9247
+ # variadic=True,
9248
+ native_func="tile_bit_xor",
9249
+ doc="Bitwise XOR each element of two tiles together",
9250
+ group="Tile Primitives",
9251
+ export=False,
9252
+ is_differentiable=False,
9253
+ )
9254
+
8315
9255
 
8316
9256
  add_builtin(
8317
9257
  "mul",
@@ -8373,6 +9313,45 @@ add_builtin(
8373
9313
  )
8374
9314
 
8375
9315
 
9316
+ add_builtin(
9317
+ "bit_and_inplace",
9318
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
9319
+ value_type=None,
9320
+ dispatch_func=tile_inplace_dispatch_func,
9321
+ export=False,
9322
+ hidden=True,
9323
+ native_func="tile_bit_and_inplace",
9324
+ group="Operators",
9325
+ is_differentiable=False,
9326
+ )
9327
+
9328
+
9329
+ add_builtin(
9330
+ "bit_or_inplace",
9331
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
9332
+ value_type=None,
9333
+ dispatch_func=tile_inplace_dispatch_func,
9334
+ export=False,
9335
+ hidden=True,
9336
+ native_func="tile_bit_or_inplace",
9337
+ group="Operators",
9338
+ is_differentiable=False,
9339
+ )
9340
+
9341
+
9342
+ add_builtin(
9343
+ "bit_xor_inplace",
9344
+ input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
9345
+ value_type=None,
9346
+ dispatch_func=tile_inplace_dispatch_func,
9347
+ export=False,
9348
+ hidden=True,
9349
+ native_func="tile_bit_xor_inplace",
9350
+ group="Operators",
9351
+ is_differentiable=False,
9352
+ )
9353
+
9354
+
8376
9355
  def tile_diag_add_value_func(arg_types, arg_values):
8377
9356
  if arg_types is None:
8378
9357
  return tile(dtype=Any, shape=Tuple[int, int])
@@ -8414,7 +9393,7 @@ def tile_diag_add_lto_dispatch_func(
8414
9393
  return_values: List[Var],
8415
9394
  arg_values: Mapping[str, Var],
8416
9395
  options: Mapping[str, Any],
8417
- builder: warp.context.ModuleBuilder,
9396
+ builder: warp._src.context.ModuleBuilder,
8418
9397
  ):
8419
9398
  a = arg_values["a"]
8420
9399
  d = arg_values["d"]
@@ -8434,7 +9413,7 @@ add_builtin(
8434
9413
  doc="Add a square matrix and a diagonal matrix 'd' represented as a 1D tile",
8435
9414
  group="Tile Primitives",
8436
9415
  export=False,
8437
- missing_grad=True,
9416
+ is_differentiable=False,
8438
9417
  )
8439
9418
 
8440
9419
 
@@ -8491,7 +9470,7 @@ def tile_matmul_lto_dispatch_func(
8491
9470
  return_values: List[Var],
8492
9471
  arg_values: Mapping[str, Var],
8493
9472
  options: Mapping[str, Any],
8494
- builder: warp.context.ModuleBuilder,
9473
+ builder: warp._src.context.ModuleBuilder,
8495
9474
  ):
8496
9475
  a = arg_values["a"]
8497
9476
  b = arg_values["b"]
@@ -8529,7 +9508,7 @@ def tile_matmul_lto_dispatch_func(
8529
9508
  num_threads = options["block_dim"]
8530
9509
  arch = options["output_arch"]
8531
9510
 
8532
- if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
9511
+ if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
8533
9512
  # CPU/no-MathDx dispatch
8534
9513
  return ((0, 0, 0, a, b, out), template_args, [], 0)
8535
9514
  else:
@@ -8542,7 +9521,7 @@ def tile_matmul_lto_dispatch_func(
8542
9521
 
8543
9522
  # generate the LTOs
8544
9523
  # C += A * B
8545
- (fun_forward, lto_forward) = warp.build.build_lto_dot(
9524
+ (fun_forward, lto_forward) = warp._src.build.build_lto_dot(
8546
9525
  M,
8547
9526
  N,
8548
9527
  K,
@@ -8558,7 +9537,7 @@ def tile_matmul_lto_dispatch_func(
8558
9537
  )
8559
9538
  if warp.config.enable_backward:
8560
9539
  # adjA += adjC * B^T - Transpose ~= flipped layout
8561
- (fun_backward_A, lto_backward_A) = warp.build.build_lto_dot(
9540
+ (fun_backward_A, lto_backward_A) = warp._src.build.build_lto_dot(
8562
9541
  M,
8563
9542
  K,
8564
9543
  N,
@@ -8573,7 +9552,7 @@ def tile_matmul_lto_dispatch_func(
8573
9552
  builder,
8574
9553
  )
8575
9554
  # adjB += A^T * adjC - Transpose ~= flipped layout
8576
- (fun_backward_B, lto_backward_B) = warp.build.build_lto_dot(
9555
+ (fun_backward_B, lto_backward_B) = warp._src.build.build_lto_dot(
8577
9556
  K,
8578
9557
  N,
8579
9558
  M,
@@ -8690,7 +9669,7 @@ def tile_fft_generic_lto_dispatch_func(
8690
9669
  return_values: List[Var],
8691
9670
  arg_values: Mapping[str, Var],
8692
9671
  options: Mapping[str, Any],
8693
- builder: warp.context.ModuleBuilder,
9672
+ builder: warp._src.context.ModuleBuilder,
8694
9673
  direction: str | None = None,
8695
9674
  ):
8696
9675
  inout = arg_values["inout"]
@@ -8719,12 +9698,12 @@ def tile_fft_generic_lto_dispatch_func(
8719
9698
  arch = options["output_arch"]
8720
9699
  ept = size // num_threads
8721
9700
 
8722
- if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
9701
+ if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
8723
9702
  # CPU/no-MathDx dispatch
8724
9703
  return ([], [], [], 0)
8725
9704
  else:
8726
9705
  # generate the LTO
8727
- lto_symbol, lto_code_data, shared_memory_bytes = warp.build.build_lto_fft(
9706
+ lto_symbol, lto_code_data, shared_memory_bytes = warp._src.build.build_lto_fft(
8728
9707
  arch, size, ept, direction, dir, precision, builder
8729
9708
  )
8730
9709
 
@@ -8762,7 +9741,7 @@ add_builtin(
8762
9741
  group="Tile Primitives",
8763
9742
  export=False,
8764
9743
  namespace="",
8765
- missing_grad=True,
9744
+ is_differentiable=False,
8766
9745
  )
8767
9746
 
8768
9747
  add_builtin(
@@ -8784,7 +9763,7 @@ add_builtin(
8784
9763
  group="Tile Primitives",
8785
9764
  export=False,
8786
9765
  namespace="",
8787
- missing_grad=True,
9766
+ is_differentiable=False,
8788
9767
  )
8789
9768
 
8790
9769
 
@@ -8829,7 +9808,7 @@ def tile_cholesky_generic_lto_dispatch_func(
8829
9808
  return_values: List[Var],
8830
9809
  arg_values: Mapping[str, Var],
8831
9810
  options: Mapping[str, Any],
8832
- builder: warp.context.ModuleBuilder,
9811
+ builder: warp._src.context.ModuleBuilder,
8833
9812
  ):
8834
9813
  a = arg_values["A"]
8835
9814
  # force source tile to shared memory
@@ -8849,7 +9828,7 @@ def tile_cholesky_generic_lto_dispatch_func(
8849
9828
 
8850
9829
  arch = options["output_arch"]
8851
9830
 
8852
- if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
9831
+ if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
8853
9832
  # CPU/no-MathDx dispatch
8854
9833
  return ((0, a, out), [], [], 0)
8855
9834
  else:
@@ -8864,7 +9843,7 @@ def tile_cholesky_generic_lto_dispatch_func(
8864
9843
  req_smem_bytes = a.type.size * type_size_in_bytes(a.type.dtype)
8865
9844
 
8866
9845
  # generate the LTO
8867
- lto_symbol, lto_code_data = warp.build.build_lto_solver(
9846
+ lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
8868
9847
  M,
8869
9848
  N,
8870
9849
  1,
@@ -8909,7 +9888,7 @@ add_builtin(
8909
9888
  group="Tile Primitives",
8910
9889
  export=False,
8911
9890
  namespace="",
8912
- missing_grad=True,
9891
+ is_differentiable=False,
8913
9892
  )
8914
9893
 
8915
9894
 
@@ -8953,7 +9932,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
8953
9932
  return_values: List[Var],
8954
9933
  arg_values: Mapping[str, Var],
8955
9934
  options: Mapping[str, Any],
8956
- builder: warp.context.ModuleBuilder,
9935
+ builder: warp._src.context.ModuleBuilder,
8957
9936
  ):
8958
9937
  L = arg_values["L"]
8959
9938
  y = arg_values["y"]
@@ -8982,7 +9961,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
8982
9961
 
8983
9962
  arch = options["output_arch"]
8984
9963
 
8985
- if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
9964
+ if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
8986
9965
  # CPU/no-MathDx dispatch
8987
9966
  return ((0, L, y, x), [], [], 0)
8988
9967
  else:
@@ -8998,7 +9977,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
8998
9977
  req_smem_bytes = (x.type.size + y.type.size + L.type.size) * type_size_in_bytes(L.type.dtype)
8999
9978
 
9000
9979
  # generate the LTO
9001
- lto_symbol, lto_code_data = warp.build.build_lto_solver(
9980
+ lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
9002
9981
  M,
9003
9982
  N,
9004
9983
  NRHS,
@@ -9040,7 +10019,7 @@ add_builtin(
9040
10019
  group="Tile Primitives",
9041
10020
  export=False,
9042
10021
  namespace="",
9043
- missing_grad=True,
10022
+ is_differentiable=False,
9044
10023
  )
9045
10024
 
9046
10025
 
@@ -9050,7 +10029,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
9050
10029
  return_values: List[Var],
9051
10030
  arg_values: Mapping[str, Var],
9052
10031
  options: Mapping[str, Any],
9053
- builder: warp.context.ModuleBuilder,
10032
+ builder: warp._src.context.ModuleBuilder,
9054
10033
  ):
9055
10034
  L = arg_values["L"]
9056
10035
  y = arg_values["y"]
@@ -9079,7 +10058,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
9079
10058
 
9080
10059
  arch = options["output_arch"]
9081
10060
 
9082
- if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
10061
+ if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
9083
10062
  # CPU/no-MathDx dispatch
9084
10063
  return ((0, L, y, z), [], [], 0)
9085
10064
  else:
@@ -9095,7 +10074,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
9095
10074
  req_smem_bytes = (z.type.size + y.type.size + L.type.size) * type_size_in_bytes(L.type.dtype)
9096
10075
 
9097
10076
  # generate the LTO
9098
- lto_symbol, lto_code_data = warp.build.build_lto_solver(
10077
+ lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
9099
10078
  M,
9100
10079
  N,
9101
10080
  NRHS,
@@ -9173,7 +10152,7 @@ add_builtin(
9173
10152
  group="Tile Primitives",
9174
10153
  export=False,
9175
10154
  namespace="",
9176
- missing_grad=True,
10155
+ is_differentiable=False,
9177
10156
  )
9178
10157
 
9179
10158
 
@@ -9183,7 +10162,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
9183
10162
  return_values: List[Var],
9184
10163
  arg_values: Mapping[str, Var],
9185
10164
  options: Mapping[str, Any],
9186
- builder: warp.context.ModuleBuilder,
10165
+ builder: warp._src.context.ModuleBuilder,
9187
10166
  ):
9188
10167
  U = arg_values["U"]
9189
10168
  z = arg_values["z"]
@@ -9212,7 +10191,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
9212
10191
 
9213
10192
  arch = options["output_arch"]
9214
10193
 
9215
- if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
10194
+ if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
9216
10195
  # CPU/no-MathDx dispatch
9217
10196
  return ((0, U, z, x), [], [], 0)
9218
10197
  else:
@@ -9228,7 +10207,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
9228
10207
  req_smem_bytes = (x.type.size + z.type.size + U.type.size) * type_size_in_bytes(U.type.dtype)
9229
10208
 
9230
10209
  # generate the LTO
9231
- lto_symbol, lto_code_data = warp.build.build_lto_solver(
10210
+ lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
9232
10211
  M,
9233
10212
  N,
9234
10213
  NRHS,
@@ -9306,7 +10285,7 @@ add_builtin(
9306
10285
  group="Tile Primitives",
9307
10286
  export=False,
9308
10287
  namespace="",
9309
- missing_grad=True,
10288
+ is_differentiable=False,
9310
10289
  )
9311
10290
 
9312
10291
 
@@ -9326,7 +10305,7 @@ add_builtin(
9326
10305
  The return type of the expression must be either a Warp function, a string, or a type that is supported inside Warp kernels and functions
9327
10306
  (excluding Warp arrays since they cannot be created in a Warp kernel at the moment).""",
9328
10307
  group="Code Generation",
9329
- missing_grad=True,
10308
+ is_differentiable=False,
9330
10309
  )
9331
10310
 
9332
10311
 
@@ -9351,7 +10330,7 @@ add_builtin(
9351
10330
  doc="Return the number of elements in a vector.",
9352
10331
  group="Utility",
9353
10332
  export=False,
9354
- missing_grad=True,
10333
+ is_differentiable=False,
9355
10334
  )
9356
10335
 
9357
10336
  add_builtin(
@@ -9361,7 +10340,7 @@ add_builtin(
9361
10340
  doc="Return the number of elements in a quaternion.",
9362
10341
  group="Utility",
9363
10342
  export=False,
9364
- missing_grad=True,
10343
+ is_differentiable=False,
9365
10344
  )
9366
10345
 
9367
10346
  add_builtin(
@@ -9371,7 +10350,7 @@ add_builtin(
9371
10350
  doc="Return the number of rows in a matrix.",
9372
10351
  group="Utility",
9373
10352
  export=False,
9374
- missing_grad=True,
10353
+ is_differentiable=False,
9375
10354
  )
9376
10355
 
9377
10356
  add_builtin(
@@ -9381,7 +10360,7 @@ add_builtin(
9381
10360
  doc="Return the number of elements in a transformation.",
9382
10361
  group="Utility",
9383
10362
  export=False,
9384
- missing_grad=True,
10363
+ is_differentiable=False,
9385
10364
  )
9386
10365
 
9387
10366
  add_builtin(
@@ -9391,7 +10370,7 @@ add_builtin(
9391
10370
  doc="Return the size of the first dimension in an array.",
9392
10371
  group="Utility",
9393
10372
  export=False,
9394
- missing_grad=True,
10373
+ is_differentiable=False,
9395
10374
  )
9396
10375
 
9397
10376
  add_builtin(
@@ -9401,7 +10380,33 @@ add_builtin(
9401
10380
  doc="Return the number of rows in a tile.",
9402
10381
  group="Utility",
9403
10382
  export=False,
9404
- missing_grad=True,
10383
+ is_differentiable=False,
10384
+ )
10385
+
10386
+
10387
+ def cast_value_func(arg_types, arg_values):
10388
+ # Return generic type for doc builds.
10389
+ if arg_types is None:
10390
+ return Any
10391
+
10392
+ return arg_values["dtype"]
10393
+
10394
+
10395
+ def cast_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
10396
+ func_args = (args["a"],)
10397
+ template_args = (args["dtype"],)
10398
+ return (func_args, template_args)
10399
+
10400
+
10401
+ add_builtin(
10402
+ "cast",
10403
+ input_types={"a": Any, "dtype": Any},
10404
+ value_func=cast_value_func,
10405
+ dispatch_func=cast_dispatch_func,
10406
+ doc="Reinterpret a value as a different type while preserving its bit pattern.",
10407
+ group="Utility",
10408
+ export=False,
10409
+ is_differentiable=False,
9405
10410
  )
9406
10411
 
9407
10412
 
@@ -9428,7 +10433,7 @@ add_builtin(
9428
10433
  doc="Construct a tuple from a list of values",
9429
10434
  group="Utility",
9430
10435
  hidden=True,
9431
- missing_grad=True,
10436
+ is_differentiable=False,
9432
10437
  export=False,
9433
10438
  )
9434
10439
 
@@ -9465,7 +10470,7 @@ add_builtin(
9465
10470
  dispatch_func=tuple_extract_dispatch_func,
9466
10471
  group="Utility",
9467
10472
  hidden=True,
9468
- missing_grad=True,
10473
+ is_differentiable=False,
9469
10474
  )
9470
10475
 
9471
10476
 
@@ -9476,7 +10481,7 @@ add_builtin(
9476
10481
  doc="Return the number of elements in a tuple.",
9477
10482
  group="Utility",
9478
10483
  export=False,
9479
- missing_grad=True,
10484
+ is_differentiable=False,
9480
10485
  )
9481
10486
 
9482
10487
  # ---------------------------------
@@ -9495,5 +10500,5 @@ add_builtin(
9495
10500
  export=False,
9496
10501
  group="Utility",
9497
10502
  hidden=True,
9498
- missing_grad=True,
10503
+ is_differentiable=False,
9499
10504
  )