warp-lang 1.0.0b2__py3-none-win_amd64.whl → 1.0.0b6__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (271) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.dll +0 -0
  57. warp/bin/warp.dll +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/bin/warp-clang.so +0 -0
  257. warp/bin/warp.so +0 -0
  258. warp/fem/field/discrete_field.py +0 -80
  259. warp/fem/space/nodal_function_space.py +0 -233
  260. warp/tests/test_all.py +0 -223
  261. warp/tests/test_array_scan.py +0 -60
  262. warp/tests/test_base.py +0 -208
  263. warp/tests/test_unresolved_func.py +0 -7
  264. warp/tests/test_unresolved_symbol.py +0 -7
  265. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  266. warp_lang-1.0.0b2.dist-info/RECORD +0 -380
  267. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  268. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  269. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  270. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  271. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/builtins.py CHANGED
@@ -6,20 +6,21 @@
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
8
  import builtins
9
- from typing import Any, Callable, Dict, List, Tuple
9
+ from typing import Any, Callable, Tuple
10
10
 
11
+ from warp.codegen import Reference
11
12
  from warp.types import *
12
13
 
13
14
  from .context import add_builtin
14
15
 
15
16
 
16
17
  def sametype_value_func(default):
17
- def fn(args, kwds, _):
18
- if args is None:
18
+ def fn(arg_types, kwds, _):
19
+ if arg_types is None:
19
20
  return default
20
- if not all(types_equal(args[0].type, a.type) for a in args[1:]):
21
- raise RuntimeError(f"Input types must be the same, found: {[type_repr(a.type) for a in args]}")
22
- return args[0].type
21
+ if not all(types_equal(arg_types[0], t) for t in arg_types[1:]):
22
+ raise RuntimeError(f"Input types must be the same, found: {[type_repr(t) for t in arg_types]}")
23
+ return arg_types[0]
23
24
 
24
25
  return fn
25
26
 
@@ -47,7 +48,7 @@ add_builtin(
47
48
  "clamp",
48
49
  input_types={"x": Scalar, "a": Scalar, "b": Scalar},
49
50
  value_func=sametype_value_func(Scalar),
50
- doc="Clamp the value of x to the range [a, b].",
51
+ doc="Clamp the value of ``x`` to the range [a, b].",
51
52
  group="Scalar Math",
52
53
  )
53
54
 
@@ -55,14 +56,14 @@ add_builtin(
55
56
  "abs",
56
57
  input_types={"x": Scalar},
57
58
  value_func=sametype_value_func(Scalar),
58
- doc="Return the absolute value of x.",
59
+ doc="Return the absolute value of ``x``.",
59
60
  group="Scalar Math",
60
61
  )
61
62
  add_builtin(
62
63
  "sign",
63
64
  input_types={"x": Scalar},
64
65
  value_func=sametype_value_func(Scalar),
65
- doc="Return -1 if x < 0, return 1 otherwise.",
66
+ doc="Return -1 if ``x`` < 0, return 1 otherwise.",
66
67
  group="Scalar Math",
67
68
  )
68
69
 
@@ -70,14 +71,14 @@ add_builtin(
70
71
  "step",
71
72
  input_types={"x": Scalar},
72
73
  value_func=sametype_value_func(Scalar),
73
- doc="Return 1.0 if x < 0.0, return 0.0 otherwise.",
74
+ doc="Return 1.0 if ``x`` < 0.0, return 0.0 otherwise.",
74
75
  group="Scalar Math",
75
76
  )
76
77
  add_builtin(
77
78
  "nonzero",
78
79
  input_types={"x": Scalar},
79
80
  value_func=sametype_value_func(Scalar),
80
- doc="Return 1.0 if x is not equal to zero, return 0.0 otherwise.",
81
+ doc="Return 1.0 if ``x`` is not equal to zero, return 0.0 otherwise.",
81
82
  group="Scalar Math",
82
83
  )
83
84
 
@@ -85,91 +86,101 @@ add_builtin(
85
86
  "sin",
86
87
  input_types={"x": Float},
87
88
  value_func=sametype_value_func(Float),
88
- doc="Return the sine of x in radians.",
89
+ doc="Return the sine of ``x`` in radians.",
89
90
  group="Scalar Math",
90
91
  )
91
92
  add_builtin(
92
93
  "cos",
93
94
  input_types={"x": Float},
94
95
  value_func=sametype_value_func(Float),
95
- doc="Return the cosine of x in radians.",
96
+ doc="Return the cosine of ``x`` in radians.",
96
97
  group="Scalar Math",
97
98
  )
98
99
  add_builtin(
99
100
  "acos",
100
101
  input_types={"x": Float},
101
102
  value_func=sametype_value_func(Float),
102
- doc="Return arccos of x in radians. Inputs are automatically clamped to [-1.0, 1.0].",
103
+ doc="Return arccos of ``x`` in radians. Inputs are automatically clamped to [-1.0, 1.0].",
103
104
  group="Scalar Math",
104
105
  )
105
106
  add_builtin(
106
107
  "asin",
107
108
  input_types={"x": Float},
108
109
  value_func=sametype_value_func(Float),
109
- doc="Return arcsin of x in radians. Inputs are automatically clamped to [-1.0, 1.0].",
110
+ doc="Return arcsin of ``x`` in radians. Inputs are automatically clamped to [-1.0, 1.0].",
110
111
  group="Scalar Math",
111
112
  )
112
113
  add_builtin(
113
114
  "sqrt",
114
115
  input_types={"x": Float},
115
116
  value_func=sametype_value_func(Float),
116
- doc="Return the sqrt of x, where x is positive.",
117
+ doc="Return the square root of ``x``, where ``x`` is positive.",
117
118
  group="Scalar Math",
119
+ require_original_output_arg=True,
120
+ )
121
+ add_builtin(
122
+ "cbrt",
123
+ input_types={"x": Float},
124
+ value_func=sametype_value_func(Float),
125
+ doc="Return the cube root of ``x``.",
126
+ group="Scalar Math",
127
+ require_original_output_arg=True,
118
128
  )
119
129
  add_builtin(
120
130
  "tan",
121
131
  input_types={"x": Float},
122
132
  value_func=sametype_value_func(Float),
123
- doc="Return tangent of x in radians.",
133
+ doc="Return the tangent of ``x`` in radians.",
124
134
  group="Scalar Math",
125
135
  )
126
136
  add_builtin(
127
137
  "atan",
128
138
  input_types={"x": Float},
129
139
  value_func=sametype_value_func(Float),
130
- doc="Return arctan of x.",
140
+ doc="Return the arctangent of ``x`` in radians.",
131
141
  group="Scalar Math",
132
142
  )
133
143
  add_builtin(
134
144
  "atan2",
135
145
  input_types={"y": Float, "x": Float},
136
146
  value_func=sametype_value_func(Float),
137
- doc="Return atan2 of x.",
147
+ doc="Return the 2-argument arctangent, atan2, of the point ``(x, y)`` in radians.",
138
148
  group="Scalar Math",
139
149
  )
140
150
  add_builtin(
141
151
  "sinh",
142
152
  input_types={"x": Float},
143
153
  value_func=sametype_value_func(Float),
144
- doc="Return the sinh of x.",
154
+ doc="Return the sinh of ``x``.",
145
155
  group="Scalar Math",
146
156
  )
147
157
  add_builtin(
148
158
  "cosh",
149
159
  input_types={"x": Float},
150
160
  value_func=sametype_value_func(Float),
151
- doc="Return the cosh of x.",
161
+ doc="Return the cosh of ``x``.",
152
162
  group="Scalar Math",
153
163
  )
154
164
  add_builtin(
155
165
  "tanh",
156
166
  input_types={"x": Float},
157
167
  value_func=sametype_value_func(Float),
158
- doc="Return the tanh of x.",
168
+ doc="Return the tanh of ``x``.",
159
169
  group="Scalar Math",
170
+ require_original_output_arg=True,
160
171
  )
161
172
  add_builtin(
162
173
  "degrees",
163
174
  input_types={"x": Float},
164
175
  value_func=sametype_value_func(Float),
165
- doc="Convert radians into degrees.",
176
+ doc="Convert ``x`` from radians into degrees.",
166
177
  group="Scalar Math",
167
178
  )
168
179
  add_builtin(
169
180
  "radians",
170
181
  input_types={"x": Float},
171
182
  value_func=sametype_value_func(Float),
172
- doc="Convert degrees into radians.",
183
+ doc="Convert ``x`` from degrees into radians.",
173
184
  group="Scalar Math",
174
185
  )
175
186
 
@@ -177,36 +188,38 @@ add_builtin(
177
188
  "log",
178
189
  input_types={"x": Float},
179
190
  value_func=sametype_value_func(Float),
180
- doc="Return the natural log (base-e) of x, where x is positive.",
191
+ doc="Return the natural logarithm (base-e) of ``x``, where ``x`` is positive.",
181
192
  group="Scalar Math",
182
193
  )
183
194
  add_builtin(
184
195
  "log2",
185
196
  input_types={"x": Float},
186
197
  value_func=sametype_value_func(Float),
187
- doc="Return the natural log (base-2) of x, where x is positive.",
198
+ doc="Return the binary logarithm (base-2) of ``x``, where ``x`` is positive.",
188
199
  group="Scalar Math",
189
200
  )
190
201
  add_builtin(
191
202
  "log10",
192
203
  input_types={"x": Float},
193
204
  value_func=sametype_value_func(Float),
194
- doc="Return the natural log (base-10) of x, where x is positive.",
205
+ doc="Return the common logarithm (base-10) of ``x``, where ``x`` is positive.",
195
206
  group="Scalar Math",
196
207
  )
197
208
  add_builtin(
198
209
  "exp",
199
210
  input_types={"x": Float},
200
211
  value_func=sametype_value_func(Float),
201
- doc="Return base-e exponential, e^x.",
212
+ doc="Return the value of the exponential function :math:`e^x`.",
202
213
  group="Scalar Math",
214
+ require_original_output_arg=True,
203
215
  )
204
216
  add_builtin(
205
217
  "pow",
206
218
  input_types={"x": Float, "y": Float},
207
219
  value_func=sametype_value_func(Float),
208
- doc="Return the result of x raised to power of y.",
220
+ doc="Return the result of ``x`` raised to power of ``y``.",
209
221
  group="Scalar Math",
222
+ require_original_output_arg=True,
210
223
  )
211
224
 
