warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +7 -1
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +410 -0
- warp/build_dll.py +6 -14
- warp/builtins.py +452 -362
- warp/codegen.py +179 -119
- warp/config.py +42 -6
- warp/context.py +490 -271
- warp/dlpack.py +8 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +2 -2
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_magnetostatics.py +6 -6
- warp/examples/fem/utils.py +9 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/tile/example_tile_matmul.py +2 -4
- warp/fem/__init__.py +11 -1
- warp/fem/adaptivity.py +4 -4
- warp/fem/field/nodal_field.py +22 -68
- warp/fem/field/virtual.py +62 -23
- warp/fem/geometry/adaptive_nanogrid.py +9 -10
- warp/fem/geometry/closest_point.py +1 -1
- warp/fem/geometry/deformed_geometry.py +5 -2
- warp/fem/geometry/geometry.py +5 -0
- warp/fem/geometry/grid_2d.py +12 -12
- warp/fem/geometry/grid_3d.py +12 -15
- warp/fem/geometry/hexmesh.py +5 -7
- warp/fem/geometry/nanogrid.py +9 -11
- warp/fem/geometry/quadmesh.py +13 -13
- warp/fem/geometry/tetmesh.py +3 -4
- warp/fem/geometry/trimesh.py +3 -8
- warp/fem/integrate.py +262 -93
- warp/fem/linalg.py +5 -5
- warp/fem/quadrature/pic_quadrature.py +37 -22
- warp/fem/quadrature/quadrature.py +194 -25
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +4 -2
- warp/fem/space/basis_space.py +25 -18
- warp/fem/space/hexmesh_function_space.py +2 -2
- warp/fem/space/partition.py +6 -2
- warp/fem/space/quadmesh_function_space.py +8 -8
- warp/fem/space/shape/cube_shape_function.py +23 -23
- warp/fem/space/shape/square_shape_function.py +12 -12
- warp/fem/space/shape/triangle_shape_function.py +1 -1
- warp/fem/space/tetmesh_function_space.py +3 -3
- warp/fem/space/trimesh_function_space.py +2 -2
- warp/fem/utils.py +12 -6
- warp/jax.py +14 -1
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +89 -0
- warp/native/array.h +13 -0
- warp/native/builtin.h +29 -3
- warp/native/bvh.cpp +3 -1
- warp/native/bvh.cu +42 -14
- warp/native/bvh.h +2 -1
- warp/native/clang/clang.cpp +30 -3
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -0
- warp/native/exports.h +68 -63
- warp/native/intersect.h +26 -26
- warp/native/intersect_adj.h +33 -33
- warp/native/marching.cu +1 -1
- warp/native/mat.h +513 -9
- warp/native/mesh.h +10 -10
- warp/native/quat.h +99 -11
- warp/native/rand.h +6 -0
- warp/native/sort.cpp +122 -59
- warp/native/sort.cu +152 -15
- warp/native/sort.h +8 -1
- warp/native/sparse.cpp +43 -22
- warp/native/sparse.cu +52 -17
- warp/native/svd.h +116 -0
- warp/native/tile.h +301 -105
- warp/native/tile_reduce.h +46 -3
- warp/native/vec.h +68 -7
- warp/native/volume.cpp +85 -113
- warp/native/volume_builder.cu +25 -10
- warp/native/volume_builder.h +6 -0
- warp/native/warp.cpp +5 -6
- warp/native/warp.cu +99 -10
- warp/native/warp.h +19 -10
- warp/optim/linear.py +10 -10
- warp/sim/articulation.py +4 -4
- warp/sim/collide.py +21 -10
- warp/sim/import_mjcf.py +449 -155
- warp/sim/import_urdf.py +32 -12
- warp/sim/integrator_euler.py +5 -5
- warp/sim/integrator_featherstone.py +3 -10
- warp/sim/integrator_vbd.py +207 -2
- warp/sim/integrator_xpbd.py +5 -5
- warp/sim/model.py +42 -13
- warp/sim/utils.py +2 -2
- warp/sparse.py +642 -555
- warp/stubs.py +216 -19
- warp/tests/__main__.py +0 -15
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
- warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
- warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
- warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
- warp/tests/{test_model.py → sim/test_model.py} +40 -0
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_bool.py +1 -1
- warp/tests/test_examples.py +28 -36
- warp/tests/test_fem.py +23 -4
- warp/tests/test_linear_solvers.py +0 -11
- warp/tests/test_mat.py +233 -79
- warp/tests/test_mat_scalar_ops.py +4 -4
- warp/tests/test_overwrite.py +0 -60
- warp/tests/test_quat.py +67 -46
- warp/tests/test_rand.py +44 -37
- warp/tests/test_sparse.py +47 -6
- warp/tests/test_spatial.py +75 -0
- warp/tests/test_static.py +1 -1
- warp/tests/test_utils.py +84 -4
- warp/tests/test_vec.py +46 -34
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
- warp/tests/unittest_serial.py +1 -0
- warp/tests/unittest_suites.py +45 -59
- warp/tests/unittest_utils.py +2 -1
- warp/thirdparty/unittest_parallel.py +3 -1
- warp/types.py +110 -658
- warp/utils.py +137 -72
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
- warp/examples/optim/example_walker.py +0 -317
- warp/native/cutlass_gemm.cpp +0 -43
- warp/native/cutlass_gemm.cu +0 -382
- warp/tests/test_matmul.py +0 -511
- warp/tests/test_matmul_lite.py +0 -411
- warp/tests/test_vbd.py +0 -386
- warp/tests/unused_test_misc.py +0 -77
- /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
- /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
- /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
- /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
- /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
- /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
- /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
- /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
- /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
- /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
- /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
- /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
- /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
- /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
- /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
- /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
- /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
- /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/builtins.py
CHANGED
|
@@ -15,10 +15,10 @@
|
|
|
15
15
|
|
|
16
16
|
import builtins
|
|
17
17
|
import functools
|
|
18
|
-
import tempfile
|
|
19
|
-
from pathlib import Path
|
|
20
18
|
from typing import Any, Callable, Mapping, Sequence
|
|
21
19
|
|
|
20
|
+
import warp.build
|
|
21
|
+
import warp.context
|
|
22
22
|
from warp.codegen import Reference, Var, strip_reference
|
|
23
23
|
from warp.types import *
|
|
24
24
|
|
|
@@ -41,7 +41,7 @@ def sametypes(arg_types: Mapping[str, Any]):
|
|
|
41
41
|
return all(types_equal(arg_type_0, t) for t in arg_types_iter)
|
|
42
42
|
|
|
43
43
|
|
|
44
|
-
def sametypes_create_value_func(default):
|
|
44
|
+
def sametypes_create_value_func(default: TypeVar):
|
|
45
45
|
def fn(arg_types, arg_values):
|
|
46
46
|
if arg_types is None:
|
|
47
47
|
return default
|
|
@@ -399,7 +399,7 @@ add_builtin(
|
|
|
399
399
|
)
|
|
400
400
|
|
|
401
401
|
|
|
402
|
-
def scalar_infer_type(arg_types: Mapping[str, type]):
|
|
402
|
+
def scalar_infer_type(arg_types: Union[Mapping[str, type], Tuple[type, ...], None]):
|
|
403
403
|
if arg_types is None:
|
|
404
404
|
return Scalar
|
|
405
405
|
|
|
@@ -950,6 +950,12 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
950
950
|
raise RuntimeError("the `shape` argument must be specified when initializing a matrix by value")
|
|
951
951
|
|
|
952
952
|
if all(type_is_vector(x) for x in variadic_arg_types):
|
|
953
|
+
warp.utils.warn(
|
|
954
|
+
"the built-in `wp.matrix()` won't support taking column vectors as input "
|
|
955
|
+
"in the future. Use `wp.matrix_from_rows()` or `wp.matrix_from_cols()` instead.",
|
|
956
|
+
DeprecationWarning,
|
|
957
|
+
)
|
|
958
|
+
|
|
953
959
|
if shape[1] != variadic_arg_count:
|
|
954
960
|
raise RuntimeError(
|
|
955
961
|
f"incompatible number of column vectors given ({variadic_arg_count}) "
|
|
@@ -1030,6 +1036,86 @@ add_builtin(
|
|
|
1030
1036
|
)
|
|
1031
1037
|
|
|
1032
1038
|
|
|
1039
|
+
def matrix_from_vecs_create_value_func(cols: bool):
|
|
1040
|
+
def fn(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1041
|
+
if arg_types is None:
|
|
1042
|
+
return matrix(shape=(Any, Any), dtype=Scalar)
|
|
1043
|
+
|
|
1044
|
+
variadic_arg_types = arg_types.get("args", ())
|
|
1045
|
+
variadic_arg_count = len(variadic_arg_types)
|
|
1046
|
+
|
|
1047
|
+
if not all(type_is_vector(x) for x in variadic_arg_types):
|
|
1048
|
+
raise RuntimeError("all arguments are expected to be vectors")
|
|
1049
|
+
|
|
1050
|
+
length = variadic_arg_types[0]._length_
|
|
1051
|
+
if any(x._length_ != length for x in variadic_arg_types):
|
|
1052
|
+
raise RuntimeError("all vectors are expected to have the same length")
|
|
1053
|
+
|
|
1054
|
+
dtype = variadic_arg_types[0]._wp_scalar_type_
|
|
1055
|
+
if any(x._wp_scalar_type_ != dtype for x in variadic_arg_types):
|
|
1056
|
+
raise RuntimeError("all vectors are expected to have the same dtype")
|
|
1057
|
+
|
|
1058
|
+
shape = (length, variadic_arg_count) if cols else (variadic_arg_count, length)
|
|
1059
|
+
return matrix(shape=shape, dtype=dtype)
|
|
1060
|
+
|
|
1061
|
+
return fn
|
|
1062
|
+
|
|
1063
|
+
|
|
1064
|
+
def matrix_from_vecs_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
1065
|
+
# We're in the codegen stage where we emit the code calling the built-in.
|
|
1066
|
+
# Further validate the given argument values if needed and map them
|
|
1067
|
+
# to the underlying C++ function's runtime and template params.
|
|
1068
|
+
|
|
1069
|
+
shape = return_type._shape_
|
|
1070
|
+
dtype = return_type._wp_scalar_type_
|
|
1071
|
+
|
|
1072
|
+
variadic_args = args.get("args", ())
|
|
1073
|
+
|
|
1074
|
+
func_args = variadic_args
|
|
1075
|
+
|
|
1076
|
+
if shape in ((2, 2), (3, 3), (4, 4)):
|
|
1077
|
+
# Template specializations exist for these shapes, don't pass them
|
|
1078
|
+
# as template parameters.
|
|
1079
|
+
template_args = (dtype,)
|
|
1080
|
+
else:
|
|
1081
|
+
template_args = (*shape, dtype)
|
|
1082
|
+
|
|
1083
|
+
return (func_args, template_args)
|
|
1084
|
+
|
|
1085
|
+
|
|
1086
|
+
def matrix_from_vecs_initializer_list_func(args, return_type):
|
|
1087
|
+
shape = return_type._shape_
|
|
1088
|
+
|
|
1089
|
+
return shape[0] != shape[1] or shape[0] > 4
|
|
1090
|
+
|
|
1091
|
+
|
|
1092
|
+
add_builtin(
|
|
1093
|
+
"matrix_from_cols",
|
|
1094
|
+
input_types={"*args": vector(length=Any, dtype=Scalar)},
|
|
1095
|
+
variadic=True,
|
|
1096
|
+
value_func=matrix_from_vecs_create_value_func(cols=True),
|
|
1097
|
+
dispatch_func=matrix_from_vecs_dispatch_func,
|
|
1098
|
+
initializer_list_func=matrix_from_vecs_initializer_list_func,
|
|
1099
|
+
native_func="matrix_from_cols",
|
|
1100
|
+
doc="Construct a matrix from column vectors.",
|
|
1101
|
+
group="Vector Math",
|
|
1102
|
+
export=False,
|
|
1103
|
+
)
|
|
1104
|
+
|
|
1105
|
+
add_builtin(
|
|
1106
|
+
"matrix_from_rows",
|
|
1107
|
+
input_types={"*args": vector(length=Any, dtype=Scalar)},
|
|
1108
|
+
variadic=True,
|
|
1109
|
+
value_func=matrix_from_vecs_create_value_func(cols=False),
|
|
1110
|
+
dispatch_func=matrix_from_vecs_dispatch_func,
|
|
1111
|
+
initializer_list_func=matrix_from_vecs_initializer_list_func,
|
|
1112
|
+
native_func="matrix_from_rows",
|
|
1113
|
+
doc="Construct a matrix from row vectors.",
|
|
1114
|
+
group="Vector Math",
|
|
1115
|
+
export=False,
|
|
1116
|
+
)
|
|
1117
|
+
|
|
1118
|
+
|
|
1033
1119
|
def identity_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1034
1120
|
if arg_types is None:
|
|
1035
1121
|
return matrix(shape=(Any, Any), dtype=Scalar)
|
|
@@ -1141,6 +1227,21 @@ add_builtin(
|
|
|
1141
1227
|
while the left and right basis vectors are returned in ``U`` and ``V``.""",
|
|
1142
1228
|
)
|
|
1143
1229
|
|
|
1230
|
+
add_builtin(
|
|
1231
|
+
"svd2",
|
|
1232
|
+
input_types={
|
|
1233
|
+
"A": matrix(shape=(2, 2), dtype=Float),
|
|
1234
|
+
"U": matrix(shape=(2, 2), dtype=Float),
|
|
1235
|
+
"sigma": vector(length=2, dtype=Float),
|
|
1236
|
+
"V": matrix(shape=(2, 2), dtype=Scalar),
|
|
1237
|
+
},
|
|
1238
|
+
value_type=None,
|
|
1239
|
+
group="Vector Math",
|
|
1240
|
+
export=False,
|
|
1241
|
+
doc="""Compute the SVD of a 2x2 matrix ``A``. The singular values are returned in ``sigma``,
|
|
1242
|
+
while the left and right basis vectors are returned in ``U`` and ``V``.""",
|
|
1243
|
+
)
|
|
1244
|
+
|
|
1144
1245
|
add_builtin(
|
|
1145
1246
|
"qr3",
|
|
1146
1247
|
input_types={
|
|
@@ -1332,7 +1433,18 @@ add_builtin(
|
|
|
1332
1433
|
input_types={"mat": matrix(shape=(3, 3), dtype=Float)},
|
|
1333
1434
|
value_func=lambda arg_types, arg_values: quaternion(dtype=float_infer_type(arg_types)),
|
|
1334
1435
|
group="Quaternion Math",
|
|
1335
|
-
doc="Construct a quaternion from a 3x3 matrix.
|
|
1436
|
+
doc="""Construct a quaternion from a 3x3 matrix.
|
|
1437
|
+
|
|
1438
|
+
If the matrix is not a pure rotation, but for example includes scaling or skewing, the result is undefined.""",
|
|
1439
|
+
)
|
|
1440
|
+
add_builtin(
|
|
1441
|
+
"quat_from_matrix",
|
|
1442
|
+
input_types={"mat": matrix(shape=(4, 4), dtype=Float)},
|
|
1443
|
+
value_func=lambda arg_types, arg_values: quaternion(dtype=float_infer_type(arg_types)),
|
|
1444
|
+
group="Quaternion Math",
|
|
1445
|
+
doc="""Construct a quaternion from a 4x4 matrix.
|
|
1446
|
+
|
|
1447
|
+
If the top-left 3x3 block of the matrix is not a pure rotation, but for example includes scaling or skewing, the result is undefined.""",
|
|
1336
1448
|
)
|
|
1337
1449
|
add_builtin(
|
|
1338
1450
|
"quat_rpy",
|
|
@@ -2375,7 +2487,7 @@ add_builtin(
|
|
|
2375
2487
|
|
|
2376
2488
|
This function converts values computed using scalar kernel code to a tile representation for input into collective operations.
|
|
2377
2489
|
|
|
2378
|
-
* If the input value is a scalar, then the resulting tile has ``shape=(
|
|
2490
|
+
* If the input value is a scalar, then the resulting tile has ``shape=(block_dim,)``
|
|
2379
2491
|
* If the input value is a vector, then the resulting tile has ``shape=(length(vector), block_dim)``
|
|
2380
2492
|
|
|
2381
2493
|
:param x: A per-thread local value, e.g. scalar, vector, or matrix.
|
|
@@ -2669,11 +2781,9 @@ def tile_broadcast_value_func(arg_types, arg_values):
|
|
|
2669
2781
|
def tile_broadcast_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
2670
2782
|
tile = arg_values["a"]
|
|
2671
2783
|
|
|
2672
|
-
|
|
2673
|
-
|
|
2674
|
-
template_args.
|
|
2675
|
-
template_args.append(return_type.strides[0])
|
|
2676
|
-
template_args.append(return_type.strides[1])
|
|
2784
|
+
assert len(return_type.shape) == len(return_type.strides)
|
|
2785
|
+
assert 1 <= len(return_type.shape) <= 4
|
|
2786
|
+
template_args = [*return_type.shape, *return_type.strides]
|
|
2677
2787
|
|
|
2678
2788
|
return ((tile,), template_args)
|
|
2679
2789
|
|
|
@@ -2686,52 +2796,13 @@ add_builtin(
|
|
|
2686
2796
|
variadic=False,
|
|
2687
2797
|
doc="""Broadcast a tile.
|
|
2688
2798
|
|
|
2689
|
-
|
|
2690
|
-
|
|
2799
|
+
Broadcasts the input tile ``a`` to the destination shape.
|
|
2691
2800
|
Broadcasting follows NumPy broadcast rules.
|
|
2692
2801
|
|
|
2693
2802
|
:param a: Tile to broadcast
|
|
2694
2803
|
:param shape: The shape to broadcast to
|
|
2695
|
-
:returns: Tile with broadcast
|
|
2696
|
-
group="Tile Primitives",
|
|
2697
|
-
export=False,
|
|
2698
|
-
)
|
|
2699
|
-
|
|
2700
|
-
|
|
2701
|
-
def tile_matmul_value_func(arg_types, arg_values):
|
|
2702
|
-
# return generic type (for doc builds)
|
|
2703
|
-
if arg_types is None:
|
|
2704
|
-
return Tile(dtype=Any, shape=Any)
|
|
2705
|
-
|
|
2706
|
-
if len(arg_types) != 3:
|
|
2707
|
-
raise TypeError(f"tile_matmul() takes exactly 3 positional arguments but {len(arg_types)} were given")
|
|
2708
|
-
|
|
2709
|
-
return None
|
|
2710
|
-
|
|
2711
|
-
|
|
2712
|
-
def tile_matmul_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
2713
|
-
a = arg_values["a"]
|
|
2714
|
-
b = arg_values["b"]
|
|
2715
|
-
out = arg_values["out"]
|
|
2716
|
-
|
|
2717
|
-
# force the storage type of the input variables to shared memory
|
|
2718
|
-
a.type.storage = "shared"
|
|
2719
|
-
b.type.storage = "shared"
|
|
2720
|
-
out.type.storage = "shared"
|
|
2721
|
-
|
|
2722
|
-
template_args = []
|
|
2723
|
-
return ((a, b, out), template_args)
|
|
2724
|
-
|
|
2725
|
-
|
|
2726
|
-
add_builtin(
|
|
2727
|
-
"tile_matmul_scalar",
|
|
2728
|
-
input_types={"a": Tile, "b": Tile, "out": Tile},
|
|
2729
|
-
value_func=tile_matmul_value_func,
|
|
2730
|
-
dispatch_func=tile_matmul_dispatch_func,
|
|
2731
|
-
variadic=True,
|
|
2732
|
-
doc="Compute matrix product and accumulate out += a*b.",
|
|
2804
|
+
:returns: Tile with broadcast shape""",
|
|
2733
2805
|
group="Tile Primitives",
|
|
2734
|
-
hidden=True,
|
|
2735
2806
|
export=False,
|
|
2736
2807
|
)
|
|
2737
2808
|
|
|
@@ -3030,7 +3101,7 @@ def tile_binary_map_value_func(arg_types, arg_values):
|
|
|
3030
3101
|
|
|
3031
3102
|
for i in range(len(a.shape)):
|
|
3032
3103
|
if a.shape[i] != b.shape[i]:
|
|
3033
|
-
raise ValueError(f"tile_map() shapes do not match on dimension {i}, got {a.shape
|
|
3104
|
+
raise ValueError(f"tile_map() shapes do not match on dimension {i}, got {a.shape} and {b.shape}")
|
|
3034
3105
|
|
|
3035
3106
|
return TileBinaryMap(a, b)
|
|
3036
3107
|
|
|
@@ -3807,6 +3878,18 @@ _volume_supported_value_types = {
|
|
|
3807
3878
|
}
|
|
3808
3879
|
|
|
3809
3880
|
|
|
3881
|
+
def _is_volume_type_supported(dtype):
|
|
3882
|
+
for typ in _volume_supported_value_types:
|
|
3883
|
+
if types_equal(typ, dtype):
|
|
3884
|
+
return True
|
|
3885
|
+
return False
|
|
3886
|
+
|
|
3887
|
+
|
|
3888
|
+
def _check_volume_type_is_supported(dtype):
|
|
3889
|
+
if not _is_volume_type_supported(dtype):
|
|
3890
|
+
raise RuntimeError(f"unsupported volume type `{type_repr(dtype)}`")
|
|
3891
|
+
|
|
3892
|
+
|
|
3810
3893
|
def check_volume_value_grad_compatibility(dtype, grad_dtype):
|
|
3811
3894
|
if type_is_vector(dtype):
|
|
3812
3895
|
expected = matrix(shape=(type_length(dtype), 3), dtype=type_scalar_type(dtype))
|
|
@@ -3822,9 +3905,7 @@ def volume_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
3822
3905
|
return Any
|
|
3823
3906
|
|
|
3824
3907
|
dtype = arg_values["dtype"]
|
|
3825
|
-
|
|
3826
|
-
if dtype not in _volume_supported_value_types:
|
|
3827
|
-
raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
|
|
3908
|
+
_check_volume_type_is_supported(dtype)
|
|
3828
3909
|
|
|
3829
3910
|
return dtype
|
|
3830
3911
|
|
|
@@ -3860,9 +3941,7 @@ def volume_sample_grad_value_func(arg_types: Mapping[str, type], arg_values: Map
|
|
|
3860
3941
|
return Any
|
|
3861
3942
|
|
|
3862
3943
|
dtype = arg_values["dtype"]
|
|
3863
|
-
|
|
3864
|
-
if dtype not in _volume_supported_value_types:
|
|
3865
|
-
raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
|
|
3944
|
+
_check_volume_type_is_supported(dtype)
|
|
3866
3945
|
|
|
3867
3946
|
check_volume_value_grad_compatibility(dtype, arg_types["grad"])
|
|
3868
3947
|
|
|
@@ -3900,9 +3979,7 @@ def volume_lookup_value_func(arg_types: Mapping[str, type], arg_values: Mapping[
|
|
|
3900
3979
|
return Any
|
|
3901
3980
|
|
|
3902
3981
|
dtype = arg_values["dtype"]
|
|
3903
|
-
|
|
3904
|
-
if dtype not in _volume_supported_value_types:
|
|
3905
|
-
raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
|
|
3982
|
+
_check_volume_type_is_supported(dtype)
|
|
3906
3983
|
|
|
3907
3984
|
return dtype
|
|
3908
3985
|
|
|
@@ -3939,9 +4016,7 @@ def volume_store_value_func(arg_types: Mapping[str, type], arg_values: Mapping[s
|
|
|
3939
4016
|
return None
|
|
3940
4017
|
|
|
3941
4018
|
dtype = arg_types["value"]
|
|
3942
|
-
|
|
3943
|
-
if dtype not in _volume_supported_value_types:
|
|
3944
|
-
raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
|
|
4019
|
+
_check_volume_type_is_supported(dtype)
|
|
3945
4020
|
|
|
3946
4021
|
return None
|
|
3947
4022
|
|
|
@@ -4191,6 +4266,20 @@ add_builtin(
|
|
|
4191
4266
|
group="Random",
|
|
4192
4267
|
doc="Return a random integer between [low, high).",
|
|
4193
4268
|
)
|
|
4269
|
+
add_builtin(
|
|
4270
|
+
"randu",
|
|
4271
|
+
input_types={"state": uint32},
|
|
4272
|
+
value_type=uint32,
|
|
4273
|
+
group="Random",
|
|
4274
|
+
doc="Return a random unsigned integer in the range [0, 2^32).",
|
|
4275
|
+
)
|
|
4276
|
+
add_builtin(
|
|
4277
|
+
"randu",
|
|
4278
|
+
input_types={"state": uint32, "low": uint32, "high": uint32},
|
|
4279
|
+
value_type=uint32,
|
|
4280
|
+
group="Random",
|
|
4281
|
+
doc="Return a random unsigned integer between [low, high).",
|
|
4282
|
+
)
|
|
4194
4283
|
add_builtin(
|
|
4195
4284
|
"randf",
|
|
4196
4285
|
input_types={"state": uint32},
|
|
@@ -4499,11 +4588,31 @@ add_builtin(
|
|
|
4499
4588
|
export=False,
|
|
4500
4589
|
group="Utility",
|
|
4501
4590
|
)
|
|
4591
|
+
|
|
4592
|
+
|
|
4593
|
+
def select_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
4594
|
+
warp.utils.warn(
|
|
4595
|
+
"wp.select() is deprecated and will be removed in a future\n"
|
|
4596
|
+
"version. Use wp.where(cond, value_if_true, value_if_false) instead.",
|
|
4597
|
+
category=DeprecationWarning,
|
|
4598
|
+
)
|
|
4599
|
+
|
|
4600
|
+
func_args = tuple(args.values())
|
|
4601
|
+
template_args = ()
|
|
4602
|
+
|
|
4603
|
+
return (func_args, template_args)
|
|
4604
|
+
|
|
4605
|
+
|
|
4502
4606
|
add_builtin(
|
|
4503
4607
|
"select",
|
|
4504
4608
|
input_types={"cond": builtins.bool, "value_if_false": Any, "value_if_true": Any},
|
|
4505
4609
|
value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
|
|
4506
|
-
|
|
4610
|
+
dispatch_func=select_dispatch_func,
|
|
4611
|
+
doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
4612
|
+
|
|
4613
|
+
.. deprecated:: 1.7
|
|
4614
|
+
Use :func:`where` instead, which has the more intuitive argument order:
|
|
4615
|
+
``where(cond, value_if_true, value_if_false)``.""",
|
|
4507
4616
|
group="Utility",
|
|
4508
4617
|
)
|
|
4509
4618
|
for t in int_types:
|
|
@@ -4511,14 +4620,47 @@ for t in int_types:
|
|
|
4511
4620
|
"select",
|
|
4512
4621
|
input_types={"cond": t, "value_if_false": Any, "value_if_true": Any},
|
|
4513
4622
|
value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
|
|
4514
|
-
|
|
4623
|
+
dispatch_func=select_dispatch_func,
|
|
4624
|
+
doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
4625
|
+
|
|
4626
|
+
.. deprecated:: 1.7
|
|
4627
|
+
Use :func:`where` instead, which has the more intuitive argument order:
|
|
4628
|
+
``where(cond, value_if_true, value_if_false)``.""",
|
|
4515
4629
|
group="Utility",
|
|
4516
4630
|
)
|
|
4517
4631
|
add_builtin(
|
|
4518
4632
|
"select",
|
|
4519
4633
|
input_types={"arr": array(dtype=Any), "value_if_false": Any, "value_if_true": Any},
|
|
4520
4634
|
value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
|
|
4521
|
-
|
|
4635
|
+
dispatch_func=select_dispatch_func,
|
|
4636
|
+
doc="""Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
4637
|
+
|
|
4638
|
+
.. deprecated:: 1.7
|
|
4639
|
+
Use :func:`where` instead, which has the more intuitive argument order:
|
|
4640
|
+
``where(arr, value_if_true, value_if_false)``.""",
|
|
4641
|
+
group="Utility",
|
|
4642
|
+
)
|
|
4643
|
+
|
|
4644
|
+
add_builtin(
|
|
4645
|
+
"where",
|
|
4646
|
+
input_types={"cond": builtins.bool, "value_if_true": Any, "value_if_false": Any},
|
|
4647
|
+
value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
|
|
4648
|
+
doc="Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``.",
|
|
4649
|
+
group="Utility",
|
|
4650
|
+
)
|
|
4651
|
+
for t in int_types:
|
|
4652
|
+
add_builtin(
|
|
4653
|
+
"where",
|
|
4654
|
+
input_types={"cond": t, "value_if_true": Any, "value_if_false": Any},
|
|
4655
|
+
value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
|
|
4656
|
+
doc="Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``.",
|
|
4657
|
+
group="Utility",
|
|
4658
|
+
)
|
|
4659
|
+
add_builtin(
|
|
4660
|
+
"where",
|
|
4661
|
+
input_types={"arr": array(dtype=Any), "value_if_true": Any, "value_if_false": Any},
|
|
4662
|
+
value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
|
|
4663
|
+
doc="Select between two arguments, if ``arr`` is not null then return ``value_if_true``, otherwise return ``value_if_false``.",
|
|
4522
4664
|
group="Utility",
|
|
4523
4665
|
)
|
|
4524
4666
|
|
|
@@ -5112,33 +5254,51 @@ add_builtin(
|
|
|
5112
5254
|
)
|
|
5113
5255
|
|
|
5114
5256
|
|
|
5257
|
+
# implements vector[index] = value
|
|
5258
|
+
add_builtin(
|
|
5259
|
+
"assign_inplace",
|
|
5260
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
|
|
5261
|
+
value_type=None,
|
|
5262
|
+
hidden=True,
|
|
5263
|
+
group="Utility",
|
|
5264
|
+
)
|
|
5265
|
+
|
|
5266
|
+
# implements quaternion[index] = value
|
|
5267
|
+
add_builtin(
|
|
5268
|
+
"assign_inplace",
|
|
5269
|
+
input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
|
|
5270
|
+
value_type=None,
|
|
5271
|
+
hidden=True,
|
|
5272
|
+
group="Utility",
|
|
5273
|
+
)
|
|
5274
|
+
|
|
5275
|
+
|
|
5115
5276
|
def vector_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
5116
5277
|
vec_type = arg_types["a"]
|
|
5117
5278
|
return vec_type
|
|
5118
5279
|
|
|
5119
5280
|
|
|
5120
|
-
# implements vector[index] = value
|
|
5281
|
+
# implements vector[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
|
|
5121
5282
|
add_builtin(
|
|
5122
|
-
"
|
|
5283
|
+
"assign_copy",
|
|
5123
5284
|
input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
|
|
5124
5285
|
value_func=vector_assign_value_func,
|
|
5125
5286
|
hidden=True,
|
|
5126
5287
|
group="Utility",
|
|
5127
5288
|
)
|
|
5128
5289
|
|
|
5129
|
-
# implements quaternion[index] = value
|
|
5290
|
+
# implements quaternion[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
|
|
5130
5291
|
add_builtin(
|
|
5131
|
-
"
|
|
5292
|
+
"assign_copy",
|
|
5132
5293
|
input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
|
|
5133
5294
|
value_func=vector_assign_value_func,
|
|
5134
5295
|
hidden=True,
|
|
5135
5296
|
group="Utility",
|
|
5136
5297
|
)
|
|
5137
5298
|
|
|
5138
|
-
|
|
5139
5299
|
# implements vector[idx] += scalar
|
|
5140
5300
|
add_builtin(
|
|
5141
|
-
"
|
|
5301
|
+
"add_inplace",
|
|
5142
5302
|
input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
|
|
5143
5303
|
value_type=None,
|
|
5144
5304
|
hidden=True,
|
|
@@ -5147,7 +5307,7 @@ add_builtin(
|
|
|
5147
5307
|
|
|
5148
5308
|
# implements quaternion[idx] += scalar
|
|
5149
5309
|
add_builtin(
|
|
5150
|
-
"
|
|
5310
|
+
"add_inplace",
|
|
5151
5311
|
input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
|
|
5152
5312
|
value_type=None,
|
|
5153
5313
|
hidden=True,
|
|
@@ -5156,7 +5316,7 @@ add_builtin(
|
|
|
5156
5316
|
|
|
5157
5317
|
# implements vector[idx] -= scalar
|
|
5158
5318
|
add_builtin(
|
|
5159
|
-
"
|
|
5319
|
+
"sub_inplace",
|
|
5160
5320
|
input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
|
|
5161
5321
|
value_type=None,
|
|
5162
5322
|
hidden=True,
|
|
@@ -5165,7 +5325,7 @@ add_builtin(
|
|
|
5165
5325
|
|
|
5166
5326
|
# implements quaternion[idx] -= scalar
|
|
5167
5327
|
add_builtin(
|
|
5168
|
-
"
|
|
5328
|
+
"sub_inplace",
|
|
5169
5329
|
input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
|
|
5170
5330
|
value_type=None,
|
|
5171
5331
|
hidden=True,
|
|
@@ -5209,11 +5369,6 @@ add_builtin(
|
|
|
5209
5369
|
)
|
|
5210
5370
|
|
|
5211
5371
|
|
|
5212
|
-
def matrix_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
5213
|
-
mat_type = arg_types["a"]
|
|
5214
|
-
return mat_type
|
|
5215
|
-
|
|
5216
|
-
|
|
5217
5372
|
def matrix_vector_sametype(arg_types: Mapping[str, Any]):
|
|
5218
5373
|
mat_size = arg_types["a"]._shape_[0]
|
|
5219
5374
|
vec_size = arg_types["value"]._length_
|
|
@@ -5224,7 +5379,33 @@ def matrix_vector_sametype(arg_types: Mapping[str, Any]):
|
|
|
5224
5379
|
|
|
5225
5380
|
# implements matrix[i,j] = scalar
|
|
5226
5381
|
add_builtin(
|
|
5227
|
-
"
|
|
5382
|
+
"assign_inplace",
|
|
5383
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
|
|
5384
|
+
value_type=None,
|
|
5385
|
+
hidden=True,
|
|
5386
|
+
group="Utility",
|
|
5387
|
+
)
|
|
5388
|
+
|
|
5389
|
+
|
|
5390
|
+
# implements matrix[i] = vector
|
|
5391
|
+
add_builtin(
|
|
5392
|
+
"assign_inplace",
|
|
5393
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
|
|
5394
|
+
constraint=matrix_vector_sametype,
|
|
5395
|
+
value_type=None,
|
|
5396
|
+
hidden=True,
|
|
5397
|
+
group="Utility",
|
|
5398
|
+
)
|
|
5399
|
+
|
|
5400
|
+
|
|
5401
|
+
def matrix_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
5402
|
+
mat_type = arg_types["a"]
|
|
5403
|
+
return mat_type
|
|
5404
|
+
|
|
5405
|
+
|
|
5406
|
+
# implements matrix[i,j] = scalar
|
|
5407
|
+
add_builtin(
|
|
5408
|
+
"assign_copy",
|
|
5228
5409
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
|
|
5229
5410
|
value_func=matrix_assign_value_func,
|
|
5230
5411
|
hidden=True,
|
|
@@ -5234,7 +5415,7 @@ add_builtin(
|
|
|
5234
5415
|
|
|
5235
5416
|
# implements matrix[i] = vector
|
|
5236
5417
|
add_builtin(
|
|
5237
|
-
"
|
|
5418
|
+
"assign_copy",
|
|
5238
5419
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
|
|
5239
5420
|
constraint=matrix_vector_sametype,
|
|
5240
5421
|
value_func=matrix_assign_value_func,
|
|
@@ -5245,7 +5426,7 @@ add_builtin(
|
|
|
5245
5426
|
|
|
5246
5427
|
# implements matrix[i,j] += scalar
|
|
5247
5428
|
add_builtin(
|
|
5248
|
-
"
|
|
5429
|
+
"add_inplace",
|
|
5249
5430
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
|
|
5250
5431
|
value_type=None,
|
|
5251
5432
|
hidden=True,
|
|
@@ -5253,9 +5434,20 @@ add_builtin(
|
|
|
5253
5434
|
)
|
|
5254
5435
|
|
|
5255
5436
|
|
|
5437
|
+
# implements matrix[i] += vector
|
|
5438
|
+
add_builtin(
|
|
5439
|
+
"add_inplace",
|
|
5440
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
|
|
5441
|
+
constraint=matrix_vector_sametype,
|
|
5442
|
+
value_type=None,
|
|
5443
|
+
hidden=True,
|
|
5444
|
+
group="Utility",
|
|
5445
|
+
)
|
|
5446
|
+
|
|
5447
|
+
|
|
5256
5448
|
# implements matrix[i,j] -= scalar
|
|
5257
5449
|
add_builtin(
|
|
5258
|
-
"
|
|
5450
|
+
"sub_inplace",
|
|
5259
5451
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
|
|
5260
5452
|
value_type=None,
|
|
5261
5453
|
hidden=True,
|
|
@@ -5263,6 +5455,16 @@ add_builtin(
|
|
|
5263
5455
|
)
|
|
5264
5456
|
|
|
5265
5457
|
|
|
5458
|
+
# implements matrix[i] -= vector
|
|
5459
|
+
add_builtin(
|
|
5460
|
+
"sub_inplace",
|
|
5461
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
|
|
5462
|
+
value_type=None,
|
|
5463
|
+
hidden=True,
|
|
5464
|
+
group="Utility",
|
|
5465
|
+
)
|
|
5466
|
+
|
|
5467
|
+
|
|
5266
5468
|
for t in scalar_types + vector_types + (bool,):
|
|
5267
5469
|
if "vec" in t.__name__ or "mat" in t.__name__:
|
|
5268
5470
|
continue
|
|
@@ -5410,7 +5612,27 @@ add_builtin(
|
|
|
5410
5612
|
)
|
|
5411
5613
|
add_builtin(
|
|
5412
5614
|
"expect_near",
|
|
5413
|
-
input_types={"a":
|
|
5615
|
+
input_types={"a": vector(length=Any, dtype=Float), "b": vector(length=Any, dtype=Float), "tolerance": Float},
|
|
5616
|
+
defaults={"tolerance": 1.0e-6},
|
|
5617
|
+
value_type=None,
|
|
5618
|
+
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
5619
|
+
group="Utility",
|
|
5620
|
+
)
|
|
5621
|
+
add_builtin(
|
|
5622
|
+
"expect_near",
|
|
5623
|
+
input_types={"a": quaternion(dtype=Float), "b": quaternion(dtype=Float), "tolerance": Float},
|
|
5624
|
+
defaults={"tolerance": 1.0e-6},
|
|
5625
|
+
value_type=None,
|
|
5626
|
+
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
5627
|
+
group="Utility",
|
|
5628
|
+
)
|
|
5629
|
+
add_builtin(
|
|
5630
|
+
"expect_near",
|
|
5631
|
+
input_types={
|
|
5632
|
+
"a": matrix(shape=(Any, Any), dtype=Float),
|
|
5633
|
+
"b": matrix(shape=(Any, Any), dtype=Float),
|
|
5634
|
+
"tolerance": Float,
|
|
5635
|
+
},
|
|
5414
5636
|
defaults={"tolerance": 1.0e-6},
|
|
5415
5637
|
value_type=None,
|
|
5416
5638
|
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
@@ -5989,7 +6211,7 @@ add_builtin(
|
|
|
5989
6211
|
##
|
|
5990
6212
|
## Matmul
|
|
5991
6213
|
##
|
|
5992
|
-
def
|
|
6214
|
+
def tile_matmul_value_func(arg_types, arg_values):
|
|
5993
6215
|
# return generic type (for doc builds)
|
|
5994
6216
|
if arg_types is None:
|
|
5995
6217
|
return Tile(dtype=Any, shape=Any)
|
|
@@ -6015,7 +6237,7 @@ def tile_matmul_generic_value_func(arg_types, arg_values):
|
|
|
6015
6237
|
return None
|
|
6016
6238
|
|
|
6017
6239
|
|
|
6018
|
-
def
|
|
6240
|
+
def tile_matmul_lto_dispatch_func(
|
|
6019
6241
|
arg_types: Mapping[str, type],
|
|
6020
6242
|
return_type: Any,
|
|
6021
6243
|
return_values: List[Var],
|
|
@@ -6054,142 +6276,82 @@ def tile_matmul_generic_lto_dispatch_func(
|
|
|
6054
6276
|
out.type.storage = "shared"
|
|
6055
6277
|
template_args = [accumulate]
|
|
6056
6278
|
|
|
6057
|
-
# Maps Python/Warp types to C++ types and enums
|
|
6058
|
-
def cublasdx_type_map(dtype):
|
|
6059
|
-
if dtype == float16:
|
|
6060
|
-
return ("wp::float16", 3, 0)
|
|
6061
|
-
if dtype == float32:
|
|
6062
|
-
return ("wp::float32", 5, 0)
|
|
6063
|
-
if dtype == float64:
|
|
6064
|
-
return ("wp::float64", 6, 0)
|
|
6065
|
-
if dtype == vec2h:
|
|
6066
|
-
return ("wp::vec2h", 3, 1)
|
|
6067
|
-
if dtype == vec2f:
|
|
6068
|
-
return ("wp::vec2f", 5, 1)
|
|
6069
|
-
if dtype == vec2d:
|
|
6070
|
-
return ("wp::vec2d", 6, 1)
|
|
6071
|
-
raise TypeError("Unsupported input type in tile_matmul")
|
|
6072
|
-
|
|
6073
|
-
def cublasdx_arrangement_map(layout):
|
|
6074
|
-
if layout == "colmajor":
|
|
6075
|
-
return 0 # CUBLASDX_ARRANGEMENT_COL_MAJOR
|
|
6076
|
-
if layout == "rowmajor":
|
|
6077
|
-
return 1 # CUBLASDX_ARRANGEMENT_ROW_MAJOR
|
|
6078
|
-
raise ValueError("Unsupported layout in tile_matmul")
|
|
6079
|
-
|
|
6080
|
-
# generate the LTO
|
|
6081
6279
|
M, K = a.type.shape[0], a.type.shape[1]
|
|
6082
6280
|
_, N = b.type.shape[0], b.type.shape[1]
|
|
6083
6281
|
num_threads = options["block_dim"]
|
|
6084
6282
|
arch = options["output_arch"]
|
|
6085
6283
|
|
|
6086
|
-
|
|
6087
|
-
|
|
6088
|
-
(
|
|
6089
|
-
|
|
6090
|
-
a_arrangement = cublasdx_arrangement_map(alayout)
|
|
6091
|
-
b_arrangement = cublasdx_arrangement_map(blayout)
|
|
6092
|
-
c_arrangement = cublasdx_arrangement_map(clayout)
|
|
6093
|
-
|
|
6094
|
-
if a_type != b_type or a_type != c_type:
|
|
6095
|
-
raise TypeError("time_matmul(A, B, C) requires all inputs to be real or complex")
|
|
6096
|
-
|
|
6097
|
-
element_type = a_type
|
|
6098
|
-
|
|
6099
|
-
lto_symbol = f"dot_{M}_{N}_{K}_{arch}_{num_threads}_{a_arrangement}_{b_arrangement}_{c_arrangement}_{a_prec}_{b_prec}_{c_prec}_{element_type}"
|
|
6284
|
+
if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
|
|
6285
|
+
# CPU/no-MathDx dispatch
|
|
6286
|
+
return ((0, 0, 0, a, b, out), template_args, [], 0)
|
|
6287
|
+
else:
|
|
6100
6288
|
|
|
6101
|
-
|
|
6102
|
-
|
|
6103
|
-
|
|
6289
|
+
def tile_flip_layout(layout):
|
|
6290
|
+
if layout == "rowmajor":
|
|
6291
|
+
return "colmajor"
|
|
6292
|
+
elif layout == "colmajor":
|
|
6293
|
+
return "rowmajor"
|
|
6104
6294
|
|
|
6105
|
-
#
|
|
6106
|
-
|
|
6107
|
-
|
|
6108
|
-
|
|
6109
|
-
|
|
6110
|
-
|
|
6111
|
-
|
|
6112
|
-
|
|
6295
|
+
# generate the LTOs
|
|
6296
|
+
# C += A * B
|
|
6297
|
+
(fun_forward, lto_forward) = warp.build.build_lto_dot(
|
|
6298
|
+
M,
|
|
6299
|
+
N,
|
|
6300
|
+
K,
|
|
6301
|
+
a.type.dtype,
|
|
6302
|
+
b.type.dtype,
|
|
6303
|
+
out.type.dtype,
|
|
6304
|
+
a.type.layout,
|
|
6305
|
+
b.type.layout,
|
|
6306
|
+
out.type.layout,
|
|
6113
6307
|
arch,
|
|
6308
|
+
num_threads,
|
|
6309
|
+
builder,
|
|
6310
|
+
)
|
|
6311
|
+
# adjA += adjC * B^T - Transpose ~= flipped layout
|
|
6312
|
+
(fun_backward_A, lto_backward_A) = warp.build.build_lto_dot(
|
|
6114
6313
|
M,
|
|
6314
|
+
K,
|
|
6115
6315
|
N,
|
|
6316
|
+
out.type.dtype,
|
|
6317
|
+
b.type.dtype,
|
|
6318
|
+
a.type.dtype,
|
|
6319
|
+
out.type.layout,
|
|
6320
|
+
tile_flip_layout(b.type.layout),
|
|
6321
|
+
a.type.layout,
|
|
6322
|
+
arch,
|
|
6323
|
+
num_threads,
|
|
6324
|
+
builder,
|
|
6325
|
+
)
|
|
6326
|
+
# adjB += A^T * adjC - Transpose ~= flipped layout
|
|
6327
|
+
(fun_backward_B, lto_backward_B) = warp.build.build_lto_dot(
|
|
6116
6328
|
K,
|
|
6117
|
-
|
|
6118
|
-
|
|
6119
|
-
|
|
6120
|
-
|
|
6121
|
-
|
|
6122
|
-
|
|
6123
|
-
|
|
6329
|
+
N,
|
|
6330
|
+
M,
|
|
6331
|
+
a.type.dtype,
|
|
6332
|
+
out.type.dtype,
|
|
6333
|
+
b.type.dtype,
|
|
6334
|
+
tile_flip_layout(a.type.layout),
|
|
6335
|
+
out.type.layout,
|
|
6336
|
+
b.type.layout,
|
|
6337
|
+
arch,
|
|
6124
6338
|
num_threads,
|
|
6339
|
+
builder,
|
|
6125
6340
|
)
|
|
6126
|
-
lto_code_path = Path(lto_code.name)
|
|
6127
|
-
if not result:
|
|
6128
|
-
lto_code.close()
|
|
6129
|
-
if lto_code_path.exists():
|
|
6130
|
-
lto_code_path.unlink()
|
|
6131
|
-
raise RuntimeError("Failed to compile tile_matmul")
|
|
6132
|
-
else:
|
|
6133
|
-
with open(lto_code.name, "rb") as f:
|
|
6134
|
-
lto_code_data = f.read()
|
|
6135
|
-
lto_code.close()
|
|
6136
|
-
lto_code_path.unlink()
|
|
6137
|
-
|
|
6138
|
-
builder.ltoirs[lto_symbol] = lto_code_data
|
|
6139
|
-
builder.ltoirs_decl[lto_symbol] = (
|
|
6140
|
-
f"void {lto_symbol}({c_dtype}, {a_dtype}*, {b_dtype}*, {c_dtype}, {c_dtype}*);"
|
|
6141
|
-
)
|
|
6142
|
-
|
|
6143
|
-
return lto_symbol, lto_code_data
|
|
6144
6341
|
|
|
6145
|
-
|
|
6146
|
-
|
|
6147
|
-
|
|
6148
|
-
|
|
6149
|
-
|
|
6150
|
-
|
|
6151
|
-
|
|
6152
|
-
|
|
6153
|
-
|
|
6154
|
-
|
|
6155
|
-
|
|
6156
|
-
|
|
6157
|
-
|
|
6158
|
-
K,
|
|
6159
|
-
N,
|
|
6160
|
-
out.type.dtype,
|
|
6161
|
-
b.type.dtype,
|
|
6162
|
-
a.type.dtype,
|
|
6163
|
-
out.type.layout,
|
|
6164
|
-
tile_flip_layout(b.type.layout),
|
|
6165
|
-
a.type.layout,
|
|
6166
|
-
)
|
|
6167
|
-
# adjB += A^T * adjC - Transpose ~= flipped layout
|
|
6168
|
-
(fun_backward_B, lto_backward_B) = make_function(
|
|
6169
|
-
K,
|
|
6170
|
-
N,
|
|
6171
|
-
M,
|
|
6172
|
-
a.type.dtype,
|
|
6173
|
-
out.type.dtype,
|
|
6174
|
-
b.type.dtype,
|
|
6175
|
-
tile_flip_layout(a.type.layout),
|
|
6176
|
-
out.type.layout,
|
|
6177
|
-
b.type.layout,
|
|
6178
|
-
)
|
|
6179
|
-
|
|
6180
|
-
return (
|
|
6181
|
-
(
|
|
6182
|
-
Var(fun_forward, str, False, True, False),
|
|
6183
|
-
Var(fun_backward_A, str, False, True, False),
|
|
6184
|
-
Var(fun_backward_B, str, False, True, False),
|
|
6185
|
-
a,
|
|
6186
|
-
b,
|
|
6187
|
-
out,
|
|
6188
|
-
),
|
|
6189
|
-
template_args,
|
|
6190
|
-
[lto_forward, lto_backward_A, lto_backward_B],
|
|
6191
|
-
0,
|
|
6192
|
-
)
|
|
6342
|
+
return (
|
|
6343
|
+
(
|
|
6344
|
+
Var(fun_forward, str, False, True, False),
|
|
6345
|
+
Var(fun_backward_A, str, False, True, False),
|
|
6346
|
+
Var(fun_backward_B, str, False, True, False),
|
|
6347
|
+
a,
|
|
6348
|
+
b,
|
|
6349
|
+
out,
|
|
6350
|
+
),
|
|
6351
|
+
template_args,
|
|
6352
|
+
[lto_forward, lto_backward_A, lto_backward_B],
|
|
6353
|
+
0,
|
|
6354
|
+
)
|
|
6193
6355
|
|
|
6194
6356
|
|
|
6195
6357
|
add_builtin(
|
|
@@ -6199,8 +6361,8 @@ add_builtin(
|
|
|
6199
6361
|
"b": Tile(dtype=Any, shape=Any),
|
|
6200
6362
|
"out": Tile(dtype=Any, shape=Any),
|
|
6201
6363
|
},
|
|
6202
|
-
value_func=
|
|
6203
|
-
lto_dispatch_func=
|
|
6364
|
+
value_func=tile_matmul_value_func,
|
|
6365
|
+
lto_dispatch_func=tile_matmul_lto_dispatch_func,
|
|
6204
6366
|
variadic=False,
|
|
6205
6367
|
doc="""Computes the matrix product and accumulates ``out += a*b``.
|
|
6206
6368
|
|
|
@@ -6208,7 +6370,7 @@ add_builtin(
|
|
|
6208
6370
|
* fp16, fp32, fp64 (real)
|
|
6209
6371
|
* vec2h, vec2f, vec2d (complex)
|
|
6210
6372
|
|
|
6211
|
-
All input and output tiles must have the same datatype. Tile data will
|
|
6373
|
+
All input and output tiles must have the same datatype. Tile data will automatically be migrated
|
|
6212
6374
|
to shared memory if necessary and will use TensorCore operations when available.
|
|
6213
6375
|
|
|
6214
6376
|
:param a: A tile with ``shape=(M, K)``
|
|
@@ -6222,8 +6384,8 @@ add_builtin(
|
|
|
6222
6384
|
add_builtin(
|
|
6223
6385
|
"tile_matmul",
|
|
6224
6386
|
input_types={"a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
|
|
6225
|
-
value_func=
|
|
6226
|
-
lto_dispatch_func=
|
|
6387
|
+
value_func=tile_matmul_value_func,
|
|
6388
|
+
lto_dispatch_func=tile_matmul_lto_dispatch_func,
|
|
6227
6389
|
variadic=False,
|
|
6228
6390
|
doc="""Computes the matrix product ``out = a*b``.
|
|
6229
6391
|
|
|
@@ -6231,7 +6393,7 @@ add_builtin(
|
|
|
6231
6393
|
* fp16, fp32, fp64 (real)
|
|
6232
6394
|
* vec2h, vec2f, vec2d (complex)
|
|
6233
6395
|
|
|
6234
|
-
Both input tiles must have the same datatype. Tile data will
|
|
6396
|
+
Both input tiles must have the same datatype. Tile data will automatically be migrated
|
|
6235
6397
|
to shared memory if necessary and will use TensorCore operations when available.
|
|
6236
6398
|
|
|
6237
6399
|
:param a: A tile with ``shape=(M, K)``
|
|
@@ -6303,59 +6465,29 @@ def tile_fft_generic_lto_dispatch_func(
|
|
|
6303
6465
|
num_threads = options["block_dim"]
|
|
6304
6466
|
arch = options["output_arch"]
|
|
6305
6467
|
ept = size // num_threads
|
|
6306
|
-
|
|
6307
|
-
|
|
6308
|
-
|
|
6309
|
-
|
|
6310
|
-
|
|
6311
|
-
|
|
6312
|
-
|
|
6313
|
-
|
|
6314
|
-
|
|
6315
|
-
|
|
6316
|
-
|
|
6317
|
-
|
|
6318
|
-
|
|
6319
|
-
|
|
6320
|
-
|
|
6321
|
-
|
|
6322
|
-
|
|
6323
|
-
|
|
6324
|
-
|
|
6325
|
-
|
|
6326
|
-
|
|
6327
|
-
|
|
6328
|
-
|
|
6329
|
-
lto_code_path = Path(lto_code.name)
|
|
6330
|
-
if not result:
|
|
6331
|
-
lto_code.close()
|
|
6332
|
-
if lto_code_path.exists():
|
|
6333
|
-
lto_code_path.unlink()
|
|
6334
|
-
raise RuntimeError("Failed to compile tile_fft")
|
|
6335
|
-
|
|
6336
|
-
with open(lto_code.name, "rb") as f:
|
|
6337
|
-
lto_code_data = f.read()
|
|
6338
|
-
|
|
6339
|
-
lto_code.close()
|
|
6340
|
-
lto_code_path.unlink()
|
|
6341
|
-
|
|
6342
|
-
builder.ltoirs[lto_symbol] = lto_code_data
|
|
6343
|
-
|
|
6344
|
-
shared_memory_bytes = Tile.round_up(shared_memory_size.value)
|
|
6345
|
-
|
|
6346
|
-
return (
|
|
6347
|
-
(
|
|
6348
|
-
Var(lto_symbol, str, False, True, False),
|
|
6349
|
-
Var(dtype, str, False, True, False),
|
|
6350
|
-
Var(str(shared_memory_bytes), str, False, True, False),
|
|
6351
|
-
Var(str(batch), str, False, True, False),
|
|
6352
|
-
Var(str(ept), str, False, True, False),
|
|
6353
|
-
inout,
|
|
6354
|
-
),
|
|
6355
|
-
[],
|
|
6356
|
-
[lto_code_data],
|
|
6357
|
-
shared_memory_bytes,
|
|
6358
|
-
)
|
|
6468
|
+
|
|
6469
|
+
if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
|
|
6470
|
+
# CPU/no-MathDx dispatch
|
|
6471
|
+
return ([], [], [], 0)
|
|
6472
|
+
else:
|
|
6473
|
+
# generate the LTO
|
|
6474
|
+
lto_symbol, lto_code_data, shared_memory_bytes = warp.build.build_lto_fft(
|
|
6475
|
+
arch, size, ept, direction, dir, precision, builder
|
|
6476
|
+
)
|
|
6477
|
+
|
|
6478
|
+
return (
|
|
6479
|
+
(
|
|
6480
|
+
Var(lto_symbol, str, False, True, False),
|
|
6481
|
+
Var(dtype, str, False, True, False),
|
|
6482
|
+
Var(str(shared_memory_bytes), str, False, True, False),
|
|
6483
|
+
Var(str(batch), str, False, True, False),
|
|
6484
|
+
Var(str(ept), str, False, True, False),
|
|
6485
|
+
inout,
|
|
6486
|
+
),
|
|
6487
|
+
[],
|
|
6488
|
+
[lto_code_data],
|
|
6489
|
+
shared_memory_bytes,
|
|
6490
|
+
)
|
|
6359
6491
|
|
|
6360
6492
|
|
|
6361
6493
|
add_builtin(
|
|
@@ -6417,7 +6549,7 @@ def tile_cholesky_generic_value_func(arg_types, arg_values):
|
|
|
6417
6549
|
raise TypeError(f"tile_cholesky() argument must be a tile, got {a!r}")
|
|
6418
6550
|
|
|
6419
6551
|
if len(a.shape) != 2:
|
|
6420
|
-
raise ValueError("tile_cholesky()
|
|
6552
|
+
raise ValueError("tile_cholesky() argument must be a 2D tile")
|
|
6421
6553
|
|
|
6422
6554
|
if a.shape[0] != a.shape[1]:
|
|
6423
6555
|
raise ValueError("tile_cholesky() argument must be square")
|
|
@@ -6458,57 +6590,36 @@ def tile_cholesky_generic_lto_dispatch_func(
|
|
|
6458
6590
|
if out.type.shape[0] != M or out.type.shape[1] != M:
|
|
6459
6591
|
raise ValueError("tile_cholesky() output tile must be square")
|
|
6460
6592
|
|
|
6461
|
-
|
|
6462
|
-
|
|
6463
|
-
lto_symbol = f"potrf_{M}_{N}_{arch}_{precision_enum}"
|
|
6464
|
-
|
|
6465
|
-
# early out if LTO for this combination already exists for this module
|
|
6466
|
-
if lto_symbol in builder.ltoirs:
|
|
6467
|
-
return lto_symbol, builder.ltoirs[lto_symbol]
|
|
6468
|
-
|
|
6469
|
-
# otherwise compile LTO
|
|
6470
|
-
lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
|
|
6471
|
-
universal_fatbin_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
|
|
6593
|
+
solver = "potrf"
|
|
6594
|
+
solver_enum = cusolver_function_map[solver]
|
|
6472
6595
|
|
|
6473
|
-
# cuSOLVERDx only
|
|
6596
|
+
# cuSOLVERDx only supports col-major input/outputs,
|
|
6474
6597
|
# so we use upper to mimic a row-major input
|
|
6475
|
-
|
|
6476
|
-
universal_fatbin_code.name.encode("utf-8"),
|
|
6477
|
-
lto_code.name.encode("utf-8"),
|
|
6478
|
-
lto_symbol.encode("utf-8"),
|
|
6479
|
-
0,
|
|
6480
|
-
None,
|
|
6481
|
-
None,
|
|
6482
|
-
arch,
|
|
6483
|
-
M,
|
|
6484
|
-
N,
|
|
6485
|
-
cusolver_function_map["potrf"],
|
|
6486
|
-
precision_enum,
|
|
6487
|
-
cusolver_fill_mode_map["upper"],
|
|
6488
|
-
num_threads,
|
|
6489
|
-
)
|
|
6598
|
+
fill_mode = cusolver_fill_mode_map["upper"]
|
|
6490
6599
|
|
|
6491
|
-
|
|
6492
|
-
|
|
6493
|
-
|
|
6494
|
-
if Path(f.name).exists():
|
|
6495
|
-
Path(f.name).unlink()
|
|
6496
|
-
raise RuntimeError("Failed to compile tile_cholesky")
|
|
6600
|
+
arch = options["output_arch"]
|
|
6601
|
+
num_threads = options["block_dim"]
|
|
6602
|
+
parameter_list = f"({dtype}*, unsigned)"
|
|
6497
6603
|
|
|
6604
|
+
if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
|
|
6605
|
+
# CPU/no-MathDx dispatch
|
|
6606
|
+
return ((0, a, out), [], [], 0)
|
|
6498
6607
|
else:
|
|
6499
|
-
|
|
6500
|
-
|
|
6501
|
-
|
|
6502
|
-
|
|
6503
|
-
|
|
6504
|
-
|
|
6505
|
-
|
|
6506
|
-
|
|
6507
|
-
|
|
6508
|
-
|
|
6509
|
-
|
|
6608
|
+
# generate the LTO
|
|
6609
|
+
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
6610
|
+
M,
|
|
6611
|
+
N,
|
|
6612
|
+
solver,
|
|
6613
|
+
solver_enum,
|
|
6614
|
+
fill_mode,
|
|
6615
|
+
arch,
|
|
6616
|
+
precision_enum,
|
|
6617
|
+
num_threads,
|
|
6618
|
+
parameter_list,
|
|
6619
|
+
builder,
|
|
6620
|
+
)
|
|
6510
6621
|
|
|
6511
|
-
|
|
6622
|
+
return ((Var(lto_symbol, str, False, True, False), a, out), [], [lto_code_data], 0)
|
|
6512
6623
|
|
|
6513
6624
|
|
|
6514
6625
|
add_builtin(
|
|
@@ -6602,57 +6713,36 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
|
6602
6713
|
f"got {y.type.shape[0]} elements in output and {M} rows in 'L'"
|
|
6603
6714
|
)
|
|
6604
6715
|
|
|
6605
|
-
|
|
6606
|
-
|
|
6607
|
-
lto_symbol = f"potrs_{M}_{N}_{arch}_{precision_enum}"
|
|
6608
|
-
|
|
6609
|
-
# early out if LTO for this combination already exists for this module
|
|
6610
|
-
if lto_symbol in builder.ltoirs:
|
|
6611
|
-
return lto_symbol, builder.ltoirs[lto_symbol]
|
|
6612
|
-
|
|
6613
|
-
# otherwise compile LTO
|
|
6614
|
-
lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
|
|
6615
|
-
universal_fatbin_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
|
|
6716
|
+
solver = "potrs"
|
|
6717
|
+
solver_enum = cusolver_function_map[solver]
|
|
6616
6718
|
|
|
6617
|
-
# cuSOLVERDx only
|
|
6719
|
+
# cuSOLVERDx only supports col-major input/outputs,
|
|
6618
6720
|
# so we use upper to mimic a row-major input
|
|
6619
|
-
|
|
6620
|
-
universal_fatbin_code.name.encode("utf-8"),
|
|
6621
|
-
lto_code.name.encode("utf-8"),
|
|
6622
|
-
lto_symbol.encode("utf-8"),
|
|
6623
|
-
0,
|
|
6624
|
-
None,
|
|
6625
|
-
None,
|
|
6626
|
-
arch,
|
|
6627
|
-
M,
|
|
6628
|
-
N,
|
|
6629
|
-
cusolver_function_map["potrs"],
|
|
6630
|
-
precision_enum,
|
|
6631
|
-
cusolver_fill_mode_map["upper"],
|
|
6632
|
-
num_threads,
|
|
6633
|
-
)
|
|
6721
|
+
fill_mode = cusolver_fill_mode_map["upper"]
|
|
6634
6722
|
|
|
6635
|
-
|
|
6636
|
-
|
|
6637
|
-
|
|
6638
|
-
if Path(f.name).exists():
|
|
6639
|
-
Path(f.name).unlink()
|
|
6640
|
-
raise RuntimeError("Failed to compile tile_cholesky_solve")
|
|
6723
|
+
arch = options["output_arch"]
|
|
6724
|
+
num_threads = options["block_dim"]
|
|
6725
|
+
parameter_list = f"({dtype}*, {dtype}*)"
|
|
6641
6726
|
|
|
6727
|
+
if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
|
|
6728
|
+
# CPU/no-MathDx dispatch
|
|
6729
|
+
return ((0, L, x, y), [], [], 0)
|
|
6642
6730
|
else:
|
|
6643
|
-
|
|
6644
|
-
|
|
6645
|
-
|
|
6646
|
-
|
|
6647
|
-
|
|
6648
|
-
|
|
6649
|
-
|
|
6650
|
-
|
|
6651
|
-
|
|
6652
|
-
|
|
6653
|
-
|
|
6654
|
-
|
|
6655
|
-
|
|
6731
|
+
# generate the LTO
|
|
6732
|
+
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
6733
|
+
M,
|
|
6734
|
+
N,
|
|
6735
|
+
solver,
|
|
6736
|
+
solver_enum,
|
|
6737
|
+
fill_mode,
|
|
6738
|
+
arch,
|
|
6739
|
+
precision_enum,
|
|
6740
|
+
num_threads,
|
|
6741
|
+
parameter_list,
|
|
6742
|
+
builder,
|
|
6743
|
+
)
|
|
6744
|
+
|
|
6745
|
+
return ((Var(lto_symbol, str, False, True, False), L, x, y), [], [lto_code_data], 0)
|
|
6656
6746
|
|
|
6657
6747
|
|
|
6658
6748
|
add_builtin(
|