warp-lang 1.2.2__py3-none-manylinux2014_aarch64.whl → 1.3.1__py3-none-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (193) hide show
  1. warp/__init__.py +8 -6
  2. warp/autograd.py +823 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +6 -2
  5. warp/builtins.py +1412 -888
  6. warp/codegen.py +503 -166
  7. warp/config.py +48 -18
  8. warp/context.py +400 -198
  9. warp/dlpack.py +8 -0
  10. warp/examples/assets/bunny.usd +0 -0
  11. warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
  12. warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
  13. warp/examples/benchmarks/benchmark_launches.py +1 -1
  14. warp/examples/core/example_cupy.py +78 -0
  15. warp/examples/fem/example_apic_fluid.py +17 -36
  16. warp/examples/fem/example_burgers.py +9 -18
  17. warp/examples/fem/example_convection_diffusion.py +7 -17
  18. warp/examples/fem/example_convection_diffusion_dg.py +27 -47
  19. warp/examples/fem/example_deformed_geometry.py +11 -22
  20. warp/examples/fem/example_diffusion.py +7 -18
  21. warp/examples/fem/example_diffusion_3d.py +24 -28
  22. warp/examples/fem/example_diffusion_mgpu.py +7 -14
  23. warp/examples/fem/example_magnetostatics.py +190 -0
  24. warp/examples/fem/example_mixed_elasticity.py +111 -80
  25. warp/examples/fem/example_navier_stokes.py +30 -34
  26. warp/examples/fem/example_nonconforming_contact.py +290 -0
  27. warp/examples/fem/example_stokes.py +17 -32
  28. warp/examples/fem/example_stokes_transfer.py +12 -21
  29. warp/examples/fem/example_streamlines.py +350 -0
  30. warp/examples/fem/utils.py +936 -0
  31. warp/fabric.py +5 -2
  32. warp/fem/__init__.py +13 -3
  33. warp/fem/cache.py +161 -11
  34. warp/fem/dirichlet.py +37 -28
  35. warp/fem/domain.py +105 -14
  36. warp/fem/field/__init__.py +14 -3
  37. warp/fem/field/field.py +454 -11
  38. warp/fem/field/nodal_field.py +33 -18
  39. warp/fem/geometry/deformed_geometry.py +50 -15
  40. warp/fem/geometry/hexmesh.py +12 -24
  41. warp/fem/geometry/nanogrid.py +106 -31
  42. warp/fem/geometry/quadmesh_2d.py +6 -11
  43. warp/fem/geometry/tetmesh.py +103 -61
  44. warp/fem/geometry/trimesh_2d.py +98 -47
  45. warp/fem/integrate.py +231 -186
  46. warp/fem/operator.py +14 -9
  47. warp/fem/quadrature/pic_quadrature.py +35 -9
  48. warp/fem/quadrature/quadrature.py +119 -32
  49. warp/fem/space/basis_space.py +98 -22
  50. warp/fem/space/collocated_function_space.py +3 -1
  51. warp/fem/space/function_space.py +7 -2
  52. warp/fem/space/grid_2d_function_space.py +3 -3
  53. warp/fem/space/grid_3d_function_space.py +4 -4
  54. warp/fem/space/hexmesh_function_space.py +3 -2
  55. warp/fem/space/nanogrid_function_space.py +12 -14
  56. warp/fem/space/partition.py +45 -47
  57. warp/fem/space/restriction.py +19 -16
  58. warp/fem/space/shape/cube_shape_function.py +91 -3
  59. warp/fem/space/shape/shape_function.py +7 -0
  60. warp/fem/space/shape/square_shape_function.py +32 -0
  61. warp/fem/space/shape/tet_shape_function.py +11 -7
  62. warp/fem/space/shape/triangle_shape_function.py +10 -1
  63. warp/fem/space/topology.py +116 -42
  64. warp/fem/types.py +8 -1
  65. warp/fem/utils.py +301 -83
  66. warp/native/array.h +16 -0
  67. warp/native/builtin.h +0 -15
  68. warp/native/cuda_util.cpp +14 -6
  69. warp/native/exports.h +1348 -1308
  70. warp/native/quat.h +79 -0
  71. warp/native/rand.h +27 -4
  72. warp/native/sparse.cpp +83 -81
  73. warp/native/sparse.cu +381 -453
  74. warp/native/vec.h +64 -0
  75. warp/native/volume.cpp +40 -49
  76. warp/native/volume_builder.cu +2 -3
  77. warp/native/volume_builder.h +12 -17
  78. warp/native/warp.cu +3 -3
  79. warp/native/warp.h +69 -59
  80. warp/render/render_opengl.py +17 -9
  81. warp/sim/articulation.py +117 -17
  82. warp/sim/collide.py +35 -29
  83. warp/sim/model.py +123 -18
  84. warp/sim/render.py +3 -1
  85. warp/sparse.py +867 -203
  86. warp/stubs.py +312 -541
  87. warp/tape.py +29 -1
  88. warp/tests/disabled_kinematics.py +1 -1
  89. warp/tests/test_adam.py +1 -1
  90. warp/tests/test_arithmetic.py +1 -1
  91. warp/tests/test_array.py +58 -1
  92. warp/tests/test_array_reduce.py +1 -1
  93. warp/tests/test_async.py +1 -1
  94. warp/tests/test_atomic.py +1 -1
  95. warp/tests/test_bool.py +1 -1
  96. warp/tests/test_builtins_resolution.py +1 -1
  97. warp/tests/test_bvh.py +6 -1
  98. warp/tests/test_closest_point_edge_edge.py +1 -1
  99. warp/tests/test_codegen.py +91 -1
  100. warp/tests/test_compile_consts.py +1 -1
  101. warp/tests/test_conditional.py +1 -1
  102. warp/tests/test_copy.py +1 -1
  103. warp/tests/test_ctypes.py +1 -1
  104. warp/tests/test_dense.py +1 -1
  105. warp/tests/test_devices.py +1 -1
  106. warp/tests/test_dlpack.py +1 -1
  107. warp/tests/test_examples.py +33 -4
  108. warp/tests/test_fabricarray.py +5 -2
  109. warp/tests/test_fast_math.py +1 -1
  110. warp/tests/test_fem.py +213 -6
  111. warp/tests/test_fp16.py +1 -1
  112. warp/tests/test_func.py +1 -1
  113. warp/tests/test_future_annotations.py +90 -0
  114. warp/tests/test_generics.py +1 -1
  115. warp/tests/test_grad.py +1 -1
  116. warp/tests/test_grad_customs.py +1 -1
  117. warp/tests/test_grad_debug.py +247 -0
  118. warp/tests/test_hash_grid.py +6 -1
  119. warp/tests/test_implicit_init.py +354 -0
  120. warp/tests/test_import.py +1 -1
  121. warp/tests/test_indexedarray.py +1 -1
  122. warp/tests/test_intersect.py +1 -1
  123. warp/tests/test_jax.py +1 -1
  124. warp/tests/test_large.py +1 -1
  125. warp/tests/test_launch.py +1 -1
  126. warp/tests/test_lerp.py +1 -1
  127. warp/tests/test_linear_solvers.py +1 -1
  128. warp/tests/test_lvalue.py +1 -1
  129. warp/tests/test_marching_cubes.py +5 -2
  130. warp/tests/test_mat.py +34 -35
  131. warp/tests/test_mat_lite.py +2 -1
  132. warp/tests/test_mat_scalar_ops.py +1 -1
  133. warp/tests/test_math.py +1 -1
  134. warp/tests/test_matmul.py +20 -16
  135. warp/tests/test_matmul_lite.py +1 -1
  136. warp/tests/test_mempool.py +1 -1
  137. warp/tests/test_mesh.py +5 -2
  138. warp/tests/test_mesh_query_aabb.py +1 -1
  139. warp/tests/test_mesh_query_point.py +1 -1
  140. warp/tests/test_mesh_query_ray.py +1 -1
  141. warp/tests/test_mlp.py +1 -1
  142. warp/tests/test_model.py +1 -1
  143. warp/tests/test_module_hashing.py +77 -1
  144. warp/tests/test_modules_lite.py +1 -1
  145. warp/tests/test_multigpu.py +1 -1
  146. warp/tests/test_noise.py +1 -1
  147. warp/tests/test_operators.py +1 -1
  148. warp/tests/test_options.py +1 -1
  149. warp/tests/test_overwrite.py +542 -0
  150. warp/tests/test_peer.py +1 -1
  151. warp/tests/test_pinned.py +1 -1
  152. warp/tests/test_print.py +1 -1
  153. warp/tests/test_quat.py +15 -1
  154. warp/tests/test_rand.py +1 -1
  155. warp/tests/test_reload.py +1 -1
  156. warp/tests/test_rounding.py +1 -1
  157. warp/tests/test_runlength_encode.py +1 -1
  158. warp/tests/test_scalar_ops.py +95 -0
  159. warp/tests/test_sim_grad.py +1 -1
  160. warp/tests/test_sim_kinematics.py +1 -1
  161. warp/tests/test_smoothstep.py +1 -1
  162. warp/tests/test_sparse.py +82 -15
  163. warp/tests/test_spatial.py +1 -1
  164. warp/tests/test_special_values.py +2 -11
  165. warp/tests/test_streams.py +11 -1
  166. warp/tests/test_struct.py +1 -1
  167. warp/tests/test_tape.py +1 -1
  168. warp/tests/test_torch.py +194 -1
  169. warp/tests/test_transient_module.py +1 -1
  170. warp/tests/test_types.py +1 -1
  171. warp/tests/test_utils.py +1 -1
  172. warp/tests/test_vec.py +15 -63
  173. warp/tests/test_vec_lite.py +2 -1
  174. warp/tests/test_vec_scalar_ops.py +65 -1
  175. warp/tests/test_verify_fp.py +1 -1
  176. warp/tests/test_volume.py +28 -2
  177. warp/tests/test_volume_write.py +1 -1
  178. warp/tests/unittest_serial.py +1 -1
  179. warp/tests/unittest_suites.py +9 -1
  180. warp/tests/walkthrough_debug.py +1 -1
  181. warp/thirdparty/unittest_parallel.py +2 -5
  182. warp/torch.py +103 -41
  183. warp/types.py +341 -224
  184. warp/utils.py +11 -2
  185. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/METADATA +99 -46
  186. warp_lang-1.3.1.dist-info/RECORD +368 -0
  187. warp/examples/fem/bsr_utils.py +0 -378
  188. warp/examples/fem/mesh_utils.py +0 -133
  189. warp/examples/fem/plot_utils.py +0 -292
  190. warp_lang-1.2.2.dist-info/RECORD +0 -359
  191. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/LICENSE.md +0 -0
  192. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/WHEEL +0 -0
  193. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/top_level.txt +0 -0
warp/builtins.py CHANGED
@@ -5,25 +5,40 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
  import builtins
8
- from typing import Any, Callable, Tuple
8
+ from typing import Any, Callable, Mapping, Sequence
9
9
 
10
- from warp.codegen import Reference
10
+ from warp.codegen import Reference, Var, strip_reference
11
11
  from warp.types import *
12
12
 
13
13
  from .context import add_builtin
14
14
 
15
15
 
16
- def sametypes(arg_types):
17
- return all(types_equal(arg_types[0], t) for t in arg_types[1:])
16
+ def seq_check_equal(seq_1, seq_2):
17
+ if not isinstance(seq_1, Sequence) or not isinstance(seq_2, Sequence):
18
+ return False
18
19
 
20
+ if len(seq_1) != len(seq_2):
21
+ return False
19
22
 
20
- def sametype_value_func(default):
21
- def fn(arg_types, kwds, _):
23
+ return all(x == y for x, y in zip(seq_1, seq_2))
24
+
25
+
26
+ def sametypes(arg_types: Mapping[str, Any]):
27
+ arg_types_iter = iter(arg_types.values())
28
+ arg_type_0 = next(arg_types_iter)
29
+ return all(types_equal(arg_type_0, t) for t in arg_types_iter)
30
+
31
+
32
+ def sametypes_create_value_func(default):
33
+ def fn(arg_types, arg_values):
22
34
  if arg_types is None:
23
35
  return default
36
+
24
37
  if not sametypes(arg_types):
25
- raise RuntimeError(f"Input types must be the same, found: {[type_repr(t) for t in arg_types]}")
26
- return arg_types[0]
38
+ raise RuntimeError(f"Input types must be the same, got {[type_repr(t) for t in arg_types.values()]}")
39
+
40
+ arg_type_0 = next(iter(arg_types.values()))
41
+ return arg_type_0
27
42
 
28
43
  return fn
29
44
 
@@ -33,39 +48,39 @@ def sametype_value_func(default):
33
48
 
34
49
  add_builtin(
35
50
  "min",
36
- input_types={"x": Scalar, "y": Scalar},
37
- value_func=sametype_value_func(Scalar),
51
+ input_types={"a": Scalar, "b": Scalar},
52
+ value_func=sametypes_create_value_func(Scalar),
38
53
  doc="Return the minimum of two scalars.",
39
54
  group="Scalar Math",
40
55
  )
41
56
 
42
57
  add_builtin(
43
58
  "max",
44
- input_types={"x": Scalar, "y": Scalar},
45
- value_func=sametype_value_func(Scalar),
59
+ input_types={"a": Scalar, "b": Scalar},
60
+ value_func=sametypes_create_value_func(Scalar),
46
61
  doc="Return the maximum of two scalars.",
47
62
  group="Scalar Math",
48
63
  )
49
64
 
50
65
  add_builtin(
51
66
  "clamp",
52
- input_types={"x": Scalar, "a": Scalar, "b": Scalar},
53
- value_func=sametype_value_func(Scalar),
54
- doc="Clamp the value of ``x`` to the range [a, b].",
67
+ input_types={"x": Scalar, "low": Scalar, "high": Scalar},
68
+ value_func=sametypes_create_value_func(Scalar),
69
+ doc="Clamp the value of ``x`` to the range [low, high].",
55
70
  group="Scalar Math",
56
71
  )
57
72
 
58
73
  add_builtin(
59
74
  "abs",
60
75
  input_types={"x": Scalar},
61
- value_func=sametype_value_func(Scalar),
76
+ value_func=sametypes_create_value_func(Scalar),
62
77
  doc="Return the absolute value of ``x``.",
63
78
  group="Scalar Math",
64
79
  )
65
80
  add_builtin(
66
81
  "sign",
67
82
  input_types={"x": Scalar},
68
- value_func=sametype_value_func(Scalar),
83
+ value_func=sametypes_create_value_func(Scalar),
69
84
  doc="Return -1 if ``x`` < 0, return 1 otherwise.",
70
85
  group="Scalar Math",
71
86
  )
