warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/builtins.py CHANGED
@@ -5,24 +5,22 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- from .context import add_builtin
8
+ import builtins
9
+ from typing import Any, Callable, Tuple
9
10
 
11
+ from warp.codegen import Reference
10
12
  from warp.types import *
11
13
 
12
- from typing import Tuple
13
- from typing import List
14
- from typing import Dict
15
- from typing import Any
16
- from typing import Callable
14
+ from .context import add_builtin
17
15
 
18
16
 
19
17
  def sametype_value_func(default):
20
- def fn(args, kwds, _):
21
- if args is None:
18
+ def fn(arg_types, kwds, _):
19
+ if arg_types is None:
22
20
  return default
23
- if not all(types_equal(args[0].type, a.type) for a in args[1:]):
24
- raise RuntimeError(f"Input types must be the same, found: {[type_repr(a.type) for a in args]}")
25
- return args[0].type
21
+ if not all(types_equal(arg_types[0], t) for t in arg_types[1:]):
22
+ raise RuntimeError(f"Input types must be the same, found: {[type_repr(t) for t in arg_types]}")
23
+ return arg_types[0]
26
24
 
27
25
  return fn
28
26
 
@@ -50,7 +48,7 @@ add_builtin(
50
48
  "clamp",
51
49
  input_types={"x": Scalar, "a": Scalar, "b": Scalar},
52
50
  value_func=sametype_value_func(Scalar),
53
- doc="Clamp the value of x to the range [a, b].",
51
+ doc="Clamp the value of ``x`` to the range [a, b].",
54
52
  group="Scalar Math",
55
53
  )
56
54
 
@@ -58,14 +56,14 @@ add_builtin(
58
56
  "abs",
59
57
  input_types={"x": Scalar},
60
58
  value_func=sametype_value_func(Scalar),
61
- doc="Return the absolute value of x.",
59
+ doc="Return the absolute value of ``x``.",
62
60
  group="Scalar Math",
63
61
  )
64
62
  add_builtin(
65
63
  "sign",
66
64
  input_types={"x": Scalar},
67
65
  value_func=sametype_value_func(Scalar),
68
- doc="Return -1 if x < 0, return 1 otherwise.",
66
+ doc="Return -1 if ``x`` < 0, return 1 otherwise.",
69
67
  group="Scalar Math",
70
68
  )
71
69
 
@@ -73,14 +71,14 @@ add_builtin(
73
71
  "step",
74
72
  input_types={"x": Scalar},
75
73
  value_func=sametype_value_func(Scalar),
76
- doc="Return 1.0 if x < 0.0, return 0.0 otherwise.",
74
+ doc="Return 1.0 if ``x`` < 0.0, return 0.0 otherwise.",
77
75
  group="Scalar Math",
78
76
  )
79
77
  add_builtin(
80
78
  "nonzero",
81
79
  input_types={"x": Scalar},
82
80
  value_func=sametype_value_func(Scalar),
83
- doc="Return 1.0 if x is not equal to zero, return 0.0 otherwise.",
81
+ doc="Return 1.0 if ``x`` is not equal to zero, return 0.0 otherwise.",
84
82
  group="Scalar Math",
85
83
  )
86
84
 
@@ -88,91 +86,101 @@ add_builtin(
88
86
  "sin",
89
87
  input_types={"x": Float},
90
88
  value_func=sametype_value_func(Float),
91
- doc="Return the sine of x in radians.",
89
+ doc="Return the sine of ``x`` in radians.",
92
90
  group="Scalar Math",
93
91
  )
94
92
  add_builtin(
95
93
  "cos",
96
94
  input_types={"x": Float},
97
95
  value_func=sametype_value_func(Float),
98
- doc="Return the cosine of x in radians.",
96
+ doc="Return the cosine of ``x`` in radians.",
99
97
  group="Scalar Math",
100
98
  )
101
99
  add_builtin(
102
100
  "acos",
103
101
  input_types={"x": Float},
104
102
  value_func=sametype_value_func(Float),
105
- doc="Return arccos of x in radians. Inputs are automatically clamped to [-1.0, 1.0].",
103
+ doc="Return arccos of ``x`` in radians. Inputs are automatically clamped to [-1.0, 1.0].",
106
104
  group="Scalar Math",
107
105
  )
108
106
  add_builtin(
109
107
  "asin",
110
108
  input_types={"x": Float},
111
109
  value_func=sametype_value_func(Float),
112
- doc="Return arcsin of x in radians. Inputs are automatically clamped to [-1.0, 1.0].",
110
+ doc="Return arcsin of ``x`` in radians. Inputs are automatically clamped to [-1.0, 1.0].",
113
111
  group="Scalar Math",
114
112
  )
115
113
  add_builtin(
116
114
  "sqrt",
117
115
  input_types={"x": Float},
118
116
  value_func=sametype_value_func(Float),
119
- doc="Return the sqrt of x, where x is positive.",
117
+ doc="Return the square root of ``x``, where ``x`` is positive.",
120
118
  group="Scalar Math",
119
+ require_original_output_arg=True,
120
+ )
121
+ add_builtin(
122
+ "cbrt",
123
+ input_types={"x": Float},
124
+ value_func=sametype_value_func(Float),
125
+ doc="Return the cube root of ``x``.",
126
+ group="Scalar Math",
127
+ require_original_output_arg=True,
121
128
  )
122
129
  add_builtin(
123
130
  "tan",
124
131
  input_types={"x": Float},
125
132
  value_func=sametype_value_func(Float),
126
- doc="Return tangent of x in radians.",
133
+ doc="Return the tangent of ``x`` in radians.",
127
134
  group="Scalar Math",
128
135
  )
129
136
  add_builtin(
130
137
  "atan",
131
138
  input_types={"x": Float},
132
139
  value_func=sametype_value_func(Float),
133
- doc="Return arctan of x.",
140
+ doc="Return the arctangent of ``x`` in radians.",
134
141
  group="Scalar Math",
135
142
  )
136
143
  add_builtin(
137
144
  "atan2",
138
145
  input_types={"y": Float, "x": Float},
139
146
  value_func=sametype_value_func(Float),
140
- doc="Return atan2 of x.",
147
+ doc="Return the 2-argument arctangent, atan2, of the point ``(x, y)`` in radians.",
141
148
  group="Scalar Math",
142
149
  )
143
150
  add_builtin(
144
151
  "sinh",
145
152
  input_types={"x": Float},
146
153
  value_func=sametype_value_func(Float),
147
- doc="Return the sinh of x.",
154
+ doc="Return the sinh of ``x``.",
148
155
  group="Scalar Math",
149
156
  )
150
157
  add_builtin(
151
158
  "cosh",
152
159
  input_types={"x": Float},
153
160
  value_func=sametype_value_func(Float),
154
- doc="Return the cosh of x.",
161
+ doc="Return the cosh of ``x``.",
155
162
  group="Scalar Math",
156
163
  )
157
164
  add_builtin(
158
165
  "tanh",
159
166
  input_types={"x": Float},
160
167
  value_func=sametype_value_func(Float),
161
- doc="Return the tanh of x.",
168
+ doc="Return the tanh of ``x``.",
162
169
  group="Scalar Math",
170
+ require_original_output_arg=True,
163
171
  )
164
172
  add_builtin(
165
173
  "degrees",
166
174
  input_types={"x": Float},
167
175
  value_func=sametype_value_func(Float),
168
- doc="Convert radians into degrees.",
176
+ doc="Convert ``x`` from radians into degrees.",
169
177
  group="Scalar Math",
170
178
  )
171
179
  add_builtin(
172
180
  "radians",
173
181
  input_types={"x": Float},
174
182
  value_func=sametype_value_func(Float),
175
- doc="Convert degrees into radians.",
183
+ doc="Convert ``x`` from degrees into radians.",
176
184
  group="Scalar Math",
177
185
  )
178
186
 
@@ -180,36 +188,38 @@ add_builtin(
180
188
  "log",
181
189
  input_types={"x": Float},
182
190
  value_func=sametype_value_func(Float),
183
- doc="Return the natural log (base-e) of x, where x is positive.",
191
+ doc="Return the natural logarithm (base-e) of ``x``, where ``x`` is positive.",
184
192
  group="Scalar Math",
185
193
  )
186
194
  add_builtin(
187
195
  "log2",
188
196
  input_types={"x": Float},
189
197
  value_func=sametype_value_func(Float),
190
- doc="Return the natural log (base-2) of x, where x is positive.",
198
+ doc="Return the binary logarithm (base-2) of ``x``, where ``x`` is positive.",
191
199
  group="Scalar Math",
192
200
  )
193
201
  add_builtin(
194
202
  "log10",
195
203
  input_types={"x": Float},
196
204
  value_func=sametype_value_func(Float),
197
- doc="Return the natural log (base-10) of x, where x is positive.",
205
+ doc="Return the common logarithm (base-10) of ``x``, where ``x`` is positive.",
198
206
  group="Scalar Math",
199
207
  )
200
208
  add_builtin(
201
209
  "exp",
202
210
  input_types={"x": Float},
203
211
  value_func=sametype_value_func(Float),
204
- doc="Return base-e exponential, e^x.",
212
+ doc="Return the value of the exponential function :math:`e^x`.",
205
213
  group="Scalar Math",
214
+ require_original_output_arg=True,
206
215
  )
207
216
  add_builtin(
208
217
  "pow",
209
218
  input_types={"x": Float, "y": Float},
210
219
  value_func=sametype_value_func(Float),
211
- doc="Return the result of x raised to power of y.",
220
+ doc="Return the result of ``x`` raised to power of ``y``.",
212
221
  group="Scalar Math",
222
+ require_original_output_arg=True,
213
223
  )
214
224
 
