warp-lang 1.2.2__py3-none-manylinux2014_aarch64.whl → 1.3.0__py3-none-manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +8 -6
- warp/autograd.py +823 -0
- warp/bin/warp.so +0 -0
- warp/build.py +6 -2
- warp/builtins.py +1410 -886
- warp/codegen.py +503 -166
- warp/config.py +48 -18
- warp/context.py +400 -198
- warp/dlpack.py +8 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
- warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
- warp/examples/benchmarks/benchmark_launches.py +1 -1
- warp/examples/core/example_cupy.py +78 -0
- warp/examples/fem/example_apic_fluid.py +17 -36
- warp/examples/fem/example_burgers.py +9 -18
- warp/examples/fem/example_convection_diffusion.py +7 -17
- warp/examples/fem/example_convection_diffusion_dg.py +27 -47
- warp/examples/fem/example_deformed_geometry.py +11 -22
- warp/examples/fem/example_diffusion.py +7 -18
- warp/examples/fem/example_diffusion_3d.py +24 -28
- warp/examples/fem/example_diffusion_mgpu.py +7 -14
- warp/examples/fem/example_magnetostatics.py +190 -0
- warp/examples/fem/example_mixed_elasticity.py +111 -80
- warp/examples/fem/example_navier_stokes.py +30 -34
- warp/examples/fem/example_nonconforming_contact.py +290 -0
- warp/examples/fem/example_stokes.py +17 -32
- warp/examples/fem/example_stokes_transfer.py +12 -21
- warp/examples/fem/example_streamlines.py +350 -0
- warp/examples/fem/utils.py +936 -0
- warp/fabric.py +5 -2
- warp/fem/__init__.py +13 -3
- warp/fem/cache.py +161 -11
- warp/fem/dirichlet.py +37 -28
- warp/fem/domain.py +105 -14
- warp/fem/field/__init__.py +14 -3
- warp/fem/field/field.py +454 -11
- warp/fem/field/nodal_field.py +33 -18
- warp/fem/geometry/deformed_geometry.py +50 -15
- warp/fem/geometry/hexmesh.py +12 -24
- warp/fem/geometry/nanogrid.py +106 -31
- warp/fem/geometry/quadmesh_2d.py +6 -11
- warp/fem/geometry/tetmesh.py +103 -61
- warp/fem/geometry/trimesh_2d.py +98 -47
- warp/fem/integrate.py +231 -186
- warp/fem/operator.py +14 -9
- warp/fem/quadrature/pic_quadrature.py +35 -9
- warp/fem/quadrature/quadrature.py +119 -32
- warp/fem/space/basis_space.py +98 -22
- warp/fem/space/collocated_function_space.py +3 -1
- warp/fem/space/function_space.py +7 -2
- warp/fem/space/grid_2d_function_space.py +3 -3
- warp/fem/space/grid_3d_function_space.py +4 -4
- warp/fem/space/hexmesh_function_space.py +3 -2
- warp/fem/space/nanogrid_function_space.py +12 -14
- warp/fem/space/partition.py +45 -47
- warp/fem/space/restriction.py +19 -16
- warp/fem/space/shape/cube_shape_function.py +91 -3
- warp/fem/space/shape/shape_function.py +7 -0
- warp/fem/space/shape/square_shape_function.py +32 -0
- warp/fem/space/shape/tet_shape_function.py +11 -7
- warp/fem/space/shape/triangle_shape_function.py +10 -1
- warp/fem/space/topology.py +116 -42
- warp/fem/types.py +8 -1
- warp/fem/utils.py +301 -83
- warp/native/array.h +16 -0
- warp/native/builtin.h +0 -15
- warp/native/cuda_util.cpp +14 -6
- warp/native/exports.h +1348 -1308
- warp/native/quat.h +79 -0
- warp/native/rand.h +27 -4
- warp/native/sparse.cpp +83 -81
- warp/native/sparse.cu +381 -453
- warp/native/vec.h +64 -0
- warp/native/volume.cpp +40 -49
- warp/native/volume_builder.cu +2 -3
- warp/native/volume_builder.h +12 -17
- warp/native/warp.cu +3 -3
- warp/native/warp.h +69 -59
- warp/render/render_opengl.py +17 -9
- warp/sim/articulation.py +117 -17
- warp/sim/collide.py +35 -29
- warp/sim/model.py +123 -18
- warp/sim/render.py +3 -1
- warp/sparse.py +867 -203
- warp/stubs.py +312 -541
- warp/tape.py +29 -1
- warp/tests/disabled_kinematics.py +1 -1
- warp/tests/test_adam.py +1 -1
- warp/tests/test_arithmetic.py +1 -1
- warp/tests/test_array.py +58 -1
- warp/tests/test_array_reduce.py +1 -1
- warp/tests/test_async.py +1 -1
- warp/tests/test_atomic.py +1 -1
- warp/tests/test_bool.py +1 -1
- warp/tests/test_builtins_resolution.py +1 -1
- warp/tests/test_bvh.py +6 -1
- warp/tests/test_closest_point_edge_edge.py +1 -1
- warp/tests/test_codegen.py +66 -1
- warp/tests/test_compile_consts.py +1 -1
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_copy.py +1 -1
- warp/tests/test_ctypes.py +1 -1
- warp/tests/test_dense.py +1 -1
- warp/tests/test_devices.py +1 -1
- warp/tests/test_dlpack.py +1 -1
- warp/tests/test_examples.py +33 -4
- warp/tests/test_fabricarray.py +5 -2
- warp/tests/test_fast_math.py +1 -1
- warp/tests/test_fem.py +213 -6
- warp/tests/test_fp16.py +1 -1
- warp/tests/test_func.py +1 -1
- warp/tests/test_future_annotations.py +90 -0
- warp/tests/test_generics.py +1 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_grad_customs.py +1 -1
- warp/tests/test_grad_debug.py +247 -0
- warp/tests/test_hash_grid.py +6 -1
- warp/tests/test_implicit_init.py +354 -0
- warp/tests/test_import.py +1 -1
- warp/tests/test_indexedarray.py +1 -1
- warp/tests/test_intersect.py +1 -1
- warp/tests/test_jax.py +1 -1
- warp/tests/test_large.py +1 -1
- warp/tests/test_launch.py +1 -1
- warp/tests/test_lerp.py +1 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_lvalue.py +1 -1
- warp/tests/test_marching_cubes.py +5 -2
- warp/tests/test_mat.py +34 -35
- warp/tests/test_mat_lite.py +2 -1
- warp/tests/test_mat_scalar_ops.py +1 -1
- warp/tests/test_math.py +1 -1
- warp/tests/test_matmul.py +20 -16
- warp/tests/test_matmul_lite.py +1 -1
- warp/tests/test_mempool.py +1 -1
- warp/tests/test_mesh.py +5 -2
- warp/tests/test_mesh_query_aabb.py +1 -1
- warp/tests/test_mesh_query_point.py +1 -1
- warp/tests/test_mesh_query_ray.py +1 -1
- warp/tests/test_mlp.py +1 -1
- warp/tests/test_model.py +1 -1
- warp/tests/test_module_hashing.py +77 -1
- warp/tests/test_modules_lite.py +1 -1
- warp/tests/test_multigpu.py +1 -1
- warp/tests/test_noise.py +1 -1
- warp/tests/test_operators.py +1 -1
- warp/tests/test_options.py +1 -1
- warp/tests/test_overwrite.py +542 -0
- warp/tests/test_peer.py +1 -1
- warp/tests/test_pinned.py +1 -1
- warp/tests/test_print.py +1 -1
- warp/tests/test_quat.py +15 -1
- warp/tests/test_rand.py +1 -1
- warp/tests/test_reload.py +1 -1
- warp/tests/test_rounding.py +1 -1
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +95 -0
- warp/tests/test_sim_grad.py +1 -1
- warp/tests/test_sim_kinematics.py +1 -1
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +82 -15
- warp/tests/test_spatial.py +1 -1
- warp/tests/test_special_values.py +2 -11
- warp/tests/test_streams.py +11 -1
- warp/tests/test_struct.py +1 -1
- warp/tests/test_tape.py +1 -1
- warp/tests/test_torch.py +194 -1
- warp/tests/test_transient_module.py +1 -1
- warp/tests/test_types.py +1 -1
- warp/tests/test_utils.py +1 -1
- warp/tests/test_vec.py +15 -63
- warp/tests/test_vec_lite.py +2 -1
- warp/tests/test_vec_scalar_ops.py +65 -1
- warp/tests/test_verify_fp.py +1 -1
- warp/tests/test_volume.py +28 -2
- warp/tests/test_volume_write.py +1 -1
- warp/tests/unittest_serial.py +1 -1
- warp/tests/unittest_suites.py +9 -1
- warp/tests/walkthrough_debug.py +1 -1
- warp/thirdparty/unittest_parallel.py +2 -5
- warp/torch.py +103 -41
- warp/types.py +341 -224
- warp/utils.py +11 -2
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/METADATA +99 -46
- warp_lang-1.3.0.dist-info/RECORD +368 -0
- warp/examples/fem/bsr_utils.py +0 -378
- warp/examples/fem/mesh_utils.py +0 -133
- warp/examples/fem/plot_utils.py +0 -292
- warp_lang-1.2.2.dist-info/RECORD +0 -359
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/top_level.txt +0 -0
warp/builtins.py
CHANGED
|
@@ -5,25 +5,40 @@
|
|
|
5
5
|
# distribution of this software and related documentation without an express
|
|
6
6
|
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
7
|
import builtins
|
|
8
|
-
from typing import Any, Callable,
|
|
8
|
+
from typing import Any, Callable, Mapping, Sequence
|
|
9
9
|
|
|
10
|
-
from warp.codegen import Reference
|
|
10
|
+
from warp.codegen import Reference, Var, strip_reference
|
|
11
11
|
from warp.types import *
|
|
12
12
|
|
|
13
13
|
from .context import add_builtin
|
|
14
14
|
|
|
15
15
|
|
|
16
|
-
def
|
|
17
|
-
|
|
16
|
+
def seq_check_equal(seq_1, seq_2):
|
|
17
|
+
if not isinstance(seq_1, Sequence) or not isinstance(seq_2, Sequence):
|
|
18
|
+
return False
|
|
18
19
|
|
|
20
|
+
if len(seq_1) != len(seq_2):
|
|
21
|
+
return False
|
|
19
22
|
|
|
20
|
-
|
|
21
|
-
|
|
23
|
+
return all(x == y for x, y in zip(seq_1, seq_2))
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def sametypes(arg_types: Mapping[str, Any]):
|
|
27
|
+
arg_types_iter = iter(arg_types.values())
|
|
28
|
+
arg_type_0 = next(arg_types_iter)
|
|
29
|
+
return all(types_equal(arg_type_0, t) for t in arg_types_iter)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def sametypes_create_value_func(default):
|
|
33
|
+
def fn(arg_types, arg_values):
|
|
22
34
|
if arg_types is None:
|
|
23
35
|
return default
|
|
36
|
+
|
|
24
37
|
if not sametypes(arg_types):
|
|
25
38
|
raise RuntimeError(f"Input types must be the same, found: {[type_repr(t) for t in arg_types]}")
|
|
26
|
-
|
|
39
|
+
|
|
40
|
+
arg_type_0 = next(iter(arg_types.values()))
|
|
41
|
+
return arg_type_0
|
|
27
42
|
|
|
28
43
|
return fn
|
|
29
44
|
|
|
@@ -33,39 +48,39 @@ def sametype_value_func(default):
|
|
|
33
48
|
|
|
34
49
|
add_builtin(
|
|
35
50
|
"min",
|
|
36
|
-
input_types={"
|
|
37
|
-
value_func=
|
|
51
|
+
input_types={"a": Scalar, "b": Scalar},
|
|
52
|
+
value_func=sametypes_create_value_func(Scalar),
|
|
38
53
|
doc="Return the minimum of two scalars.",
|
|
39
54
|
group="Scalar Math",
|
|
40
55
|
)
|
|
41
56
|
|
|
42
57
|
add_builtin(
|
|
43
58
|
"max",
|
|
44
|
-
input_types={"
|
|
45
|
-
value_func=
|
|
59
|
+
input_types={"a": Scalar, "b": Scalar},
|
|
60
|
+
value_func=sametypes_create_value_func(Scalar),
|
|
46
61
|
doc="Return the maximum of two scalars.",
|
|
47
62
|
group="Scalar Math",
|
|
48
63
|
)
|
|
49
64
|
|
|
50
65
|
add_builtin(
|
|
51
66
|
"clamp",
|
|
52
|
-
input_types={"x": Scalar, "
|
|
53
|
-
value_func=
|
|
54
|
-
doc="Clamp the value of ``x`` to the range [
|
|
67
|
+
input_types={"x": Scalar, "low": Scalar, "high": Scalar},
|
|
68
|
+
value_func=sametypes_create_value_func(Scalar),
|
|
69
|
+
doc="Clamp the value of ``x`` to the range [low, high].",
|
|
55
70
|
group="Scalar Math",
|
|
56
71
|
)
|
|
57
72
|
|
|
58
73
|
add_builtin(
|
|
59
74
|
"abs",
|
|
60
75
|
input_types={"x": Scalar},
|
|
61
|
-
value_func=
|
|
76
|
+
value_func=sametypes_create_value_func(Scalar),
|
|
62
77
|
doc="Return the absolute value of ``x``.",
|
|
63
78
|
group="Scalar Math",
|
|
64
79
|
)
|
|
65
80
|
add_builtin(
|
|
66
81
|
"sign",
|
|
67
82
|
input_types={"x": Scalar},
|
|
68
|
-
value_func=
|
|
83
|
+
value_func=sametypes_create_value_func(Scalar),
|
|
69
84
|
doc="Return -1 if ``x`` < 0, return 1 otherwise.",
|
|
70
85
|
group="Scalar Math",
|
|
71
86
|
)
|
|
@@ -73,14 +88,14 @@ add_builtin(
|
|
|
73
88
|
add_builtin(
|
|
74
89
|
"step",
|
|
75
90
|
input_types={"x": Scalar},
|
|
76
|
-
value_func=
|
|
91
|
+
value_func=sametypes_create_value_func(Scalar),
|
|
77
92
|
doc="Return 1.0 if ``x`` < 0.0, return 0.0 otherwise.",
|
|
78
93
|
group="Scalar Math",
|
|
79
94
|
)
|
|
80
95
|
add_builtin(
|
|
81
96
|
"nonzero",
|
|
82
97
|
input_types={"x": Scalar},
|
|
83
|
-
value_func=
|
|
98
|
+
value_func=sametypes_create_value_func(Scalar),
|
|
84
99
|
doc="Return 1.0 if ``x`` is not equal to zero, return 0.0 otherwise.",
|
|
85
100
|
group="Scalar Math",
|
|
86
101
|
)
|
|
@@ -88,35 +103,35 @@ add_builtin(
|
|
|
88
103
|
add_builtin(
|
|
89
104
|
"sin",
|
|
90
105
|
input_types={"x": Float},
|
|
91
|
-
value_func=
|
|
106
|
+
value_func=sametypes_create_value_func(Float),
|
|
92
107
|
doc="Return the sine of ``x`` in radians.",
|
|
93
108
|
group="Scalar Math",
|
|
94
109
|
)
|
|
95
110
|
add_builtin(
|
|
96
111
|
"cos",
|
|
97
112
|
input_types={"x": Float},
|
|
98
|
-
value_func=
|
|
113
|
+
value_func=sametypes_create_value_func(Float),
|
|
99
114
|
doc="Return the cosine of ``x`` in radians.",
|
|
100
115
|
group="Scalar Math",
|
|
101
116
|
)
|
|
102
117
|
add_builtin(
|
|
103
118
|
"acos",
|
|
104
119
|
input_types={"x": Float},
|
|
105
|
-
value_func=
|
|
120
|
+
value_func=sametypes_create_value_func(Float),
|
|
106
121
|
doc="Return arccos of ``x`` in radians. Inputs are automatically clamped to [-1.0, 1.0].",
|
|
107
122
|
group="Scalar Math",
|
|
108
123
|
)
|
|
109
124
|
add_builtin(
|
|
110
125
|
"asin",
|
|
111
126
|
input_types={"x": Float},
|
|
112
|
-
value_func=
|
|
127
|
+
value_func=sametypes_create_value_func(Float),
|
|
113
128
|
doc="Return arcsin of ``x`` in radians. Inputs are automatically clamped to [-1.0, 1.0].",
|
|
114
129
|
group="Scalar Math",
|
|
115
130
|
)
|
|
116
131
|
add_builtin(
|
|
117
132
|
"sqrt",
|
|
118
133
|
input_types={"x": Float},
|
|
119
|
-
value_func=
|
|
134
|
+
value_func=sametypes_create_value_func(Float),
|
|
120
135
|
doc="Return the square root of ``x``, where ``x`` is positive.",
|
|
121
136
|
group="Scalar Math",
|
|
122
137
|
require_original_output_arg=True,
|
|
@@ -124,7 +139,7 @@ add_builtin(
|
|
|
124
139
|
add_builtin(
|
|
125
140
|
"cbrt",
|
|
126
141
|
input_types={"x": Float},
|
|
127
|
-
value_func=
|
|
142
|
+
value_func=sametypes_create_value_func(Float),
|
|
128
143
|
doc="Return the cube root of ``x``.",
|
|
129
144
|
group="Scalar Math",
|
|
130
145
|
require_original_output_arg=True,
|
|
@@ -132,42 +147,42 @@ add_builtin(
|
|
|
132
147
|
add_builtin(
|
|
133
148
|
"tan",
|
|
134
149
|
input_types={"x": Float},
|
|
135
|
-
value_func=
|
|
150
|
+
value_func=sametypes_create_value_func(Float),
|
|
136
151
|
doc="Return the tangent of ``x`` in radians.",
|
|
137
152
|
group="Scalar Math",
|
|
138
153
|
)
|
|
139
154
|
add_builtin(
|
|
140
155
|
"atan",
|
|
141
156
|
input_types={"x": Float},
|
|
142
|
-
value_func=
|
|
157
|
+
value_func=sametypes_create_value_func(Float),
|
|
143
158
|
doc="Return the arctangent of ``x`` in radians.",
|
|
144
159
|
group="Scalar Math",
|
|
145
160
|
)
|
|
146
161
|
add_builtin(
|
|
147
162
|
"atan2",
|
|
148
163
|
input_types={"y": Float, "x": Float},
|
|
149
|
-
value_func=
|
|
164
|
+
value_func=sametypes_create_value_func(Float),
|
|
150
165
|
doc="Return the 2-argument arctangent, atan2, of the point ``(x, y)`` in radians.",
|
|
151
166
|
group="Scalar Math",
|
|
152
167
|
)
|
|
153
168
|
add_builtin(
|
|
154
169
|
"sinh",
|
|
155
170
|
input_types={"x": Float},
|
|
156
|
-
value_func=
|
|
171
|
+
value_func=sametypes_create_value_func(Float),
|
|
157
172
|
doc="Return the sinh of ``x``.",
|
|
158
173
|
group="Scalar Math",
|
|
159
174
|
)
|
|
160
175
|
add_builtin(
|
|
161
176
|
"cosh",
|
|
162
177
|
input_types={"x": Float},
|
|
163
|
-
value_func=
|
|
178
|
+
value_func=sametypes_create_value_func(Float),
|
|
164
179
|
doc="Return the cosh of ``x``.",
|
|
165
180
|
group="Scalar Math",
|
|
166
181
|
)
|
|
167
182
|
add_builtin(
|
|
168
183
|
"tanh",
|
|
169
184
|
input_types={"x": Float},
|
|
170
|
-
value_func=
|
|
185
|
+
value_func=sametypes_create_value_func(Float),
|
|
171
186
|
doc="Return the tanh of ``x``.",
|
|
172
187
|
group="Scalar Math",
|
|
173
188
|
require_original_output_arg=True,
|
|
@@ -175,14 +190,14 @@ add_builtin(
|
|
|
175
190
|
add_builtin(
|
|
176
191
|
"degrees",
|
|
177
192
|
input_types={"x": Float},
|
|
178
|
-
value_func=
|
|
193
|
+
value_func=sametypes_create_value_func(Float),
|
|
179
194
|
doc="Convert ``x`` from radians into degrees.",
|
|
180
195
|
group="Scalar Math",
|
|
181
196
|
)
|
|
182
197
|
add_builtin(
|
|
183
198
|
"radians",
|
|
184
199
|
input_types={"x": Float},
|
|
185
|
-
value_func=
|
|
200
|
+
value_func=sametypes_create_value_func(Float),
|
|
186
201
|
doc="Convert ``x`` from degrees into radians.",
|
|
187
202
|
group="Scalar Math",
|
|
188
203
|
)
|
|
@@ -190,28 +205,28 @@ add_builtin(
|
|
|
190
205
|
add_builtin(
|
|
191
206
|
"log",
|
|
192
207
|
input_types={"x": Float},
|
|
193
|
-
value_func=
|
|
208
|
+
value_func=sametypes_create_value_func(Float),
|
|
194
209
|
doc="Return the natural logarithm (base-e) of ``x``, where ``x`` is positive.",
|
|
195
210
|
group="Scalar Math",
|
|
196
211
|
)
|
|
197
212
|
add_builtin(
|
|
198
213
|
"log2",
|
|
199
214
|
input_types={"x": Float},
|
|
200
|
-
value_func=
|
|
215
|
+
value_func=sametypes_create_value_func(Float),
|
|
201
216
|
doc="Return the binary logarithm (base-2) of ``x``, where ``x`` is positive.",
|
|
202
217
|
group="Scalar Math",
|
|
203
218
|
)
|
|
204
219
|
add_builtin(
|
|
205
220
|
"log10",
|
|
206
221
|
input_types={"x": Float},
|
|
207
|
-
value_func=
|
|
222
|
+
value_func=sametypes_create_value_func(Float),
|
|
208
223
|
doc="Return the common logarithm (base-10) of ``x``, where ``x`` is positive.",
|
|
209
224
|
group="Scalar Math",
|
|
210
225
|
)
|
|
211
226
|
add_builtin(
|
|
212
227
|
"exp",
|
|
213
228
|
input_types={"x": Float},
|
|
214
|
-
value_func=
|
|
229
|
+
value_func=sametypes_create_value_func(Float),
|
|
215
230
|
doc="Return the value of the exponential function :math:`e^x`.",
|
|
216
231
|
group="Scalar Math",
|
|
217
232
|
require_original_output_arg=True,
|
|
@@ -219,7 +234,7 @@ add_builtin(
|
|
|
219
234
|
add_builtin(
|
|
220
235
|
"pow",
|
|
221
236
|
input_types={"x": Float, "y": Float},
|
|
222
|
-
value_func=
|
|
237
|
+
value_func=sametypes_create_value_func(Float),
|
|
223
238
|
doc="Return the result of ``x`` raised to power of ``y``.",
|
|
224
239
|
group="Scalar Math",
|
|
225
240
|
require_original_output_arg=True,
|
|
@@ -228,7 +243,7 @@ add_builtin(
|
|
|
228
243
|
add_builtin(
|
|
229
244
|
"round",
|
|
230
245
|
input_types={"x": Float},
|
|
231
|
-
value_func=
|
|
246
|
+
value_func=sametypes_create_value_func(Float),
|
|
232
247
|
group="Scalar Math",
|
|
233
248
|
doc="""Return the nearest integer value to ``x``, rounding halfway cases away from zero.
|
|
234
249
|
|
|
@@ -239,7 +254,7 @@ add_builtin(
|
|
|
239
254
|
add_builtin(
|
|
240
255
|
"rint",
|
|
241
256
|
input_types={"x": Float},
|
|
242
|
-
value_func=
|
|
257
|
+
value_func=sametypes_create_value_func(Float),
|
|
243
258
|
group="Scalar Math",
|
|
244
259
|
doc="""Return the nearest integer value to ``x``, rounding halfway cases to nearest even integer.
|
|
245
260
|
|
|
@@ -249,19 +264,19 @@ add_builtin(
|
|
|
249
264
|
add_builtin(
|
|
250
265
|
"trunc",
|
|
251
266
|
input_types={"x": Float},
|
|
252
|
-
value_func=
|
|
267
|
+
value_func=sametypes_create_value_func(Float),
|
|
253
268
|
group="Scalar Math",
|
|
254
269
|
doc="""Return the nearest integer that is closer to zero than ``x``.
|
|
255
270
|
|
|
256
271
|
In other words, it discards the fractional part of ``x``.
|
|
257
|
-
It is similar to casting ``float(int(
|
|
272
|
+
It is similar to casting ``float(int(a))``, but preserves the negative sign when ``x`` is in the range [-0.0, -1.0).
|
|
258
273
|
Equivalent to :func:`numpy.trunc()` and :func:`numpy.fix()`.""",
|
|
259
274
|
)
|
|
260
275
|
|
|
261
276
|
add_builtin(
|
|
262
277
|
"floor",
|
|
263
278
|
input_types={"x": Float},
|
|
264
|
-
value_func=
|
|
279
|
+
value_func=sametypes_create_value_func(Float),
|
|
265
280
|
group="Scalar Math",
|
|
266
281
|
doc="""Return the largest integer that is less than or equal to ``x``.""",
|
|
267
282
|
)
|
|
@@ -269,7 +284,7 @@ add_builtin(
|
|
|
269
284
|
add_builtin(
|
|
270
285
|
"ceil",
|
|
271
286
|
input_types={"x": Float},
|
|
272
|
-
value_func=
|
|
287
|
+
value_func=sametypes_create_value_func(Float),
|
|
273
288
|
group="Scalar Math",
|
|
274
289
|
doc="""Return the smallest integer that is greater than or equal to ``x``.""",
|
|
275
290
|
)
|
|
@@ -277,127 +292,145 @@ add_builtin(
|
|
|
277
292
|
add_builtin(
|
|
278
293
|
"frac",
|
|
279
294
|
input_types={"x": Float},
|
|
280
|
-
value_func=
|
|
295
|
+
value_func=sametypes_create_value_func(Float),
|
|
281
296
|
group="Scalar Math",
|
|
282
|
-
doc="""Retrieve the fractional part of x
|
|
297
|
+
doc="""Retrieve the fractional part of ``x``.
|
|
283
298
|
|
|
284
|
-
In other words, it discards the integer part of x and is equivalent to ``x - trunc(x)``.""",
|
|
299
|
+
In other words, it discards the integer part of ``x`` and is equivalent to ``x - trunc(x)``.""",
|
|
285
300
|
)
|
|
286
301
|
|
|
287
302
|
add_builtin(
|
|
288
303
|
"isfinite",
|
|
289
|
-
input_types={"
|
|
304
|
+
input_types={"a": Scalar},
|
|
290
305
|
value_type=builtins.bool,
|
|
291
306
|
group="Scalar Math",
|
|
292
|
-
doc="""Return ``True`` if
|
|
307
|
+
doc="""Return ``True`` if ``a`` is a finite number, otherwise return ``False``.""",
|
|
293
308
|
)
|
|
294
309
|
add_builtin(
|
|
295
310
|
"isfinite",
|
|
296
|
-
input_types={"
|
|
311
|
+
input_types={"a": vector(length=Any, dtype=Scalar)},
|
|
297
312
|
value_type=builtins.bool,
|
|
298
313
|
group="Vector Math",
|
|
299
|
-
doc="Return ``True`` if all elements of the vector ``
|
|
314
|
+
doc="Return ``True`` if all elements of the vector ``a`` are finite, otherwise return ``False``.",
|
|
300
315
|
)
|
|
301
316
|
add_builtin(
|
|
302
317
|
"isfinite",
|
|
303
|
-
input_types={"
|
|
318
|
+
input_types={"a": quaternion(dtype=Scalar)},
|
|
304
319
|
value_type=builtins.bool,
|
|
305
320
|
group="Vector Math",
|
|
306
|
-
doc="Return ``True`` if all elements of the quaternion ``
|
|
321
|
+
doc="Return ``True`` if all elements of the quaternion ``a`` are finite, otherwise return ``False``.",
|
|
307
322
|
)
|
|
308
323
|
add_builtin(
|
|
309
324
|
"isfinite",
|
|
310
|
-
input_types={"
|
|
325
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar)},
|
|
311
326
|
value_type=builtins.bool,
|
|
312
327
|
group="Vector Math",
|
|
313
|
-
doc="Return ``True`` if all elements of the matrix ``
|
|
328
|
+
doc="Return ``True`` if all elements of the matrix ``a`` are finite, otherwise return ``False``.",
|
|
314
329
|
)
|
|
315
330
|
|
|
316
331
|
add_builtin(
|
|
317
332
|
"isnan",
|
|
318
|
-
input_types={"
|
|
333
|
+
input_types={"a": Scalar},
|
|
319
334
|
value_type=builtins.bool,
|
|
320
|
-
doc="Return ``True`` if ``
|
|
335
|
+
doc="Return ``True`` if ``a`` is NaN, otherwise return ``False``.",
|
|
321
336
|
group="Scalar Math",
|
|
322
337
|
)
|
|
323
338
|
add_builtin(
|
|
324
339
|
"isnan",
|
|
325
|
-
input_types={"
|
|
340
|
+
input_types={"a": vector(length=Any, dtype=Scalar)},
|
|
326
341
|
value_type=builtins.bool,
|
|
327
342
|
group="Vector Math",
|
|
328
|
-
doc="Return ``True`` if any element of the vector ``
|
|
343
|
+
doc="Return ``True`` if any element of the vector ``a`` is NaN, otherwise return ``False``.",
|
|
329
344
|
)
|
|
330
345
|
add_builtin(
|
|
331
346
|
"isnan",
|
|
332
|
-
input_types={"
|
|
347
|
+
input_types={"a": quaternion(dtype=Scalar)},
|
|
333
348
|
value_type=builtins.bool,
|
|
334
349
|
group="Vector Math",
|
|
335
|
-
doc="Return ``True`` if any element of the quaternion ``
|
|
350
|
+
doc="Return ``True`` if any element of the quaternion ``a`` is NaN, otherwise return ``False``.",
|
|
336
351
|
)
|
|
337
352
|
add_builtin(
|
|
338
353
|
"isnan",
|
|
339
|
-
input_types={"
|
|
354
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar)},
|
|
340
355
|
value_type=builtins.bool,
|
|
341
356
|
group="Vector Math",
|
|
342
|
-
doc="Return ``True`` if any element of the matrix ``
|
|
357
|
+
doc="Return ``True`` if any element of the matrix ``a`` is NaN, otherwise return ``False``.",
|
|
343
358
|
)
|
|
344
359
|
|
|
345
360
|
add_builtin(
|
|
346
361
|
"isinf",
|
|
347
|
-
input_types={"
|
|
362
|
+
input_types={"a": Scalar},
|
|
348
363
|
value_type=builtins.bool,
|
|
349
364
|
group="Scalar Math",
|
|
350
|
-
doc="""Return ``True`` if
|
|
365
|
+
doc="""Return ``True`` if ``a`` is positive or negative infinity, otherwise return ``False``.""",
|
|
351
366
|
)
|
|
352
367
|
add_builtin(
|
|
353
368
|
"isinf",
|
|
354
|
-
input_types={"
|
|
369
|
+
input_types={"a": vector(length=Any, dtype=Scalar)},
|
|
355
370
|
value_type=builtins.bool,
|
|
356
371
|
group="Vector Math",
|
|
357
|
-
doc="Return ``True`` if any element of the vector ``
|
|
372
|
+
doc="Return ``True`` if any element of the vector ``a`` is positive or negative infinity, otherwise return ``False``.",
|
|
358
373
|
)
|
|
359
374
|
add_builtin(
|
|
360
375
|
"isinf",
|
|
361
|
-
input_types={"
|
|
376
|
+
input_types={"a": quaternion(dtype=Scalar)},
|
|
362
377
|
value_type=builtins.bool,
|
|
363
378
|
group="Vector Math",
|
|
364
|
-
doc="Return ``True`` if any element of the quaternion ``
|
|
379
|
+
doc="Return ``True`` if any element of the quaternion ``a`` is positive or negative infinity, otherwise return ``False``.",
|
|
365
380
|
)
|
|
366
381
|
add_builtin(
|
|
367
382
|
"isinf",
|
|
368
|
-
input_types={"
|
|
383
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar)},
|
|
369
384
|
value_type=builtins.bool,
|
|
370
385
|
group="Vector Math",
|
|
371
|
-
doc="Return ``True`` if any element of the matrix ``
|
|
386
|
+
doc="Return ``True`` if any element of the matrix ``a`` is positive or negative infinity, otherwise return ``False``.",
|
|
372
387
|
)
|
|
373
388
|
|
|
374
389
|
|
|
375
|
-
def
|
|
390
|
+
def scalar_infer_type(arg_types: Mapping[str, type]):
|
|
376
391
|
if arg_types is None:
|
|
377
392
|
return Scalar
|
|
378
393
|
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
if hasattr(t, "_wp_scalar_type_"):
|
|
382
|
-
yield t._wp_scalar_type_
|
|
383
|
-
elif t in scalar_and_bool_types:
|
|
384
|
-
yield t
|
|
394
|
+
if isinstance(arg_types, Mapping):
|
|
395
|
+
arg_types = tuple(arg_types.values())
|
|
385
396
|
|
|
386
|
-
|
|
387
|
-
|
|
397
|
+
scalar_types = set()
|
|
398
|
+
for t in arg_types:
|
|
399
|
+
t = strip_reference(t)
|
|
400
|
+
if hasattr(t, "_wp_scalar_type_"):
|
|
401
|
+
scalar_types.add(t._wp_scalar_type_)
|
|
402
|
+
elif t in scalar_and_bool_types:
|
|
403
|
+
scalar_types.add(t)
|
|
404
|
+
|
|
405
|
+
if len(scalar_types) > 1:
|
|
388
406
|
raise RuntimeError(
|
|
389
|
-
f"Couldn't figure out return type as arguments have multiple precisions: {list(
|
|
407
|
+
f"Couldn't figure out return type as arguments have multiple precisions: {list(scalar_types)}"
|
|
390
408
|
)
|
|
391
|
-
return
|
|
409
|
+
return next(iter(scalar_types))
|
|
392
410
|
|
|
393
411
|
|
|
394
|
-
def
|
|
412
|
+
def scalar_sametypes_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
395
413
|
if arg_types is None:
|
|
396
414
|
return Scalar
|
|
415
|
+
|
|
397
416
|
if not sametypes(arg_types):
|
|
398
417
|
raise RuntimeError(f"Input types must be exactly the same, {list(arg_types)}")
|
|
399
418
|
|
|
400
|
-
return
|
|
419
|
+
return scalar_infer_type(arg_types)
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def float_infer_type(arg_types: Mapping[str, type]):
|
|
423
|
+
if arg_types is None:
|
|
424
|
+
return Float
|
|
425
|
+
|
|
426
|
+
return scalar_infer_type(arg_types)
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def float_sametypes_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
430
|
+
if arg_types is None:
|
|
431
|
+
return Float
|
|
432
|
+
|
|
433
|
+
return scalar_sametypes_value_func(arg_types, arg_values)
|
|
401
434
|
|
|
402
435
|
|
|
403
436
|
# ---------------------------------
|
|
@@ -405,290 +438,312 @@ def sametype_scalar_value_func(arg_types, kwds, _):
|
|
|
405
438
|
|
|
406
439
|
add_builtin(
|
|
407
440
|
"dot",
|
|
408
|
-
input_types={"
|
|
441
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
|
|
409
442
|
constraint=sametypes,
|
|
410
|
-
value_func=
|
|
443
|
+
value_func=scalar_sametypes_value_func,
|
|
411
444
|
group="Vector Math",
|
|
412
445
|
doc="Compute the dot product between two vectors.",
|
|
413
446
|
)
|
|
414
447
|
add_builtin(
|
|
415
448
|
"ddot",
|
|
416
|
-
input_types={"
|
|
449
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "b": matrix(shape=(Any, Any), dtype=Scalar)},
|
|
417
450
|
constraint=sametypes,
|
|
418
|
-
value_func=
|
|
451
|
+
value_func=scalar_sametypes_value_func,
|
|
419
452
|
group="Vector Math",
|
|
420
453
|
doc="Compute the double dot product between two matrices.",
|
|
421
454
|
)
|
|
422
455
|
|
|
423
456
|
add_builtin(
|
|
424
457
|
"min",
|
|
425
|
-
input_types={"
|
|
458
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
|
|
426
459
|
constraint=sametypes,
|
|
427
|
-
value_func=
|
|
460
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
|
|
428
461
|
doc="Return the element-wise minimum of two vectors.",
|
|
429
462
|
group="Vector Math",
|
|
430
463
|
)
|
|
431
464
|
add_builtin(
|
|
432
465
|
"max",
|
|
433
|
-
input_types={"
|
|
466
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
|
|
434
467
|
constraint=sametypes,
|
|
435
|
-
value_func=
|
|
468
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
|
|
436
469
|
doc="Return the element-wise maximum of two vectors.",
|
|
437
470
|
group="Vector Math",
|
|
438
471
|
)
|
|
439
472
|
|
|
440
473
|
add_builtin(
|
|
441
474
|
"min",
|
|
442
|
-
input_types={"
|
|
443
|
-
value_func=
|
|
444
|
-
doc="Return the minimum element of a vector ``
|
|
475
|
+
input_types={"a": vector(length=Any, dtype=Scalar)},
|
|
476
|
+
value_func=scalar_sametypes_value_func,
|
|
477
|
+
doc="Return the minimum element of a vector ``a``.",
|
|
445
478
|
group="Vector Math",
|
|
446
479
|
)
|
|
447
480
|
add_builtin(
|
|
448
481
|
"max",
|
|
449
|
-
input_types={"
|
|
450
|
-
value_func=
|
|
451
|
-
doc="Return the maximum element of a vector ``
|
|
482
|
+
input_types={"a": vector(length=Any, dtype=Scalar)},
|
|
483
|
+
value_func=scalar_sametypes_value_func,
|
|
484
|
+
doc="Return the maximum element of a vector ``a``.",
|
|
452
485
|
group="Vector Math",
|
|
453
486
|
)
|
|
454
487
|
|
|
455
488
|
add_builtin(
|
|
456
489
|
"argmin",
|
|
457
|
-
input_types={"
|
|
458
|
-
value_func=lambda arg_types,
|
|
459
|
-
doc="Return the index of the minimum element of a vector ``
|
|
490
|
+
input_types={"a": vector(length=Any, dtype=Scalar)},
|
|
491
|
+
value_func=lambda arg_types, arg_values: warp.uint32,
|
|
492
|
+
doc="Return the index of the minimum element of a vector ``a``.",
|
|
460
493
|
group="Vector Math",
|
|
461
494
|
missing_grad=True,
|
|
462
495
|
)
|
|
463
496
|
add_builtin(
|
|
464
497
|
"argmax",
|
|
465
|
-
input_types={"
|
|
466
|
-
value_func=lambda arg_types,
|
|
467
|
-
doc="Return the index of the maximum element of a vector ``
|
|
498
|
+
input_types={"a": vector(length=Any, dtype=Scalar)},
|
|
499
|
+
value_func=lambda arg_types, arg_values: warp.uint32,
|
|
500
|
+
doc="Return the index of the maximum element of a vector ``a``.",
|
|
468
501
|
group="Vector Math",
|
|
469
502
|
missing_grad=True,
|
|
470
503
|
)
|
|
471
504
|
|
|
505
|
+
add_builtin(
|
|
506
|
+
"abs",
|
|
507
|
+
input_types={"x": vector(length=Any, dtype=Scalar)},
|
|
508
|
+
constraint=sametypes,
|
|
509
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
|
|
510
|
+
doc="Return the absolute values of the elements of ``x``.",
|
|
511
|
+
group="Vector Math",
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
add_builtin(
|
|
515
|
+
"sign",
|
|
516
|
+
input_types={"x": vector(length=Any, dtype=Scalar)},
|
|
517
|
+
constraint=sametypes,
|
|
518
|
+
value_func=sametypes_create_value_func(Scalar),
|
|
519
|
+
doc="Return -1 for the negative elements of ``x``, and 1 otherwise.",
|
|
520
|
+
group="Vector Math",
|
|
521
|
+
)
|
|
472
522
|
|
|
473
|
-
|
|
523
|
+
|
|
524
|
+
def outer_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
474
525
|
if arg_types is None:
|
|
475
526
|
return matrix(shape=(Any, Any), dtype=Scalar)
|
|
476
527
|
|
|
477
|
-
|
|
478
|
-
|
|
528
|
+
arg_type_values = tuple(arg_types.values())
|
|
529
|
+
scalarType = scalar_infer_type(arg_type_values)
|
|
530
|
+
vectorLengths = tuple(t._length_ for t in arg_type_values)
|
|
479
531
|
return matrix(shape=(vectorLengths), dtype=scalarType)
|
|
480
532
|
|
|
481
533
|
|
|
482
534
|
add_builtin(
|
|
483
535
|
"outer",
|
|
484
|
-
input_types={"
|
|
485
|
-
value_func=
|
|
536
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
|
|
537
|
+
value_func=outer_value_func,
|
|
486
538
|
group="Vector Math",
|
|
487
|
-
doc="Compute the outer product ``
|
|
539
|
+
doc="Compute the outer product ``a*b^T`` for two vectors.",
|
|
488
540
|
)
|
|
489
541
|
|
|
490
542
|
add_builtin(
|
|
491
543
|
"cross",
|
|
492
|
-
input_types={"
|
|
493
|
-
value_func=
|
|
544
|
+
input_types={"a": vector(length=3, dtype=Scalar), "b": vector(length=3, dtype=Scalar)},
|
|
545
|
+
value_func=sametypes_create_value_func(vector(length=3, dtype=Scalar)),
|
|
494
546
|
group="Vector Math",
|
|
495
547
|
doc="Compute the cross product of two 3D vectors.",
|
|
496
548
|
)
|
|
497
549
|
add_builtin(
|
|
498
550
|
"skew",
|
|
499
|
-
input_types={"
|
|
500
|
-
value_func=lambda arg_types,
|
|
551
|
+
input_types={"vec": vector(length=3, dtype=Scalar)},
|
|
552
|
+
value_func=lambda arg_types, arg_values: matrix(shape=(3, 3), dtype=arg_types["vec"]._wp_scalar_type_),
|
|
501
553
|
group="Vector Math",
|
|
502
|
-
doc="Compute the skew-symmetric 3x3 matrix for a 3D vector ``
|
|
554
|
+
doc="Compute the skew-symmetric 3x3 matrix for a 3D vector ``vec``.",
|
|
503
555
|
)
|
|
504
556
|
|
|
505
557
|
add_builtin(
|
|
506
558
|
"length",
|
|
507
|
-
input_types={"
|
|
508
|
-
value_func=
|
|
559
|
+
input_types={"a": vector(length=Any, dtype=Float)},
|
|
560
|
+
value_func=float_sametypes_value_func,
|
|
509
561
|
group="Vector Math",
|
|
510
|
-
doc="Compute the length of a floating-point vector ``
|
|
562
|
+
doc="Compute the length of a floating-point vector ``a``.",
|
|
511
563
|
require_original_output_arg=True,
|
|
512
564
|
)
|
|
513
565
|
add_builtin(
|
|
514
566
|
"length",
|
|
515
|
-
input_types={"
|
|
516
|
-
value_func=
|
|
567
|
+
input_types={"a": quaternion(dtype=Float)},
|
|
568
|
+
value_func=float_sametypes_value_func,
|
|
517
569
|
group="Vector Math",
|
|
518
|
-
doc="Compute the length of a quaternion ``
|
|
570
|
+
doc="Compute the length of a quaternion ``a``.",
|
|
519
571
|
require_original_output_arg=True,
|
|
520
572
|
)
|
|
521
573
|
add_builtin(
|
|
522
574
|
"length_sq",
|
|
523
|
-
input_types={"
|
|
524
|
-
value_func=
|
|
575
|
+
input_types={"a": vector(length=Any, dtype=Scalar)},
|
|
576
|
+
value_func=scalar_sametypes_value_func,
|
|
525
577
|
group="Vector Math",
|
|
526
|
-
doc="Compute the squared length of a vector ``
|
|
578
|
+
doc="Compute the squared length of a vector ``a``.",
|
|
527
579
|
)
|
|
528
580
|
add_builtin(
|
|
529
581
|
"length_sq",
|
|
530
|
-
input_types={"
|
|
531
|
-
value_func=
|
|
582
|
+
input_types={"a": quaternion(dtype=Scalar)},
|
|
583
|
+
value_func=scalar_sametypes_value_func,
|
|
532
584
|
group="Vector Math",
|
|
533
|
-
doc="Compute the squared length of a quaternion ``
|
|
585
|
+
doc="Compute the squared length of a quaternion ``a``.",
|
|
534
586
|
)
|
|
535
587
|
add_builtin(
|
|
536
588
|
"normalize",
|
|
537
|
-
input_types={"
|
|
538
|
-
value_func=
|
|
589
|
+
input_types={"a": vector(length=Any, dtype=Float)},
|
|
590
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Float)),
|
|
539
591
|
group="Vector Math",
|
|
540
|
-
doc="Compute the normalized value of ``
|
|
592
|
+
doc="Compute the normalized value of ``a``. If ``length(a)`` is 0 then the zero vector is returned.",
|
|
541
593
|
require_original_output_arg=True,
|
|
542
594
|
)
|
|
543
595
|
add_builtin(
|
|
544
596
|
"normalize",
|
|
545
|
-
input_types={"
|
|
546
|
-
value_func=
|
|
597
|
+
input_types={"a": quaternion(dtype=Float)},
|
|
598
|
+
value_func=sametypes_create_value_func(quaternion(dtype=Float)),
|
|
547
599
|
group="Vector Math",
|
|
548
|
-
doc="Compute the normalized value of ``
|
|
600
|
+
doc="Compute the normalized value of ``a``. If ``length(a)`` is 0, then the zero quaternion is returned.",
|
|
549
601
|
)
|
|
550
602
|
|
|
551
603
|
add_builtin(
|
|
552
604
|
"transpose",
|
|
553
|
-
input_types={"
|
|
554
|
-
value_func=lambda arg_types,
|
|
555
|
-
shape=(arg_types[
|
|
605
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar)},
|
|
606
|
+
value_func=lambda arg_types, arg_values: matrix(
|
|
607
|
+
shape=(arg_types["a"]._shape_[1], arg_types["a"]._shape_[0]), dtype=arg_types["a"]._wp_scalar_type_
|
|
556
608
|
),
|
|
557
609
|
group="Vector Math",
|
|
558
|
-
doc="Return the transpose of the matrix ``
|
|
610
|
+
doc="Return the transpose of the matrix ``a``.",
|
|
559
611
|
)
|
|
560
612
|
|
|
561
613
|
|
|
562
|
-
def
|
|
614
|
+
def inverse_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
563
615
|
if arg_types is None:
|
|
564
616
|
return matrix(shape=(Any, Any), dtype=Float)
|
|
565
|
-
|
|
617
|
+
|
|
618
|
+
return arg_types["a"]
|
|
566
619
|
|
|
567
620
|
|
|
568
621
|
add_builtin(
|
|
569
622
|
"inverse",
|
|
570
|
-
input_types={"
|
|
571
|
-
value_func=
|
|
623
|
+
input_types={"a": matrix(shape=(2, 2), dtype=Float)},
|
|
624
|
+
value_func=inverse_value_func,
|
|
572
625
|
group="Vector Math",
|
|
573
|
-
doc="Return the inverse of a 2x2 matrix ``
|
|
626
|
+
doc="Return the inverse of a 2x2 matrix ``a``.",
|
|
574
627
|
require_original_output_arg=True,
|
|
575
628
|
)
|
|
576
629
|
|
|
577
630
|
add_builtin(
|
|
578
631
|
"inverse",
|
|
579
|
-
input_types={"
|
|
580
|
-
value_func=
|
|
632
|
+
input_types={"a": matrix(shape=(3, 3), dtype=Float)},
|
|
633
|
+
value_func=inverse_value_func,
|
|
581
634
|
group="Vector Math",
|
|
582
|
-
doc="Return the inverse of a 3x3 matrix ``
|
|
635
|
+
doc="Return the inverse of a 3x3 matrix ``a``.",
|
|
583
636
|
require_original_output_arg=True,
|
|
584
637
|
)
|
|
585
638
|
|
|
586
639
|
add_builtin(
|
|
587
640
|
"inverse",
|
|
588
|
-
input_types={"
|
|
589
|
-
value_func=
|
|
641
|
+
input_types={"a": matrix(shape=(4, 4), dtype=Float)},
|
|
642
|
+
value_func=inverse_value_func,
|
|
590
643
|
group="Vector Math",
|
|
591
|
-
doc="Return the inverse of a 4x4 matrix ``
|
|
644
|
+
doc="Return the inverse of a 4x4 matrix ``a``.",
|
|
592
645
|
require_original_output_arg=True,
|
|
593
646
|
)
|
|
594
647
|
|
|
595
648
|
|
|
596
|
-
def
|
|
649
|
+
def determinant_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
597
650
|
if arg_types is None:
|
|
598
|
-
return
|
|
599
|
-
|
|
651
|
+
return Float
|
|
652
|
+
|
|
653
|
+
return arg_types["a"]._wp_scalar_type_
|
|
600
654
|
|
|
601
655
|
|
|
602
656
|
add_builtin(
|
|
603
657
|
"determinant",
|
|
604
|
-
input_types={"
|
|
605
|
-
value_func=
|
|
658
|
+
input_types={"a": matrix(shape=(2, 2), dtype=Float)},
|
|
659
|
+
value_func=determinant_value_func,
|
|
606
660
|
group="Vector Math",
|
|
607
|
-
doc="Return the determinant of a 2x2 matrix ``
|
|
661
|
+
doc="Return the determinant of a 2x2 matrix ``a``.",
|
|
608
662
|
)
|
|
609
663
|
|
|
610
664
|
add_builtin(
|
|
611
665
|
"determinant",
|
|
612
|
-
input_types={"
|
|
613
|
-
value_func=
|
|
666
|
+
input_types={"a": matrix(shape=(3, 3), dtype=Float)},
|
|
667
|
+
value_func=determinant_value_func,
|
|
614
668
|
group="Vector Math",
|
|
615
|
-
doc="Return the determinant of a 3x3 matrix ``
|
|
669
|
+
doc="Return the determinant of a 3x3 matrix ``a``.",
|
|
616
670
|
)
|
|
617
671
|
|
|
618
672
|
add_builtin(
|
|
619
673
|
"determinant",
|
|
620
|
-
input_types={"
|
|
621
|
-
value_func=
|
|
674
|
+
input_types={"a": matrix(shape=(4, 4), dtype=Float)},
|
|
675
|
+
value_func=determinant_value_func,
|
|
622
676
|
group="Vector Math",
|
|
623
|
-
doc="Return the determinant of a 4x4 matrix ``
|
|
677
|
+
doc="Return the determinant of a 4x4 matrix ``a``.",
|
|
624
678
|
)
|
|
625
679
|
|
|
626
680
|
|
|
627
|
-
def
|
|
681
|
+
def trace_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
628
682
|
if arg_types is None:
|
|
629
683
|
return Scalar
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
684
|
+
|
|
685
|
+
if arg_types["a"]._shape_[0] != arg_types["a"]._shape_[1]:
|
|
686
|
+
raise RuntimeError(f"Matrix shape is {arg_types['a']._shape_}. Cannot find the trace of non square matrices")
|
|
687
|
+
return arg_types["a"]._wp_scalar_type_
|
|
633
688
|
|
|
634
689
|
|
|
635
690
|
add_builtin(
|
|
636
691
|
"trace",
|
|
637
|
-
input_types={"
|
|
638
|
-
value_func=
|
|
692
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar)},
|
|
693
|
+
value_func=trace_value_func,
|
|
639
694
|
group="Vector Math",
|
|
640
|
-
doc="Return the trace of the matrix ``
|
|
695
|
+
doc="Return the trace of the matrix ``a``.",
|
|
641
696
|
)
|
|
642
697
|
|
|
643
698
|
|
|
644
|
-
def
|
|
699
|
+
def diag_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
645
700
|
if arg_types is None:
|
|
646
701
|
return matrix(shape=(Any, Any), dtype=Scalar)
|
|
647
|
-
|
|
648
|
-
|
|
702
|
+
|
|
703
|
+
return matrix(shape=(arg_types["vec"]._length_, arg_types["vec"]._length_), dtype=arg_types["vec"]._wp_scalar_type_)
|
|
649
704
|
|
|
650
705
|
|
|
651
706
|
add_builtin(
|
|
652
707
|
"diag",
|
|
653
|
-
input_types={"
|
|
654
|
-
value_func=
|
|
708
|
+
input_types={"vec": vector(length=Any, dtype=Scalar)},
|
|
709
|
+
value_func=diag_value_func,
|
|
655
710
|
group="Vector Math",
|
|
656
|
-
doc="Returns a matrix with the components of the vector ``
|
|
711
|
+
doc="Returns a matrix with the components of the vector ``vec`` on the diagonal.",
|
|
657
712
|
)
|
|
658
713
|
|
|
659
714
|
|
|
660
|
-
def
|
|
715
|
+
def get_diag_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
661
716
|
if arg_types is None:
|
|
662
717
|
return vector(length=(Any), dtype=Scalar)
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
718
|
+
|
|
719
|
+
if arg_types["mat"]._shape_[0] != arg_types["mat"]._shape_[1]:
|
|
720
|
+
raise RuntimeError(
|
|
721
|
+
f"Matrix shape is {arg_types['mat']._shape_}; get_diag is only available for square matrices."
|
|
722
|
+
)
|
|
723
|
+
return vector(length=arg_types["mat"]._shape_[0], dtype=arg_types["mat"]._wp_scalar_type_)
|
|
669
724
|
|
|
670
725
|
|
|
671
726
|
add_builtin(
|
|
672
727
|
"get_diag",
|
|
673
|
-
input_types={"
|
|
674
|
-
value_func=
|
|
728
|
+
input_types={"mat": matrix(shape=(Any, Any), dtype=Scalar)},
|
|
729
|
+
value_func=get_diag_value_func,
|
|
675
730
|
group="Vector Math",
|
|
676
|
-
doc="Returns a vector containing the diagonal elements of the square matrix ``
|
|
731
|
+
doc="Returns a vector containing the diagonal elements of the square matrix ``mat``.",
|
|
677
732
|
)
|
|
678
733
|
|
|
679
734
|
add_builtin(
|
|
680
735
|
"cw_mul",
|
|
681
|
-
input_types={"
|
|
736
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
|
|
682
737
|
constraint=sametypes,
|
|
683
|
-
value_func=
|
|
738
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
|
|
684
739
|
group="Vector Math",
|
|
685
740
|
doc="Component-wise multiplication of two vectors.",
|
|
686
741
|
)
|
|
687
742
|
add_builtin(
|
|
688
743
|
"cw_div",
|
|
689
|
-
input_types={"
|
|
744
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
|
|
690
745
|
constraint=sametypes,
|
|
691
|
-
value_func=
|
|
746
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
|
|
692
747
|
group="Vector Math",
|
|
693
748
|
doc="Component-wise division of two vectors.",
|
|
694
749
|
require_original_output_arg=True,
|
|
@@ -696,17 +751,17 @@ add_builtin(
|
|
|
696
751
|
|
|
697
752
|
add_builtin(
|
|
698
753
|
"cw_mul",
|
|
699
|
-
input_types={"
|
|
754
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "b": matrix(shape=(Any, Any), dtype=Scalar)},
|
|
700
755
|
constraint=sametypes,
|
|
701
|
-
value_func=
|
|
756
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
|
|
702
757
|
group="Vector Math",
|
|
703
758
|
doc="Component-wise multiplication of two matrices.",
|
|
704
759
|
)
|
|
705
760
|
add_builtin(
|
|
706
761
|
"cw_div",
|
|
707
|
-
input_types={"
|
|
762
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "b": matrix(shape=(Any, Any), dtype=Scalar)},
|
|
708
763
|
constraint=sametypes,
|
|
709
|
-
value_func=
|
|
764
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
|
|
710
765
|
group="Vector Math",
|
|
711
766
|
doc="Component-wise division of two matrices.",
|
|
712
767
|
require_original_output_arg=True,
|
|
@@ -719,7 +774,7 @@ for t in scalar_types_all:
|
|
|
719
774
|
for u in scalar_types_all:
|
|
720
775
|
add_builtin(
|
|
721
776
|
t.__name__,
|
|
722
|
-
input_types={"
|
|
777
|
+
input_types={"a": u},
|
|
723
778
|
value_type=t,
|
|
724
779
|
doc="",
|
|
725
780
|
hidden=True,
|
|
@@ -729,203 +784,231 @@ for t in scalar_types_all:
|
|
|
729
784
|
)
|
|
730
785
|
|
|
731
786
|
|
|
732
|
-
def
|
|
787
|
+
def vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
733
788
|
if arg_types is None:
|
|
734
789
|
return vector(length=Any, dtype=Scalar)
|
|
735
790
|
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
# constructor from another vector
|
|
756
|
-
if vectype._length_ != veclen:
|
|
757
|
-
raise RuntimeError(
|
|
758
|
-
f"Incompatible vector lengths for casting copy constructor, {veclen} vs {vectype._length_}"
|
|
759
|
-
)
|
|
760
|
-
vectype = vectype._wp_scalar_type_
|
|
761
|
-
else:
|
|
791
|
+
length = arg_values.get("length", None)
|
|
792
|
+
dtype = arg_values.get("dtype", None)
|
|
793
|
+
|
|
794
|
+
variadic_arg_types = arg_types.get("args", ())
|
|
795
|
+
variadic_arg_count = len(variadic_arg_types)
|
|
796
|
+
if variadic_arg_count == 0:
|
|
797
|
+
# Zero-initialization, e.g.: `wp.vecXX()`, `wp.vector(length=2, dtype=wp.float16)`.
|
|
798
|
+
if length is None:
|
|
799
|
+
raise RuntimeError("the `length` argument must be specified when zero-initializing a vector")
|
|
800
|
+
|
|
801
|
+
if dtype is None:
|
|
802
|
+
dtype = float32
|
|
803
|
+
elif variadic_arg_count == 1:
|
|
804
|
+
value_type = strip_reference(variadic_arg_types[0])
|
|
805
|
+
if type_is_vector(value_type):
|
|
806
|
+
# Copy constructor, e.g.: `wp.vecXX(other_vec)`, `wp.vector(other_vec)`.
|
|
807
|
+
if length is None:
|
|
808
|
+
length = value_type._length_
|
|
809
|
+
elif value_type._length_ != length:
|
|
762
810
|
raise RuntimeError(
|
|
763
|
-
"
|
|
811
|
+
f"incompatible vector of length {length} given when copy constructing "
|
|
812
|
+
f"a vector of length {value_type._length_}"
|
|
764
813
|
)
|
|
765
814
|
|
|
815
|
+
if dtype is None:
|
|
816
|
+
dtype = value_type._wp_scalar_type_
|
|
766
817
|
else:
|
|
767
|
-
|
|
818
|
+
# Initialization by filling a value, e.g.: `wp.vecXX(123)`,
|
|
819
|
+
# `wp.vector(123, length=2)`.
|
|
820
|
+
if length is None:
|
|
821
|
+
raise RuntimeError("the `length` argument must be specified when filling a vector with a value")
|
|
822
|
+
|
|
823
|
+
if dtype is None:
|
|
824
|
+
dtype = value_type
|
|
825
|
+
elif value_type != dtype:
|
|
768
826
|
raise RuntimeError(
|
|
769
|
-
"
|
|
827
|
+
f"the value used to fill this vector is expected to be of the type `{dtype.__name__}`"
|
|
770
828
|
)
|
|
829
|
+
else:
|
|
830
|
+
# Initializing by value, e.g.: `wp.vec2(1, 2)`, `wp.vector(1, 2, length=2)`.
|
|
831
|
+
if length is None:
|
|
832
|
+
length = variadic_arg_count
|
|
833
|
+
elif length != variadic_arg_count:
|
|
834
|
+
raise RuntimeError(
|
|
835
|
+
f"incompatible number of values given ({variadic_arg_count}) "
|
|
836
|
+
f"when constructing a vector of length {length}"
|
|
837
|
+
)
|
|
771
838
|
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
):
|
|
777
|
-
veclen = arg_types[0]._length_
|
|
778
|
-
vectype = kwds["dtype"]
|
|
779
|
-
templates.append(veclen)
|
|
780
|
-
templates.append(vectype)
|
|
781
|
-
return vector(length=veclen, dtype=vectype)
|
|
782
|
-
raise RuntimeError(
|
|
783
|
-
"vec() should not have dtype specified if numeric arguments are given, the dtype will be inferred from the argument types"
|
|
784
|
-
)
|
|
839
|
+
try:
|
|
840
|
+
value_type = scalar_infer_type(variadic_arg_types)
|
|
841
|
+
except RuntimeError:
|
|
842
|
+
raise RuntimeError("all values given when constructing a vector must have the same type") from None
|
|
785
843
|
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
844
|
+
if dtype is None:
|
|
845
|
+
dtype = value_type
|
|
846
|
+
elif value_type != dtype:
|
|
847
|
+
raise RuntimeError(
|
|
848
|
+
f"all values used to initialize this vector matrix are expected to be of the type `{dtype.__name__}`"
|
|
849
|
+
)
|
|
790
850
|
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
veclen = vectype._length_
|
|
794
|
-
vectype = vectype._wp_scalar_type_
|
|
795
|
-
elif not all(vectype == t for t in arg_types):
|
|
796
|
-
raise RuntimeError(
|
|
797
|
-
f"All numeric arguments to vec() constructor should have the same type, expected {veclen} arg_types of type {vectype}, received { ','.join([str(t) for t in arg_types]) }"
|
|
798
|
-
)
|
|
851
|
+
if length is None:
|
|
852
|
+
raise RuntimeError("could not infer the `length` argument when calling the `wp.vector()` function")
|
|
799
853
|
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
854
|
+
if dtype is None:
|
|
855
|
+
raise RuntimeError("could not infer the `dtype` argument when calling the `wp.vector()` function")
|
|
856
|
+
|
|
857
|
+
return vector(length=length, dtype=dtype)
|
|
803
858
|
|
|
804
|
-
else:
|
|
805
|
-
# construction of a predeclared type, e.g.: vec5d
|
|
806
|
-
veclen, vectype = templates
|
|
807
|
-
if len(arg_types) == 1 and type_is_vector(arg_types[0]):
|
|
808
|
-
# constructor from another vector
|
|
809
|
-
if arg_types[0]._length_ != veclen:
|
|
810
|
-
raise RuntimeError(
|
|
811
|
-
f"Incompatible matrix sizes for casting copy constructor, {veclen} vs {arg_types[0]._length_}"
|
|
812
|
-
)
|
|
813
|
-
elif not all(vectype == t for t in arg_types):
|
|
814
|
-
raise RuntimeError(
|
|
815
|
-
f"All numeric arguments to vec() constructor should have the same type, expected {veclen} arg_types of type {vectype}, received { ','.join([str(t) for t in arg_types]) }"
|
|
816
|
-
)
|
|
817
859
|
|
|
818
|
-
|
|
819
|
-
|
|
860
|
+
def vector_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
861
|
+
# We're in the codegen stage where we emit the code calling the built-in.
|
|
862
|
+
# Further validate the given argument values if needed and map them
|
|
863
|
+
# to the underlying C++ function's runtime and template params.
|
|
864
|
+
|
|
865
|
+
length = return_type._length_
|
|
866
|
+
dtype = return_type._wp_scalar_type_
|
|
867
|
+
|
|
868
|
+
variadic_args = args.get("args", ())
|
|
869
|
+
|
|
870
|
+
func_args = variadic_args
|
|
871
|
+
template_args = (length, dtype)
|
|
872
|
+
return (func_args, template_args)
|
|
820
873
|
|
|
821
874
|
|
|
822
875
|
add_builtin(
|
|
823
876
|
"vector",
|
|
824
|
-
input_types={"*
|
|
877
|
+
input_types={"*args": Scalar, "length": int, "dtype": Scalar},
|
|
878
|
+
defaults={"length": None, "dtype": None},
|
|
825
879
|
variadic=True,
|
|
826
|
-
initializer_list_func=lambda arg_types,
|
|
827
|
-
value_func=
|
|
880
|
+
initializer_list_func=lambda arg_types, arg_values: len(arg_types.get("args", ())) > 4,
|
|
881
|
+
value_func=vector_value_func,
|
|
882
|
+
export_func=lambda input_types: {k: v for k, v in input_types.items() if k not in ("length", "dtype")},
|
|
883
|
+
dispatch_func=vector_dispatch_func,
|
|
828
884
|
native_func="vec_t",
|
|
829
|
-
doc="Construct a vector of
|
|
885
|
+
doc="Construct a vector of given length and dtype.",
|
|
830
886
|
group="Vector Math",
|
|
831
887
|
export=False,
|
|
832
888
|
)
|
|
833
889
|
|
|
834
890
|
|
|
835
|
-
def
|
|
891
|
+
def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
836
892
|
if arg_types is None:
|
|
837
893
|
return matrix(shape=(Any, Any), dtype=Scalar)
|
|
838
894
|
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
895
|
+
shape = arg_values.get("shape", None)
|
|
896
|
+
dtype = arg_values.get("dtype", None)
|
|
897
|
+
|
|
898
|
+
variadic_arg_types = arg_types.get("args", ())
|
|
899
|
+
variadic_arg_count = len(variadic_arg_types)
|
|
900
|
+
if variadic_arg_count == 0:
|
|
901
|
+
# Zero-initialization, e.g.: `wp.matXX()`, `wp.matrix(shape=(2, 2), dtype=wp.float16)`.
|
|
902
|
+
if shape is None:
|
|
903
|
+
raise RuntimeError("the `shape` argument must be specified when zero-initializing a matrix")
|
|
904
|
+
|
|
905
|
+
if dtype is None:
|
|
906
|
+
dtype = float32
|
|
907
|
+
elif variadic_arg_count == 1:
|
|
908
|
+
value_type = strip_reference(variadic_arg_types[0])
|
|
909
|
+
if type_is_matrix(value_type):
|
|
910
|
+
# Copy constructor, e.g.: `wp.matXX(other_mat)`, `wp.matrix(other_mat)`.
|
|
911
|
+
if shape is None:
|
|
912
|
+
shape = value_type._shape_
|
|
913
|
+
elif not seq_check_equal(value_type._shape_, shape):
|
|
914
|
+
raise RuntimeError(
|
|
915
|
+
f"incompatible matrix of shape {tuple(shape)} given when copy constructing "
|
|
916
|
+
f"a matrix of shape {tuple(value_type._shape_)}"
|
|
917
|
+
)
|
|
843
918
|
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
919
|
+
if dtype is None:
|
|
920
|
+
dtype = value_type._wp_scalar_type_
|
|
921
|
+
else:
|
|
922
|
+
# Initialization by filling a value, e.g.: `wp.matXX(123)`,
|
|
923
|
+
# `wp.matrix(123, shape=(2, 2))`.
|
|
924
|
+
if shape is None:
|
|
925
|
+
raise RuntimeError("the `shape` argument must be specified when filling a matrix with a value")
|
|
926
|
+
|
|
927
|
+
if dtype is None:
|
|
928
|
+
dtype = value_type
|
|
929
|
+
elif value_type != dtype:
|
|
930
|
+
raise RuntimeError(
|
|
931
|
+
f"the value used to fill this matrix is expected to be of the type `{dtype.__name__}`"
|
|
932
|
+
)
|
|
933
|
+
else:
|
|
934
|
+
# Initializing by value, e.g.: `wp.mat22(1, 2, 3, 4)`, `wp.matrix(1, 2, 3, 4, shape=(2, 2))`.
|
|
935
|
+
if shape is None:
|
|
936
|
+
raise RuntimeError("the `shape` argument must be specified when initializing a matrix by value")
|
|
847
937
|
|
|
848
|
-
|
|
849
|
-
shape
|
|
850
|
-
|
|
938
|
+
if all(type_is_vector(x) for x in variadic_arg_types):
|
|
939
|
+
if shape[1] != variadic_arg_count:
|
|
940
|
+
raise RuntimeError(
|
|
941
|
+
f"incompatible number of column vectors given ({variadic_arg_count}) "
|
|
942
|
+
f"when constructing a matrix of shape {tuple(shape)}"
|
|
943
|
+
)
|
|
851
944
|
|
|
852
|
-
|
|
853
|
-
# value initialization, e.g.: m = matrix(1.0, shape=(3,2))
|
|
854
|
-
shape = kwds["shape"]
|
|
855
|
-
dtype = arg_types[0]
|
|
856
|
-
|
|
857
|
-
if len(arg_types) == 1 and type_is_matrix(dtype):
|
|
858
|
-
# constructor from another matrix
|
|
859
|
-
if arg_types[0]._shape_ != shape:
|
|
860
|
-
raise RuntimeError(
|
|
861
|
-
f"Incompatible matrix sizes for casting copy constructor, {shape} vs {arg_types[0]._shape_}"
|
|
862
|
-
)
|
|
863
|
-
dtype = dtype._wp_scalar_type_
|
|
864
|
-
elif len(arg_types) > 1 and len(arg_types) != shape[0] * shape[1]:
|
|
945
|
+
if any(x._length_ != shape[0] for x in variadic_arg_types):
|
|
865
946
|
raise RuntimeError(
|
|
866
|
-
"
|
|
947
|
+
f"incompatible column vector lengths given when constructing a matrix of shape {tuple(shape)}"
|
|
867
948
|
)
|
|
949
|
+
elif shape[0] * shape[1] != variadic_arg_count:
|
|
950
|
+
raise RuntimeError(
|
|
951
|
+
f"incompatible number of values given ({variadic_arg_count}) "
|
|
952
|
+
f"when constructing a matrix of shape {tuple(shape)}"
|
|
953
|
+
)
|
|
868
954
|
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
955
|
+
try:
|
|
956
|
+
value_type = scalar_infer_type(variadic_arg_types)
|
|
957
|
+
except RuntimeError:
|
|
958
|
+
raise RuntimeError("all values given when constructing a matrix must have the same type") from None
|
|
872
959
|
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
else:
|
|
886
|
-
# check scalar arg type matches declared type
|
|
887
|
-
if infer_scalar_type(arg_types) != dtype:
|
|
888
|
-
raise RuntimeError("Wrong scalar type for mat {} constructor".format(",".join(map(str, templates))))
|
|
889
|
-
|
|
890
|
-
# check vector arg type matches declared type
|
|
891
|
-
if all(type_is_vector(a) for a in arg_types):
|
|
892
|
-
cols = len(arg_types)
|
|
893
|
-
if shape[1] != cols:
|
|
894
|
-
raise RuntimeError(
|
|
895
|
-
"Wrong number of vectors when attempting to construct a matrix with column vectors"
|
|
896
|
-
)
|
|
897
|
-
|
|
898
|
-
if not all(a._length_ == shape[0] for a in arg_types):
|
|
899
|
-
raise RuntimeError(
|
|
900
|
-
"Wrong vector row count when attempting to construct a matrix with column vectors"
|
|
901
|
-
)
|
|
902
|
-
else:
|
|
903
|
-
# check that we either got 1 arg (scalar construction), or enough values for whole matrix
|
|
904
|
-
size = shape[0] * shape[1]
|
|
905
|
-
if len(arg_types) > 1 and len(arg_types) != size:
|
|
906
|
-
raise RuntimeError(
|
|
907
|
-
"Wrong number of scalars when attempting to construct a matrix from a list of components"
|
|
908
|
-
)
|
|
960
|
+
if dtype is None:
|
|
961
|
+
dtype = value_type
|
|
962
|
+
elif value_type != dtype:
|
|
963
|
+
raise RuntimeError(
|
|
964
|
+
f"all values used to initialize this matrix are expected to be of the type `{dtype.__name__}`"
|
|
965
|
+
)
|
|
966
|
+
|
|
967
|
+
if shape is None:
|
|
968
|
+
raise RuntimeError("could not infer the `shape` argument when calling the `wp.matrix()` function")
|
|
969
|
+
|
|
970
|
+
if dtype is None:
|
|
971
|
+
raise RuntimeError("could not infer the `dtype` argument when calling the `wp.matrix()` function")
|
|
909
972
|
|
|
910
973
|
return matrix(shape=shape, dtype=dtype)
|
|
911
974
|
|
|
912
975
|
|
|
976
|
+
def matrix_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
977
|
+
# We're in the codegen stage where we emit the code calling the built-in.
|
|
978
|
+
# Further validate the given argument values if needed and map them
|
|
979
|
+
# to the underlying C++ function's runtime and template params.
|
|
980
|
+
|
|
981
|
+
shape = return_type._shape_
|
|
982
|
+
dtype = return_type._wp_scalar_type_
|
|
983
|
+
|
|
984
|
+
variadic_args = args.get("args", ())
|
|
985
|
+
|
|
986
|
+
func_args = variadic_args
|
|
987
|
+
template_args = (*shape, dtype)
|
|
988
|
+
return (func_args, template_args)
|
|
989
|
+
|
|
990
|
+
|
|
913
991
|
# only use initializer list if matrix size < 5x5, or for scalar construction
|
|
914
|
-
def
|
|
915
|
-
|
|
992
|
+
def matrix_initializer_list_func(args, return_type):
|
|
993
|
+
shape = return_type._shape_
|
|
994
|
+
|
|
995
|
+
variadic_args = args.get("args", ())
|
|
996
|
+
variadic_arg_count = len(variadic_args)
|
|
916
997
|
return not (
|
|
917
|
-
|
|
918
|
-
or
|
|
919
|
-
or (m == n and n < 5) # scalar construction # value construction for small matrices
|
|
998
|
+
variadic_arg_count <= 1 # zero/fill initialization
|
|
999
|
+
or (shape[0] == shape[1] and shape[1] < 5) # value construction for small matrices
|
|
920
1000
|
)
|
|
921
1001
|
|
|
922
1002
|
|
|
923
1003
|
add_builtin(
|
|
924
1004
|
"matrix",
|
|
925
|
-
input_types={"*
|
|
1005
|
+
input_types={"*args": Scalar, "shape": Tuple[int, int], "dtype": Scalar},
|
|
1006
|
+
defaults={"shape": None, "dtype": None},
|
|
926
1007
|
variadic=True,
|
|
927
|
-
|
|
928
|
-
|
|
1008
|
+
value_func=matrix_value_func,
|
|
1009
|
+
export_func=lambda input_types: {k: v for k, v in input_types.items() if k not in ("shape", "dtype")},
|
|
1010
|
+
dispatch_func=matrix_dispatch_func,
|
|
1011
|
+
initializer_list_func=matrix_initializer_list_func,
|
|
929
1012
|
native_func="mat_t",
|
|
930
1013
|
doc="Construct a matrix. If the positional ``arg_types`` are not given, then matrix will be zero-initialized.",
|
|
931
1014
|
group="Vector Math",
|
|
@@ -933,69 +1016,95 @@ add_builtin(
|
|
|
933
1016
|
)
|
|
934
1017
|
|
|
935
1018
|
|
|
936
|
-
|
|
937
|
-
def matrix_identity_value_func(arg_types, kwds, templates):
|
|
1019
|
+
def identity_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
938
1020
|
if arg_types is None:
|
|
939
1021
|
return matrix(shape=(Any, Any), dtype=Scalar)
|
|
940
1022
|
|
|
941
|
-
|
|
942
|
-
|
|
1023
|
+
n = arg_values["n"]
|
|
1024
|
+
dtype = arg_values["dtype"]
|
|
943
1025
|
|
|
944
|
-
if
|
|
945
|
-
raise RuntimeError("'n'
|
|
1026
|
+
if n is None:
|
|
1027
|
+
raise RuntimeError("'n' must be a constant when calling identity()")
|
|
946
1028
|
|
|
947
|
-
|
|
948
|
-
raise RuntimeError("'dtype' keyword argument must be specified when calling identity() function")
|
|
1029
|
+
return matrix(shape=(n, n), dtype=dtype)
|
|
949
1030
|
|
|
950
|
-
n, dtype = [kwds["n"], kwds["dtype"]]
|
|
951
1031
|
|
|
952
|
-
|
|
953
|
-
|
|
1032
|
+
def identity_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
1033
|
+
# We're in the codegen stage where we emit the code calling the built-in.
|
|
1034
|
+
# Further validate the given argument values if needed and map them
|
|
1035
|
+
# to the underlying C++ function's runtime and template params.
|
|
954
1036
|
|
|
955
|
-
|
|
956
|
-
|
|
1037
|
+
shape = return_type._shape_
|
|
1038
|
+
dtype = return_type._wp_scalar_type_
|
|
957
1039
|
|
|
958
|
-
|
|
1040
|
+
func_args = ()
|
|
1041
|
+
template_args = (shape[0], dtype)
|
|
1042
|
+
return (func_args, template_args)
|
|
959
1043
|
|
|
960
1044
|
|
|
961
1045
|
add_builtin(
|
|
962
1046
|
"identity",
|
|
963
1047
|
input_types={"n": int, "dtype": Scalar},
|
|
964
|
-
value_func=
|
|
965
|
-
|
|
1048
|
+
value_func=identity_value_func,
|
|
1049
|
+
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
1050
|
+
dispatch_func=identity_dispatch_func,
|
|
966
1051
|
doc="Create an identity matrix with shape=(n,n) with the type given by ``dtype``.",
|
|
967
1052
|
group="Vector Math",
|
|
968
1053
|
export=False,
|
|
969
1054
|
)
|
|
970
1055
|
|
|
971
1056
|
|
|
972
|
-
def matrix_transform_value_func(arg_types,
|
|
973
|
-
if
|
|
1057
|
+
def matrix_transform_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1058
|
+
if arg_types is None:
|
|
974
1059
|
return matrix(shape=(4, 4), dtype=Float)
|
|
975
1060
|
|
|
976
|
-
|
|
977
|
-
raise RuntimeError("Cannot use a generic type name in a kernel")
|
|
1061
|
+
dtype = arg_values.get("dtype", None)
|
|
978
1062
|
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
raise RuntimeError(
|
|
1063
|
+
value_arg_types = tuple(v for k, v in arg_types.items() if k != "dtype")
|
|
1064
|
+
try:
|
|
1065
|
+
value_type = scalar_infer_type(value_arg_types)
|
|
1066
|
+
except RuntimeError:
|
|
1067
|
+
raise RuntimeError(
|
|
1068
|
+
"all values given when constructing a transformation matrix must have the same type"
|
|
1069
|
+
) from None
|
|
1070
|
+
|
|
1071
|
+
if dtype is None:
|
|
1072
|
+
dtype = value_type
|
|
1073
|
+
elif value_type != dtype:
|
|
1074
|
+
raise RuntimeError(
|
|
1075
|
+
f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
|
|
1076
|
+
)
|
|
984
1077
|
|
|
985
1078
|
return matrix(shape=(4, 4), dtype=dtype)
|
|
986
1079
|
|
|
987
1080
|
|
|
1081
|
+
def matrix_transform_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
1082
|
+
# We're in the codegen stage where we emit the code calling the built-in.
|
|
1083
|
+
# Further validate the given argument values if needed and map them
|
|
1084
|
+
# to the underlying C++ function's runtime and template params.
|
|
1085
|
+
|
|
1086
|
+
dtype = return_type._wp_scalar_type_
|
|
1087
|
+
|
|
1088
|
+
func_args = tuple(v for k, v in args.items() if k != "dtype")
|
|
1089
|
+
template_args = (4, 4, dtype)
|
|
1090
|
+
return (func_args, template_args)
|
|
1091
|
+
|
|
1092
|
+
|
|
988
1093
|
add_builtin(
|
|
989
1094
|
"matrix",
|
|
990
1095
|
input_types={
|
|
991
1096
|
"pos": vector(length=3, dtype=Float),
|
|
992
1097
|
"rot": quaternion(dtype=Float),
|
|
993
1098
|
"scale": vector(length=3, dtype=Float),
|
|
1099
|
+
"dtype": Float,
|
|
994
1100
|
},
|
|
1101
|
+
defaults={"dtype": None},
|
|
995
1102
|
value_func=matrix_transform_value_func,
|
|
1103
|
+
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
1104
|
+
dispatch_func=matrix_transform_dispatch_func,
|
|
996
1105
|
native_func="mat_t",
|
|
997
1106
|
doc="""Construct a 4x4 transformation matrix that applies the transformations as
|
|
998
|
-
Translation(pos)*Rotation(rot)*
|
|
1107
|
+
Translation(pos)*Rotation(rot)*Scaling(scale) when applied to column vectors, i.e.: y = (TRS)*x""",
|
|
999
1108
|
group="Vector Math",
|
|
1000
1109
|
export=False,
|
|
1001
1110
|
)
|
|
@@ -1050,42 +1159,69 @@ add_builtin(
|
|
|
1050
1159
|
# Quaternion Math
|
|
1051
1160
|
|
|
1052
1161
|
|
|
1053
|
-
def quaternion_value_func(arg_types,
|
|
1162
|
+
def quaternion_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1054
1163
|
if arg_types is None:
|
|
1055
1164
|
return quaternion(dtype=Float)
|
|
1056
1165
|
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1166
|
+
dtype = arg_values.get("dtype", None)
|
|
1167
|
+
|
|
1168
|
+
variadic_arg_types = tuple(v for k, v in arg_types.items() if k != "dtype")
|
|
1169
|
+
variadic_arg_count = len(variadic_arg_types)
|
|
1170
|
+
|
|
1171
|
+
if variadic_arg_count == 0:
|
|
1172
|
+
# Zero-initialization, e.g.: `wp.quat()`, `wp.quaternion(dtype=wp.float16)`.
|
|
1173
|
+
if dtype is None:
|
|
1174
|
+
dtype = float32
|
|
1175
|
+
elif dtype not in float_types:
|
|
1176
|
+
raise RuntimeError(
|
|
1177
|
+
f"a float type is expected when zero-initializing a quaternion but got `{type(dtype).__name__}` instead"
|
|
1178
|
+
)
|
|
1179
|
+
elif variadic_arg_count == 1:
|
|
1180
|
+
if type_is_quaternion(variadic_arg_types[0]):
|
|
1181
|
+
# Copy constructor, e.g.: `wp.quat(other_vec)`, `wp.quaternion(other_vec)`.
|
|
1182
|
+
in_quat = variadic_arg_types[0]
|
|
1183
|
+
if dtype is None:
|
|
1184
|
+
dtype = in_quat._wp_scalar_type_
|
|
1065
1185
|
else:
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1186
|
+
try:
|
|
1187
|
+
value_type = scalar_infer_type(variadic_arg_types)
|
|
1188
|
+
except RuntimeError:
|
|
1189
|
+
raise RuntimeError("all values given when constructing a quaternion must have the same type") from None
|
|
1190
|
+
|
|
1191
|
+
if dtype is None:
|
|
1192
|
+
dtype = value_type
|
|
1193
|
+
elif value_type != dtype:
|
|
1194
|
+
raise RuntimeError(
|
|
1195
|
+
f"all values used to initialize this quaternion are expected to be of the type `{dtype.__name__}`"
|
|
1196
|
+
)
|
|
1069
1197
|
|
|
1070
|
-
|
|
1198
|
+
if dtype is None:
|
|
1199
|
+
raise RuntimeError("could not infer the `dtype` argument when calling the `wp.quaternion()` function")
|
|
1071
1200
|
|
|
1201
|
+
return quaternion(dtype=dtype)
|
|
1072
1202
|
|
|
1073
|
-
def quat_cast_value_func(arg_types, kwds, templates):
|
|
1074
|
-
if arg_types is None:
|
|
1075
|
-
raise RuntimeError("Missing quaternion argument.")
|
|
1076
|
-
if "dtype" not in kwds:
|
|
1077
|
-
raise RuntimeError("Missing 'dtype' kwd.")
|
|
1078
1203
|
|
|
1079
|
-
|
|
1080
|
-
|
|
1204
|
+
def quaternion_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
1205
|
+
# We're in the codegen stage where we emit the code calling the built-in.
|
|
1206
|
+
# Further validate the given argument values if needed and map them
|
|
1207
|
+
# to the underlying C++ function's runtime and template params.
|
|
1081
1208
|
|
|
1082
|
-
|
|
1209
|
+
dtype = return_type._wp_scalar_type_
|
|
1210
|
+
|
|
1211
|
+
variadic_args = tuple(v for k, v in args.items() if k != "dtype")
|
|
1212
|
+
|
|
1213
|
+
func_args = variadic_args
|
|
1214
|
+
template_args = (dtype,)
|
|
1215
|
+
return (func_args, template_args)
|
|
1083
1216
|
|
|
1084
1217
|
|
|
1085
1218
|
add_builtin(
|
|
1086
1219
|
"quaternion",
|
|
1087
|
-
input_types={},
|
|
1220
|
+
input_types={"dtype": Float},
|
|
1221
|
+
defaults={"dtype": None},
|
|
1088
1222
|
value_func=quaternion_value_func,
|
|
1223
|
+
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
1224
|
+
dispatch_func=quaternion_dispatch_func,
|
|
1089
1225
|
native_func="quat_t",
|
|
1090
1226
|
group="Quaternion Math",
|
|
1091
1227
|
doc="""Construct a zero-initialized quaternion. Quaternions are laid out as
|
|
@@ -1096,6 +1232,8 @@ add_builtin(
|
|
|
1096
1232
|
"quaternion",
|
|
1097
1233
|
input_types={"x": Float, "y": Float, "z": Float, "w": Float},
|
|
1098
1234
|
value_func=quaternion_value_func,
|
|
1235
|
+
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
1236
|
+
dispatch_func=quaternion_dispatch_func,
|
|
1099
1237
|
native_func="quat_t",
|
|
1100
1238
|
group="Quaternion Math",
|
|
1101
1239
|
doc="Create a quaternion using the supplied components (type inferred from component type).",
|
|
@@ -1103,17 +1241,24 @@ add_builtin(
|
|
|
1103
1241
|
)
|
|
1104
1242
|
add_builtin(
|
|
1105
1243
|
"quaternion",
|
|
1106
|
-
input_types={"
|
|
1244
|
+
input_types={"ijk": vector(length=3, dtype=Float), "real": Float, "dtype": Float},
|
|
1245
|
+
defaults={"dtype": None},
|
|
1107
1246
|
value_func=quaternion_value_func,
|
|
1247
|
+
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
1248
|
+
dispatch_func=quaternion_dispatch_func,
|
|
1108
1249
|
native_func="quat_t",
|
|
1109
1250
|
group="Quaternion Math",
|
|
1110
1251
|
doc="Create a quaternion using the supplied vector/scalar (type inferred from scalar type).",
|
|
1111
1252
|
export=False,
|
|
1112
1253
|
)
|
|
1254
|
+
|
|
1113
1255
|
add_builtin(
|
|
1114
1256
|
"quaternion",
|
|
1115
|
-
input_types={"
|
|
1116
|
-
|
|
1257
|
+
input_types={"quat": quaternion(dtype=Float), "dtype": Float},
|
|
1258
|
+
defaults={"dtype": None},
|
|
1259
|
+
value_func=quaternion_value_func,
|
|
1260
|
+
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
1261
|
+
dispatch_func=quaternion_dispatch_func,
|
|
1117
1262
|
native_func="quat_t",
|
|
1118
1263
|
group="Quaternion Math",
|
|
1119
1264
|
doc="Construct a quaternion of type dtype from another quaternion of a different dtype.",
|
|
@@ -1121,26 +1266,34 @@ add_builtin(
|
|
|
1121
1266
|
)
|
|
1122
1267
|
|
|
1123
1268
|
|
|
1124
|
-
def quat_identity_value_func(arg_types,
|
|
1125
|
-
# if arg_types is None then we are in 'export' mode
|
|
1269
|
+
def quat_identity_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1126
1270
|
if arg_types is None:
|
|
1271
|
+
# return quaternion(dtype=Float)
|
|
1127
1272
|
return quatf
|
|
1128
1273
|
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
dtype = float32
|
|
1132
|
-
else:
|
|
1133
|
-
dtype = kwds["dtype"]
|
|
1274
|
+
dtype = arg_types.get("dtype", float32)
|
|
1275
|
+
return quaternion(dtype=dtype)
|
|
1134
1276
|
|
|
1135
|
-
templates.append(dtype)
|
|
1136
1277
|
|
|
1137
|
-
|
|
1278
|
+
def quat_identity_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
1279
|
+
# We're in the codegen stage where we emit the code calling the built-in.
|
|
1280
|
+
# Further validate the given argument values if needed and map them
|
|
1281
|
+
# to the underlying C++ function's runtime and template params.
|
|
1282
|
+
|
|
1283
|
+
dtype = return_type._wp_scalar_type_
|
|
1284
|
+
|
|
1285
|
+
func_args = ()
|
|
1286
|
+
template_args = (dtype,)
|
|
1287
|
+
return (func_args, template_args)
|
|
1138
1288
|
|
|
1139
1289
|
|
|
1140
1290
|
add_builtin(
|
|
1141
1291
|
"quat_identity",
|
|
1142
|
-
input_types={},
|
|
1292
|
+
input_types={"dtype": Float},
|
|
1293
|
+
defaults={"dtype": None},
|
|
1143
1294
|
value_func=quat_identity_value_func,
|
|
1295
|
+
export_func=lambda input_types: {},
|
|
1296
|
+
dispatch_func=quat_identity_dispatch_func,
|
|
1144
1297
|
group="Quaternion Math",
|
|
1145
1298
|
doc="Construct an identity quaternion with zero imaginary part and real part of 1.0",
|
|
1146
1299
|
export=True,
|
|
@@ -1149,72 +1302,72 @@ add_builtin(
|
|
|
1149
1302
|
add_builtin(
|
|
1150
1303
|
"quat_from_axis_angle",
|
|
1151
1304
|
input_types={"axis": vector(length=3, dtype=Float), "angle": Float},
|
|
1152
|
-
value_func=lambda arg_types,
|
|
1305
|
+
value_func=lambda arg_types, arg_values: quaternion(dtype=float_infer_type(arg_types)),
|
|
1153
1306
|
group="Quaternion Math",
|
|
1154
1307
|
doc="Construct a quaternion representing a rotation of angle radians around the given axis.",
|
|
1155
1308
|
)
|
|
1156
1309
|
add_builtin(
|
|
1157
1310
|
"quat_to_axis_angle",
|
|
1158
|
-
input_types={"
|
|
1311
|
+
input_types={"quat": quaternion(dtype=Float), "axis": vector(length=3, dtype=Float), "angle": Float},
|
|
1159
1312
|
value_type=None,
|
|
1160
1313
|
group="Quaternion Math",
|
|
1161
1314
|
doc="Extract the rotation axis and angle radians a quaternion represents.",
|
|
1162
1315
|
)
|
|
1163
1316
|
add_builtin(
|
|
1164
1317
|
"quat_from_matrix",
|
|
1165
|
-
input_types={"
|
|
1166
|
-
value_func=lambda arg_types,
|
|
1318
|
+
input_types={"mat": matrix(shape=(3, 3), dtype=Float)},
|
|
1319
|
+
value_func=lambda arg_types, arg_values: quaternion(dtype=float_infer_type(arg_types)),
|
|
1167
1320
|
group="Quaternion Math",
|
|
1168
1321
|
doc="Construct a quaternion from a 3x3 matrix.",
|
|
1169
1322
|
)
|
|
1170
1323
|
add_builtin(
|
|
1171
1324
|
"quat_rpy",
|
|
1172
1325
|
input_types={"roll": Float, "pitch": Float, "yaw": Float},
|
|
1173
|
-
value_func=lambda arg_types,
|
|
1326
|
+
value_func=lambda arg_types, arg_values: quaternion(dtype=float_infer_type(arg_types)),
|
|
1174
1327
|
group="Quaternion Math",
|
|
1175
1328
|
doc="Construct a quaternion representing a combined roll (z), pitch (x), yaw rotations (y) in radians.",
|
|
1176
1329
|
)
|
|
1177
1330
|
add_builtin(
|
|
1178
1331
|
"quat_inverse",
|
|
1179
|
-
input_types={"
|
|
1180
|
-
value_func=lambda arg_types,
|
|
1332
|
+
input_types={"quat": quaternion(dtype=Float)},
|
|
1333
|
+
value_func=lambda arg_types, arg_values: quaternion(dtype=float_infer_type(arg_types)),
|
|
1181
1334
|
group="Quaternion Math",
|
|
1182
1335
|
doc="Compute quaternion conjugate.",
|
|
1183
1336
|
)
|
|
1184
1337
|
add_builtin(
|
|
1185
1338
|
"quat_rotate",
|
|
1186
|
-
input_types={"
|
|
1187
|
-
value_func=lambda arg_types,
|
|
1339
|
+
input_types={"quat": quaternion(dtype=Float), "vec": vector(length=3, dtype=Float)},
|
|
1340
|
+
value_func=lambda arg_types, arg_values: vector(length=3, dtype=float_infer_type(arg_types)),
|
|
1188
1341
|
group="Quaternion Math",
|
|
1189
1342
|
doc="Rotate a vector by a quaternion.",
|
|
1190
1343
|
)
|
|
1191
1344
|
add_builtin(
|
|
1192
1345
|
"quat_rotate_inv",
|
|
1193
|
-
input_types={"
|
|
1194
|
-
value_func=lambda arg_types,
|
|
1346
|
+
input_types={"quat": quaternion(dtype=Float), "vec": vector(length=3, dtype=Float)},
|
|
1347
|
+
value_func=lambda arg_types, arg_values: vector(length=3, dtype=float_infer_type(arg_types)),
|
|
1195
1348
|
group="Quaternion Math",
|
|
1196
1349
|
doc="Rotate a vector by the inverse of a quaternion.",
|
|
1197
1350
|
)
|
|
1198
1351
|
add_builtin(
|
|
1199
1352
|
"quat_slerp",
|
|
1200
|
-
input_types={"
|
|
1201
|
-
value_func=lambda arg_types,
|
|
1353
|
+
input_types={"a": quaternion(dtype=Float), "b": quaternion(dtype=Float), "t": Float},
|
|
1354
|
+
value_func=lambda arg_types, arg_values: quaternion(dtype=float_infer_type(arg_types)),
|
|
1202
1355
|
group="Quaternion Math",
|
|
1203
1356
|
doc="Linearly interpolate between two quaternions.",
|
|
1204
1357
|
require_original_output_arg=True,
|
|
1205
1358
|
)
|
|
1206
1359
|
add_builtin(
|
|
1207
1360
|
"quat_to_matrix",
|
|
1208
|
-
input_types={"
|
|
1209
|
-
value_func=lambda arg_types,
|
|
1361
|
+
input_types={"quat": quaternion(dtype=Float)},
|
|
1362
|
+
value_func=lambda arg_types, arg_values: matrix(shape=(3, 3), dtype=float_infer_type(arg_types)),
|
|
1210
1363
|
group="Quaternion Math",
|
|
1211
1364
|
doc="Convert a quaternion to a 3x3 rotation matrix.",
|
|
1212
1365
|
)
|
|
1213
1366
|
|
|
1214
1367
|
add_builtin(
|
|
1215
1368
|
"dot",
|
|
1216
|
-
input_types={"
|
|
1217
|
-
value_func=
|
|
1369
|
+
input_types={"a": quaternion(dtype=Float), "b": quaternion(dtype=Float)},
|
|
1370
|
+
value_func=float_sametypes_value_func,
|
|
1218
1371
|
group="Quaternion Math",
|
|
1219
1372
|
doc="Compute the dot product between two quaternions.",
|
|
1220
1373
|
)
|
|
@@ -1222,55 +1375,85 @@ add_builtin(
|
|
|
1222
1375
|
# Transformations
|
|
1223
1376
|
|
|
1224
1377
|
|
|
1225
|
-
def
|
|
1226
|
-
if
|
|
1227
|
-
return transformation(dtype=
|
|
1378
|
+
def transformation_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1379
|
+
if arg_types is None:
|
|
1380
|
+
return transformation(dtype=Float)
|
|
1228
1381
|
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1382
|
+
try:
|
|
1383
|
+
value_type = float_infer_type(arg_types)
|
|
1384
|
+
except RuntimeError:
|
|
1385
|
+
raise RuntimeError(
|
|
1386
|
+
"all values given when constructing a transformation matrix must have the same type"
|
|
1387
|
+
) from None
|
|
1388
|
+
|
|
1389
|
+
dtype = arg_values.get("dtype", None)
|
|
1390
|
+
if dtype is None:
|
|
1391
|
+
dtype = value_type
|
|
1392
|
+
elif value_type != dtype:
|
|
1393
|
+
raise RuntimeError(
|
|
1394
|
+
f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
|
|
1395
|
+
)
|
|
1396
|
+
|
|
1397
|
+
return transformation(dtype=dtype)
|
|
1239
1398
|
|
|
1240
|
-
|
|
1399
|
+
|
|
1400
|
+
def transformation_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
1401
|
+
# We're in the codegen stage where we emit the code calling the built-in.
|
|
1402
|
+
# Further validate the given argument values if needed and map them
|
|
1403
|
+
# to the underlying C++ function's runtime and template params.
|
|
1404
|
+
|
|
1405
|
+
dtype = return_type._wp_scalar_type_
|
|
1406
|
+
|
|
1407
|
+
variadic_args = tuple(v for k, v in args.items() if k != "dtype")
|
|
1408
|
+
|
|
1409
|
+
func_args = variadic_args
|
|
1410
|
+
template_args = (dtype,)
|
|
1411
|
+
return (func_args, template_args)
|
|
1241
1412
|
|
|
1242
1413
|
|
|
1243
1414
|
add_builtin(
|
|
1244
1415
|
"transformation",
|
|
1245
|
-
input_types={"
|
|
1246
|
-
|
|
1416
|
+
input_types={"pos": vector(length=3, dtype=Float), "rot": quaternion(dtype=Float), "dtype": Float},
|
|
1417
|
+
defaults={"dtype": None},
|
|
1418
|
+
value_func=transformation_value_func,
|
|
1419
|
+
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
1420
|
+
dispatch_func=transformation_dispatch_func,
|
|
1247
1421
|
native_func="transform_t",
|
|
1248
1422
|
group="Transformations",
|
|
1249
|
-
doc="Construct a rigid-body transformation with translation part ``
|
|
1423
|
+
doc="Construct a rigid-body transformation with translation part ``pos`` and rotation ``rot``.",
|
|
1250
1424
|
export=False,
|
|
1251
1425
|
)
|
|
1252
1426
|
|
|
1253
1427
|
|
|
1254
|
-
def transform_identity_value_func(arg_types,
|
|
1428
|
+
def transform_identity_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1255
1429
|
# if arg_types is None then we are in 'export' mode
|
|
1256
1430
|
if arg_types is None:
|
|
1431
|
+
# return transformation(dtype=Float)
|
|
1257
1432
|
return transformf
|
|
1258
1433
|
|
|
1259
|
-
|
|
1260
|
-
|
|
1261
|
-
dtype = float32
|
|
1262
|
-
else:
|
|
1263
|
-
dtype = kwds["dtype"]
|
|
1434
|
+
dtype = arg_types.get("dtype", float32)
|
|
1435
|
+
return transformation(dtype=dtype)
|
|
1264
1436
|
|
|
1265
|
-
templates.append(dtype)
|
|
1266
1437
|
|
|
1267
|
-
|
|
1438
|
+
def transform_identity_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
1439
|
+
# We're in the codegen stage where we emit the code calling the built-in.
|
|
1440
|
+
# Further validate the given argument values if needed and map them
|
|
1441
|
+
# to the underlying C++ function's runtime and template params.
|
|
1442
|
+
|
|
1443
|
+
dtype = return_type._wp_scalar_type_
|
|
1444
|
+
|
|
1445
|
+
func_args = ()
|
|
1446
|
+
template_args = (dtype,)
|
|
1447
|
+
return (func_args, template_args)
|
|
1268
1448
|
|
|
1269
1449
|
|
|
1270
1450
|
add_builtin(
|
|
1271
1451
|
"transform_identity",
|
|
1272
|
-
input_types={},
|
|
1452
|
+
input_types={"dtype": Float},
|
|
1453
|
+
defaults={"dtype": None},
|
|
1273
1454
|
value_func=transform_identity_value_func,
|
|
1455
|
+
export_func=lambda input_types: {},
|
|
1456
|
+
dispatch_func=transform_identity_dispatch_func,
|
|
1274
1457
|
group="Transformations",
|
|
1275
1458
|
doc="Construct an identity transform with zero translation and identity rotation.",
|
|
1276
1459
|
export=True,
|
|
@@ -1278,103 +1461,168 @@ add_builtin(
|
|
|
1278
1461
|
|
|
1279
1462
|
add_builtin(
|
|
1280
1463
|
"transform_get_translation",
|
|
1281
|
-
input_types={"
|
|
1282
|
-
value_func=lambda arg_types,
|
|
1464
|
+
input_types={"xform": transformation(dtype=Float)},
|
|
1465
|
+
value_func=lambda arg_types, arg_values: vector(length=3, dtype=float_infer_type(arg_types)),
|
|
1283
1466
|
group="Transformations",
|
|
1284
|
-
doc="Return the translational part of a transform ``
|
|
1467
|
+
doc="Return the translational part of a transform ``xform``.",
|
|
1285
1468
|
)
|
|
1286
1469
|
add_builtin(
|
|
1287
1470
|
"transform_get_rotation",
|
|
1288
|
-
input_types={"
|
|
1289
|
-
value_func=lambda arg_types,
|
|
1471
|
+
input_types={"xform": transformation(dtype=Float)},
|
|
1472
|
+
value_func=lambda arg_types, arg_values: quaternion(dtype=float_infer_type(arg_types)),
|
|
1290
1473
|
group="Transformations",
|
|
1291
|
-
doc="Return the rotational part of a transform ``
|
|
1474
|
+
doc="Return the rotational part of a transform ``xform``.",
|
|
1292
1475
|
)
|
|
1293
1476
|
add_builtin(
|
|
1294
1477
|
"transform_multiply",
|
|
1295
1478
|
input_types={"a": transformation(dtype=Float), "b": transformation(dtype=Float)},
|
|
1296
|
-
value_func=lambda arg_types,
|
|
1479
|
+
value_func=lambda arg_types, arg_values: transformation(dtype=float_infer_type(arg_types)),
|
|
1297
1480
|
group="Transformations",
|
|
1298
1481
|
doc="Multiply two rigid body transformations together.",
|
|
1299
1482
|
)
|
|
1300
1483
|
add_builtin(
|
|
1301
1484
|
"transform_point",
|
|
1302
|
-
input_types={"
|
|
1303
|
-
value_func=lambda arg_types,
|
|
1485
|
+
input_types={"xform": transformation(dtype=Float), "point": vector(length=3, dtype=Float)},
|
|
1486
|
+
value_func=lambda arg_types, arg_values: vector(length=3, dtype=float_infer_type(arg_types)),
|
|
1304
1487
|
group="Transformations",
|
|
1305
|
-
doc="Apply the transform to a point ``
|
|
1488
|
+
doc="Apply the transform to a point ``point`` treating the homogeneous coordinate as w=1 (translation and rotation).",
|
|
1306
1489
|
)
|
|
1307
1490
|
add_builtin(
|
|
1308
1491
|
"transform_point",
|
|
1309
|
-
input_types={"
|
|
1310
|
-
value_func=lambda arg_types,
|
|
1492
|
+
input_types={"mat": matrix(shape=(4, 4), dtype=Float), "point": vector(length=3, dtype=Float)},
|
|
1493
|
+
value_func=lambda arg_types, arg_values: vector(length=3, dtype=float_infer_type(arg_types)),
|
|
1311
1494
|
group="Vector Math",
|
|
1312
|
-
doc="""Apply the transform to a point ``
|
|
1495
|
+
doc="""Apply the transform to a point ``point`` treating the homogeneous coordinate as w=1.
|
|
1313
1496
|
|
|
1314
|
-
The transformation is applied treating ``
|
|
1315
|
-
Note this is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T =
|
|
1497
|
+
The transformation is applied treating ``point`` as a column vector, e.g.: ``y = mat*point``.
|
|
1498
|
+
Note this is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T = point^T*mat^T``.
|
|
1316
1499
|
If the transform is coming from a library that uses row-vectors, then users should transpose the transformation
|
|
1317
1500
|
matrix before calling this method.""",
|
|
1318
1501
|
)
|
|
1319
1502
|
add_builtin(
|
|
1320
1503
|
"transform_vector",
|
|
1321
|
-
input_types={"
|
|
1322
|
-
value_func=lambda arg_types,
|
|
1504
|
+
input_types={"xform": transformation(dtype=Float), "vec": vector(length=3, dtype=Float)},
|
|
1505
|
+
value_func=lambda arg_types, arg_values: vector(length=3, dtype=float_infer_type(arg_types)),
|
|
1323
1506
|
group="Transformations",
|
|
1324
|
-
doc="Apply the transform to a vector ``
|
|
1507
|
+
doc="Apply the transform to a vector ``vec`` treating the homogeneous coordinate as w=0 (rotation only).",
|
|
1325
1508
|
)
|
|
1326
1509
|
add_builtin(
|
|
1327
1510
|
"transform_vector",
|
|
1328
|
-
input_types={"
|
|
1329
|
-
value_func=lambda arg_types,
|
|
1511
|
+
input_types={"mat": matrix(shape=(4, 4), dtype=Float), "vec": vector(length=3, dtype=Float)},
|
|
1512
|
+
value_func=lambda arg_types, arg_values: vector(length=3, dtype=float_infer_type(arg_types)),
|
|
1330
1513
|
group="Vector Math",
|
|
1331
|
-
doc="""Apply the transform to a vector ``
|
|
1514
|
+
doc="""Apply the transform to a vector ``vec`` treating the homogeneous coordinate as w=0.
|
|
1332
1515
|
|
|
1333
|
-
The transformation is applied treating ``
|
|
1334
|
-
note this is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T =
|
|
1516
|
+
The transformation is applied treating ``vec`` as a column vector, e.g.: ``y = mat*vec``
|
|
1517
|
+
note this is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T = vec^T*mat^T``.
|
|
1335
1518
|
If the transform is coming from a library that uses row-vectors, then users should transpose the transformation
|
|
1336
1519
|
matrix before calling this method.""",
|
|
1337
1520
|
)
|
|
1338
1521
|
add_builtin(
|
|
1339
1522
|
"transform_inverse",
|
|
1340
|
-
input_types={"
|
|
1341
|
-
value_func=
|
|
1523
|
+
input_types={"xform": transformation(dtype=Float)},
|
|
1524
|
+
value_func=sametypes_create_value_func(transformation(dtype=Float)),
|
|
1342
1525
|
group="Transformations",
|
|
1343
|
-
doc="Compute the inverse of the transformation ``
|
|
1526
|
+
doc="Compute the inverse of the transformation ``xform``.",
|
|
1344
1527
|
)
|
|
1345
1528
|
# ---------------------------------
|
|
1346
1529
|
# Spatial Math
|
|
1347
1530
|
|
|
1348
1531
|
|
|
1349
|
-
def
|
|
1350
|
-
if
|
|
1532
|
+
def spatial_vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1533
|
+
if arg_types is None:
|
|
1351
1534
|
return spatial_vector(dtype=Float)
|
|
1352
1535
|
|
|
1353
|
-
|
|
1354
|
-
|
|
1536
|
+
dtype = arg_values.get("dtype", None)
|
|
1537
|
+
|
|
1538
|
+
variadic_arg_types = tuple(v for k, v in arg_types.items() if k != "dtype")
|
|
1539
|
+
variadic_arg_count = len(variadic_arg_types)
|
|
1540
|
+
if variadic_arg_count == 0:
|
|
1541
|
+
if dtype is None:
|
|
1542
|
+
dtype = float32
|
|
1543
|
+
elif variadic_arg_count == 2:
|
|
1544
|
+
if any(not type_is_vector(x) for x in variadic_arg_types) or any(x._length_ != 3 for x in variadic_arg_types):
|
|
1545
|
+
raise RuntimeError("arguments `w` and `v` are expected to be vectors of length 3")
|
|
1546
|
+
elif variadic_arg_count != 6:
|
|
1547
|
+
raise RuntimeError("2 vectors or 6 scalar values are expected when constructing a spatial vector")
|
|
1548
|
+
|
|
1549
|
+
if variadic_arg_count:
|
|
1550
|
+
try:
|
|
1551
|
+
value_type = float_infer_type(variadic_arg_types)
|
|
1552
|
+
except RuntimeError:
|
|
1553
|
+
raise RuntimeError("all values given when constructing a spatial vector must have the same type") from None
|
|
1554
|
+
|
|
1555
|
+
if dtype is None:
|
|
1556
|
+
dtype = value_type
|
|
1557
|
+
elif value_type != dtype:
|
|
1558
|
+
raise RuntimeError(
|
|
1559
|
+
f"all values used to initialize this spatial vector are expected to be of the type `{dtype.__name__}`"
|
|
1560
|
+
)
|
|
1561
|
+
|
|
1562
|
+
return vector(length=6, dtype=dtype)
|
|
1563
|
+
|
|
1355
1564
|
|
|
1356
|
-
|
|
1357
|
-
|
|
1358
|
-
|
|
1565
|
+
def spatial_vector_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
1566
|
+
# We're in the codegen stage where we emit the code calling the built-in.
|
|
1567
|
+
# Further validate the given argument values if needed and map them
|
|
1568
|
+
# to the underlying C++ function's runtime and template params.
|
|
1359
1569
|
|
|
1360
|
-
|
|
1570
|
+
length = return_type._length_
|
|
1571
|
+
dtype = return_type._wp_scalar_type_
|
|
1572
|
+
|
|
1573
|
+
variadic_args = tuple(v for k, v in args.items() if k != "dtype")
|
|
1574
|
+
|
|
1575
|
+
func_args = variadic_args
|
|
1576
|
+
template_args = (length, dtype)
|
|
1577
|
+
return (func_args, template_args)
|
|
1361
1578
|
|
|
1362
1579
|
|
|
1363
1580
|
add_builtin(
|
|
1364
|
-
"
|
|
1365
|
-
input_types={"
|
|
1366
|
-
|
|
1581
|
+
"spatial_vector",
|
|
1582
|
+
input_types={"dtype": Float},
|
|
1583
|
+
defaults={"dtype": None},
|
|
1584
|
+
value_func=spatial_vector_value_func,
|
|
1585
|
+
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
1586
|
+
dispatch_func=spatial_vector_dispatch_func,
|
|
1587
|
+
native_func="vec_t",
|
|
1588
|
+
group="Spatial Math",
|
|
1589
|
+
doc="Zero-initialize a 6D screw vector.",
|
|
1590
|
+
export=False,
|
|
1591
|
+
)
|
|
1592
|
+
|
|
1593
|
+
|
|
1594
|
+
add_builtin(
|
|
1595
|
+
"spatial_vector",
|
|
1596
|
+
input_types={"w": vector(length=3, dtype=Float), "v": vector(length=3, dtype=Float), "dtype": Float},
|
|
1597
|
+
defaults={"dtype": None},
|
|
1598
|
+
value_func=spatial_vector_value_func,
|
|
1599
|
+
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
1600
|
+
dispatch_func=spatial_vector_dispatch_func,
|
|
1367
1601
|
native_func="vec_t",
|
|
1368
1602
|
group="Spatial Math",
|
|
1369
1603
|
doc="Construct a 6D screw vector from two 3D vectors.",
|
|
1370
1604
|
export=False,
|
|
1371
1605
|
)
|
|
1372
1606
|
|
|
1607
|
+
add_builtin(
|
|
1608
|
+
"spatial_vector",
|
|
1609
|
+
input_types={"wx": Float, "wy": Float, "wz": Float, "vx": Float, "vy": Float, "vz": Float, "dtype": Float},
|
|
1610
|
+
defaults={"dtype": None},
|
|
1611
|
+
initializer_list_func=lambda arg_types, arg_values: True,
|
|
1612
|
+
value_func=spatial_vector_value_func,
|
|
1613
|
+
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
1614
|
+
dispatch_func=spatial_vector_dispatch_func,
|
|
1615
|
+
native_func="vec_t",
|
|
1616
|
+
group="Spatial Math",
|
|
1617
|
+
doc="Construct a 6D screw vector from six values.",
|
|
1618
|
+
export=False,
|
|
1619
|
+
)
|
|
1620
|
+
|
|
1373
1621
|
|
|
1374
1622
|
add_builtin(
|
|
1375
1623
|
"spatial_adjoint",
|
|
1376
1624
|
input_types={"r": matrix(shape=(3, 3), dtype=Float), "s": matrix(shape=(3, 3), dtype=Float)},
|
|
1377
|
-
value_func=lambda arg_types,
|
|
1625
|
+
value_func=lambda arg_types, arg_values: matrix(shape=(6, 6), dtype=float_infer_type(arg_types)),
|
|
1378
1626
|
group="Spatial Math",
|
|
1379
1627
|
doc="Construct a 6x6 spatial inertial matrix from two 3x3 diagonal blocks.",
|
|
1380
1628
|
export=False,
|
|
@@ -1382,36 +1630,36 @@ add_builtin(
|
|
|
1382
1630
|
add_builtin(
|
|
1383
1631
|
"spatial_dot",
|
|
1384
1632
|
input_types={"a": vector(length=6, dtype=Float), "b": vector(length=6, dtype=Float)},
|
|
1385
|
-
value_func=
|
|
1633
|
+
value_func=float_sametypes_value_func,
|
|
1386
1634
|
group="Spatial Math",
|
|
1387
1635
|
doc="Compute the dot product of two 6D screw vectors.",
|
|
1388
1636
|
)
|
|
1389
1637
|
add_builtin(
|
|
1390
1638
|
"spatial_cross",
|
|
1391
1639
|
input_types={"a": vector(length=6, dtype=Float), "b": vector(length=6, dtype=Float)},
|
|
1392
|
-
value_func=
|
|
1640
|
+
value_func=sametypes_create_value_func(vector(length=6, dtype=Float)),
|
|
1393
1641
|
group="Spatial Math",
|
|
1394
1642
|
doc="Compute the cross product of two 6D screw vectors.",
|
|
1395
1643
|
)
|
|
1396
1644
|
add_builtin(
|
|
1397
1645
|
"spatial_cross_dual",
|
|
1398
1646
|
input_types={"a": vector(length=6, dtype=Float), "b": vector(length=6, dtype=Float)},
|
|
1399
|
-
value_func=
|
|
1647
|
+
value_func=sametypes_create_value_func(vector(length=6, dtype=Float)),
|
|
1400
1648
|
group="Spatial Math",
|
|
1401
1649
|
doc="Compute the dual cross product of two 6D screw vectors.",
|
|
1402
1650
|
)
|
|
1403
1651
|
|
|
1404
1652
|
add_builtin(
|
|
1405
1653
|
"spatial_top",
|
|
1406
|
-
input_types={"
|
|
1407
|
-
value_func=lambda arg_types,
|
|
1654
|
+
input_types={"svec": vector(length=6, dtype=Float)},
|
|
1655
|
+
value_func=lambda arg_types, arg_values: vector(length=3, dtype=arg_types["svec"]._wp_scalar_type_),
|
|
1408
1656
|
group="Spatial Math",
|
|
1409
1657
|
doc="Return the top (first) part of a 6D screw vector.",
|
|
1410
1658
|
)
|
|
1411
1659
|
add_builtin(
|
|
1412
1660
|
"spatial_bottom",
|
|
1413
|
-
input_types={"
|
|
1414
|
-
value_func=lambda arg_types,
|
|
1661
|
+
input_types={"svec": vector(length=6, dtype=Float)},
|
|
1662
|
+
value_func=lambda arg_types, arg_values: vector(length=3, dtype=arg_types["svec"]._wp_scalar_type_),
|
|
1415
1663
|
group="Spatial Math",
|
|
1416
1664
|
doc="Return the bottom (second) part of a 6D screw vector.",
|
|
1417
1665
|
)
|
|
@@ -1588,22 +1836,23 @@ add_builtin(
|
|
|
1588
1836
|
|
|
1589
1837
|
add_builtin(
|
|
1590
1838
|
"bvh_query_aabb",
|
|
1591
|
-
input_types={"id": uint64, "
|
|
1592
|
-
|
|
1839
|
+
input_types={"id": uint64, "low": vec3, "high": vec3},
|
|
1840
|
+
value_func=lambda arg_types, _: BvhQuery if arg_types is None else bvh_query_t,
|
|
1593
1841
|
group="Geometry",
|
|
1594
1842
|
doc="""Construct an axis-aligned bounding box query against a BVH object.
|
|
1595
1843
|
|
|
1596
1844
|
This query can be used to iterate over all bounds inside a BVH.
|
|
1597
1845
|
|
|
1598
1846
|
:param id: The BVH identifier
|
|
1599
|
-
:param
|
|
1600
|
-
:param
|
|
1847
|
+
:param low: The lower bound of the bounding box in BVH space
|
|
1848
|
+
:param high: The upper bound of the bounding box in BVH space""",
|
|
1849
|
+
export=False,
|
|
1601
1850
|
)
|
|
1602
1851
|
|
|
1603
1852
|
add_builtin(
|
|
1604
1853
|
"bvh_query_ray",
|
|
1605
1854
|
input_types={"id": uint64, "start": vec3, "dir": vec3},
|
|
1606
|
-
|
|
1855
|
+
value_func=lambda arg_types, _: BvhQuery if arg_types is None else bvh_query_t,
|
|
1607
1856
|
group="Geometry",
|
|
1608
1857
|
doc="""Construct a ray query against a BVH object.
|
|
1609
1858
|
|
|
@@ -1612,15 +1861,17 @@ add_builtin(
|
|
|
1612
1861
|
:param id: The BVH identifier
|
|
1613
1862
|
:param start: The start of the ray in BVH space
|
|
1614
1863
|
:param dir: The direction of the ray in BVH space""",
|
|
1864
|
+
export=False,
|
|
1615
1865
|
)
|
|
1616
1866
|
|
|
1617
1867
|
add_builtin(
|
|
1618
1868
|
"bvh_query_next",
|
|
1619
|
-
input_types={"query":
|
|
1869
|
+
input_types={"query": BvhQuery, "index": int},
|
|
1620
1870
|
value_type=builtins.bool,
|
|
1621
1871
|
group="Geometry",
|
|
1622
1872
|
doc="""Move to the next bound returned by the query.
|
|
1623
1873
|
The index of the current bound is stored in ``index``, returns ``False`` if there are no more overlapping bound.""",
|
|
1874
|
+
export=False,
|
|
1624
1875
|
)
|
|
1625
1876
|
|
|
1626
1877
|
add_builtin(
|
|
@@ -1650,6 +1901,7 @@ add_builtin(
|
|
|
1650
1901
|
:param face: Returns the index of the closest face
|
|
1651
1902
|
:param bary_u: Returns the barycentric u coordinate of the closest point
|
|
1652
1903
|
:param bary_v: Returns the barycentric v coordinate of the closest point""",
|
|
1904
|
+
export=False,
|
|
1653
1905
|
hidden=True,
|
|
1654
1906
|
)
|
|
1655
1907
|
|
|
@@ -1660,7 +1912,7 @@ add_builtin(
|
|
|
1660
1912
|
"point": vec3,
|
|
1661
1913
|
"max_dist": float,
|
|
1662
1914
|
},
|
|
1663
|
-
|
|
1915
|
+
value_func=lambda arg_types, _: MeshQueryPoint if arg_types is None else mesh_query_point_t,
|
|
1664
1916
|
group="Geometry",
|
|
1665
1917
|
doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space.
|
|
1666
1918
|
|
|
@@ -1672,6 +1924,7 @@ add_builtin(
|
|
|
1672
1924
|
:param point: The point in space to query
|
|
1673
1925
|
:param max_dist: Mesh faces above this distance will not be considered by the query""",
|
|
1674
1926
|
require_original_output_arg=True,
|
|
1927
|
+
export=False,
|
|
1675
1928
|
)
|
|
1676
1929
|
|
|
1677
1930
|
add_builtin(
|
|
@@ -1696,6 +1949,7 @@ add_builtin(
|
|
|
1696
1949
|
:param face: Returns the index of the closest face
|
|
1697
1950
|
:param bary_u: Returns the barycentric u coordinate of the closest point
|
|
1698
1951
|
:param bary_v: Returns the barycentric v coordinate of the closest point""",
|
|
1952
|
+
export=False,
|
|
1699
1953
|
hidden=True,
|
|
1700
1954
|
)
|
|
1701
1955
|
|
|
@@ -1706,7 +1960,7 @@ add_builtin(
|
|
|
1706
1960
|
"point": vec3,
|
|
1707
1961
|
"max_dist": float,
|
|
1708
1962
|
},
|
|
1709
|
-
|
|
1963
|
+
value_func=lambda arg_types, _: MeshQueryPoint if arg_types is None else mesh_query_point_t,
|
|
1710
1964
|
group="Geometry",
|
|
1711
1965
|
doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space.
|
|
1712
1966
|
|
|
@@ -1716,6 +1970,7 @@ add_builtin(
|
|
|
1716
1970
|
:param point: The point in space to query
|
|
1717
1971
|
:param max_dist: Mesh faces above this distance will not be considered by the query""",
|
|
1718
1972
|
require_original_output_arg=True,
|
|
1973
|
+
export=False,
|
|
1719
1974
|
)
|
|
1720
1975
|
|
|
1721
1976
|
add_builtin(
|
|
@@ -1740,6 +1995,7 @@ add_builtin(
|
|
|
1740
1995
|
:param face: Returns the index of the furthest face
|
|
1741
1996
|
:param bary_u: Returns the barycentric u coordinate of the furthest point
|
|
1742
1997
|
:param bary_v: Returns the barycentric v coordinate of the furthest point""",
|
|
1998
|
+
export=False,
|
|
1743
1999
|
hidden=True,
|
|
1744
2000
|
)
|
|
1745
2001
|
|
|
@@ -1750,7 +2006,7 @@ add_builtin(
|
|
|
1750
2006
|
"point": vec3,
|
|
1751
2007
|
"min_dist": float,
|
|
1752
2008
|
},
|
|
1753
|
-
|
|
2009
|
+
value_func=lambda arg_types, _: MeshQueryPoint if arg_types is None else mesh_query_point_t,
|
|
1754
2010
|
group="Geometry",
|
|
1755
2011
|
doc="""Computes the furthest point on the mesh with identifier `id` to the given point in space.
|
|
1756
2012
|
|
|
@@ -1760,6 +2016,7 @@ add_builtin(
|
|
|
1760
2016
|
:param point: The point in space to query
|
|
1761
2017
|
:param min_dist: Mesh faces below this distance will not be considered by the query""",
|
|
1762
2018
|
require_original_output_arg=True,
|
|
2019
|
+
export=False,
|
|
1763
2020
|
)
|
|
1764
2021
|
|
|
1765
2022
|
add_builtin(
|
|
@@ -1793,6 +2050,7 @@ add_builtin(
|
|
|
1793
2050
|
:param bary_v: Returns the barycentric v coordinate of the closest point
|
|
1794
2051
|
:param epsilon: Epsilon treating distance values as equal, when locating the minimum distance vertex/face/edge, as a
|
|
1795
2052
|
fraction of the average edge length, also for treating closest point as being on edge/vertex default 1e-3""",
|
|
2053
|
+
export=False,
|
|
1796
2054
|
hidden=True,
|
|
1797
2055
|
)
|
|
1798
2056
|
|
|
@@ -1805,7 +2063,7 @@ add_builtin(
|
|
|
1805
2063
|
"epsilon": float,
|
|
1806
2064
|
},
|
|
1807
2065
|
defaults={"epsilon": 1.0e-3},
|
|
1808
|
-
|
|
2066
|
+
value_func=lambda arg_types, _: MeshQueryPoint if arg_types is None else mesh_query_point_t,
|
|
1809
2067
|
group="Geometry",
|
|
1810
2068
|
doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space.
|
|
1811
2069
|
|
|
@@ -1819,6 +2077,7 @@ add_builtin(
|
|
|
1819
2077
|
:param epsilon: Epsilon treating distance values as equal, when locating the minimum distance vertex/face/edge, as a
|
|
1820
2078
|
fraction of the average edge length, also for treating closest point as being on edge/vertex default 1e-3""",
|
|
1821
2079
|
require_original_output_arg=True,
|
|
2080
|
+
export=False,
|
|
1822
2081
|
)
|
|
1823
2082
|
|
|
1824
2083
|
add_builtin(
|
|
@@ -1855,6 +2114,7 @@ add_builtin(
|
|
|
1855
2114
|
:param bary_v: Returns the barycentric v coordinate of the closest point
|
|
1856
2115
|
:param accuracy: Accuracy for computing the winding number with fast winding number method utilizing second-order dipole approximation, default 2.0
|
|
1857
2116
|
:param threshold: The threshold of the winding number to be considered inside, default 0.5""",
|
|
2117
|
+
export=False,
|
|
1858
2118
|
hidden=True,
|
|
1859
2119
|
)
|
|
1860
2120
|
|
|
@@ -1868,7 +2128,7 @@ add_builtin(
|
|
|
1868
2128
|
"threshold": float,
|
|
1869
2129
|
},
|
|
1870
2130
|
defaults={"accuracy": 2.0, "threshold": 0.5},
|
|
1871
|
-
|
|
2131
|
+
value_func=lambda arg_types, _: MeshQueryPoint if arg_types is None else mesh_query_point_t,
|
|
1872
2132
|
group="Geometry",
|
|
1873
2133
|
doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given point in space.
|
|
1874
2134
|
|
|
@@ -1884,6 +2144,7 @@ add_builtin(
|
|
|
1884
2144
|
:param accuracy: Accuracy for computing the winding number with fast winding number method utilizing second-order dipole approximation, default 2.0
|
|
1885
2145
|
:param threshold: The threshold of the winding number to be considered inside, default 0.5""",
|
|
1886
2146
|
require_original_output_arg=True,
|
|
2147
|
+
export=False,
|
|
1887
2148
|
)
|
|
1888
2149
|
|
|
1889
2150
|
add_builtin(
|
|
@@ -1914,6 +2175,7 @@ add_builtin(
|
|
|
1914
2175
|
:param sign: Returns a value > 0 if the ray hit in front of the face, returns < 0 otherwise
|
|
1915
2176
|
:param normal: Returns the face normal
|
|
1916
2177
|
:param face: Returns the index of the hit face""",
|
|
2178
|
+
export=False,
|
|
1917
2179
|
hidden=True,
|
|
1918
2180
|
)
|
|
1919
2181
|
|
|
@@ -1925,7 +2187,7 @@ add_builtin(
|
|
|
1925
2187
|
"dir": vec3,
|
|
1926
2188
|
"max_t": float,
|
|
1927
2189
|
},
|
|
1928
|
-
|
|
2190
|
+
value_func=lambda arg_types, _: MeshQueryRay if arg_types is None else mesh_query_ray_t,
|
|
1929
2191
|
group="Geometry",
|
|
1930
2192
|
doc="""Computes the closest ray hit on the :class:`Mesh` with identifier ``id``.
|
|
1931
2193
|
|
|
@@ -1934,30 +2196,33 @@ add_builtin(
|
|
|
1934
2196
|
:param dir: The ray direction (should be normalized)
|
|
1935
2197
|
:param max_t: The maximum distance along the ray to check for intersections""",
|
|
1936
2198
|
require_original_output_arg=True,
|
|
2199
|
+
export=False,
|
|
1937
2200
|
)
|
|
1938
2201
|
|
|
1939
2202
|
add_builtin(
|
|
1940
2203
|
"mesh_query_aabb",
|
|
1941
|
-
input_types={"id": uint64, "
|
|
1942
|
-
|
|
2204
|
+
input_types={"id": uint64, "low": vec3, "high": vec3},
|
|
2205
|
+
value_func=lambda arg_types, _: MeshQueryAABB if arg_types is None else mesh_query_aabb_t,
|
|
1943
2206
|
group="Geometry",
|
|
1944
2207
|
doc="""Construct an axis-aligned bounding box query against a :class:`Mesh`.
|
|
1945
2208
|
|
|
1946
2209
|
This query can be used to iterate over all triangles inside a volume.
|
|
1947
2210
|
|
|
1948
2211
|
:param id: The mesh identifier
|
|
1949
|
-
:param
|
|
1950
|
-
:param
|
|
2212
|
+
:param low: The lower bound of the bounding box in mesh space
|
|
2213
|
+
:param high: The upper bound of the bounding box in mesh space""",
|
|
2214
|
+
export=False,
|
|
1951
2215
|
)
|
|
1952
2216
|
|
|
1953
2217
|
add_builtin(
|
|
1954
2218
|
"mesh_query_aabb_next",
|
|
1955
|
-
input_types={"query":
|
|
2219
|
+
input_types={"query": MeshQueryAABB, "index": int},
|
|
1956
2220
|
value_type=builtins.bool,
|
|
1957
2221
|
group="Geometry",
|
|
1958
2222
|
doc="""Move to the next triangle overlapping the query bounding box.
|
|
1959
2223
|
|
|
1960
2224
|
The index of the current face is stored in ``index``, returns ``False`` if there are no more overlapping triangles.""",
|
|
2225
|
+
export=False,
|
|
1961
2226
|
)
|
|
1962
2227
|
|
|
1963
2228
|
add_builtin(
|
|
@@ -1966,6 +2231,7 @@ add_builtin(
|
|
|
1966
2231
|
value_type=vec3,
|
|
1967
2232
|
group="Geometry",
|
|
1968
2233
|
doc="""Evaluates the position on the :class:`Mesh` given a face index and barycentric coordinates.""",
|
|
2234
|
+
export=False,
|
|
1969
2235
|
)
|
|
1970
2236
|
|
|
1971
2237
|
add_builtin(
|
|
@@ -1974,26 +2240,29 @@ add_builtin(
|
|
|
1974
2240
|
value_type=vec3,
|
|
1975
2241
|
group="Geometry",
|
|
1976
2242
|
doc="""Evaluates the velocity on the :class:`Mesh` given a face index and barycentric coordinates.""",
|
|
2243
|
+
export=False,
|
|
1977
2244
|
)
|
|
1978
2245
|
|
|
1979
2246
|
add_builtin(
|
|
1980
2247
|
"hash_grid_query",
|
|
1981
2248
|
input_types={"id": uint64, "point": vec3, "max_dist": float},
|
|
1982
|
-
|
|
2249
|
+
value_func=lambda arg_types, _: HashGridQuery if arg_types is None else hash_grid_query_t,
|
|
1983
2250
|
group="Geometry",
|
|
1984
2251
|
doc="""Construct a point query against a :class:`HashGrid`.
|
|
1985
2252
|
|
|
1986
2253
|
This query can be used to iterate over all neighboring point within a fixed radius from the query point.""",
|
|
2254
|
+
export=False,
|
|
1987
2255
|
)
|
|
1988
2256
|
|
|
1989
2257
|
add_builtin(
|
|
1990
2258
|
"hash_grid_query_next",
|
|
1991
|
-
input_types={"query":
|
|
2259
|
+
input_types={"query": HashGridQuery, "index": int},
|
|
1992
2260
|
value_type=builtins.bool,
|
|
1993
2261
|
group="Geometry",
|
|
1994
2262
|
doc="""Move to the next point in the hash grid query.
|
|
1995
2263
|
|
|
1996
2264
|
The index of the current neighbor is stored in ``index``, returns ``False`` if there are no more neighbors.""",
|
|
2265
|
+
export=False,
|
|
1997
2266
|
)
|
|
1998
2267
|
|
|
1999
2268
|
add_builtin(
|
|
@@ -2006,6 +2275,7 @@ add_builtin(
|
|
|
2006
2275
|
This can be used to reorder threads such that grid traversal occurs in a spatially coherent order.
|
|
2007
2276
|
|
|
2008
2277
|
Returns -1 if the :class:`HashGrid` has not been reserved.""",
|
|
2278
|
+
export=False,
|
|
2009
2279
|
)
|
|
2010
2280
|
|
|
2011
2281
|
add_builtin(
|
|
@@ -2016,6 +2286,7 @@ add_builtin(
|
|
|
2016
2286
|
doc="""Tests for intersection between two triangles (v0, v1, v2) and (u0, u1, u2) using Moller's method.
|
|
2017
2287
|
|
|
2018
2288
|
Returns > 0 if triangles intersect.""",
|
|
2289
|
+
export=False,
|
|
2019
2290
|
)
|
|
2020
2291
|
|
|
2021
2292
|
add_builtin(
|
|
@@ -2025,6 +2296,7 @@ add_builtin(
|
|
|
2025
2296
|
missing_grad=True,
|
|
2026
2297
|
group="Geometry",
|
|
2027
2298
|
doc="""Retrieves the mesh given its index.""",
|
|
2299
|
+
export=False,
|
|
2028
2300
|
)
|
|
2029
2301
|
|
|
2030
2302
|
add_builtin(
|
|
@@ -2033,6 +2305,7 @@ add_builtin(
|
|
|
2033
2305
|
value_type=vec3,
|
|
2034
2306
|
group="Geometry",
|
|
2035
2307
|
doc="""Evaluates the face normal the mesh given a face index.""",
|
|
2308
|
+
export=False,
|
|
2036
2309
|
)
|
|
2037
2310
|
|
|
2038
2311
|
add_builtin(
|
|
@@ -2041,6 +2314,7 @@ add_builtin(
|
|
|
2041
2314
|
value_type=vec3,
|
|
2042
2315
|
group="Geometry",
|
|
2043
2316
|
doc="""Returns the point of the mesh given a index.""",
|
|
2317
|
+
export=False,
|
|
2044
2318
|
)
|
|
2045
2319
|
|
|
2046
2320
|
add_builtin(
|
|
@@ -2049,6 +2323,7 @@ add_builtin(
|
|
|
2049
2323
|
value_type=vec3,
|
|
2050
2324
|
group="Geometry",
|
|
2051
2325
|
doc="""Returns the velocity of the mesh given a index.""",
|
|
2326
|
+
export=False,
|
|
2052
2327
|
)
|
|
2053
2328
|
|
|
2054
2329
|
add_builtin(
|
|
@@ -2057,6 +2332,7 @@ add_builtin(
|
|
|
2057
2332
|
value_type=int,
|
|
2058
2333
|
group="Geometry",
|
|
2059
2334
|
doc="""Returns the point-index of the mesh given a face-vertex index.""",
|
|
2335
|
+
export=False,
|
|
2060
2336
|
)
|
|
2061
2337
|
|
|
2062
2338
|
|
|
@@ -2075,6 +2351,7 @@ add_builtin(
|
|
|
2075
2351
|
:param q2: Second point of second edge
|
|
2076
2352
|
:param epsilon: Zero tolerance for determining if points in an edge are degenerate.
|
|
2077
2353
|
:param out: vec3 output containing (s,t,d), where `s` in [0,1] is the barycentric weight for the first edge, `t` is the barycentric weight for the second edge, and `d` is the distance between the two edges at these two closest points.""",
|
|
2354
|
+
export=False,
|
|
2078
2355
|
)
|
|
2079
2356
|
|
|
2080
2357
|
# ---------------------------------
|
|
@@ -2096,9 +2373,13 @@ add_builtin(
|
|
|
2096
2373
|
# ---------------------------------
|
|
2097
2374
|
# Iterators
|
|
2098
2375
|
|
|
2099
|
-
add_builtin("iter_next", input_types={"range": range_t}, value_type=int, group="Utility", hidden=True)
|
|
2100
|
-
add_builtin(
|
|
2101
|
-
|
|
2376
|
+
add_builtin("iter_next", input_types={"range": range_t}, value_type=int, group="Utility", export=False, hidden=True)
|
|
2377
|
+
add_builtin(
|
|
2378
|
+
"iter_next", input_types={"query": hash_grid_query_t}, value_type=int, group="Utility", export=False, hidden=True
|
|
2379
|
+
)
|
|
2380
|
+
add_builtin(
|
|
2381
|
+
"iter_next", input_types={"query": mesh_query_aabb_t}, value_type=int, group="Utility", export=False, hidden=True
|
|
2382
|
+
)
|
|
2102
2383
|
|
|
2103
2384
|
# ---------------------------------
|
|
2104
2385
|
# Volumes
|
|
@@ -2116,26 +2397,46 @@ _volume_supported_value_types = {
|
|
|
2116
2397
|
}
|
|
2117
2398
|
|
|
2118
2399
|
|
|
2119
|
-
def
|
|
2120
|
-
|
|
2121
|
-
|
|
2122
|
-
|
|
2123
|
-
|
|
2124
|
-
"'dtype' keyword argument must be specified when calling generic volume lookup or sampling functions"
|
|
2125
|
-
) from err
|
|
2400
|
+
def check_volume_value_grad_compatibility(dtype, grad_dtype):
|
|
2401
|
+
if type_is_vector(dtype):
|
|
2402
|
+
expected = matrix(shape=(type_length(dtype), 3), dtype=type_scalar_type(dtype))
|
|
2403
|
+
else:
|
|
2404
|
+
expected = vector(length=3, dtype=dtype)
|
|
2126
2405
|
|
|
2127
|
-
if
|
|
2128
|
-
raise RuntimeError(f"
|
|
2406
|
+
if not types_equal(grad_dtype, expected):
|
|
2407
|
+
raise RuntimeError(f"Incompatible gradient type, expected {type_repr(expected)}, got {type_repr(grad_dtype)}")
|
|
2129
2408
|
|
|
2130
|
-
|
|
2409
|
+
|
|
2410
|
+
def volume_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
2411
|
+
if arg_types is None:
|
|
2412
|
+
return Any
|
|
2413
|
+
|
|
2414
|
+
dtype = arg_values["dtype"]
|
|
2415
|
+
|
|
2416
|
+
if dtype not in _volume_supported_value_types:
|
|
2417
|
+
raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
|
|
2131
2418
|
|
|
2132
2419
|
return dtype
|
|
2133
2420
|
|
|
2134
2421
|
|
|
2422
|
+
def volume_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2423
|
+
# We're in the codegen stage where we emit the code calling the built-in.
|
|
2424
|
+
# Further validate the given argument values if needed and map them
|
|
2425
|
+
# to the underlying C++ function's runtime and template params.
|
|
2426
|
+
|
|
2427
|
+
dtype = args["dtype"]
|
|
2428
|
+
|
|
2429
|
+
func_args = tuple(v for k, v in args.items() if k != "dtype")
|
|
2430
|
+
template_args = (dtype,)
|
|
2431
|
+
return (func_args, template_args)
|
|
2432
|
+
|
|
2433
|
+
|
|
2135
2434
|
add_builtin(
|
|
2136
2435
|
"volume_sample",
|
|
2137
2436
|
input_types={"id": uint64, "uvw": vec3, "sampling_mode": int, "dtype": Any},
|
|
2138
2437
|
value_func=volume_value_func,
|
|
2438
|
+
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
2439
|
+
dispatch_func=volume_dispatch_func,
|
|
2139
2440
|
export=False,
|
|
2140
2441
|
group="Volumes",
|
|
2141
2442
|
doc="""Sample the volume of type `dtype` given by ``id`` at the volume local-space point ``uvw``.
|
|
@@ -2144,31 +2445,38 @@ add_builtin(
|
|
|
2144
2445
|
)
|
|
2145
2446
|
|
|
2146
2447
|
|
|
2147
|
-
def
|
|
2148
|
-
if
|
|
2149
|
-
|
|
2150
|
-
else:
|
|
2151
|
-
expected = vector(length=3, dtype=dtype)
|
|
2152
|
-
|
|
2153
|
-
if not types_equal(grad_dtype, expected):
|
|
2154
|
-
raise RuntimeError(f"Incompatible gradient type, expected {type_repr(expected)}, got {type_repr(grad_dtype)}")
|
|
2448
|
+
def volume_sample_grad_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
2449
|
+
if arg_types is None:
|
|
2450
|
+
return Any
|
|
2155
2451
|
|
|
2452
|
+
dtype = arg_values["dtype"]
|
|
2156
2453
|
|
|
2157
|
-
|
|
2158
|
-
|
|
2454
|
+
if dtype not in _volume_supported_value_types:
|
|
2455
|
+
raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
|
|
2159
2456
|
|
|
2160
|
-
|
|
2161
|
-
raise RuntimeError("'volume_sample_grad' requires 4 positional arguments")
|
|
2457
|
+
check_volume_value_grad_compatibility(dtype, arg_types["grad"])
|
|
2162
2458
|
|
|
2163
|
-
grad_type = arg_types[3]
|
|
2164
|
-
check_volume_value_grad_compatibility(dtype, grad_type)
|
|
2165
2459
|
return dtype
|
|
2166
2460
|
|
|
2167
2461
|
|
|
2462
|
+
def volume_sample_grad_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2463
|
+
# We're in the codegen stage where we emit the code calling the built-in.
|
|
2464
|
+
# Further validate the given argument values if needed and map them
|
|
2465
|
+
# to the underlying C++ function's runtime and template params.
|
|
2466
|
+
|
|
2467
|
+
dtype = args["dtype"]
|
|
2468
|
+
|
|
2469
|
+
func_args = tuple(v for k, v in args.items() if k != "dtype")
|
|
2470
|
+
template_args = (dtype,)
|
|
2471
|
+
return (func_args, template_args)
|
|
2472
|
+
|
|
2473
|
+
|
|
2168
2474
|
add_builtin(
|
|
2169
2475
|
"volume_sample_grad",
|
|
2170
2476
|
input_types={"id": uint64, "uvw": vec3, "sampling_mode": int, "grad": Any, "dtype": Any},
|
|
2171
2477
|
value_func=volume_sample_grad_value_func,
|
|
2478
|
+
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
2479
|
+
dispatch_func=volume_sample_grad_dispatch_func,
|
|
2172
2480
|
export=False,
|
|
2173
2481
|
group="Volumes",
|
|
2174
2482
|
doc="""Sample the volume given by ``id`` and its gradient at the volume local-space point ``uvw``.
|
|
@@ -2176,11 +2484,38 @@ add_builtin(
|
|
|
2176
2484
|
Interpolation should be :attr:`warp.Volume.CLOSEST` or :attr:`wp.Volume.LINEAR.`""",
|
|
2177
2485
|
)
|
|
2178
2486
|
|
|
2487
|
+
|
|
2488
|
+
def volume_lookup_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
2489
|
+
if arg_types is None:
|
|
2490
|
+
return Any
|
|
2491
|
+
|
|
2492
|
+
dtype = arg_values["dtype"]
|
|
2493
|
+
|
|
2494
|
+
if dtype not in _volume_supported_value_types:
|
|
2495
|
+
raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
|
|
2496
|
+
|
|
2497
|
+
return dtype
|
|
2498
|
+
|
|
2499
|
+
|
|
2500
|
+
def volume_lookup_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2501
|
+
# We're in the codegen stage where we emit the code calling the built-in.
|
|
2502
|
+
# Further validate the given argument values if needed and map them
|
|
2503
|
+
# to the underlying C++ function's runtime and template params.
|
|
2504
|
+
|
|
2505
|
+
dtype = args["dtype"]
|
|
2506
|
+
|
|
2507
|
+
func_args = tuple(v for k, v in args.items() if k != "dtype")
|
|
2508
|
+
template_args = (dtype,)
|
|
2509
|
+
return (func_args, template_args)
|
|
2510
|
+
|
|
2511
|
+
|
|
2179
2512
|
add_builtin(
|
|
2180
2513
|
"volume_lookup",
|
|
2181
2514
|
input_types={"id": uint64, "i": int, "j": int, "k": int, "dtype": Any},
|
|
2182
2515
|
value_type=int,
|
|
2183
|
-
value_func=
|
|
2516
|
+
value_func=volume_lookup_value_func,
|
|
2517
|
+
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
2518
|
+
dispatch_func=volume_lookup_dispatch_func,
|
|
2184
2519
|
export=False,
|
|
2185
2520
|
group="Volumes",
|
|
2186
2521
|
doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k`` for a volume of type type `dtype`.
|
|
@@ -2189,13 +2524,14 @@ add_builtin(
|
|
|
2189
2524
|
)
|
|
2190
2525
|
|
|
2191
2526
|
|
|
2192
|
-
def volume_store_value_func(arg_types,
|
|
2193
|
-
if
|
|
2194
|
-
|
|
2527
|
+
def volume_store_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
2528
|
+
if arg_types is None:
|
|
2529
|
+
return None
|
|
2530
|
+
|
|
2531
|
+
dtype = arg_types["value"]
|
|
2195
2532
|
|
|
2196
|
-
dtype = arg_types[4]
|
|
2197
2533
|
if dtype not in _volume_supported_value_types:
|
|
2198
|
-
raise RuntimeError(f"
|
|
2534
|
+
raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
|
|
2199
2535
|
|
|
2200
2536
|
return None
|
|
2201
2537
|
|
|
@@ -2299,14 +2635,17 @@ add_builtin(
|
|
|
2299
2635
|
)
|
|
2300
2636
|
|
|
2301
2637
|
|
|
2302
|
-
def volume_sample_index_value_func(arg_types,
|
|
2303
|
-
if
|
|
2304
|
-
|
|
2638
|
+
def volume_sample_index_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
2639
|
+
if arg_types is None:
|
|
2640
|
+
return Any
|
|
2641
|
+
|
|
2642
|
+
dtype = arg_types["voxel_data"].dtype
|
|
2305
2643
|
|
|
2306
|
-
dtype
|
|
2644
|
+
if dtype not in _volume_supported_value_types:
|
|
2645
|
+
raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
|
|
2307
2646
|
|
|
2308
|
-
if not types_equal(dtype, arg_types[
|
|
2309
|
-
raise RuntimeError("
|
|
2647
|
+
if not types_equal(dtype, arg_types["background"]):
|
|
2648
|
+
raise RuntimeError("the `voxel_data` array and the `background` value must have the same dtype")
|
|
2310
2649
|
|
|
2311
2650
|
return dtype
|
|
2312
2651
|
|
|
@@ -2326,17 +2665,20 @@ add_builtin(
|
|
|
2326
2665
|
)
|
|
2327
2666
|
|
|
2328
2667
|
|
|
2329
|
-
def volume_sample_grad_index_value_func(arg_types,
|
|
2330
|
-
if
|
|
2331
|
-
|
|
2668
|
+
def volume_sample_grad_index_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
2669
|
+
if arg_types is None:
|
|
2670
|
+
return Any
|
|
2671
|
+
|
|
2672
|
+
dtype = arg_types["voxel_data"].dtype
|
|
2332
2673
|
|
|
2333
|
-
dtype
|
|
2674
|
+
if dtype not in _volume_supported_value_types:
|
|
2675
|
+
raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
|
|
2334
2676
|
|
|
2335
|
-
if not types_equal(dtype, arg_types[
|
|
2336
|
-
raise RuntimeError("
|
|
2677
|
+
if not types_equal(dtype, arg_types["background"]):
|
|
2678
|
+
raise RuntimeError("the `voxel_data` array and the `background` value must have the same dtype")
|
|
2679
|
+
|
|
2680
|
+
check_volume_value_grad_compatibility(dtype, arg_types["grad"])
|
|
2337
2681
|
|
|
2338
|
-
grad_type = arg_types[5]
|
|
2339
|
-
check_volume_value_grad_compatibility(dtype, grad_type)
|
|
2340
2682
|
return dtype
|
|
2341
2683
|
|
|
2342
2684
|
|
|
@@ -2434,10 +2776,10 @@ add_builtin(
|
|
|
2434
2776
|
)
|
|
2435
2777
|
add_builtin(
|
|
2436
2778
|
"randi",
|
|
2437
|
-
input_types={"state": uint32, "
|
|
2779
|
+
input_types={"state": uint32, "low": int, "high": int},
|
|
2438
2780
|
value_type=int,
|
|
2439
2781
|
group="Random",
|
|
2440
|
-
doc="Return a random integer between [
|
|
2782
|
+
doc="Return a random integer between [low, high).",
|
|
2441
2783
|
)
|
|
2442
2784
|
add_builtin(
|
|
2443
2785
|
"randf",
|
|
@@ -2448,10 +2790,10 @@ add_builtin(
|
|
|
2448
2790
|
)
|
|
2449
2791
|
add_builtin(
|
|
2450
2792
|
"randf",
|
|
2451
|
-
input_types={"state": uint32, "
|
|
2793
|
+
input_types={"state": uint32, "low": float, "high": float},
|
|
2452
2794
|
value_type=float,
|
|
2453
2795
|
group="Random",
|
|
2454
|
-
doc="Return a random float between [
|
|
2796
|
+
doc="Return a random float between [low, high).",
|
|
2455
2797
|
)
|
|
2456
2798
|
add_builtin(
|
|
2457
2799
|
"randn", input_types={"state": uint32}, value_type=float, group="Random", doc="Sample a normal distribution."
|
|
@@ -2600,7 +2942,7 @@ add_builtin(
|
|
|
2600
2942
|
add_builtin(
|
|
2601
2943
|
"curlnoise",
|
|
2602
2944
|
input_types={"state": uint32, "xy": vec2, "octaves": uint32, "lacunarity": float, "gain": float},
|
|
2603
|
-
defaults={"octaves": 1, "lacunarity": 2.0, "gain": 0.5},
|
|
2945
|
+
defaults={"octaves": uint32(1), "lacunarity": 2.0, "gain": 0.5},
|
|
2604
2946
|
value_type=vec2,
|
|
2605
2947
|
group="Random",
|
|
2606
2948
|
doc="Divergence-free vector field based on the gradient of a Perlin noise function.",
|
|
@@ -2609,7 +2951,7 @@ add_builtin(
|
|
|
2609
2951
|
add_builtin(
|
|
2610
2952
|
"curlnoise",
|
|
2611
2953
|
input_types={"state": uint32, "xyz": vec3, "octaves": uint32, "lacunarity": float, "gain": float},
|
|
2612
|
-
defaults={"octaves": 1, "lacunarity": 2.0, "gain": 0.5},
|
|
2954
|
+
defaults={"octaves": uint32(1), "lacunarity": 2.0, "gain": 0.5},
|
|
2613
2955
|
value_type=vec3,
|
|
2614
2956
|
group="Random",
|
|
2615
2957
|
doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
|
|
@@ -2618,19 +2960,31 @@ add_builtin(
|
|
|
2618
2960
|
add_builtin(
|
|
2619
2961
|
"curlnoise",
|
|
2620
2962
|
input_types={"state": uint32, "xyzt": vec4, "octaves": uint32, "lacunarity": float, "gain": float},
|
|
2621
|
-
defaults={"octaves": 1, "lacunarity": 2.0, "gain": 0.5},
|
|
2963
|
+
defaults={"octaves": uint32(1), "lacunarity": 2.0, "gain": 0.5},
|
|
2622
2964
|
value_type=vec3,
|
|
2623
2965
|
group="Random",
|
|
2624
2966
|
doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
|
|
2625
2967
|
missing_grad=True,
|
|
2626
2968
|
)
|
|
2627
2969
|
|
|
2970
|
+
|
|
2971
|
+
def printf_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2972
|
+
# We're in the codegen stage where we emit the code calling the built-in.
|
|
2973
|
+
# Further validate the given argument values if needed and map them
|
|
2974
|
+
# to the underlying C++ function's runtime and template params.
|
|
2975
|
+
|
|
2976
|
+
func_args = (args["fmt"], *args["args"])
|
|
2977
|
+
template_args = ()
|
|
2978
|
+
return (func_args, template_args)
|
|
2979
|
+
|
|
2980
|
+
|
|
2628
2981
|
# note printf calls directly to global CRT printf (no wp:: namespace prefix)
|
|
2629
2982
|
add_builtin(
|
|
2630
2983
|
"printf",
|
|
2631
|
-
input_types={},
|
|
2984
|
+
input_types={"fmt": str, "*args": Any},
|
|
2632
2985
|
namespace="",
|
|
2633
2986
|
variadic=True,
|
|
2987
|
+
dispatch_func=printf_dispatch_func,
|
|
2634
2988
|
group="Utility",
|
|
2635
2989
|
doc="Allows printing formatted strings using C-style format specifiers.",
|
|
2636
2990
|
)
|
|
@@ -2709,189 +3063,309 @@ add_builtin(
|
|
|
2709
3063
|
|
|
2710
3064
|
add_builtin(
|
|
2711
3065
|
"copy",
|
|
2712
|
-
input_types={"
|
|
2713
|
-
value_func=lambda arg_types,
|
|
3066
|
+
input_types={"a": Any},
|
|
3067
|
+
value_func=lambda arg_types, arg_values: arg_types["a"],
|
|
3068
|
+
hidden=True,
|
|
3069
|
+
export=False,
|
|
3070
|
+
group="Utility",
|
|
3071
|
+
)
|
|
3072
|
+
add_builtin(
|
|
3073
|
+
"assign",
|
|
3074
|
+
input_types={"dest": Any, "src": Any},
|
|
2714
3075
|
hidden=True,
|
|
2715
3076
|
export=False,
|
|
2716
3077
|
group="Utility",
|
|
2717
3078
|
)
|
|
2718
|
-
add_builtin("assign", variadic=True, hidden=True, export=False, group="Utility")
|
|
2719
3079
|
add_builtin(
|
|
2720
3080
|
"select",
|
|
2721
|
-
input_types={"cond": builtins.bool, "
|
|
2722
|
-
value_func=lambda arg_types,
|
|
2723
|
-
doc="Select between two arguments, if ``cond`` is ``False`` then return ``
|
|
3081
|
+
input_types={"cond": builtins.bool, "value_if_false": Any, "value_if_true": Any},
|
|
3082
|
+
value_func=lambda arg_types, arg_values: arg_types["value_if_false"],
|
|
3083
|
+
doc="Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``",
|
|
2724
3084
|
group="Utility",
|
|
2725
3085
|
)
|
|
2726
3086
|
for t in int_types:
|
|
2727
3087
|
add_builtin(
|
|
2728
3088
|
"select",
|
|
2729
|
-
input_types={"cond": t, "
|
|
2730
|
-
value_func=lambda arg_types,
|
|
2731
|
-
doc="Select between two arguments, if ``cond`` is ``False`` then return ``
|
|
3089
|
+
input_types={"cond": t, "value_if_false": Any, "value_if_true": Any},
|
|
3090
|
+
value_func=lambda arg_types, arg_values: arg_types["value_if_false"],
|
|
3091
|
+
doc="Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``",
|
|
2732
3092
|
group="Utility",
|
|
2733
3093
|
)
|
|
2734
3094
|
add_builtin(
|
|
2735
3095
|
"select",
|
|
2736
|
-
input_types={"arr": array(dtype=Any), "
|
|
2737
|
-
value_func=lambda arg_types,
|
|
2738
|
-
doc="Select between two arguments, if ``arr`` is null then return ``
|
|
3096
|
+
input_types={"arr": array(dtype=Any), "value_if_false": Any, "value_if_true": Any},
|
|
3097
|
+
value_func=lambda arg_types, arg_values: arg_types["value_if_false"],
|
|
3098
|
+
doc="Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``",
|
|
3099
|
+
group="Utility",
|
|
3100
|
+
)
|
|
3101
|
+
|
|
3102
|
+
|
|
3103
|
+
def array_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
3104
|
+
if arg_types is None:
|
|
3105
|
+
return array(dtype=Scalar)
|
|
3106
|
+
|
|
3107
|
+
dtype = arg_values["dtype"]
|
|
3108
|
+
shape = arg_values["shape"]
|
|
3109
|
+
return array(dtype=dtype, ndim=len(shape))
|
|
3110
|
+
|
|
3111
|
+
|
|
3112
|
+
def array_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
3113
|
+
# We're in the codegen stage where we emit the code calling the built-in.
|
|
3114
|
+
# Further validate the given argument values if needed and map them
|
|
3115
|
+
# to the underlying C++ function's runtime and template params.
|
|
3116
|
+
|
|
3117
|
+
dtype = return_type.dtype
|
|
3118
|
+
|
|
3119
|
+
func_args = (args["ptr"], *args["shape"])
|
|
3120
|
+
template_args = (dtype,)
|
|
3121
|
+
return (func_args, template_args)
|
|
3122
|
+
|
|
3123
|
+
|
|
3124
|
+
add_builtin(
|
|
3125
|
+
"array",
|
|
3126
|
+
input_types={"ptr": warp.uint64, "shape": Tuple[int, ...], "dtype": Scalar},
|
|
3127
|
+
value_func=array_value_func,
|
|
3128
|
+
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
3129
|
+
dispatch_func=array_dispatch_func,
|
|
3130
|
+
native_func="array_t",
|
|
2739
3131
|
group="Utility",
|
|
3132
|
+
hidden=True,
|
|
3133
|
+
export=False,
|
|
3134
|
+
missing_grad=True,
|
|
2740
3135
|
)
|
|
2741
3136
|
|
|
2742
3137
|
|
|
2743
3138
|
# does argument checking and type propagation for address()
|
|
2744
|
-
def address_value_func(arg_types,
|
|
2745
|
-
|
|
2746
|
-
|
|
3139
|
+
def address_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
3140
|
+
arr_type = arg_types["arr"]
|
|
3141
|
+
idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
|
|
3142
|
+
|
|
3143
|
+
if not is_array(arr_type):
|
|
3144
|
+
raise RuntimeError("address() first argument must be an array")
|
|
2747
3145
|
|
|
2748
|
-
|
|
2749
|
-
num_dims = arg_types[0].ndim
|
|
3146
|
+
idx_count = len(idx_types)
|
|
2750
3147
|
|
|
2751
|
-
if
|
|
3148
|
+
if idx_count < arr_type.ndim:
|
|
2752
3149
|
raise RuntimeError(
|
|
2753
3150
|
"Num indices < num dimensions for array load, this is a codegen error, should have generated a view instead"
|
|
2754
3151
|
)
|
|
2755
3152
|
|
|
2756
|
-
if
|
|
3153
|
+
if idx_count > arr_type.ndim:
|
|
2757
3154
|
raise RuntimeError(
|
|
2758
|
-
f"Num indices > num dimensions for array load, received {
|
|
3155
|
+
f"Num indices > num dimensions for array load, received {idx_count}, but array only has {arr_type.ndim}"
|
|
2759
3156
|
)
|
|
2760
3157
|
|
|
2761
3158
|
# check index types
|
|
2762
|
-
for t in
|
|
3159
|
+
for t in idx_types:
|
|
2763
3160
|
if not type_is_int(t):
|
|
2764
|
-
raise RuntimeError(f"address() index arguments must be of integer type, got index of type {t}")
|
|
3161
|
+
raise RuntimeError(f"address() index arguments must be of integer type, got index of type {type_repr(t)}")
|
|
2765
3162
|
|
|
2766
|
-
return Reference(
|
|
3163
|
+
return Reference(arr_type.dtype)
|
|
3164
|
+
|
|
3165
|
+
|
|
3166
|
+
for array_type in array_types:
|
|
3167
|
+
add_builtin(
|
|
3168
|
+
"address",
|
|
3169
|
+
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int},
|
|
3170
|
+
defaults={"j": None, "k": None, "l": None},
|
|
3171
|
+
hidden=True,
|
|
3172
|
+
value_func=address_value_func,
|
|
3173
|
+
group="Utility",
|
|
3174
|
+
)
|
|
2767
3175
|
|
|
2768
3176
|
|
|
2769
3177
|
# does argument checking and type propagation for view()
|
|
2770
|
-
def view_value_func(arg_types,
|
|
2771
|
-
|
|
2772
|
-
|
|
3178
|
+
def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
3179
|
+
arr_type = arg_types["arr"]
|
|
3180
|
+
idx_types = tuple(arg_types[x] for x in "ijk" if arg_types.get(x, None) is not None)
|
|
3181
|
+
|
|
3182
|
+
if not is_array(arr_type):
|
|
3183
|
+
raise RuntimeError("view() first argument must be an array")
|
|
2773
3184
|
|
|
2774
|
-
|
|
2775
|
-
num_indices = len(arg_types) - 1
|
|
2776
|
-
num_dims = arg_types[0].ndim
|
|
3185
|
+
idx_count = len(idx_types)
|
|
2777
3186
|
|
|
2778
|
-
if
|
|
3187
|
+
if idx_count >= arr_type.ndim:
|
|
2779
3188
|
raise RuntimeError(
|
|
2780
|
-
f"Trying to create an array view with {
|
|
3189
|
+
f"Trying to create an array view with {idx_count} indices, "
|
|
3190
|
+
f"but the array only has {arr_type.ndim} dimension(s). "
|
|
3191
|
+
f"Ensure that the argument type on the function or kernel specifies "
|
|
3192
|
+
f"the expected number of dimensions, e.g.: def func(param: wp.array3d(dtype=float): ..."
|
|
2781
3193
|
)
|
|
2782
3194
|
|
|
2783
3195
|
# check index types
|
|
2784
|
-
for t in
|
|
3196
|
+
for t in idx_types:
|
|
2785
3197
|
if not type_is_int(t):
|
|
2786
|
-
raise RuntimeError(f"view() index arguments must be of integer type, got index of type {t}")
|
|
3198
|
+
raise RuntimeError(f"view() index arguments must be of integer type, got index of type {type_repr(t)}")
|
|
2787
3199
|
|
|
2788
3200
|
# create an array view with leading dimensions removed
|
|
2789
|
-
dtype =
|
|
2790
|
-
ndim =
|
|
2791
|
-
if isinstance(
|
|
3201
|
+
dtype = arr_type.dtype
|
|
3202
|
+
ndim = arr_type.ndim - idx_count
|
|
3203
|
+
if isinstance(arr_type, (fabricarray, indexedfabricarray)):
|
|
2792
3204
|
# fabric array of arrays: return array attribute as a regular array
|
|
2793
3205
|
return array(dtype=dtype, ndim=ndim)
|
|
2794
|
-
|
|
2795
|
-
|
|
3206
|
+
|
|
3207
|
+
return type(arr_type)(dtype=dtype, ndim=ndim)
|
|
3208
|
+
|
|
3209
|
+
|
|
3210
|
+
for array_type in array_types:
|
|
3211
|
+
add_builtin(
|
|
3212
|
+
"view",
|
|
3213
|
+
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int},
|
|
3214
|
+
defaults={"j": None, "k": None},
|
|
3215
|
+
hidden=True,
|
|
3216
|
+
value_func=view_value_func,
|
|
3217
|
+
group="Utility",
|
|
3218
|
+
)
|
|
2796
3219
|
|
|
2797
3220
|
|
|
2798
3221
|
# does argument checking and type propagation for array_store()
|
|
2799
|
-
def array_store_value_func(arg_types,
|
|
2800
|
-
|
|
2801
|
-
|
|
2802
|
-
|
|
3222
|
+
def array_store_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
3223
|
+
arr_type = arg_types["arr"]
|
|
3224
|
+
value_type = arg_types["value"]
|
|
3225
|
+
idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
|
|
2803
3226
|
|
|
2804
|
-
|
|
2805
|
-
|
|
3227
|
+
if not is_array(arr_type):
|
|
3228
|
+
raise RuntimeError("array_store() first argument must be an array")
|
|
2806
3229
|
|
|
2807
|
-
|
|
2808
|
-
if num_indices < num_dims:
|
|
2809
|
-
raise RuntimeError("Num indices < num dimensions for array store")
|
|
3230
|
+
idx_count = len(idx_types)
|
|
2810
3231
|
|
|
2811
|
-
if
|
|
3232
|
+
if idx_count < arr_type.ndim:
|
|
2812
3233
|
raise RuntimeError(
|
|
2813
|
-
|
|
3234
|
+
"Num indices < num dimensions for array store, this is a codegen error, should have generated a view instead"
|
|
3235
|
+
)
|
|
3236
|
+
|
|
3237
|
+
if idx_count > arr_type.ndim:
|
|
3238
|
+
raise RuntimeError(
|
|
3239
|
+
f"Num indices > num dimensions for array store, received {idx_count}, but array only has {arr_type.ndim}"
|
|
2814
3240
|
)
|
|
2815
3241
|
|
|
2816
3242
|
# check index types
|
|
2817
|
-
for t in
|
|
3243
|
+
for t in idx_types:
|
|
2818
3244
|
if not type_is_int(t):
|
|
2819
|
-
raise RuntimeError(
|
|
3245
|
+
raise RuntimeError(
|
|
3246
|
+
f"array_store() index arguments must be of integer type, got index of type {type_repr(t)}"
|
|
3247
|
+
)
|
|
2820
3248
|
|
|
2821
3249
|
# check value type
|
|
2822
|
-
if not types_equal(
|
|
3250
|
+
if not types_equal(arr_type.dtype, value_type):
|
|
2823
3251
|
raise RuntimeError(
|
|
2824
|
-
f"array_store() value argument type ({
|
|
3252
|
+
f"array_store() value argument type ({type_repr(value_type)}) must be of the same type as the array ({type_repr(arr_type.dtype)})"
|
|
2825
3253
|
)
|
|
2826
3254
|
|
|
2827
3255
|
return None
|
|
2828
3256
|
|
|
2829
3257
|
|
|
3258
|
+
for array_type in array_types:
|
|
3259
|
+
add_builtin(
|
|
3260
|
+
"array_store",
|
|
3261
|
+
input_types={"arr": array_type(dtype=Any), "i": int, "value": Any},
|
|
3262
|
+
hidden=True,
|
|
3263
|
+
value_func=array_store_value_func,
|
|
3264
|
+
skip_replay=True,
|
|
3265
|
+
group="Utility",
|
|
3266
|
+
)
|
|
3267
|
+
add_builtin(
|
|
3268
|
+
"array_store",
|
|
3269
|
+
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "value": Any},
|
|
3270
|
+
hidden=True,
|
|
3271
|
+
value_func=array_store_value_func,
|
|
3272
|
+
skip_replay=True,
|
|
3273
|
+
group="Utility",
|
|
3274
|
+
)
|
|
3275
|
+
add_builtin(
|
|
3276
|
+
"array_store",
|
|
3277
|
+
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
|
|
3278
|
+
hidden=True,
|
|
3279
|
+
value_func=array_store_value_func,
|
|
3280
|
+
skip_replay=True,
|
|
3281
|
+
group="Utility",
|
|
3282
|
+
)
|
|
3283
|
+
add_builtin(
|
|
3284
|
+
"array_store",
|
|
3285
|
+
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
|
|
3286
|
+
hidden=True,
|
|
3287
|
+
value_func=array_store_value_func,
|
|
3288
|
+
skip_replay=True,
|
|
3289
|
+
group="Utility",
|
|
3290
|
+
)
|
|
3291
|
+
|
|
3292
|
+
|
|
2830
3293
|
# does argument checking for store()
|
|
2831
|
-
def store_value_func(arg_types,
|
|
3294
|
+
def store_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
2832
3295
|
# we already stripped the Reference from the argument type prior to this call
|
|
2833
|
-
if not types_equal(arg_types[
|
|
2834
|
-
raise RuntimeError(
|
|
3296
|
+
if not types_equal(arg_types["address"], arg_types["value"]):
|
|
3297
|
+
raise RuntimeError(
|
|
3298
|
+
f"store() value argument type ({arg_types['value']}) must be of the same type as the reference"
|
|
3299
|
+
)
|
|
2835
3300
|
|
|
2836
3301
|
return None
|
|
2837
3302
|
|
|
2838
3303
|
|
|
2839
|
-
|
|
2840
|
-
|
|
2841
|
-
|
|
2842
|
-
return
|
|
3304
|
+
def store_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
3305
|
+
func_args = (Reference(args["address"]), args["value"])
|
|
3306
|
+
template_args = ()
|
|
3307
|
+
return (func_args, template_args)
|
|
2843
3308
|
|
|
2844
3309
|
|
|
2845
|
-
add_builtin("address", variadic=True, hidden=True, value_func=address_value_func, group="Utility")
|
|
2846
|
-
add_builtin("view", variadic=True, hidden=True, value_func=view_value_func, group="Utility")
|
|
2847
|
-
add_builtin(
|
|
2848
|
-
"array_store", variadic=True, hidden=True, value_func=array_store_value_func, skip_replay=True, group="Utility"
|
|
2849
|
-
)
|
|
2850
3310
|
add_builtin(
|
|
2851
3311
|
"store",
|
|
2852
|
-
input_types={"address":
|
|
2853
|
-
hidden=True,
|
|
3312
|
+
input_types={"address": Any, "value": Any},
|
|
2854
3313
|
value_func=store_value_func,
|
|
3314
|
+
dispatch_func=store_dispatch_func,
|
|
3315
|
+
hidden=True,
|
|
2855
3316
|
skip_replay=True,
|
|
2856
3317
|
group="Utility",
|
|
2857
3318
|
)
|
|
3319
|
+
|
|
3320
|
+
|
|
3321
|
+
def load_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
3322
|
+
func_args = (Reference(args["address"]),)
|
|
3323
|
+
template_args = ()
|
|
3324
|
+
return (func_args, template_args)
|
|
3325
|
+
|
|
3326
|
+
|
|
2858
3327
|
add_builtin(
|
|
2859
3328
|
"load",
|
|
2860
|
-
input_types={"address":
|
|
3329
|
+
input_types={"address": Any},
|
|
3330
|
+
value_func=lambda arg_types, arg_values: arg_types["address"],
|
|
3331
|
+
dispatch_func=load_dispatch_func,
|
|
2861
3332
|
hidden=True,
|
|
2862
|
-
value_func=load_value_func,
|
|
2863
3333
|
group="Utility",
|
|
2864
3334
|
)
|
|
2865
3335
|
|
|
2866
3336
|
|
|
2867
|
-
def atomic_op_value_func(arg_types,
|
|
2868
|
-
|
|
2869
|
-
|
|
2870
|
-
|
|
3337
|
+
def atomic_op_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
3338
|
+
arr_type = arg_types["arr"]
|
|
3339
|
+
value_type = arg_types["value"]
|
|
3340
|
+
idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
|
|
2871
3341
|
|
|
2872
|
-
|
|
2873
|
-
|
|
3342
|
+
if not is_array(arr_type):
|
|
3343
|
+
raise RuntimeError("atomic() first argument must be an array")
|
|
2874
3344
|
|
|
2875
|
-
|
|
2876
|
-
if num_indices < num_dims:
|
|
2877
|
-
raise RuntimeError("Num indices < num dimensions for atomic array operation")
|
|
3345
|
+
idx_count = len(idx_types)
|
|
2878
3346
|
|
|
2879
|
-
if
|
|
3347
|
+
if idx_count < arr_type.ndim:
|
|
2880
3348
|
raise RuntimeError(
|
|
2881
|
-
|
|
3349
|
+
"Num indices < num dimensions for atomic, this is a codegen error, should have generated a view instead"
|
|
3350
|
+
)
|
|
3351
|
+
|
|
3352
|
+
if idx_count > arr_type.ndim:
|
|
3353
|
+
raise RuntimeError(
|
|
3354
|
+
f"Num indices > num dimensions for atomic, received {idx_count}, but array only has {arr_type.ndim}"
|
|
2882
3355
|
)
|
|
2883
3356
|
|
|
2884
3357
|
# check index types
|
|
2885
|
-
for t in
|
|
3358
|
+
for t in idx_types:
|
|
2886
3359
|
if not type_is_int(t):
|
|
2887
|
-
raise RuntimeError(f"atomic()
|
|
3360
|
+
raise RuntimeError(f"atomic() index arguments must be of integer type, got index of type {type_repr(t)}")
|
|
2888
3361
|
|
|
2889
|
-
|
|
3362
|
+
# check value type
|
|
3363
|
+
if not types_equal(arr_type.dtype, value_type):
|
|
2890
3364
|
raise RuntimeError(
|
|
2891
|
-
f"atomic() value argument ({
|
|
3365
|
+
f"atomic() value argument type ({type_repr(value_type)}) must be of the same type as the array ({type_repr(arr_type.dtype)})"
|
|
2892
3366
|
)
|
|
2893
3367
|
|
|
2894
|
-
return
|
|
3368
|
+
return arr_type.dtype
|
|
2895
3369
|
|
|
2896
3370
|
|
|
2897
3371
|
for array_type in array_types:
|
|
@@ -2901,36 +3375,36 @@ for array_type in array_types:
|
|
|
2901
3375
|
add_builtin(
|
|
2902
3376
|
"atomic_add",
|
|
2903
3377
|
hidden=hidden,
|
|
2904
|
-
input_types={"
|
|
3378
|
+
input_types={"arr": array_type(dtype=Any), "i": int, "value": Any},
|
|
2905
3379
|
value_func=atomic_op_value_func,
|
|
2906
|
-
doc="Atomically add ``value`` onto ``
|
|
3380
|
+
doc="Atomically add ``value`` onto ``arr[i]``.",
|
|
2907
3381
|
group="Utility",
|
|
2908
3382
|
skip_replay=True,
|
|
2909
3383
|
)
|
|
2910
3384
|
add_builtin(
|
|
2911
3385
|
"atomic_add",
|
|
2912
3386
|
hidden=hidden,
|
|
2913
|
-
input_types={"
|
|
3387
|
+
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "value": Any},
|
|
2914
3388
|
value_func=atomic_op_value_func,
|
|
2915
|
-
doc="Atomically add ``value`` onto ``
|
|
3389
|
+
doc="Atomically add ``value`` onto ``arr[i,j]``.",
|
|
2916
3390
|
group="Utility",
|
|
2917
3391
|
skip_replay=True,
|
|
2918
3392
|
)
|
|
2919
3393
|
add_builtin(
|
|
2920
3394
|
"atomic_add",
|
|
2921
3395
|
hidden=hidden,
|
|
2922
|
-
input_types={"
|
|
3396
|
+
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
|
|
2923
3397
|
value_func=atomic_op_value_func,
|
|
2924
|
-
doc="Atomically add ``value`` onto ``
|
|
3398
|
+
doc="Atomically add ``value`` onto ``arr[i,j,k]``.",
|
|
2925
3399
|
group="Utility",
|
|
2926
3400
|
skip_replay=True,
|
|
2927
3401
|
)
|
|
2928
3402
|
add_builtin(
|
|
2929
3403
|
"atomic_add",
|
|
2930
3404
|
hidden=hidden,
|
|
2931
|
-
input_types={"
|
|
3405
|
+
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
|
|
2932
3406
|
value_func=atomic_op_value_func,
|
|
2933
|
-
doc="Atomically add ``value`` onto ``
|
|
3407
|
+
doc="Atomically add ``value`` onto ``arr[i,j,k,l]``.",
|
|
2934
3408
|
group="Utility",
|
|
2935
3409
|
skip_replay=True,
|
|
2936
3410
|
)
|
|
@@ -2938,36 +3412,36 @@ for array_type in array_types:
|
|
|
2938
3412
|
add_builtin(
|
|
2939
3413
|
"atomic_sub",
|
|
2940
3414
|
hidden=hidden,
|
|
2941
|
-
input_types={"
|
|
3415
|
+
input_types={"arr": array_type(dtype=Any), "i": int, "value": Any},
|
|
2942
3416
|
value_func=atomic_op_value_func,
|
|
2943
|
-
doc="Atomically subtract ``value`` onto ``
|
|
3417
|
+
doc="Atomically subtract ``value`` onto ``arr[i]``.",
|
|
2944
3418
|
group="Utility",
|
|
2945
3419
|
skip_replay=True,
|
|
2946
3420
|
)
|
|
2947
3421
|
add_builtin(
|
|
2948
3422
|
"atomic_sub",
|
|
2949
3423
|
hidden=hidden,
|
|
2950
|
-
input_types={"
|
|
3424
|
+
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "value": Any},
|
|
2951
3425
|
value_func=atomic_op_value_func,
|
|
2952
|
-
doc="Atomically subtract ``value`` onto ``
|
|
3426
|
+
doc="Atomically subtract ``value`` onto ``arr[i,j]``.",
|
|
2953
3427
|
group="Utility",
|
|
2954
3428
|
skip_replay=True,
|
|
2955
3429
|
)
|
|
2956
3430
|
add_builtin(
|
|
2957
3431
|
"atomic_sub",
|
|
2958
3432
|
hidden=hidden,
|
|
2959
|
-
input_types={"
|
|
3433
|
+
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
|
|
2960
3434
|
value_func=atomic_op_value_func,
|
|
2961
|
-
doc="Atomically subtract ``value`` onto ``
|
|
3435
|
+
doc="Atomically subtract ``value`` onto ``arr[i,j,k]``.",
|
|
2962
3436
|
group="Utility",
|
|
2963
3437
|
skip_replay=True,
|
|
2964
3438
|
)
|
|
2965
3439
|
add_builtin(
|
|
2966
3440
|
"atomic_sub",
|
|
2967
3441
|
hidden=hidden,
|
|
2968
|
-
input_types={"
|
|
3442
|
+
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
|
|
2969
3443
|
value_func=atomic_op_value_func,
|
|
2970
|
-
doc="Atomically subtract ``value`` onto ``
|
|
3444
|
+
doc="Atomically subtract ``value`` onto ``arr[i,j,k,l]``.",
|
|
2971
3445
|
group="Utility",
|
|
2972
3446
|
skip_replay=True,
|
|
2973
3447
|
)
|
|
@@ -2975,9 +3449,9 @@ for array_type in array_types:
|
|
|
2975
3449
|
add_builtin(
|
|
2976
3450
|
"atomic_min",
|
|
2977
3451
|
hidden=hidden,
|
|
2978
|
-
input_types={"
|
|
3452
|
+
input_types={"arr": array_type(dtype=Any), "i": int, "value": Any},
|
|
2979
3453
|
value_func=atomic_op_value_func,
|
|
2980
|
-
doc="""Compute the minimum of ``value`` and ``
|
|
3454
|
+
doc="""Compute the minimum of ``value`` and ``arr[i]`` and atomically update the array.
|
|
2981
3455
|
|
|
2982
3456
|
.. note:: The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
2983
3457
|
group="Utility",
|
|
@@ -2986,9 +3460,9 @@ for array_type in array_types:
|
|
|
2986
3460
|
add_builtin(
|
|
2987
3461
|
"atomic_min",
|
|
2988
3462
|
hidden=hidden,
|
|
2989
|
-
input_types={"
|
|
3463
|
+
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "value": Any},
|
|
2990
3464
|
value_func=atomic_op_value_func,
|
|
2991
|
-
doc="""Compute the minimum of ``value`` and ``
|
|
3465
|
+
doc="""Compute the minimum of ``value`` and ``arr[i,j]`` and atomically update the array.
|
|
2992
3466
|
|
|
2993
3467
|
.. note:: The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
2994
3468
|
group="Utility",
|
|
@@ -2997,9 +3471,9 @@ for array_type in array_types:
|
|
|
2997
3471
|
add_builtin(
|
|
2998
3472
|
"atomic_min",
|
|
2999
3473
|
hidden=hidden,
|
|
3000
|
-
input_types={"
|
|
3474
|
+
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
|
|
3001
3475
|
value_func=atomic_op_value_func,
|
|
3002
|
-
doc="""Compute the minimum of ``value`` and ``
|
|
3476
|
+
doc="""Compute the minimum of ``value`` and ``arr[i,j,k]`` and atomically update the array.
|
|
3003
3477
|
|
|
3004
3478
|
.. note:: The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
3005
3479
|
group="Utility",
|
|
@@ -3008,9 +3482,9 @@ for array_type in array_types:
|
|
|
3008
3482
|
add_builtin(
|
|
3009
3483
|
"atomic_min",
|
|
3010
3484
|
hidden=hidden,
|
|
3011
|
-
input_types={"
|
|
3485
|
+
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
|
|
3012
3486
|
value_func=atomic_op_value_func,
|
|
3013
|
-
doc="""Compute the minimum of ``value`` and ``
|
|
3487
|
+
doc="""Compute the minimum of ``value`` and ``arr[i,j,k,l]`` and atomically update the array.
|
|
3014
3488
|
|
|
3015
3489
|
.. note:: The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
3016
3490
|
group="Utility",
|
|
@@ -3020,9 +3494,9 @@ for array_type in array_types:
|
|
|
3020
3494
|
add_builtin(
|
|
3021
3495
|
"atomic_max",
|
|
3022
3496
|
hidden=hidden,
|
|
3023
|
-
input_types={"
|
|
3497
|
+
input_types={"arr": array_type(dtype=Any), "i": int, "value": Any},
|
|
3024
3498
|
value_func=atomic_op_value_func,
|
|
3025
|
-
doc="""Compute the maximum of ``value`` and ``
|
|
3499
|
+
doc="""Compute the maximum of ``value`` and ``arr[i]`` and atomically update the array.
|
|
3026
3500
|
|
|
3027
3501
|
.. note:: The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
3028
3502
|
group="Utility",
|
|
@@ -3031,9 +3505,9 @@ for array_type in array_types:
|
|
|
3031
3505
|
add_builtin(
|
|
3032
3506
|
"atomic_max",
|
|
3033
3507
|
hidden=hidden,
|
|
3034
|
-
input_types={"
|
|
3508
|
+
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "value": Any},
|
|
3035
3509
|
value_func=atomic_op_value_func,
|
|
3036
|
-
doc="""Compute the maximum of ``value`` and ``
|
|
3510
|
+
doc="""Compute the maximum of ``value`` and ``arr[i,j]`` and atomically update the array.
|
|
3037
3511
|
|
|
3038
3512
|
.. note:: The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
3039
3513
|
group="Utility",
|
|
@@ -3042,9 +3516,9 @@ for array_type in array_types:
|
|
|
3042
3516
|
add_builtin(
|
|
3043
3517
|
"atomic_max",
|
|
3044
3518
|
hidden=hidden,
|
|
3045
|
-
input_types={"
|
|
3519
|
+
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "value": Any},
|
|
3046
3520
|
value_func=atomic_op_value_func,
|
|
3047
|
-
doc="""Compute the maximum of ``value`` and ``
|
|
3521
|
+
doc="""Compute the maximum of ``value`` and ``arr[i,j,k]`` and atomically update the array.
|
|
3048
3522
|
|
|
3049
3523
|
.. note:: The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
3050
3524
|
group="Utility",
|
|
@@ -3053,9 +3527,9 @@ for array_type in array_types:
|
|
|
3053
3527
|
add_builtin(
|
|
3054
3528
|
"atomic_max",
|
|
3055
3529
|
hidden=hidden,
|
|
3056
|
-
input_types={"
|
|
3530
|
+
input_types={"arr": array_type(dtype=Any), "i": int, "j": int, "k": int, "l": int, "value": Any},
|
|
3057
3531
|
value_func=atomic_op_value_func,
|
|
3058
|
-
doc="""Compute the maximum of ``value`` and ``
|
|
3532
|
+
doc="""Compute the maximum of ``value`` and ``arr[i,j,k,l]`` and atomically update the array.
|
|
3059
3533
|
|
|
3060
3534
|
.. note:: The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
3061
3535
|
group="Utility",
|
|
@@ -3064,21 +3538,21 @@ for array_type in array_types:
|
|
|
3064
3538
|
|
|
3065
3539
|
|
|
3066
3540
|
# used to index into builtin types, i.e.: y = vec3[1]
|
|
3067
|
-
def
|
|
3068
|
-
return arg_types[
|
|
3541
|
+
def extract_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
3542
|
+
return arg_types["a"]._wp_scalar_type_
|
|
3069
3543
|
|
|
3070
3544
|
|
|
3071
3545
|
add_builtin(
|
|
3072
3546
|
"extract",
|
|
3073
3547
|
input_types={"a": vector(length=Any, dtype=Scalar), "i": int},
|
|
3074
|
-
value_func=
|
|
3548
|
+
value_func=extract_value_func,
|
|
3075
3549
|
hidden=True,
|
|
3076
3550
|
group="Utility",
|
|
3077
3551
|
)
|
|
3078
3552
|
add_builtin(
|
|
3079
3553
|
"extract",
|
|
3080
3554
|
input_types={"a": quaternion(dtype=Scalar), "i": int},
|
|
3081
|
-
value_func=
|
|
3555
|
+
value_func=extract_value_func,
|
|
3082
3556
|
hidden=True,
|
|
3083
3557
|
group="Utility",
|
|
3084
3558
|
)
|
|
@@ -3086,14 +3560,16 @@ add_builtin(
|
|
|
3086
3560
|
add_builtin(
|
|
3087
3561
|
"extract",
|
|
3088
3562
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int},
|
|
3089
|
-
value_func=lambda arg_types,
|
|
3563
|
+
value_func=lambda arg_types, arg_values: vector(
|
|
3564
|
+
length=arg_types["a"]._shape_[1], dtype=arg_types["a"]._wp_scalar_type_
|
|
3565
|
+
),
|
|
3090
3566
|
hidden=True,
|
|
3091
3567
|
group="Utility",
|
|
3092
3568
|
)
|
|
3093
3569
|
add_builtin(
|
|
3094
3570
|
"extract",
|
|
3095
3571
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int},
|
|
3096
|
-
value_func=
|
|
3572
|
+
value_func=extract_value_func,
|
|
3097
3573
|
hidden=True,
|
|
3098
3574
|
group="Utility",
|
|
3099
3575
|
)
|
|
@@ -3101,7 +3577,7 @@ add_builtin(
|
|
|
3101
3577
|
add_builtin(
|
|
3102
3578
|
"extract",
|
|
3103
3579
|
input_types={"a": transformation(dtype=Scalar), "i": int},
|
|
3104
|
-
value_func=
|
|
3580
|
+
value_func=extract_value_func,
|
|
3105
3581
|
hidden=True,
|
|
3106
3582
|
group="Utility",
|
|
3107
3583
|
)
|
|
@@ -3109,19 +3585,35 @@ add_builtin(
|
|
|
3109
3585
|
add_builtin("extract", input_types={"s": shape_t, "i": int}, value_type=int, hidden=True, group="Utility")
|
|
3110
3586
|
|
|
3111
3587
|
|
|
3112
|
-
def
|
|
3113
|
-
vec_type = arg_types[
|
|
3114
|
-
# index_type = arg_types[1]
|
|
3588
|
+
def vector_index_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
3589
|
+
vec_type = arg_types["a"]
|
|
3115
3590
|
value_type = vec_type._wp_scalar_type_
|
|
3116
3591
|
|
|
3117
3592
|
return Reference(value_type)
|
|
3118
3593
|
|
|
3119
3594
|
|
|
3595
|
+
def vector_index_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
3596
|
+
func_args = (Reference(args["a"]), args["i"])
|
|
3597
|
+
template_args = ()
|
|
3598
|
+
return (func_args, template_args)
|
|
3599
|
+
|
|
3600
|
+
|
|
3120
3601
|
# implements &vector[index]
|
|
3121
3602
|
add_builtin(
|
|
3122
3603
|
"index",
|
|
3123
3604
|
input_types={"a": vector(length=Any, dtype=Scalar), "i": int},
|
|
3124
|
-
value_func=
|
|
3605
|
+
value_func=vector_index_value_func,
|
|
3606
|
+
dispatch_func=vector_index_dispatch_func,
|
|
3607
|
+
hidden=True,
|
|
3608
|
+
group="Utility",
|
|
3609
|
+
skip_replay=True,
|
|
3610
|
+
)
|
|
3611
|
+
# implements &quaternion[index]
|
|
3612
|
+
add_builtin(
|
|
3613
|
+
"index",
|
|
3614
|
+
input_types={"a": quaternion(dtype=Float), "i": int},
|
|
3615
|
+
value_func=vector_index_value_func,
|
|
3616
|
+
dispatch_func=vector_index_dispatch_func,
|
|
3125
3617
|
hidden=True,
|
|
3126
3618
|
group="Utility",
|
|
3127
3619
|
skip_replay=True,
|
|
@@ -3129,27 +3621,28 @@ add_builtin(
|
|
|
3129
3621
|
# implements &(*vector)[index]
|
|
3130
3622
|
add_builtin(
|
|
3131
3623
|
"indexref",
|
|
3132
|
-
input_types={"a":
|
|
3133
|
-
value_func=
|
|
3624
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": int},
|
|
3625
|
+
value_func=vector_index_value_func,
|
|
3626
|
+
dispatch_func=vector_index_dispatch_func,
|
|
3627
|
+
hidden=True,
|
|
3628
|
+
group="Utility",
|
|
3629
|
+
skip_replay=True,
|
|
3630
|
+
)
|
|
3631
|
+
# implements &(*quaternion)[index]
|
|
3632
|
+
add_builtin(
|
|
3633
|
+
"indexref",
|
|
3634
|
+
input_types={"a": quaternion(dtype=Float), "i": int},
|
|
3635
|
+
value_func=vector_index_value_func,
|
|
3636
|
+
dispatch_func=vector_index_dispatch_func,
|
|
3134
3637
|
hidden=True,
|
|
3135
3638
|
group="Utility",
|
|
3136
3639
|
skip_replay=True,
|
|
3137
3640
|
)
|
|
3138
3641
|
|
|
3139
3642
|
|
|
3140
|
-
def
|
|
3141
|
-
mat_type = arg_types[
|
|
3142
|
-
# row_type = arg_types[1]
|
|
3143
|
-
# col_type = arg_types[2]
|
|
3144
|
-
value_type = mat_type._wp_scalar_type_
|
|
3145
|
-
|
|
3146
|
-
return Reference(value_type)
|
|
3147
|
-
|
|
3148
|
-
|
|
3149
|
-
def matrix_indexref_row_value_func(arg_types, kwds, _):
|
|
3150
|
-
mat_type = arg_types[0]
|
|
3643
|
+
def matrix_index_row_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
3644
|
+
mat_type = arg_types["a"]
|
|
3151
3645
|
row_type = mat_type._wp_row_type_
|
|
3152
|
-
# value_type = arg_types[2]
|
|
3153
3646
|
|
|
3154
3647
|
return Reference(row_type)
|
|
3155
3648
|
|
|
@@ -3158,17 +3651,25 @@ def matrix_indexref_row_value_func(arg_types, kwds, _):
|
|
|
3158
3651
|
add_builtin(
|
|
3159
3652
|
"index",
|
|
3160
3653
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int},
|
|
3161
|
-
value_func=
|
|
3654
|
+
value_func=matrix_index_row_value_func,
|
|
3162
3655
|
hidden=True,
|
|
3163
3656
|
group="Utility",
|
|
3164
3657
|
skip_replay=True,
|
|
3165
3658
|
)
|
|
3166
3659
|
|
|
3660
|
+
|
|
3661
|
+
def matrix_index_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
3662
|
+
mat_type = arg_types["a"]
|
|
3663
|
+
value_type = mat_type._wp_scalar_type_
|
|
3664
|
+
|
|
3665
|
+
return Reference(value_type)
|
|
3666
|
+
|
|
3667
|
+
|
|
3167
3668
|
# implements matrix[i,j] = scalar
|
|
3168
3669
|
add_builtin(
|
|
3169
3670
|
"index",
|
|
3170
3671
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int},
|
|
3171
|
-
value_func=
|
|
3672
|
+
value_func=matrix_index_value_func,
|
|
3172
3673
|
hidden=True,
|
|
3173
3674
|
group="Utility",
|
|
3174
3675
|
skip_replay=True,
|
|
@@ -3177,56 +3678,58 @@ add_builtin(
|
|
|
3177
3678
|
for t in scalar_types + vector_types + (bool,):
|
|
3178
3679
|
if "vec" in t.__name__ or "mat" in t.__name__:
|
|
3179
3680
|
continue
|
|
3681
|
+
|
|
3180
3682
|
add_builtin(
|
|
3181
3683
|
"expect_eq",
|
|
3182
|
-
input_types={"
|
|
3684
|
+
input_types={"a": t, "b": t},
|
|
3183
3685
|
value_type=None,
|
|
3184
|
-
doc="Prints an error to stdout if ``
|
|
3686
|
+
doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
|
|
3185
3687
|
group="Utility",
|
|
3186
3688
|
hidden=True,
|
|
3187
3689
|
)
|
|
3188
3690
|
|
|
3189
3691
|
|
|
3190
|
-
def
|
|
3191
|
-
if not types_equal(arg_types[
|
|
3692
|
+
def expect_eq_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
3693
|
+
if not types_equal(arg_types["a"], arg_types["b"]):
|
|
3192
3694
|
raise RuntimeError("Can't test equality for objects with different types")
|
|
3695
|
+
|
|
3193
3696
|
return None
|
|
3194
3697
|
|
|
3195
3698
|
|
|
3196
3699
|
add_builtin(
|
|
3197
3700
|
"expect_eq",
|
|
3198
|
-
input_types={"
|
|
3701
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
|
|
3199
3702
|
constraint=sametypes,
|
|
3200
|
-
value_func=
|
|
3201
|
-
doc="Prints an error to stdout if ``
|
|
3703
|
+
value_func=expect_eq_value_func,
|
|
3704
|
+
doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
|
|
3202
3705
|
group="Utility",
|
|
3203
3706
|
hidden=True,
|
|
3204
3707
|
)
|
|
3205
3708
|
add_builtin(
|
|
3206
3709
|
"expect_neq",
|
|
3207
|
-
input_types={"
|
|
3710
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
|
|
3208
3711
|
constraint=sametypes,
|
|
3209
|
-
value_func=
|
|
3210
|
-
doc="Prints an error to stdout if ``
|
|
3712
|
+
value_func=expect_eq_value_func,
|
|
3713
|
+
doc="Prints an error to stdout if ``a`` and ``b`` are equal",
|
|
3211
3714
|
group="Utility",
|
|
3212
3715
|
hidden=True,
|
|
3213
3716
|
)
|
|
3214
3717
|
|
|
3215
3718
|
add_builtin(
|
|
3216
3719
|
"expect_eq",
|
|
3217
|
-
input_types={"
|
|
3720
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "b": matrix(shape=(Any, Any), dtype=Scalar)},
|
|
3218
3721
|
constraint=sametypes,
|
|
3219
|
-
value_func=
|
|
3220
|
-
doc="Prints an error to stdout if ``
|
|
3722
|
+
value_func=expect_eq_value_func,
|
|
3723
|
+
doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
|
|
3221
3724
|
group="Utility",
|
|
3222
3725
|
hidden=True,
|
|
3223
3726
|
)
|
|
3224
3727
|
add_builtin(
|
|
3225
3728
|
"expect_neq",
|
|
3226
|
-
input_types={"
|
|
3729
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "b": matrix(shape=(Any, Any), dtype=Scalar)},
|
|
3227
3730
|
constraint=sametypes,
|
|
3228
|
-
value_func=
|
|
3229
|
-
doc="Prints an error to stdout if ``
|
|
3731
|
+
value_func=expect_eq_value_func,
|
|
3732
|
+
doc="Prints an error to stdout if ``a`` and ``b`` are equal",
|
|
3230
3733
|
group="Utility",
|
|
3231
3734
|
hidden=True,
|
|
3232
3735
|
)
|
|
@@ -3234,35 +3737,36 @@ add_builtin(
|
|
|
3234
3737
|
add_builtin(
|
|
3235
3738
|
"lerp",
|
|
3236
3739
|
input_types={"a": Float, "b": Float, "t": Float},
|
|
3237
|
-
value_func=
|
|
3740
|
+
value_func=sametypes_create_value_func(Float),
|
|
3238
3741
|
doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
|
|
3239
3742
|
group="Utility",
|
|
3240
3743
|
)
|
|
3241
3744
|
add_builtin(
|
|
3242
3745
|
"smoothstep",
|
|
3243
|
-
input_types={"
|
|
3244
|
-
value_func=
|
|
3245
|
-
doc="""Smoothly interpolate between two values ``
|
|
3746
|
+
input_types={"a": Float, "b": Float, "x": Float},
|
|
3747
|
+
value_func=sametypes_create_value_func(Float),
|
|
3748
|
+
doc="""Smoothly interpolate between two values ``a`` and ``b`` using a factor ``x``,
|
|
3246
3749
|
and return a result between 0 and 1 using a cubic Hermite interpolation after clamping.""",
|
|
3247
3750
|
group="Utility",
|
|
3248
3751
|
)
|
|
3249
3752
|
|
|
3250
3753
|
|
|
3251
|
-
def lerp_constraint(arg_types):
|
|
3252
|
-
return types_equal(arg_types[
|
|
3754
|
+
def lerp_constraint(arg_types: Mapping[str, type]):
|
|
3755
|
+
return types_equal(arg_types["a"], arg_types["b"])
|
|
3253
3756
|
|
|
3254
3757
|
|
|
3255
|
-
def
|
|
3256
|
-
def fn(arg_types,
|
|
3758
|
+
def lerp_create_value_func(default):
|
|
3759
|
+
def fn(arg_types, arg_values):
|
|
3257
3760
|
if arg_types is None:
|
|
3258
3761
|
return default
|
|
3259
|
-
|
|
3762
|
+
|
|
3260
3763
|
if not lerp_constraint(arg_types):
|
|
3261
3764
|
raise RuntimeError("Can't lerp between objects with different types")
|
|
3262
|
-
|
|
3765
|
+
|
|
3766
|
+
if arg_types["a"]._wp_scalar_type_ != arg_types["t"]:
|
|
3263
3767
|
raise RuntimeError("'t' parameter must have the same scalar type as objects you're lerping between")
|
|
3264
3768
|
|
|
3265
|
-
return arg_types[
|
|
3769
|
+
return arg_types["a"]
|
|
3266
3770
|
|
|
3267
3771
|
return fn
|
|
3268
3772
|
|
|
@@ -3271,7 +3775,7 @@ add_builtin(
|
|
|
3271
3775
|
"lerp",
|
|
3272
3776
|
input_types={"a": vector(length=Any, dtype=Float), "b": vector(length=Any, dtype=Float), "t": Float},
|
|
3273
3777
|
constraint=lerp_constraint,
|
|
3274
|
-
value_func=
|
|
3778
|
+
value_func=lerp_create_value_func(vector(length=Any, dtype=Float)),
|
|
3275
3779
|
doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
|
|
3276
3780
|
group="Utility",
|
|
3277
3781
|
)
|
|
@@ -3279,21 +3783,21 @@ add_builtin(
|
|
|
3279
3783
|
"lerp",
|
|
3280
3784
|
input_types={"a": matrix(shape=(Any, Any), dtype=Float), "b": matrix(shape=(Any, Any), dtype=Float), "t": Float},
|
|
3281
3785
|
constraint=lerp_constraint,
|
|
3282
|
-
value_func=
|
|
3786
|
+
value_func=lerp_create_value_func(matrix(shape=(Any, Any), dtype=Float)),
|
|
3283
3787
|
doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
|
|
3284
3788
|
group="Utility",
|
|
3285
3789
|
)
|
|
3286
3790
|
add_builtin(
|
|
3287
3791
|
"lerp",
|
|
3288
3792
|
input_types={"a": quaternion(dtype=Float), "b": quaternion(dtype=Float), "t": Float},
|
|
3289
|
-
value_func=
|
|
3793
|
+
value_func=lerp_create_value_func(quaternion(dtype=Float)),
|
|
3290
3794
|
doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
|
|
3291
3795
|
group="Utility",
|
|
3292
3796
|
)
|
|
3293
3797
|
add_builtin(
|
|
3294
3798
|
"lerp",
|
|
3295
3799
|
input_types={"a": transformation(dtype=Float), "b": transformation(dtype=Float), "t": Float},
|
|
3296
|
-
value_func=
|
|
3800
|
+
value_func=lerp_create_value_func(transformation(dtype=Float)),
|
|
3297
3801
|
doc="Linearly interpolate two values ``a`` and ``b`` using factor ``t``, computed as ``a*(1-t) + b*t``",
|
|
3298
3802
|
group="Utility",
|
|
3299
3803
|
)
|
|
@@ -3301,17 +3805,18 @@ add_builtin(
|
|
|
3301
3805
|
# fuzzy compare for float values
|
|
3302
3806
|
add_builtin(
|
|
3303
3807
|
"expect_near",
|
|
3304
|
-
input_types={"
|
|
3808
|
+
input_types={"a": Float, "b": Float, "tolerance": Float},
|
|
3305
3809
|
defaults={"tolerance": 1.0e-6},
|
|
3306
3810
|
value_type=None,
|
|
3307
|
-
doc="Prints an error to stdout if ``
|
|
3811
|
+
doc="Prints an error to stdout if ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
3308
3812
|
group="Utility",
|
|
3309
3813
|
)
|
|
3310
3814
|
add_builtin(
|
|
3311
3815
|
"expect_near",
|
|
3312
|
-
input_types={"
|
|
3816
|
+
input_types={"a": vec3, "b": vec3, "tolerance": float},
|
|
3817
|
+
defaults={"tolerance": 1.0e-6},
|
|
3313
3818
|
value_type=None,
|
|
3314
|
-
doc="Prints an error to stdout if any element of ``
|
|
3819
|
+
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
3315
3820
|
group="Utility",
|
|
3316
3821
|
)
|
|
3317
3822
|
|
|
@@ -3335,359 +3840,378 @@ add_builtin(
|
|
|
3335
3840
|
# ---------------------------------
|
|
3336
3841
|
# Operators
|
|
3337
3842
|
|
|
3338
|
-
add_builtin(
|
|
3843
|
+
add_builtin(
|
|
3844
|
+
"add", input_types={"a": Scalar, "b": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators"
|
|
3845
|
+
)
|
|
3339
3846
|
add_builtin(
|
|
3340
3847
|
"add",
|
|
3341
|
-
input_types={"
|
|
3848
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
|
|
3342
3849
|
constraint=sametypes,
|
|
3343
|
-
value_func=
|
|
3850
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
|
|
3344
3851
|
doc="",
|
|
3345
3852
|
group="Operators",
|
|
3346
3853
|
)
|
|
3347
3854
|
add_builtin(
|
|
3348
3855
|
"add",
|
|
3349
|
-
input_types={"
|
|
3350
|
-
value_func=
|
|
3856
|
+
input_types={"a": quaternion(dtype=Scalar), "b": quaternion(dtype=Scalar)},
|
|
3857
|
+
value_func=sametypes_create_value_func(quaternion(dtype=Scalar)),
|
|
3351
3858
|
doc="",
|
|
3352
3859
|
group="Operators",
|
|
3353
3860
|
)
|
|
3354
3861
|
add_builtin(
|
|
3355
3862
|
"add",
|
|
3356
|
-
input_types={"
|
|
3863
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "b": matrix(shape=(Any, Any), dtype=Scalar)},
|
|
3357
3864
|
constraint=sametypes,
|
|
3358
|
-
value_func=
|
|
3865
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
|
|
3359
3866
|
doc="",
|
|
3360
3867
|
group="Operators",
|
|
3361
3868
|
)
|
|
3362
3869
|
add_builtin(
|
|
3363
3870
|
"add",
|
|
3364
|
-
input_types={"
|
|
3365
|
-
value_func=
|
|
3871
|
+
input_types={"a": transformation(dtype=Scalar), "b": transformation(dtype=Scalar)},
|
|
3872
|
+
value_func=sametypes_create_value_func(transformation(dtype=Scalar)),
|
|
3366
3873
|
doc="",
|
|
3367
3874
|
group="Operators",
|
|
3368
3875
|
)
|
|
3369
3876
|
|
|
3370
|
-
add_builtin(
|
|
3877
|
+
add_builtin(
|
|
3878
|
+
"sub", input_types={"a": Scalar, "b": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators"
|
|
3879
|
+
)
|
|
3371
3880
|
add_builtin(
|
|
3372
3881
|
"sub",
|
|
3373
|
-
input_types={"
|
|
3882
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
|
|
3374
3883
|
constraint=sametypes,
|
|
3375
|
-
value_func=
|
|
3884
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
|
|
3376
3885
|
doc="",
|
|
3377
3886
|
group="Operators",
|
|
3378
3887
|
)
|
|
3379
3888
|
add_builtin(
|
|
3380
3889
|
"sub",
|
|
3381
|
-
input_types={"
|
|
3890
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "b": matrix(shape=(Any, Any), dtype=Scalar)},
|
|
3382
3891
|
constraint=sametypes,
|
|
3383
|
-
value_func=
|
|
3892
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
|
|
3384
3893
|
doc="",
|
|
3385
3894
|
group="Operators",
|
|
3386
3895
|
)
|
|
3387
3896
|
add_builtin(
|
|
3388
3897
|
"sub",
|
|
3389
|
-
input_types={"
|
|
3390
|
-
value_func=
|
|
3898
|
+
input_types={"a": quaternion(dtype=Scalar), "b": quaternion(dtype=Scalar)},
|
|
3899
|
+
value_func=sametypes_create_value_func(quaternion(dtype=Scalar)),
|
|
3391
3900
|
doc="",
|
|
3392
3901
|
group="Operators",
|
|
3393
3902
|
)
|
|
3394
3903
|
add_builtin(
|
|
3395
3904
|
"sub",
|
|
3396
|
-
input_types={"
|
|
3397
|
-
value_func=
|
|
3905
|
+
input_types={"a": transformation(dtype=Scalar), "b": transformation(dtype=Scalar)},
|
|
3906
|
+
value_func=sametypes_create_value_func(transformation(dtype=Scalar)),
|
|
3398
3907
|
doc="",
|
|
3399
3908
|
group="Operators",
|
|
3400
3909
|
)
|
|
3401
3910
|
|
|
3402
3911
|
# bitwise operators
|
|
3403
|
-
add_builtin("bit_and", input_types={"
|
|
3404
|
-
add_builtin("bit_or", input_types={"
|
|
3405
|
-
add_builtin("bit_xor", input_types={"
|
|
3406
|
-
add_builtin("lshift", input_types={"
|
|
3407
|
-
add_builtin("rshift", input_types={"
|
|
3408
|
-
add_builtin("invert", input_types={"
|
|
3912
|
+
add_builtin("bit_and", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
|
|
3913
|
+
add_builtin("bit_or", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
|
|
3914
|
+
add_builtin("bit_xor", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
|
|
3915
|
+
add_builtin("lshift", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
|
|
3916
|
+
add_builtin("rshift", input_types={"a": Int, "b": Int}, value_func=sametypes_create_value_func(Int))
|
|
3917
|
+
add_builtin("invert", input_types={"a": Int}, value_func=sametypes_create_value_func(Int))
|
|
3918
|
+
|
|
3919
|
+
|
|
3920
|
+
add_builtin(
|
|
3921
|
+
"mul", input_types={"a": Scalar, "b": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators"
|
|
3922
|
+
)
|
|
3409
3923
|
|
|
3410
3924
|
|
|
3411
|
-
def
|
|
3412
|
-
def fn(arg_types,
|
|
3925
|
+
def scalar_mul_create_value_func(default):
|
|
3926
|
+
def fn(arg_types, arg_values):
|
|
3413
3927
|
if arg_types is None:
|
|
3414
3928
|
return default
|
|
3415
|
-
|
|
3416
|
-
|
|
3929
|
+
|
|
3930
|
+
scalar = next(t for t in arg_types.values() if t in scalar_types)
|
|
3931
|
+
compound = next(t for t in arg_types.values() if t not in scalar_types)
|
|
3417
3932
|
if scalar != compound._wp_scalar_type_:
|
|
3418
3933
|
raise RuntimeError("Object and coefficient must have the same scalar type when multiplying by scalar")
|
|
3934
|
+
|
|
3419
3935
|
return compound
|
|
3420
3936
|
|
|
3421
3937
|
return fn
|
|
3422
3938
|
|
|
3423
3939
|
|
|
3424
|
-
def mul_matvec_constraint(arg_types):
|
|
3425
|
-
return arg_types[0]._shape_[1] == arg_types[1]._length_
|
|
3426
|
-
|
|
3427
|
-
|
|
3428
|
-
def mul_matvec_value_func(arg_types, kwds, _):
|
|
3429
|
-
if arg_types is None:
|
|
3430
|
-
return vector(length=Any, dtype=Scalar)
|
|
3431
|
-
|
|
3432
|
-
if arg_types[0]._wp_scalar_type_ != arg_types[1]._wp_scalar_type_:
|
|
3433
|
-
raise RuntimeError(
|
|
3434
|
-
f"Can't multiply matrix and vector with different types {arg_types[0]._wp_scalar_type_}, {arg_types[1]._wp_scalar_type_}"
|
|
3435
|
-
)
|
|
3436
|
-
|
|
3437
|
-
if not mul_matmat_constraint(arg_types):
|
|
3438
|
-
raise RuntimeError(
|
|
3439
|
-
f"Can't multiply matrix of shape {arg_types[0]._shape_} and vector with length {arg_types[1]._length_}"
|
|
3440
|
-
)
|
|
3441
|
-
|
|
3442
|
-
return vector(length=arg_types[0]._shape_[0], dtype=arg_types[0]._wp_scalar_type_)
|
|
3443
|
-
|
|
3444
|
-
|
|
3445
|
-
def mul_vecmat_constraint(arg_types):
|
|
3446
|
-
return arg_types[1]._shape_[0] == arg_types[0]._length_
|
|
3447
|
-
|
|
3448
|
-
|
|
3449
|
-
def mul_vecmat_value_func(arg_types, kwds, _):
|
|
3450
|
-
if arg_types is None:
|
|
3451
|
-
return vector(length=Any, dtype=Scalar)
|
|
3452
|
-
|
|
3453
|
-
if arg_types[1]._wp_scalar_type_ != arg_types[0]._wp_scalar_type_:
|
|
3454
|
-
raise RuntimeError(
|
|
3455
|
-
f"Can't multiply vector and matrix with different types {arg_types[1]._wp_scalar_type_}, {arg_types[0]._wp_scalar_type_}"
|
|
3456
|
-
)
|
|
3457
|
-
|
|
3458
|
-
if not mul_vecmat_constraint(arg_types):
|
|
3459
|
-
raise RuntimeError(
|
|
3460
|
-
f"Can't multiply vector with length {arg_types[0]._length_} and matrix of shape {arg_types[1]._shape_}"
|
|
3461
|
-
)
|
|
3462
|
-
|
|
3463
|
-
return vector(length=arg_types[1]._shape_[1], dtype=arg_types[1]._wp_scalar_type_)
|
|
3464
|
-
|
|
3465
|
-
|
|
3466
|
-
def mul_matmat_constraint(arg_types):
|
|
3467
|
-
return arg_types[0]._shape_[1] == arg_types[1]._shape_[0]
|
|
3468
|
-
|
|
3469
|
-
|
|
3470
|
-
def mul_matmat_value_func(arg_types, kwds, _):
|
|
3471
|
-
if arg_types is None:
|
|
3472
|
-
return matrix(length=Any, dtype=Scalar)
|
|
3473
|
-
|
|
3474
|
-
if arg_types[0]._wp_scalar_type_ != arg_types[1]._wp_scalar_type_:
|
|
3475
|
-
raise RuntimeError(
|
|
3476
|
-
f"Can't multiply matrices with different types {arg_types[0]._wp_scalar_type_}, {arg_types[1]._wp_scalar_type_}"
|
|
3477
|
-
)
|
|
3478
|
-
|
|
3479
|
-
if not mul_matmat_constraint(arg_types):
|
|
3480
|
-
raise RuntimeError(f"Can't multiply matrix of shapes {arg_types[0]._shape_} and {arg_types[1]._shape_}")
|
|
3481
|
-
|
|
3482
|
-
return matrix(shape=(arg_types[0]._shape_[0], arg_types[1]._shape_[1]), dtype=arg_types[0]._wp_scalar_type_)
|
|
3483
|
-
|
|
3484
|
-
|
|
3485
|
-
add_builtin("mul", input_types={"x": Scalar, "y": Scalar}, value_func=sametype_value_func(Scalar), group="Operators")
|
|
3486
3940
|
add_builtin(
|
|
3487
3941
|
"mul",
|
|
3488
|
-
input_types={"
|
|
3489
|
-
value_func=
|
|
3942
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "b": Scalar},
|
|
3943
|
+
value_func=scalar_mul_create_value_func(vector(length=Any, dtype=Scalar)),
|
|
3490
3944
|
doc="",
|
|
3491
3945
|
group="Operators",
|
|
3492
3946
|
)
|
|
3493
3947
|
add_builtin(
|
|
3494
3948
|
"mul",
|
|
3495
|
-
input_types={"
|
|
3496
|
-
value_func=
|
|
3949
|
+
input_types={"a": Scalar, "b": vector(length=Any, dtype=Scalar)},
|
|
3950
|
+
value_func=scalar_mul_create_value_func(vector(length=Any, dtype=Scalar)),
|
|
3497
3951
|
doc="",
|
|
3498
3952
|
group="Operators",
|
|
3499
3953
|
)
|
|
3500
3954
|
add_builtin(
|
|
3501
3955
|
"mul",
|
|
3502
|
-
input_types={"
|
|
3503
|
-
value_func=
|
|
3956
|
+
input_types={"a": quaternion(dtype=Scalar), "b": Scalar},
|
|
3957
|
+
value_func=scalar_mul_create_value_func(quaternion(dtype=Scalar)),
|
|
3504
3958
|
doc="",
|
|
3505
3959
|
group="Operators",
|
|
3506
3960
|
)
|
|
3507
3961
|
add_builtin(
|
|
3508
3962
|
"mul",
|
|
3509
|
-
input_types={"
|
|
3510
|
-
value_func=
|
|
3963
|
+
input_types={"a": Scalar, "b": quaternion(dtype=Scalar)},
|
|
3964
|
+
value_func=scalar_mul_create_value_func(quaternion(dtype=Scalar)),
|
|
3511
3965
|
doc="",
|
|
3512
3966
|
group="Operators",
|
|
3513
3967
|
)
|
|
3514
3968
|
add_builtin(
|
|
3515
3969
|
"mul",
|
|
3516
|
-
input_types={"
|
|
3517
|
-
value_func=
|
|
3970
|
+
input_types={"a": quaternion(dtype=Scalar), "b": quaternion(dtype=Scalar)},
|
|
3971
|
+
value_func=sametypes_create_value_func(quaternion(dtype=Scalar)),
|
|
3518
3972
|
doc="",
|
|
3519
3973
|
group="Operators",
|
|
3520
3974
|
)
|
|
3521
3975
|
add_builtin(
|
|
3522
3976
|
"mul",
|
|
3523
|
-
input_types={"
|
|
3524
|
-
value_func=
|
|
3977
|
+
input_types={"a": Scalar, "b": matrix(shape=(Any, Any), dtype=Scalar)},
|
|
3978
|
+
value_func=scalar_mul_create_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
|
|
3525
3979
|
doc="",
|
|
3526
3980
|
group="Operators",
|
|
3527
3981
|
)
|
|
3528
3982
|
add_builtin(
|
|
3529
3983
|
"mul",
|
|
3530
|
-
input_types={"
|
|
3531
|
-
value_func=
|
|
3984
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "b": Scalar},
|
|
3985
|
+
value_func=scalar_mul_create_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
|
|
3532
3986
|
doc="",
|
|
3533
3987
|
group="Operators",
|
|
3534
3988
|
)
|
|
3989
|
+
|
|
3990
|
+
|
|
3991
|
+
def matvec_mul_constraint(arg_types: Mapping[str, type]):
|
|
3992
|
+
return arg_types["a"]._shape_[1] == arg_types["b"]._length_
|
|
3993
|
+
|
|
3994
|
+
|
|
3995
|
+
def matvec_mul_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
3996
|
+
if arg_types is None:
|
|
3997
|
+
return vector(length=Any, dtype=Scalar)
|
|
3998
|
+
|
|
3999
|
+
if arg_types["a"]._wp_scalar_type_ != arg_types["b"]._wp_scalar_type_:
|
|
4000
|
+
raise RuntimeError(
|
|
4001
|
+
f"Can't multiply matrix and vector with different types {arg_types['a']._wp_scalar_type_}, {arg_types['b']._wp_scalar_type_}"
|
|
4002
|
+
)
|
|
4003
|
+
|
|
4004
|
+
if not matvec_mul_constraint(arg_types):
|
|
4005
|
+
raise RuntimeError(
|
|
4006
|
+
f"Can't multiply matrix of shape {arg_types['a']._shape_} and vector with length {arg_types['b']._length_}"
|
|
4007
|
+
)
|
|
4008
|
+
|
|
4009
|
+
return vector(length=arg_types["a"]._shape_[0], dtype=arg_types["a"]._wp_scalar_type_)
|
|
4010
|
+
|
|
4011
|
+
|
|
3535
4012
|
add_builtin(
|
|
3536
4013
|
"mul",
|
|
3537
|
-
input_types={"
|
|
3538
|
-
constraint=
|
|
3539
|
-
value_func=
|
|
4014
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "b": vector(length=Any, dtype=Scalar)},
|
|
4015
|
+
constraint=matvec_mul_constraint,
|
|
4016
|
+
value_func=matvec_mul_value_func,
|
|
3540
4017
|
doc="",
|
|
3541
4018
|
group="Operators",
|
|
3542
4019
|
)
|
|
4020
|
+
|
|
4021
|
+
|
|
4022
|
+
def mul_vecmat_constraint(arg_types: Mapping[str, type]):
|
|
4023
|
+
return arg_types["b"]._shape_[0] == arg_types["a"]._length_
|
|
4024
|
+
|
|
4025
|
+
|
|
4026
|
+
def mul_vecmat_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
4027
|
+
if arg_types is None:
|
|
4028
|
+
return vector(length=Any, dtype=Scalar)
|
|
4029
|
+
|
|
4030
|
+
if arg_types["b"]._wp_scalar_type_ != arg_types["a"]._wp_scalar_type_:
|
|
4031
|
+
raise RuntimeError(
|
|
4032
|
+
f"Can't multiply vector and matrix with different types {arg_types['b']._wp_scalar_type_}, {arg_types['a']._wp_scalar_type_}"
|
|
4033
|
+
)
|
|
4034
|
+
|
|
4035
|
+
if not mul_vecmat_constraint(arg_types):
|
|
4036
|
+
raise RuntimeError(
|
|
4037
|
+
f"Can't multiply vector with length {arg_types['a']._length_} and matrix of shape {arg_types['b']._shape_}"
|
|
4038
|
+
)
|
|
4039
|
+
|
|
4040
|
+
return vector(length=arg_types["b"]._shape_[1], dtype=arg_types["b"]._wp_scalar_type_)
|
|
4041
|
+
|
|
4042
|
+
|
|
3543
4043
|
add_builtin(
|
|
3544
4044
|
"mul",
|
|
3545
|
-
input_types={"
|
|
4045
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "b": matrix(shape=(Any, Any), dtype=Scalar)},
|
|
3546
4046
|
constraint=mul_vecmat_constraint,
|
|
3547
4047
|
value_func=mul_vecmat_value_func,
|
|
3548
4048
|
doc="",
|
|
3549
4049
|
group="Operators",
|
|
3550
4050
|
)
|
|
4051
|
+
|
|
4052
|
+
|
|
4053
|
+
def matmat_mul_constraint(arg_types: Mapping[str, type]):
|
|
4054
|
+
return arg_types["a"]._shape_[1] == arg_types["b"]._shape_[0]
|
|
4055
|
+
|
|
4056
|
+
|
|
4057
|
+
def matmat_mul_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
4058
|
+
if arg_types is None:
|
|
4059
|
+
return matrix(length=Any, dtype=Scalar)
|
|
4060
|
+
|
|
4061
|
+
if arg_types["a"]._wp_scalar_type_ != arg_types["b"]._wp_scalar_type_:
|
|
4062
|
+
raise RuntimeError(
|
|
4063
|
+
f"Can't multiply matrices with different types {arg_types['a']._wp_scalar_type_}, {arg_types['b']._wp_scalar_type_}"
|
|
4064
|
+
)
|
|
4065
|
+
|
|
4066
|
+
if not matmat_mul_constraint(arg_types):
|
|
4067
|
+
raise RuntimeError(f"Can't multiply matrix of shapes {arg_types['a']._shape_} and {arg_types['b']._shape_}")
|
|
4068
|
+
|
|
4069
|
+
return matrix(shape=(arg_types["a"]._shape_[0], arg_types["b"]._shape_[1]), dtype=arg_types["a"]._wp_scalar_type_)
|
|
4070
|
+
|
|
4071
|
+
|
|
3551
4072
|
add_builtin(
|
|
3552
4073
|
"mul",
|
|
3553
|
-
input_types={"
|
|
3554
|
-
constraint=
|
|
3555
|
-
value_func=
|
|
4074
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "b": matrix(shape=(Any, Any), dtype=Scalar)},
|
|
4075
|
+
constraint=matmat_mul_constraint,
|
|
4076
|
+
value_func=matmat_mul_value_func,
|
|
3556
4077
|
doc="",
|
|
3557
4078
|
group="Operators",
|
|
3558
4079
|
)
|
|
3559
4080
|
|
|
4081
|
+
|
|
3560
4082
|
add_builtin(
|
|
3561
4083
|
"mul",
|
|
3562
|
-
input_types={"
|
|
3563
|
-
value_func=
|
|
4084
|
+
input_types={"a": transformation(dtype=Scalar), "b": transformation(dtype=Scalar)},
|
|
4085
|
+
value_func=sametypes_create_value_func(transformation(dtype=Scalar)),
|
|
3564
4086
|
doc="",
|
|
3565
4087
|
group="Operators",
|
|
3566
4088
|
)
|
|
3567
4089
|
add_builtin(
|
|
3568
4090
|
"mul",
|
|
3569
|
-
input_types={"
|
|
3570
|
-
value_func=
|
|
4091
|
+
input_types={"a": Scalar, "b": transformation(dtype=Scalar)},
|
|
4092
|
+
value_func=scalar_mul_create_value_func(transformation(dtype=Scalar)),
|
|
3571
4093
|
doc="",
|
|
3572
4094
|
group="Operators",
|
|
3573
4095
|
)
|
|
3574
4096
|
add_builtin(
|
|
3575
4097
|
"mul",
|
|
3576
|
-
input_types={"
|
|
3577
|
-
value_func=
|
|
4098
|
+
input_types={"a": transformation(dtype=Scalar), "b": Scalar},
|
|
4099
|
+
value_func=scalar_mul_create_value_func(transformation(dtype=Scalar)),
|
|
3578
4100
|
doc="",
|
|
3579
4101
|
group="Operators",
|
|
3580
4102
|
)
|
|
3581
4103
|
|
|
3582
|
-
add_builtin(
|
|
4104
|
+
add_builtin(
|
|
4105
|
+
"mod", input_types={"a": Scalar, "b": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators"
|
|
4106
|
+
)
|
|
3583
4107
|
|
|
3584
4108
|
add_builtin(
|
|
3585
4109
|
"div",
|
|
3586
|
-
input_types={"
|
|
3587
|
-
value_func=
|
|
4110
|
+
input_types={"a": Scalar, "b": Scalar},
|
|
4111
|
+
value_func=sametypes_create_value_func(Scalar),
|
|
3588
4112
|
doc="",
|
|
3589
4113
|
group="Operators",
|
|
3590
4114
|
require_original_output_arg=True,
|
|
3591
4115
|
)
|
|
3592
4116
|
add_builtin(
|
|
3593
4117
|
"div",
|
|
3594
|
-
input_types={"
|
|
3595
|
-
value_func=
|
|
4118
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "b": Scalar},
|
|
4119
|
+
value_func=scalar_mul_create_value_func(vector(length=Any, dtype=Scalar)),
|
|
3596
4120
|
doc="",
|
|
3597
4121
|
group="Operators",
|
|
3598
4122
|
)
|
|
3599
4123
|
add_builtin(
|
|
3600
4124
|
"div",
|
|
3601
|
-
input_types={"
|
|
3602
|
-
value_func=
|
|
4125
|
+
input_types={"a": Scalar, "b": vector(length=Any, dtype=Scalar)},
|
|
4126
|
+
value_func=scalar_mul_create_value_func(vector(length=Any, dtype=Scalar)),
|
|
3603
4127
|
doc="",
|
|
3604
4128
|
group="Operators",
|
|
3605
4129
|
)
|
|
3606
4130
|
add_builtin(
|
|
3607
4131
|
"div",
|
|
3608
|
-
input_types={"
|
|
3609
|
-
value_func=
|
|
4132
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "b": Scalar},
|
|
4133
|
+
value_func=scalar_mul_create_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
|
|
3610
4134
|
doc="",
|
|
3611
4135
|
group="Operators",
|
|
3612
4136
|
)
|
|
3613
4137
|
add_builtin(
|
|
3614
4138
|
"div",
|
|
3615
|
-
input_types={"
|
|
3616
|
-
value_func=
|
|
4139
|
+
input_types={"a": Scalar, "b": matrix(shape=(Any, Any), dtype=Scalar)},
|
|
4140
|
+
value_func=scalar_mul_create_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
|
|
3617
4141
|
doc="",
|
|
3618
4142
|
group="Operators",
|
|
3619
4143
|
)
|
|
3620
4144
|
add_builtin(
|
|
3621
4145
|
"div",
|
|
3622
|
-
input_types={"
|
|
3623
|
-
value_func=
|
|
4146
|
+
input_types={"a": quaternion(dtype=Scalar), "b": Scalar},
|
|
4147
|
+
value_func=scalar_mul_create_value_func(quaternion(dtype=Scalar)),
|
|
3624
4148
|
doc="",
|
|
3625
4149
|
group="Operators",
|
|
3626
4150
|
)
|
|
3627
4151
|
add_builtin(
|
|
3628
4152
|
"div",
|
|
3629
|
-
input_types={"
|
|
3630
|
-
value_func=
|
|
4153
|
+
input_types={"a": Scalar, "b": quaternion(dtype=Scalar)},
|
|
4154
|
+
value_func=scalar_mul_create_value_func(quaternion(dtype=Scalar)),
|
|
3631
4155
|
doc="",
|
|
3632
4156
|
group="Operators",
|
|
3633
4157
|
)
|
|
3634
4158
|
|
|
3635
4159
|
add_builtin(
|
|
3636
4160
|
"floordiv",
|
|
3637
|
-
input_types={"
|
|
3638
|
-
value_func=
|
|
4161
|
+
input_types={"a": Scalar, "b": Scalar},
|
|
4162
|
+
value_func=sametypes_create_value_func(Scalar),
|
|
3639
4163
|
doc="",
|
|
3640
4164
|
group="Operators",
|
|
3641
4165
|
)
|
|
3642
4166
|
|
|
3643
|
-
add_builtin("pos", input_types={"x": Scalar}, value_func=
|
|
4167
|
+
add_builtin("pos", input_types={"x": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators")
|
|
3644
4168
|
add_builtin(
|
|
3645
4169
|
"pos",
|
|
3646
4170
|
input_types={"x": vector(length=Any, dtype=Scalar)},
|
|
3647
|
-
value_func=
|
|
4171
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
|
|
3648
4172
|
doc="",
|
|
3649
4173
|
group="Operators",
|
|
3650
4174
|
)
|
|
3651
4175
|
add_builtin(
|
|
3652
4176
|
"pos",
|
|
3653
4177
|
input_types={"x": quaternion(dtype=Scalar)},
|
|
3654
|
-
value_func=
|
|
4178
|
+
value_func=sametypes_create_value_func(quaternion(dtype=Scalar)),
|
|
3655
4179
|
doc="",
|
|
3656
4180
|
group="Operators",
|
|
3657
4181
|
)
|
|
3658
4182
|
add_builtin(
|
|
3659
4183
|
"pos",
|
|
3660
4184
|
input_types={"x": matrix(shape=(Any, Any), dtype=Scalar)},
|
|
3661
|
-
value_func=
|
|
4185
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
|
|
3662
4186
|
doc="",
|
|
3663
4187
|
group="Operators",
|
|
3664
4188
|
)
|
|
3665
|
-
add_builtin("neg", input_types={"x": Scalar}, value_func=
|
|
4189
|
+
add_builtin("neg", input_types={"x": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators")
|
|
3666
4190
|
add_builtin(
|
|
3667
4191
|
"neg",
|
|
3668
4192
|
input_types={"x": vector(length=Any, dtype=Scalar)},
|
|
3669
|
-
value_func=
|
|
4193
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
|
|
3670
4194
|
doc="",
|
|
3671
4195
|
group="Operators",
|
|
3672
4196
|
)
|
|
3673
4197
|
add_builtin(
|
|
3674
4198
|
"neg",
|
|
3675
4199
|
input_types={"x": quaternion(dtype=Scalar)},
|
|
3676
|
-
value_func=
|
|
4200
|
+
value_func=sametypes_create_value_func(quaternion(dtype=Scalar)),
|
|
3677
4201
|
doc="",
|
|
3678
4202
|
group="Operators",
|
|
3679
4203
|
)
|
|
3680
4204
|
add_builtin(
|
|
3681
4205
|
"neg",
|
|
3682
4206
|
input_types={"x": matrix(shape=(Any, Any), dtype=Scalar)},
|
|
3683
|
-
value_func=
|
|
4207
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Scalar)),
|
|
3684
4208
|
doc="",
|
|
3685
4209
|
group="Operators",
|
|
3686
4210
|
)
|
|
3687
4211
|
|
|
3688
|
-
add_builtin("unot", input_types={"
|
|
4212
|
+
add_builtin("unot", input_types={"a": builtins.bool}, value_type=builtins.bool, doc="", group="Operators")
|
|
3689
4213
|
for t in int_types:
|
|
3690
|
-
add_builtin("unot", input_types={"
|
|
4214
|
+
add_builtin("unot", input_types={"a": t}, value_type=builtins.bool, doc="", group="Operators")
|
|
3691
4215
|
|
|
3692
4216
|
|
|
3693
4217
|
add_builtin("unot", input_types={"a": array(dtype=Any)}, value_type=builtins.bool, doc="", group="Operators")
|