@@ -73,14 +88,14 @@ add_builtin(
73
88
  add_builtin(
74
89
  "step",
75
90
  input_types={"x": Scalar},
76
- value_func=sametype_value_func(Scalar),
91
+ value_func=sametypes_create_value_func(Scalar),
77
92
  doc="Return 1.0 if ``x`` < 0.0, return 0.0 otherwise.",
78
93
  group="Scalar Math",
79
94
  )
80
95
  add_builtin(
81
96
  "nonzero",
82
97
  input_types={"x": Scalar},
83
- value_func=sametype_value_func(Scalar),
98
+ value_func=sametypes_create_value_func(Scalar),
84
99
  doc="Return 1.0 if ``x`` is not equal to zero, return 0.0 otherwise.",
85
100
  group="Scalar Math",
86
101
  )
@@ -88,35 +103,35 @@ add_builtin(
88
103
  add_builtin(
89
104
  "sin",
90
105
  input_types={"x": Float},
91
- value_func=sametype_value_func(Float),
106
+ value_func=sametypes_create_value_func(Float),
92
107
  doc="Return the sine of ``x`` in radians.",
93
108
  group="Scalar Math",
94
109
  )
95
110
  add_builtin(
96
111
  "cos",
97
112
  input_types={"x": Float},
98
- value_func=sametype_value_func(Float),
113
+ value_func=sametypes_create_value_func(Float),
99
114
  doc="Return the cosine of ``x`` in radians.",
100
115
  group="Scalar Math",
101
116
  )
102
117
  add_builtin(
103
118
  "acos",
104
119
  input_types={"x": Float},
105
- value_func=sametype_value_func(Float),
120
+ value_func=sametypes_create_value_func(Float),
106
121
  doc="Return arccos of ``x`` in radians. Inputs are automatically clamped to [-1.0, 1.0].",
107
122
  group="Scalar Math",
108
123
  )
109
124
  add_builtin(
110
125
  "asin",
111
126
  input_types={"x": Float},
112
- value_func=sametype_value_func(Float),
127
+ value_func=sametypes_create_value_func(Float),
113
128
  doc="Return arcsin of ``x`` in radians. Inputs are automatically clamped to [-1.0, 1.0].",
114
129
  group="Scalar Math",
115
130
  )
116
131
  add_builtin(
117
132
  "sqrt",
118
133
  input_types={"x": Float},
119
- value_func=sametype_value_func(Float),
134
+ value_func=sametypes_create_value_func(Float),
120
135
  doc="Return the square root of ``x``, where ``x`` is positive.",
121
136
  group="Scalar Math",
122
137
  require_original_output_arg=True,
@@ -124,7 +139,7 @@ add_builtin(
124
139
  add_builtin(
125
140
  "cbrt",
126
141
  input_types={"x": Float},
127
- value_func=sametype_value_func(Float),
142
+ value_func=sametypes_create_value_func(Float),
128
143
  doc="Return the cube root of ``x``.",
129
144
  group="Scalar Math",
130
145
  require_original_output_arg=True,
@@ -132,42 +147,42 @@ add_builtin(
132
147
  add_builtin(
133
148
  "tan",
134
149
  input_types={"x": Float},
135
- value_func=sametype_value_func(Float),
150
+ value_func=sametypes_create_value_func(Float),
136
151
  doc="Return the tangent of ``x`` in radians.",
137
152
  group="Scalar Math",
138
153
  )
139
154
  add_builtin(
140
155
  "atan",
141
156
  input_types={"x": Float},
142
- value_func=sametype_value_func(Float),
157
+ value_func=sametypes_create_value_func(Float),
143
158
  doc="Return the arctangent of ``x`` in radians.",
144
159
  group="Scalar Math",
145
160
  )
146
161
  add_builtin(
147
162
  "atan2",
148
163
  input_types={"y": Float, "x": Float},
149
- value_func=sametype_value_func(Float),
164
+ value_func=sametypes_create_value_func(Float),
150
165
  doc="Return the 2-argument arctangent, atan2, of the point ``(x, y)`` in radians.",
151
166
  group="Scalar Math",
152
167
  )
153
168
  add_builtin(
154
169
  "sinh",
155
170
  input_types={"x": Float},
156
- value_func=sametype_value_func(Float),
171
+ value_func=sametypes_create_value_func(Float),
157
172
  doc="Return the sinh of ``x``.",
158
173
  group="Scalar Math",
159
174
  )
160
175
  add_builtin(
161
176
  "cosh",
162
177
  input_types={"x": Float},
163
- value_func=sametype_value_func(Float),
178
+ value_func=sametypes_create_value_func(Float),
164
179
  doc="Return the cosh of ``x``.",
165
180
  group="Scalar Math",
166
181
  )
167
182
  add_builtin(
168
183
  "tanh",
169
184
  input_types={"x": Float},
170
- value_func=sametype_value_func(Float),
185
+ value_func=sametypes_create_value_func(Float),
171
186
  doc="Return the tanh of ``x``.",
172
187
  group="Scalar Math",
173
188
  require_original_output_arg=True,
@@ -175,14 +190,14 @@ add_builtin(
175
190
  add_builtin(
176
191
  "degrees",
177
192
  input_types={"x": Float},
178
- value_func=sametype_value_func(Float),
193
+ value_func=sametypes_create_value_func(Float),
179
194
  doc="Convert ``x`` from radians into degrees.",
180
195
  group="Scalar Math",
181
196
  )
182
197
  add_builtin(
183
198
  "radians",
184
199
  input_types={"x": Float},
185
- value_func=sametype_value_func(Float),
200
+ value_func=sametypes_create_value_func(Float),
186
201
  doc="Convert ``x`` from degrees into radians.",
187
202
  group="Scalar Math",
188
203
  )
@@ -190,28 +205,28 @@ add_builtin(
190
205
  add_builtin(
191
206
  "log",
192
207
  input_types={"x": Float},
193
- value_func=sametype_value_func(Float),
208
+ value_func=sametypes_create_value_func(Float),
194
209
  doc="Return the natural logarithm (base-e) of ``x``, where ``x`` is positive.",
195
210
  group="Scalar Math",
196
211
  )
197
212
  add_builtin(
198
213
  "log2",
199
214
  input_types={"x": Float},
200
- value_func=sametype_value_func(Float),
215
+ value_func=sametypes_create_value_func(Float),
201
216
  doc="Return the binary logarithm (base-2) of ``x``, where ``x`` is positive.",
202
217
  group="Scalar Math",
203
218
  )
204
219
  add_builtin(
205
220
  "log10",
206
221
  input_types={"x": Float},
207
- value_func=sametype_value_func(Float),
222
+ value_func=sametypes_create_value_func(Float),
208
223
  doc="Return the common logarithm (base-10) of ``x``, where ``x`` is positive.",
209
224
  group="Scalar Math",
210
225
  )
211
226
  add_builtin(
212
227
  "exp",
213
228
  input_types={"x": Float},
214
- value_func=sametype_value_func(Float),
229
+ value_func=sametypes_create_value_func(Float),
215
230
  doc="Return the value of the exponential function :math:`e^x`.",
216
231
  group="Scalar Math",
217
232
  require_original_output_arg=True,
@@ -219,7 +234,7 @@ add_builtin(
219
234
  add_builtin(
220
235
  "pow",
221
236
  input_types={"x": Float, "y": Float},
222
- value_func=sametype_value_func(Float),
237
+ value_func=sametypes_create_value_func(Float),
223
238
  doc="Return the result of ``x`` raised to power of ``y``.",
224
239
  group="Scalar Math",
225
240
  require_original_output_arg=True,
@@ -228,7 +243,7 @@ add_builtin(
228
243
  add_builtin(
229
244
  "round",
230
245
  input_types={"x": Float},
231
- value_func=sametype_value_func(Float),
246
+ value_func=sametypes_create_value_func(Float),
232
247
  group="Scalar Math",
233
248
  doc="""Return the nearest integer value to ``x``, rounding halfway cases away from zero.
234
249
 
@@ -239,7 +254,7 @@ add_builtin(
239
254
  add_builtin(
240
255
  "rint",
241
256
  input_types={"x": Float},
242
- value_func=sametype_value_func(Float),
257
+ value_func=sametypes_create_value_func(Float),
243
258
  group="Scalar Math",
244
259
  doc="""Return the nearest integer value to ``x``, rounding halfway cases to nearest even integer.
245
260
 
@@ -249,19 +264,19 @@ add_builtin(
249
264
  add_builtin(
250
265
  "trunc",
251
266
  input_types={"x": Float},
252
- value_func=sametype_value_func(Float),
267
+ value_func=sametypes_create_value_func(Float),
253
268
  group="Scalar Math",
254
269
  doc="""Return the nearest integer that is closer to zero than ``x``.
255
270
 
256
271
  In other words, it discards the fractional part of ``x``.
257
- It is similar to casting ``float(int(x))``, but preserves the negative sign when x is in the range [-0.0, -1.0).
272
+ It is similar to casting ``float(int(a))``, but preserves the negative sign when ``x`` is in the range [-0.0, -1.0).
258
273
  Equivalent to :func:`numpy.trunc()` and :func:`numpy.fix()`.""",
259
274
  )
260
275
 
261
276
  add_builtin(
262
277
  "floor",
263
278
  input_types={"x": Float},
264
- value_func=sametype_value_func(Float),
279
+ value_func=sametypes_create_value_func(Float),
265
280
  group="Scalar Math",
266
281
  doc="""Return the largest integer that is less than or equal to ``x``.""",
267
282
  )
@@ -269,7 +284,7 @@ add_builtin(
269
284
  add_builtin(
270
285
  "ceil",
271
286
  input_types={"x": Float},
272
- value_func=sametype_value_func(Float),
287
+ value_func=sametypes_create_value_func(Float),
273
288
  group="Scalar Math",
274
289
  doc="""Return the smallest integer that is greater than or equal to ``x``.""",
275
290
  )
@@ -277,127 +292,145 @@ add_builtin(
277
292
  add_builtin(
278
293
  "frac",
279
294
  input_types={"x": Float},
280
- value_func=sametype_value_func(Float),
295
+ value_func=sametypes_create_value_func(Float),
281
296
  group="Scalar Math",
282
- doc="""Retrieve the fractional part of x.
297
+ doc="""Retrieve the fractional part of ``x``.
283
298
 
284
- In other words, it discards the integer part of x and is equivalent to ``x - trunc(x)``.""",
299
+ In other words, it discards the integer part of ``x`` and is equivalent to ``x - trunc(x)``.""",
285
300
  )
286
301
 
287
302
  add_builtin(
288
303
  "isfinite",
289
- input_types={"x": Scalar},
304
+ input_types={"a": Scalar},
290
305
  value_type=builtins.bool,
291
306
  group="Scalar Math",
292
- doc="""Return ``True`` if x is a finite number, otherwise return ``False``.""",
307
+ doc="""Return ``True`` if ``a`` is a finite number, otherwise return ``False``.""",
293
308
  )
294
309
  add_builtin(
295
310
  "isfinite",
296
- input_types={"x": vector(length=Any, dtype=Scalar)},
311
+ input_types={"a": vector(length=Any, dtype=Scalar)},
297
312
  value_type=builtins.bool,
298
313
  group="Vector Math",
299
- doc="Return ``True`` if all elements of the vector ``x`` are finite, otherwise return ``False``.",
314
+ doc="Return ``True`` if all elements of the vector ``a`` are finite, otherwise return ``False``.",
300
315
  )
301
316
  add_builtin(
302
317
  "isfinite",
303
- input_types={"x": quaternion(dtype=Scalar)},
318
+ input_types={"a": quaternion(dtype=Scalar)},
304
319
  value_type=builtins.bool,
305
320
  group="Vector Math",
306
- doc="Return ``True`` if all elements of the quaternion ``x`` are finite, otherwise return ``False``.",
321
+ doc="Return ``True`` if all elements of the quaternion ``a`` are finite, otherwise return ``False``.",
307
322
  )
308
323
  add_builtin(
309
324
  "isfinite",
310
- input_types={"m": matrix(shape=(Any, Any), dtype=Scalar)},
325
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar)},
311
326
  value_type=builtins.bool,
312
327
  group="Vector Math",
313
- doc="Return ``True`` if all elements of the matrix ``m`` are finite, otherwise return ``False``.",
328
+ doc="Return ``True`` if all elements of the matrix ``a`` are finite, otherwise return ``False``.",
314
329
  )
315
330
 
316
331
  add_builtin(
317
332
  "isnan",
318
- input_types={"x": Scalar},
333
+ input_types={"a": Scalar},
319
334
  value_type=builtins.bool,
320
- doc="Return ``True`` if ``x`` is NaN, otherwise return ``False``.",
335
+ doc="Return ``True`` if ``a`` is NaN, otherwise return ``False``.",
321
336
  group="Scalar Math",
322
337
  )
323
338
  add_builtin(
324
339
  "isnan",
325
- input_types={"x": vector(length=Any, dtype=Scalar)},
340
+ input_types={"a": vector(length=Any, dtype=Scalar)},
326
341
  value_type=builtins.bool,
327
342
  group="Vector Math",
328
- doc="Return ``True`` if any element of the vector ``x`` is NaN, otherwise return ``False``.",
343
+ doc="Return ``True`` if any element of the vector ``a`` is NaN, otherwise return ``False``.",
329
344
  )
330
345
  add_builtin(
331
346
  "isnan",
332
- input_types={"x": quaternion(dtype=Scalar)},
347
+ input_types={"a": quaternion(dtype=Scalar)},
333
348
  value_type=builtins.bool,
334
349
  group="Vector Math",
335
- doc="Return ``True`` if any element of the quaternion ``x`` is NaN, otherwise return ``False``.",
350
+ doc="Return ``True`` if any element of the quaternion ``a`` is NaN, otherwise return ``False``.",
336
351
  )
337
352
  add_builtin(
338
353
  "isnan",
339
- input_types={"m": matrix(shape=(Any, Any), dtype=Scalar)},
354
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar)},
340
355
  value_type=builtins.bool,
341
356
  group="Vector Math",
342
- doc="Return ``True`` if any element of the matrix ``m`` is NaN, otherwise return ``False``.",
357
+ doc="Return ``True`` if any element of the matrix ``a`` is NaN, otherwise return ``False``.",
343
358
  )
344
359
 
345
360
  add_builtin(
346
361
  "isinf",
347
- input_types={"x": Scalar},
362
+ input_types={"a": Scalar},
348
363
  value_type=builtins.bool,
349
364
  group="Scalar Math",
350
- doc="""Return ``True`` if x is positive or negative infinity, otherwise return ``False``.""",
365
+ doc="""Return ``True`` if ``a`` is positive or negative infinity, otherwise return ``False``.""",
351
366
  )
352
367
  add_builtin(
353
368
  "isinf",
354
- input_types={"x": vector(length=Any, dtype=Scalar)},
369
+ input_types={"a": vector(length=Any, dtype=Scalar)},
355
370
  value_type=builtins.bool,
356
371
  group="Vector Math",
357
- doc="Return ``True`` if any element of the vector ``x`` is positive or negative infinity, otherwise return ``False``.",
372
+ doc="Return ``True`` if any element of the vector ``a`` is positive or negative infinity, otherwise return ``False``.",
358
373
  )
359
374
  add_builtin(
360
375
  "isinf",
361
- input_types={"x": quaternion(dtype=Scalar)},
376
+ input_types={"a": quaternion(dtype=Scalar)},
362
377
  value_type=builtins.bool,
363
378
  group="Vector Math",
364
- doc="Return ``True`` if any element of the quaternion ``x`` is positive or negative infinity, otherwise return ``False``.",
379
+ doc="Return ``True`` if any element of the quaternion ``a`` is positive or negative infinity, otherwise return ``False``.",
365
380
  )
366
381
  add_builtin(
367
382
  "isinf",
368
- input_types={"m": matrix(shape=(Any, Any), dtype=Scalar)},
383
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar)},
369
384
  value_type=builtins.bool,
370
385
  group="Vector Math",
371
- doc="Return ``True`` if any element of the matrix ``m`` is positive or negative infinity, otherwise return ``False``.",
386
+ doc="Return ``True`` if any element of the matrix ``a`` is positive or negative infinity, otherwise return ``False``.",
372
387
  )
373
388
 
374
389
 
375
- def infer_scalar_type(arg_types):
390
+ def scalar_infer_type(arg_types: Mapping[str, type]):
376
391
  if arg_types is None:
377
392
  return Scalar
378
393
 
379
- def iterate_scalar_types(arg_types):
380
- for t in arg_types:
381
- if hasattr(t, "_wp_scalar_type_"):
382
- yield t._wp_scalar_type_
383
- elif t in scalar_and_bool_types:
384
- yield t
394
+ if isinstance(arg_types, Mapping):
395
+ arg_types = tuple(arg_types.values())
385
396
 
386
- scalarTypes = set(iterate_scalar_types(arg_types))
387
- if len(scalarTypes) > 1:
397
+ scalar_types = set()
398
+ for t in arg_types:
399
+ t = strip_reference(t)
400
+ if hasattr(t, "_wp_scalar_type_"):
401
+ scalar_types.add(t._wp_scalar_type_)
402
+ elif t in scalar_and_bool_types:
403
+ scalar_types.add(t)
404
+
405
+ if len(scalar_types) > 1:
388
406
  raise RuntimeError(
389
- f"Couldn't figure out return type as arguments have multiple precisions: {list(scalarTypes)}"
407
+ f"Couldn't figure out return type as arguments have multiple precisions: {list(scalar_types)}"
390
408
  )
391
- return list(scalarTypes)[0]
409
+ return next(iter(scalar_types))
392
410
 
393
411
 
394
- def sametype_scalar_value_func(arg_types, kwds, _):
412
+ def scalar_sametypes_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
395
413
  if arg_types is None:
396
414
  return Scalar
415
+
397
416
  if not sametypes(arg_types):
398
- raise RuntimeError(f"Input types must be exactly the same, {list(arg_types)}")
417
+ raise RuntimeError(f"Input types must be exactly the same, got {[type_repr(t) for t in arg_types.values()]}")
418
+
419
+ return scalar_infer_type(arg_types)
420
+
399
421
 
400
- return infer_scalar_type(arg_types)
422
+ def float_infer_type(arg_types: Mapping[str, type]):
423
+ if arg_types is None:
424
+ return Float
425
+
426
+ return scalar_infer_type(arg_types)
427
+
428
+
429
+ def float_sametypes_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
430
+ if arg_types is None:
431
+ return Float
432
+
433
+ return scalar_sametypes_value_func(arg_types, arg_values)
401
434
 
402
435
 
403
436
  # ---------------------------------
@@ -405,290 +438,312 @@ def sametype_scalar_value_func(arg_types, kwds, _):
405
438
 
406
439
  add_builtin(
407
440
  "dot",
408
- input_types={"x": vector(length=Any, dtype=Scalar), "y": vector(length=Any, dtype=Scalar)},
441
+ input_types={"a": vector(length=Any, dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
409
442
  constraint=sametypes,
410
- value_func=sametype_scalar_value_func,
443
+ value_func=scalar_sametypes_value_func,
411
444
  group="Vector Math",
412
445
  doc="Compute the dot product between two vectors.",
413
446
  )
414
447
  add_builtin(
415
448
  "ddot",
416
- input_types={"x": matrix(shape=(Any, Any), dtype=Scalar), "y": matrix(shape=(Any, Any), dtype=Scalar)},
449
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "b": matrix(shape=(Any, Any), dtype=Scalar)},
417
450
  constraint=sametypes,
418
- value_func=sametype_scalar_value_func,
451
+ value_func=scalar_sametypes_value_func,
419
452
  group="Vector Math",
420
453
  doc="Compute the double dot product between two matrices.",
421
454
  )
422
455
 
423
456
  add_builtin(
424
457
  "min",
425
- input_types={"x": vector(length=Any, dtype=Scalar), "y": vector(length=Any, dtype=Scalar)},
458
+ input_types={"a": vector(length=Any, dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
426
459
  constraint=sametypes,
427
- value_func=sametype_value_func(vector(length=Any, dtype=Scalar)),
460
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
428
461
  doc="Return the element-wise minimum of two vectors.",
429
462
  group="Vector Math",
430
463
  )
431
464
  add_builtin(
432
465
  "max",
433
- input_types={"x": vector(length=Any, dtype=Scalar), "y": vector(length=Any, dtype=Scalar)},
466
+ input_types={"a": vector(length=Any, dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
434
467
  constraint=sametypes,
435
- value_func=sametype_value_func(vector(length=Any, dtype=Scalar)),
468
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
436
469
  doc="Return the element-wise maximum of two vectors.",
437
470
  group="Vector Math",
438
471
  )
439
472
 
440
473
  add_builtin(
441
474
  "min",
442
- input_types={"v": vector(length=Any, dtype=Scalar)},
443
- value_func=sametype_scalar_value_func,
444
- doc="Return the minimum element of a vector ``v``.",
475
+ input_types={"a": vector(length=Any, dtype=Scalar)},
476
+ value_func=scalar_sametypes_value_func,
477
+ doc="Return the minimum element of a vector ``a``.",
445
478
  group="Vector Math",
446
479
  )
447
480
  add_builtin(
448
481
  "max",
449
- input_types={"v": vector(length=Any, dtype=Scalar)},
450
- value_func=sametype_scalar_value_func,
451
- doc="Return the maximum element of a vector ``v``.",
482
+ input_types={"a": vector(length=Any, dtype=Scalar)},
483
+ value_func=scalar_sametypes_value_func,
484
+ doc="Return the maximum element of a vector ``a``.",
452
485
  group="Vector Math",
453
486
  )
454
487
 
455
488
  add_builtin(
456
489
  "argmin",
457
- input_types={"v": vector(length=Any, dtype=Scalar)},
458
- value_func=lambda arg_types, kwds, _: warp.uint32,
459
- doc="Return the index of the minimum element of a vector ``v``.",
490
+ input_types={"a": vector(length=Any, dtype=Scalar)},
491
+ value_func=lambda arg_types, arg_values: warp.uint32,
492
+ doc="Return the index of the minimum element of a vector ``a``.",
460
493
  group="Vector Math",
461
494
  missing_grad=True,
462
495
  )
463
496
  add_builtin(
464
497
  "argmax",
465
- input_types={"v": vector(length=Any, dtype=Scalar)},
466
- value_func=lambda arg_types, kwds, _: warp.uint32,
467
- doc="Return the index of the maximum element of a vector ``v``.",
498
+ input_types={"a": vector(length=Any, dtype=Scalar)},
499
+ value_func=lambda arg_types, arg_values: warp.uint32,
500
+ doc="Return the index of the maximum element of a vector ``a``.",
468
501
  group="Vector Math",
469
502
  missing_grad=True,
470
503
  )
471
504
 
505
+ add_builtin(
506
+ "abs",
507
+ input_types={"x": vector(length=Any, dtype=Scalar)},
508
+ constraint=sametypes,
509
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
510
+ doc="Return the absolute values of the elements of ``x``.",
511
+ group="Vector Math",
512
+ )
513
+
514
+ add_builtin(
515
+ "sign",
516
+ input_types={"x": vector(length=Any, dtype=Scalar)},
517
+ constraint=sametypes,
518
+ value_func=sametypes_create_value_func(Scalar),
519
+ doc="Return -1 for the negative elements of ``x``, and 1 otherwise.",
520
+ group="Vector Math",
521
+ )
472
522
 
473
- def value_func_outer(arg_types, kwds, _):
523
+
524
+ def outer_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
474
525
  if arg_types is None:
475
526
  return matrix(shape=(Any, Any), dtype=Scalar)
476
527
 
477
- scalarType = infer_scalar_type(arg_types)
478
- vectorLengths = [t._length_ for t in arg_types]
528
+ arg_type_values = tuple(arg_types.values())
529
+ scalarType = scalar_infer_type(arg_type_values)
530
+ vectorLengths = tuple(t._length_ for t in arg_type_values)
479
531
  return matrix(shape=(vectorLengths), dtype=scalarType)
480
532
 
481
533
 
482
534
  add_builtin(
483
535
  "outer",
484
- input_types={"x": vector(length=Any, dtype=Scalar), "y": vector(length=Any, dtype=Scalar)},
485
- value_func=value_func_outer,
536
+ input_types={"a": vector(length=Any, dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
537
+ value_func=outer_value_func,
486
538
  group="Vector Math",
487
- doc="Compute the outer product ``x*y^T`` for two vectors.",
539
+ doc="Compute the outer product ``a*b^T`` for two vectors.",
488
540
  )
489
541
 
490
542
  add_builtin(
491
543
  "cross",
492
- input_types={"x": vector(length=3, dtype=Scalar), "y": vector(length=3, dtype=Scalar)},
493
- value_func=sametype_value_func(vector(length=3, dtype=Scalar)),
544
+ input_types={"a": vector(length=3, dtype=Scalar), "b": vector(length=3, dtype=Scalar)},
545
+ value_func=sametypes_create_value_func(vector(length=3, dtype=Scalar)),
494
546
  group="Vector Math",
495
547
  doc="Compute the cross product of two 3D vectors.",
496
548
  )
497
549
  add_builtin(
498
550
  "skew",
499
- input_types={"x": vector(length=3, dtype=Scalar)},
500
- value_func=lambda arg_types, kwds, _: matrix(shape=(3, 3), dtype=arg_types[0]._wp_scalar_type_),
551
+ input_types={"vec": vector(length=3, dtype=Scalar)},
552
+ value_func=lambda arg_types, arg_values: matrix(shape=(3, 3), dtype=arg_types["vec"]._wp_scalar_type_),
501
553
  group="Vector Math",
502
- doc="Compute the skew-symmetric 3x3 matrix for a 3D vector ``x``.",
554
+ doc="Compute the skew-symmetric 3x3 matrix for a 3D vector ``vec``.",
503
555
  )
504
556
 
505
557
  add_builtin(
506
558
  "length",
507
- input_types={"x": vector(length=Any, dtype=Float)},
508
- value_func=sametype_scalar_value_func,
559
+ input_types={"a": vector(length=Any, dtype=Float)},
560
+ value_func=float_sametypes_value_func,
509
561
  group="Vector Math",
510
- doc="Compute the length of a floating-point vector ``x``.",
562
+ doc="Compute the length of a floating-point vector ``a``.",
511
563
  require_original_output_arg=True,
512
564
  )
513
565
  add_builtin(
514
566
  "length",
515
- input_types={"x": quaternion(dtype=Float)},
516
- value_func=sametype_scalar_value_func,
567
+ input_types={"a": quaternion(dtype=Float)},
568
+ value_func=float_sametypes_value_func,
517
569
  group="Vector Math",
518
- doc="Compute the length of a quaternion ``x``.",
570
+ doc="Compute the length of a quaternion ``a``.",
519
571
  require_original_output_arg=True,
520
572
  )
521
573
  add_builtin(
522
574
  "length_sq",
523
- input_types={"x": vector(length=Any, dtype=Scalar)},
524
- value_func=sametype_scalar_value_func,
575
+ input_types={"a": vector(length=Any, dtype=Scalar)},
576
+ value_func=scalar_sametypes_value_func,
525
577
  group="Vector Math",
526
- doc="Compute the squared length of a vector ``x``.",
578
+ doc="Compute the squared length of a vector ``a``.",
527
579
  )
528
580
  add_builtin(
529
581
  "length_sq",
530
- input_types={"x": quaternion(dtype=Scalar)},
531
- value_func=sametype_scalar_value_func,
582
+ input_types={"a": quaternion(dtype=Scalar)},
583
+ value_func=scalar_sametypes_value_func,
532
584
  group="Vector Math",
533
- doc="Compute the squared length of a quaternion ``x``.",
585
+ doc="Compute the squared length of a quaternion ``a``.",
534
586
  )
535
587
  add_builtin(
536
588
  "normalize",
537
- input_types={"x": vector(length=Any, dtype=Float)},
538
- value_func=sametype_value_func(vector(length=Any, dtype=Scalar)),
589
+ input_types={"a": vector(length=Any, dtype=Float)},
590
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Float)),
539
591
  group="Vector Math",
540
- doc="Compute the normalized value of ``x``. If ``length(x)`` is 0 then the zero vector is returned.",
592
+ doc="Compute the normalized value of ``a``. If ``length(a)`` is 0 then the zero vector is returned.",
541
593
  require_original_output_arg=True,
542
594
  )
543
595
  add_builtin(
544
596
  "normalize",
545
- input_types={"x": quaternion(dtype=Float)},
546
- value_func=sametype_value_func(quaternion(dtype=Scalar)),
597
+ input_types={"a": quaternion(dtype=Float)},
598
+ value_func=sametypes_create_value_func(quaternion(dtype=Float)),
547
599
  group="Vector Math",
548
- doc="Compute the normalized value of ``x``. If ``length(x)`` is 0, then the zero quaternion is returned.",
600
+ doc="Compute the normalized value of ``a``. If ``length(a)`` is 0, then the zero quaternion is returned.",
549
601
  )
550
602
 
551
603
  add_builtin(
552
604
  "transpose",
553
- input_types={"m": matrix(shape=(Any, Any), dtype=Scalar)},
554
- value_func=lambda arg_types, kwds, _: matrix(
555
- shape=(arg_types[0]._shape_[1], arg_types[0]._shape_[0]), dtype=arg_types[0]._wp_scalar_type_
605
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar)},
606
+ value_func=lambda arg_types, arg_values: matrix(
607
+ shape=(arg_types["a"]._shape_[1], arg_types["a"]._shape_[0]), dtype=arg_types["a"]._wp_scalar_type_
556
608
  ),
557
609
  group="Vector Math",
558
- doc="Return the transpose of the matrix ``m``.",
610
+ doc="Return the transpose of the matrix ``a``.",
559
611
  )
560
612
 
561
613
 
562
- def value_func_mat_inv(arg_types, kwds, _):
614
+ def inverse_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
563
615
  if arg_types is None:
564
616
  return matrix(shape=(Any, Any), dtype=Float)
565
- return arg_types[0]
617
+
618
+ return arg_types["a"]
566
619
 
567
620
 
568
621
  add_builtin(
569
622
  "inverse",
570
- input_types={"m": matrix(shape=(2, 2), dtype=Float)},
571
- value_func=value_func_mat_inv,
623
+ input_types={"a": matrix(shape=(2, 2), dtype=Float)},
624
+ value_func=inverse_value_func,
572
625
  group="Vector Math",
573
- doc="Return the inverse of a 2x2 matrix ``m``.",
626
+ doc="Return the inverse of a 2x2 matrix ``a``.",
574
627
  require_original_output_arg=True,
575
628
  )
576
629
 
577
630
  add_builtin(
578
631
  "inverse",
579
- input_types={"m": matrix(shape=(3, 3), dtype=Float)},
580
- value_func=value_func_mat_inv,
632
+ input_types={"a": matrix(shape=(3, 3), dtype=Float)},
633
+ value_func=inverse_value_func,
581
634
  group="Vector Math",
582
- doc="Return the inverse of a 3x3 matrix ``m``.",
635
+ doc="Return the inverse of a 3x3 matrix ``a``.",
583
636
  require_original_output_arg=True,
584
637
  )
585
638
 
586
639
  add_builtin(
587
640
  "inverse",
588
- input_types={"m": matrix(shape=(4, 4), dtype=Float)},
589
- value_func=value_func_mat_inv,
641
+ input_types={"a": matrix(shape=(4, 4), dtype=Float)},
642
+ value_func=inverse_value_func,
590
643
  group="Vector Math",
591
- doc="Return the inverse of a 4x4 matrix ``m``.",
644
+ doc="Return the inverse of a 4x4 matrix ``a``.",
592
645
  require_original_output_arg=True,
593
646
  )
594
647
 
595
648
 
596
- def value_func_mat_det(arg_types, kwds, _):
649
+ def determinant_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
597
650
  if arg_types is None:
598
- return Scalar
599
- return arg_types[0]._wp_scalar_type_
651
+ return Float
652
+
653
+ return arg_types["a"]._wp_scalar_type_
600
654
 
601
655
 
602
656
  add_builtin(
603
657
  "determinant",
604
- input_types={"m": matrix(shape=(2, 2), dtype=Float)},
605
- value_func=value_func_mat_det,
658
+ input_types={"a": matrix(shape=(2, 2), dtype=Float)},
659
+ value_func=determinant_value_func,
606
660
  group="Vector Math",
607
- doc="Return the determinant of a 2x2 matrix ``m``.",
661
+ doc="Return the determinant of a 2x2 matrix ``a``.",
608
662
  )
609
663
 
610
664
  add_builtin(
611
665
  "determinant",
612
- input_types={"m": matrix(shape=(3, 3), dtype=Float)},
613
- value_func=value_func_mat_det,
666
+ input_types={"a": matrix(shape=(3, 3), dtype=Float)},
667
+ value_func=determinant_value_func,
614
668
  group="Vector Math",
615
- doc="Return the determinant of a 3x3 matrix ``m``.",
669
+ doc="Return the determinant of a 3x3 matrix ``a``.",
616
670
  )
617
671
 
618
672
  add_builtin(
619
673
  "determinant",
620
- input_types={"m": matrix(shape=(4, 4), dtype=Float)},
621
- value_func=value_func_mat_det,
674
+ input_types={"a": matrix(shape=(4, 4), dtype=Float)},
675
+ value_func=determinant_value_func,
622
676
  group="Vector Math",
623
- doc="Return the determinant of a 4x4 matrix ``m``.",
677
+ doc="Return the determinant of a 4x4 matrix ``a``.",
624
678
  )
625
679
 
626
680
 
627
- def value_func_mat_trace(arg_types, kwds, _):
681
+ def trace_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
628
682
  if arg_types is None:
629
683
  return Scalar
630
- if arg_types[0]._shape_[0] != arg_types[0]._shape_[1]:
631
- raise RuntimeError(f"Matrix shape is {arg_types[0]._shape_}. Cannot find the trace of non square matrices")
632
- return arg_types[0]._wp_scalar_type_
684
+
685
+ if arg_types["a"]._shape_[0] != arg_types["a"]._shape_[1]:
686
+ raise RuntimeError(f"Matrix shape is {arg_types['a']._shape_}. Cannot find the trace of non square matrices")
687
+ return arg_types["a"]._wp_scalar_type_
633
688
 
634
689
 
635
690
  add_builtin(
636
691
  "trace",
637
- input_types={"m": matrix(shape=(Any, Any), dtype=Scalar)},
638
- value_func=value_func_mat_trace,
692
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar)},
693
+ value_func=trace_value_func,
639
694
  group="Vector Math",
640
- doc="Return the trace of the matrix ``m``.",
695
+ doc="Return the trace of the matrix ``a``.",
641
696
  )
642
697
 
643
698
 
644
- def value_func_diag(arg_types, kwds, _):
699
+ def diag_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
645
700
  if arg_types is None:
646
701
  return matrix(shape=(Any, Any), dtype=Scalar)
647
- else:
648
- return matrix(shape=(arg_types[0]._length_, arg_types[0]._length_), dtype=arg_types[0]._wp_scalar_type_)
702
+
703
+ return matrix(shape=(arg_types["vec"]._length_, arg_types["vec"]._length_), dtype=arg_types["vec"]._wp_scalar_type_)
649
704
 
650
705
 
651
706
  add_builtin(
652
707
  "diag",
653
- input_types={"d": vector(length=Any, dtype=Scalar)},
654
- value_func=value_func_diag,
708
+ input_types={"vec": vector(length=Any, dtype=Scalar)},
709
+ value_func=diag_value_func,
655
710
  group="Vector Math",
656
- doc="Returns a matrix with the components of the vector ``d`` on the diagonal.",
711
+ doc="Returns a matrix with the components of the vector ``vec`` on the diagonal.",
657
712
  )
658
713
 
659
714
 
660
- def value_func_get_diag(arg_types, kwds, _):
715
+ def get_diag_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
661
716
  if arg_types is None:
662
717
  return vector(length=(Any), dtype=Scalar)
663
- else:
664
- if arg_types[0]._shape_[0] != arg_types[0]._shape_[1]:
665
- raise RuntimeError(
666
- f"Matrix shape is {arg_types[0]._shape_}; get_diag is only available for square matrices."
667
- )
668
- return vector(length=arg_types[0]._shape_[0], dtype=arg_types[0]._wp_scalar_type_)
718
+
719
+ if arg_types["mat"]._shape_[0] != arg_types["mat"]._shape_[1]:
720
+ raise RuntimeError(
721
+ f"Matrix shape is {arg_types['mat']._shape_}; get_diag is only available for square matrices."
722
+ )
723
+ return vector(length=arg_types["mat"]._shape_[0], dtype=arg_types["mat"]._wp_scalar_type_)
669
724
 
670
725
 
671
726
  add_builtin(
672
727
  "get_diag",
673
- input_types={"m": matrix(shape=(Any, Any), dtype=Scalar)},
674
- value_func=value_func_get_diag,
728
+ input_types={"mat": matrix(shape=(Any, Any), dtype=Scalar)},
729
+ value_func=get_diag_value_func,
675
730
  group="Vector Math",
676
- doc="Returns a vector containing the diagonal elements of the square matrix ``m``.",
731
+ doc="Returns a vector containing the diagonal elements of the square matrix ``mat``.",
677
732
  )
678
733
 
679
734
  add_builtin(
680
735
  "cw_mul",
681
- input_types={"x": vector(length=Any, dtype=Scalar), "y": vector(length=Any, dtype=Scalar)},
736
+ input_types={"a": vector(length=Any, dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
682
737
  constraint=sametypes,
683
- value_func=sametype_value_func(vector(length=Any, dtype=Scalar)),
738
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
684
739
  group="Vector Math",
685
740
  doc="Component-wise multiplication of two vectors.",
686
741
  )
687
742
  add_builtin(
688
743
  "cw_div",
689
- input_types={"x": vector(length=Any, dtype=Scalar), "y": vector(length=Any, dtype=Scalar)},
744
+ input_types={"a": vector(length=Any, dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
690
745
  constraint=sametypes,
691
- value_func=sametype_value_func(vector(length=Any, dtype=Scalar)),
746
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
692
747
  group="Vector Math",
693
748
  doc="Component-wise division of two vectors.",
694
749
  require_original_output_arg=True,
@@ -696,17 +751,17 @@ add_builtin(
696
751
 
697
752
  add_builtin(
698
753
  "cw_mul",
699
- input_types={"x": matrix(shape=(Any, Any), dtype=Scalar), "y": matrix(shape=(Any, Any), dtype=Scalar)},
754
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "b": matrix(shape=(Any, Any), dtype=Scalar)},
700
755
  constraint=sametypes,
701
- value_func=sametype_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
756
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
702
757
  group="Vector Math",
703
758
  doc="Component-wise multiplication of two matrices.",
704
759
  )
705
760
  add_builtin(
706
761
  "cw_div",
707
- input_types={"x": matrix(shape=(Any, Any), dtype=Scalar), "y": matrix(shape=(Any, Any), dtype=Scalar)},
762
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "b": matrix(shape=(Any, Any), dtype=Scalar)},
708
763
  constraint=sametypes,
709
- value_func=sametype_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
764
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
710
765
  group="Vector Math",
711
766
  doc="Component-wise division of two matrices.",
712
767
  require_original_output_arg=True,
@@ -719,7 +774,7 @@ for t in scalar_types_all:
719
774
  for u in scalar_types_all:
720
775
  add_builtin(
721
776
  t.__name__,
722
- input_types={"u": u},
777
+ input_types={"a": u},
723
778
  value_type=t,
724
779
  doc="",
725
780
  hidden=True,
@@ -729,203 +784,231 @@ for t in scalar_types_all:
729
784
  )
730
785
 
731
786
 
732
- def vector_constructor_func(arg_types, kwds, templates):
787
+ def vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
733
788
  if arg_types is None:
734
789
  return vector(length=Any, dtype=Scalar)
735
790
 
736
- if templates is None or len(templates) == 0:
737
- # handle construction of anonymous (undeclared) vector types
738
-
739
- if "length" in kwds:
740
- if len(arg_types) == 0:
741
- if "dtype" not in kwds:
742
- raise RuntimeError(
743
- "vec() must have dtype as a keyword argument if it has no positional arguments, e.g.: wp.vector(length=5, dtype=wp.float32)"
744
- )
745
-
746
- # zero initialization e.g.: wp.vector(length=5, dtype=wp.float32)
747
- veclen = kwds["length"]
748
- vectype = kwds["dtype"]
749
-
750
- elif len(arg_types) == 1:
751
- # value initialization e.g.: wp.vec(1.0, length=5)
752
- veclen = kwds["length"]
753
- vectype = arg_types[0]
754
- if type_is_vector(vectype):
755
- # constructor from another vector
756
- if vectype._length_ != veclen:
757
- raise RuntimeError(
758
- f"Incompatible vector lengths for casting copy constructor, {veclen} vs {vectype._length_}"
759
- )
760
- vectype = vectype._wp_scalar_type_
761
- else:
791
+ length = arg_values.get("length", None)
792
+ dtype = arg_values.get("dtype", None)
793
+
794
+ variadic_arg_types = arg_types.get("args", ())
795
+ variadic_arg_count = len(variadic_arg_types)
796
+ if variadic_arg_count == 0:
797
+ # Zero-initialization, e.g.: `wp.vecXX()`, `wp.vector(length=2, dtype=wp.float16)`.
798
+ if length is None:
799
+ raise RuntimeError("the `length` argument must be specified when zero-initializing a vector")
800
+
801
+ if dtype is None:
802
+ dtype = float32
803
+ elif variadic_arg_count == 1:
804
+ value_type = strip_reference(variadic_arg_types[0])
805
+ if type_is_vector(value_type):
806
+ # Copy constructor, e.g.: `wp.vecXX(other_vec)`, `wp.vector(other_vec)`.
807
+ if length is None:
808
+ length = value_type._length_
809
+ elif value_type._length_ != length:
762
810
  raise RuntimeError(
763
- "vec() must have one scalar argument or the dtype keyword argument if the length keyword argument is specified, e.g.: wp.vec(1.0, length=5)"
811
+ f"incompatible vector of length {length} given when copy constructing "
812
+ f"a vector of length {value_type._length_}"
764
813
  )
765
814
 
815
+ if dtype is None:
816
+ dtype = value_type._wp_scalar_type_
766
817
  else:
767
- if len(arg_types) == 0:
818
+ # Initialization by filling a value, e.g.: `wp.vecXX(123)`,
819
+ # `wp.vector(123, length=2)`.
820
+ if length is None:
821
+ raise RuntimeError("the `length` argument must be specified when filling a vector with a value")
822
+
823
+ if dtype is None:
824
+ dtype = value_type
825
+ elif value_type != dtype:
768
826
  raise RuntimeError(
769
- "vec() must have at least one numeric argument, if it's length, dtype is not specified"
827
+ f"the value used to fill this vector is expected to be of the type `{dtype.__name__}`"
770
828
  )
829
+ else:
830
+ # Initializing by value, e.g.: `wp.vec2(1, 2)`, `wp.vector(1, 2, length=2)`.
831
+ if length is None:
832
+ length = variadic_arg_count
833
+ elif length != variadic_arg_count:
834
+ raise RuntimeError(
835
+ f"incompatible number of values given ({variadic_arg_count}) "
836
+ f"when constructing a vector of length {length}"
837
+ )
771
838
 
772
- if "dtype" in kwds:
773
- # casting constructor
774
- if len(arg_types) == 1 and types_equal(
775
- arg_types[0], vector(length=Any, dtype=Scalar), match_generic=True
776
- ):
777
- veclen = arg_types[0]._length_
778
- vectype = kwds["dtype"]
779
- templates.append(veclen)
780
- templates.append(vectype)
781
- return vector(length=veclen, dtype=vectype)
782
- raise RuntimeError(
783
- "vec() should not have dtype specified if numeric arguments are given, the dtype will be inferred from the argument types"
784
- )
839
+ try:
840
+ value_type = scalar_infer_type(variadic_arg_types)
841
+ except RuntimeError:
842
+ raise RuntimeError("all values given when constructing a vector must have the same type") from None
785
843
 
786
- # component wise construction of an anonymous vector, e.g. wp.vec(wp.float16(1.0), wp.float16(2.0), ....)
787
- # we infer the length and data type from the number and type of the arg values
788
- veclen = len(arg_types)
789
- vectype = arg_types[0]
844
+ if dtype is None:
845
+ dtype = value_type
846
+ elif value_type != dtype:
847
+ raise RuntimeError(
848
+ f"all values used to initialize this vector matrix are expected to be of the type `{dtype.__name__}`"
849
+ )
790
850
 
791
- if len(arg_types) == 1 and type_is_vector(vectype):
792
- # constructor from another vector
793
- veclen = vectype._length_
794
- vectype = vectype._wp_scalar_type_
795
- elif not all(vectype == t for t in arg_types):
796
- raise RuntimeError(
797
- f"All numeric arguments to vec() constructor should have the same type, expected {veclen} arg_types of type {vectype}, received { ','.join([str(t) for t in arg_types]) }"
798
- )
851
+ if length is None:
852
+ raise RuntimeError("could not infer the `length` argument when calling the `wp.vector()` function")
799
853
 
800
- # update the templates list, so we can generate vec<len, type>() correctly in codegen
801
- templates.append(veclen)
802
- templates.append(vectype)
854
+ if dtype is None:
855
+ raise RuntimeError("could not infer the `dtype` argument when calling the `wp.vector()` function")
856
+
857
+ return vector(length=length, dtype=dtype)
803
858
 
804
- else:
805
- # construction of a predeclared type, e.g.: vec5d
806
- veclen, vectype = templates
807
- if len(arg_types) == 1 and type_is_vector(arg_types[0]):
808
- # constructor from another vector
809
- if arg_types[0]._length_ != veclen:
810
- raise RuntimeError(
811
- f"Incompatible matrix sizes for casting copy constructor, {veclen} vs {arg_types[0]._length_}"
812
- )
813
- elif not all(vectype == t for t in arg_types):
814
- raise RuntimeError(
815
- f"All numeric arguments to vec() constructor should have the same type, expected {veclen} arg_types of type {vectype}, received { ','.join([str(t) for t in arg_types]) }"
816
- )
817
859
 
818
- retvalue = vector(length=veclen, dtype=vectype)
819
- return retvalue
860
+ def vector_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
861
+ # We're in the codegen stage where we emit the code calling the built-in.
862
+ # Further validate the given argument values if needed and map them
863
+ # to the underlying C++ function's runtime and template params.
864
+
865
+ length = return_type._length_
866
+ dtype = return_type._wp_scalar_type_
867
+
868
+ variadic_args = args.get("args", ())
869
+
870
+ func_args = variadic_args
871
+ template_args = (length, dtype)
872
+ return (func_args, template_args)
820
873
 
821
874
 
822
875
  add_builtin(
823
876
  "vector",
824
- input_types={"*arg_types": Scalar, "length": int, "dtype": Scalar},
877
+ input_types={"*args": Scalar, "length": int, "dtype": Scalar},
878
+ defaults={"length": None, "dtype": None},
825
879
  variadic=True,
826
- initializer_list_func=lambda arg_types, _: len(arg_types) > 4,
827
- value_func=vector_constructor_func,
880
+ initializer_list_func=lambda arg_types, arg_values: len(arg_types.get("args", ())) > 4,
881
+ value_func=vector_value_func,
882
+ export_func=lambda input_types: {k: v for k, v in input_types.items() if k not in ("length", "dtype")},
883
+ dispatch_func=vector_dispatch_func,
828
884
  native_func="vec_t",
829
- doc="Construct a vector of with given length and dtype.",
885
+ doc="Construct a vector of given length and dtype.",
830
886
  group="Vector Math",
831
887
  export=False,
832
888
  )
833
889
 
834
890
 
835
- def matrix_constructor_func(arg_types, kwds, templates):
891
+ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
836
892
  if arg_types is None:
837
893
  return matrix(shape=(Any, Any), dtype=Scalar)
838
894
 
839
- if len(templates) == 0:
840
- # anonymous construction
841
- if "shape" not in kwds:
842
- raise RuntimeError("shape keyword must be specified when calling matrix() function")
895
+ shape = arg_values.get("shape", None)
896
+ dtype = arg_values.get("dtype", None)
897
+
898
+ variadic_arg_types = arg_types.get("args", ())
899
+ variadic_arg_count = len(variadic_arg_types)
900
+ if variadic_arg_count == 0:
901
+ # Zero-initialization, e.g.: `wp.matXX()`, `wp.matrix(shape=(2, 2), dtype=wp.float16)`.
902
+ if shape is None:
903
+ raise RuntimeError("the `shape` argument must be specified when zero-initializing a matrix")
904
+
905
+ if dtype is None:
906
+ dtype = float32
907
+ elif variadic_arg_count == 1:
908
+ value_type = strip_reference(variadic_arg_types[0])
909
+ if type_is_matrix(value_type):
910
+ # Copy constructor, e.g.: `wp.matXX(other_mat)`, `wp.matrix(other_mat)`.
911
+ if shape is None:
912
+ shape = value_type._shape_
913
+ elif not seq_check_equal(value_type._shape_, shape):
914
+ raise RuntimeError(
915
+ f"incompatible matrix of shape {tuple(shape)} given when copy constructing "
916
+ f"a matrix of shape {tuple(value_type._shape_)}"
917
+ )
843
918
 
844
- if len(arg_types) == 0:
845
- if "dtype" not in kwds:
846
- raise RuntimeError("matrix() must have dtype as a keyword argument if it has no positional arguments")
919
+ if dtype is None:
920
+ dtype = value_type._wp_scalar_type_
921
+ else:
922
+ # Initialization by filling a value, e.g.: `wp.matXX(123)`,
923
+ # `wp.matrix(123, shape=(2, 2))`.
924
+ if shape is None:
925
+ raise RuntimeError("the `shape` argument must be specified when filling a matrix with a value")
926
+
927
+ if dtype is None:
928
+ dtype = value_type
929
+ elif value_type != dtype:
930
+ raise RuntimeError(
931
+ f"the value used to fill this matrix is expected to be of the type `{dtype.__name__}`"
932
+ )
933
+ else:
934
+ # Initializing by value, e.g.: `wp.mat22(1, 2, 3, 4)`, `wp.matrix(1, 2, 3, 4, shape=(2, 2))`.
935
+ if shape is None:
936
+ raise RuntimeError("the `shape` argument must be specified when initializing a matrix by value")
847
937
 
848
- # zero initialization, e.g.: m = matrix(shape=(3,2), dtype=wp.float16)
849
- shape = kwds["shape"]
850
- dtype = kwds["dtype"]
938
+ if all(type_is_vector(x) for x in variadic_arg_types):
939
+ if shape[1] != variadic_arg_count:
940
+ raise RuntimeError(
941
+ f"incompatible number of column vectors given ({variadic_arg_count}) "
942
+ f"when constructing a matrix of shape {tuple(shape)}"
943
+ )
851
944
 
852
- else:
853
- # value initialization, e.g.: m = matrix(1.0, shape=(3,2))
854
- shape = kwds["shape"]
855
- dtype = arg_types[0]
856
-
857
- if len(arg_types) == 1 and type_is_matrix(dtype):
858
- # constructor from another matrix
859
- if arg_types[0]._shape_ != shape:
860
- raise RuntimeError(
861
- f"Incompatible matrix sizes for casting copy constructor, {shape} vs {arg_types[0]._shape_}"
862
- )
863
- dtype = dtype._wp_scalar_type_
864
- elif len(arg_types) > 1 and len(arg_types) != shape[0] * shape[1]:
945
+ if any(x._length_ != shape[0] for x in variadic_arg_types):
865
946
  raise RuntimeError(
866
- "Wrong number of arguments for matrix() function, must initialize with either a scalar value, or m*n values"
947
+ f"incompatible column vector lengths given when constructing a matrix of shape {tuple(shape)}"
867
948
  )
949
+ elif shape[0] * shape[1] != variadic_arg_count:
950
+ raise RuntimeError(
951
+ f"incompatible number of values given ({variadic_arg_count}) "
952
+ f"when constructing a matrix of shape {tuple(shape)}"
953
+ )
868
954
 
869
- templates.append(shape[0])
870
- templates.append(shape[1])
871
- templates.append(dtype)
955
+ try:
956
+ value_type = scalar_infer_type(variadic_arg_types)
957
+ except RuntimeError:
958
+ raise RuntimeError("all values given when constructing a matrix must have the same type") from None
872
959
 
873
- else:
874
- # predeclared type, e.g.: mat32d
875
- shape = (templates[0], templates[1])
876
- dtype = templates[2]
877
-
878
- if len(arg_types) > 0:
879
- if len(arg_types) == 1 and type_is_matrix(arg_types[0]):
880
- # constructor from another matrix with same dimension but possibly different type
881
- if arg_types[0]._shape_ != shape:
882
- raise RuntimeError(
883
- f"Incompatible matrix sizes for casting copy constructor, {shape} vs {arg_types[0]._shape_}"
884
- )
885
- else:
886
- # check scalar arg type matches declared type
887
- if infer_scalar_type(arg_types) != dtype:
888
- raise RuntimeError("Wrong scalar type for mat {} constructor".format(",".join(map(str, templates))))
889
-
890
- # check vector arg type matches declared type
891
- if all(type_is_vector(a) for a in arg_types):
892
- cols = len(arg_types)
893
- if shape[1] != cols:
894
- raise RuntimeError(
895
- "Wrong number of vectors when attempting to construct a matrix with column vectors"
896
- )
897
-
898
- if not all(a._length_ == shape[0] for a in arg_types):
899
- raise RuntimeError(
900
- "Wrong vector row count when attempting to construct a matrix with column vectors"
901
- )
902
- else:
903
- # check that we either got 1 arg (scalar construction), or enough values for whole matrix
904
- size = shape[0] * shape[1]
905
- if len(arg_types) > 1 and len(arg_types) != size:
906
- raise RuntimeError(
907
- "Wrong number of scalars when attempting to construct a matrix from a list of components"
908
- )
960
+ if dtype is None:
961
+ dtype = value_type
962
+ elif value_type != dtype:
963
+ raise RuntimeError(
964
+ f"all values used to initialize this matrix are expected to be of the type `{dtype.__name__}`"
965
+ )
966
+
967
+ if shape is None:
968
+ raise RuntimeError("could not infer the `shape` argument when calling the `wp.matrix()` function")
969
+
970
+ if dtype is None:
971
+ raise RuntimeError("could not infer the `dtype` argument when calling the `wp.matrix()` function")
909
972
 
910
973
  return matrix(shape=shape, dtype=dtype)
911
974
 
912
975
 
976
+ def matrix_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
977
+ # We're in the codegen stage where we emit the code calling the built-in.
978
+ # Further validate the given argument values if needed and map them
979
+ # to the underlying C++ function's runtime and template params.
980
+
981
+ shape = return_type._shape_
982
+ dtype = return_type._wp_scalar_type_
983
+
984
+ variadic_args = args.get("args", ())
985
+
986
+ func_args = variadic_args
987
+ template_args = (*shape, dtype)
988
+ return (func_args, template_args)
989
+
990
+
913
991
  # only use initializer list if matrix size < 5x5, or for scalar construction
914
- def matrix_initlist_func(arg_types, templates):
915
- m, n, dtype = templates
992
+ def matrix_initializer_list_func(args, return_type):
993
+ shape = return_type._shape_
994
+
995
+ variadic_args = args.get("args", ())
996
+ variadic_arg_count = len(variadic_args)
916
997
  return not (
917
- len(arg_types) == 0
918
- or len(arg_types) == 1 # zero construction
919
- or (m == n and n < 5) # scalar construction # value construction for small matrices
998
+ variadic_arg_count <= 1 # zero/fill initialization
999
+ or (shape[0] == shape[1] and shape[1] < 5) # value construction for small matrices
920
1000
  )
921
1001
 
922
1002
 
923
1003
  add_builtin(
924
1004
  "matrix",
925
- input_types={"*arg_types": Scalar, "shape": Tuple[int, int], "dtype": Scalar},
1005
+ input_types={"*args": Scalar, "shape": Tuple[int, int], "dtype": Scalar},
1006
+ defaults={"shape": None, "dtype": None},
926
1007
  variadic=True,
927
- initializer_list_func=matrix_initlist_func,
928
- value_func=matrix_constructor_func,
1008
+ value_func=matrix_value_func,
1009
+ export_func=lambda input_types: {k: v for k, v in input_types.items() if k not in ("shape", "dtype")},
1010
+ dispatch_func=matrix_dispatch_func,
1011
+ initializer_list_func=matrix_initializer_list_func,
929
1012
  native_func="mat_t",
930
1013
  doc="Construct a matrix. If the positional ``arg_types`` are not given, then matrix will be zero-initialized.",
931
1014
  group="Vector Math",
@@ -933,69 +1016,95 @@ add_builtin(
933
1016
  )
934
1017
 
935
1018
 
936
- # identity:
937
- def matrix_identity_value_func(arg_types, kwds, templates):
1019
+ def identity_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
938
1020
  if arg_types is None:
939
1021
  return matrix(shape=(Any, Any), dtype=Scalar)
940
1022
 
941
- if len(arg_types):
942
- raise RuntimeError("identity() function does not accept positional arguments")
1023
+ n = arg_values["n"]
1024
+ dtype = arg_values["dtype"]
943
1025
 
944
- if "n" not in kwds:
945
- raise RuntimeError("'n' keyword argument must be specified when calling identity() function")
1026
+ if n is None:
1027
+ raise RuntimeError("'n' must be a constant when calling identity()")
946
1028
 
947
- if "dtype" not in kwds:
948
- raise RuntimeError("'dtype' keyword argument must be specified when calling identity() function")
1029
+ return matrix(shape=(n, n), dtype=dtype)
949
1030
 
950
- n, dtype = [kwds["n"], kwds["dtype"]]
951
1031
 
952
- if n is None:
953
- raise RuntimeError("'n' must be a constant when calling identity() function")
1032
+ def identity_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
1033
+ # We're in the codegen stage where we emit the code calling the built-in.
1034
+ # Further validate the given argument values if needed and map them
1035
+ # to the underlying C++ function's runtime and template params.
954
1036
 
955
- templates.append(n)
956
- templates.append(dtype)
1037
+ shape = return_type._shape_
1038
+ dtype = return_type._wp_scalar_type_
957
1039
 
958
- return matrix(shape=(n, n), dtype=dtype)
1040
+ func_args = ()
1041
+ template_args = (shape[0], dtype)
1042
+ return (func_args, template_args)
959
1043
 
960
1044
 
961
1045
  add_builtin(
962
1046
  "identity",
963
1047
  input_types={"n": int, "dtype": Scalar},
964
- value_func=matrix_identity_value_func,
965
- variadic=True,
1048
+ value_func=identity_value_func,
1049
+ export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
1050
+ dispatch_func=identity_dispatch_func,
966
1051
  doc="Create an identity matrix with shape=(n,n) with the type given by ``dtype``.",
967
1052
  group="Vector Math",
968
1053
  export=False,
969
1054
  )
970
1055
 
971
1056
 
972
- def matrix_transform_value_func(arg_types, kwds, templates):
973
- if templates is None:
1057
+ def matrix_transform_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1058
+ if arg_types is None:
974
1059
  return matrix(shape=(4, 4), dtype=Float)
975
1060
 
976
- if len(templates) == 0:
977
- raise RuntimeError("Cannot use a generic type name in a kernel")
1061
+ dtype = arg_values.get("dtype", None)
978
1062
 
979
- m, n, dtype = templates
980
- if (m, n) != (4, 4):
981
- raise RuntimeError("Can only construct 4x4 matrices with position, rotation and scale")
982
- if infer_scalar_type(arg_types) != dtype:
983
- raise RuntimeError("Wrong scalar type for mat<{}> constructor".format(",".join(map(str, templates))))
1063
+ value_arg_types = tuple(v for k, v in arg_types.items() if k != "dtype")
1064
+ try:
1065
+ value_type = scalar_infer_type(value_arg_types)
1066
+ except RuntimeError:
1067
+ raise RuntimeError(
1068
+ "all values given when constructing a transformation matrix must have the same type"
1069
+ ) from None
1070
+
1071
+ if dtype is None:
1072
+ dtype = value_type
1073
+ elif value_type != dtype:
1074
+ raise RuntimeError(
1075
+ f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
1076
+ )
984
1077
 
985
1078
  return matrix(shape=(4, 4), dtype=dtype)
986
1079
 
987
1080
 
1081
+ def matrix_transform_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
1082
+ # We're in the codegen stage where we emit the code calling the built-in.
1083
+ # Further validate the given argument values if needed and map them
1084
+ # to the underlying C++ function's runtime and template params.
1085
+
1086
+ dtype = return_type._wp_scalar_type_
1087
+
1088
+ func_args = tuple(v for k, v in args.items() if k != "dtype")
1089
+ template_args = (4, 4, dtype)
1090
+ return (func_args, template_args)
1091
+
1092
+
988
1093
  add_builtin(
989
1094
  "matrix",
990
1095
  input_types={
991
1096
  "pos": vector(length=3, dtype=Float),
992
1097
  "rot": quaternion(dtype=Float),
993
1098
  "scale": vector(length=3, dtype=Float),
1099
+ "dtype": Float,
994
1100
  },
1101
+ defaults={"dtype": None},
995
1102
  value_func=matrix_transform_value_func,
1103
+ export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
1104
+ dispatch_func=matrix_transform_dispatch_func,
996
1105
  native_func="mat_t",
997
1106
  doc="""Construct a 4x4 transformation matrix that applies the transformations as
998
- Translation(pos)*Rotation(rot)*Scale(scale) when applied to column vectors, i.e.: y = (TRS)*x""",
1107
+ Translation(pos)*Rotation(rot)*Scaling(scale) when applied to column vectors, i.e.: y = (TRS)*x""",
999
1108
  group="Vector Math",
1000
1109
  export=False,
1001
1110
  )
@@ -1050,42 +1159,69 @@ add_builtin(
1050
1159
  # Quaternion Math
1051
1160
 
1052
1161
 
1053
- def quaternion_value_func(arg_types, kwds, templates):
1162
+ def quaternion_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1054
1163
  if arg_types is None:
1055
1164
  return quaternion(dtype=Float)
1056
1165
 
1057
- if len(templates) == 0:
1058
- if "dtype" in kwds:
1059
- # casting constructor
1060
- dtype = kwds["dtype"]
1061
- else:
1062
- # if constructing anonymous quat type then infer output type from arguments
1063
- dtype = infer_scalar_type(arg_types)
1064
- templates.append(dtype)
1166
+ dtype = arg_values.get("dtype", None)
1167
+
1168
+ variadic_arg_types = tuple(v for k, v in arg_types.items() if k != "dtype")
1169
+ variadic_arg_count = len(variadic_arg_types)
1170
+
1171
+ if variadic_arg_count == 0:
1172
+ # Zero-initialization, e.g.: `wp.quat()`, `wp.quaternion(dtype=wp.float16)`.
1173
+ if dtype is None:
1174
+ dtype = float32
1175
+ elif dtype not in float_types:
1176
+ raise RuntimeError(
1177
+ f"a float type is expected when zero-initializing a quaternion but got `{type(dtype).__name__}` instead"
1178
+ )
1179
+ elif variadic_arg_count == 1:
1180
+ if type_is_quaternion(variadic_arg_types[0]):
1181
+ # Copy constructor, e.g.: `wp.quat(other_vec)`, `wp.quaternion(other_vec)`.
1182
+ in_quat = variadic_arg_types[0]
1183
+ if dtype is None:
1184
+ dtype = in_quat._wp_scalar_type_
1065
1185
  else:
1066
- # if constructing predeclared type then check arg_types match expectation
1067
- if len(arg_types) > 0 and infer_scalar_type(arg_types) != templates[0]:
1068
- raise RuntimeError("Wrong scalar type for quat {} constructor".format(",".join(map(str, templates))))
1186
+ try:
1187
+ value_type = scalar_infer_type(variadic_arg_types)
1188
+ except RuntimeError:
1189
+ raise RuntimeError("all values given when constructing a quaternion must have the same type") from None
1190
+
1191
+ if dtype is None:
1192
+ dtype = value_type
1193
+ elif value_type != dtype:
1194
+ raise RuntimeError(
1195
+ f"all values used to initialize this quaternion are expected to be of the type `{dtype.__name__}`"
1196
+ )
1069
1197
 
1070
- return quaternion(dtype=templates[0])
1198
+ if dtype is None:
1199
+ raise RuntimeError("could not infer the `dtype` argument when calling the `wp.quaternion()` function")
1071
1200
 
1201
+ return quaternion(dtype=dtype)
1072
1202
 
1073
- def quat_cast_value_func(arg_types, kwds, templates):
1074
- if arg_types is None:
1075
- raise RuntimeError("Missing quaternion argument.")
1076
- if "dtype" not in kwds:
1077
- raise RuntimeError("Missing 'dtype' kwd.")
1078
1203
 
1079
- dtype = kwds["dtype"]
1080
- templates.append(dtype)
1204
+ def quaternion_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
1205
+ # We're in the codegen stage where we emit the code calling the built-in.
1206
+ # Further validate the given argument values if needed and map them
1207
+ # to the underlying C++ function's runtime and template params.
1081
1208
 
1082
- return quaternion(dtype=dtype)
1209
+ dtype = return_type._wp_scalar_type_
1210
+
1211
+ variadic_args = tuple(v for k, v in args.items() if k != "dtype")
1212
+
1213
+ func_args = variadic_args
1214
+ template_args = (dtype,)
1215
+ return (func_args, template_args)
1083
1216
 
1084
1217
 
1085
1218
  add_builtin(
1086
1219
  "quaternion",
1087
- input_types={},
1220
+ input_types={"dtype": Float},
1221
+ defaults={"dtype": None},
1088
1222
  value_func=quaternion_value_func,
1223
+ export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
1224
+ dispatch_func=quaternion_dispatch_func,
1089
1225
  native_func="quat_t",
1090
1226
  group="Quaternion Math",
1091
1227
  doc="""Construct a zero-initialized quaternion. Quaternions are laid out as
@@ -1096,6 +1232,8 @@ add_builtin(
1096
1232
  "quaternion",
1097
1233
  input_types={"x": Float, "y": Float, "z": Float, "w": Float},
1098
1234
  value_func=quaternion_value_func,
1235
+ export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
1236
+ dispatch_func=quaternion_dispatch_func,
1099
1237
  native_func="quat_t",
1100
1238
  group="Quaternion Math",
1101
1239
  doc="Create a quaternion using the supplied components (type inferred from component type).",
@@ -1103,17 +1241,24 @@ add_builtin(
1103
1241
  )
1104
1242
  add_builtin(
1105
1243
  "quaternion",
1106
- input_types={"i": vector(length=3, dtype=Float), "r": Float},
1244
+ input_types={"ijk": vector(length=3, dtype=Float), "real": Float, "dtype": Float},
1245
+ defaults={"dtype": None},
1107
1246
  value_func=quaternion_value_func,
1247
+ export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
1248
+ dispatch_func=quaternion_dispatch_func,
1108
1249
  native_func="quat_t",
1109
1250
  group="Quaternion Math",
1110
1251
  doc="Create a quaternion using the supplied vector/scalar (type inferred from scalar type).",
1111
1252
  export=False,
1112
1253
  )
1254
+
1113
1255
  add_builtin(
1114
1256
  "quaternion",
1115
- input_types={"q": quaternion(dtype=Float)},
1116
- value_func=quat_cast_value_func,
1257
+ input_types={"quat": quaternion(dtype=Float), "dtype": Float},
1258
+ defaults={"dtype": None},
1259
+ value_func=quaternion_value_func,
1260
+ export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
1261
+ dispatch_func=quaternion_dispatch_func,
1117
1262
  native_func="quat_t",
1118
1263
  group="Quaternion Math",
1119
1264
  doc="Construct a quaternion of type dtype from another quaternion of a different dtype.",
@@ -1121,26 +1266,34 @@ add_builtin(
1121
1266
  )
1122
1267
 
1123
1268
 
1124
- def quat_identity_value_func(arg_types, kwds, templates):
1125
- # if arg_types is None then we are in 'export' mode
1269
+ def quat_identity_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1126
1270
  if arg_types is None:
1271
+ # return quaternion(dtype=Float)
1127
1272
  return quatf
1128
1273
 
1129
- if "dtype" not in kwds:
1130
- # defaulting to float32 to preserve current behavior:
1131
- dtype = float32
1132
- else:
1133
- dtype = kwds["dtype"]
1274
+ dtype = arg_types.get("dtype", float32)
1275
+ return quaternion(dtype=dtype)
1134
1276
 
1135
- templates.append(dtype)
1136
1277
 
1137
- return quaternion(dtype=dtype)
1278
+ def quat_identity_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
1279
+ # We're in the codegen stage where we emit the code calling the built-in.
1280
+ # Further validate the given argument values if needed and map them
1281
+ # to the underlying C++ function's runtime and template params.
1282
+
1283
+ dtype = return_type._wp_scalar_type_
1284
+
1285
+ func_args = ()
1286
+ template_args = (dtype,)
1287
+ return (func_args, template_args)
1138
1288
 
1139
1289
 
1140
1290
  add_builtin(
1141
1291
  "quat_identity",
1142
- input_types={},
1292
+ input_types={"dtype": Float},
1293
+ defaults={"dtype": None},
1143
1294
  value_func=quat_identity_value_func,
1295
+ export_func=lambda input_types: {},
1296
+ dispatch_func=quat_identity_dispatch_func,
1144
1297
  group="Quaternion Math",
1145
1298
  doc="Construct an identity quaternion with zero imaginary part and real part of 1.0",
1146
1299
  export=True,
@@ -1149,72 +1302,72 @@ add_builtin(
1149
1302
  add_builtin(
1150
1303
  "quat_from_axis_angle",
1151
1304
  input_types={"axis": vector(length=3, dtype=Float), "angle": Float},
1152
- value_func=lambda arg_types, kwds, _: quaternion(dtype=infer_scalar_type(arg_types)),
1305
+ value_func=lambda arg_types, arg_values: quaternion(dtype=float_infer_type(arg_types)),
1153
1306
  group="Quaternion Math",
1154
1307
  doc="Construct a quaternion representing a rotation of angle radians around the given axis.",
1155
1308
  )
1156
1309
  add_builtin(
1157
1310
  "quat_to_axis_angle",
1158
- input_types={"q": quaternion(dtype=Float), "axis": vector(length=3, dtype=Float), "angle": Float},
1311
+ input_types={"quat": quaternion(dtype=Float), "axis": vector(length=3, dtype=Float), "angle": Float},
1159
1312
  value_type=None,
1160
1313
  group="Quaternion Math",
1161
1314
  doc="Extract the rotation axis and angle radians a quaternion represents.",
1162
1315
  )
1163
1316
  add_builtin(
1164
1317
  "quat_from_matrix",
1165
- input_types={"m": matrix(shape=(3, 3), dtype=Float)},
1166
- value_func=lambda arg_types, kwds, _: quaternion(dtype=infer_scalar_type(arg_types)),
1318
+ input_types={"mat": matrix(shape=(3, 3), dtype=Float)},
1319
+ value_func=lambda arg_types, arg_values: quaternion(dtype=float_infer_type(arg_types)),
1167
1320
  group="Quaternion Math",
1168
1321
  doc="Construct a quaternion from a 3x3 matrix.",
1169
1322
  )
1170
1323
  add_builtin(
1171
1324
  "quat_rpy",
1172
1325
  input_types={"roll": Float, "pitch": Float, "yaw": Float},
1173
- value_func=lambda arg_types, kwds, _: quaternion(dtype=infer_scalar_type(arg_types)),
1326
+ value_func=lambda arg_types, arg_values: quaternion(dtype=float_infer_type(arg_types)),
1174
1327
  group="Quaternion Math",
1175
1328
  doc="Construct a quaternion representing a combined roll (z), pitch (x), yaw rotations (y) in radians.",
1176
1329
  )
1177
1330
  add_builtin(
1178
1331
  "quat_inverse",
1179
- input_types={"q": quaternion(dtype=Float)},
1180
- value_func=lambda arg_types, kwds, _: quaternion(dtype=infer_scalar_type(arg_types)),
1332
+ input_types={"quat": quaternion(dtype=Float)},
1333
+ value_func=lambda arg_types, arg_values: quaternion(dtype=float_infer_type(arg_types)),
1181
1334
  group="Quaternion Math",
1182
1335
  doc="Compute quaternion conjugate.",
1183
1336
  )
1184
1337
  add_builtin(
1185
1338
  "quat_rotate",
1186
- input_types={"q": quaternion(dtype=Float), "p": vector(length=3, dtype=Float)},
1187
- value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1339
+ input_types={"quat": quaternion(dtype=Float), "vec": vector(length=3, dtype=Float)},
1340
+ value_func=lambda arg_types, arg_values: vector(length=3, dtype=float_infer_type(arg_types)),
1188
1341
  group="Quaternion Math",
1189
1342
  doc="Rotate a vector by a quaternion.",
1190
1343
  )
1191
1344
  add_builtin(
1192
1345
  "quat_rotate_inv",
1193
- input_types={"q": quaternion(dtype=Float), "p": vector(length=3, dtype=Float)},
1194
- value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1346
+ input_types={"quat": quaternion(dtype=Float), "vec": vector(length=3, dtype=Float)},
1347
+ value_func=lambda arg_types, arg_values: vector(length=3, dtype=float_infer_type(arg_types)),
1195
1348
  group="Quaternion Math",
1196
1349
  doc="Rotate a vector by the inverse of a quaternion.",
1197
1350
  )
1198
1351
  add_builtin(
1199
1352
  "quat_slerp",
1200
- input_types={"q0": quaternion(dtype=Float), "q1": quaternion(dtype=Float), "t": Float},
1201
- value_func=lambda arg_types, kwds, _: quaternion(dtype=infer_scalar_type(arg_types)),
1353
+ input_types={"a": quaternion(dtype=Float), "b": quaternion(dtype=Float), "t": Float},
1354
+ value_func=lambda arg_types, arg_values: quaternion(dtype=float_infer_type(arg_types)),
1202
1355
  group="Quaternion Math",
1203
1356
  doc="Linearly interpolate between two quaternions.",
1204
1357
  require_original_output_arg=True,
1205
1358
  )
1206
1359
  add_builtin(
1207
1360
  "quat_to_matrix",
1208
- input_types={"q": quaternion(dtype=Float)},
1209
- value_func=lambda arg_types, kwds, _: matrix(shape=(3, 3), dtype=infer_scalar_type(arg_types)),
1361
+ input_types={"quat": quaternion(dtype=Float)},
1362
+ value_func=lambda arg_types, arg_values: matrix(shape=(3, 3), dtype=float_infer_type(arg_types)),
1210
1363
  group="Quaternion Math",
1211
1364
  doc="Convert a quaternion to a 3x3 rotation matrix.",
1212
1365
  )
1213
1366
 
1214
1367
  add_builtin(
1215
1368
  "dot",
1216
- input_types={"x": quaternion(dtype=Float), "y": quaternion(dtype=Float)},
1217
- value_func=sametype_scalar_value_func,
1369
+ input_types={"a": quaternion(dtype=Float), "b": quaternion(dtype=Float)},
1370
+ value_func=float_sametypes_value_func,
1218
1371
  group="Quaternion Math",
1219
1372
  doc="Compute the dot product between two quaternions.",
1220
1373
  )
@@ -1222,55 +1375,85 @@ add_builtin(
1222
1375
  # Transformations
1223
1376
 
1224
1377
 
1225
- def transform_constructor_value_func(arg_types, kwds, templates):
1226
- if templates is None:
1227
- return transformation(dtype=Scalar)
1378
+ def transformation_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1379
+ if arg_types is None:
1380
+ return transformation(dtype=Float)
1228
1381
 
1229
- if len(templates) == 0:
1230
- # if constructing anonymous transform type then infer output type from arguments
1231
- dtype = infer_scalar_type(arg_types)
1232
- templates.append(dtype)
1233
- else:
1234
- # if constructing predeclared type then check arg_types match expectation
1235
- if infer_scalar_type(arg_types) != templates[0]:
1236
- raise RuntimeError(
1237
- f"Wrong scalar type for transform constructor expected {templates[0]}, got {','.join([ str(t) for t in arg_types])}"
1238
- )
1382
+ try:
1383
+ value_type = float_infer_type(arg_types)
1384
+ except RuntimeError:
1385
+ raise RuntimeError(
1386
+ "all values given when constructing a transformation matrix must have the same type"
1387
+ ) from None
1388
+
1389
+ dtype = arg_values.get("dtype", None)
1390
+ if dtype is None:
1391
+ dtype = value_type
1392
+ elif value_type != dtype:
1393
+ raise RuntimeError(
1394
+ f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
1395
+ )
1396
+
1397
+ return transformation(dtype=dtype)
1239
1398
 
1240
- return transformation(dtype=templates[0])
1399
+
1400
+ def transformation_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
1401
+ # We're in the codegen stage where we emit the code calling the built-in.
1402
+ # Further validate the given argument values if needed and map them
1403
+ # to the underlying C++ function's runtime and template params.
1404
+
1405
+ dtype = return_type._wp_scalar_type_
1406
+
1407
+ variadic_args = tuple(v for k, v in args.items() if k != "dtype")
1408
+
1409
+ func_args = variadic_args
1410
+ template_args = (dtype,)
1411
+ return (func_args, template_args)
1241
1412
 
1242
1413
 
1243
1414
  add_builtin(
1244
1415
  "transformation",
1245
- input_types={"p": vector(length=3, dtype=Float), "q": quaternion(dtype=Float)},
1246
- value_func=transform_constructor_value_func,
1416
+ input_types={"pos": vector(length=3, dtype=Float), "rot": quaternion(dtype=Float), "dtype": Float},
1417
+ defaults={"dtype": None},
1418
+ value_func=transformation_value_func,
1419
+ export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
1420
+ dispatch_func=transformation_dispatch_func,
1247
1421
  native_func="transform_t",
1248
1422
  group="Transformations",
1249
- doc="Construct a rigid-body transformation with translation part ``p`` and rotation ``q``.",
1423
+ doc="Construct a rigid-body transformation with translation part ``pos`` and rotation ``rot``.",
1250
1424
  export=False,
1251
1425
  )
1252
1426
 
1253
1427
 
1254
- def transform_identity_value_func(arg_types, kwds, templates):
1428
+ def transform_identity_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1255
1429
  # if arg_types is None then we are in 'export' mode
1256
1430
  if arg_types is None:
1431
+ # return transformation(dtype=Float)
1257
1432
  return transformf
1258
1433
 
1259
- if "dtype" not in kwds:
1260
- # defaulting to float32 to preserve current behavior:
1261
- dtype = float32
1262
- else:
1263
- dtype = kwds["dtype"]
1434
+ dtype = arg_types.get("dtype", float32)
1435
+ return transformation(dtype=dtype)
1264
1436
 
1265
- templates.append(dtype)
1266
1437
 
1267
- return transformation(dtype=dtype)
1438
+ def transform_identity_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
1439
+ # We're in the codegen stage where we emit the code calling the built-in.
1440
+ # Further validate the given argument values if needed and map them
1441
+ # to the underlying C++ function's runtime and template params.
1442
+
1443
+ dtype = return_type._wp_scalar_type_
1444
+
1445
+ func_args = ()
1446
+ template_args = (dtype,)
1447
+ return (func_args, template_args)
1268
1448
 
1269
1449
 
1270
1450
  add_builtin(
1271
1451
  "transform_identity",
1272
- input_types={},
1452
+ input_types={"dtype": Float},
1453
+ defaults={"dtype": None},
1273
1454
  value_func=transform_identity_value_func,
1455
+ export_func=lambda input_types: {},
1456
+ dispatch_func=transform_identity_dispatch_func,
1274
1457
  group="Transformations",
1275
1458
  doc="Construct an identity transform with zero translation and identity rotation.",
1276
1459
  export=True,
@@ -1278,103 +1461,168 @@ add_builtin(
1278
1461
 
1279
1462
  add_builtin(
1280
1463
  "transform_get_translation",
1281
- input_types={"t": transformation(dtype=Float)},
1282
- value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1464
+ input_types={"xform": transformation(dtype=Float)},
1465
+ value_func=lambda arg_types, arg_values: vector(length=3, dtype=float_infer_type(arg_types)),
1283
1466
  group="Transformations",
1284
- doc="Return the translational part of a transform ``t``.",
1467
+ doc="Return the translational part of a transform ``xform``.",
1285
1468
  )
1286
1469
  add_builtin(
1287
1470
  "transform_get_rotation",
1288
- input_types={"t": transformation(dtype=Float)},
1289
- value_func=lambda arg_types, kwds, _: quaternion(dtype=infer_scalar_type(arg_types)),
1471
+ input_types={"xform": transformation(dtype=Float)},
1472
+ value_func=lambda arg_types, arg_values: quaternion(dtype=float_infer_type(arg_types)),
1290
1473
  group="Transformations",
1291
- doc="Return the rotational part of a transform ``t``.",
1474
+ doc="Return the rotational part of a transform ``xform``.",
1292
1475
  )
1293
1476
  add_builtin(
1294
1477
  "transform_multiply",
1295
1478
  input_types={"a": transformation(dtype=Float), "b": transformation(dtype=Float)},
1296
- value_func=lambda arg_types, kwds, _: transformation(dtype=infer_scalar_type(arg_types)),
1479
+ value_func=lambda arg_types, arg_values: transformation(dtype=float_infer_type(arg_types)),
1297
1480
  group="Transformations",
1298
1481
  doc="Multiply two rigid body transformations together.",
1299
1482
  )
1300
1483
  add_builtin(
1301
1484
  "transform_point",
1302
- input_types={"t": transformation(dtype=Scalar), "p": vector(length=3, dtype=Scalar)},
1303
- value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1485
+ input_types={"xform": transformation(dtype=Float), "point": vector(length=3, dtype=Float)},
1486
+ value_func=lambda arg_types, arg_values: vector(length=3, dtype=float_infer_type(arg_types)),
1304
1487
  group="Transformations",
1305
- doc="Apply the transform to a point ``p`` treating the homogeneous coordinate as w=1 (translation and rotation).",
1488
+ doc="Apply the transform to a point ``point`` treating the homogeneous coordinate as w=1 (translation and rotation).",
1306
1489
  )
1307
1490
  add_builtin(
1308
1491
  "transform_point",
1309
- input_types={"m": matrix(shape=(4, 4), dtype=Scalar), "p": vector(length=3, dtype=Scalar)},
1310
- value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1492
+ input_types={"mat": matrix(shape=(4, 4), dtype=Float), "point": vector(length=3, dtype=Float)},
1493
+ value_func=lambda arg_types, arg_values: vector(length=3, dtype=float_infer_type(arg_types)),
1311
1494
  group="Vector Math",
1312
- doc="""Apply the transform to a point ``p`` treating the homogeneous coordinate as w=1.
1495
+ doc="""Apply the transform to a point ``point`` treating the homogeneous coordinate as w=1.
1313
1496
 
1314
- The transformation is applied treating ``p`` as a column vector, e.g.: ``y = M*p``.
1315
- Note this is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T = p^T*M^T``.
1497
+ The transformation is applied treating ``point`` as a column vector, e.g.: ``y = mat*point``.
1498
+ Note this is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T = point^T*mat^T``.
1316
1499
  If the transform is coming from a library that uses row-vectors, then users should transpose the transformation
1317
1500
  matrix before calling this method.""",
1318
1501
  )
1319
1502
  add_builtin(
1320
1503
  "transform_vector",
1321
- input_types={"t": transformation(dtype=Scalar), "v": vector(length=3, dtype=Scalar)},
1322
- value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1504
+ input_types={"xform": transformation(dtype=Float), "vec": vector(length=3, dtype=Float)},
1505
+ value_func=lambda arg_types, arg_values: vector(length=3, dtype=float_infer_type(arg_types)),
1323
1506
  group="Transformations",
1324
- doc="Apply the transform to a vector ``v`` treating the homogeneous coordinate as w=0 (rotation only).",
1507
+ doc="Apply the transform to a vector ``vec`` treating the homogeneous coordinate as w=0 (rotation only).",
1325
1508
  )
1326
1509
  add_builtin(
1327
1510
  "transform_vector",
1328
- input_types={"m": matrix(shape=(4, 4), dtype=Scalar), "v": vector(length=3, dtype=Scalar)},
1329
- value_func=lambda arg_types, kwds, _: vector(length=3, dtype=infer_scalar_type(arg_types)),
1511
+ input_types={"mat": matrix(shape=(4, 4), dtype=Float), "vec": vector(length=3, dtype=Float)},
1512
+ value_func=lambda arg_types, arg_values: vector(length=3, dtype=float_infer_type(arg_types)),
1330
1513
  group="Vector Math",
1331
- doc="""Apply the transform to a vector ``v`` treating the homogeneous coordinate as w=0.
1514
+ doc="""Apply the transform to a vector ``vec`` treating the homogeneous coordinate as w=0.
1332
1515
 
1333
- The transformation is applied treating ``v`` as a column vector, e.g.: ``y = M*v``
1334
- note this is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T = v^T*M^T``.
1516
+ The transformation is applied treating ``vec`` as a column vector, e.g.: ``y = mat*vec``
1517
+ note this is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T = vec^T*mat^T``.
1335
1518
  If the transform is coming from a library that uses row-vectors, then users should transpose the transformation
1336
1519
  matrix before calling this method.""",
1337
1520
  )
1338
1521
  add_builtin(
1339
1522
  "transform_inverse",
1340
- input_types={"t": transformation(dtype=Float)},
1341
- value_func=sametype_value_func(transformation(dtype=Float)),
1523
+ input_types={"xform": transformation(dtype=Float)},
1524
+ value_func=sametypes_create_value_func(transformation(dtype=Float)),
1342
1525
  group="Transformations",
1343
- doc="Compute the inverse of the transformation ``t``.",
1526
+ doc="Compute the inverse of the transformation ``xform``.",
1344
1527
  )
1345
1528
  # ---------------------------------
1346
1529
  # Spatial Math
1347
1530
 
1348
1531
 
1349
- def spatial_vector_constructor_value_func(arg_types, kwds, templates):
1350
- if templates is None:
1532
+ def spatial_vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1533
+ if arg_types is None:
1351
1534
  return spatial_vector(dtype=Float)
1352
1535
 
1353
- if len(templates) == 0:
1354
- raise RuntimeError("Cannot use a generic type name in a kernel")
1536
+ dtype = arg_values.get("dtype", None)
1537
+
1538
+ variadic_arg_types = tuple(v for k, v in arg_types.items() if k != "dtype")
1539
+ variadic_arg_count = len(variadic_arg_types)
1540
+ if variadic_arg_count == 0:
1541
+ if dtype is None:
1542
+ dtype = float32
1543
+ elif variadic_arg_count == 2:
1544
+ if any(not type_is_vector(x) for x in variadic_arg_types) or any(x._length_ != 3 for x in variadic_arg_types):
1545
+ raise RuntimeError("arguments `w` and `v` are expected to be vectors of length 3")
1546
+ elif variadic_arg_count != 6:
1547
+ raise RuntimeError("2 vectors or 6 scalar values are expected when constructing a spatial vector")
1548
+
1549
+ if variadic_arg_count:
1550
+ try:
1551
+ value_type = float_infer_type(variadic_arg_types)
1552
+ except RuntimeError:
1553
+ raise RuntimeError("all values given when constructing a spatial vector must have the same type") from None
1554
+
1555
+ if dtype is None:
1556
+ dtype = value_type
1557
+ elif value_type != dtype:
1558
+ raise RuntimeError(
1559
+ f"all values used to initialize this spatial vector are expected to be of the type `{dtype.__name__}`"
1560
+ )
1561
+
1562
+ return vector(length=6, dtype=dtype)
1563
+
1355
1564
 
1356
- vectype = templates[1]
1357
- if len(arg_types) and infer_scalar_type(arg_types) != vectype:
1358
- raise RuntimeError("Wrong scalar type for spatial_vector<{}> constructor".format(",".join(map(str, templates))))
1565
+ def spatial_vector_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
1566
+ # We're in the codegen stage where we emit the code calling the built-in.
1567
+ # Further validate the given argument values if needed and map them
1568
+ # to the underlying C++ function's runtime and template params.
1359
1569
 
1360
- return vector(length=6, dtype=vectype)
1570
+ length = return_type._length_
1571
+ dtype = return_type._wp_scalar_type_
1572
+
1573
+ variadic_args = tuple(v for k, v in args.items() if k != "dtype")
1574
+
1575
+ func_args = variadic_args
1576
+ template_args = (length, dtype)
1577
+ return (func_args, template_args)
1361
1578
 
1362
1579
 
1363
1580
  add_builtin(
1364
- "vector",
1365
- input_types={"w": vector(length=3, dtype=Float), "v": vector(length=3, dtype=Float)},
1366
- value_func=spatial_vector_constructor_value_func,
1581
+ "spatial_vector",
1582
+ input_types={"dtype": Float},
1583
+ defaults={"dtype": None},
1584
+ value_func=spatial_vector_value_func,
1585
+ export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
1586
+ dispatch_func=spatial_vector_dispatch_func,
1587
+ native_func="vec_t",
1588
+ group="Spatial Math",
1589
+ doc="Zero-initialize a 6D screw vector.",
1590
+ export=False,
1591
+ )
1592
+
1593
+
1594
+ add_builtin(
1595
+ "spatial_vector",
1596
+ input_types={"w": vector(length=3, dtype=Float), "v": vector(length=3, dtype=Float), "dtype": Float},
1597
+ defaults={"dtype": None},
1598
+ value_func=spatial_vector_value_func,
1599
+ export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
1600
+ dispatch_func=spatial_vector_dispatch_func,
1367
1601
  native_func="vec_t",
1368
1602
  group="Spatial Math",
1369
1603
  doc="Construct a 6D screw vector from two 3D vectors.",
1370
1604
  export=False,
1371
1605
  )
1372
1606
 
1607
+ add_builtin(
1608
+ "spatial_vector",
1609
+ input_types={"wx": Float, "wy": Float, "wz": Float, "vx": Float, "vy": Float, "vz": Float, "dtype": Float},
1610
+ defaults={"dtype": None},
1611
+ initializer_list_func=lambda arg_types, arg_values: True,
1612
+ value_func=spatial_vector_value_func,
1613
+ export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
1614
+ dispatch_func=spatial_vector_dispatch_func,
1615
+ native_func="vec_t",
1616
+ group="Spatial Math",
1617
+ doc="Construct a 6D screw vector from six values.",
1618
+ export=False,
1619
+ )
1620
+
1373
1621
 
1374
1622
  add_builtin(
1375
1623
  "spatial_adjoint",
1376
1624
  input_types={"r": matrix(shape=(3, 3), dtype=Float), "s": matrix(shape=(3, 3), dtype=Float)},
1377
- value_func=lambda arg_types, kwds, _: matrix(shape=(6, 6), dtype=infer_scalar_type(arg_types)),
1625
+ value_func=lambda arg_types, arg_values: matrix(shape=(6, 6), dtype=float_infer_type(arg_types)),
1378
1626
  group="Spatial Math",
1379
1627
  doc="Construct a 6x6 spatial inertial matrix from two 3x3 diagonal blocks.",
1380
1628
  export=False,
@@ -1382,36 +1630,36 @@ add_builtin(
1382
1630
  add_builtin(
1383
1631
  "spatial_dot",
1384
1632
  input_types={"a": vector(length=6, dtype=Float), "b": vector(length=6, dtype=Float)},
1385
- value_func=sametype_scalar_value_func,
1633
+ value_func=float_sametypes_value_func,
1386
1634
  group="Spatial Math",
1387
1635
  doc="Compute the dot product of two 6D screw vectors.",
1388
1636
  )
1389
1637
  add_builtin(
1390
1638
  "spatial_cross",
1391
1639
  input_types={"a": vector(length=6, dtype=Float), "b": vector(length=6, dtype=Float)},
1392
- value_func=sametype_value_func(vector(length=6, dtype=Float)),
1640
+ value_func=sametypes_create_value_func(vector(length=6, dtype=Float)),
1393
1641
  group="Spatial Math",
1394
1642
  doc="Compute the cross product of two 6D screw vectors.",
1395
1643
  )
1396
1644
  add_builtin(
1397
1645
  "spatial_cross_dual",
1398
1646
  input_types={"a": vector(length=6, dtype=Float), "b": vector(length=6, dtype=Float)},
1399
- value_func=sametype_value_func(vector(length=6, dtype=Float)),
1647
+ value_func=sametypes_create_value_func(vector(length=6, dtype=Float)),
1400
1648
  group="Spatial Math",
1401
1649
  doc="Compute the dual cross product of two 6D screw vectors.",
1402
1650
  )
1403
1651
 
1404
1652
  add_builtin(
1405
1653
  "spatial_top",
1406
- input_types={"a": vector(length=6, dtype=Float)},
1407
- value_func=lambda arg_types, kwds, _: vector(length=3, dtype=arg_types[0]._wp_scalar_type_),
1654
+ input_types={"svec": vector(length=6, dtype=Float)},
1655
+ value_func=lambda arg_types, arg_values: vector(length=3, dtype=arg_types["svec"]._wp_scalar_type_),
1408
1656
  group="Spatial Math",
1409
1657
  doc="Return the top (first) part of a 6D screw vector.",
1410
1658
  )
1411
1659
  add_builtin(
1412
1660
  "spatial_bottom",
1413
- input_types={"a": vector(length=6, dtype=Float)},
1414
- value_func=lambda arg_types, kwds, _: vector(length=3, dtype=arg_types[0]._wp_scalar_type_),
1661
+ input_types={"svec": vector(length=6, dtype=Float)},
1662
+ value_func=lambda arg_types, arg_values: vector(length=3, dtype=arg_types["svec"]._wp_scalar_type_),
1415
1663
  group="Spatial Math",
1416
1664
  doc="Return the bottom (second) part of a 6D screw vector.",
1417
1665
  )
@@ -1588,22 +1836,23 @@ add_builtin(
1588
1836
 
1589
1837
  add_builtin(
1590
1838
  "bvh_query_aabb",
1591
- input_types={"id": uint64, "lower": vec3, "upper": vec3},
1592
- value_type=bvh_query_t,
1839
+ input_types={"id": uint64, "low": vec3, "high": vec3},
1840
+ value_func=lambda arg_types, _: BvhQuery if arg_types is None else bvh_query_t,
1593
1841
  group="Geometry",
1594
1842
  doc="""Construct an axis-aligned bounding box query against a BVH object.
1595
1843
 
1596
1844
  This query can be used to iterate over all bounds inside a BVH.
1597
1845
 
1598
1846
  :param id: The BVH identifier
1599
- :param lower: The lower bound of the bounding box in BVH space
1600
- :param upper: The upper bound of the bounding box in BVH space""",
1847
+ :param low: The lower bound of the bounding box in BVH space
1848
+ :param high: The upper bound of the bounding box in BVH space""",
1849
+ export=False,
1601
1850
  )
1602
1851
 
1603
1852
  add_builtin(
1604
1853
  "bvh_query_ray",
1605
1854
  input_types={"id": uint64, "start": vec3, "dir": vec3},
1606
- value_type=bvh_query_t,
1855
+ value_func=lambda arg_types, _: BvhQuery if arg_types is None else bvh_query_t,
1607
1856
  group="Geometry",
1608
1857
  doc="""Construct a ray query against a BVH object.
1609
1858
 
@@ -1612,15 +1861,17 @@ add_builtin(
1612
1861
  :param id: The BVH identifier
1613
1862
  :param start: The start of the ray in BVH space
1614
1863
  :param dir: The direction of the ray in BVH space""",
1864
+ export=False,
1615
1865
  )
1616
1866
 
1617
1867
  add_builtin(
1618
1868
  "bvh_query_next",
1619
- input_types={"query": bvh_query_t, "index": int},
1869
+ input_types={"query": BvhQuery, "index": int},
1620
1870
  value_type=builtins.bool,
1621
1871
  group="Geometry",
1622
1872
  doc="""Move to the next bound returned by the query.
1623
1873
  The index of the current bound is stored in ``index``, returns ``False`` if there are no more overlapping bound.""",
1874
+ export=False,
1624
1875
  )
1625
1876
 
1626
1877
  add_builtin(
@@ -1650,6 +1901,7 @@ add_builtin(
1650
1901
  :param face: Returns the index of the closest face
1651
1902
  :param bary_u: Returns the barycentric u coordinate of the closest point
1652
1903
  :param bary_v: Returns the barycentric v coordinate of the closest point""",
1904
+ export=False,
1653
1905
  hidden=True,
1654
1906
  )
1655
1907
 
@@ -1660,7 +1912,7 @@ add_builtin(
1660
1912
  "point": vec3,
1661
1913
  "max_dist": float,
1662
1914
  },
1663
- value_type=mesh_query_point_t,
1915
+ value_func=lambda arg_types, _: MeshQueryPoint if arg_types is None else mesh_query_point_t,
1664
1916
  group="Geometry",
1665
1917
  doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space.
1666
1918
 
@@ -1672,6 +1924,7 @@ add_builtin(
1672
1924
  :param point: The point in space to query
1673
1925
  :param max_dist: Mesh faces above this distance will not be considered by the query""",
1674
1926
  require_original_output_arg=True,
1927
+ export=False,
1675
1928
  )
1676
1929
 
1677
1930
  add_builtin(
@@ -1696,6 +1949,7 @@ add_builtin(
1696
1949
  :param face: Returns the index of the closest face
1697
1950
  :param bary_u: Returns the barycentric u coordinate of the closest point
1698
1951
  :param bary_v: Returns the barycentric v coordinate of the closest point""",
1952
+ export=False,
1699
1953
  hidden=True,
1700
1954
  )
1701
1955
 
@@ -1706,7 +1960,7 @@ add_builtin(
1706
1960
  "point": vec3,
1707
1961
  "max_dist": float,
1708
1962
  },
1709
- value_type=mesh_query_point_t,
1963
+ value_func=lambda arg_types, _: MeshQueryPoint if arg_types is None else mesh_query_point_t,
1710
1964
  group="Geometry",
1711
1965
  doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space.
1712
1966
 
@@ -1716,6 +1970,7 @@ add_builtin(
1716
1970
  :param point: The point in space to query
1717
1971
  :param max_dist: Mesh faces above this distance will not be considered by the query""",
1718
1972
  require_original_output_arg=True,
1973
+ export=False,
1719
1974
  )
1720
1975
 
1721
1976
  add_builtin(
@@ -1740,6 +1995,7 @@ add_builtin(
1740
1995
  :param face: Returns the index of the furthest face
1741
1996
  :param bary_u: Returns the barycentric u coordinate of the furthest point
1742
1997
  :param bary_v: Returns the barycentric v coordinate of the furthest point""",
1998
+ export=False,
1743
1999
  hidden=True,
1744
2000
  )
1745
2001
 
@@ -1750,7 +2006,7 @@ add_builtin(
1750
2006
  "point": vec3,
1751
2007
  "min_dist": float,
1752
2008
  },
1753
- value_type=mesh_query_point_t,
2009
+ value_func=lambda arg_types, _: MeshQueryPoint if arg_types is None else mesh_query_point_t,
1754
2010
  group="Geometry",
1755
2011
  doc="""Computes the furthest point on the mesh with identifier `id` to the given point in space.
1756
2012
 
@@ -1760,6 +2016,7 @@ add_builtin(
1760
2016
  :param point: The point in space to query
1761
2017
  :param min_dist: Mesh faces below this distance will not be considered by the query""",
1762
2018
  require_original_output_arg=True,
2019
+ export=False,
1763
2020
  )
1764
2021
 
1765
2022
  add_builtin(
@@ -1793,6 +2050,7 @@ add_builtin(
1793
2050
  :param bary_v: Returns the barycentric v coordinate of the closest point
1794
2051
  :param epsilon: Epsilon treating distance values as equal, when locating the minimum distance vertex/face/edge, as a
1795
2052
  fraction of the average edge length, also for treating closest point as being on edge/vertex default 1e-3""",
2053
+ export=False,
1796
2054
  hidden=True,
1797
2055
  )
1798
2056
 
@@ -1805,7 +2063,7 @@ add_builtin(
1805
2063
  "epsilon": float,
1806
2064
  },
1807
2065
  defaults={"epsilon": 1.0e-3},
1808
- value_type=mesh_query_point_t,
2066
+ value_func=lambda arg_types, _: MeshQueryPoint if arg_types is None else mesh_query_point_t,
1809
2067
  group="Geometry",
1810
2068
  doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space.
1811
2069
 
@@ -1819,6 +2077,7 @@ add_builtin(
1819
2077
  :param epsilon: Epsilon treating distance values as equal, when locating the minimum distance vertex/face/edge, as a
1820
2078
  fraction of the average edge length, also for treating closest point as being on edge/vertex default 1e-3""",
1821
2079
  require_original_output_arg=True,
2080
+ export=False,
1822
2081
  )
1823
2082
 
1824
2083
  add_builtin(
@@ -1855,6 +2114,7 @@ add_builtin(
1855
2114
  :param bary_v: Returns the barycentric v coordinate of the closest point
1856
2115
  :param accuracy: Accuracy for computing the winding number with fast winding number method utilizing second-order dipole approximation, default 2.0
1857
2116
  :param threshold: The threshold of the winding number to be considered inside, default 0.5""",
2117
+ export=False,
1858
2118
  hidden=True,
1859
2119
  )
1860
2120
 
@@ -1868,7 +2128,7 @@ add_builtin(
1868
2128
  "threshold": float,
1869
2129
  },
1870
2130
  defaults={"accuracy": 2.0, "threshold": 0.5},
1871
- value_type=mesh_query_point_t,
2131
+ value_func=lambda arg_types, _: MeshQueryPoint if arg_types is None else mesh_query_point_t,
1872
2132
  group="Geometry",
1873
2133
  doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given point in space.
1874
2134
 
@@ -1884,6 +2144,7 @@ add_builtin(
1884
2144
  :param accuracy: Accuracy for computing the winding number with fast winding number method utilizing second-order dipole approximation, default 2.0
1885
2145
  :param threshold: The threshold of the winding number to be considered inside, default 0.5""",
1886
2146
  require_original_output_arg=True,
2147
+ export=False,
1887
2148
  )
1888
2149
 
1889
2150
  add_builtin(
@@ -1914,6 +2175,7 @@ add_builtin(
1914
2175
  :param sign: Returns a value > 0 if the ray hit in front of the face, returns < 0 otherwise
1915
2176
  :param normal: Returns the face normal
1916
2177
  :param face: Returns the index of the hit face""",
2178
+ export=False,
1917
2179
  hidden=True,
1918
2180
  )
1919
2181
 
@@ -1925,7 +2187,7 @@ add_builtin(
1925
2187
  "dir": vec3,
1926
2188
  "max_t": float,
1927
2189
  },
1928
- value_type=mesh_query_ray_t,
2190
+ value_func=lambda arg_types, _: MeshQueryRay if arg_types is None else mesh_query_ray_t,
1929
2191
  group="Geometry",
1930
2192
  doc="""Computes the closest ray hit on the :class:`Mesh` with identifier ``id``.
1931
2193
 
@@ -1934,30 +2196,33 @@ add_builtin(
1934
2196
  :param dir: The ray direction (should be normalized)
1935
2197
  :param max_t: The maximum distance along the ray to check for intersections""",
1936
2198
  require_original_output_arg=True,
2199
+ export=False,
1937
2200
  )
1938
2201
 
1939
2202
  add_builtin(
1940
2203
  "mesh_query_aabb",
1941
- input_types={"id": uint64, "lower": vec3, "upper": vec3},
1942
- value_type=mesh_query_aabb_t,
2204
+ input_types={"id": uint64, "low": vec3, "high": vec3},
2205
+ value_func=lambda arg_types, _: MeshQueryAABB if arg_types is None else mesh_query_aabb_t,
1943
2206
  group="Geometry",
1944
2207
  doc="""Construct an axis-aligned bounding box query against a :class:`Mesh`.
1945
2208
 
1946
2209
  This query can be used to iterate over all triangles inside a volume.
1947
2210
 
1948
2211
  :param id: The mesh identifier
1949
- :param lower: The lower bound of the bounding box in mesh space
1950
- :param upper: The upper bound of the bounding box in mesh space""",
2212
+ :param low: The lower bound of the bounding box in mesh space
2213
+ :param high: The upper bound of the bounding box in mesh space""",
2214
+ export=False,
1951
2215
  )
1952
2216
 
1953
2217
  add_builtin(
1954
2218
  "mesh_query_aabb_next",
1955
- input_types={"query": mesh_query_aabb_t, "index": int},
2219
+ input_types={"query": MeshQueryAABB, "index": int},
1956
2220
  value_type=builtins.bool,
1957
2221
  group="Geometry",
1958
2222
  doc="""Move to the next triangle overlapping the query bounding box.
1959
2223
 
1960
2224
  The index of the current face is stored in ``index``, returns ``False`` if there are no more overlapping triangles.""",
2225
+ export=False,
1961
2226
  )
1962
2227
 
1963
2228
  add_builtin(
@@ -1966,6 +2231,7 @@ add_builtin(
1966
2231
  value_type=vec3,
1967
2232
  group="Geometry",
1968
2233
  doc="""Evaluates the position on the :class:`Mesh` given a face index and barycentric coordinates.""",
2234
+ export=False,
1969
2235
  )
1970
2236
 
1971
2237
  add_builtin(
@@ -1974,26 +2240,29 @@ add_builtin(
1974
2240
  value_type=vec3,
1975
2241
  group="Geometry",
1976
2242
  doc="""Evaluates the velocity on the :class:`Mesh` given a face index and barycentric coordinates.""",
2243
+ export=False,
1977
2244
  )
1978
2245
 
1979
2246
  add_builtin(
1980
2247
  "hash_grid_query",
1981
2248
  input_types={"id": uint64, "point": vec3, "max_dist": float},
1982
- value_type=hash_grid_query_t,
2249
+ value_func=lambda arg_types, _: HashGridQuery if arg_types is None else hash_grid_query_t,
1983
2250
  group="Geometry",
1984
2251
  doc="""Construct a point query against a :class:`HashGrid`.
1985
2252
 
1986
2253
  This query can be used to iterate over all neighboring point within a fixed radius from the query point.""",
2254
+ export=False,
1987
2255
  )
1988
2256
 
1989
2257
  add_builtin(
1990
2258
  "hash_grid_query_next",
1991
- input_types={"query": hash_grid_query_t, "index": int},
2259
+ input_types={"query": HashGridQuery, "index": int},
1992
2260
  value_type=builtins.bool,
1993
2261
  group="Geometry",
1994
2262
  doc="""Move to the next point in the hash grid query.
1995
2263
 
1996
2264
  The index of the current neighbor is stored in ``index``, returns ``False`` if there are no more neighbors.""",
2265
+ export=False,
1997
2266
  )
1998
2267
 
1999
2268
  add_builtin(
@@ -2006,6 +2275,7 @@ add_builtin(
2006
2275
  This can be used to reorder threads such that grid traversal occurs in a spatially coherent order.
2007
2276
 
2008
2277
  Returns -1 if the :class:`HashGrid` has not been reserved.""",
2278
+ export=False,
2009
2279
  )
2010
2280
 
2011
2281
  add_builtin(
@@ -2016,6 +2286,7 @@ add_builtin(
2016
2286
  doc="""Tests for intersection between two triangles (v0, v1, v2) and (u0, u1, u2) using Moller's method.
2017
2287
 
2018
2288
  Returns > 0 if triangles intersect.""",
2289
+ export=False,
2019
2290
  )
2020
2291
 
2021
2292
  add_builtin(
@@ -2025,6 +2296,7 @@ add_builtin(
2025
2296
  missing_grad=True,
2026
2297
  group="Geometry",
2027
2298
  doc="""Retrieves the mesh given its index.""",
2299
+ export=False,
2028
2300
  )
2029
2301
 
2030
2302
  add_builtin(
@@ -2033,6 +2305,7 @@ add_builtin(
2033
2305
  value_type=vec3,
2034
2306
  group="Geometry",
2035
2307
  doc="""Evaluates the face normal the mesh given a face index.""",
2308
+ export=False,
2036
2309
  )
2037
2310
 
2038
2311
  add_builtin(
@@ -2041,6 +2314,7 @@ add_builtin(
2041
2314
  value_type=vec3,
2042
2315
  group="Geometry",
2043
2316
  doc="""Returns the point of the mesh given a index.""",
2317
+ export=False,
2044
2318
  )
2045
2319
 
2046
2320
  add_builtin(
@@ -2049,6 +2323,7 @@ add_builtin(
2049
2323
  value_type=vec3,
2050
2324
  group="Geometry",
2051
2325
  doc="""Returns the velocity of the mesh given a index.""",
2326
+ export=False,
2052
2327
  )
2053
2328
 
2054
2329
  add_builtin(
@@ -2057,6 +2332,7 @@ add_builtin(
2057
2332
  value_type=int,
2058
2333
  group="Geometry",
2059
2334
  doc="""Returns the point-index of the mesh given a face-vertex index.""",
2335
+ export=False,
2060
2336
  )
2061
2337
 
2062
2338
 
@@ -2075,6 +2351,7 @@ add_builtin(
2075
2351
  :param q2: Second point of second edge
2076
2352
  :param epsilon: Zero tolerance for determining if points in an edge are degenerate.
2077
2353
  :param out: vec3 output containing (s,t,d), where `s` in [0,1] is the barycentric weight for the first edge, `t` is the barycentric weight for the second edge, and `d` is the distance between the two edges at these two closest points.""",
2354
+ export=False,
2078
2355
  )
2079
2356
 
2080
2357
  # ---------------------------------
@@ -2096,9 +2373,13 @@ add_builtin(
2096
2373
  # ---------------------------------
2097
2374
  # Iterators
2098
2375
 
2099
- add_builtin("iter_next", input_types={"range": range_t}, value_type=int, group="Utility", hidden=True)
2100
- add_builtin("iter_next", input_types={"query": hash_grid_query_t}, value_type=int, group="Utility", hidden=True)
2101
- add_builtin("iter_next", input_types={"query": mesh_query_aabb_t}, value_type=int, group="Utility", hidden=True)
2376
+ add_builtin("iter_next", input_types={"range": range_t}, value_type=int, group="Utility", export=False, hidden=True)
2377
+ add_builtin(
2378
+ "iter_next", input_types={"query": hash_grid_query_t}, value_type=int, group="Utility", export=False, hidden=True
2379
+ )
2380
+ add_builtin(
2381
+ "iter_next", input_types={"query": mesh_query_aabb_t}, value_type=int, group="Utility", export=False, hidden=True
2382
+ )
2102
2383
 
2103
2384
  # ---------------------------------
2104
2385
  # Volumes
@@ -2116,26 +2397,46 @@ _volume_supported_value_types = {
2116
2397
  }
2117
2398
 
2118
2399
 
2119
- def volume_value_func(arg_types, kwds, templates):
2120
- try:
2121
- dtype = kwds["dtype"]
2122
- except KeyError as err:
2123
- raise RuntimeError(
2124
- "'dtype' keyword argument must be specified when calling generic volume lookup or sampling functions"
2125
- ) from err
2400
+ def check_volume_value_grad_compatibility(dtype, grad_dtype):
2401
+ if type_is_vector(dtype):
2402
+ expected = matrix(shape=(type_length(dtype), 3), dtype=type_scalar_type(dtype))
2403
+ else:
2404
+ expected = vector(length=3, dtype=dtype)
2126
2405
 
2127
- if dtype not in _volume_supported_value_types:
2128
- raise RuntimeError(f"Unsupported volume type '{type_repr(dtype)}'")
2406
+ if not types_equal(grad_dtype, expected):
2407
+ raise RuntimeError(f"Incompatible gradient type, expected {type_repr(expected)}, got {type_repr(grad_dtype)}")
2129
2408
 
2130
- templates.append(dtype)
2409
+
2410
+ def volume_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
2411
+ if arg_types is None:
2412
+ return Any
2413
+
2414
+ dtype = arg_values["dtype"]
2415
+
2416
+ if dtype not in _volume_supported_value_types:
2417
+ raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
2131
2418
 
2132
2419
  return dtype
2133
2420
 
2134
2421
 
2422
+ def volume_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2423
+ # We're in the codegen stage where we emit the code calling the built-in.
2424
+ # Further validate the given argument values if needed and map them
2425
+ # to the underlying C++ function's runtime and template params.
2426
+
2427
+ dtype = args["dtype"]
2428
+
2429
+ func_args = tuple(v for k, v in args.items() if k != "dtype")
2430
+ template_args = (dtype,)
2431
+ return (func_args, template_args)
2432
+
2433
+
2135
2434
  add_builtin(
2136
2435
  "volume_sample",
2137
2436
  input_types={"id": uint64, "uvw": vec3, "sampling_mode": int, "dtype": Any},
2138
2437
  value_func=volume_value_func,
2438
+ export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
2439
+ dispatch_func=volume_dispatch_func,
2139
2440
  export=False,
2140
2441
  group="Volumes",
2141
2442
  doc="""Sample the volume of type `dtype` given by ``id`` at the volume local-space point ``uvw``.
@@ -2144,31 +2445,38 @@ add_builtin(
2144
2445
  )
2145
2446
 
2146
2447
 
2147
- def check_volume_value_grad_compatibility(dtype, grad_dtype):
2148
- if type_is_vector(dtype):
2149
- expected = matrix(shape=(type_length(dtype), 3), dtype=type_scalar_type(dtype))
2150
- else:
2151
- expected = vector(length=3, dtype=dtype)
2152
-
2153
- if not types_equal(grad_dtype, expected):
2154
- raise RuntimeError(f"Incompatible gradient type, expected {type_repr(expected)}, got {type_repr(grad_dtype)}")
2448
+ def volume_sample_grad_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
2449
+ if arg_types is None:
2450
+ return Any
2155
2451
 
2452
+ dtype = arg_values["dtype"]
2156
2453
 
2157
- def volume_sample_grad_value_func(arg_types, kwds, templates):
2158
- dtype = volume_value_func(arg_types, kwds, templates)
2454
+ if dtype not in _volume_supported_value_types:
2455
+ raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
2159
2456
 
2160
- if len(arg_types) < 4:
2161
- raise RuntimeError("'volume_sample_grad' requires 4 positional arguments")
2457
+ check_volume_value_grad_compatibility(dtype, arg_types["grad"])
2162
2458
 
2163
- grad_type = arg_types[3]
2164
- check_volume_value_grad_compatibility(dtype, grad_type)
2165
2459
  return dtype
2166
2460
 
2167
2461
 
2462
+ def volume_sample_grad_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2463
+ # We're in the codegen stage where we emit the code calling the built-in.
2464
+ # Further validate the given argument values if needed and map them
2465
+ # to the underlying C++ function's runtime and template params.
2466
+
2467
+ dtype = args["dtype"]
2468
+
2469
+ func_args = tuple(v for k, v in args.items() if k != "dtype")
2470
+ template_args = (dtype,)
2471
+ return (func_args, template_args)
2472
+
2473
+
2168
2474
  add_builtin(
2169
2475
  "volume_sample_grad",
2170
2476
  input_types={"id": uint64, "uvw": vec3, "sampling_mode": int, "grad": Any, "dtype": Any},
2171
2477
  value_func=volume_sample_grad_value_func,
2478
+ export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
2479
+ dispatch_func=volume_sample_grad_dispatch_func,
2172
2480
  export=False,
2173
2481
  group="Volumes",
2174
2482
  doc="""Sample the volume given by ``id`` and its gradient at the volume local-space point ``uvw``.
@@ -2176,11 +2484,38 @@ add_builtin(
2176
2484
  Interpolation should be :attr:`warp.Volume.CLOSEST` or :attr:`wp.Volume.LINEAR.`""",
2177
2485
  )
2178
2486
 
2487
+
2488
+ def volume_lookup_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
2489
+ if arg_types is None:
2490
+ return Any
2491
+
2492
+ dtype = arg_values["dtype"]
2493
+
2494
+ if dtype not in _volume_supported_value_types:
2495
+ raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
2496
+
2497
+ return dtype
2498
+
2499
+
2500
+ def volume_lookup_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2501
+ # We're in the codegen stage where we emit the code calling the built-in.
2502
+ # Further validate the given argument values if needed and map them
2503
+ # to the underlying C++ function's runtime and template params.
2504
+
2505
+ dtype = args["dtype"]
2506
+
2507
+ func_args = tuple(v for k, v in args.items() if k != "dtype")
2508
+ template_args = (dtype,)
2509
+ return (func_args, template_args)
2510
+
2511
+
2179
2512
  add_builtin(
2180
2513
  "volume_lookup",
2181
2514
  input_types={"id": uint64, "i": int, "j": int, "k": int, "dtype": Any},
2182
2515
  value_type=int,
2183
- value_func=volume_value_func,
2516
+ value_func=volume_lookup_value_func,
2517
+ export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
2518
+ dispatch_func=volume_lookup_dispatch_func,
2184
2519
  export=False,
2185
2520
  group="Volumes",
2186
2521
  doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k`` for a volume of type type `dtype`.
@@ -2189,13 +2524,14 @@ add_builtin(
2189
2524
  )
2190
2525
 
2191
2526
 
2192
- def volume_store_value_func(arg_types, kwds, templates):
2193
- if len(arg_types) < 4:
2194
- raise RuntimeError("'volume_store' requires 5 positional arguments")
2527
+ def volume_store_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
2528
+ if arg_types is None:
2529
+ return None
2530
+
2531
+ dtype = arg_types["value"]
2195
2532
 
2196
- dtype = arg_types[4]
2197
2533
  if dtype not in _volume_supported_value_types:
2198
- raise RuntimeError(f"Unsupported volume type '{type_repr(dtype)}'")
2534
+ raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
2199
2535
 
2200
2536
  return None
2201
2537
 
@@ -2299,14 +2635,17 @@ add_builtin(
2299
2635
  )
2300
2636
 
2301
2637
 
2302
- def volume_sample_index_value_func(arg_types, kwds, templates):
2303
- if len(arg_types) != 5:
2304
- raise RuntimeError("'volume_sample_index' requires 5 positional arguments")
2638
+ def volume_sample_index_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
2639
+ if arg_types is None:
2640
+ return Any
2641
+
2642
+ dtype = arg_types["voxel_data"].dtype
2305
2643
 
2306
- dtype = arg_types[3].dtype
2644
+ if dtype not in _volume_supported_value_types:
2645
+ raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
2307
2646
 
2308
- if not types_equal(dtype, arg_types[4]):
2309
- raise RuntimeError("The 'voxel_data' array and the 'background' value must have the same dtype")
2647
+ if not types_equal(dtype, arg_types["background"]):
2648
+ raise RuntimeError("the `voxel_data` array and the `background` value must have the same dtype")
2310
2649
 
2311
2650
  return dtype
2312
2651
 
@@ -2326,17 +2665,20 @@ add_builtin(
2326
2665
  )
2327
2666
 
2328
2667
 
2329
- def volume_sample_grad_index_value_func(arg_types, kwds, templates):
2330
- if len(arg_types) != 6:
2331
- raise RuntimeError("'volume_sample_grad_index' requires 6 positional arguments")
2668
+ def volume_sample_grad_index_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
2669
+ if arg_types is None:
2670
+ return Any
2671
+
2672
+ dtype = arg_types["voxel_data"].dtype
2332
2673
 
2333
- dtype = arg_types[3].dtype
2674
+ if dtype not in _volume_supported_value_types:
2675
+ raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
2334
2676
 
2335
- if not types_equal(dtype, arg_types[4]):
2336
- raise RuntimeError("The 'voxel_data' array and the 'background' value must have the same dtype")
2677
+ if not types_equal(dtype, arg_types["background"]):
2678
+ raise RuntimeError("the `voxel_data` array and the `background` value must have the same dtype")
2679
+
2680
+ check_volume_value_grad_compatibility(dtype, arg_types["grad"])
2337
2681
 
2338
- grad_type = arg_types[5]
2339
- check_volume_value_grad_compatibility(dtype, grad_type)
2340
2682
  return dtype
2341
2683
 
2342
2684
 
@@ -2434,10 +2776,10 @@ add_builtin(
2434
2776
  )
2435
2777
  add_builtin(
2436
2778
  "randi",
2437
- input_types={"state": uint32, "min": int, "max": int},
2779
+ input_types={"state": uint32, "low": int, "high": int},
2438
2780
  value_type=int,
2439
2781
  group="Random",
2440
- doc="Return a random integer between [min, max).",
2782
+ doc="Return a random integer between [low, high).",
2441
2783
  )
2442
2784
  add_builtin(
2443
2785
  "randf",
@@ -2448,10 +2790,10 @@ add_builtin(
2448
2790
  )
2449
2791
  add_builtin(
2450
2792
  "randf",
2451
- input_types={"state": uint32, "min": float, "max": float},
2793
+ input_types={"state": uint32, "low": float, "high": float},
2452
2794
  value_type=float,
2453
2795
  group="Random",
2454
- doc="Return a random float between [min, max).",
2796
+ doc="Return a random float between [low, high).",
2455
2797
  )
2456
2798
  add_builtin(
2457
2799
  "randn", input_types={"state": uint32}, value_type=float, group="Random", doc="Sample a normal distribution."
@@ -2600,7 +2942,7 @@ add_builtin(
2600
2942
  add_builtin(
2601
2943
  "curlnoise",
2602
2944
  input_types={"state": uint32, "xy": vec2, "octaves": uint32, "lacunarity": float, "gain": float},
2603
- defaults={"octaves": 1, "lacunarity": 2.0, "gain": 0.5},
2945
+ defaults={"octaves": uint32(1), "lacunarity": 2.0, "gain": 0.5},
2604
2946
  value_type=vec2,
2605
2947
  group="Random",
2606
2948
  doc="Divergence-free vector field based on the gradient of a Perlin noise function.",
@@ -2609,7 +2951,7 @@ add_builtin(
2609
2951
  add_builtin(
2610
2952
  "curlnoise",
2611
2953
  input_types={"state": uint32, "xyz": vec3, "octaves": uint32, "lacunarity": float, "gain": float},
2612
- defaults={"octaves": 1, "lacunarity": 2.0, "gain": 0.5},
2954
+ defaults={"octaves": uint32(1), "lacunarity": 2.0, "gain": 0.5},
2613
2955
  value_type=vec3,
2614
2956
  group="Random",
2615
2957
  doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
@@ -2618,19 +2960,31 @@ add_builtin(
2618
2960
  add_builtin(
2619
2961
  "curlnoise",
2620
2962
  input_types={"state": uint32, "xyzt": vec4, "octaves": uint32, "lacunarity": float, "gain": float},
2621
- defaults={"octaves": 1, "lacunarity": 2.0, "gain": 0.5},
2963
+ defaults={"octaves": uint32(1), "lacunarity": 2.0, "gain": 0.5},
2622
2964
  value_type=vec3,
2623
2965
  group="Random",
2624
2966
  doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
2625
2967
  missing_grad=True,
2626
2968
  )
2627
2969
 
2970
+
2971
+ def printf_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2972
+ # We're in the codegen stage where we emit the code calling the built-in.
2973
+ # Further validate the given argument values if needed and map them
2974
+ # to the underlying C++ function's runtime and template params.
2975
+
2976
+ func_args = (args["fmt"], *args["args"])
2977
+ template_args = ()
2978
+ return (func_args, template_args)
2979
+
2980
+
2628
2981
  # note printf calls directly to global CRT printf (no wp:: namespace prefix)
2629
2982
  add_builtin(
2630
2983
  "printf",
2631
- input_types={},
2984
+ input_types={"fmt": str, "*args": Any},
2632
2985
  namespace="",
2633
2986
  variadic=True,
2987
+ dispatch_func=printf_dispatch_func,
2634
2988
  group="Utility",
2635
2989
  doc="Allows printing formatted strings using C-style format specifiers.",
2636
2990
  )
@@ -2709,189 +3063,309 @@ add_builtin(
2709
3063
 
2710
3064
  add_builtin(
2711
3065
  "copy",
2712
- input_types={"value": Any},
2713
- value_func=lambda arg_types, kwds, _: arg_types[0],
3066
+ input_types={"a": Any},
3067
+ value_func=lambda arg_types, arg_values: arg_types["a"],
3068
+ hidden=True,
3069
+ export=False,
3070
+ group="Utility",
3071
+ )
3072
+ add_builtin(
3073
+ "assign",
3074
+ input_types={"dest": Any, "src": Any},
2714
3075
  hidden=True,
2715
3076
  export=False,
2716
3077
  group="Utility",
2717
3078
  )
2718
- add_builtin("assign", variadic=True, hidden=True, export=False, group="Utility")
2719
3079
  add_builtin(
2720
3080
  "select",
2721
- input_types={"cond": builtins.bool, "arg1": Any, "arg2": Any},
2722
- value_func=lambda arg_types, kwds, _: arg_types[1],
2723
- doc="Select between two arguments, if ``cond`` is ``False`` then return ``arg1``, otherwise return ``arg2``",
3081
+ input_types={"cond": builtins.bool, "value_if_false": Any, "value_if_true": Any},
3082
+ value_func=lambda arg_types, arg_values: arg_types["value_if_false"],
3083
+ doc="Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``",
2724
3084
  group="Utility",
2725
3085
  )
2726
3086
  for t in int_types:
2727
3087
  add_builtin(
2728
3088
  "select",
2729
- input_types={"cond": t, "arg1": Any, "arg2": Any},
2730
- value_func=lambda arg_types, kwds, _: arg_types[1],
2731
- doc="Select between two arguments, if ``cond`` is ``False`` then return ``arg1``, otherwise return ``arg2``",
3089
+ input_types={"cond": t, "value_if_false": Any, "value_if_true": Any},
3090
+ value_func=lambda arg_types, arg_values: arg_types["value_if_false"],
3091
+ doc="Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``",
2732
3092
  group="Utility",
2733
3093
  )
2734
3094
  add_builtin(
2735
3095
  "select",
2736
- input_types={"arr": array(dtype=Any), "arg1": Any, "arg2": Any},
2737
- value_func=lambda arg_types, kwds, _: arg_types[1],
2738
- doc="Select between two arguments, if ``arr`` is null then return ``arg1``, otherwise return ``arg2``",
3096
+ input_types={"arr": array(dtype=Any), "value_if_false": Any, "value_if_true": Any},
3097
+ value_func=lambda arg_types, arg_values: arg_types["value_if_false"],
3098
+ doc="Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``",
3099
+ group="Utility",
3100
+ )
3101
+
3102
+
3103
+ def array_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
3104
+ if arg_types is None:
3105
+ return array(dtype=Scalar)
3106
+
3107
+ dtype = arg_values["dtype"]
3108
+ shape = arg_values["shape"]
3109
+ return array(dtype=dtype, ndim=len(shape))
3110
+
3111
+
3112
+ def array_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
3113
+ # We're in the codegen stage where we emit the code calling the built-in.
3114
+ # Further validate the given argument values if needed and map them
3115
+ # to the underlying C++ function's runtime and template params.
3116
+
3117
+ dtype = return_type.dtype
3118
+
3119
+ func_args = (args["ptr"], *args["shape"])
3120
+ template_args = (dtype,)
3121
+ return (func_args, template_args)
3122
+
3123
+
3124
+ add_builtin(
3125
+ "array",
3126
+ input_types={"ptr": warp.uint64, "shape": Tuple[int, ...], "dtype": Scalar},
3127
+ value_func=array_value_func,
3128
+ export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
3129
+ dispatch_func=array_dispatch_func,
3130
+ native_func="array_t",
2739
3131
  group="Utility",
3132
+ hidden=True,
3133
+ export=False,
3134
+ missing_grad=True,
2740
3135
  )
2741
3136
 
2742
3137
 
2743
3138
  # does argument checking and type propagation for address()
2744
- def address_value_func(arg_types, kwds, _):
2745
- if not is_array(arg_types[0]):
2746
- raise RuntimeError("load() argument 0 must be an array")
3139
+ def address_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
3140
+ arr_type = arg_types["arr"]
3141
+ idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
3142
+
3143
+ if not is_array(arr_type):
3144
+ raise RuntimeError("address() first argument must be an array")
2747
3145
 
2748
- num_indices = len(arg_types) - 1
2749
- num_dims = arg_types[0].ndim
3146
+ idx_count = len(idx_types)
2750
3147
 
2751
- if num_indices < num_dims:
3148
+ if idx_count < arr_type.ndim:
2752
3149
  raise RuntimeError(
2753
3150
  "Num indices < num dimensions for array load, this is a codegen error, should have generated a view instead"
2754
3151
  )
2755
3152
 
2756
- if num_indices > num_dims:
3153
+ if idx_count > arr_type.ndim:
2757
3154
  raise RuntimeError(
2758
- f"Num indices > num dimensions for array load, received {num_indices}, but array only has {num_dims}"
3155
+ f"Num indices > num dimensions for array load, received {idx_count}, but array only has {arr_type.ndim}"
2759
3156
  )
2760
3157
 
2761
3158
  # check index types
2762
- for t in arg_types[1:]:
3159
+ for t in idx_types:
2763
3160
  if not type_is_int(t):
2764
- raise RuntimeError(f"address() index arguments must be of integer type, got index of type {t}")
3161
+ raise RuntimeError(f"address() index arguments must be of integer type, got index of type {type_repr(t)}")
2765
3162
 
2766
- return Reference(arg_types[0].dtype)
3163
+ return Reference(arr_type.dtype)
3164
+
3165
+
3166
+ for array_type in array_types:
3167
+ add_builtin(
3168
+ "address",
3169
+ input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int},
3170
+ defaults={"j": None, "k": None, "l": None},
3171
+ hidden=True,
3172
+ value_func=address_value_func,
3173
+ group="Utility",
3174
+ )
2767
3175
 
2768
3176
 
2769
3177
  # does argument checking and type propagation for view()
2770
- def view_value_func(arg_types, kwds, _):
2771
- if not is_array(arg_types[0]):
2772
- raise RuntimeError("view() argument 0 must be an array")
3178
+ def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
3179
+ arr_type = arg_types["arr"]
3180
+ idx_types = tuple(arg_types[x] for x in "ijk" if arg_types.get(x, None) is not None)
3181
+
3182
+ if not is_array(arr_type):
3183
+ raise RuntimeError("view() first argument must be an array")
2773
3184
 
2774
- # check array dim big enough to support view
2775
- num_indices = len(arg_types) - 1
2776
- num_dims = arg_types[0].ndim
3185
+ idx_count = len(idx_types)
2777
3186
 
2778
- if num_indices >= num_dims:
3187
+ if idx_count >= arr_type.ndim:
2779
3188
  raise RuntimeError(
2780
- f"Trying to create an array view with {num_indices} indices, but the array only has {num_dims} dimension(s). Ensure that the argument type on the function or kernel specifies the expected number of dimensions, e.g.: def func(param: wp.array3d(dtype=float):"
3189
+ f"Trying to create an array view with {idx_count} indices, "
3190
+ f"but the array only has {arr_type.ndim} dimension(s). "
3191
+ f"Ensure that the argument type on the function or kernel specifies "
3192
+ f"the expected number of dimensions, e.g.: def func(param: wp.array3d(dtype=float): ..."
2781
3193
  )
2782
3194
 
2783
3195
  # check index types
2784
- for t in arg_types[1:]:
3196
+ for t in idx_types:
2785
3197
  if not type_is_int(t):
2786
- raise RuntimeError(f"view() index arguments must be of integer type, got index of type {t}")
3198
+ raise RuntimeError(f"view() index arguments must be of integer type, got index of type {type_repr(t)}")
2787
3199
 
2788
3200
  # create an array view with leading dimensions removed
2789
- dtype = arg_types[0].dtype
2790
- ndim = num_dims - num_indices
2791
- if isinstance(arg_types[0], (fabricarray, indexedfabricarray)):
3201
+ dtype = arr_type.dtype
3202
+ ndim = arr_type.ndim - idx_count
3203
+ if isinstance(arr_type, (fabricarray, indexedfabricarray)):
2792
3204
  # fabric array of arrays: return array attribute as a regular array
2793
3205
  return array(dtype=dtype, ndim=ndim)
2794
- else:
2795
- return type(arg_types[0])(dtype=dtype, ndim=ndim)
3206
+
3207
+ return type(arr_type)(dtype=dtype, ndim=ndim)
3208
+
3209
+
3210
+ for array_type in array_types:
3211
+ add_builtin(
3212
+ "view",
3213
+ input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int},
3214
+ defaults={"j": None, "k": None},
3215
+ hidden=True,
3216
+ value_func=view_value_func,
3217
+ group="Utility",
3218
+ )
2796
3219
 
2797
3220
 
2798
3221
  # does argument checking and type propagation for array_store()
2799
- def array_store_value_func(arg_types, kwds, _):
2800
- # check target type
2801
- if not is_array(arg_types[0]):
2802
- raise RuntimeError("array_store() argument 0 must be an array")
3222
+ def array_store_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
3223
+ arr_type = arg_types["arr"]
3224
+ value_type = arg_types["value"]
3225
+ idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
2803
3226
 
2804
- num_indices = len(arg_types[1:-1])
2805
- num_dims = arg_types[0].ndim
3227
+ if not is_array(arr_type):
3228
+ raise RuntimeError("array_store() first argument must be an array")
2806
3229
 
2807
- # if this happens we should have generated a view instead of a load during code gen
2808
- if num_indices < num_dims:
2809
- raise RuntimeError("Num indices < num dimensions for array store")
3230
+ idx_count = len(idx_types)
2810
3231
 
2811
- if num_indices > num_dims:
3232
+ if idx_count < arr_type.ndim:
2812
3233
  raise RuntimeError(
2813
- f"Num indices > num dimensions for array store, received {num_indices}, but array only has {num_dims}"
3234
+ "Num indices < num dimensions for array store, this is a codegen error, should have generated a view instead"
3235
+ )
3236
+
3237
+ if idx_count > arr_type.ndim:
3238
+ raise RuntimeError(
3239
+ f"Num indices > num dimensions for array store, received {idx_count}, but array only has {arr_type.ndim}"
2814
3240
  )
2815
3241
 
2816
3242
  # check index types
2817
- for t in arg_types[1:-1]:
3243
+ for t in idx_types:
2818
3244
  if not type_is_int(t):
2819
- raise RuntimeError(f"array_store() index arguments must be of integer type, got index of type {t}")
3245
+ raise RuntimeError(
3246
+ f"array_store() index arguments must be of integer type, got index of type {type_repr(t)}"
3247
+ )
2820
3248
 
2821
3249
  # check value type
2822
- if not types_equal(arg_types[-1], arg_types[0].dtype):
3250
+ if not types_equal(arr_type.dtype, value_type):
2823
3251
  raise RuntimeError(
2824
- f"array_store() value argument type ({arg_types[2]}) must be of the same type as the array ({arg_types[0].dtype})"
3252
+ f"array_store() value argument type ({type_repr(value_type)}) must be of the same type as the array ({type_repr(arr_type.dtype)})"
2825
3253
  )
2826
3254
 
2827
3255
  return None
2828
3256
 
2829
3257
 
3258
+ for array_type in array_types:
3259
+ add_builtin(
3260
+ "array_store",
3261
+ input_types={"arr": array_type(dtype=Any), "i": int, "value": Any},
3262
+ hidden=True,
3263
+ value_func=array_store_value_func,
3264
+ skip_replay=True,
3265
+ group="Utility",
3266
+ )
3267
+ add_builtin(
3268
+ "array_store",
3269
+ input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "value": Any},
3270
+ hidden=True,
3271
+ value_func=array_store_value_func,
3272
+ skip_replay=True,
3273
+ group="Utility",
3274
+ )
3275
+ add_builtin(
3276
+ "array_store",
3277
+ input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
3278
+ hidden=True,
3279
+ value_func=array_store_value_func,
3280
+ skip_replay=True,
3281
+ group="Utility",
3282
+ )
3283
+ add_builtin(
3284
+ "array_store",
3285
+ input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
3286
+ hidden=True,
3287
+ value_func=array_store_value_func,
3288
+ skip_replay=True,
3289
+ group="Utility",
3290
+ )
3291
+
3292
+
2830
3293
  # does argument checking for store()
2831
- def store_value_func(arg_types, kwds, _):
3294
+ def store_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
2832
3295
  # we already stripped the Reference from the argument type prior to this call
2833
- if not types_equal(arg_types[0], arg_types[1]):
2834
- raise RuntimeError(f"store() value argument type ({arg_types[1]}) must be of the same type as the reference")
3296
+ if not types_equal(arg_types["address"], arg_types["value"]):
3297
+ raise RuntimeError(
3298
+ f"store() value argument type ({arg_types['value']}) must be of the same type as the reference"
3299
+ )
2835
3300
 
2836
3301
  return None
2837
3302
 
2838
3303
 
2839
- # does type propagation for load()
2840
- def load_value_func(arg_types, kwds, _):
2841
- # we already stripped the Reference from the argument type prior to this call
2842
- return arg_types[0]
3304
+ def store_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
3305
+ func_args = (Reference(args["address"]), args["value"])
3306
+ template_args = ()
3307
+ return (func_args, template_args)
2843
3308
 
2844
3309
 
2845
- add_builtin("address", variadic=True, hidden=True, value_func=address_value_func, group="Utility")
2846
- add_builtin("view", variadic=True, hidden=True, value_func=view_value_func, group="Utility")
2847
- add_builtin(
2848
- "array_store", variadic=True, hidden=True, value_func=array_store_value_func, skip_replay=True, group="Utility"
2849
- )
2850
3310
  add_builtin(
2851
3311
  "store",
2852
- input_types={"address": Reference, "value": Any},
2853
- hidden=True,
3312
+ input_types={"address": Any, "value": Any},
2854
3313
  value_func=store_value_func,
3314
+ dispatch_func=store_dispatch_func,
3315
+ hidden=True,
2855
3316
  skip_replay=True,
2856
3317
  group="Utility",
2857
3318
  )
3319
+
3320
+
3321
+ def load_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
3322
+ func_args = (Reference(args["address"]),)
3323
+ template_args = ()
3324
+ return (func_args, template_args)
3325
+
3326
+
2858
3327
  add_builtin(
2859
3328
  "load",
2860
- input_types={"address": Reference},
3329
+ input_types={"address": Any},
3330
+ value_func=lambda arg_types, arg_values: arg_types["address"],
3331
+ dispatch_func=load_dispatch_func,
2861
3332
  hidden=True,
2862
- value_func=load_value_func,
2863
3333
  group="Utility",
2864
3334
  )
2865
3335
 
2866
3336
 
2867
- def atomic_op_value_func(arg_types, kwds, _):
2868
- # check target type
2869
- if not is_array(arg_types[0]):
2870
- raise RuntimeError("atomic() operation argument 0 must be an array")
3337
+ def atomic_op_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
3338
+ arr_type = arg_types["arr"]
3339
+ value_type = arg_types["value"]
3340
+ idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
2871
3341
 
2872
- num_indices = len(arg_types[1:-1])
2873
- num_dims = arg_types[0].ndim
3342
+ if not is_array(arr_type):
3343
+ raise RuntimeError("atomic() first argument must be an array")
2874
3344
 
2875
- # if this happens we should have generated a view instead of a load during code gen
2876
- if num_indices < num_dims:
2877
- raise RuntimeError("Num indices < num dimensions for atomic array operation")
3345
+ idx_count = len(idx_types)
2878
3346
 
2879
- if num_indices > num_dims:
3347
+ if idx_count < arr_type.ndim:
2880
3348
  raise RuntimeError(
2881
- f"Num indices > num dimensions for atomic array operation, received {num_indices}, but array only has {num_dims}"
3349
+ "Num indices < num dimensions for atomic, this is a codegen error, should have generated a view instead"
3350
+ )
3351
+
3352
+ if idx_count > arr_type.ndim:
3353
+ raise RuntimeError(
3354
+ f"Num indices > num dimensions for atomic, received {idx_count}, but array only has {arr_type.ndim}"
2882
3355
  )
2883
3356
 
2884
3357
  # check index types
2885
- for t in arg_types[1:-1]:
3358
+ for t in idx_types:
2886
3359
  if not type_is_int(t):
2887
- raise RuntimeError(f"atomic() operation index arguments must be of integer type, got index of type {t}")
3360
+ raise RuntimeError(f"atomic() index arguments must be of integer type, got index of type {type_repr(t)}")
2888
3361
 
2889
- if not types_equal(arg_types[-1], arg_types[0].dtype):
3362
+ # check value type
3363
+ if not types_equal(arr_type.dtype, value_type):
2890
3364
  raise RuntimeError(
2891
- f"atomic() value argument ({arg_types[-1]}) must be of the same type as the array ({arg_types[0].dtype})"
3365
+ f"atomic() value argument type ({type_repr(value_type)}) must be of the same type as the array ({type_repr(arr_type.dtype)})"
2892
3366
  )
2893
3367
 
2894
- return arg_types[0].dtype
3368
+ return arr_type.dtype
2895
3369
 
2896
3370
 
2897
3371
  for array_type in array_types:
@@ -2901,36 +3375,36 @@ for array_type in array_types:
2901
3375
  add_builtin(
2902
3376
  "atomic_add",
2903
3377
  hidden=hidden,
2904
- input_types={"a": array_type(dtype=Any), "i": int, "value": Any},
3378
+ input_types={"arr": array_type(dtype=Any), "i": int, "value": Any},
2905
3379
  value_func=atomic_op_value_func,
2906
- doc="Atomically add ``value`` onto ``a[i]``.",
3380
+ doc="Atomically add ``value`` onto ``arr[i]``.",
2907
3381
  group="Utility",
2908
3382
  skip_replay=True,
2909
3383
  )
2910
3384
  add_builtin(
2911
3385
  "atomic_add",
2912
3386
  hidden=hidden,
2913
- input_types={"a": array_type(dtype=Any), "i": int, "j": int, "value": Any},
3387
+ input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "value": Any},
2914
3388
  value_func=atomic_op_value_func,
2915
- doc="Atomically add ``value`` onto ``a[i,j]``.",
3389
+ doc="Atomically add ``value`` onto ``arr[i,j]``.",
2916
3390
  group="Utility",
2917
3391
  skip_replay=True,
2918
3392
  )
2919
3393
  add_builtin(
2920
3394
  "atomic_add",
2921
3395
  hidden=hidden,
2922
- input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
3396
+ input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
2923
3397
  value_func=atomic_op_value_func,
2924
- doc="Atomically add ``value`` onto ``a[i,j,k]``.",
3398
+ doc="Atomically add ``value`` onto ``arr[i,j,k]``.",
2925
3399
  group="Utility",
2926
3400
  skip_replay=True,
2927
3401
  )
2928
3402
  add_builtin(
2929
3403
  "atomic_add",
2930
3404
  hidden=hidden,
2931
- input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
3405
+ input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
2932
3406
  value_func=atomic_op_value_func,
2933
- doc="Atomically add ``value`` onto ``a[i,j,k,l]``.",
3407
+ doc="Atomically add ``value`` onto ``arr[i,j,k,l]``.",
2934
3408
  group="Utility",
2935
3409
  skip_replay=True,
2936
3410
  )
@@ -2938,36 +3412,36 @@ for array_type in array_types:
2938
3412
  add_builtin(
2939
3413
  "atomic_sub",
2940
3414
  hidden=hidden,
2941
- input_types={"a": array_type(dtype=Any), "i": int, "value": Any},
3415
+ input_types={"arr": array_type(dtype=Any), "i": int, "value": Any},
2942
3416
  value_func=atomic_op_value_func,
2943
- doc="Atomically subtract ``value`` onto ``a[i]``.",
3417
+ doc="Atomically subtract ``value`` onto ``arr[i]``.",
2944
3418
  group="Utility",
2945
3419
  skip_replay=True,
2946
3420
  )
2947
3421
  add_builtin(
2948
3422
  "atomic_sub",
2949
3423
  hidden=hidden,
2950
- input_types={"a": array_type(dtype=Any), "i": int, "j": int, "value": Any},
3424
+ input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "value": Any},
2951
3425
  value_func=atomic_op_value_func,
2952
- doc="Atomically subtract ``value`` onto ``a[i,j]``.",
3426
+ doc="Atomically subtract ``value`` onto ``arr[i,j]``.",
2953
3427
  group="Utility",
2954
3428
  skip_replay=True,
2955
3429
  )
2956
3430
  add_builtin(
2957
3431
  "atomic_sub",
2958
3432
  hidden=hidden,
2959
- input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
3433
+ input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
2960
3434
  value_func=atomic_op_value_func,
2961
- doc="Atomically subtract ``value`` onto ``a[i,j,k]``.",
3435
+ doc="Atomically subtract ``value`` onto ``arr[i,j,k]``.",
2962
3436
  group="Utility",
2963
3437
  skip_replay=True,
2964
3438
  )
2965
3439
  add_builtin(
2966
3440
  "atomic_sub",
2967
3441
  hidden=hidden,
2968
- input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
3442
+ input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
2969
3443
  value_func=atomic_op_value_func,
2970
- doc="Atomically subtract ``value`` onto ``a[i,j,k,l]``.",
3444
+ doc="Atomically subtract ``value`` onto ``arr[i,j,k,l]``.",
2971
3445
  group="Utility",
2972
3446
  skip_replay=True,
2973
3447
  )
@@ -2975,9 +3449,9 @@ for array_type in array_types:
2975
3449
  add_builtin(
2976
3450
  "atomic_min",
2977
3451
  hidden=hidden,
2978
- input_types={"a": array_type(dtype=Any), "i": int, "value": Any},
3452
+ input_types={"arr": array_type(dtype=Any), "i": int, "value": Any},
2979
3453
  value_func=atomic_op_value_func,
2980
- doc="""Compute the minimum of ``value`` and ``a[i]`` and atomically update the array.
3454
+ doc="""Compute the minimum of ``value`` and ``arr[i]`` and atomically update the array.
2981
3455
 
2982
3456
  .. note:: The operation is only atomic on a per-component basis for vectors and matrices.""",
2983
3457
  group="Utility",
@@ -2986,9 +3460,9 @@ for array_type in array_types:
2986
3460
  add_builtin(
2987
3461
  "atomic_min",
2988
3462
  hidden=hidden,
2989
- input_types={"a": array_type(dtype=Any), "i": int, "j": int, "value": Any},
3463
+ input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "value": Any},
2990
3464
  value_func=atomic_op_value_func,
2991
- doc="""Compute the minimum of ``value`` and ``a[i,j]`` and atomically update the array.
3465
+ doc="""Compute the minimum of ``value`` and ``arr[i,j]`` and atomically update the array.
2992
3466
 
2993
3467
  .. note:: The operation is only atomic on a per-component basis for vectors and matrices.""",
2994
3468
  group="Utility",
@@ -2997,9 +3471,9 @@ for array_type in array_types:
2997
3471
  add_builtin(
2998
3472
  "atomic_min",
2999
3473
  hidden=hidden,
3000
- input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
3474
+ input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
3001
3475
  value_func=atomic_op_value_func,
3002
- doc="""Compute the minimum of ``value`` and ``a[i,j,k]`` and atomically update the array.
3476
+ doc="""Compute the minimum of ``value`` and ``arr[i,j,k]`` and atomically update the array.
3003
3477
 
3004
3478
  .. note:: The operation is only atomic on a per-component basis for vectors and matrices.""",
3005
3479
  group="Utility",
@@ -3008,9 +3482,9 @@ for array_type in array_types:
3008
3482
  add_builtin(
3009
3483
  "atomic_min",
3010
3484
  hidden=hidden,
3011
- input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
3485
+ input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
3012
3486
  value_func=atomic_op_value_func,
3013
- doc="""Compute the minimum of ``value`` and ``a[i,j,k,l]`` and atomically update the array.
3487
+ doc="""Compute the minimum of ``value`` and ``arr[i,j,k,l]`` and atomically update the array.
3014
3488
 
3015
3489
  .. note:: The operation is only atomic on a per-component basis for vectors and matrices.""",
3016
3490
  group="Utility",
@@ -3020,9 +3494,9 @@ for array_type in array_types:
3020
3494
  add_builtin(
3021
3495
  "atomic_max",
3022
3496
  hidden=hidden,
3023
- input_types={"a": array_type(dtype=Any), "i": int, "value": Any},
3497
+ input_types={"arr": array_type(dtype=Any), "i": int, "value": Any},
3024
3498
  value_func=atomic_op_value_func,
3025
- doc="""Compute the maximum of ``value`` and ``a[i]`` and atomically update the array.
3499
+ doc="""Compute the maximum of ``value`` and ``arr[i]`` and atomically update the array.
3026
3500
 
3027
3501
  .. note:: The operation is only atomic on a per-component basis for vectors and matrices.""",
3028
3502
  group="Utility",
@@ -3031,9 +3505,9 @@ for array_type in array_types:
3031
3505
  add_builtin(
3032
3506
  "atomic_max",
3033
3507
  hidden=hidden,
3034
- input_types={"a": array_type(dtype=Any), "i": int, "j": int, "value": Any},
3508
+ input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "value": Any},
3035
3509
  value_func=atomic_op_value_func,
3036
- doc="""Compute the maximum of ``value`` and ``a[i,j]`` and atomically update the array.
3510
+ doc="""Compute the maximum of ``value`` and ``arr[i,j]`` and atomically update the array.
3037
3511
 
3038
3512
  .. note:: The operation is only atomic on a per-component basis for vectors and matrices.""",
3039
3513
  group="Utility",
@@ -3042,9 +3516,9 @@ for array_type in array_types:
3042
3516
  add_builtin(
3043
3517
  "atomic_max",
3044
3518
  hidden=hidden,
3045
- input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
3519
+ input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
3046
3520
  value_func=atomic_op_value_func,
3047
- doc="""Compute the maximum of ``value`` and ``a[i,j,k]`` and atomically update the array.
3521
+ doc="""Compute the maximum of ``value`` and ``arr[i,j,k]`` and atomically update the array.
3048
3522
 
3049
3523
  .. note:: The operation is only atomic on a per-component basis for vectors and matrices.""",
3050
3524
  group="Utility",
@@ -3053,9 +3527,9 @@ for array_type in array_types:
3053
3527
  add_builtin(
3054
3528
  "atomic_max",
3055
3529
  hidden=hidden,
3056
- input_types={"a": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
3530
+ input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
3057
3531
  value_func=atomic_op_value_func,
3058
- doc="""Compute the maximum of ``value`` and ``a[i,j,k,l]`` and atomically update the array.
3532
+ doc="""Compute the maximum of ``value`` and ``arr[i,j,k,l]`` and atomically update the array.
3059
3533
 
3060
3534
  .. note:: The operation is only atomic on a per-component basis for vectors and matrices.""",
3061
3535
  group="Utility",
@@ -3064,21 +3538,21 @@ for array_type in array_types:
3064
3538
 
3065
3539
 
3066
3540
  # used to index into builtin types, i.e.: y = vec3[1]
3067
- def index_value_func(arg_types, kwds, _):
3068
- return arg_types[0]._wp_scalar_type_
3541
+ def extract_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
3542
+ return arg_types["a"]._wp_scalar_type_
3069
3543
 
3070
3544
 
3071
3545
  add_builtin(
3072
3546
  "extract",
3073
3547
  input_types={"a": vector(length=Any, dtype=Scalar), "i": int},
3074
- value_func=index_value_func,
3548
+ value_func=extract_value_func,
3075
3549
  hidden=True,
3076
3550
  group="Utility",
3077
3551
  )
3078
3552
  add_builtin(
3079
3553
  "extract",
3080
3554
  input_types={"a": quaternion(dtype=Scalar), "i": int},
3081
- value_func=index_value_func,
3555
+ value_func=extract_value_func,
3082
3556
  hidden=True,
3083
3557
  group="Utility",
3084
3558
  )
@@ -3086,14 +3560,16 @@ add_builtin(
3086
3560
  add_builtin(
3087
3561
  "extract",
3088
3562
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int},
3089
- value_func=lambda arg_types, kwds, _: vector(length=arg_types[0]._shape_[1], dtype=arg_types[0]._wp_scalar_type_),
3563
+ value_func=lambda arg_types, arg_values: vector(
3564
+ length=arg_types["a"]._shape_[1], dtype=arg_types["a"]._wp_scalar_type_
3565
+ ),
3090
3566
  hidden=True,
3091
3567
  group="Utility",
3092
3568
  )
3093
3569
  add_builtin(
3094
3570
  "extract",
3095
3571
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int},
3096
- value_func=index_value_func,
3572
+ value_func=extract_value_func,
3097
3573
  hidden=True,
3098
3574
  group="Utility",
3099
3575
  )
@@ -3101,7 +3577,7 @@ add_builtin(
3101
3577
  add_builtin(
3102
3578
  "extract",
3103
3579
  input_types={"a": transformation(dtype=Scalar), "i": int},
3104
- value_func=index_value_func,
3580
+ value_func=extract_value_func,
3105
3581
  hidden=True,
3106
3582
  group="Utility",
3107
3583
  )
@@ -3109,19 +3585,35 @@ add_builtin(
3109
3585
  add_builtin("extract", input_types={"s": shape_t, "i": int}, value_type=int, hidden=True, group="Utility")
3110
3586
 
3111
3587
 
3112
- def vector_indexref_element_value_func(arg_types, kwds, _):
3113
- vec_type = arg_types[0]
3114
- # index_type = arg_types[1]
3588
+ def vector_index_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
3589
+ vec_type = arg_types["a"]
3115
3590
  value_type = vec_type._wp_scalar_type_
3116
3591
 
3117
3592
  return Reference(value_type)
3118
3593
 
3119
3594
 
3595
+ def vector_index_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
3596
+ func_args = (Reference(args["a"]), args["i"])
3597
+ template_args = ()
3598
+ return (func_args, template_args)
3599
+
3600
+
3120
3601
  # implements &vector[index]
3121
3602
  add_builtin(
3122
3603
  "index",
3123
3604
  input_types={"a": vector(length=Any, dtype=Scalar), "i": int},
3124
- value_func=vector_indexref_element_value_func,
3605
+ value_func=vector_index_value_func,
3606
+ dispatch_func=vector_index_dispatch_func,
3607
+ hidden=True,
3608
+ group="Utility",
3609
+ skip_replay=True,
3610
+ )
3611
+ # implements &quaternion[index]
3612
+ add_builtin(
3613
+ "index",
3614
+ input_types={"a": quaternion(dtype=Float), "i": int},
3615
+ value_func=vector_index_value_func,
3616
+ dispatch_func=vector_index_dispatch_func,
3125
3617
  hidden=True,
3126
3618
  group="Utility",
3127
3619
  skip_replay=True,
@@ -3129,27 +3621,28 @@ add_builtin(
3129
3621
  # implements &(*vector)[index]
3130
3622
  add_builtin(
3131
3623
  "indexref",
3132
- input_types={"a": Reference, "i": int},
3133
- value_func=vector_indexref_element_value_func,
3624
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": int},
3625
+ value_func=vector_index_value_func,
3626
+ dispatch_func=vector_index_dispatch_func,
3627
+ hidden=True,
3628
+ group="Utility",
3629
+ skip_replay=True,
3630
+ )
3631
+ # implements &(*quaternion)[index]
3632
+ add_builtin(
3633
+ "indexref",
3634
+ input_types={"a": quaternion(dtype=Float), "i": int},
3635
+ value_func=vector_index_value_func,
3636
+ dispatch_func=vector_index_dispatch_func,
3134
3637
  hidden=True,
3135
3638
  group="Utility",
3136
3639
  skip_replay=True,
3137
3640
  )
3138
3641
 
3139
3642
 
3140
- def matrix_indexref_element_value_func(arg_types, kwds, _):
3141
- mat_type = arg_types[0]
3142
- # row_type = arg_types[1]
3143
- # col_type = arg_types[2]
3144
- value_type = mat_type._wp_scalar_type_
3145
-
3146
- return Reference(value_type)
3147
-
3148
-
3149
- def matrix_indexref_row_value_func(arg_types, kwds, _):
3150
- mat_type = arg_types[0]
3643
+ def matrix_index_row_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
3644
+ mat_type = arg_types["a"]
3151
3645
  row_type = mat_type._wp_row_type_
3152
- # value_type = arg_types[2]
3153
3646
 
3154
3647
  return Reference(row_type)
3155
3648
 
@@ -3158,17 +3651,25 @@ def matrix_indexref_row_value_func(arg_types, kwds, _):
3158
3651
  add_builtin(
3159
3652
  "index",
3160
3653
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int},
3161
- value_func=matrix_indexref_row_value_func,
3654
+ value_func=matrix_index_row_value_func,
3162
3655
  hidden=True,
3163
3656
  group="Utility",
3164
3657
  skip_replay=True,
3165
3658
  )
3166
3659
 
3660
+
3661
+ def matrix_index_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
3662
+ mat_type = arg_types["a"]
3663
+ value_type = mat_type._wp_scalar_type_
3664
+
3665
+ return Reference(value_type)
3666
+
3667
+
3167
3668
  # implements matrix[i,j] = scalar
3168
3669
  add_builtin(
3169
3670
  "index",
3170
3671
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int},
3171
- value_func=matrix_indexref_element_value_func,
3672
+ value_func=matrix_index_value_func,
3172
3673
  hidden=True,
3173
3674
  group="Utility",
3174
3675
  skip_replay=True,
@@ -3177,56 +3678,58 @@ add_builtin(
3177
3678
  for t in scalar_types + vector_types + (bool,):
3178
3679
  if "vec" in t.__name__ or "mat" in t.__name__:
3179
3680
  continue
3681
+
3180
3682
  add_builtin(
3181
3683
  "expect_eq",
3182
- input_types={"arg1": t, "arg2": t},
3684
+ input_types={"a": t, "b": t},
3183
3685
  value_type=None,
3184
- doc="Prints an error to stdout if ``arg1`` and ``arg2`` are not equal",
3686
+ doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
3185
3687
  group="Utility",
3186
3688
  hidden=True,
3187
3689
  )
3188
3690
 
3189
3691
 
3190
- def expect_eq_val_func(arg_types, kwds, _):
3191
- if not types_equal(arg_types[0], arg_types[1]):
3692
+ def expect_eq_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
3693
+ if not types_equal(arg_types["a"], arg_types["b"]):
3192
3694
  raise RuntimeError("Can't test equality for objects with different types")
3695
+
3193
3696
  return None
3194
3697
 
3195
3698
 
3196
3699
  add_builtin(
3197
3700
  "expect_eq",
3198
- input_types={"arg1": vector(length=Any, dtype=Scalar), "arg2": vector(length=Any, dtype=Scalar)},
3701
+ input_types={"a": vector(length=Any, dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
3199
3702
  constraint=sametypes,
3200
- value_func=expect_eq_val_func,
3201
- doc="Prints an error to stdout if ``arg1`` and ``arg2`` are not equal",
3703
+ value_func=expect_eq_value_func,
3704
+ doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
3202
3705
  group="Utility",
3203
3706
  hidden=True,
3204
3707
  )
3205
3708
  add_builtin(
3206
3709
  "expect_neq",
3207
- input_types={"arg1": vector(length=Any, dtype=Scalar), "arg2": vector(length=Any, dtype=Scalar)},
3710
+ input_types={"a": vector(length=Any, dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
3208
3711
  constraint=sametypes,
3209
- value_func=expect_eq_val_func,
3210
- doc="Prints an error to stdout if ``arg1`` and ``arg2`` are equal",
3712
+ value_func=expect_eq_value_func,
3713
+ doc="Prints an error to stdout if ``a`` and ``b`` are equal",
3211
3714
  group="Utility",
3212
3715
  hidden=True,
3213
3716
  )
3214
3717
 
3215
3718
  add_builtin(
3216
3719
  "expect_eq",
3217
- input_types={"arg1": matrix(shape=(Any, Any), dtype=Scalar), "arg2": matrix(shape=(Any, Any), dtype=Scalar)},
3720
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "b": matrix(shape=(Any, Any), dtype=Scalar)},
3218
3721
  constraint=sametypes,
3219
- value_func=expect_eq_val_func,
3220
- doc="Prints an error to stdout if ``arg1`` and ``arg2`` are not equal",
3722
+ value_func=expect_eq_value_func,
3723
+ doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
3221
3724
  group="Utility",
3222
3725
  hidden=True,
3223
3726
  )
3224
3727
  add_builtin(
3225
3728
  "expect_neq",
3226
- input_types={"arg1": matrix(shape=(Any, Any), dtype=Scalar), "arg2": matrix(shape=(Any, Any), dtype=Scalar)},
3729
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "b": matrix(shape=(Any, Any), dtype=Scalar)},
3227
3730
  constraint=sametypes,
3228
- value_func=expect_eq_val_func,
3229
- doc="Prints an error to stdout if ``arg1`` and ``arg2`` are equal",
3731
+ value_func=expect_eq_value_func,
3732
+ doc="Prints an error to stdout if ``a`` and ``b`` are equal",
3230
3733
  group="Utility",
3231
3734
  hidden=True,
3232
3735
  )
@@ -3234,35 +3737,36 @@ add_builtin(
3234
3737
  add_builtin(
3235
3738
  "lerp",
3236
3739
  input_types={"a": Float, "b": Float, "t": Float},
3237
- value_func=sametype_value_func(Float),
3740
+ value_func=sametypes_create_value_func(Float),
3238
3741
  doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
3239
3742
  group="Utility",
3240
3743
  )
3241
3744
  add_builtin(
3242
3745
  "smoothstep",
3243
- input_types={"edge0": Float, "edge1": Float, "x": Float},
3244
- value_func=sametype_value_func(Float),
3245
- doc="""Smoothly interpolate between two values ``edge0`` and ``edge1`` using a factor ``x``,
3746
+ input_types={"a": Float, "b": Float, "x": Float},
3747
+ value_func=sametypes_create_value_func(Float),
3748
+ doc="""Smoothly interpolate between two values ``a`` and ``b`` using a factor ``x``,
3246
3749
  and return a result between 0 and 1 using a cubic Hermite interpolation after clamping.""",
3247
3750
  group="Utility",
3248
3751
  )
3249
3752
 
3250
3753
 
3251
- def lerp_constraint(arg_types):
3252
- return types_equal(arg_types[0], arg_types[1])
3754
+ def lerp_constraint(arg_types: Mapping[str, type]):
3755
+ return types_equal(arg_types["a"], arg_types["b"])
3253
3756
 
3254
3757
 
3255
- def lerp_value_func(default):
3256
- def fn(arg_types, kwds, _):
3758
+ def lerp_create_value_func(default):
3759
+ def fn(arg_types, arg_values):
3257
3760
  if arg_types is None:
3258
3761
  return default
3259
- scalar_type = arg_types[-1]
3762
+
3260
3763
  if not lerp_constraint(arg_types):
3261
3764
  raise RuntimeError("Can't lerp between objects with different types")
3262
- if arg_types[0]._wp_scalar_type_ != scalar_type:
3765
+
3766
+ if arg_types["a"]._wp_scalar_type_ != arg_types["t"]:
3263
3767
  raise RuntimeError("'t' parameter must have the same scalar type as objects you're lerping between")
3264
3768
 
3265
- return arg_types[0]
3769
+ return arg_types["a"]
3266
3770
 
3267
3771
  return fn
3268
3772
 
@@ -3271,7 +3775,7 @@ add_builtin(
3271
3775
  "lerp",
3272
3776
  input_types={"a": vector(length=Any, dtype=Float), "b": vector(length=Any, dtype=Float), "t": Float},
3273
3777
  constraint=lerp_constraint,
3274
- value_func=lerp_value_func(vector(length=Any, dtype=Float)),
3778
+ value_func=lerp_create_value_func(vector(length=Any, dtype=Float)),
3275
3779
  doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
3276
3780
  group="Utility",
3277
3781
  )
@@ -3279,21 +3783,21 @@ add_builtin(
3279
3783
  "lerp",
3280
3784
  input_types={"a": matrix(shape=(Any, Any), dtype=Float), "b": matrix(shape=(Any, Any), dtype=Float), "t": Float},
3281
3785
  constraint=lerp_constraint,
3282
- value_func=lerp_value_func(matrix(shape=(Any, Any), dtype=Float)),
3786
+ value_func=lerp_create_value_func(matrix(shape=(Any, Any), dtype=Float)),
3283
3787
  doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
3284
3788
  group="Utility",
3285
3789
  )
3286
3790
  add_builtin(
3287
3791
  "lerp",
3288
3792
  input_types={"a": quaternion(dtype=Float), "b": quaternion(dtype=Float), "t": Float},
3289
- value_func=lerp_value_func(quaternion(dtype=Float)),
3793
+ value_func=lerp_create_value_func(quaternion(dtype=Float)),
3290
3794
  doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
3291
3795
  group="Utility",
3292
3796
  )
3293
3797
  add_builtin(
3294
3798
  "lerp",
3295
3799
  input_types={"a": transformation(dtype=Float), "b": transformation(dtype=Float), "t": Float},
3296
- value_func=lerp_value_func(transformation(dtype=Float)),
3800
+ value_func=lerp_create_value_func(transformation(dtype=Float)),
3297
3801
  doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
3298
3802
  group="Utility",
3299
3803
  )
@@ -3301,17 +3805,18 @@ add_builtin(
3301
3805
  # fuzzy compare for float values
3302
3806
  add_builtin(
3303
3807
  "expect_near",
3304
- input_types={"arg1": Float, "arg2": Float, "tolerance": Float},
3808
+ input_types={"a": Float, "b": Float, "tolerance": Float},
3305
3809
  defaults={"tolerance": 1.0e-6},
3306
3810
  value_type=None,
3307
- doc="Prints an error to stdout if ``arg1`` and ``arg2`` are not closer than tolerance in magnitude",
3811
+ doc="Prints an error to stdout if ``a`` and ``b`` are not closer than tolerance in magnitude",
3308
3812
  group="Utility",
3309
3813
  )
3310
3814
  add_builtin(
3311
3815
  "expect_near",
3312
- input_types={"arg1": vec3, "arg2": vec3, "tolerance": float},
3816
+ input_types={"a": vec3, "b": vec3, "tolerance": float},
3817
+ defaults={"tolerance": 1.0e-6},
3313
3818
  value_type=None,
3314
- doc="Prints an error to stdout if any element of ``arg1`` and ``arg2`` are not closer than tolerance in magnitude",
3819
+ doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
3315
3820
  group="Utility",
3316
3821
  )
3317
3822
 
@@ -3335,359 +3840,378 @@ add_builtin(
3335
3840
  # ---------------------------------
3336
3841
  # Operators
3337
3842
 
3338
- add_builtin("add", input_types={"x": Scalar, "y": Scalar}, value_func=sametype_value_func(Scalar), group="Operators")
3843
+ add_builtin(
3844
+ "add", input_types={"a": Scalar, "b": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators"
3845
+ )
3339
3846
  add_builtin(
3340
3847
  "add",
3341
- input_types={"x": vector(length=Any, dtype=Scalar), "y": vector(length=Any, dtype=Scalar)},
3848
+ input_types={"a": vector(length=Any, dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
3342
3849
  constraint=sametypes,
3343
- value_func=sametype_value_func(vector(length=Any, dtype=Scalar)),
3850
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
3344
3851
  doc="",
3345
3852
  group="Operators",
3346
3853
  )
3347
3854
  add_builtin(
3348
3855
  "add",
3349
- input_types={"x": quaternion(dtype=Scalar), "y": quaternion(dtype=Scalar)},
3350
- value_func=sametype_value_func(quaternion(dtype=Scalar)),
3856
+ input_types={"a": quaternion(dtype=Scalar), "b": quaternion(dtype=Scalar)},
3857
+ value_func=sametypes_create_value_func(quaternion(dtype=Scalar)),
3351
3858
  doc="",
3352
3859
  group="Operators",
3353
3860
  )
3354
3861
  add_builtin(
3355
3862
  "add",
3356
- input_types={"x": matrix(shape=(Any, Any), dtype=Scalar), "y": matrix(shape=(Any, Any), dtype=Scalar)},
3863
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "b": matrix(shape=(Any, Any), dtype=Scalar)},
3357
3864
  constraint=sametypes,
3358
- value_func=sametype_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
3865
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
3359
3866
  doc="",
3360
3867
  group="Operators",
3361
3868
  )
3362
3869
  add_builtin(
3363
3870
  "add",
3364
- input_types={"x": transformation(dtype=Scalar), "y": transformation(dtype=Scalar)},
3365
- value_func=sametype_value_func(transformation(dtype=Scalar)),
3871
+ input_types={"a": transformation(dtype=Scalar), "b": transformation(dtype=Scalar)},
3872
+ value_func=sametypes_create_value_func(transformation(dtype=Scalar)),
3366
3873
  doc="",
3367
3874
  group="Operators",
3368
3875
  )
3369
3876
 
3370
- add_builtin("sub", input_types={"x": Scalar, "y": Scalar}, value_func=sametype_value_func(Scalar), group="Operators")
3877
+ add_builtin(
3878
+ "sub", input_types={"a": Scalar, "b": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators"
3879
+ )
3371
3880
  add_builtin(
3372
3881
  "sub",
3373
- input_types={"x": vector(length=Any, dtype=Scalar), "y": vector(length=Any, dtype=Scalar)},
3882
+ input_types={"a": vector(length=Any, dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
3374
3883
  constraint=sametypes,
3375
- value_func=sametype_value_func(vector(length=Any, dtype=Scalar)),
3884
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
3376
3885
  doc="",
3377
3886
  group="Operators",
3378
3887
  )
3379
3888
  add_builtin(
3380
3889
  "sub",
3381
- input_types={"x": matrix(shape=(Any, Any), dtype=Scalar), "y": matrix(shape=(Any, Any), dtype=Scalar)},
3890
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "b": matrix(shape=(Any, Any), dtype=Scalar)},
3382
3891
  constraint=sametypes,
3383
- value_func=sametype_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
3892
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
3384
3893
  doc="",
3385
3894
  group="Operators",
3386
3895
  )
3387
3896
  add_builtin(
3388
3897
  "sub",
3389
- input_types={"x": quaternion(dtype=Scalar), "y": quaternion(dtype=Scalar)},
3390
- value_func=sametype_value_func(quaternion(dtype=Scalar)),
3898
+ input_types={"a": quaternion(dtype=Scalar), "b": quaternion(dtype=Scalar)},
3899
+ value_func=sametypes_create_value_func(quaternion(dtype=Scalar)),
3391
3900
  doc="",
3392
3901
  group="Operators",
3393
3902
  )
3394
3903
  add_builtin(
3395
3904
  "sub",
3396
- input_types={"x": transformation(dtype=Scalar), "y": transformation(dtype=Scalar)},
3397
- value_func=sametype_value_func(transformation(dtype=Scalar)),
3905
+ input_types={"a": transformation(dtype=Scalar), "b": transformation(dtype=Scalar)},
3906
+ value_func=sametypes_create_value_func(transformation(dtype=Scalar)),
3398
3907
  doc="",
3399
3908
  group="Operators",
3400
3909
  )
3401
3910
 
3402
3911
  # bitwise operators
3403
- add_builtin("bit_and", input_types={"x": Int, "y": Int}, value_func=sametype_value_func(Int))
3404
- add_builtin("bit_or", input_types={"x": Int, "y": Int}, value_func=sametype_value_func(Int))
3405
- add_builtin("bit_xor", input_types={"x": Int, "y": Int}, value_func=sametype_value_func(Int))
3406
- add_builtin("lshift", input_types={"x": Int, "y": Int}, value_func=sametype_value_func(Int))
3407
- add_builtin("rshift", input_types={"x": Int, "y": Int}, value_func=sametype_value_func(Int))
3408
- add_builtin("invert", input_types={"x": Int}, value_func=sametype_value_func(Int))
3912
+ add_builtin("bit_and", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
3913
+ add_builtin("bit_or", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
3914
+ add_builtin("bit_xor", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
3915
+ add_builtin("lshift", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
3916
+ add_builtin("rshift", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
3917
+ add_builtin("invert", input_types={"a": Int}, value_func=sametypes_create_value_func(Int))
3918
+
3919
+
3920
+ add_builtin(
3921
+ "mul", input_types={"a": Scalar, "b": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators"
3922
+ )
3409
3923
 
3410
3924
 
3411
- def scalar_mul_value_func(default):
3412
- def fn(arg_types, kwds, _):
3925
+ def scalar_mul_create_value_func(default):
3926
+ def fn(arg_types, arg_values):
3413
3927
  if arg_types is None:
3414
3928
  return default
3415
- scalar = [t for t in arg_types if t in scalar_types][0]
3416
- compound = [t for t in arg_types if t not in scalar_types][0]
3929
+
3930
+ scalar = next(t for t in arg_types.values() if t in scalar_types)
3931
+ compound = next(t for t in arg_types.values() if t not in scalar_types)
3417
3932
  if scalar != compound._wp_scalar_type_:
3418
3933
  raise RuntimeError("Object and coefficient must have the same scalar type when multiplying by scalar")
3934
+
3419
3935
  return compound
3420
3936
 
3421
3937
  return fn
3422
3938
 
3423
3939
 
3424
- def mul_matvec_constraint(arg_types):
3425
- return arg_types[0]._shape_[1] == arg_types[1]._length_
3426
-
3427
-
3428
- def mul_matvec_value_func(arg_types, kwds, _):
3429
- if arg_types is None:
3430
- return vector(length=Any, dtype=Scalar)
3431
-
3432
- if arg_types[0]._wp_scalar_type_ != arg_types[1]._wp_scalar_type_:
3433
- raise RuntimeError(
3434
- f"Can't multiply matrix and vector with different types {arg_types[0]._wp_scalar_type_}, {arg_types[1]._wp_scalar_type_}"
3435
- )
3436
-
3437
- if not mul_matmat_constraint(arg_types):
3438
- raise RuntimeError(
3439
- f"Can't multiply matrix of shape {arg_types[0]._shape_} and vector with length {arg_types[1]._length_}"
3440
- )
3441
-
3442
- return vector(length=arg_types[0]._shape_[0], dtype=arg_types[0]._wp_scalar_type_)
3443
-
3444
-
3445
- def mul_vecmat_constraint(arg_types):
3446
- return arg_types[1]._shape_[0] == arg_types[0]._length_
3447
-
3448
-
3449
- def mul_vecmat_value_func(arg_types, kwds, _):
3450
- if arg_types is None:
3451
- return vector(length=Any, dtype=Scalar)
3452
-
3453
- if arg_types[1]._wp_scalar_type_ != arg_types[0]._wp_scalar_type_:
3454
- raise RuntimeError(
3455
- f"Can't multiply vector and matrix with different types {arg_types[1]._wp_scalar_type_}, {arg_types[0]._wp_scalar_type_}"
3456
- )
3457
-
3458
- if not mul_vecmat_constraint(arg_types):
3459
- raise RuntimeError(
3460
- f"Can't multiply vector with length {arg_types[0]._length_} and matrix of shape {arg_types[1]._shape_}"
3461
- )
3462
-
3463
- return vector(length=arg_types[1]._shape_[1], dtype=arg_types[1]._wp_scalar_type_)
3464
-
3465
-
3466
- def mul_matmat_constraint(arg_types):
3467
- return arg_types[0]._shape_[1] == arg_types[1]._shape_[0]
3468
-
3469
-
3470
- def mul_matmat_value_func(arg_types, kwds, _):
3471
- if arg_types is None:
3472
- return matrix(length=Any, dtype=Scalar)
3473
-
3474
- if arg_types[0]._wp_scalar_type_ != arg_types[1]._wp_scalar_type_:
3475
- raise RuntimeError(
3476
- f"Can't multiply matrices with different types {arg_types[0]._wp_scalar_type_}, {arg_types[1]._wp_scalar_type_}"
3477
- )
3478
-
3479
- if not mul_matmat_constraint(arg_types):
3480
- raise RuntimeError(f"Can't multiply matrix of shapes {arg_types[0]._shape_} and {arg_types[1]._shape_}")
3481
-
3482
- return matrix(shape=(arg_types[0]._shape_[0], arg_types[1]._shape_[1]), dtype=arg_types[0]._wp_scalar_type_)
3483
-
3484
-
3485
- add_builtin("mul", input_types={"x": Scalar, "y": Scalar}, value_func=sametype_value_func(Scalar), group="Operators")
3486
3940
  add_builtin(
3487
3941
  "mul",
3488
- input_types={"x": vector(length=Any, dtype=Scalar), "y": Scalar},
3489
- value_func=scalar_mul_value_func(vector(length=Any, dtype=Scalar)),
3942
+ input_types={"a": vector(length=Any, dtype=Scalar), "b": Scalar},
3943
+ value_func=scalar_mul_create_value_func(vector(length=Any, dtype=Scalar)),
3490
3944
  doc="",
3491
3945
  group="Operators",
3492
3946
  )
3493
3947
  add_builtin(
3494
3948
  "mul",
3495
- input_types={"x": Scalar, "y": vector(length=Any, dtype=Scalar)},
3496
- value_func=scalar_mul_value_func(vector(length=Any, dtype=Scalar)),
3949
+ input_types={"a": Scalar, "b": vector(length=Any, dtype=Scalar)},
3950
+ value_func=scalar_mul_create_value_func(vector(length=Any, dtype=Scalar)),
3497
3951
  doc="",
3498
3952
  group="Operators",
3499
3953
  )
3500
3954
  add_builtin(
3501
3955
  "mul",
3502
- input_types={"x": quaternion(dtype=Scalar), "y": Scalar},
3503
- value_func=scalar_mul_value_func(quaternion(dtype=Scalar)),
3956
+ input_types={"a": quaternion(dtype=Scalar), "b": Scalar},
3957
+ value_func=scalar_mul_create_value_func(quaternion(dtype=Scalar)),
3504
3958
  doc="",
3505
3959
  group="Operators",
3506
3960
  )
3507
3961
  add_builtin(
3508
3962
  "mul",
3509
- input_types={"x": Scalar, "y": quaternion(dtype=Scalar)},
3510
- value_func=scalar_mul_value_func(quaternion(dtype=Scalar)),
3963
+ input_types={"a": Scalar, "b": quaternion(dtype=Scalar)},
3964
+ value_func=scalar_mul_create_value_func(quaternion(dtype=Scalar)),
3511
3965
  doc="",
3512
3966
  group="Operators",
3513
3967
  )
3514
3968
  add_builtin(
3515
3969
  "mul",
3516
- input_types={"x": quaternion(dtype=Scalar), "y": quaternion(dtype=Scalar)},
3517
- value_func=sametype_value_func(quaternion(dtype=Scalar)),
3970
+ input_types={"a": quaternion(dtype=Scalar), "b": quaternion(dtype=Scalar)},
3971
+ value_func=sametypes_create_value_func(quaternion(dtype=Scalar)),
3518
3972
  doc="",
3519
3973
  group="Operators",
3520
3974
  )
3521
3975
  add_builtin(
3522
3976
  "mul",
3523
- input_types={"x": Scalar, "y": matrix(shape=(Any, Any), dtype=Scalar)},
3524
- value_func=scalar_mul_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
3977
+ input_types={"a": Scalar, "b": matrix(shape=(Any, Any), dtype=Scalar)},
3978
+ value_func=scalar_mul_create_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
3525
3979
  doc="",
3526
3980
  group="Operators",
3527
3981
  )
3528
3982
  add_builtin(
3529
3983
  "mul",
3530
- input_types={"x": matrix(shape=(Any, Any), dtype=Scalar), "y": Scalar},
3531
- value_func=scalar_mul_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
3984
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "b": Scalar},
3985
+ value_func=scalar_mul_create_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
3532
3986
  doc="",
3533
3987
  group="Operators",
3534
3988
  )
3989
+
3990
+
3991
+ def matvec_mul_constraint(arg_types: Mapping[str, type]):
3992
+ return arg_types["a"]._shape_[1] == arg_types["b"]._length_
3993
+
3994
+
3995
+ def matvec_mul_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
3996
+ if arg_types is None:
3997
+ return vector(length=Any, dtype=Scalar)
3998
+
3999
+ if arg_types["a"]._wp_scalar_type_ != arg_types["b"]._wp_scalar_type_:
4000
+ raise RuntimeError(
4001
+ f"Can't multiply matrix and vector with different types {arg_types['a']._wp_scalar_type_}, {arg_types['b']._wp_scalar_type_}"
4002
+ )
4003
+
4004
+ if not matvec_mul_constraint(arg_types):
4005
+ raise RuntimeError(
4006
+ f"Can't multiply matrix of shape {arg_types['a']._shape_} and vector with length {arg_types['b']._length_}"
4007
+ )
4008
+
4009
+ return vector(length=arg_types["a"]._shape_[0], dtype=arg_types["a"]._wp_scalar_type_)
4010
+
4011
+
3535
4012
  add_builtin(
3536
4013
  "mul",
3537
- input_types={"x": matrix(shape=(Any, Any), dtype=Scalar), "y": vector(length=Any, dtype=Scalar)},
3538
- constraint=mul_matvec_constraint,
3539
- value_func=mul_matvec_value_func,
4014
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
4015
+ constraint=matvec_mul_constraint,
4016
+ value_func=matvec_mul_value_func,
3540
4017
  doc="",
3541
4018
  group="Operators",
3542
4019
  )
4020
+
4021
+
4022
+ def mul_vecmat_constraint(arg_types: Mapping[str, type]):
4023
+ return arg_types["b"]._shape_[0] == arg_types["a"]._length_
4024
+
4025
+
4026
+ def mul_vecmat_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
4027
+ if arg_types is None:
4028
+ return vector(length=Any, dtype=Scalar)
4029
+
4030
+ if arg_types["b"]._wp_scalar_type_ != arg_types["a"]._wp_scalar_type_:
4031
+ raise RuntimeError(
4032
+ f"Can't multiply vector and matrix with different types {arg_types['b']._wp_scalar_type_}, {arg_types['a']._wp_scalar_type_}"
4033
+ )
4034
+
4035
+ if not mul_vecmat_constraint(arg_types):
4036
+ raise RuntimeError(
4037
+ f"Can't multiply vector with length {arg_types['a']._length_} and matrix of shape {arg_types['b']._shape_}"
4038
+ )
4039
+
4040
+ return vector(length=arg_types["b"]._shape_[1], dtype=arg_types["b"]._wp_scalar_type_)
4041
+
4042
+
3543
4043
  add_builtin(
3544
4044
  "mul",
3545
- input_types={"x": vector(length=Any, dtype=Scalar), "y": matrix(shape=(Any, Any), dtype=Scalar)},
4045
+ input_types={"a": vector(length=Any, dtype=Scalar), "b": matrix(shape=(Any, Any), dtype=Scalar)},
3546
4046
  constraint=mul_vecmat_constraint,
3547
4047
  value_func=mul_vecmat_value_func,
3548
4048
  doc="",
3549
4049
  group="Operators",
3550
4050
  )
4051
+
4052
+
4053
+ def matmat_mul_constraint(arg_types: Mapping[str, type]):
4054
+ return arg_types["a"]._shape_[1] == arg_types["b"]._shape_[0]
4055
+
4056
+
4057
+ def matmat_mul_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
4058
+ if arg_types is None:
4059
+ return matrix(length=Any, dtype=Scalar)
4060
+
4061
+ if arg_types["a"]._wp_scalar_type_ != arg_types["b"]._wp_scalar_type_:
4062
+ raise RuntimeError(
4063
+ f"Can't multiply matrices with different types {arg_types['a']._wp_scalar_type_}, {arg_types['b']._wp_scalar_type_}"
4064
+ )
4065
+
4066
+ if not matmat_mul_constraint(arg_types):
4067
+ raise RuntimeError(f"Can't multiply matrix of shapes {arg_types['a']._shape_} and {arg_types['b']._shape_}")
4068
+
4069
+ return matrix(shape=(arg_types["a"]._shape_[0], arg_types["b"]._shape_[1]), dtype=arg_types["a"]._wp_scalar_type_)
4070
+
4071
+
3551
4072
  add_builtin(
3552
4073
  "mul",
3553
- input_types={"x": matrix(shape=(Any, Any), dtype=Scalar), "y": matrix(shape=(Any, Any), dtype=Scalar)},
3554
- constraint=mul_matmat_constraint,
3555
- value_func=mul_matmat_value_func,
4074
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "b": matrix(shape=(Any, Any), dtype=Scalar)},
4075
+ constraint=matmat_mul_constraint,
4076
+ value_func=matmat_mul_value_func,
3556
4077
  doc="",
3557
4078
  group="Operators",
3558
4079
  )
3559
4080
 
4081
+
3560
4082
  add_builtin(
3561
4083
  "mul",
3562
- input_types={"x": transformation(dtype=Scalar), "y": transformation(dtype=Scalar)},
3563
- value_func=sametype_value_func(transformation(dtype=Scalar)),
4084
+ input_types={"a": transformation(dtype=Scalar), "b": transformation(dtype=Scalar)},
4085
+ value_func=sametypes_create_value_func(transformation(dtype=Scalar)),
3564
4086
  doc="",
3565
4087
  group="Operators",
3566
4088
  )
3567
4089
  add_builtin(
3568
4090
  "mul",
3569
- input_types={"x": Scalar, "y": transformation(dtype=Scalar)},
3570
- value_func=scalar_mul_value_func(transformation(dtype=Scalar)),
4091
+ input_types={"a": Scalar, "b": transformation(dtype=Scalar)},
4092
+ value_func=scalar_mul_create_value_func(transformation(dtype=Scalar)),
3571
4093
  doc="",
3572
4094
  group="Operators",
3573
4095
  )
3574
4096
  add_builtin(
3575
4097
  "mul",
3576
- input_types={"x": transformation(dtype=Scalar), "y": Scalar},
3577
- value_func=scalar_mul_value_func(transformation(dtype=Scalar)),
4098
+ input_types={"a": transformation(dtype=Scalar), "b": Scalar},
4099
+ value_func=scalar_mul_create_value_func(transformation(dtype=Scalar)),
3578
4100
  doc="",
3579
4101
  group="Operators",
3580
4102
  )
3581
4103
 
3582
- add_builtin("mod", input_types={"x": Scalar, "y": Scalar}, value_func=sametype_value_func(Scalar), group="Operators")
4104
+ add_builtin(
4105
+ "mod", input_types={"a": Scalar, "b": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators"
4106
+ )
3583
4107
 
3584
4108
  add_builtin(
3585
4109
  "div",
3586
- input_types={"x": Scalar, "y": Scalar},
3587
- value_func=sametype_value_func(Scalar),
4110
+ input_types={"a": Scalar, "b": Scalar},
4111
+ value_func=sametypes_create_value_func(Scalar),
3588
4112
  doc="",
3589
4113
  group="Operators",
3590
4114
  require_original_output_arg=True,
3591
4115
  )
3592
4116
  add_builtin(
3593
4117
  "div",
3594
- input_types={"x": vector(length=Any, dtype=Scalar), "y": Scalar},
3595
- value_func=scalar_mul_value_func(vector(length=Any, dtype=Scalar)),
4118
+ input_types={"a": vector(length=Any, dtype=Scalar), "b": Scalar},
4119
+ value_func=scalar_mul_create_value_func(vector(length=Any, dtype=Scalar)),
3596
4120
  doc="",
3597
4121
  group="Operators",
3598
4122
  )
3599
4123
  add_builtin(
3600
4124
  "div",
3601
- input_types={"x": Scalar, "y": vector(length=Any, dtype=Scalar)},
3602
- value_func=scalar_mul_value_func(vector(length=Any, dtype=Scalar)),
4125
+ input_types={"a": Scalar, "b": vector(length=Any, dtype=Scalar)},
4126
+ value_func=scalar_mul_create_value_func(vector(length=Any, dtype=Scalar)),
3603
4127
  doc="",
3604
4128
  group="Operators",
3605
4129
  )
3606
4130
  add_builtin(
3607
4131
  "div",
3608
- input_types={"x": matrix(shape=(Any, Any), dtype=Scalar), "y": Scalar},
3609
- value_func=scalar_mul_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
4132
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "b": Scalar},
4133
+ value_func=scalar_mul_create_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
3610
4134
  doc="",
3611
4135
  group="Operators",
3612
4136
  )
3613
4137
  add_builtin(
3614
4138
  "div",
3615
- input_types={"x": Scalar, "y": matrix(shape=(Any, Any), dtype=Scalar)},
3616
- value_func=scalar_mul_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
4139
+ input_types={"a": Scalar, "b": matrix(shape=(Any, Any), dtype=Scalar)},
4140
+ value_func=scalar_mul_create_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
3617
4141
  doc="",
3618
4142
  group="Operators",
3619
4143
  )
3620
4144
  add_builtin(
3621
4145
  "div",
3622
- input_types={"x": quaternion(dtype=Scalar), "y": Scalar},
3623
- value_func=scalar_mul_value_func(quaternion(dtype=Scalar)),
4146
+ input_types={"a": quaternion(dtype=Scalar), "b": Scalar},
4147
+ value_func=scalar_mul_create_value_func(quaternion(dtype=Scalar)),
3624
4148
  doc="",
3625
4149
  group="Operators",
3626
4150
  )
3627
4151
  add_builtin(
3628
4152
  "div",
3629
- input_types={"x": Scalar, "y": quaternion(dtype=Scalar)},
3630
- value_func=scalar_mul_value_func(quaternion(dtype=Scalar)),
4153
+ input_types={"a": Scalar, "b": quaternion(dtype=Scalar)},
4154
+ value_func=scalar_mul_create_value_func(quaternion(dtype=Scalar)),
3631
4155
  doc="",
3632
4156
  group="Operators",
3633
4157
  )
3634
4158
 
3635
4159
  add_builtin(
3636
4160
  "floordiv",
3637
- input_types={"x": Scalar, "y": Scalar},
3638
- value_func=sametype_value_func(Scalar),
4161
+ input_types={"a": Scalar, "b": Scalar},
4162
+ value_func=sametypes_create_value_func(Scalar),
3639
4163
  doc="",
3640
4164
  group="Operators",
3641
4165
  )
3642
4166
 
3643
- add_builtin("pos", input_types={"x": Scalar}, value_func=sametype_value_func(Scalar), group="Operators")
4167
+ add_builtin("pos", input_types={"x": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators")
3644
4168
  add_builtin(
3645
4169
  "pos",
3646
4170
  input_types={"x": vector(length=Any, dtype=Scalar)},
3647
- value_func=sametype_value_func(vector(length=Any, dtype=Scalar)),
4171
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
3648
4172
  doc="",
3649
4173
  group="Operators",
3650
4174
  )
3651
4175
  add_builtin(
3652
4176
  "pos",
3653
4177
  input_types={"x": quaternion(dtype=Scalar)},
3654
- value_func=sametype_value_func(quaternion(dtype=Scalar)),
4178
+ value_func=sametypes_create_value_func(quaternion(dtype=Scalar)),
3655
4179
  doc="",
3656
4180
  group="Operators",
3657
4181
  )
3658
4182
  add_builtin(
3659
4183
  "pos",
3660
4184
  input_types={"x": matrix(shape=(Any, Any), dtype=Scalar)},
3661
- value_func=sametype_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
4185
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
3662
4186
  doc="",
3663
4187
  group="Operators",
3664
4188
  )
3665
- add_builtin("neg", input_types={"x": Scalar}, value_func=sametype_value_func(Scalar), group="Operators")
4189
+ add_builtin("neg", input_types={"x": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators")
3666
4190
  add_builtin(
3667
4191
  "neg",
3668
4192
  input_types={"x": vector(length=Any, dtype=Scalar)},
3669
- value_func=sametype_value_func(vector(length=Any, dtype=Scalar)),
4193
+ value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
3670
4194
  doc="",
3671
4195
  group="Operators",
3672
4196
  )
3673
4197
  add_builtin(
3674
4198
  "neg",
3675
4199
  input_types={"x": quaternion(dtype=Scalar)},
3676
- value_func=sametype_value_func(quaternion(dtype=Scalar)),
4200
+ value_func=sametypes_create_value_func(quaternion(dtype=Scalar)),
3677
4201
  doc="",
3678
4202
  group="Operators",
3679
4203
  )
3680
4204
  add_builtin(
3681
4205
  "neg",
3682
4206
  input_types={"x": matrix(shape=(Any, Any), dtype=Scalar)},
3683
- value_func=sametype_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
4207
+ value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
3684
4208
  doc="",
3685
4209
  group="Operators",
3686
4210
  )
3687
4211
 
3688
- add_builtin("unot", input_types={"b": builtins.bool}, value_type=builtins.bool, doc="", group="Operators")
4212
+ add_builtin("unot", input_types={"a": builtins.bool}, value_type=builtins.bool, doc="", group="Operators")
3689
4213
  for t in int_types:
3690
- add_builtin("unot", input_types={"b": t}, value_type=builtins.bool, doc="", group="Operators")
4214
+ add_builtin("unot", input_types={"a": t}, value_type=builtins.bool, doc="", group="Operators")
3691
4215
 
3692
4216
 
3693
4217
  add_builtin("unot", input_types={"a": array(dtype=Any)}, value_type=builtins.bool, doc="", group="Operators")