warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +301 -287
- warp/__init__.pyi +2220 -313
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1075 -0
- warp/_src/build.py +618 -0
- warp/_src/build_dll.py +640 -0
- warp/{builtins.py → _src/builtins.py} +1497 -226
- warp/_src/codegen.py +4359 -0
- warp/{config.py → _src/config.py} +178 -169
- warp/_src/constants.py +57 -0
- warp/_src/context.py +8294 -0
- warp/_src/dlpack.py +462 -0
- warp/_src/fabric.py +355 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +508 -0
- warp/_src/fem/cache.py +687 -0
- warp/_src/fem/dirichlet.py +188 -0
- warp/{fem → _src/fem}/domain.py +40 -30
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +701 -0
- warp/{fem → _src/fem}/field/nodal_field.py +30 -15
- warp/{fem → _src/fem}/field/restriction.py +1 -1
- warp/{fem → _src/fem}/field/virtual.py +53 -27
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
- warp/_src/fem/geometry/closest_point.py +97 -0
- warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
- warp/{fem → _src/fem}/geometry/element.py +32 -10
- warp/{fem → _src/fem}/geometry/geometry.py +48 -20
- warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
- warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
- warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
- warp/{fem → _src/fem}/geometry/partition.py +121 -63
- warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
- warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
- warp/{fem → _src/fem}/integrate.py +164 -158
- warp/_src/fem/linalg.py +383 -0
- warp/_src/fem/operator.py +396 -0
- warp/_src/fem/polynomial.py +229 -0
- warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
- warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
- warp/_src/fem/space/__init__.py +248 -0
- warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
- warp/_src/fem/space/basis_space.py +679 -0
- warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
- warp/{fem → _src/fem}/space/function_space.py +14 -13
- warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
- warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
- warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
- warp/{fem → _src/fem}/space/partition.py +117 -60
- warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/restriction.py +66 -33
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
- warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
- warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
- warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
- warp/_src/fem/space/topology.py +459 -0
- warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
- warp/_src/fem/types.py +112 -0
- warp/_src/fem/utils.py +486 -0
- warp/_src/jax.py +186 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +387 -0
- warp/_src/jax_experimental/ffi.py +1284 -0
- warp/_src/jax_experimental/xla_ffi.py +656 -0
- warp/_src/marching_cubes.py +708 -0
- warp/_src/math.py +414 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +163 -0
- warp/_src/optim/linear.py +1606 -0
- warp/_src/optim/sgd.py +112 -0
- warp/_src/paddle.py +406 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +289 -0
- warp/_src/render/render_opengl.py +3636 -0
- warp/_src/render/render_usd.py +937 -0
- warp/_src/render/utils.py +160 -0
- warp/_src/sparse.py +2716 -0
- warp/_src/tape.py +1206 -0
- warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
- warp/_src/torch.py +391 -0
- warp/_src/types.py +5870 -0
- warp/_src/utils.py +1693 -0
- warp/autograd.py +12 -1054
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +8 -588
- warp/build_dll.py +6 -471
- warp/codegen.py +6 -4246
- warp/constants.py +6 -39
- warp/context.py +12 -7851
- warp/dlpack.py +6 -444
- warp/examples/distributed/example_jacobi_mpi.py +4 -5
- warp/examples/fem/example_adaptive_grid.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +8 -8
- warp/examples/fem/example_diffusion.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_mixed_elasticity.py +2 -2
- warp/examples/fem/example_navier_stokes.py +1 -1
- warp/examples/fem/example_nonconforming_contact.py +7 -7
- warp/examples/fem/example_stokes.py +1 -1
- warp/examples/fem/example_stokes_transfer.py +1 -1
- warp/examples/fem/utils.py +2 -2
- warp/examples/interop/example_jax_callable.py +1 -1
- warp/examples/interop/example_jax_ffi_callback.py +1 -1
- warp/examples/interop/example_jax_kernel.py +3 -2
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/fabric.py +6 -337
- warp/fem/__init__.py +159 -97
- warp/fem/adaptivity.py +7 -489
- warp/fem/cache.py +9 -648
- warp/fem/dirichlet.py +6 -184
- warp/fem/field/__init__.py +8 -109
- warp/fem/field/field.py +7 -652
- warp/fem/geometry/__init__.py +7 -18
- warp/fem/geometry/closest_point.py +11 -77
- warp/fem/linalg.py +18 -366
- warp/fem/operator.py +11 -369
- warp/fem/polynomial.py +9 -209
- warp/fem/space/__init__.py +5 -211
- warp/fem/space/basis_space.py +6 -662
- warp/fem/space/shape/__init__.py +41 -118
- warp/fem/space/topology.py +6 -437
- warp/fem/types.py +6 -81
- warp/fem/utils.py +11 -444
- warp/jax.py +8 -165
- warp/jax_experimental/__init__.py +14 -1
- warp/jax_experimental/custom_call.py +8 -342
- warp/jax_experimental/ffi.py +17 -853
- warp/jax_experimental/xla_ffi.py +5 -596
- warp/marching_cubes.py +5 -689
- warp/math.py +16 -393
- warp/native/array.h +385 -37
- warp/native/builtin.h +316 -39
- warp/native/bvh.cpp +43 -9
- warp/native/bvh.cu +62 -27
- warp/native/bvh.h +310 -309
- warp/native/clang/clang.cpp +102 -97
- warp/native/coloring.cpp +0 -1
- warp/native/crt.h +208 -0
- warp/native/exports.h +156 -0
- warp/native/hashgrid.cu +2 -0
- warp/native/intersect.h +24 -1
- warp/native/intersect_tri.h +44 -35
- warp/native/mat.h +1456 -276
- warp/native/mesh.cpp +4 -4
- warp/native/mesh.cu +4 -2
- warp/native/mesh.h +176 -61
- warp/native/quat.h +0 -52
- warp/native/scan.cu +2 -0
- warp/native/sort.cu +22 -13
- warp/native/sort.h +2 -0
- warp/native/sparse.cu +7 -3
- warp/native/spatial.h +12 -0
- warp/native/tile.h +837 -70
- warp/native/tile_radix_sort.h +1 -1
- warp/native/tile_reduce.h +394 -46
- warp/native/tile_scan.h +4 -4
- warp/native/vec.h +469 -53
- warp/native/version.h +23 -0
- warp/native/volume.cpp +1 -1
- warp/native/volume.cu +1 -0
- warp/native/volume.h +1 -1
- warp/native/volume_builder.cu +2 -0
- warp/native/warp.cpp +60 -32
- warp/native/warp.cu +313 -201
- warp/native/warp.h +14 -11
- warp/optim/__init__.py +6 -3
- warp/optim/adam.py +6 -145
- warp/optim/linear.py +14 -1585
- warp/optim/sgd.py +6 -94
- warp/paddle.py +6 -388
- warp/render/__init__.py +8 -4
- warp/render/imgui_manager.py +7 -267
- warp/render/render_opengl.py +6 -3616
- warp/render/render_usd.py +6 -918
- warp/render/utils.py +6 -142
- warp/sparse.py +37 -2563
- warp/tape.py +6 -1188
- warp/tests/__main__.py +1 -1
- warp/tests/cuda/test_async.py +4 -4
- warp/tests/cuda/test_conditional_captures.py +1 -1
- warp/tests/cuda/test_multigpu.py +1 -1
- warp/tests/cuda/test_streams.py +58 -1
- warp/tests/geometry/test_bvh.py +157 -22
- warp/tests/geometry/test_hash_grid.py +38 -0
- warp/tests/geometry/test_marching_cubes.py +0 -1
- warp/tests/geometry/test_mesh.py +5 -3
- warp/tests/geometry/test_mesh_query_aabb.py +5 -12
- warp/tests/geometry/test_mesh_query_point.py +5 -2
- warp/tests/geometry/test_mesh_query_ray.py +15 -3
- warp/tests/geometry/test_volume_write.py +5 -5
- warp/tests/interop/test_dlpack.py +14 -14
- warp/tests/interop/test_jax.py +1382 -79
- warp/tests/interop/test_paddle.py +1 -1
- warp/tests/test_adam.py +0 -1
- warp/tests/test_arithmetic.py +9 -9
- warp/tests/test_array.py +529 -100
- warp/tests/test_array_reduce.py +3 -3
- warp/tests/test_atomic.py +12 -8
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +4 -4
- warp/tests/test_bool.py +2 -2
- warp/tests/test_builtins_resolution.py +5 -571
- warp/tests/test_codegen.py +34 -15
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_context.py +6 -6
- warp/tests/test_copy.py +242 -161
- warp/tests/test_ctypes.py +3 -3
- warp/tests/test_devices.py +24 -2
- warp/tests/test_examples.py +16 -84
- warp/tests/test_fabricarray.py +35 -35
- warp/tests/test_fast_math.py +0 -2
- warp/tests/test_fem.py +60 -14
- warp/tests/test_fixedarray.py +3 -3
- warp/tests/test_func.py +8 -5
- warp/tests/test_generics.py +1 -1
- warp/tests/test_indexedarray.py +24 -24
- warp/tests/test_intersect.py +39 -9
- warp/tests/test_large.py +1 -1
- warp/tests/test_lerp.py +3 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_map.py +49 -4
- warp/tests/test_mat.py +52 -62
- warp/tests/test_mat_constructors.py +4 -5
- warp/tests/test_mat_lite.py +1 -1
- warp/tests/test_mat_scalar_ops.py +121 -121
- warp/tests/test_math.py +34 -0
- warp/tests/test_module_aot.py +4 -4
- warp/tests/test_modules_lite.py +28 -2
- warp/tests/test_print.py +11 -11
- warp/tests/test_quat.py +93 -58
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +38 -10
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +126 -15
- warp/tests/test_spatial.py +105 -87
- warp/tests/test_special_values.py +6 -6
- warp/tests/test_static.py +7 -7
- warp/tests/test_struct.py +13 -2
- warp/tests/test_triangle_closest_point.py +48 -1
- warp/tests/test_tuple.py +96 -0
- warp/tests/test_types.py +82 -9
- warp/tests/test_utils.py +52 -52
- warp/tests/test_vec.py +29 -29
- warp/tests/test_vec_constructors.py +5 -5
- warp/tests/test_vec_scalar_ops.py +97 -97
- warp/tests/test_version.py +75 -0
- warp/tests/tile/test_tile.py +239 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +7 -4
- warp/tests/tile/test_tile_load.py +26 -2
- warp/tests/tile/test_tile_mathdx.py +3 -3
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +2 -4
- warp/tests/tile/test_tile_reduce.py +214 -13
- warp/tests/unittest_suites.py +6 -14
- warp/tests/unittest_utils.py +10 -9
- warp/tests/walkthrough_debug.py +3 -1
- warp/torch.py +6 -373
- warp/types.py +29 -5750
- warp/utils.py +10 -1659
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
- warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp/examples/assets/cartpole.urdf +0 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/nv_ant.xml +0 -92
- warp/examples/assets/nv_humanoid.xml +0 -183
- warp/examples/assets/quadruped.urdf +0 -268
- warp/examples/optim/example_bounce.py +0 -266
- warp/examples/optim/example_cloth_throw.py +0 -228
- warp/examples/optim/example_drone.py +0 -870
- warp/examples/optim/example_inverse_kinematics.py +0 -182
- warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
- warp/examples/optim/example_softbody_properties.py +0 -400
- warp/examples/optim/example_spring_cage.py +0 -245
- warp/examples/optim/example_trajectory.py +0 -227
- warp/examples/sim/example_cartpole.py +0 -143
- warp/examples/sim/example_cloth.py +0 -225
- warp/examples/sim/example_cloth_self_contact.py +0 -316
- warp/examples/sim/example_granular.py +0 -130
- warp/examples/sim/example_granular_collision_sdf.py +0 -202
- warp/examples/sim/example_jacobian_ik.py +0 -244
- warp/examples/sim/example_particle_chain.py +0 -124
- warp/examples/sim/example_quadruped.py +0 -203
- warp/examples/sim/example_rigid_chain.py +0 -203
- warp/examples/sim/example_rigid_contact.py +0 -195
- warp/examples/sim/example_rigid_force.py +0 -133
- warp/examples/sim/example_rigid_gyroscopic.py +0 -115
- warp/examples/sim/example_rigid_soft_contact.py +0 -140
- warp/examples/sim/example_soft_body.py +0 -196
- warp/examples/tile/example_tile_walker.py +0 -327
- warp/sim/__init__.py +0 -74
- warp/sim/articulation.py +0 -793
- warp/sim/collide.py +0 -2570
- warp/sim/graph_coloring.py +0 -307
- warp/sim/import_mjcf.py +0 -791
- warp/sim/import_snu.py +0 -227
- warp/sim/import_urdf.py +0 -579
- warp/sim/import_usd.py +0 -898
- warp/sim/inertia.py +0 -357
- warp/sim/integrator.py +0 -245
- warp/sim/integrator_euler.py +0 -2000
- warp/sim/integrator_featherstone.py +0 -2101
- warp/sim/integrator_vbd.py +0 -2487
- warp/sim/integrator_xpbd.py +0 -3295
- warp/sim/model.py +0 -4821
- warp/sim/particles.py +0 -121
- warp/sim/render.py +0 -431
- warp/sim/utils.py +0 -431
- warp/tests/sim/disabled_kinematics.py +0 -244
- warp/tests/sim/test_cloth.py +0 -863
- warp/tests/sim/test_collision.py +0 -743
- warp/tests/sim/test_coloring.py +0 -347
- warp/tests/sim/test_inertia.py +0 -161
- warp/tests/sim/test_model.py +0 -226
- warp/tests/sim/test_sim_grad.py +0 -287
- warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
- warp/tests/sim/test_sim_kinematics.py +0 -98
- warp/thirdparty/__init__.py +0 -0
- warp_lang-1.9.0.dist-info/RECORD +0 -456
- /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
- /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -20,11 +20,11 @@ import functools
|
|
|
20
20
|
import math
|
|
21
21
|
from typing import Any, Callable, Mapping, Sequence
|
|
22
22
|
|
|
23
|
-
import warp.build
|
|
24
|
-
import warp.context
|
|
25
|
-
import warp.utils
|
|
26
|
-
from warp.codegen import Reference, Var, get_arg_value, strip_reference
|
|
27
|
-
from warp.types import *
|
|
23
|
+
import warp._src.build
|
|
24
|
+
import warp._src.context
|
|
25
|
+
import warp._src.utils
|
|
26
|
+
from warp._src.codegen import Reference, Var, get_arg_value, strip_reference
|
|
27
|
+
from warp._src.types import *
|
|
28
28
|
|
|
29
29
|
from .context import add_builtin
|
|
30
30
|
|
|
@@ -61,11 +61,11 @@ def sametypes_create_value_func(default: TypeVar):
|
|
|
61
61
|
|
|
62
62
|
def extract_tuple(arg, as_constant=False):
|
|
63
63
|
if isinstance(arg, Var):
|
|
64
|
-
if isinstance(arg.type, warp.types.tuple_t):
|
|
64
|
+
if isinstance(arg.type, warp._src.types.tuple_t):
|
|
65
65
|
out = arg.type.values
|
|
66
66
|
else:
|
|
67
67
|
out = (arg,)
|
|
68
|
-
elif isinstance(arg, warp.types.tuple_t):
|
|
68
|
+
elif isinstance(arg, warp._src.types.tuple_t):
|
|
69
69
|
out = arg.values
|
|
70
70
|
elif not isinstance(arg, Sequence):
|
|
71
71
|
out = (arg,)
|
|
@@ -82,7 +82,7 @@ def static_len_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str
|
|
|
82
82
|
if arg_types is None:
|
|
83
83
|
return int
|
|
84
84
|
|
|
85
|
-
length = warp.types.type_length(arg_types["a"])
|
|
85
|
+
length = warp._src.types.type_length(arg_types["a"])
|
|
86
86
|
return Var(None, type=int, constant=length)
|
|
87
87
|
|
|
88
88
|
|
|
@@ -126,6 +126,7 @@ add_builtin(
|
|
|
126
126
|
value_func=sametypes_create_value_func(Scalar),
|
|
127
127
|
doc="Return -1 if ``x`` < 0, return 1 otherwise.",
|
|
128
128
|
group="Scalar Math",
|
|
129
|
+
is_differentiable=False,
|
|
129
130
|
)
|
|
130
131
|
|
|
131
132
|
add_builtin(
|
|
@@ -134,6 +135,7 @@ add_builtin(
|
|
|
134
135
|
value_func=sametypes_create_value_func(Scalar),
|
|
135
136
|
doc="Return 1.0 if ``x`` < 0.0, return 0.0 otherwise.",
|
|
136
137
|
group="Scalar Math",
|
|
138
|
+
is_differentiable=False,
|
|
137
139
|
)
|
|
138
140
|
add_builtin(
|
|
139
141
|
"nonzero",
|
|
@@ -141,6 +143,7 @@ add_builtin(
|
|
|
141
143
|
value_func=sametypes_create_value_func(Scalar),
|
|
142
144
|
doc="Return 1.0 if ``x`` is not equal to zero, return 0.0 otherwise.",
|
|
143
145
|
group="Scalar Math",
|
|
146
|
+
is_differentiable=False,
|
|
144
147
|
)
|
|
145
148
|
|
|
146
149
|
add_builtin(
|
|
@@ -282,7 +285,36 @@ add_builtin(
|
|
|
282
285
|
group="Scalar Math",
|
|
283
286
|
require_original_output_arg=True,
|
|
284
287
|
)
|
|
285
|
-
|
|
288
|
+
add_builtin(
|
|
289
|
+
"erf",
|
|
290
|
+
input_types={"x": Float},
|
|
291
|
+
value_func=sametypes_create_value_func(Float),
|
|
292
|
+
doc="Return the error function of ``x``.",
|
|
293
|
+
group="Scalar Math",
|
|
294
|
+
)
|
|
295
|
+
add_builtin(
|
|
296
|
+
"erfc",
|
|
297
|
+
input_types={"x": Float},
|
|
298
|
+
value_func=sametypes_create_value_func(Float),
|
|
299
|
+
doc="Return the complementary error function of ``x``.",
|
|
300
|
+
group="Scalar Math",
|
|
301
|
+
)
|
|
302
|
+
add_builtin(
|
|
303
|
+
"erfinv",
|
|
304
|
+
input_types={"x": Float},
|
|
305
|
+
value_func=sametypes_create_value_func(Float),
|
|
306
|
+
doc="Return the inverse error function of ``x``.",
|
|
307
|
+
group="Scalar Math",
|
|
308
|
+
require_original_output_arg=True,
|
|
309
|
+
)
|
|
310
|
+
add_builtin(
|
|
311
|
+
"erfcinv",
|
|
312
|
+
input_types={"x": Float},
|
|
313
|
+
value_func=sametypes_create_value_func(Float),
|
|
314
|
+
doc="Return the inverse complementary error function of ``x``.",
|
|
315
|
+
group="Scalar Math",
|
|
316
|
+
require_original_output_arg=True,
|
|
317
|
+
)
|
|
286
318
|
add_builtin(
|
|
287
319
|
"round",
|
|
288
320
|
input_types={"x": Float},
|
|
@@ -292,6 +324,7 @@ add_builtin(
|
|
|
292
324
|
|
|
293
325
|
This is the most intuitive form of rounding in the colloquial sense, but can be slower than other options like :func:`warp.rint()`.
|
|
294
326
|
Differs from :func:`numpy.round()`, which behaves the same way as :func:`numpy.rint()`.""",
|
|
327
|
+
is_differentiable=False,
|
|
295
328
|
)
|
|
296
329
|
|
|
297
330
|
add_builtin(
|
|
@@ -302,6 +335,7 @@ add_builtin(
|
|
|
302
335
|
doc="""Return the nearest integer value to ``x``, rounding halfway cases to nearest even integer.
|
|
303
336
|
|
|
304
337
|
It is generally faster than :func:`warp.round()`. Equivalent to :func:`numpy.rint()`.""",
|
|
338
|
+
is_differentiable=False,
|
|
305
339
|
)
|
|
306
340
|
|
|
307
341
|
add_builtin(
|
|
@@ -314,6 +348,7 @@ add_builtin(
|
|
|
314
348
|
In other words, it discards the fractional part of ``x``.
|
|
315
349
|
It is similar to casting ``float(int(a))``, but preserves the negative sign when ``x`` is in the range [-0.0, -1.0).
|
|
316
350
|
Equivalent to :func:`numpy.trunc()` and :func:`numpy.fix()`.""",
|
|
351
|
+
is_differentiable=False,
|
|
317
352
|
)
|
|
318
353
|
|
|
319
354
|
add_builtin(
|
|
@@ -322,6 +357,7 @@ add_builtin(
|
|
|
322
357
|
value_func=sametypes_create_value_func(Float),
|
|
323
358
|
group="Scalar Math",
|
|
324
359
|
doc="""Return the largest integer that is less than or equal to ``x``.""",
|
|
360
|
+
is_differentiable=False,
|
|
325
361
|
)
|
|
326
362
|
|
|
327
363
|
add_builtin(
|
|
@@ -330,6 +366,7 @@ add_builtin(
|
|
|
330
366
|
value_func=sametypes_create_value_func(Float),
|
|
331
367
|
group="Scalar Math",
|
|
332
368
|
doc="""Return the smallest integer that is greater than or equal to ``x``.""",
|
|
369
|
+
is_differentiable=False,
|
|
333
370
|
)
|
|
334
371
|
|
|
335
372
|
add_builtin(
|
|
@@ -340,6 +377,7 @@ add_builtin(
|
|
|
340
377
|
doc="""Retrieve the fractional part of ``x``.
|
|
341
378
|
|
|
342
379
|
In other words, it discards the integer part of ``x`` and is equivalent to ``x - trunc(x)``.""",
|
|
380
|
+
is_differentiable=False,
|
|
343
381
|
)
|
|
344
382
|
|
|
345
383
|
add_builtin(
|
|
@@ -348,6 +386,7 @@ add_builtin(
|
|
|
348
386
|
value_type=builtins.bool,
|
|
349
387
|
group="Scalar Math",
|
|
350
388
|
doc="""Return ``True`` if ``a`` is a finite number, otherwise return ``False``.""",
|
|
389
|
+
is_differentiable=False,
|
|
351
390
|
)
|
|
352
391
|
add_builtin(
|
|
353
392
|
"isfinite",
|
|
@@ -355,6 +394,7 @@ add_builtin(
|
|
|
355
394
|
value_type=builtins.bool,
|
|
356
395
|
group="Vector Math",
|
|
357
396
|
doc="Return ``True`` if all elements of the vector ``a`` are finite, otherwise return ``False``.",
|
|
397
|
+
is_differentiable=False,
|
|
358
398
|
)
|
|
359
399
|
add_builtin(
|
|
360
400
|
"isfinite",
|
|
@@ -362,6 +402,7 @@ add_builtin(
|
|
|
362
402
|
value_type=builtins.bool,
|
|
363
403
|
group="Vector Math",
|
|
364
404
|
doc="Return ``True`` if all elements of the quaternion ``a`` are finite, otherwise return ``False``.",
|
|
405
|
+
is_differentiable=False,
|
|
365
406
|
)
|
|
366
407
|
add_builtin(
|
|
367
408
|
"isfinite",
|
|
@@ -369,6 +410,7 @@ add_builtin(
|
|
|
369
410
|
value_type=builtins.bool,
|
|
370
411
|
group="Vector Math",
|
|
371
412
|
doc="Return ``True`` if all elements of the matrix ``a`` are finite, otherwise return ``False``.",
|
|
413
|
+
is_differentiable=False,
|
|
372
414
|
)
|
|
373
415
|
|
|
374
416
|
add_builtin(
|
|
@@ -377,6 +419,7 @@ add_builtin(
|
|
|
377
419
|
value_type=builtins.bool,
|
|
378
420
|
doc="Return ``True`` if ``a`` is NaN, otherwise return ``False``.",
|
|
379
421
|
group="Scalar Math",
|
|
422
|
+
is_differentiable=False,
|
|
380
423
|
)
|
|
381
424
|
add_builtin(
|
|
382
425
|
"isnan",
|
|
@@ -384,6 +427,7 @@ add_builtin(
|
|
|
384
427
|
value_type=builtins.bool,
|
|
385
428
|
group="Vector Math",
|
|
386
429
|
doc="Return ``True`` if any element of the vector ``a`` is NaN, otherwise return ``False``.",
|
|
430
|
+
is_differentiable=False,
|
|
387
431
|
)
|
|
388
432
|
add_builtin(
|
|
389
433
|
"isnan",
|
|
@@ -391,6 +435,7 @@ add_builtin(
|
|
|
391
435
|
value_type=builtins.bool,
|
|
392
436
|
group="Vector Math",
|
|
393
437
|
doc="Return ``True`` if any element of the quaternion ``a`` is NaN, otherwise return ``False``.",
|
|
438
|
+
is_differentiable=False,
|
|
394
439
|
)
|
|
395
440
|
add_builtin(
|
|
396
441
|
"isnan",
|
|
@@ -398,6 +443,7 @@ add_builtin(
|
|
|
398
443
|
value_type=builtins.bool,
|
|
399
444
|
group="Vector Math",
|
|
400
445
|
doc="Return ``True`` if any element of the matrix ``a`` is NaN, otherwise return ``False``.",
|
|
446
|
+
is_differentiable=False,
|
|
401
447
|
)
|
|
402
448
|
|
|
403
449
|
add_builtin(
|
|
@@ -406,6 +452,7 @@ add_builtin(
|
|
|
406
452
|
value_type=builtins.bool,
|
|
407
453
|
group="Scalar Math",
|
|
408
454
|
doc="""Return ``True`` if ``a`` is positive or negative infinity, otherwise return ``False``.""",
|
|
455
|
+
is_differentiable=False,
|
|
409
456
|
)
|
|
410
457
|
add_builtin(
|
|
411
458
|
"isinf",
|
|
@@ -413,6 +460,7 @@ add_builtin(
|
|
|
413
460
|
value_type=builtins.bool,
|
|
414
461
|
group="Vector Math",
|
|
415
462
|
doc="Return ``True`` if any element of the vector ``a`` is positive or negative infinity, otherwise return ``False``.",
|
|
463
|
+
is_differentiable=False,
|
|
416
464
|
)
|
|
417
465
|
add_builtin(
|
|
418
466
|
"isinf",
|
|
@@ -420,6 +468,7 @@ add_builtin(
|
|
|
420
468
|
value_type=builtins.bool,
|
|
421
469
|
group="Vector Math",
|
|
422
470
|
doc="Return ``True`` if any element of the quaternion ``a`` is positive or negative infinity, otherwise return ``False``.",
|
|
471
|
+
is_differentiable=False,
|
|
423
472
|
)
|
|
424
473
|
add_builtin(
|
|
425
474
|
"isinf",
|
|
@@ -427,6 +476,7 @@ add_builtin(
|
|
|
427
476
|
value_type=builtins.bool,
|
|
428
477
|
group="Vector Math",
|
|
429
478
|
doc="Return ``True`` if any element of the matrix ``a`` is positive or negative infinity, otherwise return ``False``.",
|
|
479
|
+
is_differentiable=False,
|
|
430
480
|
)
|
|
431
481
|
|
|
432
482
|
|
|
@@ -534,7 +584,7 @@ add_builtin(
|
|
|
534
584
|
value_func=lambda arg_types, arg_values: warp.uint32,
|
|
535
585
|
doc="Return the index of the minimum element of a vector ``a``.",
|
|
536
586
|
group="Vector Math",
|
|
537
|
-
|
|
587
|
+
is_differentiable=False,
|
|
538
588
|
)
|
|
539
589
|
add_builtin(
|
|
540
590
|
"argmax",
|
|
@@ -542,7 +592,7 @@ add_builtin(
|
|
|
542
592
|
value_func=lambda arg_types, arg_values: warp.uint32,
|
|
543
593
|
doc="Return the index of the maximum element of a vector ``a``.",
|
|
544
594
|
group="Vector Math",
|
|
545
|
-
|
|
595
|
+
is_differentiable=False,
|
|
546
596
|
)
|
|
547
597
|
|
|
548
598
|
add_builtin(
|
|
@@ -867,7 +917,7 @@ def vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
867
917
|
|
|
868
918
|
if dtype is None:
|
|
869
919
|
dtype = value_type
|
|
870
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
920
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
871
921
|
raise RuntimeError(
|
|
872
922
|
f"the value used to fill this vector is expected to be of the type `{dtype.__name__}`"
|
|
873
923
|
)
|
|
@@ -888,7 +938,7 @@ def vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
888
938
|
|
|
889
939
|
if dtype is None:
|
|
890
940
|
dtype = value_type
|
|
891
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
941
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
892
942
|
raise RuntimeError(
|
|
893
943
|
f"all values used to initialize this vector are expected to be of the type `{dtype.__name__}`"
|
|
894
944
|
)
|
|
@@ -971,7 +1021,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
971
1021
|
|
|
972
1022
|
if dtype is None:
|
|
973
1023
|
dtype = value_type
|
|
974
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1024
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
975
1025
|
raise RuntimeError(
|
|
976
1026
|
f"the value used to fill this matrix is expected to be of the type `{dtype.__name__}`"
|
|
977
1027
|
)
|
|
@@ -981,7 +1031,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
981
1031
|
raise RuntimeError("the `shape` argument must be specified when initializing a matrix by value")
|
|
982
1032
|
|
|
983
1033
|
if all(type_is_vector(x) for x in variadic_arg_types):
|
|
984
|
-
warp.utils.warn(
|
|
1034
|
+
warp._src.utils.warn(
|
|
985
1035
|
"the built-in `wp.matrix()` won't support taking column vectors as input "
|
|
986
1036
|
"in the future. Use `wp.matrix_from_rows()` or `wp.matrix_from_cols()` instead.",
|
|
987
1037
|
DeprecationWarning,
|
|
@@ -1010,7 +1060,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
1010
1060
|
|
|
1011
1061
|
if dtype is None:
|
|
1012
1062
|
dtype = value_type
|
|
1013
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1063
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
1014
1064
|
raise RuntimeError(
|
|
1015
1065
|
f"all values used to initialize this matrix are expected to be of the type `{dtype.__name__}`"
|
|
1016
1066
|
)
|
|
@@ -1182,48 +1232,18 @@ add_builtin(
|
|
|
1182
1232
|
doc="Create an identity matrix with shape=(n,n) with the type given by ``dtype``.",
|
|
1183
1233
|
group="Vector Math",
|
|
1184
1234
|
export=False,
|
|
1235
|
+
is_differentiable=False,
|
|
1185
1236
|
)
|
|
1186
1237
|
|
|
1187
1238
|
|
|
1188
1239
|
def matrix_transform_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1189
|
-
warp.utils.warn(
|
|
1190
|
-
"the built-in `wp.matrix()` function to construct a 4x4 matrix from a 3D position, quaternion, "
|
|
1191
|
-
"and 3D scale vector will be deprecated in favor of `wp.transform_compose()`.",
|
|
1192
|
-
DeprecationWarning,
|
|
1193
|
-
)
|
|
1194
1240
|
if arg_types is None:
|
|
1195
1241
|
return matrix(shape=(4, 4), dtype=Float)
|
|
1196
1242
|
|
|
1197
|
-
|
|
1198
|
-
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
value_type = scalar_infer_type(value_arg_types)
|
|
1202
|
-
except RuntimeError:
|
|
1203
|
-
raise RuntimeError(
|
|
1204
|
-
"all values given when constructing a transformation matrix must have the same type"
|
|
1205
|
-
) from None
|
|
1206
|
-
|
|
1207
|
-
if dtype is None:
|
|
1208
|
-
dtype = value_type
|
|
1209
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1210
|
-
raise RuntimeError(
|
|
1211
|
-
f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
|
|
1212
|
-
)
|
|
1213
|
-
|
|
1214
|
-
return matrix(shape=(4, 4), dtype=dtype)
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
def matrix_transform_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
1218
|
-
# We're in the codegen stage where we emit the code calling the built-in.
|
|
1219
|
-
# Further validate the given argument values if needed and map them
|
|
1220
|
-
# to the underlying C++ function's runtime and template params.
|
|
1221
|
-
|
|
1222
|
-
dtype = return_type._wp_scalar_type_
|
|
1223
|
-
|
|
1224
|
-
func_args = tuple(v for k, v in args.items() if k != "dtype")
|
|
1225
|
-
template_args = (4, 4, dtype)
|
|
1226
|
-
return (func_args, template_args)
|
|
1243
|
+
raise RuntimeError(
|
|
1244
|
+
"the built-in `wp.matrix()` to construct a 4x4 matrix from a 3D position, quaternion, "
|
|
1245
|
+
"and 3D scale vector has been removed in favor of `wp.transform_compose()`."
|
|
1246
|
+
)
|
|
1227
1247
|
|
|
1228
1248
|
|
|
1229
1249
|
add_builtin(
|
|
@@ -1237,13 +1257,14 @@ add_builtin(
|
|
|
1237
1257
|
defaults={"dtype": None},
|
|
1238
1258
|
value_func=matrix_transform_value_func,
|
|
1239
1259
|
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
1240
|
-
dispatch_func=matrix_transform_dispatch_func,
|
|
1241
1260
|
native_func="mat_t",
|
|
1242
1261
|
doc="""Construct a 4x4 transformation matrix that applies the transformations as
|
|
1243
1262
|
Translation(pos)*Rotation(rot)*Scaling(scale) when applied to column vectors, i.e.: y = (TRS)*x
|
|
1244
1263
|
|
|
1245
|
-
..
|
|
1246
|
-
This function has been
|
|
1264
|
+
.. versionremoved:: 1.10
|
|
1265
|
+
This function has been removed in favor of :func:`warp.math.transform_compose()`.
|
|
1266
|
+
|
|
1267
|
+
.. deprecated:: 1.8""",
|
|
1247
1268
|
group="Vector Math",
|
|
1248
1269
|
export=False,
|
|
1249
1270
|
)
|
|
@@ -1438,7 +1459,7 @@ def quaternion_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str
|
|
|
1438
1459
|
|
|
1439
1460
|
if dtype is None:
|
|
1440
1461
|
dtype = value_type
|
|
1441
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1462
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
1442
1463
|
raise RuntimeError(
|
|
1443
1464
|
f"all values used to initialize this quaternion are expected to be of the type `{dtype.__name__}`"
|
|
1444
1465
|
)
|
|
@@ -1546,6 +1567,7 @@ add_builtin(
|
|
|
1546
1567
|
group="Quaternion Math",
|
|
1547
1568
|
doc="Construct an identity quaternion with zero imaginary part and real part of 1.0",
|
|
1548
1569
|
export=True,
|
|
1570
|
+
is_differentiable=False,
|
|
1549
1571
|
)
|
|
1550
1572
|
|
|
1551
1573
|
add_builtin(
|
|
@@ -1674,7 +1696,7 @@ def transformation_value_func(arg_types: Mapping[str, type], arg_values: Mapping
|
|
|
1674
1696
|
value_type = strip_reference(variadic_arg_types[0])
|
|
1675
1697
|
if dtype is None:
|
|
1676
1698
|
dtype = value_type
|
|
1677
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1699
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
1678
1700
|
raise RuntimeError(
|
|
1679
1701
|
f"the value used to fill this transform is expected to be of the type `{dtype.__name__}`"
|
|
1680
1702
|
)
|
|
@@ -1687,7 +1709,7 @@ def transformation_value_func(arg_types: Mapping[str, type], arg_values: Mapping
|
|
|
1687
1709
|
|
|
1688
1710
|
if dtype is None:
|
|
1689
1711
|
dtype = value_type
|
|
1690
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1712
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
1691
1713
|
raise RuntimeError(
|
|
1692
1714
|
f"all values used to initialize this transform are expected to be of the type `{dtype.__name__}`"
|
|
1693
1715
|
)
|
|
@@ -1712,7 +1734,7 @@ def transformation_pq_value_func(arg_types: Mapping[str, type], arg_values: Mapp
|
|
|
1712
1734
|
dtype = arg_values.get("dtype", None)
|
|
1713
1735
|
if dtype is None:
|
|
1714
1736
|
dtype = value_type
|
|
1715
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1737
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
1716
1738
|
raise RuntimeError(
|
|
1717
1739
|
f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
|
|
1718
1740
|
)
|
|
@@ -1727,9 +1749,19 @@ def transformation_dispatch_func(input_types: Mapping[str, type], return_type: A
|
|
|
1727
1749
|
|
|
1728
1750
|
dtype = return_type._wp_scalar_type_
|
|
1729
1751
|
|
|
1730
|
-
variadic_args =
|
|
1752
|
+
variadic_args = args.get("args", ())
|
|
1753
|
+
variadic_arg_count = len(variadic_args)
|
|
1754
|
+
|
|
1755
|
+
if variadic_arg_count == 7:
|
|
1756
|
+
func_args = variadic_args
|
|
1757
|
+
else:
|
|
1758
|
+
func_args = tuple(v for k, v in args.items() if k != "dtype")
|
|
1759
|
+
if "p" in args and "q" not in args:
|
|
1760
|
+
quat_ident = warp._src.codegen.Var(
|
|
1761
|
+
label=None, type=quaternion(dtype=dtype), constant=quaternion(dtype=dtype)(0, 0, 0, 1)
|
|
1762
|
+
)
|
|
1763
|
+
func_args += (quat_ident,)
|
|
1731
1764
|
|
|
1732
|
-
func_args = variadic_args
|
|
1733
1765
|
template_args = (dtype,)
|
|
1734
1766
|
return (func_args, template_args)
|
|
1735
1767
|
|
|
@@ -1737,7 +1769,7 @@ def transformation_dispatch_func(input_types: Mapping[str, type], return_type: A
|
|
|
1737
1769
|
add_builtin(
|
|
1738
1770
|
"transformation",
|
|
1739
1771
|
input_types={"p": vector(length=3, dtype=Float), "q": quaternion(dtype=Float), "dtype": Float},
|
|
1740
|
-
defaults={"dtype": None},
|
|
1772
|
+
defaults={"q": None, "dtype": None},
|
|
1741
1773
|
value_func=transformation_pq_value_func,
|
|
1742
1774
|
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
1743
1775
|
dispatch_func=transformation_dispatch_func,
|
|
@@ -1795,6 +1827,7 @@ add_builtin(
|
|
|
1795
1827
|
group="Transformations",
|
|
1796
1828
|
doc="Construct an identity transform with zero translation and identity rotation.",
|
|
1797
1829
|
export=True,
|
|
1830
|
+
is_differentiable=False,
|
|
1798
1831
|
)
|
|
1799
1832
|
|
|
1800
1833
|
add_builtin(
|
|
@@ -1928,7 +1961,7 @@ def spatial_vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping
|
|
|
1928
1961
|
|
|
1929
1962
|
if dtype is None:
|
|
1930
1963
|
dtype = value_type
|
|
1931
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1964
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
1932
1965
|
raise RuntimeError(
|
|
1933
1966
|
f"all values used to initialize this spatial vector are expected to be of the type `{dtype.__name__}`"
|
|
1934
1967
|
)
|
|
@@ -2122,7 +2155,7 @@ add_builtin(
|
|
|
2122
2155
|
value_func=tile_zeros_value_func,
|
|
2123
2156
|
dispatch_func=tile_zeros_dispatch_func,
|
|
2124
2157
|
variadic=False,
|
|
2125
|
-
|
|
2158
|
+
is_differentiable=False,
|
|
2126
2159
|
doc="""Allocate a tile of zero-initialized items.
|
|
2127
2160
|
|
|
2128
2161
|
:param shape: Shape of the output tile
|
|
@@ -2142,7 +2175,7 @@ add_builtin(
|
|
|
2142
2175
|
value_func=tile_zeros_value_func,
|
|
2143
2176
|
dispatch_func=tile_zeros_dispatch_func,
|
|
2144
2177
|
variadic=False,
|
|
2145
|
-
|
|
2178
|
+
is_differentiable=False,
|
|
2146
2179
|
hidden=True,
|
|
2147
2180
|
group="Tile Primitives",
|
|
2148
2181
|
export=False,
|
|
@@ -2194,7 +2227,7 @@ add_builtin(
|
|
|
2194
2227
|
defaults={"storage": "register"},
|
|
2195
2228
|
value_func=tile_ones_value_func,
|
|
2196
2229
|
dispatch_func=tile_ones_dispatch_func,
|
|
2197
|
-
|
|
2230
|
+
is_differentiable=False,
|
|
2198
2231
|
doc="""Allocate a tile of one-initialized items.
|
|
2199
2232
|
|
|
2200
2233
|
:param shape: Shape of the output tile
|
|
@@ -2213,7 +2246,86 @@ add_builtin(
|
|
|
2213
2246
|
defaults={"storage": "register"},
|
|
2214
2247
|
value_func=tile_ones_value_func,
|
|
2215
2248
|
dispatch_func=tile_ones_dispatch_func,
|
|
2216
|
-
|
|
2249
|
+
is_differentiable=False,
|
|
2250
|
+
hidden=True,
|
|
2251
|
+
group="Tile Primitives",
|
|
2252
|
+
export=False,
|
|
2253
|
+
)
|
|
2254
|
+
|
|
2255
|
+
|
|
2256
|
+
def tile_full_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
2257
|
+
# return generic type (for doc builds)
|
|
2258
|
+
if arg_types is None:
|
|
2259
|
+
return tile(dtype=Any, shape=Tuple[int, ...])
|
|
2260
|
+
|
|
2261
|
+
shape = extract_tuple(arg_values["shape"], as_constant=True)
|
|
2262
|
+
|
|
2263
|
+
if None in shape:
|
|
2264
|
+
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
2265
|
+
|
|
2266
|
+
if "value" not in arg_values:
|
|
2267
|
+
raise TypeError("tile_full() missing required keyword argument 'value'")
|
|
2268
|
+
|
|
2269
|
+
if "dtype" not in arg_values:
|
|
2270
|
+
raise TypeError("tile_full() missing required keyword argument 'dtype'")
|
|
2271
|
+
|
|
2272
|
+
if "storage" not in arg_values:
|
|
2273
|
+
raise TypeError("tile_full() missing required keyword argument 'storage'")
|
|
2274
|
+
|
|
2275
|
+
if arg_values["storage"] not in {"shared", "register"}:
|
|
2276
|
+
raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
|
|
2277
|
+
|
|
2278
|
+
dtype = arg_values["dtype"]
|
|
2279
|
+
|
|
2280
|
+
return tile(dtype=dtype, shape=shape, storage=arg_values["storage"])
|
|
2281
|
+
|
|
2282
|
+
|
|
2283
|
+
def tile_full_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
2284
|
+
shape = extract_tuple(arg_values["shape"], as_constant=True)
|
|
2285
|
+
|
|
2286
|
+
if None in shape:
|
|
2287
|
+
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
2288
|
+
|
|
2289
|
+
dtype = arg_values["dtype"]
|
|
2290
|
+
value = arg_values["value"]
|
|
2291
|
+
|
|
2292
|
+
func_args = [value]
|
|
2293
|
+
|
|
2294
|
+
template_args = []
|
|
2295
|
+
template_args.append(dtype)
|
|
2296
|
+
template_args.extend(shape)
|
|
2297
|
+
|
|
2298
|
+
return (func_args, template_args)
|
|
2299
|
+
|
|
2300
|
+
|
|
2301
|
+
add_builtin(
|
|
2302
|
+
"tile_full",
|
|
2303
|
+
input_types={"shape": Tuple[int, ...], "value": Any, "dtype": Any, "storage": str},
|
|
2304
|
+
defaults={"storage": "register"},
|
|
2305
|
+
value_func=tile_full_value_func,
|
|
2306
|
+
dispatch_func=tile_full_dispatch_func,
|
|
2307
|
+
is_differentiable=False,
|
|
2308
|
+
doc="""Allocate a tile filled with the specified value.
|
|
2309
|
+
|
|
2310
|
+
:param shape: Shape of the output tile
|
|
2311
|
+
:param value: Value to fill the tile with
|
|
2312
|
+
:param dtype: Data type of output tile's elements
|
|
2313
|
+
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
2314
|
+
(default) or ``"shared"`` for shared memory.
|
|
2315
|
+
:returns: A tile filled with the specified value""",
|
|
2316
|
+
group="Tile Primitives",
|
|
2317
|
+
export=False,
|
|
2318
|
+
)
|
|
2319
|
+
|
|
2320
|
+
|
|
2321
|
+
# overload for scalar shape
|
|
2322
|
+
add_builtin(
|
|
2323
|
+
"tile_full",
|
|
2324
|
+
input_types={"shape": int, "value": Any, "dtype": Any, "storage": str},
|
|
2325
|
+
defaults={"storage": "register"},
|
|
2326
|
+
value_func=tile_full_value_func,
|
|
2327
|
+
dispatch_func=tile_full_dispatch_func,
|
|
2328
|
+
is_differentiable=False,
|
|
2217
2329
|
hidden=True,
|
|
2218
2330
|
group="Tile Primitives",
|
|
2219
2331
|
export=False,
|
|
@@ -2275,13 +2387,13 @@ def tile_arange_dispatch_func(arg_types: Mapping[str, type], return_type: Any, a
|
|
|
2275
2387
|
args = arg_values["args"]
|
|
2276
2388
|
|
|
2277
2389
|
if len(args) == 1:
|
|
2278
|
-
start = warp.codegen.Var(label=None, type=return_type.dtype, constant=0)
|
|
2390
|
+
start = warp._src.codegen.Var(label=None, type=return_type.dtype, constant=0)
|
|
2279
2391
|
stop = args[0]
|
|
2280
|
-
step = warp.codegen.Var(label=None, type=return_type.dtype, constant=1)
|
|
2392
|
+
step = warp._src.codegen.Var(label=None, type=return_type.dtype, constant=1)
|
|
2281
2393
|
elif len(args) == 2:
|
|
2282
2394
|
start = args[0]
|
|
2283
2395
|
stop = args[1]
|
|
2284
|
-
step = warp.codegen.Var(label=None, type=return_type.dtype, constant=1)
|
|
2396
|
+
step = warp._src.codegen.Var(label=None, type=return_type.dtype, constant=1)
|
|
2285
2397
|
elif len(args) == 3:
|
|
2286
2398
|
start = args[0]
|
|
2287
2399
|
stop = args[1]
|
|
@@ -2304,7 +2416,7 @@ add_builtin(
|
|
|
2304
2416
|
value_func=tile_arange_value_func,
|
|
2305
2417
|
dispatch_func=tile_arange_dispatch_func,
|
|
2306
2418
|
variadic=True,
|
|
2307
|
-
|
|
2419
|
+
is_differentiable=False,
|
|
2308
2420
|
doc="""Generate a tile of linearly spaced elements.
|
|
2309
2421
|
|
|
2310
2422
|
:param args: Variable-length positional arguments, interpreted as:
|
|
@@ -3099,7 +3211,7 @@ add_builtin(
|
|
|
3099
3211
|
:param shape: Shape of the returned slice
|
|
3100
3212
|
:returns: A tile with dimensions given by the specified shape or the remaining source tile dimensions""",
|
|
3101
3213
|
group="Tile Primitives",
|
|
3102
|
-
|
|
3214
|
+
is_differentiable=False,
|
|
3103
3215
|
export=False,
|
|
3104
3216
|
)
|
|
3105
3217
|
|
|
@@ -3346,7 +3458,32 @@ add_builtin(
|
|
|
3346
3458
|
|
|
3347
3459
|
add_builtin(
|
|
3348
3460
|
"assign",
|
|
3349
|
-
input_types={"dst": tile(dtype=Any, shape=Tuple[int,
|
|
3461
|
+
input_types={"dst": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "src": Any},
|
|
3462
|
+
value_func=tile_assign_value_func,
|
|
3463
|
+
group="Tile Primitives",
|
|
3464
|
+
export=False,
|
|
3465
|
+
hidden=True,
|
|
3466
|
+
)
|
|
3467
|
+
|
|
3468
|
+
add_builtin(
|
|
3469
|
+
"assign",
|
|
3470
|
+
input_types={"dst": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "src": Any},
|
|
3471
|
+
value_func=tile_assign_value_func,
|
|
3472
|
+
group="Tile Primitives",
|
|
3473
|
+
export=False,
|
|
3474
|
+
hidden=True,
|
|
3475
|
+
)
|
|
3476
|
+
|
|
3477
|
+
add_builtin(
|
|
3478
|
+
"assign",
|
|
3479
|
+
input_types={
|
|
3480
|
+
"dst": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
3481
|
+
"i": int,
|
|
3482
|
+
"j": int,
|
|
3483
|
+
"k": int,
|
|
3484
|
+
"l": int,
|
|
3485
|
+
"src": Any,
|
|
3486
|
+
},
|
|
3350
3487
|
value_func=tile_assign_value_func,
|
|
3351
3488
|
group="Tile Primitives",
|
|
3352
3489
|
export=False,
|
|
@@ -3355,7 +3492,15 @@ add_builtin(
|
|
|
3355
3492
|
|
|
3356
3493
|
add_builtin(
|
|
3357
3494
|
"assign",
|
|
3358
|
-
input_types={
|
|
3495
|
+
input_types={
|
|
3496
|
+
"dst": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
3497
|
+
"i": int,
|
|
3498
|
+
"j": int,
|
|
3499
|
+
"k": int,
|
|
3500
|
+
"l": int,
|
|
3501
|
+
"m": int,
|
|
3502
|
+
"src": Any,
|
|
3503
|
+
},
|
|
3359
3504
|
value_func=tile_assign_value_func,
|
|
3360
3505
|
group="Tile Primitives",
|
|
3361
3506
|
export=False,
|
|
@@ -3370,6 +3515,8 @@ add_builtin(
|
|
|
3370
3515
|
"j": int,
|
|
3371
3516
|
"k": int,
|
|
3372
3517
|
"l": int,
|
|
3518
|
+
"m": int,
|
|
3519
|
+
"n": int,
|
|
3373
3520
|
"src": Any,
|
|
3374
3521
|
},
|
|
3375
3522
|
value_func=tile_assign_value_func,
|
|
@@ -3391,7 +3538,7 @@ def tile_value_func(arg_types, arg_values):
|
|
|
3391
3538
|
|
|
3392
3539
|
if preserve_type:
|
|
3393
3540
|
dtype = arg_types["x"]
|
|
3394
|
-
shape = (warp.codegen.options["block_dim"],)
|
|
3541
|
+
shape = (warp._src.codegen.options["block_dim"],)
|
|
3395
3542
|
|
|
3396
3543
|
return tile(dtype=dtype, shape=shape)
|
|
3397
3544
|
|
|
@@ -3399,18 +3546,18 @@ def tile_value_func(arg_types, arg_values):
|
|
|
3399
3546
|
if type_is_vector(arg_types["x"]):
|
|
3400
3547
|
dtype = arg_types["x"]._wp_scalar_type_
|
|
3401
3548
|
length = arg_types["x"]._shape_[0]
|
|
3402
|
-
shape = (length, warp.codegen.options["block_dim"])
|
|
3549
|
+
shape = (length, warp._src.codegen.options["block_dim"])
|
|
3403
3550
|
elif type_is_quaternion(arg_types["x"]):
|
|
3404
3551
|
dtype = arg_types["x"]._wp_scalar_type_
|
|
3405
|
-
shape = (4, warp.codegen.options["block_dim"])
|
|
3552
|
+
shape = (4, warp._src.codegen.options["block_dim"])
|
|
3406
3553
|
elif type_is_matrix(arg_types["x"]):
|
|
3407
3554
|
dtype = arg_types["x"]._wp_scalar_type_
|
|
3408
3555
|
rows = arg_types["x"]._shape_[0]
|
|
3409
3556
|
cols = arg_types["x"]._shape_[1]
|
|
3410
|
-
shape = (rows, cols, warp.codegen.options["block_dim"])
|
|
3557
|
+
shape = (rows, cols, warp._src.codegen.options["block_dim"])
|
|
3411
3558
|
else:
|
|
3412
3559
|
dtype = arg_types["x"]
|
|
3413
|
-
shape = (warp.codegen.options["block_dim"],)
|
|
3560
|
+
shape = (warp._src.codegen.options["block_dim"],)
|
|
3414
3561
|
|
|
3415
3562
|
return tile(dtype=dtype, shape=shape)
|
|
3416
3563
|
|
|
@@ -3500,17 +3647,17 @@ def untile_value_func(arg_types, arg_values):
|
|
|
3500
3647
|
if not is_tile(t):
|
|
3501
3648
|
raise TypeError(f"untile() argument must be a tile, got {t!r}")
|
|
3502
3649
|
|
|
3503
|
-
if t.shape[-1] != warp.codegen.options["block_dim"]:
|
|
3650
|
+
if t.shape[-1] != warp._src.codegen.options["block_dim"]:
|
|
3504
3651
|
raise ValueError(
|
|
3505
|
-
f"untile() argument last dimension {t.shape[-1]} does not match the expected block width {warp.codegen.options['block_dim']}"
|
|
3652
|
+
f"untile() argument last dimension {t.shape[-1]} does not match the expected block width {warp._src.codegen.options['block_dim']}"
|
|
3506
3653
|
)
|
|
3507
3654
|
|
|
3508
3655
|
if len(t.shape) == 1:
|
|
3509
3656
|
return t.dtype
|
|
3510
3657
|
elif len(t.shape) == 2:
|
|
3511
|
-
return warp.types.vector(t.shape[0], t.dtype)
|
|
3658
|
+
return warp._src.types.vector(t.shape[0], t.dtype)
|
|
3512
3659
|
elif len(t.shape) == 3:
|
|
3513
|
-
return warp.types.matrix((t.shape[0], t.shape[1]), t.dtype)
|
|
3660
|
+
return warp._src.types.matrix((t.shape[0], t.shape[1]), t.dtype)
|
|
3514
3661
|
else:
|
|
3515
3662
|
raise ValueError(f"untile() argument must have a positive size in dimension 0, but got {t.shape[0]}")
|
|
3516
3663
|
|
|
@@ -3572,7 +3719,36 @@ def tile_extract_value_func(arg_types, arg_values):
|
|
|
3572
3719
|
# force the input tile to shared memory
|
|
3573
3720
|
arg_types["a"].storage = "shared"
|
|
3574
3721
|
|
|
3575
|
-
|
|
3722
|
+
# count the number of indices (all parameters except the tile "a")
|
|
3723
|
+
num_indices = len(arg_types) - 1
|
|
3724
|
+
tile_dtype = arg_types["a"].dtype
|
|
3725
|
+
tile_shape = arg_types["a"].shape
|
|
3726
|
+
|
|
3727
|
+
if type_is_vector(tile_dtype):
|
|
3728
|
+
if num_indices == len(tile_shape):
|
|
3729
|
+
return tile_dtype
|
|
3730
|
+
elif num_indices == len(tile_shape) + 1:
|
|
3731
|
+
return tile_dtype._wp_scalar_type_
|
|
3732
|
+
else:
|
|
3733
|
+
raise IndexError(
|
|
3734
|
+
f"tile_extract: incorrect number of indices ({num_indices}) for tile shape {tuple(tile_shape)}"
|
|
3735
|
+
)
|
|
3736
|
+
elif type_is_matrix(tile_dtype):
|
|
3737
|
+
if num_indices == len(tile_shape):
|
|
3738
|
+
return tile_dtype
|
|
3739
|
+
elif num_indices == len(tile_shape) + 2:
|
|
3740
|
+
return tile_dtype._wp_scalar_type_
|
|
3741
|
+
else:
|
|
3742
|
+
raise IndexError(
|
|
3743
|
+
f"tile_extract: incorrect number of indices ({num_indices}) for matrix tile shape {tuple(tile_shape)}"
|
|
3744
|
+
)
|
|
3745
|
+
else:
|
|
3746
|
+
# scalar element: index count must exactly match tile rank
|
|
3747
|
+
if num_indices == len(tile_shape):
|
|
3748
|
+
return tile_dtype
|
|
3749
|
+
raise IndexError(
|
|
3750
|
+
f"tile_extract: incorrect number of indices ({num_indices}) for tile shape {tuple(tile_shape)}"
|
|
3751
|
+
)
|
|
3576
3752
|
|
|
3577
3753
|
|
|
3578
3754
|
add_builtin(
|
|
@@ -3596,7 +3772,7 @@ add_builtin(
|
|
|
3596
3772
|
|
|
3597
3773
|
add_builtin(
|
|
3598
3774
|
"tile_extract",
|
|
3599
|
-
input_types={"a": tile(dtype=Any, shape=Tuple[int,
|
|
3775
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int},
|
|
3600
3776
|
value_func=tile_extract_value_func,
|
|
3601
3777
|
variadic=False,
|
|
3602
3778
|
doc="""Extract a single element from the tile.
|
|
@@ -3607,7 +3783,28 @@ add_builtin(
|
|
|
3607
3783
|
|
|
3608
3784
|
:param a: Tile to extract the element from
|
|
3609
3785
|
:param i: Coordinate of element on first dimension
|
|
3610
|
-
:param j: Coordinate of element on the second dimension
|
|
3786
|
+
:param j: Coordinate of element on the second dimension, or vector index
|
|
3787
|
+
:returns: The value of the element at the specified tile location with the same data type as the input tile""",
|
|
3788
|
+
group="Tile Primitives",
|
|
3789
|
+
hidden=True,
|
|
3790
|
+
export=False,
|
|
3791
|
+
)
|
|
3792
|
+
|
|
3793
|
+
add_builtin(
|
|
3794
|
+
"tile_extract",
|
|
3795
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int},
|
|
3796
|
+
value_func=tile_extract_value_func,
|
|
3797
|
+
variadic=False,
|
|
3798
|
+
doc="""Extract a single element from the tile.
|
|
3799
|
+
|
|
3800
|
+
This function will extract an element from the tile and broadcast its value to all threads in the block.
|
|
3801
|
+
|
|
3802
|
+
Note that this may incur additional synchronization if the source tile is a register tile.
|
|
3803
|
+
|
|
3804
|
+
:param a: Tile to extract the element from
|
|
3805
|
+
:param i: Coordinate of element on first dimension
|
|
3806
|
+
:param j: Coordinate of element on the second dimension, or first matrix index
|
|
3807
|
+
:param k: Coordinate of element on the third dimension, or vector index, or second matrix index
|
|
3611
3808
|
:returns: The value of the element at the specified tile location with the same data type as the input tile""",
|
|
3612
3809
|
group="Tile Primitives",
|
|
3613
3810
|
hidden=True,
|
|
@@ -3616,7 +3813,36 @@ add_builtin(
|
|
|
3616
3813
|
|
|
3617
3814
|
add_builtin(
|
|
3618
3815
|
"tile_extract",
|
|
3619
|
-
input_types={"a": tile(dtype=Any, shape=Tuple[int,
|
|
3816
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int},
|
|
3817
|
+
value_func=tile_extract_value_func,
|
|
3818
|
+
variadic=False,
|
|
3819
|
+
doc="""Extract a single element from the tile.
|
|
3820
|
+
|
|
3821
|
+
This function will extract an element from the tile and broadcast its value to all threads in the block.
|
|
3822
|
+
|
|
3823
|
+
Note that this may incur additional synchronization if the source tile is a register tile.
|
|
3824
|
+
|
|
3825
|
+
:param a: Tile to extract the element from
|
|
3826
|
+
:param i: Coordinate of element on first dimension
|
|
3827
|
+
:param j: Coordinate of element on the second dimension
|
|
3828
|
+
:param k: Coordinate of element on the third dimension, or first matrix index
|
|
3829
|
+
:param l: Coordinate of element on the fourth dimension, or vector index, or second matrix index
|
|
3830
|
+
:returns: The value of the element at the specified tile location, with the same data type as the input tile""",
|
|
3831
|
+
group="Tile Primitives",
|
|
3832
|
+
hidden=True,
|
|
3833
|
+
export=False,
|
|
3834
|
+
)
|
|
3835
|
+
|
|
3836
|
+
add_builtin(
|
|
3837
|
+
"tile_extract",
|
|
3838
|
+
input_types={
|
|
3839
|
+
"a": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
3840
|
+
"i": int,
|
|
3841
|
+
"j": int,
|
|
3842
|
+
"k": int,
|
|
3843
|
+
"l": int,
|
|
3844
|
+
"m": int,
|
|
3845
|
+
},
|
|
3620
3846
|
value_func=tile_extract_value_func,
|
|
3621
3847
|
variadic=False,
|
|
3622
3848
|
doc="""Extract a single element from the tile.
|
|
@@ -3629,7 +3855,9 @@ add_builtin(
|
|
|
3629
3855
|
:param i: Coordinate of element on first dimension
|
|
3630
3856
|
:param j: Coordinate of element on the second dimension
|
|
3631
3857
|
:param k: Coordinate of element on the third dimension
|
|
3632
|
-
:
|
|
3858
|
+
:param l: Coordinate of element on the fourth dimension, or first matrix index
|
|
3859
|
+
:param m: Vector index, or second matrix index
|
|
3860
|
+
:returns: The value of the element at the specified tile location, with the same data type as the input tile""",
|
|
3633
3861
|
group="Tile Primitives",
|
|
3634
3862
|
hidden=True,
|
|
3635
3863
|
export=False,
|
|
@@ -3637,7 +3865,15 @@ add_builtin(
|
|
|
3637
3865
|
|
|
3638
3866
|
add_builtin(
|
|
3639
3867
|
"tile_extract",
|
|
3640
|
-
input_types={
|
|
3868
|
+
input_types={
|
|
3869
|
+
"a": tile(dtype=Any, shape=Tuple[int, int, int, int]),
|
|
3870
|
+
"i": int,
|
|
3871
|
+
"j": int,
|
|
3872
|
+
"k": int,
|
|
3873
|
+
"l": int,
|
|
3874
|
+
"m": int,
|
|
3875
|
+
"n": int,
|
|
3876
|
+
},
|
|
3641
3877
|
value_func=tile_extract_value_func,
|
|
3642
3878
|
variadic=False,
|
|
3643
3879
|
doc="""Extract a single element from the tile.
|
|
@@ -3651,6 +3887,8 @@ add_builtin(
|
|
|
3651
3887
|
:param j: Coordinate of element on the second dimension
|
|
3652
3888
|
:param k: Coordinate of element on the third dimension
|
|
3653
3889
|
:param l: Coordinate of element on the fourth dimension
|
|
3890
|
+
:param m: Vector index, or first matrix index
|
|
3891
|
+
:param n: Second matrix index
|
|
3654
3892
|
:returns: The value of the element at the specified tile location, with the same data type as the input tile""",
|
|
3655
3893
|
group="Tile Primitives",
|
|
3656
3894
|
hidden=True,
|
|
@@ -3737,50 +3975,161 @@ add_builtin(
|
|
|
3737
3975
|
export=False,
|
|
3738
3976
|
)
|
|
3739
3977
|
|
|
3740
|
-
|
|
3741
|
-
def tile_transpose_value_func(arg_types, arg_values):
|
|
3742
|
-
# return generic type (for doc builds)
|
|
3743
|
-
if arg_types is None:
|
|
3744
|
-
return tile(dtype=Any, shape=Tuple[int, int])
|
|
3745
|
-
|
|
3746
|
-
if len(arg_types) != 1:
|
|
3747
|
-
raise TypeError(f"tile_transpose() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
3748
|
-
|
|
3749
|
-
t = arg_types["a"]
|
|
3750
|
-
|
|
3751
|
-
if not is_tile(t):
|
|
3752
|
-
raise TypeError(f"tile_transpose() argument must be a tile, got {t!r}")
|
|
3753
|
-
|
|
3754
|
-
layout = None
|
|
3755
|
-
|
|
3756
|
-
# flip layout
|
|
3757
|
-
if t.layout == "rowmajor":
|
|
3758
|
-
layout = "colmajor"
|
|
3759
|
-
elif t.layout == "colmajor":
|
|
3760
|
-
layout = "rowmajor"
|
|
3761
|
-
|
|
3762
|
-
# force the input tile to shared memory
|
|
3763
|
-
t.storage = "shared"
|
|
3764
|
-
|
|
3765
|
-
return tile(
|
|
3766
|
-
dtype=t.dtype,
|
|
3767
|
-
shape=t.shape[::-1],
|
|
3768
|
-
storage=t.storage,
|
|
3769
|
-
strides=t.strides[::-1],
|
|
3770
|
-
layout=layout,
|
|
3771
|
-
owner=False,
|
|
3772
|
-
)
|
|
3773
|
-
|
|
3774
|
-
|
|
3775
3978
|
add_builtin(
|
|
3776
|
-
"
|
|
3777
|
-
input_types={"a": tile(dtype=Any, shape=Tuple[int,
|
|
3778
|
-
value_func=
|
|
3779
|
-
|
|
3780
|
-
|
|
3781
|
-
|
|
3782
|
-
|
|
3783
|
-
|
|
3979
|
+
"tile_bit_and_inplace",
|
|
3980
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
|
|
3981
|
+
value_func=tile_inplace_value_func,
|
|
3982
|
+
group="Tile Primitives",
|
|
3983
|
+
hidden=True,
|
|
3984
|
+
export=False,
|
|
3985
|
+
is_differentiable=False,
|
|
3986
|
+
)
|
|
3987
|
+
add_builtin(
|
|
3988
|
+
"tile_bit_and_inplace",
|
|
3989
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
|
|
3990
|
+
value_func=tile_inplace_value_func,
|
|
3991
|
+
group="Tile Primitives",
|
|
3992
|
+
hidden=True,
|
|
3993
|
+
export=False,
|
|
3994
|
+
is_differentiable=False,
|
|
3995
|
+
)
|
|
3996
|
+
add_builtin(
|
|
3997
|
+
"tile_bit_and_inplace",
|
|
3998
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
|
|
3999
|
+
value_func=tile_inplace_value_func,
|
|
4000
|
+
group="Tile Primitives",
|
|
4001
|
+
hidden=True,
|
|
4002
|
+
export=False,
|
|
4003
|
+
is_differentiable=False,
|
|
4004
|
+
)
|
|
4005
|
+
add_builtin(
|
|
4006
|
+
"tile_bit_and_inplace",
|
|
4007
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
|
|
4008
|
+
value_func=tile_inplace_value_func,
|
|
4009
|
+
group="Tile Primitives",
|
|
4010
|
+
hidden=True,
|
|
4011
|
+
export=False,
|
|
4012
|
+
is_differentiable=False,
|
|
4013
|
+
)
|
|
4014
|
+
|
|
4015
|
+
add_builtin(
|
|
4016
|
+
"tile_bit_or_inplace",
|
|
4017
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
|
|
4018
|
+
value_func=tile_inplace_value_func,
|
|
4019
|
+
group="Tile Primitives",
|
|
4020
|
+
hidden=True,
|
|
4021
|
+
export=False,
|
|
4022
|
+
is_differentiable=False,
|
|
4023
|
+
)
|
|
4024
|
+
add_builtin(
|
|
4025
|
+
"tile_bit_or_inplace",
|
|
4026
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
|
|
4027
|
+
value_func=tile_inplace_value_func,
|
|
4028
|
+
group="Tile Primitives",
|
|
4029
|
+
hidden=True,
|
|
4030
|
+
export=False,
|
|
4031
|
+
is_differentiable=False,
|
|
4032
|
+
)
|
|
4033
|
+
add_builtin(
|
|
4034
|
+
"tile_bit_or_inplace",
|
|
4035
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
|
|
4036
|
+
value_func=tile_inplace_value_func,
|
|
4037
|
+
group="Tile Primitives",
|
|
4038
|
+
hidden=True,
|
|
4039
|
+
export=False,
|
|
4040
|
+
is_differentiable=False,
|
|
4041
|
+
)
|
|
4042
|
+
add_builtin(
|
|
4043
|
+
"tile_bit_or_inplace",
|
|
4044
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
|
|
4045
|
+
value_func=tile_inplace_value_func,
|
|
4046
|
+
group="Tile Primitives",
|
|
4047
|
+
hidden=True,
|
|
4048
|
+
export=False,
|
|
4049
|
+
is_differentiable=False,
|
|
4050
|
+
)
|
|
4051
|
+
|
|
4052
|
+
add_builtin(
|
|
4053
|
+
"tile_bit_xor_inplace",
|
|
4054
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
|
|
4055
|
+
value_func=tile_inplace_value_func,
|
|
4056
|
+
group="Tile Primitives",
|
|
4057
|
+
hidden=True,
|
|
4058
|
+
export=False,
|
|
4059
|
+
is_differentiable=False,
|
|
4060
|
+
)
|
|
4061
|
+
add_builtin(
|
|
4062
|
+
"tile_bit_xor_inplace",
|
|
4063
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
|
|
4064
|
+
value_func=tile_inplace_value_func,
|
|
4065
|
+
group="Tile Primitives",
|
|
4066
|
+
hidden=True,
|
|
4067
|
+
export=False,
|
|
4068
|
+
is_differentiable=False,
|
|
4069
|
+
)
|
|
4070
|
+
add_builtin(
|
|
4071
|
+
"tile_bit_xor_inplace",
|
|
4072
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
|
|
4073
|
+
value_func=tile_inplace_value_func,
|
|
4074
|
+
group="Tile Primitives",
|
|
4075
|
+
hidden=True,
|
|
4076
|
+
export=False,
|
|
4077
|
+
is_differentiable=False,
|
|
4078
|
+
)
|
|
4079
|
+
add_builtin(
|
|
4080
|
+
"tile_bit_xor_inplace",
|
|
4081
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
|
|
4082
|
+
value_func=tile_inplace_value_func,
|
|
4083
|
+
group="Tile Primitives",
|
|
4084
|
+
hidden=True,
|
|
4085
|
+
export=False,
|
|
4086
|
+
is_differentiable=False,
|
|
4087
|
+
)
|
|
4088
|
+
|
|
4089
|
+
|
|
4090
|
+
def tile_transpose_value_func(arg_types, arg_values):
|
|
4091
|
+
# return generic type (for doc builds)
|
|
4092
|
+
if arg_types is None:
|
|
4093
|
+
return tile(dtype=Any, shape=Tuple[int, int])
|
|
4094
|
+
|
|
4095
|
+
if len(arg_types) != 1:
|
|
4096
|
+
raise TypeError(f"tile_transpose() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
4097
|
+
|
|
4098
|
+
t = arg_types["a"]
|
|
4099
|
+
|
|
4100
|
+
if not is_tile(t):
|
|
4101
|
+
raise TypeError(f"tile_transpose() argument must be a tile, got {t!r}")
|
|
4102
|
+
|
|
4103
|
+
layout = None
|
|
4104
|
+
|
|
4105
|
+
# flip layout
|
|
4106
|
+
if t.layout == "rowmajor":
|
|
4107
|
+
layout = "colmajor"
|
|
4108
|
+
elif t.layout == "colmajor":
|
|
4109
|
+
layout = "rowmajor"
|
|
4110
|
+
|
|
4111
|
+
# force the input tile to shared memory
|
|
4112
|
+
t.storage = "shared"
|
|
4113
|
+
|
|
4114
|
+
return tile(
|
|
4115
|
+
dtype=t.dtype,
|
|
4116
|
+
shape=t.shape[::-1],
|
|
4117
|
+
storage=t.storage,
|
|
4118
|
+
strides=t.strides[::-1],
|
|
4119
|
+
layout=layout,
|
|
4120
|
+
owner=False,
|
|
4121
|
+
)
|
|
4122
|
+
|
|
4123
|
+
|
|
4124
|
+
add_builtin(
|
|
4125
|
+
"tile_transpose",
|
|
4126
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, int])},
|
|
4127
|
+
value_func=tile_transpose_value_func,
|
|
4128
|
+
variadic=True,
|
|
4129
|
+
doc="""Transpose a tile.
|
|
4130
|
+
|
|
4131
|
+
For shared memory tiles, this operation will alias the input tile.
|
|
4132
|
+
Register tiles will first be transferred to shared memory before transposition.
|
|
3784
4133
|
|
|
3785
4134
|
:param a: Tile to transpose with ``shape=(M,N)``
|
|
3786
4135
|
:returns: Tile with ``shape=(N,M)``""",
|
|
@@ -3910,6 +4259,80 @@ add_builtin(
|
|
|
3910
4259
|
)
|
|
3911
4260
|
|
|
3912
4261
|
|
|
4262
|
+
def tile_sum_axis_value_func(arg_types, arg_values):
|
|
4263
|
+
if arg_types is None:
|
|
4264
|
+
return tile(dtype=Scalar, shape=Tuple[int, ...])
|
|
4265
|
+
|
|
4266
|
+
a = arg_types["a"]
|
|
4267
|
+
|
|
4268
|
+
if not is_tile(a):
|
|
4269
|
+
raise TypeError(f"tile_sum() 'a' argument must be a tile, got {a!r}")
|
|
4270
|
+
|
|
4271
|
+
# force input tile to shared
|
|
4272
|
+
a.storage = "shared"
|
|
4273
|
+
|
|
4274
|
+
axis = arg_values["axis"]
|
|
4275
|
+
shape = a.shape
|
|
4276
|
+
|
|
4277
|
+
if axis < 0 or axis >= len(shape):
|
|
4278
|
+
raise ValueError(f"tile_sum() axis {axis} is out of bounds for tile with {len(shape)} dimensions")
|
|
4279
|
+
|
|
4280
|
+
# shape is identical less the axis reduction is along
|
|
4281
|
+
if len(shape) > 1:
|
|
4282
|
+
new_shape = shape[:axis] + shape[axis + 1 :]
|
|
4283
|
+
else:
|
|
4284
|
+
new_shape = (1,)
|
|
4285
|
+
|
|
4286
|
+
return tile(dtype=a.dtype, shape=new_shape)
|
|
4287
|
+
|
|
4288
|
+
|
|
4289
|
+
def tile_sum_axis_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
4290
|
+
tile = arg_values["a"]
|
|
4291
|
+
axis_var = arg_values["axis"]
|
|
4292
|
+
if not hasattr(axis_var, "constant") or axis_var.constant is None:
|
|
4293
|
+
raise ValueError("tile_sum() axis must be a compile-time constant")
|
|
4294
|
+
axis = axis_var.constant
|
|
4295
|
+
|
|
4296
|
+
return ((tile,), (axis,))
|
|
4297
|
+
|
|
4298
|
+
|
|
4299
|
+
add_builtin(
|
|
4300
|
+
"tile_sum",
|
|
4301
|
+
input_types={"a": tile(dtype=Scalar, shape=Tuple[int, ...]), "axis": int},
|
|
4302
|
+
value_func=tile_sum_axis_value_func,
|
|
4303
|
+
dispatch_func=tile_sum_axis_dispatch_func,
|
|
4304
|
+
doc="""Cooperatively compute the sum of the tile elements across an axis of the tile using all threads in the block.
|
|
4305
|
+
|
|
4306
|
+
:param a: The input tile. Must reside in shared memory.
|
|
4307
|
+
:param axis: The tile axis to compute the sum across. Must be a compile-time constant.
|
|
4308
|
+
:returns: A tile with the same shape as the input tile less the axis dimension and the same data type as the input tile.
|
|
4309
|
+
|
|
4310
|
+
Example:
|
|
4311
|
+
|
|
4312
|
+
.. code-block:: python
|
|
4313
|
+
|
|
4314
|
+
@wp.kernel
|
|
4315
|
+
def compute():
|
|
4316
|
+
|
|
4317
|
+
t = wp.tile_ones(dtype=float, shape=(8, 8))
|
|
4318
|
+
s = wp.tile_sum(t, axis=0)
|
|
4319
|
+
|
|
4320
|
+
print(s)
|
|
4321
|
+
|
|
4322
|
+
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
|
|
4323
|
+
|
|
4324
|
+
Prints:
|
|
4325
|
+
|
|
4326
|
+
.. code-block:: text
|
|
4327
|
+
|
|
4328
|
+
[8 8 8 8 8 8 8 8] = tile(shape=(8), storage=register)
|
|
4329
|
+
|
|
4330
|
+
""",
|
|
4331
|
+
group="Tile Primitives",
|
|
4332
|
+
export=False,
|
|
4333
|
+
)
|
|
4334
|
+
|
|
4335
|
+
|
|
3913
4336
|
def tile_sort_value_func(arg_types, arg_values):
|
|
3914
4337
|
# return generic type (for doc builds)
|
|
3915
4338
|
if arg_types is None:
|
|
@@ -3986,6 +4409,7 @@ add_builtin(
|
|
|
3986
4409
|
""",
|
|
3987
4410
|
group="Tile Primitives",
|
|
3988
4411
|
export=False,
|
|
4412
|
+
is_differentiable=False,
|
|
3989
4413
|
)
|
|
3990
4414
|
|
|
3991
4415
|
|
|
@@ -4039,6 +4463,7 @@ add_builtin(
|
|
|
4039
4463
|
""",
|
|
4040
4464
|
group="Tile Primitives",
|
|
4041
4465
|
export=False,
|
|
4466
|
+
is_differentiable=False,
|
|
4042
4467
|
)
|
|
4043
4468
|
|
|
4044
4469
|
|
|
@@ -4092,6 +4517,7 @@ add_builtin(
|
|
|
4092
4517
|
""",
|
|
4093
4518
|
group="Tile Primitives",
|
|
4094
4519
|
export=False,
|
|
4520
|
+
is_differentiable=False,
|
|
4095
4521
|
)
|
|
4096
4522
|
|
|
4097
4523
|
|
|
@@ -4144,6 +4570,7 @@ add_builtin(
|
|
|
4144
4570
|
""",
|
|
4145
4571
|
group="Tile Primitives",
|
|
4146
4572
|
export=False,
|
|
4573
|
+
is_differentiable=False,
|
|
4147
4574
|
)
|
|
4148
4575
|
|
|
4149
4576
|
|
|
@@ -4196,10 +4623,10 @@ add_builtin(
|
|
|
4196
4623
|
""",
|
|
4197
4624
|
group="Tile Primitives",
|
|
4198
4625
|
export=False,
|
|
4626
|
+
is_differentiable=False,
|
|
4199
4627
|
)
|
|
4200
4628
|
|
|
4201
4629
|
|
|
4202
|
-
# does type propagation for load()
|
|
4203
4630
|
def tile_reduce_value_func(arg_types, arg_values):
|
|
4204
4631
|
if arg_types is None:
|
|
4205
4632
|
return tile(dtype=Scalar, shape=(1,))
|
|
@@ -4253,6 +4680,88 @@ add_builtin(
|
|
|
4253
4680
|
""",
|
|
4254
4681
|
group="Tile Primitives",
|
|
4255
4682
|
export=False,
|
|
4683
|
+
is_differentiable=False,
|
|
4684
|
+
)
|
|
4685
|
+
|
|
4686
|
+
|
|
4687
|
+
def tile_reduce_axis_value_func(arg_types, arg_values):
|
|
4688
|
+
if arg_types is None:
|
|
4689
|
+
return tile(dtype=Scalar, shape=Tuple[int, ...])
|
|
4690
|
+
|
|
4691
|
+
a = arg_types["a"]
|
|
4692
|
+
|
|
4693
|
+
if not is_tile(a):
|
|
4694
|
+
raise TypeError(f"tile_reduce() 'a' argument must be a tile, got {a!r}")
|
|
4695
|
+
|
|
4696
|
+
# force input tile to shared memory
|
|
4697
|
+
a.storage = "shared"
|
|
4698
|
+
|
|
4699
|
+
axis = arg_values["axis"]
|
|
4700
|
+
shape = a.shape
|
|
4701
|
+
|
|
4702
|
+
if axis < 0 or axis >= len(shape):
|
|
4703
|
+
raise ValueError(f"tile_reduce() axis {axis} is out of bounds for tile with {len(shape)} dimensions")
|
|
4704
|
+
|
|
4705
|
+
# shape is identical less the axis reduction is along
|
|
4706
|
+
if len(shape) > 1:
|
|
4707
|
+
new_shape = shape[:axis] + shape[axis + 1 :]
|
|
4708
|
+
else:
|
|
4709
|
+
new_shape = (1,)
|
|
4710
|
+
|
|
4711
|
+
return tile(dtype=a.dtype, shape=new_shape)
|
|
4712
|
+
|
|
4713
|
+
|
|
4714
|
+
add_builtin(
|
|
4715
|
+
"tile_reduce",
|
|
4716
|
+
input_types={"op": Callable, "a": tile(dtype=Scalar, shape=Tuple[int, ...]), "axis": int},
|
|
4717
|
+
value_func=tile_reduce_axis_value_func,
|
|
4718
|
+
native_func="tile_reduce_axis",
|
|
4719
|
+
doc="""Apply a custom reduction operator across a tile axis.
|
|
4720
|
+
|
|
4721
|
+
This function cooperatively performs a reduction using the provided operator across an axis of the tile.
|
|
4722
|
+
|
|
4723
|
+
:param op: A callable function that accepts two arguments and returns one argument, may be a user function or builtin
|
|
4724
|
+
:param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's data type. Must reside in shared memory.
|
|
4725
|
+
:param axis: The tile axis to perform the reduction across. Must be a compile-time constant.
|
|
4726
|
+
:returns: A tile with the same shape as the input tile less the axis dimension and the same data type as the input tile.
|
|
4727
|
+
|
|
4728
|
+
Example:
|
|
4729
|
+
|
|
4730
|
+
.. code-block:: python
|
|
4731
|
+
|
|
4732
|
+
TILE_M = wp.constant(4)
|
|
4733
|
+
TILE_N = wp.constant(2)
|
|
4734
|
+
|
|
4735
|
+
@wp.kernel
|
|
4736
|
+
def compute(x: wp.array2d(dtype=float), y: wp.array(dtype=float)):
|
|
4737
|
+
|
|
4738
|
+
a = wp.tile_load(x, shape=(TILE_M, TILE_N))
|
|
4739
|
+
b = wp.tile_reduce(wp.add, a, axis=1)
|
|
4740
|
+
wp.tile_store(y, b)
|
|
4741
|
+
|
|
4742
|
+
arr = np.arange(TILE_M * TILE_N).reshape(TILE_M, TILE_N)
|
|
4743
|
+
|
|
4744
|
+
x = wp.array(arr, dtype=float)
|
|
4745
|
+
y = wp.zeros(TILE_M, dtype=float)
|
|
4746
|
+
|
|
4747
|
+
wp.launch_tiled(compute, dim=[1], inputs=[x], outputs=[y], block_dim=32)
|
|
4748
|
+
|
|
4749
|
+
print(x.numpy())
|
|
4750
|
+
print(y.numpy())
|
|
4751
|
+
|
|
4752
|
+
Prints:
|
|
4753
|
+
|
|
4754
|
+
.. code-block:: text
|
|
4755
|
+
|
|
4756
|
+
[[0. 1.]
|
|
4757
|
+
[2. 3.]
|
|
4758
|
+
[4. 5.]
|
|
4759
|
+
[6. 7.]]
|
|
4760
|
+
[ 1. 5. 9. 13.]
|
|
4761
|
+
""",
|
|
4762
|
+
group="Tile Primitives",
|
|
4763
|
+
export=False,
|
|
4764
|
+
is_differentiable=False,
|
|
4256
4765
|
)
|
|
4257
4766
|
|
|
4258
4767
|
|
|
@@ -4316,6 +4825,7 @@ add_builtin(
|
|
|
4316
4825
|
""",
|
|
4317
4826
|
group="Tile Primitives",
|
|
4318
4827
|
export=False,
|
|
4828
|
+
is_differentiable=False,
|
|
4319
4829
|
)
|
|
4320
4830
|
|
|
4321
4831
|
|
|
@@ -4379,6 +4889,7 @@ add_builtin(
|
|
|
4379
4889
|
""",
|
|
4380
4890
|
group="Tile Primitives",
|
|
4381
4891
|
export=False,
|
|
4892
|
+
is_differentiable=False,
|
|
4382
4893
|
)
|
|
4383
4894
|
|
|
4384
4895
|
|
|
@@ -4632,6 +5143,7 @@ add_builtin(
|
|
|
4632
5143
|
doc="WIP",
|
|
4633
5144
|
group="Utility",
|
|
4634
5145
|
hidden=True,
|
|
5146
|
+
is_differentiable=False,
|
|
4635
5147
|
)
|
|
4636
5148
|
|
|
4637
5149
|
add_builtin(
|
|
@@ -4647,6 +5159,7 @@ add_builtin(
|
|
|
4647
5159
|
doc="WIP",
|
|
4648
5160
|
group="Utility",
|
|
4649
5161
|
hidden=True,
|
|
5162
|
+
is_differentiable=False,
|
|
4650
5163
|
)
|
|
4651
5164
|
|
|
4652
5165
|
add_builtin(
|
|
@@ -4656,6 +5169,7 @@ add_builtin(
|
|
|
4656
5169
|
doc="WIP",
|
|
4657
5170
|
group="Utility",
|
|
4658
5171
|
hidden=True,
|
|
5172
|
+
is_differentiable=False,
|
|
4659
5173
|
)
|
|
4660
5174
|
|
|
4661
5175
|
add_builtin(
|
|
@@ -4707,6 +5221,7 @@ add_builtin(
|
|
|
4707
5221
|
:param low: The lower bound of the bounding box in BVH space
|
|
4708
5222
|
:param high: The upper bound of the bounding box in BVH space""",
|
|
4709
5223
|
export=False,
|
|
5224
|
+
is_differentiable=False,
|
|
4710
5225
|
)
|
|
4711
5226
|
|
|
4712
5227
|
add_builtin(
|
|
@@ -4722,6 +5237,7 @@ add_builtin(
|
|
|
4722
5237
|
:param start: The start of the ray in BVH space
|
|
4723
5238
|
:param dir: The direction of the ray in BVH space""",
|
|
4724
5239
|
export=False,
|
|
5240
|
+
is_differentiable=False,
|
|
4725
5241
|
)
|
|
4726
5242
|
|
|
4727
5243
|
add_builtin(
|
|
@@ -4732,6 +5248,7 @@ add_builtin(
|
|
|
4732
5248
|
doc="""Move to the next bound returned by the query.
|
|
4733
5249
|
The index of the current bound is stored in ``index``, returns ``False`` if there are no more overlapping bound.""",
|
|
4734
5250
|
export=False,
|
|
5251
|
+
is_differentiable=False,
|
|
4735
5252
|
)
|
|
4736
5253
|
|
|
4737
5254
|
add_builtin(
|
|
@@ -5066,12 +5583,13 @@ add_builtin(
|
|
|
5066
5583
|
group="Geometry",
|
|
5067
5584
|
doc="""Construct an axis-aligned bounding box query against a :class:`Mesh`.
|
|
5068
5585
|
|
|
5069
|
-
This query can be used to iterate over all triangles inside a volume.
|
|
5586
|
+
This query can be used to iterate over all bounding boxes of the triangles inside a volume.
|
|
5070
5587
|
|
|
5071
5588
|
:param id: The mesh identifier
|
|
5072
5589
|
:param low: The lower bound of the bounding box in mesh space
|
|
5073
5590
|
:param high: The upper bound of the bounding box in mesh space""",
|
|
5074
5591
|
export=False,
|
|
5592
|
+
is_differentiable=False,
|
|
5075
5593
|
)
|
|
5076
5594
|
|
|
5077
5595
|
add_builtin(
|
|
@@ -5079,10 +5597,11 @@ add_builtin(
|
|
|
5079
5597
|
input_types={"query": MeshQueryAABB, "index": int},
|
|
5080
5598
|
value_type=builtins.bool,
|
|
5081
5599
|
group="Geometry",
|
|
5082
|
-
doc="""Move to the next triangle
|
|
5600
|
+
doc="""Move to the next triangle whose bounding box overlaps the query bounding box.
|
|
5083
5601
|
|
|
5084
5602
|
The index of the current face is stored in ``index``, returns ``False`` if there are no more overlapping triangles.""",
|
|
5085
5603
|
export=False,
|
|
5604
|
+
is_differentiable=False,
|
|
5086
5605
|
)
|
|
5087
5606
|
|
|
5088
5607
|
add_builtin(
|
|
@@ -5112,6 +5631,7 @@ add_builtin(
|
|
|
5112
5631
|
|
|
5113
5632
|
This query can be used to iterate over all neighboring point within a fixed radius from the query point.""",
|
|
5114
5633
|
export=False,
|
|
5634
|
+
is_differentiable=False,
|
|
5115
5635
|
)
|
|
5116
5636
|
|
|
5117
5637
|
add_builtin(
|
|
@@ -5123,6 +5643,7 @@ add_builtin(
|
|
|
5123
5643
|
|
|
5124
5644
|
The index of the current neighbor is stored in ``index``, returns ``False`` if there are no more neighbors.""",
|
|
5125
5645
|
export=False,
|
|
5646
|
+
is_differentiable=False,
|
|
5126
5647
|
)
|
|
5127
5648
|
|
|
5128
5649
|
add_builtin(
|
|
@@ -5136,6 +5657,7 @@ add_builtin(
|
|
|
5136
5657
|
|
|
5137
5658
|
Returns -1 if the :class:`HashGrid` has not been reserved.""",
|
|
5138
5659
|
export=False,
|
|
5660
|
+
is_differentiable=False,
|
|
5139
5661
|
)
|
|
5140
5662
|
|
|
5141
5663
|
add_builtin(
|
|
@@ -5145,15 +5667,34 @@ add_builtin(
|
|
|
5145
5667
|
group="Geometry",
|
|
5146
5668
|
doc="""Tests for intersection between two triangles (v0, v1, v2) and (u0, u1, u2) using Moller's method.
|
|
5147
5669
|
|
|
5670
|
+
This function works with single precision, may return incorrect results in some case.
|
|
5671
|
+
|
|
5148
5672
|
Returns > 0 if triangles intersect.""",
|
|
5149
5673
|
export=False,
|
|
5674
|
+
is_differentiable=False,
|
|
5150
5675
|
)
|
|
5151
5676
|
|
|
5677
|
+
|
|
5678
|
+
add_builtin(
|
|
5679
|
+
"intersect_tri_tri",
|
|
5680
|
+
input_types={"v0": vec3d, "v1": vec3d, "v2": vec3d, "u0": vec3d, "u1": vec3d, "u2": vec3d},
|
|
5681
|
+
value_type=int,
|
|
5682
|
+
group="Geometry",
|
|
5683
|
+
doc="""Tests for intersection between two triangles (v0, v1, v2) and (u0, u1, u2) using Moller's method.
|
|
5684
|
+
|
|
5685
|
+
This function works with double precision, results are more accurate than the single precision version.
|
|
5686
|
+
|
|
5687
|
+
Returns > 0 if triangles intersect.""",
|
|
5688
|
+
export=False,
|
|
5689
|
+
is_differentiable=False,
|
|
5690
|
+
)
|
|
5691
|
+
|
|
5692
|
+
|
|
5152
5693
|
add_builtin(
|
|
5153
5694
|
"mesh_get",
|
|
5154
5695
|
input_types={"id": uint64},
|
|
5155
5696
|
value_type=Mesh,
|
|
5156
|
-
|
|
5697
|
+
is_differentiable=False,
|
|
5157
5698
|
group="Geometry",
|
|
5158
5699
|
doc="""Retrieves the mesh given its index.""",
|
|
5159
5700
|
export=False,
|
|
@@ -5166,6 +5707,7 @@ add_builtin(
|
|
|
5166
5707
|
group="Geometry",
|
|
5167
5708
|
doc="""Evaluates the face normal the mesh given a face index.""",
|
|
5168
5709
|
export=False,
|
|
5710
|
+
is_differentiable=False,
|
|
5169
5711
|
)
|
|
5170
5712
|
|
|
5171
5713
|
add_builtin(
|
|
@@ -5175,6 +5717,7 @@ add_builtin(
|
|
|
5175
5717
|
group="Geometry",
|
|
5176
5718
|
doc="""Returns the point of the mesh given a index.""",
|
|
5177
5719
|
export=False,
|
|
5720
|
+
is_differentiable=False,
|
|
5178
5721
|
)
|
|
5179
5722
|
|
|
5180
5723
|
add_builtin(
|
|
@@ -5184,6 +5727,7 @@ add_builtin(
|
|
|
5184
5727
|
group="Geometry",
|
|
5185
5728
|
doc="""Returns the velocity of the mesh given a index.""",
|
|
5186
5729
|
export=False,
|
|
5730
|
+
is_differentiable=False,
|
|
5187
5731
|
)
|
|
5188
5732
|
|
|
5189
5733
|
add_builtin(
|
|
@@ -5193,6 +5737,7 @@ add_builtin(
|
|
|
5193
5737
|
group="Geometry",
|
|
5194
5738
|
doc="""Returns the point-index of the mesh given a face-vertex index.""",
|
|
5195
5739
|
export=False,
|
|
5740
|
+
is_differentiable=False,
|
|
5196
5741
|
)
|
|
5197
5742
|
|
|
5198
5743
|
|
|
@@ -5233,12 +5778,32 @@ add_builtin(
|
|
|
5233
5778
|
# ---------------------------------
|
|
5234
5779
|
# Iterators
|
|
5235
5780
|
|
|
5236
|
-
add_builtin("iter_next", input_types={"range": range_t}, value_type=int, group="Utility", export=False, hidden=True)
|
|
5237
5781
|
add_builtin(
|
|
5238
|
-
"iter_next",
|
|
5782
|
+
"iter_next",
|
|
5783
|
+
input_types={"range": range_t},
|
|
5784
|
+
value_type=int,
|
|
5785
|
+
group="Utility",
|
|
5786
|
+
export=False,
|
|
5787
|
+
hidden=True,
|
|
5788
|
+
is_differentiable=False,
|
|
5789
|
+
)
|
|
5790
|
+
add_builtin(
|
|
5791
|
+
"iter_next",
|
|
5792
|
+
input_types={"query": HashGridQuery},
|
|
5793
|
+
value_type=int,
|
|
5794
|
+
group="Utility",
|
|
5795
|
+
export=False,
|
|
5796
|
+
hidden=True,
|
|
5797
|
+
is_differentiable=False,
|
|
5239
5798
|
)
|
|
5240
5799
|
add_builtin(
|
|
5241
|
-
"iter_next",
|
|
5800
|
+
"iter_next",
|
|
5801
|
+
input_types={"query": MeshQueryAABB},
|
|
5802
|
+
value_type=int,
|
|
5803
|
+
group="Utility",
|
|
5804
|
+
export=False,
|
|
5805
|
+
hidden=True,
|
|
5806
|
+
is_differentiable=False,
|
|
5242
5807
|
)
|
|
5243
5808
|
|
|
5244
5809
|
add_builtin(
|
|
@@ -5249,6 +5814,7 @@ add_builtin(
|
|
|
5249
5814
|
group="Utility",
|
|
5250
5815
|
doc="""Returns the range in reversed order.""",
|
|
5251
5816
|
export=False,
|
|
5817
|
+
is_differentiable=False,
|
|
5252
5818
|
)
|
|
5253
5819
|
|
|
5254
5820
|
# ---------------------------------
|
|
@@ -5268,8 +5834,8 @@ _volume_supported_value_types = {
|
|
|
5268
5834
|
|
|
5269
5835
|
|
|
5270
5836
|
def _is_volume_type_supported(dtype):
|
|
5271
|
-
for
|
|
5272
|
-
if types_equal(
|
|
5837
|
+
for value_type in _volume_supported_value_types:
|
|
5838
|
+
if types_equal(value_type, dtype):
|
|
5273
5839
|
return True
|
|
5274
5840
|
return False
|
|
5275
5841
|
|
|
@@ -5397,6 +5963,7 @@ add_builtin(
|
|
|
5397
5963
|
doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k`` for a volume of type type `dtype`.
|
|
5398
5964
|
|
|
5399
5965
|
If the voxel at this index does not exist, this function returns the background value.""",
|
|
5966
|
+
is_differentiable=False,
|
|
5400
5967
|
)
|
|
5401
5968
|
|
|
5402
5969
|
|
|
@@ -5417,6 +5984,7 @@ add_builtin(
|
|
|
5417
5984
|
export=False,
|
|
5418
5985
|
group="Volumes",
|
|
5419
5986
|
doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
|
|
5987
|
+
is_differentiable=False,
|
|
5420
5988
|
)
|
|
5421
5989
|
|
|
5422
5990
|
add_builtin(
|
|
@@ -5447,6 +6015,7 @@ add_builtin(
|
|
|
5447
6015
|
doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k``.
|
|
5448
6016
|
|
|
5449
6017
|
If the voxel at this index does not exist, this function returns the background value""",
|
|
6018
|
+
is_differentiable=False,
|
|
5450
6019
|
)
|
|
5451
6020
|
|
|
5452
6021
|
add_builtin(
|
|
@@ -5455,6 +6024,7 @@ add_builtin(
|
|
|
5455
6024
|
group="Volumes",
|
|
5456
6025
|
doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
|
|
5457
6026
|
export=False,
|
|
6027
|
+
is_differentiable=False,
|
|
5458
6028
|
)
|
|
5459
6029
|
|
|
5460
6030
|
add_builtin(
|
|
@@ -5475,6 +6045,7 @@ add_builtin(
|
|
|
5475
6045
|
doc="""Returns the vector value of voxel with coordinates ``i``, ``j``, ``k``.
|
|
5476
6046
|
|
|
5477
6047
|
If the voxel at this index does not exist, this function returns the background value.""",
|
|
6048
|
+
is_differentiable=False,
|
|
5478
6049
|
)
|
|
5479
6050
|
|
|
5480
6051
|
add_builtin(
|
|
@@ -5483,6 +6054,7 @@ add_builtin(
|
|
|
5483
6054
|
group="Volumes",
|
|
5484
6055
|
doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
|
|
5485
6056
|
export=False,
|
|
6057
|
+
is_differentiable=False,
|
|
5486
6058
|
)
|
|
5487
6059
|
|
|
5488
6060
|
add_builtin(
|
|
@@ -5501,6 +6073,7 @@ add_builtin(
|
|
|
5501
6073
|
doc="""Returns the :class:`int32` value of voxel with coordinates ``i``, ``j``, ``k``.
|
|
5502
6074
|
|
|
5503
6075
|
If the voxel at this index does not exist, this function returns the background value.""",
|
|
6076
|
+
is_differentiable=False,
|
|
5504
6077
|
)
|
|
5505
6078
|
|
|
5506
6079
|
add_builtin(
|
|
@@ -5509,6 +6082,7 @@ add_builtin(
|
|
|
5509
6082
|
group="Volumes",
|
|
5510
6083
|
doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
|
|
5511
6084
|
export=False,
|
|
6085
|
+
is_differentiable=False,
|
|
5512
6086
|
)
|
|
5513
6087
|
|
|
5514
6088
|
|
|
@@ -5590,6 +6164,7 @@ add_builtin(
|
|
|
5590
6164
|
If the voxel at this index does not exist, this function returns -1.
|
|
5591
6165
|
This function is available for both index grids and classical volumes.
|
|
5592
6166
|
""",
|
|
6167
|
+
is_differentiable=False,
|
|
5593
6168
|
)
|
|
5594
6169
|
|
|
5595
6170
|
add_builtin(
|
|
@@ -5631,6 +6206,7 @@ add_builtin(
|
|
|
5631
6206
|
value_type=uint32,
|
|
5632
6207
|
group="Random",
|
|
5633
6208
|
doc="Initialize a new random number generator given a user-defined seed. Returns a 32-bit integer representing the RNG state.",
|
|
6209
|
+
is_differentiable=False,
|
|
5634
6210
|
)
|
|
5635
6211
|
|
|
5636
6212
|
add_builtin(
|
|
@@ -5642,6 +6218,7 @@ add_builtin(
|
|
|
5642
6218
|
|
|
5643
6219
|
This alternative constructor can be useful in parallel programs, where a kernel as a whole should share a seed,
|
|
5644
6220
|
but each thread should generate uncorrelated values. In this case usage should be ``r = rand_init(seed, tid)``""",
|
|
6221
|
+
is_differentiable=False,
|
|
5645
6222
|
)
|
|
5646
6223
|
|
|
5647
6224
|
add_builtin(
|
|
@@ -5650,6 +6227,7 @@ add_builtin(
|
|
|
5650
6227
|
value_type=int,
|
|
5651
6228
|
group="Random",
|
|
5652
6229
|
doc="Return a random integer in the range [-2^31, 2^31).",
|
|
6230
|
+
is_differentiable=False,
|
|
5653
6231
|
)
|
|
5654
6232
|
add_builtin(
|
|
5655
6233
|
"randi",
|
|
@@ -5657,6 +6235,7 @@ add_builtin(
|
|
|
5657
6235
|
value_type=int,
|
|
5658
6236
|
group="Random",
|
|
5659
6237
|
doc="Return a random integer between [low, high).",
|
|
6238
|
+
is_differentiable=False,
|
|
5660
6239
|
)
|
|
5661
6240
|
add_builtin(
|
|
5662
6241
|
"randu",
|
|
@@ -5664,6 +6243,7 @@ add_builtin(
|
|
|
5664
6243
|
value_type=uint32,
|
|
5665
6244
|
group="Random",
|
|
5666
6245
|
doc="Return a random unsigned integer in the range [0, 2^32).",
|
|
6246
|
+
is_differentiable=False,
|
|
5667
6247
|
)
|
|
5668
6248
|
add_builtin(
|
|
5669
6249
|
"randu",
|
|
@@ -5671,6 +6251,7 @@ add_builtin(
|
|
|
5671
6251
|
value_type=uint32,
|
|
5672
6252
|
group="Random",
|
|
5673
6253
|
doc="Return a random unsigned integer between [low, high).",
|
|
6254
|
+
is_differentiable=False,
|
|
5674
6255
|
)
|
|
5675
6256
|
add_builtin(
|
|
5676
6257
|
"randf",
|
|
@@ -5678,6 +6259,7 @@ add_builtin(
|
|
|
5678
6259
|
value_type=float,
|
|
5679
6260
|
group="Random",
|
|
5680
6261
|
doc="Return a random float between [0.0, 1.0).",
|
|
6262
|
+
is_differentiable=False,
|
|
5681
6263
|
)
|
|
5682
6264
|
add_builtin(
|
|
5683
6265
|
"randf",
|
|
@@ -5685,6 +6267,7 @@ add_builtin(
|
|
|
5685
6267
|
value_type=float,
|
|
5686
6268
|
group="Random",
|
|
5687
6269
|
doc="Return a random float between [low, high).",
|
|
6270
|
+
is_differentiable=False,
|
|
5688
6271
|
)
|
|
5689
6272
|
add_builtin(
|
|
5690
6273
|
"randn",
|
|
@@ -5692,6 +6275,7 @@ add_builtin(
|
|
|
5692
6275
|
value_type=float,
|
|
5693
6276
|
group="Random",
|
|
5694
6277
|
doc="Sample a normal (Gaussian) distribution of mean 0 and variance 1. ",
|
|
6278
|
+
is_differentiable=False,
|
|
5695
6279
|
)
|
|
5696
6280
|
|
|
5697
6281
|
add_builtin(
|
|
@@ -5700,6 +6284,7 @@ add_builtin(
|
|
|
5700
6284
|
value_type=int,
|
|
5701
6285
|
group="Random",
|
|
5702
6286
|
doc="Inverse-transform sample a cumulative distribution function.",
|
|
6287
|
+
is_differentiable=False,
|
|
5703
6288
|
)
|
|
5704
6289
|
add_builtin(
|
|
5705
6290
|
"sample_triangle",
|
|
@@ -5707,6 +6292,7 @@ add_builtin(
|
|
|
5707
6292
|
value_type=vec2,
|
|
5708
6293
|
group="Random",
|
|
5709
6294
|
doc="Uniformly sample a triangle. Returns sample barycentric coordinates.",
|
|
6295
|
+
is_differentiable=False,
|
|
5710
6296
|
)
|
|
5711
6297
|
add_builtin(
|
|
5712
6298
|
"sample_unit_ring",
|
|
@@ -5714,6 +6300,7 @@ add_builtin(
|
|
|
5714
6300
|
value_type=vec2,
|
|
5715
6301
|
group="Random",
|
|
5716
6302
|
doc="Uniformly sample a ring in the xy plane.",
|
|
6303
|
+
is_differentiable=False,
|
|
5717
6304
|
)
|
|
5718
6305
|
add_builtin(
|
|
5719
6306
|
"sample_unit_disk",
|
|
@@ -5721,6 +6308,7 @@ add_builtin(
|
|
|
5721
6308
|
value_type=vec2,
|
|
5722
6309
|
group="Random",
|
|
5723
6310
|
doc="Uniformly sample a disk in the xy plane.",
|
|
6311
|
+
is_differentiable=False,
|
|
5724
6312
|
)
|
|
5725
6313
|
add_builtin(
|
|
5726
6314
|
"sample_unit_sphere_surface",
|
|
@@ -5728,6 +6316,7 @@ add_builtin(
|
|
|
5728
6316
|
value_type=vec3,
|
|
5729
6317
|
group="Random",
|
|
5730
6318
|
doc="Uniformly sample a unit sphere surface.",
|
|
6319
|
+
is_differentiable=False,
|
|
5731
6320
|
)
|
|
5732
6321
|
add_builtin(
|
|
5733
6322
|
"sample_unit_sphere",
|
|
@@ -5735,6 +6324,7 @@ add_builtin(
|
|
|
5735
6324
|
value_type=vec3,
|
|
5736
6325
|
group="Random",
|
|
5737
6326
|
doc="Uniformly sample a unit sphere.",
|
|
6327
|
+
is_differentiable=False,
|
|
5738
6328
|
)
|
|
5739
6329
|
add_builtin(
|
|
5740
6330
|
"sample_unit_hemisphere_surface",
|
|
@@ -5742,6 +6332,7 @@ add_builtin(
|
|
|
5742
6332
|
value_type=vec3,
|
|
5743
6333
|
group="Random",
|
|
5744
6334
|
doc="Uniformly sample a unit hemisphere surface.",
|
|
6335
|
+
is_differentiable=False,
|
|
5745
6336
|
)
|
|
5746
6337
|
add_builtin(
|
|
5747
6338
|
"sample_unit_hemisphere",
|
|
@@ -5749,6 +6340,7 @@ add_builtin(
|
|
|
5749
6340
|
value_type=vec3,
|
|
5750
6341
|
group="Random",
|
|
5751
6342
|
doc="Uniformly sample a unit hemisphere.",
|
|
6343
|
+
is_differentiable=False,
|
|
5752
6344
|
)
|
|
5753
6345
|
add_builtin(
|
|
5754
6346
|
"sample_unit_square",
|
|
@@ -5756,6 +6348,7 @@ add_builtin(
|
|
|
5756
6348
|
value_type=vec2,
|
|
5757
6349
|
group="Random",
|
|
5758
6350
|
doc="Uniformly sample a unit square.",
|
|
6351
|
+
is_differentiable=False,
|
|
5759
6352
|
)
|
|
5760
6353
|
add_builtin(
|
|
5761
6354
|
"sample_unit_cube",
|
|
@@ -5763,6 +6356,7 @@ add_builtin(
|
|
|
5763
6356
|
value_type=vec3,
|
|
5764
6357
|
group="Random",
|
|
5765
6358
|
doc="Uniformly sample a unit cube.",
|
|
6359
|
+
is_differentiable=False,
|
|
5766
6360
|
)
|
|
5767
6361
|
|
|
5768
6362
|
add_builtin(
|
|
@@ -5774,6 +6368,7 @@ add_builtin(
|
|
|
5774
6368
|
|
|
5775
6369
|
:param state: RNG state
|
|
5776
6370
|
:param lam: The expected value of the distribution""",
|
|
6371
|
+
is_differentiable=False,
|
|
5777
6372
|
)
|
|
5778
6373
|
|
|
5779
6374
|
add_builtin(
|
|
@@ -5841,7 +6436,7 @@ add_builtin(
|
|
|
5841
6436
|
value_type=vec2,
|
|
5842
6437
|
group="Random",
|
|
5843
6438
|
doc="Divergence-free vector field based on the gradient of a Perlin noise function.",
|
|
5844
|
-
|
|
6439
|
+
is_differentiable=False,
|
|
5845
6440
|
)
|
|
5846
6441
|
add_builtin(
|
|
5847
6442
|
"curlnoise",
|
|
@@ -5850,7 +6445,7 @@ add_builtin(
|
|
|
5850
6445
|
value_type=vec3,
|
|
5851
6446
|
group="Random",
|
|
5852
6447
|
doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
|
|
5853
|
-
|
|
6448
|
+
is_differentiable=False,
|
|
5854
6449
|
)
|
|
5855
6450
|
add_builtin(
|
|
5856
6451
|
"curlnoise",
|
|
@@ -5859,7 +6454,7 @@ add_builtin(
|
|
|
5859
6454
|
value_type=vec3,
|
|
5860
6455
|
group="Random",
|
|
5861
6456
|
doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
|
|
5862
|
-
|
|
6457
|
+
is_differentiable=False,
|
|
5863
6458
|
)
|
|
5864
6459
|
|
|
5865
6460
|
|
|
@@ -5891,9 +6486,16 @@ add_builtin(
|
|
|
5891
6486
|
dispatch_func=printf_dispatch_func,
|
|
5892
6487
|
group="Utility",
|
|
5893
6488
|
doc="Allows printing formatted strings using C-style format specifiers.",
|
|
6489
|
+
is_differentiable=False,
|
|
5894
6490
|
)
|
|
5895
6491
|
|
|
5896
|
-
add_builtin(
|
|
6492
|
+
add_builtin(
|
|
6493
|
+
"print",
|
|
6494
|
+
input_types={"value": Any},
|
|
6495
|
+
doc="Print variable to stdout",
|
|
6496
|
+
export=False,
|
|
6497
|
+
group="Utility",
|
|
6498
|
+
)
|
|
5897
6499
|
|
|
5898
6500
|
add_builtin(
|
|
5899
6501
|
"breakpoint",
|
|
@@ -5903,6 +6505,7 @@ add_builtin(
|
|
|
5903
6505
|
group="Utility",
|
|
5904
6506
|
namespace="",
|
|
5905
6507
|
native_func="__debugbreak",
|
|
6508
|
+
is_differentiable=False,
|
|
5906
6509
|
)
|
|
5907
6510
|
|
|
5908
6511
|
# helpers
|
|
@@ -5920,6 +6523,7 @@ add_builtin(
|
|
|
5920
6523
|
This function may not be called from user-defined Warp functions.""",
|
|
5921
6524
|
namespace="",
|
|
5922
6525
|
native_func="builtin_tid1d",
|
|
6526
|
+
is_differentiable=False,
|
|
5923
6527
|
)
|
|
5924
6528
|
|
|
5925
6529
|
add_builtin(
|
|
@@ -5930,6 +6534,7 @@ add_builtin(
|
|
|
5930
6534
|
doc="Returns the number of threads in the current block.",
|
|
5931
6535
|
namespace="",
|
|
5932
6536
|
native_func="builtin_block_dim",
|
|
6537
|
+
is_differentiable=False,
|
|
5933
6538
|
)
|
|
5934
6539
|
|
|
5935
6540
|
add_builtin(
|
|
@@ -5944,6 +6549,7 @@ add_builtin(
|
|
|
5944
6549
|
This function may not be called from user-defined Warp functions.""",
|
|
5945
6550
|
namespace="",
|
|
5946
6551
|
native_func="builtin_tid2d",
|
|
6552
|
+
is_differentiable=False,
|
|
5947
6553
|
)
|
|
5948
6554
|
|
|
5949
6555
|
add_builtin(
|
|
@@ -5958,6 +6564,7 @@ add_builtin(
|
|
|
5958
6564
|
This function may not be called from user-defined Warp functions.""",
|
|
5959
6565
|
namespace="",
|
|
5960
6566
|
native_func="builtin_tid3d",
|
|
6567
|
+
is_differentiable=False,
|
|
5961
6568
|
)
|
|
5962
6569
|
|
|
5963
6570
|
add_builtin(
|
|
@@ -5972,17 +6579,37 @@ add_builtin(
|
|
|
5972
6579
|
This function may not be called from user-defined Warp functions.""",
|
|
5973
6580
|
namespace="",
|
|
5974
6581
|
native_func="builtin_tid4d",
|
|
6582
|
+
is_differentiable=False,
|
|
5975
6583
|
)
|
|
5976
6584
|
|
|
5977
6585
|
|
|
6586
|
+
def copy_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
6587
|
+
a = arg_types["a"]
|
|
6588
|
+
|
|
6589
|
+
# if the input is a shared tile, we force a copy
|
|
6590
|
+
if is_tile(a) and a.storage == "shared":
|
|
6591
|
+
return tile(
|
|
6592
|
+
dtype=a.dtype,
|
|
6593
|
+
shape=a.shape,
|
|
6594
|
+
storage=a.storage,
|
|
6595
|
+
strides=a.strides,
|
|
6596
|
+
layout=a.layout,
|
|
6597
|
+
owner=True,
|
|
6598
|
+
)
|
|
6599
|
+
|
|
6600
|
+
return a
|
|
6601
|
+
|
|
6602
|
+
|
|
5978
6603
|
add_builtin(
|
|
5979
6604
|
"copy",
|
|
5980
6605
|
input_types={"a": Any},
|
|
5981
|
-
value_func=
|
|
6606
|
+
value_func=copy_value_func,
|
|
5982
6607
|
hidden=True,
|
|
5983
6608
|
export=False,
|
|
5984
6609
|
group="Utility",
|
|
5985
6610
|
)
|
|
6611
|
+
|
|
6612
|
+
|
|
5986
6613
|
add_builtin(
|
|
5987
6614
|
"assign",
|
|
5988
6615
|
input_types={"dest": Any, "src": Any},
|
|
@@ -5992,61 +6619,88 @@ add_builtin(
|
|
|
5992
6619
|
)
|
|
5993
6620
|
|
|
5994
6621
|
|
|
5995
|
-
def
|
|
5996
|
-
|
|
5997
|
-
|
|
5998
|
-
"version. Use wp.where(cond, value_if_true, value_if_false) instead.",
|
|
5999
|
-
category=DeprecationWarning,
|
|
6000
|
-
)
|
|
6001
|
-
|
|
6002
|
-
func_args = tuple(args.values())
|
|
6003
|
-
template_args = ()
|
|
6622
|
+
def select_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
6623
|
+
if arg_types is None:
|
|
6624
|
+
return Any
|
|
6004
6625
|
|
|
6005
|
-
|
|
6626
|
+
raise RuntimeError("wp.select() has been removed. Use wp.where(cond, value_if_true, value_if_false) instead.")
|
|
6006
6627
|
|
|
6007
6628
|
|
|
6008
6629
|
add_builtin(
|
|
6009
6630
|
"select",
|
|
6010
6631
|
input_types={"cond": builtins.bool, "value_if_false": Any, "value_if_true": Any},
|
|
6011
|
-
value_func=
|
|
6012
|
-
dispatch_func=select_dispatch_func,
|
|
6632
|
+
value_func=select_value_func,
|
|
6013
6633
|
doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
6014
6634
|
|
|
6015
|
-
..
|
|
6635
|
+
.. versionremoved:: 1.10
|
|
6016
6636
|
Use :func:`where` instead, which has the more intuitive argument order:
|
|
6017
|
-
``where(cond, value_if_true, value_if_false)``.
|
|
6637
|
+
``where(cond, value_if_true, value_if_false)``.
|
|
6638
|
+
|
|
6639
|
+
.. deprecated:: 1.7""",
|
|
6018
6640
|
group="Utility",
|
|
6019
6641
|
)
|
|
6020
6642
|
for t in int_types:
|
|
6021
6643
|
add_builtin(
|
|
6022
6644
|
"select",
|
|
6023
6645
|
input_types={"cond": t, "value_if_false": Any, "value_if_true": Any},
|
|
6024
|
-
value_func=
|
|
6025
|
-
dispatch_func=select_dispatch_func,
|
|
6646
|
+
value_func=select_value_func,
|
|
6026
6647
|
doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
6027
6648
|
|
|
6028
|
-
..
|
|
6649
|
+
.. versionremoved:: 1.10
|
|
6029
6650
|
Use :func:`where` instead, which has the more intuitive argument order:
|
|
6030
|
-
``where(cond, value_if_true, value_if_false)``.
|
|
6651
|
+
``where(cond, value_if_true, value_if_false)``.
|
|
6652
|
+
|
|
6653
|
+
.. deprecated:: 1.7""",
|
|
6031
6654
|
group="Utility",
|
|
6032
6655
|
)
|
|
6033
6656
|
add_builtin(
|
|
6034
6657
|
"select",
|
|
6035
6658
|
input_types={"arr": array(dtype=Any), "value_if_false": Any, "value_if_true": Any},
|
|
6036
|
-
value_func=
|
|
6037
|
-
dispatch_func=select_dispatch_func,
|
|
6659
|
+
value_func=select_value_func,
|
|
6038
6660
|
doc="""Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
6039
6661
|
|
|
6040
|
-
..
|
|
6662
|
+
.. versionremoved:: 1.10
|
|
6041
6663
|
Use :func:`where` instead, which has the more intuitive argument order:
|
|
6042
|
-
``where(arr, value_if_true, value_if_false)``.
|
|
6664
|
+
``where(arr, value_if_true, value_if_false)``.
|
|
6665
|
+
|
|
6666
|
+
.. deprecated:: 1.7""",
|
|
6043
6667
|
group="Utility",
|
|
6044
6668
|
)
|
|
6045
6669
|
|
|
6670
|
+
|
|
6671
|
+
def where_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
6672
|
+
if arg_types is None:
|
|
6673
|
+
return Any
|
|
6674
|
+
|
|
6675
|
+
v_true = arg_types["value_if_true"]
|
|
6676
|
+
v_false = arg_types["value_if_false"]
|
|
6677
|
+
|
|
6678
|
+
if not types_equal(v_true, v_false):
|
|
6679
|
+
raise RuntimeError(f"where() true value type ({v_true}) must be of the same type as the false type ({v_false})")
|
|
6680
|
+
|
|
6681
|
+
if is_tile(v_false):
|
|
6682
|
+
if v_true.storage == "register":
|
|
6683
|
+
return v_true
|
|
6684
|
+
if v_false.storage == "register":
|
|
6685
|
+
return v_false
|
|
6686
|
+
|
|
6687
|
+
# both v_true and v_false are shared
|
|
6688
|
+
return tile(
|
|
6689
|
+
dtype=v_true.dtype,
|
|
6690
|
+
shape=v_true.shape,
|
|
6691
|
+
storage=v_true.storage,
|
|
6692
|
+
strides=v_true.strides,
|
|
6693
|
+
layout=v_true.layout,
|
|
6694
|
+
owner=True,
|
|
6695
|
+
)
|
|
6696
|
+
|
|
6697
|
+
return v_true
|
|
6698
|
+
|
|
6699
|
+
|
|
6046
6700
|
add_builtin(
|
|
6047
6701
|
"where",
|
|
6048
6702
|
input_types={"cond": builtins.bool, "value_if_true": Any, "value_if_false": Any},
|
|
6049
|
-
value_func=
|
|
6703
|
+
value_func=where_value_func,
|
|
6050
6704
|
doc="Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``.",
|
|
6051
6705
|
group="Utility",
|
|
6052
6706
|
)
|
|
@@ -6054,14 +6708,14 @@ for t in int_types:
|
|
|
6054
6708
|
add_builtin(
|
|
6055
6709
|
"where",
|
|
6056
6710
|
input_types={"cond": t, "value_if_true": Any, "value_if_false": Any},
|
|
6057
|
-
value_func=
|
|
6711
|
+
value_func=where_value_func,
|
|
6058
6712
|
doc="Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``.",
|
|
6059
6713
|
group="Utility",
|
|
6060
6714
|
)
|
|
6061
6715
|
add_builtin(
|
|
6062
6716
|
"where",
|
|
6063
6717
|
input_types={"arr": array(dtype=Any), "value_if_true": Any, "value_if_false": Any},
|
|
6064
|
-
value_func=
|
|
6718
|
+
value_func=where_value_func,
|
|
6065
6719
|
doc="Select between two arguments, if ``arr`` is not null then return ``value_if_true``, otherwise return ``value_if_false``.",
|
|
6066
6720
|
group="Utility",
|
|
6067
6721
|
)
|
|
@@ -6099,7 +6753,7 @@ add_builtin(
|
|
|
6099
6753
|
group="Utility",
|
|
6100
6754
|
hidden=True,
|
|
6101
6755
|
export=False,
|
|
6102
|
-
|
|
6756
|
+
is_differentiable=False,
|
|
6103
6757
|
)
|
|
6104
6758
|
|
|
6105
6759
|
|
|
@@ -6140,7 +6794,7 @@ add_builtin(
|
|
|
6140
6794
|
native_func="fixedarray_t",
|
|
6141
6795
|
group="Utility",
|
|
6142
6796
|
export=False,
|
|
6143
|
-
|
|
6797
|
+
is_differentiable=False,
|
|
6144
6798
|
hidden=True, # Unhide once we can document both a built-in and a Python scope function sharing the same name.
|
|
6145
6799
|
)
|
|
6146
6800
|
|
|
@@ -6183,14 +6837,13 @@ for array_type in array_types:
|
|
|
6183
6837
|
# does argument checking and type propagation for view()
|
|
6184
6838
|
def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
6185
6839
|
arr_type = arg_types["arr"]
|
|
6186
|
-
idx_types = tuple(arg_types[x] for x in "
|
|
6840
|
+
idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
|
|
6187
6841
|
|
|
6188
6842
|
if not is_array(arr_type):
|
|
6189
6843
|
raise RuntimeError("view() first argument must be an array")
|
|
6190
6844
|
|
|
6191
6845
|
idx_count = len(idx_types)
|
|
6192
|
-
|
|
6193
|
-
if idx_count >= arr_type.ndim:
|
|
6846
|
+
if idx_count > arr_type.ndim:
|
|
6194
6847
|
raise RuntimeError(
|
|
6195
6848
|
f"Trying to create an array view with {idx_count} indices, "
|
|
6196
6849
|
f"but the array only has {arr_type.ndim} dimension(s). "
|
|
@@ -6198,14 +6851,35 @@ def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]
|
|
|
6198
6851
|
f"the expected number of dimensions, e.g.: def func(param: wp.array3d(dtype=float): ..."
|
|
6199
6852
|
)
|
|
6200
6853
|
|
|
6201
|
-
|
|
6202
|
-
|
|
6203
|
-
|
|
6204
|
-
|
|
6854
|
+
has_slice = any(is_slice(x) for x in idx_types)
|
|
6855
|
+
if has_slice:
|
|
6856
|
+
# check index types
|
|
6857
|
+
for t in idx_types:
|
|
6858
|
+
if not (type_is_int(t) or is_slice(t)):
|
|
6859
|
+
raise RuntimeError(
|
|
6860
|
+
f"view() index arguments must be of integer or slice types, got index of type {type_repr(t)}"
|
|
6861
|
+
)
|
|
6862
|
+
|
|
6863
|
+
# Each integer index collapses one dimension.
|
|
6864
|
+
int_count = sum(x.step == 0 for x in idx_types)
|
|
6865
|
+
ndim = arr_type.ndim - int_count
|
|
6866
|
+
assert ndim > 0
|
|
6867
|
+
else:
|
|
6868
|
+
if idx_count == arr_type.ndim:
|
|
6869
|
+
raise RuntimeError("Expected to call `address()` instead of `view()`")
|
|
6870
|
+
|
|
6871
|
+
# check index types
|
|
6872
|
+
for t in idx_types:
|
|
6873
|
+
if not type_is_int(t):
|
|
6874
|
+
raise RuntimeError(
|
|
6875
|
+
f"view() index arguments must be of integer or slice types, got index of type {type_repr(t)}"
|
|
6876
|
+
)
|
|
6877
|
+
|
|
6878
|
+
# create an array view with leading dimensions removed
|
|
6879
|
+
ndim = arr_type.ndim - idx_count
|
|
6880
|
+
assert ndim > 0
|
|
6205
6881
|
|
|
6206
|
-
# create an array view with leading dimensions removed
|
|
6207
6882
|
dtype = arr_type.dtype
|
|
6208
|
-
ndim = arr_type.ndim - idx_count
|
|
6209
6883
|
if isinstance(arr_type, (fabricarray, indexedfabricarray)):
|
|
6210
6884
|
# fabric array of arrays: return array attribute as a regular array
|
|
6211
6885
|
return array(dtype=dtype, ndim=ndim)
|
|
@@ -6216,8 +6890,18 @@ def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]
|
|
|
6216
6890
|
for array_type in array_types:
|
|
6217
6891
|
add_builtin(
|
|
6218
6892
|
"view",
|
|
6219
|
-
input_types={
|
|
6220
|
-
|
|
6893
|
+
input_types={
|
|
6894
|
+
"arr": array_type(dtype=Any),
|
|
6895
|
+
"i": Any,
|
|
6896
|
+
"j": Any,
|
|
6897
|
+
"k": Any,
|
|
6898
|
+
"l": Any,
|
|
6899
|
+
},
|
|
6900
|
+
defaults={
|
|
6901
|
+
"j": None,
|
|
6902
|
+
"k": None,
|
|
6903
|
+
"l": None,
|
|
6904
|
+
},
|
|
6221
6905
|
constraint=sametypes,
|
|
6222
6906
|
hidden=True,
|
|
6223
6907
|
value_func=view_value_func,
|
|
@@ -6321,6 +7005,7 @@ add_builtin(
|
|
|
6321
7005
|
hidden=True,
|
|
6322
7006
|
skip_replay=True,
|
|
6323
7007
|
group="Utility",
|
|
7008
|
+
is_differentiable=False,
|
|
6324
7009
|
)
|
|
6325
7010
|
|
|
6326
7011
|
|
|
@@ -6337,6 +7022,7 @@ add_builtin(
|
|
|
6337
7022
|
dispatch_func=load_dispatch_func,
|
|
6338
7023
|
hidden=True,
|
|
6339
7024
|
group="Utility",
|
|
7025
|
+
is_differentiable=False,
|
|
6340
7026
|
)
|
|
6341
7027
|
|
|
6342
7028
|
|
|
@@ -6412,6 +7098,13 @@ def create_atomic_op_value_func(op: str):
|
|
|
6412
7098
|
f"atomic_{op}() operations only work on arrays with [u]int32, [u]int64, float32, or float64 "
|
|
6413
7099
|
f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
|
|
6414
7100
|
)
|
|
7101
|
+
elif op in ("and", "or", "xor"):
|
|
7102
|
+
supported_atomic_types = (warp.int32, warp.int64, warp.uint32, warp.uint64)
|
|
7103
|
+
if not any(types_equal(scalar_type, x, match_generic=True) for x in supported_atomic_types):
|
|
7104
|
+
raise RuntimeError(
|
|
7105
|
+
f"atomic_{op}() operations only work on arrays with [u]int32 or [u]int64 "
|
|
7106
|
+
f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
|
|
7107
|
+
)
|
|
6415
7108
|
else:
|
|
6416
7109
|
raise NotImplementedError
|
|
6417
7110
|
|
|
@@ -6653,6 +7346,7 @@ for array_type in array_types:
|
|
|
6653
7346
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6654
7347
|
group="Utility",
|
|
6655
7348
|
skip_replay=True,
|
|
7349
|
+
is_differentiable=False,
|
|
6656
7350
|
)
|
|
6657
7351
|
add_builtin(
|
|
6658
7352
|
"atomic_cas",
|
|
@@ -6666,6 +7360,7 @@ for array_type in array_types:
|
|
|
6666
7360
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6667
7361
|
group="Utility",
|
|
6668
7362
|
skip_replay=True,
|
|
7363
|
+
is_differentiable=False,
|
|
6669
7364
|
)
|
|
6670
7365
|
add_builtin(
|
|
6671
7366
|
"atomic_cas",
|
|
@@ -6679,6 +7374,7 @@ for array_type in array_types:
|
|
|
6679
7374
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6680
7375
|
group="Utility",
|
|
6681
7376
|
skip_replay=True,
|
|
7377
|
+
is_differentiable=False,
|
|
6682
7378
|
)
|
|
6683
7379
|
add_builtin(
|
|
6684
7380
|
"atomic_cas",
|
|
@@ -6700,6 +7396,7 @@ for array_type in array_types:
|
|
|
6700
7396
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6701
7397
|
group="Utility",
|
|
6702
7398
|
skip_replay=True,
|
|
7399
|
+
is_differentiable=False,
|
|
6703
7400
|
)
|
|
6704
7401
|
|
|
6705
7402
|
add_builtin(
|
|
@@ -6714,6 +7411,7 @@ for array_type in array_types:
|
|
|
6714
7411
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6715
7412
|
group="Utility",
|
|
6716
7413
|
skip_replay=True,
|
|
7414
|
+
is_differentiable=False,
|
|
6717
7415
|
)
|
|
6718
7416
|
add_builtin(
|
|
6719
7417
|
"atomic_exch",
|
|
@@ -6727,32 +7425,193 @@ for array_type in array_types:
|
|
|
6727
7425
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6728
7426
|
group="Utility",
|
|
6729
7427
|
skip_replay=True,
|
|
7428
|
+
is_differentiable=False,
|
|
7429
|
+
)
|
|
7430
|
+
add_builtin(
|
|
7431
|
+
"atomic_exch",
|
|
7432
|
+
hidden=hidden,
|
|
7433
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
7434
|
+
constraint=atomic_op_constraint,
|
|
7435
|
+
value_func=create_atomic_op_value_func("exch"),
|
|
7436
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7437
|
+
doc="""Atomically exchange ``value`` with ``arr[i,j,k]`` and return the old value.
|
|
7438
|
+
|
|
7439
|
+
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
7440
|
+
group="Utility",
|
|
7441
|
+
skip_replay=True,
|
|
7442
|
+
is_differentiable=False,
|
|
7443
|
+
)
|
|
7444
|
+
add_builtin(
|
|
7445
|
+
"atomic_exch",
|
|
7446
|
+
hidden=hidden,
|
|
7447
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
7448
|
+
constraint=atomic_op_constraint,
|
|
7449
|
+
value_func=create_atomic_op_value_func("exch"),
|
|
7450
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7451
|
+
doc="""Atomically exchange ``value`` with ``arr[i,j,k,l]`` and return the old value.
|
|
7452
|
+
|
|
7453
|
+
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
7454
|
+
group="Utility",
|
|
7455
|
+
skip_replay=True,
|
|
7456
|
+
)
|
|
7457
|
+
|
|
7458
|
+
add_builtin(
|
|
7459
|
+
"atomic_and",
|
|
7460
|
+
hidden=hidden,
|
|
7461
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
7462
|
+
constraint=atomic_op_constraint,
|
|
7463
|
+
value_func=create_atomic_op_value_func("and"),
|
|
7464
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7465
|
+
doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i]``, atomically update the array, and return the old value.
|
|
7466
|
+
This function is automatically invoked when using the syntax ``arr[i] &= value``.""",
|
|
7467
|
+
group="Utility",
|
|
7468
|
+
skip_replay=True,
|
|
7469
|
+
is_differentiable=False,
|
|
7470
|
+
)
|
|
7471
|
+
add_builtin(
|
|
7472
|
+
"atomic_and",
|
|
7473
|
+
hidden=hidden,
|
|
7474
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
7475
|
+
constraint=atomic_op_constraint,
|
|
7476
|
+
value_func=create_atomic_op_value_func("and"),
|
|
7477
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7478
|
+
doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
|
|
7479
|
+
This function is automatically invoked when using the syntax ``arr[i,j] &= value``.""",
|
|
7480
|
+
group="Utility",
|
|
7481
|
+
skip_replay=True,
|
|
7482
|
+
is_differentiable=False,
|
|
7483
|
+
)
|
|
7484
|
+
add_builtin(
|
|
7485
|
+
"atomic_and",
|
|
7486
|
+
hidden=hidden,
|
|
7487
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
7488
|
+
constraint=atomic_op_constraint,
|
|
7489
|
+
value_func=create_atomic_op_value_func("and"),
|
|
7490
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7491
|
+
doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
|
|
7492
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k] &= value``.""",
|
|
7493
|
+
group="Utility",
|
|
7494
|
+
skip_replay=True,
|
|
7495
|
+
is_differentiable=False,
|
|
7496
|
+
)
|
|
7497
|
+
add_builtin(
|
|
7498
|
+
"atomic_and",
|
|
7499
|
+
hidden=hidden,
|
|
7500
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
7501
|
+
constraint=atomic_op_constraint,
|
|
7502
|
+
value_func=create_atomic_op_value_func("and"),
|
|
7503
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7504
|
+
doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
|
|
7505
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k,l] &= value``.""",
|
|
7506
|
+
group="Utility",
|
|
7507
|
+
skip_replay=True,
|
|
7508
|
+
is_differentiable=False,
|
|
7509
|
+
)
|
|
7510
|
+
|
|
7511
|
+
add_builtin(
|
|
7512
|
+
"atomic_or",
|
|
7513
|
+
hidden=hidden,
|
|
7514
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
7515
|
+
constraint=atomic_op_constraint,
|
|
7516
|
+
value_func=create_atomic_op_value_func("or"),
|
|
7517
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7518
|
+
doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i]``, atomically update the array, and return the old value.
|
|
7519
|
+
This function is automatically invoked when using the syntax ``arr[i] |= value``.""",
|
|
7520
|
+
group="Utility",
|
|
7521
|
+
skip_replay=True,
|
|
7522
|
+
is_differentiable=False,
|
|
7523
|
+
)
|
|
7524
|
+
add_builtin(
|
|
7525
|
+
"atomic_or",
|
|
7526
|
+
hidden=hidden,
|
|
7527
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
7528
|
+
constraint=atomic_op_constraint,
|
|
7529
|
+
value_func=create_atomic_op_value_func("or"),
|
|
7530
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7531
|
+
doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
|
|
7532
|
+
This function is automatically invoked when using the syntax ``arr[i,j] |= value``.""",
|
|
7533
|
+
group="Utility",
|
|
7534
|
+
skip_replay=True,
|
|
7535
|
+
is_differentiable=False,
|
|
7536
|
+
)
|
|
7537
|
+
add_builtin(
|
|
7538
|
+
"atomic_or",
|
|
7539
|
+
hidden=hidden,
|
|
7540
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
7541
|
+
constraint=atomic_op_constraint,
|
|
7542
|
+
value_func=create_atomic_op_value_func("or"),
|
|
7543
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7544
|
+
doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
|
|
7545
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k] |= value``.""",
|
|
7546
|
+
group="Utility",
|
|
7547
|
+
skip_replay=True,
|
|
7548
|
+
is_differentiable=False,
|
|
7549
|
+
)
|
|
7550
|
+
add_builtin(
|
|
7551
|
+
"atomic_or",
|
|
7552
|
+
hidden=hidden,
|
|
7553
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
7554
|
+
constraint=atomic_op_constraint,
|
|
7555
|
+
value_func=create_atomic_op_value_func("or"),
|
|
7556
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7557
|
+
doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
|
|
7558
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k,l] |= value``.""",
|
|
7559
|
+
group="Utility",
|
|
7560
|
+
skip_replay=True,
|
|
7561
|
+
is_differentiable=False,
|
|
7562
|
+
)
|
|
7563
|
+
|
|
7564
|
+
add_builtin(
|
|
7565
|
+
"atomic_xor",
|
|
7566
|
+
hidden=hidden,
|
|
7567
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
7568
|
+
constraint=atomic_op_constraint,
|
|
7569
|
+
value_func=create_atomic_op_value_func("xor"),
|
|
7570
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7571
|
+
doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i]``, atomically update the array, and return the old value.
|
|
7572
|
+
This function is automatically invoked when using the syntax ``arr[i] ^= value``.""",
|
|
7573
|
+
group="Utility",
|
|
7574
|
+
skip_replay=True,
|
|
7575
|
+
is_differentiable=False,
|
|
7576
|
+
)
|
|
7577
|
+
add_builtin(
|
|
7578
|
+
"atomic_xor",
|
|
7579
|
+
hidden=hidden,
|
|
7580
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
7581
|
+
constraint=atomic_op_constraint,
|
|
7582
|
+
value_func=create_atomic_op_value_func("xor"),
|
|
7583
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7584
|
+
doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
|
|
7585
|
+
This function is automatically invoked when using the syntax ``arr[i,j] ^= value``.""",
|
|
7586
|
+
group="Utility",
|
|
7587
|
+
skip_replay=True,
|
|
7588
|
+
is_differentiable=False,
|
|
6730
7589
|
)
|
|
6731
7590
|
add_builtin(
|
|
6732
|
-
"
|
|
7591
|
+
"atomic_xor",
|
|
6733
7592
|
hidden=hidden,
|
|
6734
7593
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
6735
7594
|
constraint=atomic_op_constraint,
|
|
6736
|
-
value_func=create_atomic_op_value_func("
|
|
7595
|
+
value_func=create_atomic_op_value_func("xor"),
|
|
6737
7596
|
dispatch_func=atomic_op_dispatch_func,
|
|
6738
|
-
doc="""Atomically
|
|
6739
|
-
|
|
6740
|
-
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
7597
|
+
doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
|
|
7598
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k] ^= value``.""",
|
|
6741
7599
|
group="Utility",
|
|
6742
7600
|
skip_replay=True,
|
|
7601
|
+
is_differentiable=False,
|
|
6743
7602
|
)
|
|
6744
7603
|
add_builtin(
|
|
6745
|
-
"
|
|
7604
|
+
"atomic_xor",
|
|
6746
7605
|
hidden=hidden,
|
|
6747
7606
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
6748
7607
|
constraint=atomic_op_constraint,
|
|
6749
|
-
value_func=create_atomic_op_value_func("
|
|
7608
|
+
value_func=create_atomic_op_value_func("xor"),
|
|
6750
7609
|
dispatch_func=atomic_op_dispatch_func,
|
|
6751
|
-
doc="""Atomically
|
|
6752
|
-
|
|
6753
|
-
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
7610
|
+
doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
|
|
7611
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k,l] ^= value``.""",
|
|
6754
7612
|
group="Utility",
|
|
6755
7613
|
skip_replay=True,
|
|
7614
|
+
is_differentiable=False,
|
|
6756
7615
|
)
|
|
6757
7616
|
|
|
6758
7617
|
|
|
@@ -6903,6 +7762,7 @@ add_builtin(
|
|
|
6903
7762
|
hidden=True,
|
|
6904
7763
|
group="Utility",
|
|
6905
7764
|
skip_replay=True,
|
|
7765
|
+
is_differentiable=False,
|
|
6906
7766
|
)
|
|
6907
7767
|
# implements &quaternion[index]
|
|
6908
7768
|
add_builtin(
|
|
@@ -6913,6 +7773,7 @@ add_builtin(
|
|
|
6913
7773
|
hidden=True,
|
|
6914
7774
|
group="Utility",
|
|
6915
7775
|
skip_replay=True,
|
|
7776
|
+
is_differentiable=False,
|
|
6916
7777
|
)
|
|
6917
7778
|
# implements &transformation[index]
|
|
6918
7779
|
add_builtin(
|
|
@@ -6923,6 +7784,7 @@ add_builtin(
|
|
|
6923
7784
|
hidden=True,
|
|
6924
7785
|
group="Utility",
|
|
6925
7786
|
skip_replay=True,
|
|
7787
|
+
is_differentiable=False,
|
|
6926
7788
|
)
|
|
6927
7789
|
# implements &(*vector)[index]
|
|
6928
7790
|
add_builtin(
|
|
@@ -6933,6 +7795,7 @@ add_builtin(
|
|
|
6933
7795
|
hidden=True,
|
|
6934
7796
|
group="Utility",
|
|
6935
7797
|
skip_replay=True,
|
|
7798
|
+
is_differentiable=False,
|
|
6936
7799
|
)
|
|
6937
7800
|
# implements &(*matrix)[i, j]
|
|
6938
7801
|
add_builtin(
|
|
@@ -6943,6 +7806,7 @@ add_builtin(
|
|
|
6943
7806
|
hidden=True,
|
|
6944
7807
|
group="Utility",
|
|
6945
7808
|
skip_replay=True,
|
|
7809
|
+
is_differentiable=False,
|
|
6946
7810
|
)
|
|
6947
7811
|
# implements &(*quaternion)[index]
|
|
6948
7812
|
add_builtin(
|
|
@@ -6953,6 +7817,7 @@ add_builtin(
|
|
|
6953
7817
|
hidden=True,
|
|
6954
7818
|
group="Utility",
|
|
6955
7819
|
skip_replay=True,
|
|
7820
|
+
is_differentiable=False,
|
|
6956
7821
|
)
|
|
6957
7822
|
# implements &(*transformation)[index]
|
|
6958
7823
|
add_builtin(
|
|
@@ -6963,6 +7828,7 @@ add_builtin(
|
|
|
6963
7828
|
hidden=True,
|
|
6964
7829
|
group="Utility",
|
|
6965
7830
|
skip_replay=True,
|
|
7831
|
+
is_differentiable=False,
|
|
6966
7832
|
)
|
|
6967
7833
|
|
|
6968
7834
|
|
|
@@ -7158,6 +8024,43 @@ add_builtin(
|
|
|
7158
8024
|
)
|
|
7159
8025
|
|
|
7160
8026
|
|
|
8027
|
+
# implements vector[idx] &= scalar
|
|
8028
|
+
add_builtin(
|
|
8029
|
+
"bit_and_inplace",
|
|
8030
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
|
|
8031
|
+
value_type=None,
|
|
8032
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
8033
|
+
hidden=True,
|
|
8034
|
+
export=False,
|
|
8035
|
+
group="Utility",
|
|
8036
|
+
is_differentiable=False,
|
|
8037
|
+
)
|
|
8038
|
+
|
|
8039
|
+
# implements vector[idx] |= scalar
|
|
8040
|
+
add_builtin(
|
|
8041
|
+
"bit_or_inplace",
|
|
8042
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
|
|
8043
|
+
value_type=None,
|
|
8044
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
8045
|
+
hidden=True,
|
|
8046
|
+
export=False,
|
|
8047
|
+
group="Utility",
|
|
8048
|
+
is_differentiable=False,
|
|
8049
|
+
)
|
|
8050
|
+
|
|
8051
|
+
# implements vector[idx] ^= scalar
|
|
8052
|
+
add_builtin(
|
|
8053
|
+
"bit_xor_inplace",
|
|
8054
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
|
|
8055
|
+
value_type=None,
|
|
8056
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
8057
|
+
hidden=True,
|
|
8058
|
+
export=False,
|
|
8059
|
+
group="Utility",
|
|
8060
|
+
is_differentiable=False,
|
|
8061
|
+
)
|
|
8062
|
+
|
|
8063
|
+
|
|
7161
8064
|
def matrix_index_row_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
7162
8065
|
mat_type = arg_types["a"]
|
|
7163
8066
|
row_type = mat_type._wp_row_type_
|
|
@@ -7173,6 +8076,7 @@ add_builtin(
|
|
|
7173
8076
|
hidden=True,
|
|
7174
8077
|
group="Utility",
|
|
7175
8078
|
skip_replay=True,
|
|
8079
|
+
is_differentiable=False,
|
|
7176
8080
|
)
|
|
7177
8081
|
|
|
7178
8082
|
|
|
@@ -7191,6 +8095,7 @@ add_builtin(
|
|
|
7191
8095
|
hidden=True,
|
|
7192
8096
|
group="Utility",
|
|
7193
8097
|
skip_replay=True,
|
|
8098
|
+
is_differentiable=False,
|
|
7194
8099
|
)
|
|
7195
8100
|
|
|
7196
8101
|
|
|
@@ -7390,6 +8295,78 @@ add_builtin(
|
|
|
7390
8295
|
)
|
|
7391
8296
|
|
|
7392
8297
|
|
|
8298
|
+
# implements matrix[i] &= value
|
|
8299
|
+
add_builtin(
|
|
8300
|
+
"bit_and_inplace",
|
|
8301
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
|
|
8302
|
+
value_type=None,
|
|
8303
|
+
hidden=True,
|
|
8304
|
+
export=False,
|
|
8305
|
+
group="Utility",
|
|
8306
|
+
is_differentiable=False,
|
|
8307
|
+
)
|
|
8308
|
+
|
|
8309
|
+
|
|
8310
|
+
# implements matrix[i,j] &= value
|
|
8311
|
+
add_builtin(
|
|
8312
|
+
"bit_and_inplace",
|
|
8313
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
|
|
8314
|
+
value_type=None,
|
|
8315
|
+
hidden=True,
|
|
8316
|
+
export=False,
|
|
8317
|
+
group="Utility",
|
|
8318
|
+
is_differentiable=False,
|
|
8319
|
+
)
|
|
8320
|
+
|
|
8321
|
+
|
|
8322
|
+
# implements matrix[i] |= value
|
|
8323
|
+
add_builtin(
|
|
8324
|
+
"bit_or_inplace",
|
|
8325
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
|
|
8326
|
+
value_type=None,
|
|
8327
|
+
hidden=True,
|
|
8328
|
+
export=False,
|
|
8329
|
+
group="Utility",
|
|
8330
|
+
is_differentiable=False,
|
|
8331
|
+
)
|
|
8332
|
+
|
|
8333
|
+
|
|
8334
|
+
# implements matrix[i,j] |= value
|
|
8335
|
+
add_builtin(
|
|
8336
|
+
"bit_or_inplace",
|
|
8337
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
|
|
8338
|
+
value_type=None,
|
|
8339
|
+
hidden=True,
|
|
8340
|
+
export=False,
|
|
8341
|
+
group="Utility",
|
|
8342
|
+
is_differentiable=False,
|
|
8343
|
+
)
|
|
8344
|
+
|
|
8345
|
+
|
|
8346
|
+
# implements matrix[i] ^= value
|
|
8347
|
+
add_builtin(
|
|
8348
|
+
"bit_xor_inplace",
|
|
8349
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
|
|
8350
|
+
value_type=None,
|
|
8351
|
+
hidden=True,
|
|
8352
|
+
export=False,
|
|
8353
|
+
group="Utility",
|
|
8354
|
+
is_differentiable=False,
|
|
8355
|
+
)
|
|
8356
|
+
|
|
8357
|
+
|
|
8358
|
+
# implements matrix[i,j] ^= value
|
|
8359
|
+
add_builtin(
|
|
8360
|
+
"bit_xor_inplace",
|
|
8361
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
|
|
8362
|
+
value_type=None,
|
|
8363
|
+
hidden=True,
|
|
8364
|
+
export=False,
|
|
8365
|
+
group="Utility",
|
|
8366
|
+
is_differentiable=False,
|
|
8367
|
+
)
|
|
8368
|
+
|
|
8369
|
+
|
|
7393
8370
|
for t in scalar_types + vector_types + (bool,):
|
|
7394
8371
|
if "vec" in t.__name__ or "mat" in t.__name__:
|
|
7395
8372
|
continue
|
|
@@ -7401,6 +8378,7 @@ for t in scalar_types + vector_types + (bool,):
|
|
|
7401
8378
|
doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
|
|
7402
8379
|
group="Utility",
|
|
7403
8380
|
hidden=True,
|
|
8381
|
+
is_differentiable=False,
|
|
7404
8382
|
)
|
|
7405
8383
|
|
|
7406
8384
|
add_builtin(
|
|
@@ -7411,6 +8389,7 @@ for t in scalar_types + vector_types + (bool,):
|
|
|
7411
8389
|
group="Utility",
|
|
7412
8390
|
hidden=True,
|
|
7413
8391
|
export=False,
|
|
8392
|
+
is_differentiable=False,
|
|
7414
8393
|
)
|
|
7415
8394
|
|
|
7416
8395
|
|
|
@@ -7429,6 +8408,7 @@ add_builtin(
|
|
|
7429
8408
|
doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
|
|
7430
8409
|
group="Utility",
|
|
7431
8410
|
hidden=True,
|
|
8411
|
+
is_differentiable=False,
|
|
7432
8412
|
)
|
|
7433
8413
|
add_builtin(
|
|
7434
8414
|
"expect_neq",
|
|
@@ -7439,6 +8419,7 @@ add_builtin(
|
|
|
7439
8419
|
group="Utility",
|
|
7440
8420
|
hidden=True,
|
|
7441
8421
|
export=False,
|
|
8422
|
+
is_differentiable=False,
|
|
7442
8423
|
)
|
|
7443
8424
|
|
|
7444
8425
|
add_builtin(
|
|
@@ -7449,6 +8430,7 @@ add_builtin(
|
|
|
7449
8430
|
doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
|
|
7450
8431
|
group="Utility",
|
|
7451
8432
|
hidden=True,
|
|
8433
|
+
is_differentiable=False,
|
|
7452
8434
|
)
|
|
7453
8435
|
add_builtin(
|
|
7454
8436
|
"expect_neq",
|
|
@@ -7459,6 +8441,7 @@ add_builtin(
|
|
|
7459
8441
|
group="Utility",
|
|
7460
8442
|
hidden=True,
|
|
7461
8443
|
export=False,
|
|
8444
|
+
is_differentiable=False,
|
|
7462
8445
|
)
|
|
7463
8446
|
|
|
7464
8447
|
add_builtin(
|
|
@@ -7549,6 +8532,7 @@ add_builtin(
|
|
|
7549
8532
|
value_type=None,
|
|
7550
8533
|
doc="Prints an error to stdout if ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
7551
8534
|
group="Utility",
|
|
8535
|
+
is_differentiable=False,
|
|
7552
8536
|
)
|
|
7553
8537
|
add_builtin(
|
|
7554
8538
|
"expect_near",
|
|
@@ -7558,6 +8542,7 @@ add_builtin(
|
|
|
7558
8542
|
value_type=None,
|
|
7559
8543
|
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
7560
8544
|
group="Utility",
|
|
8545
|
+
is_differentiable=False,
|
|
7561
8546
|
)
|
|
7562
8547
|
add_builtin(
|
|
7563
8548
|
"expect_near",
|
|
@@ -7567,6 +8552,7 @@ add_builtin(
|
|
|
7567
8552
|
value_type=None,
|
|
7568
8553
|
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
7569
8554
|
group="Utility",
|
|
8555
|
+
is_differentiable=False,
|
|
7570
8556
|
)
|
|
7571
8557
|
add_builtin(
|
|
7572
8558
|
"expect_near",
|
|
@@ -7580,6 +8566,7 @@ add_builtin(
|
|
|
7580
8566
|
value_type=None,
|
|
7581
8567
|
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
7582
8568
|
group="Utility",
|
|
8569
|
+
is_differentiable=False,
|
|
7583
8570
|
)
|
|
7584
8571
|
|
|
7585
8572
|
# ---------------------------------
|
|
@@ -7590,6 +8577,7 @@ add_builtin(
|
|
|
7590
8577
|
input_types={"arr": array(dtype=Scalar), "value": Scalar},
|
|
7591
8578
|
value_type=int,
|
|
7592
8579
|
doc="Search a sorted array ``arr`` for the closest element greater than or equal to ``value``.",
|
|
8580
|
+
is_differentiable=False,
|
|
7593
8581
|
)
|
|
7594
8582
|
|
|
7595
8583
|
add_builtin(
|
|
@@ -7597,6 +8585,7 @@ add_builtin(
|
|
|
7597
8585
|
input_types={"arr": array(dtype=Scalar), "arr_begin": int, "arr_end": int, "value": Scalar},
|
|
7598
8586
|
value_type=int,
|
|
7599
8587
|
doc="Search a sorted array ``arr`` in the range [arr_begin, arr_end) for the closest element greater than or equal to ``value``.",
|
|
8588
|
+
is_differentiable=False,
|
|
7600
8589
|
)
|
|
7601
8590
|
|
|
7602
8591
|
# ---------------------------------
|
|
@@ -7672,12 +8661,157 @@ add_builtin(
|
|
|
7672
8661
|
)
|
|
7673
8662
|
|
|
7674
8663
|
# bitwise operators
|
|
7675
|
-
add_builtin(
|
|
7676
|
-
|
|
7677
|
-
|
|
7678
|
-
|
|
7679
|
-
|
|
7680
|
-
|
|
8664
|
+
add_builtin(
|
|
8665
|
+
"bit_and",
|
|
8666
|
+
input_types={"a": Int, "b": Int},
|
|
8667
|
+
value_func=sametypes_create_value_func(Int),
|
|
8668
|
+
group="Operators",
|
|
8669
|
+
is_differentiable=False,
|
|
8670
|
+
)
|
|
8671
|
+
add_builtin(
|
|
8672
|
+
"bit_and",
|
|
8673
|
+
input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
|
|
8674
|
+
constraint=sametypes,
|
|
8675
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
|
|
8676
|
+
doc="",
|
|
8677
|
+
group="Operators",
|
|
8678
|
+
is_differentiable=False,
|
|
8679
|
+
)
|
|
8680
|
+
add_builtin(
|
|
8681
|
+
"bit_and",
|
|
8682
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
|
|
8683
|
+
constraint=sametypes,
|
|
8684
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
|
|
8685
|
+
doc="",
|
|
8686
|
+
group="Operators",
|
|
8687
|
+
is_differentiable=False,
|
|
8688
|
+
)
|
|
8689
|
+
|
|
8690
|
+
add_builtin(
|
|
8691
|
+
"bit_or",
|
|
8692
|
+
input_types={"a": Int, "b": Int},
|
|
8693
|
+
value_func=sametypes_create_value_func(Int),
|
|
8694
|
+
group="Operators",
|
|
8695
|
+
is_differentiable=False,
|
|
8696
|
+
)
|
|
8697
|
+
add_builtin(
|
|
8698
|
+
"bit_or",
|
|
8699
|
+
input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
|
|
8700
|
+
constraint=sametypes,
|
|
8701
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
|
|
8702
|
+
doc="",
|
|
8703
|
+
group="Operators",
|
|
8704
|
+
is_differentiable=False,
|
|
8705
|
+
)
|
|
8706
|
+
add_builtin(
|
|
8707
|
+
"bit_or",
|
|
8708
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
|
|
8709
|
+
constraint=sametypes,
|
|
8710
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
|
|
8711
|
+
doc="",
|
|
8712
|
+
group="Operators",
|
|
8713
|
+
is_differentiable=False,
|
|
8714
|
+
)
|
|
8715
|
+
|
|
8716
|
+
add_builtin(
|
|
8717
|
+
"bit_xor",
|
|
8718
|
+
input_types={"a": Int, "b": Int},
|
|
8719
|
+
value_func=sametypes_create_value_func(Int),
|
|
8720
|
+
group="Operators",
|
|
8721
|
+
is_differentiable=False,
|
|
8722
|
+
)
|
|
8723
|
+
add_builtin(
|
|
8724
|
+
"bit_xor",
|
|
8725
|
+
input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
|
|
8726
|
+
constraint=sametypes,
|
|
8727
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
|
|
8728
|
+
doc="",
|
|
8729
|
+
group="Operators",
|
|
8730
|
+
is_differentiable=False,
|
|
8731
|
+
)
|
|
8732
|
+
add_builtin(
|
|
8733
|
+
"bit_xor",
|
|
8734
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
|
|
8735
|
+
constraint=sametypes,
|
|
8736
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
|
|
8737
|
+
doc="",
|
|
8738
|
+
group="Operators",
|
|
8739
|
+
is_differentiable=False,
|
|
8740
|
+
)
|
|
8741
|
+
|
|
8742
|
+
add_builtin(
|
|
8743
|
+
"lshift",
|
|
8744
|
+
input_types={"a": Int, "b": Int},
|
|
8745
|
+
value_func=sametypes_create_value_func(Int),
|
|
8746
|
+
group="Operators",
|
|
8747
|
+
is_differentiable=False,
|
|
8748
|
+
)
|
|
8749
|
+
add_builtin(
|
|
8750
|
+
"lshift",
|
|
8751
|
+
input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
|
|
8752
|
+
constraint=sametypes,
|
|
8753
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
|
|
8754
|
+
doc="",
|
|
8755
|
+
group="Operators",
|
|
8756
|
+
is_differentiable=False,
|
|
8757
|
+
)
|
|
8758
|
+
add_builtin(
|
|
8759
|
+
"lshift",
|
|
8760
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
|
|
8761
|
+
constraint=sametypes,
|
|
8762
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
|
|
8763
|
+
doc="",
|
|
8764
|
+
group="Operators",
|
|
8765
|
+
is_differentiable=False,
|
|
8766
|
+
)
|
|
8767
|
+
|
|
8768
|
+
add_builtin(
|
|
8769
|
+
"rshift",
|
|
8770
|
+
input_types={"a": Int, "b": Int},
|
|
8771
|
+
value_func=sametypes_create_value_func(Int),
|
|
8772
|
+
group="Operators",
|
|
8773
|
+
is_differentiable=False,
|
|
8774
|
+
)
|
|
8775
|
+
add_builtin(
|
|
8776
|
+
"rshift",
|
|
8777
|
+
input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
|
|
8778
|
+
constraint=sametypes,
|
|
8779
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
|
|
8780
|
+
doc="",
|
|
8781
|
+
group="Operators",
|
|
8782
|
+
is_differentiable=False,
|
|
8783
|
+
)
|
|
8784
|
+
add_builtin(
|
|
8785
|
+
"rshift",
|
|
8786
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
|
|
8787
|
+
constraint=sametypes,
|
|
8788
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
|
|
8789
|
+
doc="",
|
|
8790
|
+
group="Operators",
|
|
8791
|
+
is_differentiable=False,
|
|
8792
|
+
)
|
|
8793
|
+
|
|
8794
|
+
add_builtin(
|
|
8795
|
+
"invert",
|
|
8796
|
+
input_types={"a": Int},
|
|
8797
|
+
value_func=sametypes_create_value_func(Int),
|
|
8798
|
+
group="Operators",
|
|
8799
|
+
is_differentiable=False,
|
|
8800
|
+
)
|
|
8801
|
+
add_builtin(
|
|
8802
|
+
"invert",
|
|
8803
|
+
input_types={"a": vector(length=Any, dtype=Int)},
|
|
8804
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
|
|
8805
|
+
group="Operators",
|
|
8806
|
+
is_differentiable=False,
|
|
8807
|
+
)
|
|
8808
|
+
add_builtin(
|
|
8809
|
+
"invert",
|
|
8810
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Int)},
|
|
8811
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
|
|
8812
|
+
group="Operators",
|
|
8813
|
+
is_differentiable=False,
|
|
8814
|
+
)
|
|
7681
8815
|
|
|
7682
8816
|
|
|
7683
8817
|
add_builtin(
|
|
@@ -7878,6 +9012,7 @@ add_builtin(
|
|
|
7878
9012
|
value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
|
|
7879
9013
|
doc="Modulo operation using truncated division.",
|
|
7880
9014
|
group="Operators",
|
|
9015
|
+
is_differentiable=False,
|
|
7881
9016
|
)
|
|
7882
9017
|
|
|
7883
9018
|
add_builtin(
|
|
@@ -7937,6 +9072,7 @@ add_builtin(
|
|
|
7937
9072
|
value_func=sametypes_create_value_func(Scalar),
|
|
7938
9073
|
doc="",
|
|
7939
9074
|
group="Operators",
|
|
9075
|
+
is_differentiable=False,
|
|
7940
9076
|
)
|
|
7941
9077
|
|
|
7942
9078
|
add_builtin("pos", input_types={"x": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators")
|
|
@@ -7984,12 +9120,28 @@ add_builtin(
|
|
|
7984
9120
|
group="Operators",
|
|
7985
9121
|
)
|
|
7986
9122
|
|
|
7987
|
-
add_builtin(
|
|
9123
|
+
add_builtin(
|
|
9124
|
+
"unot",
|
|
9125
|
+
input_types={"a": builtins.bool},
|
|
9126
|
+
value_type=builtins.bool,
|
|
9127
|
+
doc="",
|
|
9128
|
+
group="Operators",
|
|
9129
|
+
is_differentiable=False,
|
|
9130
|
+
)
|
|
7988
9131
|
for t in int_types:
|
|
7989
|
-
add_builtin(
|
|
9132
|
+
add_builtin(
|
|
9133
|
+
"unot", input_types={"a": t}, value_type=builtins.bool, doc="", group="Operators", is_differentiable=False
|
|
9134
|
+
)
|
|
7990
9135
|
|
|
7991
9136
|
|
|
7992
|
-
add_builtin(
|
|
9137
|
+
add_builtin(
|
|
9138
|
+
"unot",
|
|
9139
|
+
input_types={"a": array(dtype=Any)},
|
|
9140
|
+
value_type=builtins.bool,
|
|
9141
|
+
doc="",
|
|
9142
|
+
group="Operators",
|
|
9143
|
+
is_differentiable=False,
|
|
9144
|
+
)
|
|
7993
9145
|
|
|
7994
9146
|
|
|
7995
9147
|
# Tile operators
|
|
@@ -8061,6 +9213,45 @@ add_builtin(
|
|
|
8061
9213
|
export=False,
|
|
8062
9214
|
)
|
|
8063
9215
|
|
|
9216
|
+
add_builtin(
|
|
9217
|
+
"bit_and",
|
|
9218
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
9219
|
+
value_func=tile_binary_map_value_func,
|
|
9220
|
+
# dispatch_func=tile_map_dispatch_func,
|
|
9221
|
+
# variadic=True,
|
|
9222
|
+
native_func="tile_bit_and",
|
|
9223
|
+
doc="Bitwise AND each element of two tiles together",
|
|
9224
|
+
group="Tile Primitives",
|
|
9225
|
+
export=False,
|
|
9226
|
+
is_differentiable=False,
|
|
9227
|
+
)
|
|
9228
|
+
|
|
9229
|
+
add_builtin(
|
|
9230
|
+
"bit_or",
|
|
9231
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
9232
|
+
value_func=tile_binary_map_value_func,
|
|
9233
|
+
# dispatch_func=tile_map_dispatch_func,
|
|
9234
|
+
# variadic=True,
|
|
9235
|
+
native_func="tile_bit_or",
|
|
9236
|
+
doc="Bitwise OR each element of two tiles together",
|
|
9237
|
+
group="Tile Primitives",
|
|
9238
|
+
export=False,
|
|
9239
|
+
is_differentiable=False,
|
|
9240
|
+
)
|
|
9241
|
+
|
|
9242
|
+
add_builtin(
|
|
9243
|
+
"bit_xor",
|
|
9244
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
9245
|
+
value_func=tile_binary_map_value_func,
|
|
9246
|
+
# dispatch_func=tile_map_dispatch_func,
|
|
9247
|
+
# variadic=True,
|
|
9248
|
+
native_func="tile_bit_xor",
|
|
9249
|
+
doc="Bitwise XOR each element of two tiles together",
|
|
9250
|
+
group="Tile Primitives",
|
|
9251
|
+
export=False,
|
|
9252
|
+
is_differentiable=False,
|
|
9253
|
+
)
|
|
9254
|
+
|
|
8064
9255
|
|
|
8065
9256
|
add_builtin(
|
|
8066
9257
|
"mul",
|
|
@@ -8122,6 +9313,45 @@ add_builtin(
|
|
|
8122
9313
|
)
|
|
8123
9314
|
|
|
8124
9315
|
|
|
9316
|
+
add_builtin(
|
|
9317
|
+
"bit_and_inplace",
|
|
9318
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
9319
|
+
value_type=None,
|
|
9320
|
+
dispatch_func=tile_inplace_dispatch_func,
|
|
9321
|
+
export=False,
|
|
9322
|
+
hidden=True,
|
|
9323
|
+
native_func="tile_bit_and_inplace",
|
|
9324
|
+
group="Operators",
|
|
9325
|
+
is_differentiable=False,
|
|
9326
|
+
)
|
|
9327
|
+
|
|
9328
|
+
|
|
9329
|
+
add_builtin(
|
|
9330
|
+
"bit_or_inplace",
|
|
9331
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
9332
|
+
value_type=None,
|
|
9333
|
+
dispatch_func=tile_inplace_dispatch_func,
|
|
9334
|
+
export=False,
|
|
9335
|
+
hidden=True,
|
|
9336
|
+
native_func="tile_bit_or_inplace",
|
|
9337
|
+
group="Operators",
|
|
9338
|
+
is_differentiable=False,
|
|
9339
|
+
)
|
|
9340
|
+
|
|
9341
|
+
|
|
9342
|
+
add_builtin(
|
|
9343
|
+
"bit_xor_inplace",
|
|
9344
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
9345
|
+
value_type=None,
|
|
9346
|
+
dispatch_func=tile_inplace_dispatch_func,
|
|
9347
|
+
export=False,
|
|
9348
|
+
hidden=True,
|
|
9349
|
+
native_func="tile_bit_xor_inplace",
|
|
9350
|
+
group="Operators",
|
|
9351
|
+
is_differentiable=False,
|
|
9352
|
+
)
|
|
9353
|
+
|
|
9354
|
+
|
|
8125
9355
|
def tile_diag_add_value_func(arg_types, arg_values):
|
|
8126
9356
|
if arg_types is None:
|
|
8127
9357
|
return tile(dtype=Any, shape=Tuple[int, int])
|
|
@@ -8163,7 +9393,7 @@ def tile_diag_add_lto_dispatch_func(
|
|
|
8163
9393
|
return_values: List[Var],
|
|
8164
9394
|
arg_values: Mapping[str, Var],
|
|
8165
9395
|
options: Mapping[str, Any],
|
|
8166
|
-
builder: warp.context.ModuleBuilder,
|
|
9396
|
+
builder: warp._src.context.ModuleBuilder,
|
|
8167
9397
|
):
|
|
8168
9398
|
a = arg_values["a"]
|
|
8169
9399
|
d = arg_values["d"]
|
|
@@ -8183,6 +9413,7 @@ add_builtin(
|
|
|
8183
9413
|
doc="Add a square matrix and a diagonal matrix 'd' represented as a 1D tile",
|
|
8184
9414
|
group="Tile Primitives",
|
|
8185
9415
|
export=False,
|
|
9416
|
+
is_differentiable=False,
|
|
8186
9417
|
)
|
|
8187
9418
|
|
|
8188
9419
|
|
|
@@ -8239,7 +9470,7 @@ def tile_matmul_lto_dispatch_func(
|
|
|
8239
9470
|
return_values: List[Var],
|
|
8240
9471
|
arg_values: Mapping[str, Var],
|
|
8241
9472
|
options: Mapping[str, Any],
|
|
8242
|
-
builder: warp.context.ModuleBuilder,
|
|
9473
|
+
builder: warp._src.context.ModuleBuilder,
|
|
8243
9474
|
):
|
|
8244
9475
|
a = arg_values["a"]
|
|
8245
9476
|
b = arg_values["b"]
|
|
@@ -8277,7 +9508,7 @@ def tile_matmul_lto_dispatch_func(
|
|
|
8277
9508
|
num_threads = options["block_dim"]
|
|
8278
9509
|
arch = options["output_arch"]
|
|
8279
9510
|
|
|
8280
|
-
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
9511
|
+
if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
|
|
8281
9512
|
# CPU/no-MathDx dispatch
|
|
8282
9513
|
return ((0, 0, 0, a, b, out), template_args, [], 0)
|
|
8283
9514
|
else:
|
|
@@ -8290,7 +9521,7 @@ def tile_matmul_lto_dispatch_func(
|
|
|
8290
9521
|
|
|
8291
9522
|
# generate the LTOs
|
|
8292
9523
|
# C += A * B
|
|
8293
|
-
(fun_forward, lto_forward) = warp.build.build_lto_dot(
|
|
9524
|
+
(fun_forward, lto_forward) = warp._src.build.build_lto_dot(
|
|
8294
9525
|
M,
|
|
8295
9526
|
N,
|
|
8296
9527
|
K,
|
|
@@ -8306,7 +9537,7 @@ def tile_matmul_lto_dispatch_func(
|
|
|
8306
9537
|
)
|
|
8307
9538
|
if warp.config.enable_backward:
|
|
8308
9539
|
# adjA += adjC * B^T - Transpose ~= flipped layout
|
|
8309
|
-
(fun_backward_A, lto_backward_A) = warp.build.build_lto_dot(
|
|
9540
|
+
(fun_backward_A, lto_backward_A) = warp._src.build.build_lto_dot(
|
|
8310
9541
|
M,
|
|
8311
9542
|
K,
|
|
8312
9543
|
N,
|
|
@@ -8321,7 +9552,7 @@ def tile_matmul_lto_dispatch_func(
|
|
|
8321
9552
|
builder,
|
|
8322
9553
|
)
|
|
8323
9554
|
# adjB += A^T * adjC - Transpose ~= flipped layout
|
|
8324
|
-
(fun_backward_B, lto_backward_B) = warp.build.build_lto_dot(
|
|
9555
|
+
(fun_backward_B, lto_backward_B) = warp._src.build.build_lto_dot(
|
|
8325
9556
|
K,
|
|
8326
9557
|
N,
|
|
8327
9558
|
M,
|
|
@@ -8438,7 +9669,7 @@ def tile_fft_generic_lto_dispatch_func(
|
|
|
8438
9669
|
return_values: List[Var],
|
|
8439
9670
|
arg_values: Mapping[str, Var],
|
|
8440
9671
|
options: Mapping[str, Any],
|
|
8441
|
-
builder: warp.context.ModuleBuilder,
|
|
9672
|
+
builder: warp._src.context.ModuleBuilder,
|
|
8442
9673
|
direction: str | None = None,
|
|
8443
9674
|
):
|
|
8444
9675
|
inout = arg_values["inout"]
|
|
@@ -8467,12 +9698,12 @@ def tile_fft_generic_lto_dispatch_func(
|
|
|
8467
9698
|
arch = options["output_arch"]
|
|
8468
9699
|
ept = size // num_threads
|
|
8469
9700
|
|
|
8470
|
-
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
9701
|
+
if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
|
|
8471
9702
|
# CPU/no-MathDx dispatch
|
|
8472
9703
|
return ([], [], [], 0)
|
|
8473
9704
|
else:
|
|
8474
9705
|
# generate the LTO
|
|
8475
|
-
lto_symbol, lto_code_data, shared_memory_bytes = warp.build.build_lto_fft(
|
|
9706
|
+
lto_symbol, lto_code_data, shared_memory_bytes = warp._src.build.build_lto_fft(
|
|
8476
9707
|
arch, size, ept, direction, dir, precision, builder
|
|
8477
9708
|
)
|
|
8478
9709
|
|
|
@@ -8510,6 +9741,7 @@ add_builtin(
|
|
|
8510
9741
|
group="Tile Primitives",
|
|
8511
9742
|
export=False,
|
|
8512
9743
|
namespace="",
|
|
9744
|
+
is_differentiable=False,
|
|
8513
9745
|
)
|
|
8514
9746
|
|
|
8515
9747
|
add_builtin(
|
|
@@ -8531,6 +9763,7 @@ add_builtin(
|
|
|
8531
9763
|
group="Tile Primitives",
|
|
8532
9764
|
export=False,
|
|
8533
9765
|
namespace="",
|
|
9766
|
+
is_differentiable=False,
|
|
8534
9767
|
)
|
|
8535
9768
|
|
|
8536
9769
|
|
|
@@ -8575,7 +9808,7 @@ def tile_cholesky_generic_lto_dispatch_func(
|
|
|
8575
9808
|
return_values: List[Var],
|
|
8576
9809
|
arg_values: Mapping[str, Var],
|
|
8577
9810
|
options: Mapping[str, Any],
|
|
8578
|
-
builder: warp.context.ModuleBuilder,
|
|
9811
|
+
builder: warp._src.context.ModuleBuilder,
|
|
8579
9812
|
):
|
|
8580
9813
|
a = arg_values["A"]
|
|
8581
9814
|
# force source tile to shared memory
|
|
@@ -8595,7 +9828,7 @@ def tile_cholesky_generic_lto_dispatch_func(
|
|
|
8595
9828
|
|
|
8596
9829
|
arch = options["output_arch"]
|
|
8597
9830
|
|
|
8598
|
-
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
9831
|
+
if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
|
|
8599
9832
|
# CPU/no-MathDx dispatch
|
|
8600
9833
|
return ((0, a, out), [], [], 0)
|
|
8601
9834
|
else:
|
|
@@ -8610,7 +9843,7 @@ def tile_cholesky_generic_lto_dispatch_func(
|
|
|
8610
9843
|
req_smem_bytes = a.type.size * type_size_in_bytes(a.type.dtype)
|
|
8611
9844
|
|
|
8612
9845
|
# generate the LTO
|
|
8613
|
-
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
9846
|
+
lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
|
|
8614
9847
|
M,
|
|
8615
9848
|
N,
|
|
8616
9849
|
1,
|
|
@@ -8655,6 +9888,7 @@ add_builtin(
|
|
|
8655
9888
|
group="Tile Primitives",
|
|
8656
9889
|
export=False,
|
|
8657
9890
|
namespace="",
|
|
9891
|
+
is_differentiable=False,
|
|
8658
9892
|
)
|
|
8659
9893
|
|
|
8660
9894
|
|
|
@@ -8698,7 +9932,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
|
8698
9932
|
return_values: List[Var],
|
|
8699
9933
|
arg_values: Mapping[str, Var],
|
|
8700
9934
|
options: Mapping[str, Any],
|
|
8701
|
-
builder: warp.context.ModuleBuilder,
|
|
9935
|
+
builder: warp._src.context.ModuleBuilder,
|
|
8702
9936
|
):
|
|
8703
9937
|
L = arg_values["L"]
|
|
8704
9938
|
y = arg_values["y"]
|
|
@@ -8727,7 +9961,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
|
8727
9961
|
|
|
8728
9962
|
arch = options["output_arch"]
|
|
8729
9963
|
|
|
8730
|
-
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
9964
|
+
if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
|
|
8731
9965
|
# CPU/no-MathDx dispatch
|
|
8732
9966
|
return ((0, L, y, x), [], [], 0)
|
|
8733
9967
|
else:
|
|
@@ -8743,7 +9977,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
|
8743
9977
|
req_smem_bytes = (x.type.size + y.type.size + L.type.size) * type_size_in_bytes(L.type.dtype)
|
|
8744
9978
|
|
|
8745
9979
|
# generate the LTO
|
|
8746
|
-
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
9980
|
+
lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
|
|
8747
9981
|
M,
|
|
8748
9982
|
N,
|
|
8749
9983
|
NRHS,
|
|
@@ -8785,6 +10019,7 @@ add_builtin(
|
|
|
8785
10019
|
group="Tile Primitives",
|
|
8786
10020
|
export=False,
|
|
8787
10021
|
namespace="",
|
|
10022
|
+
is_differentiable=False,
|
|
8788
10023
|
)
|
|
8789
10024
|
|
|
8790
10025
|
|
|
@@ -8794,7 +10029,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
|
|
|
8794
10029
|
return_values: List[Var],
|
|
8795
10030
|
arg_values: Mapping[str, Var],
|
|
8796
10031
|
options: Mapping[str, Any],
|
|
8797
|
-
builder: warp.context.ModuleBuilder,
|
|
10032
|
+
builder: warp._src.context.ModuleBuilder,
|
|
8798
10033
|
):
|
|
8799
10034
|
L = arg_values["L"]
|
|
8800
10035
|
y = arg_values["y"]
|
|
@@ -8823,7 +10058,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
|
|
|
8823
10058
|
|
|
8824
10059
|
arch = options["output_arch"]
|
|
8825
10060
|
|
|
8826
|
-
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
10061
|
+
if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
|
|
8827
10062
|
# CPU/no-MathDx dispatch
|
|
8828
10063
|
return ((0, L, y, z), [], [], 0)
|
|
8829
10064
|
else:
|
|
@@ -8839,7 +10074,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
|
|
|
8839
10074
|
req_smem_bytes = (z.type.size + y.type.size + L.type.size) * type_size_in_bytes(L.type.dtype)
|
|
8840
10075
|
|
|
8841
10076
|
# generate the LTO
|
|
8842
|
-
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
10077
|
+
lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
|
|
8843
10078
|
M,
|
|
8844
10079
|
N,
|
|
8845
10080
|
NRHS,
|
|
@@ -8917,6 +10152,7 @@ add_builtin(
|
|
|
8917
10152
|
group="Tile Primitives",
|
|
8918
10153
|
export=False,
|
|
8919
10154
|
namespace="",
|
|
10155
|
+
is_differentiable=False,
|
|
8920
10156
|
)
|
|
8921
10157
|
|
|
8922
10158
|
|
|
@@ -8926,7 +10162,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
|
|
|
8926
10162
|
return_values: List[Var],
|
|
8927
10163
|
arg_values: Mapping[str, Var],
|
|
8928
10164
|
options: Mapping[str, Any],
|
|
8929
|
-
builder: warp.context.ModuleBuilder,
|
|
10165
|
+
builder: warp._src.context.ModuleBuilder,
|
|
8930
10166
|
):
|
|
8931
10167
|
U = arg_values["U"]
|
|
8932
10168
|
z = arg_values["z"]
|
|
@@ -8955,7 +10191,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
|
|
|
8955
10191
|
|
|
8956
10192
|
arch = options["output_arch"]
|
|
8957
10193
|
|
|
8958
|
-
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
10194
|
+
if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
|
|
8959
10195
|
# CPU/no-MathDx dispatch
|
|
8960
10196
|
return ((0, U, z, x), [], [], 0)
|
|
8961
10197
|
else:
|
|
@@ -8971,7 +10207,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
|
|
|
8971
10207
|
req_smem_bytes = (x.type.size + z.type.size + U.type.size) * type_size_in_bytes(U.type.dtype)
|
|
8972
10208
|
|
|
8973
10209
|
# generate the LTO
|
|
8974
|
-
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
10210
|
+
lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
|
|
8975
10211
|
M,
|
|
8976
10212
|
N,
|
|
8977
10213
|
NRHS,
|
|
@@ -9049,6 +10285,7 @@ add_builtin(
|
|
|
9049
10285
|
group="Tile Primitives",
|
|
9050
10286
|
export=False,
|
|
9051
10287
|
namespace="",
|
|
10288
|
+
is_differentiable=False,
|
|
9052
10289
|
)
|
|
9053
10290
|
|
|
9054
10291
|
|
|
@@ -9068,6 +10305,7 @@ add_builtin(
|
|
|
9068
10305
|
The return type of the expression must be either a Warp function, a string, or a type that is supported inside Warp kernels and functions
|
|
9069
10306
|
(excluding Warp arrays since they cannot be created in a Warp kernel at the moment).""",
|
|
9070
10307
|
group="Code Generation",
|
|
10308
|
+
is_differentiable=False,
|
|
9071
10309
|
)
|
|
9072
10310
|
|
|
9073
10311
|
|
|
@@ -9092,6 +10330,7 @@ add_builtin(
|
|
|
9092
10330
|
doc="Return the number of elements in a vector.",
|
|
9093
10331
|
group="Utility",
|
|
9094
10332
|
export=False,
|
|
10333
|
+
is_differentiable=False,
|
|
9095
10334
|
)
|
|
9096
10335
|
|
|
9097
10336
|
add_builtin(
|
|
@@ -9101,6 +10340,7 @@ add_builtin(
|
|
|
9101
10340
|
doc="Return the number of elements in a quaternion.",
|
|
9102
10341
|
group="Utility",
|
|
9103
10342
|
export=False,
|
|
10343
|
+
is_differentiable=False,
|
|
9104
10344
|
)
|
|
9105
10345
|
|
|
9106
10346
|
add_builtin(
|
|
@@ -9110,6 +10350,7 @@ add_builtin(
|
|
|
9110
10350
|
doc="Return the number of rows in a matrix.",
|
|
9111
10351
|
group="Utility",
|
|
9112
10352
|
export=False,
|
|
10353
|
+
is_differentiable=False,
|
|
9113
10354
|
)
|
|
9114
10355
|
|
|
9115
10356
|
add_builtin(
|
|
@@ -9119,6 +10360,7 @@ add_builtin(
|
|
|
9119
10360
|
doc="Return the number of elements in a transformation.",
|
|
9120
10361
|
group="Utility",
|
|
9121
10362
|
export=False,
|
|
10363
|
+
is_differentiable=False,
|
|
9122
10364
|
)
|
|
9123
10365
|
|
|
9124
10366
|
add_builtin(
|
|
@@ -9128,6 +10370,7 @@ add_builtin(
|
|
|
9128
10370
|
doc="Return the size of the first dimension in an array.",
|
|
9129
10371
|
group="Utility",
|
|
9130
10372
|
export=False,
|
|
10373
|
+
is_differentiable=False,
|
|
9131
10374
|
)
|
|
9132
10375
|
|
|
9133
10376
|
add_builtin(
|
|
@@ -9137,6 +10380,33 @@ add_builtin(
|
|
|
9137
10380
|
doc="Return the number of rows in a tile.",
|
|
9138
10381
|
group="Utility",
|
|
9139
10382
|
export=False,
|
|
10383
|
+
is_differentiable=False,
|
|
10384
|
+
)
|
|
10385
|
+
|
|
10386
|
+
|
|
10387
|
+
def cast_value_func(arg_types, arg_values):
|
|
10388
|
+
# Return generic type for doc builds.
|
|
10389
|
+
if arg_types is None:
|
|
10390
|
+
return Any
|
|
10391
|
+
|
|
10392
|
+
return arg_values["dtype"]
|
|
10393
|
+
|
|
10394
|
+
|
|
10395
|
+
def cast_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
10396
|
+
func_args = (args["a"],)
|
|
10397
|
+
template_args = (args["dtype"],)
|
|
10398
|
+
return (func_args, template_args)
|
|
10399
|
+
|
|
10400
|
+
|
|
10401
|
+
add_builtin(
|
|
10402
|
+
"cast",
|
|
10403
|
+
input_types={"a": Any, "dtype": Any},
|
|
10404
|
+
value_func=cast_value_func,
|
|
10405
|
+
dispatch_func=cast_dispatch_func,
|
|
10406
|
+
doc="Reinterpret a value as a different type while preserving its bit pattern.",
|
|
10407
|
+
group="Utility",
|
|
10408
|
+
export=False,
|
|
10409
|
+
is_differentiable=False,
|
|
9140
10410
|
)
|
|
9141
10411
|
|
|
9142
10412
|
|
|
@@ -9163,7 +10433,7 @@ add_builtin(
|
|
|
9163
10433
|
doc="Construct a tuple from a list of values",
|
|
9164
10434
|
group="Utility",
|
|
9165
10435
|
hidden=True,
|
|
9166
|
-
|
|
10436
|
+
is_differentiable=False,
|
|
9167
10437
|
export=False,
|
|
9168
10438
|
)
|
|
9169
10439
|
|
|
@@ -9200,7 +10470,7 @@ add_builtin(
|
|
|
9200
10470
|
dispatch_func=tuple_extract_dispatch_func,
|
|
9201
10471
|
group="Utility",
|
|
9202
10472
|
hidden=True,
|
|
9203
|
-
|
|
10473
|
+
is_differentiable=False,
|
|
9204
10474
|
)
|
|
9205
10475
|
|
|
9206
10476
|
|
|
@@ -9211,6 +10481,7 @@ add_builtin(
|
|
|
9211
10481
|
doc="Return the number of elements in a tuple.",
|
|
9212
10482
|
group="Utility",
|
|
9213
10483
|
export=False,
|
|
10484
|
+
is_differentiable=False,
|
|
9214
10485
|
)
|
|
9215
10486
|
|
|
9216
10487
|
# ---------------------------------
|
|
@@ -9229,5 +10500,5 @@ add_builtin(
|
|
|
9229
10500
|
export=False,
|
|
9230
10501
|
group="Utility",
|
|
9231
10502
|
hidden=True,
|
|
9232
|
-
|
|
10503
|
+
is_differentiable=False,
|
|
9233
10504
|
)
|