warp-lang 1.6.2__py3-none-win_amd64.whl → 1.7.1__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +7 -1
- warp/autograd.py +12 -2
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +410 -0
- warp/build_dll.py +6 -14
- warp/builtins.py +463 -372
- warp/codegen.py +196 -124
- warp/config.py +42 -6
- warp/context.py +496 -271
- warp/dlpack.py +8 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_cloth.py +1 -1
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/distributed/example_jacobi_mpi.py +507 -0
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +2 -2
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_magnetostatics.py +6 -6
- warp/examples/fem/utils.py +9 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/tile/example_tile_matmul.py +2 -4
- warp/fem/__init__.py +11 -1
- warp/fem/adaptivity.py +4 -4
- warp/fem/field/field.py +11 -1
- warp/fem/field/nodal_field.py +56 -88
- warp/fem/field/virtual.py +62 -23
- warp/fem/geometry/adaptive_nanogrid.py +16 -13
- warp/fem/geometry/closest_point.py +1 -1
- warp/fem/geometry/deformed_geometry.py +5 -2
- warp/fem/geometry/geometry.py +5 -0
- warp/fem/geometry/grid_2d.py +12 -12
- warp/fem/geometry/grid_3d.py +12 -15
- warp/fem/geometry/hexmesh.py +5 -7
- warp/fem/geometry/nanogrid.py +9 -11
- warp/fem/geometry/quadmesh.py +13 -13
- warp/fem/geometry/tetmesh.py +3 -4
- warp/fem/geometry/trimesh.py +7 -20
- warp/fem/integrate.py +262 -93
- warp/fem/linalg.py +5 -5
- warp/fem/quadrature/pic_quadrature.py +37 -22
- warp/fem/quadrature/quadrature.py +194 -25
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +4 -2
- warp/fem/space/basis_space.py +25 -18
- warp/fem/space/hexmesh_function_space.py +2 -2
- warp/fem/space/partition.py +6 -2
- warp/fem/space/quadmesh_function_space.py +8 -8
- warp/fem/space/shape/cube_shape_function.py +23 -23
- warp/fem/space/shape/square_shape_function.py +12 -12
- warp/fem/space/shape/triangle_shape_function.py +1 -1
- warp/fem/space/tetmesh_function_space.py +3 -3
- warp/fem/space/trimesh_function_space.py +2 -2
- warp/fem/utils.py +12 -6
- warp/jax.py +14 -1
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
- warp/jax_experimental/ffi.py +702 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +89 -0
- warp/native/array.h +13 -0
- warp/native/builtin.h +29 -3
- warp/native/bvh.cpp +3 -1
- warp/native/bvh.cu +42 -14
- warp/native/bvh.h +2 -1
- warp/native/clang/clang.cpp +30 -3
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -0
- warp/native/exports.h +68 -63
- warp/native/intersect.h +26 -26
- warp/native/intersect_adj.h +33 -33
- warp/native/marching.cu +1 -1
- warp/native/mat.h +513 -9
- warp/native/mesh.h +10 -10
- warp/native/quat.h +99 -11
- warp/native/rand.h +6 -0
- warp/native/sort.cpp +122 -59
- warp/native/sort.cu +152 -15
- warp/native/sort.h +8 -1
- warp/native/sparse.cpp +43 -22
- warp/native/sparse.cu +52 -17
- warp/native/svd.h +116 -0
- warp/native/tile.h +312 -116
- warp/native/tile_reduce.h +46 -3
- warp/native/vec.h +68 -7
- warp/native/volume.cpp +85 -113
- warp/native/volume_builder.cu +25 -10
- warp/native/volume_builder.h +6 -0
- warp/native/warp.cpp +5 -6
- warp/native/warp.cu +100 -11
- warp/native/warp.h +19 -10
- warp/optim/linear.py +10 -10
- warp/render/render_opengl.py +19 -17
- warp/render/render_usd.py +93 -3
- warp/sim/articulation.py +4 -4
- warp/sim/collide.py +32 -19
- warp/sim/import_mjcf.py +449 -155
- warp/sim/import_urdf.py +32 -12
- warp/sim/inertia.py +189 -156
- warp/sim/integrator_euler.py +8 -5
- warp/sim/integrator_featherstone.py +3 -10
- warp/sim/integrator_vbd.py +207 -2
- warp/sim/integrator_xpbd.py +8 -5
- warp/sim/model.py +71 -25
- warp/sim/render.py +4 -0
- warp/sim/utils.py +2 -2
- warp/sparse.py +642 -555
- warp/stubs.py +217 -20
- warp/tests/__main__.py +0 -15
- warp/tests/assets/torus.usda +1 -1
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
- warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
- warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
- warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
- warp/tests/sim/test_inertia.py +161 -0
- warp/tests/{test_model.py → sim/test_model.py} +40 -0
- warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/sim/test_xpbd.py +399 -0
- warp/tests/test_bool.py +1 -1
- warp/tests/test_codegen.py +24 -3
- warp/tests/test_examples.py +40 -38
- warp/tests/test_fem.py +98 -14
- warp/tests/test_linear_solvers.py +0 -11
- warp/tests/test_mat.py +577 -156
- warp/tests/test_mat_scalar_ops.py +4 -4
- warp/tests/test_overwrite.py +0 -60
- warp/tests/test_quat.py +356 -151
- warp/tests/test_rand.py +44 -37
- warp/tests/test_sparse.py +47 -6
- warp/tests/test_spatial.py +75 -0
- warp/tests/test_static.py +1 -1
- warp/tests/test_utils.py +84 -4
- warp/tests/test_vec.py +336 -178
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
- warp/tests/unittest_serial.py +1 -0
- warp/tests/unittest_suites.py +45 -62
- warp/tests/unittest_utils.py +2 -1
- warp/thirdparty/unittest_parallel.py +3 -1
- warp/types.py +175 -666
- warp/utils.py +137 -72
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
- warp/examples/optim/example_walker.py +0 -317
- warp/native/cutlass_gemm.cpp +0 -43
- warp/native/cutlass_gemm.cu +0 -382
- warp/tests/test_matmul.py +0 -511
- warp/tests/test_matmul_lite.py +0 -411
- warp/tests/test_vbd.py +0 -386
- warp/tests/unused_test_misc.py +0 -77
- /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
- /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
- /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
- /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
- /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
- /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
- /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
- /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
- /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
- /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
- /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
- /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
- /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
- /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
- /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
- /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
- /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
warp/builtins.py
CHANGED
|
@@ -15,10 +15,10 @@
|
|
|
15
15
|
|
|
16
16
|
import builtins
|
|
17
17
|
import functools
|
|
18
|
-
import tempfile
|
|
19
|
-
from pathlib import Path
|
|
20
18
|
from typing import Any, Callable, Mapping, Sequence
|
|
21
19
|
|
|
20
|
+
import warp.build
|
|
21
|
+
import warp.context
|
|
22
22
|
from warp.codegen import Reference, Var, strip_reference
|
|
23
23
|
from warp.types import *
|
|
24
24
|
|
|
@@ -41,7 +41,7 @@ def sametypes(arg_types: Mapping[str, Any]):
|
|
|
41
41
|
return all(types_equal(arg_type_0, t) for t in arg_types_iter)
|
|
42
42
|
|
|
43
43
|
|
|
44
|
-
def sametypes_create_value_func(default):
|
|
44
|
+
def sametypes_create_value_func(default: TypeVar):
|
|
45
45
|
def fn(arg_types, arg_values):
|
|
46
46
|
if arg_types is None:
|
|
47
47
|
return default
|
|
@@ -399,7 +399,7 @@ add_builtin(
|
|
|
399
399
|
)
|
|
400
400
|
|
|
401
401
|
|
|
402
|
-
def scalar_infer_type(arg_types: Mapping[str, type]):
|
|
402
|
+
def scalar_infer_type(arg_types: Union[Mapping[str, type], Tuple[type, ...], None]):
|
|
403
403
|
if arg_types is None:
|
|
404
404
|
return Scalar
|
|
405
405
|
|
|
@@ -836,7 +836,7 @@ def vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
836
836
|
|
|
837
837
|
if dtype is None:
|
|
838
838
|
dtype = value_type
|
|
839
|
-
elif value_type
|
|
839
|
+
elif not warp.types.scalars_equal(value_type, dtype):
|
|
840
840
|
raise RuntimeError(
|
|
841
841
|
f"the value used to fill this vector is expected to be of the type `{dtype.__name__}`"
|
|
842
842
|
)
|
|
@@ -857,9 +857,9 @@ def vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
857
857
|
|
|
858
858
|
if dtype is None:
|
|
859
859
|
dtype = value_type
|
|
860
|
-
elif value_type
|
|
860
|
+
elif not warp.types.scalars_equal(value_type, dtype):
|
|
861
861
|
raise RuntimeError(
|
|
862
|
-
f"all values used to initialize this vector
|
|
862
|
+
f"all values used to initialize this vector are expected to be of the type `{dtype.__name__}`"
|
|
863
863
|
)
|
|
864
864
|
|
|
865
865
|
if length is None:
|
|
@@ -940,7 +940,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
940
940
|
|
|
941
941
|
if dtype is None:
|
|
942
942
|
dtype = value_type
|
|
943
|
-
elif value_type
|
|
943
|
+
elif not warp.types.scalars_equal(value_type, dtype):
|
|
944
944
|
raise RuntimeError(
|
|
945
945
|
f"the value used to fill this matrix is expected to be of the type `{dtype.__name__}`"
|
|
946
946
|
)
|
|
@@ -950,6 +950,12 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
950
950
|
raise RuntimeError("the `shape` argument must be specified when initializing a matrix by value")
|
|
951
951
|
|
|
952
952
|
if all(type_is_vector(x) for x in variadic_arg_types):
|
|
953
|
+
warp.utils.warn(
|
|
954
|
+
"the built-in `wp.matrix()` won't support taking column vectors as input "
|
|
955
|
+
"in the future. Use `wp.matrix_from_rows()` or `wp.matrix_from_cols()` instead.",
|
|
956
|
+
DeprecationWarning,
|
|
957
|
+
)
|
|
958
|
+
|
|
953
959
|
if shape[1] != variadic_arg_count:
|
|
954
960
|
raise RuntimeError(
|
|
955
961
|
f"incompatible number of column vectors given ({variadic_arg_count}) "
|
|
@@ -973,7 +979,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
973
979
|
|
|
974
980
|
if dtype is None:
|
|
975
981
|
dtype = value_type
|
|
976
|
-
elif value_type
|
|
982
|
+
elif not warp.types.scalars_equal(value_type, dtype):
|
|
977
983
|
raise RuntimeError(
|
|
978
984
|
f"all values used to initialize this matrix are expected to be of the type `{dtype.__name__}`"
|
|
979
985
|
)
|
|
@@ -1030,6 +1036,86 @@ add_builtin(
|
|
|
1030
1036
|
)
|
|
1031
1037
|
|
|
1032
1038
|
|
|
1039
|
+
def matrix_from_vecs_create_value_func(cols: bool):
|
|
1040
|
+
def fn(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1041
|
+
if arg_types is None:
|
|
1042
|
+
return matrix(shape=(Any, Any), dtype=Scalar)
|
|
1043
|
+
|
|
1044
|
+
variadic_arg_types = arg_types.get("args", ())
|
|
1045
|
+
variadic_arg_count = len(variadic_arg_types)
|
|
1046
|
+
|
|
1047
|
+
if not all(type_is_vector(x) for x in variadic_arg_types):
|
|
1048
|
+
raise RuntimeError("all arguments are expected to be vectors")
|
|
1049
|
+
|
|
1050
|
+
length = variadic_arg_types[0]._length_
|
|
1051
|
+
if any(x._length_ != length for x in variadic_arg_types):
|
|
1052
|
+
raise RuntimeError("all vectors are expected to have the same length")
|
|
1053
|
+
|
|
1054
|
+
dtype = variadic_arg_types[0]._wp_scalar_type_
|
|
1055
|
+
if any(x._wp_scalar_type_ != dtype for x in variadic_arg_types):
|
|
1056
|
+
raise RuntimeError("all vectors are expected to have the same dtype")
|
|
1057
|
+
|
|
1058
|
+
shape = (length, variadic_arg_count) if cols else (variadic_arg_count, length)
|
|
1059
|
+
return matrix(shape=shape, dtype=dtype)
|
|
1060
|
+
|
|
1061
|
+
return fn
|
|
1062
|
+
|
|
1063
|
+
|
|
1064
|
+
def matrix_from_vecs_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
1065
|
+
# We're in the codegen stage where we emit the code calling the built-in.
|
|
1066
|
+
# Further validate the given argument values if needed and map them
|
|
1067
|
+
# to the underlying C++ function's runtime and template params.
|
|
1068
|
+
|
|
1069
|
+
shape = return_type._shape_
|
|
1070
|
+
dtype = return_type._wp_scalar_type_
|
|
1071
|
+
|
|
1072
|
+
variadic_args = args.get("args", ())
|
|
1073
|
+
|
|
1074
|
+
func_args = variadic_args
|
|
1075
|
+
|
|
1076
|
+
if shape in ((2, 2), (3, 3), (4, 4)):
|
|
1077
|
+
# Template specializations exist for these shapes, don't pass them
|
|
1078
|
+
# as template parameters.
|
|
1079
|
+
template_args = (dtype,)
|
|
1080
|
+
else:
|
|
1081
|
+
template_args = (*shape, dtype)
|
|
1082
|
+
|
|
1083
|
+
return (func_args, template_args)
|
|
1084
|
+
|
|
1085
|
+
|
|
1086
|
+
def matrix_from_vecs_initializer_list_func(args, return_type):
|
|
1087
|
+
shape = return_type._shape_
|
|
1088
|
+
|
|
1089
|
+
return shape[0] != shape[1] or shape[0] > 4
|
|
1090
|
+
|
|
1091
|
+
|
|
1092
|
+
add_builtin(
|
|
1093
|
+
"matrix_from_cols",
|
|
1094
|
+
input_types={"*args": vector(length=Any, dtype=Scalar)},
|
|
1095
|
+
variadic=True,
|
|
1096
|
+
value_func=matrix_from_vecs_create_value_func(cols=True),
|
|
1097
|
+
dispatch_func=matrix_from_vecs_dispatch_func,
|
|
1098
|
+
initializer_list_func=matrix_from_vecs_initializer_list_func,
|
|
1099
|
+
native_func="matrix_from_cols",
|
|
1100
|
+
doc="Construct a matrix from column vectors.",
|
|
1101
|
+
group="Vector Math",
|
|
1102
|
+
export=False,
|
|
1103
|
+
)
|
|
1104
|
+
|
|
1105
|
+
add_builtin(
|
|
1106
|
+
"matrix_from_rows",
|
|
1107
|
+
input_types={"*args": vector(length=Any, dtype=Scalar)},
|
|
1108
|
+
variadic=True,
|
|
1109
|
+
value_func=matrix_from_vecs_create_value_func(cols=False),
|
|
1110
|
+
dispatch_func=matrix_from_vecs_dispatch_func,
|
|
1111
|
+
initializer_list_func=matrix_from_vecs_initializer_list_func,
|
|
1112
|
+
native_func="matrix_from_rows",
|
|
1113
|
+
doc="Construct a matrix from row vectors.",
|
|
1114
|
+
group="Vector Math",
|
|
1115
|
+
export=False,
|
|
1116
|
+
)
|
|
1117
|
+
|
|
1118
|
+
|
|
1033
1119
|
def identity_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1034
1120
|
if arg_types is None:
|
|
1035
1121
|
return matrix(shape=(Any, Any), dtype=Scalar)
|
|
@@ -1084,7 +1170,7 @@ def matrix_transform_value_func(arg_types: Mapping[str, type], arg_values: Mappi
|
|
|
1084
1170
|
|
|
1085
1171
|
if dtype is None:
|
|
1086
1172
|
dtype = value_type
|
|
1087
|
-
elif value_type
|
|
1173
|
+
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1088
1174
|
raise RuntimeError(
|
|
1089
1175
|
f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
|
|
1090
1176
|
)
|
|
@@ -1141,6 +1227,21 @@ add_builtin(
|
|
|
1141
1227
|
while the left and right basis vectors are returned in ``U`` and ``V``.""",
|
|
1142
1228
|
)
|
|
1143
1229
|
|
|
1230
|
+
add_builtin(
|
|
1231
|
+
"svd2",
|
|
1232
|
+
input_types={
|
|
1233
|
+
"A": matrix(shape=(2, 2), dtype=Float),
|
|
1234
|
+
"U": matrix(shape=(2, 2), dtype=Float),
|
|
1235
|
+
"sigma": vector(length=2, dtype=Float),
|
|
1236
|
+
"V": matrix(shape=(2, 2), dtype=Scalar),
|
|
1237
|
+
},
|
|
1238
|
+
value_type=None,
|
|
1239
|
+
group="Vector Math",
|
|
1240
|
+
export=False,
|
|
1241
|
+
doc="""Compute the SVD of a 2x2 matrix ``A``. The singular values are returned in ``sigma``,
|
|
1242
|
+
while the left and right basis vectors are returned in ``U`` and ``V``.""",
|
|
1243
|
+
)
|
|
1244
|
+
|
|
1144
1245
|
add_builtin(
|
|
1145
1246
|
"qr3",
|
|
1146
1247
|
input_types={
|
|
@@ -1204,7 +1305,7 @@ def quaternion_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str
|
|
|
1204
1305
|
|
|
1205
1306
|
if dtype is None:
|
|
1206
1307
|
dtype = value_type
|
|
1207
|
-
elif value_type
|
|
1308
|
+
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1208
1309
|
raise RuntimeError(
|
|
1209
1310
|
f"all values used to initialize this quaternion are expected to be of the type `{dtype.__name__}`"
|
|
1210
1311
|
)
|
|
@@ -1244,7 +1345,8 @@ add_builtin(
|
|
|
1244
1345
|
)
|
|
1245
1346
|
add_builtin(
|
|
1246
1347
|
"quaternion",
|
|
1247
|
-
input_types={"x": Float, "y": Float, "z": Float, "w": Float},
|
|
1348
|
+
input_types={"x": Float, "y": Float, "z": Float, "w": Float, "dtype": Scalar},
|
|
1349
|
+
defaults={"dtype": None},
|
|
1248
1350
|
value_func=quaternion_value_func,
|
|
1249
1351
|
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
1250
1352
|
dispatch_func=quaternion_dispatch_func,
|
|
@@ -1332,7 +1434,18 @@ add_builtin(
|
|
|
1332
1434
|
input_types={"mat": matrix(shape=(3, 3), dtype=Float)},
|
|
1333
1435
|
value_func=lambda arg_types, arg_values: quaternion(dtype=float_infer_type(arg_types)),
|
|
1334
1436
|
group="Quaternion Math",
|
|
1335
|
-
doc="Construct a quaternion from a 3x3 matrix.
|
|
1437
|
+
doc="""Construct a quaternion from a 3x3 matrix.
|
|
1438
|
+
|
|
1439
|
+
If the matrix is not a pure rotation, but for example includes scaling or skewing, the result is undefined.""",
|
|
1440
|
+
)
|
|
1441
|
+
add_builtin(
|
|
1442
|
+
"quat_from_matrix",
|
|
1443
|
+
input_types={"mat": matrix(shape=(4, 4), dtype=Float)},
|
|
1444
|
+
value_func=lambda arg_types, arg_values: quaternion(dtype=float_infer_type(arg_types)),
|
|
1445
|
+
group="Quaternion Math",
|
|
1446
|
+
doc="""Construct a quaternion from a 4x4 matrix.
|
|
1447
|
+
|
|
1448
|
+
If the top-left 3x3 block of the matrix is not a pure rotation, but for example includes scaling or skewing, the result is undefined.""",
|
|
1336
1449
|
)
|
|
1337
1450
|
add_builtin(
|
|
1338
1451
|
"quat_rpy",
|
|
@@ -1403,7 +1516,7 @@ def transformation_value_func(arg_types: Mapping[str, type], arg_values: Mapping
|
|
|
1403
1516
|
dtype = arg_values.get("dtype", None)
|
|
1404
1517
|
if dtype is None:
|
|
1405
1518
|
dtype = value_type
|
|
1406
|
-
elif value_type
|
|
1519
|
+
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1407
1520
|
raise RuntimeError(
|
|
1408
1521
|
f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
|
|
1409
1522
|
)
|
|
@@ -1570,7 +1683,7 @@ def spatial_vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping
|
|
|
1570
1683
|
|
|
1571
1684
|
if dtype is None:
|
|
1572
1685
|
dtype = value_type
|
|
1573
|
-
elif value_type
|
|
1686
|
+
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1574
1687
|
raise RuntimeError(
|
|
1575
1688
|
f"all values used to initialize this spatial vector are expected to be of the type `{dtype.__name__}`"
|
|
1576
1689
|
)
|
|
@@ -2375,7 +2488,7 @@ add_builtin(
|
|
|
2375
2488
|
|
|
2376
2489
|
This function converts values computed using scalar kernel code to a tile representation for input into collective operations.
|
|
2377
2490
|
|
|
2378
|
-
* If the input value is a scalar, then the resulting tile has ``shape=(
|
|
2491
|
+
* If the input value is a scalar, then the resulting tile has ``shape=(block_dim,)``
|
|
2379
2492
|
* If the input value is a vector, then the resulting tile has ``shape=(length(vector), block_dim)``
|
|
2380
2493
|
|
|
2381
2494
|
:param x: A per-thread local value, e.g. scalar, vector, or matrix.
|
|
@@ -2669,11 +2782,9 @@ def tile_broadcast_value_func(arg_types, arg_values):
|
|
|
2669
2782
|
def tile_broadcast_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
2670
2783
|
tile = arg_values["a"]
|
|
2671
2784
|
|
|
2672
|
-
|
|
2673
|
-
|
|
2674
|
-
template_args.
|
|
2675
|
-
template_args.append(return_type.strides[0])
|
|
2676
|
-
template_args.append(return_type.strides[1])
|
|
2785
|
+
assert len(return_type.shape) == len(return_type.strides)
|
|
2786
|
+
assert 1 <= len(return_type.shape) <= 4
|
|
2787
|
+
template_args = [*return_type.shape, *return_type.strides]
|
|
2677
2788
|
|
|
2678
2789
|
return ((tile,), template_args)
|
|
2679
2790
|
|
|
@@ -2686,56 +2797,17 @@ add_builtin(
|
|
|
2686
2797
|
variadic=False,
|
|
2687
2798
|
doc="""Broadcast a tile.
|
|
2688
2799
|
|
|
2689
|
-
|
|
2690
|
-
|
|
2800
|
+
Broadcasts the input tile ``a`` to the destination shape.
|
|
2691
2801
|
Broadcasting follows NumPy broadcast rules.
|
|
2692
2802
|
|
|
2693
2803
|
:param a: Tile to broadcast
|
|
2694
2804
|
:param shape: The shape to broadcast to
|
|
2695
|
-
:returns: Tile with broadcast
|
|
2805
|
+
:returns: Tile with broadcast shape""",
|
|
2696
2806
|
group="Tile Primitives",
|
|
2697
2807
|
export=False,
|
|
2698
2808
|
)
|
|
2699
2809
|
|
|
2700
2810
|
|
|
2701
|
-
def tile_matmul_value_func(arg_types, arg_values):
|
|
2702
|
-
# return generic type (for doc builds)
|
|
2703
|
-
if arg_types is None:
|
|
2704
|
-
return Tile(dtype=Any, shape=Any)
|
|
2705
|
-
|
|
2706
|
-
if len(arg_types) != 3:
|
|
2707
|
-
raise TypeError(f"tile_matmul() takes exactly 3 positional arguments but {len(arg_types)} were given")
|
|
2708
|
-
|
|
2709
|
-
return None
|
|
2710
|
-
|
|
2711
|
-
|
|
2712
|
-
def tile_matmul_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
2713
|
-
a = arg_values["a"]
|
|
2714
|
-
b = arg_values["b"]
|
|
2715
|
-
out = arg_values["out"]
|
|
2716
|
-
|
|
2717
|
-
# force the storage type of the input variables to shared memory
|
|
2718
|
-
a.type.storage = "shared"
|
|
2719
|
-
b.type.storage = "shared"
|
|
2720
|
-
out.type.storage = "shared"
|
|
2721
|
-
|
|
2722
|
-
template_args = []
|
|
2723
|
-
return ((a, b, out), template_args)
|
|
2724
|
-
|
|
2725
|
-
|
|
2726
|
-
add_builtin(
|
|
2727
|
-
"tile_matmul_scalar",
|
|
2728
|
-
input_types={"a": Tile, "b": Tile, "out": Tile},
|
|
2729
|
-
value_func=tile_matmul_value_func,
|
|
2730
|
-
dispatch_func=tile_matmul_dispatch_func,
|
|
2731
|
-
variadic=True,
|
|
2732
|
-
doc="Compute matrix product and accumulate out += a*b.",
|
|
2733
|
-
group="Tile Primitives",
|
|
2734
|
-
hidden=True,
|
|
2735
|
-
export=False,
|
|
2736
|
-
)
|
|
2737
|
-
|
|
2738
|
-
|
|
2739
2811
|
def tile_sum_value_func(arg_types, arg_values):
|
|
2740
2812
|
# return generic type (for doc builds)
|
|
2741
2813
|
if arg_types is None:
|
|
@@ -3030,7 +3102,7 @@ def tile_binary_map_value_func(arg_types, arg_values):
|
|
|
3030
3102
|
|
|
3031
3103
|
for i in range(len(a.shape)):
|
|
3032
3104
|
if a.shape[i] != b.shape[i]:
|
|
3033
|
-
raise ValueError(f"tile_map() shapes do not match on dimension {i}, got {a.shape
|
|
3105
|
+
raise ValueError(f"tile_map() shapes do not match on dimension {i}, got {a.shape} and {b.shape}")
|
|
3034
3106
|
|
|
3035
3107
|
return TileBinaryMap(a, b)
|
|
3036
3108
|
|
|
@@ -3807,6 +3879,18 @@ _volume_supported_value_types = {
|
|
|
3807
3879
|
}
|
|
3808
3880
|
|
|
3809
3881
|
|
|
3882
|
+
def _is_volume_type_supported(dtype):
|
|
3883
|
+
for typ in _volume_supported_value_types:
|
|
3884
|
+
if types_equal(typ, dtype):
|
|
3885
|
+
return True
|
|
3886
|
+
return False
|
|
3887
|
+
|
|
3888
|
+
|
|
3889
|
+
def _check_volume_type_is_supported(dtype):
|
|
3890
|
+
if not _is_volume_type_supported(dtype):
|
|
3891
|
+
raise RuntimeError(f"unsupported volume type `{type_repr(dtype)}`")
|
|
3892
|
+
|
|
3893
|
+
|
|
3810
3894
|
def check_volume_value_grad_compatibility(dtype, grad_dtype):
|
|
3811
3895
|
if type_is_vector(dtype):
|
|
3812
3896
|
expected = matrix(shape=(type_length(dtype), 3), dtype=type_scalar_type(dtype))
|
|
@@ -3822,9 +3906,7 @@ def volume_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
3822
3906
|
return Any
|
|
3823
3907
|
|
|
3824
3908
|
dtype = arg_values["dtype"]
|
|
3825
|
-
|
|
3826
|
-
if dtype not in _volume_supported_value_types:
|
|
3827
|
-
raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
|
|
3909
|
+
_check_volume_type_is_supported(dtype)
|
|
3828
3910
|
|
|
3829
3911
|
return dtype
|
|
3830
3912
|
|
|
@@ -3860,9 +3942,7 @@ def volume_sample_grad_value_func(arg_types: Mapping[str, type], arg_values: Map
|
|
|
3860
3942
|
return Any
|
|
3861
3943
|
|
|
3862
3944
|
dtype = arg_values["dtype"]
|
|
3863
|
-
|
|
3864
|
-
if dtype not in _volume_supported_value_types:
|
|
3865
|
-
raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
|
|
3945
|
+
_check_volume_type_is_supported(dtype)
|
|
3866
3946
|
|
|
3867
3947
|
check_volume_value_grad_compatibility(dtype, arg_types["grad"])
|
|
3868
3948
|
|
|
@@ -3900,9 +3980,7 @@ def volume_lookup_value_func(arg_types: Mapping[str, type], arg_values: Mapping[
|
|
|
3900
3980
|
return Any
|
|
3901
3981
|
|
|
3902
3982
|
dtype = arg_values["dtype"]
|
|
3903
|
-
|
|
3904
|
-
if dtype not in _volume_supported_value_types:
|
|
3905
|
-
raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
|
|
3983
|
+
_check_volume_type_is_supported(dtype)
|
|
3906
3984
|
|
|
3907
3985
|
return dtype
|
|
3908
3986
|
|
|
@@ -3939,9 +4017,7 @@ def volume_store_value_func(arg_types: Mapping[str, type], arg_values: Mapping[s
|
|
|
3939
4017
|
return None
|
|
3940
4018
|
|
|
3941
4019
|
dtype = arg_types["value"]
|
|
3942
|
-
|
|
3943
|
-
if dtype not in _volume_supported_value_types:
|
|
3944
|
-
raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
|
|
4020
|
+
_check_volume_type_is_supported(dtype)
|
|
3945
4021
|
|
|
3946
4022
|
return None
|
|
3947
4023
|
|
|
@@ -4191,6 +4267,20 @@ add_builtin(
|
|
|
4191
4267
|
group="Random",
|
|
4192
4268
|
doc="Return a random integer between [low, high).",
|
|
4193
4269
|
)
|
|
4270
|
+
add_builtin(
|
|
4271
|
+
"randu",
|
|
4272
|
+
input_types={"state": uint32},
|
|
4273
|
+
value_type=uint32,
|
|
4274
|
+
group="Random",
|
|
4275
|
+
doc="Return a random unsigned integer in the range [0, 2^32).",
|
|
4276
|
+
)
|
|
4277
|
+
add_builtin(
|
|
4278
|
+
"randu",
|
|
4279
|
+
input_types={"state": uint32, "low": uint32, "high": uint32},
|
|
4280
|
+
value_type=uint32,
|
|
4281
|
+
group="Random",
|
|
4282
|
+
doc="Return a random unsigned integer between [low, high).",
|
|
4283
|
+
)
|
|
4194
4284
|
add_builtin(
|
|
4195
4285
|
"randf",
|
|
4196
4286
|
input_types={"state": uint32},
|
|
@@ -4499,11 +4589,31 @@ add_builtin(
|
|
|
4499
4589
|
export=False,
|
|
4500
4590
|
group="Utility",
|
|
4501
4591
|
)
|
|
4592
|
+
|
|
4593
|
+
|
|
4594
|
+
def select_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
4595
|
+
warp.utils.warn(
|
|
4596
|
+
"wp.select() is deprecated and will be removed in a future\n"
|
|
4597
|
+
"version. Use wp.where(cond, value_if_true, value_if_false) instead.",
|
|
4598
|
+
category=DeprecationWarning,
|
|
4599
|
+
)
|
|
4600
|
+
|
|
4601
|
+
func_args = tuple(args.values())
|
|
4602
|
+
template_args = ()
|
|
4603
|
+
|
|
4604
|
+
return (func_args, template_args)
|
|
4605
|
+
|
|
4606
|
+
|
|
4502
4607
|
add_builtin(
|
|
4503
4608
|
"select",
|
|
4504
4609
|
input_types={"cond": builtins.bool, "value_if_false": Any, "value_if_true": Any},
|
|
4505
4610
|
value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
|
|
4506
|
-
|
|
4611
|
+
dispatch_func=select_dispatch_func,
|
|
4612
|
+
doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
4613
|
+
|
|
4614
|
+
.. deprecated:: 1.7
|
|
4615
|
+
Use :func:`where` instead, which has the more intuitive argument order:
|
|
4616
|
+
``where(cond, value_if_true, value_if_false)``.""",
|
|
4507
4617
|
group="Utility",
|
|
4508
4618
|
)
|
|
4509
4619
|
for t in int_types:
|
|
@@ -4511,14 +4621,47 @@ for t in int_types:
|
|
|
4511
4621
|
"select",
|
|
4512
4622
|
input_types={"cond": t, "value_if_false": Any, "value_if_true": Any},
|
|
4513
4623
|
value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
|
|
4514
|
-
|
|
4624
|
+
dispatch_func=select_dispatch_func,
|
|
4625
|
+
doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
4626
|
+
|
|
4627
|
+
.. deprecated:: 1.7
|
|
4628
|
+
Use :func:`where` instead, which has the more intuitive argument order:
|
|
4629
|
+
``where(cond, value_if_true, value_if_false)``.""",
|
|
4515
4630
|
group="Utility",
|
|
4516
4631
|
)
|
|
4517
4632
|
add_builtin(
|
|
4518
4633
|
"select",
|
|
4519
4634
|
input_types={"arr": array(dtype=Any), "value_if_false": Any, "value_if_true": Any},
|
|
4520
4635
|
value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
|
|
4521
|
-
|
|
4636
|
+
dispatch_func=select_dispatch_func,
|
|
4637
|
+
doc="""Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
4638
|
+
|
|
4639
|
+
.. deprecated:: 1.7
|
|
4640
|
+
Use :func:`where` instead, which has the more intuitive argument order:
|
|
4641
|
+
``where(arr, value_if_true, value_if_false)``.""",
|
|
4642
|
+
group="Utility",
|
|
4643
|
+
)
|
|
4644
|
+
|
|
4645
|
+
add_builtin(
|
|
4646
|
+
"where",
|
|
4647
|
+
input_types={"cond": builtins.bool, "value_if_true": Any, "value_if_false": Any},
|
|
4648
|
+
value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
|
|
4649
|
+
doc="Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``.",
|
|
4650
|
+
group="Utility",
|
|
4651
|
+
)
|
|
4652
|
+
for t in int_types:
|
|
4653
|
+
add_builtin(
|
|
4654
|
+
"where",
|
|
4655
|
+
input_types={"cond": t, "value_if_true": Any, "value_if_false": Any},
|
|
4656
|
+
value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
|
|
4657
|
+
doc="Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``.",
|
|
4658
|
+
group="Utility",
|
|
4659
|
+
)
|
|
4660
|
+
add_builtin(
|
|
4661
|
+
"where",
|
|
4662
|
+
input_types={"arr": array(dtype=Any), "value_if_true": Any, "value_if_false": Any},
|
|
4663
|
+
value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
|
|
4664
|
+
doc="Select between two arguments, if ``arr`` is not null then return ``value_if_true``, otherwise return ``value_if_false``.",
|
|
4522
4665
|
group="Utility",
|
|
4523
4666
|
)
|
|
4524
4667
|
|
|
@@ -5112,33 +5255,51 @@ add_builtin(
|
|
|
5112
5255
|
)
|
|
5113
5256
|
|
|
5114
5257
|
|
|
5258
|
+
# implements vector[index] = value
|
|
5259
|
+
add_builtin(
|
|
5260
|
+
"assign_inplace",
|
|
5261
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
|
|
5262
|
+
value_type=None,
|
|
5263
|
+
hidden=True,
|
|
5264
|
+
group="Utility",
|
|
5265
|
+
)
|
|
5266
|
+
|
|
5267
|
+
# implements quaternion[index] = value
|
|
5268
|
+
add_builtin(
|
|
5269
|
+
"assign_inplace",
|
|
5270
|
+
input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
|
|
5271
|
+
value_type=None,
|
|
5272
|
+
hidden=True,
|
|
5273
|
+
group="Utility",
|
|
5274
|
+
)
|
|
5275
|
+
|
|
5276
|
+
|
|
5115
5277
|
def vector_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
5116
5278
|
vec_type = arg_types["a"]
|
|
5117
5279
|
return vec_type
|
|
5118
5280
|
|
|
5119
5281
|
|
|
5120
|
-
# implements vector[index] = value
|
|
5282
|
+
# implements vector[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
|
|
5121
5283
|
add_builtin(
|
|
5122
|
-
"
|
|
5284
|
+
"assign_copy",
|
|
5123
5285
|
input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
|
|
5124
5286
|
value_func=vector_assign_value_func,
|
|
5125
5287
|
hidden=True,
|
|
5126
5288
|
group="Utility",
|
|
5127
5289
|
)
|
|
5128
5290
|
|
|
5129
|
-
# implements quaternion[index] = value
|
|
5291
|
+
# implements quaternion[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
|
|
5130
5292
|
add_builtin(
|
|
5131
|
-
"
|
|
5293
|
+
"assign_copy",
|
|
5132
5294
|
input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
|
|
5133
5295
|
value_func=vector_assign_value_func,
|
|
5134
5296
|
hidden=True,
|
|
5135
5297
|
group="Utility",
|
|
5136
5298
|
)
|
|
5137
5299
|
|
|
5138
|
-
|
|
5139
5300
|
# implements vector[idx] += scalar
|
|
5140
5301
|
add_builtin(
|
|
5141
|
-
"
|
|
5302
|
+
"add_inplace",
|
|
5142
5303
|
input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
|
|
5143
5304
|
value_type=None,
|
|
5144
5305
|
hidden=True,
|
|
@@ -5147,7 +5308,7 @@ add_builtin(
|
|
|
5147
5308
|
|
|
5148
5309
|
# implements quaternion[idx] += scalar
|
|
5149
5310
|
add_builtin(
|
|
5150
|
-
"
|
|
5311
|
+
"add_inplace",
|
|
5151
5312
|
input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
|
|
5152
5313
|
value_type=None,
|
|
5153
5314
|
hidden=True,
|
|
@@ -5156,7 +5317,7 @@ add_builtin(
|
|
|
5156
5317
|
|
|
5157
5318
|
# implements vector[idx] -= scalar
|
|
5158
5319
|
add_builtin(
|
|
5159
|
-
"
|
|
5320
|
+
"sub_inplace",
|
|
5160
5321
|
input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
|
|
5161
5322
|
value_type=None,
|
|
5162
5323
|
hidden=True,
|
|
@@ -5165,7 +5326,7 @@ add_builtin(
|
|
|
5165
5326
|
|
|
5166
5327
|
# implements quaternion[idx] -= scalar
|
|
5167
5328
|
add_builtin(
|
|
5168
|
-
"
|
|
5329
|
+
"sub_inplace",
|
|
5169
5330
|
input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
|
|
5170
5331
|
value_type=None,
|
|
5171
5332
|
hidden=True,
|
|
@@ -5209,11 +5370,6 @@ add_builtin(
|
|
|
5209
5370
|
)
|
|
5210
5371
|
|
|
5211
5372
|
|
|
5212
|
-
def matrix_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
5213
|
-
mat_type = arg_types["a"]
|
|
5214
|
-
return mat_type
|
|
5215
|
-
|
|
5216
|
-
|
|
5217
5373
|
def matrix_vector_sametype(arg_types: Mapping[str, Any]):
|
|
5218
5374
|
mat_size = arg_types["a"]._shape_[0]
|
|
5219
5375
|
vec_size = arg_types["value"]._length_
|
|
@@ -5224,7 +5380,33 @@ def matrix_vector_sametype(arg_types: Mapping[str, Any]):
|
|
|
5224
5380
|
|
|
5225
5381
|
# implements matrix[i,j] = scalar
|
|
5226
5382
|
add_builtin(
|
|
5227
|
-
"
|
|
5383
|
+
"assign_inplace",
|
|
5384
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
|
|
5385
|
+
value_type=None,
|
|
5386
|
+
hidden=True,
|
|
5387
|
+
group="Utility",
|
|
5388
|
+
)
|
|
5389
|
+
|
|
5390
|
+
|
|
5391
|
+
# implements matrix[i] = vector
|
|
5392
|
+
add_builtin(
|
|
5393
|
+
"assign_inplace",
|
|
5394
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
|
|
5395
|
+
constraint=matrix_vector_sametype,
|
|
5396
|
+
value_type=None,
|
|
5397
|
+
hidden=True,
|
|
5398
|
+
group="Utility",
|
|
5399
|
+
)
|
|
5400
|
+
|
|
5401
|
+
|
|
5402
|
+
def matrix_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
5403
|
+
mat_type = arg_types["a"]
|
|
5404
|
+
return mat_type
|
|
5405
|
+
|
|
5406
|
+
|
|
5407
|
+
# implements matrix[i,j] = scalar
|
|
5408
|
+
add_builtin(
|
|
5409
|
+
"assign_copy",
|
|
5228
5410
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
|
|
5229
5411
|
value_func=matrix_assign_value_func,
|
|
5230
5412
|
hidden=True,
|
|
@@ -5234,7 +5416,7 @@ add_builtin(
|
|
|
5234
5416
|
|
|
5235
5417
|
# implements matrix[i] = vector
|
|
5236
5418
|
add_builtin(
|
|
5237
|
-
"
|
|
5419
|
+
"assign_copy",
|
|
5238
5420
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
|
|
5239
5421
|
constraint=matrix_vector_sametype,
|
|
5240
5422
|
value_func=matrix_assign_value_func,
|
|
@@ -5245,7 +5427,7 @@ add_builtin(
|
|
|
5245
5427
|
|
|
5246
5428
|
# implements matrix[i,j] += scalar
|
|
5247
5429
|
add_builtin(
|
|
5248
|
-
"
|
|
5430
|
+
"add_inplace",
|
|
5249
5431
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
|
|
5250
5432
|
value_type=None,
|
|
5251
5433
|
hidden=True,
|
|
@@ -5253,9 +5435,20 @@ add_builtin(
|
|
|
5253
5435
|
)
|
|
5254
5436
|
|
|
5255
5437
|
|
|
5438
|
+
# implements matrix[i] += vector
|
|
5439
|
+
add_builtin(
|
|
5440
|
+
"add_inplace",
|
|
5441
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
|
|
5442
|
+
constraint=matrix_vector_sametype,
|
|
5443
|
+
value_type=None,
|
|
5444
|
+
hidden=True,
|
|
5445
|
+
group="Utility",
|
|
5446
|
+
)
|
|
5447
|
+
|
|
5448
|
+
|
|
5256
5449
|
# implements matrix[i,j] -= scalar
|
|
5257
5450
|
add_builtin(
|
|
5258
|
-
"
|
|
5451
|
+
"sub_inplace",
|
|
5259
5452
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
|
|
5260
5453
|
value_type=None,
|
|
5261
5454
|
hidden=True,
|
|
@@ -5263,6 +5456,16 @@ add_builtin(
|
|
|
5263
5456
|
)
|
|
5264
5457
|
|
|
5265
5458
|
|
|
5459
|
+
# implements matrix[i] -= vector
|
|
5460
|
+
add_builtin(
|
|
5461
|
+
"sub_inplace",
|
|
5462
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
|
|
5463
|
+
value_type=None,
|
|
5464
|
+
hidden=True,
|
|
5465
|
+
group="Utility",
|
|
5466
|
+
)
|
|
5467
|
+
|
|
5468
|
+
|
|
5266
5469
|
for t in scalar_types + vector_types + (bool,):
|
|
5267
5470
|
if "vec" in t.__name__ or "mat" in t.__name__:
|
|
5268
5471
|
continue
|
|
@@ -5410,7 +5613,27 @@ add_builtin(
|
|
|
5410
5613
|
)
|
|
5411
5614
|
add_builtin(
|
|
5412
5615
|
"expect_near",
|
|
5413
|
-
input_types={"a":
|
|
5616
|
+
input_types={"a": vector(length=Any, dtype=Float), "b": vector(length=Any, dtype=Float), "tolerance": Float},
|
|
5617
|
+
defaults={"tolerance": 1.0e-6},
|
|
5618
|
+
value_type=None,
|
|
5619
|
+
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
5620
|
+
group="Utility",
|
|
5621
|
+
)
|
|
5622
|
+
add_builtin(
|
|
5623
|
+
"expect_near",
|
|
5624
|
+
input_types={"a": quaternion(dtype=Float), "b": quaternion(dtype=Float), "tolerance": Float},
|
|
5625
|
+
defaults={"tolerance": 1.0e-6},
|
|
5626
|
+
value_type=None,
|
|
5627
|
+
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
5628
|
+
group="Utility",
|
|
5629
|
+
)
|
|
5630
|
+
add_builtin(
|
|
5631
|
+
"expect_near",
|
|
5632
|
+
input_types={
|
|
5633
|
+
"a": matrix(shape=(Any, Any), dtype=Float),
|
|
5634
|
+
"b": matrix(shape=(Any, Any), dtype=Float),
|
|
5635
|
+
"tolerance": Float,
|
|
5636
|
+
},
|
|
5414
5637
|
defaults={"tolerance": 1.0e-6},
|
|
5415
5638
|
value_type=None,
|
|
5416
5639
|
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
@@ -5989,7 +6212,7 @@ add_builtin(
|
|
|
5989
6212
|
##
|
|
5990
6213
|
## Matmul
|
|
5991
6214
|
##
|
|
5992
|
-
def
|
|
6215
|
+
def tile_matmul_value_func(arg_types, arg_values):
|
|
5993
6216
|
# return generic type (for doc builds)
|
|
5994
6217
|
if arg_types is None:
|
|
5995
6218
|
return Tile(dtype=Any, shape=Any)
|
|
@@ -6015,7 +6238,7 @@ def tile_matmul_generic_value_func(arg_types, arg_values):
|
|
|
6015
6238
|
return None
|
|
6016
6239
|
|
|
6017
6240
|
|
|
6018
|
-
def
|
|
6241
|
+
def tile_matmul_lto_dispatch_func(
|
|
6019
6242
|
arg_types: Mapping[str, type],
|
|
6020
6243
|
return_type: Any,
|
|
6021
6244
|
return_values: List[Var],
|
|
@@ -6054,142 +6277,82 @@ def tile_matmul_generic_lto_dispatch_func(
|
|
|
6054
6277
|
out.type.storage = "shared"
|
|
6055
6278
|
template_args = [accumulate]
|
|
6056
6279
|
|
|
6057
|
-
# Maps Python/Warp types to C++ types and enums
|
|
6058
|
-
def cublasdx_type_map(dtype):
|
|
6059
|
-
if dtype == float16:
|
|
6060
|
-
return ("wp::float16", 3, 0)
|
|
6061
|
-
if dtype == float32:
|
|
6062
|
-
return ("wp::float32", 5, 0)
|
|
6063
|
-
if dtype == float64:
|
|
6064
|
-
return ("wp::float64", 6, 0)
|
|
6065
|
-
if dtype == vec2h:
|
|
6066
|
-
return ("wp::vec2h", 3, 1)
|
|
6067
|
-
if dtype == vec2f:
|
|
6068
|
-
return ("wp::vec2f", 5, 1)
|
|
6069
|
-
if dtype == vec2d:
|
|
6070
|
-
return ("wp::vec2d", 6, 1)
|
|
6071
|
-
raise TypeError("Unsupported input type in tile_matmul")
|
|
6072
|
-
|
|
6073
|
-
def cublasdx_arrangement_map(layout):
|
|
6074
|
-
if layout == "colmajor":
|
|
6075
|
-
return 0 # CUBLASDX_ARRANGEMENT_COL_MAJOR
|
|
6076
|
-
if layout == "rowmajor":
|
|
6077
|
-
return 1 # CUBLASDX_ARRANGEMENT_ROW_MAJOR
|
|
6078
|
-
raise ValueError("Unsupported layout in tile_matmul")
|
|
6079
|
-
|
|
6080
|
-
# generate the LTO
|
|
6081
6280
|
M, K = a.type.shape[0], a.type.shape[1]
|
|
6082
6281
|
_, N = b.type.shape[0], b.type.shape[1]
|
|
6083
6282
|
num_threads = options["block_dim"]
|
|
6084
6283
|
arch = options["output_arch"]
|
|
6085
6284
|
|
|
6086
|
-
|
|
6087
|
-
|
|
6088
|
-
(
|
|
6089
|
-
|
|
6090
|
-
a_arrangement = cublasdx_arrangement_map(alayout)
|
|
6091
|
-
b_arrangement = cublasdx_arrangement_map(blayout)
|
|
6092
|
-
c_arrangement = cublasdx_arrangement_map(clayout)
|
|
6093
|
-
|
|
6094
|
-
if a_type != b_type or a_type != c_type:
|
|
6095
|
-
raise TypeError("time_matmul(A, B, C) requires all inputs to be real or complex")
|
|
6096
|
-
|
|
6097
|
-
element_type = a_type
|
|
6098
|
-
|
|
6099
|
-
lto_symbol = f"dot_{M}_{N}_{K}_{arch}_{num_threads}_{a_arrangement}_{b_arrangement}_{c_arrangement}_{a_prec}_{b_prec}_{c_prec}_{element_type}"
|
|
6285
|
+
if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
|
|
6286
|
+
# CPU/no-MathDx dispatch
|
|
6287
|
+
return ((0, 0, 0, a, b, out), template_args, [], 0)
|
|
6288
|
+
else:
|
|
6100
6289
|
|
|
6101
|
-
|
|
6102
|
-
|
|
6103
|
-
|
|
6290
|
+
def tile_flip_layout(layout):
|
|
6291
|
+
if layout == "rowmajor":
|
|
6292
|
+
return "colmajor"
|
|
6293
|
+
elif layout == "colmajor":
|
|
6294
|
+
return "rowmajor"
|
|
6104
6295
|
|
|
6105
|
-
#
|
|
6106
|
-
|
|
6107
|
-
|
|
6108
|
-
|
|
6109
|
-
|
|
6110
|
-
|
|
6111
|
-
|
|
6112
|
-
|
|
6296
|
+
# generate the LTOs
|
|
6297
|
+
# C += A * B
|
|
6298
|
+
(fun_forward, lto_forward) = warp.build.build_lto_dot(
|
|
6299
|
+
M,
|
|
6300
|
+
N,
|
|
6301
|
+
K,
|
|
6302
|
+
a.type.dtype,
|
|
6303
|
+
b.type.dtype,
|
|
6304
|
+
out.type.dtype,
|
|
6305
|
+
a.type.layout,
|
|
6306
|
+
b.type.layout,
|
|
6307
|
+
out.type.layout,
|
|
6113
6308
|
arch,
|
|
6309
|
+
num_threads,
|
|
6310
|
+
builder,
|
|
6311
|
+
)
|
|
6312
|
+
# adjA += adjC * B^T - Transpose ~= flipped layout
|
|
6313
|
+
(fun_backward_A, lto_backward_A) = warp.build.build_lto_dot(
|
|
6114
6314
|
M,
|
|
6315
|
+
K,
|
|
6115
6316
|
N,
|
|
6317
|
+
out.type.dtype,
|
|
6318
|
+
b.type.dtype,
|
|
6319
|
+
a.type.dtype,
|
|
6320
|
+
out.type.layout,
|
|
6321
|
+
tile_flip_layout(b.type.layout),
|
|
6322
|
+
a.type.layout,
|
|
6323
|
+
arch,
|
|
6324
|
+
num_threads,
|
|
6325
|
+
builder,
|
|
6326
|
+
)
|
|
6327
|
+
# adjB += A^T * adjC - Transpose ~= flipped layout
|
|
6328
|
+
(fun_backward_B, lto_backward_B) = warp.build.build_lto_dot(
|
|
6116
6329
|
K,
|
|
6117
|
-
|
|
6118
|
-
|
|
6119
|
-
|
|
6120
|
-
|
|
6121
|
-
|
|
6122
|
-
|
|
6123
|
-
|
|
6330
|
+
N,
|
|
6331
|
+
M,
|
|
6332
|
+
a.type.dtype,
|
|
6333
|
+
out.type.dtype,
|
|
6334
|
+
b.type.dtype,
|
|
6335
|
+
tile_flip_layout(a.type.layout),
|
|
6336
|
+
out.type.layout,
|
|
6337
|
+
b.type.layout,
|
|
6338
|
+
arch,
|
|
6124
6339
|
num_threads,
|
|
6340
|
+
builder,
|
|
6125
6341
|
)
|
|
6126
|
-
lto_code_path = Path(lto_code.name)
|
|
6127
|
-
if not result:
|
|
6128
|
-
lto_code.close()
|
|
6129
|
-
if lto_code_path.exists():
|
|
6130
|
-
lto_code_path.unlink()
|
|
6131
|
-
raise RuntimeError("Failed to compile tile_matmul")
|
|
6132
|
-
else:
|
|
6133
|
-
with open(lto_code.name, "rb") as f:
|
|
6134
|
-
lto_code_data = f.read()
|
|
6135
|
-
lto_code.close()
|
|
6136
|
-
lto_code_path.unlink()
|
|
6137
|
-
|
|
6138
|
-
builder.ltoirs[lto_symbol] = lto_code_data
|
|
6139
|
-
builder.ltoirs_decl[lto_symbol] = (
|
|
6140
|
-
f"void {lto_symbol}({c_dtype}, {a_dtype}*, {b_dtype}*, {c_dtype}, {c_dtype}*);"
|
|
6141
|
-
)
|
|
6142
|
-
|
|
6143
|
-
return lto_symbol, lto_code_data
|
|
6144
6342
|
|
|
6145
|
-
|
|
6146
|
-
|
|
6147
|
-
|
|
6148
|
-
|
|
6149
|
-
|
|
6150
|
-
|
|
6151
|
-
|
|
6152
|
-
|
|
6153
|
-
|
|
6154
|
-
|
|
6155
|
-
|
|
6156
|
-
|
|
6157
|
-
|
|
6158
|
-
K,
|
|
6159
|
-
N,
|
|
6160
|
-
out.type.dtype,
|
|
6161
|
-
b.type.dtype,
|
|
6162
|
-
a.type.dtype,
|
|
6163
|
-
out.type.layout,
|
|
6164
|
-
tile_flip_layout(b.type.layout),
|
|
6165
|
-
a.type.layout,
|
|
6166
|
-
)
|
|
6167
|
-
# adjB += A^T * adjC - Transpose ~= flipped layout
|
|
6168
|
-
(fun_backward_B, lto_backward_B) = make_function(
|
|
6169
|
-
K,
|
|
6170
|
-
N,
|
|
6171
|
-
M,
|
|
6172
|
-
a.type.dtype,
|
|
6173
|
-
out.type.dtype,
|
|
6174
|
-
b.type.dtype,
|
|
6175
|
-
tile_flip_layout(a.type.layout),
|
|
6176
|
-
out.type.layout,
|
|
6177
|
-
b.type.layout,
|
|
6178
|
-
)
|
|
6179
|
-
|
|
6180
|
-
return (
|
|
6181
|
-
(
|
|
6182
|
-
Var(fun_forward, str, False, True, False),
|
|
6183
|
-
Var(fun_backward_A, str, False, True, False),
|
|
6184
|
-
Var(fun_backward_B, str, False, True, False),
|
|
6185
|
-
a,
|
|
6186
|
-
b,
|
|
6187
|
-
out,
|
|
6188
|
-
),
|
|
6189
|
-
template_args,
|
|
6190
|
-
[lto_forward, lto_backward_A, lto_backward_B],
|
|
6191
|
-
0,
|
|
6192
|
-
)
|
|
6343
|
+
return (
|
|
6344
|
+
(
|
|
6345
|
+
Var(fun_forward, str, False, True, False),
|
|
6346
|
+
Var(fun_backward_A, str, False, True, False),
|
|
6347
|
+
Var(fun_backward_B, str, False, True, False),
|
|
6348
|
+
a,
|
|
6349
|
+
b,
|
|
6350
|
+
out,
|
|
6351
|
+
),
|
|
6352
|
+
template_args,
|
|
6353
|
+
[lto_forward, lto_backward_A, lto_backward_B],
|
|
6354
|
+
0,
|
|
6355
|
+
)
|
|
6193
6356
|
|
|
6194
6357
|
|
|
6195
6358
|
add_builtin(
|
|
@@ -6199,8 +6362,8 @@ add_builtin(
|
|
|
6199
6362
|
"b": Tile(dtype=Any, shape=Any),
|
|
6200
6363
|
"out": Tile(dtype=Any, shape=Any),
|
|
6201
6364
|
},
|
|
6202
|
-
value_func=
|
|
6203
|
-
lto_dispatch_func=
|
|
6365
|
+
value_func=tile_matmul_value_func,
|
|
6366
|
+
lto_dispatch_func=tile_matmul_lto_dispatch_func,
|
|
6204
6367
|
variadic=False,
|
|
6205
6368
|
doc="""Computes the matrix product and accumulates ``out += a*b``.
|
|
6206
6369
|
|
|
@@ -6208,7 +6371,7 @@ add_builtin(
|
|
|
6208
6371
|
* fp16, fp32, fp64 (real)
|
|
6209
6372
|
* vec2h, vec2f, vec2d (complex)
|
|
6210
6373
|
|
|
6211
|
-
All input and output tiles must have the same datatype. Tile data will
|
|
6374
|
+
All input and output tiles must have the same datatype. Tile data will automatically be migrated
|
|
6212
6375
|
to shared memory if necessary and will use TensorCore operations when available.
|
|
6213
6376
|
|
|
6214
6377
|
:param a: A tile with ``shape=(M, K)``
|
|
@@ -6222,8 +6385,8 @@ add_builtin(
|
|
|
6222
6385
|
add_builtin(
|
|
6223
6386
|
"tile_matmul",
|
|
6224
6387
|
input_types={"a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
|
|
6225
|
-
value_func=
|
|
6226
|
-
lto_dispatch_func=
|
|
6388
|
+
value_func=tile_matmul_value_func,
|
|
6389
|
+
lto_dispatch_func=tile_matmul_lto_dispatch_func,
|
|
6227
6390
|
variadic=False,
|
|
6228
6391
|
doc="""Computes the matrix product ``out = a*b``.
|
|
6229
6392
|
|
|
@@ -6231,7 +6394,7 @@ add_builtin(
|
|
|
6231
6394
|
* fp16, fp32, fp64 (real)
|
|
6232
6395
|
* vec2h, vec2f, vec2d (complex)
|
|
6233
6396
|
|
|
6234
|
-
Both input tiles must have the same datatype. Tile data will
|
|
6397
|
+
Both input tiles must have the same datatype. Tile data will automatically be migrated
|
|
6235
6398
|
to shared memory if necessary and will use TensorCore operations when available.
|
|
6236
6399
|
|
|
6237
6400
|
:param a: A tile with ``shape=(M, K)``
|
|
@@ -6303,59 +6466,29 @@ def tile_fft_generic_lto_dispatch_func(
|
|
|
6303
6466
|
num_threads = options["block_dim"]
|
|
6304
6467
|
arch = options["output_arch"]
|
|
6305
6468
|
ept = size // num_threads
|
|
6306
|
-
|
|
6307
|
-
|
|
6308
|
-
|
|
6309
|
-
|
|
6310
|
-
|
|
6311
|
-
|
|
6312
|
-
|
|
6313
|
-
|
|
6314
|
-
|
|
6315
|
-
|
|
6316
|
-
|
|
6317
|
-
|
|
6318
|
-
|
|
6319
|
-
|
|
6320
|
-
|
|
6321
|
-
|
|
6322
|
-
|
|
6323
|
-
|
|
6324
|
-
|
|
6325
|
-
|
|
6326
|
-
|
|
6327
|
-
|
|
6328
|
-
|
|
6329
|
-
lto_code_path = Path(lto_code.name)
|
|
6330
|
-
if not result:
|
|
6331
|
-
lto_code.close()
|
|
6332
|
-
if lto_code_path.exists():
|
|
6333
|
-
lto_code_path.unlink()
|
|
6334
|
-
raise RuntimeError("Failed to compile tile_fft")
|
|
6335
|
-
|
|
6336
|
-
with open(lto_code.name, "rb") as f:
|
|
6337
|
-
lto_code_data = f.read()
|
|
6338
|
-
|
|
6339
|
-
lto_code.close()
|
|
6340
|
-
lto_code_path.unlink()
|
|
6341
|
-
|
|
6342
|
-
builder.ltoirs[lto_symbol] = lto_code_data
|
|
6343
|
-
|
|
6344
|
-
shared_memory_bytes = Tile.round_up(shared_memory_size.value)
|
|
6345
|
-
|
|
6346
|
-
return (
|
|
6347
|
-
(
|
|
6348
|
-
Var(lto_symbol, str, False, True, False),
|
|
6349
|
-
Var(dtype, str, False, True, False),
|
|
6350
|
-
Var(str(shared_memory_bytes), str, False, True, False),
|
|
6351
|
-
Var(str(batch), str, False, True, False),
|
|
6352
|
-
Var(str(ept), str, False, True, False),
|
|
6353
|
-
inout,
|
|
6354
|
-
),
|
|
6355
|
-
[],
|
|
6356
|
-
[lto_code_data],
|
|
6357
|
-
shared_memory_bytes,
|
|
6358
|
-
)
|
|
6469
|
+
|
|
6470
|
+
if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
|
|
6471
|
+
# CPU/no-MathDx dispatch
|
|
6472
|
+
return ([], [], [], 0)
|
|
6473
|
+
else:
|
|
6474
|
+
# generate the LTO
|
|
6475
|
+
lto_symbol, lto_code_data, shared_memory_bytes = warp.build.build_lto_fft(
|
|
6476
|
+
arch, size, ept, direction, dir, precision, builder
|
|
6477
|
+
)
|
|
6478
|
+
|
|
6479
|
+
return (
|
|
6480
|
+
(
|
|
6481
|
+
Var(lto_symbol, str, False, True, False),
|
|
6482
|
+
Var(dtype, str, False, True, False),
|
|
6483
|
+
Var(str(shared_memory_bytes), str, False, True, False),
|
|
6484
|
+
Var(str(batch), str, False, True, False),
|
|
6485
|
+
Var(str(ept), str, False, True, False),
|
|
6486
|
+
inout,
|
|
6487
|
+
),
|
|
6488
|
+
[],
|
|
6489
|
+
[lto_code_data],
|
|
6490
|
+
shared_memory_bytes,
|
|
6491
|
+
)
|
|
6359
6492
|
|
|
6360
6493
|
|
|
6361
6494
|
add_builtin(
|
|
@@ -6417,7 +6550,7 @@ def tile_cholesky_generic_value_func(arg_types, arg_values):
|
|
|
6417
6550
|
raise TypeError(f"tile_cholesky() argument must be a tile, got {a!r}")
|
|
6418
6551
|
|
|
6419
6552
|
if len(a.shape) != 2:
|
|
6420
|
-
raise ValueError("tile_cholesky()
|
|
6553
|
+
raise ValueError("tile_cholesky() argument must be a 2D tile")
|
|
6421
6554
|
|
|
6422
6555
|
if a.shape[0] != a.shape[1]:
|
|
6423
6556
|
raise ValueError("tile_cholesky() argument must be square")
|
|
@@ -6458,57 +6591,36 @@ def tile_cholesky_generic_lto_dispatch_func(
|
|
|
6458
6591
|
if out.type.shape[0] != M or out.type.shape[1] != M:
|
|
6459
6592
|
raise ValueError("tile_cholesky() output tile must be square")
|
|
6460
6593
|
|
|
6461
|
-
|
|
6462
|
-
|
|
6463
|
-
lto_symbol = f"potrf_{M}_{N}_{arch}_{precision_enum}"
|
|
6464
|
-
|
|
6465
|
-
# early out if LTO for this combination already exists for this module
|
|
6466
|
-
if lto_symbol in builder.ltoirs:
|
|
6467
|
-
return lto_symbol, builder.ltoirs[lto_symbol]
|
|
6468
|
-
|
|
6469
|
-
# otherwise compile LTO
|
|
6470
|
-
lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
|
|
6471
|
-
universal_fatbin_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
|
|
6594
|
+
solver = "potrf"
|
|
6595
|
+
solver_enum = cusolver_function_map[solver]
|
|
6472
6596
|
|
|
6473
|
-
# cuSOLVERDx only
|
|
6597
|
+
# cuSOLVERDx only supports col-major input/outputs,
|
|
6474
6598
|
# so we use upper to mimic a row-major input
|
|
6475
|
-
|
|
6476
|
-
universal_fatbin_code.name.encode("utf-8"),
|
|
6477
|
-
lto_code.name.encode("utf-8"),
|
|
6478
|
-
lto_symbol.encode("utf-8"),
|
|
6479
|
-
0,
|
|
6480
|
-
None,
|
|
6481
|
-
None,
|
|
6482
|
-
arch,
|
|
6483
|
-
M,
|
|
6484
|
-
N,
|
|
6485
|
-
cusolver_function_map["potrf"],
|
|
6486
|
-
precision_enum,
|
|
6487
|
-
cusolver_fill_mode_map["upper"],
|
|
6488
|
-
num_threads,
|
|
6489
|
-
)
|
|
6599
|
+
fill_mode = cusolver_fill_mode_map["upper"]
|
|
6490
6600
|
|
|
6491
|
-
|
|
6492
|
-
|
|
6493
|
-
|
|
6494
|
-
if Path(f.name).exists():
|
|
6495
|
-
Path(f.name).unlink()
|
|
6496
|
-
raise RuntimeError("Failed to compile tile_cholesky")
|
|
6601
|
+
arch = options["output_arch"]
|
|
6602
|
+
num_threads = options["block_dim"]
|
|
6603
|
+
parameter_list = f"({dtype}*, unsigned)"
|
|
6497
6604
|
|
|
6605
|
+
if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
|
|
6606
|
+
# CPU/no-MathDx dispatch
|
|
6607
|
+
return ((0, a, out), [], [], 0)
|
|
6498
6608
|
else:
|
|
6499
|
-
|
|
6500
|
-
|
|
6501
|
-
|
|
6502
|
-
|
|
6503
|
-
|
|
6504
|
-
|
|
6505
|
-
|
|
6506
|
-
|
|
6507
|
-
|
|
6508
|
-
|
|
6509
|
-
|
|
6609
|
+
# generate the LTO
|
|
6610
|
+
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
6611
|
+
M,
|
|
6612
|
+
N,
|
|
6613
|
+
solver,
|
|
6614
|
+
solver_enum,
|
|
6615
|
+
fill_mode,
|
|
6616
|
+
arch,
|
|
6617
|
+
precision_enum,
|
|
6618
|
+
num_threads,
|
|
6619
|
+
parameter_list,
|
|
6620
|
+
builder,
|
|
6621
|
+
)
|
|
6510
6622
|
|
|
6511
|
-
|
|
6623
|
+
return ((Var(lto_symbol, str, False, True, False), a, out), [], [lto_code_data], 0)
|
|
6512
6624
|
|
|
6513
6625
|
|
|
6514
6626
|
add_builtin(
|
|
@@ -6602,57 +6714,36 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
|
6602
6714
|
f"got {y.type.shape[0]} elements in output and {M} rows in 'L'"
|
|
6603
6715
|
)
|
|
6604
6716
|
|
|
6605
|
-
|
|
6606
|
-
|
|
6607
|
-
lto_symbol = f"potrs_{M}_{N}_{arch}_{precision_enum}"
|
|
6608
|
-
|
|
6609
|
-
# early out if LTO for this combination already exists for this module
|
|
6610
|
-
if lto_symbol in builder.ltoirs:
|
|
6611
|
-
return lto_symbol, builder.ltoirs[lto_symbol]
|
|
6612
|
-
|
|
6613
|
-
# otherwise compile LTO
|
|
6614
|
-
lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
|
|
6615
|
-
universal_fatbin_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
|
|
6717
|
+
solver = "potrs"
|
|
6718
|
+
solver_enum = cusolver_function_map[solver]
|
|
6616
6719
|
|
|
6617
|
-
# cuSOLVERDx only
|
|
6720
|
+
# cuSOLVERDx only supports col-major input/outputs,
|
|
6618
6721
|
# so we use upper to mimic a row-major input
|
|
6619
|
-
|
|
6620
|
-
universal_fatbin_code.name.encode("utf-8"),
|
|
6621
|
-
lto_code.name.encode("utf-8"),
|
|
6622
|
-
lto_symbol.encode("utf-8"),
|
|
6623
|
-
0,
|
|
6624
|
-
None,
|
|
6625
|
-
None,
|
|
6626
|
-
arch,
|
|
6627
|
-
M,
|
|
6628
|
-
N,
|
|
6629
|
-
cusolver_function_map["potrs"],
|
|
6630
|
-
precision_enum,
|
|
6631
|
-
cusolver_fill_mode_map["upper"],
|
|
6632
|
-
num_threads,
|
|
6633
|
-
)
|
|
6722
|
+
fill_mode = cusolver_fill_mode_map["upper"]
|
|
6634
6723
|
|
|
6635
|
-
|
|
6636
|
-
|
|
6637
|
-
|
|
6638
|
-
if Path(f.name).exists():
|
|
6639
|
-
Path(f.name).unlink()
|
|
6640
|
-
raise RuntimeError("Failed to compile tile_cholesky_solve")
|
|
6724
|
+
arch = options["output_arch"]
|
|
6725
|
+
num_threads = options["block_dim"]
|
|
6726
|
+
parameter_list = f"({dtype}*, {dtype}*)"
|
|
6641
6727
|
|
|
6728
|
+
if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
|
|
6729
|
+
# CPU/no-MathDx dispatch
|
|
6730
|
+
return ((0, L, x, y), [], [], 0)
|
|
6642
6731
|
else:
|
|
6643
|
-
|
|
6644
|
-
|
|
6645
|
-
|
|
6646
|
-
|
|
6647
|
-
|
|
6648
|
-
|
|
6649
|
-
|
|
6650
|
-
|
|
6651
|
-
|
|
6652
|
-
|
|
6653
|
-
|
|
6654
|
-
|
|
6655
|
-
|
|
6732
|
+
# generate the LTO
|
|
6733
|
+
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
6734
|
+
M,
|
|
6735
|
+
N,
|
|
6736
|
+
solver,
|
|
6737
|
+
solver_enum,
|
|
6738
|
+
fill_mode,
|
|
6739
|
+
arch,
|
|
6740
|
+
precision_enum,
|
|
6741
|
+
num_threads,
|
|
6742
|
+
parameter_list,
|
|
6743
|
+
builder,
|
|
6744
|
+
)
|
|
6745
|
+
|
|
6746
|
+
return ((Var(lto_symbol, str, False, True, False), L, x, y), [], [lto_code_data], 0)
|
|
6656
6747
|
|
|
6657
6748
|
|
|
6658
6749
|
add_builtin(
|