warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/builtins.py CHANGED
@@ -5,24 +5,22 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- from .context import add_builtin
8
+ import builtins
9
+ from typing import Any, Callable, Tuple
9
10
 
11
+ from warp.codegen import Reference
10
12
  from warp.types import *
11
13
 
12
- from typing import Tuple
13
- from typing import List
14
- from typing import Dict
15
- from typing import Any
16
- from typing import Callable
14
+ from .context import add_builtin
17
15
 
18
16
 
19
17
  def sametype_value_func(default):
20
- def fn(args, kwds, _):
21
- if args is None:
18
+ def fn(arg_types, kwds, _):
19
+ if arg_types is None:
22
20
  return default
23
- if not all(types_equal(args[0].type, a.type) for a in args[1:]):
24
- raise RuntimeError(f"Input types must be the same, found: {[type_repr(a.type) for a in args]}")
25
- return args[0].type
21
+ if not all(types_equal(arg_types[0], t) for t in arg_types[1:]):
22
+ raise RuntimeError(f"Input types must be the same, found: {[type_repr(t) for t in arg_types]}")
23
+ return arg_types[0]
26
24
 
27
25
  return fn
28
26
 
@@ -50,7 +48,7 @@ add_builtin(
50
48
  "clamp",
51
49
  input_types={"x": Scalar, "a": Scalar, "b": Scalar},
52
50
  value_func=sametype_value_func(Scalar),
53
- doc="Clamp the value of x to the range [a, b].",
51
+ doc="Clamp the value of ``x`` to the range [a, b].",
54
52
  group="Scalar Math",
55
53
  )
56
54
 
@@ -58,14 +56,14 @@ add_builtin(
58
56
  "abs",
59
57
  input_types={"x": Scalar},
60
58
  value_func=sametype_value_func(Scalar),
61
- doc="Return the absolute value of x.",
59
+ doc="Return the absolute value of ``x``.",
62
60
  group="Scalar Math",
63
61
  )
64
62
  add_builtin(
65
63
  "sign",
66
64
  input_types={"x": Scalar},
67
65
  value_func=sametype_value_func(Scalar),
68
- doc="Return -1 if x < 0, return 1 otherwise.",
66
+ doc="Return -1 if ``x`` < 0, return 1 otherwise.",
69
67
  group="Scalar Math",
70
68
  )
71
69
 
@@ -73,14 +71,14 @@ add_builtin(
73
71
  "step",
74
72
  input_types={"x": Scalar},
75
73
  value_func=sametype_value_func(Scalar),
76
- doc="Return 1.0 if x < 0.0, return 0.0 otherwise.",
74
+ doc="Return 1.0 if ``x`` < 0.0, return 0.0 otherwise.",
77
75
  group="Scalar Math",
78
76
  )
79
77
  add_builtin(
80
78
  "nonzero",
81
79
  input_types={"x": Scalar},
82
80
  value_func=sametype_value_func(Scalar),
83
- doc="Return 1.0 if x is not equal to zero, return 0.0 otherwise.",
81
+ doc="Return 1.0 if ``x`` is not equal to zero, return 0.0 otherwise.",
84
82
  group="Scalar Math",
85
83
  )
86
84
 
@@ -88,91 +86,101 @@ add_builtin(
88
86
  "sin",
89
87
  input_types={"x": Float},
90
88
  value_func=sametype_value_func(Float),
91
- doc="Return the sine of x in radians.",
89
+ doc="Return the sine of ``x`` in radians.",
92
90
  group="Scalar Math",
93
91
  )
94
92
  add_builtin(
95
93
  "cos",
96
94
  input_types={"x": Float},
97
95
  value_func=sametype_value_func(Float),
98
- doc="Return the cosine of x in radians.",
96
+ doc="Return the cosine of ``x`` in radians.",
99
97
  group="Scalar Math",
100
98
  )
101
99
  add_builtin(
102
100
  "acos",
103
101
  input_types={"x": Float},
104
102
  value_func=sametype_value_func(Float),
105
- doc="Return arccos of x in radians. Inputs are automatically clamped to [-1.0, 1.0].",
103
+ doc="Return arccos of ``x`` in radians. Inputs are automatically clamped to [-1.0, 1.0].",
106
104
  group="Scalar Math",
107
105
  )
108
106
  add_builtin(
109
107
  "asin",
110
108
  input_types={"x": Float},
111
109
  value_func=sametype_value_func(Float),
112
- doc="Return arcsin of x in radians. Inputs are automatically clamped to [-1.0, 1.0].",
110
+ doc="Return arcsin of ``x`` in radians. Inputs are automatically clamped to [-1.0, 1.0].",
113
111
  group="Scalar Math",
114
112
  )
115
113
  add_builtin(
116
114
  "sqrt",
117
115
  input_types={"x": Float},
118
116
  value_func=sametype_value_func(Float),
119
- doc="Return the sqrt of x, where x is positive.",
117
+ doc="Return the square root of ``x``, where ``x`` is positive.",
120
118
  group="Scalar Math",
119
+ require_original_output_arg=True,
120
+ )
121
+ add_builtin(
122
+ "cbrt",
123
+ input_types={"x": Float},
124
+ value_func=sametype_value_func(Float),
125
+ doc="Return the cube root of ``x``.",
126
+ group="Scalar Math",
127
+ require_original_output_arg=True,
121
128
  )
122
129
  add_builtin(
123
130
  "tan",
124
131
  input_types={"x": Float},
125
132
  value_func=sametype_value_func(Float),
126
- doc="Return tangent of x in radians.",
133
+ doc="Return the tangent of ``x`` in radians.",
127
134
  group="Scalar Math",
128
135
  )
129
136
  add_builtin(
130
137
  "atan",
131
138
  input_types={"x": Float},
132
139
  value_func=sametype_value_func(Float),
133
- doc="Return arctan of x.",
140
+ doc="Return the arctangent of ``x`` in radians.",
134
141
  group="Scalar Math",
135
142
  )
136
143
  add_builtin(
137
144
  "atan2",
138
145
  input_types={"y": Float, "x": Float},
139
146
  value_func=sametype_value_func(Float),
140
- doc="Return atan2 of x.",
147
+ doc="Return the 2-argument arctangent, atan2, of the point ``(x, y)`` in radians.",
141
148
  group="Scalar Math",
142
149
  )
143
150
  add_builtin(
144
151
  "sinh",
145
152
  input_types={"x": Float},
146
153
  value_func=sametype_value_func(Float),
147
- doc="Return the sinh of x.",
154
+ doc="Return the sinh of ``x``.",
148
155
  group="Scalar Math",
149
156
  )
150
157
  add_builtin(
151
158
  "cosh",
152
159
  input_types={"x": Float},
153
160
  value_func=sametype_value_func(Float),
154
- doc="Return the cosh of x.",
161
+ doc="Return the cosh of ``x``.",
155
162
  group="Scalar Math",
156
163
  )
157
164
  add_builtin(
158
165
  "tanh",
159
166
  input_types={"x": Float},
160
167
  value_func=sametype_value_func(Float),
161
- doc="Return the tanh of x.",
168
+ doc="Return the tanh of ``x``.",
162
169
  group="Scalar Math",
170
+ require_original_output_arg=True,
163
171
  )
164
172
  add_builtin(
165
173
  "degrees",
166
174
  input_types={"x": Float},
167
175
  value_func=sametype_value_func(Float),
168
- doc="Convert radians into degrees.",
176
+ doc="Convert ``x`` from radians into degrees.",
169
177
  group="Scalar Math",
170
178
  )
171
179
  add_builtin(
172
180
  "radians",
173
181
  input_types={"x": Float},
174
182
  value_func=sametype_value_func(Float),
175
- doc="Convert degrees into radians.",
183
+ doc="Convert ``x`` from degrees into radians.",
176
184
  group="Scalar Math",
177
185
  )
178
186
 
@@ -180,36 +188,38 @@ add_builtin(
180
188
  "log",
181
189
  input_types={"x": Float},
182
190
  value_func=sametype_value_func(Float),
183
- doc="Return the natural log (base-e) of x, where x is positive.",
191
+ doc="Return the natural logarithm (base-e) of ``x``, where ``x`` is positive.",
184
192
  group="Scalar Math",
185
193
  )
186
194
  add_builtin(
187
195
  "log2",
188
196
  input_types={"x": Float},
189
197
  value_func=sametype_value_func(Float),
190
- doc="Return the natural log (base-2) of x, where x is positive.",
198
+ doc="Return the binary logarithm (base-2) of ``x``, where ``x`` is positive.",
191
199
  group="Scalar Math",
192
200
  )
193
201
  add_builtin(
194
202
  "log10",
195
203
  input_types={"x": Float},
196
204
  value_func=sametype_value_func(Float),
197
- doc="Return the natural log (base-10) of x, where x is positive.",
205
+ doc="Return the common logarithm (base-10) of ``x``, where ``x`` is positive.",
198
206
  group="Scalar Math",
199
207
  )
200
208
  add_builtin(
201
209
  "exp",
202
210
  input_types={"x": Float},
203
211
  value_func=sametype_value_func(Float),
204
- doc="Return base-e exponential, e^x.",
212
+ doc="Return the value of the exponential function :math:`e^x`.",
205
213
  group="Scalar Math",
214
+ require_original_output_arg=True,
206
215
  )
207
216
  add_builtin(
208
217
  "pow",
209
218
  input_types={"x": Float, "y": Float},
210
219
  value_func=sametype_value_func(Float),
211
- doc="Return the result of x raised to power of y.",
220
+ doc="Return the result of ``x`` raised to power of ``y``.",
212
221
  group="Scalar Math",
222
+ require_original_output_arg=True,
213
223
  )
214
224
 
