warp-lang 1.9.1__py3-none-manylinux_2_34_aarch64.whl → 1.10.0__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +301 -287
- warp/__init__.pyi +882 -305
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1077 -0
- warp/_src/build.py +620 -0
- warp/_src/build_dll.py +642 -0
- warp/{builtins.py → _src/builtins.py} +1435 -379
- warp/_src/codegen.py +4361 -0
- warp/{config.py → _src/config.py} +178 -169
- warp/_src/constants.py +59 -0
- warp/_src/context.py +8352 -0
- warp/_src/dlpack.py +464 -0
- warp/_src/fabric.py +362 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +510 -0
- warp/_src/fem/cache.py +689 -0
- warp/_src/fem/dirichlet.py +190 -0
- warp/{fem → _src/fem}/domain.py +42 -30
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +703 -0
- warp/{fem → _src/fem}/field/nodal_field.py +32 -15
- warp/{fem → _src/fem}/field/restriction.py +3 -1
- warp/{fem → _src/fem}/field/virtual.py +55 -27
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +79 -163
- warp/_src/fem/geometry/closest_point.py +99 -0
- warp/{fem → _src/fem}/geometry/deformed_geometry.py +16 -22
- warp/{fem → _src/fem}/geometry/element.py +34 -10
- warp/{fem → _src/fem}/geometry/geometry.py +50 -20
- warp/{fem → _src/fem}/geometry/grid_2d.py +14 -23
- warp/{fem → _src/fem}/geometry/grid_3d.py +14 -23
- warp/{fem → _src/fem}/geometry/hexmesh.py +42 -63
- warp/{fem → _src/fem}/geometry/nanogrid.py +256 -247
- warp/{fem → _src/fem}/geometry/partition.py +123 -63
- warp/{fem → _src/fem}/geometry/quadmesh.py +28 -45
- warp/{fem → _src/fem}/geometry/tetmesh.py +42 -63
- warp/{fem → _src/fem}/geometry/trimesh.py +28 -45
- warp/{fem → _src/fem}/integrate.py +166 -158
- warp/_src/fem/linalg.py +385 -0
- warp/_src/fem/operator.py +398 -0
- warp/_src/fem/polynomial.py +231 -0
- warp/{fem → _src/fem}/quadrature/pic_quadrature.py +17 -20
- warp/{fem → _src/fem}/quadrature/quadrature.py +97 -47
- warp/_src/fem/space/__init__.py +248 -0
- warp/{fem → _src/fem}/space/basis_function_space.py +22 -11
- warp/_src/fem/space/basis_space.py +681 -0
- warp/{fem → _src/fem}/space/dof_mapper.py +5 -3
- warp/{fem → _src/fem}/space/function_space.py +16 -13
- warp/{fem → _src/fem}/space/grid_2d_function_space.py +6 -7
- warp/{fem → _src/fem}/space/grid_3d_function_space.py +6 -4
- warp/{fem → _src/fem}/space/hexmesh_function_space.py +6 -10
- warp/{fem → _src/fem}/space/nanogrid_function_space.py +5 -9
- warp/{fem → _src/fem}/space/partition.py +119 -60
- warp/{fem → _src/fem}/space/quadmesh_function_space.py +6 -10
- warp/{fem → _src/fem}/space/restriction.py +68 -33
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/{fem → _src/fem}/space/shape/cube_shape_function.py +11 -9
- warp/{fem → _src/fem}/space/shape/shape_function.py +10 -9
- warp/{fem → _src/fem}/space/shape/square_shape_function.py +8 -6
- warp/{fem → _src/fem}/space/shape/tet_shape_function.py +5 -3
- warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +5 -3
- warp/{fem → _src/fem}/space/tetmesh_function_space.py +5 -9
- warp/_src/fem/space/topology.py +461 -0
- warp/{fem → _src/fem}/space/trimesh_function_space.py +5 -9
- warp/_src/fem/types.py +114 -0
- warp/_src/fem/utils.py +488 -0
- warp/_src/jax.py +188 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +389 -0
- warp/_src/jax_experimental/ffi.py +1286 -0
- warp/_src/jax_experimental/xla_ffi.py +658 -0
- warp/_src/marching_cubes.py +710 -0
- warp/_src/math.py +416 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +165 -0
- warp/_src/optim/linear.py +1608 -0
- warp/_src/optim/sgd.py +114 -0
- warp/_src/paddle.py +408 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +291 -0
- warp/_src/render/render_opengl.py +3638 -0
- warp/_src/render/render_usd.py +939 -0
- warp/_src/render/utils.py +162 -0
- warp/_src/sparse.py +2718 -0
- warp/_src/tape.py +1208 -0
- warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
- warp/_src/torch.py +393 -0
- warp/_src/types.py +5888 -0
- warp/_src/utils.py +1695 -0
- warp/autograd.py +12 -1054
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +8 -588
- warp/build_dll.py +6 -721
- warp/codegen.py +6 -4251
- warp/constants.py +6 -39
- warp/context.py +12 -8062
- warp/dlpack.py +6 -444
- warp/examples/distributed/example_jacobi_mpi.py +4 -5
- warp/examples/fem/example_adaptive_grid.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +8 -8
- warp/examples/fem/example_diffusion.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_mixed_elasticity.py +2 -2
- warp/examples/fem/example_navier_stokes.py +1 -1
- warp/examples/fem/example_nonconforming_contact.py +7 -7
- warp/examples/fem/example_stokes.py +1 -1
- warp/examples/fem/example_stokes_transfer.py +1 -1
- warp/examples/fem/utils.py +2 -2
- warp/examples/interop/example_jax_callable.py +1 -1
- warp/examples/interop/example_jax_ffi_callback.py +1 -1
- warp/examples/interop/example_jax_kernel.py +1 -1
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/fabric.py +6 -337
- warp/fem/__init__.py +159 -97
- warp/fem/adaptivity.py +7 -489
- warp/fem/cache.py +9 -648
- warp/fem/dirichlet.py +6 -184
- warp/fem/field/__init__.py +8 -109
- warp/fem/field/field.py +7 -652
- warp/fem/geometry/__init__.py +7 -18
- warp/fem/geometry/closest_point.py +11 -77
- warp/fem/linalg.py +18 -366
- warp/fem/operator.py +11 -369
- warp/fem/polynomial.py +9 -209
- warp/fem/space/__init__.py +5 -211
- warp/fem/space/basis_space.py +6 -662
- warp/fem/space/shape/__init__.py +41 -118
- warp/fem/space/topology.py +6 -437
- warp/fem/types.py +6 -81
- warp/fem/utils.py +11 -444
- warp/jax.py +8 -165
- warp/jax_experimental/__init__.py +14 -1
- warp/jax_experimental/custom_call.py +8 -365
- warp/jax_experimental/ffi.py +17 -873
- warp/jax_experimental/xla_ffi.py +5 -605
- warp/marching_cubes.py +5 -689
- warp/math.py +16 -393
- warp/native/array.h +385 -37
- warp/native/builtin.h +314 -37
- warp/native/bvh.cpp +43 -9
- warp/native/bvh.cu +62 -27
- warp/native/bvh.h +310 -309
- warp/native/clang/clang.cpp +102 -97
- warp/native/coloring.cpp +0 -1
- warp/native/crt.h +208 -0
- warp/native/exports.h +156 -0
- warp/native/hashgrid.cu +2 -0
- warp/native/intersect.h +24 -1
- warp/native/intersect_tri.h +44 -35
- warp/native/mat.h +1456 -276
- warp/native/mesh.cpp +4 -4
- warp/native/mesh.cu +4 -2
- warp/native/mesh.h +176 -61
- warp/native/quat.h +0 -52
- warp/native/scan.cu +2 -0
- warp/native/sparse.cu +7 -3
- warp/native/spatial.h +12 -0
- warp/native/tile.h +681 -89
- warp/native/tile_radix_sort.h +3 -3
- warp/native/tile_reduce.h +394 -46
- warp/native/tile_scan.h +4 -4
- warp/native/vec.h +469 -0
- warp/native/version.h +23 -0
- warp/native/volume.cpp +1 -1
- warp/native/volume.cu +1 -0
- warp/native/volume.h +1 -1
- warp/native/volume_builder.cu +2 -0
- warp/native/warp.cpp +57 -29
- warp/native/warp.cu +521 -250
- warp/native/warp.h +11 -8
- warp/optim/__init__.py +6 -3
- warp/optim/adam.py +6 -145
- warp/optim/linear.py +14 -1585
- warp/optim/sgd.py +6 -94
- warp/paddle.py +6 -388
- warp/render/__init__.py +8 -4
- warp/render/imgui_manager.py +7 -267
- warp/render/render_opengl.py +6 -3618
- warp/render/render_usd.py +6 -919
- warp/render/utils.py +6 -142
- warp/sparse.py +37 -2563
- warp/tape.py +6 -1188
- warp/tests/__main__.py +1 -1
- warp/tests/cuda/test_async.py +4 -4
- warp/tests/cuda/test_conditional_captures.py +1 -1
- warp/tests/cuda/test_multigpu.py +1 -1
- warp/tests/cuda/test_streams.py +58 -1
- warp/tests/geometry/test_bvh.py +157 -22
- warp/tests/geometry/test_marching_cubes.py +0 -1
- warp/tests/geometry/test_mesh.py +5 -3
- warp/tests/geometry/test_mesh_query_aabb.py +5 -12
- warp/tests/geometry/test_mesh_query_point.py +5 -2
- warp/tests/geometry/test_mesh_query_ray.py +15 -3
- warp/tests/geometry/test_volume_write.py +5 -5
- warp/tests/interop/test_dlpack.py +18 -17
- warp/tests/interop/test_jax.py +772 -49
- warp/tests/interop/test_paddle.py +1 -1
- warp/tests/test_adam.py +0 -1
- warp/tests/test_arithmetic.py +9 -9
- warp/tests/test_array.py +578 -100
- warp/tests/test_array_reduce.py +3 -3
- warp/tests/test_atomic.py +12 -8
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +4 -4
- warp/tests/test_bool.py +2 -2
- warp/tests/test_builtins_resolution.py +5 -571
- warp/tests/test_codegen.py +33 -14
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_context.py +6 -6
- warp/tests/test_copy.py +242 -161
- warp/tests/test_ctypes.py +3 -3
- warp/tests/test_devices.py +24 -2
- warp/tests/test_examples.py +16 -84
- warp/tests/test_fabricarray.py +35 -35
- warp/tests/test_fast_math.py +0 -2
- warp/tests/test_fem.py +56 -10
- warp/tests/test_fixedarray.py +3 -3
- warp/tests/test_func.py +8 -5
- warp/tests/test_generics.py +1 -1
- warp/tests/test_indexedarray.py +24 -24
- warp/tests/test_intersect.py +39 -9
- warp/tests/test_large.py +1 -1
- warp/tests/test_lerp.py +3 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_map.py +35 -4
- warp/tests/test_mat.py +52 -62
- warp/tests/test_mat_constructors.py +4 -5
- warp/tests/test_mat_lite.py +1 -1
- warp/tests/test_mat_scalar_ops.py +121 -121
- warp/tests/test_math.py +34 -0
- warp/tests/test_module_aot.py +4 -4
- warp/tests/test_modules_lite.py +28 -2
- warp/tests/test_print.py +11 -11
- warp/tests/test_quat.py +93 -58
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +38 -10
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +126 -15
- warp/tests/test_spatial.py +105 -87
- warp/tests/test_special_values.py +6 -6
- warp/tests/test_static.py +7 -7
- warp/tests/test_struct.py +13 -2
- warp/tests/test_triangle_closest_point.py +48 -1
- warp/tests/test_types.py +27 -15
- warp/tests/test_utils.py +52 -52
- warp/tests/test_vec.py +29 -29
- warp/tests/test_vec_constructors.py +5 -5
- warp/tests/test_vec_scalar_ops.py +97 -97
- warp/tests/test_version.py +75 -0
- warp/tests/tile/test_tile.py +178 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +7 -4
- warp/tests/tile/test_tile_load.py +26 -2
- warp/tests/tile/test_tile_mathdx.py +3 -3
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +2 -4
- warp/tests/tile/test_tile_reduce.py +214 -13
- warp/tests/unittest_suites.py +6 -14
- warp/tests/unittest_utils.py +10 -9
- warp/tests/walkthrough_debug.py +3 -1
- warp/torch.py +6 -373
- warp/types.py +29 -5764
- warp/utils.py +10 -1659
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/METADATA +46 -99
- warp_lang-1.10.0.dist-info/RECORD +468 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp/examples/assets/cartpole.urdf +0 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/nv_ant.xml +0 -92
- warp/examples/assets/nv_humanoid.xml +0 -183
- warp/examples/assets/quadruped.urdf +0 -268
- warp/examples/optim/example_bounce.py +0 -266
- warp/examples/optim/example_cloth_throw.py +0 -228
- warp/examples/optim/example_drone.py +0 -870
- warp/examples/optim/example_inverse_kinematics.py +0 -182
- warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
- warp/examples/optim/example_softbody_properties.py +0 -400
- warp/examples/optim/example_spring_cage.py +0 -245
- warp/examples/optim/example_trajectory.py +0 -227
- warp/examples/sim/example_cartpole.py +0 -143
- warp/examples/sim/example_cloth.py +0 -225
- warp/examples/sim/example_cloth_self_contact.py +0 -316
- warp/examples/sim/example_granular.py +0 -130
- warp/examples/sim/example_granular_collision_sdf.py +0 -202
- warp/examples/sim/example_jacobian_ik.py +0 -244
- warp/examples/sim/example_particle_chain.py +0 -124
- warp/examples/sim/example_quadruped.py +0 -203
- warp/examples/sim/example_rigid_chain.py +0 -203
- warp/examples/sim/example_rigid_contact.py +0 -195
- warp/examples/sim/example_rigid_force.py +0 -133
- warp/examples/sim/example_rigid_gyroscopic.py +0 -115
- warp/examples/sim/example_rigid_soft_contact.py +0 -140
- warp/examples/sim/example_soft_body.py +0 -196
- warp/examples/tile/example_tile_walker.py +0 -327
- warp/sim/__init__.py +0 -74
- warp/sim/articulation.py +0 -793
- warp/sim/collide.py +0 -2570
- warp/sim/graph_coloring.py +0 -307
- warp/sim/import_mjcf.py +0 -791
- warp/sim/import_snu.py +0 -227
- warp/sim/import_urdf.py +0 -579
- warp/sim/import_usd.py +0 -898
- warp/sim/inertia.py +0 -357
- warp/sim/integrator.py +0 -245
- warp/sim/integrator_euler.py +0 -2000
- warp/sim/integrator_featherstone.py +0 -2101
- warp/sim/integrator_vbd.py +0 -2487
- warp/sim/integrator_xpbd.py +0 -3295
- warp/sim/model.py +0 -4821
- warp/sim/particles.py +0 -121
- warp/sim/render.py +0 -431
- warp/sim/utils.py +0 -431
- warp/tests/sim/disabled_kinematics.py +0 -244
- warp/tests/sim/test_cloth.py +0 -863
- warp/tests/sim/test_collision.py +0 -743
- warp/tests/sim/test_coloring.py +0 -347
- warp/tests/sim/test_inertia.py +0 -161
- warp/tests/sim/test_model.py +0 -226
- warp/tests/sim/test_sim_grad.py +0 -287
- warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
- warp/tests/sim/test_sim_kinematics.py +0 -98
- warp/thirdparty/__init__.py +0 -0
- warp_lang-1.9.1.dist-info/RECORD +0 -456
- /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
- /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/top_level.txt +0 -0
|
@@ -20,14 +20,16 @@ import functools
|
|
|
20
20
|
import math
|
|
21
21
|
from typing import Any, Callable, Mapping, Sequence
|
|
22
22
|
|
|
23
|
-
import warp.build
|
|
24
|
-
import warp.context
|
|
25
|
-
import warp.utils
|
|
26
|
-
from warp.codegen import Reference, Var, get_arg_value, strip_reference
|
|
27
|
-
from warp.types import *
|
|
23
|
+
import warp._src.build
|
|
24
|
+
import warp._src.context
|
|
25
|
+
import warp._src.utils
|
|
26
|
+
from warp._src.codegen import Reference, Var, get_arg_value, strip_reference
|
|
27
|
+
from warp._src.types import *
|
|
28
28
|
|
|
29
29
|
from .context import add_builtin
|
|
30
30
|
|
|
31
|
+
_wp_module_name_ = "warp.builtins"
|
|
32
|
+
|
|
31
33
|
|
|
32
34
|
def seq_check_equal(seq_1, seq_2):
|
|
33
35
|
if not isinstance(seq_1, Sequence) or not isinstance(seq_2, Sequence):
|
|
@@ -61,11 +63,11 @@ def sametypes_create_value_func(default: TypeVar):
|
|
|
61
63
|
|
|
62
64
|
def extract_tuple(arg, as_constant=False):
|
|
63
65
|
if isinstance(arg, Var):
|
|
64
|
-
if isinstance(arg.type, warp.types.tuple_t):
|
|
66
|
+
if isinstance(arg.type, warp._src.types.tuple_t):
|
|
65
67
|
out = arg.type.values
|
|
66
68
|
else:
|
|
67
69
|
out = (arg,)
|
|
68
|
-
elif isinstance(arg, warp.types.tuple_t):
|
|
70
|
+
elif isinstance(arg, warp._src.types.tuple_t):
|
|
69
71
|
out = arg.values
|
|
70
72
|
elif not isinstance(arg, Sequence):
|
|
71
73
|
out = (arg,)
|
|
@@ -82,7 +84,7 @@ def static_len_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str
|
|
|
82
84
|
if arg_types is None:
|
|
83
85
|
return int
|
|
84
86
|
|
|
85
|
-
length = warp.types.type_length(arg_types["a"])
|
|
87
|
+
length = warp._src.types.type_length(arg_types["a"])
|
|
86
88
|
return Var(None, type=int, constant=length)
|
|
87
89
|
|
|
88
90
|
|
|
@@ -126,7 +128,7 @@ add_builtin(
|
|
|
126
128
|
value_func=sametypes_create_value_func(Scalar),
|
|
127
129
|
doc="Return -1 if ``x`` < 0, return 1 otherwise.",
|
|
128
130
|
group="Scalar Math",
|
|
129
|
-
|
|
131
|
+
is_differentiable=False,
|
|
130
132
|
)
|
|
131
133
|
|
|
132
134
|
add_builtin(
|
|
@@ -135,7 +137,7 @@ add_builtin(
|
|
|
135
137
|
value_func=sametypes_create_value_func(Scalar),
|
|
136
138
|
doc="Return 1.0 if ``x`` < 0.0, return 0.0 otherwise.",
|
|
137
139
|
group="Scalar Math",
|
|
138
|
-
|
|
140
|
+
is_differentiable=False,
|
|
139
141
|
)
|
|
140
142
|
add_builtin(
|
|
141
143
|
"nonzero",
|
|
@@ -143,7 +145,7 @@ add_builtin(
|
|
|
143
145
|
value_func=sametypes_create_value_func(Scalar),
|
|
144
146
|
doc="Return 1.0 if ``x`` is not equal to zero, return 0.0 otherwise.",
|
|
145
147
|
group="Scalar Math",
|
|
146
|
-
|
|
148
|
+
is_differentiable=False,
|
|
147
149
|
)
|
|
148
150
|
|
|
149
151
|
add_builtin(
|
|
@@ -285,7 +287,36 @@ add_builtin(
|
|
|
285
287
|
group="Scalar Math",
|
|
286
288
|
require_original_output_arg=True,
|
|
287
289
|
)
|
|
288
|
-
|
|
290
|
+
add_builtin(
|
|
291
|
+
"erf",
|
|
292
|
+
input_types={"x": Float},
|
|
293
|
+
value_func=sametypes_create_value_func(Float),
|
|
294
|
+
doc="Return the error function of ``x``.",
|
|
295
|
+
group="Scalar Math",
|
|
296
|
+
)
|
|
297
|
+
add_builtin(
|
|
298
|
+
"erfc",
|
|
299
|
+
input_types={"x": Float},
|
|
300
|
+
value_func=sametypes_create_value_func(Float),
|
|
301
|
+
doc="Return the complementary error function of ``x``.",
|
|
302
|
+
group="Scalar Math",
|
|
303
|
+
)
|
|
304
|
+
add_builtin(
|
|
305
|
+
"erfinv",
|
|
306
|
+
input_types={"x": Float},
|
|
307
|
+
value_func=sametypes_create_value_func(Float),
|
|
308
|
+
doc="Return the inverse error function of ``x``.",
|
|
309
|
+
group="Scalar Math",
|
|
310
|
+
require_original_output_arg=True,
|
|
311
|
+
)
|
|
312
|
+
add_builtin(
|
|
313
|
+
"erfcinv",
|
|
314
|
+
input_types={"x": Float},
|
|
315
|
+
value_func=sametypes_create_value_func(Float),
|
|
316
|
+
doc="Return the inverse complementary error function of ``x``.",
|
|
317
|
+
group="Scalar Math",
|
|
318
|
+
require_original_output_arg=True,
|
|
319
|
+
)
|
|
289
320
|
add_builtin(
|
|
290
321
|
"round",
|
|
291
322
|
input_types={"x": Float},
|
|
@@ -295,7 +326,7 @@ add_builtin(
|
|
|
295
326
|
|
|
296
327
|
This is the most intuitive form of rounding in the colloquial sense, but can be slower than other options like :func:`warp.rint()`.
|
|
297
328
|
Differs from :func:`numpy.round()`, which behaves the same way as :func:`numpy.rint()`.""",
|
|
298
|
-
|
|
329
|
+
is_differentiable=False,
|
|
299
330
|
)
|
|
300
331
|
|
|
301
332
|
add_builtin(
|
|
@@ -306,7 +337,7 @@ add_builtin(
|
|
|
306
337
|
doc="""Return the nearest integer value to ``x``, rounding halfway cases to nearest even integer.
|
|
307
338
|
|
|
308
339
|
It is generally faster than :func:`warp.round()`. Equivalent to :func:`numpy.rint()`.""",
|
|
309
|
-
|
|
340
|
+
is_differentiable=False,
|
|
310
341
|
)
|
|
311
342
|
|
|
312
343
|
add_builtin(
|
|
@@ -319,7 +350,7 @@ add_builtin(
|
|
|
319
350
|
In other words, it discards the fractional part of ``x``.
|
|
320
351
|
It is similar to casting ``float(int(a))``, but preserves the negative sign when ``x`` is in the range [-0.0, -1.0).
|
|
321
352
|
Equivalent to :func:`numpy.trunc()` and :func:`numpy.fix()`.""",
|
|
322
|
-
|
|
353
|
+
is_differentiable=False,
|
|
323
354
|
)
|
|
324
355
|
|
|
325
356
|
add_builtin(
|
|
@@ -328,7 +359,7 @@ add_builtin(
|
|
|
328
359
|
value_func=sametypes_create_value_func(Float),
|
|
329
360
|
group="Scalar Math",
|
|
330
361
|
doc="""Return the largest integer that is less than or equal to ``x``.""",
|
|
331
|
-
|
|
362
|
+
is_differentiable=False,
|
|
332
363
|
)
|
|
333
364
|
|
|
334
365
|
add_builtin(
|
|
@@ -337,7 +368,7 @@ add_builtin(
|
|
|
337
368
|
value_func=sametypes_create_value_func(Float),
|
|
338
369
|
group="Scalar Math",
|
|
339
370
|
doc="""Return the smallest integer that is greater than or equal to ``x``.""",
|
|
340
|
-
|
|
371
|
+
is_differentiable=False,
|
|
341
372
|
)
|
|
342
373
|
|
|
343
374
|
add_builtin(
|
|
@@ -348,7 +379,7 @@ add_builtin(
|
|
|
348
379
|
doc="""Retrieve the fractional part of ``x``.
|
|
349
380
|
|
|
350
381
|
In other words, it discards the integer part of ``x`` and is equivalent to ``x - trunc(x)``.""",
|
|
351
|
-
|
|
382
|
+
is_differentiable=False,
|
|
352
383
|
)
|
|
353
384
|
|
|
354
385
|
add_builtin(
|
|
@@ -357,7 +388,7 @@ add_builtin(
|
|
|
357
388
|
value_type=builtins.bool,
|
|
358
389
|
group="Scalar Math",
|
|
359
390
|
doc="""Return ``True`` if ``a`` is a finite number, otherwise return ``False``.""",
|
|
360
|
-
|
|
391
|
+
is_differentiable=False,
|
|
361
392
|
)
|
|
362
393
|
add_builtin(
|
|
363
394
|
"isfinite",
|
|
@@ -365,7 +396,7 @@ add_builtin(
|
|
|
365
396
|
value_type=builtins.bool,
|
|
366
397
|
group="Vector Math",
|
|
367
398
|
doc="Return ``True`` if all elements of the vector ``a`` are finite, otherwise return ``False``.",
|
|
368
|
-
|
|
399
|
+
is_differentiable=False,
|
|
369
400
|
)
|
|
370
401
|
add_builtin(
|
|
371
402
|
"isfinite",
|
|
@@ -373,7 +404,7 @@ add_builtin(
|
|
|
373
404
|
value_type=builtins.bool,
|
|
374
405
|
group="Vector Math",
|
|
375
406
|
doc="Return ``True`` if all elements of the quaternion ``a`` are finite, otherwise return ``False``.",
|
|
376
|
-
|
|
407
|
+
is_differentiable=False,
|
|
377
408
|
)
|
|
378
409
|
add_builtin(
|
|
379
410
|
"isfinite",
|
|
@@ -381,7 +412,7 @@ add_builtin(
|
|
|
381
412
|
value_type=builtins.bool,
|
|
382
413
|
group="Vector Math",
|
|
383
414
|
doc="Return ``True`` if all elements of the matrix ``a`` are finite, otherwise return ``False``.",
|
|
384
|
-
|
|
415
|
+
is_differentiable=False,
|
|
385
416
|
)
|
|
386
417
|
|
|
387
418
|
add_builtin(
|
|
@@ -390,7 +421,7 @@ add_builtin(
|
|
|
390
421
|
value_type=builtins.bool,
|
|
391
422
|
doc="Return ``True`` if ``a`` is NaN, otherwise return ``False``.",
|
|
392
423
|
group="Scalar Math",
|
|
393
|
-
|
|
424
|
+
is_differentiable=False,
|
|
394
425
|
)
|
|
395
426
|
add_builtin(
|
|
396
427
|
"isnan",
|
|
@@ -398,7 +429,7 @@ add_builtin(
|
|
|
398
429
|
value_type=builtins.bool,
|
|
399
430
|
group="Vector Math",
|
|
400
431
|
doc="Return ``True`` if any element of the vector ``a`` is NaN, otherwise return ``False``.",
|
|
401
|
-
|
|
432
|
+
is_differentiable=False,
|
|
402
433
|
)
|
|
403
434
|
add_builtin(
|
|
404
435
|
"isnan",
|
|
@@ -406,7 +437,7 @@ add_builtin(
|
|
|
406
437
|
value_type=builtins.bool,
|
|
407
438
|
group="Vector Math",
|
|
408
439
|
doc="Return ``True`` if any element of the quaternion ``a`` is NaN, otherwise return ``False``.",
|
|
409
|
-
|
|
440
|
+
is_differentiable=False,
|
|
410
441
|
)
|
|
411
442
|
add_builtin(
|
|
412
443
|
"isnan",
|
|
@@ -414,7 +445,7 @@ add_builtin(
|
|
|
414
445
|
value_type=builtins.bool,
|
|
415
446
|
group="Vector Math",
|
|
416
447
|
doc="Return ``True`` if any element of the matrix ``a`` is NaN, otherwise return ``False``.",
|
|
417
|
-
|
|
448
|
+
is_differentiable=False,
|
|
418
449
|
)
|
|
419
450
|
|
|
420
451
|
add_builtin(
|
|
@@ -423,7 +454,7 @@ add_builtin(
|
|
|
423
454
|
value_type=builtins.bool,
|
|
424
455
|
group="Scalar Math",
|
|
425
456
|
doc="""Return ``True`` if ``a`` is positive or negative infinity, otherwise return ``False``.""",
|
|
426
|
-
|
|
457
|
+
is_differentiable=False,
|
|
427
458
|
)
|
|
428
459
|
add_builtin(
|
|
429
460
|
"isinf",
|
|
@@ -431,7 +462,7 @@ add_builtin(
|
|
|
431
462
|
value_type=builtins.bool,
|
|
432
463
|
group="Vector Math",
|
|
433
464
|
doc="Return ``True`` if any element of the vector ``a`` is positive or negative infinity, otherwise return ``False``.",
|
|
434
|
-
|
|
465
|
+
is_differentiable=False,
|
|
435
466
|
)
|
|
436
467
|
add_builtin(
|
|
437
468
|
"isinf",
|
|
@@ -439,7 +470,7 @@ add_builtin(
|
|
|
439
470
|
value_type=builtins.bool,
|
|
440
471
|
group="Vector Math",
|
|
441
472
|
doc="Return ``True`` if any element of the quaternion ``a`` is positive or negative infinity, otherwise return ``False``.",
|
|
442
|
-
|
|
473
|
+
is_differentiable=False,
|
|
443
474
|
)
|
|
444
475
|
add_builtin(
|
|
445
476
|
"isinf",
|
|
@@ -447,7 +478,7 @@ add_builtin(
|
|
|
447
478
|
value_type=builtins.bool,
|
|
448
479
|
group="Vector Math",
|
|
449
480
|
doc="Return ``True`` if any element of the matrix ``a`` is positive or negative infinity, otherwise return ``False``.",
|
|
450
|
-
|
|
481
|
+
is_differentiable=False,
|
|
451
482
|
)
|
|
452
483
|
|
|
453
484
|
|
|
@@ -555,7 +586,7 @@ add_builtin(
|
|
|
555
586
|
value_func=lambda arg_types, arg_values: warp.uint32,
|
|
556
587
|
doc="Return the index of the minimum element of a vector ``a``.",
|
|
557
588
|
group="Vector Math",
|
|
558
|
-
|
|
589
|
+
is_differentiable=False,
|
|
559
590
|
)
|
|
560
591
|
add_builtin(
|
|
561
592
|
"argmax",
|
|
@@ -563,7 +594,7 @@ add_builtin(
|
|
|
563
594
|
value_func=lambda arg_types, arg_values: warp.uint32,
|
|
564
595
|
doc="Return the index of the maximum element of a vector ``a``.",
|
|
565
596
|
group="Vector Math",
|
|
566
|
-
|
|
597
|
+
is_differentiable=False,
|
|
567
598
|
)
|
|
568
599
|
|
|
569
600
|
add_builtin(
|
|
@@ -888,7 +919,7 @@ def vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
888
919
|
|
|
889
920
|
if dtype is None:
|
|
890
921
|
dtype = value_type
|
|
891
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
922
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
892
923
|
raise RuntimeError(
|
|
893
924
|
f"the value used to fill this vector is expected to be of the type `{dtype.__name__}`"
|
|
894
925
|
)
|
|
@@ -909,7 +940,7 @@ def vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
909
940
|
|
|
910
941
|
if dtype is None:
|
|
911
942
|
dtype = value_type
|
|
912
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
943
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
913
944
|
raise RuntimeError(
|
|
914
945
|
f"all values used to initialize this vector are expected to be of the type `{dtype.__name__}`"
|
|
915
946
|
)
|
|
@@ -992,7 +1023,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
992
1023
|
|
|
993
1024
|
if dtype is None:
|
|
994
1025
|
dtype = value_type
|
|
995
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1026
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
996
1027
|
raise RuntimeError(
|
|
997
1028
|
f"the value used to fill this matrix is expected to be of the type `{dtype.__name__}`"
|
|
998
1029
|
)
|
|
@@ -1002,7 +1033,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
1002
1033
|
raise RuntimeError("the `shape` argument must be specified when initializing a matrix by value")
|
|
1003
1034
|
|
|
1004
1035
|
if all(type_is_vector(x) for x in variadic_arg_types):
|
|
1005
|
-
warp.utils.warn(
|
|
1036
|
+
warp._src.utils.warn(
|
|
1006
1037
|
"the built-in `wp.matrix()` won't support taking column vectors as input "
|
|
1007
1038
|
"in the future. Use `wp.matrix_from_rows()` or `wp.matrix_from_cols()` instead.",
|
|
1008
1039
|
DeprecationWarning,
|
|
@@ -1031,7 +1062,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
1031
1062
|
|
|
1032
1063
|
if dtype is None:
|
|
1033
1064
|
dtype = value_type
|
|
1034
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1065
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
1035
1066
|
raise RuntimeError(
|
|
1036
1067
|
f"all values used to initialize this matrix are expected to be of the type `{dtype.__name__}`"
|
|
1037
1068
|
)
|
|
@@ -1203,49 +1234,18 @@ add_builtin(
|
|
|
1203
1234
|
doc="Create an identity matrix with shape=(n,n) with the type given by ``dtype``.",
|
|
1204
1235
|
group="Vector Math",
|
|
1205
1236
|
export=False,
|
|
1206
|
-
|
|
1237
|
+
is_differentiable=False,
|
|
1207
1238
|
)
|
|
1208
1239
|
|
|
1209
1240
|
|
|
1210
1241
|
def matrix_transform_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1211
|
-
warp.utils.warn(
|
|
1212
|
-
"the built-in `wp.matrix()` function to construct a 4x4 matrix from a 3D position, quaternion, "
|
|
1213
|
-
"and 3D scale vector will be deprecated in favor of `wp.transform_compose()`.",
|
|
1214
|
-
DeprecationWarning,
|
|
1215
|
-
)
|
|
1216
1242
|
if arg_types is None:
|
|
1217
1243
|
return matrix(shape=(4, 4), dtype=Float)
|
|
1218
1244
|
|
|
1219
|
-
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
value_type = scalar_infer_type(value_arg_types)
|
|
1224
|
-
except RuntimeError:
|
|
1225
|
-
raise RuntimeError(
|
|
1226
|
-
"all values given when constructing a transformation matrix must have the same type"
|
|
1227
|
-
) from None
|
|
1228
|
-
|
|
1229
|
-
if dtype is None:
|
|
1230
|
-
dtype = value_type
|
|
1231
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1232
|
-
raise RuntimeError(
|
|
1233
|
-
f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
|
|
1234
|
-
)
|
|
1235
|
-
|
|
1236
|
-
return matrix(shape=(4, 4), dtype=dtype)
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
def matrix_transform_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
1240
|
-
# We're in the codegen stage where we emit the code calling the built-in.
|
|
1241
|
-
# Further validate the given argument values if needed and map them
|
|
1242
|
-
# to the underlying C++ function's runtime and template params.
|
|
1243
|
-
|
|
1244
|
-
dtype = return_type._wp_scalar_type_
|
|
1245
|
-
|
|
1246
|
-
func_args = tuple(v for k, v in args.items() if k != "dtype")
|
|
1247
|
-
template_args = (4, 4, dtype)
|
|
1248
|
-
return (func_args, template_args)
|
|
1245
|
+
raise RuntimeError(
|
|
1246
|
+
"the built-in `wp.matrix()` to construct a 4x4 matrix from a 3D position, quaternion, "
|
|
1247
|
+
"and 3D scale vector has been removed in favor of `wp.transform_compose()`."
|
|
1248
|
+
)
|
|
1249
1249
|
|
|
1250
1250
|
|
|
1251
1251
|
add_builtin(
|
|
@@ -1259,13 +1259,14 @@ add_builtin(
|
|
|
1259
1259
|
defaults={"dtype": None},
|
|
1260
1260
|
value_func=matrix_transform_value_func,
|
|
1261
1261
|
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
1262
|
-
dispatch_func=matrix_transform_dispatch_func,
|
|
1263
1262
|
native_func="mat_t",
|
|
1264
1263
|
doc="""Construct a 4x4 transformation matrix that applies the transformations as
|
|
1265
1264
|
Translation(pos)*Rotation(rot)*Scaling(scale) when applied to column vectors, i.e.: y = (TRS)*x
|
|
1266
1265
|
|
|
1267
|
-
..
|
|
1268
|
-
This function has been
|
|
1266
|
+
.. versionremoved:: 1.10
|
|
1267
|
+
This function has been removed in favor of :func:`warp.math.transform_compose()`.
|
|
1268
|
+
|
|
1269
|
+
.. deprecated:: 1.8""",
|
|
1269
1270
|
group="Vector Math",
|
|
1270
1271
|
export=False,
|
|
1271
1272
|
)
|
|
@@ -1460,7 +1461,7 @@ def quaternion_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str
|
|
|
1460
1461
|
|
|
1461
1462
|
if dtype is None:
|
|
1462
1463
|
dtype = value_type
|
|
1463
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1464
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
1464
1465
|
raise RuntimeError(
|
|
1465
1466
|
f"all values used to initialize this quaternion are expected to be of the type `{dtype.__name__}`"
|
|
1466
1467
|
)
|
|
@@ -1568,7 +1569,7 @@ add_builtin(
|
|
|
1568
1569
|
group="Quaternion Math",
|
|
1569
1570
|
doc="Construct an identity quaternion with zero imaginary part and real part of 1.0",
|
|
1570
1571
|
export=True,
|
|
1571
|
-
|
|
1572
|
+
is_differentiable=False,
|
|
1572
1573
|
)
|
|
1573
1574
|
|
|
1574
1575
|
add_builtin(
|
|
@@ -1697,7 +1698,7 @@ def transformation_value_func(arg_types: Mapping[str, type], arg_values: Mapping
|
|
|
1697
1698
|
value_type = strip_reference(variadic_arg_types[0])
|
|
1698
1699
|
if dtype is None:
|
|
1699
1700
|
dtype = value_type
|
|
1700
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1701
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
1701
1702
|
raise RuntimeError(
|
|
1702
1703
|
f"the value used to fill this transform is expected to be of the type `{dtype.__name__}`"
|
|
1703
1704
|
)
|
|
@@ -1710,7 +1711,7 @@ def transformation_value_func(arg_types: Mapping[str, type], arg_values: Mapping
|
|
|
1710
1711
|
|
|
1711
1712
|
if dtype is None:
|
|
1712
1713
|
dtype = value_type
|
|
1713
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1714
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
1714
1715
|
raise RuntimeError(
|
|
1715
1716
|
f"all values used to initialize this transform are expected to be of the type `{dtype.__name__}`"
|
|
1716
1717
|
)
|
|
@@ -1735,7 +1736,7 @@ def transformation_pq_value_func(arg_types: Mapping[str, type], arg_values: Mapp
|
|
|
1735
1736
|
dtype = arg_values.get("dtype", None)
|
|
1736
1737
|
if dtype is None:
|
|
1737
1738
|
dtype = value_type
|
|
1738
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1739
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
1739
1740
|
raise RuntimeError(
|
|
1740
1741
|
f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
|
|
1741
1742
|
)
|
|
@@ -1750,9 +1751,19 @@ def transformation_dispatch_func(input_types: Mapping[str, type], return_type: A
|
|
|
1750
1751
|
|
|
1751
1752
|
dtype = return_type._wp_scalar_type_
|
|
1752
1753
|
|
|
1753
|
-
variadic_args =
|
|
1754
|
+
variadic_args = args.get("args", ())
|
|
1755
|
+
variadic_arg_count = len(variadic_args)
|
|
1756
|
+
|
|
1757
|
+
if variadic_arg_count == 7:
|
|
1758
|
+
func_args = variadic_args
|
|
1759
|
+
else:
|
|
1760
|
+
func_args = tuple(v for k, v in args.items() if k != "dtype")
|
|
1761
|
+
if "p" in args and "q" not in args:
|
|
1762
|
+
quat_ident = warp._src.codegen.Var(
|
|
1763
|
+
label=None, type=quaternion(dtype=dtype), constant=quaternion(dtype=dtype)(0, 0, 0, 1)
|
|
1764
|
+
)
|
|
1765
|
+
func_args += (quat_ident,)
|
|
1754
1766
|
|
|
1755
|
-
func_args = variadic_args
|
|
1756
1767
|
template_args = (dtype,)
|
|
1757
1768
|
return (func_args, template_args)
|
|
1758
1769
|
|
|
@@ -1760,7 +1771,7 @@ def transformation_dispatch_func(input_types: Mapping[str, type], return_type: A
|
|
|
1760
1771
|
add_builtin(
|
|
1761
1772
|
"transformation",
|
|
1762
1773
|
input_types={"p": vector(length=3, dtype=Float), "q": quaternion(dtype=Float), "dtype": Float},
|
|
1763
|
-
defaults={"dtype": None},
|
|
1774
|
+
defaults={"q": None, "dtype": None},
|
|
1764
1775
|
value_func=transformation_pq_value_func,
|
|
1765
1776
|
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
1766
1777
|
dispatch_func=transformation_dispatch_func,
|
|
@@ -1784,7 +1795,6 @@ add_builtin(
|
|
|
1784
1795
|
doc="Construct a spatial transform vector of given dtype.",
|
|
1785
1796
|
group="Spatial Math",
|
|
1786
1797
|
export=False,
|
|
1787
|
-
missing_grad=True,
|
|
1788
1798
|
)
|
|
1789
1799
|
|
|
1790
1800
|
|
|
@@ -1819,7 +1829,7 @@ add_builtin(
|
|
|
1819
1829
|
group="Transformations",
|
|
1820
1830
|
doc="Construct an identity transform with zero translation and identity rotation.",
|
|
1821
1831
|
export=True,
|
|
1822
|
-
|
|
1832
|
+
is_differentiable=False,
|
|
1823
1833
|
)
|
|
1824
1834
|
|
|
1825
1835
|
add_builtin(
|
|
@@ -1953,7 +1963,7 @@ def spatial_vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping
|
|
|
1953
1963
|
|
|
1954
1964
|
if dtype is None:
|
|
1955
1965
|
dtype = value_type
|
|
1956
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1966
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
1957
1967
|
raise RuntimeError(
|
|
1958
1968
|
f"all values used to initialize this spatial vector are expected to be of the type `{dtype.__name__}`"
|
|
1959
1969
|
)
|
|
@@ -2147,7 +2157,7 @@ add_builtin(
|
|
|
2147
2157
|
value_func=tile_zeros_value_func,
|
|
2148
2158
|
dispatch_func=tile_zeros_dispatch_func,
|
|
2149
2159
|
variadic=False,
|
|
2150
|
-
|
|
2160
|
+
is_differentiable=False,
|
|
2151
2161
|
doc="""Allocate a tile of zero-initialized items.
|
|
2152
2162
|
|
|
2153
2163
|
:param shape: Shape of the output tile
|
|
@@ -2167,7 +2177,7 @@ add_builtin(
|
|
|
2167
2177
|
value_func=tile_zeros_value_func,
|
|
2168
2178
|
dispatch_func=tile_zeros_dispatch_func,
|
|
2169
2179
|
variadic=False,
|
|
2170
|
-
|
|
2180
|
+
is_differentiable=False,
|
|
2171
2181
|
hidden=True,
|
|
2172
2182
|
group="Tile Primitives",
|
|
2173
2183
|
export=False,
|
|
@@ -2219,7 +2229,7 @@ add_builtin(
|
|
|
2219
2229
|
defaults={"storage": "register"},
|
|
2220
2230
|
value_func=tile_ones_value_func,
|
|
2221
2231
|
dispatch_func=tile_ones_dispatch_func,
|
|
2222
|
-
|
|
2232
|
+
is_differentiable=False,
|
|
2223
2233
|
doc="""Allocate a tile of one-initialized items.
|
|
2224
2234
|
|
|
2225
2235
|
:param shape: Shape of the output tile
|
|
@@ -2238,7 +2248,86 @@ add_builtin(
|
|
|
2238
2248
|
defaults={"storage": "register"},
|
|
2239
2249
|
value_func=tile_ones_value_func,
|
|
2240
2250
|
dispatch_func=tile_ones_dispatch_func,
|
|
2241
|
-
|
|
2251
|
+
is_differentiable=False,
|
|
2252
|
+
hidden=True,
|
|
2253
|
+
group="Tile Primitives",
|
|
2254
|
+
export=False,
|
|
2255
|
+
)
|
|
2256
|
+
|
|
2257
|
+
|
|
2258
|
+
def tile_full_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
2259
|
+
# return generic type (for doc builds)
|
|
2260
|
+
if arg_types is None:
|
|
2261
|
+
return tile(dtype=Any, shape=Tuple[int, ...])
|
|
2262
|
+
|
|
2263
|
+
shape = extract_tuple(arg_values["shape"], as_constant=True)
|
|
2264
|
+
|
|
2265
|
+
if None in shape:
|
|
2266
|
+
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
2267
|
+
|
|
2268
|
+
if "value" not in arg_values:
|
|
2269
|
+
raise TypeError("tile_full() missing required keyword argument 'value'")
|
|
2270
|
+
|
|
2271
|
+
if "dtype" not in arg_values:
|
|
2272
|
+
raise TypeError("tile_full() missing required keyword argument 'dtype'")
|
|
2273
|
+
|
|
2274
|
+
if "storage" not in arg_values:
|
|
2275
|
+
raise TypeError("tile_full() missing required keyword argument 'storage'")
|
|
2276
|
+
|
|
2277
|
+
if arg_values["storage"] not in {"shared", "register"}:
|
|
2278
|
+
raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
|
|
2279
|
+
|
|
2280
|
+
dtype = arg_values["dtype"]
|
|
2281
|
+
|
|
2282
|
+
return tile(dtype=dtype, shape=shape, storage=arg_values["storage"])
|
|
2283
|
+
|
|
2284
|
+
|
|
2285
|
+
def tile_full_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
2286
|
+
shape = extract_tuple(arg_values["shape"], as_constant=True)
|
|
2287
|
+
|
|
2288
|
+
if None in shape:
|
|
2289
|
+
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
2290
|
+
|
|
2291
|
+
dtype = arg_values["dtype"]
|
|
2292
|
+
value = arg_values["value"]
|
|
2293
|
+
|
|
2294
|
+
func_args = [value]
|
|
2295
|
+
|
|
2296
|
+
template_args = []
|
|
2297
|
+
template_args.append(dtype)
|
|
2298
|
+
template_args.extend(shape)
|
|
2299
|
+
|
|
2300
|
+
return (func_args, template_args)
|
|
2301
|
+
|
|
2302
|
+
|
|
2303
|
+
add_builtin(
|
|
2304
|
+
"tile_full",
|
|
2305
|
+
input_types={"shape": Tuple[int, ...], "value": Any, "dtype": Any, "storage": str},
|
|
2306
|
+
defaults={"storage": "register"},
|
|
2307
|
+
value_func=tile_full_value_func,
|
|
2308
|
+
dispatch_func=tile_full_dispatch_func,
|
|
2309
|
+
is_differentiable=False,
|
|
2310
|
+
doc="""Allocate a tile filled with the specified value.
|
|
2311
|
+
|
|
2312
|
+
:param shape: Shape of the output tile
|
|
2313
|
+
:param value: Value to fill the tile with
|
|
2314
|
+
:param dtype: Data type of output tile's elements
|
|
2315
|
+
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
2316
|
+
(default) or ``"shared"`` for shared memory.
|
|
2317
|
+
:returns: A tile filled with the specified value""",
|
|
2318
|
+
group="Tile Primitives",
|
|
2319
|
+
export=False,
|
|
2320
|
+
)
|
|
2321
|
+
|
|
2322
|
+
|
|
2323
|
+
# overload for scalar shape
|
|
2324
|
+
add_builtin(
|
|
2325
|
+
"tile_full",
|
|
2326
|
+
input_types={"shape": int, "value": Any, "dtype": Any, "storage": str},
|
|
2327
|
+
defaults={"storage": "register"},
|
|
2328
|
+
value_func=tile_full_value_func,
|
|
2329
|
+
dispatch_func=tile_full_dispatch_func,
|
|
2330
|
+
is_differentiable=False,
|
|
2242
2331
|
hidden=True,
|
|
2243
2332
|
group="Tile Primitives",
|
|
2244
2333
|
export=False,
|
|
@@ -2300,13 +2389,13 @@ def tile_arange_dispatch_func(arg_types: Mapping[str, type], return_type: Any, a
|
|
|
2300
2389
|
args = arg_values["args"]
|
|
2301
2390
|
|
|
2302
2391
|
if len(args) == 1:
|
|
2303
|
-
start = warp.codegen.Var(label=None, type=return_type.dtype, constant=0)
|
|
2392
|
+
start = warp._src.codegen.Var(label=None, type=return_type.dtype, constant=0)
|
|
2304
2393
|
stop = args[0]
|
|
2305
|
-
step = warp.codegen.Var(label=None, type=return_type.dtype, constant=1)
|
|
2394
|
+
step = warp._src.codegen.Var(label=None, type=return_type.dtype, constant=1)
|
|
2306
2395
|
elif len(args) == 2:
|
|
2307
2396
|
start = args[0]
|
|
2308
2397
|
stop = args[1]
|
|
2309
|
-
step = warp.codegen.Var(label=None, type=return_type.dtype, constant=1)
|
|
2398
|
+
step = warp._src.codegen.Var(label=None, type=return_type.dtype, constant=1)
|
|
2310
2399
|
elif len(args) == 3:
|
|
2311
2400
|
start = args[0]
|
|
2312
2401
|
stop = args[1]
|
|
@@ -2329,7 +2418,7 @@ add_builtin(
|
|
|
2329
2418
|
value_func=tile_arange_value_func,
|
|
2330
2419
|
dispatch_func=tile_arange_dispatch_func,
|
|
2331
2420
|
variadic=True,
|
|
2332
|
-
|
|
2421
|
+
is_differentiable=False,
|
|
2333
2422
|
doc="""Generate a tile of linearly spaced elements.
|
|
2334
2423
|
|
|
2335
2424
|
:param args: Variable-length positional arguments, interpreted as:
|
|
@@ -3124,7 +3213,7 @@ add_builtin(
|
|
|
3124
3213
|
:param shape: Shape of the returned slice
|
|
3125
3214
|
:returns: A tile with dimensions given by the specified shape or the remaining source tile dimensions""",
|
|
3126
3215
|
group="Tile Primitives",
|
|
3127
|
-
|
|
3216
|
+
is_differentiable=False,
|
|
3128
3217
|
export=False,
|
|
3129
3218
|
)
|
|
3130
3219
|
|
|
@@ -3371,7 +3460,32 @@ add_builtin(
|
|
|
3371
3460
|
|
|
3372
3461
|
add_builtin(
|
|
3373
3462
|
"assign",
|
|
3374
|
-
input_types={"dst": tile(dtype=Any, shape=Tuple[int,
|
|
3463
|
+
input_types={"dst": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "src": Any},
|
|
3464
|
+
value_func=tile_assign_value_func,
|
|
3465
|
+
group="Tile Primitives",
|
|
3466
|
+
export=False,
|
|
3467
|
+
hidden=True,
|
|
3468
|
+
)
|
|
3469
|
+
|
|
3470
|
+
add_builtin(
|
|
3471
|
+
"assign",
|
|
3472
|
+
input_types={"dst": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "src": Any},
|
|
3473
|
+
value_func=tile_assign_value_func,
|
|
3474
|
+
group="Tile Primitives",
|
|
3475
|
+
export=False,
|
|
3476
|
+
hidden=True,
|
|
3477
|
+
)
|
|
3478
|
+
|
|
3479
|
+
add_builtin(
|
|
3480
|
+
"assign",
|
|
3481
|
+
input_types={
|
|
3482
|
+
"dst": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
3483
|
+
"i": int,
|
|
3484
|
+
"j": int,
|
|
3485
|
+
"k": int,
|
|
3486
|
+
"l": int,
|
|
3487
|
+
"src": Any,
|
|
3488
|
+
},
|
|
3375
3489
|
value_func=tile_assign_value_func,
|
|
3376
3490
|
group="Tile Primitives",
|
|
3377
3491
|
export=False,
|
|
@@ -3380,7 +3494,15 @@ add_builtin(
|
|
|
3380
3494
|
|
|
3381
3495
|
add_builtin(
|
|
3382
3496
|
"assign",
|
|
3383
|
-
input_types={
|
|
3497
|
+
input_types={
|
|
3498
|
+
"dst": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
3499
|
+
"i": int,
|
|
3500
|
+
"j": int,
|
|
3501
|
+
"k": int,
|
|
3502
|
+
"l": int,
|
|
3503
|
+
"m": int,
|
|
3504
|
+
"src": Any,
|
|
3505
|
+
},
|
|
3384
3506
|
value_func=tile_assign_value_func,
|
|
3385
3507
|
group="Tile Primitives",
|
|
3386
3508
|
export=False,
|
|
@@ -3395,6 +3517,8 @@ add_builtin(
|
|
|
3395
3517
|
"j": int,
|
|
3396
3518
|
"k": int,
|
|
3397
3519
|
"l": int,
|
|
3520
|
+
"m": int,
|
|
3521
|
+
"n": int,
|
|
3398
3522
|
"src": Any,
|
|
3399
3523
|
},
|
|
3400
3524
|
value_func=tile_assign_value_func,
|
|
@@ -3416,7 +3540,7 @@ def tile_value_func(arg_types, arg_values):
|
|
|
3416
3540
|
|
|
3417
3541
|
if preserve_type:
|
|
3418
3542
|
dtype = arg_types["x"]
|
|
3419
|
-
shape = (warp.codegen.options["block_dim"],)
|
|
3543
|
+
shape = (warp._src.codegen.options["block_dim"],)
|
|
3420
3544
|
|
|
3421
3545
|
return tile(dtype=dtype, shape=shape)
|
|
3422
3546
|
|
|
@@ -3424,18 +3548,18 @@ def tile_value_func(arg_types, arg_values):
|
|
|
3424
3548
|
if type_is_vector(arg_types["x"]):
|
|
3425
3549
|
dtype = arg_types["x"]._wp_scalar_type_
|
|
3426
3550
|
length = arg_types["x"]._shape_[0]
|
|
3427
|
-
shape = (length, warp.codegen.options["block_dim"])
|
|
3551
|
+
shape = (length, warp._src.codegen.options["block_dim"])
|
|
3428
3552
|
elif type_is_quaternion(arg_types["x"]):
|
|
3429
3553
|
dtype = arg_types["x"]._wp_scalar_type_
|
|
3430
|
-
shape = (4, warp.codegen.options["block_dim"])
|
|
3554
|
+
shape = (4, warp._src.codegen.options["block_dim"])
|
|
3431
3555
|
elif type_is_matrix(arg_types["x"]):
|
|
3432
3556
|
dtype = arg_types["x"]._wp_scalar_type_
|
|
3433
3557
|
rows = arg_types["x"]._shape_[0]
|
|
3434
3558
|
cols = arg_types["x"]._shape_[1]
|
|
3435
|
-
shape = (rows, cols, warp.codegen.options["block_dim"])
|
|
3559
|
+
shape = (rows, cols, warp._src.codegen.options["block_dim"])
|
|
3436
3560
|
else:
|
|
3437
3561
|
dtype = arg_types["x"]
|
|
3438
|
-
shape = (warp.codegen.options["block_dim"],)
|
|
3562
|
+
shape = (warp._src.codegen.options["block_dim"],)
|
|
3439
3563
|
|
|
3440
3564
|
return tile(dtype=dtype, shape=shape)
|
|
3441
3565
|
|
|
@@ -3525,17 +3649,17 @@ def untile_value_func(arg_types, arg_values):
|
|
|
3525
3649
|
if not is_tile(t):
|
|
3526
3650
|
raise TypeError(f"untile() argument must be a tile, got {t!r}")
|
|
3527
3651
|
|
|
3528
|
-
if t.shape[-1] != warp.codegen.options["block_dim"]:
|
|
3652
|
+
if t.shape[-1] != warp._src.codegen.options["block_dim"]:
|
|
3529
3653
|
raise ValueError(
|
|
3530
|
-
f"untile() argument last dimension {t.shape[-1]} does not match the expected block width {warp.codegen.options['block_dim']}"
|
|
3654
|
+
f"untile() argument last dimension {t.shape[-1]} does not match the expected block width {warp._src.codegen.options['block_dim']}"
|
|
3531
3655
|
)
|
|
3532
3656
|
|
|
3533
3657
|
if len(t.shape) == 1:
|
|
3534
3658
|
return t.dtype
|
|
3535
3659
|
elif len(t.shape) == 2:
|
|
3536
|
-
return warp.types.vector(t.shape[0], t.dtype)
|
|
3660
|
+
return warp._src.types.vector(t.shape[0], t.dtype)
|
|
3537
3661
|
elif len(t.shape) == 3:
|
|
3538
|
-
return warp.types.matrix((t.shape[0], t.shape[1]), t.dtype)
|
|
3662
|
+
return warp._src.types.matrix((t.shape[0], t.shape[1]), t.dtype)
|
|
3539
3663
|
else:
|
|
3540
3664
|
raise ValueError(f"untile() argument must have a positive size in dimension 0, but got {t.shape[0]}")
|
|
3541
3665
|
|
|
@@ -3597,7 +3721,36 @@ def tile_extract_value_func(arg_types, arg_values):
|
|
|
3597
3721
|
# force the input tile to shared memory
|
|
3598
3722
|
arg_types["a"].storage = "shared"
|
|
3599
3723
|
|
|
3600
|
-
|
|
3724
|
+
# count the number of indices (all parameters except the tile "a")
|
|
3725
|
+
num_indices = len(arg_types) - 1
|
|
3726
|
+
tile_dtype = arg_types["a"].dtype
|
|
3727
|
+
tile_shape = arg_types["a"].shape
|
|
3728
|
+
|
|
3729
|
+
if type_is_vector(tile_dtype):
|
|
3730
|
+
if num_indices == len(tile_shape):
|
|
3731
|
+
return tile_dtype
|
|
3732
|
+
elif num_indices == len(tile_shape) + 1:
|
|
3733
|
+
return tile_dtype._wp_scalar_type_
|
|
3734
|
+
else:
|
|
3735
|
+
raise IndexError(
|
|
3736
|
+
f"tile_extract: incorrect number of indices ({num_indices}) for tile shape {tuple(tile_shape)}"
|
|
3737
|
+
)
|
|
3738
|
+
elif type_is_matrix(tile_dtype):
|
|
3739
|
+
if num_indices == len(tile_shape):
|
|
3740
|
+
return tile_dtype
|
|
3741
|
+
elif num_indices == len(tile_shape) + 2:
|
|
3742
|
+
return tile_dtype._wp_scalar_type_
|
|
3743
|
+
else:
|
|
3744
|
+
raise IndexError(
|
|
3745
|
+
f"tile_extract: incorrect number of indices ({num_indices}) for matrix tile shape {tuple(tile_shape)}"
|
|
3746
|
+
)
|
|
3747
|
+
else:
|
|
3748
|
+
# scalar element: index count must exactly match tile rank
|
|
3749
|
+
if num_indices == len(tile_shape):
|
|
3750
|
+
return tile_dtype
|
|
3751
|
+
raise IndexError(
|
|
3752
|
+
f"tile_extract: incorrect number of indices ({num_indices}) for tile shape {tuple(tile_shape)}"
|
|
3753
|
+
)
|
|
3601
3754
|
|
|
3602
3755
|
|
|
3603
3756
|
add_builtin(
|
|
@@ -3621,7 +3774,7 @@ add_builtin(
|
|
|
3621
3774
|
|
|
3622
3775
|
add_builtin(
|
|
3623
3776
|
"tile_extract",
|
|
3624
|
-
input_types={"a": tile(dtype=Any, shape=Tuple[int,
|
|
3777
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int},
|
|
3625
3778
|
value_func=tile_extract_value_func,
|
|
3626
3779
|
variadic=False,
|
|
3627
3780
|
doc="""Extract a single element from the tile.
|
|
@@ -3632,7 +3785,7 @@ add_builtin(
|
|
|
3632
3785
|
|
|
3633
3786
|
:param a: Tile to extract the element from
|
|
3634
3787
|
:param i: Coordinate of element on first dimension
|
|
3635
|
-
:param j: Coordinate of element on the second dimension
|
|
3788
|
+
:param j: Coordinate of element on the second dimension, or vector index
|
|
3636
3789
|
:returns: The value of the element at the specified tile location with the same data type as the input tile""",
|
|
3637
3790
|
group="Tile Primitives",
|
|
3638
3791
|
hidden=True,
|
|
@@ -3641,7 +3794,57 @@ add_builtin(
|
|
|
3641
3794
|
|
|
3642
3795
|
add_builtin(
|
|
3643
3796
|
"tile_extract",
|
|
3644
|
-
input_types={"a": tile(dtype=Any, shape=Tuple[int,
|
|
3797
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int},
|
|
3798
|
+
value_func=tile_extract_value_func,
|
|
3799
|
+
variadic=False,
|
|
3800
|
+
doc="""Extract a single element from the tile.
|
|
3801
|
+
|
|
3802
|
+
This function will extract an element from the tile and broadcast its value to all threads in the block.
|
|
3803
|
+
|
|
3804
|
+
Note that this may incur additional synchronization if the source tile is a register tile.
|
|
3805
|
+
|
|
3806
|
+
:param a: Tile to extract the element from
|
|
3807
|
+
:param i: Coordinate of element on first dimension
|
|
3808
|
+
:param j: Coordinate of element on the second dimension, or first matrix index
|
|
3809
|
+
:param k: Coordinate of element on the third dimension, or vector index, or second matrix index
|
|
3810
|
+
:returns: The value of the element at the specified tile location with the same data type as the input tile""",
|
|
3811
|
+
group="Tile Primitives",
|
|
3812
|
+
hidden=True,
|
|
3813
|
+
export=False,
|
|
3814
|
+
)
|
|
3815
|
+
|
|
3816
|
+
add_builtin(
|
|
3817
|
+
"tile_extract",
|
|
3818
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int},
|
|
3819
|
+
value_func=tile_extract_value_func,
|
|
3820
|
+
variadic=False,
|
|
3821
|
+
doc="""Extract a single element from the tile.
|
|
3822
|
+
|
|
3823
|
+
This function will extract an element from the tile and broadcast its value to all threads in the block.
|
|
3824
|
+
|
|
3825
|
+
Note that this may incur additional synchronization if the source tile is a register tile.
|
|
3826
|
+
|
|
3827
|
+
:param a: Tile to extract the element from
|
|
3828
|
+
:param i: Coordinate of element on first dimension
|
|
3829
|
+
:param j: Coordinate of element on the second dimension
|
|
3830
|
+
:param k: Coordinate of element on the third dimension, or first matrix index
|
|
3831
|
+
:param l: Coordinate of element on the fourth dimension, or vector index, or second matrix index
|
|
3832
|
+
:returns: The value of the element at the specified tile location, with the same data type as the input tile""",
|
|
3833
|
+
group="Tile Primitives",
|
|
3834
|
+
hidden=True,
|
|
3835
|
+
export=False,
|
|
3836
|
+
)
|
|
3837
|
+
|
|
3838
|
+
add_builtin(
|
|
3839
|
+
"tile_extract",
|
|
3840
|
+
input_types={
|
|
3841
|
+
"a": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
3842
|
+
"i": int,
|
|
3843
|
+
"j": int,
|
|
3844
|
+
"k": int,
|
|
3845
|
+
"l": int,
|
|
3846
|
+
"m": int,
|
|
3847
|
+
},
|
|
3645
3848
|
value_func=tile_extract_value_func,
|
|
3646
3849
|
variadic=False,
|
|
3647
3850
|
doc="""Extract a single element from the tile.
|
|
@@ -3654,7 +3857,9 @@ add_builtin(
|
|
|
3654
3857
|
:param i: Coordinate of element on first dimension
|
|
3655
3858
|
:param j: Coordinate of element on the second dimension
|
|
3656
3859
|
:param k: Coordinate of element on the third dimension
|
|
3657
|
-
:
|
|
3860
|
+
:param l: Coordinate of element on the fourth dimension, or first matrix index
|
|
3861
|
+
:param m: Vector index, or second matrix index
|
|
3862
|
+
:returns: The value of the element at the specified tile location, with the same data type as the input tile""",
|
|
3658
3863
|
group="Tile Primitives",
|
|
3659
3864
|
hidden=True,
|
|
3660
3865
|
export=False,
|
|
@@ -3662,7 +3867,15 @@ add_builtin(
|
|
|
3662
3867
|
|
|
3663
3868
|
add_builtin(
|
|
3664
3869
|
"tile_extract",
|
|
3665
|
-
input_types={
|
|
3870
|
+
input_types={
|
|
3871
|
+
"a": tile(dtype=Any, shape=Tuple[int, int, int, int]),
|
|
3872
|
+
"i": int,
|
|
3873
|
+
"j": int,
|
|
3874
|
+
"k": int,
|
|
3875
|
+
"l": int,
|
|
3876
|
+
"m": int,
|
|
3877
|
+
"n": int,
|
|
3878
|
+
},
|
|
3666
3879
|
value_func=tile_extract_value_func,
|
|
3667
3880
|
variadic=False,
|
|
3668
3881
|
doc="""Extract a single element from the tile.
|
|
@@ -3676,6 +3889,8 @@ add_builtin(
|
|
|
3676
3889
|
:param j: Coordinate of element on the second dimension
|
|
3677
3890
|
:param k: Coordinate of element on the third dimension
|
|
3678
3891
|
:param l: Coordinate of element on the fourth dimension
|
|
3892
|
+
:param m: Vector index, or first matrix index
|
|
3893
|
+
:param n: Second matrix index
|
|
3679
3894
|
:returns: The value of the element at the specified tile location, with the same data type as the input tile""",
|
|
3680
3895
|
group="Tile Primitives",
|
|
3681
3896
|
hidden=True,
|
|
@@ -3762,49 +3977,160 @@ add_builtin(
|
|
|
3762
3977
|
export=False,
|
|
3763
3978
|
)
|
|
3764
3979
|
|
|
3765
|
-
|
|
3766
|
-
def tile_transpose_value_func(arg_types, arg_values):
|
|
3767
|
-
# return generic type (for doc builds)
|
|
3768
|
-
if arg_types is None:
|
|
3769
|
-
return tile(dtype=Any, shape=Tuple[int, int])
|
|
3770
|
-
|
|
3771
|
-
if len(arg_types) != 1:
|
|
3772
|
-
raise TypeError(f"tile_transpose() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
3773
|
-
|
|
3774
|
-
t = arg_types["a"]
|
|
3775
|
-
|
|
3776
|
-
if not is_tile(t):
|
|
3777
|
-
raise TypeError(f"tile_transpose() argument must be a tile, got {t!r}")
|
|
3778
|
-
|
|
3779
|
-
layout = None
|
|
3780
|
-
|
|
3781
|
-
# flip layout
|
|
3782
|
-
if t.layout == "rowmajor":
|
|
3783
|
-
layout = "colmajor"
|
|
3784
|
-
elif t.layout == "colmajor":
|
|
3785
|
-
layout = "rowmajor"
|
|
3786
|
-
|
|
3787
|
-
# force the input tile to shared memory
|
|
3788
|
-
t.storage = "shared"
|
|
3789
|
-
|
|
3790
|
-
return tile(
|
|
3791
|
-
dtype=t.dtype,
|
|
3792
|
-
shape=t.shape[::-1],
|
|
3793
|
-
storage=t.storage,
|
|
3794
|
-
strides=t.strides[::-1],
|
|
3795
|
-
layout=layout,
|
|
3796
|
-
owner=False,
|
|
3797
|
-
)
|
|
3798
|
-
|
|
3799
|
-
|
|
3800
3980
|
add_builtin(
|
|
3801
|
-
"
|
|
3802
|
-
input_types={"a": tile(dtype=Any, shape=Tuple[int,
|
|
3803
|
-
value_func=
|
|
3804
|
-
|
|
3805
|
-
|
|
3806
|
-
|
|
3807
|
-
|
|
3981
|
+
"tile_bit_and_inplace",
|
|
3982
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
|
|
3983
|
+
value_func=tile_inplace_value_func,
|
|
3984
|
+
group="Tile Primitives",
|
|
3985
|
+
hidden=True,
|
|
3986
|
+
export=False,
|
|
3987
|
+
is_differentiable=False,
|
|
3988
|
+
)
|
|
3989
|
+
add_builtin(
|
|
3990
|
+
"tile_bit_and_inplace",
|
|
3991
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
|
|
3992
|
+
value_func=tile_inplace_value_func,
|
|
3993
|
+
group="Tile Primitives",
|
|
3994
|
+
hidden=True,
|
|
3995
|
+
export=False,
|
|
3996
|
+
is_differentiable=False,
|
|
3997
|
+
)
|
|
3998
|
+
add_builtin(
|
|
3999
|
+
"tile_bit_and_inplace",
|
|
4000
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
|
|
4001
|
+
value_func=tile_inplace_value_func,
|
|
4002
|
+
group="Tile Primitives",
|
|
4003
|
+
hidden=True,
|
|
4004
|
+
export=False,
|
|
4005
|
+
is_differentiable=False,
|
|
4006
|
+
)
|
|
4007
|
+
add_builtin(
|
|
4008
|
+
"tile_bit_and_inplace",
|
|
4009
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
|
|
4010
|
+
value_func=tile_inplace_value_func,
|
|
4011
|
+
group="Tile Primitives",
|
|
4012
|
+
hidden=True,
|
|
4013
|
+
export=False,
|
|
4014
|
+
is_differentiable=False,
|
|
4015
|
+
)
|
|
4016
|
+
|
|
4017
|
+
add_builtin(
|
|
4018
|
+
"tile_bit_or_inplace",
|
|
4019
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
|
|
4020
|
+
value_func=tile_inplace_value_func,
|
|
4021
|
+
group="Tile Primitives",
|
|
4022
|
+
hidden=True,
|
|
4023
|
+
export=False,
|
|
4024
|
+
is_differentiable=False,
|
|
4025
|
+
)
|
|
4026
|
+
add_builtin(
|
|
4027
|
+
"tile_bit_or_inplace",
|
|
4028
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
|
|
4029
|
+
value_func=tile_inplace_value_func,
|
|
4030
|
+
group="Tile Primitives",
|
|
4031
|
+
hidden=True,
|
|
4032
|
+
export=False,
|
|
4033
|
+
is_differentiable=False,
|
|
4034
|
+
)
|
|
4035
|
+
add_builtin(
|
|
4036
|
+
"tile_bit_or_inplace",
|
|
4037
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
|
|
4038
|
+
value_func=tile_inplace_value_func,
|
|
4039
|
+
group="Tile Primitives",
|
|
4040
|
+
hidden=True,
|
|
4041
|
+
export=False,
|
|
4042
|
+
is_differentiable=False,
|
|
4043
|
+
)
|
|
4044
|
+
add_builtin(
|
|
4045
|
+
"tile_bit_or_inplace",
|
|
4046
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
|
|
4047
|
+
value_func=tile_inplace_value_func,
|
|
4048
|
+
group="Tile Primitives",
|
|
4049
|
+
hidden=True,
|
|
4050
|
+
export=False,
|
|
4051
|
+
is_differentiable=False,
|
|
4052
|
+
)
|
|
4053
|
+
|
|
4054
|
+
add_builtin(
|
|
4055
|
+
"tile_bit_xor_inplace",
|
|
4056
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
|
|
4057
|
+
value_func=tile_inplace_value_func,
|
|
4058
|
+
group="Tile Primitives",
|
|
4059
|
+
hidden=True,
|
|
4060
|
+
export=False,
|
|
4061
|
+
is_differentiable=False,
|
|
4062
|
+
)
|
|
4063
|
+
add_builtin(
|
|
4064
|
+
"tile_bit_xor_inplace",
|
|
4065
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
|
|
4066
|
+
value_func=tile_inplace_value_func,
|
|
4067
|
+
group="Tile Primitives",
|
|
4068
|
+
hidden=True,
|
|
4069
|
+
export=False,
|
|
4070
|
+
is_differentiable=False,
|
|
4071
|
+
)
|
|
4072
|
+
add_builtin(
|
|
4073
|
+
"tile_bit_xor_inplace",
|
|
4074
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
|
|
4075
|
+
value_func=tile_inplace_value_func,
|
|
4076
|
+
group="Tile Primitives",
|
|
4077
|
+
hidden=True,
|
|
4078
|
+
export=False,
|
|
4079
|
+
is_differentiable=False,
|
|
4080
|
+
)
|
|
4081
|
+
add_builtin(
|
|
4082
|
+
"tile_bit_xor_inplace",
|
|
4083
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
|
|
4084
|
+
value_func=tile_inplace_value_func,
|
|
4085
|
+
group="Tile Primitives",
|
|
4086
|
+
hidden=True,
|
|
4087
|
+
export=False,
|
|
4088
|
+
is_differentiable=False,
|
|
4089
|
+
)
|
|
4090
|
+
|
|
4091
|
+
|
|
4092
|
+
def tile_transpose_value_func(arg_types, arg_values):
|
|
4093
|
+
# return generic type (for doc builds)
|
|
4094
|
+
if arg_types is None:
|
|
4095
|
+
return tile(dtype=Any, shape=Tuple[int, int])
|
|
4096
|
+
|
|
4097
|
+
if len(arg_types) != 1:
|
|
4098
|
+
raise TypeError(f"tile_transpose() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
4099
|
+
|
|
4100
|
+
t = arg_types["a"]
|
|
4101
|
+
|
|
4102
|
+
if not is_tile(t):
|
|
4103
|
+
raise TypeError(f"tile_transpose() argument must be a tile, got {t!r}")
|
|
4104
|
+
|
|
4105
|
+
layout = None
|
|
4106
|
+
|
|
4107
|
+
# flip layout
|
|
4108
|
+
if t.layout == "rowmajor":
|
|
4109
|
+
layout = "colmajor"
|
|
4110
|
+
elif t.layout == "colmajor":
|
|
4111
|
+
layout = "rowmajor"
|
|
4112
|
+
|
|
4113
|
+
# force the input tile to shared memory
|
|
4114
|
+
t.storage = "shared"
|
|
4115
|
+
|
|
4116
|
+
return tile(
|
|
4117
|
+
dtype=t.dtype,
|
|
4118
|
+
shape=t.shape[::-1],
|
|
4119
|
+
storage=t.storage,
|
|
4120
|
+
strides=t.strides[::-1],
|
|
4121
|
+
layout=layout,
|
|
4122
|
+
owner=False,
|
|
4123
|
+
)
|
|
4124
|
+
|
|
4125
|
+
|
|
4126
|
+
add_builtin(
|
|
4127
|
+
"tile_transpose",
|
|
4128
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, int])},
|
|
4129
|
+
value_func=tile_transpose_value_func,
|
|
4130
|
+
variadic=True,
|
|
4131
|
+
doc="""Transpose a tile.
|
|
4132
|
+
|
|
4133
|
+
For shared memory tiles, this operation will alias the input tile.
|
|
3808
4134
|
Register tiles will first be transferred to shared memory before transposition.
|
|
3809
4135
|
|
|
3810
4136
|
:param a: Tile to transpose with ``shape=(M,N)``
|
|
@@ -3935,6 +4261,80 @@ add_builtin(
|
|
|
3935
4261
|
)
|
|
3936
4262
|
|
|
3937
4263
|
|
|
4264
|
+
def tile_sum_axis_value_func(arg_types, arg_values):
|
|
4265
|
+
if arg_types is None:
|
|
4266
|
+
return tile(dtype=Scalar, shape=Tuple[int, ...])
|
|
4267
|
+
|
|
4268
|
+
a = arg_types["a"]
|
|
4269
|
+
|
|
4270
|
+
if not is_tile(a):
|
|
4271
|
+
raise TypeError(f"tile_sum() 'a' argument must be a tile, got {a!r}")
|
|
4272
|
+
|
|
4273
|
+
# force input tile to shared
|
|
4274
|
+
a.storage = "shared"
|
|
4275
|
+
|
|
4276
|
+
axis = arg_values["axis"]
|
|
4277
|
+
shape = a.shape
|
|
4278
|
+
|
|
4279
|
+
if axis < 0 or axis >= len(shape):
|
|
4280
|
+
raise ValueError(f"tile_sum() axis {axis} is out of bounds for tile with {len(shape)} dimensions")
|
|
4281
|
+
|
|
4282
|
+
# shape is identical less the axis reduction is along
|
|
4283
|
+
if len(shape) > 1:
|
|
4284
|
+
new_shape = shape[:axis] + shape[axis + 1 :]
|
|
4285
|
+
else:
|
|
4286
|
+
new_shape = (1,)
|
|
4287
|
+
|
|
4288
|
+
return tile(dtype=a.dtype, shape=new_shape)
|
|
4289
|
+
|
|
4290
|
+
|
|
4291
|
+
def tile_sum_axis_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
4292
|
+
tile = arg_values["a"]
|
|
4293
|
+
axis_var = arg_values["axis"]
|
|
4294
|
+
if not hasattr(axis_var, "constant") or axis_var.constant is None:
|
|
4295
|
+
raise ValueError("tile_sum() axis must be a compile-time constant")
|
|
4296
|
+
axis = axis_var.constant
|
|
4297
|
+
|
|
4298
|
+
return ((tile,), (axis,))
|
|
4299
|
+
|
|
4300
|
+
|
|
4301
|
+
add_builtin(
|
|
4302
|
+
"tile_sum",
|
|
4303
|
+
input_types={"a": tile(dtype=Scalar, shape=Tuple[int, ...]), "axis": int},
|
|
4304
|
+
value_func=tile_sum_axis_value_func,
|
|
4305
|
+
dispatch_func=tile_sum_axis_dispatch_func,
|
|
4306
|
+
doc="""Cooperatively compute the sum of the tile elements across an axis of the tile using all threads in the block.
|
|
4307
|
+
|
|
4308
|
+
:param a: The input tile. Must reside in shared memory.
|
|
4309
|
+
:param axis: The tile axis to compute the sum across. Must be a compile-time constant.
|
|
4310
|
+
:returns: A tile with the same shape as the input tile less the axis dimension and the same data type as the input tile.
|
|
4311
|
+
|
|
4312
|
+
Example:
|
|
4313
|
+
|
|
4314
|
+
.. code-block:: python
|
|
4315
|
+
|
|
4316
|
+
@wp.kernel
|
|
4317
|
+
def compute():
|
|
4318
|
+
|
|
4319
|
+
t = wp.tile_ones(dtype=float, shape=(8, 8))
|
|
4320
|
+
s = wp.tile_sum(t, axis=0)
|
|
4321
|
+
|
|
4322
|
+
print(s)
|
|
4323
|
+
|
|
4324
|
+
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
|
|
4325
|
+
|
|
4326
|
+
Prints:
|
|
4327
|
+
|
|
4328
|
+
.. code-block:: text
|
|
4329
|
+
|
|
4330
|
+
[8 8 8 8 8 8 8 8] = tile(shape=(8), storage=register)
|
|
4331
|
+
|
|
4332
|
+
""",
|
|
4333
|
+
group="Tile Primitives",
|
|
4334
|
+
export=False,
|
|
4335
|
+
)
|
|
4336
|
+
|
|
4337
|
+
|
|
3938
4338
|
def tile_sort_value_func(arg_types, arg_values):
|
|
3939
4339
|
# return generic type (for doc builds)
|
|
3940
4340
|
if arg_types is None:
|
|
@@ -4011,7 +4411,7 @@ add_builtin(
|
|
|
4011
4411
|
""",
|
|
4012
4412
|
group="Tile Primitives",
|
|
4013
4413
|
export=False,
|
|
4014
|
-
|
|
4414
|
+
is_differentiable=False,
|
|
4015
4415
|
)
|
|
4016
4416
|
|
|
4017
4417
|
|
|
@@ -4065,7 +4465,7 @@ add_builtin(
|
|
|
4065
4465
|
""",
|
|
4066
4466
|
group="Tile Primitives",
|
|
4067
4467
|
export=False,
|
|
4068
|
-
|
|
4468
|
+
is_differentiable=False,
|
|
4069
4469
|
)
|
|
4070
4470
|
|
|
4071
4471
|
|
|
@@ -4119,7 +4519,7 @@ add_builtin(
|
|
|
4119
4519
|
""",
|
|
4120
4520
|
group="Tile Primitives",
|
|
4121
4521
|
export=False,
|
|
4122
|
-
|
|
4522
|
+
is_differentiable=False,
|
|
4123
4523
|
)
|
|
4124
4524
|
|
|
4125
4525
|
|
|
@@ -4172,7 +4572,7 @@ add_builtin(
|
|
|
4172
4572
|
""",
|
|
4173
4573
|
group="Tile Primitives",
|
|
4174
4574
|
export=False,
|
|
4175
|
-
|
|
4575
|
+
is_differentiable=False,
|
|
4176
4576
|
)
|
|
4177
4577
|
|
|
4178
4578
|
|
|
@@ -4225,11 +4625,10 @@ add_builtin(
|
|
|
4225
4625
|
""",
|
|
4226
4626
|
group="Tile Primitives",
|
|
4227
4627
|
export=False,
|
|
4228
|
-
|
|
4628
|
+
is_differentiable=False,
|
|
4229
4629
|
)
|
|
4230
4630
|
|
|
4231
4631
|
|
|
4232
|
-
# does type propagation for load()
|
|
4233
4632
|
def tile_reduce_value_func(arg_types, arg_values):
|
|
4234
4633
|
if arg_types is None:
|
|
4235
4634
|
return tile(dtype=Scalar, shape=(1,))
|
|
@@ -4283,7 +4682,88 @@ add_builtin(
|
|
|
4283
4682
|
""",
|
|
4284
4683
|
group="Tile Primitives",
|
|
4285
4684
|
export=False,
|
|
4286
|
-
|
|
4685
|
+
is_differentiable=False,
|
|
4686
|
+
)
|
|
4687
|
+
|
|
4688
|
+
|
|
4689
|
+
def tile_reduce_axis_value_func(arg_types, arg_values):
|
|
4690
|
+
if arg_types is None:
|
|
4691
|
+
return tile(dtype=Scalar, shape=Tuple[int, ...])
|
|
4692
|
+
|
|
4693
|
+
a = arg_types["a"]
|
|
4694
|
+
|
|
4695
|
+
if not is_tile(a):
|
|
4696
|
+
raise TypeError(f"tile_reduce() 'a' argument must be a tile, got {a!r}")
|
|
4697
|
+
|
|
4698
|
+
# force input tile to shared memory
|
|
4699
|
+
a.storage = "shared"
|
|
4700
|
+
|
|
4701
|
+
axis = arg_values["axis"]
|
|
4702
|
+
shape = a.shape
|
|
4703
|
+
|
|
4704
|
+
if axis < 0 or axis >= len(shape):
|
|
4705
|
+
raise ValueError(f"tile_reduce() axis {axis} is out of bounds for tile with {len(shape)} dimensions")
|
|
4706
|
+
|
|
4707
|
+
# shape is identical less the axis reduction is along
|
|
4708
|
+
if len(shape) > 1:
|
|
4709
|
+
new_shape = shape[:axis] + shape[axis + 1 :]
|
|
4710
|
+
else:
|
|
4711
|
+
new_shape = (1,)
|
|
4712
|
+
|
|
4713
|
+
return tile(dtype=a.dtype, shape=new_shape)
|
|
4714
|
+
|
|
4715
|
+
|
|
4716
|
+
add_builtin(
|
|
4717
|
+
"tile_reduce",
|
|
4718
|
+
input_types={"op": Callable, "a": tile(dtype=Scalar, shape=Tuple[int, ...]), "axis": int},
|
|
4719
|
+
value_func=tile_reduce_axis_value_func,
|
|
4720
|
+
native_func="tile_reduce_axis",
|
|
4721
|
+
doc="""Apply a custom reduction operator across a tile axis.
|
|
4722
|
+
|
|
4723
|
+
This function cooperatively performs a reduction using the provided operator across an axis of the tile.
|
|
4724
|
+
|
|
4725
|
+
:param op: A callable function that accepts two arguments and returns one argument, may be a user function or builtin
|
|
4726
|
+
:param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's data type. Must reside in shared memory.
|
|
4727
|
+
:param axis: The tile axis to perform the reduction across. Must be a compile-time constant.
|
|
4728
|
+
:returns: A tile with the same shape as the input tile less the axis dimension and the same data type as the input tile.
|
|
4729
|
+
|
|
4730
|
+
Example:
|
|
4731
|
+
|
|
4732
|
+
.. code-block:: python
|
|
4733
|
+
|
|
4734
|
+
TILE_M = wp.constant(4)
|
|
4735
|
+
TILE_N = wp.constant(2)
|
|
4736
|
+
|
|
4737
|
+
@wp.kernel
|
|
4738
|
+
def compute(x: wp.array2d(dtype=float), y: wp.array(dtype=float)):
|
|
4739
|
+
|
|
4740
|
+
a = wp.tile_load(x, shape=(TILE_M, TILE_N))
|
|
4741
|
+
b = wp.tile_reduce(wp.add, a, axis=1)
|
|
4742
|
+
wp.tile_store(y, b)
|
|
4743
|
+
|
|
4744
|
+
arr = np.arange(TILE_M * TILE_N).reshape(TILE_M, TILE_N)
|
|
4745
|
+
|
|
4746
|
+
x = wp.array(arr, dtype=float)
|
|
4747
|
+
y = wp.zeros(TILE_M, dtype=float)
|
|
4748
|
+
|
|
4749
|
+
wp.launch_tiled(compute, dim=[1], inputs=[x], outputs=[y], block_dim=32)
|
|
4750
|
+
|
|
4751
|
+
print(x.numpy())
|
|
4752
|
+
print(y.numpy())
|
|
4753
|
+
|
|
4754
|
+
Prints:
|
|
4755
|
+
|
|
4756
|
+
.. code-block:: text
|
|
4757
|
+
|
|
4758
|
+
[[0. 1.]
|
|
4759
|
+
[2. 3.]
|
|
4760
|
+
[4. 5.]
|
|
4761
|
+
[6. 7.]]
|
|
4762
|
+
[ 1. 5. 9. 13.]
|
|
4763
|
+
""",
|
|
4764
|
+
group="Tile Primitives",
|
|
4765
|
+
export=False,
|
|
4766
|
+
is_differentiable=False,
|
|
4287
4767
|
)
|
|
4288
4768
|
|
|
4289
4769
|
|
|
@@ -4347,7 +4827,7 @@ add_builtin(
|
|
|
4347
4827
|
""",
|
|
4348
4828
|
group="Tile Primitives",
|
|
4349
4829
|
export=False,
|
|
4350
|
-
|
|
4830
|
+
is_differentiable=False,
|
|
4351
4831
|
)
|
|
4352
4832
|
|
|
4353
4833
|
|
|
@@ -4411,7 +4891,7 @@ add_builtin(
|
|
|
4411
4891
|
""",
|
|
4412
4892
|
group="Tile Primitives",
|
|
4413
4893
|
export=False,
|
|
4414
|
-
|
|
4894
|
+
is_differentiable=False,
|
|
4415
4895
|
)
|
|
4416
4896
|
|
|
4417
4897
|
|
|
@@ -4665,7 +5145,7 @@ add_builtin(
|
|
|
4665
5145
|
doc="WIP",
|
|
4666
5146
|
group="Utility",
|
|
4667
5147
|
hidden=True,
|
|
4668
|
-
|
|
5148
|
+
is_differentiable=False,
|
|
4669
5149
|
)
|
|
4670
5150
|
|
|
4671
5151
|
add_builtin(
|
|
@@ -4681,7 +5161,7 @@ add_builtin(
|
|
|
4681
5161
|
doc="WIP",
|
|
4682
5162
|
group="Utility",
|
|
4683
5163
|
hidden=True,
|
|
4684
|
-
|
|
5164
|
+
is_differentiable=False,
|
|
4685
5165
|
)
|
|
4686
5166
|
|
|
4687
5167
|
add_builtin(
|
|
@@ -4691,7 +5171,7 @@ add_builtin(
|
|
|
4691
5171
|
doc="WIP",
|
|
4692
5172
|
group="Utility",
|
|
4693
5173
|
hidden=True,
|
|
4694
|
-
|
|
5174
|
+
is_differentiable=False,
|
|
4695
5175
|
)
|
|
4696
5176
|
|
|
4697
5177
|
add_builtin(
|
|
@@ -4743,7 +5223,7 @@ add_builtin(
|
|
|
4743
5223
|
:param low: The lower bound of the bounding box in BVH space
|
|
4744
5224
|
:param high: The upper bound of the bounding box in BVH space""",
|
|
4745
5225
|
export=False,
|
|
4746
|
-
|
|
5226
|
+
is_differentiable=False,
|
|
4747
5227
|
)
|
|
4748
5228
|
|
|
4749
5229
|
add_builtin(
|
|
@@ -4759,7 +5239,7 @@ add_builtin(
|
|
|
4759
5239
|
:param start: The start of the ray in BVH space
|
|
4760
5240
|
:param dir: The direction of the ray in BVH space""",
|
|
4761
5241
|
export=False,
|
|
4762
|
-
|
|
5242
|
+
is_differentiable=False,
|
|
4763
5243
|
)
|
|
4764
5244
|
|
|
4765
5245
|
add_builtin(
|
|
@@ -4770,7 +5250,7 @@ add_builtin(
|
|
|
4770
5250
|
doc="""Move to the next bound returned by the query.
|
|
4771
5251
|
The index of the current bound is stored in ``index``, returns ``False`` if there are no more overlapping bound.""",
|
|
4772
5252
|
export=False,
|
|
4773
|
-
|
|
5253
|
+
is_differentiable=False,
|
|
4774
5254
|
)
|
|
4775
5255
|
|
|
4776
5256
|
add_builtin(
|
|
@@ -5111,7 +5591,7 @@ add_builtin(
|
|
|
5111
5591
|
:param low: The lower bound of the bounding box in mesh space
|
|
5112
5592
|
:param high: The upper bound of the bounding box in mesh space""",
|
|
5113
5593
|
export=False,
|
|
5114
|
-
|
|
5594
|
+
is_differentiable=False,
|
|
5115
5595
|
)
|
|
5116
5596
|
|
|
5117
5597
|
add_builtin(
|
|
@@ -5123,7 +5603,7 @@ add_builtin(
|
|
|
5123
5603
|
|
|
5124
5604
|
The index of the current face is stored in ``index``, returns ``False`` if there are no more overlapping triangles.""",
|
|
5125
5605
|
export=False,
|
|
5126
|
-
|
|
5606
|
+
is_differentiable=False,
|
|
5127
5607
|
)
|
|
5128
5608
|
|
|
5129
5609
|
add_builtin(
|
|
@@ -5153,7 +5633,7 @@ add_builtin(
|
|
|
5153
5633
|
|
|
5154
5634
|
This query can be used to iterate over all neighboring point within a fixed radius from the query point.""",
|
|
5155
5635
|
export=False,
|
|
5156
|
-
|
|
5636
|
+
is_differentiable=False,
|
|
5157
5637
|
)
|
|
5158
5638
|
|
|
5159
5639
|
add_builtin(
|
|
@@ -5165,7 +5645,7 @@ add_builtin(
|
|
|
5165
5645
|
|
|
5166
5646
|
The index of the current neighbor is stored in ``index``, returns ``False`` if there are no more neighbors.""",
|
|
5167
5647
|
export=False,
|
|
5168
|
-
|
|
5648
|
+
is_differentiable=False,
|
|
5169
5649
|
)
|
|
5170
5650
|
|
|
5171
5651
|
add_builtin(
|
|
@@ -5179,7 +5659,7 @@ add_builtin(
|
|
|
5179
5659
|
|
|
5180
5660
|
Returns -1 if the :class:`HashGrid` has not been reserved.""",
|
|
5181
5661
|
export=False,
|
|
5182
|
-
|
|
5662
|
+
is_differentiable=False,
|
|
5183
5663
|
)
|
|
5184
5664
|
|
|
5185
5665
|
add_builtin(
|
|
@@ -5189,16 +5669,34 @@ add_builtin(
|
|
|
5189
5669
|
group="Geometry",
|
|
5190
5670
|
doc="""Tests for intersection between two triangles (v0, v1, v2) and (u0, u1, u2) using Moller's method.
|
|
5191
5671
|
|
|
5672
|
+
This function works with single precision, may return incorrect results in some case.
|
|
5673
|
+
|
|
5674
|
+
Returns > 0 if triangles intersect.""",
|
|
5675
|
+
export=False,
|
|
5676
|
+
is_differentiable=False,
|
|
5677
|
+
)
|
|
5678
|
+
|
|
5679
|
+
|
|
5680
|
+
add_builtin(
|
|
5681
|
+
"intersect_tri_tri",
|
|
5682
|
+
input_types={"v0": vec3d, "v1": vec3d, "v2": vec3d, "u0": vec3d, "u1": vec3d, "u2": vec3d},
|
|
5683
|
+
value_type=int,
|
|
5684
|
+
group="Geometry",
|
|
5685
|
+
doc="""Tests for intersection between two triangles (v0, v1, v2) and (u0, u1, u2) using Moller's method.
|
|
5686
|
+
|
|
5687
|
+
This function works with double precision, results are more accurate than the single precision version.
|
|
5688
|
+
|
|
5192
5689
|
Returns > 0 if triangles intersect.""",
|
|
5193
5690
|
export=False,
|
|
5194
|
-
|
|
5691
|
+
is_differentiable=False,
|
|
5195
5692
|
)
|
|
5196
5693
|
|
|
5694
|
+
|
|
5197
5695
|
add_builtin(
|
|
5198
5696
|
"mesh_get",
|
|
5199
5697
|
input_types={"id": uint64},
|
|
5200
5698
|
value_type=Mesh,
|
|
5201
|
-
|
|
5699
|
+
is_differentiable=False,
|
|
5202
5700
|
group="Geometry",
|
|
5203
5701
|
doc="""Retrieves the mesh given its index.""",
|
|
5204
5702
|
export=False,
|
|
@@ -5211,7 +5709,7 @@ add_builtin(
|
|
|
5211
5709
|
group="Geometry",
|
|
5212
5710
|
doc="""Evaluates the face normal the mesh given a face index.""",
|
|
5213
5711
|
export=False,
|
|
5214
|
-
|
|
5712
|
+
is_differentiable=False,
|
|
5215
5713
|
)
|
|
5216
5714
|
|
|
5217
5715
|
add_builtin(
|
|
@@ -5221,7 +5719,7 @@ add_builtin(
|
|
|
5221
5719
|
group="Geometry",
|
|
5222
5720
|
doc="""Returns the point of the mesh given a index.""",
|
|
5223
5721
|
export=False,
|
|
5224
|
-
|
|
5722
|
+
is_differentiable=False,
|
|
5225
5723
|
)
|
|
5226
5724
|
|
|
5227
5725
|
add_builtin(
|
|
@@ -5231,7 +5729,7 @@ add_builtin(
|
|
|
5231
5729
|
group="Geometry",
|
|
5232
5730
|
doc="""Returns the velocity of the mesh given a index.""",
|
|
5233
5731
|
export=False,
|
|
5234
|
-
|
|
5732
|
+
is_differentiable=False,
|
|
5235
5733
|
)
|
|
5236
5734
|
|
|
5237
5735
|
add_builtin(
|
|
@@ -5241,7 +5739,7 @@ add_builtin(
|
|
|
5241
5739
|
group="Geometry",
|
|
5242
5740
|
doc="""Returns the point-index of the mesh given a face-vertex index.""",
|
|
5243
5741
|
export=False,
|
|
5244
|
-
|
|
5742
|
+
is_differentiable=False,
|
|
5245
5743
|
)
|
|
5246
5744
|
|
|
5247
5745
|
|
|
@@ -5289,7 +5787,7 @@ add_builtin(
|
|
|
5289
5787
|
group="Utility",
|
|
5290
5788
|
export=False,
|
|
5291
5789
|
hidden=True,
|
|
5292
|
-
|
|
5790
|
+
is_differentiable=False,
|
|
5293
5791
|
)
|
|
5294
5792
|
add_builtin(
|
|
5295
5793
|
"iter_next",
|
|
@@ -5298,7 +5796,7 @@ add_builtin(
|
|
|
5298
5796
|
group="Utility",
|
|
5299
5797
|
export=False,
|
|
5300
5798
|
hidden=True,
|
|
5301
|
-
|
|
5799
|
+
is_differentiable=False,
|
|
5302
5800
|
)
|
|
5303
5801
|
add_builtin(
|
|
5304
5802
|
"iter_next",
|
|
@@ -5307,7 +5805,7 @@ add_builtin(
|
|
|
5307
5805
|
group="Utility",
|
|
5308
5806
|
export=False,
|
|
5309
5807
|
hidden=True,
|
|
5310
|
-
|
|
5808
|
+
is_differentiable=False,
|
|
5311
5809
|
)
|
|
5312
5810
|
|
|
5313
5811
|
add_builtin(
|
|
@@ -5318,7 +5816,7 @@ add_builtin(
|
|
|
5318
5816
|
group="Utility",
|
|
5319
5817
|
doc="""Returns the range in reversed order.""",
|
|
5320
5818
|
export=False,
|
|
5321
|
-
|
|
5819
|
+
is_differentiable=False,
|
|
5322
5820
|
)
|
|
5323
5821
|
|
|
5324
5822
|
# ---------------------------------
|
|
@@ -5338,8 +5836,8 @@ _volume_supported_value_types = {
|
|
|
5338
5836
|
|
|
5339
5837
|
|
|
5340
5838
|
def _is_volume_type_supported(dtype):
|
|
5341
|
-
for
|
|
5342
|
-
if types_equal(
|
|
5839
|
+
for value_type in _volume_supported_value_types:
|
|
5840
|
+
if types_equal(value_type, dtype):
|
|
5343
5841
|
return True
|
|
5344
5842
|
return False
|
|
5345
5843
|
|
|
@@ -5467,7 +5965,7 @@ add_builtin(
|
|
|
5467
5965
|
doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k`` for a volume of type type `dtype`.
|
|
5468
5966
|
|
|
5469
5967
|
If the voxel at this index does not exist, this function returns the background value.""",
|
|
5470
|
-
|
|
5968
|
+
is_differentiable=False,
|
|
5471
5969
|
)
|
|
5472
5970
|
|
|
5473
5971
|
|
|
@@ -5488,7 +5986,7 @@ add_builtin(
|
|
|
5488
5986
|
export=False,
|
|
5489
5987
|
group="Volumes",
|
|
5490
5988
|
doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
|
|
5491
|
-
|
|
5989
|
+
is_differentiable=False,
|
|
5492
5990
|
)
|
|
5493
5991
|
|
|
5494
5992
|
add_builtin(
|
|
@@ -5519,7 +6017,7 @@ add_builtin(
|
|
|
5519
6017
|
doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k``.
|
|
5520
6018
|
|
|
5521
6019
|
If the voxel at this index does not exist, this function returns the background value""",
|
|
5522
|
-
|
|
6020
|
+
is_differentiable=False,
|
|
5523
6021
|
)
|
|
5524
6022
|
|
|
5525
6023
|
add_builtin(
|
|
@@ -5528,7 +6026,7 @@ add_builtin(
|
|
|
5528
6026
|
group="Volumes",
|
|
5529
6027
|
doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
|
|
5530
6028
|
export=False,
|
|
5531
|
-
|
|
6029
|
+
is_differentiable=False,
|
|
5532
6030
|
)
|
|
5533
6031
|
|
|
5534
6032
|
add_builtin(
|
|
@@ -5549,7 +6047,7 @@ add_builtin(
|
|
|
5549
6047
|
doc="""Returns the vector value of voxel with coordinates ``i``, ``j``, ``k``.
|
|
5550
6048
|
|
|
5551
6049
|
If the voxel at this index does not exist, this function returns the background value.""",
|
|
5552
|
-
|
|
6050
|
+
is_differentiable=False,
|
|
5553
6051
|
)
|
|
5554
6052
|
|
|
5555
6053
|
add_builtin(
|
|
@@ -5558,7 +6056,7 @@ add_builtin(
|
|
|
5558
6056
|
group="Volumes",
|
|
5559
6057
|
doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
|
|
5560
6058
|
export=False,
|
|
5561
|
-
|
|
6059
|
+
is_differentiable=False,
|
|
5562
6060
|
)
|
|
5563
6061
|
|
|
5564
6062
|
add_builtin(
|
|
@@ -5577,7 +6075,7 @@ add_builtin(
|
|
|
5577
6075
|
doc="""Returns the :class:`int32` value of voxel with coordinates ``i``, ``j``, ``k``.
|
|
5578
6076
|
|
|
5579
6077
|
If the voxel at this index does not exist, this function returns the background value.""",
|
|
5580
|
-
|
|
6078
|
+
is_differentiable=False,
|
|
5581
6079
|
)
|
|
5582
6080
|
|
|
5583
6081
|
add_builtin(
|
|
@@ -5586,7 +6084,7 @@ add_builtin(
|
|
|
5586
6084
|
group="Volumes",
|
|
5587
6085
|
doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
|
|
5588
6086
|
export=False,
|
|
5589
|
-
|
|
6087
|
+
is_differentiable=False,
|
|
5590
6088
|
)
|
|
5591
6089
|
|
|
5592
6090
|
|
|
@@ -5668,7 +6166,7 @@ add_builtin(
|
|
|
5668
6166
|
If the voxel at this index does not exist, this function returns -1.
|
|
5669
6167
|
This function is available for both index grids and classical volumes.
|
|
5670
6168
|
""",
|
|
5671
|
-
|
|
6169
|
+
is_differentiable=False,
|
|
5672
6170
|
)
|
|
5673
6171
|
|
|
5674
6172
|
add_builtin(
|
|
@@ -5710,7 +6208,7 @@ add_builtin(
|
|
|
5710
6208
|
value_type=uint32,
|
|
5711
6209
|
group="Random",
|
|
5712
6210
|
doc="Initialize a new random number generator given a user-defined seed. Returns a 32-bit integer representing the RNG state.",
|
|
5713
|
-
|
|
6211
|
+
is_differentiable=False,
|
|
5714
6212
|
)
|
|
5715
6213
|
|
|
5716
6214
|
add_builtin(
|
|
@@ -5722,7 +6220,7 @@ add_builtin(
|
|
|
5722
6220
|
|
|
5723
6221
|
This alternative constructor can be useful in parallel programs, where a kernel as a whole should share a seed,
|
|
5724
6222
|
but each thread should generate uncorrelated values. In this case usage should be ``r = rand_init(seed, tid)``""",
|
|
5725
|
-
|
|
6223
|
+
is_differentiable=False,
|
|
5726
6224
|
)
|
|
5727
6225
|
|
|
5728
6226
|
add_builtin(
|
|
@@ -5731,7 +6229,7 @@ add_builtin(
|
|
|
5731
6229
|
value_type=int,
|
|
5732
6230
|
group="Random",
|
|
5733
6231
|
doc="Return a random integer in the range [-2^31, 2^31).",
|
|
5734
|
-
|
|
6232
|
+
is_differentiable=False,
|
|
5735
6233
|
)
|
|
5736
6234
|
add_builtin(
|
|
5737
6235
|
"randi",
|
|
@@ -5739,7 +6237,7 @@ add_builtin(
|
|
|
5739
6237
|
value_type=int,
|
|
5740
6238
|
group="Random",
|
|
5741
6239
|
doc="Return a random integer between [low, high).",
|
|
5742
|
-
|
|
6240
|
+
is_differentiable=False,
|
|
5743
6241
|
)
|
|
5744
6242
|
add_builtin(
|
|
5745
6243
|
"randu",
|
|
@@ -5747,7 +6245,7 @@ add_builtin(
|
|
|
5747
6245
|
value_type=uint32,
|
|
5748
6246
|
group="Random",
|
|
5749
6247
|
doc="Return a random unsigned integer in the range [0, 2^32).",
|
|
5750
|
-
|
|
6248
|
+
is_differentiable=False,
|
|
5751
6249
|
)
|
|
5752
6250
|
add_builtin(
|
|
5753
6251
|
"randu",
|
|
@@ -5755,7 +6253,7 @@ add_builtin(
|
|
|
5755
6253
|
value_type=uint32,
|
|
5756
6254
|
group="Random",
|
|
5757
6255
|
doc="Return a random unsigned integer between [low, high).",
|
|
5758
|
-
|
|
6256
|
+
is_differentiable=False,
|
|
5759
6257
|
)
|
|
5760
6258
|
add_builtin(
|
|
5761
6259
|
"randf",
|
|
@@ -5763,7 +6261,7 @@ add_builtin(
|
|
|
5763
6261
|
value_type=float,
|
|
5764
6262
|
group="Random",
|
|
5765
6263
|
doc="Return a random float between [0.0, 1.0).",
|
|
5766
|
-
|
|
6264
|
+
is_differentiable=False,
|
|
5767
6265
|
)
|
|
5768
6266
|
add_builtin(
|
|
5769
6267
|
"randf",
|
|
@@ -5771,7 +6269,7 @@ add_builtin(
|
|
|
5771
6269
|
value_type=float,
|
|
5772
6270
|
group="Random",
|
|
5773
6271
|
doc="Return a random float between [low, high).",
|
|
5774
|
-
|
|
6272
|
+
is_differentiable=False,
|
|
5775
6273
|
)
|
|
5776
6274
|
add_builtin(
|
|
5777
6275
|
"randn",
|
|
@@ -5779,7 +6277,7 @@ add_builtin(
|
|
|
5779
6277
|
value_type=float,
|
|
5780
6278
|
group="Random",
|
|
5781
6279
|
doc="Sample a normal (Gaussian) distribution of mean 0 and variance 1. ",
|
|
5782
|
-
|
|
6280
|
+
is_differentiable=False,
|
|
5783
6281
|
)
|
|
5784
6282
|
|
|
5785
6283
|
add_builtin(
|
|
@@ -5788,7 +6286,7 @@ add_builtin(
|
|
|
5788
6286
|
value_type=int,
|
|
5789
6287
|
group="Random",
|
|
5790
6288
|
doc="Inverse-transform sample a cumulative distribution function.",
|
|
5791
|
-
|
|
6289
|
+
is_differentiable=False,
|
|
5792
6290
|
)
|
|
5793
6291
|
add_builtin(
|
|
5794
6292
|
"sample_triangle",
|
|
@@ -5796,7 +6294,7 @@ add_builtin(
|
|
|
5796
6294
|
value_type=vec2,
|
|
5797
6295
|
group="Random",
|
|
5798
6296
|
doc="Uniformly sample a triangle. Returns sample barycentric coordinates.",
|
|
5799
|
-
|
|
6297
|
+
is_differentiable=False,
|
|
5800
6298
|
)
|
|
5801
6299
|
add_builtin(
|
|
5802
6300
|
"sample_unit_ring",
|
|
@@ -5804,7 +6302,7 @@ add_builtin(
|
|
|
5804
6302
|
value_type=vec2,
|
|
5805
6303
|
group="Random",
|
|
5806
6304
|
doc="Uniformly sample a ring in the xy plane.",
|
|
5807
|
-
|
|
6305
|
+
is_differentiable=False,
|
|
5808
6306
|
)
|
|
5809
6307
|
add_builtin(
|
|
5810
6308
|
"sample_unit_disk",
|
|
@@ -5812,7 +6310,7 @@ add_builtin(
|
|
|
5812
6310
|
value_type=vec2,
|
|
5813
6311
|
group="Random",
|
|
5814
6312
|
doc="Uniformly sample a disk in the xy plane.",
|
|
5815
|
-
|
|
6313
|
+
is_differentiable=False,
|
|
5816
6314
|
)
|
|
5817
6315
|
add_builtin(
|
|
5818
6316
|
"sample_unit_sphere_surface",
|
|
@@ -5820,7 +6318,7 @@ add_builtin(
|
|
|
5820
6318
|
value_type=vec3,
|
|
5821
6319
|
group="Random",
|
|
5822
6320
|
doc="Uniformly sample a unit sphere surface.",
|
|
5823
|
-
|
|
6321
|
+
is_differentiable=False,
|
|
5824
6322
|
)
|
|
5825
6323
|
add_builtin(
|
|
5826
6324
|
"sample_unit_sphere",
|
|
@@ -5828,7 +6326,7 @@ add_builtin(
|
|
|
5828
6326
|
value_type=vec3,
|
|
5829
6327
|
group="Random",
|
|
5830
6328
|
doc="Uniformly sample a unit sphere.",
|
|
5831
|
-
|
|
6329
|
+
is_differentiable=False,
|
|
5832
6330
|
)
|
|
5833
6331
|
add_builtin(
|
|
5834
6332
|
"sample_unit_hemisphere_surface",
|
|
@@ -5836,7 +6334,7 @@ add_builtin(
|
|
|
5836
6334
|
value_type=vec3,
|
|
5837
6335
|
group="Random",
|
|
5838
6336
|
doc="Uniformly sample a unit hemisphere surface.",
|
|
5839
|
-
|
|
6337
|
+
is_differentiable=False,
|
|
5840
6338
|
)
|
|
5841
6339
|
add_builtin(
|
|
5842
6340
|
"sample_unit_hemisphere",
|
|
@@ -5844,7 +6342,7 @@ add_builtin(
|
|
|
5844
6342
|
value_type=vec3,
|
|
5845
6343
|
group="Random",
|
|
5846
6344
|
doc="Uniformly sample a unit hemisphere.",
|
|
5847
|
-
|
|
6345
|
+
is_differentiable=False,
|
|
5848
6346
|
)
|
|
5849
6347
|
add_builtin(
|
|
5850
6348
|
"sample_unit_square",
|
|
@@ -5852,7 +6350,7 @@ add_builtin(
|
|
|
5852
6350
|
value_type=vec2,
|
|
5853
6351
|
group="Random",
|
|
5854
6352
|
doc="Uniformly sample a unit square.",
|
|
5855
|
-
|
|
6353
|
+
is_differentiable=False,
|
|
5856
6354
|
)
|
|
5857
6355
|
add_builtin(
|
|
5858
6356
|
"sample_unit_cube",
|
|
@@ -5860,7 +6358,7 @@ add_builtin(
|
|
|
5860
6358
|
value_type=vec3,
|
|
5861
6359
|
group="Random",
|
|
5862
6360
|
doc="Uniformly sample a unit cube.",
|
|
5863
|
-
|
|
6361
|
+
is_differentiable=False,
|
|
5864
6362
|
)
|
|
5865
6363
|
|
|
5866
6364
|
add_builtin(
|
|
@@ -5872,7 +6370,7 @@ add_builtin(
|
|
|
5872
6370
|
|
|
5873
6371
|
:param state: RNG state
|
|
5874
6372
|
:param lam: The expected value of the distribution""",
|
|
5875
|
-
|
|
6373
|
+
is_differentiable=False,
|
|
5876
6374
|
)
|
|
5877
6375
|
|
|
5878
6376
|
add_builtin(
|
|
@@ -5940,7 +6438,7 @@ add_builtin(
|
|
|
5940
6438
|
value_type=vec2,
|
|
5941
6439
|
group="Random",
|
|
5942
6440
|
doc="Divergence-free vector field based on the gradient of a Perlin noise function.",
|
|
5943
|
-
|
|
6441
|
+
is_differentiable=False,
|
|
5944
6442
|
)
|
|
5945
6443
|
add_builtin(
|
|
5946
6444
|
"curlnoise",
|
|
@@ -5949,7 +6447,7 @@ add_builtin(
|
|
|
5949
6447
|
value_type=vec3,
|
|
5950
6448
|
group="Random",
|
|
5951
6449
|
doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
|
|
5952
|
-
|
|
6450
|
+
is_differentiable=False,
|
|
5953
6451
|
)
|
|
5954
6452
|
add_builtin(
|
|
5955
6453
|
"curlnoise",
|
|
@@ -5958,7 +6456,7 @@ add_builtin(
|
|
|
5958
6456
|
value_type=vec3,
|
|
5959
6457
|
group="Random",
|
|
5960
6458
|
doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
|
|
5961
|
-
|
|
6459
|
+
is_differentiable=False,
|
|
5962
6460
|
)
|
|
5963
6461
|
|
|
5964
6462
|
|
|
@@ -5990,7 +6488,7 @@ add_builtin(
|
|
|
5990
6488
|
dispatch_func=printf_dispatch_func,
|
|
5991
6489
|
group="Utility",
|
|
5992
6490
|
doc="Allows printing formatted strings using C-style format specifiers.",
|
|
5993
|
-
|
|
6491
|
+
is_differentiable=False,
|
|
5994
6492
|
)
|
|
5995
6493
|
|
|
5996
6494
|
add_builtin(
|
|
@@ -6009,7 +6507,7 @@ add_builtin(
|
|
|
6009
6507
|
group="Utility",
|
|
6010
6508
|
namespace="",
|
|
6011
6509
|
native_func="__debugbreak",
|
|
6012
|
-
|
|
6510
|
+
is_differentiable=False,
|
|
6013
6511
|
)
|
|
6014
6512
|
|
|
6015
6513
|
# helpers
|
|
@@ -6027,7 +6525,7 @@ add_builtin(
|
|
|
6027
6525
|
This function may not be called from user-defined Warp functions.""",
|
|
6028
6526
|
namespace="",
|
|
6029
6527
|
native_func="builtin_tid1d",
|
|
6030
|
-
|
|
6528
|
+
is_differentiable=False,
|
|
6031
6529
|
)
|
|
6032
6530
|
|
|
6033
6531
|
add_builtin(
|
|
@@ -6038,7 +6536,7 @@ add_builtin(
|
|
|
6038
6536
|
doc="Returns the number of threads in the current block.",
|
|
6039
6537
|
namespace="",
|
|
6040
6538
|
native_func="builtin_block_dim",
|
|
6041
|
-
|
|
6539
|
+
is_differentiable=False,
|
|
6042
6540
|
)
|
|
6043
6541
|
|
|
6044
6542
|
add_builtin(
|
|
@@ -6053,7 +6551,7 @@ add_builtin(
|
|
|
6053
6551
|
This function may not be called from user-defined Warp functions.""",
|
|
6054
6552
|
namespace="",
|
|
6055
6553
|
native_func="builtin_tid2d",
|
|
6056
|
-
|
|
6554
|
+
is_differentiable=False,
|
|
6057
6555
|
)
|
|
6058
6556
|
|
|
6059
6557
|
add_builtin(
|
|
@@ -6068,7 +6566,7 @@ add_builtin(
|
|
|
6068
6566
|
This function may not be called from user-defined Warp functions.""",
|
|
6069
6567
|
namespace="",
|
|
6070
6568
|
native_func="builtin_tid3d",
|
|
6071
|
-
|
|
6569
|
+
is_differentiable=False,
|
|
6072
6570
|
)
|
|
6073
6571
|
|
|
6074
6572
|
add_builtin(
|
|
@@ -6083,7 +6581,7 @@ add_builtin(
|
|
|
6083
6581
|
This function may not be called from user-defined Warp functions.""",
|
|
6084
6582
|
namespace="",
|
|
6085
6583
|
native_func="builtin_tid4d",
|
|
6086
|
-
|
|
6584
|
+
is_differentiable=False,
|
|
6087
6585
|
)
|
|
6088
6586
|
|
|
6089
6587
|
|
|
@@ -6127,56 +6625,20 @@ def select_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
6127
6625
|
if arg_types is None:
|
|
6128
6626
|
return Any
|
|
6129
6627
|
|
|
6130
|
-
|
|
6131
|
-
v_false = arg_types["value_if_false"]
|
|
6628
|
+
raise RuntimeError("wp.select() has been removed. Use wp.where(cond, value_if_true, value_if_false) instead.")
|
|
6132
6629
|
|
|
6133
|
-
if not types_equal(v_true, v_false):
|
|
6134
|
-
raise RuntimeError(
|
|
6135
|
-
f"select() true value type ({v_true}) must be of the same type as the false type ({v_false})"
|
|
6136
|
-
)
|
|
6137
6630
|
|
|
6138
|
-
|
|
6139
|
-
|
|
6140
|
-
|
|
6141
|
-
|
|
6142
|
-
|
|
6631
|
+
add_builtin(
|
|
6632
|
+
"select",
|
|
6633
|
+
input_types={"cond": builtins.bool, "value_if_false": Any, "value_if_true": Any},
|
|
6634
|
+
value_func=select_value_func,
|
|
6635
|
+
doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
6143
6636
|
|
|
6144
|
-
|
|
6145
|
-
return tile(
|
|
6146
|
-
dtype=v_true.dtype,
|
|
6147
|
-
shape=v_true.shape,
|
|
6148
|
-
storage=v_true.storage,
|
|
6149
|
-
strides=v_true.strides,
|
|
6150
|
-
layout=v_true.layout,
|
|
6151
|
-
owner=True,
|
|
6152
|
-
)
|
|
6153
|
-
|
|
6154
|
-
return v_true
|
|
6155
|
-
|
|
6156
|
-
|
|
6157
|
-
def select_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
6158
|
-
warp.utils.warn(
|
|
6159
|
-
"wp.select() is deprecated and will be removed in a future\n"
|
|
6160
|
-
"version. Use wp.where(cond, value_if_true, value_if_false) instead.",
|
|
6161
|
-
category=DeprecationWarning,
|
|
6162
|
-
)
|
|
6163
|
-
|
|
6164
|
-
func_args = tuple(args.values())
|
|
6165
|
-
template_args = ()
|
|
6166
|
-
|
|
6167
|
-
return (func_args, template_args)
|
|
6168
|
-
|
|
6169
|
-
|
|
6170
|
-
add_builtin(
|
|
6171
|
-
"select",
|
|
6172
|
-
input_types={"cond": builtins.bool, "value_if_false": Any, "value_if_true": Any},
|
|
6173
|
-
value_func=select_value_func,
|
|
6174
|
-
dispatch_func=select_dispatch_func,
|
|
6175
|
-
doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
6176
|
-
|
|
6177
|
-
.. deprecated:: 1.7
|
|
6637
|
+
.. versionremoved:: 1.10
|
|
6178
6638
|
Use :func:`where` instead, which has the more intuitive argument order:
|
|
6179
|
-
``where(cond, value_if_true, value_if_false)``.
|
|
6639
|
+
``where(cond, value_if_true, value_if_false)``.
|
|
6640
|
+
|
|
6641
|
+
.. deprecated:: 1.7""",
|
|
6180
6642
|
group="Utility",
|
|
6181
6643
|
)
|
|
6182
6644
|
for t in int_types:
|
|
@@ -6184,24 +6646,26 @@ for t in int_types:
|
|
|
6184
6646
|
"select",
|
|
6185
6647
|
input_types={"cond": t, "value_if_false": Any, "value_if_true": Any},
|
|
6186
6648
|
value_func=select_value_func,
|
|
6187
|
-
dispatch_func=select_dispatch_func,
|
|
6188
6649
|
doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
6189
6650
|
|
|
6190
|
-
..
|
|
6651
|
+
.. versionremoved:: 1.10
|
|
6191
6652
|
Use :func:`where` instead, which has the more intuitive argument order:
|
|
6192
|
-
``where(cond, value_if_true, value_if_false)``.
|
|
6653
|
+
``where(cond, value_if_true, value_if_false)``.
|
|
6654
|
+
|
|
6655
|
+
.. deprecated:: 1.7""",
|
|
6193
6656
|
group="Utility",
|
|
6194
6657
|
)
|
|
6195
6658
|
add_builtin(
|
|
6196
6659
|
"select",
|
|
6197
6660
|
input_types={"arr": array(dtype=Any), "value_if_false": Any, "value_if_true": Any},
|
|
6198
6661
|
value_func=select_value_func,
|
|
6199
|
-
dispatch_func=select_dispatch_func,
|
|
6200
6662
|
doc="""Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
6201
6663
|
|
|
6202
|
-
..
|
|
6664
|
+
.. versionremoved:: 1.10
|
|
6203
6665
|
Use :func:`where` instead, which has the more intuitive argument order:
|
|
6204
|
-
``where(arr, value_if_true, value_if_false)``.
|
|
6666
|
+
``where(arr, value_if_true, value_if_false)``.
|
|
6667
|
+
|
|
6668
|
+
.. deprecated:: 1.7""",
|
|
6205
6669
|
group="Utility",
|
|
6206
6670
|
)
|
|
6207
6671
|
|
|
@@ -6291,7 +6755,7 @@ add_builtin(
|
|
|
6291
6755
|
group="Utility",
|
|
6292
6756
|
hidden=True,
|
|
6293
6757
|
export=False,
|
|
6294
|
-
|
|
6758
|
+
is_differentiable=False,
|
|
6295
6759
|
)
|
|
6296
6760
|
|
|
6297
6761
|
|
|
@@ -6332,7 +6796,7 @@ add_builtin(
|
|
|
6332
6796
|
native_func="fixedarray_t",
|
|
6333
6797
|
group="Utility",
|
|
6334
6798
|
export=False,
|
|
6335
|
-
|
|
6799
|
+
is_differentiable=False,
|
|
6336
6800
|
hidden=True, # Unhide once we can document both a built-in and a Python scope function sharing the same name.
|
|
6337
6801
|
)
|
|
6338
6802
|
|
|
@@ -6375,14 +6839,13 @@ for array_type in array_types:
|
|
|
6375
6839
|
# does argument checking and type propagation for view()
|
|
6376
6840
|
def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
6377
6841
|
arr_type = arg_types["arr"]
|
|
6378
|
-
idx_types = tuple(arg_types[x] for x in "
|
|
6842
|
+
idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
|
|
6379
6843
|
|
|
6380
6844
|
if not is_array(arr_type):
|
|
6381
6845
|
raise RuntimeError("view() first argument must be an array")
|
|
6382
6846
|
|
|
6383
6847
|
idx_count = len(idx_types)
|
|
6384
|
-
|
|
6385
|
-
if idx_count >= arr_type.ndim:
|
|
6848
|
+
if idx_count > arr_type.ndim:
|
|
6386
6849
|
raise RuntimeError(
|
|
6387
6850
|
f"Trying to create an array view with {idx_count} indices, "
|
|
6388
6851
|
f"but the array only has {arr_type.ndim} dimension(s). "
|
|
@@ -6390,14 +6853,35 @@ def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]
|
|
|
6390
6853
|
f"the expected number of dimensions, e.g.: def func(param: wp.array3d(dtype=float): ..."
|
|
6391
6854
|
)
|
|
6392
6855
|
|
|
6393
|
-
|
|
6394
|
-
|
|
6395
|
-
|
|
6396
|
-
|
|
6856
|
+
has_slice = any(is_slice(x) for x in idx_types)
|
|
6857
|
+
if has_slice:
|
|
6858
|
+
# check index types
|
|
6859
|
+
for t in idx_types:
|
|
6860
|
+
if not (type_is_int(t) or is_slice(t)):
|
|
6861
|
+
raise RuntimeError(
|
|
6862
|
+
f"view() index arguments must be of integer or slice types, got index of type {type_repr(t)}"
|
|
6863
|
+
)
|
|
6864
|
+
|
|
6865
|
+
# Each integer index collapses one dimension.
|
|
6866
|
+
int_count = sum(x.step == 0 for x in idx_types)
|
|
6867
|
+
ndim = arr_type.ndim - int_count
|
|
6868
|
+
assert ndim > 0
|
|
6869
|
+
else:
|
|
6870
|
+
if idx_count == arr_type.ndim:
|
|
6871
|
+
raise RuntimeError("Expected to call `address()` instead of `view()`")
|
|
6872
|
+
|
|
6873
|
+
# check index types
|
|
6874
|
+
for t in idx_types:
|
|
6875
|
+
if not type_is_int(t):
|
|
6876
|
+
raise RuntimeError(
|
|
6877
|
+
f"view() index arguments must be of integer or slice types, got index of type {type_repr(t)}"
|
|
6878
|
+
)
|
|
6879
|
+
|
|
6880
|
+
# create an array view with leading dimensions removed
|
|
6881
|
+
ndim = arr_type.ndim - idx_count
|
|
6882
|
+
assert ndim > 0
|
|
6397
6883
|
|
|
6398
|
-
# create an array view with leading dimensions removed
|
|
6399
6884
|
dtype = arr_type.dtype
|
|
6400
|
-
ndim = arr_type.ndim - idx_count
|
|
6401
6885
|
if isinstance(arr_type, (fabricarray, indexedfabricarray)):
|
|
6402
6886
|
# fabric array of arrays: return array attribute as a regular array
|
|
6403
6887
|
return array(dtype=dtype, ndim=ndim)
|
|
@@ -6408,8 +6892,18 @@ def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]
|
|
|
6408
6892
|
for array_type in array_types:
|
|
6409
6893
|
add_builtin(
|
|
6410
6894
|
"view",
|
|
6411
|
-
input_types={
|
|
6412
|
-
|
|
6895
|
+
input_types={
|
|
6896
|
+
"arr": array_type(dtype=Any),
|
|
6897
|
+
"i": Any,
|
|
6898
|
+
"j": Any,
|
|
6899
|
+
"k": Any,
|
|
6900
|
+
"l": Any,
|
|
6901
|
+
},
|
|
6902
|
+
defaults={
|
|
6903
|
+
"j": None,
|
|
6904
|
+
"k": None,
|
|
6905
|
+
"l": None,
|
|
6906
|
+
},
|
|
6413
6907
|
constraint=sametypes,
|
|
6414
6908
|
hidden=True,
|
|
6415
6909
|
value_func=view_value_func,
|
|
@@ -6513,7 +7007,7 @@ add_builtin(
|
|
|
6513
7007
|
hidden=True,
|
|
6514
7008
|
skip_replay=True,
|
|
6515
7009
|
group="Utility",
|
|
6516
|
-
|
|
7010
|
+
is_differentiable=False,
|
|
6517
7011
|
)
|
|
6518
7012
|
|
|
6519
7013
|
|
|
@@ -6530,7 +7024,7 @@ add_builtin(
|
|
|
6530
7024
|
dispatch_func=load_dispatch_func,
|
|
6531
7025
|
hidden=True,
|
|
6532
7026
|
group="Utility",
|
|
6533
|
-
|
|
7027
|
+
is_differentiable=False,
|
|
6534
7028
|
)
|
|
6535
7029
|
|
|
6536
7030
|
|
|
@@ -6606,6 +7100,13 @@ def create_atomic_op_value_func(op: str):
|
|
|
6606
7100
|
f"atomic_{op}() operations only work on arrays with [u]int32, [u]int64, float32, or float64 "
|
|
6607
7101
|
f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
|
|
6608
7102
|
)
|
|
7103
|
+
elif op in ("and", "or", "xor"):
|
|
7104
|
+
supported_atomic_types = (warp.int32, warp.int64, warp.uint32, warp.uint64)
|
|
7105
|
+
if not any(types_equal(scalar_type, x, match_generic=True) for x in supported_atomic_types):
|
|
7106
|
+
raise RuntimeError(
|
|
7107
|
+
f"atomic_{op}() operations only work on arrays with [u]int32 or [u]int64 "
|
|
7108
|
+
f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
|
|
7109
|
+
)
|
|
6609
7110
|
else:
|
|
6610
7111
|
raise NotImplementedError
|
|
6611
7112
|
|
|
@@ -6639,7 +7140,8 @@ for array_type in array_types:
|
|
|
6639
7140
|
value_func=create_atomic_op_value_func("add"),
|
|
6640
7141
|
dispatch_func=atomic_op_dispatch_func,
|
|
6641
7142
|
doc="""Atomically adds ``value`` onto ``arr[i]`` and returns the original value of ``arr[i]``.
|
|
6642
|
-
|
|
7143
|
+
|
|
7144
|
+
This function is automatically invoked when using the syntax ``arr[i] += value``.""",
|
|
6643
7145
|
group="Utility",
|
|
6644
7146
|
skip_replay=True,
|
|
6645
7147
|
)
|
|
@@ -6651,7 +7153,8 @@ for array_type in array_types:
|
|
|
6651
7153
|
value_func=create_atomic_op_value_func("add"),
|
|
6652
7154
|
dispatch_func=atomic_op_dispatch_func,
|
|
6653
7155
|
doc="""Atomically adds ``value`` onto ``arr[i,j]`` and returns the original value of ``arr[i,j]``.
|
|
6654
|
-
|
|
7156
|
+
|
|
7157
|
+
This function is automatically invoked when using the syntax ``arr[i,j] += value``.""",
|
|
6655
7158
|
group="Utility",
|
|
6656
7159
|
skip_replay=True,
|
|
6657
7160
|
)
|
|
@@ -6663,7 +7166,8 @@ for array_type in array_types:
|
|
|
6663
7166
|
value_func=create_atomic_op_value_func("add"),
|
|
6664
7167
|
dispatch_func=atomic_op_dispatch_func,
|
|
6665
7168
|
doc="""Atomically adds ``value`` onto ``arr[i,j,k]`` and returns the original value of ``arr[i,j,k]``.
|
|
6666
|
-
|
|
7169
|
+
|
|
7170
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k] += value``.""",
|
|
6667
7171
|
group="Utility",
|
|
6668
7172
|
skip_replay=True,
|
|
6669
7173
|
)
|
|
@@ -6675,7 +7179,8 @@ for array_type in array_types:
|
|
|
6675
7179
|
value_func=create_atomic_op_value_func("add"),
|
|
6676
7180
|
dispatch_func=atomic_op_dispatch_func,
|
|
6677
7181
|
doc="""Atomically adds ``value`` onto ``arr[i,j,k,l]`` and returns the original value of ``arr[i,j,k,l]``.
|
|
6678
|
-
|
|
7182
|
+
|
|
7183
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k,l] += value``.""",
|
|
6679
7184
|
group="Utility",
|
|
6680
7185
|
skip_replay=True,
|
|
6681
7186
|
)
|
|
@@ -6688,7 +7193,8 @@ for array_type in array_types:
|
|
|
6688
7193
|
value_func=create_atomic_op_value_func("sub"),
|
|
6689
7194
|
dispatch_func=atomic_op_dispatch_func,
|
|
6690
7195
|
doc="""Atomically subtracts ``value`` onto ``arr[i]`` and returns the original value of ``arr[i]``.
|
|
6691
|
-
|
|
7196
|
+
|
|
7197
|
+
This function is automatically invoked when using the syntax ``arr[i] -= value``.""",
|
|
6692
7198
|
group="Utility",
|
|
6693
7199
|
skip_replay=True,
|
|
6694
7200
|
)
|
|
@@ -6700,7 +7206,8 @@ for array_type in array_types:
|
|
|
6700
7206
|
value_func=create_atomic_op_value_func("sub"),
|
|
6701
7207
|
dispatch_func=atomic_op_dispatch_func,
|
|
6702
7208
|
doc="""Atomically subtracts ``value`` onto ``arr[i,j]`` and returns the original value of ``arr[i,j]``.
|
|
6703
|
-
|
|
7209
|
+
|
|
7210
|
+
This function is automatically invoked when using the syntax ``arr[i,j] -= value``.""",
|
|
6704
7211
|
group="Utility",
|
|
6705
7212
|
skip_replay=True,
|
|
6706
7213
|
)
|
|
@@ -6712,7 +7219,8 @@ for array_type in array_types:
|
|
|
6712
7219
|
value_func=create_atomic_op_value_func("sub"),
|
|
6713
7220
|
dispatch_func=atomic_op_dispatch_func,
|
|
6714
7221
|
doc="""Atomically subtracts ``value`` onto ``arr[i,j,k]`` and returns the original value of ``arr[i,j,k]``.
|
|
6715
|
-
|
|
7222
|
+
|
|
7223
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k] -= value``.""",
|
|
6716
7224
|
group="Utility",
|
|
6717
7225
|
skip_replay=True,
|
|
6718
7226
|
)
|
|
@@ -6724,7 +7232,8 @@ for array_type in array_types:
|
|
|
6724
7232
|
value_func=create_atomic_op_value_func("sub"),
|
|
6725
7233
|
dispatch_func=atomic_op_dispatch_func,
|
|
6726
7234
|
doc="""Atomically subtracts ``value`` onto ``arr[i,j,k,l]`` and returns the original value of ``arr[i,j,k,l]``.
|
|
6727
|
-
|
|
7235
|
+
|
|
7236
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k,l] -= value``.""",
|
|
6728
7237
|
group="Utility",
|
|
6729
7238
|
skip_replay=True,
|
|
6730
7239
|
)
|
|
@@ -6847,7 +7356,7 @@ for array_type in array_types:
|
|
|
6847
7356
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6848
7357
|
group="Utility",
|
|
6849
7358
|
skip_replay=True,
|
|
6850
|
-
|
|
7359
|
+
is_differentiable=False,
|
|
6851
7360
|
)
|
|
6852
7361
|
add_builtin(
|
|
6853
7362
|
"atomic_cas",
|
|
@@ -6861,7 +7370,7 @@ for array_type in array_types:
|
|
|
6861
7370
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6862
7371
|
group="Utility",
|
|
6863
7372
|
skip_replay=True,
|
|
6864
|
-
|
|
7373
|
+
is_differentiable=False,
|
|
6865
7374
|
)
|
|
6866
7375
|
add_builtin(
|
|
6867
7376
|
"atomic_cas",
|
|
@@ -6875,7 +7384,7 @@ for array_type in array_types:
|
|
|
6875
7384
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6876
7385
|
group="Utility",
|
|
6877
7386
|
skip_replay=True,
|
|
6878
|
-
|
|
7387
|
+
is_differentiable=False,
|
|
6879
7388
|
)
|
|
6880
7389
|
add_builtin(
|
|
6881
7390
|
"atomic_cas",
|
|
@@ -6897,7 +7406,7 @@ for array_type in array_types:
|
|
|
6897
7406
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6898
7407
|
group="Utility",
|
|
6899
7408
|
skip_replay=True,
|
|
6900
|
-
|
|
7409
|
+
is_differentiable=False,
|
|
6901
7410
|
)
|
|
6902
7411
|
|
|
6903
7412
|
add_builtin(
|
|
@@ -6912,7 +7421,7 @@ for array_type in array_types:
|
|
|
6912
7421
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6913
7422
|
group="Utility",
|
|
6914
7423
|
skip_replay=True,
|
|
6915
|
-
|
|
7424
|
+
is_differentiable=False,
|
|
6916
7425
|
)
|
|
6917
7426
|
add_builtin(
|
|
6918
7427
|
"atomic_exch",
|
|
@@ -6926,7 +7435,7 @@ for array_type in array_types:
|
|
|
6926
7435
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6927
7436
|
group="Utility",
|
|
6928
7437
|
skip_replay=True,
|
|
6929
|
-
|
|
7438
|
+
is_differentiable=False,
|
|
6930
7439
|
)
|
|
6931
7440
|
add_builtin(
|
|
6932
7441
|
"atomic_exch",
|
|
@@ -6940,7 +7449,7 @@ for array_type in array_types:
|
|
|
6940
7449
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6941
7450
|
group="Utility",
|
|
6942
7451
|
skip_replay=True,
|
|
6943
|
-
|
|
7452
|
+
is_differentiable=False,
|
|
6944
7453
|
)
|
|
6945
7454
|
add_builtin(
|
|
6946
7455
|
"atomic_exch",
|
|
@@ -6956,6 +7465,177 @@ for array_type in array_types:
|
|
|
6956
7465
|
skip_replay=True,
|
|
6957
7466
|
)
|
|
6958
7467
|
|
|
7468
|
+
add_builtin(
|
|
7469
|
+
"atomic_and",
|
|
7470
|
+
hidden=hidden,
|
|
7471
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
7472
|
+
constraint=atomic_op_constraint,
|
|
7473
|
+
value_func=create_atomic_op_value_func("and"),
|
|
7474
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7475
|
+
doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i]``, atomically update the array, and return the old value.
|
|
7476
|
+
|
|
7477
|
+
This function is automatically invoked when using the syntax ``arr[i] &= value``.""",
|
|
7478
|
+
group="Utility",
|
|
7479
|
+
skip_replay=True,
|
|
7480
|
+
is_differentiable=False,
|
|
7481
|
+
)
|
|
7482
|
+
add_builtin(
|
|
7483
|
+
"atomic_and",
|
|
7484
|
+
hidden=hidden,
|
|
7485
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
7486
|
+
constraint=atomic_op_constraint,
|
|
7487
|
+
value_func=create_atomic_op_value_func("and"),
|
|
7488
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7489
|
+
doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
|
|
7490
|
+
|
|
7491
|
+
This function is automatically invoked when using the syntax ``arr[i,j] &= value``.""",
|
|
7492
|
+
group="Utility",
|
|
7493
|
+
skip_replay=True,
|
|
7494
|
+
is_differentiable=False,
|
|
7495
|
+
)
|
|
7496
|
+
add_builtin(
|
|
7497
|
+
"atomic_and",
|
|
7498
|
+
hidden=hidden,
|
|
7499
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
7500
|
+
constraint=atomic_op_constraint,
|
|
7501
|
+
value_func=create_atomic_op_value_func("and"),
|
|
7502
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7503
|
+
doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
|
|
7504
|
+
|
|
7505
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k] &= value``.""",
|
|
7506
|
+
group="Utility",
|
|
7507
|
+
skip_replay=True,
|
|
7508
|
+
is_differentiable=False,
|
|
7509
|
+
)
|
|
7510
|
+
add_builtin(
|
|
7511
|
+
"atomic_and",
|
|
7512
|
+
hidden=hidden,
|
|
7513
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
7514
|
+
constraint=atomic_op_constraint,
|
|
7515
|
+
value_func=create_atomic_op_value_func("and"),
|
|
7516
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7517
|
+
doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
|
|
7518
|
+
|
|
7519
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k,l] &= value``.""",
|
|
7520
|
+
group="Utility",
|
|
7521
|
+
skip_replay=True,
|
|
7522
|
+
is_differentiable=False,
|
|
7523
|
+
)
|
|
7524
|
+
|
|
7525
|
+
add_builtin(
|
|
7526
|
+
"atomic_or",
|
|
7527
|
+
hidden=hidden,
|
|
7528
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
7529
|
+
constraint=atomic_op_constraint,
|
|
7530
|
+
value_func=create_atomic_op_value_func("or"),
|
|
7531
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7532
|
+
doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i]``, atomically update the array, and return the old value.
|
|
7533
|
+
|
|
7534
|
+
This function is automatically invoked when using the syntax ``arr[i] |= value``.""",
|
|
7535
|
+
group="Utility",
|
|
7536
|
+
skip_replay=True,
|
|
7537
|
+
is_differentiable=False,
|
|
7538
|
+
)
|
|
7539
|
+
add_builtin(
|
|
7540
|
+
"atomic_or",
|
|
7541
|
+
hidden=hidden,
|
|
7542
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
7543
|
+
constraint=atomic_op_constraint,
|
|
7544
|
+
value_func=create_atomic_op_value_func("or"),
|
|
7545
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7546
|
+
doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
|
|
7547
|
+
|
|
7548
|
+
This function is automatically invoked when using the syntax ``arr[i,j] |= value``.""",
|
|
7549
|
+
group="Utility",
|
|
7550
|
+
skip_replay=True,
|
|
7551
|
+
is_differentiable=False,
|
|
7552
|
+
)
|
|
7553
|
+
add_builtin(
|
|
7554
|
+
"atomic_or",
|
|
7555
|
+
hidden=hidden,
|
|
7556
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
7557
|
+
constraint=atomic_op_constraint,
|
|
7558
|
+
value_func=create_atomic_op_value_func("or"),
|
|
7559
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7560
|
+
doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
|
|
7561
|
+
|
|
7562
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k] |= value``.""",
|
|
7563
|
+
group="Utility",
|
|
7564
|
+
skip_replay=True,
|
|
7565
|
+
is_differentiable=False,
|
|
7566
|
+
)
|
|
7567
|
+
add_builtin(
|
|
7568
|
+
"atomic_or",
|
|
7569
|
+
hidden=hidden,
|
|
7570
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
7571
|
+
constraint=atomic_op_constraint,
|
|
7572
|
+
value_func=create_atomic_op_value_func("or"),
|
|
7573
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7574
|
+
doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
|
|
7575
|
+
|
|
7576
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k,l] |= value``.""",
|
|
7577
|
+
group="Utility",
|
|
7578
|
+
skip_replay=True,
|
|
7579
|
+
is_differentiable=False,
|
|
7580
|
+
)
|
|
7581
|
+
|
|
7582
|
+
add_builtin(
|
|
7583
|
+
"atomic_xor",
|
|
7584
|
+
hidden=hidden,
|
|
7585
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
7586
|
+
constraint=atomic_op_constraint,
|
|
7587
|
+
value_func=create_atomic_op_value_func("xor"),
|
|
7588
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7589
|
+
doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i]``, atomically update the array, and return the old value.
|
|
7590
|
+
|
|
7591
|
+
This function is automatically invoked when using the syntax ``arr[i] ^= value``.""",
|
|
7592
|
+
group="Utility",
|
|
7593
|
+
skip_replay=True,
|
|
7594
|
+
is_differentiable=False,
|
|
7595
|
+
)
|
|
7596
|
+
add_builtin(
|
|
7597
|
+
"atomic_xor",
|
|
7598
|
+
hidden=hidden,
|
|
7599
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
7600
|
+
constraint=atomic_op_constraint,
|
|
7601
|
+
value_func=create_atomic_op_value_func("xor"),
|
|
7602
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7603
|
+
doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
|
|
7604
|
+
|
|
7605
|
+
This function is automatically invoked when using the syntax ``arr[i,j] ^= value``.""",
|
|
7606
|
+
group="Utility",
|
|
7607
|
+
skip_replay=True,
|
|
7608
|
+
is_differentiable=False,
|
|
7609
|
+
)
|
|
7610
|
+
add_builtin(
|
|
7611
|
+
"atomic_xor",
|
|
7612
|
+
hidden=hidden,
|
|
7613
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
7614
|
+
constraint=atomic_op_constraint,
|
|
7615
|
+
value_func=create_atomic_op_value_func("xor"),
|
|
7616
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7617
|
+
doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
|
|
7618
|
+
|
|
7619
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k] ^= value``.""",
|
|
7620
|
+
group="Utility",
|
|
7621
|
+
skip_replay=True,
|
|
7622
|
+
is_differentiable=False,
|
|
7623
|
+
)
|
|
7624
|
+
add_builtin(
|
|
7625
|
+
"atomic_xor",
|
|
7626
|
+
hidden=hidden,
|
|
7627
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
7628
|
+
constraint=atomic_op_constraint,
|
|
7629
|
+
value_func=create_atomic_op_value_func("xor"),
|
|
7630
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7631
|
+
doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
|
|
7632
|
+
|
|
7633
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k,l] ^= value``.""",
|
|
7634
|
+
group="Utility",
|
|
7635
|
+
skip_replay=True,
|
|
7636
|
+
is_differentiable=False,
|
|
7637
|
+
)
|
|
7638
|
+
|
|
6959
7639
|
|
|
6960
7640
|
# used to index into builtin types, i.e.: y = vec3[1]
|
|
6961
7641
|
def vector_extract_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
@@ -7104,7 +7784,7 @@ add_builtin(
|
|
|
7104
7784
|
hidden=True,
|
|
7105
7785
|
group="Utility",
|
|
7106
7786
|
skip_replay=True,
|
|
7107
|
-
|
|
7787
|
+
is_differentiable=False,
|
|
7108
7788
|
)
|
|
7109
7789
|
# implements &quaternion[index]
|
|
7110
7790
|
add_builtin(
|
|
@@ -7115,7 +7795,7 @@ add_builtin(
|
|
|
7115
7795
|
hidden=True,
|
|
7116
7796
|
group="Utility",
|
|
7117
7797
|
skip_replay=True,
|
|
7118
|
-
|
|
7798
|
+
is_differentiable=False,
|
|
7119
7799
|
)
|
|
7120
7800
|
# implements &transformation[index]
|
|
7121
7801
|
add_builtin(
|
|
@@ -7126,7 +7806,7 @@ add_builtin(
|
|
|
7126
7806
|
hidden=True,
|
|
7127
7807
|
group="Utility",
|
|
7128
7808
|
skip_replay=True,
|
|
7129
|
-
|
|
7809
|
+
is_differentiable=False,
|
|
7130
7810
|
)
|
|
7131
7811
|
# implements &(*vector)[index]
|
|
7132
7812
|
add_builtin(
|
|
@@ -7137,7 +7817,7 @@ add_builtin(
|
|
|
7137
7817
|
hidden=True,
|
|
7138
7818
|
group="Utility",
|
|
7139
7819
|
skip_replay=True,
|
|
7140
|
-
|
|
7820
|
+
is_differentiable=False,
|
|
7141
7821
|
)
|
|
7142
7822
|
# implements &(*matrix)[i, j]
|
|
7143
7823
|
add_builtin(
|
|
@@ -7148,7 +7828,7 @@ add_builtin(
|
|
|
7148
7828
|
hidden=True,
|
|
7149
7829
|
group="Utility",
|
|
7150
7830
|
skip_replay=True,
|
|
7151
|
-
|
|
7831
|
+
is_differentiable=False,
|
|
7152
7832
|
)
|
|
7153
7833
|
# implements &(*quaternion)[index]
|
|
7154
7834
|
add_builtin(
|
|
@@ -7159,7 +7839,7 @@ add_builtin(
|
|
|
7159
7839
|
hidden=True,
|
|
7160
7840
|
group="Utility",
|
|
7161
7841
|
skip_replay=True,
|
|
7162
|
-
|
|
7842
|
+
is_differentiable=False,
|
|
7163
7843
|
)
|
|
7164
7844
|
# implements &(*transformation)[index]
|
|
7165
7845
|
add_builtin(
|
|
@@ -7170,7 +7850,7 @@ add_builtin(
|
|
|
7170
7850
|
hidden=True,
|
|
7171
7851
|
group="Utility",
|
|
7172
7852
|
skip_replay=True,
|
|
7173
|
-
|
|
7853
|
+
is_differentiable=False,
|
|
7174
7854
|
)
|
|
7175
7855
|
|
|
7176
7856
|
|
|
@@ -7366,6 +8046,43 @@ add_builtin(
|
|
|
7366
8046
|
)
|
|
7367
8047
|
|
|
7368
8048
|
|
|
8049
|
+
# implements vector[idx] &= scalar
|
|
8050
|
+
add_builtin(
|
|
8051
|
+
"bit_and_inplace",
|
|
8052
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
|
|
8053
|
+
value_type=None,
|
|
8054
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
8055
|
+
hidden=True,
|
|
8056
|
+
export=False,
|
|
8057
|
+
group="Utility",
|
|
8058
|
+
is_differentiable=False,
|
|
8059
|
+
)
|
|
8060
|
+
|
|
8061
|
+
# implements vector[idx] |= scalar
|
|
8062
|
+
add_builtin(
|
|
8063
|
+
"bit_or_inplace",
|
|
8064
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
|
|
8065
|
+
value_type=None,
|
|
8066
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
8067
|
+
hidden=True,
|
|
8068
|
+
export=False,
|
|
8069
|
+
group="Utility",
|
|
8070
|
+
is_differentiable=False,
|
|
8071
|
+
)
|
|
8072
|
+
|
|
8073
|
+
# implements vector[idx] ^= scalar
|
|
8074
|
+
add_builtin(
|
|
8075
|
+
"bit_xor_inplace",
|
|
8076
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
|
|
8077
|
+
value_type=None,
|
|
8078
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
8079
|
+
hidden=True,
|
|
8080
|
+
export=False,
|
|
8081
|
+
group="Utility",
|
|
8082
|
+
is_differentiable=False,
|
|
8083
|
+
)
|
|
8084
|
+
|
|
8085
|
+
|
|
7369
8086
|
def matrix_index_row_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
7370
8087
|
mat_type = arg_types["a"]
|
|
7371
8088
|
row_type = mat_type._wp_row_type_
|
|
@@ -7381,7 +8098,7 @@ add_builtin(
|
|
|
7381
8098
|
hidden=True,
|
|
7382
8099
|
group="Utility",
|
|
7383
8100
|
skip_replay=True,
|
|
7384
|
-
|
|
8101
|
+
is_differentiable=False,
|
|
7385
8102
|
)
|
|
7386
8103
|
|
|
7387
8104
|
|
|
@@ -7400,7 +8117,7 @@ add_builtin(
|
|
|
7400
8117
|
hidden=True,
|
|
7401
8118
|
group="Utility",
|
|
7402
8119
|
skip_replay=True,
|
|
7403
|
-
|
|
8120
|
+
is_differentiable=False,
|
|
7404
8121
|
)
|
|
7405
8122
|
|
|
7406
8123
|
|
|
@@ -7600,6 +8317,78 @@ add_builtin(
|
|
|
7600
8317
|
)
|
|
7601
8318
|
|
|
7602
8319
|
|
|
8320
|
+
# implements matrix[i] &= value
|
|
8321
|
+
add_builtin(
|
|
8322
|
+
"bit_and_inplace",
|
|
8323
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
|
|
8324
|
+
value_type=None,
|
|
8325
|
+
hidden=True,
|
|
8326
|
+
export=False,
|
|
8327
|
+
group="Utility",
|
|
8328
|
+
is_differentiable=False,
|
|
8329
|
+
)
|
|
8330
|
+
|
|
8331
|
+
|
|
8332
|
+
# implements matrix[i,j] &= value
|
|
8333
|
+
add_builtin(
|
|
8334
|
+
"bit_and_inplace",
|
|
8335
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
|
|
8336
|
+
value_type=None,
|
|
8337
|
+
hidden=True,
|
|
8338
|
+
export=False,
|
|
8339
|
+
group="Utility",
|
|
8340
|
+
is_differentiable=False,
|
|
8341
|
+
)
|
|
8342
|
+
|
|
8343
|
+
|
|
8344
|
+
# implements matrix[i] |= value
|
|
8345
|
+
add_builtin(
|
|
8346
|
+
"bit_or_inplace",
|
|
8347
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
|
|
8348
|
+
value_type=None,
|
|
8349
|
+
hidden=True,
|
|
8350
|
+
export=False,
|
|
8351
|
+
group="Utility",
|
|
8352
|
+
is_differentiable=False,
|
|
8353
|
+
)
|
|
8354
|
+
|
|
8355
|
+
|
|
8356
|
+
# implements matrix[i,j] |= value
|
|
8357
|
+
add_builtin(
|
|
8358
|
+
"bit_or_inplace",
|
|
8359
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
|
|
8360
|
+
value_type=None,
|
|
8361
|
+
hidden=True,
|
|
8362
|
+
export=False,
|
|
8363
|
+
group="Utility",
|
|
8364
|
+
is_differentiable=False,
|
|
8365
|
+
)
|
|
8366
|
+
|
|
8367
|
+
|
|
8368
|
+
# implements matrix[i] ^= value
|
|
8369
|
+
add_builtin(
|
|
8370
|
+
"bit_xor_inplace",
|
|
8371
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
|
|
8372
|
+
value_type=None,
|
|
8373
|
+
hidden=True,
|
|
8374
|
+
export=False,
|
|
8375
|
+
group="Utility",
|
|
8376
|
+
is_differentiable=False,
|
|
8377
|
+
)
|
|
8378
|
+
|
|
8379
|
+
|
|
8380
|
+
# implements matrix[i,j] ^= value
|
|
8381
|
+
add_builtin(
|
|
8382
|
+
"bit_xor_inplace",
|
|
8383
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
|
|
8384
|
+
value_type=None,
|
|
8385
|
+
hidden=True,
|
|
8386
|
+
export=False,
|
|
8387
|
+
group="Utility",
|
|
8388
|
+
is_differentiable=False,
|
|
8389
|
+
)
|
|
8390
|
+
|
|
8391
|
+
|
|
7603
8392
|
for t in scalar_types + vector_types + (bool,):
|
|
7604
8393
|
if "vec" in t.__name__ or "mat" in t.__name__:
|
|
7605
8394
|
continue
|
|
@@ -7611,7 +8400,7 @@ for t in scalar_types + vector_types + (bool,):
|
|
|
7611
8400
|
doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
|
|
7612
8401
|
group="Utility",
|
|
7613
8402
|
hidden=True,
|
|
7614
|
-
|
|
8403
|
+
is_differentiable=False,
|
|
7615
8404
|
)
|
|
7616
8405
|
|
|
7617
8406
|
add_builtin(
|
|
@@ -7622,7 +8411,7 @@ for t in scalar_types + vector_types + (bool,):
|
|
|
7622
8411
|
group="Utility",
|
|
7623
8412
|
hidden=True,
|
|
7624
8413
|
export=False,
|
|
7625
|
-
|
|
8414
|
+
is_differentiable=False,
|
|
7626
8415
|
)
|
|
7627
8416
|
|
|
7628
8417
|
|
|
@@ -7641,7 +8430,7 @@ add_builtin(
|
|
|
7641
8430
|
doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
|
|
7642
8431
|
group="Utility",
|
|
7643
8432
|
hidden=True,
|
|
7644
|
-
|
|
8433
|
+
is_differentiable=False,
|
|
7645
8434
|
)
|
|
7646
8435
|
add_builtin(
|
|
7647
8436
|
"expect_neq",
|
|
@@ -7652,7 +8441,7 @@ add_builtin(
|
|
|
7652
8441
|
group="Utility",
|
|
7653
8442
|
hidden=True,
|
|
7654
8443
|
export=False,
|
|
7655
|
-
|
|
8444
|
+
is_differentiable=False,
|
|
7656
8445
|
)
|
|
7657
8446
|
|
|
7658
8447
|
add_builtin(
|
|
@@ -7663,7 +8452,7 @@ add_builtin(
|
|
|
7663
8452
|
doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
|
|
7664
8453
|
group="Utility",
|
|
7665
8454
|
hidden=True,
|
|
7666
|
-
|
|
8455
|
+
is_differentiable=False,
|
|
7667
8456
|
)
|
|
7668
8457
|
add_builtin(
|
|
7669
8458
|
"expect_neq",
|
|
@@ -7674,7 +8463,7 @@ add_builtin(
|
|
|
7674
8463
|
group="Utility",
|
|
7675
8464
|
hidden=True,
|
|
7676
8465
|
export=False,
|
|
7677
|
-
|
|
8466
|
+
is_differentiable=False,
|
|
7678
8467
|
)
|
|
7679
8468
|
|
|
7680
8469
|
add_builtin(
|
|
@@ -7765,7 +8554,7 @@ add_builtin(
|
|
|
7765
8554
|
value_type=None,
|
|
7766
8555
|
doc="Prints an error to stdout if ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
7767
8556
|
group="Utility",
|
|
7768
|
-
|
|
8557
|
+
is_differentiable=False,
|
|
7769
8558
|
)
|
|
7770
8559
|
add_builtin(
|
|
7771
8560
|
"expect_near",
|
|
@@ -7775,7 +8564,7 @@ add_builtin(
|
|
|
7775
8564
|
value_type=None,
|
|
7776
8565
|
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
7777
8566
|
group="Utility",
|
|
7778
|
-
|
|
8567
|
+
is_differentiable=False,
|
|
7779
8568
|
)
|
|
7780
8569
|
add_builtin(
|
|
7781
8570
|
"expect_near",
|
|
@@ -7785,7 +8574,7 @@ add_builtin(
|
|
|
7785
8574
|
value_type=None,
|
|
7786
8575
|
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
7787
8576
|
group="Utility",
|
|
7788
|
-
|
|
8577
|
+
is_differentiable=False,
|
|
7789
8578
|
)
|
|
7790
8579
|
add_builtin(
|
|
7791
8580
|
"expect_near",
|
|
@@ -7799,7 +8588,7 @@ add_builtin(
|
|
|
7799
8588
|
value_type=None,
|
|
7800
8589
|
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
7801
8590
|
group="Utility",
|
|
7802
|
-
|
|
8591
|
+
is_differentiable=False,
|
|
7803
8592
|
)
|
|
7804
8593
|
|
|
7805
8594
|
# ---------------------------------
|
|
@@ -7810,7 +8599,7 @@ add_builtin(
|
|
|
7810
8599
|
input_types={"arr": array(dtype=Scalar), "value": Scalar},
|
|
7811
8600
|
value_type=int,
|
|
7812
8601
|
doc="Search a sorted array ``arr`` for the closest element greater than or equal to ``value``.",
|
|
7813
|
-
|
|
8602
|
+
is_differentiable=False,
|
|
7814
8603
|
)
|
|
7815
8604
|
|
|
7816
8605
|
add_builtin(
|
|
@@ -7818,7 +8607,7 @@ add_builtin(
|
|
|
7818
8607
|
input_types={"arr": array(dtype=Scalar), "arr_begin": int, "arr_end": int, "value": Scalar},
|
|
7819
8608
|
value_type=int,
|
|
7820
8609
|
doc="Search a sorted array ``arr`` in the range [arr_begin, arr_end) for the closest element greater than or equal to ``value``.",
|
|
7821
|
-
|
|
8610
|
+
is_differentiable=False,
|
|
7822
8611
|
)
|
|
7823
8612
|
|
|
7824
8613
|
# ---------------------------------
|
|
@@ -7899,31 +8688,153 @@ add_builtin(
|
|
|
7899
8688
|
input_types={"a": Int, "b": Int},
|
|
7900
8689
|
value_func=sametypes_create_value_func(Int),
|
|
7901
8690
|
group="Operators",
|
|
7902
|
-
|
|
8691
|
+
is_differentiable=False,
|
|
7903
8692
|
)
|
|
8693
|
+
add_builtin(
|
|
8694
|
+
"bit_and",
|
|
8695
|
+
input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
|
|
8696
|
+
constraint=sametypes,
|
|
8697
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
|
|
8698
|
+
doc="",
|
|
8699
|
+
group="Operators",
|
|
8700
|
+
is_differentiable=False,
|
|
8701
|
+
)
|
|
8702
|
+
add_builtin(
|
|
8703
|
+
"bit_and",
|
|
8704
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
|
|
8705
|
+
constraint=sametypes,
|
|
8706
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
|
|
8707
|
+
doc="",
|
|
8708
|
+
group="Operators",
|
|
8709
|
+
is_differentiable=False,
|
|
8710
|
+
)
|
|
8711
|
+
|
|
7904
8712
|
add_builtin(
|
|
7905
8713
|
"bit_or",
|
|
7906
8714
|
input_types={"a": Int, "b": Int},
|
|
7907
8715
|
value_func=sametypes_create_value_func(Int),
|
|
7908
8716
|
group="Operators",
|
|
7909
|
-
|
|
8717
|
+
is_differentiable=False,
|
|
7910
8718
|
)
|
|
8719
|
+
add_builtin(
|
|
8720
|
+
"bit_or",
|
|
8721
|
+
input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
|
|
8722
|
+
constraint=sametypes,
|
|
8723
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
|
|
8724
|
+
doc="",
|
|
8725
|
+
group="Operators",
|
|
8726
|
+
is_differentiable=False,
|
|
8727
|
+
)
|
|
8728
|
+
add_builtin(
|
|
8729
|
+
"bit_or",
|
|
8730
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
|
|
8731
|
+
constraint=sametypes,
|
|
8732
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
|
|
8733
|
+
doc="",
|
|
8734
|
+
group="Operators",
|
|
8735
|
+
is_differentiable=False,
|
|
8736
|
+
)
|
|
8737
|
+
|
|
7911
8738
|
add_builtin(
|
|
7912
8739
|
"bit_xor",
|
|
7913
8740
|
input_types={"a": Int, "b": Int},
|
|
7914
8741
|
value_func=sametypes_create_value_func(Int),
|
|
7915
8742
|
group="Operators",
|
|
7916
|
-
|
|
8743
|
+
is_differentiable=False,
|
|
8744
|
+
)
|
|
8745
|
+
add_builtin(
|
|
8746
|
+
"bit_xor",
|
|
8747
|
+
input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
|
|
8748
|
+
constraint=sametypes,
|
|
8749
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
|
|
8750
|
+
doc="",
|
|
8751
|
+
group="Operators",
|
|
8752
|
+
is_differentiable=False,
|
|
8753
|
+
)
|
|
8754
|
+
add_builtin(
|
|
8755
|
+
"bit_xor",
|
|
8756
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
|
|
8757
|
+
constraint=sametypes,
|
|
8758
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
|
|
8759
|
+
doc="",
|
|
8760
|
+
group="Operators",
|
|
8761
|
+
is_differentiable=False,
|
|
8762
|
+
)
|
|
8763
|
+
|
|
8764
|
+
add_builtin(
|
|
8765
|
+
"lshift",
|
|
8766
|
+
input_types={"a": Int, "b": Int},
|
|
8767
|
+
value_func=sametypes_create_value_func(Int),
|
|
8768
|
+
group="Operators",
|
|
8769
|
+
is_differentiable=False,
|
|
8770
|
+
)
|
|
8771
|
+
add_builtin(
|
|
8772
|
+
"lshift",
|
|
8773
|
+
input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
|
|
8774
|
+
constraint=sametypes,
|
|
8775
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
|
|
8776
|
+
doc="",
|
|
8777
|
+
group="Operators",
|
|
8778
|
+
is_differentiable=False,
|
|
7917
8779
|
)
|
|
7918
|
-
add_builtin(
|
|
8780
|
+
add_builtin(
|
|
8781
|
+
"lshift",
|
|
8782
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
|
|
8783
|
+
constraint=sametypes,
|
|
8784
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
|
|
8785
|
+
doc="",
|
|
8786
|
+
group="Operators",
|
|
8787
|
+
is_differentiable=False,
|
|
8788
|
+
)
|
|
8789
|
+
|
|
7919
8790
|
add_builtin(
|
|
7920
8791
|
"rshift",
|
|
7921
8792
|
input_types={"a": Int, "b": Int},
|
|
7922
8793
|
value_func=sametypes_create_value_func(Int),
|
|
7923
8794
|
group="Operators",
|
|
7924
|
-
|
|
8795
|
+
is_differentiable=False,
|
|
8796
|
+
)
|
|
8797
|
+
add_builtin(
|
|
8798
|
+
"rshift",
|
|
8799
|
+
input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
|
|
8800
|
+
constraint=sametypes,
|
|
8801
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
|
|
8802
|
+
doc="",
|
|
8803
|
+
group="Operators",
|
|
8804
|
+
is_differentiable=False,
|
|
8805
|
+
)
|
|
8806
|
+
add_builtin(
|
|
8807
|
+
"rshift",
|
|
8808
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
|
|
8809
|
+
constraint=sametypes,
|
|
8810
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
|
|
8811
|
+
doc="",
|
|
8812
|
+
group="Operators",
|
|
8813
|
+
is_differentiable=False,
|
|
8814
|
+
)
|
|
8815
|
+
|
|
8816
|
+
add_builtin(
|
|
8817
|
+
"invert",
|
|
8818
|
+
input_types={"a": Int},
|
|
8819
|
+
value_func=sametypes_create_value_func(Int),
|
|
8820
|
+
group="Operators",
|
|
8821
|
+
is_differentiable=False,
|
|
8822
|
+
)
|
|
8823
|
+
add_builtin(
|
|
8824
|
+
"invert",
|
|
8825
|
+
input_types={"a": vector(length=Any, dtype=Int)},
|
|
8826
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
|
|
8827
|
+
group="Operators",
|
|
8828
|
+
is_differentiable=False,
|
|
7925
8829
|
)
|
|
7926
|
-
add_builtin(
|
|
8830
|
+
add_builtin(
|
|
8831
|
+
"invert",
|
|
8832
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Int)},
|
|
8833
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
|
|
8834
|
+
group="Operators",
|
|
8835
|
+
is_differentiable=False,
|
|
8836
|
+
)
|
|
8837
|
+
|
|
7927
8838
|
|
|
7928
8839
|
add_builtin(
|
|
7929
8840
|
"mul", input_types={"a": Scalar, "b": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators"
|
|
@@ -8123,7 +9034,7 @@ add_builtin(
|
|
|
8123
9034
|
value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
|
|
8124
9035
|
doc="Modulo operation using truncated division.",
|
|
8125
9036
|
group="Operators",
|
|
8126
|
-
|
|
9037
|
+
is_differentiable=False,
|
|
8127
9038
|
)
|
|
8128
9039
|
|
|
8129
9040
|
add_builtin(
|
|
@@ -8183,7 +9094,7 @@ add_builtin(
|
|
|
8183
9094
|
value_func=sametypes_create_value_func(Scalar),
|
|
8184
9095
|
doc="",
|
|
8185
9096
|
group="Operators",
|
|
8186
|
-
|
|
9097
|
+
is_differentiable=False,
|
|
8187
9098
|
)
|
|
8188
9099
|
|
|
8189
9100
|
add_builtin("pos", input_types={"x": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators")
|
|
@@ -8232,14 +9143,26 @@ add_builtin(
|
|
|
8232
9143
|
)
|
|
8233
9144
|
|
|
8234
9145
|
add_builtin(
|
|
8235
|
-
"unot",
|
|
9146
|
+
"unot",
|
|
9147
|
+
input_types={"a": builtins.bool},
|
|
9148
|
+
value_type=builtins.bool,
|
|
9149
|
+
doc="",
|
|
9150
|
+
group="Operators",
|
|
9151
|
+
is_differentiable=False,
|
|
8236
9152
|
)
|
|
8237
9153
|
for t in int_types:
|
|
8238
|
-
add_builtin(
|
|
9154
|
+
add_builtin(
|
|
9155
|
+
"unot", input_types={"a": t}, value_type=builtins.bool, doc="", group="Operators", is_differentiable=False
|
|
9156
|
+
)
|
|
8239
9157
|
|
|
8240
9158
|
|
|
8241
9159
|
add_builtin(
|
|
8242
|
-
"unot",
|
|
9160
|
+
"unot",
|
|
9161
|
+
input_types={"a": array(dtype=Any)},
|
|
9162
|
+
value_type=builtins.bool,
|
|
9163
|
+
doc="",
|
|
9164
|
+
group="Operators",
|
|
9165
|
+
is_differentiable=False,
|
|
8243
9166
|
)
|
|
8244
9167
|
|
|
8245
9168
|
|
|
@@ -8312,6 +9235,45 @@ add_builtin(
|
|
|
8312
9235
|
export=False,
|
|
8313
9236
|
)
|
|
8314
9237
|
|
|
9238
|
+
add_builtin(
|
|
9239
|
+
"bit_and",
|
|
9240
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
9241
|
+
value_func=tile_binary_map_value_func,
|
|
9242
|
+
# dispatch_func=tile_map_dispatch_func,
|
|
9243
|
+
# variadic=True,
|
|
9244
|
+
native_func="tile_bit_and",
|
|
9245
|
+
doc="Bitwise AND each element of two tiles together",
|
|
9246
|
+
group="Tile Primitives",
|
|
9247
|
+
export=False,
|
|
9248
|
+
is_differentiable=False,
|
|
9249
|
+
)
|
|
9250
|
+
|
|
9251
|
+
add_builtin(
|
|
9252
|
+
"bit_or",
|
|
9253
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
9254
|
+
value_func=tile_binary_map_value_func,
|
|
9255
|
+
# dispatch_func=tile_map_dispatch_func,
|
|
9256
|
+
# variadic=True,
|
|
9257
|
+
native_func="tile_bit_or",
|
|
9258
|
+
doc="Bitwise OR each element of two tiles together",
|
|
9259
|
+
group="Tile Primitives",
|
|
9260
|
+
export=False,
|
|
9261
|
+
is_differentiable=False,
|
|
9262
|
+
)
|
|
9263
|
+
|
|
9264
|
+
add_builtin(
|
|
9265
|
+
"bit_xor",
|
|
9266
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
9267
|
+
value_func=tile_binary_map_value_func,
|
|
9268
|
+
# dispatch_func=tile_map_dispatch_func,
|
|
9269
|
+
# variadic=True,
|
|
9270
|
+
native_func="tile_bit_xor",
|
|
9271
|
+
doc="Bitwise XOR each element of two tiles together",
|
|
9272
|
+
group="Tile Primitives",
|
|
9273
|
+
export=False,
|
|
9274
|
+
is_differentiable=False,
|
|
9275
|
+
)
|
|
9276
|
+
|
|
8315
9277
|
|
|
8316
9278
|
add_builtin(
|
|
8317
9279
|
"mul",
|
|
@@ -8373,6 +9335,45 @@ add_builtin(
|
|
|
8373
9335
|
)
|
|
8374
9336
|
|
|
8375
9337
|
|
|
9338
|
+
add_builtin(
|
|
9339
|
+
"bit_and_inplace",
|
|
9340
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
9341
|
+
value_type=None,
|
|
9342
|
+
dispatch_func=tile_inplace_dispatch_func,
|
|
9343
|
+
export=False,
|
|
9344
|
+
hidden=True,
|
|
9345
|
+
native_func="tile_bit_and_inplace",
|
|
9346
|
+
group="Operators",
|
|
9347
|
+
is_differentiable=False,
|
|
9348
|
+
)
|
|
9349
|
+
|
|
9350
|
+
|
|
9351
|
+
add_builtin(
|
|
9352
|
+
"bit_or_inplace",
|
|
9353
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
9354
|
+
value_type=None,
|
|
9355
|
+
dispatch_func=tile_inplace_dispatch_func,
|
|
9356
|
+
export=False,
|
|
9357
|
+
hidden=True,
|
|
9358
|
+
native_func="tile_bit_or_inplace",
|
|
9359
|
+
group="Operators",
|
|
9360
|
+
is_differentiable=False,
|
|
9361
|
+
)
|
|
9362
|
+
|
|
9363
|
+
|
|
9364
|
+
add_builtin(
|
|
9365
|
+
"bit_xor_inplace",
|
|
9366
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
9367
|
+
value_type=None,
|
|
9368
|
+
dispatch_func=tile_inplace_dispatch_func,
|
|
9369
|
+
export=False,
|
|
9370
|
+
hidden=True,
|
|
9371
|
+
native_func="tile_bit_xor_inplace",
|
|
9372
|
+
group="Operators",
|
|
9373
|
+
is_differentiable=False,
|
|
9374
|
+
)
|
|
9375
|
+
|
|
9376
|
+
|
|
8376
9377
|
def tile_diag_add_value_func(arg_types, arg_values):
|
|
8377
9378
|
if arg_types is None:
|
|
8378
9379
|
return tile(dtype=Any, shape=Tuple[int, int])
|
|
@@ -8414,7 +9415,7 @@ def tile_diag_add_lto_dispatch_func(
|
|
|
8414
9415
|
return_values: List[Var],
|
|
8415
9416
|
arg_values: Mapping[str, Var],
|
|
8416
9417
|
options: Mapping[str, Any],
|
|
8417
|
-
builder: warp.context.ModuleBuilder,
|
|
9418
|
+
builder: warp._src.context.ModuleBuilder,
|
|
8418
9419
|
):
|
|
8419
9420
|
a = arg_values["a"]
|
|
8420
9421
|
d = arg_values["d"]
|
|
@@ -8434,7 +9435,7 @@ add_builtin(
|
|
|
8434
9435
|
doc="Add a square matrix and a diagonal matrix 'd' represented as a 1D tile",
|
|
8435
9436
|
group="Tile Primitives",
|
|
8436
9437
|
export=False,
|
|
8437
|
-
|
|
9438
|
+
is_differentiable=False,
|
|
8438
9439
|
)
|
|
8439
9440
|
|
|
8440
9441
|
|
|
@@ -8491,7 +9492,7 @@ def tile_matmul_lto_dispatch_func(
|
|
|
8491
9492
|
return_values: List[Var],
|
|
8492
9493
|
arg_values: Mapping[str, Var],
|
|
8493
9494
|
options: Mapping[str, Any],
|
|
8494
|
-
builder: warp.context.ModuleBuilder,
|
|
9495
|
+
builder: warp._src.context.ModuleBuilder,
|
|
8495
9496
|
):
|
|
8496
9497
|
a = arg_values["a"]
|
|
8497
9498
|
b = arg_values["b"]
|
|
@@ -8529,7 +9530,7 @@ def tile_matmul_lto_dispatch_func(
|
|
|
8529
9530
|
num_threads = options["block_dim"]
|
|
8530
9531
|
arch = options["output_arch"]
|
|
8531
9532
|
|
|
8532
|
-
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
9533
|
+
if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
|
|
8533
9534
|
# CPU/no-MathDx dispatch
|
|
8534
9535
|
return ((0, 0, 0, a, b, out), template_args, [], 0)
|
|
8535
9536
|
else:
|
|
@@ -8542,7 +9543,7 @@ def tile_matmul_lto_dispatch_func(
|
|
|
8542
9543
|
|
|
8543
9544
|
# generate the LTOs
|
|
8544
9545
|
# C += A * B
|
|
8545
|
-
(fun_forward, lto_forward) = warp.build.build_lto_dot(
|
|
9546
|
+
(fun_forward, lto_forward) = warp._src.build.build_lto_dot(
|
|
8546
9547
|
M,
|
|
8547
9548
|
N,
|
|
8548
9549
|
K,
|
|
@@ -8558,7 +9559,7 @@ def tile_matmul_lto_dispatch_func(
|
|
|
8558
9559
|
)
|
|
8559
9560
|
if warp.config.enable_backward:
|
|
8560
9561
|
# adjA += adjC * B^T - Transpose ~= flipped layout
|
|
8561
|
-
(fun_backward_A, lto_backward_A) = warp.build.build_lto_dot(
|
|
9562
|
+
(fun_backward_A, lto_backward_A) = warp._src.build.build_lto_dot(
|
|
8562
9563
|
M,
|
|
8563
9564
|
K,
|
|
8564
9565
|
N,
|
|
@@ -8573,7 +9574,7 @@ def tile_matmul_lto_dispatch_func(
|
|
|
8573
9574
|
builder,
|
|
8574
9575
|
)
|
|
8575
9576
|
# adjB += A^T * adjC - Transpose ~= flipped layout
|
|
8576
|
-
(fun_backward_B, lto_backward_B) = warp.build.build_lto_dot(
|
|
9577
|
+
(fun_backward_B, lto_backward_B) = warp._src.build.build_lto_dot(
|
|
8577
9578
|
K,
|
|
8578
9579
|
N,
|
|
8579
9580
|
M,
|
|
@@ -8690,7 +9691,7 @@ def tile_fft_generic_lto_dispatch_func(
|
|
|
8690
9691
|
return_values: List[Var],
|
|
8691
9692
|
arg_values: Mapping[str, Var],
|
|
8692
9693
|
options: Mapping[str, Any],
|
|
8693
|
-
builder: warp.context.ModuleBuilder,
|
|
9694
|
+
builder: warp._src.context.ModuleBuilder,
|
|
8694
9695
|
direction: str | None = None,
|
|
8695
9696
|
):
|
|
8696
9697
|
inout = arg_values["inout"]
|
|
@@ -8719,12 +9720,12 @@ def tile_fft_generic_lto_dispatch_func(
|
|
|
8719
9720
|
arch = options["output_arch"]
|
|
8720
9721
|
ept = size // num_threads
|
|
8721
9722
|
|
|
8722
|
-
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
9723
|
+
if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
|
|
8723
9724
|
# CPU/no-MathDx dispatch
|
|
8724
9725
|
return ([], [], [], 0)
|
|
8725
9726
|
else:
|
|
8726
9727
|
# generate the LTO
|
|
8727
|
-
lto_symbol, lto_code_data, shared_memory_bytes = warp.build.build_lto_fft(
|
|
9728
|
+
lto_symbol, lto_code_data, shared_memory_bytes = warp._src.build.build_lto_fft(
|
|
8728
9729
|
arch, size, ept, direction, dir, precision, builder
|
|
8729
9730
|
)
|
|
8730
9731
|
|
|
@@ -8762,7 +9763,7 @@ add_builtin(
|
|
|
8762
9763
|
group="Tile Primitives",
|
|
8763
9764
|
export=False,
|
|
8764
9765
|
namespace="",
|
|
8765
|
-
|
|
9766
|
+
is_differentiable=False,
|
|
8766
9767
|
)
|
|
8767
9768
|
|
|
8768
9769
|
add_builtin(
|
|
@@ -8784,7 +9785,7 @@ add_builtin(
|
|
|
8784
9785
|
group="Tile Primitives",
|
|
8785
9786
|
export=False,
|
|
8786
9787
|
namespace="",
|
|
8787
|
-
|
|
9788
|
+
is_differentiable=False,
|
|
8788
9789
|
)
|
|
8789
9790
|
|
|
8790
9791
|
|
|
@@ -8829,7 +9830,7 @@ def tile_cholesky_generic_lto_dispatch_func(
|
|
|
8829
9830
|
return_values: List[Var],
|
|
8830
9831
|
arg_values: Mapping[str, Var],
|
|
8831
9832
|
options: Mapping[str, Any],
|
|
8832
|
-
builder: warp.context.ModuleBuilder,
|
|
9833
|
+
builder: warp._src.context.ModuleBuilder,
|
|
8833
9834
|
):
|
|
8834
9835
|
a = arg_values["A"]
|
|
8835
9836
|
# force source tile to shared memory
|
|
@@ -8849,7 +9850,7 @@ def tile_cholesky_generic_lto_dispatch_func(
|
|
|
8849
9850
|
|
|
8850
9851
|
arch = options["output_arch"]
|
|
8851
9852
|
|
|
8852
|
-
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
9853
|
+
if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
|
|
8853
9854
|
# CPU/no-MathDx dispatch
|
|
8854
9855
|
return ((0, a, out), [], [], 0)
|
|
8855
9856
|
else:
|
|
@@ -8864,7 +9865,7 @@ def tile_cholesky_generic_lto_dispatch_func(
|
|
|
8864
9865
|
req_smem_bytes = a.type.size * type_size_in_bytes(a.type.dtype)
|
|
8865
9866
|
|
|
8866
9867
|
# generate the LTO
|
|
8867
|
-
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
9868
|
+
lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
|
|
8868
9869
|
M,
|
|
8869
9870
|
N,
|
|
8870
9871
|
1,
|
|
@@ -8909,7 +9910,7 @@ add_builtin(
|
|
|
8909
9910
|
group="Tile Primitives",
|
|
8910
9911
|
export=False,
|
|
8911
9912
|
namespace="",
|
|
8912
|
-
|
|
9913
|
+
is_differentiable=False,
|
|
8913
9914
|
)
|
|
8914
9915
|
|
|
8915
9916
|
|
|
@@ -8953,7 +9954,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
|
8953
9954
|
return_values: List[Var],
|
|
8954
9955
|
arg_values: Mapping[str, Var],
|
|
8955
9956
|
options: Mapping[str, Any],
|
|
8956
|
-
builder: warp.context.ModuleBuilder,
|
|
9957
|
+
builder: warp._src.context.ModuleBuilder,
|
|
8957
9958
|
):
|
|
8958
9959
|
L = arg_values["L"]
|
|
8959
9960
|
y = arg_values["y"]
|
|
@@ -8982,7 +9983,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
|
8982
9983
|
|
|
8983
9984
|
arch = options["output_arch"]
|
|
8984
9985
|
|
|
8985
|
-
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
9986
|
+
if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
|
|
8986
9987
|
# CPU/no-MathDx dispatch
|
|
8987
9988
|
return ((0, L, y, x), [], [], 0)
|
|
8988
9989
|
else:
|
|
@@ -8998,7 +9999,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
|
8998
9999
|
req_smem_bytes = (x.type.size + y.type.size + L.type.size) * type_size_in_bytes(L.type.dtype)
|
|
8999
10000
|
|
|
9000
10001
|
# generate the LTO
|
|
9001
|
-
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
10002
|
+
lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
|
|
9002
10003
|
M,
|
|
9003
10004
|
N,
|
|
9004
10005
|
NRHS,
|
|
@@ -9040,7 +10041,7 @@ add_builtin(
|
|
|
9040
10041
|
group="Tile Primitives",
|
|
9041
10042
|
export=False,
|
|
9042
10043
|
namespace="",
|
|
9043
|
-
|
|
10044
|
+
is_differentiable=False,
|
|
9044
10045
|
)
|
|
9045
10046
|
|
|
9046
10047
|
|
|
@@ -9050,7 +10051,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
|
|
|
9050
10051
|
return_values: List[Var],
|
|
9051
10052
|
arg_values: Mapping[str, Var],
|
|
9052
10053
|
options: Mapping[str, Any],
|
|
9053
|
-
builder: warp.context.ModuleBuilder,
|
|
10054
|
+
builder: warp._src.context.ModuleBuilder,
|
|
9054
10055
|
):
|
|
9055
10056
|
L = arg_values["L"]
|
|
9056
10057
|
y = arg_values["y"]
|
|
@@ -9079,7 +10080,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
|
|
|
9079
10080
|
|
|
9080
10081
|
arch = options["output_arch"]
|
|
9081
10082
|
|
|
9082
|
-
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
10083
|
+
if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
|
|
9083
10084
|
# CPU/no-MathDx dispatch
|
|
9084
10085
|
return ((0, L, y, z), [], [], 0)
|
|
9085
10086
|
else:
|
|
@@ -9095,7 +10096,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
|
|
|
9095
10096
|
req_smem_bytes = (z.type.size + y.type.size + L.type.size) * type_size_in_bytes(L.type.dtype)
|
|
9096
10097
|
|
|
9097
10098
|
# generate the LTO
|
|
9098
|
-
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
10099
|
+
lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
|
|
9099
10100
|
M,
|
|
9100
10101
|
N,
|
|
9101
10102
|
NRHS,
|
|
@@ -9173,7 +10174,7 @@ add_builtin(
|
|
|
9173
10174
|
group="Tile Primitives",
|
|
9174
10175
|
export=False,
|
|
9175
10176
|
namespace="",
|
|
9176
|
-
|
|
10177
|
+
is_differentiable=False,
|
|
9177
10178
|
)
|
|
9178
10179
|
|
|
9179
10180
|
|
|
@@ -9183,7 +10184,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
|
|
|
9183
10184
|
return_values: List[Var],
|
|
9184
10185
|
arg_values: Mapping[str, Var],
|
|
9185
10186
|
options: Mapping[str, Any],
|
|
9186
|
-
builder: warp.context.ModuleBuilder,
|
|
10187
|
+
builder: warp._src.context.ModuleBuilder,
|
|
9187
10188
|
):
|
|
9188
10189
|
U = arg_values["U"]
|
|
9189
10190
|
z = arg_values["z"]
|
|
@@ -9212,7 +10213,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
|
|
|
9212
10213
|
|
|
9213
10214
|
arch = options["output_arch"]
|
|
9214
10215
|
|
|
9215
|
-
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
10216
|
+
if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
|
|
9216
10217
|
# CPU/no-MathDx dispatch
|
|
9217
10218
|
return ((0, U, z, x), [], [], 0)
|
|
9218
10219
|
else:
|
|
@@ -9228,7 +10229,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
|
|
|
9228
10229
|
req_smem_bytes = (x.type.size + z.type.size + U.type.size) * type_size_in_bytes(U.type.dtype)
|
|
9229
10230
|
|
|
9230
10231
|
# generate the LTO
|
|
9231
|
-
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
10232
|
+
lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
|
|
9232
10233
|
M,
|
|
9233
10234
|
N,
|
|
9234
10235
|
NRHS,
|
|
@@ -9306,7 +10307,7 @@ add_builtin(
|
|
|
9306
10307
|
group="Tile Primitives",
|
|
9307
10308
|
export=False,
|
|
9308
10309
|
namespace="",
|
|
9309
|
-
|
|
10310
|
+
is_differentiable=False,
|
|
9310
10311
|
)
|
|
9311
10312
|
|
|
9312
10313
|
|
|
@@ -9326,7 +10327,7 @@ add_builtin(
|
|
|
9326
10327
|
The return type of the expression must be either a Warp function, a string, or a type that is supported inside Warp kernels and functions
|
|
9327
10328
|
(excluding Warp arrays since they cannot be created in a Warp kernel at the moment).""",
|
|
9328
10329
|
group="Code Generation",
|
|
9329
|
-
|
|
10330
|
+
is_differentiable=False,
|
|
9330
10331
|
)
|
|
9331
10332
|
|
|
9332
10333
|
|
|
@@ -9351,7 +10352,7 @@ add_builtin(
|
|
|
9351
10352
|
doc="Return the number of elements in a vector.",
|
|
9352
10353
|
group="Utility",
|
|
9353
10354
|
export=False,
|
|
9354
|
-
|
|
10355
|
+
is_differentiable=False,
|
|
9355
10356
|
)
|
|
9356
10357
|
|
|
9357
10358
|
add_builtin(
|
|
@@ -9361,7 +10362,7 @@ add_builtin(
|
|
|
9361
10362
|
doc="Return the number of elements in a quaternion.",
|
|
9362
10363
|
group="Utility",
|
|
9363
10364
|
export=False,
|
|
9364
|
-
|
|
10365
|
+
is_differentiable=False,
|
|
9365
10366
|
)
|
|
9366
10367
|
|
|
9367
10368
|
add_builtin(
|
|
@@ -9371,7 +10372,7 @@ add_builtin(
|
|
|
9371
10372
|
doc="Return the number of rows in a matrix.",
|
|
9372
10373
|
group="Utility",
|
|
9373
10374
|
export=False,
|
|
9374
|
-
|
|
10375
|
+
is_differentiable=False,
|
|
9375
10376
|
)
|
|
9376
10377
|
|
|
9377
10378
|
add_builtin(
|
|
@@ -9381,7 +10382,7 @@ add_builtin(
|
|
|
9381
10382
|
doc="Return the number of elements in a transformation.",
|
|
9382
10383
|
group="Utility",
|
|
9383
10384
|
export=False,
|
|
9384
|
-
|
|
10385
|
+
is_differentiable=False,
|
|
9385
10386
|
)
|
|
9386
10387
|
|
|
9387
10388
|
add_builtin(
|
|
@@ -9391,7 +10392,7 @@ add_builtin(
|
|
|
9391
10392
|
doc="Return the size of the first dimension in an array.",
|
|
9392
10393
|
group="Utility",
|
|
9393
10394
|
export=False,
|
|
9394
|
-
|
|
10395
|
+
is_differentiable=False,
|
|
9395
10396
|
)
|
|
9396
10397
|
|
|
9397
10398
|
add_builtin(
|
|
@@ -9401,7 +10402,62 @@ add_builtin(
|
|
|
9401
10402
|
doc="Return the number of rows in a tile.",
|
|
9402
10403
|
group="Utility",
|
|
9403
10404
|
export=False,
|
|
9404
|
-
|
|
10405
|
+
is_differentiable=False,
|
|
10406
|
+
)
|
|
10407
|
+
|
|
10408
|
+
|
|
10409
|
+
def cast_value_func(arg_types, arg_values):
|
|
10410
|
+
# Return generic type for doc builds.
|
|
10411
|
+
if arg_types is None:
|
|
10412
|
+
return Any
|
|
10413
|
+
|
|
10414
|
+
return arg_values["dtype"]
|
|
10415
|
+
|
|
10416
|
+
|
|
10417
|
+
def cast_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
10418
|
+
func_args = (args["a"],)
|
|
10419
|
+
template_args = (args["dtype"],)
|
|
10420
|
+
return (func_args, template_args)
|
|
10421
|
+
|
|
10422
|
+
|
|
10423
|
+
add_builtin(
|
|
10424
|
+
"cast",
|
|
10425
|
+
input_types={"a": Any, "dtype": Any},
|
|
10426
|
+
value_func=cast_value_func,
|
|
10427
|
+
dispatch_func=cast_dispatch_func,
|
|
10428
|
+
group="Utility",
|
|
10429
|
+
export=False,
|
|
10430
|
+
is_differentiable=False,
|
|
10431
|
+
doc="""Reinterpret a value as a different type while preserving its bit pattern.
|
|
10432
|
+
|
|
10433
|
+
:param a: The value to cast
|
|
10434
|
+
:param dtype: The target type
|
|
10435
|
+
|
|
10436
|
+
Example:
|
|
10437
|
+
|
|
10438
|
+
.. code-block:: python
|
|
10439
|
+
|
|
10440
|
+
@wp.struct
|
|
10441
|
+
class MyStruct:
|
|
10442
|
+
f: wp.float16
|
|
10443
|
+
i: wp.int16
|
|
10444
|
+
|
|
10445
|
+
|
|
10446
|
+
@wp.kernel
|
|
10447
|
+
def compute():
|
|
10448
|
+
x = wp.int32(0x40000000)
|
|
10449
|
+
x_casted = wp.cast(x, wp.float32)
|
|
10450
|
+
wp.expect_eq(x_casted, 2.0) # 0x40000000
|
|
10451
|
+
|
|
10452
|
+
s = MyStruct()
|
|
10453
|
+
s.f = wp.float16(2.0) # 0x4000
|
|
10454
|
+
s.i = wp.int16(4096) # 0x1000
|
|
10455
|
+
s_casted = wp.cast(s, wp.int32)
|
|
10456
|
+
wp.expect_eq(s_casted, 0x10004000)
|
|
10457
|
+
|
|
10458
|
+
|
|
10459
|
+
wp.launch(compute, dim=1)
|
|
10460
|
+
""",
|
|
9405
10461
|
)
|
|
9406
10462
|
|
|
9407
10463
|
|
|
@@ -9428,7 +10484,7 @@ add_builtin(
|
|
|
9428
10484
|
doc="Construct a tuple from a list of values",
|
|
9429
10485
|
group="Utility",
|
|
9430
10486
|
hidden=True,
|
|
9431
|
-
|
|
10487
|
+
is_differentiable=False,
|
|
9432
10488
|
export=False,
|
|
9433
10489
|
)
|
|
9434
10490
|
|
|
@@ -9465,7 +10521,7 @@ add_builtin(
|
|
|
9465
10521
|
dispatch_func=tuple_extract_dispatch_func,
|
|
9466
10522
|
group="Utility",
|
|
9467
10523
|
hidden=True,
|
|
9468
|
-
|
|
10524
|
+
is_differentiable=False,
|
|
9469
10525
|
)
|
|
9470
10526
|
|
|
9471
10527
|
|
|
@@ -9476,7 +10532,7 @@ add_builtin(
|
|
|
9476
10532
|
doc="Return the number of elements in a tuple.",
|
|
9477
10533
|
group="Utility",
|
|
9478
10534
|
export=False,
|
|
9479
|
-
|
|
10535
|
+
is_differentiable=False,
|
|
9480
10536
|
)
|
|
9481
10537
|
|
|
9482
10538
|
# ---------------------------------
|
|
@@ -9495,5 +10551,5 @@ add_builtin(
|
|
|
9495
10551
|
export=False,
|
|
9496
10552
|
group="Utility",
|
|
9497
10553
|
hidden=True,
|
|
9498
|
-
|
|
10554
|
+
is_differentiable=False,
|
|
9499
10555
|
)
|