warp-lang 1.7.2__py3-none-manylinux_2_34_aarch64.whl → 1.8.0__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +3 -1
- warp/__init__.pyi +3489 -1
- warp/autograd.py +45 -122
- warp/bin/warp.so +0 -0
- warp/build.py +241 -252
- warp/build_dll.py +125 -26
- warp/builtins.py +1907 -384
- warp/codegen.py +257 -101
- warp/config.py +12 -1
- warp/constants.py +1 -1
- warp/context.py +657 -223
- warp/dlpack.py +1 -1
- warp/examples/benchmarks/benchmark_cloth.py +2 -2
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/core/example_sample_mesh.py +1 -1
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/fem/example_adaptive_grid.py +5 -5
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +1 -1
- warp/examples/fem/example_convection_diffusion.py +9 -6
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion.py +2 -2
- warp/examples/fem/example_diffusion_3d.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +5 -3
- warp/examples/fem/example_mixed_elasticity.py +5 -3
- warp/examples/fem/example_navier_stokes.py +11 -9
- warp/examples/fem/example_nonconforming_contact.py +5 -3
- warp/examples/fem/example_streamlines.py +8 -3
- warp/examples/fem/utils.py +9 -8
- warp/examples/interop/example_jax_ffi_callback.py +2 -2
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/sim/example_cloth.py +1 -1
- warp/examples/sim/example_cloth_self_contact.py +48 -54
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +2 -1
- warp/examples/tile/example_tile_convolution.py +1 -1
- warp/examples/tile/example_tile_filtering.py +1 -1
- warp/examples/tile/example_tile_matmul.py +1 -1
- warp/examples/tile/example_tile_mlp.py +2 -0
- warp/fabric.py +7 -7
- warp/fem/__init__.py +5 -0
- warp/fem/adaptivity.py +1 -1
- warp/fem/cache.py +152 -63
- warp/fem/dirichlet.py +2 -2
- warp/fem/domain.py +136 -6
- warp/fem/field/field.py +141 -99
- warp/fem/field/nodal_field.py +85 -39
- warp/fem/field/virtual.py +97 -52
- warp/fem/geometry/adaptive_nanogrid.py +91 -86
- warp/fem/geometry/closest_point.py +13 -0
- warp/fem/geometry/deformed_geometry.py +102 -40
- warp/fem/geometry/element.py +56 -2
- warp/fem/geometry/geometry.py +323 -22
- warp/fem/geometry/grid_2d.py +157 -62
- warp/fem/geometry/grid_3d.py +116 -20
- warp/fem/geometry/hexmesh.py +86 -20
- warp/fem/geometry/nanogrid.py +166 -86
- warp/fem/geometry/partition.py +59 -25
- warp/fem/geometry/quadmesh.py +86 -135
- warp/fem/geometry/tetmesh.py +47 -119
- warp/fem/geometry/trimesh.py +77 -270
- warp/fem/integrate.py +107 -52
- warp/fem/linalg.py +25 -58
- warp/fem/operator.py +124 -27
- warp/fem/quadrature/pic_quadrature.py +36 -14
- warp/fem/quadrature/quadrature.py +40 -16
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +66 -46
- warp/fem/space/basis_space.py +17 -4
- warp/fem/space/dof_mapper.py +1 -1
- warp/fem/space/function_space.py +2 -2
- warp/fem/space/grid_2d_function_space.py +4 -1
- warp/fem/space/hexmesh_function_space.py +4 -2
- warp/fem/space/nanogrid_function_space.py +3 -1
- warp/fem/space/partition.py +11 -2
- warp/fem/space/quadmesh_function_space.py +4 -1
- warp/fem/space/restriction.py +5 -2
- warp/fem/space/shape/__init__.py +10 -8
- warp/fem/space/tetmesh_function_space.py +4 -1
- warp/fem/space/topology.py +52 -21
- warp/fem/space/trimesh_function_space.py +4 -1
- warp/fem/utils.py +53 -8
- warp/jax.py +1 -2
- warp/jax_experimental/ffi.py +12 -17
- warp/jax_experimental/xla_ffi.py +37 -24
- warp/math.py +171 -1
- warp/native/array.h +99 -0
- warp/native/builtin.h +174 -31
- warp/native/coloring.cpp +1 -1
- warp/native/exports.h +118 -63
- warp/native/intersect.h +3 -3
- warp/native/mat.h +5 -10
- warp/native/mathdx.cpp +11 -5
- warp/native/matnn.h +1 -123
- warp/native/quat.h +28 -4
- warp/native/sparse.cpp +121 -258
- warp/native/sparse.cu +181 -274
- warp/native/spatial.h +305 -17
- warp/native/tile.h +583 -72
- warp/native/tile_radix_sort.h +1108 -0
- warp/native/tile_reduce.h +237 -2
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +6 -16
- warp/native/warp.cpp +36 -4
- warp/native/warp.cu +574 -51
- warp/native/warp.h +47 -74
- warp/optim/linear.py +5 -1
- warp/paddle.py +7 -8
- warp/py.typed +0 -0
- warp/render/render_opengl.py +58 -29
- warp/render/render_usd.py +124 -61
- warp/sim/__init__.py +9 -0
- warp/sim/collide.py +252 -78
- warp/sim/graph_coloring.py +8 -1
- warp/sim/import_mjcf.py +4 -3
- warp/sim/import_usd.py +11 -7
- warp/sim/integrator.py +5 -2
- warp/sim/integrator_euler.py +1 -1
- warp/sim/integrator_featherstone.py +1 -1
- warp/sim/integrator_vbd.py +751 -320
- warp/sim/integrator_xpbd.py +1 -1
- warp/sim/model.py +265 -260
- warp/sim/utils.py +10 -7
- warp/sparse.py +303 -166
- warp/tape.py +52 -51
- warp/tests/cuda/test_conditional_captures.py +1046 -0
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/geometry/test_volume.py +2 -2
- warp/tests/interop/test_dlpack.py +9 -9
- warp/tests/interop/test_jax.py +0 -1
- warp/tests/run_coverage_serial.py +1 -1
- warp/tests/sim/disabled_kinematics.py +2 -2
- warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
- warp/tests/sim/test_collision.py +159 -51
- warp/tests/sim/test_coloring.py +15 -1
- warp/tests/test_array.py +254 -2
- warp/tests/test_array_reduce.py +2 -2
- warp/tests/test_atomic_cas.py +299 -0
- warp/tests/test_codegen.py +142 -19
- warp/tests/test_conditional.py +47 -1
- warp/tests/test_ctypes.py +0 -20
- warp/tests/test_devices.py +8 -0
- warp/tests/test_fabricarray.py +4 -2
- warp/tests/test_fem.py +58 -25
- warp/tests/test_func.py +42 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_lerp.py +1 -3
- warp/tests/test_map.py +481 -0
- warp/tests/test_mat.py +1 -24
- warp/tests/test_quat.py +6 -15
- warp/tests/test_rounding.py +10 -38
- warp/tests/test_runlength_encode.py +7 -7
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +51 -2
- warp/tests/test_spatial.py +507 -1
- warp/tests/test_struct.py +2 -2
- warp/tests/test_tuple.py +265 -0
- warp/tests/test_types.py +2 -2
- warp/tests/test_utils.py +24 -18
- warp/tests/tile/test_tile.py +420 -1
- warp/tests/tile/test_tile_mathdx.py +518 -14
- warp/tests/tile/test_tile_reduce.py +213 -0
- warp/tests/tile/test_tile_shared_memory.py +130 -1
- warp/tests/tile/test_tile_sort.py +117 -0
- warp/tests/unittest_suites.py +4 -6
- warp/types.py +462 -308
- warp/utils.py +647 -86
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/RECORD +177 -165
- warp/stubs.py +0 -3381
- warp/tests/sim/test_xpbd.py +0 -399
- warp/tests/test_mlp.py +0 -282
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/builtins.py
CHANGED
|
@@ -13,13 +13,15 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
16
18
|
import builtins
|
|
17
19
|
import functools
|
|
18
20
|
from typing import Any, Callable, Mapping, Sequence
|
|
19
21
|
|
|
20
22
|
import warp.build
|
|
21
23
|
import warp.context
|
|
22
|
-
from warp.codegen import Reference, Var, strip_reference
|
|
24
|
+
from warp.codegen import Reference, Var, get_arg_value, strip_reference
|
|
23
25
|
from warp.types import *
|
|
24
26
|
|
|
25
27
|
from .context import add_builtin
|
|
@@ -55,6 +57,33 @@ def sametypes_create_value_func(default: TypeVar):
|
|
|
55
57
|
return fn
|
|
56
58
|
|
|
57
59
|
|
|
60
|
+
def extract_tuple(arg, as_constant=False):
|
|
61
|
+
if isinstance(arg, Var):
|
|
62
|
+
if isinstance(arg.type, warp.types.tuple_t):
|
|
63
|
+
out = arg.type.values
|
|
64
|
+
else:
|
|
65
|
+
out = (arg,)
|
|
66
|
+
elif isinstance(arg, warp.types.tuple_t):
|
|
67
|
+
out = arg.values
|
|
68
|
+
elif not isinstance(arg, Sequence):
|
|
69
|
+
out = (arg,)
|
|
70
|
+
else:
|
|
71
|
+
out = arg
|
|
72
|
+
|
|
73
|
+
if as_constant:
|
|
74
|
+
return tuple(x.constant if isinstance(x, Var) else x for x in out)
|
|
75
|
+
|
|
76
|
+
return out
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def static_len_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
80
|
+
if arg_types is None:
|
|
81
|
+
return int
|
|
82
|
+
|
|
83
|
+
length = warp.types.type_length(arg_types["a"])
|
|
84
|
+
return Var(None, type=int, constant=length)
|
|
85
|
+
|
|
86
|
+
|
|
58
87
|
# ---------------------------------
|
|
59
88
|
# Scalar Math
|
|
60
89
|
|
|
@@ -399,7 +428,7 @@ add_builtin(
|
|
|
399
428
|
)
|
|
400
429
|
|
|
401
430
|
|
|
402
|
-
def scalar_infer_type(arg_types:
|
|
431
|
+
def scalar_infer_type(arg_types: Mapping[str, type] | tuple[type, ...] | None):
|
|
403
432
|
if arg_types is None:
|
|
404
433
|
return Scalar
|
|
405
434
|
|
|
@@ -1155,6 +1184,11 @@ add_builtin(
|
|
|
1155
1184
|
|
|
1156
1185
|
|
|
1157
1186
|
def matrix_transform_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1187
|
+
warp.utils.warn(
|
|
1188
|
+
"the built-in `wp.matrix()` function to construct a 4x4 matrix from a 3D position, quaternion, "
|
|
1189
|
+
"and 3D scale vector will be deprecated in favor of `wp.transform_compose()`.",
|
|
1190
|
+
DeprecationWarning,
|
|
1191
|
+
)
|
|
1158
1192
|
if arg_types is None:
|
|
1159
1193
|
return matrix(shape=(4, 4), dtype=Float)
|
|
1160
1194
|
|
|
@@ -1204,21 +1238,47 @@ add_builtin(
|
|
|
1204
1238
|
dispatch_func=matrix_transform_dispatch_func,
|
|
1205
1239
|
native_func="mat_t",
|
|
1206
1240
|
doc="""Construct a 4x4 transformation matrix that applies the transformations as
|
|
1207
|
-
Translation(pos)*Rotation(rot)*Scaling(scale) when applied to column vectors, i.e.: y = (TRS)*x
|
|
1241
|
+
Translation(pos)*Rotation(rot)*Scaling(scale) when applied to column vectors, i.e.: y = (TRS)*x
|
|
1242
|
+
|
|
1243
|
+
.. warning::
|
|
1244
|
+
This function has been deprecated in favor of :func:`warp.math.transform_compose()`.""",
|
|
1208
1245
|
group="Vector Math",
|
|
1209
1246
|
export=False,
|
|
1210
1247
|
)
|
|
1211
1248
|
|
|
1212
1249
|
|
|
1213
|
-
|
|
1214
|
-
|
|
1250
|
+
def svd3_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1251
|
+
if arg_types is None:
|
|
1252
|
+
return (
|
|
1253
|
+
matrix(shape=(3, 3), dtype=Float),
|
|
1254
|
+
vector(length=3, dtype=Float),
|
|
1255
|
+
matrix(shape=(3, 3), dtype=Float),
|
|
1256
|
+
)
|
|
1257
|
+
|
|
1258
|
+
dtype = arg_types["A"]._wp_scalar_type_
|
|
1259
|
+
return (
|
|
1260
|
+
matrix(shape=(3, 3), dtype=dtype),
|
|
1261
|
+
vector(length=3, dtype=dtype),
|
|
1262
|
+
matrix(shape=(3, 3), dtype=dtype),
|
|
1263
|
+
)
|
|
1264
|
+
|
|
1265
|
+
|
|
1266
|
+
add_builtin(
|
|
1267
|
+
"svd3",
|
|
1268
|
+
input_types={"A": matrix(shape=(3, 3), dtype=Float)},
|
|
1269
|
+
value_func=svd3_value_func,
|
|
1270
|
+
group="Vector Math",
|
|
1271
|
+
doc="""Compute the SVD of a 3x3 matrix ``A``. The singular values are returned in ``sigma``,
|
|
1272
|
+
while the left and right basis vectors are returned in ``U`` and ``V``.""",
|
|
1273
|
+
)
|
|
1274
|
+
|
|
1215
1275
|
add_builtin(
|
|
1216
1276
|
"svd3",
|
|
1217
1277
|
input_types={
|
|
1218
1278
|
"A": matrix(shape=(3, 3), dtype=Float),
|
|
1219
1279
|
"U": matrix(shape=(3, 3), dtype=Float),
|
|
1220
1280
|
"sigma": vector(length=3, dtype=Float),
|
|
1221
|
-
"V": matrix(shape=(3, 3), dtype=
|
|
1281
|
+
"V": matrix(shape=(3, 3), dtype=Float),
|
|
1222
1282
|
},
|
|
1223
1283
|
value_type=None,
|
|
1224
1284
|
group="Vector Math",
|
|
@@ -1227,13 +1287,39 @@ add_builtin(
|
|
|
1227
1287
|
while the left and right basis vectors are returned in ``U`` and ``V``.""",
|
|
1228
1288
|
)
|
|
1229
1289
|
|
|
1290
|
+
|
|
1291
|
+
def svd2_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1292
|
+
if arg_types is None:
|
|
1293
|
+
return (
|
|
1294
|
+
matrix(shape=(2, 2), dtype=Float),
|
|
1295
|
+
vector(length=2, dtype=Float),
|
|
1296
|
+
matrix(shape=(2, 2), dtype=Float),
|
|
1297
|
+
)
|
|
1298
|
+
|
|
1299
|
+
dtype = arg_types["A"]._wp_scalar_type_
|
|
1300
|
+
return (
|
|
1301
|
+
matrix(shape=(2, 2), dtype=dtype),
|
|
1302
|
+
vector(length=2, dtype=dtype),
|
|
1303
|
+
matrix(shape=(2, 2), dtype=dtype),
|
|
1304
|
+
)
|
|
1305
|
+
|
|
1306
|
+
|
|
1307
|
+
add_builtin(
|
|
1308
|
+
"svd2",
|
|
1309
|
+
input_types={"A": matrix(shape=(2, 2), dtype=Float)},
|
|
1310
|
+
value_func=svd2_value_func,
|
|
1311
|
+
group="Vector Math",
|
|
1312
|
+
doc="""Compute the SVD of a 2x2 matrix ``A``. The singular values are returned in ``sigma``,
|
|
1313
|
+
while the left and right basis vectors are returned in ``U`` and ``V``.""",
|
|
1314
|
+
)
|
|
1315
|
+
|
|
1230
1316
|
add_builtin(
|
|
1231
1317
|
"svd2",
|
|
1232
1318
|
input_types={
|
|
1233
1319
|
"A": matrix(shape=(2, 2), dtype=Float),
|
|
1234
1320
|
"U": matrix(shape=(2, 2), dtype=Float),
|
|
1235
1321
|
"sigma": vector(length=2, dtype=Float),
|
|
1236
|
-
"V": matrix(shape=(2, 2), dtype=
|
|
1322
|
+
"V": matrix(shape=(2, 2), dtype=Float),
|
|
1237
1323
|
},
|
|
1238
1324
|
value_type=None,
|
|
1239
1325
|
group="Vector Math",
|
|
@@ -1242,6 +1328,30 @@ add_builtin(
|
|
|
1242
1328
|
while the left and right basis vectors are returned in ``U`` and ``V``.""",
|
|
1243
1329
|
)
|
|
1244
1330
|
|
|
1331
|
+
|
|
1332
|
+
def qr3_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1333
|
+
if arg_types is None:
|
|
1334
|
+
return (
|
|
1335
|
+
matrix(shape=(3, 3), dtype=Float),
|
|
1336
|
+
matrix(shape=(3, 3), dtype=Float),
|
|
1337
|
+
)
|
|
1338
|
+
|
|
1339
|
+
dtype = arg_types["A"]._wp_scalar_type_
|
|
1340
|
+
return (
|
|
1341
|
+
matrix(shape=(3, 3), dtype=dtype),
|
|
1342
|
+
matrix(shape=(3, 3), dtype=dtype),
|
|
1343
|
+
)
|
|
1344
|
+
|
|
1345
|
+
|
|
1346
|
+
add_builtin(
|
|
1347
|
+
"qr3",
|
|
1348
|
+
input_types={"A": matrix(shape=(3, 3), dtype=Float)},
|
|
1349
|
+
value_func=qr3_value_func,
|
|
1350
|
+
group="Vector Math",
|
|
1351
|
+
doc="""Compute the QR decomposition of a 3x3 matrix ``A``. The orthogonal matrix is returned in ``Q``,
|
|
1352
|
+
while the upper triangular matrix is returned in ``R``.""",
|
|
1353
|
+
)
|
|
1354
|
+
|
|
1245
1355
|
add_builtin(
|
|
1246
1356
|
"qr3",
|
|
1247
1357
|
input_types={
|
|
@@ -1256,6 +1366,27 @@ add_builtin(
|
|
|
1256
1366
|
while the upper triangular matrix is returned in ``R``.""",
|
|
1257
1367
|
)
|
|
1258
1368
|
|
|
1369
|
+
|
|
1370
|
+
def eig3_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1371
|
+
if arg_types is None:
|
|
1372
|
+
return (matrix(shape=(3, 3), dtype=Float), vector(length=3, dtype=Float))
|
|
1373
|
+
|
|
1374
|
+
dtype = arg_types["A"]._wp_scalar_type_
|
|
1375
|
+
return (
|
|
1376
|
+
matrix(shape=(3, 3), dtype=dtype),
|
|
1377
|
+
vector(length=3, dtype=dtype),
|
|
1378
|
+
)
|
|
1379
|
+
|
|
1380
|
+
|
|
1381
|
+
add_builtin(
|
|
1382
|
+
"eig3",
|
|
1383
|
+
input_types={"A": matrix(shape=(3, 3), dtype=Float)},
|
|
1384
|
+
value_func=eig3_value_func,
|
|
1385
|
+
group="Vector Math",
|
|
1386
|
+
doc="""Compute the eigendecomposition of a 3x3 matrix ``A``. The eigenvectors are returned as the columns of ``Q``,
|
|
1387
|
+
while the corresponding eigenvalues are returned in ``d``.""",
|
|
1388
|
+
)
|
|
1389
|
+
|
|
1259
1390
|
add_builtin(
|
|
1260
1391
|
"eig3",
|
|
1261
1392
|
input_types={
|
|
@@ -1422,13 +1553,34 @@ add_builtin(
|
|
|
1422
1553
|
group="Quaternion Math",
|
|
1423
1554
|
doc="Construct a quaternion representing a rotation of angle radians around the given axis.",
|
|
1424
1555
|
)
|
|
1556
|
+
|
|
1557
|
+
|
|
1558
|
+
def quat_to_axis_angle_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1559
|
+
if arg_types is None:
|
|
1560
|
+
return (vector(length=3, dtype=Float), Float)
|
|
1561
|
+
|
|
1562
|
+
dtype = arg_types["quat"]._wp_scalar_type_
|
|
1563
|
+
return (vector(length=3, dtype=dtype), dtype)
|
|
1564
|
+
|
|
1565
|
+
|
|
1566
|
+
add_builtin(
|
|
1567
|
+
"quat_to_axis_angle",
|
|
1568
|
+
input_types={"quat": quaternion(dtype=Float)},
|
|
1569
|
+
value_func=quat_to_axis_angle_value_func,
|
|
1570
|
+
group="Quaternion Math",
|
|
1571
|
+
doc="Extract the rotation axis and angle radians a quaternion represents.",
|
|
1572
|
+
)
|
|
1573
|
+
|
|
1425
1574
|
add_builtin(
|
|
1426
1575
|
"quat_to_axis_angle",
|
|
1427
1576
|
input_types={"quat": quaternion(dtype=Float), "axis": vector(length=3, dtype=Float), "angle": Float},
|
|
1428
1577
|
value_type=None,
|
|
1429
1578
|
group="Quaternion Math",
|
|
1430
1579
|
doc="Extract the rotation axis and angle radians a quaternion represents.",
|
|
1580
|
+
export=False,
|
|
1431
1581
|
)
|
|
1582
|
+
|
|
1583
|
+
|
|
1432
1584
|
add_builtin(
|
|
1433
1585
|
"quat_from_matrix",
|
|
1434
1586
|
input_types={"mat": matrix(shape=(3, 3), dtype=Float)},
|
|
@@ -1506,6 +1658,48 @@ def transformation_value_func(arg_types: Mapping[str, type], arg_values: Mapping
|
|
|
1506
1658
|
if arg_types is None:
|
|
1507
1659
|
return transformation(dtype=Float)
|
|
1508
1660
|
|
|
1661
|
+
dtype = arg_values.get("dtype", None)
|
|
1662
|
+
|
|
1663
|
+
variadic_arg_types = arg_types.get("args", ())
|
|
1664
|
+
variadic_arg_count = len(variadic_arg_types)
|
|
1665
|
+
if variadic_arg_count == 0:
|
|
1666
|
+
# Zero-initialization, e.g.: `wp.transform()`, `wp.transformation(dtype=wp.float16)`.
|
|
1667
|
+
if dtype is None:
|
|
1668
|
+
dtype = float32
|
|
1669
|
+
elif variadic_arg_count == 1:
|
|
1670
|
+
# Initialization by filling a value, e.g.: `wp.transform(123)`,
|
|
1671
|
+
# `wp.transformation(123)`.
|
|
1672
|
+
value_type = strip_reference(variadic_arg_types[0])
|
|
1673
|
+
if dtype is None:
|
|
1674
|
+
dtype = value_type
|
|
1675
|
+
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1676
|
+
raise RuntimeError(
|
|
1677
|
+
f"the value used to fill this transform is expected to be of the type `{dtype.__name__}`"
|
|
1678
|
+
)
|
|
1679
|
+
elif variadic_arg_count == 7:
|
|
1680
|
+
# Initializing by value, e.g.: `wp.transform(1, 2, 3, 4, 5, 6, 7)`.
|
|
1681
|
+
try:
|
|
1682
|
+
value_type = scalar_infer_type(variadic_arg_types)
|
|
1683
|
+
except RuntimeError:
|
|
1684
|
+
raise RuntimeError("all values given when constructing a transform must have the same type") from None
|
|
1685
|
+
|
|
1686
|
+
if dtype is None:
|
|
1687
|
+
dtype = value_type
|
|
1688
|
+
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1689
|
+
raise RuntimeError(
|
|
1690
|
+
f"all values used to initialize this transform are expected to be of the type `{dtype.__name__}`"
|
|
1691
|
+
)
|
|
1692
|
+
|
|
1693
|
+
if dtype is None:
|
|
1694
|
+
raise RuntimeError("could not infer the `dtype` argument when calling the `wp.transform()` function")
|
|
1695
|
+
|
|
1696
|
+
return transformation(dtype=dtype)
|
|
1697
|
+
|
|
1698
|
+
|
|
1699
|
+
def transformation_pq_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1700
|
+
if arg_types is None:
|
|
1701
|
+
return transformation(dtype=Float)
|
|
1702
|
+
|
|
1509
1703
|
try:
|
|
1510
1704
|
value_type = float_infer_type(arg_types)
|
|
1511
1705
|
except RuntimeError:
|
|
@@ -1540,20 +1734,35 @@ def transformation_dispatch_func(input_types: Mapping[str, type], return_type: A
|
|
|
1540
1734
|
|
|
1541
1735
|
add_builtin(
|
|
1542
1736
|
"transformation",
|
|
1543
|
-
input_types={"
|
|
1737
|
+
input_types={"p": vector(length=3, dtype=Float), "q": quaternion(dtype=Float), "dtype": Float},
|
|
1544
1738
|
defaults={"dtype": None},
|
|
1545
|
-
value_func=
|
|
1739
|
+
value_func=transformation_pq_value_func,
|
|
1546
1740
|
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
1547
1741
|
dispatch_func=transformation_dispatch_func,
|
|
1548
1742
|
native_func="transform_t",
|
|
1549
1743
|
group="Transformations",
|
|
1550
|
-
doc="Construct a rigid-body transformation with translation part ``
|
|
1744
|
+
doc="Construct a rigid-body transformation with translation part ``p`` and rotation ``q``.",
|
|
1745
|
+
export=False,
|
|
1746
|
+
)
|
|
1747
|
+
|
|
1748
|
+
|
|
1749
|
+
add_builtin(
|
|
1750
|
+
"transformation",
|
|
1751
|
+
input_types={"*args": Float, "dtype": Float},
|
|
1752
|
+
defaults={"dtype": None},
|
|
1753
|
+
variadic=True,
|
|
1754
|
+
initializer_list_func=lambda arg_types, arg_values: len(arg_types.get("args", ())) > 1,
|
|
1755
|
+
value_func=transformation_value_func,
|
|
1756
|
+
export_func=lambda input_types: {k: v for k, v in input_types.items() if k not in ("dtype")},
|
|
1757
|
+
dispatch_func=transformation_dispatch_func,
|
|
1758
|
+
native_func="transform_t",
|
|
1759
|
+
doc="Construct a spatial transform vector of given dtype.",
|
|
1760
|
+
group="Spatial Math",
|
|
1551
1761
|
export=False,
|
|
1552
1762
|
)
|
|
1553
1763
|
|
|
1554
1764
|
|
|
1555
1765
|
def transform_identity_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1556
|
-
# if arg_types is None then we are in 'export' mode
|
|
1557
1766
|
if arg_types is None:
|
|
1558
1767
|
# return transformation(dtype=Float)
|
|
1559
1768
|
return transformf
|
|
@@ -1600,6 +1809,40 @@ add_builtin(
|
|
|
1600
1809
|
group="Transformations",
|
|
1601
1810
|
doc="Return the rotational part of a transform ``xform``.",
|
|
1602
1811
|
)
|
|
1812
|
+
add_builtin(
|
|
1813
|
+
"transform_set_translation",
|
|
1814
|
+
input_types={"xform": transformation(dtype=Float), "p": vector(length=3, dtype=Float)},
|
|
1815
|
+
value_type=None,
|
|
1816
|
+
group="Transformations",
|
|
1817
|
+
doc="Set the translational part of a transform ``xform``.",
|
|
1818
|
+
)
|
|
1819
|
+
add_builtin(
|
|
1820
|
+
"transform_set_rotation",
|
|
1821
|
+
input_types={"xform": transformation(dtype=Float), "q": quaternion(dtype=Float)},
|
|
1822
|
+
value_type=None,
|
|
1823
|
+
group="Transformations",
|
|
1824
|
+
doc="Set the rotational part of a transform ``xform``.",
|
|
1825
|
+
)
|
|
1826
|
+
# performs a copy internally if wp.config.enable_vector_component_overwrites is True
|
|
1827
|
+
add_builtin(
|
|
1828
|
+
"transform_set_translation_copy",
|
|
1829
|
+
input_types={"xform": transformation(dtype=Float), "p": vector(length=3, dtype=Float)},
|
|
1830
|
+
value_type=transformation(dtype=Float),
|
|
1831
|
+
group="Transformations",
|
|
1832
|
+
doc="Set the translational part of a transform ``xform``.",
|
|
1833
|
+
hidden=True,
|
|
1834
|
+
export=False,
|
|
1835
|
+
)
|
|
1836
|
+
# performs a copy internally if wp.config.enable_vector_component_overwrites is True
|
|
1837
|
+
add_builtin(
|
|
1838
|
+
"transform_set_rotation_copy",
|
|
1839
|
+
input_types={"xform": transformation(dtype=Float), "q": quaternion(dtype=Float)},
|
|
1840
|
+
value_type=transformation(dtype=Float),
|
|
1841
|
+
group="Transformations",
|
|
1842
|
+
doc="Set the rotational part of a transform ``xform``.",
|
|
1843
|
+
hidden=True,
|
|
1844
|
+
export=False,
|
|
1845
|
+
)
|
|
1603
1846
|
add_builtin(
|
|
1604
1847
|
"transform_multiply",
|
|
1605
1848
|
input_types={"a": transformation(dtype=Float), "b": transformation(dtype=Float)},
|
|
@@ -1831,40 +2074,15 @@ add_builtin(
|
|
|
1831
2074
|
# Tile-based primitives
|
|
1832
2075
|
|
|
1833
2076
|
|
|
1834
|
-
def tile_unpack_shape(arg_values):
|
|
1835
|
-
shape = arg_values["shape"]
|
|
1836
|
-
|
|
1837
|
-
if not isinstance(shape, tuple):
|
|
1838
|
-
# promote to tuple
|
|
1839
|
-
shape = (shape,)
|
|
1840
|
-
|
|
1841
|
-
# check that components are constants
|
|
1842
|
-
for d in shape:
|
|
1843
|
-
if d is None:
|
|
1844
|
-
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
1845
|
-
|
|
1846
|
-
return shape
|
|
1847
|
-
|
|
1848
|
-
|
|
1849
|
-
def tile_unpack_offset(arg_values, ndim=0):
|
|
1850
|
-
if "offset" in arg_values:
|
|
1851
|
-
offset = arg_values["offset"]
|
|
1852
|
-
else:
|
|
1853
|
-
offset = (0,) * ndim
|
|
1854
|
-
|
|
1855
|
-
if isinstance(offset, tuple):
|
|
1856
|
-
return offset
|
|
1857
|
-
else:
|
|
1858
|
-
# promote to tuple
|
|
1859
|
-
return (offset,)
|
|
1860
|
-
|
|
1861
|
-
|
|
1862
2077
|
def tile_zeros_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1863
2078
|
# return generic type (for doc builds)
|
|
1864
2079
|
if arg_types is None:
|
|
1865
|
-
return
|
|
2080
|
+
return tile(dtype=Any, shape=Tuple[int, ...])
|
|
2081
|
+
|
|
2082
|
+
shape = extract_tuple(arg_values["shape"], as_constant=True)
|
|
1866
2083
|
|
|
1867
|
-
|
|
2084
|
+
if None in shape:
|
|
2085
|
+
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
1868
2086
|
|
|
1869
2087
|
if "dtype" not in arg_values:
|
|
1870
2088
|
raise TypeError("tile_zeros() missing required keyword argument 'dtype'")
|
|
@@ -1877,17 +2095,20 @@ def tile_zeros_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str
|
|
|
1877
2095
|
|
|
1878
2096
|
dtype = arg_values["dtype"]
|
|
1879
2097
|
|
|
1880
|
-
return
|
|
2098
|
+
return tile(dtype=dtype, shape=shape, storage=arg_values["storage"])
|
|
1881
2099
|
|
|
1882
2100
|
|
|
1883
2101
|
def tile_zeros_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
1884
|
-
shape =
|
|
2102
|
+
shape = extract_tuple(arg_values["shape"], as_constant=True)
|
|
2103
|
+
|
|
2104
|
+
if None in shape:
|
|
2105
|
+
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
2106
|
+
|
|
1885
2107
|
dtype = arg_values["dtype"]
|
|
1886
2108
|
|
|
1887
2109
|
template_args = []
|
|
1888
2110
|
template_args.append(dtype)
|
|
1889
|
-
|
|
1890
|
-
template_args.append(d.constant)
|
|
2111
|
+
template_args.extend(shape)
|
|
1891
2112
|
|
|
1892
2113
|
return ([], template_args)
|
|
1893
2114
|
|
|
@@ -1929,9 +2150,12 @@ add_builtin(
|
|
|
1929
2150
|
def tile_ones_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1930
2151
|
# return generic type (for doc builds)
|
|
1931
2152
|
if arg_types is None:
|
|
1932
|
-
return
|
|
2153
|
+
return tile(dtype=Any, shape=Tuple[int, ...])
|
|
1933
2154
|
|
|
1934
|
-
shape =
|
|
2155
|
+
shape = extract_tuple(arg_values["shape"], as_constant=True)
|
|
2156
|
+
|
|
2157
|
+
if None in shape:
|
|
2158
|
+
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
1935
2159
|
|
|
1936
2160
|
if "dtype" not in arg_values:
|
|
1937
2161
|
raise TypeError("tile_ones() missing required keyword argument 'dtype'")
|
|
@@ -1944,17 +2168,20 @@ def tile_ones_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str,
|
|
|
1944
2168
|
|
|
1945
2169
|
dtype = arg_values["dtype"]
|
|
1946
2170
|
|
|
1947
|
-
return
|
|
2171
|
+
return tile(dtype=dtype, shape=shape, storage=arg_values["storage"])
|
|
1948
2172
|
|
|
1949
2173
|
|
|
1950
2174
|
def tile_ones_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
1951
|
-
shape =
|
|
2175
|
+
shape = extract_tuple(arg_values["shape"], as_constant=True)
|
|
2176
|
+
|
|
2177
|
+
if None in shape:
|
|
2178
|
+
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
2179
|
+
|
|
1952
2180
|
dtype = arg_values["dtype"]
|
|
1953
2181
|
|
|
1954
2182
|
template_args = []
|
|
1955
2183
|
template_args.append(dtype)
|
|
1956
|
-
|
|
1957
|
-
template_args.append(d.constant)
|
|
2184
|
+
template_args.extend(shape)
|
|
1958
2185
|
|
|
1959
2186
|
return ([], template_args)
|
|
1960
2187
|
|
|
@@ -1994,7 +2221,7 @@ add_builtin(
|
|
|
1994
2221
|
def tile_arange_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1995
2222
|
# return generic type (for doc builds)
|
|
1996
2223
|
if arg_types is None:
|
|
1997
|
-
return
|
|
2224
|
+
return tile(dtype=Scalar, shape=Tuple[int])
|
|
1998
2225
|
|
|
1999
2226
|
if "args" not in arg_values:
|
|
2000
2227
|
raise TypeError("tile_arange() requires at least one positional argument specifying the range")
|
|
@@ -2029,7 +2256,8 @@ def tile_arange_value_func(arg_types: Mapping[str, type], arg_values: Mapping[st
|
|
|
2029
2256
|
if arg_values["storage"] not in {"shared", "register"}:
|
|
2030
2257
|
raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
|
|
2031
2258
|
|
|
2032
|
-
|
|
2259
|
+
n = int((stop - start) / step)
|
|
2260
|
+
return tile(dtype=dtype, shape=(n,), storage=arg_values["storage"])
|
|
2033
2261
|
|
|
2034
2262
|
|
|
2035
2263
|
def tile_arange_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
@@ -2045,13 +2273,13 @@ def tile_arange_dispatch_func(arg_types: Mapping[str, type], return_type: Any, a
|
|
|
2045
2273
|
args = arg_values["args"]
|
|
2046
2274
|
|
|
2047
2275
|
if len(args) == 1:
|
|
2048
|
-
start = warp.codegen.Var(label=None, type=return_type.dtype, constant=
|
|
2276
|
+
start = warp.codegen.Var(label=None, type=return_type.dtype, constant=0)
|
|
2049
2277
|
stop = args[0]
|
|
2050
|
-
step = warp.codegen.Var(label=None, type=return_type.dtype, constant=
|
|
2278
|
+
step = warp.codegen.Var(label=None, type=return_type.dtype, constant=1)
|
|
2051
2279
|
elif len(args) == 2:
|
|
2052
2280
|
start = args[0]
|
|
2053
2281
|
stop = args[1]
|
|
2054
|
-
step = warp.codegen.Var(label=None, type=return_type.dtype, constant=
|
|
2282
|
+
step = warp.codegen.Var(label=None, type=return_type.dtype, constant=1)
|
|
2055
2283
|
elif len(args) == 3:
|
|
2056
2284
|
start = args[0]
|
|
2057
2285
|
stop = args[1]
|
|
@@ -2069,7 +2297,7 @@ def tile_arange_dispatch_func(arg_types: Mapping[str, type], return_type: Any, a
|
|
|
2069
2297
|
|
|
2070
2298
|
add_builtin(
|
|
2071
2299
|
"tile_arange",
|
|
2072
|
-
input_types={"*args": Scalar, "dtype":
|
|
2300
|
+
input_types={"*args": Scalar, "dtype": Scalar, "storage": str},
|
|
2073
2301
|
defaults={"dtype": None, "storage": "register"},
|
|
2074
2302
|
value_func=tile_arange_value_func,
|
|
2075
2303
|
dispatch_func=tile_arange_dispatch_func,
|
|
@@ -2094,12 +2322,19 @@ add_builtin(
|
|
|
2094
2322
|
|
|
2095
2323
|
def tile_load_tuple_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
2096
2324
|
if arg_types is None:
|
|
2097
|
-
return
|
|
2325
|
+
return tile(dtype=Any, shape=Tuple[int, ...])
|
|
2098
2326
|
|
|
2099
2327
|
a = arg_types["a"]
|
|
2100
2328
|
|
|
2101
|
-
shape =
|
|
2102
|
-
|
|
2329
|
+
shape = extract_tuple(arg_values["shape"], as_constant=True)
|
|
2330
|
+
|
|
2331
|
+
if None in shape:
|
|
2332
|
+
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
2333
|
+
|
|
2334
|
+
if "offset" in arg_values:
|
|
2335
|
+
offset = extract_tuple(arg_values["offset"])
|
|
2336
|
+
else:
|
|
2337
|
+
offset = (0,) * a.ndim
|
|
2103
2338
|
|
|
2104
2339
|
if a.ndim != len(shape):
|
|
2105
2340
|
raise ValueError(
|
|
@@ -2114,16 +2349,23 @@ def tile_load_tuple_value_func(arg_types: Mapping[str, type], arg_values: Mappin
|
|
|
2114
2349
|
if arg_values["storage"] not in {"shared", "register"}:
|
|
2115
2350
|
raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
|
|
2116
2351
|
|
|
2117
|
-
return
|
|
2352
|
+
return tile(dtype=a.dtype, shape=shape, storage=arg_values["storage"])
|
|
2118
2353
|
|
|
2119
2354
|
|
|
2120
2355
|
def tile_load_tuple_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2121
2356
|
a = args["a"]
|
|
2122
|
-
shape =
|
|
2123
|
-
|
|
2357
|
+
shape = extract_tuple(args["shape"], as_constant=True)
|
|
2358
|
+
|
|
2359
|
+
if None in shape:
|
|
2360
|
+
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
2361
|
+
|
|
2362
|
+
if "offset" in args:
|
|
2363
|
+
offset = extract_tuple(args["offset"])
|
|
2364
|
+
else:
|
|
2365
|
+
offset = (0,) * a.type.ndim
|
|
2124
2366
|
|
|
2125
2367
|
func_args = (a, *offset)
|
|
2126
|
-
template_args =
|
|
2368
|
+
template_args = shape
|
|
2127
2369
|
|
|
2128
2370
|
return (func_args, template_args)
|
|
2129
2371
|
|
|
@@ -2170,7 +2412,10 @@ def tile_store_value_func(arg_types, arg_values):
|
|
|
2170
2412
|
a = arg_types["a"]
|
|
2171
2413
|
t = arg_types["t"]
|
|
2172
2414
|
|
|
2173
|
-
|
|
2415
|
+
if "offset" in arg_types:
|
|
2416
|
+
c = extract_tuple(arg_values["offset"])
|
|
2417
|
+
else:
|
|
2418
|
+
c = (0,) * a.ndim
|
|
2174
2419
|
|
|
2175
2420
|
if len(c) != a.ndim:
|
|
2176
2421
|
raise ValueError(
|
|
@@ -2196,7 +2441,10 @@ def tile_store_dispatch_func(input_types: Mapping[str, type], return_type: Any,
|
|
|
2196
2441
|
a = args["a"]
|
|
2197
2442
|
t = args["t"]
|
|
2198
2443
|
|
|
2199
|
-
offset
|
|
2444
|
+
if "offset" in args:
|
|
2445
|
+
offset = extract_tuple(args["offset"])
|
|
2446
|
+
else:
|
|
2447
|
+
offset = (0,) * a.type.ndim
|
|
2200
2448
|
|
|
2201
2449
|
func_args = (a, *offset, t)
|
|
2202
2450
|
template_args = []
|
|
@@ -2206,7 +2454,7 @@ def tile_store_dispatch_func(input_types: Mapping[str, type], return_type: Any,
|
|
|
2206
2454
|
|
|
2207
2455
|
add_builtin(
|
|
2208
2456
|
"tile_store",
|
|
2209
|
-
input_types={"a": array(dtype=Any), "t":
|
|
2457
|
+
input_types={"a": array(dtype=Any), "t": tile(dtype=Any, shape=Tuple[int, ...]), "offset": Tuple[int, ...]},
|
|
2210
2458
|
value_func=tile_store_value_func,
|
|
2211
2459
|
dispatch_func=tile_store_dispatch_func,
|
|
2212
2460
|
defaults={"offset": None},
|
|
@@ -2226,7 +2474,7 @@ add_builtin(
|
|
|
2226
2474
|
# overload for scalar offset
|
|
2227
2475
|
add_builtin(
|
|
2228
2476
|
"tile_store",
|
|
2229
|
-
input_types={"a": array(dtype=Any), "t":
|
|
2477
|
+
input_types={"a": array(dtype=Any), "t": tile(dtype=Any, shape=Tuple[int, ...]), "offset": int},
|
|
2230
2478
|
value_func=tile_store_value_func,
|
|
2231
2479
|
dispatch_func=tile_store_dispatch_func,
|
|
2232
2480
|
defaults={"offset": None},
|
|
@@ -2241,12 +2489,16 @@ add_builtin(
|
|
|
2241
2489
|
def tile_atomic_add_value_func(arg_types, arg_values):
|
|
2242
2490
|
# return generic type (for doc builds)
|
|
2243
2491
|
if arg_types is None:
|
|
2244
|
-
return
|
|
2492
|
+
return tile(dtype=Any, shape=Tuple[int, ...])
|
|
2245
2493
|
|
|
2246
2494
|
a = arg_types["a"]
|
|
2247
2495
|
t = arg_types["t"]
|
|
2248
2496
|
|
|
2249
|
-
|
|
2497
|
+
if "offset" in arg_types:
|
|
2498
|
+
c = extract_tuple(arg_values["offset"])
|
|
2499
|
+
else:
|
|
2500
|
+
c = (0,) * a.ndim
|
|
2501
|
+
|
|
2250
2502
|
if len(c) != a.ndim:
|
|
2251
2503
|
raise ValueError(
|
|
2252
2504
|
f"tile_atomic_add() 'a' argument must have {len(c)} dimensions, "
|
|
@@ -2264,14 +2516,21 @@ def tile_atomic_add_value_func(arg_types, arg_values):
|
|
|
2264
2516
|
f"tile_atomic_add() 'a' and 't' arguments must have the same dtype, got {arg_types['a'].dtype} and {arg_types['t'].dtype}"
|
|
2265
2517
|
)
|
|
2266
2518
|
|
|
2267
|
-
return
|
|
2519
|
+
return tile(
|
|
2520
|
+
dtype=arg_types["t"].dtype,
|
|
2521
|
+
shape=arg_types["t"].shape,
|
|
2522
|
+
storage=arg_types["t"].storage,
|
|
2523
|
+
)
|
|
2268
2524
|
|
|
2269
2525
|
|
|
2270
2526
|
def tile_atomic_add_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2271
2527
|
a = args["a"]
|
|
2272
2528
|
t = args["t"]
|
|
2273
2529
|
|
|
2274
|
-
offset
|
|
2530
|
+
if "offset" in args:
|
|
2531
|
+
offset = extract_tuple(args["offset"])
|
|
2532
|
+
else:
|
|
2533
|
+
offset = (0,) * a.type.ndim
|
|
2275
2534
|
|
|
2276
2535
|
func_args = (a, *offset, t)
|
|
2277
2536
|
template_args = []
|
|
@@ -2281,13 +2540,13 @@ def tile_atomic_add_dispatch_func(input_types: Mapping[str, type], return_type:
|
|
|
2281
2540
|
|
|
2282
2541
|
add_builtin(
|
|
2283
2542
|
"tile_atomic_add",
|
|
2284
|
-
input_types={"a": array(dtype=Any), "t":
|
|
2543
|
+
input_types={"a": array(dtype=Any), "t": tile(dtype=Any, shape=Tuple[int, ...]), "offset": Tuple[int, ...]},
|
|
2285
2544
|
value_func=tile_atomic_add_value_func,
|
|
2286
2545
|
dispatch_func=tile_atomic_add_dispatch_func,
|
|
2287
2546
|
defaults={"offset": None},
|
|
2288
2547
|
variadic=False,
|
|
2289
2548
|
skip_replay=True,
|
|
2290
|
-
doc="""Atomically add a
|
|
2549
|
+
doc="""Atomically add a tile onto the array `a`, each element will be updated atomically.
|
|
2291
2550
|
|
|
2292
2551
|
:param a: Array in global memory, should have the same ``dtype`` as the input tile
|
|
2293
2552
|
:param t: Source tile to add to the destination array
|
|
@@ -2300,7 +2559,7 @@ add_builtin(
|
|
|
2300
2559
|
# overload for scalar offset
|
|
2301
2560
|
add_builtin(
|
|
2302
2561
|
"tile_atomic_add",
|
|
2303
|
-
input_types={"a": array(dtype=Any), "t":
|
|
2562
|
+
input_types={"a": array(dtype=Any), "t": tile(dtype=Any, shape=Tuple[int, ...]), "offset": int},
|
|
2304
2563
|
value_func=tile_atomic_add_value_func,
|
|
2305
2564
|
dispatch_func=tile_atomic_add_dispatch_func,
|
|
2306
2565
|
defaults={"offset": None},
|
|
@@ -2315,54 +2574,59 @@ add_builtin(
|
|
|
2315
2574
|
def tile_view_value_func(arg_types, arg_values):
|
|
2316
2575
|
# return generic type (for doc builds)
|
|
2317
2576
|
if arg_types is None:
|
|
2318
|
-
return
|
|
2577
|
+
return tile(dtype=Any, shape=Tuple[int, ...])
|
|
2319
2578
|
|
|
2320
|
-
|
|
2321
|
-
offset =
|
|
2579
|
+
tile_type = arg_types["t"]
|
|
2580
|
+
offset = extract_tuple(arg_values["offset"])
|
|
2322
2581
|
|
|
2323
|
-
if len(offset) > len(
|
|
2324
|
-
raise ValueError(f"tile_view() specified too many offset coordinates {len(offset)} > {len(
|
|
2582
|
+
if len(offset) > len(tile_type.shape):
|
|
2583
|
+
raise ValueError(f"tile_view() specified too many offset coordinates {len(offset)} > {len(tile_type.shape)}")
|
|
2325
2584
|
|
|
2326
2585
|
if "shape" in arg_values:
|
|
2327
2586
|
# if shape is specified take it directly, e.g.:
|
|
2328
2587
|
# tile_view(t, offset=(i,j), shape=(m,n))
|
|
2329
|
-
shape = arg_values["shape"]
|
|
2330
|
-
strides =
|
|
2588
|
+
shape = extract_tuple(arg_values["shape"], as_constant=True)
|
|
2589
|
+
strides = tile_type.strides
|
|
2331
2590
|
|
|
2332
|
-
if len(shape) != len(
|
|
2591
|
+
if len(shape) != len(tile_type.shape):
|
|
2333
2592
|
raise ValueError(
|
|
2334
|
-
f"tile_view() if shape is specified it must have same number of dimensions as source tile, expected {len(
|
|
2593
|
+
f"tile_view() if shape is specified it must have same number of dimensions as source tile, expected {len(tile_type.shape)}, got {len(shape)}"
|
|
2335
2594
|
)
|
|
2336
2595
|
else:
|
|
2337
2596
|
# if not specified, then take output shape from unspecified src dimensions
|
|
2338
2597
|
# e.g.: tile[i] will return a whole row of a 2D tile
|
|
2339
|
-
shape =
|
|
2340
|
-
strides =
|
|
2598
|
+
shape = tile_type.shape[len(offset) :]
|
|
2599
|
+
strides = tile_type.strides[len(offset) :]
|
|
2341
2600
|
|
|
2342
2601
|
assert len(shape) == len(strides)
|
|
2343
2602
|
|
|
2344
2603
|
# force source tile to shared memory
|
|
2345
|
-
|
|
2604
|
+
tile_type.storage = "shared"
|
|
2346
2605
|
|
|
2347
|
-
output =
|
|
2606
|
+
output = tile(
|
|
2607
|
+
dtype=tile_type.dtype, shape=shape, strides=strides, layout=tile_type.layout, storage="shared", owner=False
|
|
2608
|
+
)
|
|
2348
2609
|
return output
|
|
2349
2610
|
|
|
2350
2611
|
|
|
2351
2612
|
def tile_view_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
2352
2613
|
tile = arg_values["t"]
|
|
2353
|
-
coord = arg_values["offset"]
|
|
2614
|
+
coord = extract_tuple(arg_values["offset"])
|
|
2354
2615
|
|
|
2355
2616
|
# zero-pad coord to match source array
|
|
2356
2617
|
view_coord = [0] * len(tile.type.shape)
|
|
2357
2618
|
for i in range(len(coord)):
|
|
2358
2619
|
view_coord[i] = coord[i]
|
|
2359
2620
|
|
|
2360
|
-
|
|
2621
|
+
func_args = (tile, *view_coord)
|
|
2622
|
+
template_args = (return_type,)
|
|
2623
|
+
|
|
2624
|
+
return (func_args, template_args)
|
|
2361
2625
|
|
|
2362
2626
|
|
|
2363
2627
|
add_builtin(
|
|
2364
2628
|
"tile_view",
|
|
2365
|
-
input_types={"t":
|
|
2629
|
+
input_types={"t": tile(dtype=Any, shape=Tuple[int, ...]), "offset": Tuple[int, ...], "shape": Tuple[int, ...]},
|
|
2366
2630
|
value_func=tile_view_value_func,
|
|
2367
2631
|
dispatch_func=tile_view_dispatch_func,
|
|
2368
2632
|
defaults={"shape": None},
|
|
@@ -2379,116 +2643,363 @@ add_builtin(
|
|
|
2379
2643
|
)
|
|
2380
2644
|
|
|
2381
2645
|
|
|
2382
|
-
def
|
|
2646
|
+
def tile_squeeze_value_func(arg_types, arg_values):
|
|
2647
|
+
# return generic type (for doc builds)
|
|
2383
2648
|
if arg_types is None:
|
|
2384
|
-
return
|
|
2649
|
+
return tile(dtype=Any, shape=Tuple[int, ...])
|
|
2385
2650
|
|
|
2386
|
-
|
|
2387
|
-
|
|
2388
|
-
|
|
2651
|
+
tile_type = arg_types["t"]
|
|
2652
|
+
shape = tile_type.shape
|
|
2653
|
+
strides = tile_type.strides
|
|
2654
|
+
ndim = len(shape)
|
|
2389
2655
|
|
|
2656
|
+
if "axis" in arg_values:
|
|
2657
|
+
axis = arg_values["axis"]
|
|
2390
2658
|
|
|
2391
|
-
|
|
2392
|
-
|
|
2393
|
-
|
|
2659
|
+
if not isinstance(axis, Sequence):
|
|
2660
|
+
# promote to tuple
|
|
2661
|
+
axis = (axis,)
|
|
2394
2662
|
|
|
2395
|
-
|
|
2663
|
+
# promote negative indices to their positive equivalents
|
|
2664
|
+
axis = tuple([a if a >= 0 else a + ndim for a in axis])
|
|
2396
2665
|
|
|
2397
|
-
|
|
2398
|
-
|
|
2666
|
+
# validate that specified axes are size 1
|
|
2667
|
+
for a in axis:
|
|
2668
|
+
if shape[a] != 1:
|
|
2669
|
+
raise ValueError(
|
|
2670
|
+
f"Cannot select an axis to squeeze out which has size not equal to one, axis={a}, size={shape[a]}"
|
|
2671
|
+
)
|
|
2399
2672
|
|
|
2400
|
-
|
|
2673
|
+
# build new shape by skipping specified axes (if size is 1)
|
|
2674
|
+
new_shape = tuple(dim for i, dim in enumerate(shape) if i not in axis)
|
|
2675
|
+
new_strides = tuple(stride for i, stride in enumerate(strides) if i not in axis)
|
|
2401
2676
|
|
|
2677
|
+
else:
|
|
2678
|
+
# no axis specified: remove all singleton dimensions
|
|
2679
|
+
new_shape = tuple(dim for dim in shape if dim != 1)
|
|
2680
|
+
new_strides = tuple(stride for i, stride in enumerate(strides) if shape[i] != 1)
|
|
2402
2681
|
|
|
2403
|
-
|
|
2404
|
-
"
|
|
2405
|
-
|
|
2406
|
-
|
|
2407
|
-
|
|
2408
|
-
|
|
2409
|
-
|
|
2682
|
+
# force source tile to shared memory
|
|
2683
|
+
tile_type.storage = "shared"
|
|
2684
|
+
|
|
2685
|
+
output = tile(
|
|
2686
|
+
dtype=tile_type.dtype,
|
|
2687
|
+
shape=new_shape,
|
|
2688
|
+
strides=new_strides,
|
|
2689
|
+
layout=tile_type.layout,
|
|
2690
|
+
storage="shared",
|
|
2691
|
+
owner=False,
|
|
2692
|
+
)
|
|
2693
|
+
return output
|
|
2410
2694
|
|
|
2411
|
-
:param dst: The destination tile to assign to
|
|
2412
|
-
:param src: The source tile to read values from
|
|
2413
|
-
:param offset: Offset in the destination tile to write to""",
|
|
2414
|
-
group="Tile Primitives",
|
|
2415
|
-
export=False,
|
|
2416
|
-
)
|
|
2417
2695
|
|
|
2418
|
-
|
|
2419
|
-
|
|
2420
|
-
"assign",
|
|
2421
|
-
input_types={"dst": Tile(dtype=Any, shape=Any), "i": int, "src": Scalar},
|
|
2422
|
-
value_func=tile_assign_value_func,
|
|
2423
|
-
group="Tile Primitives",
|
|
2424
|
-
export=False,
|
|
2425
|
-
hidden=True,
|
|
2426
|
-
)
|
|
2696
|
+
def tile_squeeze_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
2697
|
+
source_tile = arg_values["t"]
|
|
2427
2698
|
|
|
2428
|
-
|
|
2429
|
-
"assign",
|
|
2430
|
-
input_types={"dst": Tile(dtype=Any, shape=Any), "i": int, "j": int, "src": Scalar},
|
|
2431
|
-
value_func=tile_assign_value_func,
|
|
2432
|
-
group="Tile Primitives",
|
|
2433
|
-
export=False,
|
|
2434
|
-
hidden=True,
|
|
2435
|
-
)
|
|
2699
|
+
return ((source_tile,), (return_type,))
|
|
2436
2700
|
|
|
2437
|
-
add_builtin(
|
|
2438
|
-
"assign",
|
|
2439
|
-
input_types={"dst": Tile(dtype=Any, shape=Any), "i": int, "j": int, "k": int, "src": Scalar},
|
|
2440
|
-
value_func=tile_assign_value_func,
|
|
2441
|
-
group="Tile Primitives",
|
|
2442
|
-
export=False,
|
|
2443
|
-
hidden=True,
|
|
2444
|
-
)
|
|
2445
2701
|
|
|
2446
2702
|
add_builtin(
|
|
2447
|
-
"
|
|
2448
|
-
input_types={"
|
|
2449
|
-
|
|
2703
|
+
"tile_squeeze",
|
|
2704
|
+
input_types={"t": tile(dtype=Any, shape=Tuple[int, ...]), "axis": Tuple[int, ...]},
|
|
2705
|
+
defaults={"axis": None},
|
|
2706
|
+
value_func=tile_squeeze_value_func,
|
|
2707
|
+
dispatch_func=tile_squeeze_dispatch_func,
|
|
2708
|
+
variadic=False,
|
|
2709
|
+
doc="""Return a squeezed view of a tile with the same data.
|
|
2710
|
+
|
|
2711
|
+
:param t: Input tile to squeeze
|
|
2712
|
+
:param axis: A subset of the entries of length one in the shape (optional)
|
|
2713
|
+
:returns: The input tile but with all or a subset of the dimensions of length one removed.""",
|
|
2450
2714
|
group="Tile Primitives",
|
|
2451
2715
|
export=False,
|
|
2452
|
-
hidden=True,
|
|
2453
2716
|
)
|
|
2454
2717
|
|
|
2455
2718
|
|
|
2456
|
-
def
|
|
2719
|
+
def tile_reshape_value_func(arg_types, arg_values):
|
|
2457
2720
|
# return generic type (for doc builds)
|
|
2458
2721
|
if arg_types is None:
|
|
2459
|
-
return
|
|
2460
|
-
|
|
2461
|
-
if len(arg_types) != 1:
|
|
2462
|
-
raise TypeError(f"tile() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
2722
|
+
return tile(dtype=Any, shape=Tuple[int, ...])
|
|
2463
2723
|
|
|
2464
|
-
|
|
2465
|
-
length = None
|
|
2724
|
+
tile_type = arg_types["t"]
|
|
2466
2725
|
|
|
2467
|
-
|
|
2468
|
-
|
|
2469
|
-
|
|
2470
|
-
|
|
2471
|
-
else:
|
|
2472
|
-
dtype = arg_types["x"]
|
|
2473
|
-
shape = (warp.codegen.options["block_dim"],)
|
|
2726
|
+
# calculate total size of tile_type
|
|
2727
|
+
size = 1
|
|
2728
|
+
for s in tile_type.shape:
|
|
2729
|
+
size *= int(s)
|
|
2474
2730
|
|
|
2475
|
-
|
|
2731
|
+
shape = extract_tuple(arg_values["shape"], as_constant=True)
|
|
2476
2732
|
|
|
2733
|
+
if None in shape:
|
|
2734
|
+
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
2477
2735
|
|
|
2478
|
-
|
|
2479
|
-
|
|
2480
|
-
|
|
2481
|
-
|
|
2482
|
-
|
|
2483
|
-
|
|
2736
|
+
# check for -1 dimension and reformat
|
|
2737
|
+
if -1 in shape:
|
|
2738
|
+
idx = size
|
|
2739
|
+
denom = 1
|
|
2740
|
+
minus_one_count = 0
|
|
2741
|
+
for i, d in enumerate(shape):
|
|
2742
|
+
if d == -1:
|
|
2743
|
+
idx = i
|
|
2744
|
+
minus_one_count += 1
|
|
2745
|
+
else:
|
|
2746
|
+
denom *= d
|
|
2747
|
+
if minus_one_count > 1:
|
|
2748
|
+
raise RuntimeError("Cannot infer shape if more than one index is -1.")
|
|
2749
|
+
new_shape = list(shape)
|
|
2750
|
+
new_shape[idx] = int(size / denom)
|
|
2751
|
+
shape = tuple(new_shape)
|
|
2752
|
+
|
|
2753
|
+
# calculate total size of new shape
|
|
2754
|
+
new_size = 1
|
|
2755
|
+
for s in shape:
|
|
2756
|
+
new_size *= int(s)
|
|
2757
|
+
|
|
2758
|
+
if new_size != size:
|
|
2759
|
+
raise ValueError(f"New shape {shape} has total size {new_size} which does not match original size {size}")
|
|
2760
|
+
|
|
2761
|
+
# compute new strides matching shape
|
|
2762
|
+
strides = []
|
|
2763
|
+
stride = 1
|
|
2764
|
+
for s in reversed(shape):
|
|
2765
|
+
strides.append(stride)
|
|
2766
|
+
stride *= s
|
|
2767
|
+
strides = tuple(reversed(strides))
|
|
2484
2768
|
|
|
2485
|
-
|
|
2769
|
+
# force source tile to shared memory
|
|
2770
|
+
tile_type.storage = "shared"
|
|
2771
|
+
|
|
2772
|
+
output = tile(
|
|
2773
|
+
dtype=tile_type.dtype, shape=shape, strides=strides, layout=tile_type.layout, storage="shared", owner=False
|
|
2774
|
+
)
|
|
2775
|
+
return output
|
|
2776
|
+
|
|
2777
|
+
|
|
2778
|
+
def tile_reshape_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
2779
|
+
tile = arg_values["t"]
|
|
2780
|
+
|
|
2781
|
+
return ((tile,), (return_type,))
|
|
2782
|
+
|
|
2783
|
+
|
|
2784
|
+
add_builtin(
|
|
2785
|
+
"tile_reshape",
|
|
2786
|
+
input_types={"t": tile(dtype=Any, shape=Tuple[int, ...]), "shape": Tuple[int, ...]},
|
|
2787
|
+
value_func=tile_reshape_value_func,
|
|
2788
|
+
dispatch_func=tile_reshape_dispatch_func,
|
|
2789
|
+
variadic=False,
|
|
2790
|
+
doc="""Return a reshaped view of a tile with the same data.
|
|
2791
|
+
|
|
2792
|
+
:param t: Input tile to reshape
|
|
2793
|
+
:param shape: New shape for the tile
|
|
2794
|
+
:returns: A tile containing the same data as the input tile, but arranged in a new shape.""",
|
|
2795
|
+
group="Tile Primitives",
|
|
2796
|
+
export=False,
|
|
2797
|
+
)
|
|
2798
|
+
|
|
2799
|
+
|
|
2800
|
+
def tile_astype_value_func(arg_types, arg_values):
|
|
2801
|
+
# return generic type (for doc builds)
|
|
2802
|
+
if arg_types is None:
|
|
2803
|
+
return tile(dtype=Any, shape=Tuple[int, ...])
|
|
2804
|
+
|
|
2805
|
+
tile_type = arg_types["t"]
|
|
2806
|
+
dtype = arg_values["dtype"]
|
|
2807
|
+
|
|
2808
|
+
return tile(dtype=dtype, shape=tile_type.shape)
|
|
2809
|
+
|
|
2810
|
+
|
|
2811
|
+
def tile_astype_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
2812
|
+
tile = arg_values["t"]
|
|
2813
|
+
|
|
2814
|
+
return ((tile,), (return_type,))
|
|
2815
|
+
|
|
2816
|
+
|
|
2817
|
+
add_builtin(
|
|
2818
|
+
"tile_astype",
|
|
2819
|
+
input_types={"t": tile(dtype=Scalar, shape=Tuple[int, ...]), "dtype": Scalar},
|
|
2820
|
+
value_func=tile_astype_value_func,
|
|
2821
|
+
dispatch_func=tile_astype_dispatch_func,
|
|
2822
|
+
variadic=False,
|
|
2823
|
+
doc="""Return a new tile with the same data as the input tile, but with a different data type.
|
|
2824
|
+
|
|
2825
|
+
:param t: Input tile
|
|
2826
|
+
:param dtype: New data type for the tile
|
|
2827
|
+
:returns: A tile with the same data as the input tile, but with a different data type""",
|
|
2828
|
+
group="Tile Primitives",
|
|
2829
|
+
export=False,
|
|
2830
|
+
)
|
|
2831
|
+
|
|
2832
|
+
|
|
2833
|
+
def tile_assign_value_func(arg_types, arg_values):
|
|
2834
|
+
if arg_types is None:
|
|
2835
|
+
return None
|
|
2836
|
+
|
|
2837
|
+
# force the destination tile to shared memory
|
|
2838
|
+
arg_types["dst"].storage = "shared"
|
|
2839
|
+
return None
|
|
2840
|
+
|
|
2841
|
+
|
|
2842
|
+
def tile_assign_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2843
|
+
dst = args["dst"]
|
|
2844
|
+
src = args["src"]
|
|
2845
|
+
|
|
2846
|
+
if "offset" in args:
|
|
2847
|
+
offset = extract_tuple(args["offset"])
|
|
2848
|
+
else:
|
|
2849
|
+
offset = (0,) * len(dst.type.shape)
|
|
2850
|
+
|
|
2851
|
+
func_args = (dst, src, *offset)
|
|
2852
|
+
template_args = []
|
|
2853
|
+
|
|
2854
|
+
return (func_args, template_args)
|
|
2855
|
+
|
|
2856
|
+
|
|
2857
|
+
add_builtin(
|
|
2858
|
+
"tile_assign",
|
|
2859
|
+
input_types={
|
|
2860
|
+
"dst": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
2861
|
+
"src": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
2862
|
+
"offset": Tuple[int, ...],
|
|
2863
|
+
},
|
|
2864
|
+
value_func=tile_assign_value_func,
|
|
2865
|
+
dispatch_func=tile_assign_dispatch_func,
|
|
2866
|
+
defaults={"offset": None},
|
|
2867
|
+
doc="""Assign a tile to a subrange of a destination tile.
|
|
2868
|
+
|
|
2869
|
+
:param dst: The destination tile to assign to
|
|
2870
|
+
:param src: The source tile to read values from
|
|
2871
|
+
:param offset: Offset in the destination tile to write to""",
|
|
2872
|
+
group="Tile Primitives",
|
|
2873
|
+
export=False,
|
|
2874
|
+
)
|
|
2875
|
+
|
|
2876
|
+
# handles expressions like tile[i,j] = 1.0
|
|
2877
|
+
add_builtin(
|
|
2878
|
+
"assign",
|
|
2879
|
+
input_types={"dst": tile(dtype=Any, shape=Tuple[int]), "i": int, "src": Any},
|
|
2880
|
+
value_func=tile_assign_value_func,
|
|
2881
|
+
group="Tile Primitives",
|
|
2882
|
+
export=False,
|
|
2883
|
+
hidden=True,
|
|
2884
|
+
)
|
|
2885
|
+
|
|
2886
|
+
add_builtin(
|
|
2887
|
+
"assign",
|
|
2888
|
+
input_types={"dst": tile(dtype=Any, shape=Tuple[int, int]), "i": int, "j": int, "src": Any},
|
|
2889
|
+
value_func=tile_assign_value_func,
|
|
2890
|
+
group="Tile Primitives",
|
|
2891
|
+
export=False,
|
|
2892
|
+
hidden=True,
|
|
2893
|
+
)
|
|
2894
|
+
|
|
2895
|
+
add_builtin(
|
|
2896
|
+
"assign",
|
|
2897
|
+
input_types={"dst": tile(dtype=Any, shape=Tuple[int, int, int]), "i": int, "j": int, "k": int, "src": Any},
|
|
2898
|
+
value_func=tile_assign_value_func,
|
|
2899
|
+
group="Tile Primitives",
|
|
2900
|
+
export=False,
|
|
2901
|
+
hidden=True,
|
|
2902
|
+
)
|
|
2903
|
+
|
|
2904
|
+
add_builtin(
|
|
2905
|
+
"assign",
|
|
2906
|
+
input_types={
|
|
2907
|
+
"dst": tile(dtype=Any, shape=Tuple[int, int, int, int]),
|
|
2908
|
+
"i": int,
|
|
2909
|
+
"j": int,
|
|
2910
|
+
"k": int,
|
|
2911
|
+
"l": int,
|
|
2912
|
+
"src": Any,
|
|
2913
|
+
},
|
|
2914
|
+
value_func=tile_assign_value_func,
|
|
2915
|
+
group="Tile Primitives",
|
|
2916
|
+
export=False,
|
|
2917
|
+
hidden=True,
|
|
2918
|
+
)
|
|
2919
|
+
|
|
2920
|
+
|
|
2921
|
+
def tile_value_func(arg_types, arg_values):
|
|
2922
|
+
# return generic type (for doc builds)
|
|
2923
|
+
if arg_types is None:
|
|
2924
|
+
return tile(dtype=Any, shape=Tuple)
|
|
2925
|
+
|
|
2926
|
+
if len(arg_types) > 2:
|
|
2927
|
+
raise TypeError(f"tile() takes 1 positional argument and 1 optional argument but {len(arg_types)} were given")
|
|
2928
|
+
|
|
2929
|
+
preserve_type = arg_values["preserve_type"]
|
|
2930
|
+
|
|
2931
|
+
if preserve_type:
|
|
2932
|
+
dtype = arg_types["x"]
|
|
2933
|
+
shape = (warp.codegen.options["block_dim"],)
|
|
2934
|
+
|
|
2935
|
+
return tile(dtype=dtype, shape=shape)
|
|
2936
|
+
|
|
2937
|
+
else:
|
|
2938
|
+
if type_is_vector(arg_types["x"]):
|
|
2939
|
+
dtype = arg_types["x"]._wp_scalar_type_
|
|
2940
|
+
length = arg_types["x"]._shape_[0]
|
|
2941
|
+
shape = (length, warp.codegen.options["block_dim"])
|
|
2942
|
+
elif type_is_quaternion(arg_types["x"]):
|
|
2943
|
+
dtype = arg_types["x"]._wp_scalar_type_
|
|
2944
|
+
shape = (4, warp.codegen.options["block_dim"])
|
|
2945
|
+
elif type_is_matrix(arg_types["x"]):
|
|
2946
|
+
dtype = arg_types["x"]._wp_scalar_type_
|
|
2947
|
+
rows = arg_types["x"]._shape_[0]
|
|
2948
|
+
cols = arg_types["x"]._shape_[1]
|
|
2949
|
+
shape = (rows, cols, warp.codegen.options["block_dim"])
|
|
2950
|
+
else:
|
|
2951
|
+
dtype = arg_types["x"]
|
|
2952
|
+
shape = (warp.codegen.options["block_dim"],)
|
|
2953
|
+
|
|
2954
|
+
return tile(dtype=dtype, shape=shape)
|
|
2955
|
+
|
|
2956
|
+
|
|
2957
|
+
def tile_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
2958
|
+
x = arg_values["x"]
|
|
2959
|
+
preserve_type = get_arg_value(arg_values["preserve_type"])
|
|
2960
|
+
|
|
2961
|
+
if preserve_type:
|
|
2962
|
+
dtype = x.type
|
|
2963
|
+
return ((x,), (dtype,))
|
|
2964
|
+
|
|
2965
|
+
else:
|
|
2966
|
+
if type_is_vector(x.type):
|
|
2967
|
+
dtype = x.type._wp_scalar_type_
|
|
2968
|
+
length = x.type._shape_[0]
|
|
2969
|
+
return ((x,), (dtype, length))
|
|
2970
|
+
elif type_is_quaternion(x.type):
|
|
2971
|
+
dtype = x.type._wp_scalar_type_
|
|
2972
|
+
return ((x,), (dtype, 4))
|
|
2973
|
+
elif type_is_matrix(x.type):
|
|
2974
|
+
dtype = x.type._wp_scalar_type_
|
|
2975
|
+
rows = x.type._shape_[0]
|
|
2976
|
+
cols = x.type._shape_[1]
|
|
2977
|
+
return ((x,), (rows, cols, dtype))
|
|
2978
|
+
else:
|
|
2979
|
+
dtype = x.type
|
|
2980
|
+
return ((x,), (dtype,))
|
|
2981
|
+
|
|
2982
|
+
|
|
2983
|
+
add_builtin(
|
|
2984
|
+
"tile",
|
|
2985
|
+
input_types={"x": Any, "preserve_type": bool},
|
|
2986
|
+
value_func=tile_value_func,
|
|
2987
|
+
dispatch_func=tile_dispatch_func,
|
|
2988
|
+
variadic=True,
|
|
2989
|
+
defaults={"preserve_type": False},
|
|
2990
|
+
doc="""Construct a new tile from per-thread kernel values.
|
|
2991
|
+
|
|
2992
|
+
This function converts values computed using scalar kernel code to a tile representation for input into collective operations.
|
|
2486
2993
|
|
|
2487
2994
|
* If the input value is a scalar, then the resulting tile has ``shape=(block_dim,)``
|
|
2488
2995
|
* If the input value is a vector, then the resulting tile has ``shape=(length(vector), block_dim)``
|
|
2996
|
+
* If the input value is a vector, and ``preserve_type=True``, then the resulting tile has ``dtype=vector`` and ``shape=(block_dim,)``
|
|
2997
|
+
* If the input value is a matrix, then the resulting tile has ``shape=(rows, cols, block_dim)``
|
|
2998
|
+
* If the input value is a matrix, and ``preserve_type=True``, then the resulting tile has ``dtype=matrix`` and ``shape=(block_dim,)``
|
|
2489
2999
|
|
|
2490
3000
|
:param x: A per-thread local value, e.g. scalar, vector, or matrix.
|
|
2491
|
-
:
|
|
3001
|
+
:param preserve_type: If true, the tile will have the same data type as the input value.
|
|
3002
|
+
:returns: If ``preserve_type=True``, a tile of type ``x.type`` of length ``block_dim``. Otherwise, an N-dimensional tile such that the first N-1 dimensions match the shape of ``x`` and the final dimension is of size ``block_dim``.
|
|
2492
3003
|
|
|
2493
3004
|
This example shows how to create a linear sequence from thread variables:
|
|
2494
3005
|
|
|
@@ -2511,13 +3022,14 @@ add_builtin(
|
|
|
2511
3022
|
""",
|
|
2512
3023
|
group="Tile Primitives",
|
|
2513
3024
|
export=False,
|
|
3025
|
+
hidden=True,
|
|
2514
3026
|
)
|
|
2515
3027
|
|
|
2516
3028
|
|
|
2517
3029
|
def untile_value_func(arg_types, arg_values):
|
|
2518
3030
|
# return generic type (for doc builds)
|
|
2519
3031
|
if arg_types is None:
|
|
2520
|
-
return
|
|
3032
|
+
return Any
|
|
2521
3033
|
|
|
2522
3034
|
if len(arg_types) != 1:
|
|
2523
3035
|
raise TypeError(f"untile() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
@@ -2536,13 +3048,15 @@ def untile_value_func(arg_types, arg_values):
|
|
|
2536
3048
|
return t.dtype
|
|
2537
3049
|
elif len(t.shape) == 2:
|
|
2538
3050
|
return warp.types.vector(t.shape[0], t.dtype)
|
|
3051
|
+
elif len(t.shape) == 3:
|
|
3052
|
+
return warp.types.matrix((t.shape[0], t.shape[1]), t.dtype)
|
|
2539
3053
|
else:
|
|
2540
3054
|
raise ValueError(f"untile() argument must have a positive size in dimension 0, but got {t.shape[0]}")
|
|
2541
3055
|
|
|
2542
3056
|
|
|
2543
3057
|
add_builtin(
|
|
2544
3058
|
"untile",
|
|
2545
|
-
input_types={"a":
|
|
3059
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
2546
3060
|
value_func=untile_value_func,
|
|
2547
3061
|
variadic=True,
|
|
2548
3062
|
doc="""Convert a tile back to per-thread values.
|
|
@@ -2592,7 +3106,7 @@ add_builtin(
|
|
|
2592
3106
|
def tile_extract_value_func(arg_types, arg_values):
|
|
2593
3107
|
# return generic type (for doc builds)
|
|
2594
3108
|
if arg_types is None:
|
|
2595
|
-
return
|
|
3109
|
+
return Any
|
|
2596
3110
|
|
|
2597
3111
|
# force the input tile to shared memory
|
|
2598
3112
|
arg_types["a"].storage = "shared"
|
|
@@ -2602,10 +3116,10 @@ def tile_extract_value_func(arg_types, arg_values):
|
|
|
2602
3116
|
|
|
2603
3117
|
add_builtin(
|
|
2604
3118
|
"tile_extract",
|
|
2605
|
-
input_types={"a":
|
|
3119
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int]), "i": int},
|
|
2606
3120
|
value_func=tile_extract_value_func,
|
|
2607
3121
|
variadic=False,
|
|
2608
|
-
doc="""Extract a single element from the tile
|
|
3122
|
+
doc="""Extract a single element from the tile.
|
|
2609
3123
|
|
|
2610
3124
|
This function will extract an element from the tile and broadcast its value to all threads in the block.
|
|
2611
3125
|
|
|
@@ -2619,13 +3133,12 @@ add_builtin(
|
|
|
2619
3133
|
export=False,
|
|
2620
3134
|
)
|
|
2621
3135
|
|
|
2622
|
-
|
|
2623
3136
|
add_builtin(
|
|
2624
3137
|
"tile_extract",
|
|
2625
|
-
input_types={"a":
|
|
3138
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, int]), "i": int, "j": int},
|
|
2626
3139
|
value_func=tile_extract_value_func,
|
|
2627
3140
|
variadic=False,
|
|
2628
|
-
doc="""Extract a single element from the tile
|
|
3141
|
+
doc="""Extract a single element from the tile.
|
|
2629
3142
|
|
|
2630
3143
|
This function will extract an element from the tile and broadcast its value to all threads in the block.
|
|
2631
3144
|
|
|
@@ -2642,10 +3155,10 @@ add_builtin(
|
|
|
2642
3155
|
|
|
2643
3156
|
add_builtin(
|
|
2644
3157
|
"tile_extract",
|
|
2645
|
-
input_types={"a":
|
|
3158
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, int, int]), "i": int, "j": int, "k": int},
|
|
2646
3159
|
value_func=tile_extract_value_func,
|
|
2647
3160
|
variadic=False,
|
|
2648
|
-
doc="""Extract a single element from the tile
|
|
3161
|
+
doc="""Extract a single element from the tile.
|
|
2649
3162
|
|
|
2650
3163
|
This function will extract an element from the tile and broadcast its value to all threads in the block.
|
|
2651
3164
|
|
|
@@ -2663,10 +3176,10 @@ add_builtin(
|
|
|
2663
3176
|
|
|
2664
3177
|
add_builtin(
|
|
2665
3178
|
"tile_extract",
|
|
2666
|
-
input_types={"a":
|
|
3179
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, int, int, int]), "i": int, "j": int, "k": int, "l": int},
|
|
2667
3180
|
value_func=tile_extract_value_func,
|
|
2668
3181
|
variadic=False,
|
|
2669
|
-
doc="""Extract a single element from the tile
|
|
3182
|
+
doc="""Extract a single element from the tile.
|
|
2670
3183
|
|
|
2671
3184
|
This function will extract an element from the tile and broadcast its value to all threads in the block.
|
|
2672
3185
|
|
|
@@ -2684,10 +3197,90 @@ add_builtin(
|
|
|
2684
3197
|
)
|
|
2685
3198
|
|
|
2686
3199
|
|
|
3200
|
+
def tile_inplace_value_func(arg_types, arg_values):
|
|
3201
|
+
if not types_equal(arg_types["a"].dtype, arg_types["value"]):
|
|
3202
|
+
raise TypeError(
|
|
3203
|
+
f"'value' must have the same dtype as target tile for inplace ops, got {arg_types['a'].dtype} and {arg_types['value']}"
|
|
3204
|
+
)
|
|
3205
|
+
|
|
3206
|
+
# force the input tile to shared memory
|
|
3207
|
+
# as inplace addition/subtraction relies on shared memory atomics
|
|
3208
|
+
arg_types["a"].storage = "shared"
|
|
3209
|
+
|
|
3210
|
+
return None
|
|
3211
|
+
|
|
3212
|
+
|
|
3213
|
+
add_builtin(
|
|
3214
|
+
"tile_add_inplace",
|
|
3215
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
|
|
3216
|
+
value_func=tile_inplace_value_func,
|
|
3217
|
+
group="Tile Primitives",
|
|
3218
|
+
hidden=True,
|
|
3219
|
+
export=False,
|
|
3220
|
+
)
|
|
3221
|
+
add_builtin(
|
|
3222
|
+
"tile_add_inplace",
|
|
3223
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
|
|
3224
|
+
value_func=tile_inplace_value_func,
|
|
3225
|
+
group="Tile Primitives",
|
|
3226
|
+
hidden=True,
|
|
3227
|
+
export=False,
|
|
3228
|
+
)
|
|
3229
|
+
add_builtin(
|
|
3230
|
+
"tile_add_inplace",
|
|
3231
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
|
|
3232
|
+
value_func=tile_inplace_value_func,
|
|
3233
|
+
group="Tile Primitives",
|
|
3234
|
+
hidden=True,
|
|
3235
|
+
export=False,
|
|
3236
|
+
)
|
|
3237
|
+
add_builtin(
|
|
3238
|
+
"tile_add_inplace",
|
|
3239
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
|
|
3240
|
+
value_func=tile_inplace_value_func,
|
|
3241
|
+
group="Tile Primitives",
|
|
3242
|
+
hidden=True,
|
|
3243
|
+
export=False,
|
|
3244
|
+
)
|
|
3245
|
+
|
|
3246
|
+
add_builtin(
|
|
3247
|
+
"tile_sub_inplace",
|
|
3248
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
|
|
3249
|
+
value_func=tile_inplace_value_func,
|
|
3250
|
+
group="Tile Primitives",
|
|
3251
|
+
hidden=True,
|
|
3252
|
+
export=False,
|
|
3253
|
+
)
|
|
3254
|
+
add_builtin(
|
|
3255
|
+
"tile_sub_inplace",
|
|
3256
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
|
|
3257
|
+
value_func=tile_inplace_value_func,
|
|
3258
|
+
group="Tile Primitives",
|
|
3259
|
+
hidden=True,
|
|
3260
|
+
export=False,
|
|
3261
|
+
)
|
|
3262
|
+
add_builtin(
|
|
3263
|
+
"tile_sub_inplace",
|
|
3264
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
|
|
3265
|
+
value_func=tile_inplace_value_func,
|
|
3266
|
+
group="Tile Primitives",
|
|
3267
|
+
hidden=True,
|
|
3268
|
+
export=False,
|
|
3269
|
+
)
|
|
3270
|
+
add_builtin(
|
|
3271
|
+
"tile_sub_inplace",
|
|
3272
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
|
|
3273
|
+
value_func=tile_inplace_value_func,
|
|
3274
|
+
group="Tile Primitives",
|
|
3275
|
+
hidden=True,
|
|
3276
|
+
export=False,
|
|
3277
|
+
)
|
|
3278
|
+
|
|
3279
|
+
|
|
2687
3280
|
def tile_transpose_value_func(arg_types, arg_values):
|
|
2688
3281
|
# return generic type (for doc builds)
|
|
2689
3282
|
if arg_types is None:
|
|
2690
|
-
return
|
|
3283
|
+
return tile(dtype=Any, shape=Tuple[int, int])
|
|
2691
3284
|
|
|
2692
3285
|
if len(arg_types) != 1:
|
|
2693
3286
|
raise TypeError(f"tile_transpose() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
@@ -2708,10 +3301,9 @@ def tile_transpose_value_func(arg_types, arg_values):
|
|
|
2708
3301
|
# force the input tile to shared memory
|
|
2709
3302
|
t.storage = "shared"
|
|
2710
3303
|
|
|
2711
|
-
return
|
|
3304
|
+
return tile(
|
|
2712
3305
|
dtype=t.dtype,
|
|
2713
3306
|
shape=t.shape[::-1],
|
|
2714
|
-
op="transpose",
|
|
2715
3307
|
storage=t.storage,
|
|
2716
3308
|
strides=t.strides[::-1],
|
|
2717
3309
|
layout=layout,
|
|
@@ -2721,7 +3313,7 @@ def tile_transpose_value_func(arg_types, arg_values):
|
|
|
2721
3313
|
|
|
2722
3314
|
add_builtin(
|
|
2723
3315
|
"tile_transpose",
|
|
2724
|
-
input_types={"a":
|
|
3316
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, int])},
|
|
2725
3317
|
value_func=tile_transpose_value_func,
|
|
2726
3318
|
variadic=True,
|
|
2727
3319
|
doc="""Transpose a tile.
|
|
@@ -2739,12 +3331,16 @@ add_builtin(
|
|
|
2739
3331
|
def tile_broadcast_value_func(arg_types, arg_values):
|
|
2740
3332
|
# return generic type (for doc builds)
|
|
2741
3333
|
if arg_types is None:
|
|
2742
|
-
return
|
|
3334
|
+
return tile(dtype=Any, shape=Tuple[int, ...])
|
|
2743
3335
|
|
|
2744
3336
|
t = arg_types["a"]
|
|
2745
3337
|
|
|
2746
3338
|
# target shape and strides
|
|
2747
|
-
target_shape =
|
|
3339
|
+
target_shape = extract_tuple(arg_values["shape"], as_constant=True)
|
|
3340
|
+
|
|
3341
|
+
if None in target_shape:
|
|
3342
|
+
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
3343
|
+
|
|
2748
3344
|
target_strides = [0] * len(target_shape)
|
|
2749
3345
|
|
|
2750
3346
|
offset = len(target_shape) - len(t.shape)
|
|
@@ -2769,10 +3365,7 @@ def tile_broadcast_value_func(arg_types, arg_values):
|
|
|
2769
3365
|
# force the input tile to shared memory
|
|
2770
3366
|
t.storage = "shared"
|
|
2771
3367
|
|
|
2772
|
-
|
|
2773
|
-
dtype=t.dtype, shape=target_shape, op="broadcast", storage=t.storage, strides=target_strides, owner=False
|
|
2774
|
-
)
|
|
2775
|
-
return tile_type
|
|
3368
|
+
return tile(dtype=t.dtype, shape=target_shape, storage=t.storage, strides=target_strides, owner=False)
|
|
2776
3369
|
|
|
2777
3370
|
|
|
2778
3371
|
def tile_broadcast_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
@@ -2787,7 +3380,7 @@ def tile_broadcast_dispatch_func(arg_types: Mapping[str, type], return_type: Any
|
|
|
2787
3380
|
|
|
2788
3381
|
add_builtin(
|
|
2789
3382
|
"tile_broadcast",
|
|
2790
|
-
input_types={"a":
|
|
3383
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "shape": Tuple[int, ...]},
|
|
2791
3384
|
value_func=tile_broadcast_value_func,
|
|
2792
3385
|
dispatch_func=tile_broadcast_dispatch_func,
|
|
2793
3386
|
variadic=False,
|
|
@@ -2807,7 +3400,7 @@ add_builtin(
|
|
|
2807
3400
|
def tile_sum_value_func(arg_types, arg_values):
|
|
2808
3401
|
# return generic type (for doc builds)
|
|
2809
3402
|
if arg_types is None:
|
|
2810
|
-
return
|
|
3403
|
+
return tile(dtype=Scalar, shape=(1,))
|
|
2811
3404
|
|
|
2812
3405
|
if len(arg_types) != 1:
|
|
2813
3406
|
raise TypeError(f"tile_sum() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
@@ -2817,12 +3410,12 @@ def tile_sum_value_func(arg_types, arg_values):
|
|
|
2817
3410
|
if not is_tile(a):
|
|
2818
3411
|
raise TypeError(f"tile_sum() argument must be a tile, got {a!r}")
|
|
2819
3412
|
|
|
2820
|
-
return
|
|
3413
|
+
return tile(dtype=a.dtype, shape=(1,))
|
|
2821
3414
|
|
|
2822
3415
|
|
|
2823
3416
|
add_builtin(
|
|
2824
3417
|
"tile_sum",
|
|
2825
|
-
input_types={"a":
|
|
3418
|
+
input_types={"a": tile(dtype=Scalar, shape=Tuple[int, ...])},
|
|
2826
3419
|
value_func=tile_sum_value_func,
|
|
2827
3420
|
variadic=True,
|
|
2828
3421
|
doc="""Cooperatively compute the sum of the tile elements using all threads in the block.
|
|
@@ -2856,10 +3449,89 @@ add_builtin(
|
|
|
2856
3449
|
)
|
|
2857
3450
|
|
|
2858
3451
|
|
|
3452
|
+
def tile_sort_value_func(arg_types, arg_values):
|
|
3453
|
+
# return generic type (for doc builds)
|
|
3454
|
+
if arg_types is None:
|
|
3455
|
+
return None
|
|
3456
|
+
|
|
3457
|
+
if len(arg_types) != 2:
|
|
3458
|
+
raise TypeError(
|
|
3459
|
+
f"tile_sort() takes exactly 2 positional arguments (keys and values) but {len(arg_types)} were given"
|
|
3460
|
+
)
|
|
3461
|
+
|
|
3462
|
+
a = arg_types["keys"]
|
|
3463
|
+
b = arg_types["values"]
|
|
3464
|
+
|
|
3465
|
+
if not is_tile(a):
|
|
3466
|
+
raise TypeError(f"First tile_sort() argument must be a tile, got {a!r}")
|
|
3467
|
+
|
|
3468
|
+
if not is_tile(b):
|
|
3469
|
+
raise TypeError(f"Second tile_sort() argument must be a tile, got {b!r}")
|
|
3470
|
+
|
|
3471
|
+
if not (a.dtype is warp.float32 or a.dtype is warp.int32 or a.dtype is warp.uint32):
|
|
3472
|
+
raise TypeError(f"First tile_sort() argument must be a tile of type float or int, got {a.dtype}")
|
|
3473
|
+
|
|
3474
|
+
# set the storage type to the inputs to shared
|
|
3475
|
+
a.storage = "shared"
|
|
3476
|
+
b.storage = "shared"
|
|
3477
|
+
|
|
3478
|
+
if len(a.shape) != len(b.shape):
|
|
3479
|
+
raise ValueError(
|
|
3480
|
+
f"tile_sort() shapes must have the same number of dimensions, got {len(a.shape)} and {len(b.shape)}"
|
|
3481
|
+
)
|
|
3482
|
+
|
|
3483
|
+
for i in range(len(a.shape)):
|
|
3484
|
+
if a.shape[i] != b.shape[i]:
|
|
3485
|
+
raise ValueError(f"tile_sort() shapes do not match on dimension {i}, got {a.shape} and {b.shape}")
|
|
3486
|
+
|
|
3487
|
+
return None
|
|
3488
|
+
|
|
3489
|
+
|
|
3490
|
+
add_builtin(
|
|
3491
|
+
"tile_sort",
|
|
3492
|
+
input_types={"keys": tile(dtype=Any, shape=Tuple[int]), "values": tile(dtype=Any, shape=Tuple[int])},
|
|
3493
|
+
value_func=tile_sort_value_func,
|
|
3494
|
+
variadic=True,
|
|
3495
|
+
doc="""Cooperatively sort the elements of two tiles in ascending order based on the keys, using all threads in the block.
|
|
3496
|
+
|
|
3497
|
+
:param keys: Keys to sort by. Supported key types: :class:`float32`, :class:`int32`, :class:`uint32`. Must be in shared memory.
|
|
3498
|
+
:param values: Values to sort along with keys. No type restrictions. Must be in shared memory.
|
|
3499
|
+
:returns: No return value. Sorts both tiles in-place.
|
|
3500
|
+
|
|
3501
|
+
Example:
|
|
3502
|
+
|
|
3503
|
+
.. code-block:: python
|
|
3504
|
+
|
|
3505
|
+
@wp.kernel
|
|
3506
|
+
def compute():
|
|
3507
|
+
|
|
3508
|
+
keys = wp.tile_arange(32, 0, -1, dtype=int, storage="shared")
|
|
3509
|
+
values = wp.tile_arange(0, 32, 1, dtype=int, storage="shared")
|
|
3510
|
+
wp.tile_sort(keys, values)
|
|
3511
|
+
|
|
3512
|
+
print(keys)
|
|
3513
|
+
print(values)
|
|
3514
|
+
|
|
3515
|
+
|
|
3516
|
+
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
|
|
3517
|
+
|
|
3518
|
+
Prints:
|
|
3519
|
+
|
|
3520
|
+
.. code-block:: text
|
|
3521
|
+
|
|
3522
|
+
[1, 2, ..., 32] = tile(shape=(32), storage=shared)
|
|
3523
|
+
[31, 30, 29, ..., 0] = tile(shape=(32), storage=shared)
|
|
3524
|
+
|
|
3525
|
+
""",
|
|
3526
|
+
group="Tile Primitives",
|
|
3527
|
+
export=False,
|
|
3528
|
+
)
|
|
3529
|
+
|
|
3530
|
+
|
|
2859
3531
|
def tile_min_value_func(arg_types, arg_values):
|
|
2860
3532
|
# return generic type (for doc builds)
|
|
2861
3533
|
if arg_types is None:
|
|
2862
|
-
return
|
|
3534
|
+
return tile(dtype=Scalar, shape=(1,))
|
|
2863
3535
|
|
|
2864
3536
|
if len(arg_types) != 1:
|
|
2865
3537
|
raise TypeError(f"tile_min() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
@@ -2869,12 +3541,12 @@ def tile_min_value_func(arg_types, arg_values):
|
|
|
2869
3541
|
if not is_tile(a):
|
|
2870
3542
|
raise TypeError(f"tile_min() argument must be a tile, got {a!r}")
|
|
2871
3543
|
|
|
2872
|
-
return
|
|
3544
|
+
return tile(dtype=a.dtype, shape=(1,))
|
|
2873
3545
|
|
|
2874
3546
|
|
|
2875
3547
|
add_builtin(
|
|
2876
3548
|
"tile_min",
|
|
2877
|
-
input_types={"a":
|
|
3549
|
+
input_types={"a": tile(dtype=Scalar, shape=Tuple[int, ...])},
|
|
2878
3550
|
value_func=tile_min_value_func,
|
|
2879
3551
|
variadic=True,
|
|
2880
3552
|
doc="""Cooperatively compute the minimum of the tile elements using all threads in the block.
|
|
@@ -2909,10 +3581,63 @@ add_builtin(
|
|
|
2909
3581
|
)
|
|
2910
3582
|
|
|
2911
3583
|
|
|
3584
|
+
def tile_argmin_value_func(arg_types, arg_values):
|
|
3585
|
+
# return generic type (for doc builds)
|
|
3586
|
+
if arg_types is None:
|
|
3587
|
+
return tile(dtype=Int, shape=(1,))
|
|
3588
|
+
|
|
3589
|
+
if len(arg_types) != 1:
|
|
3590
|
+
raise TypeError(f"tile_argmin() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
3591
|
+
|
|
3592
|
+
a = arg_types["a"]
|
|
3593
|
+
|
|
3594
|
+
if not is_tile(a):
|
|
3595
|
+
raise TypeError(f"tile_argmin() argument must be a tile, got {a!r}")
|
|
3596
|
+
|
|
3597
|
+
return tile(dtype=warp.int32, shape=(1,))
|
|
3598
|
+
|
|
3599
|
+
|
|
3600
|
+
add_builtin(
|
|
3601
|
+
"tile_argmin",
|
|
3602
|
+
input_types={"a": tile(dtype=Scalar, shape=Tuple[int, ...])},
|
|
3603
|
+
value_func=tile_argmin_value_func,
|
|
3604
|
+
variadic=True,
|
|
3605
|
+
doc="""Cooperatively compute the index of the minimum element in the tile using all threads in the block.
|
|
3606
|
+
|
|
3607
|
+
:param a: The tile to compute the argmin from
|
|
3608
|
+
:returns: A single-element tile holding the index of the minimum value
|
|
3609
|
+
|
|
3610
|
+
Example:
|
|
3611
|
+
|
|
3612
|
+
.. code-block:: python
|
|
3613
|
+
|
|
3614
|
+
@wp.kernel
|
|
3615
|
+
def compute():
|
|
3616
|
+
|
|
3617
|
+
t = wp.tile_arange(64, 128)
|
|
3618
|
+
s = wp.tile_argmin(t)
|
|
3619
|
+
|
|
3620
|
+
print(s)
|
|
3621
|
+
|
|
3622
|
+
|
|
3623
|
+
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
|
|
3624
|
+
|
|
3625
|
+
Prints:
|
|
3626
|
+
|
|
3627
|
+
.. code-block:: text
|
|
3628
|
+
|
|
3629
|
+
[0] = tile(shape=(1), storage=register)
|
|
3630
|
+
|
|
3631
|
+
""",
|
|
3632
|
+
group="Tile Primitives",
|
|
3633
|
+
export=False,
|
|
3634
|
+
)
|
|
3635
|
+
|
|
3636
|
+
|
|
2912
3637
|
def tile_max_value_func(arg_types, arg_values):
|
|
2913
3638
|
# return generic type (for doc builds)
|
|
2914
3639
|
if arg_types is None:
|
|
2915
|
-
return
|
|
3640
|
+
return tile(dtype=Scalar, shape=(1,))
|
|
2916
3641
|
|
|
2917
3642
|
if len(arg_types) != 1:
|
|
2918
3643
|
raise TypeError(f"tile_max() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
@@ -2922,12 +3647,12 @@ def tile_max_value_func(arg_types, arg_values):
|
|
|
2922
3647
|
if not is_tile(a):
|
|
2923
3648
|
raise TypeError(f"tile_max() argument must be a tile, got {a!r}")
|
|
2924
3649
|
|
|
2925
|
-
return
|
|
3650
|
+
return tile(dtype=a.dtype, shape=(1,))
|
|
2926
3651
|
|
|
2927
3652
|
|
|
2928
3653
|
add_builtin(
|
|
2929
3654
|
"tile_max",
|
|
2930
|
-
input_types={"a":
|
|
3655
|
+
input_types={"a": tile(dtype=Scalar, shape=Tuple[int, ...])},
|
|
2931
3656
|
value_func=tile_max_value_func,
|
|
2932
3657
|
variadic=False,
|
|
2933
3658
|
doc="""Cooperatively compute the maximum of the tile elements using all threads in the block.
|
|
@@ -2961,17 +3686,69 @@ add_builtin(
|
|
|
2961
3686
|
)
|
|
2962
3687
|
|
|
2963
3688
|
|
|
3689
|
+
def tile_argmax_value_func(arg_types, arg_values):
|
|
3690
|
+
# return generic type (for doc builds)
|
|
3691
|
+
if arg_types is None:
|
|
3692
|
+
return tile(dtype=Int, shape=(1,))
|
|
3693
|
+
|
|
3694
|
+
if len(arg_types) != 1:
|
|
3695
|
+
raise TypeError(f"tile_argmax() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
3696
|
+
|
|
3697
|
+
a = arg_types["a"]
|
|
3698
|
+
|
|
3699
|
+
if not is_tile(a):
|
|
3700
|
+
raise TypeError(f"tile_argmax() argument must be a tile, got {a!r}")
|
|
3701
|
+
|
|
3702
|
+
return tile(dtype=warp.int32, shape=(1,))
|
|
3703
|
+
|
|
3704
|
+
|
|
3705
|
+
add_builtin(
|
|
3706
|
+
"tile_argmax",
|
|
3707
|
+
input_types={"a": tile(dtype=Scalar, shape=Tuple[int, ...])},
|
|
3708
|
+
value_func=tile_argmax_value_func,
|
|
3709
|
+
variadic=False,
|
|
3710
|
+
doc="""Cooperatively compute the index of the maximum element in the tile using all threads in the block.
|
|
3711
|
+
|
|
3712
|
+
:param a: The tile to compute the argmax from
|
|
3713
|
+
:returns: A single-element tile holding the index of the maximum value
|
|
3714
|
+
|
|
3715
|
+
Example:
|
|
3716
|
+
|
|
3717
|
+
.. code-block:: python
|
|
3718
|
+
|
|
3719
|
+
@wp.kernel
|
|
3720
|
+
def compute():
|
|
3721
|
+
|
|
3722
|
+
t = wp.tile_arange(64, 128)
|
|
3723
|
+
s = wp.tile_argmax(t)
|
|
3724
|
+
|
|
3725
|
+
print(s)
|
|
3726
|
+
|
|
3727
|
+
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
|
|
3728
|
+
|
|
3729
|
+
Prints:
|
|
3730
|
+
|
|
3731
|
+
.. code-block:: text
|
|
3732
|
+
|
|
3733
|
+
[63] = tile(shape=(1), storage=register)
|
|
3734
|
+
|
|
3735
|
+
""",
|
|
3736
|
+
group="Tile Primitives",
|
|
3737
|
+
export=False,
|
|
3738
|
+
)
|
|
3739
|
+
|
|
3740
|
+
|
|
2964
3741
|
# does type propagation for load()
|
|
2965
3742
|
def tile_reduce_value_func(arg_types, arg_values):
|
|
2966
3743
|
if arg_types is None:
|
|
2967
|
-
return
|
|
3744
|
+
return tile(dtype=Scalar, shape=(1,))
|
|
2968
3745
|
|
|
2969
3746
|
a = arg_types["a"]
|
|
2970
3747
|
|
|
2971
3748
|
if not is_tile(a):
|
|
2972
3749
|
raise TypeError(f"tile_reduce() 'a' argument must be a tile, got {a!r}")
|
|
2973
3750
|
|
|
2974
|
-
return
|
|
3751
|
+
return tile(dtype=a.dtype, shape=(1,))
|
|
2975
3752
|
|
|
2976
3753
|
|
|
2977
3754
|
def tile_reduce_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
@@ -2982,7 +3759,7 @@ def tile_reduce_dispatch_func(input_types: Mapping[str, type], return_type: Any,
|
|
|
2982
3759
|
|
|
2983
3760
|
add_builtin(
|
|
2984
3761
|
"tile_reduce",
|
|
2985
|
-
input_types={"op": Callable, "a":
|
|
3762
|
+
input_types={"op": Callable, "a": tile(dtype=Scalar, shape=Tuple[int, ...])},
|
|
2986
3763
|
value_func=tile_reduce_value_func,
|
|
2987
3764
|
native_func="tile_reduce",
|
|
2988
3765
|
doc="""Apply a custom reduction operator across the tile.
|
|
@@ -3005,37 +3782,164 @@ add_builtin(
|
|
|
3005
3782
|
|
|
3006
3783
|
print(s)
|
|
3007
3784
|
|
|
3008
|
-
wp.launch_tiled(factorial, dim=[1], inputs=[], block_dim=16)
|
|
3785
|
+
wp.launch_tiled(factorial, dim=[1], inputs=[], block_dim=16)
|
|
3786
|
+
|
|
3787
|
+
Prints:
|
|
3788
|
+
|
|
3789
|
+
.. code-block:: text
|
|
3790
|
+
|
|
3791
|
+
[362880] = tile(shape=(1), storage=register)
|
|
3792
|
+
""",
|
|
3793
|
+
group="Tile Primitives",
|
|
3794
|
+
export=False,
|
|
3795
|
+
)
|
|
3796
|
+
|
|
3797
|
+
|
|
3798
|
+
def tile_scan_inclusive_value_func(arg_types, arg_values):
|
|
3799
|
+
# Return type is the same as input type
|
|
3800
|
+
if arg_types is None:
|
|
3801
|
+
return tile(dtype=Scalar, shape=Tuple[int, ...])
|
|
3802
|
+
|
|
3803
|
+
if len(arg_types) != 1:
|
|
3804
|
+
raise TypeError(f"tile_scan_inclusive() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
3805
|
+
|
|
3806
|
+
a = arg_types["a"]
|
|
3807
|
+
|
|
3808
|
+
if not is_tile(a):
|
|
3809
|
+
raise TypeError(f"tile_scan_inclusive() argument must be a tile, got {a!r}")
|
|
3810
|
+
|
|
3811
|
+
# Only allow float32, int32, or uint32 for scan (like tile_sort)
|
|
3812
|
+
if not (a.dtype is warp.float32 or a.dtype is warp.int32 or a.dtype is warp.uint32):
|
|
3813
|
+
raise TypeError(
|
|
3814
|
+
f"tile_scan_inclusive() argument must be a tile of type float32, int32, or uint32, got {a.dtype}"
|
|
3815
|
+
)
|
|
3816
|
+
|
|
3817
|
+
return tile(dtype=a.dtype, shape=a.shape)
|
|
3818
|
+
|
|
3819
|
+
|
|
3820
|
+
def tile_scan_inclusive_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
3821
|
+
func_args = (args["a"],)
|
|
3822
|
+
template_args = ()
|
|
3823
|
+
return (func_args, template_args)
|
|
3824
|
+
|
|
3825
|
+
|
|
3826
|
+
add_builtin(
|
|
3827
|
+
"tile_scan_inclusive",
|
|
3828
|
+
input_types={"a": tile(dtype=Scalar, shape=Tuple[int, ...])},
|
|
3829
|
+
value_func=tile_scan_inclusive_value_func,
|
|
3830
|
+
native_func="tile_scan_inclusive",
|
|
3831
|
+
doc="""Inclusive scan (prefix sum) across the tile.
|
|
3832
|
+
|
|
3833
|
+
This function cooperatively performs an inclusive scan (cumulative sum) across the tile.
|
|
3834
|
+
|
|
3835
|
+
:param a: The input tile. Must be a tile of type float32, int32, or uint32.
|
|
3836
|
+
:returns: A new tile containing the inclusive scan result.
|
|
3837
|
+
|
|
3838
|
+
Example:
|
|
3839
|
+
|
|
3840
|
+
.. code-block:: python
|
|
3841
|
+
|
|
3842
|
+
@wp.kernel
|
|
3843
|
+
def scan_example():
|
|
3844
|
+
t = wp.tile_arange(1, 5, dtype=int)
|
|
3845
|
+
s = wp.tile_scan_inclusive(t)
|
|
3846
|
+
print(s)
|
|
3847
|
+
|
|
3848
|
+
wp.launch_tiled(scan_example, dim=[1], inputs=[], block_dim=16)
|
|
3849
|
+
|
|
3850
|
+
Prints:
|
|
3851
|
+
|
|
3852
|
+
.. code-block:: text
|
|
3853
|
+
|
|
3854
|
+
[1, 3, 6, 10] = tile(shape=(4), storage=register)
|
|
3855
|
+
""",
|
|
3856
|
+
group="Tile Primitives",
|
|
3857
|
+
export=False,
|
|
3858
|
+
)
|
|
3859
|
+
|
|
3860
|
+
|
|
3861
|
+
def tile_scan_exclusive_value_func(arg_types, arg_values):
|
|
3862
|
+
# return generic type (for doc builds)
|
|
3863
|
+
if arg_types is None:
|
|
3864
|
+
return tile(dtype=Scalar, shape=Tuple[int, ...])
|
|
3865
|
+
|
|
3866
|
+
if len(arg_types) != 1:
|
|
3867
|
+
raise TypeError(f"tile_scan_exclusive() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
3868
|
+
|
|
3869
|
+
a = arg_types["a"]
|
|
3870
|
+
|
|
3871
|
+
if not is_tile(a):
|
|
3872
|
+
raise TypeError(f"tile_scan_exclusive() argument must be a tile, got {a!r}")
|
|
3873
|
+
|
|
3874
|
+
# Only allow float32, int32, or uint32 for scan (like tile_sort)
|
|
3875
|
+
if not (a.dtype is warp.float32 or a.dtype is warp.int32 or a.dtype is warp.uint32):
|
|
3876
|
+
raise TypeError(
|
|
3877
|
+
f"tile_scan_exclusive() argument must be a tile of type float32, int32, or uint32, got {a.dtype}"
|
|
3878
|
+
)
|
|
3879
|
+
|
|
3880
|
+
return tile(dtype=a.dtype, shape=a.shape)
|
|
3881
|
+
|
|
3882
|
+
|
|
3883
|
+
def tile_scan_exclusive_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
3884
|
+
func_args = (args["a"],)
|
|
3885
|
+
template_args = ()
|
|
3886
|
+
return (func_args, template_args)
|
|
3887
|
+
|
|
3888
|
+
|
|
3889
|
+
add_builtin(
|
|
3890
|
+
"tile_scan_exclusive",
|
|
3891
|
+
input_types={"a": tile(dtype=Scalar, shape=Tuple[int, ...])},
|
|
3892
|
+
value_func=tile_scan_exclusive_value_func,
|
|
3893
|
+
native_func="tile_scan_exclusive",
|
|
3894
|
+
doc="""Exclusive scan (prefix sum) across the tile.
|
|
3895
|
+
|
|
3896
|
+
This function cooperatively performs an exclusive scan (cumulative sum) across the tile.
|
|
3897
|
+
|
|
3898
|
+
:param a: The input tile. Must be a tile of type float32, int32, or uint32.
|
|
3899
|
+
:returns: A new tile containing the exclusive scan result.
|
|
3900
|
+
|
|
3901
|
+
Example:
|
|
3902
|
+
|
|
3903
|
+
.. code-block:: python
|
|
3904
|
+
|
|
3905
|
+
@wp.kernel
|
|
3906
|
+
def scan_example():
|
|
3907
|
+
t = wp.tile_arange(1, 5, dtype=int)
|
|
3908
|
+
s = wp.tile_scan_exclusive(t)
|
|
3909
|
+
print(s)
|
|
3910
|
+
|
|
3911
|
+
wp.launch_tiled(scan_example, dim=[1], inputs=[], block_dim=16)
|
|
3009
3912
|
|
|
3010
3913
|
Prints:
|
|
3011
3914
|
|
|
3012
3915
|
.. code-block:: text
|
|
3013
3916
|
|
|
3014
|
-
[
|
|
3917
|
+
[0, 1, 3, 6] = tile(shape=(4), storage=register)
|
|
3015
3918
|
""",
|
|
3016
3919
|
group="Tile Primitives",
|
|
3017
3920
|
export=False,
|
|
3018
3921
|
)
|
|
3019
3922
|
|
|
3923
|
+
|
|
3020
3924
|
# maps
|
|
3021
3925
|
|
|
3022
3926
|
|
|
3023
3927
|
# does type propagation for load()
|
|
3024
3928
|
def tile_unary_map_value_func(arg_types, arg_values):
|
|
3025
3929
|
if arg_types is None:
|
|
3026
|
-
return
|
|
3930
|
+
return tile(dtype=Scalar, shape=Tuple[int, ...])
|
|
3027
3931
|
|
|
3028
3932
|
a = arg_types["a"]
|
|
3029
3933
|
|
|
3030
3934
|
if not is_tile(a):
|
|
3031
3935
|
raise TypeError(f"tile_map() 'a' argument must be a tile, got {a!r}")
|
|
3032
3936
|
|
|
3033
|
-
return
|
|
3937
|
+
return tile(dtype=a.dtype, shape=a.shape)
|
|
3034
3938
|
|
|
3035
3939
|
|
|
3036
3940
|
add_builtin(
|
|
3037
3941
|
"tile_map",
|
|
3038
|
-
input_types={"op": Callable, "a":
|
|
3942
|
+
input_types={"op": Callable, "a": tile(dtype=Scalar, shape=Tuple[int, ...])},
|
|
3039
3943
|
value_func=tile_unary_map_value_func,
|
|
3040
3944
|
# dispatch_func=tile_map_dispatch_func,
|
|
3041
3945
|
# variadic=True,
|
|
@@ -3075,7 +3979,7 @@ add_builtin(
|
|
|
3075
3979
|
|
|
3076
3980
|
def tile_binary_map_value_func(arg_types, arg_values):
|
|
3077
3981
|
if arg_types is None:
|
|
3078
|
-
return
|
|
3982
|
+
return tile(dtype=Scalar, shape=Tuple[int, ...])
|
|
3079
3983
|
|
|
3080
3984
|
a = arg_types["a"]
|
|
3081
3985
|
b = arg_types["b"]
|
|
@@ -3100,12 +4004,16 @@ def tile_binary_map_value_func(arg_types, arg_values):
|
|
|
3100
4004
|
if a.shape[i] != b.shape[i]:
|
|
3101
4005
|
raise ValueError(f"tile_map() shapes do not match on dimension {i}, got {a.shape} and {b.shape}")
|
|
3102
4006
|
|
|
3103
|
-
return
|
|
4007
|
+
return tile(dtype=a.dtype, shape=a.shape)
|
|
3104
4008
|
|
|
3105
4009
|
|
|
3106
4010
|
add_builtin(
|
|
3107
4011
|
"tile_map",
|
|
3108
|
-
input_types={
|
|
4012
|
+
input_types={
|
|
4013
|
+
"op": Callable,
|
|
4014
|
+
"a": tile(dtype=Scalar, shape=Tuple[int, ...]),
|
|
4015
|
+
"b": tile(dtype=Scalar, shape=Tuple[int, ...]),
|
|
4016
|
+
},
|
|
3109
4017
|
value_func=tile_binary_map_value_func,
|
|
3110
4018
|
# dispatch_func=tile_map_dispatch_func,
|
|
3111
4019
|
# variadic=True,
|
|
@@ -3255,57 +4163,13 @@ add_builtin(
|
|
|
3255
4163
|
)
|
|
3256
4164
|
|
|
3257
4165
|
|
|
3258
|
-
def mlp_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
3259
|
-
warp.utils.warn(
|
|
3260
|
-
"wp.mlp() is deprecated and will be removed in a future\nversion. Use tile primitives instead.",
|
|
3261
|
-
category=DeprecationWarning,
|
|
3262
|
-
)
|
|
3263
|
-
|
|
3264
|
-
func_args = tuple(args.values())
|
|
3265
|
-
template_args = ()
|
|
3266
|
-
|
|
3267
|
-
return (func_args, template_args)
|
|
3268
|
-
|
|
3269
|
-
|
|
3270
|
-
add_builtin(
|
|
3271
|
-
"mlp",
|
|
3272
|
-
input_types={
|
|
3273
|
-
"weights": array(dtype=float, ndim=2),
|
|
3274
|
-
"bias": array(dtype=float, ndim=1),
|
|
3275
|
-
"activation": Callable,
|
|
3276
|
-
"index": int,
|
|
3277
|
-
"x": array(dtype=float, ndim=2),
|
|
3278
|
-
"out": array(dtype=float, ndim=2),
|
|
3279
|
-
},
|
|
3280
|
-
value_type=None,
|
|
3281
|
-
dispatch_func=mlp_dispatch_func,
|
|
3282
|
-
skip_replay=True,
|
|
3283
|
-
doc="""Evaluate a multi-layer perceptron (MLP) layer in the form: ``out = act(weights*x + bias)``.
|
|
3284
|
-
|
|
3285
|
-
.. deprecated:: 1.6
|
|
3286
|
-
Use :doc:`tile primitives </modules/tiles>` instead.
|
|
3287
|
-
|
|
3288
|
-
:param weights: A layer's network weights with dimensions ``(m, n)``.
|
|
3289
|
-
:param bias: An array with dimensions ``(n)``.
|
|
3290
|
-
:param activation: A ``wp.func`` function that takes a single scalar float as input and returns a scalar float as output
|
|
3291
|
-
:param index: The batch item to process, typically each thread will process one item in the batch, in which case
|
|
3292
|
-
index should be ``wp.tid()``
|
|
3293
|
-
:param x: The feature matrix with dimensions ``(n, b)``
|
|
3294
|
-
:param out: The network output with dimensions ``(m, b)``
|
|
3295
|
-
|
|
3296
|
-
:note: Feature and output matrices are transposed compared to some other frameworks such as PyTorch.
|
|
3297
|
-
All matrices are assumed to be stored in flattened row-major memory layout (NumPy default).""",
|
|
3298
|
-
group="Utility",
|
|
3299
|
-
)
|
|
3300
|
-
|
|
3301
|
-
|
|
3302
4166
|
# ---------------------------------
|
|
3303
4167
|
# Geometry
|
|
3304
4168
|
|
|
3305
4169
|
add_builtin(
|
|
3306
4170
|
"bvh_query_aabb",
|
|
3307
4171
|
input_types={"id": uint64, "low": vec3, "high": vec3},
|
|
3308
|
-
|
|
4172
|
+
value_type=BvhQuery,
|
|
3309
4173
|
group="Geometry",
|
|
3310
4174
|
doc="""Construct an axis-aligned bounding box query against a BVH object.
|
|
3311
4175
|
|
|
@@ -3320,7 +4184,7 @@ add_builtin(
|
|
|
3320
4184
|
add_builtin(
|
|
3321
4185
|
"bvh_query_ray",
|
|
3322
4186
|
input_types={"id": uint64, "start": vec3, "dir": vec3},
|
|
3323
|
-
|
|
4187
|
+
value_type=BvhQuery,
|
|
3324
4188
|
group="Geometry",
|
|
3325
4189
|
doc="""Construct a ray query against a BVH object.
|
|
3326
4190
|
|
|
@@ -3380,7 +4244,7 @@ add_builtin(
|
|
|
3380
4244
|
"point": vec3,
|
|
3381
4245
|
"max_dist": float,
|
|
3382
4246
|
},
|
|
3383
|
-
|
|
4247
|
+
value_type=MeshQueryPoint,
|
|
3384
4248
|
group="Geometry",
|
|
3385
4249
|
doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space.
|
|
3386
4250
|
|
|
@@ -3428,7 +4292,7 @@ add_builtin(
|
|
|
3428
4292
|
"point": vec3,
|
|
3429
4293
|
"max_dist": float,
|
|
3430
4294
|
},
|
|
3431
|
-
|
|
4295
|
+
value_type=MeshQueryPoint,
|
|
3432
4296
|
group="Geometry",
|
|
3433
4297
|
doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space.
|
|
3434
4298
|
|
|
@@ -3474,7 +4338,7 @@ add_builtin(
|
|
|
3474
4338
|
"point": vec3,
|
|
3475
4339
|
"min_dist": float,
|
|
3476
4340
|
},
|
|
3477
|
-
|
|
4341
|
+
value_type=MeshQueryPoint,
|
|
3478
4342
|
group="Geometry",
|
|
3479
4343
|
doc="""Computes the furthest point on the mesh with identifier `id` to the given point in space.
|
|
3480
4344
|
|
|
@@ -3531,7 +4395,7 @@ add_builtin(
|
|
|
3531
4395
|
"epsilon": float,
|
|
3532
4396
|
},
|
|
3533
4397
|
defaults={"epsilon": 1.0e-3},
|
|
3534
|
-
|
|
4398
|
+
value_type=MeshQueryPoint,
|
|
3535
4399
|
group="Geometry",
|
|
3536
4400
|
doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space.
|
|
3537
4401
|
|
|
@@ -3596,7 +4460,7 @@ add_builtin(
|
|
|
3596
4460
|
"threshold": float,
|
|
3597
4461
|
},
|
|
3598
4462
|
defaults={"accuracy": 2.0, "threshold": 0.5},
|
|
3599
|
-
|
|
4463
|
+
value_type=MeshQueryPoint,
|
|
3600
4464
|
group="Geometry",
|
|
3601
4465
|
doc="""Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given point in space.
|
|
3602
4466
|
|
|
@@ -3655,7 +4519,7 @@ add_builtin(
|
|
|
3655
4519
|
"dir": vec3,
|
|
3656
4520
|
"max_t": float,
|
|
3657
4521
|
},
|
|
3658
|
-
|
|
4522
|
+
value_type=MeshQueryRay,
|
|
3659
4523
|
group="Geometry",
|
|
3660
4524
|
doc="""Computes the closest ray hit on the :class:`Mesh` with identifier ``id``.
|
|
3661
4525
|
|
|
@@ -3670,7 +4534,7 @@ add_builtin(
|
|
|
3670
4534
|
add_builtin(
|
|
3671
4535
|
"mesh_query_aabb",
|
|
3672
4536
|
input_types={"id": uint64, "low": vec3, "high": vec3},
|
|
3673
|
-
|
|
4537
|
+
value_type=MeshQueryAABB,
|
|
3674
4538
|
group="Geometry",
|
|
3675
4539
|
doc="""Construct an axis-aligned bounding box query against a :class:`Mesh`.
|
|
3676
4540
|
|
|
@@ -3714,7 +4578,7 @@ add_builtin(
|
|
|
3714
4578
|
add_builtin(
|
|
3715
4579
|
"hash_grid_query",
|
|
3716
4580
|
input_types={"id": uint64, "point": vec3, "max_dist": float},
|
|
3717
|
-
|
|
4581
|
+
value_type=HashGridQuery,
|
|
3718
4582
|
group="Geometry",
|
|
3719
4583
|
doc="""Construct a point query against a :class:`HashGrid`.
|
|
3720
4584
|
|
|
@@ -3843,10 +4707,10 @@ add_builtin(
|
|
|
3843
4707
|
|
|
3844
4708
|
add_builtin("iter_next", input_types={"range": range_t}, value_type=int, group="Utility", export=False, hidden=True)
|
|
3845
4709
|
add_builtin(
|
|
3846
|
-
"iter_next", input_types={"query":
|
|
4710
|
+
"iter_next", input_types={"query": HashGridQuery}, value_type=int, group="Utility", export=False, hidden=True
|
|
3847
4711
|
)
|
|
3848
4712
|
add_builtin(
|
|
3849
|
-
"iter_next", input_types={"query":
|
|
4713
|
+
"iter_next", input_types={"query": MeshQueryAABB}, value_type=int, group="Utility", export=False, hidden=True
|
|
3850
4714
|
)
|
|
3851
4715
|
|
|
3852
4716
|
add_builtin(
|
|
@@ -3889,7 +4753,7 @@ def _check_volume_type_is_supported(dtype):
|
|
|
3889
4753
|
|
|
3890
4754
|
def check_volume_value_grad_compatibility(dtype, grad_dtype):
|
|
3891
4755
|
if type_is_vector(dtype):
|
|
3892
|
-
expected = matrix(shape=(
|
|
4756
|
+
expected = matrix(shape=(type_size(dtype), 3), dtype=type_scalar_type(dtype))
|
|
3893
4757
|
else:
|
|
3894
4758
|
expected = vector(length=3, dtype=dtype)
|
|
3895
4759
|
|
|
@@ -4062,6 +4926,7 @@ add_builtin(
|
|
|
4062
4926
|
input_types={"id": uint64, "i": int, "j": int, "k": int, "value": float},
|
|
4063
4927
|
group="Volumes",
|
|
4064
4928
|
doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
|
|
4929
|
+
export=False,
|
|
4065
4930
|
)
|
|
4066
4931
|
|
|
4067
4932
|
add_builtin(
|
|
@@ -4089,6 +4954,7 @@ add_builtin(
|
|
|
4089
4954
|
input_types={"id": uint64, "i": int, "j": int, "k": int, "value": vec3},
|
|
4090
4955
|
group="Volumes",
|
|
4091
4956
|
doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
|
|
4957
|
+
export=False,
|
|
4092
4958
|
)
|
|
4093
4959
|
|
|
4094
4960
|
add_builtin(
|
|
@@ -4114,6 +4980,7 @@ add_builtin(
|
|
|
4114
4980
|
input_types={"id": uint64, "i": int, "j": int, "k": int, "value": int},
|
|
4115
4981
|
group="Volumes",
|
|
4116
4982
|
doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
|
|
4983
|
+
export=False,
|
|
4117
4984
|
)
|
|
4118
4985
|
|
|
4119
4986
|
|
|
@@ -4527,6 +5394,16 @@ add_builtin(
|
|
|
4527
5394
|
native_func="builtin_tid1d",
|
|
4528
5395
|
)
|
|
4529
5396
|
|
|
5397
|
+
add_builtin(
|
|
5398
|
+
"block_dim",
|
|
5399
|
+
input_types={},
|
|
5400
|
+
value_type=int,
|
|
5401
|
+
group="Utility",
|
|
5402
|
+
doc="Returns the number of threads in the current block.",
|
|
5403
|
+
namespace="",
|
|
5404
|
+
native_func="builtin_block_dim",
|
|
5405
|
+
)
|
|
5406
|
+
|
|
4530
5407
|
add_builtin(
|
|
4531
5408
|
"tid",
|
|
4532
5409
|
input_types={},
|
|
@@ -4667,7 +5544,7 @@ def array_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any
|
|
|
4667
5544
|
return array(dtype=Scalar)
|
|
4668
5545
|
|
|
4669
5546
|
dtype = arg_values["dtype"]
|
|
4670
|
-
shape = arg_values["shape"]
|
|
5547
|
+
shape = extract_tuple(arg_values["shape"], as_constant=True)
|
|
4671
5548
|
return array(dtype=dtype, ndim=len(shape))
|
|
4672
5549
|
|
|
4673
5550
|
|
|
@@ -4677,8 +5554,9 @@ def array_dispatch_func(input_types: Mapping[str, type], return_type: Any, args:
|
|
|
4677
5554
|
# to the underlying C++ function's runtime and template params.
|
|
4678
5555
|
|
|
4679
5556
|
dtype = return_type.dtype
|
|
5557
|
+
shape = extract_tuple(args["shape"], as_constant=True)
|
|
4680
5558
|
|
|
4681
|
-
func_args = (args["ptr"], *
|
|
5559
|
+
func_args = (args["ptr"], *shape)
|
|
4682
5560
|
template_args = (dtype,)
|
|
4683
5561
|
return (func_args, template_args)
|
|
4684
5562
|
|
|
@@ -4958,6 +5836,12 @@ def create_atomic_op_value_func(op: str):
|
|
|
4958
5836
|
f"atomic_{op}() operations only work on arrays with [u]int32, [u]int64, float32, or float64 "
|
|
4959
5837
|
f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
|
|
4960
5838
|
)
|
|
5839
|
+
elif op in ("cas", "exch"):
|
|
5840
|
+
if not any(types_equal(scalar_type, x, match_generic=True) for x in SUPPORTED_ATOMIC_TYPES):
|
|
5841
|
+
raise RuntimeError(
|
|
5842
|
+
f"atomic_{op}() operations only work on arrays with [u]int32, [u]int64, float32, or float64 "
|
|
5843
|
+
f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
|
|
5844
|
+
)
|
|
4961
5845
|
else:
|
|
4962
5846
|
raise NotImplementedError
|
|
4963
5847
|
|
|
@@ -5187,6 +6071,120 @@ for array_type in array_types:
|
|
|
5187
6071
|
skip_replay=True,
|
|
5188
6072
|
)
|
|
5189
6073
|
|
|
6074
|
+
add_builtin(
|
|
6075
|
+
"atomic_cas",
|
|
6076
|
+
hidden=hidden,
|
|
6077
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "compare": Any, "value": Any},
|
|
6078
|
+
constraint=atomic_op_constraint,
|
|
6079
|
+
value_func=create_atomic_op_value_func("cas"),
|
|
6080
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
6081
|
+
doc="""Atomically compare and swap ``value`` with ``arr[i]`` if ``arr[i]`` equals ``compare``, and return the old value.
|
|
6082
|
+
|
|
6083
|
+
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6084
|
+
group="Utility",
|
|
6085
|
+
skip_replay=True,
|
|
6086
|
+
)
|
|
6087
|
+
add_builtin(
|
|
6088
|
+
"atomic_cas",
|
|
6089
|
+
hidden=hidden,
|
|
6090
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "compare": Any, "value": Any},
|
|
6091
|
+
constraint=atomic_op_constraint,
|
|
6092
|
+
value_func=create_atomic_op_value_func("cas"),
|
|
6093
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
6094
|
+
doc="""Atomically compare and swap ``value`` with ``arr[i,j]`` if ``arr[i,j]`` equals ``compare``, and return the old value.
|
|
6095
|
+
|
|
6096
|
+
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6097
|
+
group="Utility",
|
|
6098
|
+
skip_replay=True,
|
|
6099
|
+
)
|
|
6100
|
+
add_builtin(
|
|
6101
|
+
"atomic_cas",
|
|
6102
|
+
hidden=hidden,
|
|
6103
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "compare": Any, "value": Any},
|
|
6104
|
+
constraint=atomic_op_constraint,
|
|
6105
|
+
value_func=create_atomic_op_value_func("cas"),
|
|
6106
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
6107
|
+
doc="""Atomically compare and swap ``value`` with ``arr[i,j,k]`` if ``arr[i,j,k]`` equals ``compare``, and return the old value.
|
|
6108
|
+
|
|
6109
|
+
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6110
|
+
group="Utility",
|
|
6111
|
+
skip_replay=True,
|
|
6112
|
+
)
|
|
6113
|
+
add_builtin(
|
|
6114
|
+
"atomic_cas",
|
|
6115
|
+
hidden=hidden,
|
|
6116
|
+
input_types={
|
|
6117
|
+
"arr": array_type(dtype=Any),
|
|
6118
|
+
"i": Int,
|
|
6119
|
+
"j": Int,
|
|
6120
|
+
"k": Int,
|
|
6121
|
+
"l": Int,
|
|
6122
|
+
"compare": Any,
|
|
6123
|
+
"value": Any,
|
|
6124
|
+
},
|
|
6125
|
+
constraint=atomic_op_constraint,
|
|
6126
|
+
value_func=create_atomic_op_value_func("cas"),
|
|
6127
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
6128
|
+
doc="""Atomically compare and swap ``value`` with ``arr[i,j,k,l]`` if ``arr[i,j,k,l]`` equals ``compare``, and return the old value.
|
|
6129
|
+
|
|
6130
|
+
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6131
|
+
group="Utility",
|
|
6132
|
+
skip_replay=True,
|
|
6133
|
+
)
|
|
6134
|
+
|
|
6135
|
+
add_builtin(
|
|
6136
|
+
"atomic_exch",
|
|
6137
|
+
hidden=hidden,
|
|
6138
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
6139
|
+
constraint=atomic_op_constraint,
|
|
6140
|
+
value_func=create_atomic_op_value_func("exch"),
|
|
6141
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
6142
|
+
doc="""Atomically exchange ``value`` with ``arr[i]`` and return the old value.
|
|
6143
|
+
|
|
6144
|
+
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6145
|
+
group="Utility",
|
|
6146
|
+
skip_replay=True,
|
|
6147
|
+
)
|
|
6148
|
+
add_builtin(
|
|
6149
|
+
"atomic_exch",
|
|
6150
|
+
hidden=hidden,
|
|
6151
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
6152
|
+
constraint=atomic_op_constraint,
|
|
6153
|
+
value_func=create_atomic_op_value_func("exch"),
|
|
6154
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
6155
|
+
doc="""Atomically exchange ``value`` with ``arr[i,j]`` and return the old value.
|
|
6156
|
+
|
|
6157
|
+
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6158
|
+
group="Utility",
|
|
6159
|
+
skip_replay=True,
|
|
6160
|
+
)
|
|
6161
|
+
add_builtin(
|
|
6162
|
+
"atomic_exch",
|
|
6163
|
+
hidden=hidden,
|
|
6164
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
6165
|
+
constraint=atomic_op_constraint,
|
|
6166
|
+
value_func=create_atomic_op_value_func("exch"),
|
|
6167
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
6168
|
+
doc="""Atomically exchange ``value`` with ``arr[i,j,k]`` and return the old value.
|
|
6169
|
+
|
|
6170
|
+
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6171
|
+
group="Utility",
|
|
6172
|
+
skip_replay=True,
|
|
6173
|
+
)
|
|
6174
|
+
add_builtin(
|
|
6175
|
+
"atomic_exch",
|
|
6176
|
+
hidden=hidden,
|
|
6177
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
6178
|
+
constraint=atomic_op_constraint,
|
|
6179
|
+
value_func=create_atomic_op_value_func("exch"),
|
|
6180
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
6181
|
+
doc="""Atomically exchange ``value`` with ``arr[i,j,k,l]`` and return the old value.
|
|
6182
|
+
|
|
6183
|
+
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6184
|
+
group="Utility",
|
|
6185
|
+
skip_replay=True,
|
|
6186
|
+
)
|
|
6187
|
+
|
|
5190
6188
|
|
|
5191
6189
|
# used to index into builtin types, i.e.: y = vec3[1]
|
|
5192
6190
|
def extract_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
@@ -5269,6 +6267,16 @@ add_builtin(
|
|
|
5269
6267
|
group="Utility",
|
|
5270
6268
|
skip_replay=True,
|
|
5271
6269
|
)
|
|
6270
|
+
# implements &transformation[index]
|
|
6271
|
+
add_builtin(
|
|
6272
|
+
"index",
|
|
6273
|
+
input_types={"a": transformation(dtype=Float), "i": int},
|
|
6274
|
+
value_func=vector_index_value_func,
|
|
6275
|
+
dispatch_func=vector_index_dispatch_func,
|
|
6276
|
+
hidden=True,
|
|
6277
|
+
group="Utility",
|
|
6278
|
+
skip_replay=True,
|
|
6279
|
+
)
|
|
5272
6280
|
# implements &(*vector)[index]
|
|
5273
6281
|
add_builtin(
|
|
5274
6282
|
"indexref",
|
|
@@ -5289,6 +6297,16 @@ add_builtin(
|
|
|
5289
6297
|
group="Utility",
|
|
5290
6298
|
skip_replay=True,
|
|
5291
6299
|
)
|
|
6300
|
+
# implements &(*transformation)[index]
|
|
6301
|
+
add_builtin(
|
|
6302
|
+
"indexref",
|
|
6303
|
+
input_types={"a": transformation(dtype=Float), "i": int},
|
|
6304
|
+
value_func=vector_index_value_func,
|
|
6305
|
+
dispatch_func=vector_index_dispatch_func,
|
|
6306
|
+
hidden=True,
|
|
6307
|
+
group="Utility",
|
|
6308
|
+
skip_replay=True,
|
|
6309
|
+
)
|
|
5292
6310
|
|
|
5293
6311
|
|
|
5294
6312
|
# implements vector[index] = value
|
|
@@ -5297,6 +6315,7 @@ add_builtin(
|
|
|
5297
6315
|
input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
|
|
5298
6316
|
value_type=None,
|
|
5299
6317
|
hidden=True,
|
|
6318
|
+
export=False,
|
|
5300
6319
|
group="Utility",
|
|
5301
6320
|
)
|
|
5302
6321
|
|
|
@@ -5306,6 +6325,16 @@ add_builtin(
|
|
|
5306
6325
|
input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
|
|
5307
6326
|
value_type=None,
|
|
5308
6327
|
hidden=True,
|
|
6328
|
+
export=False,
|
|
6329
|
+
group="Utility",
|
|
6330
|
+
)
|
|
6331
|
+
# implements transformation[index] = value
|
|
6332
|
+
add_builtin(
|
|
6333
|
+
"assign_inplace",
|
|
6334
|
+
input_types={"a": transformation(dtype=Scalar), "i": int, "value": Scalar},
|
|
6335
|
+
value_type=None,
|
|
6336
|
+
hidden=True,
|
|
6337
|
+
export=False,
|
|
5309
6338
|
group="Utility",
|
|
5310
6339
|
)
|
|
5311
6340
|
|
|
@@ -5321,6 +6350,7 @@ add_builtin(
|
|
|
5321
6350
|
input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
|
|
5322
6351
|
value_func=vector_assign_value_func,
|
|
5323
6352
|
hidden=True,
|
|
6353
|
+
export=False,
|
|
5324
6354
|
group="Utility",
|
|
5325
6355
|
)
|
|
5326
6356
|
|
|
@@ -5330,6 +6360,17 @@ add_builtin(
|
|
|
5330
6360
|
input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
|
|
5331
6361
|
value_func=vector_assign_value_func,
|
|
5332
6362
|
hidden=True,
|
|
6363
|
+
export=False,
|
|
6364
|
+
group="Utility",
|
|
6365
|
+
)
|
|
6366
|
+
|
|
6367
|
+
# implements transformation[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
|
|
6368
|
+
add_builtin(
|
|
6369
|
+
"assign_copy",
|
|
6370
|
+
input_types={"a": transformation(dtype=Scalar), "i": int, "value": Scalar},
|
|
6371
|
+
value_func=vector_assign_value_func,
|
|
6372
|
+
hidden=True,
|
|
6373
|
+
export=False,
|
|
5333
6374
|
group="Utility",
|
|
5334
6375
|
)
|
|
5335
6376
|
|
|
@@ -5339,6 +6380,7 @@ add_builtin(
|
|
|
5339
6380
|
input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
|
|
5340
6381
|
value_type=None,
|
|
5341
6382
|
hidden=True,
|
|
6383
|
+
export=False,
|
|
5342
6384
|
group="Utility",
|
|
5343
6385
|
)
|
|
5344
6386
|
|
|
@@ -5348,6 +6390,27 @@ add_builtin(
|
|
|
5348
6390
|
input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
|
|
5349
6391
|
value_type=None,
|
|
5350
6392
|
hidden=True,
|
|
6393
|
+
export=False,
|
|
6394
|
+
group="Utility",
|
|
6395
|
+
)
|
|
6396
|
+
|
|
6397
|
+
# implements transformation[idx] += scalar
|
|
6398
|
+
add_builtin(
|
|
6399
|
+
"add_inplace",
|
|
6400
|
+
input_types={"a": transformation(dtype=Float), "i": int, "value": Float},
|
|
6401
|
+
value_type=None,
|
|
6402
|
+
hidden=True,
|
|
6403
|
+
export=False,
|
|
6404
|
+
group="Utility",
|
|
6405
|
+
)
|
|
6406
|
+
|
|
6407
|
+
# implements transformation.p += vec3
|
|
6408
|
+
add_builtin(
|
|
6409
|
+
"transform_add_inplace",
|
|
6410
|
+
input_types={"a": transformation(dtype=Float), "value": vector(length=3, dtype=Float)},
|
|
6411
|
+
value_type=None,
|
|
6412
|
+
hidden=True,
|
|
6413
|
+
export=False,
|
|
5351
6414
|
group="Utility",
|
|
5352
6415
|
)
|
|
5353
6416
|
|
|
@@ -5357,6 +6420,7 @@ add_builtin(
|
|
|
5357
6420
|
input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
|
|
5358
6421
|
value_type=None,
|
|
5359
6422
|
hidden=True,
|
|
6423
|
+
export=False,
|
|
5360
6424
|
group="Utility",
|
|
5361
6425
|
)
|
|
5362
6426
|
|
|
@@ -5366,6 +6430,27 @@ add_builtin(
|
|
|
5366
6430
|
input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
|
|
5367
6431
|
value_type=None,
|
|
5368
6432
|
hidden=True,
|
|
6433
|
+
export=False,
|
|
6434
|
+
group="Utility",
|
|
6435
|
+
)
|
|
6436
|
+
|
|
6437
|
+
# implements transformation[idx] -= scalar
|
|
6438
|
+
add_builtin(
|
|
6439
|
+
"sub_inplace",
|
|
6440
|
+
input_types={"a": transformation(dtype=Scalar), "i": int, "value": Scalar},
|
|
6441
|
+
value_type=None,
|
|
6442
|
+
hidden=True,
|
|
6443
|
+
export=False,
|
|
6444
|
+
group="Utility",
|
|
6445
|
+
)
|
|
6446
|
+
|
|
6447
|
+
# implements transformation.p -= vec3
|
|
6448
|
+
add_builtin(
|
|
6449
|
+
"transform_sub_inplace",
|
|
6450
|
+
input_types={"a": transformation(dtype=Float), "value": vector(length=3, dtype=Float)},
|
|
6451
|
+
value_type=None,
|
|
6452
|
+
hidden=True,
|
|
6453
|
+
export=False,
|
|
5369
6454
|
group="Utility",
|
|
5370
6455
|
)
|
|
5371
6456
|
|
|
@@ -5407,7 +6492,7 @@ add_builtin(
|
|
|
5407
6492
|
|
|
5408
6493
|
|
|
5409
6494
|
def matrix_vector_sametype(arg_types: Mapping[str, Any]):
|
|
5410
|
-
mat_size = arg_types["a"]._shape_[
|
|
6495
|
+
mat_size = arg_types["a"]._shape_[1]
|
|
5411
6496
|
vec_size = arg_types["value"]._length_
|
|
5412
6497
|
mat_type = arg_types["a"]._type_
|
|
5413
6498
|
vec_type = arg_types["value"]._type_
|
|
@@ -5420,6 +6505,7 @@ add_builtin(
|
|
|
5420
6505
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
|
|
5421
6506
|
value_type=None,
|
|
5422
6507
|
hidden=True,
|
|
6508
|
+
export=False,
|
|
5423
6509
|
group="Utility",
|
|
5424
6510
|
)
|
|
5425
6511
|
|
|
@@ -5431,6 +6517,7 @@ add_builtin(
|
|
|
5431
6517
|
constraint=matrix_vector_sametype,
|
|
5432
6518
|
value_type=None,
|
|
5433
6519
|
hidden=True,
|
|
6520
|
+
export=False,
|
|
5434
6521
|
group="Utility",
|
|
5435
6522
|
)
|
|
5436
6523
|
|
|
@@ -5446,6 +6533,7 @@ add_builtin(
|
|
|
5446
6533
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
|
|
5447
6534
|
value_func=matrix_assign_value_func,
|
|
5448
6535
|
hidden=True,
|
|
6536
|
+
export=False,
|
|
5449
6537
|
group="Utility",
|
|
5450
6538
|
)
|
|
5451
6539
|
|
|
@@ -5457,6 +6545,7 @@ add_builtin(
|
|
|
5457
6545
|
constraint=matrix_vector_sametype,
|
|
5458
6546
|
value_func=matrix_assign_value_func,
|
|
5459
6547
|
hidden=True,
|
|
6548
|
+
export=False,
|
|
5460
6549
|
group="Utility",
|
|
5461
6550
|
)
|
|
5462
6551
|
|
|
@@ -5467,6 +6556,7 @@ add_builtin(
|
|
|
5467
6556
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
|
|
5468
6557
|
value_type=None,
|
|
5469
6558
|
hidden=True,
|
|
6559
|
+
export=False,
|
|
5470
6560
|
group="Utility",
|
|
5471
6561
|
)
|
|
5472
6562
|
|
|
@@ -5478,6 +6568,7 @@ add_builtin(
|
|
|
5478
6568
|
constraint=matrix_vector_sametype,
|
|
5479
6569
|
value_type=None,
|
|
5480
6570
|
hidden=True,
|
|
6571
|
+
export=False,
|
|
5481
6572
|
group="Utility",
|
|
5482
6573
|
)
|
|
5483
6574
|
|
|
@@ -5488,6 +6579,7 @@ add_builtin(
|
|
|
5488
6579
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
|
|
5489
6580
|
value_type=None,
|
|
5490
6581
|
hidden=True,
|
|
6582
|
+
export=False,
|
|
5491
6583
|
group="Utility",
|
|
5492
6584
|
)
|
|
5493
6585
|
|
|
@@ -5498,6 +6590,7 @@ add_builtin(
|
|
|
5498
6590
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
|
|
5499
6591
|
value_type=None,
|
|
5500
6592
|
hidden=True,
|
|
6593
|
+
export=False,
|
|
5501
6594
|
group="Utility",
|
|
5502
6595
|
)
|
|
5503
6596
|
|
|
@@ -5522,6 +6615,7 @@ for t in scalar_types + vector_types + (bool,):
|
|
|
5522
6615
|
doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
|
|
5523
6616
|
group="Utility",
|
|
5524
6617
|
hidden=True,
|
|
6618
|
+
export=False,
|
|
5525
6619
|
)
|
|
5526
6620
|
|
|
5527
6621
|
|
|
@@ -5549,6 +6643,7 @@ add_builtin(
|
|
|
5549
6643
|
doc="Prints an error to stdout if ``a`` and ``b`` are equal",
|
|
5550
6644
|
group="Utility",
|
|
5551
6645
|
hidden=True,
|
|
6646
|
+
export=False,
|
|
5552
6647
|
)
|
|
5553
6648
|
|
|
5554
6649
|
add_builtin(
|
|
@@ -5568,6 +6663,7 @@ add_builtin(
|
|
|
5568
6663
|
doc="Prints an error to stdout if ``a`` and ``b`` are equal",
|
|
5569
6664
|
group="Utility",
|
|
5570
6665
|
hidden=True,
|
|
6666
|
+
export=False,
|
|
5571
6667
|
)
|
|
5572
6668
|
|
|
5573
6669
|
add_builtin(
|
|
@@ -5638,11 +6734,23 @@ add_builtin(
|
|
|
5638
6734
|
group="Utility",
|
|
5639
6735
|
)
|
|
5640
6736
|
|
|
6737
|
+
|
|
5641
6738
|
# fuzzy compare for float values
|
|
6739
|
+
def expect_near_constraint(arg_types: Mapping[str, type]):
|
|
6740
|
+
if not types_equal(arg_types["a"], arg_types["b"]):
|
|
6741
|
+
return False
|
|
6742
|
+
|
|
6743
|
+
if hasattr(arg_types["a"], "_wp_scalar_type_"):
|
|
6744
|
+
return types_equal(arg_types["a"]._wp_scalar_type_, arg_types["tolerance"])
|
|
6745
|
+
|
|
6746
|
+
return types_equal(arg_types["a"], arg_types["tolerance"])
|
|
6747
|
+
|
|
6748
|
+
|
|
5642
6749
|
add_builtin(
|
|
5643
6750
|
"expect_near",
|
|
5644
6751
|
input_types={"a": Float, "b": Float, "tolerance": Float},
|
|
5645
6752
|
defaults={"tolerance": 1.0e-6},
|
|
6753
|
+
constraint=expect_near_constraint,
|
|
5646
6754
|
value_type=None,
|
|
5647
6755
|
doc="Prints an error to stdout if ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
5648
6756
|
group="Utility",
|
|
@@ -5651,6 +6759,7 @@ add_builtin(
|
|
|
5651
6759
|
"expect_near",
|
|
5652
6760
|
input_types={"a": vector(length=Any, dtype=Float), "b": vector(length=Any, dtype=Float), "tolerance": Float},
|
|
5653
6761
|
defaults={"tolerance": 1.0e-6},
|
|
6762
|
+
constraint=expect_near_constraint,
|
|
5654
6763
|
value_type=None,
|
|
5655
6764
|
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
5656
6765
|
group="Utility",
|
|
@@ -5659,6 +6768,7 @@ add_builtin(
|
|
|
5659
6768
|
"expect_near",
|
|
5660
6769
|
input_types={"a": quaternion(dtype=Float), "b": quaternion(dtype=Float), "tolerance": Float},
|
|
5661
6770
|
defaults={"tolerance": 1.0e-6},
|
|
6771
|
+
constraint=expect_near_constraint,
|
|
5662
6772
|
value_type=None,
|
|
5663
6773
|
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
5664
6774
|
group="Utility",
|
|
@@ -5671,6 +6781,7 @@ add_builtin(
|
|
|
5671
6781
|
"tolerance": Float,
|
|
5672
6782
|
},
|
|
5673
6783
|
defaults={"tolerance": 1.0e-6},
|
|
6784
|
+
constraint=expect_near_constraint,
|
|
5674
6785
|
value_type=None,
|
|
5675
6786
|
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
5676
6787
|
group="Utility",
|
|
@@ -6088,19 +7199,19 @@ add_builtin("unot", input_types={"a": array(dtype=Any)}, value_type=builtins.boo
|
|
|
6088
7199
|
# Tile operators
|
|
6089
7200
|
def tile_unary_value_func(arg_types, arg_values):
|
|
6090
7201
|
if arg_types is None:
|
|
6091
|
-
return
|
|
7202
|
+
return tile(dtype=Scalar, shape=Tuple[int, ...])
|
|
6092
7203
|
|
|
6093
7204
|
t = arg_types["x"]
|
|
6094
7205
|
|
|
6095
7206
|
if not is_tile(t):
|
|
6096
7207
|
raise TypeError(f"Expected tile for unary expression, got {t}")
|
|
6097
7208
|
|
|
6098
|
-
return
|
|
7209
|
+
return tile(dtype=t.dtype, shape=t.shape)
|
|
6099
7210
|
|
|
6100
7211
|
|
|
6101
7212
|
def tile_scalar_mul_value_func(arg_types, arg_values):
|
|
6102
7213
|
if arg_types is None:
|
|
6103
|
-
return
|
|
7214
|
+
return tile(dtype=Any, shape=Tuple[int, ...])
|
|
6104
7215
|
|
|
6105
7216
|
x = arg_types["x"]
|
|
6106
7217
|
y = arg_types["y"]
|
|
@@ -6110,19 +7221,19 @@ def tile_scalar_mul_value_func(arg_types, arg_values):
|
|
|
6110
7221
|
if x.dtype != y:
|
|
6111
7222
|
raise TypeError(f"Scalar factor type {y} does not match tile type {x.dtype} for tile*scalar")
|
|
6112
7223
|
|
|
6113
|
-
return
|
|
7224
|
+
return tile(dtype=x.dtype, shape=x.shape)
|
|
6114
7225
|
|
|
6115
7226
|
# scalar*tile
|
|
6116
7227
|
if is_tile(y):
|
|
6117
7228
|
if y.dtype != x:
|
|
6118
7229
|
raise TypeError(f"Scalar factor type {x} does not match tile type {y.dtype} for scalar*tile")
|
|
6119
7230
|
|
|
6120
|
-
return
|
|
7231
|
+
return tile(dtype=y.dtype, shape=y.shape)
|
|
6121
7232
|
|
|
6122
7233
|
|
|
6123
7234
|
add_builtin(
|
|
6124
7235
|
"neg",
|
|
6125
|
-
input_types={"x":
|
|
7236
|
+
input_types={"x": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
6126
7237
|
value_func=tile_unary_value_func,
|
|
6127
7238
|
doc="Negate each element of a tile",
|
|
6128
7239
|
export=False,
|
|
@@ -6132,7 +7243,7 @@ add_builtin(
|
|
|
6132
7243
|
|
|
6133
7244
|
add_builtin(
|
|
6134
7245
|
"add",
|
|
6135
|
-
input_types={"a":
|
|
7246
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
6136
7247
|
value_func=tile_binary_map_value_func,
|
|
6137
7248
|
# dispatch_func=tile_map_dispatch_func,
|
|
6138
7249
|
# variadic=True,
|
|
@@ -6144,7 +7255,7 @@ add_builtin(
|
|
|
6144
7255
|
|
|
6145
7256
|
add_builtin(
|
|
6146
7257
|
"sub",
|
|
6147
|
-
input_types={"a":
|
|
7258
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
6148
7259
|
value_func=tile_binary_map_value_func,
|
|
6149
7260
|
# dispatch_func=tile_map_dispatch_func,
|
|
6150
7261
|
# variadic=True,
|
|
@@ -6157,7 +7268,7 @@ add_builtin(
|
|
|
6157
7268
|
|
|
6158
7269
|
add_builtin(
|
|
6159
7270
|
"mul",
|
|
6160
|
-
input_types={"x":
|
|
7271
|
+
input_types={"x": tile(dtype=Any, shape=Tuple[int, ...]), "y": Scalar},
|
|
6161
7272
|
value_func=tile_scalar_mul_value_func,
|
|
6162
7273
|
doc="Multiply each element of a tile by a scalar",
|
|
6163
7274
|
export=False,
|
|
@@ -6167,7 +7278,7 @@ add_builtin(
|
|
|
6167
7278
|
|
|
6168
7279
|
add_builtin(
|
|
6169
7280
|
"mul",
|
|
6170
|
-
input_types={"x": Scalar, "y":
|
|
7281
|
+
input_types={"x": Scalar, "y": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
6171
7282
|
value_func=tile_scalar_mul_value_func,
|
|
6172
7283
|
doc="Multiply each element of a tile by a scalar",
|
|
6173
7284
|
export=False,
|
|
@@ -6176,9 +7287,48 @@ add_builtin(
|
|
|
6176
7287
|
)
|
|
6177
7288
|
|
|
6178
7289
|
|
|
7290
|
+
def tile_inplace_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
7291
|
+
a = args["a"]
|
|
7292
|
+
b = args["b"]
|
|
7293
|
+
|
|
7294
|
+
a_type = input_types["a"]
|
|
7295
|
+
b_type = input_types["b"]
|
|
7296
|
+
|
|
7297
|
+
if a_type.shape != b_type.shape:
|
|
7298
|
+
raise ValueError(f"Tile inplace arguments must have the same shape, got {a_type.shape} and {b_type.shape}")
|
|
7299
|
+
|
|
7300
|
+
func_args = (a, b)
|
|
7301
|
+
template_args = ()
|
|
7302
|
+
return (func_args, template_args)
|
|
7303
|
+
|
|
7304
|
+
|
|
7305
|
+
add_builtin(
|
|
7306
|
+
"add_inplace",
|
|
7307
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
7308
|
+
value_type=None,
|
|
7309
|
+
dispatch_func=tile_inplace_dispatch_func,
|
|
7310
|
+
export=False,
|
|
7311
|
+
hidden=True,
|
|
7312
|
+
native_func="tile_add_inplace",
|
|
7313
|
+
group="Operators",
|
|
7314
|
+
)
|
|
7315
|
+
|
|
7316
|
+
|
|
7317
|
+
add_builtin(
|
|
7318
|
+
"sub_inplace",
|
|
7319
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
7320
|
+
value_type=None,
|
|
7321
|
+
dispatch_func=tile_inplace_dispatch_func,
|
|
7322
|
+
export=False,
|
|
7323
|
+
hidden=True,
|
|
7324
|
+
native_func="tile_sub_inplace",
|
|
7325
|
+
group="Operators",
|
|
7326
|
+
)
|
|
7327
|
+
|
|
7328
|
+
|
|
6179
7329
|
def tile_diag_add_value_func(arg_types, arg_values):
|
|
6180
7330
|
if arg_types is None:
|
|
6181
|
-
return
|
|
7331
|
+
return tile(dtype=Any, shape=Tuple[int, int])
|
|
6182
7332
|
|
|
6183
7333
|
a = arg_types["a"]
|
|
6184
7334
|
d = arg_types["d"]
|
|
@@ -6208,7 +7358,7 @@ def tile_diag_add_value_func(arg_types, arg_values):
|
|
|
6208
7358
|
)
|
|
6209
7359
|
|
|
6210
7360
|
# use first argument to define output type
|
|
6211
|
-
return
|
|
7361
|
+
return tile(dtype=a.dtype, shape=a.shape, layout=a.layout, strides=a.strides, storage="shared")
|
|
6212
7362
|
|
|
6213
7363
|
|
|
6214
7364
|
def tile_diag_add_lto_dispatch_func(
|
|
@@ -6230,7 +7380,7 @@ def tile_diag_add_lto_dispatch_func(
|
|
|
6230
7380
|
|
|
6231
7381
|
add_builtin(
|
|
6232
7382
|
"tile_diag_add",
|
|
6233
|
-
input_types={"a":
|
|
7383
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, int]), "d": tile(dtype=Any, shape=Tuple[int])},
|
|
6234
7384
|
value_func=tile_diag_add_value_func,
|
|
6235
7385
|
lto_dispatch_func=tile_diag_add_lto_dispatch_func,
|
|
6236
7386
|
native_func="tile_diag_add",
|
|
@@ -6240,18 +7390,40 @@ add_builtin(
|
|
|
6240
7390
|
)
|
|
6241
7391
|
|
|
6242
7392
|
|
|
6243
|
-
##
|
|
6244
|
-
## MathDx, LTOIR-based, Tile functions
|
|
6245
|
-
##
|
|
7393
|
+
##
|
|
7394
|
+
## MathDx, LTOIR-based, Tile functions
|
|
7395
|
+
##
|
|
7396
|
+
|
|
7397
|
+
|
|
7398
|
+
##
|
|
7399
|
+
## Matmul
|
|
7400
|
+
##
|
|
7401
|
+
|
|
7402
|
+
|
|
7403
|
+
def tile_matmul_out_value_func(arg_types, arg_values):
|
|
7404
|
+
# return generic type (for doc builds)
|
|
7405
|
+
if arg_types is None:
|
|
7406
|
+
return None
|
|
7407
|
+
|
|
7408
|
+
a = arg_types["a"]
|
|
7409
|
+
b = arg_types["b"]
|
|
7410
|
+
|
|
7411
|
+
if not is_tile(a):
|
|
7412
|
+
raise TypeError(f"tile_matmul() 'a' argument must be a tile, got {a!r}")
|
|
7413
|
+
|
|
7414
|
+
if not is_tile(b):
|
|
7415
|
+
raise TypeError(f"tile_matmul() 'b' argument must be a tile, got {b!r}")
|
|
7416
|
+
|
|
7417
|
+
if not is_tile(arg_types["out"]):
|
|
7418
|
+
raise TypeError(f"tile_matmul() 'out' argument must be a tile, got {arg_types['out']!r}")
|
|
7419
|
+
|
|
7420
|
+
return None
|
|
6246
7421
|
|
|
6247
7422
|
|
|
6248
|
-
##
|
|
6249
|
-
## Matmul
|
|
6250
|
-
##
|
|
6251
7423
|
def tile_matmul_value_func(arg_types, arg_values):
|
|
6252
7424
|
# return generic type (for doc builds)
|
|
6253
7425
|
if arg_types is None:
|
|
6254
|
-
return
|
|
7426
|
+
return tile(dtype=Float, shape=Tuple[int, int])
|
|
6255
7427
|
|
|
6256
7428
|
a = arg_types["a"]
|
|
6257
7429
|
b = arg_types["b"]
|
|
@@ -6262,16 +7434,7 @@ def tile_matmul_value_func(arg_types, arg_values):
|
|
|
6262
7434
|
if not is_tile(b):
|
|
6263
7435
|
raise TypeError(f"tile_matmul() 'b' argument must be a tile, got {b!r}")
|
|
6264
7436
|
|
|
6265
|
-
|
|
6266
|
-
if len(arg_types) == 2:
|
|
6267
|
-
return Tile(dtype=a.dtype, shape=(a.shape[0], b.shape[1]), storage="shared")
|
|
6268
|
-
|
|
6269
|
-
# wp.tile_matmul(a, b, out)
|
|
6270
|
-
elif len(arg_types) == 3:
|
|
6271
|
-
if not is_tile(arg_types["out"]):
|
|
6272
|
-
raise TypeError(f"tile_matmul() 'out' argument must be a tile, got {arg_types['out']!r}")
|
|
6273
|
-
|
|
6274
|
-
return None
|
|
7437
|
+
return tile(dtype=a.dtype, shape=(a.shape[0], b.shape[1]), storage="shared")
|
|
6275
7438
|
|
|
6276
7439
|
|
|
6277
7440
|
def tile_matmul_lto_dispatch_func(
|
|
@@ -6345,36 +7508,41 @@ def tile_matmul_lto_dispatch_func(
|
|
|
6345
7508
|
num_threads,
|
|
6346
7509
|
builder,
|
|
6347
7510
|
)
|
|
6348
|
-
|
|
6349
|
-
|
|
6350
|
-
|
|
6351
|
-
|
|
6352
|
-
|
|
6353
|
-
|
|
6354
|
-
|
|
6355
|
-
|
|
6356
|
-
|
|
6357
|
-
|
|
6358
|
-
|
|
6359
|
-
|
|
6360
|
-
|
|
6361
|
-
|
|
6362
|
-
|
|
6363
|
-
|
|
6364
|
-
|
|
6365
|
-
|
|
6366
|
-
|
|
6367
|
-
|
|
6368
|
-
|
|
6369
|
-
|
|
6370
|
-
|
|
6371
|
-
|
|
6372
|
-
|
|
6373
|
-
|
|
6374
|
-
|
|
6375
|
-
|
|
6376
|
-
|
|
6377
|
-
|
|
7511
|
+
if warp.config.enable_backward:
|
|
7512
|
+
# adjA += adjC * B^T - Transpose ~= flipped layout
|
|
7513
|
+
(fun_backward_A, lto_backward_A) = warp.build.build_lto_dot(
|
|
7514
|
+
M,
|
|
7515
|
+
K,
|
|
7516
|
+
N,
|
|
7517
|
+
out.type.dtype,
|
|
7518
|
+
b.type.dtype,
|
|
7519
|
+
a.type.dtype,
|
|
7520
|
+
out.type.layout,
|
|
7521
|
+
tile_flip_layout(b.type.layout),
|
|
7522
|
+
a.type.layout,
|
|
7523
|
+
arch,
|
|
7524
|
+
num_threads,
|
|
7525
|
+
builder,
|
|
7526
|
+
)
|
|
7527
|
+
# adjB += A^T * adjC - Transpose ~= flipped layout
|
|
7528
|
+
(fun_backward_B, lto_backward_B) = warp.build.build_lto_dot(
|
|
7529
|
+
K,
|
|
7530
|
+
N,
|
|
7531
|
+
M,
|
|
7532
|
+
a.type.dtype,
|
|
7533
|
+
out.type.dtype,
|
|
7534
|
+
b.type.dtype,
|
|
7535
|
+
tile_flip_layout(a.type.layout),
|
|
7536
|
+
out.type.layout,
|
|
7537
|
+
b.type.layout,
|
|
7538
|
+
arch,
|
|
7539
|
+
num_threads,
|
|
7540
|
+
builder,
|
|
7541
|
+
)
|
|
7542
|
+
else:
|
|
7543
|
+
# adjoints aren't computed, so we reuse fun_forward as a dummy arg
|
|
7544
|
+
(fun_backward_A, lto_backward_A) = (fun_forward, None)
|
|
7545
|
+
(fun_backward_B, lto_backward_B) = (fun_forward, None)
|
|
6378
7546
|
|
|
6379
7547
|
return (
|
|
6380
7548
|
(
|
|
@@ -6394,11 +7562,11 @@ def tile_matmul_lto_dispatch_func(
|
|
|
6394
7562
|
add_builtin(
|
|
6395
7563
|
"tile_matmul",
|
|
6396
7564
|
input_types={
|
|
6397
|
-
"a":
|
|
6398
|
-
"b":
|
|
6399
|
-
"out":
|
|
7565
|
+
"a": tile(dtype=Float, shape=Tuple[int, int]),
|
|
7566
|
+
"b": tile(dtype=Float, shape=Tuple[int, int]),
|
|
7567
|
+
"out": tile(dtype=Float, shape=Tuple[int, int]),
|
|
6400
7568
|
},
|
|
6401
|
-
value_func=
|
|
7569
|
+
value_func=tile_matmul_out_value_func,
|
|
6402
7570
|
lto_dispatch_func=tile_matmul_lto_dispatch_func,
|
|
6403
7571
|
variadic=False,
|
|
6404
7572
|
doc="""Computes the matrix product and accumulates ``out += a*b``.
|
|
@@ -6420,7 +7588,7 @@ add_builtin(
|
|
|
6420
7588
|
|
|
6421
7589
|
add_builtin(
|
|
6422
7590
|
"tile_matmul",
|
|
6423
|
-
input_types={"a":
|
|
7591
|
+
input_types={"a": tile(dtype=Float, shape=Tuple[int, int]), "b": tile(dtype=Float, shape=Tuple[int, int])},
|
|
6424
7592
|
value_func=tile_matmul_value_func,
|
|
6425
7593
|
lto_dispatch_func=tile_matmul_lto_dispatch_func,
|
|
6426
7594
|
variadic=False,
|
|
@@ -6447,7 +7615,7 @@ add_builtin(
|
|
|
6447
7615
|
##
|
|
6448
7616
|
def tile_fft_generic_value_func(arg_types, arg_values):
|
|
6449
7617
|
if arg_types is None:
|
|
6450
|
-
return
|
|
7618
|
+
return tile(dtype=vector(length=2, dtype=Float), shape=Tuple[int, int])
|
|
6451
7619
|
|
|
6452
7620
|
if len(arg_types) != 1:
|
|
6453
7621
|
raise TypeError(f"tile_fft() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
@@ -6475,7 +7643,7 @@ def tile_fft_generic_lto_dispatch_func(
|
|
|
6475
7643
|
arg_values: Mapping[str, Var],
|
|
6476
7644
|
options: Mapping[str, Any],
|
|
6477
7645
|
builder: warp.context.ModuleBuilder,
|
|
6478
|
-
direction: str = None,
|
|
7646
|
+
direction: str | None = None,
|
|
6479
7647
|
):
|
|
6480
7648
|
inout = arg_values["inout"]
|
|
6481
7649
|
inout.type.storage = "register"
|
|
@@ -6529,7 +7697,7 @@ def tile_fft_generic_lto_dispatch_func(
|
|
|
6529
7697
|
|
|
6530
7698
|
add_builtin(
|
|
6531
7699
|
"tile_fft",
|
|
6532
|
-
input_types={"inout":
|
|
7700
|
+
input_types={"inout": tile(dtype=vector(length=2, dtype=Float), shape=Tuple[int, int])},
|
|
6533
7701
|
value_func=tile_fft_generic_value_func,
|
|
6534
7702
|
lto_dispatch_func=functools.partial(tile_fft_generic_lto_dispatch_func, direction="forward"),
|
|
6535
7703
|
variadic=True,
|
|
@@ -6550,7 +7718,7 @@ add_builtin(
|
|
|
6550
7718
|
|
|
6551
7719
|
add_builtin(
|
|
6552
7720
|
"tile_ifft",
|
|
6553
|
-
input_types={"inout":
|
|
7721
|
+
input_types={"inout": tile(dtype=vector(length=2, dtype=Float), shape=Tuple[int, int])},
|
|
6554
7722
|
value_func=tile_fft_generic_value_func,
|
|
6555
7723
|
lto_dispatch_func=functools.partial(tile_fft_generic_lto_dispatch_func, direction="inverse"),
|
|
6556
7724
|
variadic=True,
|
|
@@ -6575,7 +7743,7 @@ add_builtin(
|
|
|
6575
7743
|
##
|
|
6576
7744
|
def tile_cholesky_generic_value_func(arg_types, arg_values):
|
|
6577
7745
|
if arg_types is None:
|
|
6578
|
-
return
|
|
7746
|
+
return tile(dtype=Float, shape=Tuple[int, int])
|
|
6579
7747
|
|
|
6580
7748
|
if len(arg_types) != 1:
|
|
6581
7749
|
raise TypeError("tile_cholesky() requires 1 positional args")
|
|
@@ -6591,15 +7759,19 @@ def tile_cholesky_generic_value_func(arg_types, arg_values):
|
|
|
6591
7759
|
if a.shape[0] != a.shape[1]:
|
|
6592
7760
|
raise ValueError("tile_cholesky() argument must be square")
|
|
6593
7761
|
|
|
6594
|
-
return
|
|
7762
|
+
return tile(dtype=a.dtype, shape=a.shape, layout=a.layout, strides=a.strides, storage="shared")
|
|
6595
7763
|
|
|
6596
7764
|
|
|
6597
|
-
cusolver_function_map = {"getrf": 0, "getrf_no_pivot": 1, "potrf": 2, "potrs": 3}
|
|
7765
|
+
cusolver_function_map = {"getrf": 0, "getrf_no_pivot": 1, "potrf": 2, "potrs": 3, "trsm": 4}
|
|
6598
7766
|
|
|
6599
7767
|
cusolver_type_map = {float32: ("wp::float32", 5), float64: ("wp::float64", 6)}
|
|
6600
7768
|
|
|
6601
7769
|
cusolver_fill_mode_map = {"upper": 0, "lower": 1}
|
|
6602
7770
|
|
|
7771
|
+
cusolver_side_map = {"-": -1, "left": 0, "right": 1}
|
|
7772
|
+
|
|
7773
|
+
cusolver_diag_map = {"-": -1, "unit": 0, "nounit": 1}
|
|
7774
|
+
|
|
6603
7775
|
|
|
6604
7776
|
def tile_cholesky_generic_lto_dispatch_func(
|
|
6605
7777
|
arg_types: Mapping[str, type],
|
|
@@ -6623,20 +7795,20 @@ def tile_cholesky_generic_lto_dispatch_func(
|
|
|
6623
7795
|
dtype, precision_enum = cusolver_type_map[a.type.dtype]
|
|
6624
7796
|
|
|
6625
7797
|
# We already ensured a is square in tile_cholesky_generic_value_func()
|
|
6626
|
-
M, N = a.type.shape
|
|
7798
|
+
M, N = a.type.shape
|
|
6627
7799
|
if out.type.shape[0] != M or out.type.shape[1] != M:
|
|
6628
7800
|
raise ValueError("tile_cholesky() output tile must be square")
|
|
6629
7801
|
|
|
6630
7802
|
solver = "potrf"
|
|
6631
7803
|
solver_enum = cusolver_function_map[solver]
|
|
6632
7804
|
|
|
6633
|
-
|
|
6634
|
-
|
|
6635
|
-
fill_mode = cusolver_fill_mode_map["
|
|
7805
|
+
side_enum = cusolver_side_map["-"]
|
|
7806
|
+
diag_enum = cusolver_diag_map["-"]
|
|
7807
|
+
fill_mode = cusolver_fill_mode_map["lower"]
|
|
6636
7808
|
|
|
6637
7809
|
arch = options["output_arch"]
|
|
6638
7810
|
num_threads = options["block_dim"]
|
|
6639
|
-
parameter_list = f"({dtype}*,
|
|
7811
|
+
parameter_list = f"({dtype}*, int*)"
|
|
6640
7812
|
|
|
6641
7813
|
if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
|
|
6642
7814
|
# CPU/no-MathDx dispatch
|
|
@@ -6646,8 +7818,13 @@ def tile_cholesky_generic_lto_dispatch_func(
|
|
|
6646
7818
|
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
6647
7819
|
M,
|
|
6648
7820
|
N,
|
|
7821
|
+
1,
|
|
6649
7822
|
solver,
|
|
6650
7823
|
solver_enum,
|
|
7824
|
+
side_enum,
|
|
7825
|
+
diag_enum,
|
|
7826
|
+
a.type.layout,
|
|
7827
|
+
out.type.layout,
|
|
6651
7828
|
fill_mode,
|
|
6652
7829
|
arch,
|
|
6653
7830
|
precision_enum,
|
|
@@ -6661,20 +7838,23 @@ def tile_cholesky_generic_lto_dispatch_func(
|
|
|
6661
7838
|
|
|
6662
7839
|
add_builtin(
|
|
6663
7840
|
"tile_cholesky",
|
|
6664
|
-
input_types={"A":
|
|
7841
|
+
input_types={"A": tile(dtype=Float, shape=Tuple[int, int])},
|
|
6665
7842
|
value_func=tile_cholesky_generic_value_func,
|
|
6666
7843
|
lto_dispatch_func=tile_cholesky_generic_lto_dispatch_func,
|
|
6667
7844
|
variadic=True,
|
|
6668
7845
|
doc="""Compute the Cholesky factorization L of a matrix A.
|
|
6669
7846
|
L is lower triangular and satisfies LL^T = A.
|
|
6670
7847
|
|
|
7848
|
+
Only the lower triangular portion of A is used for the decomposition;
|
|
7849
|
+
the upper triangular part may be left unspecified.
|
|
7850
|
+
|
|
6671
7851
|
Note that computing the adjoint is not yet supported.
|
|
6672
7852
|
|
|
6673
7853
|
Supported datatypes are:
|
|
6674
7854
|
* float32
|
|
6675
7855
|
* float64
|
|
6676
7856
|
|
|
6677
|
-
:param A: A square, symmetric positive-definite, matrix.
|
|
7857
|
+
:param A: A square, symmetric positive-definite, matrix. Only the lower triangular part of A is needed; the upper part is ignored.
|
|
6678
7858
|
:returns L: A square, lower triangular, matrix, such that LL^T = A""",
|
|
6679
7859
|
group="Tile Primitives",
|
|
6680
7860
|
export=False,
|
|
@@ -6690,30 +7870,30 @@ def tile_cholesky_solve_generic_value_func(arg_types, arg_values):
|
|
|
6690
7870
|
raise TypeError("tile_cholesky_solve() requires exactly 2 positional args")
|
|
6691
7871
|
|
|
6692
7872
|
l = arg_types["L"]
|
|
6693
|
-
|
|
7873
|
+
y = arg_types["y"]
|
|
6694
7874
|
|
|
6695
7875
|
if not is_tile(l):
|
|
6696
7876
|
raise TypeError(f"tile_cholesky_solve() 'L' argument must be a tile, got {l!r}")
|
|
6697
7877
|
|
|
6698
|
-
if not is_tile(
|
|
6699
|
-
raise TypeError(f"tile_cholesky_solve() '
|
|
7878
|
+
if not is_tile(y):
|
|
7879
|
+
raise TypeError(f"tile_cholesky_solve() 'y' argument must be a tile, got {l!r}")
|
|
6700
7880
|
|
|
6701
|
-
if not types_equal(l.dtype,
|
|
6702
|
-
raise TypeError(f"tile_cholesky_solve() arguments must have the same dtype, got {l.dtype} and {
|
|
7881
|
+
if not types_equal(l.dtype, y.dtype):
|
|
7882
|
+
raise TypeError(f"tile_cholesky_solve() arguments must have the same dtype, got {l.dtype} and {y.dtype}")
|
|
6703
7883
|
|
|
6704
7884
|
if l.shape[0] != l.shape[1]:
|
|
6705
7885
|
raise ValueError("tile_cholesky_solve() 'L' argument must be square")
|
|
6706
7886
|
|
|
6707
|
-
if len(
|
|
6708
|
-
raise TypeError("tile_cholesky_solve() '
|
|
7887
|
+
if len(y.shape) > 2 or len(y.shape) < 1:
|
|
7888
|
+
raise TypeError("tile_cholesky_solve() 'y' argument must be a 1D or 2D tile")
|
|
6709
7889
|
|
|
6710
|
-
if
|
|
7890
|
+
if y.shape[0] != l.shape[0]:
|
|
6711
7891
|
raise ValueError(
|
|
6712
|
-
f"tile_cholesky_solve() '
|
|
6713
|
-
f"got {
|
|
7892
|
+
f"tile_cholesky_solve() 'y' argument must have the same number of elements as the number of rows in 'L', "
|
|
7893
|
+
f"got {y.shape[0]} elements in 'x' and {l.shape[0]} rows in 'L'"
|
|
6714
7894
|
)
|
|
6715
7895
|
|
|
6716
|
-
return
|
|
7896
|
+
return tile(dtype=l.dtype, shape=y.shape, layout=y.layout, strides=y.strides, storage="shared")
|
|
6717
7897
|
|
|
6718
7898
|
|
|
6719
7899
|
def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
@@ -6725,37 +7905,38 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
|
6725
7905
|
builder: warp.context.ModuleBuilder,
|
|
6726
7906
|
):
|
|
6727
7907
|
L = arg_values["L"]
|
|
6728
|
-
|
|
7908
|
+
y = arg_values["y"]
|
|
6729
7909
|
# force the storage type of the input variables to shared memory
|
|
6730
7910
|
L.type.storage = "shared"
|
|
6731
|
-
|
|
7911
|
+
y.type.storage = "shared"
|
|
6732
7912
|
|
|
6733
7913
|
if len(return_values) != 1:
|
|
6734
7914
|
raise TypeError(f"tile_cholesky_solve() must return exactly one value, got {len(return_values)}")
|
|
6735
7915
|
|
|
6736
|
-
|
|
7916
|
+
x = return_values[0]
|
|
6737
7917
|
|
|
6738
|
-
if any(T not in cusolver_type_map.keys() for T in [
|
|
7918
|
+
if any(T not in cusolver_type_map.keys() for T in [y.type.dtype, L.type.dtype]):
|
|
6739
7919
|
raise TypeError("tile_cholesky_solve() arguments be tiles of float64 or float32")
|
|
6740
7920
|
|
|
6741
7921
|
dtype, precision_enum = cusolver_type_map[L.type.dtype]
|
|
6742
|
-
M, N = L.type.shape
|
|
7922
|
+
M, N = L.type.shape
|
|
7923
|
+
NRHS = x.type.shape[1] if len(x.type.shape) > 1 else 1
|
|
6743
7924
|
|
|
6744
|
-
if len(
|
|
6745
|
-
raise TypeError("tile_cholesky_solve() output vector must be 1D")
|
|
7925
|
+
if len(x.type.shape) > 2 or len(x.type.shape) < 1:
|
|
7926
|
+
raise TypeError(f"tile_cholesky_solve() output vector must be 1D or 2D, got {len(x.type.shape)}-D")
|
|
6746
7927
|
|
|
6747
|
-
if
|
|
7928
|
+
if x.type.shape[0] != M:
|
|
6748
7929
|
raise ValueError(
|
|
6749
7930
|
"tile_cholesky_solve() output vector must have same number of elements as the number of rows in 'L' "
|
|
6750
|
-
f"got {
|
|
7931
|
+
f"got {x.type.shape[0]} elements in output and {M} rows in 'L'"
|
|
6751
7932
|
)
|
|
6752
7933
|
|
|
6753
7934
|
solver = "potrs"
|
|
6754
7935
|
solver_enum = cusolver_function_map[solver]
|
|
6755
7936
|
|
|
6756
|
-
|
|
6757
|
-
|
|
6758
|
-
fill_mode = cusolver_fill_mode_map["
|
|
7937
|
+
side_enum = cusolver_side_map["-"]
|
|
7938
|
+
diag_enum = cusolver_diag_map["-"]
|
|
7939
|
+
fill_mode = cusolver_fill_mode_map["lower"]
|
|
6759
7940
|
|
|
6760
7941
|
arch = options["output_arch"]
|
|
6761
7942
|
num_threads = options["block_dim"]
|
|
@@ -6763,14 +7944,19 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
|
6763
7944
|
|
|
6764
7945
|
if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
|
|
6765
7946
|
# CPU/no-MathDx dispatch
|
|
6766
|
-
return ((0, L,
|
|
7947
|
+
return ((0, L, y, x), [], [], 0)
|
|
6767
7948
|
else:
|
|
6768
7949
|
# generate the LTO
|
|
6769
7950
|
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
6770
7951
|
M,
|
|
6771
7952
|
N,
|
|
7953
|
+
NRHS,
|
|
6772
7954
|
solver,
|
|
6773
7955
|
solver_enum,
|
|
7956
|
+
side_enum,
|
|
7957
|
+
diag_enum,
|
|
7958
|
+
L.type.layout,
|
|
7959
|
+
y.type.layout,
|
|
6774
7960
|
fill_mode,
|
|
6775
7961
|
arch,
|
|
6776
7962
|
precision_enum,
|
|
@@ -6779,12 +7965,12 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
|
6779
7965
|
builder,
|
|
6780
7966
|
)
|
|
6781
7967
|
|
|
6782
|
-
return ((Var(lto_symbol, str, False, True, False), L,
|
|
7968
|
+
return ((Var(lto_symbol, str, False, True, False), L, y, x), [], [lto_code_data], 0)
|
|
6783
7969
|
|
|
6784
7970
|
|
|
6785
7971
|
add_builtin(
|
|
6786
7972
|
"tile_cholesky_solve",
|
|
6787
|
-
input_types={"L":
|
|
7973
|
+
input_types={"L": tile(dtype=Float, shape=Tuple[int, int]), "y": tile(dtype=Float, shape=Tuple[int])},
|
|
6788
7974
|
value_func=tile_cholesky_solve_generic_value_func,
|
|
6789
7975
|
lto_dispatch_func=tile_cholesky_solve_generic_lto_dispatch_func,
|
|
6790
7976
|
variadic=True,
|
|
@@ -6797,13 +7983,276 @@ add_builtin(
|
|
|
6797
7983
|
* float64
|
|
6798
7984
|
|
|
6799
7985
|
:param L: A square, lower triangular, matrix, such that LL^T = A
|
|
6800
|
-
:param
|
|
6801
|
-
:returns
|
|
7986
|
+
:param y: A 1D or 2D tile of length M
|
|
7987
|
+
:returns x: A tile of the same shape as y such that LL^T x = y""",
|
|
7988
|
+
group="Tile Primitives",
|
|
7989
|
+
export=False,
|
|
7990
|
+
namespace="",
|
|
7991
|
+
)
|
|
7992
|
+
|
|
7993
|
+
|
|
7994
|
+
def tile_lower_solve_generic_lto_dispatch_func(
|
|
7995
|
+
arg_types: Mapping[str, type],
|
|
7996
|
+
return_type: Any,
|
|
7997
|
+
return_values: List[Var],
|
|
7998
|
+
arg_values: Mapping[str, Var],
|
|
7999
|
+
options: Mapping[str, Any],
|
|
8000
|
+
builder: warp.context.ModuleBuilder,
|
|
8001
|
+
):
|
|
8002
|
+
L = arg_values["L"]
|
|
8003
|
+
y = arg_values["y"]
|
|
8004
|
+
# force the storage type of the input variables to shared memory
|
|
8005
|
+
L.type.storage = "shared"
|
|
8006
|
+
y.type.storage = "shared"
|
|
8007
|
+
|
|
8008
|
+
if any(T not in cusolver_type_map.keys() for T in [y.type.dtype, L.type.dtype]):
|
|
8009
|
+
raise TypeError("tile_lower_solve() arguments must be tiles of float64 or float32")
|
|
8010
|
+
|
|
8011
|
+
if len(return_values) != 1:
|
|
8012
|
+
raise TypeError(f"tile_lower_solve() must return exactly one value, got {len(return_values)}")
|
|
8013
|
+
|
|
8014
|
+
z = return_values[0]
|
|
8015
|
+
|
|
8016
|
+
dtype, precision_enum = cusolver_type_map[L.type.dtype]
|
|
8017
|
+
M, N = L.type.shape
|
|
8018
|
+
NRHS = z.type.shape[1] if len(z.type.shape) > 1 else 1
|
|
8019
|
+
|
|
8020
|
+
if len(z.type.shape) > 2 or len(z.type.shape) < 1:
|
|
8021
|
+
raise TypeError(f"tile_lower_solve() output vector must be 1D or 2D, got {len(z.type.shape)}-D")
|
|
8022
|
+
|
|
8023
|
+
if z.type.shape[0] != M:
|
|
8024
|
+
raise ValueError(
|
|
8025
|
+
"tile_lower_solve() output vector must have same number of elements as the number of rows in 'L' "
|
|
8026
|
+
f"got {z.type.shape[0]} elements in output and {M} rows in 'L'"
|
|
8027
|
+
)
|
|
8028
|
+
|
|
8029
|
+
solver = "trsm"
|
|
8030
|
+
solver_enum = cusolver_function_map[solver]
|
|
8031
|
+
|
|
8032
|
+
side_enum = cusolver_side_map["left"]
|
|
8033
|
+
diag_enum = cusolver_diag_map["nounit"]
|
|
8034
|
+
fill_mode = cusolver_fill_mode_map["lower"]
|
|
8035
|
+
|
|
8036
|
+
arch = options["output_arch"]
|
|
8037
|
+
num_threads = options["block_dim"]
|
|
8038
|
+
parameter_list = f"({dtype}*, {dtype}*)"
|
|
8039
|
+
|
|
8040
|
+
if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
|
|
8041
|
+
# CPU/no-MathDx dispatch
|
|
8042
|
+
return ((0, L, y, z), [], [], 0)
|
|
8043
|
+
else:
|
|
8044
|
+
# generate the LTO
|
|
8045
|
+
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
8046
|
+
M,
|
|
8047
|
+
N,
|
|
8048
|
+
NRHS,
|
|
8049
|
+
solver,
|
|
8050
|
+
solver_enum,
|
|
8051
|
+
side_enum,
|
|
8052
|
+
diag_enum,
|
|
8053
|
+
L.type.layout,
|
|
8054
|
+
y.type.layout,
|
|
8055
|
+
fill_mode,
|
|
8056
|
+
arch,
|
|
8057
|
+
precision_enum,
|
|
8058
|
+
num_threads,
|
|
8059
|
+
parameter_list,
|
|
8060
|
+
builder,
|
|
8061
|
+
)
|
|
8062
|
+
|
|
8063
|
+
return ((Var(lto_symbol, str, False, True, False), L, y, z), [], [lto_code_data], 0)
|
|
8064
|
+
|
|
8065
|
+
|
|
8066
|
+
def tile_lower_solve_generic_value_func(arg_types, arg_values):
|
|
8067
|
+
if arg_types is None:
|
|
8068
|
+
return tile(dtype=Float, shape=Tuple[int])
|
|
8069
|
+
|
|
8070
|
+
if len(arg_types) != 2:
|
|
8071
|
+
raise TypeError("tile_lower_solve() requires exactly 2 positional args")
|
|
8072
|
+
|
|
8073
|
+
l = arg_types["L"]
|
|
8074
|
+
y = arg_types["y"]
|
|
8075
|
+
|
|
8076
|
+
if not is_tile(l):
|
|
8077
|
+
raise TypeError(f"tile_lower_solve() 'L' argument must be a tile, got {l!r}")
|
|
8078
|
+
|
|
8079
|
+
if not is_tile(y):
|
|
8080
|
+
raise TypeError(f"tile_lower_solve() 'y' argument must be a tile, got {y!r}")
|
|
8081
|
+
|
|
8082
|
+
if not types_equal(l.dtype, y.dtype):
|
|
8083
|
+
raise TypeError(f"tile_lower_solve() arguments must have the same dtype, got {l.dtype} and {y.dtype}")
|
|
8084
|
+
|
|
8085
|
+
if l.shape[0] != l.shape[1]:
|
|
8086
|
+
raise ValueError("tile_lower_solve() 'L' argument must be square")
|
|
8087
|
+
|
|
8088
|
+
if len(y.shape) > 2 or len(y.shape) < 1:
|
|
8089
|
+
raise TypeError("tile_lower_solve() 'y' argument must be a 1D or 2D tile")
|
|
8090
|
+
|
|
8091
|
+
if y.shape[0] != l.shape[0]:
|
|
8092
|
+
raise ValueError(
|
|
8093
|
+
f"tile_lower_solve() 'y' argument must have the same number of elements as the number of rows in 'L', "
|
|
8094
|
+
f"got {y.shape[0]} elements in 'y' and {l.shape[0]} rows in 'L'"
|
|
8095
|
+
)
|
|
8096
|
+
|
|
8097
|
+
return tile(dtype=l.dtype, shape=y.shape, layout=y.layout, strides=y.strides, storage="shared")
|
|
8098
|
+
|
|
8099
|
+
|
|
8100
|
+
add_builtin(
|
|
8101
|
+
"tile_lower_solve",
|
|
8102
|
+
input_types={"L": tile(dtype=Float, shape=Tuple[int, int]), "y": tile(dtype=Float, shape=Tuple[int])},
|
|
8103
|
+
value_func=tile_lower_solve_generic_value_func,
|
|
8104
|
+
lto_dispatch_func=tile_lower_solve_generic_lto_dispatch_func,
|
|
8105
|
+
variadic=True,
|
|
8106
|
+
doc="""Solve for z in Lz = y, where L is a lower triangular matrix.
|
|
8107
|
+
|
|
8108
|
+
This performs general forward substitution for a lower triangular system.
|
|
8109
|
+
|
|
8110
|
+
Note that computing the adjoint is not yet supported.
|
|
8111
|
+
|
|
8112
|
+
Supported datatypes are:
|
|
8113
|
+
* float32
|
|
8114
|
+
* float64
|
|
8115
|
+
|
|
8116
|
+
:param L: A square, non-singular, lower triangular matrix
|
|
8117
|
+
:param y: A 1D or 2D tile with compatible shape
|
|
8118
|
+
:returns z: A tile of the same shape as y such that Lz = y""",
|
|
8119
|
+
group="Tile Primitives",
|
|
8120
|
+
export=False,
|
|
8121
|
+
namespace="",
|
|
8122
|
+
)
|
|
8123
|
+
|
|
8124
|
+
|
|
8125
|
+
def tile_upper_solve_generic_lto_dispatch_func(
|
|
8126
|
+
arg_types: Mapping[str, type],
|
|
8127
|
+
return_type: Any,
|
|
8128
|
+
return_values: List[Var],
|
|
8129
|
+
arg_values: Mapping[str, Var],
|
|
8130
|
+
options: Mapping[str, Any],
|
|
8131
|
+
builder: warp.context.ModuleBuilder,
|
|
8132
|
+
):
|
|
8133
|
+
U = arg_values["U"]
|
|
8134
|
+
z = arg_values["z"]
|
|
8135
|
+
# force the storage type of the input variables to shared memory
|
|
8136
|
+
U.type.storage = "shared"
|
|
8137
|
+
z.type.storage = "shared"
|
|
8138
|
+
|
|
8139
|
+
if any(T not in cusolver_type_map.keys() for T in [z.type.dtype, U.type.dtype]):
|
|
8140
|
+
raise TypeError("tile_upper_solve() arguments must be tiles of float64 or float32")
|
|
8141
|
+
|
|
8142
|
+
if len(return_values) != 1:
|
|
8143
|
+
raise TypeError(f"tile_upper_solve() must return exactly one value, got {len(return_values)}")
|
|
8144
|
+
|
|
8145
|
+
x = return_values[0]
|
|
8146
|
+
|
|
8147
|
+
dtype, precision_enum = cusolver_type_map[U.type.dtype]
|
|
8148
|
+
M, N = U.type.shape
|
|
8149
|
+
NRHS = x.type.shape[1] if len(x.type.shape) > 1 else 1
|
|
8150
|
+
|
|
8151
|
+
if len(z.type.shape) > 2 or len(z.type.shape) < 1:
|
|
8152
|
+
raise TypeError(f"tile_upper_solve() output tile must be 1D or 2D, got {len(z.type.shape)}-D")
|
|
8153
|
+
|
|
8154
|
+
if z.type.shape[0] != M:
|
|
8155
|
+
raise ValueError(
|
|
8156
|
+
"tile_upper_solve() output tile must have same number of elements as the number of rows in 'U' "
|
|
8157
|
+
f"got {z.type.shape[0]} elements in output and {M} rows in 'U'"
|
|
8158
|
+
)
|
|
8159
|
+
|
|
8160
|
+
solver = "trsm"
|
|
8161
|
+
solver_enum = cusolver_function_map[solver]
|
|
8162
|
+
|
|
8163
|
+
side_enum = cusolver_side_map["left"]
|
|
8164
|
+
diag_enum = cusolver_diag_map["nounit"]
|
|
8165
|
+
fill_mode = cusolver_fill_mode_map["upper"]
|
|
8166
|
+
|
|
8167
|
+
arch = options["output_arch"]
|
|
8168
|
+
num_threads = options["block_dim"]
|
|
8169
|
+
parameter_list = f"({dtype}*, {dtype}*)"
|
|
8170
|
+
|
|
8171
|
+
if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
|
|
8172
|
+
# CPU/no-MathDx dispatch
|
|
8173
|
+
return ((0, U, z, x), [], [], 0)
|
|
8174
|
+
else:
|
|
8175
|
+
# generate the LTO
|
|
8176
|
+
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
8177
|
+
M,
|
|
8178
|
+
N,
|
|
8179
|
+
NRHS,
|
|
8180
|
+
solver,
|
|
8181
|
+
solver_enum,
|
|
8182
|
+
side_enum,
|
|
8183
|
+
diag_enum,
|
|
8184
|
+
U.type.layout,
|
|
8185
|
+
z.type.layout,
|
|
8186
|
+
fill_mode,
|
|
8187
|
+
arch,
|
|
8188
|
+
precision_enum,
|
|
8189
|
+
num_threads,
|
|
8190
|
+
parameter_list,
|
|
8191
|
+
builder,
|
|
8192
|
+
)
|
|
8193
|
+
|
|
8194
|
+
return ((Var(lto_symbol, str, False, True, False), U, z, x), [], [lto_code_data], 0)
|
|
8195
|
+
|
|
8196
|
+
|
|
8197
|
+
def tile_upper_solve_generic_value_func(arg_types, arg_values):
|
|
8198
|
+
if arg_types is None:
|
|
8199
|
+
return tile(dtype=Float, shape=Tuple[int])
|
|
8200
|
+
|
|
8201
|
+
if len(arg_types) != 2:
|
|
8202
|
+
raise TypeError("tile_upper_solve() requires exactly 2 positional args")
|
|
8203
|
+
|
|
8204
|
+
u = arg_types["U"]
|
|
8205
|
+
z = arg_types["z"]
|
|
8206
|
+
|
|
8207
|
+
if not is_tile(u):
|
|
8208
|
+
raise TypeError(f"tile_upper_solve() 'U' argument must be a tile, got {u!r}")
|
|
8209
|
+
|
|
8210
|
+
if not is_tile(z):
|
|
8211
|
+
raise TypeError(f"tile_upper_solve() 'z' argument must be a tile, got {z!r}")
|
|
8212
|
+
|
|
8213
|
+
if not types_equal(u.dtype, z.dtype):
|
|
8214
|
+
raise TypeError(f"tile_upper_solve() arguments must have the same dtype, got {u.dtype} and {z.dtype}")
|
|
8215
|
+
|
|
8216
|
+
if u.shape[0] != u.shape[1]:
|
|
8217
|
+
raise ValueError("tile_upper_solve() 'U' argument must be square")
|
|
8218
|
+
|
|
8219
|
+
if len(z.shape) > 2 or len(z.shape) < 1:
|
|
8220
|
+
raise TypeError("tile_upper_solve() 'z' argument must be a 1D or 2D tile")
|
|
8221
|
+
|
|
8222
|
+
if z.shape[0] != u.shape[0]:
|
|
8223
|
+
raise ValueError(
|
|
8224
|
+
f"tile_upper_solve() 'z' argument must have the same number of elements as the number of rows in 'U', "
|
|
8225
|
+
f"got {z.shape[0]} elements in 'z' and {u.shape[0]} rows in 'U'"
|
|
8226
|
+
)
|
|
8227
|
+
|
|
8228
|
+
return tile(dtype=u.dtype, shape=z.shape, layout=z.layout, strides=z.strides, storage="shared")
|
|
8229
|
+
|
|
8230
|
+
|
|
8231
|
+
add_builtin(
|
|
8232
|
+
"tile_upper_solve",
|
|
8233
|
+
input_types={"U": tile(dtype=Float, shape=Tuple[int, int]), "z": tile(dtype=Float, shape=Tuple[int])},
|
|
8234
|
+
value_func=tile_upper_solve_generic_value_func,
|
|
8235
|
+
lto_dispatch_func=tile_upper_solve_generic_lto_dispatch_func,
|
|
8236
|
+
variadic=True,
|
|
8237
|
+
doc="""Solve for x in U x = z, where U is an upper triangular matrix.
|
|
8238
|
+
|
|
8239
|
+
This performs general back substitution for upper triangular systems.
|
|
8240
|
+
|
|
8241
|
+
Note that computing the adjoint is not yet supported.
|
|
8242
|
+
|
|
8243
|
+
Supported datatypes are:
|
|
8244
|
+
* float32
|
|
8245
|
+
* float64
|
|
8246
|
+
|
|
8247
|
+
:param U: A square, non-singular, upper triangular matrix
|
|
8248
|
+
:param z: A 1D or 2D tile with compatible shape
|
|
8249
|
+
:returns x: A tile of the same shape as z such that U x = z""",
|
|
6802
8250
|
group="Tile Primitives",
|
|
6803
8251
|
export=False,
|
|
6804
8252
|
namespace="",
|
|
6805
8253
|
)
|
|
6806
8254
|
|
|
8255
|
+
|
|
6807
8256
|
# ---------------------------------
|
|
6808
8257
|
# Code Generation
|
|
6809
8258
|
|
|
@@ -6840,7 +8289,7 @@ def static(expr):
|
|
|
6840
8289
|
add_builtin(
|
|
6841
8290
|
"len",
|
|
6842
8291
|
input_types={"a": vector(length=Any, dtype=Scalar)},
|
|
6843
|
-
|
|
8292
|
+
value_func=static_len_value_func,
|
|
6844
8293
|
doc="Return the number of elements in a vector.",
|
|
6845
8294
|
group="Utility",
|
|
6846
8295
|
export=False,
|
|
@@ -6849,7 +8298,7 @@ add_builtin(
|
|
|
6849
8298
|
add_builtin(
|
|
6850
8299
|
"len",
|
|
6851
8300
|
input_types={"a": quaternion(dtype=Scalar)},
|
|
6852
|
-
|
|
8301
|
+
value_func=static_len_value_func,
|
|
6853
8302
|
doc="Return the number of elements in a quaternion.",
|
|
6854
8303
|
group="Utility",
|
|
6855
8304
|
export=False,
|
|
@@ -6858,7 +8307,7 @@ add_builtin(
|
|
|
6858
8307
|
add_builtin(
|
|
6859
8308
|
"len",
|
|
6860
8309
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar)},
|
|
6861
|
-
|
|
8310
|
+
value_func=static_len_value_func,
|
|
6862
8311
|
doc="Return the number of rows in a matrix.",
|
|
6863
8312
|
group="Utility",
|
|
6864
8313
|
export=False,
|
|
@@ -6867,7 +8316,7 @@ add_builtin(
|
|
|
6867
8316
|
add_builtin(
|
|
6868
8317
|
"len",
|
|
6869
8318
|
input_types={"a": transformation(dtype=Float)},
|
|
6870
|
-
|
|
8319
|
+
value_func=static_len_value_func,
|
|
6871
8320
|
doc="Return the number of elements in a transformation.",
|
|
6872
8321
|
group="Utility",
|
|
6873
8322
|
export=False,
|
|
@@ -6884,9 +8333,83 @@ add_builtin(
|
|
|
6884
8333
|
|
|
6885
8334
|
add_builtin(
|
|
6886
8335
|
"len",
|
|
6887
|
-
input_types={"a":
|
|
6888
|
-
|
|
8336
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
8337
|
+
value_func=static_len_value_func,
|
|
6889
8338
|
doc="Return the number of rows in a tile.",
|
|
6890
8339
|
group="Utility",
|
|
6891
8340
|
export=False,
|
|
6892
8341
|
)
|
|
8342
|
+
|
|
8343
|
+
|
|
8344
|
+
# ---------------------------------
|
|
8345
|
+
# Tuple
|
|
8346
|
+
|
|
8347
|
+
|
|
8348
|
+
def tuple_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
8349
|
+
return tuple_t(arg_types["args"], arg_values["args"])
|
|
8350
|
+
|
|
8351
|
+
|
|
8352
|
+
def tuple_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
8353
|
+
func_args = args.get("args", ())
|
|
8354
|
+
template_args = ()
|
|
8355
|
+
return (func_args, template_args)
|
|
8356
|
+
|
|
8357
|
+
|
|
8358
|
+
add_builtin(
|
|
8359
|
+
"tuple",
|
|
8360
|
+
input_types={"*args": Any},
|
|
8361
|
+
value_func=tuple_value_func,
|
|
8362
|
+
dispatch_func=tuple_dispatch_func,
|
|
8363
|
+
variadic=True,
|
|
8364
|
+
doc="Construct a tuple from a list of values",
|
|
8365
|
+
group="Utility",
|
|
8366
|
+
hidden=True,
|
|
8367
|
+
missing_grad=True,
|
|
8368
|
+
export=False,
|
|
8369
|
+
)
|
|
8370
|
+
|
|
8371
|
+
|
|
8372
|
+
def tuple_extract_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
8373
|
+
tuple_type = arg_types["a"]
|
|
8374
|
+
elements = tuple_type.types if is_tuple(tuple_type) else tuple_type
|
|
8375
|
+
|
|
8376
|
+
if "i" not in arg_values:
|
|
8377
|
+
raise RuntimeError("Tuple index must be a compile time expression.")
|
|
8378
|
+
|
|
8379
|
+
index = arg_values["i"]
|
|
8380
|
+
if isinstance(index, Var):
|
|
8381
|
+
raise RuntimeError("Tuple index must be a compile time expression.")
|
|
8382
|
+
|
|
8383
|
+
length = len(elements)
|
|
8384
|
+
if index >= length:
|
|
8385
|
+
raise RuntimeError(f"Tuple index out of bounds, {index} >= {length}")
|
|
8386
|
+
|
|
8387
|
+
value_type = elements[index]
|
|
8388
|
+
return value_type
|
|
8389
|
+
|
|
8390
|
+
|
|
8391
|
+
def tuple_extract_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
8392
|
+
func_args = (args["a"],)
|
|
8393
|
+
template_args = (args["i"].constant,)
|
|
8394
|
+
return (func_args, template_args)
|
|
8395
|
+
|
|
8396
|
+
|
|
8397
|
+
add_builtin(
|
|
8398
|
+
"extract",
|
|
8399
|
+
input_types={"a": Tuple, "i": int},
|
|
8400
|
+
value_func=tuple_extract_value_func,
|
|
8401
|
+
dispatch_func=tuple_extract_dispatch_func,
|
|
8402
|
+
group="Utility",
|
|
8403
|
+
hidden=True,
|
|
8404
|
+
missing_grad=True,
|
|
8405
|
+
)
|
|
8406
|
+
|
|
8407
|
+
|
|
8408
|
+
add_builtin(
|
|
8409
|
+
"len",
|
|
8410
|
+
input_types={"a": Tuple},
|
|
8411
|
+
value_func=static_len_value_func,
|
|
8412
|
+
doc="Return the number of elements in a tuple.",
|
|
8413
|
+
group="Utility",
|
|
8414
|
+
export=False,
|
|
8415
|
+
)
|