215
225
  add_builtin(
@@ -217,9 +227,9 @@ add_builtin(
217
227
  input_types={"x": Float},
218
228
  value_func=sametype_value_func(Float),
219
229
  group="Scalar Math",
220
- doc="""Calculate the nearest integer value, rounding halfway cases away from zero.
221
- This is the most intuitive form of rounding in the colloquial sense, but can be slower than other options like ``warp.rint()``.
222
- Differs from ``numpy.round()``, which behaves the same way as ``numpy.rint()``.""",
230
+ doc="""Return the nearest integer value to ``x``, rounding halfway cases away from zero.
231
+ This is the most intuitive form of rounding in the colloquial sense, but can be slower than other options like :func:`warp.rint()`.
232
+ Differs from :func:`numpy.round()`, which behaves the same way as :func:`numpy.rint()`.""",
223
233
  )
224
234
 
225
235
  add_builtin(
@@ -227,9 +237,8 @@ add_builtin(
227
237
  input_types={"x": Float},
228
238
  value_func=sametype_value_func(Float),
229
239
  group="Scalar Math",
230
- doc="""Calculate the nearest integer value, rounding halfway cases to nearest even integer.
231
- It is generally faster than ``warp.round()``.
232
- Equivalent to ``numpy.rint()``.""",
240
+ doc="""Return the nearest integer value to ``x``, rounding halfway cases to nearest even integer.
241
+ It is generally faster than :func:`warp.round()`. Equivalent to :func:`numpy.rint()`.""",
233
242
  )
234
243
 
235
244
  add_builtin(
@@ -237,10 +246,10 @@ add_builtin(
237
246
  input_types={"x": Float},
238
247
  value_func=sametype_value_func(Float),
239
248
  group="Scalar Math",
240
- doc="""Calculate the nearest integer that is closer to zero than x.
241
- In other words, it discards the fractional part of x.
242
- It is similar to casting ``float(int(x))``, but preserves the negative sign when x is in the range [-0.0, -1.0).
243
- Equivalent to ``numpy.trunc()`` and ``numpy.fix()``.""",
249
+ doc="""Return the nearest integer that is closer to zero than ``x``.
250
+ In other words, it discards the fractional part of ``x``.
251
+ It is similar to casting ``float(int(x))``, but preserves the negative sign when x is in the range [-0.0, -1.0).
252
+ Equivalent to :func:`numpy.trunc()` and :func:`numpy.fix()`.""",
244
253
  )
245
254
 
246
255
  add_builtin(
@@ -248,7 +257,7 @@ add_builtin(
248
257
  input_types={"x": Float},
249
258
  value_func=sametype_value_func(Float),
250
259
  group="Scalar Math",
251
- doc="""Calculate the largest integer that is less than or equal to x.""",
260
+ doc="""Return the largest integer that is less than or equal to ``x``.""",
252
261
  )
253
262
 
254
263
  add_builtin(
@@ -256,22 +265,31 @@ add_builtin(
256
265
  input_types={"x": Float},
257
266
  value_func=sametype_value_func(Float),
258
267
  group="Scalar Math",
259
- doc="""Calculate the smallest integer that is greater than or equal to x.""",
268
+ doc="""Return the smallest integer that is greater than or equal to ``x``.""",
269
+ )
270
+
271
+ add_builtin(
272
+ "frac",
273
+ input_types={"x": Float},
274
+ value_func=sametype_value_func(Float),
275
+ group="Scalar Math",
276
+ doc="""Retrieve the fractional part of x.
277
+ In other words, it discards the integer part of x and is equivalent to ``x - trunc(x)``.""",
260
278
  )
261
279
 
262
280
 
263
- def infer_scalar_type(args):
264
- if args is None:
281
+ def infer_scalar_type(arg_types):
282
+ if arg_types is None:
265
283
  return Scalar
266
284
 
267
- def iterate_scalar_types(args):
268
- for a in args:
269
- if hasattr(a.type, "_wp_scalar_type_"):
270
- yield a.type._wp_scalar_type_
271
- elif a.type in scalar_types:
272
- yield a.type
285
+ def iterate_scalar_types(arg_types):
286
+ for t in arg_types:
287
+ if hasattr(t, "_wp_scalar_type_"):
288
+ yield t._wp_scalar_type_
289
+ elif t in scalar_types:
290
+ yield t
273
291
 
274
- scalarTypes = set(iterate_scalar_types(args))
292
+ scalarTypes = set(iterate_scalar_types(arg_types))
275
293
  if len(scalarTypes) > 1:
276
294
  raise RuntimeError(
277
295
  f"Couldn't figure out return type as arguments have multiple precisions: {list(scalarTypes)}"
@@ -279,13 +297,13 @@ def infer_scalar_type(args):
279
297
  return list(scalarTypes)[0]
280
298
 
281
299
 
282
- def sametype_scalar_value_func(args, kwds, _):
283
- if args is None:
300
+ def sametype_scalar_value_func(arg_types, kwds, _):
301
+ if arg_types is None:
284
302
  return Scalar
285
- if not all(types_equal(args[0].type, a.type) for a in args[1:]):
286
- raise RuntimeError(f"Input types must be exactly the same, {[a.type for a in args]}")
303
+ if not all(types_equal(arg_types[0], t) for t in arg_types[1:]):
304
+ raise RuntimeError(f"Input types must be exactly the same, {[t for t in arg_types]}")
287
305
 
288
- return infer_scalar_type(args)
306
+ return infer_scalar_type(arg_types)
289
307
 
290
308
 
291
309
  # ---------------------------------
@@ -310,14 +328,14 @@ add_builtin(
310
328
  "min",
311
329
  input_types={"x": vector(length=Any, dtype=Scalar), "y": vector(length=Any, dtype=Scalar)},
312
330
  value_func=sametype_value_func(vector(length=Any, dtype=Scalar)),
313
- doc="Return the element wise minimum of two vectors.",
331
+ doc="Return the element-wise minimum of two vectors.",
314
332
  group="Vector Math",
315
333
  )
316
334
  add_builtin(
317
335
  "max",
318
336
  input_types={"x": vector(length=Any, dtype=Scalar), "y": vector(length=Any, dtype=Scalar)},
319
337
  value_func=sametype_value_func(vector(length=Any, dtype=Scalar)),
320
- doc="Return the element wise maximum of two vectors.",
338
+ doc="Return the element-wise maximum of two vectors.",
321
339
  group="Vector Math",
322
340
  )
323
341
 
@@ -325,41 +343,41 @@ add_builtin(
325
343
  "min",
326
344
  input_types={"v": vector(length=Any, dtype=Scalar)},
327
345
  value_func=sametype_scalar_value_func,
328
- doc="Return the minimum element of a vector.",
346
+ doc="Return the minimum element of a vector ``v``.",
329
347
  group="Vector Math",
330
348
  )
331
349
  add_builtin(
332
350
  "max",
333
351
  input_types={"v": vector(length=Any, dtype=Scalar)},
334
352
  value_func=sametype_scalar_value_func,
335
- doc="Return the maximum element of a vector.",
353
+ doc="Return the maximum element of a vector ``v``.",
336
354
  group="Vector Math",
337
355
  )
338
356
 
339
357
  add_builtin(
340
358
  "argmin",
341
359
  input_types={"v": vector(length=Any, dtype=Scalar)},
342
- value_func=lambda args, kwds, _: warp.uint32,
343
- doc="Return the index of the minimum element of a vector.",
360
+ value_func=lambda arg_types, kwds, _: warp.uint32,
361
+ doc="Return the index of the minimum element of a vector ``v``.",
344
362
  group="Vector Math",
345
363
  missing_grad=True,
346
364
  )
347
365
  add_builtin(
348
366
  "argmax",
349
367
  input_types={"v": vector(length=Any, dtype=Scalar)},
350
- value_func=lambda args, kwds, _: warp.uint32,
351
- doc="Return the index of the maximum element of a vector.",
368
+ value_func=lambda arg_types, kwds, _: warp.uint32,
369
+ doc="Return the index of the maximum element of a vector ``v``.",
352
370
  group="Vector Math",
353
371
  missing_grad=True,
354
372
  )
355
373
 
356
374
 
357
- def value_func_outer(args, kwds, _):
358
- if args is None:
375
+ def value_func_outer(arg_types, kwds, _):
376
+ if arg_types is None:
359
377
  return matrix(shape=(Any, Any), dtype=Scalar)
360
378
 
361
- scalarType = infer_scalar_type(args)
362
- vectorLengths = [i.type._length_ for i in args]
379
+ scalarType = infer_scalar_type(arg_types)
380
+ vectorLengths = [t._length_ for t in arg_types]
363
381
  return matrix(shape=(vectorLengths), dtype=scalarType)
364
382
 
365
383
 
@@ -368,7 +386,7 @@ add_builtin(
368
386
  input_types={"x": vector(length=Any, dtype=Scalar), "y": vector(length=Any, dtype=Scalar)},
369
387
  value_func=value_func_outer,
370
388
  group="Vector Math",
371
- doc="Compute the outer product x*y^T for two vec2 objects.",
389
+ doc="Compute the outer product ``x*y^T`` for two vectors.",
372
390
  )
373
391
 
374
392
  add_builtin(
@@ -376,14 +394,14 @@ add_builtin(
376
394
  input_types={"x": vector(length=3, dtype=Scalar), "y": vector(length=3, dtype=Scalar)},
377
395
  value_func=sametype_value_func(vector(length=3, dtype=Scalar)),
378
396
  group="Vector Math",
379
- doc="Compute the cross product of two 3d vectors.",
397
+ doc="Compute the cross product of two 3D vectors.",
380
398
  )
381
399
  add_builtin(
382
400
  "skew",
383
401
  input_types={"x": vector(length=3, dtype=Scalar)},
384
- value_func=lambda args, kwds, _: matrix(shape=(3, 3), dtype=args[0].type._wp_scalar_type_),
402
+ value_func=lambda arg_types, kwds, _: matrix(shape=(3, 3), dtype=arg_types[0]._wp_scalar_type_),
385
403
  group="Vector Math",
386
- doc="Compute the skew symmetric matrix for a 3d vector.",
404
+ doc="Compute the skew-symmetric 3x3 matrix for a 3D vector ``x``.",
387
405
  )
388
406
 
389
407
  add_builtin(
@@ -391,59 +409,62 @@ add_builtin(
391
409
  input_types={"x": vector(length=Any, dtype=Float)},
392
410
  value_func=sametype_scalar_value_func,
393
411
  group="Vector Math",
394
- doc="Compute the length of a vector.",
412
+ doc="Compute the length of a vector ``x``.",
413
+ require_original_output_arg=True,
395
414
  )
396
415
  add_builtin(
397
416
  "length",
398
417
  input_types={"x": quaternion(dtype=Float)},
399
418
  value_func=sametype_scalar_value_func,
400
419
  group="Vector Math",
401
- doc="Compute the length of a quaternion.",
420
+ doc="Compute the length of a quaternion ``x``.",
421
+ require_original_output_arg=True,
402
422
  )
403
423
  add_builtin(
404
424
  "length_sq",
405
425
  input_types={"x": vector(length=Any, dtype=Scalar)},
406
426
  value_func=sametype_scalar_value_func,
407
427
  group="Vector Math",
408
- doc="Compute the squared length of a 2d vector.",
428
+ doc="Compute the squared length of a 2D vector ``x``.",
409
429
  )
410
430
  add_builtin(
411
431
  "length_sq",
412
432
  input_types={"x": quaternion(dtype=Scalar)},
413
433
  value_func=sametype_scalar_value_func,
414
434
  group="Vector Math",
415
- doc="Compute the squared length of a quaternion.",
435
+ doc="Compute the squared length of a quaternion ``x``.",
416
436
  )
417
437
  add_builtin(
418
438
  "normalize",
419
439
  input_types={"x": vector(length=Any, dtype=Float)},
420
440
  value_func=sametype_value_func(vector(length=Any, dtype=Scalar)),
421
441
  group="Vector Math",
422
- doc="Compute the normalized value of x, if length(x) is 0 then the zero vector is returned.",
442
+ doc="Compute the normalized value of ``x``. If ``length(x)`` is 0 then the zero vector is returned.",
443
+ require_original_output_arg=True,
423
444
  )
424
445
  add_builtin(
425
446
  "normalize",
426
447
  input_types={"x": quaternion(dtype=Float)},
427
448
  value_func=sametype_value_func(quaternion(dtype=Scalar)),
428
449
  group="Vector Math",
429
- doc="Compute the normalized value of x, if length(x) is 0 then the zero quat is returned.",
450
+ doc="Compute the normalized value of ``x``. If ``length(x)`` is 0, then the zero quaternion is returned.",
430
451
  )
431
452
 
432
453
  add_builtin(
433
454
  "transpose",
434
455
  input_types={"m": matrix(shape=(Any, Any), dtype=Scalar)},
435
- value_func=lambda args, kwds, _: matrix(
436
- shape=(args[0].type._shape_[1], args[0].type._shape_[0]), dtype=args[0].type._wp_scalar_type_
456
+ value_func=lambda arg_types, kwds, _: matrix(
457
+ shape=(arg_types[0]._shape_[1], arg_types[0]._shape_[0]), dtype=arg_types[0]._wp_scalar_type_
437
458
  ),
438
459
  group="Vector Math",
439
- doc="Return the transpose of the matrix m",
460
+ doc="Return the transpose of the matrix ``m``.",
440
461
  )
441
462
 
442
463
 
443
- def value_func_mat_inv(args, kwds, _):
444
- if args is None:
464
+ def value_func_mat_inv(arg_types, kwds, _):
465
+ if arg_types is None:
445
466
  return matrix(shape=(Any, Any), dtype=Float)
446
- return args[0].type
467
+ return arg_types[0]
447
468
 
448
469
 
449
470
  add_builtin(
@@ -451,7 +472,8 @@ add_builtin(
451
472
  input_types={"m": matrix(shape=(2, 2), dtype=Float)},
452
473
  value_func=value_func_mat_inv,
453
474
  group="Vector Math",
454
- doc="Return the inverse of a 2x2 matrix m",
475
+ doc="Return the inverse of a 2x2 matrix ``m``.",
476
+ require_original_output_arg=True,
455
477
  )
456
478
 
457
479
  add_builtin(
@@ -459,7 +481,8 @@ add_builtin(
459
481
  input_types={"m": matrix(shape=(3, 3), dtype=Float)},
460
482
  value_func=value_func_mat_inv,
461
483
  group="Vector Math",
462
- doc="Return the inverse of a 3x3 matrix m",
484
+ doc="Return the inverse of a 3x3 matrix ``m``.",
485
+ require_original_output_arg=True,
463
486
  )
464
487
 
465
488
  add_builtin(
@@ -467,14 +490,15 @@ add_builtin(
467
490
  input_types={"m": matrix(shape=(4, 4), dtype=Float)},
468
491
  value_func=value_func_mat_inv,
469
492
  group="Vector Math",
470
- doc="Return the inverse of a 4x4 matrix m",
493
+ doc="Return the inverse of a 4x4 matrix ``m``.",
494
+ require_original_output_arg=True,
471
495
  )
472
496
 
473
497
 
474
- def value_func_mat_det(args, kwds, _):
475
- if args is None:
498
+ def value_func_mat_det(arg_types, kwds, _):
499
+ if arg_types is None:
476
500
  return Scalar
477
- return args[0].type._wp_scalar_type_
501
+ return arg_types[0]._wp_scalar_type_
478
502
 
479
503
 
480
504
  add_builtin(
@@ -482,7 +506,7 @@ add_builtin(
482
506
  input_types={"m": matrix(shape=(2, 2), dtype=Float)},
483
507
  value_func=value_func_mat_det,
484
508
  group="Vector Math",
485
- doc="Return the determinant of a 2x2 matrix m",
509
+ doc="Return the determinant of a 2x2 matrix ``m``.",
486
510
  )
487
511
 
488
512
  add_builtin(
@@ -490,7 +514,7 @@ add_builtin(
490
514
  input_types={"m": matrix(shape=(3, 3), dtype=Float)},
491
515
  value_func=value_func_mat_det,
492
516
  group="Vector Math",
493
- doc="Return the determinant of a 3x3 matrix m",
517
+ doc="Return the determinant of a 3x3 matrix ``m``.",
494
518
  )
495
519
 
496
520
  add_builtin(
@@ -498,16 +522,16 @@ add_builtin(
498
522
  input_types={"m": matrix(shape=(4, 4), dtype=Float)},
499
523
  value_func=value_func_mat_det,
500
524
  group="Vector Math",
501
- doc="Return the determinant of a 4x4 matrix m",
525
+ doc="Return the determinant of a 4x4 matrix ``m``.",
502
526
  )
503
527
 
504
528
 
505
- def value_func_mat_trace(args, kwds, _):
506
- if args is None:
529
+ def value_func_mat_trace(arg_types, kwds, _):
530
+ if arg_types is None:
507
531
  return Scalar
508
- if args[0].type._shape_[0] != args[0].type._shape_[1]:
509
- raise RuntimeError(f"Matrix shape is {args[0].type._shape_}. Cannot find the trace of non square matrices")
510
- return args[0].type._wp_scalar_type_
532
+ if arg_types[0]._shape_[0] != arg_types[0]._shape_[1]:
533
+ raise RuntimeError(f"Matrix shape is {arg_types[0]._shape_}. Cannot find the trace of non square matrices")
534
+ return arg_types[0]._wp_scalar_type_
511
535
 
512
536
 
513
537
  add_builtin(
@@ -515,15 +539,15 @@ add_builtin(
515
539
  input_types={"m": matrix(shape=(Any, Any), dtype=Scalar)},
516
540
  value_func=value_func_mat_trace,
517
541
  group="Vector Math",
518
- doc="Return the trace of the matrix m",
542
+ doc="Return the trace of the matrix ``m``.",
519
543
  )
520
544
 
521
545
 
522
- def value_func_diag(args, kwds, _):
523
- if args is None:
546
+ def value_func_diag(arg_types, kwds, _):
547
+ if arg_types is None:
524
548
  return matrix(shape=(Any, Any), dtype=Scalar)
525
549
  else:
526
- return matrix(shape=(args[0].type._length_, args[0].type._length_), dtype=args[0].type._wp_scalar_type_)
550
+ return matrix(shape=(arg_types[0]._length_, arg_types[0]._length_), dtype=arg_types[0]._wp_scalar_type_)
527
551
 
528
552
 
529
553
  add_builtin(
@@ -531,7 +555,27 @@ add_builtin(
531
555
  input_types={"d": vector(length=Any, dtype=Scalar)},
532
556
  value_func=value_func_diag,
533
557
  group="Vector Math",
534
- doc="Returns a matrix with the components of the vector d on the diagonal",
558
+ doc="Returns a matrix with the components of the vector ``d`` on the diagonal.",
559
+ )
560
+
561
+
562
+ def value_func_get_diag(arg_types, kwds, _):
563
+ if arg_types is None:
564
+ return vector(length=(Any), dtype=Scalar)
565
+ else:
566
+ if arg_types[0]._shape_[0] != arg_types[0]._shape_[1]:
567
+ raise RuntimeError(
568
+ f"Matrix shape is {arg_types[0]._shape_}; get_diag is only available for square matrices."
569
+ )
570
+ return vector(length=arg_types[0]._shape_[0], dtype=arg_types[0]._wp_scalar_type_)
571
+
572
+
573
+ add_builtin(
574
+ "get_diag",
575
+ input_types={"m": matrix(shape=(Any, Any), dtype=Scalar)},
576
+ value_func=value_func_get_diag,
577
+ group="Vector Math",
578
+ doc="Returns a vector containing the diagonal elements of the square matrix ``m``.",
535
579
  )
536
580
 
537
581
  add_builtin(
@@ -539,14 +583,15 @@ add_builtin(
539
583
  input_types={"x": vector(length=Any, dtype=Scalar), "y": vector(length=Any, dtype=Scalar)},
540
584
  value_func=sametype_value_func(vector(length=Any, dtype=Scalar)),
541
585
  group="Vector Math",
542
- doc="Component wise multiply of two 2d vectors.",
586
+ doc="Component-wise multiplication of two 2D vectors.",
543
587
  )
544
588
  add_builtin(
545
589
  "cw_div",
546
590
  input_types={"x": vector(length=Any, dtype=Scalar), "y": vector(length=Any, dtype=Scalar)},
547
591
  value_func=sametype_value_func(vector(length=Any, dtype=Scalar)),
548
592
  group="Vector Math",
549
- doc="Component wise division of two 2d vectors.",
593
+ doc="Component-wise division of two 2D vectors.",
594
+ require_original_output_arg=True,
550
595
  )
551
596
 
552
597
  add_builtin(
@@ -554,14 +599,15 @@ add_builtin(
554
599
  input_types={"x": matrix(shape=(Any, Any), dtype=Scalar), "y": matrix(shape=(Any, Any), dtype=Scalar)},
555
600
  value_func=sametype_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
556
601
  group="Vector Math",
557
- doc="Component wise multiply of two 2d vectors.",
602
+ doc="Component-wise multiplication of two 2D vectors.",
558
603
  )
559
604
  add_builtin(
560
605
  "cw_div",
561
606
  input_types={"x": matrix(shape=(Any, Any), dtype=Scalar), "y": matrix(shape=(Any, Any), dtype=Scalar)},
562
607
  value_func=sametype_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
563
608
  group="Vector Math",
564
- doc="Component wise division of two 2d vectors.",
609
+ doc="Component-wise division of two 2D vectors.",
610
+ require_original_output_arg=True,
565
611
  )
566
612
 
567
613
 
@@ -573,16 +619,19 @@ for t in scalar_types_all:
573
619
  t.__name__, input_types={"u": u}, value_type=t, doc="", hidden=True, group="Scalar Math", export=False
574
620
  )
575
621
 
622
+ for u in [bool, builtins.bool]:
623
+ add_builtin(bool.__name__, input_types={"u": u}, value_type=bool, doc="", hidden=True, export=False, namespace="")
624
+
576
625
 
577
- def vector_constructor_func(args, kwds, templates):
578
- if args is None:
626
+ def vector_constructor_func(arg_types, kwds, templates):
627
+ if arg_types is None:
579
628
  return vector(length=Any, dtype=Scalar)
580
629
 
581
630
  if templates is None or len(templates) == 0:
582
631
  # handle construction of anonymous (undeclared) vector types
583
632
 
584
633
  if "length" in kwds:
585
- if len(args) == 0:
634
+ if len(arg_types) == 0:
586
635
  if "dtype" not in kwds:
587
636
  raise RuntimeError(
588
637
  "vec() must have dtype as a keyword argument if it has no positional arguments, e.g.: wp.vector(length=5, dtype=wp.float32)"
@@ -592,34 +641,54 @@ def vector_constructor_func(args, kwds, templates):
592
641
  veclen = kwds["length"]
593
642
  vectype = kwds["dtype"]
594
643
 
595
- elif len(args) == 1:
644
+ elif len(arg_types) == 1:
596
645
  # value initialization e.g.: wp.vec(1.0, length=5)
597
646
  veclen = kwds["length"]
598
- vectype = args[0].type
647
+ vectype = arg_types[0]
648
+ if getattr(vectype, "_wp_generic_type_str_", None) == "vec_t":
649
+ # constructor from another vector
650
+ if vectype._length_ != veclen:
651
+ raise RuntimeError(
652
+ f"Incompatible vector lengths for casting copy constructor, {veclen} vs {vectype._length_}"
653
+ )
654
+ vectype = vectype._wp_scalar_type_
599
655
  else:
600
656
  raise RuntimeError(
601
657
  "vec() must have one scalar argument or the dtype keyword argument if the length keyword argument is specified, e.g.: wp.vec(1.0, length=5)"
602
658
  )
603
659
 
604
660
  else:
605
- if len(args) == 0:
661
+ if len(arg_types) == 0:
606
662
  raise RuntimeError(
607
663
  "vec() must have at least one numeric argument, if it's length, dtype is not specified"
608
664
  )
609
665
 
610
666
  if "dtype" in kwds:
667
+ # casting constructor
668
+ if len(arg_types) == 1 and types_equal(
669
+ arg_types[0], vector(length=Any, dtype=Scalar), match_generic=True
670
+ ):
671
+ veclen = arg_types[0]._length_
672
+ vectype = kwds["dtype"]
673
+ templates.append(veclen)
674
+ templates.append(vectype)
675
+ return vector(length=veclen, dtype=vectype)
611
676
  raise RuntimeError(
612
677
  "vec() should not have dtype specified if numeric arguments are given, the dtype will be inferred from the argument types"
613
678
  )
614
679
 
615
680
  # component wise construction of an anonymous vector, e.g. wp.vec(wp.float16(1.0), wp.float16(2.0), ....)
616
681
  # we infer the length and data type from the number and type of the arg values
617
- veclen = len(args)
618
- vectype = args[0].type
619
-
620
- if not all(vectype == a.type for a in args):
682
+ veclen = len(arg_types)
683
+ vectype = arg_types[0]
684
+
685
+ if len(arg_types) == 1 and getattr(vectype, "_wp_generic_type_str_", None) == "vec_t":
686
+ # constructor from another vector
687
+ veclen = vectype._length_
688
+ vectype = vectype._wp_scalar_type_
689
+ elif not all(vectype == t for t in arg_types):
621
690
  raise RuntimeError(
622
- f"All numeric arguments to vec() constructor should have the same type, expected {veclen} args of type {vectype}, received { ','.join(map(lambda x : str(x.type), args)) }"
691
+ f"All numeric arguments to vec() constructor should have the same type, expected {veclen} arg_types of type {vectype}, received { ','.join(map(lambda t : str(t), arg_types)) }"
623
692
  )
624
693
 
625
694
  # update the templates list, so we can generate vec<len, type>() correctly in codegen
@@ -629,9 +698,15 @@ def vector_constructor_func(args, kwds, templates):
629
698
  else:
630
699
  # construction of a predeclared type, e.g.: vec5d
631
700
  veclen, vectype = templates
632
- if not all(vectype == a.type for a in args):
701
+ if len(arg_types) == 1 and getattr(arg_types[0], "_wp_generic_type_str_", None) == "vec_t":
702
+ # constructor from another vector
703
+ if arg_types[0]._length_ != veclen:
704
+ raise RuntimeError(
705
+ f"Incompatible matrix sizes for casting copy constructor, {veclen} vs {arg_types[0]._length_}"
706
+ )
707
+ elif not all(vectype == t for t in arg_types):
633
708
  raise RuntimeError(
634
- f"All numeric arguments to vec() constructor should have the same type, expected {veclen} args of type {vectype}, received { ','.join(map(lambda x : str(x.type), args)) }"
709
+ f"All numeric arguments to vec() constructor should have the same type, expected {veclen} arg_types of type {vectype}, received { ','.join(map(lambda t : str(t), arg_types)) }"
635
710
  )
636
711
 
637
712
  retvalue = vector(length=veclen, dtype=vectype)
@@ -640,9 +715,9 @@ def vector_constructor_func(args, kwds, templates):
640
715
 
641
716
  add_builtin(
642
717
  "vector",
643
- input_types={"*args": Scalar, "length": int, "dtype": Scalar},
718
+ input_types={"*arg_types": Scalar, "length": int, "dtype": Scalar},
644
719
  variadic=True,
645
- initializer_list_func=lambda args, _: len(args) > 4,
720
+ initializer_list_func=lambda arg_types, _: len(arg_types) > 4,
646
721
  value_func=vector_constructor_func,
647
722
  native_func="vec_t",
648
723
  doc="Construct a vector of with given length and dtype.",
@@ -651,8 +726,8 @@ add_builtin(
651
726
  )
652
727
 
653
728
 
654
- def matrix_constructor_func(args, kwds, templates):
655
- if args is None:
729
+ def matrix_constructor_func(arg_types, kwds, templates):
730
+ if arg_types is None:
656
731
  return matrix(shape=(Any, Any), dtype=Scalar)
657
732
 
658
733
  if len(templates) == 0:
@@ -660,7 +735,7 @@ def matrix_constructor_func(args, kwds, templates):
660
735
  if "shape" not in kwds:
661
736
  raise RuntimeError("shape keyword must be specified when calling matrix() function")
662
737
 
663
- if len(args) == 0:
738
+ if len(arg_types) == 0:
664
739
  if "dtype" not in kwds:
665
740
  raise RuntimeError("matrix() must have dtype as a keyword argument if it has no positional arguments")
666
741
 
@@ -671,9 +746,16 @@ def matrix_constructor_func(args, kwds, templates):
671
746
  else:
672
747
  # value initialization, e.g.: m = matrix(1.0, shape=(3,2))
673
748
  shape = kwds["shape"]
674
- dtype = args[0].type
749
+ dtype = arg_types[0]
675
750
 
676
- if len(args) > 1 and len(args) != shape[0] * shape[1]:
751
+ if len(arg_types) == 1 and getattr(dtype, "_wp_generic_type_str_", None) == "mat_t":
752
+ # constructor from another matrix
753
+ if arg_types[0]._shape_ != shape:
754
+ raise RuntimeError(
755
+ f"Incompatible matrix sizes for casting copy constructor, {shape} vs {arg_types[0]._shape_}"
756
+ )
757
+ dtype = dtype._wp_scalar_type_
758
+ elif len(arg_types) > 1 and len(arg_types) != shape[0] * shape[1]:
677
759
  raise RuntimeError(
678
760
  "Wrong number of arguments for matrix() function, must initialize with either a scalar value, or m*n values"
679
761
  )
@@ -687,68 +769,70 @@ def matrix_constructor_func(args, kwds, templates):
687
769
  shape = (templates[0], templates[1])
688
770
  dtype = templates[2]
689
771
 
690
- if len(args) > 0:
691
- # check scalar arg type matches declared type
692
- if infer_scalar_type(args) != dtype:
693
- raise RuntimeError("Wrong scalar type for mat {} constructor".format(",".join(map(str, templates))))
694
-
695
- # check vector arg type matches declared type
696
- types = [a.type for a in args]
697
- if all(hasattr(a, "_wp_generic_type_str_") and a._wp_generic_type_str_ == "vec_t" for a in types):
698
- cols = len(types)
699
- if shape[1] != cols:
700
- raise RuntimeError(
701
- "Wrong number of vectors when attempting to construct a matrix with column vectors"
702
- )
703
-
704
- if not all(a._length_ == shape[0] for a in types):
772
+ if len(arg_types) > 0:
773
+ if len(arg_types) == 1 and getattr(arg_types[0], "_wp_generic_type_str_", None) == "mat_t":
774
+ # constructor from another matrix with same dimension but possibly different type
775
+ if arg_types[0]._shape_ != shape:
705
776
  raise RuntimeError(
706
- "Wrong vector row count when attempting to construct a matrix with column vectors"
777
+ f"Incompatible matrix sizes for casting copy constructor, {shape} vs {arg_types[0]._shape_}"
707
778
  )
708
-
709
779
  else:
710
- # check that we either got 1 arg (scalar construction), or enough values for whole matrix
711
- size = shape[0] * shape[1]
712
- if len(args) > 1 and len(args) != size:
713
- raise RuntimeError(
714
- "Wrong number of scalars when attempting to construct a matrix from a list of components"
715
- )
780
+ # check scalar arg type matches declared type
781
+ if infer_scalar_type(arg_types) != dtype:
782
+ raise RuntimeError("Wrong scalar type for mat {} constructor".format(",".join(map(str, templates))))
783
+
784
+ # check vector arg type matches declared type
785
+ if all(hasattr(a, "_wp_generic_type_str_") and a._wp_generic_type_str_ == "vec_t" for a in arg_types):
786
+ cols = len(arg_types)
787
+ if shape[1] != cols:
788
+ raise RuntimeError(
789
+ "Wrong number of vectors when attempting to construct a matrix with column vectors"
790
+ )
791
+
792
+ if not all(a._length_ == shape[0] for a in arg_types):
793
+ raise RuntimeError(
794
+ "Wrong vector row count when attempting to construct a matrix with column vectors"
795
+ )
796
+ else:
797
+ # check that we either got 1 arg (scalar construction), or enough values for whole matrix
798
+ size = shape[0] * shape[1]
799
+ if len(arg_types) > 1 and len(arg_types) != size:
800
+ raise RuntimeError(
801
+ "Wrong number of scalars when attempting to construct a matrix from a list of components"
802
+ )
716
803
 
717
804
  return matrix(shape=shape, dtype=dtype)
718
805
 
719
806
 
720
807
  # only use initializer list if matrix size < 5x5, or for scalar construction
721
- def matrix_initlist_func(args, templates):
808
+ def matrix_initlist_func(arg_types, templates):
722
809
  m, n, dtype = templates
723
- if (
724
- len(args) == 0
725
- or len(args) == 1 # zero construction
810
+ return not (
811
+ len(arg_types) == 0
812
+ or len(arg_types) == 1 # zero construction
726
813
  or (m == n and n < 5) # scalar construction # value construction for small matrices
727
- ):
728
- return False
729
- else:
730
- return True
814
+ )
731
815
 
732
816
 
733
817
  add_builtin(
734
818
  "matrix",
735
- input_types={"*args": Scalar, "shape": Tuple[int, int], "dtype": Scalar},
819
+ input_types={"*arg_types": Scalar, "shape": Tuple[int, int], "dtype": Scalar},
736
820
  variadic=True,
737
821
  initializer_list_func=matrix_initlist_func,
738
822
  value_func=matrix_constructor_func,
739
823
  native_func="mat_t",
740
- doc="Construct a matrix, if positional args are not given then matrix will be zero-initialized.",
824
+ doc="Construct a matrix. If the positional ``arg_types`` are not given, then matrix will be zero-initialized.",
741
825
  group="Vector Math",
742
826
  export=False,
743
827
  )
744
828
 
745
829
 
746
830
  # identity:
747
- def matrix_identity_value_func(args, kwds, templates):
748
- if args is None:
831
+ def matrix_identity_value_func(arg_types, kwds, templates):
832
+ if arg_types is None:
749
833
  return matrix(shape=(Any, Any), dtype=Scalar)
750
834
 
751
- if len(args):
835
+ if len(arg_types):
752
836
  raise RuntimeError("identity() function does not accept positional arguments")
753
837
 
754
838
  if "n" not in kwds:
@@ -779,7 +863,7 @@ add_builtin(
779
863
  )
780
864
 
781
865
 
782
- def matrix_transform_value_func(args, kwds, templates):
866
+ def matrix_transform_value_func(arg_types, kwds, templates):
783
867
  if templates is None:
784
868
  return matrix(shape=(Any, Any), dtype=Float)
785
869
 
@@ -789,7 +873,7 @@ def matrix_transform_value_func(args, kwds, templates):
789
873
  m, n, dtype = templates
790
874
  if (m, n) != (4, 4):
791
875
  raise RuntimeError("Can only construct 4x4 matrices with position, rotation and scale")
792
- if infer_scalar_type(args) != dtype:
876
+ if infer_scalar_type(arg_types) != dtype:
793
877
  raise RuntimeError("Wrong scalar type for mat<{}> constructor".format(",".join(map(str, templates))))
794
878
 
795
879
  return matrix(shape=(4, 4), dtype=dtype)
@@ -804,7 +888,8 @@ add_builtin(
804
888
  },
805
889
  value_func=matrix_transform_value_func,
806
890
  native_func="mat_t",
807
- doc="""Construct a 4x4 transformation matrix that applies the transformations as Translation(pos)*Rotation(rot)*Scale(scale) when applied to column vectors, i.e.: y = (TRS)*x""",
891
+ doc="""Construct a 4x4 transformation matrix that applies the transformations as
892
+ Translation(pos)*Rotation(rot)*Scale(scale) when applied to column vectors, i.e.: y = (TRS)*x""",
808
893
  group="Vector Math",
809
894
  export=False,
810
895
  )
@@ -823,8 +908,8 @@ add_builtin(
823
908
  value_type=None,
824
909
  group="Vector Math",
825
910
  export=False,
826
- doc="""Compute the SVD of a 3x3 matrix. The singular values are returned in sigma,
827
- while the left and right basis vectors are returned in U and V.""",
911
+ doc="""Compute the SVD of a 3x3 matrix ``A``. The singular values are returned in ``sigma``,
912
+ while the left and right basis vectors are returned in ``U`` and ``V``.""",
828
913
  )
829
914
 
830
915
  add_builtin(
@@ -837,7 +922,8 @@ add_builtin(
837
922
  value_type=None,
838
923
  group="Vector Math",
839
924
  export=False,
840
- doc="""Compute the QR decomposition of a 3x3 matrix. The orthogonal matrix is returned in Q, while the upper triangular matrix is returned in R.""",
925
+ doc="""Compute the QR decomposition of a 3x3 matrix ``A``. The orthogonal matrix is returned in ``Q``,
926
+ while the upper triangular matrix is returned in ``R``.""",
841
927
  )
842
928
 
843
929
  add_builtin(
@@ -850,36 +936,53 @@ add_builtin(
850
936
  value_type=None,
851
937
  group="Vector Math",
852
938
  export=False,
853
- doc="""Compute the eigendecomposition of a 3x3 matrix. The eigenvectors are returned as the columns of Q, while the corresponding eigenvalues are returned in d.""",
939
+ doc="""Compute the eigendecomposition of a 3x3 matrix ``A``. The eigenvectors are returned as the columns of ``Q``,
940
+ while the corresponding eigenvalues are returned in ``d``.""",
854
941
  )
855
942
 
856
943
  # ---------------------------------
857
944
  # Quaternion Math
858
945
 
859
946
 
860
- def quaternion_value_func(args, kwds, templates):
861
- if args is None:
862
- return quaternion(dtype=Scalar)
947
+ def quaternion_value_func(arg_types, kwds, templates):
948
+ if arg_types is None:
949
+ return quaternion(dtype=Float)
863
950
 
864
- # if constructing anonymous quat type then infer output type from arguments
865
951
  if len(templates) == 0:
866
- dtype = infer_scalar_type(args)
952
+ if "dtype" in kwds:
953
+ # casting constructor
954
+ dtype = kwds["dtype"]
955
+ else:
956
+ # if constructing anonymous quat type then infer output type from arguments
957
+ dtype = infer_scalar_type(arg_types)
867
958
  templates.append(dtype)
868
959
  else:
869
- # if constructing predeclared type then check args match expectation
870
- if len(args) > 0 and infer_scalar_type(args) != templates[0]:
960
+ # if constructing predeclared type then check arg_types match expectation
961
+ if len(arg_types) > 0 and infer_scalar_type(arg_types) != templates[0]:
871
962
  raise RuntimeError("Wrong scalar type for quat {} constructor".format(",".join(map(str, templates))))
872
963
 
873
964
  return quaternion(dtype=templates[0])
874
965
 
875
966
 
967
+ def quat_cast_value_func(arg_types, kwds, templates):
968
+ if arg_types is None:
969
+ raise RuntimeError("Missing quaternion argument.")
970
+ if "dtype" not in kwds:
971
+ raise RuntimeError("Missing 'dtype' kwd.")
972
+
973
+ dtype = kwds["dtype"]
974
+ templates.append(dtype)
975
+
976
+ return quaternion(dtype=dtype)
977
+
978
+
876
979
  add_builtin(
877
980
  "quaternion",
878
981
  input_types={},
879
982
  value_func=quaternion_value_func,
880
983
  native_func="quat_t",
881
984
  group="Quaternion Math",
882
- doc="""Construct a zero-initialized quaternion, quaternions are laid out as
985
+ doc="""Construct a zero-initialized quaternion. Quaternions are laid out as
883
986
  [ix, iy, iz, r], where ix, iy, iz are the imaginary part, and r the real part.""",
884
987
  export=False,
885
988
  )
@@ -889,7 +992,7 @@ add_builtin(
889
992
  value_func=quaternion_value_func,
890
993
  native_func="quat_t",
891
994
  group="Quaternion Math",
892
- doc="Create a quaternion using the supplied components (type inferred from component type)",
995
+ doc="Create a quaternion using the supplied components (type inferred from component type).",
893
996
  export=False,
894
997
  )
895
998
  add_builtin(
@@ -898,14 +1001,23 @@ add_builtin(
898
1001
  value_func=quaternion_value_func,
899
1002
  native_func="quat_t",
900
1003
  group="Quaternion Math",
901
- doc="Create a quaternion using the supplied vector/scalar (type inferred from scalar type)",
1004
+ doc="Create a quaternion using the supplied vector/scalar (type inferred from scalar type).",
1005
+ export=False,
1006
+ )
1007
+ add_builtin(
1008
+ "quaternion",
1009
+ input_types={"q": quaternion(dtype=Float)},
1010
+ value_func=quat_cast_value_func,
1011
+ native_func="quat_t",
1012
+ group="Quaternion Math",
1013
+ doc="Construct a quaternion of type dtype from another quaternion of a different dtype.",
902
1014
  export=False,
903
1015
  )
904
1016
 
905
1017
 
906
- def quat_identity_value_func(args, kwds, templates):
907
- # if args is None then we are in 'export' mode
908
- if args is None:
1018
+ def quat_identity_value_func(arg_types, kwds, templates):
1019
+ # if arg_types is None then we are in 'export' mode
1020
+ if arg_types is None:
909
1021
  return quatf
910
1022
 
911
1023
  if "dtype" not in kwds:
@@ -931,7 +1043,7 @@ add_builtin(
931
1043
  add_builtin(
932
1044
  "quat_from_axis_angle",
933
1045
  input_types={"axis": vector(length=3, dtype=Float), "angle": Float},
934
- value_func=lambda args, kwds, _: quaternion(dtype=infer_scalar_type(args)),
1046
+ value_func=lambda arg_types, kwds, _: quaternion(dtype=infer_scalar_type(arg_types)),
935
1047
  group="Quaternion Math",
936
1048
  doc="Construct a quaternion representing a rotation of angle radians around the given axis.",
937
1049
  )
@@ -945,49 +1057,50 @@ add_builtin(
945
1057
  add_builtin(
946
1058
  "quat_from_matrix",
947
1059
  input_types={"m": matrix(shape=(3, 3), dtype=Float)},
948
- value_func=lambda args, kwds, _: quaternion(dtype=infer_scalar_type(args)),
1060
+ value_func=lambda arg_types, kwds, _: quaternion(dtype=infer_scalar_type(arg_types)),
949
1061
  group="Quaternion Math",
950
1062
  doc="Construct a quaternion from a 3x3 matrix.",
951
1063
  )
952
1064
  add_builtin(
953
1065
  "quat_rpy",
954
1066
  input_types={"roll": Float, "pitch": Float, "yaw": Float},
955
- value_func=lambda args, kwds, _: quaternion(dtype=infer_scalar_type(args)),
1067
+ value_func=lambda arg_types, kwds, _: quaternion(dtype=infer_scalar_type(arg_types)),
956
1068
  group="Quaternion Math",
957
1069
  doc="Construct a quaternion representing a combined roll (z), pitch (x), yaw rotations (y) in radians.",
958
1070
  )
959
1071
  add_builtin(
960
1072
  "quat_inverse",
961
1073
  input_types={"q": quaternion(dtype=Float)},
962
- value_func=lambda args, kwds, _: quaternion(dtype=infer_scalar_type(args)),
1074
+ value_func=lambda arg_types, kwds, _: quaternion(dtype=infer_scalar_type(arg_types)),
963
1075
  group="Quaternion Math",
964
1076
  doc="Compute quaternion conjugate.",
965
1077
  )
966
1078
  add_builtin(
967
1079
  "quat_rotate",
968
1080
  input_types={"q": quaternion(dtype=Float), "p": vector(length=3, dtype=Float)},
969
- value_func=lambda args, kwds, _: vector(length=3, dtype=infer_scalar_type(args)),
1081
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
970
1082
  group="Quaternion Math",
971
1083
  doc="Rotate a vector by a quaternion.",
972
1084
  )
973
1085
  add_builtin(
974
1086
  "quat_rotate_inv",
975
1087
  input_types={"q": quaternion(dtype=Float), "p": vector(length=3, dtype=Float)},
976
- value_func=lambda args, kwds, _: vector(length=3, dtype=infer_scalar_type(args)),
1088
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
977
1089
  group="Quaternion Math",
978
- doc="Rotate a vector the inverse of a quaternion.",
1090
+ doc="Rotate a vector by the inverse of a quaternion.",
979
1091
  )
980
1092
  add_builtin(
981
1093
  "quat_slerp",
982
1094
  input_types={"q0": quaternion(dtype=Float), "q1": quaternion(dtype=Float), "t": Float},
983
- value_func=lambda args, kwds, _: quaternion(dtype=infer_scalar_type(args)),
1095
+ value_func=lambda arg_types, kwds, _: quaternion(dtype=infer_scalar_type(arg_types)),
984
1096
  group="Quaternion Math",
985
1097
  doc="Linearly interpolate between two quaternions.",
1098
+ require_original_output_arg=True,
986
1099
  )
987
1100
  add_builtin(
988
1101
  "quat_to_matrix",
989
1102
  input_types={"q": quaternion(dtype=Float)},
990
- value_func=lambda args, kwds, _: matrix(shape=(3, 3), dtype=infer_scalar_type(args)),
1103
+ value_func=lambda arg_types, kwds, _: matrix(shape=(3, 3), dtype=infer_scalar_type(arg_types)),
991
1104
  group="Quaternion Math",
992
1105
  doc="Convert a quaternion to a 3x3 rotation matrix.",
993
1106
  )
@@ -1003,19 +1116,19 @@ add_builtin(
1003
1116
  # Transformations
1004
1117
 
1005
1118
 
1006
- def transform_constructor_value_func(args, kwds, templates):
1119
+ def transform_constructor_value_func(arg_types, kwds, templates):
1007
1120
  if templates is None:
1008
1121
  return transformation(dtype=Scalar)
1009
1122
 
1010
1123
  if len(templates) == 0:
1011
1124
  # if constructing anonymous transform type then infer output type from arguments
1012
- dtype = infer_scalar_type(args)
1125
+ dtype = infer_scalar_type(arg_types)
1013
1126
  templates.append(dtype)
1014
1127
  else:
1015
- # if constructing predeclared type then check args match expectation
1016
- if infer_scalar_type(args) != templates[0]:
1128
+ # if constructing predeclared type then check arg_types match expectation
1129
+ if infer_scalar_type(arg_types) != templates[0]:
1017
1130
  raise RuntimeError(
1018
- f"Wrong scalar type for transform constructor expected {templates[0]}, got {','.join(map(lambda x : str(x.type), args))}"
1131
+ f"Wrong scalar type for transform constructor expected {templates[0]}, got {','.join(map(lambda t : str(t), arg_types))}"
1019
1132
  )
1020
1133
 
1021
1134
  return transformation(dtype=templates[0])
@@ -1027,13 +1140,13 @@ add_builtin(
1027
1140
  value_func=transform_constructor_value_func,
1028
1141
  native_func="transform_t",
1029
1142
  group="Transformations",
1030
- doc="Construct a rigid body transformation with translation part p and rotation q.",
1143
+ doc="Construct a rigid-body transformation with translation part ``p`` and rotation ``q``.",
1031
1144
  export=False,
1032
1145
  )
1033
1146
 
1034
1147
 
1035
- def transform_identity_value_func(args, kwds, templates):
1036
- if args is None:
1148
+ def transform_identity_value_func(arg_types, kwds, templates):
1149
+ if arg_types is None:
1037
1150
  return transformf
1038
1151
 
1039
1152
  if "dtype" not in kwds:
@@ -1059,68 +1172,72 @@ add_builtin(
1059
1172
  add_builtin(
1060
1173
  "transform_get_translation",
1061
1174
  input_types={"t": transformation(dtype=Float)},
1062
- value_func=lambda args, kwds, _: vector(length=3, dtype=infer_scalar_type(args)),
1175
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1063
1176
  group="Transformations",
1064
- doc="Return the translational part of a transform.",
1177
+ doc="Return the translational part of a transform ``t``.",
1065
1178
  )
1066
1179
  add_builtin(
1067
1180
  "transform_get_rotation",
1068
1181
  input_types={"t": transformation(dtype=Float)},
1069
- value_func=lambda args, kwds, _: quaternion(dtype=infer_scalar_type(args)),
1182
+ value_func=lambda arg_types, kwds, _: quaternion(dtype=infer_scalar_type(arg_types)),
1070
1183
  group="Transformations",
1071
- doc="Return the rotational part of a transform.",
1184
+ doc="Return the rotational part of a transform ``t``.",
1072
1185
  )
1073
1186
  add_builtin(
1074
1187
  "transform_multiply",
1075
1188
  input_types={"a": transformation(dtype=Float), "b": transformation(dtype=Float)},
1076
- value_func=lambda args, kwds, _: transformation(dtype=infer_scalar_type(args)),
1189
+ value_func=lambda arg_types, kwds, _: transformation(dtype=infer_scalar_type(arg_types)),
1077
1190
  group="Transformations",
1078
1191
  doc="Multiply two rigid body transformations together.",
1079
1192
  )
1080
1193
  add_builtin(
1081
1194
  "transform_point",
1082
1195
  input_types={"t": transformation(dtype=Scalar), "p": vector(length=3, dtype=Scalar)},
1083
- value_func=lambda args, kwds, _: vector(length=3, dtype=infer_scalar_type(args)),
1196
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1084
1197
  group="Transformations",
1085
- doc="Apply the transform to a point p treating the homogenous coordinate as w=1 (translation and rotation).",
1198
+ doc="Apply the transform to a point ``p`` treating the homogeneous coordinate as w=1 (translation and rotation).",
1086
1199
  )
1087
1200
  add_builtin(
1088
1201
  "transform_point",
1089
1202
  input_types={"m": matrix(shape=(4, 4), dtype=Scalar), "p": vector(length=3, dtype=Scalar)},
1090
- value_func=lambda args, kwds, _: vector(length=3, dtype=infer_scalar_type(args)),
1203
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1091
1204
  group="Vector Math",
1092
- doc="""Apply the transform to a point ``p`` treating the homogenous coordinate as w=1. The transformation is applied treating ``p`` as a column vector, e.g.: ``y = M*p``
1093
- note this is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T = p^T*M^T``. If the transform is coming from a library that uses row-vectors
1094
- then users should transpose the transformation matrix before calling this method.""",
1205
+ doc="""Apply the transform to a point ``p`` treating the homogeneous coordinate as w=1.
1206
+ The transformation is applied treating ``p`` as a column vector, e.g.: ``y = M*p``.
1207
+ Note this is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T = p^T*M^T``.
1208
+ If the transform is coming from a library that uses row-vectors, then users should transpose the transformation
1209
+ matrix before calling this method.""",
1095
1210
  )
1096
1211
  add_builtin(
1097
1212
  "transform_vector",
1098
1213
  input_types={"t": transformation(dtype=Scalar), "v": vector(length=3, dtype=Scalar)},
1099
- value_func=lambda args, kwds, _: vector(length=3, dtype=infer_scalar_type(args)),
1214
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1100
1215
  group="Transformations",
1101
- doc="Apply the transform to a vector v treating the homogenous coordinate as w=0 (rotation only).",
1216
+ doc="Apply the transform to a vector ``v`` treating the homogeneous coordinate as w=0 (rotation only).",
1102
1217
  )
1103
1218
  add_builtin(
1104
1219
  "transform_vector",
1105
1220
  input_types={"m": matrix(shape=(4, 4), dtype=Scalar), "v": vector(length=3, dtype=Scalar)},
1106
- value_func=lambda args, kwds, _: vector(length=3, dtype=infer_scalar_type(args)),
1221
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1107
1222
  group="Vector Math",
1108
- doc="""Apply the transform to a vector ``v`` treating the homogenous coordinate as w=0. The transformation is applied treating ``v`` as a column vector, e.g.: ``y = M*v``
1109
- note this is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T = v^T*M^T``. If the transform is coming from a library that uses row-vectors
1110
- then users should transpose the transformation matrix before calling this method.""",
1223
+ doc="""Apply the transform to a vector ``v`` treating the homogeneous coordinate as w=0.
1224
+ The transformation is applied treating ``v`` as a column vector, e.g.: ``y = M*v``
1225
+ note this is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T = v^T*M^T``.
1226
+ If the transform is coming from a library that uses row-vectors, then users should transpose the transformation
1227
+ matrix before calling this method.""",
1111
1228
  )
1112
1229
  add_builtin(
1113
1230
  "transform_inverse",
1114
1231
  input_types={"t": transformation(dtype=Float)},
1115
1232
  value_func=sametype_value_func(transformation(dtype=Float)),
1116
1233
  group="Transformations",
1117
- doc="Compute the inverse of the transform.",
1234
+ doc="Compute the inverse of the transformation ``t``.",
1118
1235
  )
1119
1236
  # ---------------------------------
1120
1237
  # Spatial Math
1121
1238
 
1122
1239
 
1123
- def spatial_vector_constructor_value_func(args, kwds, templates):
1240
+ def spatial_vector_constructor_value_func(arg_types, kwds, templates):
1124
1241
  if templates is None:
1125
1242
  return spatial_vector(dtype=Float)
1126
1243
 
@@ -1128,7 +1245,7 @@ def spatial_vector_constructor_value_func(args, kwds, templates):
1128
1245
  raise RuntimeError("Cannot use a generic type name in a kernel")
1129
1246
 
1130
1247
  vectype = templates[1]
1131
- if len(args) and infer_scalar_type(args) != vectype:
1248
+ if len(arg_types) and infer_scalar_type(arg_types) != vectype:
1132
1249
  raise RuntimeError("Wrong scalar type for spatial_vector<{}> constructor".format(",".join(map(str, templates))))
1133
1250
 
1134
1251
  return vector(length=6, dtype=vectype)
@@ -1140,7 +1257,7 @@ add_builtin(
1140
1257
  value_func=spatial_vector_constructor_value_func,
1141
1258
  native_func="vec_t",
1142
1259
  group="Spatial Math",
1143
- doc="Construct a 6d screw vector from two 3d vectors.",
1260
+ doc="Construct a 6D screw vector from two 3D vectors.",
1144
1261
  export=False,
1145
1262
  )
1146
1263
 
@@ -1148,7 +1265,7 @@ add_builtin(
1148
1265
  add_builtin(
1149
1266
  "spatial_adjoint",
1150
1267
  input_types={"r": matrix(shape=(3, 3), dtype=Float), "s": matrix(shape=(3, 3), dtype=Float)},
1151
- value_func=lambda args, kwds, _: matrix(shape=(6, 6), dtype=infer_scalar_type(args)),
1268
+ value_func=lambda arg_types, kwds, _: matrix(shape=(6, 6), dtype=infer_scalar_type(arg_types)),
1152
1269
  group="Spatial Math",
1153
1270
  doc="Construct a 6x6 spatial inertial matrix from two 3x3 diagonal blocks.",
1154
1271
  export=False,
@@ -1158,36 +1275,36 @@ add_builtin(
1158
1275
  input_types={"a": vector(length=6, dtype=Float), "b": vector(length=6, dtype=Float)},
1159
1276
  value_func=sametype_scalar_value_func,
1160
1277
  group="Spatial Math",
1161
- doc="Compute the dot product of two 6d screw vectors.",
1278
+ doc="Compute the dot product of two 6D screw vectors.",
1162
1279
  )
1163
1280
  add_builtin(
1164
1281
  "spatial_cross",
1165
1282
  input_types={"a": vector(length=6, dtype=Float), "b": vector(length=6, dtype=Float)},
1166
1283
  value_func=sametype_value_func(vector(length=6, dtype=Float)),
1167
1284
  group="Spatial Math",
1168
- doc="Compute the cross-product of two 6d screw vectors.",
1285
+ doc="Compute the cross product of two 6D screw vectors.",
1169
1286
  )
1170
1287
  add_builtin(
1171
1288
  "spatial_cross_dual",
1172
1289
  input_types={"a": vector(length=6, dtype=Float), "b": vector(length=6, dtype=Float)},
1173
1290
  value_func=sametype_value_func(vector(length=6, dtype=Float)),
1174
1291
  group="Spatial Math",
1175
- doc="Compute the dual cross-product of two 6d screw vectors.",
1292
+ doc="Compute the dual cross product of two 6D screw vectors.",
1176
1293
  )
1177
1294
 
1178
1295
  add_builtin(
1179
1296
  "spatial_top",
1180
1297
  input_types={"a": vector(length=6, dtype=Float)},
1181
- value_func=lambda args, kwds, _: vector(length=3, dtype=args[0].type._wp_scalar_type_),
1298
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=arg_types[0]._wp_scalar_type_),
1182
1299
  group="Spatial Math",
1183
- doc="Return the top (first) part of a 6d screw vector.",
1300
+ doc="Return the top (first) part of a 6D screw vector.",
1184
1301
  )
1185
1302
  add_builtin(
1186
1303
  "spatial_bottom",
1187
1304
  input_types={"a": vector(length=6, dtype=Float)},
1188
- value_func=lambda args, kwds, _: vector(length=3, dtype=args[0].type._wp_scalar_type_),
1305
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=arg_types[0]._wp_scalar_type_),
1189
1306
  group="Spatial Math",
1190
- doc="Return the bottom (second) part of a 6d screw vector.",
1307
+ doc="Return the bottom (second) part of a 6D screw vector.",
1191
1308
  )
1192
1309
 
1193
1310
  add_builtin(
@@ -1341,16 +1458,18 @@ add_builtin(
1341
1458
  },
1342
1459
  value_type=None,
1343
1460
  skip_replay=True,
1344
- doc="""Evaluate a multi-layer perceptron (MLP) layer in the form: ``out = act(weights*x + bias)``.
1461
+ doc="""Evaluate a multi-layer perceptron (MLP) layer in the form: ``out = act(weights*x + bias)``.
1345
1462
 
1346
1463
  :param weights: A layer's network weights with dimensions ``(m, n)``.
1347
1464
  :param bias: An array with dimensions ``(n)``.
1348
1465
  :param activation: A ``wp.func`` function that takes a single scalar float as input and returns a scalar float as output
1349
- :param index: The batch item to process, typically each thread will process 1 item in the batch, in this case index should be ``wp.tid()``
1466
+ :param index: The batch item to process, typically each thread will process one item in the batch, in which case
1467
+ index should be ``wp.tid()``
1350
1468
  :param x: The feature matrix with dimensions ``(n, b)``
1351
1469
  :param out: The network output with dimensions ``(m, b)``
1352
1470
 
1353
- :note: Feature and output matrices are transposed compared to some other frameworks such as PyTorch. All matrices are assumed to be stored in flattened row-major memory layout (NumPy default).""",
1471
+ :note: Feature and output matrices are transposed compared to some other frameworks such as PyTorch.
1472
+ All matrices are assumed to be stored in flattened row-major memory layout (NumPy default).""",
1354
1473
  group="Utility",
1355
1474
  )
1356
1475
 
@@ -1363,12 +1482,12 @@ add_builtin(
1363
1482
  input_types={"id": uint64, "lower": vec3, "upper": vec3},
1364
1483
  value_type=bvh_query_t,
1365
1484
  group="Geometry",
1366
- doc="""Construct an axis-aligned bounding box query against a bvh object. This query can be used to iterate over all bounds
1367
- inside a bvh. Returns an object that is used to track state during bvh traversal.
1368
-
1369
- :param id: The bvh identifier
1370
- :param lower: The lower bound of the bounding box in bvh space
1371
- :param upper: The upper bound of the bounding box in bvh space""",
1485
+ doc="""Construct an axis-aligned bounding box query against a BVH object. This query can be used to iterate over all bounds
1486
+ inside a BVH.
1487
+
1488
+ :param id: The BVH identifier
1489
+ :param lower: The lower bound of the bounding box in BVH space
1490
+ :param upper: The upper bound of the bounding box in BVH space""",
1372
1491
  )
1373
1492
 
1374
1493
  add_builtin(
@@ -1376,21 +1495,21 @@ add_builtin(
1376
1495
  input_types={"id": uint64, "start": vec3, "dir": vec3},
1377
1496
  value_type=bvh_query_t,
1378
1497
  group="Geometry",
1379
- doc="""Construct a ray query against a bvh object. This query can be used to iterate over all bounds
1380
- that intersect the ray. Returns an object that is used to track state during bvh traversal.
1381
-
1382
- :param id: The bvh identifier
1383
- :param start: The start of the ray in bvh space
1384
- :param dir: The direction of the ray in bvh space""",
1498
+ doc="""Construct a ray query against a BVH object. This query can be used to iterate over all bounds
1499
+ that intersect the ray.
1500
+
1501
+ :param id: The BVH identifier
1502
+ :param start: The start of the ray in BVH space
1503
+ :param dir: The direction of the ray in BVH space""",
1385
1504
  )
1386
1505
 
1387
1506
  add_builtin(
1388
1507
  "bvh_query_next",
1389
1508
  input_types={"query": bvh_query_t, "index": int},
1390
- value_type=bool,
1509
+ value_type=builtins.bool,
1391
1510
  group="Geometry",
1392
- doc="""Move to the next bound returned by the query. The index of the current bound is stored in ``index``, returns ``False``
1393
- if there are no more overlapping bound.""",
1511
+ doc="""Move to the next bound returned by the query.
1512
+ The index of the current bound is stored in ``index``, returns ``False`` if there are no more overlapping bound.""",
1394
1513
  )
1395
1514
 
1396
1515
  add_builtin(
@@ -1404,17 +1523,256 @@ add_builtin(
1404
1523
  "bary_u": float,
1405
1524
  "bary_v": float,
1406
1525
  },
1407
- value_type=bool,
1526
+ value_type=builtins.bool,
1527
+ group="Geometry",
1528
+ doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space. Returns ``True`` if a point < ``max_dist`` is found.
1529
+
1530
+ Identifies the sign of the distance using additional ray-casts to determine if the point is inside or outside.
1531
+ This method is relatively robust, but does increase computational cost.
1532
+ See below for additional sign determination methods.
1533
+
1534
+ :param id: The mesh identifier
1535
+ :param point: The point in space to query
1536
+ :param max_dist: Mesh faces above this distance will not be considered by the query
1537
+ :param inside: Returns a value < 0 if query point is inside the mesh, >=0 otherwise.
1538
+ Note that mesh must be watertight for this to be robust
1539
+ :param face: Returns the index of the closest face
1540
+ :param bary_u: Returns the barycentric u coordinate of the closest point
1541
+ :param bary_v: Returns the barycentric v coordinate of the closest point""",
1542
+ hidden=True,
1543
+ )
1544
+
1545
+ add_builtin(
1546
+ "mesh_query_point",
1547
+ input_types={
1548
+ "id": uint64,
1549
+ "point": vec3,
1550
+ "max_dist": float,
1551
+ },
1552
+ value_type=mesh_query_point_t,
1553
+ group="Geometry",
1554
+ doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space.
1555
+
1556
+ Identifies the sign of the distance using additional ray-casts to determine if the point is inside or outside.
1557
+ This method is relatively robust, but does increase computational cost.
1558
+ See below for additional sign determination methods.
1559
+
1560
+ :param id: The mesh identifier
1561
+ :param point: The point in space to query
1562
+ :param max_dist: Mesh faces above this distance will not be considered by the query""",
1563
+ require_original_output_arg=True,
1564
+ )
1565
+
1566
+ add_builtin(
1567
+ "mesh_query_point_no_sign",
1568
+ input_types={
1569
+ "id": uint64,
1570
+ "point": vec3,
1571
+ "max_dist": float,
1572
+ "face": int,
1573
+ "bary_u": float,
1574
+ "bary_v": float,
1575
+ },
1576
+ value_type=builtins.bool,
1408
1577
  group="Geometry",
1409
- doc="""Computes the closest point on the mesh with identifier `id` to the given point in space. Returns ``True`` if a point < ``max_dist`` is found.
1578
+ doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space. Returns ``True`` if a point < ``max_dist`` is found.
1579
+
1580
+ This method does not compute the sign of the point (inside/outside) which makes it faster than other point query methods.
1410
1581
 
1411
1582
  :param id: The mesh identifier
1412
1583
  :param point: The point in space to query
1413
1584
  :param max_dist: Mesh faces above this distance will not be considered by the query
1414
- :param inside: Returns a value < 0 if query point is inside the mesh, >=0 otherwise. Note that mesh must be watertight for this to be robust
1415
1585
  :param face: Returns the index of the closest face
1416
1586
  :param bary_u: Returns the barycentric u coordinate of the closest point
1417
1587
  :param bary_v: Returns the barycentric v coordinate of the closest point""",
1588
+ hidden=True,
1589
+ )
1590
+
1591
+ add_builtin(
1592
+ "mesh_query_point_no_sign",
1593
+ input_types={
1594
+ "id": uint64,
1595
+ "point": vec3,
1596
+ "max_dist": float,
1597
+ },
1598
+ value_type=mesh_query_point_t,
1599
+ group="Geometry",
1600
+ doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space.
1601
+
1602
+ This method does not compute the sign of the point (inside/outside) which makes it faster than other point query methods.
1603
+
1604
+ :param id: The mesh identifier
1605
+ :param point: The point in space to query
1606
+ :param max_dist: Mesh faces above this distance will not be considered by the query""",
1607
+ require_original_output_arg=True,
1608
+ )
1609
+
1610
+ add_builtin(
1611
+ "mesh_query_furthest_point_no_sign",
1612
+ input_types={
1613
+ "id": uint64,
1614
+ "point": vec3,
1615
+ "min_dist": float,
1616
+ "face": int,
1617
+ "bary_u": float,
1618
+ "bary_v": float,
1619
+ },
1620
+ value_type=builtins.bool,
1621
+ group="Geometry",
1622
+ doc="""Computes the furthest point on the mesh with identifier `id` to the given point in space. Returns ``True`` if a point > ``min_dist`` is found.
1623
+
1624
+ This method does not compute the sign of the point (inside/outside).
1625
+
1626
+ :param id: The mesh identifier
1627
+ :param point: The point in space to query
1628
+ :param min_dist: Mesh faces below this distance will not be considered by the query
1629
+ :param face: Returns the index of the furthest face
1630
+ :param bary_u: Returns the barycentric u coordinate of the furthest point
1631
+ :param bary_v: Returns the barycentric v coordinate of the furthest point""",
1632
+ hidden=True,
1633
+ )
1634
+
1635
+ add_builtin(
1636
+ "mesh_query_furthest_point_no_sign",
1637
+ input_types={
1638
+ "id": uint64,
1639
+ "point": vec3,
1640
+ "min_dist": float,
1641
+ },
1642
+ value_type=mesh_query_point_t,
1643
+ group="Geometry",
1644
+ doc="""Computes the furthest point on the mesh with identifier `id` to the given point in space.
1645
+
1646
+ This method does not compute the sign of the point (inside/outside).
1647
+
1648
+ :param id: The mesh identifier
1649
+ :param point: The point in space to query
1650
+ :param min_dist: Mesh faces below this distance will not be considered by the query""",
1651
+ require_original_output_arg=True,
1652
+ )
1653
+
1654
+ add_builtin(
1655
+ "mesh_query_point_sign_normal",
1656
+ input_types={
1657
+ "id": uint64,
1658
+ "point": vec3,
1659
+ "max_dist": float,
1660
+ "inside": float,
1661
+ "face": int,
1662
+ "bary_u": float,
1663
+ "bary_v": float,
1664
+ "epsilon": float,
1665
+ },
1666
+ defaults={"epsilon": 1.0e-3},
1667
+ value_type=builtins.bool,
1668
+ group="Geometry",
1669
+ doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space. Returns ``True`` if a point < ``max_dist`` is found.
1670
+
1671
+ Identifies the sign of the distance (inside/outside) using the angle-weighted pseudo normal.
1672
+ This approach to sign determination is robust for well conditioned meshes that are watertight and non-self intersecting.
1673
+ It is also comparatively fast to compute.
1674
+
1675
+ :param id: The mesh identifier
1676
+ :param point: The point in space to query
1677
+ :param max_dist: Mesh faces above this distance will not be considered by the query
1678
+ :param inside: Returns a value < 0 if query point is inside the mesh, >=0 otherwise.
1679
+ Note that mesh must be watertight for this to be robust
1680
+ :param face: Returns the index of the closest face
1681
+ :param bary_u: Returns the barycentric u coordinate of the closest point
1682
+ :param bary_v: Returns the barycentric v coordinate of the closest point
1683
+ :param epsilon: Epsilon treating distance values as equal, when locating the minimum distance vertex/face/edge, as a
1684
+ fraction of the average edge length, also for treating closest point as being on edge/vertex default 1e-3""",
1685
+ hidden=True,
1686
+ )
1687
+
1688
+ add_builtin(
1689
+ "mesh_query_point_sign_normal",
1690
+ input_types={
1691
+ "id": uint64,
1692
+ "point": vec3,
1693
+ "max_dist": float,
1694
+ "epsilon": float,
1695
+ },
1696
+ defaults={"epsilon": 1.0e-3},
1697
+ value_type=mesh_query_point_t,
1698
+ group="Geometry",
1699
+ doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space.
1700
+
1701
+ Identifies the sign of the distance (inside/outside) using the angle-weighted pseudo normal.
1702
+ This approach to sign determination is robust for well conditioned meshes that are watertight and non-self intersecting.
1703
+ It is also comparatively fast to compute.
1704
+
1705
+ :param id: The mesh identifier
1706
+ :param point: The point in space to query
1707
+ :param max_dist: Mesh faces above this distance will not be considered by the query
1708
+ :param epsilon: Epsilon treating distance values as equal, when locating the minimum distance vertex/face/edge, as a
1709
+ fraction of the average edge length, also for treating closest point as being on edge/vertex default 1e-3""",
1710
+ require_original_output_arg=True,
1711
+ )
1712
+
1713
+ add_builtin(
1714
+ "mesh_query_point_sign_winding_number",
1715
+ input_types={
1716
+ "id": uint64,
1717
+ "point": vec3,
1718
+ "max_dist": float,
1719
+ "inside": float,
1720
+ "face": int,
1721
+ "bary_u": float,
1722
+ "bary_v": float,
1723
+ "accuracy": float,
1724
+ "threshold": float,
1725
+ },
1726
+ defaults={"accuracy": 2.0, "threshold": 0.5},
1727
+ value_type=builtins.bool,
1728
+ group="Geometry",
1729
+ doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given point in space. Returns ``True`` if a point < ``max_dist`` is found.
1730
+
1731
+ Identifies the sign using the winding number of the mesh relative to the query point. This method of sign determination is robust for poorly conditioned meshes
1732
+ and provides a smooth approximation to sign even when the mesh is not watertight. This method is the most robust and accurate of the sign determination meshes
1733
+ but also the most expensive.
1734
+
1735
+ .. note:: The :class:`Mesh` object must be constructed with ``support_winding_number=True`` for this method to return correct results.
1736
+
1737
+ :param id: The mesh identifier
1738
+ :param point: The point in space to query
1739
+ :param max_dist: Mesh faces above this distance will not be considered by the query
1740
+ :param inside: Returns a value < 0 if query point is inside the mesh, >=0 otherwise.
1741
+ Note that mesh must be watertight for this to be robust
1742
+ :param face: Returns the index of the closest face
1743
+ :param bary_u: Returns the barycentric u coordinate of the closest point
1744
+ :param bary_v: Returns the barycentric v coordinate of the closest point
1745
+ :param accuracy: Accuracy for computing the winding number with fast winding number method utilizing second-order dipole approximation, default 2.0
1746
+ :param threshold: The threshold of the winding number to be considered inside, default 0.5""",
1747
+ hidden=True,
1748
+ )
1749
+
1750
+ add_builtin(
1751
+ "mesh_query_point_sign_winding_number",
1752
+ input_types={
1753
+ "id": uint64,
1754
+ "point": vec3,
1755
+ "max_dist": float,
1756
+ "accuracy": float,
1757
+ "threshold": float,
1758
+ },
1759
+ defaults={"accuracy": 2.0, "threshold": 0.5},
1760
+ value_type=mesh_query_point_t,
1761
+ group="Geometry",
1762
+ doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given point in space.
1763
+
1764
+ Identifies the sign using the winding number of the mesh relative to the query point. This method of sign determination is robust for poorly conditioned meshes
1765
+ and provides a smooth approximation to sign even when the mesh is not watertight. This method is the most robust and accurate of the sign determination meshes
1766
+ but also the most expensive.
1767
+
1768
+ .. note:: The :class:`Mesh` object must be constructed with ``support_winding_number=True`` for this method to return correct results.
1769
+
1770
+ :param id: The mesh identifier
1771
+ :param point: The point in space to query
1772
+ :param max_dist: Mesh faces above this distance will not be considered by the query
1773
+ :param accuracy: Accuracy for computing the winding number with fast winding number method utilizing second-order dipole approximation, default 2.0
1774
+ :param threshold: The threshold of the winding number to be considered inside, default 0.5""",
1775
+ require_original_output_arg=True,
1418
1776
  )
1419
1777
 
1420
1778
  add_builtin(
@@ -1431,9 +1789,9 @@ add_builtin(
1431
1789
  "normal": vec3,
1432
1790
  "face": int,
1433
1791
  },
1434
- value_type=bool,
1792
+ value_type=builtins.bool,
1435
1793
  group="Geometry",
1436
- doc="""Computes the closest ray hit on the mesh with identifier `id`, returns ``True`` if a point < ``max_t`` is found.
1794
+ doc="""Computes the closest ray hit on the :class:`Mesh` with identifier ``id``, returns ``True`` if a hit < ``max_t`` is found.
1437
1795
 
1438
1796
  :param id: The mesh identifier
1439
1797
  :param start: The start point of the ray
@@ -1442,9 +1800,29 @@ add_builtin(
1442
1800
  :param t: Returns the distance of the closest hit along the ray
1443
1801
  :param bary_u: Returns the barycentric u coordinate of the closest hit
1444
1802
  :param bary_v: Returns the barycentric v coordinate of the closest hit
1445
- :param sign: Returns a value > 0 if the hit ray hit front of the face, returns < 0 otherwise
1803
+ :param sign: Returns a value > 0 if the ray hit in front of the face, returns < 0 otherwise
1446
1804
  :param normal: Returns the face normal
1447
1805
  :param face: Returns the index of the hit face""",
1806
+ hidden=True,
1807
+ )
1808
+
1809
+ add_builtin(
1810
+ "mesh_query_ray",
1811
+ input_types={
1812
+ "id": uint64,
1813
+ "start": vec3,
1814
+ "dir": vec3,
1815
+ "max_t": float,
1816
+ },
1817
+ value_type=mesh_query_ray_t,
1818
+ group="Geometry",
1819
+ doc="""Computes the closest ray hit on the :class:`Mesh` with identifier ``id``.
1820
+
1821
+ :param id: The mesh identifier
1822
+ :param start: The start point of the ray
1823
+ :param dir: The ray direction (should be normalized)
1824
+ :param max_t: The maximum distance along the ray to check for intersections""",
1825
+ require_original_output_arg=True,
1448
1826
  )
1449
1827
 
1450
1828
  add_builtin(
@@ -1452,9 +1830,9 @@ add_builtin(
1452
1830
  input_types={"id": uint64, "lower": vec3, "upper": vec3},
1453
1831
  value_type=mesh_query_aabb_t,
1454
1832
  group="Geometry",
1455
- doc="""Construct an axis-aligned bounding box query against a mesh object. This query can be used to iterate over all triangles
1456
- inside a volume. Returns an object that is used to track state during mesh traversal.
1457
-
1833
+ doc="""Construct an axis-aligned bounding box query against a :class:`Mesh`.
1834
+ This query can be used to iterate over all triangles inside a volume.
1835
+
1458
1836
  :param id: The mesh identifier
1459
1837
  :param lower: The lower bound of the bounding box in mesh space
1460
1838
  :param upper: The upper bound of the bounding box in mesh space""",
@@ -1463,10 +1841,10 @@ add_builtin(
1463
1841
  add_builtin(
1464
1842
  "mesh_query_aabb_next",
1465
1843
  input_types={"query": mesh_query_aabb_t, "index": int},
1466
- value_type=bool,
1844
+ value_type=builtins.bool,
1467
1845
  group="Geometry",
1468
- doc="""Move to the next triangle overlapping the query bounding box. The index of the current face is stored in ``index``, returns ``False``
1469
- if there are no more overlapping triangles.""",
1846
+ doc="""Move to the next triangle overlapping the query bounding box.
1847
+ The index of the current face is stored in ``index``, returns ``False`` if there are no more overlapping triangles.""",
1470
1848
  )
1471
1849
 
1472
1850
  add_builtin(
@@ -1474,7 +1852,7 @@ add_builtin(
1474
1852
  input_types={"id": uint64, "face": int, "bary_u": float, "bary_v": float},
1475
1853
  value_type=vec3,
1476
1854
  group="Geometry",
1477
- doc="""Evaluates the position on the mesh given a face index, and barycentric coordinates.""",
1855
+ doc="""Evaluates the position on the :class:`Mesh` given a face index and barycentric coordinates.""",
1478
1856
  )
1479
1857
 
1480
1858
  add_builtin(
@@ -1482,7 +1860,7 @@ add_builtin(
1482
1860
  input_types={"id": uint64, "face": int, "bary_u": float, "bary_v": float},
1483
1861
  value_type=vec3,
1484
1862
  group="Geometry",
1485
- doc="""Evaluates the velocity on the mesh given a face index, and barycentric coordinates.""",
1863
+ doc="""Evaluates the velocity on the :class:`Mesh` given a face index and barycentric coordinates.""",
1486
1864
  )
1487
1865
 
1488
1866
  add_builtin(
@@ -1490,14 +1868,14 @@ add_builtin(
1490
1868
  input_types={"id": uint64, "point": vec3, "max_dist": float},
1491
1869
  value_type=hash_grid_query_t,
1492
1870
  group="Geometry",
1493
- doc="""Construct a point query against a hash grid. This query can be used to iterate over all neighboring points withing a
1494
- fixed radius from the query point. Returns an object that is used to track state during neighbor traversal.""",
1871
+ doc="Construct a point query against a :class:`HashGrid`. This query can be used to iterate over all neighboring points "
1872
+ "within a fixed radius from the query point.",
1495
1873
  )
1496
1874
 
1497
1875
  add_builtin(
1498
1876
  "hash_grid_query_next",
1499
1877
  input_types={"query": hash_grid_query_t, "index": int},
1500
- value_type=bool,
1878
+ value_type=builtins.bool,
1501
1879
  group="Geometry",
1502
1880
  doc="""Move to the next point in the hash grid query. The index of the current neighbor is stored in ``index``, returns ``False``
1503
1881
  if there are no more neighbors.""",
@@ -1508,8 +1886,10 @@ add_builtin(
1508
1886
  input_types={"id": uint64, "index": int},
1509
1887
  value_type=int,
1510
1888
  group="Geometry",
1511
- doc="""Return the index of a point in the grid, this can be used to re-order threads such that grid
1512
- traversal occurs in a spatially coherent order.""",
1889
+ doc="""Return the index of a point in the :class:`HashGrid`. This can be used to reorder threads such that grid
1890
+ traversal occurs in a spatially coherent order.
1891
+
1892
+ Returns -1 if the :class:`HashGrid` has not been reserved.""",
1513
1893
  )
1514
1894
 
1515
1895
  add_builtin(
@@ -1608,7 +1988,17 @@ add_builtin(
1608
1988
  input_types={"id": uint64, "uvw": vec3, "sampling_mode": int},
1609
1989
  value_type=float,
1610
1990
  group="Volumes",
1611
- doc="""Sample the volume given by ``id`` at the volume local-space point ``uvw``. Interpolation should be ``wp.Volume.CLOSEST``, or ``wp.Volume.LINEAR.``""",
1991
+ doc="""Sample the volume given by ``id`` at the volume local-space point ``uvw``.
1992
+ Interpolation should be :attr:`warp.Volume.CLOSEST` or :attr:`wp.Volume.LINEAR.`""",
1993
+ )
1994
+
1995
+ add_builtin(
1996
+ "volume_sample_grad_f",
1997
+ input_types={"id": uint64, "uvw": vec3, "sampling_mode": int, "grad": vec3},
1998
+ value_type=float,
1999
+ group="Volumes",
2000
+ doc="""Sample the volume and its gradient given by ``id`` at the volume local-space point ``uvw``.
2001
+ Interpolation should be :attr:`warp.Volume.CLOSEST` or :attr:`wp.Volume.LINEAR.`""",
1612
2002
  )
1613
2003
 
1614
2004
  add_builtin(
@@ -1616,14 +2006,15 @@ add_builtin(
1616
2006
  input_types={"id": uint64, "i": int, "j": int, "k": int},
1617
2007
  value_type=float,
1618
2008
  group="Volumes",
1619
- doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k``, if the voxel at this index does not exist this function returns the background value""",
2009
+ doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k``.
2010
+ If the voxel at this index does not exist, this function returns the background value""",
1620
2011
  )
1621
2012
 
1622
2013
  add_builtin(
1623
2014
  "volume_store_f",
1624
2015
  input_types={"id": uint64, "i": int, "j": int, "k": int, "value": float},
1625
2016
  group="Volumes",
1626
- doc="""Store the value at voxel with coordinates ``i``, ``j``, ``k``.""",
2017
+ doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
1627
2018
  )
1628
2019
 
1629
2020
  add_builtin(
@@ -1631,7 +2022,8 @@ add_builtin(
1631
2022
  input_types={"id": uint64, "uvw": vec3, "sampling_mode": int},
1632
2023
  value_type=vec3,
1633
2024
  group="Volumes",
1634
- doc="""Sample the vector volume given by ``id`` at the volume local-space point ``uvw``. Interpolation should be ``wp.Volume.CLOSEST``, or ``wp.Volume.LINEAR.``""",
2025
+ doc="""Sample the vector volume given by ``id`` at the volume local-space point ``uvw``.
2026
+ Interpolation should be :attr:`warp.Volume.CLOSEST` or :attr:`wp.Volume.LINEAR.`""",
1635
2027
  )
1636
2028
 
1637
2029
  add_builtin(
@@ -1639,14 +2031,15 @@ add_builtin(
1639
2031
  input_types={"id": uint64, "i": int, "j": int, "k": int},
1640
2032
  value_type=vec3,
1641
2033
  group="Volumes",
1642
- doc="""Returns the vector value of voxel with coordinates ``i``, ``j``, ``k``, if the voxel at this index does not exist this function returns the background value""",
2034
+ doc="""Returns the vector value of voxel with coordinates ``i``, ``j``, ``k``.
2035
+ If the voxel at this index does not exist, this function returns the background value.""",
1643
2036
  )
1644
2037
 
1645
2038
  add_builtin(
1646
2039
  "volume_store_v",
1647
2040
  input_types={"id": uint64, "i": int, "j": int, "k": int, "value": vec3},
1648
2041
  group="Volumes",
1649
- doc="""Store the value at voxel with coordinates ``i``, ``j``, ``k``.""",
2042
+ doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
1650
2043
  )
1651
2044
 
1652
2045
  add_builtin(
@@ -1654,7 +2047,7 @@ add_builtin(
1654
2047
  input_types={"id": uint64, "uvw": vec3},
1655
2048
  value_type=int,
1656
2049
  group="Volumes",
1657
- doc="""Sample the int32 volume given by ``id`` at the volume local-space point ``uvw``. """,
2050
+ doc="""Sample the :class:`int32` volume given by ``id`` at the volume local-space point ``uvw``. """,
1658
2051
  )
1659
2052
 
1660
2053
  add_builtin(
@@ -1662,14 +2055,15 @@ add_builtin(
1662
2055
  input_types={"id": uint64, "i": int, "j": int, "k": int},
1663
2056
  value_type=int,
1664
2057
  group="Volumes",
1665
- doc="""Returns the int32 value of voxel with coordinates ``i``, ``j``, ``k``, if the voxel at this index does not exist this function returns the background value""",
2058
+ doc="""Returns the :class:`int32` value of voxel with coordinates ``i``, ``j``, ``k``.
2059
+ If the voxel at this index does not exist, this function returns the background value.""",
1666
2060
  )
1667
2061
 
1668
2062
  add_builtin(
1669
2063
  "volume_store_i",
1670
2064
  input_types={"id": uint64, "i": int, "j": int, "k": int, "value": int},
1671
2065
  group="Volumes",
1672
- doc="""Store the value at voxel with coordinates ``i``, ``j``, ``k``.""",
2066
+ doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
1673
2067
  )
1674
2068
 
1675
2069
  add_builtin(
@@ -1677,28 +2071,28 @@ add_builtin(
1677
2071
  input_types={"id": uint64, "uvw": vec3},
1678
2072
  value_type=vec3,
1679
2073
  group="Volumes",
1680
- doc="""Transform a point defined in volume index space to world space given the volume's intrinsic affine transformation.""",
2074
+ doc="""Transform a point ``uvw`` defined in volume index space to world space given the volume's intrinsic affine transformation.""",
1681
2075
  )
1682
2076
  add_builtin(
1683
2077
  "volume_world_to_index",
1684
2078
  input_types={"id": uint64, "xyz": vec3},
1685
2079
  value_type=vec3,
1686
2080
  group="Volumes",
1687
- doc="""Transform a point defined in volume world space to the volume's index space, given the volume's intrinsic affine transformation.""",
2081
+ doc="""Transform a point ``xyz`` defined in volume world space to the volume's index space given the volume's intrinsic affine transformation.""",
1688
2082
  )
1689
2083
  add_builtin(
1690
2084
  "volume_index_to_world_dir",
1691
2085
  input_types={"id": uint64, "uvw": vec3},
1692
2086
  value_type=vec3,
1693
2087
  group="Volumes",
1694
- doc="""Transform a direction defined in volume index space to world space given the volume's intrinsic affine transformation.""",
2088
+ doc="""Transform a direction ``uvw`` defined in volume index space to world space given the volume's intrinsic affine transformation.""",
1695
2089
  )
1696
2090
  add_builtin(
1697
2091
  "volume_world_to_index_dir",
1698
2092
  input_types={"id": uint64, "xyz": vec3},
1699
2093
  value_type=vec3,
1700
2094
  group="Volumes",
1701
- doc="""Transform a direction defined in volume world space to the volume's index space, given the volume's intrinsic affine transformation.""",
2095
+ doc="""Transform a direction ``xyz`` defined in volume world space to the volume's index space given the volume's intrinsic affine transformation.""",
1702
2096
  )
1703
2097
 
1704
2098
 
@@ -1718,7 +2112,7 @@ add_builtin(
1718
2112
  input_types={"seed": int, "offset": int},
1719
2113
  value_type=uint32,
1720
2114
  group="Random",
1721
- doc="""Initialize a new random number generator given a user-defined seed and an offset.
2115
+ doc="""Initialize a new random number generator given a user-defined seed and an offset.
1722
2116
  This alternative constructor can be useful in parallel programs, where a kernel as a whole should share a seed,
1723
2117
  but each thread should generate uncorrelated values. In this case usage should be ``r = rand_init(seed, tid)``""",
1724
2118
  )
@@ -1728,31 +2122,31 @@ add_builtin(
1728
2122
  input_types={"state": uint32},
1729
2123
  value_type=int,
1730
2124
  group="Random",
1731
- doc="Return a random integer between [0, 2^32)",
2125
+ doc="Return a random integer in the range [0, 2^32).",
1732
2126
  )
1733
2127
  add_builtin(
1734
2128
  "randi",
1735
2129
  input_types={"state": uint32, "min": int, "max": int},
1736
2130
  value_type=int,
1737
2131
  group="Random",
1738
- doc="Return a random integer between [min, max)",
2132
+ doc="Return a random integer between [min, max).",
1739
2133
  )
1740
2134
  add_builtin(
1741
2135
  "randf",
1742
2136
  input_types={"state": uint32},
1743
2137
  value_type=float,
1744
2138
  group="Random",
1745
- doc="Return a random float between [0.0, 1.0)",
2139
+ doc="Return a random float between [0.0, 1.0).",
1746
2140
  )
1747
2141
  add_builtin(
1748
2142
  "randf",
1749
2143
  input_types={"state": uint32, "min": float, "max": float},
1750
2144
  value_type=float,
1751
2145
  group="Random",
1752
- doc="Return a random float between [min, max)",
2146
+ doc="Return a random float between [min, max).",
1753
2147
  )
1754
2148
  add_builtin(
1755
- "randn", input_types={"state": uint32}, value_type=float, group="Random", doc="Sample a normal distribution"
2149
+ "randn", input_types={"state": uint32}, value_type=float, group="Random", doc="Sample a normal distribution."
1756
2150
  )
1757
2151
 
1758
2152
  add_builtin(
@@ -1760,70 +2154,70 @@ add_builtin(
1760
2154
  input_types={"state": uint32, "cdf": array(dtype=float)},
1761
2155
  value_type=int,
1762
2156
  group="Random",
1763
- doc="Inverse transform sample a cumulative distribution function",
2157
+ doc="Inverse-transform sample a cumulative distribution function.",
1764
2158
  )
1765
2159
  add_builtin(
1766
2160
  "sample_triangle",
1767
2161
  input_types={"state": uint32},
1768
2162
  value_type=vec2,
1769
2163
  group="Random",
1770
- doc="Uniformly sample a triangle. Returns sample barycentric coordinates",
2164
+ doc="Uniformly sample a triangle. Returns sample barycentric coordinates.",
1771
2165
  )
1772
2166
  add_builtin(
1773
2167
  "sample_unit_ring",
1774
2168
  input_types={"state": uint32},
1775
2169
  value_type=vec2,
1776
2170
  group="Random",
1777
- doc="Uniformly sample a ring in the xy plane",
2171
+ doc="Uniformly sample a ring in the xy plane.",
1778
2172
  )
1779
2173
  add_builtin(
1780
2174
  "sample_unit_disk",
1781
2175
  input_types={"state": uint32},
1782
2176
  value_type=vec2,
1783
2177
  group="Random",
1784
- doc="Uniformly sample a disk in the xy plane",
2178
+ doc="Uniformly sample a disk in the xy plane.",
1785
2179
  )
1786
2180
  add_builtin(
1787
2181
  "sample_unit_sphere_surface",
1788
2182
  input_types={"state": uint32},
1789
2183
  value_type=vec3,
1790
2184
  group="Random",
1791
- doc="Uniformly sample a unit sphere surface",
2185
+ doc="Uniformly sample a unit sphere surface.",
1792
2186
  )
1793
2187
  add_builtin(
1794
2188
  "sample_unit_sphere",
1795
2189
  input_types={"state": uint32},
1796
2190
  value_type=vec3,
1797
2191
  group="Random",
1798
- doc="Uniformly sample a unit sphere",
2192
+ doc="Uniformly sample a unit sphere.",
1799
2193
  )
1800
2194
  add_builtin(
1801
2195
  "sample_unit_hemisphere_surface",
1802
2196
  input_types={"state": uint32},
1803
2197
  value_type=vec3,
1804
2198
  group="Random",
1805
- doc="Uniformly sample a unit hemisphere surface",
2199
+ doc="Uniformly sample a unit hemisphere surface.",
1806
2200
  )
1807
2201
  add_builtin(
1808
2202
  "sample_unit_hemisphere",
1809
2203
  input_types={"state": uint32},
1810
2204
  value_type=vec3,
1811
2205
  group="Random",
1812
- doc="Uniformly sample a unit hemisphere",
2206
+ doc="Uniformly sample a unit hemisphere.",
1813
2207
  )
1814
2208
  add_builtin(
1815
2209
  "sample_unit_square",
1816
2210
  input_types={"state": uint32},
1817
2211
  value_type=vec2,
1818
2212
  group="Random",
1819
- doc="Uniformly sample a unit square",
2213
+ doc="Uniformly sample a unit square.",
1820
2214
  )
1821
2215
  add_builtin(
1822
2216
  "sample_unit_cube",
1823
2217
  input_types={"state": uint32},
1824
2218
  value_type=vec3,
1825
2219
  group="Random",
1826
- doc="Uniformly sample a unit cube",
2220
+ doc="Uniformly sample a unit cube.",
1827
2221
  )
1828
2222
 
1829
2223
  add_builtin(
@@ -1832,9 +2226,9 @@ add_builtin(
1832
2226
  value_type=uint32,
1833
2227
  group="Random",
1834
2228
  doc="""Generate a random sample from a Poisson distribution.
1835
-
1836
- :param state: RNG state
1837
- :param lam: The expected value of the distribution""",
2229
+
2230
+ :param state: RNG state
2231
+ :param lam: The expected value of the distribution""",
1838
2232
  )
1839
2233
 
1840
2234
  add_builtin(
@@ -1842,28 +2236,28 @@ add_builtin(
1842
2236
  input_types={"state": uint32, "x": float},
1843
2237
  value_type=float,
1844
2238
  group="Random",
1845
- doc="Non-periodic Perlin-style noise in 1d.",
2239
+ doc="Non-periodic Perlin-style noise in 1D.",
1846
2240
  )
1847
2241
  add_builtin(
1848
2242
  "noise",
1849
2243
  input_types={"state": uint32, "xy": vec2},
1850
2244
  value_type=float,
1851
2245
  group="Random",
1852
- doc="Non-periodic Perlin-style noise in 2d.",
2246
+ doc="Non-periodic Perlin-style noise in 2D.",
1853
2247
  )
1854
2248
  add_builtin(
1855
2249
  "noise",
1856
2250
  input_types={"state": uint32, "xyz": vec3},
1857
2251
  value_type=float,
1858
2252
  group="Random",
1859
- doc="Non-periodic Perlin-style noise in 3d.",
2253
+ doc="Non-periodic Perlin-style noise in 3D.",
1860
2254
  )
1861
2255
  add_builtin(
1862
2256
  "noise",
1863
2257
  input_types={"state": uint32, "xyzt": vec4},
1864
2258
  value_type=float,
1865
2259
  group="Random",
1866
- doc="Non-periodic Perlin-style noise in 4d.",
2260
+ doc="Non-periodic Perlin-style noise in 4D.",
1867
2261
  )
1868
2262
 
1869
2263
  add_builtin(
@@ -1871,33 +2265,34 @@ add_builtin(
1871
2265
  input_types={"state": uint32, "x": float, "px": int},
1872
2266
  value_type=float,
1873
2267
  group="Random",
1874
- doc="Periodic Perlin-style noise in 1d.",
2268
+ doc="Periodic Perlin-style noise in 1D.",
1875
2269
  )
1876
2270
  add_builtin(
1877
2271
  "pnoise",
1878
2272
  input_types={"state": uint32, "xy": vec2, "px": int, "py": int},
1879
2273
  value_type=float,
1880
2274
  group="Random",
1881
- doc="Periodic Perlin-style noise in 2d.",
2275
+ doc="Periodic Perlin-style noise in 2D.",
1882
2276
  )
1883
2277
  add_builtin(
1884
2278
  "pnoise",
1885
2279
  input_types={"state": uint32, "xyz": vec3, "px": int, "py": int, "pz": int},
1886
2280
  value_type=float,
1887
2281
  group="Random",
1888
- doc="Periodic Perlin-style noise in 3d.",
2282
+ doc="Periodic Perlin-style noise in 3D.",
1889
2283
  )
1890
2284
  add_builtin(
1891
2285
  "pnoise",
1892
2286
  input_types={"state": uint32, "xyzt": vec4, "px": int, "py": int, "pz": int, "pt": int},
1893
2287
  value_type=float,
1894
2288
  group="Random",
1895
- doc="Periodic Perlin-style noise in 4d.",
2289
+ doc="Periodic Perlin-style noise in 4D.",
1896
2290
  )
1897
2291
 
1898
2292
  add_builtin(
1899
2293
  "curlnoise",
1900
- input_types={"state": uint32, "xy": vec2},
2294
+ input_types={"state": uint32, "xy": vec2, "octaves": uint32, "lacunarity": float, "gain": float},
2295
+ defaults={"octaves": 1, "lacunarity": 2.0, "gain": 0.5},
1901
2296
  value_type=vec2,
1902
2297
  group="Random",
1903
2298
  doc="Divergence-free vector field based on the gradient of a Perlin noise function.",
@@ -1905,7 +2300,8 @@ add_builtin(
1905
2300
  )
1906
2301
  add_builtin(
1907
2302
  "curlnoise",
1908
- input_types={"state": uint32, "xyz": vec3},
2303
+ input_types={"state": uint32, "xyz": vec3, "octaves": uint32, "lacunarity": float, "gain": float},
2304
+ defaults={"octaves": 1, "lacunarity": 2.0, "gain": 0.5},
1909
2305
  value_type=vec3,
1910
2306
  group="Random",
1911
2307
  doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
@@ -1913,7 +2309,8 @@ add_builtin(
1913
2309
  )
1914
2310
  add_builtin(
1915
2311
  "curlnoise",
1916
- input_types={"state": uint32, "xyzt": vec4},
2312
+ input_types={"state": uint32, "xyzt": vec4, "octaves": uint32, "lacunarity": float, "gain": float},
2313
+ defaults={"octaves": 1, "lacunarity": 2.0, "gain": 0.5},
1917
2314
  value_type=vec3,
1918
2315
  group="Random",
1919
2316
  doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
@@ -1927,7 +2324,7 @@ add_builtin(
1927
2324
  namespace="",
1928
2325
  variadic=True,
1929
2326
  group="Utility",
1930
- doc="Allows printing formatted strings, using C-style format specifiers.",
2327
+ doc="Allows printing formatted strings using C-style format specifiers.",
1931
2328
  )
1932
2329
 
1933
2330
  add_builtin("print", input_types={"value": Any}, doc="Print variable to stdout", export=False, group="Utility")
@@ -1947,9 +2344,12 @@ add_builtin(
1947
2344
  "tid",
1948
2345
  input_types={},
1949
2346
  value_type=int,
2347
+ export=False,
1950
2348
  group="Utility",
1951
- doc="""Return the current thread index. Note that this is the *global* index of the thread in the range [0, dim)
1952
- where dim is the parameter passed to kernel launch.""",
2349
+ doc="""Return the current thread index for a 1D kernel launch. Note that this is the *global* index of the thread in the range [0, dim)
2350
+ where dim is the parameter passed to kernel launch. This function may not be called from user-defined Warp functions.""",
2351
+ namespace="",
2352
+ native_func="builtin_tid1d",
1953
2353
  )
1954
2354
 
1955
2355
  add_builtin(
@@ -1957,7 +2357,10 @@ add_builtin(
1957
2357
  input_types={},
1958
2358
  value_type=[int, int],
1959
2359
  group="Utility",
1960
- doc="""Return the current thread indices for a 2d kernel launch. Use ``i,j = wp.tid()`` syntax to retrieve the coordinates inside the kernel thread grid.""",
2360
+ doc="""Return the current thread indices for a 2D kernel launch. Use ``i,j = wp.tid()`` syntax to retrieve the
2361
+ coordinates inside the kernel thread grid. This function may not be called from user-defined Warp functions.""",
2362
+ namespace="",
2363
+ native_func="builtin_tid2d",
1961
2364
  )
1962
2365
 
1963
2366
  add_builtin(
@@ -1965,7 +2368,10 @@ add_builtin(
1965
2368
  input_types={},
1966
2369
  value_type=[int, int, int],
1967
2370
  group="Utility",
1968
- doc="""Return the current thread indices for a 3d kernel launch. Use ``i,j,k = wp.tid()`` syntax to retrieve the coordinates inside the kernel thread grid.""",
2371
+ doc="""Return the current thread indices for a 3D kernel launch. Use ``i,j,k = wp.tid()`` syntax to retrieve the
2372
+ coordinates inside the kernel thread grid. This function may not be called from user-defined Warp functions.""",
2373
+ namespace="",
2374
+ native_func="builtin_tid3d",
1969
2375
  )
1970
2376
 
1971
2377
  add_builtin(
@@ -1973,42 +2379,60 @@ add_builtin(
1973
2379
  input_types={},
1974
2380
  value_type=[int, int, int, int],
1975
2381
  group="Utility",
1976
- doc="""Return the current thread indices for a 4d kernel launch. Use ``i,j,k,l = wp.tid()`` syntax to retrieve the coordinates inside the kernel thread grid.""",
2382
+ doc="""Return the current thread indices for a 4D kernel launch. Use ``i,j,k,l = wp.tid()`` syntax to retrieve the
2383
+ coordinates inside the kernel thread grid. This function may not be called from user-defined Warp functions.""",
2384
+ namespace="",
2385
+ native_func="builtin_tid4d",
1977
2386
  )
1978
2387
 
1979
2388
 
1980
- add_builtin("copy", variadic=True, hidden=True, export=False, group="Utility")
2389
+ add_builtin(
2390
+ "copy",
2391
+ input_types={"value": Any},
2392
+ value_func=lambda arg_types, kwds, _: arg_types[0],
2393
+ hidden=True,
2394
+ export=False,
2395
+ group="Utility",
2396
+ )
2397
+ add_builtin("assign", variadic=True, hidden=True, export=False, group="Utility")
1981
2398
  add_builtin(
1982
2399
  "select",
1983
2400
  input_types={"cond": bool, "arg1": Any, "arg2": Any},
2401
+ value_func=lambda arg_types, kwds, _: arg_types[1],
2402
+ doc="Select between two arguments, if ``cond`` is ``False`` then return ``arg1``, otherwise return ``arg2``",
2403
+ group="Utility",
2404
+ )
2405
+ add_builtin(
2406
+ "select",
2407
+ input_types={"cond": builtins.bool, "arg1": Any, "arg2": Any},
1984
2408
  value_func=lambda args, kwds, _: args[1].type,
1985
- doc="Select between two arguments, if cond is false then return ``arg1``, otherwise return ``arg2``",
2409
+ doc="Select between two arguments, if ``cond`` is ``False`` then return ``arg1``, otherwise return ``arg2``",
1986
2410
  group="Utility",
1987
2411
  )
1988
2412
  for t in int_types:
1989
2413
  add_builtin(
1990
2414
  "select",
1991
2415
  input_types={"cond": t, "arg1": Any, "arg2": Any},
1992
- value_func=lambda args, kwds, _: args[1].type,
1993
- doc="Select between two arguments, if cond is false then return ``arg1``, otherwise return ``arg2``",
2416
+ value_func=lambda arg_types, kwds, _: arg_types[1],
2417
+ doc="Select between two arguments, if ``cond`` is ``False`` then return ``arg1``, otherwise return ``arg2``",
1994
2418
  group="Utility",
1995
2419
  )
1996
2420
  add_builtin(
1997
2421
  "select",
1998
2422
  input_types={"arr": array(dtype=Any), "arg1": Any, "arg2": Any},
1999
- value_func=lambda args, kwds, _: args[1].type,
2000
- doc="Select between two arguments, if array is null then return ``arg1``, otherwise return ``arg2``",
2423
+ value_func=lambda arg_types, kwds, _: arg_types[1],
2424
+ doc="Select between two arguments, if ``arr`` is null then return ``arg1``, otherwise return ``arg2``",
2001
2425
  group="Utility",
2002
2426
  )
2003
2427
 
2004
2428
 
2005
- # does argument checking and type propagation for load()
2006
- def load_value_func(args, kwds, _):
2007
- if not is_array(args[0].type):
2429
+ # does argument checking and type propagation for address()
2430
+ def address_value_func(arg_types, kwds, _):
2431
+ if not is_array(arg_types[0]):
2008
2432
  raise RuntimeError("load() argument 0 must be an array")
2009
2433
 
2010
- num_indices = len(args[1:])
2011
- num_dims = args[0].type.ndim
2434
+ num_indices = len(arg_types[1:])
2435
+ num_dims = arg_types[0].ndim
2012
2436
 
2013
2437
  if num_indices < num_dims:
2014
2438
  raise RuntimeError(
@@ -2021,21 +2445,21 @@ def load_value_func(args, kwds, _):
2021
2445
  )
2022
2446
 
2023
2447
  # check index types
2024
- for a in args[1:]:
2025
- if type_is_int(a.type) == False:
2026
- raise RuntimeError(f"load() index arguments must be of integer type, got index of type {a.type}")
2448
+ for t in arg_types[1:]:
2449
+ if not type_is_int(t):
2450
+ raise RuntimeError(f"address() index arguments must be of integer type, got index of type {t}")
2027
2451
 
2028
- return args[0].type.dtype
2452
+ return Reference(arg_types[0].dtype)
2029
2453
 
2030
2454
 
2031
2455
  # does argument checking and type propagation for view()
2032
- def view_value_func(args, kwds, _):
2033
- if not is_array(args[0].type):
2456
+ def view_value_func(arg_types, kwds, _):
2457
+ if not is_array(arg_types[0]):
2034
2458
  raise RuntimeError("view() argument 0 must be an array")
2035
2459
 
2036
2460
  # check array dim big enough to support view
2037
- num_indices = len(args[1:])
2038
- num_dims = args[0].type.ndim
2461
+ num_indices = len(arg_types[1:])
2462
+ num_dims = arg_types[0].ndim
2039
2463
 
2040
2464
  if num_indices >= num_dims:
2041
2465
  raise RuntimeError(
@@ -2043,27 +2467,28 @@ def view_value_func(args, kwds, _):
2043
2467
  )
2044
2468
 
2045
2469
  # check index types
2046
- for a in args[1:]:
2047
- if type_is_int(a.type) == False:
2048
- raise RuntimeError(f"view() index arguments must be of integer type, got index of type {a.type}")
2470
+ for t in arg_types[1:]:
2471
+ if not type_is_int(t):
2472
+ raise RuntimeError(f"view() index arguments must be of integer type, got index of type {t}")
2049
2473
 
2050
2474
  # create an array view with leading dimensions removed
2051
- import copy
2052
-
2053
- view_type = copy.copy(args[0].type)
2054
- view_type.ndim -= num_indices
2055
-
2056
- return view_type
2475
+ dtype = arg_types[0].dtype
2476
+ ndim = num_dims - num_indices
2477
+ if isinstance(arg_types[0], (fabricarray, indexedfabricarray)):
2478
+ # fabric array of arrays: return array attribute as a regular array
2479
+ return array(dtype=dtype, ndim=ndim)
2480
+ else:
2481
+ return type(arg_types[0])(dtype=dtype, ndim=ndim)
2057
2482
 
2058
2483
 
2059
- # does argument checking and type propagation for store()
2060
- def store_value_func(args, kwds, _):
2484
+ # does argument checking and type propagation for array_store()
2485
+ def array_store_value_func(arg_types, kwds, _):
2061
2486
  # check target type
2062
- if not is_array(args[0].type):
2063
- raise RuntimeError("store() argument 0 must be an array")
2487
+ if not is_array(arg_types[0]):
2488
+ raise RuntimeError("array_store() argument 0 must be an array")
2064
2489
 
2065
- num_indices = len(args[1:-1])
2066
- num_dims = args[0].type.ndim
2490
+ num_indices = len(arg_types[1:-1])
2491
+ num_dims = arg_types[0].ndim
2067
2492
 
2068
2493
  # if this happens we should have generated a view instead of a load during code gen
2069
2494
  if num_indices < num_dims:
@@ -2075,31 +2500,63 @@ def store_value_func(args, kwds, _):
2075
2500
  )
2076
2501
 
2077
2502
  # check index types
2078
- for a in args[1:-1]:
2079
- if type_is_int(a.type) == False:
2080
- raise RuntimeError(f"store() index arguments must be of integer type, got index of type {a.type}")
2503
+ for t in arg_types[1:-1]:
2504
+ if not type_is_int(t):
2505
+ raise RuntimeError(f"array_store() index arguments must be of integer type, got index of type {t}")
2081
2506
 
2082
2507
  # check value type
2083
- if not types_equal(args[-1].type, args[0].type.dtype):
2508
+ if not types_equal(arg_types[-1], arg_types[0].dtype):
2084
2509
  raise RuntimeError(
2085
- f"store() value argument type ({args[2].type}) must be of the same type as the array ({args[0].type.dtype})"
2510
+ f"array_store() value argument type ({arg_types[2]}) must be of the same type as the array ({arg_types[0].dtype})"
2086
2511
  )
2087
2512
 
2088
2513
  return None
2089
2514
 
2090
2515
 
2091
- add_builtin("load", variadic=True, hidden=True, value_func=load_value_func, group="Utility")
2516
+ # does argument checking for store()
2517
+ def store_value_func(arg_types, kwds, _):
2518
+ # we already stripped the Reference from the argument type prior to this call
2519
+ if not types_equal(arg_types[0], arg_types[1]):
2520
+ raise RuntimeError(f"store() value argument type ({arg_types[1]}) must be of the same type as the reference")
2521
+
2522
+ return None
2523
+
2524
+
2525
+ # does type propagation for load()
2526
+ def load_value_func(arg_types, kwds, _):
2527
+ # we already stripped the Reference from the argument type prior to this call
2528
+ return arg_types[0]
2529
+
2530
+
2531
+ add_builtin("address", variadic=True, hidden=True, value_func=address_value_func, group="Utility")
2092
2532
  add_builtin("view", variadic=True, hidden=True, value_func=view_value_func, group="Utility")
2093
- add_builtin("store", variadic=True, hidden=True, value_func=store_value_func, skip_replay=True, group="Utility")
2533
+ add_builtin(
2534
+ "array_store", variadic=True, hidden=True, value_func=array_store_value_func, skip_replay=True, group="Utility"
2535
+ )
2536
+ add_builtin(
2537
+ "store",
2538
+ input_types={"address": Reference, "value": Any},
2539
+ hidden=True,
2540
+ value_func=store_value_func,
2541
+ skip_replay=True,
2542
+ group="Utility",
2543
+ )
2544
+ add_builtin(
2545
+ "load",
2546
+ input_types={"address": Reference},
2547
+ hidden=True,
2548
+ value_func=load_value_func,
2549
+ group="Utility",
2550
+ )
2094
2551
 
2095
2552
 
2096
- def atomic_op_value_func(args, kwds, _):
2553
+ def atomic_op_value_func(arg_types, kwds, _):
2097
2554
  # check target type
2098
- if not is_array(args[0].type):
2555
+ if not is_array(arg_types[0]):
2099
2556
  raise RuntimeError("atomic() operation argument 0 must be an array")
2100
2557
 
2101
- num_indices = len(args[1:-1])
2102
- num_dims = args[0].type.ndim
2558
+ num_indices = len(arg_types[1:-1])
2559
+ num_dims = arg_types[0].ndim
2103
2560
 
2104
2561
  # if this happens we should have generated a view instead of a load during code gen
2105
2562
  if num_indices < num_dims:
@@ -2111,18 +2568,16 @@ def atomic_op_value_func(args, kwds, _):
2111
2568
  )
2112
2569
 
2113
2570
  # check index types
2114
- for a in args[1:-1]:
2115
- if type_is_int(a.type) == False:
2116
- raise RuntimeError(
2117
- f"atomic() operation index arguments must be of integer type, got index of type {a.type}"
2118
- )
2571
+ for t in arg_types[1:-1]:
2572
+ if not type_is_int(t):
2573
+ raise RuntimeError(f"atomic() operation index arguments must be of integer type, got index of type {t}")
2119
2574
 
2120
- if not types_equal(args[-1].type, args[0].type.dtype):
2575
+ if not types_equal(arg_types[-1], arg_types[0].dtype):
2121
2576
  raise RuntimeError(
2122
- f"atomic() value argument ({args[-1].type}) must be of the same type as the array ({args[0].type.dtype})"
2577
+ f"atomic() value argument ({arg_types[-1]}) must be of the same type as the array ({arg_types[0].dtype})"
2123
2578
  )
2124
2579
 
2125
- return args[0].type.dtype
2580
+ return arg_types[0].dtype
2126
2581
 
2127
2582
 
2128
2583
  for array_type in array_types:
@@ -2134,7 +2589,7 @@ for array_type in array_types:
2134
2589
  hidden=hidden,
2135
2590
  input_types={"a": array_type(dtype=Any), "i": int, "value": Any},
2136
2591
  value_func=atomic_op_value_func,
2137
- doc="Atomically add ``value`` onto the array at location given by index.",
2592
+ doc="Atomically add ``value`` onto ``a[i]``.",
2138
2593
  group="Utility",
2139
2594
  skip_replay=True,
2140
2595
  )
@@ -2143,7 +2598,7 @@ for array_type in array_types:
2143
2598
  hidden=hidden,
2144
2599
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "value": Any},
2145
2600
  value_func=atomic_op_value_func,
2146
- doc="Atomically add ``value`` onto the array at location given by indices.",
2601
+ doc="Atomically add ``value`` onto ``a[i,j]``.",
2147
2602
  group="Utility",
2148
2603
  skip_replay=True,
2149
2604
  )
@@ -2152,7 +2607,7 @@ for array_type in array_types:
2152
2607
  hidden=hidden,
2153
2608
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
2154
2609
  value_func=atomic_op_value_func,
2155
- doc="Atomically add ``value`` onto the array at location given by indices.",
2610
+ doc="Atomically add ``value`` onto ``a[i,j,k]``.",
2156
2611
  group="Utility",
2157
2612
  skip_replay=True,
2158
2613
  )
@@ -2161,7 +2616,7 @@ for array_type in array_types:
2161
2616
  hidden=hidden,
2162
2617
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
2163
2618
  value_func=atomic_op_value_func,
2164
- doc="Atomically add ``value`` onto the array at location given by indices.",
2619
+ doc="Atomically add ``value`` onto ``a[i,j,k,l]``.",
2165
2620
  group="Utility",
2166
2621
  skip_replay=True,
2167
2622
  )
@@ -2171,7 +2626,7 @@ for array_type in array_types:
2171
2626
  hidden=hidden,
2172
2627
  input_types={"a": array_type(dtype=Any), "i": int, "value": Any},
2173
2628
  value_func=atomic_op_value_func,
2174
- doc="Atomically subtract ``value`` onto the array at location given by index.",
2629
+ doc="Atomically subtract ``value`` onto ``a[i]``.",
2175
2630
  group="Utility",
2176
2631
  skip_replay=True,
2177
2632
  )
@@ -2180,7 +2635,7 @@ for array_type in array_types:
2180
2635
  hidden=hidden,
2181
2636
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "value": Any},
2182
2637
  value_func=atomic_op_value_func,
2183
- doc="Atomically subtract ``value`` onto the array at location given by indices.",
2638
+ doc="Atomically subtract ``value`` onto ``a[i,j]``.",
2184
2639
  group="Utility",
2185
2640
  skip_replay=True,
2186
2641
  )
@@ -2189,7 +2644,7 @@ for array_type in array_types:
2189
2644
  hidden=hidden,
2190
2645
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
2191
2646
  value_func=atomic_op_value_func,
2192
- doc="Atomically subtract ``value`` onto the array at location given by indices.",
2647
+ doc="Atomically subtract ``value`` onto ``a[i,j,k]``.",
2193
2648
  group="Utility",
2194
2649
  skip_replay=True,
2195
2650
  )
@@ -2198,7 +2653,7 @@ for array_type in array_types:
2198
2653
  hidden=hidden,
2199
2654
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
2200
2655
  value_func=atomic_op_value_func,
2201
- doc="Atomically subtract ``value`` onto the array at location given by indices.",
2656
+ doc="Atomically subtract ``value`` onto ``a[i,j,k,l]``.",
2202
2657
  group="Utility",
2203
2658
  skip_replay=True,
2204
2659
  )
@@ -2208,7 +2663,8 @@ for array_type in array_types:
2208
2663
  hidden=hidden,
2209
2664
  input_types={"a": array_type(dtype=Any), "i": int, "value": Any},
2210
2665
  value_func=atomic_op_value_func,
2211
- doc="Compute the minimum of ``value`` and ``array[index]`` and atomically update the array. Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2666
+ doc="Compute the minimum of ``value`` and ``a[i]`` and atomically update the array.\n\n"
2667
+ "Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2212
2668
  group="Utility",
2213
2669
  skip_replay=True,
2214
2670
  )
@@ -2217,7 +2673,8 @@ for array_type in array_types:
2217
2673
  hidden=hidden,
2218
2674
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "value": Any},
2219
2675
  value_func=atomic_op_value_func,
2220
- doc="Compute the minimum of ``value`` and ``array[index]`` and atomically update the array. Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2676
+ doc="Compute the minimum of ``value`` and ``a[i,j]`` and atomically update the array.\n\n"
2677
+ "Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2221
2678
  group="Utility",
2222
2679
  skip_replay=True,
2223
2680
  )
@@ -2226,7 +2683,8 @@ for array_type in array_types:
2226
2683
  hidden=hidden,
2227
2684
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
2228
2685
  value_func=atomic_op_value_func,
2229
- doc="Compute the minimum of ``value`` and ``array[index]`` and atomically update the array. Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2686
+ doc="Compute the minimum of ``value`` and ``a[i,j,k]`` and atomically update the array.\n\n"
2687
+ "Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2230
2688
  group="Utility",
2231
2689
  skip_replay=True,
2232
2690
  )
@@ -2235,7 +2693,8 @@ for array_type in array_types:
2235
2693
  hidden=hidden,
2236
2694
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
2237
2695
  value_func=atomic_op_value_func,
2238
- doc="Compute the minimum of ``value`` and ``array[index]`` and atomically update the array. Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2696
+ doc="Compute the minimum of ``value`` and ``a[i,j,k,l]`` and atomically update the array.\n\n"
2697
+ "Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2239
2698
  group="Utility",
2240
2699
  skip_replay=True,
2241
2700
  )
@@ -2245,7 +2704,8 @@ for array_type in array_types:
2245
2704
  hidden=hidden,
2246
2705
  input_types={"a": array_type(dtype=Any), "i": int, "value": Any},
2247
2706
  value_func=atomic_op_value_func,
2248
- doc="Compute the maximum of ``value`` and ``array[index]`` and atomically update the array. Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2707
+ doc="Compute the maximum of ``value`` and ``a[i]`` and atomically update the array.\n\n"
2708
+ "Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2249
2709
  group="Utility",
2250
2710
  skip_replay=True,
2251
2711
  )
@@ -2254,7 +2714,8 @@ for array_type in array_types:
2254
2714
  hidden=hidden,
2255
2715
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "value": Any},
2256
2716
  value_func=atomic_op_value_func,
2257
- doc="Compute the maximum of ``value`` and ``array[index]`` and atomically update the array. Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2717
+ doc="Compute the maximum of ``value`` and ``a[i,j]`` and atomically update the array.\n\n"
2718
+ "Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2258
2719
  group="Utility",
2259
2720
  skip_replay=True,
2260
2721
  )
@@ -2263,7 +2724,8 @@ for array_type in array_types:
2263
2724
  hidden=hidden,
2264
2725
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
2265
2726
  value_func=atomic_op_value_func,
2266
- doc="Compute the maximum of ``value`` and ``array[index]`` and atomically update the array. Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2727
+ doc="Compute the maximum of ``value`` and ``a[i,j,k]`` and atomically update the array.\n\n"
2728
+ "Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2267
2729
  group="Utility",
2268
2730
  skip_replay=True,
2269
2731
  )
@@ -2272,26 +2734,27 @@ for array_type in array_types:
2272
2734
  hidden=hidden,
2273
2735
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
2274
2736
  value_func=atomic_op_value_func,
2275
- doc="Compute the maximum of ``value`` and ``array[index]`` and atomically update the array. Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2737
+ doc="Compute the maximum of ``value`` and ``a[i,j,k,l]`` and atomically update the array.\n\n"
2738
+ "Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2276
2739
  group="Utility",
2277
2740
  skip_replay=True,
2278
2741
  )
2279
2742
 
2280
2743
 
2281
2744
  # used to index into builtin types, i.e.: y = vec3[1]
2282
- def index_value_func(args, kwds, _):
2283
- return args[0].type._wp_scalar_type_
2745
+ def index_value_func(arg_types, kwds, _):
2746
+ return arg_types[0]._wp_scalar_type_
2284
2747
 
2285
2748
 
2286
2749
  add_builtin(
2287
- "index",
2750
+ "extract",
2288
2751
  input_types={"a": vector(length=Any, dtype=Scalar), "i": int},
2289
2752
  value_func=index_value_func,
2290
2753
  hidden=True,
2291
2754
  group="Utility",
2292
2755
  )
2293
2756
  add_builtin(
2294
- "index",
2757
+ "extract",
2295
2758
  input_types={"a": quaternion(dtype=Scalar), "i": int},
2296
2759
  value_func=index_value_func,
2297
2760
  hidden=True,
@@ -2299,14 +2762,14 @@ add_builtin(
2299
2762
  )
2300
2763
 
2301
2764
  add_builtin(
2302
- "index",
2765
+ "extract",
2303
2766
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int},
2304
- value_func=lambda args, kwds, _: vector(length=args[0].type._shape_[1], dtype=args[0].type._wp_scalar_type_),
2767
+ value_func=lambda arg_types, kwds, _: vector(length=arg_types[0]._shape_[1], dtype=arg_types[0]._wp_scalar_type_),
2305
2768
  hidden=True,
2306
2769
  group="Utility",
2307
2770
  )
2308
2771
  add_builtin(
2309
- "index",
2772
+ "extract",
2310
2773
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int},
2311
2774
  value_func=index_value_func,
2312
2775
  hidden=True,
@@ -2314,77 +2777,66 @@ add_builtin(
2314
2777
  )
2315
2778
 
2316
2779
  add_builtin(
2317
- "index",
2780
+ "extract",
2318
2781
  input_types={"a": transformation(dtype=Scalar), "i": int},
2319
2782
  value_func=index_value_func,
2320
2783
  hidden=True,
2321
2784
  group="Utility",
2322
2785
  )
2323
2786
 
2324
- add_builtin("index", input_types={"s": shape_t, "i": int}, value_type=int, hidden=True, group="Utility")
2787
+ add_builtin("extract", input_types={"s": shape_t, "i": int}, value_type=int, hidden=True, group="Utility")
2325
2788
 
2326
2789
 
2327
- def vector_indexset_element_value_func(args, kwds, _):
2328
- vec = args[0]
2329
- index = args[1]
2330
- value = args[2]
2790
+ def vector_indexref_element_value_func(arg_types, kwds, _):
2791
+ vec_type = arg_types[0]
2792
+ # index_type = arg_types[1]
2793
+ value_type = vec_type._wp_scalar_type_
2331
2794
 
2332
- if value.type is not vec.type._wp_scalar_type_:
2333
- raise RuntimeError(
2334
- f"Trying to assign type '{type_repr(value.type)}' to element of a vector with type '{type_repr(vec.type)}'"
2335
- )
2336
-
2337
- return None
2795
+ return Reference(value_type)
2338
2796
 
2339
2797
 
2340
- # implements vector[index] = value
2798
+ # implements &vector[index]
2341
2799
  add_builtin(
2342
- "indexset",
2343
- input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
2344
- value_func=vector_indexset_element_value_func,
2800
+ "index",
2801
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": int},
2802
+ value_func=vector_indexref_element_value_func,
2803
+ hidden=True,
2804
+ group="Utility",
2805
+ skip_replay=True,
2806
+ )
2807
+ # implements &(*vector)[index]
2808
+ add_builtin(
2809
+ "indexref",
2810
+ input_types={"a": Reference, "i": int},
2811
+ value_func=vector_indexref_element_value_func,
2345
2812
  hidden=True,
2346
2813
  group="Utility",
2347
2814
  skip_replay=True,
2348
2815
  )
2349
2816
 
2350
2817
 
2351
- def matrix_indexset_element_value_func(args, kwds, _):
2352
- mat = args[0]
2353
- row = args[1]
2354
- col = args[2]
2355
- value = args[3]
2356
-
2357
- if value.type is not mat.type._wp_scalar_type_:
2358
- raise RuntimeError(
2359
- f"Trying to assign type '{type_repr(value.type)}' to element of a matrix with type '{type_repr(mat.type)}'"
2360
- )
2361
-
2362
- return None
2363
-
2818
+ def matrix_indexref_element_value_func(arg_types, kwds, _):
2819
+ mat_type = arg_types[0]
2820
+ # row_type = arg_types[1]
2821
+ # col_type = arg_types[2]
2822
+ value_type = mat_type._wp_scalar_type_
2364
2823
 
2365
- def matrix_indexset_row_value_func(args, kwds, _):
2366
- mat = args[0]
2367
- row = args[1]
2368
- value = args[2]
2824
+ return Reference(value_type)
2369
2825
 
2370
- if value.type._shape_[0] != mat.type._shape_[1]:
2371
- raise RuntimeError(
2372
- f"Trying to assign vector with length {value.type._length} to matrix with shape {mat.type._shape}, vector length must match the number of matrix columns."
2373
- )
2374
2826
 
2375
- if value.type._wp_scalar_type_ is not mat.type._wp_scalar_type_:
2376
- raise RuntimeError(
2377
- f"Trying to assign vector of type '{type_repr(value.type)}' to row of matrix of type '{type_repr(mat.type)}'"
2378
- )
2827
+ def matrix_indexref_row_value_func(arg_types, kwds, _):
2828
+ mat_type = arg_types[0]
2829
+ row_type = mat_type._wp_row_type_
2830
+ # value_type = arg_types[2]
2379
2831
 
2380
- return None
2832
+ return Reference(row_type)
2381
2833
 
2382
2834
 
2383
2835
  # implements matrix[i] = row
2384
2836
  add_builtin(
2385
- "indexset",
2386
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
2387
- value_func=matrix_indexset_row_value_func,
2837
+ "index",
2838
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int},
2839
+ value_func=matrix_indexref_row_value_func,
2388
2840
  hidden=True,
2389
2841
  group="Utility",
2390
2842
  skip_replay=True,
@@ -2392,29 +2844,29 @@ add_builtin(
2392
2844
 
2393
2845
  # implements matrix[i,j] = scalar
2394
2846
  add_builtin(
2395
- "indexset",
2396
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
2397
- value_func=matrix_indexset_element_value_func,
2847
+ "index",
2848
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int},
2849
+ value_func=matrix_indexref_element_value_func,
2398
2850
  hidden=True,
2399
2851
  group="Utility",
2400
2852
  skip_replay=True,
2401
2853
  )
2402
2854
 
2403
- for t in scalar_types + vector_types:
2855
+ for t in scalar_types + vector_types + [builtins.bool]:
2404
2856
  if "vec" in t.__name__ or "mat" in t.__name__:
2405
2857
  continue
2406
2858
  add_builtin(
2407
2859
  "expect_eq",
2408
2860
  input_types={"arg1": t, "arg2": t},
2409
2861
  value_type=None,
2410
- doc="Prints an error to stdout if arg1 and arg2 are not equal",
2862
+ doc="Prints an error to stdout if ``arg1`` and ``arg2`` are not equal",
2411
2863
  group="Utility",
2412
2864
  hidden=True,
2413
2865
  )
2414
2866
 
2415
2867
 
2416
- def expect_eq_val_func(args, kwds, _):
2417
- if not types_equal(args[0].type, args[1].type):
2868
+ def expect_eq_val_func(arg_types, kwds, _):
2869
+ if not types_equal(arg_types[0], arg_types[1]):
2418
2870
  raise RuntimeError("Can't test equality for objects with different types")
2419
2871
  return None
2420
2872
 
@@ -2423,7 +2875,7 @@ add_builtin(
2423
2875
  "expect_eq",
2424
2876
  input_types={"arg1": vector(length=Any, dtype=Scalar), "arg2": vector(length=Any, dtype=Scalar)},
2425
2877
  value_func=expect_eq_val_func,
2426
- doc="Prints an error to stdout if arg1 and arg2 are not equal",
2878
+ doc="Prints an error to stdout if ``arg1`` and ``arg2`` are not equal",
2427
2879
  group="Utility",
2428
2880
  hidden=True,
2429
2881
  )
@@ -2431,7 +2883,7 @@ add_builtin(
2431
2883
  "expect_neq",
2432
2884
  input_types={"arg1": vector(length=Any, dtype=Scalar), "arg2": vector(length=Any, dtype=Scalar)},
2433
2885
  value_func=expect_eq_val_func,
2434
- doc="Prints an error to stdout if arg1 and arg2 are equal",
2886
+ doc="Prints an error to stdout if ``arg1`` and ``arg2`` are equal",
2435
2887
  group="Utility",
2436
2888
  hidden=True,
2437
2889
  )
@@ -2440,7 +2892,7 @@ add_builtin(
2440
2892
  "expect_eq",
2441
2893
  input_types={"arg1": matrix(shape=(Any, Any), dtype=Scalar), "arg2": matrix(shape=(Any, Any), dtype=Scalar)},
2442
2894
  value_func=expect_eq_val_func,
2443
- doc="Prints an error to stdout if arg1 and arg2 are not equal",
2895
+ doc="Prints an error to stdout if ``arg1`` and ``arg2`` are not equal",
2444
2896
  group="Utility",
2445
2897
  hidden=True,
2446
2898
  )
@@ -2448,7 +2900,7 @@ add_builtin(
2448
2900
  "expect_neq",
2449
2901
  input_types={"arg1": matrix(shape=(Any, Any), dtype=Scalar), "arg2": matrix(shape=(Any, Any), dtype=Scalar)},
2450
2902
  value_func=expect_eq_val_func,
2451
- doc="Prints an error to stdout if arg1 and arg2 are equal",
2903
+ doc="Prints an error to stdout if ``arg1`` and ``arg2`` are equal",
2452
2904
  group="Utility",
2453
2905
  hidden=True,
2454
2906
  )
@@ -2457,29 +2909,30 @@ add_builtin(
2457
2909
  "lerp",
2458
2910
  input_types={"a": Float, "b": Float, "t": Float},
2459
2911
  value_func=sametype_value_func(Float),
2460
- doc="Linearly interpolate two values a and b using factor t, computed as ``a*(1-t) + b*t``",
2912
+ doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
2461
2913
  group="Utility",
2462
2914
  )
2463
2915
  add_builtin(
2464
2916
  "smoothstep",
2465
2917
  input_types={"edge0": Float, "edge1": Float, "x": Float},
2466
2918
  value_func=sametype_value_func(Float),
2467
- doc="Smoothly interpolate between two values edge0 and edge1 using a factor x, and return a result between 0 and 1 using a cubic Hermite interpolation after clamping",
2919
+ doc="""Smoothly interpolate between two values ``edge0`` and ``edge1`` using a factor ``x``,
2920
+ and return a result between 0 and 1 using a cubic Hermite interpolation after clamping.""",
2468
2921
  group="Utility",
2469
2922
  )
2470
2923
 
2471
2924
 
2472
2925
  def lerp_value_func(default):
2473
- def fn(args, kwds, _):
2474
- if args is None:
2926
+ def fn(arg_types, kwds, _):
2927
+ if arg_types is None:
2475
2928
  return default
2476
- scalar_type = args[-1].type
2477
- if not types_equal(args[0].type, args[1].type):
2929
+ scalar_type = arg_types[-1]
2930
+ if not types_equal(arg_types[0], arg_types[1]):
2478
2931
  raise RuntimeError("Can't lerp between objects with different types")
2479
- if args[0].type._wp_scalar_type_ != scalar_type:
2932
+ if arg_types[0]._wp_scalar_type_ != scalar_type:
2480
2933
  raise RuntimeError("'t' parameter must have the same scalar type as objects you're lerping between")
2481
2934
 
2482
- return args[0].type
2935
+ return arg_types[0]
2483
2936
 
2484
2937
  return fn
2485
2938
 
@@ -2488,28 +2941,28 @@ add_builtin(
2488
2941
  "lerp",
2489
2942
  input_types={"a": vector(length=Any, dtype=Float), "b": vector(length=Any, dtype=Float), "t": Float},
2490
2943
  value_func=lerp_value_func(vector(length=Any, dtype=Float)),
2491
- doc="Linearly interpolate two values a and b using factor t, computed as ``a*(1-t) + b*t``",
2944
+ doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
2492
2945
  group="Utility",
2493
2946
  )
2494
2947
  add_builtin(
2495
2948
  "lerp",
2496
2949
  input_types={"a": matrix(shape=(Any, Any), dtype=Float), "b": matrix(shape=(Any, Any), dtype=Float), "t": Float},
2497
2950
  value_func=lerp_value_func(matrix(shape=(Any, Any), dtype=Float)),
2498
- doc="Linearly interpolate two values a and b using factor t, computed as ``a*(1-t) + b*t``",
2951
+ doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
2499
2952
  group="Utility",
2500
2953
  )
2501
2954
  add_builtin(
2502
2955
  "lerp",
2503
2956
  input_types={"a": quaternion(dtype=Float), "b": quaternion(dtype=Float), "t": Float},
2504
2957
  value_func=lerp_value_func(quaternion(dtype=Float)),
2505
- doc="Linearly interpolate two values a and b using factor t, computed as ``a*(1-t) + b*t``",
2958
+ doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
2506
2959
  group="Utility",
2507
2960
  )
2508
2961
  add_builtin(
2509
2962
  "lerp",
2510
2963
  input_types={"a": transformation(dtype=Float), "b": transformation(dtype=Float), "t": Float},
2511
2964
  value_func=lerp_value_func(transformation(dtype=Float)),
2512
- doc="Linearly interpolate two values a and b using factor t, computed as ``a*(1-t) + b*t``",
2965
+ doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
2513
2966
  group="Utility",
2514
2967
  )
2515
2968
 
@@ -2519,14 +2972,14 @@ add_builtin(
2519
2972
  input_types={"arg1": Float, "arg2": Float, "tolerance": Float},
2520
2973
  defaults={"tolerance": 1.0e-6},
2521
2974
  value_type=None,
2522
- doc="Prints an error to stdout if arg1 and arg2 are not closer than tolerance in magnitude",
2975
+ doc="Prints an error to stdout if ``arg1`` and ``arg2`` are not closer than tolerance in magnitude",
2523
2976
  group="Utility",
2524
2977
  )
2525
2978
  add_builtin(
2526
2979
  "expect_near",
2527
2980
  input_types={"arg1": vec3, "arg2": vec3, "tolerance": float},
2528
2981
  value_type=None,
2529
- doc="Prints an error to stdout if any element of arg1 and arg2 are not closer than tolerance in magnitude",
2982
+ doc="Prints an error to stdout if any element of ``arg1`` and ``arg2`` are not closer than tolerance in magnitude",
2530
2983
  group="Utility",
2531
2984
  )
2532
2985
 
@@ -2537,7 +2990,14 @@ add_builtin(
2537
2990
  "lower_bound",
2538
2991
  input_types={"arr": array(dtype=Scalar), "value": Scalar},
2539
2992
  value_type=int,
2540
- doc="Search a sorted array for the closest element greater than or equal to value.",
2993
+ doc="Search a sorted array ``arr`` for the closest element greater than or equal to ``value``.",
2994
+ )
2995
+
2996
+ add_builtin(
2997
+ "lower_bound",
2998
+ input_types={"arr": array(dtype=Scalar), "arr_begin": int, "arr_end": int, "value": Scalar},
2999
+ value_type=int,
3000
+ doc="Search a sorted array ``arr`` in the range [arr_begin, arr_end) for the closest element greater than or equal to ``value``.",
2541
3001
  )
2542
3002
 
2543
3003
  # ---------------------------------
@@ -2617,11 +3077,11 @@ add_builtin("invert", input_types={"x": Int}, value_func=sametype_value_func(Int
2617
3077
 
2618
3078
 
2619
3079
  def scalar_mul_value_func(default):
2620
- def fn(args, kwds, _):
2621
- if args is None:
3080
+ def fn(arg_types, kwds, _):
3081
+ if arg_types is None:
2622
3082
  return default
2623
- scalar = [a.type for a in args if a.type in scalar_types][0]
2624
- compound = [a.type for a in args if a.type not in scalar_types][0]
3083
+ scalar = [t for t in arg_types if t in scalar_types][0]
3084
+ compound = [t for t in arg_types if t not in scalar_types][0]
2625
3085
  if scalar != compound._wp_scalar_type_:
2626
3086
  raise RuntimeError("Object and coefficient must have the same scalar type when multiplying by scalar")
2627
3087
  return compound
@@ -2629,36 +3089,53 @@ def scalar_mul_value_func(default):
2629
3089
  return fn
2630
3090
 
2631
3091
 
2632
- def mul_matvec_value_func(args, kwds, _):
2633
- if args is None:
3092
+ def mul_matvec_value_func(arg_types, kwds, _):
3093
+ if arg_types is None:
3094
+ return vector(length=Any, dtype=Scalar)
3095
+
3096
+ if arg_types[0]._wp_scalar_type_ != arg_types[1]._wp_scalar_type_:
3097
+ raise RuntimeError(
3098
+ f"Can't multiply matrix and vector with different types {arg_types[0]._wp_scalar_type_}, {arg_types[1]._wp_scalar_type_}"
3099
+ )
3100
+
3101
+ if arg_types[0]._shape_[1] != arg_types[1]._length_:
3102
+ raise RuntimeError(
3103
+ f"Can't multiply matrix of shape {arg_types[0]._shape_} and vector with length {arg_types[1]._length_}"
3104
+ )
3105
+
3106
+ return vector(length=arg_types[0]._shape_[0], dtype=arg_types[0]._wp_scalar_type_)
3107
+
3108
+
3109
+ def mul_vecmat_value_func(arg_types, kwds, _):
3110
+ if arg_types is None:
2634
3111
  return vector(length=Any, dtype=Scalar)
2635
3112
 
2636
- if args[0].type._wp_scalar_type_ != args[1].type._wp_scalar_type_:
3113
+ if arg_types[1]._wp_scalar_type_ != arg_types[0]._wp_scalar_type_:
2637
3114
  raise RuntimeError(
2638
- f"Can't multiply matrix and vector with different types {args[0].type._wp_scalar_type_}, {args[1].type._wp_scalar_type_}"
3115
+ f"Can't multiply vector and matrix with different types {arg_types[1]._wp_scalar_type_}, {arg_types[0]._wp_scalar_type_}"
2639
3116
  )
2640
3117
 
2641
- if args[0].type._shape_[1] != args[1].type._length_:
3118
+ if arg_types[1]._shape_[0] != arg_types[0]._length_:
2642
3119
  raise RuntimeError(
2643
- f"Can't multiply matrix of shape {args[0].type._shape_} and vector with length {args[1].type._length_}"
3120
+ f"Can't multiply vector with length {arg_types[0]._length_} and matrix of shape {arg_types[1]._shape_}"
2644
3121
  )
2645
3122
 
2646
- return vector(length=args[0].type._shape_[0], dtype=args[0].type._wp_scalar_type_)
3123
+ return vector(length=arg_types[1]._shape_[1], dtype=arg_types[1]._wp_scalar_type_)
2647
3124
 
2648
3125
 
2649
- def mul_matmat_value_func(args, kwds, _):
2650
- if args is None:
3126
+ def mul_matmat_value_func(arg_types, kwds, _):
3127
+ if arg_types is None:
2651
3128
  return matrix(length=Any, dtype=Scalar)
2652
3129
 
2653
- if args[0].type._wp_scalar_type_ != args[1].type._wp_scalar_type_:
3130
+ if arg_types[0]._wp_scalar_type_ != arg_types[1]._wp_scalar_type_:
2654
3131
  raise RuntimeError(
2655
- f"Can't multiply matrices with different types {args[0].type._wp_scalar_type_}, {args[1].type._wp_scalar_type_}"
3132
+ f"Can't multiply matrices with different types {arg_types[0]._wp_scalar_type_}, {arg_types[1]._wp_scalar_type_}"
2656
3133
  )
2657
3134
 
2658
- if args[0].type._shape_[1] != args[1].type._shape_[0]:
2659
- raise RuntimeError(f"Can't multiply matrix of shapes {args[0].type._shape_} and {args[1].type._shape_}")
3135
+ if arg_types[0]._shape_[1] != arg_types[1]._shape_[0]:
3136
+ raise RuntimeError(f"Can't multiply matrix of shapes {arg_types[0]._shape_} and {arg_types[1]._shape_}")
2660
3137
 
2661
- return matrix(shape=(args[0].type._shape_[0], args[1].type._shape_[1]), dtype=args[0].type._wp_scalar_type_)
3138
+ return matrix(shape=(arg_types[0]._shape_[0], arg_types[1]._shape_[1]), dtype=arg_types[0]._wp_scalar_type_)
2662
3139
 
2663
3140
 
2664
3141
  add_builtin(
@@ -2720,6 +3197,13 @@ add_builtin(
2720
3197
  doc="",
2721
3198
  group="Operators",
2722
3199
  )
3200
+ add_builtin(
3201
+ "mul",
3202
+ input_types={"x": vector(length=Any, dtype=Scalar), "y": matrix(shape=(Any, Any), dtype=Scalar)},
3203
+ value_func=mul_vecmat_value_func,
3204
+ doc="",
3205
+ group="Operators",
3206
+ )
2723
3207
  add_builtin(
2724
3208
  "mul",
2725
3209
  input_types={"x": matrix(shape=(Any, Any), dtype=Scalar), "y": matrix(shape=(Any, Any), dtype=Scalar)},
@@ -2755,7 +3239,12 @@ add_builtin(
2755
3239
  )
2756
3240
 
2757
3241
  add_builtin(
2758
- "div", input_types={"x": Scalar, "y": Scalar}, value_func=sametype_value_func(Scalar), doc="", group="Operators"
3242
+ "div",
3243
+ input_types={"x": Scalar, "y": Scalar},
3244
+ value_func=sametype_value_func(Scalar),
3245
+ doc="",
3246
+ group="Operators",
3247
+ require_original_output_arg=True,
2759
3248
  )
2760
3249
  add_builtin(
2761
3250
  "div",
@@ -2764,6 +3253,13 @@ add_builtin(
2764
3253
  doc="",
2765
3254
  group="Operators",
2766
3255
  )
3256
+ add_builtin(
3257
+ "div",
3258
+ input_types={"x": Scalar, "y": vector(length=Any, dtype=Scalar)},
3259
+ value_func=scalar_mul_value_func(vector(length=Any, dtype=Scalar)),
3260
+ doc="",
3261
+ group="Operators",
3262
+ )
2767
3263
  add_builtin(
2768
3264
  "div",
2769
3265
  input_types={"x": matrix(shape=(Any, Any), dtype=Scalar), "y": Scalar},
@@ -2771,6 +3267,13 @@ add_builtin(
2771
3267
  doc="",
2772
3268
  group="Operators",
2773
3269
  )
3270
+ add_builtin(
3271
+ "div",
3272
+ input_types={"x": Scalar, "y": matrix(shape=(Any, Any), dtype=Scalar)},
3273
+ value_func=scalar_mul_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
3274
+ doc="",
3275
+ group="Operators",
3276
+ )
2774
3277
  add_builtin(
2775
3278
  "div",
2776
3279
  input_types={"x": quaternion(dtype=Scalar), "y": Scalar},
@@ -2778,6 +3281,13 @@ add_builtin(
2778
3281
  doc="",
2779
3282
  group="Operators",
2780
3283
  )
3284
+ add_builtin(
3285
+ "div",
3286
+ input_types={"x": Scalar, "y": quaternion(dtype=Scalar)},
3287
+ value_func=scalar_mul_value_func(quaternion(dtype=Scalar)),
3288
+ doc="",
3289
+ group="Operators",
3290
+ )
2781
3291
 
2782
3292
  add_builtin(
2783
3293
  "floordiv",
@@ -2832,9 +3342,9 @@ add_builtin(
2832
3342
  group="Operators",
2833
3343
  )
2834
3344
 
2835
- add_builtin("unot", input_types={"b": bool}, value_type=bool, doc="", group="Operators")
3345
+ add_builtin("unot", input_types={"b": builtins.bool}, value_type=builtins.bool, doc="", group="Operators")
2836
3346
  for t in int_types:
2837
- add_builtin("unot", input_types={"b": t}, value_type=bool, doc="", group="Operators")
3347
+ add_builtin("unot", input_types={"b": t}, value_type=builtins.bool, doc="", group="Operators")
2838
3348
 
2839
3349
 
2840
- add_builtin("unot", input_types={"a": array(dtype=Any)}, value_type=bool, doc="", group="Operators")
3350
+ add_builtin("unot", input_types={"a": array(dtype=Any)}, value_type=builtins.bool, doc="", group="Operators")