212
225
  add_builtin(
@@ -214,9 +227,9 @@ add_builtin(
214
227
  input_types={"x": Float},
215
228
  value_func=sametype_value_func(Float),
216
229
  group="Scalar Math",
217
- doc="""Calculate the nearest integer value, rounding halfway cases away from zero.
218
- This is the most intuitive form of rounding in the colloquial sense, but can be slower than other options like ``warp.rint()``.
219
- Differs from ``numpy.round()``, which behaves the same way as ``numpy.rint()``.""",
230
+ doc="""Return the nearest integer value to ``x``, rounding halfway cases away from zero.
231
+ This is the most intuitive form of rounding in the colloquial sense, but can be slower than other options like :func:`warp.rint()`.
232
+ Differs from :func:`numpy.round()`, which behaves the same way as :func:`numpy.rint()`.""",
220
233
  )
221
234
 
222
235
  add_builtin(
@@ -224,9 +237,8 @@ add_builtin(
224
237
  input_types={"x": Float},
225
238
  value_func=sametype_value_func(Float),
226
239
  group="Scalar Math",
227
- doc="""Calculate the nearest integer value, rounding halfway cases to nearest even integer.
228
- It is generally faster than ``warp.round()``.
229
- Equivalent to ``numpy.rint()``.""",
240
+ doc="""Return the nearest integer value to ``x``, rounding halfway cases to nearest even integer.
241
+ It is generally faster than :func:`warp.round()`. Equivalent to :func:`numpy.rint()`.""",
230
242
  )
231
243
 
232
244
  add_builtin(
@@ -234,10 +246,10 @@ add_builtin(
234
246
  input_types={"x": Float},
235
247
  value_func=sametype_value_func(Float),
236
248
  group="Scalar Math",
237
- doc="""Calculate the nearest integer that is closer to zero than x.
238
- In other words, it discards the fractional part of x.
239
- It is similar to casting ``float(int(x))``, but preserves the negative sign when x is in the range [-0.0, -1.0).
240
- Equivalent to ``numpy.trunc()`` and ``numpy.fix()``.""",
249
+ doc="""Return the nearest integer that is closer to zero than ``x``.
250
+ In other words, it discards the fractional part of ``x``.
251
+ It is similar to casting ``float(int(x))``, but preserves the negative sign when x is in the range [-0.0, -1.0).
252
+ Equivalent to :func:`numpy.trunc()` and :func:`numpy.fix()`.""",
241
253
  )
242
254
 
243
255
  add_builtin(
@@ -245,7 +257,7 @@ add_builtin(
245
257
  input_types={"x": Float},
246
258
  value_func=sametype_value_func(Float),
247
259
  group="Scalar Math",
248
- doc="""Calculate the largest integer that is less than or equal to x.""",
260
+ doc="""Return the largest integer that is less than or equal to ``x``.""",
249
261
  )
250
262
 
251
263
  add_builtin(
@@ -253,22 +265,31 @@ add_builtin(
253
265
  input_types={"x": Float},
254
266
  value_func=sametype_value_func(Float),
255
267
  group="Scalar Math",
256
- doc="""Calculate the smallest integer that is greater than or equal to x.""",
268
+ doc="""Return the smallest integer that is greater than or equal to ``x``.""",
269
+ )
270
+
271
+ add_builtin(
272
+ "frac",
273
+ input_types={"x": Float},
274
+ value_func=sametype_value_func(Float),
275
+ group="Scalar Math",
276
+ doc="""Retrieve the fractional part of x.
277
+ In other words, it discards the integer part of x and is equivalent to ``x - trunc(x)``.""",
257
278
  )
258
279
 
259
280
 
260
- def infer_scalar_type(args):
261
- if args is None:
281
+ def infer_scalar_type(arg_types):
282
+ if arg_types is None:
262
283
  return Scalar
263
284
 
264
- def iterate_scalar_types(args):
265
- for a in args:
266
- if hasattr(a.type, "_wp_scalar_type_"):
267
- yield a.type._wp_scalar_type_
268
- elif a.type in scalar_types:
269
- yield a.type
285
+ def iterate_scalar_types(arg_types):
286
+ for t in arg_types:
287
+ if hasattr(t, "_wp_scalar_type_"):
288
+ yield t._wp_scalar_type_
289
+ elif t in scalar_types:
290
+ yield t
270
291
 
271
- scalarTypes = set(iterate_scalar_types(args))
292
+ scalarTypes = set(iterate_scalar_types(arg_types))
272
293
  if len(scalarTypes) > 1:
273
294
  raise RuntimeError(
274
295
  f"Couldn't figure out return type as arguments have multiple precisions: {list(scalarTypes)}"
@@ -276,13 +297,13 @@ def infer_scalar_type(args):
276
297
  return list(scalarTypes)[0]
277
298
 
278
299
 
279
- def sametype_scalar_value_func(args, kwds, _):
280
- if args is None:
300
+ def sametype_scalar_value_func(arg_types, kwds, _):
301
+ if arg_types is None:
281
302
  return Scalar
282
- if not all(types_equal(args[0].type, a.type) for a in args[1:]):
283
- raise RuntimeError(f"Input types must be exactly the same, {[a.type for a in args]}")
303
+ if not all(types_equal(arg_types[0], t) for t in arg_types[1:]):
304
+ raise RuntimeError(f"Input types must be exactly the same, {[t for t in arg_types]}")
284
305
 
285
- return infer_scalar_type(args)
306
+ return infer_scalar_type(arg_types)
286
307
 
287
308
 
288
309
  # ---------------------------------
@@ -307,14 +328,14 @@ add_builtin(
307
328
  "min",
308
329
  input_types={"x": vector(length=Any, dtype=Scalar), "y": vector(length=Any, dtype=Scalar)},
309
330
  value_func=sametype_value_func(vector(length=Any, dtype=Scalar)),
310
- doc="Return the element wise minimum of two vectors.",
331
+ doc="Return the element-wise minimum of two vectors.",
311
332
  group="Vector Math",
312
333
  )
313
334
  add_builtin(
314
335
  "max",
315
336
  input_types={"x": vector(length=Any, dtype=Scalar), "y": vector(length=Any, dtype=Scalar)},
316
337
  value_func=sametype_value_func(vector(length=Any, dtype=Scalar)),
317
- doc="Return the element wise maximum of two vectors.",
338
+ doc="Return the element-wise maximum of two vectors.",
318
339
  group="Vector Math",
319
340
  )
320
341
 
@@ -322,41 +343,41 @@ add_builtin(
322
343
  "min",
323
344
  input_types={"v": vector(length=Any, dtype=Scalar)},
324
345
  value_func=sametype_scalar_value_func,
325
- doc="Return the minimum element of a vector.",
346
+ doc="Return the minimum element of a vector ``v``.",
326
347
  group="Vector Math",
327
348
  )
328
349
  add_builtin(
329
350
  "max",
330
351
  input_types={"v": vector(length=Any, dtype=Scalar)},
331
352
  value_func=sametype_scalar_value_func,
332
- doc="Return the maximum element of a vector.",
353
+ doc="Return the maximum element of a vector ``v``.",
333
354
  group="Vector Math",
334
355
  )
335
356
 
336
357
  add_builtin(
337
358
  "argmin",
338
359
  input_types={"v": vector(length=Any, dtype=Scalar)},
339
- value_func=lambda args, kwds, _: warp.uint32,
340
- doc="Return the index of the minimum element of a vector.",
360
+ value_func=lambda arg_types, kwds, _: warp.uint32,
361
+ doc="Return the index of the minimum element of a vector ``v``.",
341
362
  group="Vector Math",
342
363
  missing_grad=True,
343
364
  )
344
365
  add_builtin(
345
366
  "argmax",
346
367
  input_types={"v": vector(length=Any, dtype=Scalar)},
347
- value_func=lambda args, kwds, _: warp.uint32,
348
- doc="Return the index of the maximum element of a vector.",
368
+ value_func=lambda arg_types, kwds, _: warp.uint32,
369
+ doc="Return the index of the maximum element of a vector ``v``.",
349
370
  group="Vector Math",
350
371
  missing_grad=True,
351
372
  )
352
373
 
353
374
 
354
- def value_func_outer(args, kwds, _):
355
- if args is None:
375
+ def value_func_outer(arg_types, kwds, _):
376
+ if arg_types is None:
356
377
  return matrix(shape=(Any, Any), dtype=Scalar)
357
378
 
358
- scalarType = infer_scalar_type(args)
359
- vectorLengths = [i.type._length_ for i in args]
379
+ scalarType = infer_scalar_type(arg_types)
380
+ vectorLengths = [t._length_ for t in arg_types]
360
381
  return matrix(shape=(vectorLengths), dtype=scalarType)
361
382
 
362
383
 
@@ -365,7 +386,7 @@ add_builtin(
365
386
  input_types={"x": vector(length=Any, dtype=Scalar), "y": vector(length=Any, dtype=Scalar)},
366
387
  value_func=value_func_outer,
367
388
  group="Vector Math",
368
- doc="Compute the outer product x*y^T for two vec2 objects.",
389
+ doc="Compute the outer product ``x*y^T`` for two vectors.",
369
390
  )
370
391
 
371
392
  add_builtin(
@@ -373,14 +394,14 @@ add_builtin(
373
394
  input_types={"x": vector(length=3, dtype=Scalar), "y": vector(length=3, dtype=Scalar)},
374
395
  value_func=sametype_value_func(vector(length=3, dtype=Scalar)),
375
396
  group="Vector Math",
376
- doc="Compute the cross product of two 3d vectors.",
397
+ doc="Compute the cross product of two 3D vectors.",
377
398
  )
378
399
  add_builtin(
379
400
  "skew",
380
401
  input_types={"x": vector(length=3, dtype=Scalar)},
381
- value_func=lambda args, kwds, _: matrix(shape=(3, 3), dtype=args[0].type._wp_scalar_type_),
402
+ value_func=lambda arg_types, kwds, _: matrix(shape=(3, 3), dtype=arg_types[0]._wp_scalar_type_),
382
403
  group="Vector Math",
383
- doc="Compute the skew symmetric matrix for a 3d vector.",
404
+ doc="Compute the skew-symmetric 3x3 matrix for a 3D vector ``x``.",
384
405
  )
385
406
 
386
407
  add_builtin(
@@ -388,59 +409,62 @@ add_builtin(
388
409
  input_types={"x": vector(length=Any, dtype=Float)},
389
410
  value_func=sametype_scalar_value_func,
390
411
  group="Vector Math",
391
- doc="Compute the length of a vector.",
412
+ doc="Compute the length of a vector ``x``.",
413
+ require_original_output_arg=True,
392
414
  )
393
415
  add_builtin(
394
416
  "length",
395
417
  input_types={"x": quaternion(dtype=Float)},
396
418
  value_func=sametype_scalar_value_func,
397
419
  group="Vector Math",
398
- doc="Compute the length of a quaternion.",
420
+ doc="Compute the length of a quaternion ``x``.",
421
+ require_original_output_arg=True,
399
422
  )
400
423
  add_builtin(
401
424
  "length_sq",
402
425
  input_types={"x": vector(length=Any, dtype=Scalar)},
403
426
  value_func=sametype_scalar_value_func,
404
427
  group="Vector Math",
405
- doc="Compute the squared length of a 2d vector.",
428
+ doc="Compute the squared length of a 2D vector ``x``.",
406
429
  )
407
430
  add_builtin(
408
431
  "length_sq",
409
432
  input_types={"x": quaternion(dtype=Scalar)},
410
433
  value_func=sametype_scalar_value_func,
411
434
  group="Vector Math",
412
- doc="Compute the squared length of a quaternion.",
435
+ doc="Compute the squared length of a quaternion ``x``.",
413
436
  )
414
437
  add_builtin(
415
438
  "normalize",
416
439
  input_types={"x": vector(length=Any, dtype=Float)},
417
440
  value_func=sametype_value_func(vector(length=Any, dtype=Scalar)),
418
441
  group="Vector Math",
419
- doc="Compute the normalized value of x, if length(x) is 0 then the zero vector is returned.",
442
+ doc="Compute the normalized value of ``x``. If ``length(x)`` is 0 then the zero vector is returned.",
443
+ require_original_output_arg=True,
420
444
  )
421
445
  add_builtin(
422
446
  "normalize",
423
447
  input_types={"x": quaternion(dtype=Float)},
424
448
  value_func=sametype_value_func(quaternion(dtype=Scalar)),
425
449
  group="Vector Math",
426
- doc="Compute the normalized value of x, if length(x) is 0 then the zero quat is returned.",
450
+ doc="Compute the normalized value of ``x``. If ``length(x)`` is 0, then the zero quaternion is returned.",
427
451
  )
428
452
 
429
453
  add_builtin(
430
454
  "transpose",
431
455
  input_types={"m": matrix(shape=(Any, Any), dtype=Scalar)},
432
- value_func=lambda args, kwds, _: matrix(
433
- shape=(args[0].type._shape_[1], args[0].type._shape_[0]), dtype=args[0].type._wp_scalar_type_
456
+ value_func=lambda arg_types, kwds, _: matrix(
457
+ shape=(arg_types[0]._shape_[1], arg_types[0]._shape_[0]), dtype=arg_types[0]._wp_scalar_type_
434
458
  ),
435
459
  group="Vector Math",
436
- doc="Return the transpose of the matrix m",
460
+ doc="Return the transpose of the matrix ``m``.",
437
461
  )
438
462
 
439
463
 
440
- def value_func_mat_inv(args, kwds, _):
441
- if args is None:
464
+ def value_func_mat_inv(arg_types, kwds, _):
465
+ if arg_types is None:
442
466
  return matrix(shape=(Any, Any), dtype=Float)
443
- return args[0].type
467
+ return arg_types[0]
444
468
 
445
469
 
446
470
  add_builtin(
@@ -448,7 +472,8 @@ add_builtin(
448
472
  input_types={"m": matrix(shape=(2, 2), dtype=Float)},
449
473
  value_func=value_func_mat_inv,
450
474
  group="Vector Math",
451
- doc="Return the inverse of a 2x2 matrix m",
475
+ doc="Return the inverse of a 2x2 matrix ``m``.",
476
+ require_original_output_arg=True,
452
477
  )
453
478
 
454
479
  add_builtin(
@@ -456,7 +481,8 @@ add_builtin(
456
481
  input_types={"m": matrix(shape=(3, 3), dtype=Float)},
457
482
  value_func=value_func_mat_inv,
458
483
  group="Vector Math",
459
- doc="Return the inverse of a 3x3 matrix m",
484
+ doc="Return the inverse of a 3x3 matrix ``m``.",
485
+ require_original_output_arg=True,
460
486
  )
461
487
 
462
488
  add_builtin(
@@ -464,14 +490,15 @@ add_builtin(
464
490
  input_types={"m": matrix(shape=(4, 4), dtype=Float)},
465
491
  value_func=value_func_mat_inv,
466
492
  group="Vector Math",
467
- doc="Return the inverse of a 4x4 matrix m",
493
+ doc="Return the inverse of a 4x4 matrix ``m``.",
494
+ require_original_output_arg=True,
468
495
  )
469
496
 
470
497
 
471
- def value_func_mat_det(args, kwds, _):
472
- if args is None:
498
+ def value_func_mat_det(arg_types, kwds, _):
499
+ if arg_types is None:
473
500
  return Scalar
474
- return args[0].type._wp_scalar_type_
501
+ return arg_types[0]._wp_scalar_type_
475
502
 
476
503
 
477
504
  add_builtin(
@@ -479,7 +506,7 @@ add_builtin(
479
506
  input_types={"m": matrix(shape=(2, 2), dtype=Float)},
480
507
  value_func=value_func_mat_det,
481
508
  group="Vector Math",
482
- doc="Return the determinant of a 2x2 matrix m",
509
+ doc="Return the determinant of a 2x2 matrix ``m``.",
483
510
  )
484
511
 
485
512
  add_builtin(
@@ -487,7 +514,7 @@ add_builtin(
487
514
  input_types={"m": matrix(shape=(3, 3), dtype=Float)},
488
515
  value_func=value_func_mat_det,
489
516
  group="Vector Math",
490
- doc="Return the determinant of a 3x3 matrix m",
517
+ doc="Return the determinant of a 3x3 matrix ``m``.",
491
518
  )
492
519
 
493
520
  add_builtin(
@@ -495,16 +522,16 @@ add_builtin(
495
522
  input_types={"m": matrix(shape=(4, 4), dtype=Float)},
496
523
  value_func=value_func_mat_det,
497
524
  group="Vector Math",
498
- doc="Return the determinant of a 4x4 matrix m",
525
+ doc="Return the determinant of a 4x4 matrix ``m``.",
499
526
  )
500
527
 
501
528
 
502
- def value_func_mat_trace(args, kwds, _):
503
- if args is None:
529
+ def value_func_mat_trace(arg_types, kwds, _):
530
+ if arg_types is None:
504
531
  return Scalar
505
- if args[0].type._shape_[0] != args[0].type._shape_[1]:
506
- raise RuntimeError(f"Matrix shape is {args[0].type._shape_}. Cannot find the trace of non square matrices")
507
- return args[0].type._wp_scalar_type_
532
+ if arg_types[0]._shape_[0] != arg_types[0]._shape_[1]:
533
+ raise RuntimeError(f"Matrix shape is {arg_types[0]._shape_}. Cannot find the trace of non square matrices")
534
+ return arg_types[0]._wp_scalar_type_
508
535
 
509
536
 
510
537
  add_builtin(
@@ -512,15 +539,15 @@ add_builtin(
512
539
  input_types={"m": matrix(shape=(Any, Any), dtype=Scalar)},
513
540
  value_func=value_func_mat_trace,
514
541
  group="Vector Math",
515
- doc="Return the trace of the matrix m",
542
+ doc="Return the trace of the matrix ``m``.",
516
543
  )
517
544
 
518
545
 
519
- def value_func_diag(args, kwds, _):
520
- if args is None:
546
+ def value_func_diag(arg_types, kwds, _):
547
+ if arg_types is None:
521
548
  return matrix(shape=(Any, Any), dtype=Scalar)
522
549
  else:
523
- return matrix(shape=(args[0].type._length_, args[0].type._length_), dtype=args[0].type._wp_scalar_type_)
550
+ return matrix(shape=(arg_types[0]._length_, arg_types[0]._length_), dtype=arg_types[0]._wp_scalar_type_)
524
551
 
525
552
 
526
553
  add_builtin(
@@ -528,19 +555,19 @@ add_builtin(
528
555
  input_types={"d": vector(length=Any, dtype=Scalar)},
529
556
  value_func=value_func_diag,
530
557
  group="Vector Math",
531
- doc="Returns a matrix with the components of the vector d on the diagonal",
558
+ doc="Returns a matrix with the components of the vector ``d`` on the diagonal.",
532
559
  )
533
560
 
534
561
 
535
- def value_func_get_diag(args, kwds, _):
536
- if args is None:
562
+ def value_func_get_diag(arg_types, kwds, _):
563
+ if arg_types is None:
537
564
  return vector(length=(Any), dtype=Scalar)
538
565
  else:
539
- if args[0].type._shape_[0] != args[0].type._shape_[1]:
566
+ if arg_types[0]._shape_[0] != arg_types[0]._shape_[1]:
540
567
  raise RuntimeError(
541
- f"Matrix shape is {args[0].type._shape_}; get_diag is only available for square matrices."
568
+ f"Matrix shape is {arg_types[0]._shape_}; get_diag is only available for square matrices."
542
569
  )
543
- return vector(length=args[0].type._shape_[0], dtype=args[0].type._wp_scalar_type_)
570
+ return vector(length=arg_types[0]._shape_[0], dtype=arg_types[0]._wp_scalar_type_)
544
571
 
545
572
 
546
573
  add_builtin(
@@ -548,7 +575,7 @@ add_builtin(
548
575
  input_types={"m": matrix(shape=(Any, Any), dtype=Scalar)},
549
576
  value_func=value_func_get_diag,
550
577
  group="Vector Math",
551
- doc="Returns a vector containing the diagonal elements of the square matrix.",
578
+ doc="Returns a vector containing the diagonal elements of the square matrix ``m``.",
552
579
  )
553
580
 
554
581
  add_builtin(
@@ -556,14 +583,15 @@ add_builtin(
556
583
  input_types={"x": vector(length=Any, dtype=Scalar), "y": vector(length=Any, dtype=Scalar)},
557
584
  value_func=sametype_value_func(vector(length=Any, dtype=Scalar)),
558
585
  group="Vector Math",
559
- doc="Component wise multiply of two 2d vectors.",
586
+ doc="Component-wise multiplication of two 2D vectors.",
560
587
  )
561
588
  add_builtin(
562
589
  "cw_div",
563
590
  input_types={"x": vector(length=Any, dtype=Scalar), "y": vector(length=Any, dtype=Scalar)},
564
591
  value_func=sametype_value_func(vector(length=Any, dtype=Scalar)),
565
592
  group="Vector Math",
566
- doc="Component wise division of two 2d vectors.",
593
+ doc="Component-wise division of two 2D vectors.",
594
+ require_original_output_arg=True,
567
595
  )
568
596
 
569
597
  add_builtin(
@@ -571,14 +599,15 @@ add_builtin(
571
599
  input_types={"x": matrix(shape=(Any, Any), dtype=Scalar), "y": matrix(shape=(Any, Any), dtype=Scalar)},
572
600
  value_func=sametype_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
573
601
  group="Vector Math",
574
- doc="Component wise multiply of two 2d vectors.",
602
+ doc="Component-wise multiplication of two 2D vectors.",
575
603
  )
576
604
  add_builtin(
577
605
  "cw_div",
578
606
  input_types={"x": matrix(shape=(Any, Any), dtype=Scalar), "y": matrix(shape=(Any, Any), dtype=Scalar)},
579
607
  value_func=sametype_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
580
608
  group="Vector Math",
581
- doc="Component wise division of two 2d vectors.",
609
+ doc="Component-wise division of two 2D vectors.",
610
+ require_original_output_arg=True,
582
611
  )
583
612
 
584
613
 
@@ -594,15 +623,15 @@ for u in [bool, builtins.bool]:
594
623
  add_builtin(bool.__name__, input_types={"u": u}, value_type=bool, doc="", hidden=True, export=False, namespace="")
595
624
 
596
625
 
597
- def vector_constructor_func(args, kwds, templates):
598
- if args is None:
626
+ def vector_constructor_func(arg_types, kwds, templates):
627
+ if arg_types is None:
599
628
  return vector(length=Any, dtype=Scalar)
600
629
 
601
630
  if templates is None or len(templates) == 0:
602
631
  # handle construction of anonymous (undeclared) vector types
603
632
 
604
633
  if "length" in kwds:
605
- if len(args) == 0:
634
+ if len(arg_types) == 0:
606
635
  if "dtype" not in kwds:
607
636
  raise RuntimeError(
608
637
  "vec() must have dtype as a keyword argument if it has no positional arguments, e.g.: wp.vector(length=5, dtype=wp.float32)"
@@ -612,12 +641,12 @@ def vector_constructor_func(args, kwds, templates):
612
641
  veclen = kwds["length"]
613
642
  vectype = kwds["dtype"]
614
643
 
615
- elif len(args) == 1:
644
+ elif len(arg_types) == 1:
616
645
  # value initialization e.g.: wp.vec(1.0, length=5)
617
646
  veclen = kwds["length"]
618
- vectype = args[0].type
647
+ vectype = arg_types[0]
619
648
  if getattr(vectype, "_wp_generic_type_str_", None) == "vec_t":
620
- # constructor from another matrix
649
+ # constructor from another vector
621
650
  if vectype._length_ != veclen:
622
651
  raise RuntimeError(
623
652
  f"Incompatible vector lengths for casting copy constructor, {veclen} vs {vectype._length_}"
@@ -629,28 +658,37 @@ def vector_constructor_func(args, kwds, templates):
629
658
  )
630
659
 
631
660
  else:
632
- if len(args) == 0:
661
+ if len(arg_types) == 0:
633
662
  raise RuntimeError(
634
663
  "vec() must have at least one numeric argument, if it's length, dtype is not specified"
635
664
  )
636
665
 
637
666
  if "dtype" in kwds:
667
+ # casting constructor
668
+ if len(arg_types) == 1 and types_equal(
669
+ arg_types[0], vector(length=Any, dtype=Scalar), match_generic=True
670
+ ):
671
+ veclen = arg_types[0]._length_
672
+ vectype = kwds["dtype"]
673
+ templates.append(veclen)
674
+ templates.append(vectype)
675
+ return vector(length=veclen, dtype=vectype)
638
676
  raise RuntimeError(
639
677
  "vec() should not have dtype specified if numeric arguments are given, the dtype will be inferred from the argument types"
640
678
  )
641
679
 
642
680
  # component wise construction of an anonymous vector, e.g. wp.vec(wp.float16(1.0), wp.float16(2.0), ....)
643
681
  # we infer the length and data type from the number and type of the arg values
644
- veclen = len(args)
645
- vectype = args[0].type
682
+ veclen = len(arg_types)
683
+ vectype = arg_types[0]
646
684
 
647
- if len(args) == 1 and getattr(vectype, "_wp_generic_type_str_", None) == "vec_t":
685
+ if len(arg_types) == 1 and getattr(vectype, "_wp_generic_type_str_", None) == "vec_t":
648
686
  # constructor from another vector
649
687
  veclen = vectype._length_
650
688
  vectype = vectype._wp_scalar_type_
651
- elif not all(vectype == a.type for a in args):
689
+ elif not all(vectype == t for t in arg_types):
652
690
  raise RuntimeError(
653
- f"All numeric arguments to vec() constructor should have the same type, expected {veclen} args of type {vectype}, received { ','.join(map(lambda x : str(x.type), args)) }"
691
+ f"All numeric arguments to vec() constructor should have the same type, expected {veclen} arg_types of type {vectype}, received { ','.join(map(lambda t : str(t), arg_types)) }"
654
692
  )
655
693
 
656
694
  # update the templates list, so we can generate vec<len, type>() correctly in codegen
@@ -660,15 +698,15 @@ def vector_constructor_func(args, kwds, templates):
660
698
  else:
661
699
  # construction of a predeclared type, e.g.: vec5d
662
700
  veclen, vectype = templates
663
- if len(args) == 1 and getattr(args[0].type, "_wp_generic_type_str_", None) == "vec_t":
701
+ if len(arg_types) == 1 and getattr(arg_types[0], "_wp_generic_type_str_", None) == "vec_t":
664
702
  # constructor from another vector
665
- if args[0].type._length_ != veclen:
703
+ if arg_types[0]._length_ != veclen:
666
704
  raise RuntimeError(
667
- f"Incompatible matrix sizes for casting copy constructor, {veclen} vs {args[0].type._length_}"
705
+ f"Incompatible matrix sizes for casting copy constructor, {veclen} vs {arg_types[0]._length_}"
668
706
  )
669
- elif not all(vectype == a.type for a in args):
707
+ elif not all(vectype == t for t in arg_types):
670
708
  raise RuntimeError(
671
- f"All numeric arguments to vec() constructor should have the same type, expected {veclen} args of type {vectype}, received { ','.join(map(lambda x : str(x.type), args)) }"
709
+ f"All numeric arguments to vec() constructor should have the same type, expected {veclen} arg_types of type {vectype}, received { ','.join(map(lambda t : str(t), arg_types)) }"
672
710
  )
673
711
 
674
712
  retvalue = vector(length=veclen, dtype=vectype)
@@ -677,9 +715,9 @@ def vector_constructor_func(args, kwds, templates):
677
715
 
678
716
  add_builtin(
679
717
  "vector",
680
- input_types={"*args": Scalar, "length": int, "dtype": Scalar},
718
+ input_types={"*arg_types": Scalar, "length": int, "dtype": Scalar},
681
719
  variadic=True,
682
- initializer_list_func=lambda args, _: len(args) > 4,
720
+ initializer_list_func=lambda arg_types, _: len(arg_types) > 4,
683
721
  value_func=vector_constructor_func,
684
722
  native_func="vec_t",
685
723
  doc="Construct a vector of with given length and dtype.",
@@ -688,8 +726,8 @@ add_builtin(
688
726
  )
689
727
 
690
728
 
691
- def matrix_constructor_func(args, kwds, templates):
692
- if args is None:
729
+ def matrix_constructor_func(arg_types, kwds, templates):
730
+ if arg_types is None:
693
731
  return matrix(shape=(Any, Any), dtype=Scalar)
694
732
 
695
733
  if len(templates) == 0:
@@ -697,7 +735,7 @@ def matrix_constructor_func(args, kwds, templates):
697
735
  if "shape" not in kwds:
698
736
  raise RuntimeError("shape keyword must be specified when calling matrix() function")
699
737
 
700
- if len(args) == 0:
738
+ if len(arg_types) == 0:
701
739
  if "dtype" not in kwds:
702
740
  raise RuntimeError("matrix() must have dtype as a keyword argument if it has no positional arguments")
703
741
 
@@ -708,16 +746,16 @@ def matrix_constructor_func(args, kwds, templates):
708
746
  else:
709
747
  # value initialization, e.g.: m = matrix(1.0, shape=(3,2))
710
748
  shape = kwds["shape"]
711
- dtype = args[0].type
749
+ dtype = arg_types[0]
712
750
 
713
- if len(args) == 1 and getattr(dtype, "_wp_generic_type_str_", None) == "mat_t":
751
+ if len(arg_types) == 1 and getattr(dtype, "_wp_generic_type_str_", None) == "mat_t":
714
752
  # constructor from another matrix
715
- if types[0]._shape_ != shape:
753
+ if arg_types[0]._shape_ != shape:
716
754
  raise RuntimeError(
717
- f"Incompatible matrix sizes for casting copy constructor, {shape} vs {types[0]._shape_}"
755
+ f"Incompatible matrix sizes for casting copy constructor, {shape} vs {arg_types[0]._shape_}"
718
756
  )
719
757
  dtype = dtype._wp_scalar_type_
720
- elif len(args) > 1 and len(args) != shape[0] * shape[1]:
758
+ elif len(arg_types) > 1 and len(arg_types) != shape[0] * shape[1]:
721
759
  raise RuntimeError(
722
760
  "Wrong number of arguments for matrix() function, must initialize with either a scalar value, or m*n values"
723
761
  )
@@ -731,35 +769,34 @@ def matrix_constructor_func(args, kwds, templates):
731
769
  shape = (templates[0], templates[1])
732
770
  dtype = templates[2]
733
771
 
734
- if len(args) > 0:
735
- types = [a.type for a in args]
736
- if len(args) == 1 and getattr(types[0], "_wp_generic_type_str_", None) == "mat_t":
772
+ if len(arg_types) > 0:
773
+ if len(arg_types) == 1 and getattr(arg_types[0], "_wp_generic_type_str_", None) == "mat_t":
737
774
  # constructor from another matrix with same dimension but possibly different type
738
- if types[0]._shape_ != shape:
775
+ if arg_types[0]._shape_ != shape:
739
776
  raise RuntimeError(
740
- f"Incompatible matrix sizes for casting copy constructor, {shape} vs {types[0]._shape_}"
777
+ f"Incompatible matrix sizes for casting copy constructor, {shape} vs {arg_types[0]._shape_}"
741
778
  )
742
779
  else:
743
780
  # check scalar arg type matches declared type
744
- if infer_scalar_type(args) != dtype:
781
+ if infer_scalar_type(arg_types) != dtype:
745
782
  raise RuntimeError("Wrong scalar type for mat {} constructor".format(",".join(map(str, templates))))
746
783
 
747
784
  # check vector arg type matches declared type
748
- if all(hasattr(a, "_wp_generic_type_str_") and a._wp_generic_type_str_ == "vec_t" for a in types):
749
- cols = len(types)
785
+ if all(hasattr(a, "_wp_generic_type_str_") and a._wp_generic_type_str_ == "vec_t" for a in arg_types):
786
+ cols = len(arg_types)
750
787
  if shape[1] != cols:
751
788
  raise RuntimeError(
752
789
  "Wrong number of vectors when attempting to construct a matrix with column vectors"
753
790
  )
754
791
 
755
- if not all(a._length_ == shape[0] for a in types):
792
+ if not all(a._length_ == shape[0] for a in arg_types):
756
793
  raise RuntimeError(
757
794
  "Wrong vector row count when attempting to construct a matrix with column vectors"
758
795
  )
759
796
  else:
760
797
  # check that we either got 1 arg (scalar construction), or enough values for whole matrix
761
798
  size = shape[0] * shape[1]
762
- if len(args) > 1 and len(args) != size:
799
+ if len(arg_types) > 1 and len(arg_types) != size:
763
800
  raise RuntimeError(
764
801
  "Wrong number of scalars when attempting to construct a matrix from a list of components"
765
802
  )
@@ -768,37 +805,34 @@ def matrix_constructor_func(args, kwds, templates):
768
805
 
769
806
 
770
807
  # only use initializer list if matrix size < 5x5, or for scalar construction
771
- def matrix_initlist_func(args, templates):
808
+ def matrix_initlist_func(arg_types, templates):
772
809
  m, n, dtype = templates
773
- if (
774
- len(args) == 0
775
- or len(args) == 1 # zero construction
810
+ return not (
811
+ len(arg_types) == 0
812
+ or len(arg_types) == 1 # zero construction
776
813
  or (m == n and n < 5) # scalar construction # value construction for small matrices
777
- ):
778
- return False
779
- else:
780
- return True
814
+ )
781
815
 
782
816
 
783
817
  add_builtin(
784
818
  "matrix",
785
- input_types={"*args": Scalar, "shape": Tuple[int, int], "dtype": Scalar},
819
+ input_types={"*arg_types": Scalar, "shape": Tuple[int, int], "dtype": Scalar},
786
820
  variadic=True,
787
821
  initializer_list_func=matrix_initlist_func,
788
822
  value_func=matrix_constructor_func,
789
823
  native_func="mat_t",
790
- doc="Construct a matrix, if positional args are not given then matrix will be zero-initialized.",
824
+ doc="Construct a matrix. If the positional ``arg_types`` are not given, then matrix will be zero-initialized.",
791
825
  group="Vector Math",
792
826
  export=False,
793
827
  )
794
828
 
795
829
 
796
830
  # identity:
797
- def matrix_identity_value_func(args, kwds, templates):
798
- if args is None:
831
+ def matrix_identity_value_func(arg_types, kwds, templates):
832
+ if arg_types is None:
799
833
  return matrix(shape=(Any, Any), dtype=Scalar)
800
834
 
801
- if len(args):
835
+ if len(arg_types):
802
836
  raise RuntimeError("identity() function does not accept positional arguments")
803
837
 
804
838
  if "n" not in kwds:
@@ -829,7 +863,7 @@ add_builtin(
829
863
  )
830
864
 
831
865
 
832
- def matrix_transform_value_func(args, kwds, templates):
866
+ def matrix_transform_value_func(arg_types, kwds, templates):
833
867
  if templates is None:
834
868
  return matrix(shape=(Any, Any), dtype=Float)
835
869
 
@@ -839,7 +873,7 @@ def matrix_transform_value_func(args, kwds, templates):
839
873
  m, n, dtype = templates
840
874
  if (m, n) != (4, 4):
841
875
  raise RuntimeError("Can only construct 4x4 matrices with position, rotation and scale")
842
- if infer_scalar_type(args) != dtype:
876
+ if infer_scalar_type(arg_types) != dtype:
843
877
  raise RuntimeError("Wrong scalar type for mat<{}> constructor".format(",".join(map(str, templates))))
844
878
 
845
879
  return matrix(shape=(4, 4), dtype=dtype)
@@ -854,7 +888,8 @@ add_builtin(
854
888
  },
855
889
  value_func=matrix_transform_value_func,
856
890
  native_func="mat_t",
857
- doc="""Construct a 4x4 transformation matrix that applies the transformations as Translation(pos)*Rotation(rot)*Scale(scale) when applied to column vectors, i.e.: y = (TRS)*x""",
891
+ doc="""Construct a 4x4 transformation matrix that applies the transformations as
892
+ Translation(pos)*Rotation(rot)*Scale(scale) when applied to column vectors, i.e.: y = (TRS)*x""",
858
893
  group="Vector Math",
859
894
  export=False,
860
895
  )
@@ -873,8 +908,8 @@ add_builtin(
873
908
  value_type=None,
874
909
  group="Vector Math",
875
910
  export=False,
876
- doc="""Compute the SVD of a 3x3 matrix. The singular values are returned in sigma,
877
- while the left and right basis vectors are returned in U and V.""",
911
+ doc="""Compute the SVD of a 3x3 matrix ``A``. The singular values are returned in ``sigma``,
912
+ while the left and right basis vectors are returned in ``U`` and ``V``.""",
878
913
  )
879
914
 
880
915
  add_builtin(
@@ -887,7 +922,8 @@ add_builtin(
887
922
  value_type=None,
888
923
  group="Vector Math",
889
924
  export=False,
890
- doc="""Compute the QR decomposition of a 3x3 matrix. The orthogonal matrix is returned in Q, while the upper triangular matrix is returned in R.""",
925
+ doc="""Compute the QR decomposition of a 3x3 matrix ``A``. The orthogonal matrix is returned in ``Q``,
926
+ while the upper triangular matrix is returned in ``R``.""",
891
927
  )
892
928
 
893
929
  add_builtin(
@@ -900,36 +936,53 @@ add_builtin(
900
936
  value_type=None,
901
937
  group="Vector Math",
902
938
  export=False,
903
- doc="""Compute the eigendecomposition of a 3x3 matrix. The eigenvectors are returned as the columns of Q, while the corresponding eigenvalues are returned in d.""",
939
+ doc="""Compute the eigendecomposition of a 3x3 matrix ``A``. The eigenvectors are returned as the columns of ``Q``,
940
+ while the corresponding eigenvalues are returned in ``d``.""",
904
941
  )
905
942
 
906
943
  # ---------------------------------
907
944
  # Quaternion Math
908
945
 
909
946
 
910
- def quaternion_value_func(args, kwds, templates):
911
- if args is None:
912
- return quaternion(dtype=Scalar)
947
+ def quaternion_value_func(arg_types, kwds, templates):
948
+ if arg_types is None:
949
+ return quaternion(dtype=Float)
913
950
 
914
- # if constructing anonymous quat type then infer output type from arguments
915
951
  if len(templates) == 0:
916
- dtype = infer_scalar_type(args)
952
+ if "dtype" in kwds:
953
+ # casting constructor
954
+ dtype = kwds["dtype"]
955
+ else:
956
+ # if constructing anonymous quat type then infer output type from arguments
957
+ dtype = infer_scalar_type(arg_types)
917
958
  templates.append(dtype)
918
959
  else:
919
- # if constructing predeclared type then check args match expectation
920
- if len(args) > 0 and infer_scalar_type(args) != templates[0]:
960
+ # if constructing predeclared type then check arg_types match expectation
961
+ if len(arg_types) > 0 and infer_scalar_type(arg_types) != templates[0]:
921
962
  raise RuntimeError("Wrong scalar type for quat {} constructor".format(",".join(map(str, templates))))
922
963
 
923
964
  return quaternion(dtype=templates[0])
924
965
 
925
966
 
967
+ def quat_cast_value_func(arg_types, kwds, templates):
968
+ if arg_types is None:
969
+ raise RuntimeError("Missing quaternion argument.")
970
+ if "dtype" not in kwds:
971
+ raise RuntimeError("Missing 'dtype' kwd.")
972
+
973
+ dtype = kwds["dtype"]
974
+ templates.append(dtype)
975
+
976
+ return quaternion(dtype=dtype)
977
+
978
+
926
979
  add_builtin(
927
980
  "quaternion",
928
981
  input_types={},
929
982
  value_func=quaternion_value_func,
930
983
  native_func="quat_t",
931
984
  group="Quaternion Math",
932
- doc="""Construct a zero-initialized quaternion, quaternions are laid out as
985
+ doc="""Construct a zero-initialized quaternion. Quaternions are laid out as
933
986
  [ix, iy, iz, r], where ix, iy, iz are the imaginary part, and r the real part.""",
934
987
  export=False,
935
988
  )
@@ -939,7 +992,7 @@ add_builtin(
939
992
  value_func=quaternion_value_func,
940
993
  native_func="quat_t",
941
994
  group="Quaternion Math",
942
- doc="Create a quaternion using the supplied components (type inferred from component type)",
995
+ doc="Create a quaternion using the supplied components (type inferred from component type).",
943
996
  export=False,
944
997
  )
945
998
  add_builtin(
@@ -948,14 +1001,23 @@ add_builtin(
948
1001
  value_func=quaternion_value_func,
949
1002
  native_func="quat_t",
950
1003
  group="Quaternion Math",
951
- doc="Create a quaternion using the supplied vector/scalar (type inferred from scalar type)",
1004
+ doc="Create a quaternion using the supplied vector/scalar (type inferred from scalar type).",
1005
+ export=False,
1006
+ )
1007
+ add_builtin(
1008
+ "quaternion",
1009
+ input_types={"q": quaternion(dtype=Float)},
1010
+ value_func=quat_cast_value_func,
1011
+ native_func="quat_t",
1012
+ group="Quaternion Math",
1013
+ doc="Construct a quaternion of type dtype from another quaternion of a different dtype.",
952
1014
  export=False,
953
1015
  )
954
1016
 
955
1017
 
956
- def quat_identity_value_func(args, kwds, templates):
957
- # if args is None then we are in 'export' mode
958
- if args is None:
1018
+ def quat_identity_value_func(arg_types, kwds, templates):
1019
+ # if arg_types is None then we are in 'export' mode
1020
+ if arg_types is None:
959
1021
  return quatf
960
1022
 
961
1023
  if "dtype" not in kwds:
@@ -981,7 +1043,7 @@ add_builtin(
981
1043
  add_builtin(
982
1044
  "quat_from_axis_angle",
983
1045
  input_types={"axis": vector(length=3, dtype=Float), "angle": Float},
984
- value_func=lambda args, kwds, _: quaternion(dtype=infer_scalar_type(args)),
1046
+ value_func=lambda arg_types, kwds, _: quaternion(dtype=infer_scalar_type(arg_types)),
985
1047
  group="Quaternion Math",
986
1048
  doc="Construct a quaternion representing a rotation of angle radians around the given axis.",
987
1049
  )
@@ -995,49 +1057,50 @@ add_builtin(
995
1057
  add_builtin(
996
1058
  "quat_from_matrix",
997
1059
  input_types={"m": matrix(shape=(3, 3), dtype=Float)},
998
- value_func=lambda args, kwds, _: quaternion(dtype=infer_scalar_type(args)),
1060
+ value_func=lambda arg_types, kwds, _: quaternion(dtype=infer_scalar_type(arg_types)),
999
1061
  group="Quaternion Math",
1000
1062
  doc="Construct a quaternion from a 3x3 matrix.",
1001
1063
  )
1002
1064
  add_builtin(
1003
1065
  "quat_rpy",
1004
1066
  input_types={"roll": Float, "pitch": Float, "yaw": Float},
1005
- value_func=lambda args, kwds, _: quaternion(dtype=infer_scalar_type(args)),
1067
+ value_func=lambda arg_types, kwds, _: quaternion(dtype=infer_scalar_type(arg_types)),
1006
1068
  group="Quaternion Math",
1007
1069
  doc="Construct a quaternion representing a combined roll (z), pitch (x), yaw rotations (y) in radians.",
1008
1070
  )
1009
1071
  add_builtin(
1010
1072
  "quat_inverse",
1011
1073
  input_types={"q": quaternion(dtype=Float)},
1012
- value_func=lambda args, kwds, _: quaternion(dtype=infer_scalar_type(args)),
1074
+ value_func=lambda arg_types, kwds, _: quaternion(dtype=infer_scalar_type(arg_types)),
1013
1075
  group="Quaternion Math",
1014
1076
  doc="Compute quaternion conjugate.",
1015
1077
  )
1016
1078
  add_builtin(
1017
1079
  "quat_rotate",
1018
1080
  input_types={"q": quaternion(dtype=Float), "p": vector(length=3, dtype=Float)},
1019
- value_func=lambda args, kwds, _: vector(length=3, dtype=infer_scalar_type(args)),
1081
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1020
1082
  group="Quaternion Math",
1021
1083
  doc="Rotate a vector by a quaternion.",
1022
1084
  )
1023
1085
  add_builtin(
1024
1086
  "quat_rotate_inv",
1025
1087
  input_types={"q": quaternion(dtype=Float), "p": vector(length=3, dtype=Float)},
1026
- value_func=lambda args, kwds, _: vector(length=3, dtype=infer_scalar_type(args)),
1088
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1027
1089
  group="Quaternion Math",
1028
- doc="Rotate a vector the inverse of a quaternion.",
1090
+ doc="Rotate a vector by the inverse of a quaternion.",
1029
1091
  )
1030
1092
  add_builtin(
1031
1093
  "quat_slerp",
1032
1094
  input_types={"q0": quaternion(dtype=Float), "q1": quaternion(dtype=Float), "t": Float},
1033
- value_func=lambda args, kwds, _: quaternion(dtype=infer_scalar_type(args)),
1095
+ value_func=lambda arg_types, kwds, _: quaternion(dtype=infer_scalar_type(arg_types)),
1034
1096
  group="Quaternion Math",
1035
1097
  doc="Linearly interpolate between two quaternions.",
1098
+ require_original_output_arg=True,
1036
1099
  )
1037
1100
  add_builtin(
1038
1101
  "quat_to_matrix",
1039
1102
  input_types={"q": quaternion(dtype=Float)},
1040
- value_func=lambda args, kwds, _: matrix(shape=(3, 3), dtype=infer_scalar_type(args)),
1103
+ value_func=lambda arg_types, kwds, _: matrix(shape=(3, 3), dtype=infer_scalar_type(arg_types)),
1041
1104
  group="Quaternion Math",
1042
1105
  doc="Convert a quaternion to a 3x3 rotation matrix.",
1043
1106
  )
@@ -1053,19 +1116,19 @@ add_builtin(
1053
1116
  # Transformations
1054
1117
 
1055
1118
 
1056
- def transform_constructor_value_func(args, kwds, templates):
1119
+ def transform_constructor_value_func(arg_types, kwds, templates):
1057
1120
  if templates is None:
1058
1121
  return transformation(dtype=Scalar)
1059
1122
 
1060
1123
  if len(templates) == 0:
1061
1124
  # if constructing anonymous transform type then infer output type from arguments
1062
- dtype = infer_scalar_type(args)
1125
+ dtype = infer_scalar_type(arg_types)
1063
1126
  templates.append(dtype)
1064
1127
  else:
1065
- # if constructing predeclared type then check args match expectation
1066
- if infer_scalar_type(args) != templates[0]:
1128
+ # if constructing predeclared type then check arg_types match expectation
1129
+ if infer_scalar_type(arg_types) != templates[0]:
1067
1130
  raise RuntimeError(
1068
- f"Wrong scalar type for transform constructor expected {templates[0]}, got {','.join(map(lambda x : str(x.type), args))}"
1131
+ f"Wrong scalar type for transform constructor expected {templates[0]}, got {','.join(map(lambda t : str(t), arg_types))}"
1069
1132
  )
1070
1133
 
1071
1134
  return transformation(dtype=templates[0])
@@ -1077,13 +1140,13 @@ add_builtin(
1077
1140
  value_func=transform_constructor_value_func,
1078
1141
  native_func="transform_t",
1079
1142
  group="Transformations",
1080
- doc="Construct a rigid body transformation with translation part p and rotation q.",
1143
+ doc="Construct a rigid-body transformation with translation part ``p`` and rotation ``q``.",
1081
1144
  export=False,
1082
1145
  )
1083
1146
 
1084
1147
 
1085
- def transform_identity_value_func(args, kwds, templates):
1086
- if args is None:
1148
+ def transform_identity_value_func(arg_types, kwds, templates):
1149
+ if arg_types is None:
1087
1150
  return transformf
1088
1151
 
1089
1152
  if "dtype" not in kwds:
@@ -1109,68 +1172,72 @@ add_builtin(
1109
1172
  add_builtin(
1110
1173
  "transform_get_translation",
1111
1174
  input_types={"t": transformation(dtype=Float)},
1112
- value_func=lambda args, kwds, _: vector(length=3, dtype=infer_scalar_type(args)),
1175
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1113
1176
  group="Transformations",
1114
- doc="Return the translational part of a transform.",
1177
+ doc="Return the translational part of a transform ``t``.",
1115
1178
  )
1116
1179
  add_builtin(
1117
1180
  "transform_get_rotation",
1118
1181
  input_types={"t": transformation(dtype=Float)},
1119
- value_func=lambda args, kwds, _: quaternion(dtype=infer_scalar_type(args)),
1182
+ value_func=lambda arg_types, kwds, _: quaternion(dtype=infer_scalar_type(arg_types)),
1120
1183
  group="Transformations",
1121
- doc="Return the rotational part of a transform.",
1184
+ doc="Return the rotational part of a transform ``t``.",
1122
1185
  )
1123
1186
  add_builtin(
1124
1187
  "transform_multiply",
1125
1188
  input_types={"a": transformation(dtype=Float), "b": transformation(dtype=Float)},
1126
- value_func=lambda args, kwds, _: transformation(dtype=infer_scalar_type(args)),
1189
+ value_func=lambda arg_types, kwds, _: transformation(dtype=infer_scalar_type(arg_types)),
1127
1190
  group="Transformations",
1128
1191
  doc="Multiply two rigid body transformations together.",
1129
1192
  )
1130
1193
  add_builtin(
1131
1194
  "transform_point",
1132
1195
  input_types={"t": transformation(dtype=Scalar), "p": vector(length=3, dtype=Scalar)},
1133
- value_func=lambda args, kwds, _: vector(length=3, dtype=infer_scalar_type(args)),
1196
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1134
1197
  group="Transformations",
1135
- doc="Apply the transform to a point p treating the homogenous coordinate as w=1 (translation and rotation).",
1198
+ doc="Apply the transform to a point ``p`` treating the homogenous coordinate as w=1 (translation and rotation).",
1136
1199
  )
1137
1200
  add_builtin(
1138
1201
  "transform_point",
1139
1202
  input_types={"m": matrix(shape=(4, 4), dtype=Scalar), "p": vector(length=3, dtype=Scalar)},
1140
- value_func=lambda args, kwds, _: vector(length=3, dtype=infer_scalar_type(args)),
1203
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1141
1204
  group="Vector Math",
1142
- doc="""Apply the transform to a point ``p`` treating the homogenous coordinate as w=1. The transformation is applied treating ``p`` as a column vector, e.g.: ``y = M*p``
1143
- note this is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T = p^T*M^T``. If the transform is coming from a library that uses row-vectors
1144
- then users should transpose the transformation matrix before calling this method.""",
1205
+ doc="""Apply the transform to a point ``p`` treating the homogenous coordinate as w=1.
1206
+ The transformation is applied treating ``p`` as a column vector, e.g.: ``y = M*p``.
1207
+ Note this is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T = p^T*M^T``.
1208
+ If the transform is coming from a library that uses row-vectors, then users should transpose the transformation
1209
+ matrix before calling this method.""",
1145
1210
  )
1146
1211
  add_builtin(
1147
1212
  "transform_vector",
1148
1213
  input_types={"t": transformation(dtype=Scalar), "v": vector(length=3, dtype=Scalar)},
1149
- value_func=lambda args, kwds, _: vector(length=3, dtype=infer_scalar_type(args)),
1214
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1150
1215
  group="Transformations",
1151
- doc="Apply the transform to a vector v treating the homogenous coordinate as w=0 (rotation only).",
1216
+ doc="Apply the transform to a vector ``v`` treating the homogenous coordinate as w=0 (rotation only).",
1152
1217
  )
1153
1218
  add_builtin(
1154
1219
  "transform_vector",
1155
1220
  input_types={"m": matrix(shape=(4, 4), dtype=Scalar), "v": vector(length=3, dtype=Scalar)},
1156
- value_func=lambda args, kwds, _: vector(length=3, dtype=infer_scalar_type(args)),
1221
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1157
1222
  group="Vector Math",
1158
- doc="""Apply the transform to a vector ``v`` treating the homogenous coordinate as w=0. The transformation is applied treating ``v`` as a column vector, e.g.: ``y = M*v``
1159
- note this is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T = v^T*M^T``. If the transform is coming from a library that uses row-vectors
1160
- then users should transpose the transformation matrix before calling this method.""",
1223
+ doc="""Apply the transform to a vector ``v`` treating the homogenous coordinate as w=0.
1224
+ The transformation is applied treating ``v`` as a column vector, e.g.: ``y = M*v``
1225
+ note this is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T = v^T*M^T``.
1226
+ If the transform is coming from a library that uses row-vectors, then users should transpose the transformation
1227
+ matrix before calling this method.""",
1161
1228
  )
1162
1229
  add_builtin(
1163
1230
  "transform_inverse",
1164
1231
  input_types={"t": transformation(dtype=Float)},
1165
1232
  value_func=sametype_value_func(transformation(dtype=Float)),
1166
1233
  group="Transformations",
1167
- doc="Compute the inverse of the transform.",
1234
+ doc="Compute the inverse of the transformation ``t``.",
1168
1235
  )
1169
1236
  # ---------------------------------
1170
1237
  # Spatial Math
1171
1238
 
1172
1239
 
1173
- def spatial_vector_constructor_value_func(args, kwds, templates):
1240
+ def spatial_vector_constructor_value_func(arg_types, kwds, templates):
1174
1241
  if templates is None:
1175
1242
  return spatial_vector(dtype=Float)
1176
1243
 
@@ -1178,7 +1245,7 @@ def spatial_vector_constructor_value_func(args, kwds, templates):
1178
1245
  raise RuntimeError("Cannot use a generic type name in a kernel")
1179
1246
 
1180
1247
  vectype = templates[1]
1181
- if len(args) and infer_scalar_type(args) != vectype:
1248
+ if len(arg_types) and infer_scalar_type(arg_types) != vectype:
1182
1249
  raise RuntimeError("Wrong scalar type for spatial_vector<{}> constructor".format(",".join(map(str, templates))))
1183
1250
 
1184
1251
  return vector(length=6, dtype=vectype)
@@ -1190,7 +1257,7 @@ add_builtin(
1190
1257
  value_func=spatial_vector_constructor_value_func,
1191
1258
  native_func="vec_t",
1192
1259
  group="Spatial Math",
1193
- doc="Construct a 6d screw vector from two 3d vectors.",
1260
+ doc="Construct a 6D screw vector from two 3D vectors.",
1194
1261
  export=False,
1195
1262
  )
1196
1263
 
@@ -1198,7 +1265,7 @@ add_builtin(
1198
1265
  add_builtin(
1199
1266
  "spatial_adjoint",
1200
1267
  input_types={"r": matrix(shape=(3, 3), dtype=Float), "s": matrix(shape=(3, 3), dtype=Float)},
1201
- value_func=lambda args, kwds, _: matrix(shape=(6, 6), dtype=infer_scalar_type(args)),
1268
+ value_func=lambda arg_types, kwds, _: matrix(shape=(6, 6), dtype=infer_scalar_type(arg_types)),
1202
1269
  group="Spatial Math",
1203
1270
  doc="Construct a 6x6 spatial inertial matrix from two 3x3 diagonal blocks.",
1204
1271
  export=False,
@@ -1208,36 +1275,36 @@ add_builtin(
1208
1275
  input_types={"a": vector(length=6, dtype=Float), "b": vector(length=6, dtype=Float)},
1209
1276
  value_func=sametype_scalar_value_func,
1210
1277
  group="Spatial Math",
1211
- doc="Compute the dot product of two 6d screw vectors.",
1278
+ doc="Compute the dot product of two 6D screw vectors.",
1212
1279
  )
1213
1280
  add_builtin(
1214
1281
  "spatial_cross",
1215
1282
  input_types={"a": vector(length=6, dtype=Float), "b": vector(length=6, dtype=Float)},
1216
1283
  value_func=sametype_value_func(vector(length=6, dtype=Float)),
1217
1284
  group="Spatial Math",
1218
- doc="Compute the cross-product of two 6d screw vectors.",
1285
+ doc="Compute the cross product of two 6D screw vectors.",
1219
1286
  )
1220
1287
  add_builtin(
1221
1288
  "spatial_cross_dual",
1222
1289
  input_types={"a": vector(length=6, dtype=Float), "b": vector(length=6, dtype=Float)},
1223
1290
  value_func=sametype_value_func(vector(length=6, dtype=Float)),
1224
1291
  group="Spatial Math",
1225
- doc="Compute the dual cross-product of two 6d screw vectors.",
1292
+ doc="Compute the dual cross product of two 6D screw vectors.",
1226
1293
  )
1227
1294
 
1228
1295
  add_builtin(
1229
1296
  "spatial_top",
1230
1297
  input_types={"a": vector(length=6, dtype=Float)},
1231
- value_func=lambda args, kwds, _: vector(length=3, dtype=args[0].type._wp_scalar_type_),
1298
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=arg_types[0]._wp_scalar_type_),
1232
1299
  group="Spatial Math",
1233
- doc="Return the top (first) part of a 6d screw vector.",
1300
+ doc="Return the top (first) part of a 6D screw vector.",
1234
1301
  )
1235
1302
  add_builtin(
1236
1303
  "spatial_bottom",
1237
1304
  input_types={"a": vector(length=6, dtype=Float)},
1238
- value_func=lambda args, kwds, _: vector(length=3, dtype=args[0].type._wp_scalar_type_),
1305
+ value_func=lambda arg_types, kwds, _: vector(length=3, dtype=arg_types[0]._wp_scalar_type_),
1239
1306
  group="Spatial Math",
1240
- doc="Return the bottom (second) part of a 6d screw vector.",
1307
+ doc="Return the bottom (second) part of a 6D screw vector.",
1241
1308
  )
1242
1309
 
1243
1310
  add_builtin(
@@ -1391,16 +1458,18 @@ add_builtin(
1391
1458
  },
1392
1459
  value_type=None,
1393
1460
  skip_replay=True,
1394
- doc="""Evaluate a multi-layer perceptron (MLP) layer in the form: ``out = act(weights*x + bias)``.
1461
+ doc="""Evaluate a multi-layer perceptron (MLP) layer in the form: ``out = act(weights*x + bias)``.
1395
1462
 
1396
1463
  :param weights: A layer's network weights with dimensions ``(m, n)``.
1397
1464
  :param bias: An array with dimensions ``(n)``.
1398
1465
  :param activation: A ``wp.func`` function that takes a single scalar float as input and returns a scalar float as output
1399
- :param index: The batch item to process, typically each thread will process 1 item in the batch, in this case index should be ``wp.tid()``
1466
+ :param index: The batch item to process, typically each thread will process one item in the batch, in which case
1467
+ index should be ``wp.tid()``
1400
1468
  :param x: The feature matrix with dimensions ``(n, b)``
1401
1469
  :param out: The network output with dimensions ``(m, b)``
1402
1470
 
1403
- :note: Feature and output matrices are transposed compared to some other frameworks such as PyTorch. All matrices are assumed to be stored in flattened row-major memory layout (NumPy default).""",
1471
+ :note: Feature and output matrices are transposed compared to some other frameworks such as PyTorch.
1472
+ All matrices are assumed to be stored in flattened row-major memory layout (NumPy default).""",
1404
1473
  group="Utility",
1405
1474
  )
1406
1475
 
@@ -1413,12 +1482,12 @@ add_builtin(
1413
1482
  input_types={"id": uint64, "lower": vec3, "upper": vec3},
1414
1483
  value_type=bvh_query_t,
1415
1484
  group="Geometry",
1416
- doc="""Construct an axis-aligned bounding box query against a bvh object. This query can be used to iterate over all bounds
1417
- inside a bvh. Returns an object that is used to track state during bvh traversal.
1418
-
1419
- :param id: The bvh identifier
1420
- :param lower: The lower bound of the bounding box in bvh space
1421
- :param upper: The upper bound of the bounding box in bvh space""",
1485
+ doc="""Construct an axis-aligned bounding box query against a BVH object. This query can be used to iterate over all bounds
1486
+ inside a BVH.
1487
+
1488
+ :param id: The BVH identifier
1489
+ :param lower: The lower bound of the bounding box in BVH space
1490
+ :param upper: The upper bound of the bounding box in BVH space""",
1422
1491
  )
1423
1492
 
1424
1493
  add_builtin(
@@ -1426,12 +1495,12 @@ add_builtin(
1426
1495
  input_types={"id": uint64, "start": vec3, "dir": vec3},
1427
1496
  value_type=bvh_query_t,
1428
1497
  group="Geometry",
1429
- doc="""Construct a ray query against a bvh object. This query can be used to iterate over all bounds
1430
- that intersect the ray. Returns an object that is used to track state during bvh traversal.
1431
-
1432
- :param id: The bvh identifier
1433
- :param start: The start of the ray in bvh space
1434
- :param dir: The direction of the ray in bvh space""",
1498
+ doc="""Construct a ray query against a BVH object. This query can be used to iterate over all bounds
1499
+ that intersect the ray.
1500
+
1501
+ :param id: The BVH identifier
1502
+ :param start: The start of the ray in BVH space
1503
+ :param dir: The direction of the ray in BVH space""",
1435
1504
  )
1436
1505
 
1437
1506
  add_builtin(
@@ -1439,8 +1508,8 @@ add_builtin(
1439
1508
  input_types={"query": bvh_query_t, "index": int},
1440
1509
  value_type=builtins.bool,
1441
1510
  group="Geometry",
1442
- doc="""Move to the next bound returned by the query. The index of the current bound is stored in ``index``, returns ``False``
1443
- if there are no more overlapping bound.""",
1511
+ doc="""Move to the next bound returned by the query.
1512
+ The index of the current bound is stored in ``index``, returns ``False`` if there are no more overlapping bound.""",
1444
1513
  )
1445
1514
 
1446
1515
  add_builtin(
@@ -1456,18 +1525,42 @@ add_builtin(
1456
1525
  },
1457
1526
  value_type=builtins.bool,
1458
1527
  group="Geometry",
1459
- doc="""Computes the closest point on the mesh with identifier `id` to the given point in space. Returns ``True`` if a point < ``max_dist`` is found.
1528
+ doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space. Returns ``True`` if a point < ``max_dist`` is found.
1460
1529
 
1461
- Identifies the sign of the distance using additional ray-casts to determine if the point is inside or outside. This method is relatively robust, but
1462
- does increase computational cost. See below for additional sign determination methods.
1530
+ Identifies the sign of the distance using additional ray-casts to determine if the point is inside or outside.
1531
+ This method is relatively robust, but does increase computational cost.
1532
+ See below for additional sign determination methods.
1463
1533
 
1464
1534
  :param id: The mesh identifier
1465
1535
  :param point: The point in space to query
1466
1536
  :param max_dist: Mesh faces above this distance will not be considered by the query
1467
- :param inside: Returns a value < 0 if query point is inside the mesh, >=0 otherwise. Note that mesh must be watertight for this to be robust
1537
+ :param inside: Returns a value < 0 if query point is inside the mesh, >=0 otherwise.
1538
+ Note that mesh must be watertight for this to be robust
1468
1539
  :param face: Returns the index of the closest face
1469
1540
  :param bary_u: Returns the barycentric u coordinate of the closest point
1470
1541
  :param bary_v: Returns the barycentric v coordinate of the closest point""",
1542
+ hidden=True,
1543
+ )
1544
+
1545
+ add_builtin(
1546
+ "mesh_query_point",
1547
+ input_types={
1548
+ "id": uint64,
1549
+ "point": vec3,
1550
+ "max_dist": float,
1551
+ },
1552
+ value_type=mesh_query_point_t,
1553
+ group="Geometry",
1554
+ doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space.
1555
+
1556
+ Identifies the sign of the distance using additional ray-casts to determine if the point is inside or outside.
1557
+ This method is relatively robust, but does increase computational cost.
1558
+ See below for additional sign determination methods.
1559
+
1560
+ :param id: The mesh identifier
1561
+ :param point: The point in space to query
1562
+ :param max_dist: Mesh faces above this distance will not be considered by the query""",
1563
+ require_original_output_arg=True,
1471
1564
  )
1472
1565
 
1473
1566
  add_builtin(
@@ -1482,7 +1575,7 @@ add_builtin(
1482
1575
  },
1483
1576
  value_type=builtins.bool,
1484
1577
  group="Geometry",
1485
- doc="""Computes the closest point on the mesh with identifier `id` to the given point in space. Returns ``True`` if a point < ``max_dist`` is found.
1578
+ doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space. Returns ``True`` if a point < ``max_dist`` is found.
1486
1579
 
1487
1580
  This method does not compute the sign of the point (inside/outside) which makes it faster than other point query methods.
1488
1581
 
@@ -1492,6 +1585,70 @@ add_builtin(
1492
1585
  :param face: Returns the index of the closest face
1493
1586
  :param bary_u: Returns the barycentric u coordinate of the closest point
1494
1587
  :param bary_v: Returns the barycentric v coordinate of the closest point""",
1588
+ hidden=True,
1589
+ )
1590
+
1591
+ add_builtin(
1592
+ "mesh_query_point_no_sign",
1593
+ input_types={
1594
+ "id": uint64,
1595
+ "point": vec3,
1596
+ "max_dist": float,
1597
+ },
1598
+ value_type=mesh_query_point_t,
1599
+ group="Geometry",
1600
+ doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space.
1601
+
1602
+ This method does not compute the sign of the point (inside/outside) which makes it faster than other point query methods.
1603
+
1604
+ :param id: The mesh identifier
1605
+ :param point: The point in space to query
1606
+ :param max_dist: Mesh faces above this distance will not be considered by the query""",
1607
+ require_original_output_arg=True,
1608
+ )
1609
+
1610
+ add_builtin(
1611
+ "mesh_query_furthest_point_no_sign",
1612
+ input_types={
1613
+ "id": uint64,
1614
+ "point": vec3,
1615
+ "min_dist": float,
1616
+ "face": int,
1617
+ "bary_u": float,
1618
+ "bary_v": float,
1619
+ },
1620
+ value_type=builtins.bool,
1621
+ group="Geometry",
1622
+ doc="""Computes the furthest point on the mesh with identifier `id` to the given point in space. Returns ``True`` if a point > ``min_dist`` is found.
1623
+
1624
+ This method does not compute the sign of the point (inside/outside).
1625
+
1626
+ :param id: The mesh identifier
1627
+ :param point: The point in space to query
1628
+ :param min_dist: Mesh faces below this distance will not be considered by the query
1629
+ :param face: Returns the index of the furthest face
1630
+ :param bary_u: Returns the barycentric u coordinate of the furthest point
1631
+ :param bary_v: Returns the barycentric v coordinate of the furthest point""",
1632
+ hidden=True,
1633
+ )
1634
+
1635
+ add_builtin(
1636
+ "mesh_query_furthest_point_no_sign",
1637
+ input_types={
1638
+ "id": uint64,
1639
+ "point": vec3,
1640
+ "min_dist": float,
1641
+ },
1642
+ value_type=mesh_query_point_t,
1643
+ group="Geometry",
1644
+ doc="""Computes the furthest point on the mesh with identifier `id` to the given point in space.
1645
+
1646
+ This method does not compute the sign of the point (inside/outside).
1647
+
1648
+ :param id: The mesh identifier
1649
+ :param point: The point in space to query
1650
+ :param min_dist: Mesh faces below this distance will not be considered by the query""",
1651
+ require_original_output_arg=True,
1495
1652
  )
1496
1653
 
1497
1654
  add_builtin(
@@ -1509,19 +1666,48 @@ add_builtin(
1509
1666
  defaults={"epsilon": 1.0e-3},
1510
1667
  value_type=builtins.bool,
1511
1668
  group="Geometry",
1512
- doc="""Computes the closest point on the mesh with identifier `id` to the given point in space. Returns ``True`` if a point < ``max_dist`` is found.
1513
-
1514
- Identifies the sign of the distance (inside/outside) using the angle-weighted pseudo normal. This approach to sign determination is robust for well conditioned meshes
1515
- that are watertight and non-self intersecting, it is also comparatively fast to compute.
1669
+ doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space. Returns ``True`` if a point < ``max_dist`` is found.
1670
+
1671
+ Identifies the sign of the distance (inside/outside) using the angle-weighted pseudo normal.
1672
+ This approach to sign determination is robust for well conditioned meshes that are watertight and non-self intersecting.
1673
+ It is also comparatively fast to compute.
1516
1674
 
1517
1675
  :param id: The mesh identifier
1518
1676
  :param point: The point in space to query
1519
1677
  :param max_dist: Mesh faces above this distance will not be considered by the query
1520
- :param inside: Returns a value < 0 if query point is inside the mesh, >=0 otherwise. Note that mesh must be watertight for this to be robust
1678
+ :param inside: Returns a value < 0 if query point is inside the mesh, >=0 otherwise.
1679
+ Note that mesh must be watertight for this to be robust
1521
1680
  :param face: Returns the index of the closest face
1522
1681
  :param bary_u: Returns the barycentric u coordinate of the closest point
1523
1682
  :param bary_v: Returns the barycentric v coordinate of the closest point
1524
- :param epsilon: Epsilon treating distance values as equal, when locating the minimum distance vertex/face/edge, as a fraction of the average edge length, also for treating closest point as being on edge/vertex default 1e-3""",
1683
+ :param epsilon: Epsilon treating distance values as equal, when locating the minimum distance vertex/face/edge, as a
1684
+ fraction of the average edge length, also for treating closest point as being on edge/vertex default 1e-3""",
1685
+ hidden=True,
1686
+ )
1687
+
1688
+ add_builtin(
1689
+ "mesh_query_point_sign_normal",
1690
+ input_types={
1691
+ "id": uint64,
1692
+ "point": vec3,
1693
+ "max_dist": float,
1694
+ "epsilon": float,
1695
+ },
1696
+ defaults={"epsilon": 1.0e-3},
1697
+ value_type=mesh_query_point_t,
1698
+ group="Geometry",
1699
+ doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space.
1700
+
1701
+ Identifies the sign of the distance (inside/outside) using the angle-weighted pseudo normal.
1702
+ This approach to sign determination is robust for well conditioned meshes that are watertight and non-self intersecting.
1703
+ It is also comparatively fast to compute.
1704
+
1705
+ :param id: The mesh identifier
1706
+ :param point: The point in space to query
1707
+ :param max_dist: Mesh faces above this distance will not be considered by the query
1708
+ :param epsilon: Epsilon treating distance values as equal, when locating the minimum distance vertex/face/edge, as a
1709
+ fraction of the average edge length, also for treating closest point as being on edge/vertex default 1e-3""",
1710
+ require_original_output_arg=True,
1525
1711
  )
1526
1712
 
1527
1713
  add_builtin(
@@ -1540,23 +1726,53 @@ add_builtin(
1540
1726
  defaults={"accuracy": 2.0, "threshold": 0.5},
1541
1727
  value_type=builtins.bool,
1542
1728
  group="Geometry",
1543
- doc="""Computes the closest point on the mesh with identifier `id` to the given point in space. Returns ``True`` if a point < ``max_dist`` is found.
1544
-
1729
+ doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given point in space. Returns ``True`` if a point < ``max_dist`` is found.
1730
+
1545
1731
  Identifies the sign using the winding number of the mesh relative to the query point. This method of sign determination is robust for poorly conditioned meshes
1546
1732
  and provides a smooth approximation to sign even when the mesh is not watertight. This method is the most robust and accurate of the sign determination meshes
1547
1733
  but also the most expensive.
1548
-
1549
- Note that the Mesh object must be constructed with ``suport_winding_number=True`` for this method to return correct results.
1734
+
1735
+ .. note:: The :class:`Mesh` object must be constructed with ``support_winding_number=True`` for this method to return correct results.
1550
1736
 
1551
1737
  :param id: The mesh identifier
1552
1738
  :param point: The point in space to query
1553
1739
  :param max_dist: Mesh faces above this distance will not be considered by the query
1554
- :param inside: Returns a value < 0 if query point is inside the mesh, >=0 otherwise. Note that mesh must be watertight for this to be robust
1740
+ :param inside: Returns a value < 0 if query point is inside the mesh, >=0 otherwise.
1741
+ Note that mesh must be watertight for this to be robust
1555
1742
  :param face: Returns the index of the closest face
1556
1743
  :param bary_u: Returns the barycentric u coordinate of the closest point
1557
1744
  :param bary_v: Returns the barycentric v coordinate of the closest point
1558
- :param accuracy: Accuracy for computing the winding number with fast winding number method utilizing second order dipole approximation, default 2.0
1745
+ :param accuracy: Accuracy for computing the winding number with fast winding number method utilizing second-order dipole approximation, default 2.0
1746
+ :param threshold: The threshold of the winding number to be considered inside, default 0.5""",
1747
+ hidden=True,
1748
+ )
1749
+
1750
+ add_builtin(
1751
+ "mesh_query_point_sign_winding_number",
1752
+ input_types={
1753
+ "id": uint64,
1754
+ "point": vec3,
1755
+ "max_dist": float,
1756
+ "accuracy": float,
1757
+ "threshold": float,
1758
+ },
1759
+ defaults={"accuracy": 2.0, "threshold": 0.5},
1760
+ value_type=mesh_query_point_t,
1761
+ group="Geometry",
1762
+ doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given point in space.
1763
+
1764
+ Identifies the sign using the winding number of the mesh relative to the query point. This method of sign determination is robust for poorly conditioned meshes
1765
+ and provides a smooth approximation to sign even when the mesh is not watertight. This method is the most robust and accurate of the sign determination meshes
1766
+ but also the most expensive.
1767
+
1768
+ .. note:: The :class:`Mesh` object must be constructed with ``support_winding_number=True`` for this method to return correct results.
1769
+
1770
+ :param id: The mesh identifier
1771
+ :param point: The point in space to query
1772
+ :param max_dist: Mesh faces above this distance will not be considered by the query
1773
+ :param accuracy: Accuracy for computing the winding number with fast winding number method utilizing second-order dipole approximation, default 2.0
1559
1774
  :param threshold: The threshold of the winding number to be considered inside, default 0.5""",
1775
+ require_original_output_arg=True,
1560
1776
  )
1561
1777
 
1562
1778
  add_builtin(
@@ -1575,7 +1791,7 @@ add_builtin(
1575
1791
  },
1576
1792
  value_type=builtins.bool,
1577
1793
  group="Geometry",
1578
- doc="""Computes the closest ray hit on the mesh with identifier `id`, returns ``True`` if a point < ``max_t`` is found.
1794
+ doc="""Computes the closest ray hit on the :class:`Mesh` with identifier ``id``, returns ``True`` if a hit < ``max_t`` is found.
1579
1795
 
1580
1796
  :param id: The mesh identifier
1581
1797
  :param start: The start point of the ray
@@ -1584,9 +1800,29 @@ add_builtin(
1584
1800
  :param t: Returns the distance of the closest hit along the ray
1585
1801
  :param bary_u: Returns the barycentric u coordinate of the closest hit
1586
1802
  :param bary_v: Returns the barycentric v coordinate of the closest hit
1587
- :param sign: Returns a value > 0 if the hit ray hit front of the face, returns < 0 otherwise
1803
+ :param sign: Returns a value > 0 if the ray hit in front of the face, returns < 0 otherwise
1588
1804
  :param normal: Returns the face normal
1589
1805
  :param face: Returns the index of the hit face""",
1806
+ hidden=True,
1807
+ )
1808
+
1809
+ add_builtin(
1810
+ "mesh_query_ray",
1811
+ input_types={
1812
+ "id": uint64,
1813
+ "start": vec3,
1814
+ "dir": vec3,
1815
+ "max_t": float,
1816
+ },
1817
+ value_type=mesh_query_ray_t,
1818
+ group="Geometry",
1819
+ doc="""Computes the closest ray hit on the :class:`Mesh` with identifier ``id``.
1820
+
1821
+ :param id: The mesh identifier
1822
+ :param start: The start point of the ray
1823
+ :param dir: The ray direction (should be normalized)
1824
+ :param max_t: The maximum distance along the ray to check for intersections""",
1825
+ require_original_output_arg=True,
1590
1826
  )
1591
1827
 
1592
1828
  add_builtin(
@@ -1594,9 +1830,9 @@ add_builtin(
1594
1830
  input_types={"id": uint64, "lower": vec3, "upper": vec3},
1595
1831
  value_type=mesh_query_aabb_t,
1596
1832
  group="Geometry",
1597
- doc="""Construct an axis-aligned bounding box query against a mesh object. This query can be used to iterate over all triangles
1598
- inside a volume. Returns an object that is used to track state during mesh traversal.
1599
-
1833
+ doc="""Construct an axis-aligned bounding box query against a :class:`Mesh`.
1834
+ This query can be used to iterate over all triangles inside a volume.
1835
+
1600
1836
  :param id: The mesh identifier
1601
1837
  :param lower: The lower bound of the bounding box in mesh space
1602
1838
  :param upper: The upper bound of the bounding box in mesh space""",
@@ -1607,8 +1843,8 @@ add_builtin(
1607
1843
  input_types={"query": mesh_query_aabb_t, "index": int},
1608
1844
  value_type=builtins.bool,
1609
1845
  group="Geometry",
1610
- doc="""Move to the next triangle overlapping the query bounding box. The index of the current face is stored in ``index``, returns ``False``
1611
- if there are no more overlapping triangles.""",
1846
+ doc="""Move to the next triangle overlapping the query bounding box.
1847
+ The index of the current face is stored in ``index``, returns ``False`` if there are no more overlapping triangles.""",
1612
1848
  )
1613
1849
 
1614
1850
  add_builtin(
@@ -1616,7 +1852,7 @@ add_builtin(
1616
1852
  input_types={"id": uint64, "face": int, "bary_u": float, "bary_v": float},
1617
1853
  value_type=vec3,
1618
1854
  group="Geometry",
1619
- doc="""Evaluates the position on the mesh given a face index, and barycentric coordinates.""",
1855
+ doc="""Evaluates the position on the :class:`Mesh` given a face index and barycentric coordinates.""",
1620
1856
  )
1621
1857
 
1622
1858
  add_builtin(
@@ -1624,7 +1860,7 @@ add_builtin(
1624
1860
  input_types={"id": uint64, "face": int, "bary_u": float, "bary_v": float},
1625
1861
  value_type=vec3,
1626
1862
  group="Geometry",
1627
- doc="""Evaluates the velocity on the mesh given a face index, and barycentric coordinates.""",
1863
+ doc="""Evaluates the velocity on the :class:`Mesh` given a face index and barycentric coordinates.""",
1628
1864
  )
1629
1865
 
1630
1866
  add_builtin(
@@ -1632,8 +1868,8 @@ add_builtin(
1632
1868
  input_types={"id": uint64, "point": vec3, "max_dist": float},
1633
1869
  value_type=hash_grid_query_t,
1634
1870
  group="Geometry",
1635
- doc="""Construct a point query against a hash grid. This query can be used to iterate over all neighboring points withing a
1636
- fixed radius from the query point. Returns an object that is used to track state during neighbor traversal.""",
1871
+ doc="Construct a point query against a :class:`HashGrid`. This query can be used to iterate over all neighboring points "
1872
+ "within a fixed radius from the query point.",
1637
1873
  )
1638
1874
 
1639
1875
  add_builtin(
@@ -1650,8 +1886,10 @@ add_builtin(
1650
1886
  input_types={"id": uint64, "index": int},
1651
1887
  value_type=int,
1652
1888
  group="Geometry",
1653
- doc="""Return the index of a point in the grid, this can be used to re-order threads such that grid
1654
- traversal occurs in a spatially coherent order.""",
1889
+ doc="""Return the index of a point in the :class:`HashGrid`. This can be used to reorder threads such that grid
1890
+ traversal occurs in a spatially coherent order.
1891
+
1892
+ Returns -1 if the :class:`HashGrid` has not been reserved.""",
1655
1893
  )
1656
1894
 
1657
1895
  add_builtin(
@@ -1750,7 +1988,8 @@ add_builtin(
1750
1988
  input_types={"id": uint64, "uvw": vec3, "sampling_mode": int},
1751
1989
  value_type=float,
1752
1990
  group="Volumes",
1753
- doc="""Sample the volume given by ``id`` at the volume local-space point ``uvw``. Interpolation should be ``wp.Volume.CLOSEST``, or ``wp.Volume.LINEAR.``""",
1991
+ doc="""Sample the volume given by ``id`` at the volume local-space point ``uvw``.
1992
+ Interpolation should be :attr:`warp.Volume.CLOSEST` or :attr:`wp.Volume.LINEAR.`""",
1754
1993
  )
1755
1994
 
1756
1995
  add_builtin(
@@ -1758,7 +1997,8 @@ add_builtin(
1758
1997
  input_types={"id": uint64, "uvw": vec3, "sampling_mode": int, "grad": vec3},
1759
1998
  value_type=float,
1760
1999
  group="Volumes",
1761
- doc="""Sample the volume and its gradient given by ``id`` at the volume local-space point ``uvw``. Interpolation should be ``wp.Volume.CLOSEST``, or ``wp.Volume.LINEAR.``""",
2000
+ doc="""Sample the volume and its gradient given by ``id`` at the volume local-space point ``uvw``.
2001
+ Interpolation should be :attr:`warp.Volume.CLOSEST` or :attr:`wp.Volume.LINEAR.`""",
1762
2002
  )
1763
2003
 
1764
2004
  add_builtin(
@@ -1766,14 +2006,15 @@ add_builtin(
1766
2006
  input_types={"id": uint64, "i": int, "j": int, "k": int},
1767
2007
  value_type=float,
1768
2008
  group="Volumes",
1769
- doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k``, if the voxel at this index does not exist this function returns the background value""",
2009
+ doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k``.
2010
+ If the voxel at this index does not exist, this function returns the background value""",
1770
2011
  )
1771
2012
 
1772
2013
  add_builtin(
1773
2014
  "volume_store_f",
1774
2015
  input_types={"id": uint64, "i": int, "j": int, "k": int, "value": float},
1775
2016
  group="Volumes",
1776
- doc="""Store the value at voxel with coordinates ``i``, ``j``, ``k``.""",
2017
+ doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
1777
2018
  )
1778
2019
 
1779
2020
  add_builtin(
@@ -1781,7 +2022,8 @@ add_builtin(
1781
2022
  input_types={"id": uint64, "uvw": vec3, "sampling_mode": int},
1782
2023
  value_type=vec3,
1783
2024
  group="Volumes",
1784
- doc="""Sample the vector volume given by ``id`` at the volume local-space point ``uvw``. Interpolation should be ``wp.Volume.CLOSEST``, or ``wp.Volume.LINEAR.``""",
2025
+ doc="""Sample the vector volume given by ``id`` at the volume local-space point ``uvw``.
2026
+ Interpolation should be :attr:`warp.Volume.CLOSEST` or :attr:`wp.Volume.LINEAR.`""",
1785
2027
  )
1786
2028
 
1787
2029
  add_builtin(
@@ -1789,14 +2031,15 @@ add_builtin(
1789
2031
  input_types={"id": uint64, "i": int, "j": int, "k": int},
1790
2032
  value_type=vec3,
1791
2033
  group="Volumes",
1792
- doc="""Returns the vector value of voxel with coordinates ``i``, ``j``, ``k``, if the voxel at this index does not exist this function returns the background value""",
2034
+ doc="""Returns the vector value of voxel with coordinates ``i``, ``j``, ``k``.
2035
+ If the voxel at this index does not exist, this function returns the background value.""",
1793
2036
  )
1794
2037
 
1795
2038
  add_builtin(
1796
2039
  "volume_store_v",
1797
2040
  input_types={"id": uint64, "i": int, "j": int, "k": int, "value": vec3},
1798
2041
  group="Volumes",
1799
- doc="""Store the value at voxel with coordinates ``i``, ``j``, ``k``.""",
2042
+ doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
1800
2043
  )
1801
2044
 
1802
2045
  add_builtin(
@@ -1804,7 +2047,7 @@ add_builtin(
1804
2047
  input_types={"id": uint64, "uvw": vec3},
1805
2048
  value_type=int,
1806
2049
  group="Volumes",
1807
- doc="""Sample the int32 volume given by ``id`` at the volume local-space point ``uvw``. """,
2050
+ doc="""Sample the :class:`int32` volume given by ``id`` at the volume local-space point ``uvw``. """,
1808
2051
  )
1809
2052
 
1810
2053
  add_builtin(
@@ -1812,14 +2055,15 @@ add_builtin(
1812
2055
  input_types={"id": uint64, "i": int, "j": int, "k": int},
1813
2056
  value_type=int,
1814
2057
  group="Volumes",
1815
- doc="""Returns the int32 value of voxel with coordinates ``i``, ``j``, ``k``, if the voxel at this index does not exist this function returns the background value""",
2058
+ doc="""Returns the :class:`int32` value of voxel with coordinates ``i``, ``j``, ``k``.
2059
+ If the voxel at this index does not exist, this function returns the background value.""",
1816
2060
  )
1817
2061
 
1818
2062
  add_builtin(
1819
2063
  "volume_store_i",
1820
2064
  input_types={"id": uint64, "i": int, "j": int, "k": int, "value": int},
1821
2065
  group="Volumes",
1822
- doc="""Store the value at voxel with coordinates ``i``, ``j``, ``k``.""",
2066
+ doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
1823
2067
  )
1824
2068
 
1825
2069
  add_builtin(
@@ -1827,28 +2071,28 @@ add_builtin(
1827
2071
  input_types={"id": uint64, "uvw": vec3},
1828
2072
  value_type=vec3,
1829
2073
  group="Volumes",
1830
- doc="""Transform a point defined in volume index space to world space given the volume's intrinsic affine transformation.""",
2074
+ doc="""Transform a point ``uvw`` defined in volume index space to world space given the volume's intrinsic affine transformation.""",
1831
2075
  )
1832
2076
  add_builtin(
1833
2077
  "volume_world_to_index",
1834
2078
  input_types={"id": uint64, "xyz": vec3},
1835
2079
  value_type=vec3,
1836
2080
  group="Volumes",
1837
- doc="""Transform a point defined in volume world space to the volume's index space, given the volume's intrinsic affine transformation.""",
2081
+ doc="""Transform a point ``xyz`` defined in volume world space to the volume's index space given the volume's intrinsic affine transformation.""",
1838
2082
  )
1839
2083
  add_builtin(
1840
2084
  "volume_index_to_world_dir",
1841
2085
  input_types={"id": uint64, "uvw": vec3},
1842
2086
  value_type=vec3,
1843
2087
  group="Volumes",
1844
- doc="""Transform a direction defined in volume index space to world space given the volume's intrinsic affine transformation.""",
2088
+ doc="""Transform a direction ``uvw`` defined in volume index space to world space given the volume's intrinsic affine transformation.""",
1845
2089
  )
1846
2090
  add_builtin(
1847
2091
  "volume_world_to_index_dir",
1848
2092
  input_types={"id": uint64, "xyz": vec3},
1849
2093
  value_type=vec3,
1850
2094
  group="Volumes",
1851
- doc="""Transform a direction defined in volume world space to the volume's index space, given the volume's intrinsic affine transformation.""",
2095
+ doc="""Transform a direction ``xyz`` defined in volume world space to the volume's index space given the volume's intrinsic affine transformation.""",
1852
2096
  )
1853
2097
 
1854
2098
 
@@ -1868,7 +2112,7 @@ add_builtin(
1868
2112
  input_types={"seed": int, "offset": int},
1869
2113
  value_type=uint32,
1870
2114
  group="Random",
1871
- doc="""Initialize a new random number generator given a user-defined seed and an offset.
2115
+ doc="""Initialize a new random number generator given a user-defined seed and an offset.
1872
2116
  This alternative constructor can be useful in parallel programs, where a kernel as a whole should share a seed,
1873
2117
  but each thread should generate uncorrelated values. In this case usage should be ``r = rand_init(seed, tid)``""",
1874
2118
  )
@@ -1878,31 +2122,31 @@ add_builtin(
1878
2122
  input_types={"state": uint32},
1879
2123
  value_type=int,
1880
2124
  group="Random",
1881
- doc="Return a random integer between [0, 2^32)",
2125
+ doc="Return a random integer in the range [0, 2^32).",
1882
2126
  )
1883
2127
  add_builtin(
1884
2128
  "randi",
1885
2129
  input_types={"state": uint32, "min": int, "max": int},
1886
2130
  value_type=int,
1887
2131
  group="Random",
1888
- doc="Return a random integer between [min, max)",
2132
+ doc="Return a random integer between [min, max).",
1889
2133
  )
1890
2134
  add_builtin(
1891
2135
  "randf",
1892
2136
  input_types={"state": uint32},
1893
2137
  value_type=float,
1894
2138
  group="Random",
1895
- doc="Return a random float between [0.0, 1.0)",
2139
+ doc="Return a random float between [0.0, 1.0).",
1896
2140
  )
1897
2141
  add_builtin(
1898
2142
  "randf",
1899
2143
  input_types={"state": uint32, "min": float, "max": float},
1900
2144
  value_type=float,
1901
2145
  group="Random",
1902
- doc="Return a random float between [min, max)",
2146
+ doc="Return a random float between [min, max).",
1903
2147
  )
1904
2148
  add_builtin(
1905
- "randn", input_types={"state": uint32}, value_type=float, group="Random", doc="Sample a normal distribution"
2149
+ "randn", input_types={"state": uint32}, value_type=float, group="Random", doc="Sample a normal distribution."
1906
2150
  )
1907
2151
 
1908
2152
  add_builtin(
@@ -1910,70 +2154,70 @@ add_builtin(
1910
2154
  input_types={"state": uint32, "cdf": array(dtype=float)},
1911
2155
  value_type=int,
1912
2156
  group="Random",
1913
- doc="Inverse transform sample a cumulative distribution function",
2157
+ doc="Inverse-transform sample a cumulative distribution function.",
1914
2158
  )
1915
2159
  add_builtin(
1916
2160
  "sample_triangle",
1917
2161
  input_types={"state": uint32},
1918
2162
  value_type=vec2,
1919
2163
  group="Random",
1920
- doc="Uniformly sample a triangle. Returns sample barycentric coordinates",
2164
+ doc="Uniformly sample a triangle. Returns sample barycentric coordinates.",
1921
2165
  )
1922
2166
  add_builtin(
1923
2167
  "sample_unit_ring",
1924
2168
  input_types={"state": uint32},
1925
2169
  value_type=vec2,
1926
2170
  group="Random",
1927
- doc="Uniformly sample a ring in the xy plane",
2171
+ doc="Uniformly sample a ring in the xy plane.",
1928
2172
  )
1929
2173
  add_builtin(
1930
2174
  "sample_unit_disk",
1931
2175
  input_types={"state": uint32},
1932
2176
  value_type=vec2,
1933
2177
  group="Random",
1934
- doc="Uniformly sample a disk in the xy plane",
2178
+ doc="Uniformly sample a disk in the xy plane.",
1935
2179
  )
1936
2180
  add_builtin(
1937
2181
  "sample_unit_sphere_surface",
1938
2182
  input_types={"state": uint32},
1939
2183
  value_type=vec3,
1940
2184
  group="Random",
1941
- doc="Uniformly sample a unit sphere surface",
2185
+ doc="Uniformly sample a unit sphere surface.",
1942
2186
  )
1943
2187
  add_builtin(
1944
2188
  "sample_unit_sphere",
1945
2189
  input_types={"state": uint32},
1946
2190
  value_type=vec3,
1947
2191
  group="Random",
1948
- doc="Uniformly sample a unit sphere",
2192
+ doc="Uniformly sample a unit sphere.",
1949
2193
  )
1950
2194
  add_builtin(
1951
2195
  "sample_unit_hemisphere_surface",
1952
2196
  input_types={"state": uint32},
1953
2197
  value_type=vec3,
1954
2198
  group="Random",
1955
- doc="Uniformly sample a unit hemisphere surface",
2199
+ doc="Uniformly sample a unit hemisphere surface.",
1956
2200
  )
1957
2201
  add_builtin(
1958
2202
  "sample_unit_hemisphere",
1959
2203
  input_types={"state": uint32},
1960
2204
  value_type=vec3,
1961
2205
  group="Random",
1962
- doc="Uniformly sample a unit hemisphere",
2206
+ doc="Uniformly sample a unit hemisphere.",
1963
2207
  )
1964
2208
  add_builtin(
1965
2209
  "sample_unit_square",
1966
2210
  input_types={"state": uint32},
1967
2211
  value_type=vec2,
1968
2212
  group="Random",
1969
- doc="Uniformly sample a unit square",
2213
+ doc="Uniformly sample a unit square.",
1970
2214
  )
1971
2215
  add_builtin(
1972
2216
  "sample_unit_cube",
1973
2217
  input_types={"state": uint32},
1974
2218
  value_type=vec3,
1975
2219
  group="Random",
1976
- doc="Uniformly sample a unit cube",
2220
+ doc="Uniformly sample a unit cube.",
1977
2221
  )
1978
2222
 
1979
2223
  add_builtin(
@@ -1982,9 +2226,9 @@ add_builtin(
1982
2226
  value_type=uint32,
1983
2227
  group="Random",
1984
2228
  doc="""Generate a random sample from a Poisson distribution.
1985
-
1986
- :param state: RNG state
1987
- :param lam: The expected value of the distribution""",
2229
+
2230
+ :param state: RNG state
2231
+ :param lam: The expected value of the distribution""",
1988
2232
  )
1989
2233
 
1990
2234
  add_builtin(
@@ -1992,28 +2236,28 @@ add_builtin(
1992
2236
  input_types={"state": uint32, "x": float},
1993
2237
  value_type=float,
1994
2238
  group="Random",
1995
- doc="Non-periodic Perlin-style noise in 1d.",
2239
+ doc="Non-periodic Perlin-style noise in 1D.",
1996
2240
  )
1997
2241
  add_builtin(
1998
2242
  "noise",
1999
2243
  input_types={"state": uint32, "xy": vec2},
2000
2244
  value_type=float,
2001
2245
  group="Random",
2002
- doc="Non-periodic Perlin-style noise in 2d.",
2246
+ doc="Non-periodic Perlin-style noise in 2D.",
2003
2247
  )
2004
2248
  add_builtin(
2005
2249
  "noise",
2006
2250
  input_types={"state": uint32, "xyz": vec3},
2007
2251
  value_type=float,
2008
2252
  group="Random",
2009
- doc="Non-periodic Perlin-style noise in 3d.",
2253
+ doc="Non-periodic Perlin-style noise in 3D.",
2010
2254
  )
2011
2255
  add_builtin(
2012
2256
  "noise",
2013
2257
  input_types={"state": uint32, "xyzt": vec4},
2014
2258
  value_type=float,
2015
2259
  group="Random",
2016
- doc="Non-periodic Perlin-style noise in 4d.",
2260
+ doc="Non-periodic Perlin-style noise in 4D.",
2017
2261
  )
2018
2262
 
2019
2263
  add_builtin(
@@ -2021,33 +2265,34 @@ add_builtin(
2021
2265
  input_types={"state": uint32, "x": float, "px": int},
2022
2266
  value_type=float,
2023
2267
  group="Random",
2024
- doc="Periodic Perlin-style noise in 1d.",
2268
+ doc="Periodic Perlin-style noise in 1D.",
2025
2269
  )
2026
2270
  add_builtin(
2027
2271
  "pnoise",
2028
2272
  input_types={"state": uint32, "xy": vec2, "px": int, "py": int},
2029
2273
  value_type=float,
2030
2274
  group="Random",
2031
- doc="Periodic Perlin-style noise in 2d.",
2275
+ doc="Periodic Perlin-style noise in 2D.",
2032
2276
  )
2033
2277
  add_builtin(
2034
2278
  "pnoise",
2035
2279
  input_types={"state": uint32, "xyz": vec3, "px": int, "py": int, "pz": int},
2036
2280
  value_type=float,
2037
2281
  group="Random",
2038
- doc="Periodic Perlin-style noise in 3d.",
2282
+ doc="Periodic Perlin-style noise in 3D.",
2039
2283
  )
2040
2284
  add_builtin(
2041
2285
  "pnoise",
2042
2286
  input_types={"state": uint32, "xyzt": vec4, "px": int, "py": int, "pz": int, "pt": int},
2043
2287
  value_type=float,
2044
2288
  group="Random",
2045
- doc="Periodic Perlin-style noise in 4d.",
2289
+ doc="Periodic Perlin-style noise in 4D.",
2046
2290
  )
2047
2291
 
2048
2292
  add_builtin(
2049
2293
  "curlnoise",
2050
- input_types={"state": uint32, "xy": vec2},
2294
+ input_types={"state": uint32, "xy": vec2, "octaves": uint32, "lacunarity": float, "gain": float},
2295
+ defaults={"octaves": 1, "lacunarity": 2.0, "gain": 0.5},
2051
2296
  value_type=vec2,
2052
2297
  group="Random",
2053
2298
  doc="Divergence-free vector field based on the gradient of a Perlin noise function.",
@@ -2055,7 +2300,8 @@ add_builtin(
2055
2300
  )
2056
2301
  add_builtin(
2057
2302
  "curlnoise",
2058
- input_types={"state": uint32, "xyz": vec3},
2303
+ input_types={"state": uint32, "xyz": vec3, "octaves": uint32, "lacunarity": float, "gain": float},
2304
+ defaults={"octaves": 1, "lacunarity": 2.0, "gain": 0.5},
2059
2305
  value_type=vec3,
2060
2306
  group="Random",
2061
2307
  doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
@@ -2063,7 +2309,8 @@ add_builtin(
2063
2309
  )
2064
2310
  add_builtin(
2065
2311
  "curlnoise",
2066
- input_types={"state": uint32, "xyzt": vec4},
2312
+ input_types={"state": uint32, "xyzt": vec4, "octaves": uint32, "lacunarity": float, "gain": float},
2313
+ defaults={"octaves": 1, "lacunarity": 2.0, "gain": 0.5},
2067
2314
  value_type=vec3,
2068
2315
  group="Random",
2069
2316
  doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
@@ -2077,7 +2324,7 @@ add_builtin(
2077
2324
  namespace="",
2078
2325
  variadic=True,
2079
2326
  group="Utility",
2080
- doc="Allows printing formatted strings, using C-style format specifiers.",
2327
+ doc="Allows printing formatted strings using C-style format specifiers.",
2081
2328
  )
2082
2329
 
2083
2330
  add_builtin("print", input_types={"value": Any}, doc="Print variable to stdout", export=False, group="Utility")
@@ -2097,9 +2344,12 @@ add_builtin(
2097
2344
  "tid",
2098
2345
  input_types={},
2099
2346
  value_type=int,
2347
+ export=False,
2100
2348
  group="Utility",
2101
- doc="""Return the current thread index. Note that this is the *global* index of the thread in the range [0, dim)
2102
- where dim is the parameter passed to kernel launch.""",
2349
+ doc="""Return the current thread index for a 1D kernel launch. Note that this is the *global* index of the thread in the range [0, dim)
2350
+ where dim is the parameter passed to kernel launch. This function may not be called from user-defined Warp functions.""",
2351
+ namespace="",
2352
+ native_func="builtin_tid1d",
2103
2353
  )
2104
2354
 
2105
2355
  add_builtin(
@@ -2107,7 +2357,10 @@ add_builtin(
2107
2357
  input_types={},
2108
2358
  value_type=[int, int],
2109
2359
  group="Utility",
2110
- doc="""Return the current thread indices for a 2d kernel launch. Use ``i,j = wp.tid()`` syntax to retrieve the coordinates inside the kernel thread grid.""",
2360
+ doc="""Return the current thread indices for a 2D kernel launch. Use ``i,j = wp.tid()`` syntax to retrieve the
2361
+ coordinates inside the kernel thread grid. This function may not be called from user-defined Warp functions.""",
2362
+ namespace="",
2363
+ native_func="builtin_tid2d",
2111
2364
  )
2112
2365
 
2113
2366
  add_builtin(
@@ -2115,7 +2368,10 @@ add_builtin(
2115
2368
  input_types={},
2116
2369
  value_type=[int, int, int],
2117
2370
  group="Utility",
2118
- doc="""Return the current thread indices for a 3d kernel launch. Use ``i,j,k = wp.tid()`` syntax to retrieve the coordinates inside the kernel thread grid.""",
2371
+ doc="""Return the current thread indices for a 3D kernel launch. Use ``i,j,k = wp.tid()`` syntax to retrieve the
2372
+ coordinates inside the kernel thread grid. This function may not be called from user-defined Warp functions.""",
2373
+ namespace="",
2374
+ native_func="builtin_tid3d",
2119
2375
  )
2120
2376
 
2121
2377
  add_builtin(
@@ -2123,49 +2379,60 @@ add_builtin(
2123
2379
  input_types={},
2124
2380
  value_type=[int, int, int, int],
2125
2381
  group="Utility",
2126
- doc="""Return the current thread indices for a 4d kernel launch. Use ``i,j,k,l = wp.tid()`` syntax to retrieve the coordinates inside the kernel thread grid.""",
2382
+ doc="""Return the current thread indices for a 4D kernel launch. Use ``i,j,k,l = wp.tid()`` syntax to retrieve the
2383
+ coordinates inside the kernel thread grid. This function may not be called from user-defined Warp functions.""",
2384
+ namespace="",
2385
+ native_func="builtin_tid4d",
2127
2386
  )
2128
2387
 
2129
2388
 
2130
- add_builtin("copy", variadic=True, hidden=True, export=False, group="Utility")
2389
+ add_builtin(
2390
+ "copy",
2391
+ input_types={"value": Any},
2392
+ value_func=lambda arg_types, kwds, _: arg_types[0],
2393
+ hidden=True,
2394
+ export=False,
2395
+ group="Utility",
2396
+ )
2397
+ add_builtin("assign", variadic=True, hidden=True, export=False, group="Utility")
2131
2398
  add_builtin(
2132
2399
  "select",
2133
2400
  input_types={"cond": bool, "arg1": Any, "arg2": Any},
2134
- value_func=lambda args, kwds, _: args[1].type,
2135
- doc="Select between two arguments, if cond is false then return ``arg1``, otherwise return ``arg2``",
2401
+ value_func=lambda arg_types, kwds, _: arg_types[1],
2402
+ doc="Select between two arguments, if ``cond`` is ``False`` then return ``arg1``, otherwise return ``arg2``",
2136
2403
  group="Utility",
2137
2404
  )
2138
2405
  add_builtin(
2139
2406
  "select",
2140
2407
  input_types={"cond": builtins.bool, "arg1": Any, "arg2": Any},
2141
2408
  value_func=lambda args, kwds, _: args[1].type,
2142
- doc="Select between two arguments, if cond is false then return ``arg1``, otherwise return ``arg2``",
2409
+ doc="Select between two arguments, if ``cond`` is ``False`` then return ``arg1``, otherwise return ``arg2``",
2143
2410
  group="Utility",
2144
2411
  )
2145
2412
  for t in int_types:
2146
2413
  add_builtin(
2147
2414
  "select",
2148
2415
  input_types={"cond": t, "arg1": Any, "arg2": Any},
2149
- value_func=lambda args, kwds, _: args[1].type,
2150
- doc="Select between two arguments, if cond is false then return ``arg1``, otherwise return ``arg2``",
2416
+ value_func=lambda arg_types, kwds, _: arg_types[1],
2417
+ doc="Select between two arguments, if ``cond`` is ``False`` then return ``arg1``, otherwise return ``arg2``",
2151
2418
  group="Utility",
2152
2419
  )
2153
2420
  add_builtin(
2154
2421
  "select",
2155
2422
  input_types={"arr": array(dtype=Any), "arg1": Any, "arg2": Any},
2156
- value_func=lambda args, kwds, _: args[1].type,
2157
- doc="Select between two arguments, if array is null then return ``arg1``, otherwise return ``arg2``",
2423
+ value_func=lambda arg_types, kwds, _: arg_types[1],
2424
+ doc="Select between two arguments, if ``arr`` is null then return ``arg1``, otherwise return ``arg2``",
2158
2425
  group="Utility",
2159
2426
  )
2160
2427
 
2161
2428
 
2162
- # does argument checking and type propagation for load()
2163
- def load_value_func(args, kwds, _):
2164
- if not is_array(args[0].type):
2429
+ # does argument checking and type propagation for address()
2430
+ def address_value_func(arg_types, kwds, _):
2431
+ if not is_array(arg_types[0]):
2165
2432
  raise RuntimeError("load() argument 0 must be an array")
2166
2433
 
2167
- num_indices = len(args[1:])
2168
- num_dims = args[0].type.ndim
2434
+ num_indices = len(arg_types[1:])
2435
+ num_dims = arg_types[0].ndim
2169
2436
 
2170
2437
  if num_indices < num_dims:
2171
2438
  raise RuntimeError(
@@ -2178,21 +2445,21 @@ def load_value_func(args, kwds, _):
2178
2445
  )
2179
2446
 
2180
2447
  # check index types
2181
- for a in args[1:]:
2182
- if type_is_int(a.type) == False:
2183
- raise RuntimeError(f"load() index arguments must be of integer type, got index of type {a.type}")
2448
+ for t in arg_types[1:]:
2449
+ if not type_is_int(t):
2450
+ raise RuntimeError(f"address() index arguments must be of integer type, got index of type {t}")
2184
2451
 
2185
- return args[0].type.dtype
2452
+ return Reference(arg_types[0].dtype)
2186
2453
 
2187
2454
 
2188
2455
  # does argument checking and type propagation for view()
2189
- def view_value_func(args, kwds, _):
2190
- if not is_array(args[0].type):
2456
+ def view_value_func(arg_types, kwds, _):
2457
+ if not is_array(arg_types[0]):
2191
2458
  raise RuntimeError("view() argument 0 must be an array")
2192
2459
 
2193
2460
  # check array dim big enough to support view
2194
- num_indices = len(args[1:])
2195
- num_dims = args[0].type.ndim
2461
+ num_indices = len(arg_types[1:])
2462
+ num_dims = arg_types[0].ndim
2196
2463
 
2197
2464
  if num_indices >= num_dims:
2198
2465
  raise RuntimeError(
@@ -2200,27 +2467,28 @@ def view_value_func(args, kwds, _):
2200
2467
  )
2201
2468
 
2202
2469
  # check index types
2203
- for a in args[1:]:
2204
- if type_is_int(a.type) == False:
2205
- raise RuntimeError(f"view() index arguments must be of integer type, got index of type {a.type}")
2470
+ for t in arg_types[1:]:
2471
+ if not type_is_int(t):
2472
+ raise RuntimeError(f"view() index arguments must be of integer type, got index of type {t}")
2206
2473
 
2207
2474
  # create an array view with leading dimensions removed
2208
- dtype = args[0].type.dtype
2475
+ dtype = arg_types[0].dtype
2209
2476
  ndim = num_dims - num_indices
2210
- if isinstance(args[0].type, (fabricarray, indexedfabricarray)):
2477
+ if isinstance(arg_types[0], (fabricarray, indexedfabricarray)):
2211
2478
  # fabric array of arrays: return array attribute as a regular array
2212
2479
  return array(dtype=dtype, ndim=ndim)
2213
2480
  else:
2214
- return type(args[0].type)(dtype=dtype, ndim=ndim)
2481
+ return type(arg_types[0])(dtype=dtype, ndim=ndim)
2482
+
2215
2483
 
2216
- # does argument checking and type propagation for store()
2217
- def store_value_func(args, kwds, _):
2484
+ # does argument checking and type propagation for array_store()
2485
+ def array_store_value_func(arg_types, kwds, _):
2218
2486
  # check target type
2219
- if not is_array(args[0].type):
2220
- raise RuntimeError("store() argument 0 must be an array")
2487
+ if not is_array(arg_types[0]):
2488
+ raise RuntimeError("array_store() argument 0 must be an array")
2221
2489
 
2222
- num_indices = len(args[1:-1])
2223
- num_dims = args[0].type.ndim
2490
+ num_indices = len(arg_types[1:-1])
2491
+ num_dims = arg_types[0].ndim
2224
2492
 
2225
2493
  # if this happens we should have generated a view instead of a load during code gen
2226
2494
  if num_indices < num_dims:
@@ -2232,31 +2500,63 @@ def store_value_func(args, kwds, _):
2232
2500
  )
2233
2501
 
2234
2502
  # check index types
2235
- for a in args[1:-1]:
2236
- if type_is_int(a.type) == False:
2237
- raise RuntimeError(f"store() index arguments must be of integer type, got index of type {a.type}")
2503
+ for t in arg_types[1:-1]:
2504
+ if not type_is_int(t):
2505
+ raise RuntimeError(f"array_store() index arguments must be of integer type, got index of type {t}")
2238
2506
 
2239
2507
  # check value type
2240
- if not types_equal(args[-1].type, args[0].type.dtype):
2508
+ if not types_equal(arg_types[-1], arg_types[0].dtype):
2241
2509
  raise RuntimeError(
2242
- f"store() value argument type ({args[2].type}) must be of the same type as the array ({args[0].type.dtype})"
2510
+ f"array_store() value argument type ({arg_types[2]}) must be of the same type as the array ({arg_types[0].dtype})"
2243
2511
  )
2244
2512
 
2245
2513
  return None
2246
2514
 
2247
2515
 
2248
- add_builtin("load", variadic=True, hidden=True, value_func=load_value_func, group="Utility")
2516
+ # does argument checking for store()
2517
+ def store_value_func(arg_types, kwds, _):
2518
+ # we already stripped the Reference from the argument type prior to this call
2519
+ if not types_equal(arg_types[0], arg_types[1]):
2520
+ raise RuntimeError(f"store() value argument type ({arg_types[1]}) must be of the same type as the reference")
2521
+
2522
+ return None
2523
+
2524
+
2525
+ # does type propagation for load()
2526
+ def load_value_func(arg_types, kwds, _):
2527
+ # we already stripped the Reference from the argument type prior to this call
2528
+ return arg_types[0]
2529
+
2530
+
2531
+ add_builtin("address", variadic=True, hidden=True, value_func=address_value_func, group="Utility")
2249
2532
  add_builtin("view", variadic=True, hidden=True, value_func=view_value_func, group="Utility")
2250
- add_builtin("store", variadic=True, hidden=True, value_func=store_value_func, skip_replay=True, group="Utility")
2533
+ add_builtin(
2534
+ "array_store", variadic=True, hidden=True, value_func=array_store_value_func, skip_replay=True, group="Utility"
2535
+ )
2536
+ add_builtin(
2537
+ "store",
2538
+ input_types={"address": Reference, "value": Any},
2539
+ hidden=True,
2540
+ value_func=store_value_func,
2541
+ skip_replay=True,
2542
+ group="Utility",
2543
+ )
2544
+ add_builtin(
2545
+ "load",
2546
+ input_types={"address": Reference},
2547
+ hidden=True,
2548
+ value_func=load_value_func,
2549
+ group="Utility",
2550
+ )
2251
2551
 
2252
2552
 
2253
- def atomic_op_value_func(args, kwds, _):
2553
+ def atomic_op_value_func(arg_types, kwds, _):
2254
2554
  # check target type
2255
- if not is_array(args[0].type):
2555
+ if not is_array(arg_types[0]):
2256
2556
  raise RuntimeError("atomic() operation argument 0 must be an array")
2257
2557
 
2258
- num_indices = len(args[1:-1])
2259
- num_dims = args[0].type.ndim
2558
+ num_indices = len(arg_types[1:-1])
2559
+ num_dims = arg_types[0].ndim
2260
2560
 
2261
2561
  # if this happens we should have generated a view instead of a load during code gen
2262
2562
  if num_indices < num_dims:
@@ -2268,18 +2568,16 @@ def atomic_op_value_func(args, kwds, _):
2268
2568
  )
2269
2569
 
2270
2570
  # check index types
2271
- for a in args[1:-1]:
2272
- if type_is_int(a.type) == False:
2273
- raise RuntimeError(
2274
- f"atomic() operation index arguments must be of integer type, got index of type {a.type}"
2275
- )
2571
+ for t in arg_types[1:-1]:
2572
+ if not type_is_int(t):
2573
+ raise RuntimeError(f"atomic() operation index arguments must be of integer type, got index of type {t}")
2276
2574
 
2277
- if not types_equal(args[-1].type, args[0].type.dtype):
2575
+ if not types_equal(arg_types[-1], arg_types[0].dtype):
2278
2576
  raise RuntimeError(
2279
- f"atomic() value argument ({args[-1].type}) must be of the same type as the array ({args[0].type.dtype})"
2577
+ f"atomic() value argument ({arg_types[-1]}) must be of the same type as the array ({arg_types[0].dtype})"
2280
2578
  )
2281
2579
 
2282
- return args[0].type.dtype
2580
+ return arg_types[0].dtype
2283
2581
 
2284
2582
 
2285
2583
  for array_type in array_types:
@@ -2291,7 +2589,7 @@ for array_type in array_types:
2291
2589
  hidden=hidden,
2292
2590
  input_types={"a": array_type(dtype=Any), "i": int, "value": Any},
2293
2591
  value_func=atomic_op_value_func,
2294
- doc="Atomically add ``value`` onto the array at location given by index.",
2592
+ doc="Atomically add ``value`` onto ``a[i]``.",
2295
2593
  group="Utility",
2296
2594
  skip_replay=True,
2297
2595
  )
@@ -2300,7 +2598,7 @@ for array_type in array_types:
2300
2598
  hidden=hidden,
2301
2599
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "value": Any},
2302
2600
  value_func=atomic_op_value_func,
2303
- doc="Atomically add ``value`` onto the array at location given by indices.",
2601
+ doc="Atomically add ``value`` onto ``a[i,j]``.",
2304
2602
  group="Utility",
2305
2603
  skip_replay=True,
2306
2604
  )
@@ -2309,7 +2607,7 @@ for array_type in array_types:
2309
2607
  hidden=hidden,
2310
2608
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
2311
2609
  value_func=atomic_op_value_func,
2312
- doc="Atomically add ``value`` onto the array at location given by indices.",
2610
+ doc="Atomically add ``value`` onto ``a[i,j,k]``.",
2313
2611
  group="Utility",
2314
2612
  skip_replay=True,
2315
2613
  )
@@ -2318,7 +2616,7 @@ for array_type in array_types:
2318
2616
  hidden=hidden,
2319
2617
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
2320
2618
  value_func=atomic_op_value_func,
2321
- doc="Atomically add ``value`` onto the array at location given by indices.",
2619
+ doc="Atomically add ``value`` onto ``a[i,j,k,l]``.",
2322
2620
  group="Utility",
2323
2621
  skip_replay=True,
2324
2622
  )
@@ -2328,7 +2626,7 @@ for array_type in array_types:
2328
2626
  hidden=hidden,
2329
2627
  input_types={"a": array_type(dtype=Any), "i": int, "value": Any},
2330
2628
  value_func=atomic_op_value_func,
2331
- doc="Atomically subtract ``value`` onto the array at location given by index.",
2629
+ doc="Atomically subtract ``value`` onto ``a[i]``.",
2332
2630
  group="Utility",
2333
2631
  skip_replay=True,
2334
2632
  )
@@ -2337,7 +2635,7 @@ for array_type in array_types:
2337
2635
  hidden=hidden,
2338
2636
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "value": Any},
2339
2637
  value_func=atomic_op_value_func,
2340
- doc="Atomically subtract ``value`` onto the array at location given by indices.",
2638
+ doc="Atomically subtract ``value`` onto ``a[i,j]``.",
2341
2639
  group="Utility",
2342
2640
  skip_replay=True,
2343
2641
  )
@@ -2346,7 +2644,7 @@ for array_type in array_types:
2346
2644
  hidden=hidden,
2347
2645
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
2348
2646
  value_func=atomic_op_value_func,
2349
- doc="Atomically subtract ``value`` onto the array at location given by indices.",
2647
+ doc="Atomically subtract ``value`` onto ``a[i,j,k]``.",
2350
2648
  group="Utility",
2351
2649
  skip_replay=True,
2352
2650
  )
@@ -2355,7 +2653,7 @@ for array_type in array_types:
2355
2653
  hidden=hidden,
2356
2654
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
2357
2655
  value_func=atomic_op_value_func,
2358
- doc="Atomically subtract ``value`` onto the array at location given by indices.",
2656
+ doc="Atomically subtract ``value`` onto ``a[i,j,k,l]``.",
2359
2657
  group="Utility",
2360
2658
  skip_replay=True,
2361
2659
  )
@@ -2365,7 +2663,8 @@ for array_type in array_types:
2365
2663
  hidden=hidden,
2366
2664
  input_types={"a": array_type(dtype=Any), "i": int, "value": Any},
2367
2665
  value_func=atomic_op_value_func,
2368
- doc="Compute the minimum of ``value`` and ``array[index]`` and atomically update the array. Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2666
+ doc="Compute the minimum of ``value`` and ``a[i]`` and atomically update the array.\n\n"
2667
+ "Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2369
2668
  group="Utility",
2370
2669
  skip_replay=True,
2371
2670
  )
@@ -2374,7 +2673,8 @@ for array_type in array_types:
2374
2673
  hidden=hidden,
2375
2674
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "value": Any},
2376
2675
  value_func=atomic_op_value_func,
2377
- doc="Compute the minimum of ``value`` and ``array[index]`` and atomically update the array. Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2676
+ doc="Compute the minimum of ``value`` and ``a[i,j]`` and atomically update the array.\n\n"
2677
+ "Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2378
2678
  group="Utility",
2379
2679
  skip_replay=True,
2380
2680
  )
@@ -2383,7 +2683,8 @@ for array_type in array_types:
2383
2683
  hidden=hidden,
2384
2684
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
2385
2685
  value_func=atomic_op_value_func,
2386
- doc="Compute the minimum of ``value`` and ``array[index]`` and atomically update the array. Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2686
+ doc="Compute the minimum of ``value`` and ``a[i,j,k]`` and atomically update the array.\n\n"
2687
+ "Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2387
2688
  group="Utility",
2388
2689
  skip_replay=True,
2389
2690
  )
@@ -2392,7 +2693,8 @@ for array_type in array_types:
2392
2693
  hidden=hidden,
2393
2694
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
2394
2695
  value_func=atomic_op_value_func,
2395
- doc="Compute the minimum of ``value`` and ``array[index]`` and atomically update the array. Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2696
+ doc="Compute the minimum of ``value`` and ``a[i,j,k,l]`` and atomically update the array.\n\n"
2697
+ "Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2396
2698
  group="Utility",
2397
2699
  skip_replay=True,
2398
2700
  )
@@ -2402,7 +2704,8 @@ for array_type in array_types:
2402
2704
  hidden=hidden,
2403
2705
  input_types={"a": array_type(dtype=Any), "i": int, "value": Any},
2404
2706
  value_func=atomic_op_value_func,
2405
- doc="Compute the maximum of ``value`` and ``array[index]`` and atomically update the array. Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2707
+ doc="Compute the maximum of ``value`` and ``a[i]`` and atomically update the array.\n\n"
2708
+ "Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2406
2709
  group="Utility",
2407
2710
  skip_replay=True,
2408
2711
  )
@@ -2411,7 +2714,8 @@ for array_type in array_types:
2411
2714
  hidden=hidden,
2412
2715
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "value": Any},
2413
2716
  value_func=atomic_op_value_func,
2414
- doc="Compute the maximum of ``value`` and ``array[index]`` and atomically update the array. Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2717
+ doc="Compute the maximum of ``value`` and ``a[i,j]`` and atomically update the array.\n\n"
2718
+ "Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2415
2719
  group="Utility",
2416
2720
  skip_replay=True,
2417
2721
  )
@@ -2420,7 +2724,8 @@ for array_type in array_types:
2420
2724
  hidden=hidden,
2421
2725
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
2422
2726
  value_func=atomic_op_value_func,
2423
- doc="Compute the maximum of ``value`` and ``array[index]`` and atomically update the array. Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2727
+ doc="Compute the maximum of ``value`` and ``a[i,j,k]`` and atomically update the array.\n\n"
2728
+ "Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2424
2729
  group="Utility",
2425
2730
  skip_replay=True,
2426
2731
  )
@@ -2429,26 +2734,27 @@ for array_type in array_types:
2429
2734
  hidden=hidden,
2430
2735
  input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
2431
2736
  value_func=atomic_op_value_func,
2432
- doc="Compute the maximum of ``value`` and ``array[index]`` and atomically update the array. Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2737
+ doc="Compute the maximum of ``value`` and ``a[i,j,k,l]`` and atomically update the array.\n\n"
2738
+ "Note that for vectors and matrices the operation is only atomic on a per-component basis.",
2433
2739
  group="Utility",
2434
2740
  skip_replay=True,
2435
2741
  )
2436
2742
 
2437
2743
 
2438
2744
  # used to index into builtin types, i.e.: y = vec3[1]
2439
- def index_value_func(args, kwds, _):
2440
- return args[0].type._wp_scalar_type_
2745
+ def index_value_func(arg_types, kwds, _):
2746
+ return arg_types[0]._wp_scalar_type_
2441
2747
 
2442
2748
 
2443
2749
  add_builtin(
2444
- "index",
2750
+ "extract",
2445
2751
  input_types={"a": vector(length=Any, dtype=Scalar), "i": int},
2446
2752
  value_func=index_value_func,
2447
2753
  hidden=True,
2448
2754
  group="Utility",
2449
2755
  )
2450
2756
  add_builtin(
2451
- "index",
2757
+ "extract",
2452
2758
  input_types={"a": quaternion(dtype=Scalar), "i": int},
2453
2759
  value_func=index_value_func,
2454
2760
  hidden=True,
@@ -2456,14 +2762,14 @@ add_builtin(
2456
2762
  )
2457
2763
 
2458
2764
  add_builtin(
2459
- "index",
2765
+ "extract",
2460
2766
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int},
2461
- value_func=lambda args, kwds, _: vector(length=args[0].type._shape_[1], dtype=args[0].type._wp_scalar_type_),
2767
+ value_func=lambda arg_types, kwds, _: vector(length=arg_types[0]._shape_[1], dtype=arg_types[0]._wp_scalar_type_),
2462
2768
  hidden=True,
2463
2769
  group="Utility",
2464
2770
  )
2465
2771
  add_builtin(
2466
- "index",
2772
+ "extract",
2467
2773
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int},
2468
2774
  value_func=index_value_func,
2469
2775
  hidden=True,
@@ -2471,77 +2777,66 @@ add_builtin(
2471
2777
  )
2472
2778
 
2473
2779
  add_builtin(
2474
- "index",
2780
+ "extract",
2475
2781
  input_types={"a": transformation(dtype=Scalar), "i": int},
2476
2782
  value_func=index_value_func,
2477
2783
  hidden=True,
2478
2784
  group="Utility",
2479
2785
  )
2480
2786
 
2481
- add_builtin("index", input_types={"s": shape_t, "i": int}, value_type=int, hidden=True, group="Utility")
2787
+ add_builtin("extract", input_types={"s": shape_t, "i": int}, value_type=int, hidden=True, group="Utility")
2482
2788
 
2483
2789
 
2484
- def vector_indexset_element_value_func(args, kwds, _):
2485
- vec = args[0]
2486
- index = args[1]
2487
- value = args[2]
2790
+ def vector_indexref_element_value_func(arg_types, kwds, _):
2791
+ vec_type = arg_types[0]
2792
+ # index_type = arg_types[1]
2793
+ value_type = vec_type._wp_scalar_type_
2488
2794
 
2489
- if value.type is not vec.type._wp_scalar_type_:
2490
- raise RuntimeError(
2491
- f"Trying to assign type '{type_repr(value.type)}' to element of a vector with type '{type_repr(vec.type)}'"
2492
- )
2795
+ return Reference(value_type)
2493
2796
 
2494
- return None
2495
2797
 
2496
-
2497
- # implements vector[index] = value
2798
+ # implements &vector[index]
2498
2799
  add_builtin(
2499
- "indexset",
2500
- input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
2501
- value_func=vector_indexset_element_value_func,
2800
+ "index",
2801
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": int},
2802
+ value_func=vector_indexref_element_value_func,
2803
+ hidden=True,
2804
+ group="Utility",
2805
+ skip_replay=True,
2806
+ )
2807
+ # implements &(*vector)[index]
2808
+ add_builtin(
2809
+ "indexref",
2810
+ input_types={"a": Reference, "i": int},
2811
+ value_func=vector_indexref_element_value_func,
2502
2812
  hidden=True,
2503
2813
  group="Utility",
2504
2814
  skip_replay=True,
2505
2815
  )
2506
2816
 
2507
2817
 
2508
- def matrix_indexset_element_value_func(args, kwds, _):
2509
- mat = args[0]
2510
- row = args[1]
2511
- col = args[2]
2512
- value = args[3]
2818
+ def matrix_indexref_element_value_func(arg_types, kwds, _):
2819
+ mat_type = arg_types[0]
2820
+ # row_type = arg_types[1]
2821
+ # col_type = arg_types[2]
2822
+ value_type = mat_type._wp_scalar_type_
2513
2823
 
2514
- if value.type is not mat.type._wp_scalar_type_:
2515
- raise RuntimeError(
2516
- f"Trying to assign type '{type_repr(value.type)}' to element of a matrix with type '{type_repr(mat.type)}'"
2517
- )
2824
+ return Reference(value_type)
2518
2825
 
2519
- return None
2520
2826
 
2827
+ def matrix_indexref_row_value_func(arg_types, kwds, _):
2828
+ mat_type = arg_types[0]
2829
+ row_type = mat_type._wp_row_type_
2830
+ # value_type = arg_types[2]
2521
2831
 
2522
- def matrix_indexset_row_value_func(args, kwds, _):
2523
- mat = args[0]
2524
- row = args[1]
2525
- value = args[2]
2526
-
2527
- if value.type._shape_[0] != mat.type._shape_[1]:
2528
- raise RuntimeError(
2529
- f"Trying to assign vector with length {value.type._length} to matrix with shape {mat.type._shape}, vector length must match the number of matrix columns."
2530
- )
2531
-
2532
- if value.type._wp_scalar_type_ is not mat.type._wp_scalar_type_:
2533
- raise RuntimeError(
2534
- f"Trying to assign vector of type '{type_repr(value.type)}' to row of matrix of type '{type_repr(mat.type)}'"
2535
- )
2536
-
2537
- return None
2832
+ return Reference(row_type)
2538
2833
 
2539
2834
 
2540
2835
  # implements matrix[i] = row
2541
2836
  add_builtin(
2542
- "indexset",
2543
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
2544
- value_func=matrix_indexset_row_value_func,
2837
+ "index",
2838
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int},
2839
+ value_func=matrix_indexref_row_value_func,
2545
2840
  hidden=True,
2546
2841
  group="Utility",
2547
2842
  skip_replay=True,
@@ -2549,29 +2844,29 @@ add_builtin(
2549
2844
 
2550
2845
  # implements matrix[i,j] = scalar
2551
2846
  add_builtin(
2552
- "indexset",
2553
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
2554
- value_func=matrix_indexset_element_value_func,
2847
+ "index",
2848
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int},
2849
+ value_func=matrix_indexref_element_value_func,
2555
2850
  hidden=True,
2556
2851
  group="Utility",
2557
2852
  skip_replay=True,
2558
2853
  )
2559
2854
 
2560
- for t in scalar_types + vector_types:
2855
+ for t in scalar_types + vector_types + [builtins.bool]:
2561
2856
  if "vec" in t.__name__ or "mat" in t.__name__:
2562
2857
  continue
2563
2858
  add_builtin(
2564
2859
  "expect_eq",
2565
2860
  input_types={"arg1": t, "arg2": t},
2566
2861
  value_type=None,
2567
- doc="Prints an error to stdout if arg1 and arg2 are not equal",
2862
+ doc="Prints an error to stdout if ``arg1`` and ``arg2`` are not equal",
2568
2863
  group="Utility",
2569
2864
  hidden=True,
2570
2865
  )
2571
2866
 
2572
2867
 
2573
- def expect_eq_val_func(args, kwds, _):
2574
- if not types_equal(args[0].type, args[1].type):
2868
+ def expect_eq_val_func(arg_types, kwds, _):
2869
+ if not types_equal(arg_types[0], arg_types[1]):
2575
2870
  raise RuntimeError("Can't test equality for objects with different types")
2576
2871
  return None
2577
2872
 
@@ -2580,7 +2875,7 @@ add_builtin(
2580
2875
  "expect_eq",
2581
2876
  input_types={"arg1": vector(length=Any, dtype=Scalar), "arg2": vector(length=Any, dtype=Scalar)},
2582
2877
  value_func=expect_eq_val_func,
2583
- doc="Prints an error to stdout if arg1 and arg2 are not equal",
2878
+ doc="Prints an error to stdout if ``arg1`` and ``arg2`` are not equal",
2584
2879
  group="Utility",
2585
2880
  hidden=True,
2586
2881
  )
@@ -2588,7 +2883,7 @@ add_builtin(
2588
2883
  "expect_neq",
2589
2884
  input_types={"arg1": vector(length=Any, dtype=Scalar), "arg2": vector(length=Any, dtype=Scalar)},
2590
2885
  value_func=expect_eq_val_func,
2591
- doc="Prints an error to stdout if arg1 and arg2 are equal",
2886
+ doc="Prints an error to stdout if ``arg1`` and ``arg2`` are equal",
2592
2887
  group="Utility",
2593
2888
  hidden=True,
2594
2889
  )
@@ -2597,7 +2892,7 @@ add_builtin(
2597
2892
  "expect_eq",
2598
2893
  input_types={"arg1": matrix(shape=(Any, Any), dtype=Scalar), "arg2": matrix(shape=(Any, Any), dtype=Scalar)},
2599
2894
  value_func=expect_eq_val_func,
2600
- doc="Prints an error to stdout if arg1 and arg2 are not equal",
2895
+ doc="Prints an error to stdout if ``arg1`` and ``arg2`` are not equal",
2601
2896
  group="Utility",
2602
2897
  hidden=True,
2603
2898
  )
@@ -2605,7 +2900,7 @@ add_builtin(
2605
2900
  "expect_neq",
2606
2901
  input_types={"arg1": matrix(shape=(Any, Any), dtype=Scalar), "arg2": matrix(shape=(Any, Any), dtype=Scalar)},
2607
2902
  value_func=expect_eq_val_func,
2608
- doc="Prints an error to stdout if arg1 and arg2 are equal",
2903
+ doc="Prints an error to stdout if ``arg1`` and ``arg2`` are equal",
2609
2904
  group="Utility",
2610
2905
  hidden=True,
2611
2906
  )
@@ -2614,29 +2909,30 @@ add_builtin(
2614
2909
  "lerp",
2615
2910
  input_types={"a": Float, "b": Float, "t": Float},
2616
2911
  value_func=sametype_value_func(Float),
2617
- doc="Linearly interpolate two values a and b using factor t, computed as ``a*(1-t) + b*t``",
2912
+ doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
2618
2913
  group="Utility",
2619
2914
  )
2620
2915
  add_builtin(
2621
2916
  "smoothstep",
2622
2917
  input_types={"edge0": Float, "edge1": Float, "x": Float},
2623
2918
  value_func=sametype_value_func(Float),
2624
- doc="Smoothly interpolate between two values edge0 and edge1 using a factor x, and return a result between 0 and 1 using a cubic Hermite interpolation after clamping",
2919
+ doc="""Smoothly interpolate between two values ``edge0`` and ``edge1`` using a factor ``x``,
2920
+ and return a result between 0 and 1 using a cubic Hermite interpolation after clamping.""",
2625
2921
  group="Utility",
2626
2922
  )
2627
2923
 
2628
2924
 
2629
2925
  def lerp_value_func(default):
2630
- def fn(args, kwds, _):
2631
- if args is None:
2926
+ def fn(arg_types, kwds, _):
2927
+ if arg_types is None:
2632
2928
  return default
2633
- scalar_type = args[-1].type
2634
- if not types_equal(args[0].type, args[1].type):
2929
+ scalar_type = arg_types[-1]
2930
+ if not types_equal(arg_types[0], arg_types[1]):
2635
2931
  raise RuntimeError("Can't lerp between objects with different types")
2636
- if args[0].type._wp_scalar_type_ != scalar_type:
2932
+ if arg_types[0]._wp_scalar_type_ != scalar_type:
2637
2933
  raise RuntimeError("'t' parameter must have the same scalar type as objects you're lerping between")
2638
2934
 
2639
- return args[0].type
2935
+ return arg_types[0]
2640
2936
 
2641
2937
  return fn
2642
2938
 
@@ -2645,28 +2941,28 @@ add_builtin(
2645
2941
  "lerp",
2646
2942
  input_types={"a": vector(length=Any, dtype=Float), "b": vector(length=Any, dtype=Float), "t": Float},
2647
2943
  value_func=lerp_value_func(vector(length=Any, dtype=Float)),
2648
- doc="Linearly interpolate two values a and b using factor t, computed as ``a*(1-t) + b*t``",
2944
+ doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
2649
2945
  group="Utility",
2650
2946
  )
2651
2947
  add_builtin(
2652
2948
  "lerp",
2653
2949
  input_types={"a": matrix(shape=(Any, Any), dtype=Float), "b": matrix(shape=(Any, Any), dtype=Float), "t": Float},
2654
2950
  value_func=lerp_value_func(matrix(shape=(Any, Any), dtype=Float)),
2655
- doc="Linearly interpolate two values a and b using factor t, computed as ``a*(1-t) + b*t``",
2951
+ doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
2656
2952
  group="Utility",
2657
2953
  )
2658
2954
  add_builtin(
2659
2955
  "lerp",
2660
2956
  input_types={"a": quaternion(dtype=Float), "b": quaternion(dtype=Float), "t": Float},
2661
2957
  value_func=lerp_value_func(quaternion(dtype=Float)),
2662
- doc="Linearly interpolate two values a and b using factor t, computed as ``a*(1-t) + b*t``",
2958
+ doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
2663
2959
  group="Utility",
2664
2960
  )
2665
2961
  add_builtin(
2666
2962
  "lerp",
2667
2963
  input_types={"a": transformation(dtype=Float), "b": transformation(dtype=Float), "t": Float},
2668
2964
  value_func=lerp_value_func(transformation(dtype=Float)),
2669
- doc="Linearly interpolate two values a and b using factor t, computed as ``a*(1-t) + b*t``",
2965
+ doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
2670
2966
  group="Utility",
2671
2967
  )
2672
2968
 
@@ -2676,14 +2972,14 @@ add_builtin(
2676
2972
  input_types={"arg1": Float, "arg2": Float, "tolerance": Float},
2677
2973
  defaults={"tolerance": 1.0e-6},
2678
2974
  value_type=None,
2679
- doc="Prints an error to stdout if arg1 and arg2 are not closer than tolerance in magnitude",
2975
+ doc="Prints an error to stdout if ``arg1`` and ``arg2`` are not closer than tolerance in magnitude",
2680
2976
  group="Utility",
2681
2977
  )
2682
2978
  add_builtin(
2683
2979
  "expect_near",
2684
2980
  input_types={"arg1": vec3, "arg2": vec3, "tolerance": float},
2685
2981
  value_type=None,
2686
- doc="Prints an error to stdout if any element of arg1 and arg2 are not closer than tolerance in magnitude",
2982
+ doc="Prints an error to stdout if any element of ``arg1`` and ``arg2`` are not closer than tolerance in magnitude",
2687
2983
  group="Utility",
2688
2984
  )
2689
2985
 
@@ -2694,14 +2990,14 @@ add_builtin(
2694
2990
  "lower_bound",
2695
2991
  input_types={"arr": array(dtype=Scalar), "value": Scalar},
2696
2992
  value_type=int,
2697
- doc="Search a sorted array for the closest element greater than or equal to value.",
2993
+ doc="Search a sorted array ``arr`` for the closest element greater than or equal to ``value``.",
2698
2994
  )
2699
2995
 
2700
2996
  add_builtin(
2701
2997
  "lower_bound",
2702
2998
  input_types={"arr": array(dtype=Scalar), "arr_begin": int, "arr_end": int, "value": Scalar},
2703
2999
  value_type=int,
2704
- doc="Search a sorted array range [arr_begin, arr_end) for the closest element greater than or equal to value.",
3000
+ doc="Search a sorted array ``arr`` in the range [arr_begin, arr_end) for the closest element greater than or equal to ``value``.",
2705
3001
  )
2706
3002
 
2707
3003
  # ---------------------------------
@@ -2781,11 +3077,11 @@ add_builtin("invert", input_types={"x": Int}, value_func=sametype_value_func(Int
2781
3077
 
2782
3078
 
2783
3079
  def scalar_mul_value_func(default):
2784
- def fn(args, kwds, _):
2785
- if args is None:
3080
+ def fn(arg_types, kwds, _):
3081
+ if arg_types is None:
2786
3082
  return default
2787
- scalar = [a.type for a in args if a.type in scalar_types][0]
2788
- compound = [a.type for a in args if a.type not in scalar_types][0]
3083
+ scalar = [t for t in arg_types if t in scalar_types][0]
3084
+ compound = [t for t in arg_types if t not in scalar_types][0]
2789
3085
  if scalar != compound._wp_scalar_type_:
2790
3086
  raise RuntimeError("Object and coefficient must have the same scalar type when multiplying by scalar")
2791
3087
  return compound
@@ -2793,36 +3089,53 @@ def scalar_mul_value_func(default):
2793
3089
  return fn
2794
3090
 
2795
3091
 
2796
- def mul_matvec_value_func(args, kwds, _):
2797
- if args is None:
3092
+ def mul_matvec_value_func(arg_types, kwds, _):
3093
+ if arg_types is None:
2798
3094
  return vector(length=Any, dtype=Scalar)
2799
3095
 
2800
- if args[0].type._wp_scalar_type_ != args[1].type._wp_scalar_type_:
3096
+ if arg_types[0]._wp_scalar_type_ != arg_types[1]._wp_scalar_type_:
2801
3097
  raise RuntimeError(
2802
- f"Can't multiply matrix and vector with different types {args[0].type._wp_scalar_type_}, {args[1].type._wp_scalar_type_}"
3098
+ f"Can't multiply matrix and vector with different types {arg_types[0]._wp_scalar_type_}, {arg_types[1]._wp_scalar_type_}"
2803
3099
  )
2804
3100
 
2805
- if args[0].type._shape_[1] != args[1].type._length_:
3101
+ if arg_types[0]._shape_[1] != arg_types[1]._length_:
2806
3102
  raise RuntimeError(
2807
- f"Can't multiply matrix of shape {args[0].type._shape_} and vector with length {args[1].type._length_}"
3103
+ f"Can't multiply matrix of shape {arg_types[0]._shape_} and vector with length {arg_types[1]._length_}"
2808
3104
  )
2809
3105
 
2810
- return vector(length=args[0].type._shape_[0], dtype=args[0].type._wp_scalar_type_)
3106
+ return vector(length=arg_types[0]._shape_[0], dtype=arg_types[0]._wp_scalar_type_)
3107
+
3108
+
3109
+ def mul_vecmat_value_func(arg_types, kwds, _):
3110
+ if arg_types is None:
3111
+ return vector(length=Any, dtype=Scalar)
2811
3112
 
3113
+ if arg_types[1]._wp_scalar_type_ != arg_types[0]._wp_scalar_type_:
3114
+ raise RuntimeError(
3115
+ f"Can't multiply vector and matrix with different types {arg_types[1]._wp_scalar_type_}, {arg_types[0]._wp_scalar_type_}"
3116
+ )
2812
3117
 
2813
- def mul_matmat_value_func(args, kwds, _):
2814
- if args is None:
3118
+ if arg_types[1]._shape_[0] != arg_types[0]._length_:
3119
+ raise RuntimeError(
3120
+ f"Can't multiply vector with length {arg_types[0]._length_} and matrix of shape {arg_types[1]._shape_}"
3121
+ )
3122
+
3123
+ return vector(length=arg_types[1]._shape_[1], dtype=arg_types[1]._wp_scalar_type_)
3124
+
3125
+
3126
+ def mul_matmat_value_func(arg_types, kwds, _):
3127
+ if arg_types is None:
2815
3128
  return matrix(length=Any, dtype=Scalar)
2816
3129
 
2817
- if args[0].type._wp_scalar_type_ != args[1].type._wp_scalar_type_:
3130
+ if arg_types[0]._wp_scalar_type_ != arg_types[1]._wp_scalar_type_:
2818
3131
  raise RuntimeError(
2819
- f"Can't multiply matrices with different types {args[0].type._wp_scalar_type_}, {args[1].type._wp_scalar_type_}"
3132
+ f"Can't multiply matrices with different types {arg_types[0]._wp_scalar_type_}, {arg_types[1]._wp_scalar_type_}"
2820
3133
  )
2821
3134
 
2822
- if args[0].type._shape_[1] != args[1].type._shape_[0]:
2823
- raise RuntimeError(f"Can't multiply matrix of shapes {args[0].type._shape_} and {args[1].type._shape_}")
3135
+ if arg_types[0]._shape_[1] != arg_types[1]._shape_[0]:
3136
+ raise RuntimeError(f"Can't multiply matrix of shapes {arg_types[0]._shape_} and {arg_types[1]._shape_}")
2824
3137
 
2825
- return matrix(shape=(args[0].type._shape_[0], args[1].type._shape_[1]), dtype=args[0].type._wp_scalar_type_)
3138
+ return matrix(shape=(arg_types[0]._shape_[0], arg_types[1]._shape_[1]), dtype=arg_types[0]._wp_scalar_type_)
2826
3139
 
2827
3140
 
2828
3141
  add_builtin(
@@ -2884,6 +3197,13 @@ add_builtin(
2884
3197
  doc="",
2885
3198
  group="Operators",
2886
3199
  )
3200
+ add_builtin(
3201
+ "mul",
3202
+ input_types={"x": vector(length=Any, dtype=Scalar), "y": matrix(shape=(Any, Any), dtype=Scalar)},
3203
+ value_func=mul_vecmat_value_func,
3204
+ doc="",
3205
+ group="Operators",
3206
+ )
2887
3207
  add_builtin(
2888
3208
  "mul",
2889
3209
  input_types={"x": matrix(shape=(Any, Any), dtype=Scalar), "y": matrix(shape=(Any, Any), dtype=Scalar)},
@@ -2919,7 +3239,11 @@ add_builtin(
2919
3239
  )
2920
3240
 
2921
3241
  add_builtin(
2922
- "div", input_types={"x": Scalar, "y": Scalar}, value_func=sametype_value_func(Scalar), doc="", group="Operators"
3242
+ "div",
3243
+ input_types={"x": Scalar, "y": Scalar},
3244
+ value_func=sametype_value_func(Scalar),
3245
+ doc="", group="Operators",
3246
+ require_original_output_arg=True,
2923
3247
  )
2924
3248
  add_builtin(
2925
3249
  "div",
@@ -2928,6 +3252,13 @@ add_builtin(
2928
3252
  doc="",
2929
3253
  group="Operators",
2930
3254
  )
3255
+ add_builtin(
3256
+ "div",
3257
+ input_types={"x": Scalar, "y": vector(length=Any, dtype=Scalar)},
3258
+ value_func=scalar_mul_value_func(vector(length=Any, dtype=Scalar)),
3259
+ doc="",
3260
+ group="Operators",
3261
+ )
2931
3262
  add_builtin(
2932
3263
  "div",
2933
3264
  input_types={"x": matrix(shape=(Any, Any), dtype=Scalar), "y": Scalar},
@@ -2935,6 +3266,13 @@ add_builtin(
2935
3266
  doc="",
2936
3267
  group="Operators",
2937
3268
  )
3269
+ add_builtin(
3270
+ "div",
3271
+ input_types={"x": Scalar, "y": matrix(shape=(Any, Any), dtype=Scalar)},
3272
+ value_func=scalar_mul_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
3273
+ doc="",
3274
+ group="Operators",
3275
+ )
2938
3276
  add_builtin(
2939
3277
  "div",
2940
3278
  input_types={"x": quaternion(dtype=Scalar), "y": Scalar},
@@ -2942,6 +3280,13 @@ add_builtin(
2942
3280
  doc="",
2943
3281
  group="Operators",
2944
3282
  )
3283
+ add_builtin(
3284
+ "div",
3285
+ input_types={"x": Scalar, "y": quaternion(dtype=Scalar)},
3286
+ value_func=scalar_mul_value_func(quaternion(dtype=Scalar)),
3287
+ doc="",
3288
+ group="Operators",
3289
+ )
2945
3290
 
2946
3291
  add_builtin(
2947
3292
  "floordiv",
@@ -2997,7 +3342,6 @@ add_builtin(
2997
3342
  )
2998
3343
 
2999
3344
  add_builtin("unot", input_types={"b": builtins.bool}, value_type=builtins.bool, doc="", group="Operators")
3000
- add_builtin("unot", input_types={"b": bool}, value_type=builtins.bool, doc="", group="Operators")
3001
3345
  for t in int_types:
3002
3346
  add_builtin("unot", input_types={"b": t}, value_type=builtins.bool, doc="", group="Operators")
3003
3347