215
225
  add_builtin(
@@ -217,9 +227,9 @@ add_builtin(
217
227
  input_types={"x": Float},
218
228
  value_func=sametype_value_func(Float),
219
229
  group="Scalar Math",
220
- doc="""Calculate the nearest integer value, rounding halfway cases away from zero.
221
- This is the most intuitive form of rounding in the colloquial sense, but can be slower than other options like ``warp.rint()``.
222
- Differs from ``numpy.round()``, which behaves the same way as ``numpy.rint()``.""",
230
+ doc="""Return the nearest integer value to ``x``, rounding halfway cases away from zero.
231
+ This is the most intuitive form of rounding in the colloquial sense, but can be slower than other options like :func:`warp.rint()`.
232
+ Differs from :func:`numpy.round()`, which behaves the same way as :func:`numpy.rint()`.""",
223
233
  )
224
234
 
225
235
  add_builtin(
@@ -227,9 +237,8 @@ add_builtin(
227
237
  input_types={"x": Float},
228
238
  value_func=sametype_value_func(Float),
229
239
  group="Scalar Math",
230
- doc="""Calculate the nearest integer value, rounding halfway cases to nearest even integer.
231
- It is generally faster than ``warp.round()``.
232
- Equivalent to ``numpy.rint()``.""",
240
+ doc="""Return the nearest integer value to ``x``, rounding halfway cases to nearest even integer.
241
+ It is generally faster than :func:`warp.round()`. Equivalent to :func:`numpy.rint()`.""",
233
242
  )
234
243
 
235
244
  add_builtin(
@@ -237,10 +246,10 @@ add_builtin(
237
246
  input_types={"x": Float},
238
247
  value_func=sametype_value_func(Float),
239
248
  group="Scalar Math",
240
- doc="""Calculate the nearest integer that is closer to zero than x.
241
- In other words, it discards the fractional part of x.
242
- It is similar to casting ``float(int(x))``, but preserves the negative sign when x is in the range [-0.0, -1.0).
243
- Equivalent to ``numpy.trunc()`` and ``numpy.fix()``.""",
249
+ doc="""Return the nearest integer that is closer to zero than ``x``.
250
+ In other words, it discards the fractional part of ``x``.
251
+ It is similar to casting ``float(int(x))``, but preserves the negative sign when x is in the range [-0.0, -1.0).
252
+ Equivalent to :func:`numpy.trunc()` and :func:`numpy.fix()`.""",
244
253
  )
245
254
 
246
255
  add_builtin(
@@ -248,7 +257,7 @@ add_builtin(
248
257
  input_types={"x": Float},
249
258
  value_func=sametype_value_func(Float),
250
259
  group="Scalar Math",
251
- doc="""Calculate the largest integer that is less than or equal to x.""",
260
+ doc="""Return the largest integer that is less than or equal to ``x``.""",
252
261
  )
253
262
 
254
263
  add_builtin(
@@ -256,22 +265,31 @@ add_builtin(
256
265
  input_types={"x": Float},
257
266
  value_func=sametype_value_func(Float),
258
267
  group="Scalar Math",
259
- doc="""Calculate the smallest integer that is greater than or equal to x.""",
268
+ doc="""Return the smallest integer that is greater than or equal to ``x``.""",
269
+ )
270
+
271
+ add_builtin(
272
+ "frac",
273
+ input_types={"x": Float},
274
+ value_func=sametype_value_func(Float),
275
+ group="Scalar Math",
276
+ doc="""Retrieve the fractional part of x.
277
+ In other words, it discards the integer part of x and is equivalent to ``x - trunc(x)``.""",
260
278
  )
261
279
 
262
280
 
263
- def infer_scalar_type(args):
264
- if args is None:
281
+ def infer_scalar_type(arg_types):
282
+ if arg_types is None:
265
283
  return Scalar
266
284
 
267
- def iterate_scalar_types(args):
268
- for a in args:
269
- if hasattr(a.type, "_wp_scalar_type_"):
270
- yield a.type._wp_scalar_type_
271
- elif a.type in scalar_types:
272
- yield a.type
285
+ def iterate_scalar_types(arg_types):
286
+ for t in arg_types:
287
+ if hasattr(t, "_wp_scalar_type_"):
288
+ yield t._wp_scalar_type_
289
+ elif t in scalar_types:
290
+ yield t
273
291
 
274
- scalarTypes = set(iterate_scalar_types(args))
292
+ scalarTypes = set(iterate_scalar_types(arg_types))
275
293
  if len(scalarTypes) > 1:
276
294
  raise RuntimeError(
277
295
  f"Couldn't figure out return type as arguments have multiple precisions: {list(scalarTypes)}"
@@ -279,13 +297,13 @@ def infer_scalar_type(args):
279
297
  return list(scalarTypes)[0]
280
298
 
281
299
 
282
- def sametype_scalar_value_func(args, kwds, _):
283
- if args is None:
300
+ def sametype_scalar_value_func(arg_types, kwds, _):
301
+ if arg_types is None:
284
302
  return Scalar
285
- if not all(types_equal(args[0].type, a.type) for a in args[1:]):
286
- raise RuntimeError(f"Input types must be exactly the same, {[a.type for a in args]}")
303
+ if not all(types_equal(arg_types[0], t) for t in arg_types[1:]):
304
+ raise RuntimeError(f"Input types must be exactly the same, {[t for t in arg_types]}")
287
305
 
288
- return infer_scalar_type(args)
306
+ return infer_scalar_type(arg_types)
289
307
 
290
308
 
291
309
  # ---------------------------------
@@ -310,14 +328,14 @@ add_builtin(
310
328
  "min",
311
329
  input_types={"x": vector(length=Any, dtype=Scalar), "y": vector(length=Any, dtype=Scalar)},
312
330
  value_func=sametype_value_func(vector(length=Any, dtype=Scalar)),
313
- doc="Return the element wise minimum of two vectors.",
331
+ doc="Return the element-wise minimum of two vectors.",
314
332
  group="Vector Math",
315
333
  )
316
334
  add_builtin(
317
335
  "max",
318
336
  input_types={"x": vector(length=Any, dtype=Scalar), "y": vector(length=Any, dtype=Scalar)},
319
337
  value_func=sametype_value_func(vector(length=Any, dtype=Scalar)),
320
- doc="Return the element wise maximum of two vectors.",
338
+ doc="Return the element-wise maximum of two vectors.",
321
339
  group="Vector Math",
322
340
  )
323
341
 
@@ -325,41 +343,41 @@ add_builtin(
325
343
  "min",
326
344
  input_types={"v": vector(length=Any, dtype=Scalar)},
327
345
  value_func=sametype_scalar_value_func,
328
- doc="Return the minimum element of a vector.",
346
+ doc="Return the minimum element of a vector ``v``.",
329
347
  group="Vector Math",
330
348
  )
331
349
  add_builtin(
332
350
  "max",
333
351
  input_types={"v": vector(length=Any, dtype=Scalar)},
334
352
  value_func=sametype_scalar_value_func,
335
- doc="Return the maximum element of a vector.",
353
+ doc="Return the maximum element of a vector ``v``.",
336
354
  group="Vector Math",
337
355
  )
338
356
 
339
357
  add_builtin(
340
358
  "argmin",
341
359
  input_types={"v": vector(length=Any, dtype=Scalar)},
342
- value_func=lambda args, kwds, _: warp.uint32,
343
- doc="Return the index of the minimum element of a vector.",
360
+ value_func=lambda arg_types, kwds, _: warp.uint32,
361
+ doc="Return the index of the minimum element of a vector ``v``.",
344
362
  group="Vector Math",
345
363
  missing_grad=True,
346
364
  )
347
365
  add_builtin(
348
366
  "argmax",
349
367
  input_types={"v": vector(length=Any, dtype=Scalar)},
350
- value_func=lambda args, kwds, _: warp.uint32,
351
- doc="Return the index of the maximum element of a vector.",
368
+ value_func=lambda arg_types, kwds, _: warp.uint32,
369
+ doc="Return the index of the maximum element of a vector ``v``.",
352
370
  group="Vector Math",
353
371
  missing_grad=True,
354
372
  )
355
373
 
356
374
 
357
- def value_func_outer(args, kwds, _):
358
- if args is None:
375
+ def value_func_outer(arg_types, kwds, _):
376
+ if arg_types is None:
359
377
  return matrix(shape=(Any, Any), dtype=Scalar)
360
378
 
361
- scalarType = infer_scalar_type(args)
362
- vectorLengths = [i.type._length_ for i in args]
379
+ scalarType = infer_scalar_type(arg_types)
380
+ vectorLengths = [t._length_ for t in arg_types]
363
381
  return matrix(shape=(vectorLengths), dtype=scalarType)
364
382
 
365
383
 
@@ -368,7 +386,7 @@ add_builtin(
368
386
  input_types={"x": vector(length=Any, dtype=Scalar), "y": vector(length=Any, dtype=Scalar)},
369
387
  value_func=value_func_outer,
370
388
  group="Vector Math",
371
- doc="Compute the outer product x*y^T for two vec2 objects.",
389
+ doc="Compute the outer product ``x*y^T`` for two vectors.",
372
390
  )
373
391
 
374
392
  add_builtin(
@@ -376,14 +394,14 @@ add_builtin(
376
394
  input_types={"x": vector(length=3, dtype=Scalar), "y": vector(length=3, dtype=Scalar)},
377
395
  value_func=sametype_value_func(vector(length=3, dtype=Scalar)),
378
396
  group="Vector Math",
379
- doc="Compute the cross product of two 3d vectors.",
397
+ doc="Compute the cross product of two 3D vectors.",
380
398
  )
381
399
  add_builtin(
382
400
  "skew",
383
401
  input_types={"x": vector(length=3, dtype=Scalar)},
384
- value_func=lambda args, kwds, _: matrix(shape=(3, 3), dtype=args[0].type._wp_scalar_type_),
402
+ value_func=lambda arg_types, kwds, _: matrix(shape=(3, 3), dtype=arg_types[0]._wp_scalar_type_),
385
403
  group="Vector Math",
386
- doc="Compute the skew symmetric matrix for a 3d vector.",
404
+ doc="Compute the skew-symmetric 3x3 matrix for a 3D vector ``x``.",
387
405
  )
388
406
 
389
407
  add_builtin(
@@ -391,59 +409,62 @@ add_builtin(
391
409
  input_types={"x": vector(length=Any, dtype=Float)},
392
410
  value_func=sametype_scalar_value_func,
393
411
  group="Vector Math",
394
- doc="Compute the length of a vector.",
412
+ doc="Compute the length of a vector ``x``.",
413
+ require_original_output_arg=True,
395
414
  )
396
415
  add_builtin(
397
416
  "length",
398
417
  input_types={"x": quaternion(dtype=Float)},
399
418
  value_func=sametype_scalar_value_func,
400
419
  group="Vector Math",
401
- doc="Compute the length of a quaternion.",
420
+ doc="Compute the length of a quaternion ``x``.",
421
+ require_original_output_arg=True,
402
422
  )
403
423
  add_builtin(
404
424
  "length_sq",
405
425
  input_types={"x": vector(length=Any, dtype=Scalar)},
406
426
  value_func=sametype_scalar_value_func,
407
427
  group="Vector Math",
408
- doc="Compute the squared length of a 2d vector.",
428
+ doc="Compute the squared length of a 2D vector ``x``.",
409
429
  )
410
430
  add_builtin(
411
431
  "length_sq",
412
432
  input_types={"x": quaternion(dtype=Scalar)},
413
433
  value_func=sametype_scalar_value_func,
414
434
  group="Vector Math",
415
- doc="Compute the squared length of a quaternion.",
435
+ doc="Compute the squared length of a quaternion ``x``.",
416
436
  )
417
437
  add_builtin(
418
438
  "normalize",
419
439
  input_types={"x": vector(length=Any, dtype=Float)},
420
440
  value_func=sametype_value_func(vector(length=Any, dtype=Scalar)),
421
441
  group="Vector Math",
422
- doc="Compute the normalized value of x, if length(x) is 0 then the zero vector is returned.",
442
+ doc="Compute the normalized value of ``x``. If ``length(x)`` is 0 then the zero vector is returned.",
443
+ require_original_output_arg=True,
423
444
  )
424
445
  add_builtin(
425
446
  "normalize",
426
447
  input_types={"x": quaternion(dtype=Float)},
427
448
  value_func=sametype_value_func(quaternion(dtype=Scalar)),
428
449
  group="Vector Math",
429
- doc="Compute the normalized value of x, if length(x) is 0 then the zero quat is returned.",
450
+ doc="Compute the normalized value of ``x``. If ``length(x)`` is 0, then the zero quaternion is returned.",
430
451
  )
431
452
 
432
453
  add_builtin(
433
454
  "transpose",
434
455
  input_types={"m": matrix(shape=(Any, Any), dtype=Scalar)},
435
- value_func=lambda args, kwds, _: matrix(
436
- shape=(args[0].type._shape_[1], args[0].type._shape_[0]), dtype=args[0].type._wp_scalar_type_
456
+ value_func=lambda arg_types, kwds, _: matrix(
457
+ shape=(arg_types[0]._shape_[1], arg_types[0]._shape_[0]), dtype=arg_types[0]._wp_scalar_type_
437
458
  ),
438
459
  group="Vector Math",
439
- doc="Return the transpose of the matrix m",
460
+ doc="Return the transpose of the matrix ``m``.",
440
461
  )
441
462
 
442
463
 
443
- def value_func_mat_inv(args, kwds, _):
444
- if args is None:
464
+ def value_func_mat_inv(arg_types, kwds, _):
465
+ if arg_types is None:
445
466
  return matrix(shape=(Any, Any), dtype=Float)
446
- return args[0].type
467
+ return arg_types[0]
447
468
 
448
469
 
449
470
  add_builtin(
@@ -451,7 +472,8 @@ add_builtin(
451
472
  input_types={"m": matrix(shape=(2, 2), dtype=Float)},
452
473
  value_func=value_func_mat_inv,
453
474
  group="Vector Math",
454
- doc="Return the inverse of a 2x2 matrix m",
475
+ doc="Return the inverse of a 2x2 matrix ``m``.",
476
+ require_original_output_arg=True,
455
477
  )
456
478
 
457
479
  add_builtin(
@@ -459,7 +481,8 @@ add_builtin(
459
481
  input_types={"m": matrix(shape=(3, 3), dtype=Float)},
460
482
  value_func=value_func_mat_inv,
461
483
  group="Vector Math",
462
- doc="Return the inverse of a 3x3 matrix m",
484
+ doc="Return the inverse of a 3x3 matrix ``m``.",
485
+ require_original_output_arg=True,
463
486
  )
464
487
 
465
488
  add_builtin(
@@ -467,14 +490,15 @@ add_builtin(
467
490
  input_types={"m": matrix(shape=(4, 4), dtype=Float)},
468
491
  value_func=value_func_mat_inv,
469
492
  group="Vector Math",
470
- doc="Return the inverse of a 4x4 matrix m",
493
+ doc="Return the inverse of a 4x4 matrix ``m``.",
494
+ require_original_output_arg=True,
471
495
  )
472
496
 
473
497
 
474
- def value_func_mat_det(args, kwds, _):
475
- if args is None:
498
+ def value_func_mat_det(arg_types, kwds, _):
499
+ if arg_types is None:
476
500
  return Scalar
477
- return args[0].type._wp_scalar_type_
501
+ return arg_types[0]._wp_scalar_type_
478
502
 
479
503
 
480
504
  add_builtin(
@@ -482,7 +506,7 @@ add_builtin(
482
506
  input_types={"m": matrix(shape=(2, 2), dtype=Float)},
483
507
  value_func=value_func_mat_det,
484
508
  group="Vector Math",
485
- doc="Return the determinant of a 2x2 matrix m",
509
+ doc="Return the determinant of a 2x2 matrix ``m``.",
486
510
  )
487
511
 
488
512
  add_builtin(
@@ -490,7 +514,7 @@ add_builtin(
490
514
  input_types={"m": matrix(shape=(3, 3), dtype=Float)},
491
515
  value_func=value_func_mat_det,
492
516
  group="Vector Math",
493
- doc="Return the determinant of a 3x3 matrix m",
517
+ doc="Return the determinant of a 3x3 matrix ``m``.",
494
518
  )
495
519
 
496
520
  add_builtin(
@@ -498,16 +522,16 @@ add_builtin(
498
522
  input_types={"m": matrix(shape=(4, 4), dtype=Float)},
499
523
  value_func=value_func_mat_det,
500
524
  group="Vector Math",
501
- doc="Return the determinant of a 4x4 matrix m",
525
+ doc="Return the determinant of a 4x4 matrix ``m``.",
502
526
  )
503
527
 
504
528
 
505
- def value_func_mat_trace(args, kwds, _):
506
- if args is None:
529
+ def value_func_mat_trace(arg_types, kwds, _):
530
+ if arg_types is None:
507
531
  return Scalar
508
- if args[0].type._shape_[0] != args[0].type._shape_[1]:
509
- raise RuntimeError(f"Matrix shape is {args[0].type._shape_}. Cannot find the trace of non square matrices")
510
- return args[0].type._wp_scalar_type_
532
+ if arg_types[0]._shape_[0] != arg_types[0]._shape_[1]:
533
+ raise RuntimeError(f"Matrix shape is {arg_types[0]._shape_}. Cannot find the trace of non square matrices")
534
+ return arg_types[0]._wp_scalar_type_
511
535
 
512
536
 
513
537
  add_builtin(
@@ -515,15 +539,15 @@ add_builtin(
515
539
  input_types={"m": matrix(shape=(Any, Any), dtype=Scalar)},
516
540
  value_func=value_func_mat_trace,
517
541
  group="Vector Math",
518
- doc="Return the trace of the matrix m",
542
+ doc="Return the trace of the matrix ``m``.",
519
543
  )
520
544
 
521
545
 
522
- def value_func_diag(args, kwds, _):
523
- if args is None:
546
+ def value_func_diag(arg_types, kwds, _):
547
+ if arg_types is None:
524
548
  return matrix(shape=(Any, Any), dtype=Scalar)
525
549
  else:
526
- return matrix(shape=(args[0].type._length_, args[0].type._length_), dtype=args[0].type._wp_scalar_type_)
550
+ return matrix(shape=(arg_types[0]._length_, arg_types[0]._length_), dtype=arg_types[0]._wp_scalar_type_)
527
551
 
528
552
 
529
553
  add_builtin(
@@ -531,19 +555,19 @@ add_builtin(
531
555
  input_types={"d": vector(length=Any, dtype=Scalar)},
532
556
  value_func=value_func_diag,
533
557
  group="Vector Math",
534
- doc="Returns a matrix with the components of the vector d on the diagonal",
558
+ doc="Returns a matrix with the components of the vector ``d`` on the diagonal.",
535
559
  )
536
560
 
537
561
 
538
- def value_func_get_diag(args, kwds, _):
539
- if args is None:
562
+ def value_func_get_diag(arg_types, kwds, _):
563
+ if arg_types is None:
540
564
  return vector(length=(Any), dtype=Scalar)
541
565
  else:
542
- if args[0].type._shape_[0] != args[0].type._shape_[1]:
566
+ if arg_types[0]._shape_[0] != arg_types[0]._shape_[1]:
543
567
  raise RuntimeError(
544
- f"Matrix shape is {args[0].type._shape_}; get_diag is only available for square matrices."
568
+ f"Matrix shape is {arg_types[0]._shape_}; get_diag is only available for square matrices."
545
569
  )
546
- return vector(length=args[0].type._shape_[0], dtype=args[0].type._wp_scalar_type_)
570
+ return vector(length=arg_types[0]._shape_[0], dtype=arg_types[0]._wp_scalar_type_)
547
571
 
548
572
 
549
573
  add_builtin(
@@ -551,7 +575,7 @@ add_builtin(
551
575
  input_types={"m": matrix(shape=(Any, Any), dtype=Scalar)},
552
576
  value_func=value_func_get_diag,
553
577
  group="Vector Math",
554
- doc="Returns a vector containing the diagonal elements of the square matrix.",
578
+ doc="Returns a vector containing the diagonal elements of the square matrix ``m``.",
555
579
  )
556
580
 
557
581
  add_builtin(
@@ -559,14 +583,15 @@ add_builtin(
559
583
  input_types={"x": vector(length=Any, dtype=Scalar), "y": vector(length=Any, dtype=Scalar)},
560
584
  value_func=sametype_value_func(vector(length=Any, dtype=Scalar)),
561
585
  group="Vector Math",
562
- doc="Component wise multiply of two 2d vectors.",
586
+ doc="Component-wise multiplication of two 2D vectors.",
563
587
  )
564
588
  add_builtin(
565
589
  "cw_div",
566
590
  input_types={"x": vector(length=Any, dtype=Scalar), "y": vector(length=Any, dtype=Scalar)},
567
591
  value_func=sametype_value_func(vector(length=Any, dtype=Scalar)),
568
592
  group="Vector Math",
569
- doc="Component wise division of two 2d vectors.",
593
+ doc="Component-wise division of two 2D vectors.",
594
+ require_original_output_arg=True,
570
595
  )
571
596
 
572
597
  add_builtin(
@@ -574,14 +599,15 @@ add_builtin(
574
599
  input_types={"x": matrix(shape=(Any, Any), dtype=Scalar), "y": matrix(shape=(Any, Any), dtype=Scalar)},
575
600
  value_func=sametype_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
576
601
  group="Vector Math",
577
- doc="Component wise multiply of two 2d vectors.",
602
+ doc="Component-wise multiplication of two 2D vectors.",
578
603
  )
579
604
  add_builtin(
580
605
  "cw_div",
581
606
  input_types={"x": matrix(shape=(Any, Any), dtype=Scalar), "y": matrix(shape=(Any, Any), dtype=Scalar)},
582
607
  value_func=sametype_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
583
608
  group="Vector Math",
584
- doc="Component wise division of two 2d vectors.",
609
+ doc="Component-wise division of two 2D vectors.",
610
+ require_original_output_arg=True,
585
611
  )
586
612
 
587
613
 
@@ -593,16 +619,19 @@ for t in scalar_types_all:
593
619
  t.__name__, input_types={"u": u}, value_type=t, doc="", hidden=True, group="Scalar Math", export=False
594
620
  )
595
621
 
622
+ for u in [bool, builtins.bool]:
623
+ add_builtin(bool.__name__, input_types={"u": u}, value_type=bool, doc="", hidden=True, export=False, namespace="")
624
+
596
625
 
597
- def vector_constructor_func(args, kwds, templates):
598
- if args is None:
626
+ def vector_constructor_func(arg_types, kwds, templates):
627
+ if arg_types is None:
599
628
  return vector(length=Any, dtype=Scalar)
600
629
 
601
630
  if templates is None or len(templates) == 0:
602
631
  # handle construction of anonymous (undeclared) vector types
603
632
 
604
633
  if "length" in kwds:
605
- if len(args) == 0:
634
+ if len(arg_types) == 0:
606
635
  if "dtype" not in kwds:
607
636
  raise RuntimeError(
608
637
  "vec() must have dtype as a keyword argument if it has no positional arguments, e.g.: wp.vector(length=5, dtype=wp.float32)"
@@ -612,12 +641,12 @@ def vector_constructor_func(args, kwds, templates):
612
641
  veclen = kwds["length"]
613
642
  vectype = kwds["dtype"]
614
643
 
615
- elif len(args) == 1:
644
+ elif len(arg_types) == 1:
616
645
  # value initialization e.g.: wp.vec(1.0, length=5)
617
646
  veclen = kwds["length"]
618
- vectype = args[0].type
647
+ vectype = arg_types[0]
619
648
  if getattr(vectype, "_wp_generic_type_str_", None) == "vec_t":
620
- # constructor from another matrix
649
+ # constructor from another vector
621
650
  if vectype._length_ != veclen:
622
651
  raise RuntimeError(
623
652
  f"Incompatible vector lengths for casting copy constructor, {veclen} vs {vectype._length_}"
@@ -629,28 +658,37 @@ def vector_constructor_func(args, kwds, templates):
629
658
  )
630
659
 
631
660
  else:
632
- if len(args) == 0:
661
+ if len(arg_types) == 0:
633
662
  raise RuntimeError(
634
663
  "vec() must have at least one numeric argument, if it's length, dtype is not specified"
635
664
  )
636
665
 
637
666
  if "dtype" in kwds:
667
+ # casting constructor
668
+ if len(arg_types) == 1 and types_equal(
669
+ arg_types[0], vector(length=Any, dtype=Scalar), match_generic=True
670
+ ):
671
+ veclen = arg_types[0]._length_
672
+ vectype = kwds["dtype"]
673
+ templates.append(veclen)
674
+ templates.append(vectype)
675
+ return vector(length=veclen, dtype=vectype)
638
676
  raise RuntimeError(
639
677
  "vec() should not have dtype specified if numeric arguments are given, the dtype will be inferred from the argument types"
640
678
  )
641
679
 
642
680
  # component wise construction of an anonymous vector, e.g. wp.vec(wp.float16(1.0), wp.float16(2.0), ....)
643
681
  # we infer the length and data type from the number and type of the arg values
644
- veclen = len(args)
645
- vectype = args[0].type
682
+ veclen = len(arg_types)
683
+ vectype = arg_types[0]
646
684
 
647
- if len(args) == 1 and getattr(vectype, "_wp_generic_type_str_", None) == "vec_t":
685
+ if len(arg_types) == 1 and getattr(vectype, "_wp_generic_type_str_", None) == "vec_t":
648
686
  # constructor from another vector
649
687
  veclen = vectype._length_
650
688
  vectype = vectype._wp_scalar_type_
651
- elif not all(vectype == a.type for a in args):
689
+ elif not all(vectype == t for t in arg_types):
652
690
  raise RuntimeError(
653
- f"All numeric arguments to vec() constructor should have the same type, expected {veclen} args of type {vectype}, received { ','.join(map(lambda x : str(x.type), args)) }"
691
+ f"All numeric arguments to vec() constructor should have the same type, expected {veclen} arg_types of type {vectype}, received { ','.join(map(lambda t : str(t), arg_types)) }"
654
692
  )
655
693
 
656
694
  # update the templates list, so we can generate vec<len, type>() correctly in codegen
@@ -660,15 +698,15 @@ def vector_constructor_func(args, kwds, templates):
660
698
  else:
661
699
  # construction of a predeclared type, e.g.: vec5d
662
700
  veclen, vectype = templates
663
- if len(args) == 1 and getattr(args[0].type, "_wp_generic_type_str_", None) == "vec_t":
701
+ if len(arg_types) == 1 and getattr(arg_types[0], "_wp_generic_type_str_", None) == "vec_t":
664
702
  # constructor from another vector
665
- if args[0].type._length_ != veclen:
703
+ if arg_types[0]._length_ != veclen:
666
704
  raise RuntimeError(
667
- f"Incompatible matrix sizes for casting copy constructor, {veclen} vs {args[0].type._length_}"
705
+ f"Incompatible matrix sizes for casting copy constructor, {veclen} vs {arg_types[0]._length_}"
668
706
  )
669
- elif not all(vectype == a.type for a in args):
707
+ elif not all(vectype == t for t in arg_types):
670
708
  raise RuntimeError(
671
- f"All numeric arguments to vec() constructor should have the same type, expected {veclen} args of type {vectype}, received { ','.join(map(lambda x : str(x.type), args)) }"
709
+ f"All numeric arguments to vec() constructor should have the same type, expected {veclen} arg_types of type {vectype}, received { ','.join(map(lambda t : str(t), arg_types)) }"
672
710
  )
673
711
 
674
712
  retvalue = vector(length=veclen, dtype=vectype)
@@ -677,9 +715,9 @@ def vector_constructor_func(args, kwds, templates):
677
715
 
678
716
  add_builtin(
679
717
  "vector",
680
- input_types={"*args": Scalar, "length": int, "dtype": Scalar},
718
+ input_types={"*arg_types": Scalar, "length": int, "dtype": Scalar},
681
719
  variadic=True,
682
- initializer_list_func=lambda args, _: len(args) > 4,
720
+ initializer_list_func=lambda arg_types, _: len(arg_types) > 4,
683
721
  value_func=vector_constructor_func,
684
722
  native_func="vec_t",
685
723
  doc="Construct a vector of with given length and dtype.",
@@ -688,8 +726,8 @@ add_builtin(
688
726
  )
689
727
 
690
728
 
691
- def matrix_constructor_func(args, kwds, templates):
692
- if args is None:
729
+ def matrix_constructor_func(arg_types, kwds, templates):
730
+ if arg_types is None:
693
731
  return matrix(shape=(Any, Any), dtype=Scalar)
694
732
 
695
733
  if len(templates) == 0:
@@ -697,7 +735,7 @@ def matrix_constructor_func(args, kwds, templates):
697
735
  if "shape" not in kwds:
698
736
  raise RuntimeError("shape keyword must be specified when calling matrix() function")
699
737
 
700
- if len(args) == 0:
738
+ if len(arg_types) == 0:
701
739
  if "dtype" not in kwds:
702
740
  raise RuntimeError("matrix() must have dtype as a keyword argument if it has no positional arguments")
703
741
 
@@ -708,16 +746,16 @@ def matrix_constructor_func(args, kwds, templates):
708
746
  else:
709
747
  # value initialization, e.g.: m = matrix(1.0, shape=(3,2))
710
748
  shape = kwds["shape"]
711
- dtype = args[0].type
749
+ dtype = arg_types[0]
712
750
 
713
- if len(args) == 1 and getattr(dtype, "_wp_generic_type_str_", None) == "mat_t":
751
+ if len(arg_types) == 1 and getattr(dtype, "_wp_generic_type_str_", None) == "mat_t":
714
752
  # constructor from another matrix
715
- if types[0]._shape_ != shape:
753
+ if arg_types[0]._shape_ != shape:
716
754
  raise RuntimeError(
717
- f"Incompatible matrix sizes for casting copy constructor, {shape} vs {types[0]._shape_}"
755
+ f"Incompatible matrix sizes for casting copy constructor, {shape} vs {arg_types[0]._shape_}"
718
756
  )
719
757
  dtype = dtype._wp_scalar_type_
720
- elif len(args) > 1 and len(args) != shape[0] * shape[1]:
758
+ elif len(arg_types) > 1 and len(arg_types) != shape[0] * shape[1]:
721
759
  raise RuntimeError(
722
760
  "Wrong number of arguments for matrix() function, must initialize with either a scalar value, or m*n values"
723
761
  )
@@ -731,35 +769,34 @@ def matrix_constructor_func(args, kwds, templates):
731
769
  shape = (templates[0], templates[1])
732
770
  dtype = templates[2]
733
771
 
734
- if len(args) > 0:
735
- types = [a.type for a in args]
736
- if len(args) == 1 and getattr(types[0], "_wp_generic_type_str_", None) == "mat_t":
772
+ if len(arg_types) > 0:
773
+ if len(arg_types) == 1 and getattr(arg_types[0], "_wp_generic_type_str_", None) == "mat_t":
737
774
  # constructor from another matrix with same dimension but possibly different type
738
- if types[0]._shape_ != shape:
775
+ if arg_types[0]._shape_ != shape:
739
776
  raise RuntimeError(
740
- f"Incompatible matrix sizes for casting copy constructor, {shape} vs {types[0]._shape_}"
777
+ f"Incompatible matrix sizes for casting copy constructor, {shape} vs {arg_types[0]._shape_}"
741
778
  )
742
779
  else:
743
780
  # check scalar arg type matches declared type
744
- if infer_scalar_type(args) != dtype:
781
+ if infer_scalar_type(arg_types) != dtype:
745
782
  raise RuntimeError("Wrong scalar type for mat {} constructor".format(",".join(map(str, templates))))
746
783
 
747
784
  # check vector arg type matches declared type
748
- if all(hasattr(a, "_wp_generic_type_str_") and a._wp_generic_type_str_ == "vec_t" for a in types):
749
- cols = len(types)
785
+ if all(hasattr(a, "_wp_generic_type_str_") and a._wp_generic_type_str_ == "vec_t" for a in arg_types):
786
+ cols = len(arg_types)
750
787
  if shape[1] != cols:
751
788
  raise RuntimeError(
752
789
  "Wrong number of vectors when attempting to construct a matrix with column vectors"
753
790
  )
754
791
 
755
- if not all(a._length_ == shape[0] for a in types):
792
+ if not all(a._length_ == shape[0] for a in arg_types):
756
793
  raise RuntimeError(
757
794
  "Wrong vector row count when attempting to construct a matrix with column vectors"
758
795
  )
759
796
  else:
760
797
  # check that we either got 1 arg (scalar construction), or enough values for whole matrix
761
798
  size = shape[0] * shape[1]
762
- if len(args) > 1 and len(args) != size:
799
+ if len(arg_types) > 1 and len(arg_types) != size:
763
800
  raise RuntimeError(
764
801
  "Wrong number of scalars when attempting to construct a matrix from a list of components"
765
802
  )
@@ -768,37 +805,34 @@ def matrix_constructor_func(args, kwds, templates):
768
805
 
769
806
 
770
807
  # only use initializer list if matrix size < 5x5, or for scalar construction
771
- def matrix_initlist_func(args, templates):
808
+ def matrix_initlist_func(arg_types, templates):
772
809
  m, n, dtype = templates
773
- if (
774
- len(args) == 0
775
- or len(args) == 1 # zero construction
810
+ return not (
811
+ len(arg_types) == 0
812
+ or len(arg_types) == 1 # zero construction
776
813
  or (m == n and n < 5) # scalar construction # value construction for small matrices
777
- ):
778
- return False
779
- else:
780
- return True
814
+ )
781
815
 
782
816
 
783
817
  add_builtin(
784
818
  "matrix",
785
- input_types={"*args": Scalar, "shape": Tuple[int, int], "dtype": Scalar},
819
+ input_types={"*arg_types": Scalar, "shape": Tuple[int, int], "dtype": Scalar},
786
820
  variadic=True,
787
821
  initializer_list_func=matrix_initlist_func,
788
822
  value_func=matrix_constructor_func,
789
823
  native_func="mat_t",
790
- doc="Construct a matrix, if positional args are not given then matrix will be zero-initialized.",
824
+ doc="Construct a matrix. If the positional ``arg_types`` are not given, then matrix will be zero-initialized.",
791
825
  group="Vector Math",
792
826
  export=False,
793
827
  )
794
828
 
795
829
 
796
830
  # identity:
797
- def matrix_identity_value_func(args, kwds, templates):
798
- if args is None:
831
+ def matrix_identity_value_func(arg_types, kwds, templates):
832
+ if arg_types is None:
799
833
  return matrix(shape=(Any, Any), dtype=Scalar)
800
834
 
801
- if len(args):
835
+ if len(arg_types):
802
836
  raise RuntimeError("identity() function does not accept positional arguments")
803
837
 
804
838
  if "n" not in kwds:
@@ -829,7 +863,7 @@ add_builtin(
829
863
  )
830
864
 
831
865
 
832
- def matrix_transform_value_func(args, kwds, templates):
866
+ def matrix_transform_value_func(arg_types, kwds, templates):
833
867
  if templates is None:
834
868
  return matrix(shape=(Any, Any), dtype=Float)
835
869
 
@@ -839,7 +873,7 @@ def matrix_transform_value_func(args, kwds, templates):
839
873
  m, n, dtype = templates
840
874
  if (m, n) != (4, 4):
841
875
  raise RuntimeError("Can only construct 4x4 matrices with position, rotation and scale")
842
- if infer_scalar_type(args) != dtype:
876
+ if infer_scalar_type(arg_types) != dtype:
843
877
  raise RuntimeError("Wrong scalar type for mat<{}> constructor".format(",".join(map(str, templates))))
844
878
 
845
879
  return matrix(shape=(4, 4), dtype=dtype)
@@ -854,7 +888,8 @@ add_builtin(
854
888
  },
855
889
  value_func=matrix_transform_value_func,
856
890
  native_func="mat_t",
857
- doc="""Construct a 4x4 transformation matrix that applies the transformations as Translation(pos)*Rotation(rot)*Scale(scale) when applied to column vectors, i.e.: y = (TRS)*x""",
891
+ doc="""Construct a 4x4 transformation matrix that applies the transformations as
892
+ Translation(pos)*Rotation(rot)*Scale(scale) when applied to column vectors, i.e.: y = (TRS)*x""",
858
893
  group="Vector Math",
859
894
  export=False,
860
895
  )
@@ -873,8 +908,8 @@ add_builtin(
873
908
  value_type=None,
874
909
  group="Vector Math",
875
910
  export=False,
876
- doc="""Compute the SVD of a 3x3 matrix. The singular values are returned in sigma,
877
- while the left and right basis vectors are returned in U and V.""",
911
+ doc="""Compute the SVD of a 3x3 matrix ``A``. The singular values are returned in ``sigma``,
912
+ while the left and right basis vectors are returned in ``U`` and ``V``.""",
878
913
  )
879
914
 
880
915
  add_builtin(
@@ -887,7 +922,8 @@ add_builtin(
887
922
  value_type=None,
888
923
  group="Vector Math",
889
924
  export=False,
890
- doc="""Compute the QR decomposition of a 3x3 matrix. The orthogonal matrix is returned in Q, while the upper triangular matrix is returned in R.""",
925
+ doc="""Compute the QR decomposition of a 3x3 matrix ``A``. The orthogonal matrix is returned in ``Q``,
926
+ while the upper triangular matrix is returned in ``R``.""",
891
927
  )
892
928
 
893
929
  add_builtin(
@@ -900,36 +936,53 @@ add_builtin(
900
936
  value_type=None,
901
937
  group="Vector Math",
902
938
  export=False,
903
- doc="""Compute the eigendecomposition of a 3x3 matrix. The eigenvectors are returned as the columns of Q, while the corresponding eigenvalues are returned in d.""",
939
+ doc="""Compute the eigendecomposition of a 3x3 matrix ``A``. The eigenvectors are returned as the columns of ``Q``,
940
+ while the corresponding eigenvalues are returned in ``d``.""",
904
941
  )
905
942
 
906
943
  # ---------------------------------
907
944
  # Quaternion Math
908
945
 
909
946
 
910
- def quaternion_value_func(args, kwds, templates):
911
- if args is None:
912
- return quaternion(dtype=Scalar)
947
+ def quaternion_value_func(arg_types, kwds, templates):
948
+ if arg_types is None:
949
+ return quaternion(dtype=Float)
913
950
 
914
- # if constructing anonymous quat type then infer output type from arguments
915
951
  if len(templates) == 0:
916
- dtype = infer_scalar_type(args)
952
+ if "dtype" in kwds:
953
+ # casting constructor
954
+ dtype = kwds["dtype"]
955
+ else:
956
+ # if constructing anonymous quat type then infer output type from arguments
957
+ dtype = infer_scalar_type(arg_types)
917
958
  templates.append(dtype)
918
959
  else:
919
- # if constructing predeclared type then check args match expectation
920
- if len(args) > 0 and infer_scalar_type(args) != templates[0]:
960
+ # if constructing predeclared type then check arg_types match expectation
961
+ if len(arg_types) > 0 and infer_scalar_type(arg_types) != templates[0]:
921
962
  raise RuntimeError("Wrong scalar type for quat {} constructor".format(",".join(map(str, templates))))
922
963
 
923
964
  return quaternion(dtype=templates[0])
924
965
 
925
966
 
967
+ def quat_cast_value_func(arg_types, kwds, templates):
968
+ if arg_types is None:
969
+ raise RuntimeError("Missing quaternion argument.")
970
+ if "dtype" not in kwds:
971
+ raise RuntimeError("Missing 'dtype' kwd.")
972
+
973
+ dtype = kwds["dtype"]
974
+ templates.append(dtype)
975
+
976
+ return quaternion(dtype=dtype)
977
+
978
+
926
979
  add_builtin(
927
980
  "quaternion",
928
981
  input_types={},
929
982
  value_func=quaternion_value_func,
930
983
  native_func="quat_t",
931
984
  group="Quaternion Math",
932
- doc="""Construct a zero-initialized quaternion, quaternions are laid out as
985
+ doc="""Construct a zero-initialized quaternion. Quaternions are laid out as
933
986
  [ix, iy, iz, r], where ix, iy, iz are the imaginary part, and r the real part.""",
934
987
  export=False,
935
988
  )
@@ -939,7 +992,7 @@ add_builtin(
939
992
  value_func=quaternion_value_func,
940
993
  native_func="quat_t",
941
994
  group="Quaternion Math",
942
- doc="Create a quaternion using the supplied components (type inferred from component type)",
995
+ doc="Create a quaternion using the supplied components (type inferred from component type).",
943
996
  export=False,
944
997
  )
945
998
  add_builtin(
@@ -948,14 +1001,23 @@ add_builtin(
948
1001
  value_func=quaternion_value_func,
949
1002
  native_func="quat_t",
950
1003
  group="Quaternion Math",
951
- doc="Create a quaternion using the supplied vector/scalar (type inferred from scalar type)",
1004
+ doc="Create a quaternion using the supplied vector/scalar (type inferred from scalar type).",
1005
+ export=False,
1006
+ )
1007
+ add_builtin(
1008
+ "quaternion",
1009
+ input_types={"q": quaternion(dtype=Float)},
1010
+ value_func=quat_cast_value_func,
1011
+ native_func="quat_t",
1012
+ group="Quaternion Math",
1013
+ doc="Construct a quaternion of type dtype from another quaternion of a different dtype.",
952
1014
  export=False,
953
1015
  )
954
1016
 
955
1017
 
956
- def quat_identity_value_func(args, kwds, templates):
957
- # if args is None then we are in 'export' mode
958
- if args is None:
1018
+ def quat_identity_value_func(arg_types, kwds, templates):
1019
+ # if arg_types is None then we are in 'export' mode
1020
+ if arg_types is None:
959
1021
  return quatf
960
1022
 
961
1023
  if "dtype" not in kwds:
@@ -981,7 +1043,7 @@ add_builtin(
981
1043
  add_builtin(
982
1044
  "quat_from_axis_angle",
983
1045
  input_types={"axis": vector(length=3, dtype=Float), "angle": Float},
984
- value_func=lambda args, kwds, _: quaternion(dtype=infer_scalar_type(args)),
1046
+ value_func=lambda arg_types, kwds, _: quaternion(dtype=infer_scalar_type(arg_types)),
985
1047
  group="Quaternion Math",
986
1048
  doc="Construct a quaternion representing a rotation of angle radians around the given axis.",
987
1049
  )
@@ -995,49 +1057,50 @@ add_builtin(
995
1057
  add_builtin(
996
1058
  "quat_from_matrix",
997
1059
  input_types={"m": matrix(shape=(3, 3), dtype=Float)},
998
- value_func=lambda args, kwds, _: quaternion(dtype=infer_scalar_type(args)),
1060
+ value_func=lambda arg_types, kwds, _: quaternion(dtype=infer_scalar_type(arg_types)),
999
1061
  group="Quaternion Math",
1000
1062
  doc="Construct a quaternion from a 3x3 matrix.",
1001
1063
  )
1002
1064
  add_builtin(
1003
1065
  "quat_rpy",
1004
1066
  input_types={"roll": Float, "pitch": Float, "yaw": Float},
1005
- value_func=lambda args, kwds, _: quaternion(dtype=infer_scalar_type(args)),
1067
+ value_func=lambda arg_types, kwds, _: quaternion(dtype=infer_scalar_type(arg_types)),
1006
1068
  group="Quaternion Math",
1007
1069
  doc="Construct a quaternion representing a combined roll (z), pitch (x), yaw rotations (y) in radians.",
1008
1070
  )
1009
1071
  add_builtin(
1010
1072
  "quat_inverse",
1011
1073
  input_types={"q": quaternion(dtype=Float)},
1012
- value_func=lambda args, kwds, _: quaternion(dtype=infer_scalar_type(args)),
1074
+ value_func=lambda arg_types, kwds, _: quaternion(dtype=infer_scalar_type(arg_types)),
1013
1075
  group="Quaternion Math",
1014
1076
  doc="Compute quaternion conjugate.",
1015
1077
  )
1016
1078
  add_builtin(
1017
1079
  "quat_rotate",
1018
1080
  input_types={"q": quaternion(dtype=Float), "p": vector(length=3, dtype=Float)},
1019
- value_func=lambda args, kwds, _: vector(length=3, dtype=infer_scalar_type(args)),
1081
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1020
1082
  group="Quaternion Math",
1021
1083
  doc="Rotate a vector by a quaternion.",
1022
1084
  )
1023
1085
  add_builtin(
1024
1086
  "quat_rotate_inv",
1025
1087
  input_types={"q": quaternion(dtype=Float), "p": vector(length=3, dtype=Float)},
1026
- value_func=lambda args, kwds, _: vector(length=3, dtype=infer_scalar_type(args)),
1088
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1027
1089
  group="Quaternion Math",
1028
- doc="Rotate a vector the inverse of a quaternion.",
1090
+ doc="Rotate a vector by the inverse of a quaternion.",
1029
1091
  )
1030
1092
  add_builtin(
1031
1093
  "quat_slerp",
1032
1094
  input_types={"q0": quaternion(dtype=Float), "q1": quaternion(dtype=Float), "t": Float},
1033
- value_func=lambda args, kwds, _: quaternion(dtype=infer_scalar_type(args)),
1095
+ value_func=lambda arg_types, kwds, _: quaternion(dtype=infer_scalar_type(arg_types)),
1034
1096
  group="Quaternion Math",
1035
1097
  doc="Linearly interpolate between two quaternions.",
1098
+ require_original_output_arg=True,
1036
1099
  )
1037
1100
  add_builtin(
1038
1101
  "quat_to_matrix",
1039
1102
  input_types={"q": quaternion(dtype=Float)},
1040
- value_func=lambda args, kwds, _: matrix(shape=(3, 3), dtype=infer_scalar_type(args)),
1103
+ value_func=lambda arg_types, kwds, _: matrix(shape=(3, 3), dtype=infer_scalar_type(arg_types)),
1041
1104
  group="Quaternion Math",
1042
1105
  doc="Convert a quaternion to a 3x3 rotation matrix.",
1043
1106
  )
@@ -1053,19 +1116,19 @@ add_builtin(
1053
1116
  # Transformations
1054
1117
 
1055
1118
 
1056
- def transform_constructor_value_func(args, kwds, templates):
1119
+ def transform_constructor_value_func(arg_types, kwds, templates):
1057
1120
  if templates is None:
1058
1121
  return transformation(dtype=Scalar)
1059
1122
 
1060
1123
  if len(templates) == 0:
1061
1124
  # if constructing anonymous transform type then infer output type from arguments
1062
- dtype = infer_scalar_type(args)
1125
+ dtype = infer_scalar_type(arg_types)
1063
1126
  templates.append(dtype)
1064
1127
  else:
1065
- # if constructing predeclared type then check args match expectation
1066
- if infer_scalar_type(args) != templates[0]:
1128
+ # if constructing predeclared type then check arg_types match expectation
1129
+ if infer_scalar_type(arg_types) != templates[0]:
1067
1130
  raise RuntimeError(
1068
- f"Wrong scalar type for transform constructor expected {templates[0]}, got {','.join(map(lambda x : str(x.type), args))}"
1131
+ f"Wrong scalar type for transform constructor expected {templates[0]}, got {','.join(map(lambda t : str(t), arg_types))}"
1069
1132
  )
1070
1133
 
1071
1134
  return transformation(dtype=templates[0])
@@ -1077,13 +1140,13 @@ add_builtin(
1077
1140
  value_func=transform_constructor_value_func,
1078
1141
  native_func="transform_t",
1079
1142
  group="Transformations",
1080
- doc="Construct a rigid body transformation with translation part p and rotation q.",
1143
+ doc="Construct a rigid-body transformation with translation part ``p`` and rotation ``q``.",
1081
1144
  export=False,
1082
1145
  )
1083
1146
 
1084
1147
 
1085
- def transform_identity_value_func(args, kwds, templates):
1086
- if args is None:
1148
+ def transform_identity_value_func(arg_types, kwds, templates):
1149
+ if arg_types is None:
1087
1150
  return transformf
1088
1151
 
1089
1152
  if "dtype" not in kwds:
@@ -1109,68 +1172,72 @@ add_builtin(
1109
1172
  add_builtin(
1110
1173
  "transform_get_translation",
1111
1174
  input_types={"t": transformation(dtype=Float)},
1112
- value_func=lambda args, kwds, _: vector(length=3, dtype=infer_scalar_type(args)),
1175
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1113
1176
  group="Transformations",
1114
- doc="Return the translational part of a transform.",
1177
+ doc="Return the translational part of a transform ``t``.",
1115
1178
  )
1116
1179
  add_builtin(
1117
1180
  "transform_get_rotation",
1118
1181
  input_types={"t": transformation(dtype=Float)},
1119
- value_func=lambda args, kwds, _: quaternion(dtype=infer_scalar_type(args)),
1182
+ value_func=lambda arg_types, kwds, _: quaternion(dtype=infer_scalar_type(arg_types)),
1120
1183
  group="Transformations",
1121
- doc="Return the rotational part of a transform.",
1184
+ doc="Return the rotational part of a transform ``t``.",
1122
1185
  )
1123
1186
  add_builtin(
1124
1187
  "transform_multiply",
1125
1188
  input_types={"a": transformation(dtype=Float), "b": transformation(dtype=Float)},
1126
- value_func=lambda args, kwds, _: transformation(dtype=infer_scalar_type(args)),
1189
+ value_func=lambda arg_types, kwds, _: transformation(dtype=infer_scalar_type(arg_types)),
1127
1190
  group="Transformations",
1128
1191
  doc="Multiply two rigid body transformations together.",
1129
1192
  )
1130
1193
  add_builtin(
1131
1194
  "transform_point",
1132
1195
  input_types={"t": transformation(dtype=Scalar), "p": vector(length=3, dtype=Scalar)},
1133
- value_func=lambda args, kwds, _: vector(length=3, dtype=infer_scalar_type(args)),
1196
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1134
1197
  group="Transformations",
1135
- doc="Apply the transform to a point p treating the homogenous coordinate as w=1 (translation and rotation).",
1198
+ doc="Apply the transform to a point ``p`` treating the homogeneous coordinate as w=1 (translation and rotation).",
1136
1199
  )
1137
1200
  add_builtin(
1138
1201
  "transform_point",
1139
1202
  input_types={"m": matrix(shape=(4, 4), dtype=Scalar), "p": vector(length=3, dtype=Scalar)},
1140
- value_func=lambda args, kwds, _: vector(length=3, dtype=infer_scalar_type(args)),
1203
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1141
1204
  group="Vector Math",
1142
- doc="""Apply the transform to a point ``p`` treating the homogenous coordinate as w=1. The transformation is applied treating ``p`` as a column vector, e.g.: ``y = M*p``
1143
- note this is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T = p^T*M^T``. If the transform is coming from a library that uses row-vectors
1144
- then users should transpose the transformation matrix before calling this method.""",
1205
+ doc="""Apply the transform to a point ``p`` treating the homogeneous coordinate as w=1.
1206
+ The transformation is applied treating ``p`` as a column vector, e.g.: ``y = M*p``.
1207
+ Note this is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T = p^T*M^T``.
1208
+ If the transform is coming from a library that uses row-vectors, then users should transpose the transformation
1209
+ matrix before calling this method.""",
1145
1210
  )
1146
1211
  add_builtin(
1147
1212
  "transform_vector",
1148
1213
  input_types={"t": transformation(dtype=Scalar), "v": vector(length=3, dtype=Scalar)},
1149
- value_func=lambda args, kwds, _: vector(length=3, dtype=infer_scalar_type(args)),
1214
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1150
1215
  group="Transformations",
1151
- doc="Apply the transform to a vector v treating the homogenous coordinate as w=0 (rotation only).",
1216
+ doc="Apply the transform to a vector ``v`` treating the homogeneous coordinate as w=0 (rotation only).",
1152
1217
  )
1153
1218
  add_builtin(
1154
1219
  "transform_vector",
1155
1220
  input_types={"m": matrix(shape=(4, 4), dtype=Scalar), "v": vector(length=3, dtype=Scalar)},
1156
- value_func=lambda args, kwds, _: vector(length=3, dtype=infer_scalar_type(args)),
1221
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1157
1222
  group="Vector Math",
1158
- doc="""Apply the transform to a vector ``v`` treating the homogenous coordinate as w=0. The transformation is applied treating ``v`` as a column vector, e.g.: ``y = M*v``
1159
- note this is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T = v^T*M^T``. If the transform is coming from a library that uses row-vectors
1160
- then users should transpose the transformation matrix before calling this method.""",
1223
+ doc="""Apply the transform to a vector ``v`` treating the homogeneous coordinate as w=0.
1224
+ The transformation is applied treating ``v`` as a column vector, e.g.: ``y = M*v``
1225
+ note this is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T = v^T*M^T``.
1226
+ If the transform is coming from a library that uses row-vectors, then users should transpose the transformation
1227
+ matrix before calling this method.""",
1161
1228
  )
1162
1229
  add_builtin(
1163
1230
  "transform_inverse",
1164
1231
  input_types={"t": transformation(dtype=Float)},
1165
1232
  value_func=sametype_value_func(transformation(dtype=Float)),
1166
1233
  group="Transformations",
1167
- doc="Compute the inverse of the transform.",
1234
+ doc="Compute the inverse of the transformation ``t``.",
1168
1235
  )
1169
1236
  # ---------------------------------
1170
1237
  # Spatial Math
1171
1238
 
1172
1239
 
1173
- def spatial_vector_constructor_value_func(args, kwds, templates):
1240
+ def spatial_vector_constructor_value_func(arg_types, kwds, templates):
1174
1241
  if templates is None:
1175
1242
  return spatial_vector(dtype=Float)
1176
1243
 
@@ -1178,7 +1245,7 @@ def spatial_vector_constructor_value_func(args, kwds, templates):
1178
1245
  raise RuntimeError("Cannot use a generic type name in a kernel")
1179
1246
 
1180
1247
  vectype = templates[1]
1181
- if len(args) and infer_scalar_type(args) != vectype:
1248
+ if len(arg_types) and infer_scalar_type(arg_types) != vectype:
1182
1249
  raise RuntimeError("Wrong scalar type for spatial_vector<{}> constructor".format(",".join(map(str, templates))))
1183
1250
 
1184
1251
  return vector(length=6, dtype=vectype)
@@ -1190,7 +1257,7 @@ add_builtin(
1190
1257
  value_func=spatial_vector_constructor_value_func,
1191
1258
  native_func="vec_t",
1192
1259
  group="Spatial Math",
1193
- doc="Construct a 6d screw vector from two 3d vectors.",
1260
+ doc="Construct a 6D screw vector from two 3D vectors.",
1194
1261
  export=False,
1195
1262
  )
1196
1263
 
@@ -1198,7 +1265,7 @@ add_builtin(
1198
1265
  add_builtin(
1199
1266
  "spatial_adjoint",
1200
1267
  input_types={"r": matrix(shape=(3, 3), dtype=Float), "s": matrix(shape=(3, 3), dtype=Float)},
1201
- value_func=lambda args, kwds, _: matrix(shape=(6, 6), dtype=infer_scalar_type(args)),
1268
+ value_func=lambda arg_types, kwds, _: matrix(shape=(6, 6), dtype=infer_scalar_type(arg_types)),
1202
1269
  group="Spatial Math",
1203
1270
  doc="Construct a 6x6 spatial inertial matrix from two 3x3 diagonal blocks.",
1204
1271
  export=False,
@@ -1208,36 +1275,36 @@ add_builtin(
1208
1275
  input_types={"a": vector(length=6, dtype=Float), "b": vector(length=6, dtype=Float)},
1209
1276
  value_func=sametype_scalar_value_func,
1210
1277
  group="Spatial Math",
1211
- doc="Compute the dot product of two 6d screw vectors.",
1278
+ doc="Compute the dot product of two 6D screw vectors.",
1212
1279
  )
1213
1280
  add_builtin(
1214
1281
  "spatial_cross",
1215
1282
  input_types={"a": vector(length=6, dtype=Float), "b": vector(length=6, dtype=Float)},
1216
1283
  value_func=sametype_value_func(vector(length=6, dtype=Float)),
1217
1284
  group="Spatial Math",
1218
- doc="Compute the cross-product of two 6d screw vectors.",
1285
+ doc="Compute the cross product of two 6D screw vectors.",
1219
1286
  )
1220
1287
  add_builtin(
1221
1288
  "spatial_cross_dual",
1222
1289
  input_types={"a": vector(length=6, dtype=Float), "b": vector(length=6, dtype=Float)},
1223
1290
  value_func=sametype_value_func(vector(length=6, dtype=Float)),
1224
1291
  group="Spatial Math",
1225
- doc="Compute the dual cross-product of two 6d screw vectors.",
1292
+ doc="Compute the dual cross product of two 6D screw vectors.",
1226
1293
  )
1227
1294
 
1228
1295
  add_builtin(
1229
1296
  "spatial_top",
1230
1297
  input_types={"a": vector(length=6, dtype=Float)},
1231
- value_func=lambda args, kwds, _: vector(length=3, dtype=args[0].type._wp_scalar_type_),
1298
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=arg_types[0]._wp_scalar_type_),
1232
1299
  group="Spatial Math",
1233
- doc="Return the top (first) part of a 6d screw vector.",
1300
+ doc="Return the top (first) part of a 6D screw vector.",
1234
1301
  )
1235
1302
  add_builtin(
1236
1303
  "spatial_bottom",
1237
1304
  input_types={"a": vector(length=6, dtype=Float)},
1238
- value_func=lambda args, kwds, _: vector(length=3, dtype=args[0].type._wp_scalar_type_),
1305
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=arg_types[0]._wp_scalar_type_),
1239
1306
  group="Spatial Math",
1240
- doc="Return the bottom (second) part of a 6d screw vector.",
1307
+ doc="Return the bottom (second) part of a 6D screw vector.",
1241
1308
  )
1242
1309
 
1243
1310
  add_builtin(
@@ -1391,16 +1458,18 @@ add_builtin(
1391
1458
  },
1392
1459
  value_type=None,
1393
1460
  skip_replay=True,
1394
- doc="""Evaluate a multi-layer perceptron (MLP) layer in the form: ``out = act(weights*x + bias)``.
1461
+ doc="""Evaluate a multi-layer perceptron (MLP) layer in the form: ``out = act(weights*x + bias)``.
1395
1462
 
1396
1463
  :param weights: A layer's network weights with dimensions ``(m, n)``.
1397
1464
  :param bias: An array with dimensions ``(n)``.
1398
1465
  :param activation: A ``wp.func`` function that takes a single scalar float as input and returns a scalar float as output
1399
- :param index: The batch item to process, typically each thread will process 1 item in the batch, in this case index should be ``wp.tid()``
1466
+ :param index: The batch item to process, typically each thread will process one item in the batch, in which case
1467
+ index should be ``wp.tid()``
1400
1468
  :param x: The feature matrix with dimensions ``(n, b)``
1401
1469
  :param out: The network output with dimensions ``(m, b)``
1402
1470
 
1403
- :note: Feature and output matrices are transposed compared to some other frameworks such as PyTorch. All matrices are assumed to be stored in flattened row-major memory layout (NumPy default).""",
1471
+ :note: Feature and output matrices are transposed compared to some other frameworks such as PyTorch.
1472
+ All matrices are assumed to be stored in flattened row-major memory layout (NumPy default).""",
1404
1473
  group="Utility",
1405
1474
  )
1406
1475
 
@@ -1413,12 +1482,12 @@ add_builtin(
1413
1482
  input_types={"id": uint64, "lower": vec3, "upper": vec3},
1414
1483
  value_type=bvh_query_t,
1415
1484
  group="Geometry",
1416
- doc="""Construct an axis-aligned bounding box query against a bvh object. This query can be used to iterate over all bounds
1417
- inside a bvh. Returns an object that is used to track state during bvh traversal.
1418
-
1419
- :param id: The bvh identifier
1420
- :param lower: The lower bound of the bounding box in bvh space
1421
- :param upper: The upper bound of the bounding box in bvh space""",
1485
+ doc="""Construct an axis-aligned bounding box query against a BVH object. This query can be used to iterate over all bounds
1486
+ inside a BVH.
1487
+
1488
+ :param id: The BVH identifier
1489
+ :param lower: The lower bound of the bounding box in BVH space
1490
+ :param upper: The upper bound of the bounding box in BVH space""",
1422
1491
  )
1423
1492
 
1424
1493
  add_builtin(
@@ -1426,21 +1495,21 @@ add_builtin(
1426
1495
  input_types={"id": uint64, "start": vec3, "dir": vec3},
1427
1496
  value_type=bvh_query_t,
1428
1497
  group="Geometry",
1429
- doc="""Construct a ray query against a bvh object. This query can be used to iterate over all bounds
1430
- that intersect the ray. Returns an object that is used to track state during bvh traversal.
1431
-
1432
- :param id: The bvh identifier
1433
- :param start: The start of the ray in bvh space
1434
- :param dir: The direction of the ray in bvh space""",
1498
+ doc="""Construct a ray query against a BVH object. This query can be used to iterate over all bounds
1499
+ that intersect the ray.
1500
+
1501
+ :param id: The BVH identifier
1502
+ :param start: The start of the ray in BVH space
1503
+ :param dir: The direction of the ray in BVH space""",
1435
1504
  )
1436
1505
 
1437
1506
  add_builtin(
1438
1507
  "bvh_query_next",
1439
1508
  input_types={"query": bvh_query_t, "index": int},
1440
- value_type=bool,
1509
+ value_type=builtins.bool,
1441
1510
  group="Geometry",
1442
- doc="""Move to the next bound returned by the query. The index of the current bound is stored in ``index``, returns ``False``
1443
- if there are no more overlapping bound.""",
1511
+ doc="""Move to the next bound returned by the query.
1512
+ The index of the current bound is stored in ``index``, returns ``False`` if there are no more overlapping bound.""",
1444
1513
  )
1445
1514
 
1446
1515
  add_builtin(
@@ -1454,20 +1523,44 @@ add_builtin(
1454
1523
  "bary_u": float,
1455
1524
  "bary_v": float,
1456
1525
  },
1457
- value_type=bool,
1526
+ value_type=builtins.bool,
1458
1527
  group="Geometry",
1459
- doc="""Computes the closest point on the mesh with identifier `id` to the given point in space. Returns ``True`` if a point < ``max_dist`` is found.
1528
+ doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space. Returns ``True`` if a point < ``max_dist`` is found.
1460
1529
 
1461
- Identifies the sign of the distance using additional ray-casts to determine if the point is inside or outside. This method is relatively robust, but
1462
- does increase computational cost. See below for additional sign determination methods.
1530
+ Identifies the sign of the distance using additional ray-casts to determine if the point is inside or outside.
1531
+ This method is relatively robust, but does increase computational cost.
1532
+ See below for additional sign determination methods.
1463
1533
 
1464
1534
  :param id: The mesh identifier
1465
1535
  :param point: The point in space to query
1466
1536
  :param max_dist: Mesh faces above this distance will not be considered by the query
1467
- :param inside: Returns a value < 0 if query point is inside the mesh, >=0 otherwise. Note that mesh must be watertight for this to be robust
1537
+ :param inside: Returns a value < 0 if query point is inside the mesh, >=0 otherwise.
1538
+ Note that mesh must be watertight for this to be robust
1468
1539
  :param face: Returns the index of the closest face
1469
1540
  :param bary_u: Returns the barycentric u coordinate of the closest point
1470
1541
  :param bary_v: Returns the barycentric v coordinate of the closest point""",
1542
+ hidden=True,
1543
+ )
1544
+
1545
+ add_builtin(
1546
+ "mesh_query_point",
1547
+ input_types={
1548
+ "id": uint64,
1549
+ "point": vec3,
1550
+ "max_dist": float,
1551
+ },
1552
+ value_type=mesh_query_point_t,
1553
+ group="Geometry",
1554
+ doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space.
1555
+
1556
+ Identifies the sign of the distance using additional ray-casts to determine if the point is inside or outside.
1557
+ This method is relatively robust, but does increase computational cost.
1558
+ See below for additional sign determination methods.
1559
+
1560
+ :param id: The mesh identifier
1561
+ :param point: The point in space to query
1562
+ :param max_dist: Mesh faces above this distance will not be considered by the query""",
1563
+ require_original_output_arg=True,
1471
1564
  )
1472
1565
 
1473
1566
  add_builtin(
@@ -1480,9 +1573,9 @@ add_builtin(
1480
1573
  "bary_u": float,
1481
1574
  "bary_v": float,
1482
1575
  },
1483
- value_type=bool,
1576
+ value_type=builtins.bool,
1484
1577
  group="Geometry",
1485
- doc="""Computes the closest point on the mesh with identifier `id` to the given point in space. Returns ``True`` if a point < ``max_dist`` is found.
1578
+ doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space. Returns ``True`` if a point < ``max_dist`` is found.
1486
1579
 
1487
1580
  This method does not compute the sign of the point (inside/outside) which makes it faster than other point query methods.
1488
1581
 
@@ -1492,6 +1585,70 @@ add_builtin(
1492
1585
  :param face: Returns the index of the closest face
1493
1586
  :param bary_u: Returns the barycentric u coordinate of the closest point
1494
1587
  :param bary_v: Returns the barycentric v coordinate of the closest point""",
1588
+ hidden=True,
1589
+ )
1590
+
1591
+ add_builtin(
1592
+ "mesh_query_point_no_sign",
1593
+ input_types={
1594
+ "id": uint64,
1595
+ "point": vec3,
1596
+ "max_dist": float,
1597
+ },
1598
+ value_type=mesh_query_point_t,
1599
+ group="Geometry",
1600
+ doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space.
1601
+
1602
+ This method does not compute the sign of the point (inside/outside) which makes it faster than other point query methods.
1603
+
1604
+ :param id: The mesh identifier
1605
+ :param point: The point in space to query
1606
+ :param max_dist: Mesh faces above this distance will not be considered by the query""",
1607
+ require_original_output_arg=True,
1608
+ )
1609
+
1610
+ add_builtin(
1611
+ "mesh_query_furthest_point_no_sign",
1612
+ input_types={
1613
+ "id": uint64,
1614
+ "point": vec3,
1615
+ "min_dist": float,
1616
+ "face": int,
1617
+ "bary_u": float,
1618
+ "bary_v": float,
1619
+ },
1620
+ value_type=builtins.bool,
1621
+ group="Geometry",
1622
+ doc="""Computes the furthest point on the mesh with identifier `id` to the given point in space. Returns ``True`` if a point > ``min_dist`` is found.
1623
+
1624
+ This method does not compute the sign of the point (inside/outside).
1625
+
1626
+ :param id: The mesh identifier
1627
+ :param point: The point in space to query
1628
+ :param min_dist: Mesh faces below this distance will not be considered by the query
1629
+ :param face: Returns the index of the furthest face
1630
+ :param bary_u: Returns the barycentric u coordinate of the furthest point
1631
+ :param bary_v: Returns the barycentric v coordinate of the furthest point""",
1632
+ hidden=True,
1633
+ )
1634
+
1635
+ add_builtin(
1636
+ "mesh_query_furthest_point_no_sign",
1637
+ input_types={
1638
+ "id": uint64,
1639
+ "point": vec3,
1640
+ "min_dist": float,
1641
+ },
1642
+ value_type=mesh_query_point_t,
1643
+ group="Geometry",
1644
+ doc="""Computes the furthest point on the mesh with identifier `id` to the given point in space.
1645
+
1646
+ This method does not compute the sign of the point (inside/outside).
1647
+
1648
+ :param id: The mesh identifier
1649
+ :param point: The point in space to query
1650
+ :param min_dist: Mesh faces below this distance will not be considered by the query""",
1651
+ require_original_output_arg=True,
1495
1652
  )
1496
1653
 
1497
1654
  add_builtin(
@@ -1507,21 +1664,50 @@ add_builtin(
1507
1664
  "epsilon": float,
1508
1665
  },
1509
1666
  defaults={"epsilon": 1.0e-3},
1510
- value_type=bool,
1667
+ value_type=builtins.bool,
1511
1668
  group="Geometry",
1512
- doc="""Computes the closest point on the mesh with identifier `id` to the given point in space. Returns ``True`` if a point < ``max_dist`` is found.
1513
-
1514
- Identifies the sign of the distance (inside/outside) using the angle-weighted pseudo normal. This approach to sign determination is robust for well conditioned meshes
1515
- that are watertight and non-self intersecting, it is also comparatively fast to compute.
1669
+ doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space. Returns ``True`` if a point < ``max_dist`` is found.
1670
+
1671
+ Identifies the sign of the distance (inside/outside) using the angle-weighted pseudo normal.
1672
+ This approach to sign determination is robust for well conditioned meshes that are watertight and non-self intersecting.
1673
+ It is also comparatively fast to compute.
1516
1674
 
1517
1675
  :param id: The mesh identifier
1518
1676
  :param point: The point in space to query
1519
1677
  :param max_dist: Mesh faces above this distance will not be considered by the query
1520
- :param inside: Returns a value < 0 if query point is inside the mesh, >=0 otherwise. Note that mesh must be watertight for this to be robust
1678
+ :param inside: Returns a value < 0 if query point is inside the mesh, >=0 otherwise.
1679
+ Note that mesh must be watertight for this to be robust
1521
1680
  :param face: Returns the index of the closest face
1522
1681
  :param bary_u: Returns the barycentric u coordinate of the closest point
1523
1682
  :param bary_v: Returns the barycentric v coordinate of the closest point
1524
- :param epsilon: Epsilon treating distance values as equal, when locating the minimum distance vertex/face/edge, as a fraction of the average edge length, also for treating closest point as being on edge/vertex default 1e-3""",
1683
+ :param epsilon: Epsilon treating distance values as equal, when locating the minimum distance vertex/face/edge, as a
1684
+ fraction of the average edge length, also for treating closest point as being on edge/vertex default 1e-3""",
1685
+ hidden=True,
1686
+ )
1687
+
1688
+ add_builtin(
1689
+ "mesh_query_point_sign_normal",
1690
+ input_types={
1691
+ "id": uint64,
1692
+ "point": vec3,
1693
+ "max_dist": float,
1694
+ "epsilon": float,
1695
+ },
1696
+ defaults={"epsilon": 1.0e-3},
1697
+ value_type=mesh_query_point_t,
1698
+ group="Geometry",
1699
+ doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space.
1700
+
1701
+ Identifies the sign of the distance (inside/outside) using the angle-weighted pseudo normal.
1702
+ This approach to sign determination is robust for well conditioned meshes that are watertight and non-self intersecting.
1703
+ It is also comparatively fast to compute.
1704
+
1705
+ :param id: The mesh identifier
1706
+ :param point: The point in space to query
1707
+ :param max_dist: Mesh faces above this distance will not be considered by the query
1708
+ :param epsilon: Epsilon treating distance values as equal, when locating the minimum distance vertex/face/edge, as a
1709
+ fraction of the average edge length, also for treating closest point as being on edge/vertex default 1e-3""",
1710
+ require_original_output_arg=True,
1525
1711
  )
1526
1712
 
1527
1713
  add_builtin(
@@ -1538,25 +1724,55 @@ add_builtin(
1538
1724
  "threshold": float,
1539
1725
  },
1540
1726
  defaults={"accuracy": 2.0, "threshold": 0.5},
1541
- value_type=bool,
1727
+ value_type=builtins.bool,
1542
1728
  group="Geometry",
1543
- doc="""Computes the closest point on the mesh with identifier `id` to the given point in space. Returns ``True`` if a point < ``max_dist`` is found.
1544
-
1729
+ doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given point in space. Returns ``True`` if a point < ``max_dist`` is found.
1730
+
1545
1731
  Identifies the sign using the winding number of the mesh relative to the query point. This method of sign determination is robust for poorly conditioned meshes
1546
1732
  and provides a smooth approximation to sign even when the mesh is not watertight. This method is the most robust and accurate of the sign determination meshes
1547
1733
  but also the most expensive.
1548
-
1549
- Note that the Mesh object must be constructed with ``suport_winding_number=True`` for this method to return correct results.
1734
+
1735
+ .. note:: The :class:`Mesh` object must be constructed with ``support_winding_number=True`` for this method to return correct results.
1550
1736
 
1551
1737
  :param id: The mesh identifier
1552
1738
  :param point: The point in space to query
1553
1739
  :param max_dist: Mesh faces above this distance will not be considered by the query
1554
- :param inside: Returns a value < 0 if query point is inside the mesh, >=0 otherwise. Note that mesh must be watertight for this to be robust
1740
+ :param inside: Returns a value < 0 if query point is inside the mesh, >=0 otherwise.
1741
+ Note that mesh must be watertight for this to be robust
1555
1742
  :param face: Returns the index of the closest face
1556
1743
  :param bary_u: Returns the barycentric u coordinate of the closest point
1557
1744
  :param bary_v: Returns the barycentric v coordinate of the closest point
1558
- :param accuracy: Accuracy for computing the winding number with fast winding number method utilizing second order dipole approximation, default 2.0
1745
+ :param accuracy: Accuracy for computing the winding number with fast winding number method utilizing second-order dipole approximation, default 2.0
1746
+ :param threshold: The threshold of the winding number to be considered inside, default 0.5""",
1747
+ hidden=True,
1748
+ )
1749
+
1750
+ add_builtin(
1751
+ "mesh_query_point_sign_winding_number",
1752
+ input_types={
1753
+ "id": uint64,
1754
+ "point": vec3,
1755
+ "max_dist": float,
1756
+ "accuracy": float,
1757
+ "threshold": float,
1758
+ },
1759
+ defaults={"accuracy": 2.0, "threshold": 0.5},
1760
+ value_type=mesh_query_point_t,
1761
+ group="Geometry",
1762
+ doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given point in space.
1763
+
1764
+ Identifies the sign using the winding number of the mesh relative to the query point. This method of sign determination is robust for poorly conditioned meshes
1765
+ and provides a smooth approximation to sign even when the mesh is not watertight. This method is the most robust and accurate of the sign determination meshes
1766
+ but also the most expensive.
1767
+
1768
+ .. note:: The :class:`Mesh` object must be constructed with ``support_winding_number=True`` for this method to return correct results.
1769
+
1770
+ :param id: The mesh identifier
1771
+ :param point: The point in space to query
1772
+ :param max_dist: Mesh faces above this distance will not be considered by the query
1773
+ :param accuracy: Accuracy for computing the winding number with fast winding number method utilizing second-order dipole approximation, default 2.0
1559
1774
  :param threshold: The threshold of the winding number to be considered inside, default 0.5""",
1775
+ require_original_output_arg=True,
1560
1776
  )
1561
1777
 
1562
1778
  add_builtin(
@@ -1573,9 +1789,9 @@ add_builtin(
1573
1789
  "normal": vec3,
1574
1790
  "face": int,
1575
1791
  },
1576
- value_type=bool,
1792
+ value_type=builtins.bool,
1577
1793
  group="Geometry",
1578
- doc="""Computes the closest ray hit on the mesh with identifier `id`, returns ``True`` if a point < ``max_t`` is found.
1794
+ doc="""Computes the closest ray hit on the :class:`Mesh` with identifier ``id``, returns ``True`` if a hit < ``max_t`` is found.
1579
1795
 
1580
1796
  :param id: The mesh identifier
1581
1797
  :param start: The start point of the ray
@@ -1584,9 +1800,29 @@ add_builtin(
1584
1800
  :param t: Returns the distance of the closest hit along the ray
1585
1801
  :param bary_u: Returns the barycentric u coordinate of the closest hit
1586
1802
  :param bary_v: Returns the barycentric v coordinate of the closest hit
1587
- :param sign: Returns a value > 0 if the hit ray hit front of the face, returns < 0 otherwise
1803
+ :param sign: Returns a value > 0 if the ray hit in front of the face, returns < 0 otherwise
1588
1804
  :param normal: Returns the face normal
1589
1805
  :param face: Returns the index of the hit face""",
1806
+ hidden=True,
1807
+ )
1808
+
1809
+ add_builtin(
1810
+ "mesh_query_ray",
1811
+ input_types={
1812
+ "id": uint64,
1813
+ "start": vec3,
1814
+ "dir": vec3,
1815
+ "max_t": float,
1816
+ },
1817
+ value_type=mesh_query_ray_t,
1818
+ group="Geometry",
1819
+ doc="""Computes the closest ray hit on the :class:`Mesh` with identifier ``id``.
1820
+
1821
+ :param id: The mesh identifier
1822
+ :param start: The start point of the ray
1823
+ :param dir: The ray direction (should be normalized)
1824
+ :param max_t: The maximum distance along the ray to check for intersections""",
1825
+ require_original_output_arg=True,
1590
1826
  )
1591
1827
 
1592
1828
  add_builtin(
@@ -1594,9 +1830,9 @@ add_builtin(
1594
1830
  input_types={"id": uint64, "lower": vec3, "upper": vec3},
1595
1831
  value_type=mesh_query_aabb_t,
1596
1832
  group="Geometry",
1597
- doc="""Construct an axis-aligned bounding box query against a mesh object. This query can be used to iterate over all triangles
1598
- inside a volume. Returns an object that is used to track state during mesh traversal.
1599
-
1833
+ doc="""Construct an axis-aligned bounding box query against a :class:`Mesh`.
1834
+ This query can be used to iterate over all triangles inside a volume.
1835
+
1600
1836
  :param id: The mesh identifier
1601
1837
  :param lower: The lower bound of the bounding box in mesh space
1602
1838
  :param upper: The upper bound of the bounding box in mesh space""",
@@ -1605,10 +1841,10 @@ add_builtin(
1605
1841
  add_builtin(
1606
1842
  "mesh_query_aabb_next",
1607
1843
  input_types={"query": mesh_query_aabb_t, "index": int},
1608
- value_type=bool,
1844
+ value_type=builtins.bool,
1609
1845
  group="Geometry",
1610
- doc="""Move to the next triangle overlapping the query bounding box. The index of the current face is stored in ``index``, returns ``False``
1611
- if there are no more overlapping triangles.""",
1846
+ doc="""Move to the next triangle overlapping the query bounding box.
1847
+ The index of the current face is stored in ``index``, returns ``False`` if there are no more overlapping triangles.""",
1612
1848
  )
1613
1849
 
1614
1850
  add_builtin(
@@ -1616,7 +1852,7 @@ add_builtin(
1616
1852
  input_types={"id": uint64, "face": int, "bary_u": float, "bary_v": float},
1617
1853
  value_type=vec3,
1618
1854
  group="Geometry",
1619
- doc="""Evaluates the position on the mesh given a face index, and barycentric coordinates.""",
1855
+ doc="""Evaluates the position on the :class:`Mesh` given a face index and barycentric coordinates.""",
1620
1856
  )
1621
1857
 
1622
1858
  add_builtin(
@@ -1624,7 +1860,7 @@ add_builtin(
1624
1860
  input_types={"id": uint64, "face": int, "bary_u": float, "bary_v": float},
1625
1861
  value_type=vec3,
1626
1862
  group="Geometry",
1627
- doc="""Evaluates the velocity on the mesh given a face index, and barycentric coordinates.""",
1863
+ doc="""Evaluates the velocity on the :class:`Mesh` given a face index and barycentric coordinates.""",
1628
1864
  )
1629
1865
 
1630
1866
  add_builtin(
@@ -1632,14 +1868,14 @@ add_builtin(
1632
1868
  input_types={"id": uint64, "point": vec3, "max_dist": float},
1633
1869
  value_type=hash_grid_query_t,
1634
1870
  group="Geometry",
1635
- doc="""Construct a point query against a hash grid. This query can be used to iterate over all neighboring points withing a
1636
- fixed radius from the query point. Returns an object that is used to track state during neighbor traversal.""",
1871
+ doc="Construct a point query against a :class:`HashGrid`. This query can be used to iterate over all neighboring points "
1872
+ "within a fixed radius from the query point.",
1637
1873
  )
1638
1874
 
1639
1875
  add_builtin(
1640
1876
  "hash_grid_query_next",
1641
1877
  input_types={"query": hash_grid_query_t, "index": int},
1642
- value_type=bool,
1878
+ value_type=builtins.bool,
1643
1879
  group="Geometry",
1644
1880
  doc="""Move to the next point in the hash grid query. The index of the current neighbor is stored in ``index``, returns ``False``
1645
1881
  if there are no more neighbors.""",
@@ -1650,8 +1886,10 @@ add_builtin(
1650
1886
  input_types={"id": uint64, "index": int},
1651
1887
  value_type=int,
1652
1888
  group="Geometry",
1653
- doc="""Return the index of a point in the grid, this can be used to re-order threads such that grid
1654
- traversal occurs in a spatially coherent order.""",
1889
+ doc="""Return the index of a point in the :class:`HashGrid`. This can be used to reorder threads such that grid
1890
+ traversal occurs in a spatially coherent order.
1891
+
1892
+ Returns -1 if the :class:`HashGrid` has not been reserved.""",
1655
1893
  )
1656
1894
 
1657
1895
  add_builtin(
@@ -1750,7 +1988,17 @@ add_builtin(
1750
1988
  input_types={"id": uint64, "uvw": vec3, "sampling_mode": int},
1751
1989
  value_type=float,
1752
1990
  group="Volumes",
1753
- doc="""Sample the volume given by ``id`` at the volume local-space point ``uvw``. Interpolation should be ``wp.Volume.CLOSEST``, or ``wp.Volume.LINEAR.``""",
1991
+ doc="""Sample the volume given by ``id`` at the volume local-space point ``uvw``.
1992
+ Interpolation should be :attr:`warp.Volume.CLOSEST` or :attr:`wp.Volume.LINEAR.`""",
1993
+ )
1994
+
1995
+ add_builtin(
1996
+ "volume_sample_grad_f",
1997
+ input_types={"id": uint64, "uvw": vec3, "sampling_mode": int, "grad": vec3},
1998
+ value_type=float,
1999
+ group="Volumes",
2000
+ doc="""Sample the volume and its gradient given by ``id`` at the volume local-space point ``uvw``.
2001
+ Interpolation should be :attr:`warp.Volume.CLOSEST` or :attr:`wp.Volume.LINEAR.`""",
1754
2002
  )
1755
2003
 
1756
2004
  add_builtin(
@@ -1758,14 +2006,15 @@ add_builtin(
1758
2006
  input_types={"id": uint64, "i": int, "j": int, "k": int},
1759
2007
  value_type=float,
1760
2008
  group="Volumes",
1761
- doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k``, if the voxel at this index does not exist this function returns the background value""",
2009
+ doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k``.
2010
+ If the voxel at this index does not exist, this function returns the background value""",
1762
2011
  )
1763
2012
 
1764
2013
  add_builtin(
1765
2014
  "volume_store_f",
1766
2015
  input_types={"id": uint64, "i": int, "j": int, "k": int, "value": float},
1767
2016
  group="Volumes",
1768
- doc="""Store the value at voxel with coordinates ``i``, ``j``, ``k``.""",
2017
+ doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
1769
2018
  )
1770
2019
 
1771
2020
  add_builtin(
@@ -1773,7 +2022,8 @@ add_builtin(
1773
2022
  input_types={"id": uint64, "uvw": vec3, "sampling_mode": int},
1774
2023
  value_type=vec3,
1775
2024
  group="Volumes",
1776
- doc="""Sample the vector volume given by ``id`` at the volume local-space point ``uvw``. Interpolation should be ``wp.Volume.CLOSEST``, or ``wp.Volume.LINEAR.``""",
2025
+ doc="""Sample the vector volume given by ``id`` at the volume local-space point ``uvw``.
2026
+ Interpolation should be :attr:`warp.Volume.CLOSEST` or :attr:`wp.Volume.LINEAR.`""",
1777
2027
  )
1778
2028
 
1779
2029
  add_builtin(
@@ -1781,14 +2031,15 @@ add_builtin(
1781
2031
  input_types={"id": uint64, "i": int, "j": int, "k": int},
1782
2032
  value_type=vec3,
1783
2033
  group="Volumes",
1784
- doc="""Returns the vector value of voxel with coordinates ``i``, ``j``, ``k``, if the voxel at this index does not exist this function returns the background value""",
2034
+ doc="""Returns the vector value of voxel with coordinates ``i``, ``j``, ``k``.
2035
+ If the voxel at this index does not exist, this function returns the background value.""",
1785
2036
  )
1786
2037
 
1787
2038
  add_builtin(
1788
2039
  "volume_store_v",
1789
2040
  input_types={"id": uint64, "i": int, "j": int, "k": int, "value": vec3},
1790
2041
  group="Volumes",
1791
- doc="""Store the value at voxel with coordinates ``i``, ``j``, ``k``.""",
2042
+ doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
1792
2043
  )
1793
2044
 
1794
2045
  add_builtin(
@@ -1796,7 +2047,7 @@ add_builtin(
1796
2047
  input_types={"id": uint64, "uvw": vec3},
1797
2048
  value_type=int,
1798
2049
  group="Volumes",
1799
- doc="""Sample the int32 volume given by ``id`` at the volume local-space point ``uvw``. """,
2050
+ doc="""Sample the :class:`int32` volume given by ``id`` at the volume local-space point ``uvw``. """,
1800
2051
  )
1801
2052
 
1802
2053
  add_builtin(
@@ -1804,14 +2055,15 @@ add_builtin(
1804
2055
  input_types={"id": uint64, "i": int, "j": int, "k": int},
1805
2056
  value_type=int,
1806
2057
  group="Volumes",
1807
- doc="""Returns the int32 value of voxel with coordinates ``i``, ``j``, ``k``, if the voxel at this index does not exist this function returns the background value""",
2058
+ doc="""Returns the :class:`int32` value of voxel with coordinates ``i``, ``j``, ``k``.
2059
+ If the voxel at this index does not exist, this function returns the background value.""",
1808
2060
  )
1809
2061
 
1810
2062
  add_builtin(
1811
2063
  "volume_store_i",
1812
2064
  input_types={"id": uint64, "i": int, "j": int, "k": int, "value": int},
1813
2065
  group="Volumes",
1814
- doc="""Store the value at voxel with coordinates ``i``, ``j``, ``k``.""",
2066
+ doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
1815
2067
  )
1816
2068
 
1817
2069
  add_builtin(
@@ -1819,28 +2071,28 @@ add_builtin(
1819
2071
  input_types={"id": uint64, "uvw": vec3},
1820
2072
  value_type=vec3,
1821
2073
  group="Volumes",
1822
- doc="""Transform a point defined in volume index space to world space given the volume's intrinsic affine transformation.""",
2074
+ doc="""Transform a point ``uvw`` defined in volume index space to world space given the volume's intrinsic affine transformation.""",
1823
2075
  )
1824
2076
  add_builtin(
1825
2077
  "volume_world_to_index",
1826
2078
  input_types={"id": uint64, "xyz": vec3},
1827
2079
  value_type=vec3,
1828
2080
  group="Volumes",
1829
- doc="""Transform a point defined in volume world space to the volume's index space, given the volume's intrinsic affine transformation.""",
2081
+ doc="""Transform a point ``xyz`` defined in volume world space to the volume's index space given the volume's intrinsic affine transformation.""",
1830
2082
  )
1831
2083
  add_builtin(
1832
2084
  "volume_index_to_world_dir",
1833
2085
  input_types={"id": uint64, "uvw": vec3},
1834
2086
  value_type=vec3,
1835
2087
  group="Volumes",
1836
- doc="""Transform a direction defined in volume index space to world space given the volume's intrinsic affine transformation.""",
2088
+ doc="""Transform a direction ``uvw`` defined in volume index space to world space given the volume's intrinsic affine transformation.""",
1837
2089
  )
1838
2090
  add_builtin(
1839
2091
  "volume_world_to_index_dir",
1840
2092
  input_types={"id": uint64, "xyz": vec3},
1841
2093
  value_type=vec3,
1842
2094
  group="Volumes",
1843
- doc="""Transform a direction defined in volume world space to the volume's index space, given the volume's intrinsic affine transformation.""",
2095
+ doc="""Transform a direction ``xyz`` defined in volume world space to the volume's index space given the volume's intrinsic affine transformation.""",
1844
2096
  )
1845
2097
 
1846
2098
 
@@ -1860,7 +2112,7 @@ add_builtin(
1860
2112
  input_types={"seed": int, "offset": int},
1861
2113
  value_type=uint32,
1862
2114
  group="Random",
1863
- doc="""Initialize a new random number generator given a user-defined seed and an offset.
2115
+ doc="""Initialize a new random number generator given a user-defined seed and an offset.
1864
2116
  This alternative constructor can be useful in parallel programs, where a kernel as a whole should share a seed,
1865
2117
  but each thread should generate uncorrelated values. In this case usage should be ``r = rand_init(seed, tid)``""",
1866
2118
  )
@@ -1870,31 +2122,31 @@ add_builtin(
1870
2122
  input_types={"state": uint32},
1871
2123
  value_type=int,
1872
2124
  group="Random",
1873
- doc="Return a random integer between [0, 2^32)",
2125
+ doc="Return a random integer in the range [0, 2^32).",
1874
2126
  )
1875
2127
  add_builtin(
1876
2128
  "randi",
1877
2129
  input_types={"state": uint32, "min": int, "max": int},
1878
2130
  value_type=int,
1879
2131
  group="Random",
1880
- doc="Return a random integer between [min, max)",
2132
+ doc="Return a random integer between [min, max).",
1881
2133
  )
1882
2134
  add_builtin(
1883
2135
  "randf",
1884
2136
  input_types={"state": uint32},
1885
2137
  value_type=float,
1886
2138
  group="Random",
1887
- doc="Return a random float between [0.0, 1.0)",
2139
+ doc="Return a random float between [0.0, 1.0).",
1888
2140
  )
1889
2141
  add_builtin(
1890
2142
  "randf",
1891
2143
  input_types={"state": uint32, "min": float, "max": float},
1892
2144
  value_type=float,
1893
2145
  group="Random",
1894
- doc="Return a random float between [min, max)",
2146
+ doc="Return a random float between [min, max).",
1895
2147
  )
1896
2148
  add_builtin(
1897
- "randn", input_types={"state": uint32}, value_type=float, group="Random", doc="Sample a normal distribution"
2149
+ "randn", input_types={"state": uint32}, value_type=float, group="Random", doc="Sample a normal distribution."
1898
2150
  )
1899
2151
 
1900
2152
  add_builtin(
@@ -1902,70 +2154,70 @@ add_builtin(
1902
2154
  input_types={"state": uint32, "cdf": array(dtype=float)},
1903
2155
  value_type=int,
1904
2156
  group="Random",
1905
- doc="Inverse transform sample a cumulative distribution function",
2157
+ doc="Inverse-transform sample a cumulative distribution function.",
1906
2158
  )
1907
2159
  add_builtin(
1908
2160
  "sample_triangle",
1909
2161
  input_types={"state": uint32},
1910
2162
  value_type=vec2,
1911
2163
  group="Random",
1912
- doc="Uniformly sample a triangle. Returns sample barycentric coordinates",
2164
+ doc="Uniformly sample a triangle. Returns sample barycentric coordinates.",
1913
2165
  )
1914
2166
  add_builtin(
1915
2167
  "sample_unit_ring",
1916
2168
  input_types={"state": uint32},
1917
2169
  value_type=vec2,
1918
2170
  group="Random",
1919
- doc="Uniformly sample a ring in the xy plane",
2171
+ doc="Uniformly sample a ring in the xy plane.",
1920
2172
  )
1921
2173
  add_builtin(
1922
2174
  "sample_unit_disk",
1923
2175
  input_types={"state": uint32},
1924
2176
  value_type=vec2,
1925
2177
  group="Random",
1926
- doc="Uniformly sample a disk in the xy plane",
2178
+ doc="Uniformly sample a disk in the xy plane.",
1927
2179
  )
1928
2180
  add_builtin(
1929
2181
  "sample_unit_sphere_surface",
1930
2182
  input_types={"state": uint32},
1931
2183
  value_type=vec3,
1932
2184
  group="Random",
1933
- doc="Uniformly sample a unit sphere surface",
2185
+ doc="Uniformly sample a unit sphere surface.",
1934
2186
  )
1935
2187
  add_builtin(
1936
2188
  "sample_unit_sphere",
1937
2189
  input_types={"state": uint32},
1938
2190
  value_type=vec3,
1939
2191
  group="Random",
1940
- doc="Uniformly sample a unit sphere",
2192
+ doc="Uniformly sample a unit sphere.",
1941
2193
  )
1942
2194
  add_builtin(
1943
2195
  "sample_unit_hemisphere_surface",
1944
2196
  input_types={"state": uint32},
1945
2197
  value_type=vec3,
1946
2198
  group="Random",
1947
- doc="Uniformly sample a unit hemisphere surface",
2199
+ doc="Uniformly sample a unit hemisphere surface.",
1948
2200
  )
1949
2201
  add_builtin(
1950
2202
  "sample_unit_hemisphere",
1951
2203
  input_types={"state": uint32},
1952
2204
  value_type=vec3,
1953
2205
  group="Random",
1954
- doc="Uniformly sample a unit hemisphere",
2206
+ doc="Uniformly sample a unit hemisphere.",
1955
2207
  )
1956
2208
  add_builtin(
1957
2209
  "sample_unit_square",
1958
2210
  input_types={"state": uint32},
1959
2211
  value_type=vec2,
1960
2212
  group="Random",
1961
- doc="Uniformly sample a unit square",
2213
+ doc="Uniformly sample a unit square.",
1962
2214
  )
1963
2215
  add_builtin(
1964
2216
  "sample_unit_cube",
1965
2217
  input_types={"state": uint32},
1966
2218
  value_type=vec3,
1967
2219
  group="Random",
1968
- doc="Uniformly sample a unit cube",
2220
+ doc="Uniformly sample a unit cube.",
1969
2221
  )
1970
2222
 
1971
2223
  add_builtin(
@@ -1974,9 +2226,9 @@ add_builtin(
1974
2226
  value_type=uint32,
1975
2227
  group="Random",
1976
2228
  doc="""Generate a random sample from a Poisson distribution.
1977
-
1978
- :param state: RNG state
1979
- :param lam: The expected value of the distribution""",
2229
+
2230
+ :param state: RNG state
2231
+ :param lam: The expected value of the distribution""",
1980
2232
  )
1981
2233
 
1982
2234
  add_builtin(
@@ -1984,28 +2236,28 @@ add_builtin(
1984
2236
  input_types={"state": uint32, "x": float},
1985
2237
  value_type=float,
1986
2238
  group="Random",
1987
- doc="Non-periodic Perlin-style noise in 1d.",
2239
+ doc="Non-periodic Perlin-style noise in 1D.",
1988
2240
  )
1989
2241
  add_builtin(
1990
2242
  "noise",
1991
2243
  input_types={"state": uint32, "xy": vec2},
1992
2244
  value_type=float,
1993
2245
  group="Random",
1994
- doc="Non-periodic Perlin-style noise in 2d.",
2246
+ doc="Non-periodic Perlin-style noise in 2D.",
1995
2247
  )
1996
2248
  add_builtin(
1997
2249
  "noise",
1998
2250
  input_types={"state": uint32, "xyz": vec3},
1999
2251
  value_type=float,
2000
2252
  group="Random",
2001
- doc="Non-periodic Perlin-style noise in 3d.",
2253
+ doc="Non-periodic Perlin-style noise in 3D.",
2002
2254
  )
2003
2255
  add_builtin(
2004
2256
  "noise",
2005
2257
  input_types={"state": uint32, "xyzt": vec4},
2006
2258
  value_type=float,
2007
2259
  group="Random",
2008
- doc="Non-periodic Perlin-style noise in 4d.",
2260
+ doc="Non-periodic Perlin-style noise in 4D.",
2009
2261
  )
2010
2262
 
2011
2263
  add_builtin(
@@ -2013,33 +2265,34 @@ add_builtin(
2013
2265
  input_types={"state": uint32, "x": float, "px": int},
2014
2266
  value_type=float,
2015
2267
  group="Random",
2016
- doc="Periodic Perlin-style noise in 1d.",
2268
+ doc="Periodic Perlin-style noise in 1D.",
2017
2269
  )
2018
2270
  add_builtin(
2019
2271
  "pnoise",
2020
2272
  input_types={"state": uint32, "xy": vec2, "px": int, "py": int},
2021
2273
  value_type=float,
2022
2274
  group="Random",
2023
- doc="Periodic Perlin-style noise in 2d.",
2275
+ doc="Periodic Perlin-style noise in 2D.",
2024
2276
  )
2025
2277
  add_builtin(
2026
2278
  "pnoise",
2027
2279
  input_types={"state": uint32, "xyz": vec3, "px": int, "py": int, "pz": int},
2028
2280
  value_type=float,
2029
2281
  group="Random",
2030
- doc="Periodic Perlin-style noise in 3d.",
2282
+ doc="Periodic Perlin-style noise in 3D.",
2031
2283
  )
2032
2284
  add_builtin(
2033
2285
  "pnoise",
2034
2286
  input_types={"state": uint32, "xyzt": vec4, "px": int, "py": int, "pz": int, "pt": int},
2035
2287
  value_type=float,
2036
2288
  group="Random",
2037
- doc="Periodic Perlin-style noise in 4d.",
2289
+ doc="Periodic Perlin-style noise in 4D.",
2038
2290
  )
2039
2291
 
2040
2292
  add_builtin(
2041
2293
  "curlnoise",
2042
- input_types={"state": uint32, "xy": vec2},
2294
+ input_types={"state": uint32, "xy": vec2, "octaves": uint32, "lacunarity": float, "gain": float},
2295
+ defaults={"octaves": 1, "lacunarity": 2.0, "gain": 0.5},
2043
2296
  value_type=vec2,
2044
2297
  group="Random",
2045
2298
  doc="Divergence-free vector field based on the gradient of a Perlin noise function.",
@@ -2047,7 +2300,8 @@ add_builtin(
2047
2300
  )
2048
2301
  add_builtin(
2049
2302
  "curlnoise",
2050
- input_types={"state": uint32, "xyz": vec3},
2303
+ input_types={"state": uint32, "xyz": vec3, "octaves": uint32, "lacunarity": float, "gain": float},
2304
+ defaults={"octaves": 1, "lacunarity": 2.0, "gain": 0.5},
2051
2305
  value_type=vec3,
2052
2306
  group="Random",
2053
2307
  doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
@@ -2055,7 +2309,8 @@ add_builtin(
2055
2309
  )
2056
2310
  add_builtin(
2057
2311
  "curlnoise",
2058
- input_types={"state": uint32, "xyzt": vec4},
2312
+ input_types={"state": uint32, "xyzt": vec4, "octaves": uint32, "lacunarity": float, "gain": float},
2313
+ defaults={"octaves": 1, "lacunarity": 2.0, "gain": 0.5},
2059
2314
  value_type=vec3,
2060
2315
  group="Random",
2061
2316
  doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
@@ -2069,7 +2324,7 @@ add_builtin(
2069
2324
  namespace="",
2070
2325
  variadic=True,
2071
2326
  group="Utility",
2072
- doc="Allows printing formatted strings, using C-style format specifiers.",
2327
+ doc="Allows printing formatted strings using C-style format specifiers.",
2073
2328
  )
2074
2329
 
2075
2330
  add_builtin("print", input_types={"value": Any}, doc="Print variable to stdout", export=False, group="Utility")
@@ -2089,9 +2344,12 @@ add_builtin(
2089
2344
  "tid",
2090
2345
  input_types={},
2091
2346
  value_type=int,
2347
+ export=False,
2092
2348
  group="Utility",
2093
- doc="""Return the current thread index. Note that this is the *global* index of the thread in the range [0, dim)
2094
- where dim is the parameter passed to kernel launch.""",
2349
+ doc="""Return the current thread index for a 1D kernel launch. Note that this is the *global* index of the thread in the range [0, dim)
2350
+ where dim is the parameter passed to kernel launch. This function may not be called from user-defined Warp functions.""",
2351
+ namespace="",
2352
+ native_func="builtin_tid1d",
2095
2353
  )
2096
2354
 
2097
2355
  add_builtin(
@@ -2099,7 +2357,10 @@ add_builtin(
2099
2357
  input_types={},
2100
2358
  value_type=[int, int],
2101
2359
  group="Utility",
2102
- doc="""Return the current thread indices for a 2d kernel launch. Use ``i,j = wp.tid()`` syntax to retrieve the coordinates inside the kernel thread grid.""",
2360
+ doc="""Return the current thread indices for a 2D kernel launch. Use ``i,j = wp.tid()`` syntax to retrieve the
2361
+ coordinates inside the kernel thread grid. This function may not be called from user-defined Warp functions.""",
2362
+ namespace="",
2363
+ native_func="builtin_tid2d",
2103
2364
  )
2104
2365
 
2105
2366
  add_builtin(
@@ -2107,7 +2368,10 @@ add_builtin(
2107
2368
  input_types={},
2108
2369
  value_type=[int, int, int],
2109
2370
  group="Utility",
2110
- doc="""Return the current thread indices for a 3d kernel launch. Use ``i,j,k = wp.tid()`` syntax to retrieve the coordinates inside the kernel thread grid.""",
2371
+ doc="""Return the current thread indices for a 3D kernel launch. Use ``i,j,k = wp.tid()`` syntax to retrieve the
2372
+ coordinates inside the kernel thread grid. This function may not be called from user-defined Warp functions.""",
2373
+ namespace="",
2374
+ native_func="builtin_tid3d",
2111
2375
  )
2112
2376
 
2113
2377
  add_builtin(
@@ -2115,42 +2379,60 @@ add_builtin(
2115
2379
  input_types={},
2116
2380
  value_type=[int, int, int, int],
2117
2381
  group="Utility",
2118
- doc="""Return the current thread indices for a 4d kernel launch. Use ``i,j,k,l = wp.tid()`` syntax to retrieve the coordinates inside the kernel thread grid.""",
2382
+ doc="""Return the current thread indices for a 4D kernel launch. Use ``i,j,k,l = wp.tid()`` syntax to retrieve the
2383
+ coordinates inside the kernel thread grid. This function may not be called from user-defined Warp functions.""",
2384
+ namespace="",
2385
+ native_func="builtin_tid4d",
2119
2386
  )
2120
2387
 
2121
2388
 
2122
- add_builtin("copy", variadic=True, hidden=True, export=False, group="Utility")
2389
+ add_builtin(
2390
+ "copy",
2391
+ input_types={"value": Any},
2392
+ value_func=lambda arg_types, kwds, _: arg_types[0],
2393
+ hidden=True,
2394
+ export=False,
2395
+ group="Utility",
2396
+ )
2397
+ add_builtin("assign", variadic=True, hidden=True, export=False, group="Utility")
2123
2398
  add_builtin(
2124
2399
  "select",
2125
2400
  input_types={"cond": bool, "arg1": Any, "arg2": Any},
2401
+ value_func=lambda arg_types, kwds, _: arg_types[1],
2402
+ doc="Select between two arguments, if ``cond`` is ``False`` then return ``arg1``, otherwise return ``arg2``",
2403
+ group="Utility",
2404
+ )
2405
+ add_builtin(
2406
+ "select",
2407
+ input_types={"cond": builtins.bool, "arg1": Any, "arg2": Any},
2126
2408
  value_func=lambda args, kwds, _: args[1].type,
2127
- doc="Select between two arguments, if cond is false then return ``arg1``, otherwise return ``arg2``",
2409
+ doc="Select between two arguments, if ``cond`` is ``False`` then return ``arg1``, otherwise return ``arg2``",
2128
2410
  group="Utility",
2129
2411
  )
2130
2412
  for t in int_types:
2131
2413
  add_builtin(
2132
2414
  "select",
2133
2415
  input_types={"cond": t, "arg1": Any, "arg2": Any},
2134
- value_func=lambda args, kwds, _: args[1].type,
2135
- doc="Select between two arguments, if cond is false then return ``arg1``, otherwise return ``arg2``",
2416
+ value_func=lambda arg_types, kwds, _: arg_types[1],
2417
+ doc="Select between two arguments, if ``cond`` is ``False`` then return ``arg1``, otherwise return ``arg2``",
2136
2418
  group="Utility",
2137
2419
  )
2138
2420
  add_builtin(
2139
2421
  "select",
2140
2422
  input_types={"arr": array(dtype=Any), "arg1": Any, "arg2": Any},
2141
- value_func=lambda args, kwds, _: args[1].type,
2142
- doc="Select between two arguments, if array is null then return ``arg1``, otherwise return ``arg2``",
2423
+ value_func=lambda arg_types, kwds, _: arg_types[1],
2424
+ doc="Select between two arguments, if ``arr`` is null then return ``arg1``, otherwise return ``arg2``",
2143
2425
  group="Utility",
2144
2426
  )
2145
2427
 
2146
2428
 
2147
- # does argument checking and type propagation for load()
2148
- def load_value_func(args, kwds, _):
2149
- if not is_array(args[0].type):
2429
+ # does argument checking and type propagation for address()
2430
+ def address_value_func(arg_types, kwds, _):
2431
+ if not is_array(arg_types[0]):
2150
2432
  raise RuntimeError("load() argument 0 must be an array")
2151
2433
 
2152
- num_indices = len(args[1:])
2153
- num_dims = args[0].type.ndim
2434
+ num_indices = len(arg_types[1:])
2435
+ num_dims = arg_types[0].ndim
2154
2436
 
2155
2437
  if num_indices < num_dims:
2156
2438
  raise RuntimeError(
@@ -2163,21 +2445,21 @@ def load_value_func(args, kwds, _):
2163
2445
  )
2164
2446
 
2165
2447
  # check index types
2166
- for a in args[1:]:
2167
- if type_is_int(a.type) == False:
2168
- raise RuntimeError(f"load() index arguments must be of integer type, got index of type {a.type}")
2448
+ for t in arg_types[1:]:
2449
+ if not type_is_int(t):
2450
+ raise RuntimeError(f"address() index arguments must be of integer type, got index of type {t}")
2169
2451
 
2170
- return args[0].type.dtype
2452
+ return Reference(arg_types[0].dtype)
2171
2453
 
2172
2454
 
2173
2455
  # does argument checking and type propagation for view()
2174
- def view_value_func(args, kwds, _):
2175
- if not is_array(args[0].type):
2456
+ def view_value_func(arg_types, kwds, _):
2457
+ if not is_array(arg_types[0]):
2176
2458
  raise RuntimeError("view() argument 0 must be an array")
2177
2459
 
2178
2460
  # check array dim big enough to support view
2179
- num_indices = len(args[1:])
2180
- num_dims = args[0].type.ndim
2461
+ num_indices = len(arg_types[1:])
2462
+ num_dims = arg_types[0].ndim
2181
2463
 
2182
2464
  if num_indices >= num_dims:
2183
2465
  raise RuntimeError(
@@ -2185,27 +2467,28 @@ def view_value_func(args, kwds, _):
2185
2467
  )
2186
2468
 
2187
2469
  # check index types
2188
- for a in args[1:]:
2189
- if type_is_int(a.type) == False:
2190
- raise RuntimeError(f"view() index arguments must be of integer type, got index of type {a.type}")
2470
+ for t in arg_types[1:]:
2471
+ if not type_is_int(t):
2472
+ raise RuntimeError(f"view() index arguments must be of integer type, got index of type {t}")
2191
2473
 
2192
2474
  # create an array view with leading dimensions removed
2193
- import copy
2194
-
2195
- view_type = copy.copy(args[0].type)
2196
- view_type.ndim -= num_indices
2197
-
2198
- return view_type
2475
+ dtype = arg_types[0].dtype
2476
+ ndim = num_dims - num_indices
2477
+ if isinstance(arg_types[0], (fabricarray, indexedfabricarray)):
2478
+ # fabric array of arrays: return array attribute as a regular array
2479
+ return array(dtype=dtype, ndim=ndim)
2480
+ else:
2481
+ return type(arg_types[0])(dtype=dtype, ndim=ndim)
2199
2482
 
2200
2483
 
2201
- # does argument checking and type propagation for store()
2202
- def store_value_func(args, kwds, _):
2484
+ # does argument checking and type propagation for array_store()
2485
+ def array_store_value_func(arg_types, kwds, _):
2203
2486
  # check target type
2204
- if not is_array(args[0].type):
2205
- raise RuntimeError("store() argument 0 must be an array")
2487
+ if not is_array(arg_types[0]):
2488
+ raise RuntimeError("array_store() argument 0 must be an array")
2206
2489
 
2207
- num_indices = len(args[1:-1])
2208
- num_dims = args[0].type.ndim
2490
+ num_indices = len(arg_types[1:-1])
2491
+ num_dims = arg_types[0].ndim
2209
2492
 
2210
2493
  # if this happens we should have generated a view instead of a load during code gen
2211
2494
  if num_indices < num_dims:
@@ -2217,31 +2500,63 @@ def store_value_func(args, kwds, _):
2217
2500
  )
2218
2501
 
2219
2502
  # check index types
2220
- for a in args[1:-1]:
2221
- if type_is_int(a.type) == False:
2222
- raise RuntimeError(f"store() index arguments must be of integer type, got index of type {a.type}")
2503
+ for t in arg_types[1:-1]:
2504
+ if not type_is_int(t):
2505
+ raise RuntimeError(f"array_store() index arguments must be of integer type, got index of type {t}")
2223
2506
 
2224
2507
  # check value type
2225
- if not types_equal(args[-1].type, args[0].type.dtype):
2508
+ if not types_equal(arg_types[-1], arg_types[0].dtype):
2226
2509
  raise RuntimeError(
2227
- f"store() value argument type ({args[2].type}) must be of the same type as the array ({args[0].type.dtype})"
2510
+ f"array_store() value argument type ({arg_types[2]}) must be of the same type as the array ({arg_types[0].dtype})"
2228
2511
  )
2229
2512
 
2230
2513
  return None
2231
2514
 
2232
2515
 
2233
- add_builtin("load", variadic=True, hidden=True, value_func=load_value_func, group="Utility")
2516
+ # does argument checking for store()
2517
+ def store_value_func(arg_types, kwds, _):
2518
+ # we already stripped the Reference from the argument type prior to this call
2519
+ if not types_equal(arg_types[0], arg_types[1]):
2520
+ raise RuntimeError(f"store() value argument type ({arg_types[1]}) must be of the same type as the reference")
2521
+
2522
+ return None
2523
+
2524
+
2525
+ # does type propagation for load()
2526
+ def load_value_func(arg_types, kwds, _):
2527
+ # we already stripped the Reference from the argument type prior to this call
2528
+ return arg_types[0]
2529
+
2530
+
2531
+ add_builtin("address", variadic=True, hidden=True, value_func=address_value_func, group="Utility")
2234
2532
  add_builtin("view", variadic=True, hidden=True, value_func=view_value_func, group="Utility")
2235
- add_builtin("store", variadic=True, hidden=True, value_func=store_value_func, skip_replay=True, group="Utility")
2533
+ add_builtin(
2534
+ "array_store", variadic=True, hidden=True, value_func=array_store_value_func, skip_replay=True, group="Utility"
2535
+ )
2536
+ add_builtin(
2537
+ "store",
2538
+ input_types={"address": Reference, "value": Any},
2539
+ hidden=True,
2540
+ value_func=store_value_func,
2541
+ skip_replay=True,
2542
+ group="Utility",
2543
+ )
2544
+ add_builtin(
2545
+ "load",
2546
+ input_types={"address": Reference},
2547
+ hidden=True,
2548
+ value_func=load_value_func,
2549
+ group="Utility",
2550
+ )
2236
2551
 
2237
2552
 
2238
- def atomic_op_value_func(args, kwds, _):
2553
+ def atomic_op_value_func(arg_types, kwds, _):
2239
2554
  # check target type
2240
- if not is_array(args[0].type):
2555
+ if not is_array(arg_types[0]):
2241
2556
  raise RuntimeError("atomic() operation argument 0 must be an array")
2242
2557
 
2243
- num_indices = len(args[1:-1])
2244
- num_dims = args[0].type.ndim
2558
+ num_indices = len(arg_types[1:-1])
2559
+ num_dims = arg_types[0].ndim
2245
2560
 
2246
2561
  # if this happens we should have generated a view instead of a load during code gen
2247
2562
  if num_indices < num_dims:
@@ -2253,18 +2568,16 @@ def atomic_op_value_func(args, kwds, _):
2253
2568
  )
2254
2569
 
2255
2570
  # check index types
2256
- for a in args[1:-1]:
2257
- if type_is_int(a.type) == False:
2258
- raise RuntimeError(
2259
- f"atomic() operation index arguments must be of integer type, got index of type {a.type}"
2260
- )
2571
+ for t in arg_types[1:-1]:
2572
+ if not type_is_int(t):
2573
+ raise RuntimeError(f"atomic() operation index arguments must be of integer type, got index of type {t}")
2261
2574
 
2262
- if not types_equal(args[-1].type, args[0].type.dtype):
2575
+ if not types_equal(arg_types[-1], arg_types[0].dtype):
2263
2576
  raise RuntimeError(
2264
- f"atomic() value argument ({args[-1].type}) must be of the same type as the array ({args[0].type.dtype})"
2577
+ f"atomic() value argument ({arg_types[-1]}) must be of the same type as the array ({arg_types[0].dtype})"
2265
2578
  )
2266
2579
 
2267
- return args[0].type.dtype
2580
+ return arg_types[0].dtype
2268
2581
 
2269
2582
 
2270
2583
  for array_type in array_types:
@@ -2276,7 +2589,7 @@ for array_type in array_types:
2276
2589
  hidden=hidden,
2277
2590
  input_types={"a": array_type(dtype=Any), "i": int, "value": Any},
2278
2591
  value_func=atomic_op_value_func,
2279
- doc="Atomically add ``value`` onto the array at location given by index.",
2592
+ doc="Atomically add ``value`` onto ``a[i]``.",
2280
2593
  group="Utility",
2281
2594
  skip_replay=True,
2282
2595
  )
@@ -2285,7 +2598,7 @@ for array_type in array_types:
2285
2598
  hidden=hidden,
2286
2599
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "value": Any},
2287
2600
  value_func=atomic_op_value_func,
2288
- doc="Atomically add ``value`` onto the array at location given by indices.",
2601
+ doc="Atomically add ``value`` onto ``a[i,j]``.",
2289
2602
  group="Utility",
2290
2603
  skip_replay=True,
2291
2604
  )
@@ -2294,7 +2607,7 @@ for array_type in array_types:
2294
2607
  hidden=hidden,
2295
2608
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
2296
2609
  value_func=atomic_op_value_func,
2297
- doc="Atomically add ``value`` onto the array at location given by indices.",
2610
+ doc="Atomically add ``value`` onto ``a[i,j,k]``.",
2298
2611
  group="Utility",
2299
2612
  skip_replay=True,
2300
2613
  )
@@ -2303,7 +2616,7 @@ for array_type in array_types:
2303
2616
  hidden=hidden,
2304
2617
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
2305
2618
  value_func=atomic_op_value_func,
2306
- doc="Atomically add ``value`` onto the array at location given by indices.",
2619
+ doc="Atomically add ``value`` onto ``a[i,j,k,l]``.",
2307
2620
  group="Utility",
2308
2621
  skip_replay=True,
2309
2622
  )
@@ -2313,7 +2626,7 @@ for array_type in array_types:
2313
2626
  hidden=hidden,
2314
2627
  input_types={"a": array_type(dtype=Any), "i": int, "value": Any},
2315
2628
  value_func=atomic_op_value_func,
2316
- doc="Atomically subtract ``value`` onto the array at location given by index.",
2629
+ doc="Atomically subtract ``value`` onto ``a[i]``.",
2317
2630
  group="Utility",
2318
2631
  skip_replay=True,
2319
2632
  )
@@ -2322,7 +2635,7 @@ for array_type in array_types:
2322
2635
  hidden=hidden,
2323
2636
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "value": Any},
2324
2637
  value_func=atomic_op_value_func,
2325
- doc="Atomically subtract ``value`` onto the array at location given by indices.",
2638
+ doc="Atomically subtract ``value`` onto ``a[i,j]``.",
2326
2639
  group="Utility",
2327
2640
  skip_replay=True,
2328
2641
  )
@@ -2331,7 +2644,7 @@ for array_type in array_types:
2331
2644
  hidden=hidden,
2332
2645
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
2333
2646
  value_func=atomic_op_value_func,
2334
- doc="Atomically subtract ``value`` onto the array at location given by indices.",
2647
+ doc="Atomically subtract ``value`` onto ``a[i,j,k]``.",
2335
2648
  group="Utility",
2336
2649
  skip_replay=True,
2337
2650
  )
@@ -2340,7 +2653,7 @@ for array_type in array_types:
2340
2653
  hidden=hidden,
2341
2654
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
2342
2655
  value_func=atomic_op_value_func,
2343
- doc="Atomically subtract ``value`` onto the array at location given by indices.",
2656
+ doc="Atomically subtract ``value`` onto ``a[i,j,k,l]``.",
2344
2657
  group="Utility",
2345
2658
  skip_replay=True,
2346
2659
  )
@@ -2350,7 +2663,8 @@ for array_type in array_types:
2350
2663
  hidden=hidden,
2351
2664
  input_types={"a": array_type(dtype=Any), "i": int, "value": Any},
2352
2665
  value_func=atomic_op_value_func,
2353
- doc="Compute the minimum of ``value`` and ``array[index]`` and atomically update the array. Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2666
+ doc="Compute the minimum of ``value`` and ``a[i]`` and atomically update the array.\n\n"
2667
+ "Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2354
2668
  group="Utility",
2355
2669
  skip_replay=True,
2356
2670
  )
@@ -2359,7 +2673,8 @@ for array_type in array_types:
2359
2673
  hidden=hidden,
2360
2674
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "value": Any},
2361
2675
  value_func=atomic_op_value_func,
2362
- doc="Compute the minimum of ``value`` and ``array[index]`` and atomically update the array. Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2676
+ doc="Compute the minimum of ``value`` and ``a[i,j]`` and atomically update the array.\n\n"
2677
+ "Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2363
2678
  group="Utility",
2364
2679
  skip_replay=True,
2365
2680
  )
@@ -2368,7 +2683,8 @@ for array_type in array_types:
2368
2683
  hidden=hidden,
2369
2684
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
2370
2685
  value_func=atomic_op_value_func,
2371
- doc="Compute the minimum of ``value`` and ``array[index]`` and atomically update the array. Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2686
+ doc="Compute the minimum of ``value`` and ``a[i,j,k]`` and atomically update the array.\n\n"
2687
+ "Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2372
2688
  group="Utility",
2373
2689
  skip_replay=True,
2374
2690
  )
@@ -2377,7 +2693,8 @@ for array_type in array_types:
2377
2693
  hidden=hidden,
2378
2694
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
2379
2695
  value_func=atomic_op_value_func,
2380
- doc="Compute the minimum of ``value`` and ``array[index]`` and atomically update the array. Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2696
+ doc="Compute the minimum of ``value`` and ``a[i,j,k,l]`` and atomically update the array.\n\n"
2697
+ "Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2381
2698
  group="Utility",
2382
2699
  skip_replay=True,
2383
2700
  )
@@ -2387,7 +2704,8 @@ for array_type in array_types:
2387
2704
  hidden=hidden,
2388
2705
  input_types={"a": array_type(dtype=Any), "i": int, "value": Any},
2389
2706
  value_func=atomic_op_value_func,
2390
- doc="Compute the maximum of ``value`` and ``array[index]`` and atomically update the array. Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2707
+ doc="Compute the maximum of ``value`` and ``a[i]`` and atomically update the array.\n\n"
2708
+ "Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2391
2709
  group="Utility",
2392
2710
  skip_replay=True,
2393
2711
  )
@@ -2396,7 +2714,8 @@ for array_type in array_types:
2396
2714
  hidden=hidden,
2397
2715
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "value": Any},
2398
2716
  value_func=atomic_op_value_func,
2399
- doc="Compute the maximum of ``value`` and ``array[index]`` and atomically update the array. Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2717
+ doc="Compute the maximum of ``value`` and ``a[i,j]`` and atomically update the array.\n\n"
2718
+ "Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2400
2719
  group="Utility",
2401
2720
  skip_replay=True,
2402
2721
  )
@@ -2405,7 +2724,8 @@ for array_type in array_types:
2405
2724
  hidden=hidden,
2406
2725
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
2407
2726
  value_func=atomic_op_value_func,
2408
- doc="Compute the maximum of ``value`` and ``array[index]`` and atomically update the array. Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2727
+ doc="Compute the maximum of ``value`` and ``a[i,j,k]`` and atomically update the array.\n\n"
2728
+ "Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2409
2729
  group="Utility",
2410
2730
  skip_replay=True,
2411
2731
  )
@@ -2414,26 +2734,27 @@ for array_type in array_types:
2414
2734
  hidden=hidden,
2415
2735
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
2416
2736
  value_func=atomic_op_value_func,
2417
- doc="Compute the maximum of ``value`` and ``array[index]`` and atomically update the array. Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2737
+ doc="Compute the maximum of ``value`` and ``a[i,j,k,l]`` and atomically update the array.\n\n"
2738
+ "Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2418
2739
  group="Utility",
2419
2740
  skip_replay=True,
2420
2741
  )
2421
2742
 
2422
2743
 
2423
2744
  # used to index into builtin types, i.e.: y = vec3[1]
2424
- def index_value_func(args, kwds, _):
2425
- return args[0].type._wp_scalar_type_
2745
+ def index_value_func(arg_types, kwds, _):
2746
+ return arg_types[0]._wp_scalar_type_
2426
2747
 
2427
2748
 
2428
2749
  add_builtin(
2429
- "index",
2750
+ "extract",
2430
2751
  input_types={"a": vector(length=Any, dtype=Scalar), "i": int},
2431
2752
  value_func=index_value_func,
2432
2753
  hidden=True,
2433
2754
  group="Utility",
2434
2755
  )
2435
2756
  add_builtin(
2436
- "index",
2757
+ "extract",
2437
2758
  input_types={"a": quaternion(dtype=Scalar), "i": int},
2438
2759
  value_func=index_value_func,
2439
2760
  hidden=True,
@@ -2441,14 +2762,14 @@ add_builtin(
2441
2762
  )
2442
2763
 
2443
2764
  add_builtin(
2444
- "index",
2765
+ "extract",
2445
2766
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int},
2446
- value_func=lambda args, kwds, _: vector(length=args[0].type._shape_[1], dtype=args[0].type._wp_scalar_type_),
2767
+ value_func=lambda arg_types, kwds, _: vector(length=arg_types[0]._shape_[1], dtype=arg_types[0]._wp_scalar_type_),
2447
2768
  hidden=True,
2448
2769
  group="Utility",
2449
2770
  )
2450
2771
  add_builtin(
2451
- "index",
2772
+ "extract",
2452
2773
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int},
2453
2774
  value_func=index_value_func,
2454
2775
  hidden=True,
@@ -2456,77 +2777,66 @@ add_builtin(
2456
2777
  )
2457
2778
 
2458
2779
  add_builtin(
2459
- "index",
2780
+ "extract",
2460
2781
  input_types={"a": transformation(dtype=Scalar), "i": int},
2461
2782
  value_func=index_value_func,
2462
2783
  hidden=True,
2463
2784
  group="Utility",
2464
2785
  )
2465
2786
 
2466
- add_builtin("index", input_types={"s": shape_t, "i": int}, value_type=int, hidden=True, group="Utility")
2787
+ add_builtin("extract", input_types={"s": shape_t, "i": int}, value_type=int, hidden=True, group="Utility")
2467
2788
 
2468
2789
 
2469
- def vector_indexset_element_value_func(args, kwds, _):
2470
- vec = args[0]
2471
- index = args[1]
2472
- value = args[2]
2790
+ def vector_indexref_element_value_func(arg_types, kwds, _):
2791
+ vec_type = arg_types[0]
2792
+ # index_type = arg_types[1]
2793
+ value_type = vec_type._wp_scalar_type_
2473
2794
 
2474
- if value.type is not vec.type._wp_scalar_type_:
2475
- raise RuntimeError(
2476
- f"Trying to assign type '{type_repr(value.type)}' to element of a vector with type '{type_repr(vec.type)}'"
2477
- )
2478
-
2479
- return None
2795
+ return Reference(value_type)
2480
2796
 
2481
2797
 
2482
- # implements vector[index] = value
2798
+ # implements &vector[index]
2483
2799
  add_builtin(
2484
- "indexset",
2485
- input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
2486
- value_func=vector_indexset_element_value_func,
2800
+ "index",
2801
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": int},
2802
+ value_func=vector_indexref_element_value_func,
2803
+ hidden=True,
2804
+ group="Utility",
2805
+ skip_replay=True,
2806
+ )
2807
+ # implements &(*vector)[index]
2808
+ add_builtin(
2809
+ "indexref",
2810
+ input_types={"a": Reference, "i": int},
2811
+ value_func=vector_indexref_element_value_func,
2487
2812
  hidden=True,
2488
2813
  group="Utility",
2489
2814
  skip_replay=True,
2490
2815
  )
2491
2816
 
2492
2817
 
2493
- def matrix_indexset_element_value_func(args, kwds, _):
2494
- mat = args[0]
2495
- row = args[1]
2496
- col = args[2]
2497
- value = args[3]
2498
-
2499
- if value.type is not mat.type._wp_scalar_type_:
2500
- raise RuntimeError(
2501
- f"Trying to assign type '{type_repr(value.type)}' to element of a matrix with type '{type_repr(mat.type)}'"
2502
- )
2503
-
2504
- return None
2505
-
2818
+ def matrix_indexref_element_value_func(arg_types, kwds, _):
2819
+ mat_type = arg_types[0]
2820
+ # row_type = arg_types[1]
2821
+ # col_type = arg_types[2]
2822
+ value_type = mat_type._wp_scalar_type_
2506
2823
 
2507
- def matrix_indexset_row_value_func(args, kwds, _):
2508
- mat = args[0]
2509
- row = args[1]
2510
- value = args[2]
2824
+ return Reference(value_type)
2511
2825
 
2512
- if value.type._shape_[0] != mat.type._shape_[1]:
2513
- raise RuntimeError(
2514
- f"Trying to assign vector with length {value.type._length} to matrix with shape {mat.type._shape}, vector length must match the number of matrix columns."
2515
- )
2516
2826
 
2517
- if value.type._wp_scalar_type_ is not mat.type._wp_scalar_type_:
2518
- raise RuntimeError(
2519
- f"Trying to assign vector of type '{type_repr(value.type)}' to row of matrix of type '{type_repr(mat.type)}'"
2520
- )
2827
+ def matrix_indexref_row_value_func(arg_types, kwds, _):
2828
+ mat_type = arg_types[0]
2829
+ row_type = mat_type._wp_row_type_
2830
+ # value_type = arg_types[2]
2521
2831
 
2522
- return None
2832
+ return Reference(row_type)
2523
2833
 
2524
2834
 
2525
2835
  # implements matrix[i] = row
2526
2836
  add_builtin(
2527
- "indexset",
2528
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
2529
- value_func=matrix_indexset_row_value_func,
2837
+ "index",
2838
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int},
2839
+ value_func=matrix_indexref_row_value_func,
2530
2840
  hidden=True,
2531
2841
  group="Utility",
2532
2842
  skip_replay=True,
@@ -2534,29 +2844,29 @@ add_builtin(
2534
2844
 
2535
2845
  # implements matrix[i,j] = scalar
2536
2846
  add_builtin(
2537
- "indexset",
2538
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
2539
- value_func=matrix_indexset_element_value_func,
2847
+ "index",
2848
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int},
2849
+ value_func=matrix_indexref_element_value_func,
2540
2850
  hidden=True,
2541
2851
  group="Utility",
2542
2852
  skip_replay=True,
2543
2853
  )
2544
2854
 
2545
- for t in scalar_types + vector_types:
2855
+ for t in scalar_types + vector_types + [builtins.bool]:
2546
2856
  if "vec" in t.__name__ or "mat" in t.__name__:
2547
2857
  continue
2548
2858
  add_builtin(
2549
2859
  "expect_eq",
2550
2860
  input_types={"arg1": t, "arg2": t},
2551
2861
  value_type=None,
2552
- doc="Prints an error to stdout if arg1 and arg2 are not equal",
2862
+ doc="Prints an error to stdout if ``arg1`` and ``arg2`` are not equal",
2553
2863
  group="Utility",
2554
2864
  hidden=True,
2555
2865
  )
2556
2866
 
2557
2867
 
2558
- def expect_eq_val_func(args, kwds, _):
2559
- if not types_equal(args[0].type, args[1].type):
2868
+ def expect_eq_val_func(arg_types, kwds, _):
2869
+ if not types_equal(arg_types[0], arg_types[1]):
2560
2870
  raise RuntimeError("Can't test equality for objects with different types")
2561
2871
  return None
2562
2872
 
@@ -2565,7 +2875,7 @@ add_builtin(
2565
2875
  "expect_eq",
2566
2876
  input_types={"arg1": vector(length=Any, dtype=Scalar), "arg2": vector(length=Any, dtype=Scalar)},
2567
2877
  value_func=expect_eq_val_func,
2568
- doc="Prints an error to stdout if arg1 and arg2 are not equal",
2878
+ doc="Prints an error to stdout if ``arg1`` and ``arg2`` are not equal",
2569
2879
  group="Utility",
2570
2880
  hidden=True,
2571
2881
  )
@@ -2573,7 +2883,7 @@ add_builtin(
2573
2883
  "expect_neq",
2574
2884
  input_types={"arg1": vector(length=Any, dtype=Scalar), "arg2": vector(length=Any, dtype=Scalar)},
2575
2885
  value_func=expect_eq_val_func,
2576
- doc="Prints an error to stdout if arg1 and arg2 are equal",
2886
+ doc="Prints an error to stdout if ``arg1`` and ``arg2`` are equal",
2577
2887
  group="Utility",
2578
2888
  hidden=True,
2579
2889
  )
@@ -2582,7 +2892,7 @@ add_builtin(
2582
2892
  "expect_eq",
2583
2893
  input_types={"arg1": matrix(shape=(Any, Any), dtype=Scalar), "arg2": matrix(shape=(Any, Any), dtype=Scalar)},
2584
2894
  value_func=expect_eq_val_func,
2585
- doc="Prints an error to stdout if arg1 and arg2 are not equal",
2895
+ doc="Prints an error to stdout if ``arg1`` and ``arg2`` are not equal",
2586
2896
  group="Utility",
2587
2897
  hidden=True,
2588
2898
  )
@@ -2590,7 +2900,7 @@ add_builtin(
2590
2900
  "expect_neq",
2591
2901
  input_types={"arg1": matrix(shape=(Any, Any), dtype=Scalar), "arg2": matrix(shape=(Any, Any), dtype=Scalar)},
2592
2902
  value_func=expect_eq_val_func,
2593
- doc="Prints an error to stdout if arg1 and arg2 are equal",
2903
+ doc="Prints an error to stdout if ``arg1`` and ``arg2`` are equal",
2594
2904
  group="Utility",
2595
2905
  hidden=True,
2596
2906
  )
@@ -2599,29 +2909,30 @@ add_builtin(
2599
2909
  "lerp",
2600
2910
  input_types={"a": Float, "b": Float, "t": Float},
2601
2911
  value_func=sametype_value_func(Float),
2602
- doc="Linearly interpolate two values a and b using factor t, computed as ``a*(1-t) + b*t``",
2912
+ doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
2603
2913
  group="Utility",
2604
2914
  )
2605
2915
  add_builtin(
2606
2916
  "smoothstep",
2607
2917
  input_types={"edge0": Float, "edge1": Float, "x": Float},
2608
2918
  value_func=sametype_value_func(Float),
2609
- doc="Smoothly interpolate between two values edge0 and edge1 using a factor x, and return a result between 0 and 1 using a cubic Hermite interpolation after clamping",
2919
+ doc="""Smoothly interpolate between two values ``edge0`` and ``edge1`` using a factor ``x``,
2920
+ and return a result between 0 and 1 using a cubic Hermite interpolation after clamping.""",
2610
2921
  group="Utility",
2611
2922
  )
2612
2923
 
2613
2924
 
2614
2925
  def lerp_value_func(default):
2615
- def fn(args, kwds, _):
2616
- if args is None:
2926
+ def fn(arg_types, kwds, _):
2927
+ if arg_types is None:
2617
2928
  return default
2618
- scalar_type = args[-1].type
2619
- if not types_equal(args[0].type, args[1].type):
2929
+ scalar_type = arg_types[-1]
2930
+ if not types_equal(arg_types[0], arg_types[1]):
2620
2931
  raise RuntimeError("Can't lerp between objects with different types")
2621
- if args[0].type._wp_scalar_type_ != scalar_type:
2932
+ if arg_types[0]._wp_scalar_type_ != scalar_type:
2622
2933
  raise RuntimeError("'t' parameter must have the same scalar type as objects you're lerping between")
2623
2934
 
2624
- return args[0].type
2935
+ return arg_types[0]
2625
2936
 
2626
2937
  return fn
2627
2938
 
@@ -2630,28 +2941,28 @@ add_builtin(
2630
2941
  "lerp",
2631
2942
  input_types={"a": vector(length=Any, dtype=Float), "b": vector(length=Any, dtype=Float), "t": Float},
2632
2943
  value_func=lerp_value_func(vector(length=Any, dtype=Float)),
2633
- doc="Linearly interpolate two values a and b using factor t, computed as ``a*(1-t) + b*t``",
2944
+ doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
2634
2945
  group="Utility",
2635
2946
  )
2636
2947
  add_builtin(
2637
2948
  "lerp",
2638
2949
  input_types={"a": matrix(shape=(Any, Any), dtype=Float), "b": matrix(shape=(Any, Any), dtype=Float), "t": Float},
2639
2950
  value_func=lerp_value_func(matrix(shape=(Any, Any), dtype=Float)),
2640
- doc="Linearly interpolate two values a and b using factor t, computed as ``a*(1-t) + b*t``",
2951
+ doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
2641
2952
  group="Utility",
2642
2953
  )
2643
2954
  add_builtin(
2644
2955
  "lerp",
2645
2956
  input_types={"a": quaternion(dtype=Float), "b": quaternion(dtype=Float), "t": Float},
2646
2957
  value_func=lerp_value_func(quaternion(dtype=Float)),
2647
- doc="Linearly interpolate two values a and b using factor t, computed as ``a*(1-t) + b*t``",
2958
+ doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
2648
2959
  group="Utility",
2649
2960
  )
2650
2961
  add_builtin(
2651
2962
  "lerp",
2652
2963
  input_types={"a": transformation(dtype=Float), "b": transformation(dtype=Float), "t": Float},
2653
2964
  value_func=lerp_value_func(transformation(dtype=Float)),
2654
- doc="Linearly interpolate two values a and b using factor t, computed as ``a*(1-t) + b*t``",
2965
+ doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
2655
2966
  group="Utility",
2656
2967
  )
2657
2968
 
@@ -2661,14 +2972,14 @@ add_builtin(
2661
2972
  input_types={"arg1": Float, "arg2": Float, "tolerance": Float},
2662
2973
  defaults={"tolerance": 1.0e-6},
2663
2974
  value_type=None,
2664
- doc="Prints an error to stdout if arg1 and arg2 are not closer than tolerance in magnitude",
2975
+ doc="Prints an error to stdout if ``arg1`` and ``arg2`` are not closer than tolerance in magnitude",
2665
2976
  group="Utility",
2666
2977
  )
2667
2978
  add_builtin(
2668
2979
  "expect_near",
2669
2980
  input_types={"arg1": vec3, "arg2": vec3, "tolerance": float},
2670
2981
  value_type=None,
2671
- doc="Prints an error to stdout if any element of arg1 and arg2 are not closer than tolerance in magnitude",
2982
+ doc="Prints an error to stdout if any element of ``arg1`` and ``arg2`` are not closer than tolerance in magnitude",
2672
2983
  group="Utility",
2673
2984
  )
2674
2985
 
@@ -2679,14 +2990,14 @@ add_builtin(
2679
2990
  "lower_bound",
2680
2991
  input_types={"arr": array(dtype=Scalar), "value": Scalar},
2681
2992
  value_type=int,
2682
- doc="Search a sorted array for the closest element greater than or equal to value.",
2993
+ doc="Search a sorted array ``arr`` for the closest element greater than or equal to ``value``.",
2683
2994
  )
2684
2995
 
2685
2996
  add_builtin(
2686
2997
  "lower_bound",
2687
2998
  input_types={"arr": array(dtype=Scalar), "arr_begin": int, "arr_end": int, "value": Scalar},
2688
2999
  value_type=int,
2689
- doc="Search a sorted array range [arr_begin, arr_end) for the closest element greater than or equal to value.",
3000
+ doc="Search a sorted array ``arr`` in the range [arr_begin, arr_end) for the closest element greater than or equal to ``value``.",
2690
3001
  )
2691
3002
 
2692
3003
  # ---------------------------------
@@ -2766,11 +3077,11 @@ add_builtin("invert", input_types={"x": Int}, value_func=sametype_value_func(Int
2766
3077
 
2767
3078
 
2768
3079
  def scalar_mul_value_func(default):
2769
- def fn(args, kwds, _):
2770
- if args is None:
3080
+ def fn(arg_types, kwds, _):
3081
+ if arg_types is None:
2771
3082
  return default
2772
- scalar = [a.type for a in args if a.type in scalar_types][0]
2773
- compound = [a.type for a in args if a.type not in scalar_types][0]
3083
+ scalar = [t for t in arg_types if t in scalar_types][0]
3084
+ compound = [t for t in arg_types if t not in scalar_types][0]
2774
3085
  if scalar != compound._wp_scalar_type_:
2775
3086
  raise RuntimeError("Object and coefficient must have the same scalar type when multiplying by scalar")
2776
3087
  return compound
@@ -2778,36 +3089,53 @@ def scalar_mul_value_func(default):
2778
3089
  return fn
2779
3090
 
2780
3091
 
2781
- def mul_matvec_value_func(args, kwds, _):
2782
- if args is None:
3092
+ def mul_matvec_value_func(arg_types, kwds, _):
3093
+ if arg_types is None:
3094
+ return vector(length=Any, dtype=Scalar)
3095
+
3096
+ if arg_types[0]._wp_scalar_type_ != arg_types[1]._wp_scalar_type_:
3097
+ raise RuntimeError(
3098
+ f"Can't multiply matrix and vector with different types {arg_types[0]._wp_scalar_type_}, {arg_types[1]._wp_scalar_type_}"
3099
+ )
3100
+
3101
+ if arg_types[0]._shape_[1] != arg_types[1]._length_:
3102
+ raise RuntimeError(
3103
+ f"Can't multiply matrix of shape {arg_types[0]._shape_} and vector with length {arg_types[1]._length_}"
3104
+ )
3105
+
3106
+ return vector(length=arg_types[0]._shape_[0], dtype=arg_types[0]._wp_scalar_type_)
3107
+
3108
+
3109
+ def mul_vecmat_value_func(arg_types, kwds, _):
3110
+ if arg_types is None:
2783
3111
  return vector(length=Any, dtype=Scalar)
2784
3112
 
2785
- if args[0].type._wp_scalar_type_ != args[1].type._wp_scalar_type_:
3113
+ if arg_types[1]._wp_scalar_type_ != arg_types[0]._wp_scalar_type_:
2786
3114
  raise RuntimeError(
2787
- f"Can't multiply matrix and vector with different types {args[0].type._wp_scalar_type_}, {args[1].type._wp_scalar_type_}"
3115
+ f"Can't multiply vector and matrix with different types {arg_types[1]._wp_scalar_type_}, {arg_types[0]._wp_scalar_type_}"
2788
3116
  )
2789
3117
 
2790
- if args[0].type._shape_[1] != args[1].type._length_:
3118
+ if arg_types[1]._shape_[0] != arg_types[0]._length_:
2791
3119
  raise RuntimeError(
2792
- f"Can't multiply matrix of shape {args[0].type._shape_} and vector with length {args[1].type._length_}"
3120
+ f"Can't multiply vector with length {arg_types[0]._length_} and matrix of shape {arg_types[1]._shape_}"
2793
3121
  )
2794
3122
 
2795
- return vector(length=args[0].type._shape_[0], dtype=args[0].type._wp_scalar_type_)
3123
+ return vector(length=arg_types[1]._shape_[1], dtype=arg_types[1]._wp_scalar_type_)
2796
3124
 
2797
3125
 
2798
- def mul_matmat_value_func(args, kwds, _):
2799
- if args is None:
3126
+ def mul_matmat_value_func(arg_types, kwds, _):
3127
+ if arg_types is None:
2800
3128
  return matrix(length=Any, dtype=Scalar)
2801
3129
 
2802
- if args[0].type._wp_scalar_type_ != args[1].type._wp_scalar_type_:
3130
+ if arg_types[0]._wp_scalar_type_ != arg_types[1]._wp_scalar_type_:
2803
3131
  raise RuntimeError(
2804
- f"Can't multiply matrices with different types {args[0].type._wp_scalar_type_}, {args[1].type._wp_scalar_type_}"
3132
+ f"Can't multiply matrices with different types {arg_types[0]._wp_scalar_type_}, {arg_types[1]._wp_scalar_type_}"
2805
3133
  )
2806
3134
 
2807
- if args[0].type._shape_[1] != args[1].type._shape_[0]:
2808
- raise RuntimeError(f"Can't multiply matrix of shapes {args[0].type._shape_} and {args[1].type._shape_}")
3135
+ if arg_types[0]._shape_[1] != arg_types[1]._shape_[0]:
3136
+ raise RuntimeError(f"Can't multiply matrix of shapes {arg_types[0]._shape_} and {arg_types[1]._shape_}")
2809
3137
 
2810
- return matrix(shape=(args[0].type._shape_[0], args[1].type._shape_[1]), dtype=args[0].type._wp_scalar_type_)
3138
+ return matrix(shape=(arg_types[0]._shape_[0], arg_types[1]._shape_[1]), dtype=arg_types[0]._wp_scalar_type_)
2811
3139
 
2812
3140
 
2813
3141
  add_builtin(
@@ -2869,6 +3197,13 @@ add_builtin(
2869
3197
  doc="",
2870
3198
  group="Operators",
2871
3199
  )
3200
+ add_builtin(
3201
+ "mul",
3202
+ input_types={"x": vector(length=Any, dtype=Scalar), "y": matrix(shape=(Any, Any), dtype=Scalar)},
3203
+ value_func=mul_vecmat_value_func,
3204
+ doc="",
3205
+ group="Operators",
3206
+ )
2872
3207
  add_builtin(
2873
3208
  "mul",
2874
3209
  input_types={"x": matrix(shape=(Any, Any), dtype=Scalar), "y": matrix(shape=(Any, Any), dtype=Scalar)},
@@ -2904,7 +3239,12 @@ add_builtin(
2904
3239
  )
2905
3240
 
2906
3241
  add_builtin(
2907
- "div", input_types={"x": Scalar, "y": Scalar}, value_func=sametype_value_func(Scalar), doc="", group="Operators"
3242
+ "div",
3243
+ input_types={"x": Scalar, "y": Scalar},
3244
+ value_func=sametype_value_func(Scalar),
3245
+ doc="",
3246
+ group="Operators",
3247
+ require_original_output_arg=True,
2908
3248
  )
2909
3249
  add_builtin(
2910
3250
  "div",
@@ -2913,6 +3253,13 @@ add_builtin(
2913
3253
  doc="",
2914
3254
  group="Operators",
2915
3255
  )
3256
+ add_builtin(
3257
+ "div",
3258
+ input_types={"x": Scalar, "y": vector(length=Any, dtype=Scalar)},
3259
+ value_func=scalar_mul_value_func(vector(length=Any, dtype=Scalar)),
3260
+ doc="",
3261
+ group="Operators",
3262
+ )
2916
3263
  add_builtin(
2917
3264
  "div",
2918
3265
  input_types={"x": matrix(shape=(Any, Any), dtype=Scalar), "y": Scalar},
@@ -2920,6 +3267,13 @@ add_builtin(
2920
3267
  doc="",
2921
3268
  group="Operators",
2922
3269
  )
3270
+ add_builtin(
3271
+ "div",
3272
+ input_types={"x": Scalar, "y": matrix(shape=(Any, Any), dtype=Scalar)},
3273
+ value_func=scalar_mul_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
3274
+ doc="",
3275
+ group="Operators",
3276
+ )
2923
3277
  add_builtin(
2924
3278
  "div",
2925
3279
  input_types={"x": quaternion(dtype=Scalar), "y": Scalar},
@@ -2927,6 +3281,13 @@ add_builtin(
2927
3281
  doc="",
2928
3282
  group="Operators",
2929
3283
  )
3284
+ add_builtin(
3285
+ "div",
3286
+ input_types={"x": Scalar, "y": quaternion(dtype=Scalar)},
3287
+ value_func=scalar_mul_value_func(quaternion(dtype=Scalar)),
3288
+ doc="",
3289
+ group="Operators",
3290
+ )
2930
3291
 
2931
3292
  add_builtin(
2932
3293
  "floordiv",
@@ -2981,9 +3342,9 @@ add_builtin(
2981
3342
  group="Operators",
2982
3343
  )
2983
3344
 
2984
- add_builtin("unot", input_types={"b": bool}, value_type=bool, doc="", group="Operators")
3345
+ add_builtin("unot", input_types={"b": builtins.bool}, value_type=builtins.bool, doc="", group="Operators")
2985
3346
  for t in int_types:
2986
- add_builtin("unot", input_types={"b": t}, value_type=bool, doc="", group="Operators")
3347
+ add_builtin("unot", input_types={"b": t}, value_type=builtins.bool, doc="", group="Operators")
2987
3348
 
2988
3349
 
2989
- add_builtin("unot", input_types={"a": array(dtype=Any)}, value_type=bool, doc="", group="Operators")
3350
+ add_builtin("unot", input_types={"a": array(dtype=Any)}, value_type=builtins.bool, doc="", group="Operators")