warp-lang 1.9.0__py3-none-manylinux_2_34_aarch64.whl → 1.10.0__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +301 -287
- warp/__init__.pyi +2302 -307
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1077 -0
- warp/_src/build.py +620 -0
- warp/_src/build_dll.py +642 -0
- warp/{builtins.py → _src/builtins.py} +1546 -224
- warp/_src/codegen.py +4361 -0
- warp/{config.py → _src/config.py} +178 -169
- warp/_src/constants.py +59 -0
- warp/_src/context.py +8352 -0
- warp/_src/dlpack.py +464 -0
- warp/_src/fabric.py +362 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +510 -0
- warp/_src/fem/cache.py +689 -0
- warp/_src/fem/dirichlet.py +190 -0
- warp/{fem → _src/fem}/domain.py +42 -30
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +703 -0
- warp/{fem → _src/fem}/field/nodal_field.py +32 -15
- warp/{fem → _src/fem}/field/restriction.py +3 -1
- warp/{fem → _src/fem}/field/virtual.py +55 -27
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +79 -163
- warp/_src/fem/geometry/closest_point.py +99 -0
- warp/{fem → _src/fem}/geometry/deformed_geometry.py +16 -22
- warp/{fem → _src/fem}/geometry/element.py +34 -10
- warp/{fem → _src/fem}/geometry/geometry.py +50 -20
- warp/{fem → _src/fem}/geometry/grid_2d.py +14 -23
- warp/{fem → _src/fem}/geometry/grid_3d.py +14 -23
- warp/{fem → _src/fem}/geometry/hexmesh.py +42 -63
- warp/{fem → _src/fem}/geometry/nanogrid.py +256 -247
- warp/{fem → _src/fem}/geometry/partition.py +123 -63
- warp/{fem → _src/fem}/geometry/quadmesh.py +28 -45
- warp/{fem → _src/fem}/geometry/tetmesh.py +42 -63
- warp/{fem → _src/fem}/geometry/trimesh.py +28 -45
- warp/{fem → _src/fem}/integrate.py +166 -158
- warp/_src/fem/linalg.py +385 -0
- warp/_src/fem/operator.py +398 -0
- warp/_src/fem/polynomial.py +231 -0
- warp/{fem → _src/fem}/quadrature/pic_quadrature.py +17 -20
- warp/{fem → _src/fem}/quadrature/quadrature.py +97 -47
- warp/_src/fem/space/__init__.py +248 -0
- warp/{fem → _src/fem}/space/basis_function_space.py +22 -11
- warp/_src/fem/space/basis_space.py +681 -0
- warp/{fem → _src/fem}/space/dof_mapper.py +5 -3
- warp/{fem → _src/fem}/space/function_space.py +16 -13
- warp/{fem → _src/fem}/space/grid_2d_function_space.py +6 -7
- warp/{fem → _src/fem}/space/grid_3d_function_space.py +6 -4
- warp/{fem → _src/fem}/space/hexmesh_function_space.py +6 -10
- warp/{fem → _src/fem}/space/nanogrid_function_space.py +5 -9
- warp/{fem → _src/fem}/space/partition.py +119 -60
- warp/{fem → _src/fem}/space/quadmesh_function_space.py +6 -10
- warp/{fem → _src/fem}/space/restriction.py +68 -33
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/{fem → _src/fem}/space/shape/cube_shape_function.py +11 -9
- warp/{fem → _src/fem}/space/shape/shape_function.py +10 -9
- warp/{fem → _src/fem}/space/shape/square_shape_function.py +8 -6
- warp/{fem → _src/fem}/space/shape/tet_shape_function.py +5 -3
- warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +5 -3
- warp/{fem → _src/fem}/space/tetmesh_function_space.py +5 -9
- warp/_src/fem/space/topology.py +461 -0
- warp/{fem → _src/fem}/space/trimesh_function_space.py +5 -9
- warp/_src/fem/types.py +114 -0
- warp/_src/fem/utils.py +488 -0
- warp/_src/jax.py +188 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +389 -0
- warp/_src/jax_experimental/ffi.py +1286 -0
- warp/_src/jax_experimental/xla_ffi.py +658 -0
- warp/_src/marching_cubes.py +710 -0
- warp/_src/math.py +416 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +165 -0
- warp/_src/optim/linear.py +1608 -0
- warp/_src/optim/sgd.py +114 -0
- warp/_src/paddle.py +408 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +291 -0
- warp/_src/render/render_opengl.py +3638 -0
- warp/_src/render/render_usd.py +939 -0
- warp/_src/render/utils.py +162 -0
- warp/_src/sparse.py +2718 -0
- warp/_src/tape.py +1208 -0
- warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
- warp/_src/torch.py +393 -0
- warp/_src/types.py +5888 -0
- warp/_src/utils.py +1695 -0
- warp/autograd.py +12 -1054
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +8 -588
- warp/build_dll.py +6 -471
- warp/codegen.py +6 -4246
- warp/constants.py +6 -39
- warp/context.py +12 -7851
- warp/dlpack.py +6 -444
- warp/examples/distributed/example_jacobi_mpi.py +4 -5
- warp/examples/fem/example_adaptive_grid.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +8 -8
- warp/examples/fem/example_diffusion.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_mixed_elasticity.py +2 -2
- warp/examples/fem/example_navier_stokes.py +1 -1
- warp/examples/fem/example_nonconforming_contact.py +7 -7
- warp/examples/fem/example_stokes.py +1 -1
- warp/examples/fem/example_stokes_transfer.py +1 -1
- warp/examples/fem/utils.py +2 -2
- warp/examples/interop/example_jax_callable.py +1 -1
- warp/examples/interop/example_jax_ffi_callback.py +1 -1
- warp/examples/interop/example_jax_kernel.py +3 -2
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/fabric.py +6 -337
- warp/fem/__init__.py +159 -97
- warp/fem/adaptivity.py +7 -489
- warp/fem/cache.py +9 -648
- warp/fem/dirichlet.py +6 -184
- warp/fem/field/__init__.py +8 -109
- warp/fem/field/field.py +7 -652
- warp/fem/geometry/__init__.py +7 -18
- warp/fem/geometry/closest_point.py +11 -77
- warp/fem/linalg.py +18 -366
- warp/fem/operator.py +11 -369
- warp/fem/polynomial.py +9 -209
- warp/fem/space/__init__.py +5 -211
- warp/fem/space/basis_space.py +6 -662
- warp/fem/space/shape/__init__.py +41 -118
- warp/fem/space/topology.py +6 -437
- warp/fem/types.py +6 -81
- warp/fem/utils.py +11 -444
- warp/jax.py +8 -165
- warp/jax_experimental/__init__.py +14 -1
- warp/jax_experimental/custom_call.py +8 -342
- warp/jax_experimental/ffi.py +17 -853
- warp/jax_experimental/xla_ffi.py +5 -596
- warp/marching_cubes.py +5 -689
- warp/math.py +16 -393
- warp/native/array.h +385 -37
- warp/native/builtin.h +316 -39
- warp/native/bvh.cpp +43 -9
- warp/native/bvh.cu +62 -27
- warp/native/bvh.h +310 -309
- warp/native/clang/clang.cpp +102 -97
- warp/native/coloring.cpp +0 -1
- warp/native/crt.h +208 -0
- warp/native/exports.h +156 -0
- warp/native/hashgrid.cu +2 -0
- warp/native/intersect.h +24 -1
- warp/native/intersect_tri.h +44 -35
- warp/native/mat.h +1456 -276
- warp/native/mesh.cpp +4 -4
- warp/native/mesh.cu +4 -2
- warp/native/mesh.h +176 -61
- warp/native/quat.h +0 -52
- warp/native/scan.cu +2 -0
- warp/native/sort.cu +22 -13
- warp/native/sort.h +2 -0
- warp/native/sparse.cu +7 -3
- warp/native/spatial.h +12 -0
- warp/native/tile.h +837 -70
- warp/native/tile_radix_sort.h +3 -3
- warp/native/tile_reduce.h +394 -46
- warp/native/tile_scan.h +4 -4
- warp/native/vec.h +469 -53
- warp/native/version.h +23 -0
- warp/native/volume.cpp +1 -1
- warp/native/volume.cu +1 -0
- warp/native/volume.h +1 -1
- warp/native/volume_builder.cu +2 -0
- warp/native/warp.cpp +60 -32
- warp/native/warp.cu +581 -280
- warp/native/warp.h +14 -11
- warp/optim/__init__.py +6 -3
- warp/optim/adam.py +6 -145
- warp/optim/linear.py +14 -1585
- warp/optim/sgd.py +6 -94
- warp/paddle.py +6 -388
- warp/render/__init__.py +8 -4
- warp/render/imgui_manager.py +7 -267
- warp/render/render_opengl.py +6 -3616
- warp/render/render_usd.py +6 -918
- warp/render/utils.py +6 -142
- warp/sparse.py +37 -2563
- warp/tape.py +6 -1188
- warp/tests/__main__.py +1 -1
- warp/tests/cuda/test_async.py +4 -4
- warp/tests/cuda/test_conditional_captures.py +1 -1
- warp/tests/cuda/test_multigpu.py +1 -1
- warp/tests/cuda/test_streams.py +58 -1
- warp/tests/geometry/test_bvh.py +157 -22
- warp/tests/geometry/test_hash_grid.py +38 -0
- warp/tests/geometry/test_marching_cubes.py +0 -1
- warp/tests/geometry/test_mesh.py +5 -3
- warp/tests/geometry/test_mesh_query_aabb.py +5 -12
- warp/tests/geometry/test_mesh_query_point.py +5 -2
- warp/tests/geometry/test_mesh_query_ray.py +15 -3
- warp/tests/geometry/test_volume_write.py +5 -5
- warp/tests/interop/test_dlpack.py +18 -17
- warp/tests/interop/test_jax.py +1382 -79
- warp/tests/interop/test_paddle.py +1 -1
- warp/tests/test_adam.py +0 -1
- warp/tests/test_arithmetic.py +9 -9
- warp/tests/test_array.py +580 -100
- warp/tests/test_array_reduce.py +3 -3
- warp/tests/test_atomic.py +12 -8
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +4 -4
- warp/tests/test_bool.py +2 -2
- warp/tests/test_builtins_resolution.py +5 -571
- warp/tests/test_codegen.py +34 -15
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_context.py +6 -6
- warp/tests/test_copy.py +242 -161
- warp/tests/test_ctypes.py +3 -3
- warp/tests/test_devices.py +24 -2
- warp/tests/test_examples.py +16 -84
- warp/tests/test_fabricarray.py +35 -35
- warp/tests/test_fast_math.py +0 -2
- warp/tests/test_fem.py +60 -14
- warp/tests/test_fixedarray.py +3 -3
- warp/tests/test_func.py +8 -5
- warp/tests/test_generics.py +1 -1
- warp/tests/test_indexedarray.py +24 -24
- warp/tests/test_intersect.py +39 -9
- warp/tests/test_large.py +1 -1
- warp/tests/test_lerp.py +3 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_map.py +49 -4
- warp/tests/test_mat.py +52 -62
- warp/tests/test_mat_constructors.py +4 -5
- warp/tests/test_mat_lite.py +1 -1
- warp/tests/test_mat_scalar_ops.py +121 -121
- warp/tests/test_math.py +34 -0
- warp/tests/test_module_aot.py +4 -4
- warp/tests/test_modules_lite.py +28 -2
- warp/tests/test_print.py +11 -11
- warp/tests/test_quat.py +93 -58
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +38 -10
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +126 -15
- warp/tests/test_spatial.py +105 -87
- warp/tests/test_special_values.py +6 -6
- warp/tests/test_static.py +7 -7
- warp/tests/test_struct.py +13 -2
- warp/tests/test_triangle_closest_point.py +48 -1
- warp/tests/test_tuple.py +96 -0
- warp/tests/test_types.py +82 -9
- warp/tests/test_utils.py +52 -52
- warp/tests/test_vec.py +29 -29
- warp/tests/test_vec_constructors.py +5 -5
- warp/tests/test_vec_scalar_ops.py +97 -97
- warp/tests/test_version.py +75 -0
- warp/tests/tile/test_tile.py +239 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +7 -4
- warp/tests/tile/test_tile_load.py +26 -2
- warp/tests/tile/test_tile_mathdx.py +3 -3
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +2 -4
- warp/tests/tile/test_tile_reduce.py +214 -13
- warp/tests/unittest_suites.py +6 -14
- warp/tests/unittest_utils.py +10 -9
- warp/tests/walkthrough_debug.py +3 -1
- warp/torch.py +6 -373
- warp/types.py +29 -5750
- warp/utils.py +10 -1659
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/METADATA +47 -103
- warp_lang-1.10.0.dist-info/RECORD +468 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp/examples/assets/cartpole.urdf +0 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/nv_ant.xml +0 -92
- warp/examples/assets/nv_humanoid.xml +0 -183
- warp/examples/assets/quadruped.urdf +0 -268
- warp/examples/optim/example_bounce.py +0 -266
- warp/examples/optim/example_cloth_throw.py +0 -228
- warp/examples/optim/example_drone.py +0 -870
- warp/examples/optim/example_inverse_kinematics.py +0 -182
- warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
- warp/examples/optim/example_softbody_properties.py +0 -400
- warp/examples/optim/example_spring_cage.py +0 -245
- warp/examples/optim/example_trajectory.py +0 -227
- warp/examples/sim/example_cartpole.py +0 -143
- warp/examples/sim/example_cloth.py +0 -225
- warp/examples/sim/example_cloth_self_contact.py +0 -316
- warp/examples/sim/example_granular.py +0 -130
- warp/examples/sim/example_granular_collision_sdf.py +0 -202
- warp/examples/sim/example_jacobian_ik.py +0 -244
- warp/examples/sim/example_particle_chain.py +0 -124
- warp/examples/sim/example_quadruped.py +0 -203
- warp/examples/sim/example_rigid_chain.py +0 -203
- warp/examples/sim/example_rigid_contact.py +0 -195
- warp/examples/sim/example_rigid_force.py +0 -133
- warp/examples/sim/example_rigid_gyroscopic.py +0 -115
- warp/examples/sim/example_rigid_soft_contact.py +0 -140
- warp/examples/sim/example_soft_body.py +0 -196
- warp/examples/tile/example_tile_walker.py +0 -327
- warp/sim/__init__.py +0 -74
- warp/sim/articulation.py +0 -793
- warp/sim/collide.py +0 -2570
- warp/sim/graph_coloring.py +0 -307
- warp/sim/import_mjcf.py +0 -791
- warp/sim/import_snu.py +0 -227
- warp/sim/import_urdf.py +0 -579
- warp/sim/import_usd.py +0 -898
- warp/sim/inertia.py +0 -357
- warp/sim/integrator.py +0 -245
- warp/sim/integrator_euler.py +0 -2000
- warp/sim/integrator_featherstone.py +0 -2101
- warp/sim/integrator_vbd.py +0 -2487
- warp/sim/integrator_xpbd.py +0 -3295
- warp/sim/model.py +0 -4821
- warp/sim/particles.py +0 -121
- warp/sim/render.py +0 -431
- warp/sim/utils.py +0 -431
- warp/tests/sim/disabled_kinematics.py +0 -244
- warp/tests/sim/test_cloth.py +0 -863
- warp/tests/sim/test_collision.py +0 -743
- warp/tests/sim/test_coloring.py +0 -347
- warp/tests/sim/test_inertia.py +0 -161
- warp/tests/sim/test_model.py +0 -226
- warp/tests/sim/test_sim_grad.py +0 -287
- warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
- warp/tests/sim/test_sim_kinematics.py +0 -98
- warp/thirdparty/__init__.py +0 -0
- warp_lang-1.9.0.dist-info/RECORD +0 -456
- /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
- /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/top_level.txt +0 -0
|
@@ -20,14 +20,16 @@ import functools
|
|
|
20
20
|
import math
|
|
21
21
|
from typing import Any, Callable, Mapping, Sequence
|
|
22
22
|
|
|
23
|
-
import warp.build
|
|
24
|
-
import warp.context
|
|
25
|
-
import warp.utils
|
|
26
|
-
from warp.codegen import Reference, Var, get_arg_value, strip_reference
|
|
27
|
-
from warp.types import *
|
|
23
|
+
import warp._src.build
|
|
24
|
+
import warp._src.context
|
|
25
|
+
import warp._src.utils
|
|
26
|
+
from warp._src.codegen import Reference, Var, get_arg_value, strip_reference
|
|
27
|
+
from warp._src.types import *
|
|
28
28
|
|
|
29
29
|
from .context import add_builtin
|
|
30
30
|
|
|
31
|
+
_wp_module_name_ = "warp.builtins"
|
|
32
|
+
|
|
31
33
|
|
|
32
34
|
def seq_check_equal(seq_1, seq_2):
|
|
33
35
|
if not isinstance(seq_1, Sequence) or not isinstance(seq_2, Sequence):
|
|
@@ -61,11 +63,11 @@ def sametypes_create_value_func(default: TypeVar):
|
|
|
61
63
|
|
|
62
64
|
def extract_tuple(arg, as_constant=False):
|
|
63
65
|
if isinstance(arg, Var):
|
|
64
|
-
if isinstance(arg.type, warp.types.tuple_t):
|
|
66
|
+
if isinstance(arg.type, warp._src.types.tuple_t):
|
|
65
67
|
out = arg.type.values
|
|
66
68
|
else:
|
|
67
69
|
out = (arg,)
|
|
68
|
-
elif isinstance(arg, warp.types.tuple_t):
|
|
70
|
+
elif isinstance(arg, warp._src.types.tuple_t):
|
|
69
71
|
out = arg.values
|
|
70
72
|
elif not isinstance(arg, Sequence):
|
|
71
73
|
out = (arg,)
|
|
@@ -82,7 +84,7 @@ def static_len_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str
|
|
|
82
84
|
if arg_types is None:
|
|
83
85
|
return int
|
|
84
86
|
|
|
85
|
-
length = warp.types.type_length(arg_types["a"])
|
|
87
|
+
length = warp._src.types.type_length(arg_types["a"])
|
|
86
88
|
return Var(None, type=int, constant=length)
|
|
87
89
|
|
|
88
90
|
|
|
@@ -126,6 +128,7 @@ add_builtin(
|
|
|
126
128
|
value_func=sametypes_create_value_func(Scalar),
|
|
127
129
|
doc="Return -1 if ``x`` < 0, return 1 otherwise.",
|
|
128
130
|
group="Scalar Math",
|
|
131
|
+
is_differentiable=False,
|
|
129
132
|
)
|
|
130
133
|
|
|
131
134
|
add_builtin(
|
|
@@ -134,6 +137,7 @@ add_builtin(
|
|
|
134
137
|
value_func=sametypes_create_value_func(Scalar),
|
|
135
138
|
doc="Return 1.0 if ``x`` < 0.0, return 0.0 otherwise.",
|
|
136
139
|
group="Scalar Math",
|
|
140
|
+
is_differentiable=False,
|
|
137
141
|
)
|
|
138
142
|
add_builtin(
|
|
139
143
|
"nonzero",
|
|
@@ -141,6 +145,7 @@ add_builtin(
|
|
|
141
145
|
value_func=sametypes_create_value_func(Scalar),
|
|
142
146
|
doc="Return 1.0 if ``x`` is not equal to zero, return 0.0 otherwise.",
|
|
143
147
|
group="Scalar Math",
|
|
148
|
+
is_differentiable=False,
|
|
144
149
|
)
|
|
145
150
|
|
|
146
151
|
add_builtin(
|
|
@@ -282,7 +287,36 @@ add_builtin(
|
|
|
282
287
|
group="Scalar Math",
|
|
283
288
|
require_original_output_arg=True,
|
|
284
289
|
)
|
|
285
|
-
|
|
290
|
+
add_builtin(
|
|
291
|
+
"erf",
|
|
292
|
+
input_types={"x": Float},
|
|
293
|
+
value_func=sametypes_create_value_func(Float),
|
|
294
|
+
doc="Return the error function of ``x``.",
|
|
295
|
+
group="Scalar Math",
|
|
296
|
+
)
|
|
297
|
+
add_builtin(
|
|
298
|
+
"erfc",
|
|
299
|
+
input_types={"x": Float},
|
|
300
|
+
value_func=sametypes_create_value_func(Float),
|
|
301
|
+
doc="Return the complementary error function of ``x``.",
|
|
302
|
+
group="Scalar Math",
|
|
303
|
+
)
|
|
304
|
+
add_builtin(
|
|
305
|
+
"erfinv",
|
|
306
|
+
input_types={"x": Float},
|
|
307
|
+
value_func=sametypes_create_value_func(Float),
|
|
308
|
+
doc="Return the inverse error function of ``x``.",
|
|
309
|
+
group="Scalar Math",
|
|
310
|
+
require_original_output_arg=True,
|
|
311
|
+
)
|
|
312
|
+
add_builtin(
|
|
313
|
+
"erfcinv",
|
|
314
|
+
input_types={"x": Float},
|
|
315
|
+
value_func=sametypes_create_value_func(Float),
|
|
316
|
+
doc="Return the inverse complementary error function of ``x``.",
|
|
317
|
+
group="Scalar Math",
|
|
318
|
+
require_original_output_arg=True,
|
|
319
|
+
)
|
|
286
320
|
add_builtin(
|
|
287
321
|
"round",
|
|
288
322
|
input_types={"x": Float},
|
|
@@ -292,6 +326,7 @@ add_builtin(
|
|
|
292
326
|
|
|
293
327
|
This is the most intuitive form of rounding in the colloquial sense, but can be slower than other options like :func:`warp.rint()`.
|
|
294
328
|
Differs from :func:`numpy.round()`, which behaves the same way as :func:`numpy.rint()`.""",
|
|
329
|
+
is_differentiable=False,
|
|
295
330
|
)
|
|
296
331
|
|
|
297
332
|
add_builtin(
|
|
@@ -302,6 +337,7 @@ add_builtin(
|
|
|
302
337
|
doc="""Return the nearest integer value to ``x``, rounding halfway cases to nearest even integer.
|
|
303
338
|
|
|
304
339
|
It is generally faster than :func:`warp.round()`. Equivalent to :func:`numpy.rint()`.""",
|
|
340
|
+
is_differentiable=False,
|
|
305
341
|
)
|
|
306
342
|
|
|
307
343
|
add_builtin(
|
|
@@ -314,6 +350,7 @@ add_builtin(
|
|
|
314
350
|
In other words, it discards the fractional part of ``x``.
|
|
315
351
|
It is similar to casting ``float(int(a))``, but preserves the negative sign when ``x`` is in the range [-0.0, -1.0).
|
|
316
352
|
Equivalent to :func:`numpy.trunc()` and :func:`numpy.fix()`.""",
|
|
353
|
+
is_differentiable=False,
|
|
317
354
|
)
|
|
318
355
|
|
|
319
356
|
add_builtin(
|
|
@@ -322,6 +359,7 @@ add_builtin(
|
|
|
322
359
|
value_func=sametypes_create_value_func(Float),
|
|
323
360
|
group="Scalar Math",
|
|
324
361
|
doc="""Return the largest integer that is less than or equal to ``x``.""",
|
|
362
|
+
is_differentiable=False,
|
|
325
363
|
)
|
|
326
364
|
|
|
327
365
|
add_builtin(
|
|
@@ -330,6 +368,7 @@ add_builtin(
|
|
|
330
368
|
value_func=sametypes_create_value_func(Float),
|
|
331
369
|
group="Scalar Math",
|
|
332
370
|
doc="""Return the smallest integer that is greater than or equal to ``x``.""",
|
|
371
|
+
is_differentiable=False,
|
|
333
372
|
)
|
|
334
373
|
|
|
335
374
|
add_builtin(
|
|
@@ -340,6 +379,7 @@ add_builtin(
|
|
|
340
379
|
doc="""Retrieve the fractional part of ``x``.
|
|
341
380
|
|
|
342
381
|
In other words, it discards the integer part of ``x`` and is equivalent to ``x - trunc(x)``.""",
|
|
382
|
+
is_differentiable=False,
|
|
343
383
|
)
|
|
344
384
|
|
|
345
385
|
add_builtin(
|
|
@@ -348,6 +388,7 @@ add_builtin(
|
|
|
348
388
|
value_type=builtins.bool,
|
|
349
389
|
group="Scalar Math",
|
|
350
390
|
doc="""Return ``True`` if ``a`` is a finite number, otherwise return ``False``.""",
|
|
391
|
+
is_differentiable=False,
|
|
351
392
|
)
|
|
352
393
|
add_builtin(
|
|
353
394
|
"isfinite",
|
|
@@ -355,6 +396,7 @@ add_builtin(
|
|
|
355
396
|
value_type=builtins.bool,
|
|
356
397
|
group="Vector Math",
|
|
357
398
|
doc="Return ``True`` if all elements of the vector ``a`` are finite, otherwise return ``False``.",
|
|
399
|
+
is_differentiable=False,
|
|
358
400
|
)
|
|
359
401
|
add_builtin(
|
|
360
402
|
"isfinite",
|
|
@@ -362,6 +404,7 @@ add_builtin(
|
|
|
362
404
|
value_type=builtins.bool,
|
|
363
405
|
group="Vector Math",
|
|
364
406
|
doc="Return ``True`` if all elements of the quaternion ``a`` are finite, otherwise return ``False``.",
|
|
407
|
+
is_differentiable=False,
|
|
365
408
|
)
|
|
366
409
|
add_builtin(
|
|
367
410
|
"isfinite",
|
|
@@ -369,6 +412,7 @@ add_builtin(
|
|
|
369
412
|
value_type=builtins.bool,
|
|
370
413
|
group="Vector Math",
|
|
371
414
|
doc="Return ``True`` if all elements of the matrix ``a`` are finite, otherwise return ``False``.",
|
|
415
|
+
is_differentiable=False,
|
|
372
416
|
)
|
|
373
417
|
|
|
374
418
|
add_builtin(
|
|
@@ -377,6 +421,7 @@ add_builtin(
|
|
|
377
421
|
value_type=builtins.bool,
|
|
378
422
|
doc="Return ``True`` if ``a`` is NaN, otherwise return ``False``.",
|
|
379
423
|
group="Scalar Math",
|
|
424
|
+
is_differentiable=False,
|
|
380
425
|
)
|
|
381
426
|
add_builtin(
|
|
382
427
|
"isnan",
|
|
@@ -384,6 +429,7 @@ add_builtin(
|
|
|
384
429
|
value_type=builtins.bool,
|
|
385
430
|
group="Vector Math",
|
|
386
431
|
doc="Return ``True`` if any element of the vector ``a`` is NaN, otherwise return ``False``.",
|
|
432
|
+
is_differentiable=False,
|
|
387
433
|
)
|
|
388
434
|
add_builtin(
|
|
389
435
|
"isnan",
|
|
@@ -391,6 +437,7 @@ add_builtin(
|
|
|
391
437
|
value_type=builtins.bool,
|
|
392
438
|
group="Vector Math",
|
|
393
439
|
doc="Return ``True`` if any element of the quaternion ``a`` is NaN, otherwise return ``False``.",
|
|
440
|
+
is_differentiable=False,
|
|
394
441
|
)
|
|
395
442
|
add_builtin(
|
|
396
443
|
"isnan",
|
|
@@ -398,6 +445,7 @@ add_builtin(
|
|
|
398
445
|
value_type=builtins.bool,
|
|
399
446
|
group="Vector Math",
|
|
400
447
|
doc="Return ``True`` if any element of the matrix ``a`` is NaN, otherwise return ``False``.",
|
|
448
|
+
is_differentiable=False,
|
|
401
449
|
)
|
|
402
450
|
|
|
403
451
|
add_builtin(
|
|
@@ -406,6 +454,7 @@ add_builtin(
|
|
|
406
454
|
value_type=builtins.bool,
|
|
407
455
|
group="Scalar Math",
|
|
408
456
|
doc="""Return ``True`` if ``a`` is positive or negative infinity, otherwise return ``False``.""",
|
|
457
|
+
is_differentiable=False,
|
|
409
458
|
)
|
|
410
459
|
add_builtin(
|
|
411
460
|
"isinf",
|
|
@@ -413,6 +462,7 @@ add_builtin(
|
|
|
413
462
|
value_type=builtins.bool,
|
|
414
463
|
group="Vector Math",
|
|
415
464
|
doc="Return ``True`` if any element of the vector ``a`` is positive or negative infinity, otherwise return ``False``.",
|
|
465
|
+
is_differentiable=False,
|
|
416
466
|
)
|
|
417
467
|
add_builtin(
|
|
418
468
|
"isinf",
|
|
@@ -420,6 +470,7 @@ add_builtin(
|
|
|
420
470
|
value_type=builtins.bool,
|
|
421
471
|
group="Vector Math",
|
|
422
472
|
doc="Return ``True`` if any element of the quaternion ``a`` is positive or negative infinity, otherwise return ``False``.",
|
|
473
|
+
is_differentiable=False,
|
|
423
474
|
)
|
|
424
475
|
add_builtin(
|
|
425
476
|
"isinf",
|
|
@@ -427,6 +478,7 @@ add_builtin(
|
|
|
427
478
|
value_type=builtins.bool,
|
|
428
479
|
group="Vector Math",
|
|
429
480
|
doc="Return ``True`` if any element of the matrix ``a`` is positive or negative infinity, otherwise return ``False``.",
|
|
481
|
+
is_differentiable=False,
|
|
430
482
|
)
|
|
431
483
|
|
|
432
484
|
|
|
@@ -534,7 +586,7 @@ add_builtin(
|
|
|
534
586
|
value_func=lambda arg_types, arg_values: warp.uint32,
|
|
535
587
|
doc="Return the index of the minimum element of a vector ``a``.",
|
|
536
588
|
group="Vector Math",
|
|
537
|
-
|
|
589
|
+
is_differentiable=False,
|
|
538
590
|
)
|
|
539
591
|
add_builtin(
|
|
540
592
|
"argmax",
|
|
@@ -542,7 +594,7 @@ add_builtin(
|
|
|
542
594
|
value_func=lambda arg_types, arg_values: warp.uint32,
|
|
543
595
|
doc="Return the index of the maximum element of a vector ``a``.",
|
|
544
596
|
group="Vector Math",
|
|
545
|
-
|
|
597
|
+
is_differentiable=False,
|
|
546
598
|
)
|
|
547
599
|
|
|
548
600
|
add_builtin(
|
|
@@ -867,7 +919,7 @@ def vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
867
919
|
|
|
868
920
|
if dtype is None:
|
|
869
921
|
dtype = value_type
|
|
870
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
922
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
871
923
|
raise RuntimeError(
|
|
872
924
|
f"the value used to fill this vector is expected to be of the type `{dtype.__name__}`"
|
|
873
925
|
)
|
|
@@ -888,7 +940,7 @@ def vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
888
940
|
|
|
889
941
|
if dtype is None:
|
|
890
942
|
dtype = value_type
|
|
891
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
943
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
892
944
|
raise RuntimeError(
|
|
893
945
|
f"all values used to initialize this vector are expected to be of the type `{dtype.__name__}`"
|
|
894
946
|
)
|
|
@@ -971,7 +1023,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
971
1023
|
|
|
972
1024
|
if dtype is None:
|
|
973
1025
|
dtype = value_type
|
|
974
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1026
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
975
1027
|
raise RuntimeError(
|
|
976
1028
|
f"the value used to fill this matrix is expected to be of the type `{dtype.__name__}`"
|
|
977
1029
|
)
|
|
@@ -981,7 +1033,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
981
1033
|
raise RuntimeError("the `shape` argument must be specified when initializing a matrix by value")
|
|
982
1034
|
|
|
983
1035
|
if all(type_is_vector(x) for x in variadic_arg_types):
|
|
984
|
-
warp.utils.warn(
|
|
1036
|
+
warp._src.utils.warn(
|
|
985
1037
|
"the built-in `wp.matrix()` won't support taking column vectors as input "
|
|
986
1038
|
"in the future. Use `wp.matrix_from_rows()` or `wp.matrix_from_cols()` instead.",
|
|
987
1039
|
DeprecationWarning,
|
|
@@ -1010,7 +1062,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
1010
1062
|
|
|
1011
1063
|
if dtype is None:
|
|
1012
1064
|
dtype = value_type
|
|
1013
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1065
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
1014
1066
|
raise RuntimeError(
|
|
1015
1067
|
f"all values used to initialize this matrix are expected to be of the type `{dtype.__name__}`"
|
|
1016
1068
|
)
|
|
@@ -1182,48 +1234,18 @@ add_builtin(
|
|
|
1182
1234
|
doc="Create an identity matrix with shape=(n,n) with the type given by ``dtype``.",
|
|
1183
1235
|
group="Vector Math",
|
|
1184
1236
|
export=False,
|
|
1237
|
+
is_differentiable=False,
|
|
1185
1238
|
)
|
|
1186
1239
|
|
|
1187
1240
|
|
|
1188
1241
|
def matrix_transform_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1189
|
-
warp.utils.warn(
|
|
1190
|
-
"the built-in `wp.matrix()` function to construct a 4x4 matrix from a 3D position, quaternion, "
|
|
1191
|
-
"and 3D scale vector will be deprecated in favor of `wp.transform_compose()`.",
|
|
1192
|
-
DeprecationWarning,
|
|
1193
|
-
)
|
|
1194
1242
|
if arg_types is None:
|
|
1195
1243
|
return matrix(shape=(4, 4), dtype=Float)
|
|
1196
1244
|
|
|
1197
|
-
|
|
1198
|
-
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
value_type = scalar_infer_type(value_arg_types)
|
|
1202
|
-
except RuntimeError:
|
|
1203
|
-
raise RuntimeError(
|
|
1204
|
-
"all values given when constructing a transformation matrix must have the same type"
|
|
1205
|
-
) from None
|
|
1206
|
-
|
|
1207
|
-
if dtype is None:
|
|
1208
|
-
dtype = value_type
|
|
1209
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1210
|
-
raise RuntimeError(
|
|
1211
|
-
f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
|
|
1212
|
-
)
|
|
1213
|
-
|
|
1214
|
-
return matrix(shape=(4, 4), dtype=dtype)
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
def matrix_transform_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
1218
|
-
# We're in the codegen stage where we emit the code calling the built-in.
|
|
1219
|
-
# Further validate the given argument values if needed and map them
|
|
1220
|
-
# to the underlying C++ function's runtime and template params.
|
|
1221
|
-
|
|
1222
|
-
dtype = return_type._wp_scalar_type_
|
|
1223
|
-
|
|
1224
|
-
func_args = tuple(v for k, v in args.items() if k != "dtype")
|
|
1225
|
-
template_args = (4, 4, dtype)
|
|
1226
|
-
return (func_args, template_args)
|
|
1245
|
+
raise RuntimeError(
|
|
1246
|
+
"the built-in `wp.matrix()` to construct a 4x4 matrix from a 3D position, quaternion, "
|
|
1247
|
+
"and 3D scale vector has been removed in favor of `wp.transform_compose()`."
|
|
1248
|
+
)
|
|
1227
1249
|
|
|
1228
1250
|
|
|
1229
1251
|
add_builtin(
|
|
@@ -1237,13 +1259,14 @@ add_builtin(
|
|
|
1237
1259
|
defaults={"dtype": None},
|
|
1238
1260
|
value_func=matrix_transform_value_func,
|
|
1239
1261
|
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
1240
|
-
dispatch_func=matrix_transform_dispatch_func,
|
|
1241
1262
|
native_func="mat_t",
|
|
1242
1263
|
doc="""Construct a 4x4 transformation matrix that applies the transformations as
|
|
1243
1264
|
Translation(pos)*Rotation(rot)*Scaling(scale) when applied to column vectors, i.e.: y = (TRS)*x
|
|
1244
1265
|
|
|
1245
|
-
..
|
|
1246
|
-
This function has been
|
|
1266
|
+
.. versionremoved:: 1.10
|
|
1267
|
+
This function has been removed in favor of :func:`warp.math.transform_compose()`.
|
|
1268
|
+
|
|
1269
|
+
.. deprecated:: 1.8""",
|
|
1247
1270
|
group="Vector Math",
|
|
1248
1271
|
export=False,
|
|
1249
1272
|
)
|
|
@@ -1438,7 +1461,7 @@ def quaternion_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str
|
|
|
1438
1461
|
|
|
1439
1462
|
if dtype is None:
|
|
1440
1463
|
dtype = value_type
|
|
1441
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1464
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
1442
1465
|
raise RuntimeError(
|
|
1443
1466
|
f"all values used to initialize this quaternion are expected to be of the type `{dtype.__name__}`"
|
|
1444
1467
|
)
|
|
@@ -1546,6 +1569,7 @@ add_builtin(
|
|
|
1546
1569
|
group="Quaternion Math",
|
|
1547
1570
|
doc="Construct an identity quaternion with zero imaginary part and real part of 1.0",
|
|
1548
1571
|
export=True,
|
|
1572
|
+
is_differentiable=False,
|
|
1549
1573
|
)
|
|
1550
1574
|
|
|
1551
1575
|
add_builtin(
|
|
@@ -1674,7 +1698,7 @@ def transformation_value_func(arg_types: Mapping[str, type], arg_values: Mapping
|
|
|
1674
1698
|
value_type = strip_reference(variadic_arg_types[0])
|
|
1675
1699
|
if dtype is None:
|
|
1676
1700
|
dtype = value_type
|
|
1677
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1701
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
1678
1702
|
raise RuntimeError(
|
|
1679
1703
|
f"the value used to fill this transform is expected to be of the type `{dtype.__name__}`"
|
|
1680
1704
|
)
|
|
@@ -1687,7 +1711,7 @@ def transformation_value_func(arg_types: Mapping[str, type], arg_values: Mapping
|
|
|
1687
1711
|
|
|
1688
1712
|
if dtype is None:
|
|
1689
1713
|
dtype = value_type
|
|
1690
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1714
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
1691
1715
|
raise RuntimeError(
|
|
1692
1716
|
f"all values used to initialize this transform are expected to be of the type `{dtype.__name__}`"
|
|
1693
1717
|
)
|
|
@@ -1712,7 +1736,7 @@ def transformation_pq_value_func(arg_types: Mapping[str, type], arg_values: Mapp
|
|
|
1712
1736
|
dtype = arg_values.get("dtype", None)
|
|
1713
1737
|
if dtype is None:
|
|
1714
1738
|
dtype = value_type
|
|
1715
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1739
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
1716
1740
|
raise RuntimeError(
|
|
1717
1741
|
f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
|
|
1718
1742
|
)
|
|
@@ -1727,9 +1751,19 @@ def transformation_dispatch_func(input_types: Mapping[str, type], return_type: A
|
|
|
1727
1751
|
|
|
1728
1752
|
dtype = return_type._wp_scalar_type_
|
|
1729
1753
|
|
|
1730
|
-
variadic_args =
|
|
1754
|
+
variadic_args = args.get("args", ())
|
|
1755
|
+
variadic_arg_count = len(variadic_args)
|
|
1756
|
+
|
|
1757
|
+
if variadic_arg_count == 7:
|
|
1758
|
+
func_args = variadic_args
|
|
1759
|
+
else:
|
|
1760
|
+
func_args = tuple(v for k, v in args.items() if k != "dtype")
|
|
1761
|
+
if "p" in args and "q" not in args:
|
|
1762
|
+
quat_ident = warp._src.codegen.Var(
|
|
1763
|
+
label=None, type=quaternion(dtype=dtype), constant=quaternion(dtype=dtype)(0, 0, 0, 1)
|
|
1764
|
+
)
|
|
1765
|
+
func_args += (quat_ident,)
|
|
1731
1766
|
|
|
1732
|
-
func_args = variadic_args
|
|
1733
1767
|
template_args = (dtype,)
|
|
1734
1768
|
return (func_args, template_args)
|
|
1735
1769
|
|
|
@@ -1737,7 +1771,7 @@ def transformation_dispatch_func(input_types: Mapping[str, type], return_type: A
|
|
|
1737
1771
|
add_builtin(
|
|
1738
1772
|
"transformation",
|
|
1739
1773
|
input_types={"p": vector(length=3, dtype=Float), "q": quaternion(dtype=Float), "dtype": Float},
|
|
1740
|
-
defaults={"dtype": None},
|
|
1774
|
+
defaults={"q": None, "dtype": None},
|
|
1741
1775
|
value_func=transformation_pq_value_func,
|
|
1742
1776
|
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
1743
1777
|
dispatch_func=transformation_dispatch_func,
|
|
@@ -1795,6 +1829,7 @@ add_builtin(
|
|
|
1795
1829
|
group="Transformations",
|
|
1796
1830
|
doc="Construct an identity transform with zero translation and identity rotation.",
|
|
1797
1831
|
export=True,
|
|
1832
|
+
is_differentiable=False,
|
|
1798
1833
|
)
|
|
1799
1834
|
|
|
1800
1835
|
add_builtin(
|
|
@@ -1928,7 +1963,7 @@ def spatial_vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping
|
|
|
1928
1963
|
|
|
1929
1964
|
if dtype is None:
|
|
1930
1965
|
dtype = value_type
|
|
1931
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1966
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
1932
1967
|
raise RuntimeError(
|
|
1933
1968
|
f"all values used to initialize this spatial vector are expected to be of the type `{dtype.__name__}`"
|
|
1934
1969
|
)
|
|
@@ -2122,7 +2157,7 @@ add_builtin(
|
|
|
2122
2157
|
value_func=tile_zeros_value_func,
|
|
2123
2158
|
dispatch_func=tile_zeros_dispatch_func,
|
|
2124
2159
|
variadic=False,
|
|
2125
|
-
|
|
2160
|
+
is_differentiable=False,
|
|
2126
2161
|
doc="""Allocate a tile of zero-initialized items.
|
|
2127
2162
|
|
|
2128
2163
|
:param shape: Shape of the output tile
|
|
@@ -2142,7 +2177,7 @@ add_builtin(
|
|
|
2142
2177
|
value_func=tile_zeros_value_func,
|
|
2143
2178
|
dispatch_func=tile_zeros_dispatch_func,
|
|
2144
2179
|
variadic=False,
|
|
2145
|
-
|
|
2180
|
+
is_differentiable=False,
|
|
2146
2181
|
hidden=True,
|
|
2147
2182
|
group="Tile Primitives",
|
|
2148
2183
|
export=False,
|
|
@@ -2194,7 +2229,7 @@ add_builtin(
|
|
|
2194
2229
|
defaults={"storage": "register"},
|
|
2195
2230
|
value_func=tile_ones_value_func,
|
|
2196
2231
|
dispatch_func=tile_ones_dispatch_func,
|
|
2197
|
-
|
|
2232
|
+
is_differentiable=False,
|
|
2198
2233
|
doc="""Allocate a tile of one-initialized items.
|
|
2199
2234
|
|
|
2200
2235
|
:param shape: Shape of the output tile
|
|
@@ -2213,7 +2248,86 @@ add_builtin(
|
|
|
2213
2248
|
defaults={"storage": "register"},
|
|
2214
2249
|
value_func=tile_ones_value_func,
|
|
2215
2250
|
dispatch_func=tile_ones_dispatch_func,
|
|
2216
|
-
|
|
2251
|
+
is_differentiable=False,
|
|
2252
|
+
hidden=True,
|
|
2253
|
+
group="Tile Primitives",
|
|
2254
|
+
export=False,
|
|
2255
|
+
)
|
|
2256
|
+
|
|
2257
|
+
|
|
2258
|
+
def tile_full_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
2259
|
+
# return generic type (for doc builds)
|
|
2260
|
+
if arg_types is None:
|
|
2261
|
+
return tile(dtype=Any, shape=Tuple[int, ...])
|
|
2262
|
+
|
|
2263
|
+
shape = extract_tuple(arg_values["shape"], as_constant=True)
|
|
2264
|
+
|
|
2265
|
+
if None in shape:
|
|
2266
|
+
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
2267
|
+
|
|
2268
|
+
if "value" not in arg_values:
|
|
2269
|
+
raise TypeError("tile_full() missing required keyword argument 'value'")
|
|
2270
|
+
|
|
2271
|
+
if "dtype" not in arg_values:
|
|
2272
|
+
raise TypeError("tile_full() missing required keyword argument 'dtype'")
|
|
2273
|
+
|
|
2274
|
+
if "storage" not in arg_values:
|
|
2275
|
+
raise TypeError("tile_full() missing required keyword argument 'storage'")
|
|
2276
|
+
|
|
2277
|
+
if arg_values["storage"] not in {"shared", "register"}:
|
|
2278
|
+
raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
|
|
2279
|
+
|
|
2280
|
+
dtype = arg_values["dtype"]
|
|
2281
|
+
|
|
2282
|
+
return tile(dtype=dtype, shape=shape, storage=arg_values["storage"])
|
|
2283
|
+
|
|
2284
|
+
|
|
2285
|
+
def tile_full_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
2286
|
+
shape = extract_tuple(arg_values["shape"], as_constant=True)
|
|
2287
|
+
|
|
2288
|
+
if None in shape:
|
|
2289
|
+
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
2290
|
+
|
|
2291
|
+
dtype = arg_values["dtype"]
|
|
2292
|
+
value = arg_values["value"]
|
|
2293
|
+
|
|
2294
|
+
func_args = [value]
|
|
2295
|
+
|
|
2296
|
+
template_args = []
|
|
2297
|
+
template_args.append(dtype)
|
|
2298
|
+
template_args.extend(shape)
|
|
2299
|
+
|
|
2300
|
+
return (func_args, template_args)
|
|
2301
|
+
|
|
2302
|
+
|
|
2303
|
+
add_builtin(
|
|
2304
|
+
"tile_full",
|
|
2305
|
+
input_types={"shape": Tuple[int, ...], "value": Any, "dtype": Any, "storage": str},
|
|
2306
|
+
defaults={"storage": "register"},
|
|
2307
|
+
value_func=tile_full_value_func,
|
|
2308
|
+
dispatch_func=tile_full_dispatch_func,
|
|
2309
|
+
is_differentiable=False,
|
|
2310
|
+
doc="""Allocate a tile filled with the specified value.
|
|
2311
|
+
|
|
2312
|
+
:param shape: Shape of the output tile
|
|
2313
|
+
:param value: Value to fill the tile with
|
|
2314
|
+
:param dtype: Data type of output tile's elements
|
|
2315
|
+
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
2316
|
+
(default) or ``"shared"`` for shared memory.
|
|
2317
|
+
:returns: A tile filled with the specified value""",
|
|
2318
|
+
group="Tile Primitives",
|
|
2319
|
+
export=False,
|
|
2320
|
+
)
|
|
2321
|
+
|
|
2322
|
+
|
|
2323
|
+
# overload for scalar shape
|
|
2324
|
+
add_builtin(
|
|
2325
|
+
"tile_full",
|
|
2326
|
+
input_types={"shape": int, "value": Any, "dtype": Any, "storage": str},
|
|
2327
|
+
defaults={"storage": "register"},
|
|
2328
|
+
value_func=tile_full_value_func,
|
|
2329
|
+
dispatch_func=tile_full_dispatch_func,
|
|
2330
|
+
is_differentiable=False,
|
|
2217
2331
|
hidden=True,
|
|
2218
2332
|
group="Tile Primitives",
|
|
2219
2333
|
export=False,
|
|
@@ -2275,13 +2389,13 @@ def tile_arange_dispatch_func(arg_types: Mapping[str, type], return_type: Any, a
|
|
|
2275
2389
|
args = arg_values["args"]
|
|
2276
2390
|
|
|
2277
2391
|
if len(args) == 1:
|
|
2278
|
-
start = warp.codegen.Var(label=None, type=return_type.dtype, constant=0)
|
|
2392
|
+
start = warp._src.codegen.Var(label=None, type=return_type.dtype, constant=0)
|
|
2279
2393
|
stop = args[0]
|
|
2280
|
-
step = warp.codegen.Var(label=None, type=return_type.dtype, constant=1)
|
|
2394
|
+
step = warp._src.codegen.Var(label=None, type=return_type.dtype, constant=1)
|
|
2281
2395
|
elif len(args) == 2:
|
|
2282
2396
|
start = args[0]
|
|
2283
2397
|
stop = args[1]
|
|
2284
|
-
step = warp.codegen.Var(label=None, type=return_type.dtype, constant=1)
|
|
2398
|
+
step = warp._src.codegen.Var(label=None, type=return_type.dtype, constant=1)
|
|
2285
2399
|
elif len(args) == 3:
|
|
2286
2400
|
start = args[0]
|
|
2287
2401
|
stop = args[1]
|
|
@@ -2304,7 +2418,7 @@ add_builtin(
|
|
|
2304
2418
|
value_func=tile_arange_value_func,
|
|
2305
2419
|
dispatch_func=tile_arange_dispatch_func,
|
|
2306
2420
|
variadic=True,
|
|
2307
|
-
|
|
2421
|
+
is_differentiable=False,
|
|
2308
2422
|
doc="""Generate a tile of linearly spaced elements.
|
|
2309
2423
|
|
|
2310
2424
|
:param args: Variable-length positional arguments, interpreted as:
|
|
@@ -3099,7 +3213,7 @@ add_builtin(
|
|
|
3099
3213
|
:param shape: Shape of the returned slice
|
|
3100
3214
|
:returns: A tile with dimensions given by the specified shape or the remaining source tile dimensions""",
|
|
3101
3215
|
group="Tile Primitives",
|
|
3102
|
-
|
|
3216
|
+
is_differentiable=False,
|
|
3103
3217
|
export=False,
|
|
3104
3218
|
)
|
|
3105
3219
|
|
|
@@ -3346,7 +3460,32 @@ add_builtin(
|
|
|
3346
3460
|
|
|
3347
3461
|
add_builtin(
|
|
3348
3462
|
"assign",
|
|
3349
|
-
input_types={"dst": tile(dtype=Any, shape=Tuple[int,
|
|
3463
|
+
input_types={"dst": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "src": Any},
|
|
3464
|
+
value_func=tile_assign_value_func,
|
|
3465
|
+
group="Tile Primitives",
|
|
3466
|
+
export=False,
|
|
3467
|
+
hidden=True,
|
|
3468
|
+
)
|
|
3469
|
+
|
|
3470
|
+
add_builtin(
|
|
3471
|
+
"assign",
|
|
3472
|
+
input_types={"dst": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "src": Any},
|
|
3473
|
+
value_func=tile_assign_value_func,
|
|
3474
|
+
group="Tile Primitives",
|
|
3475
|
+
export=False,
|
|
3476
|
+
hidden=True,
|
|
3477
|
+
)
|
|
3478
|
+
|
|
3479
|
+
add_builtin(
|
|
3480
|
+
"assign",
|
|
3481
|
+
input_types={
|
|
3482
|
+
"dst": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
3483
|
+
"i": int,
|
|
3484
|
+
"j": int,
|
|
3485
|
+
"k": int,
|
|
3486
|
+
"l": int,
|
|
3487
|
+
"src": Any,
|
|
3488
|
+
},
|
|
3350
3489
|
value_func=tile_assign_value_func,
|
|
3351
3490
|
group="Tile Primitives",
|
|
3352
3491
|
export=False,
|
|
@@ -3355,7 +3494,15 @@ add_builtin(
|
|
|
3355
3494
|
|
|
3356
3495
|
add_builtin(
|
|
3357
3496
|
"assign",
|
|
3358
|
-
input_types={
|
|
3497
|
+
input_types={
|
|
3498
|
+
"dst": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
3499
|
+
"i": int,
|
|
3500
|
+
"j": int,
|
|
3501
|
+
"k": int,
|
|
3502
|
+
"l": int,
|
|
3503
|
+
"m": int,
|
|
3504
|
+
"src": Any,
|
|
3505
|
+
},
|
|
3359
3506
|
value_func=tile_assign_value_func,
|
|
3360
3507
|
group="Tile Primitives",
|
|
3361
3508
|
export=False,
|
|
@@ -3370,6 +3517,8 @@ add_builtin(
|
|
|
3370
3517
|
"j": int,
|
|
3371
3518
|
"k": int,
|
|
3372
3519
|
"l": int,
|
|
3520
|
+
"m": int,
|
|
3521
|
+
"n": int,
|
|
3373
3522
|
"src": Any,
|
|
3374
3523
|
},
|
|
3375
3524
|
value_func=tile_assign_value_func,
|
|
@@ -3391,7 +3540,7 @@ def tile_value_func(arg_types, arg_values):
|
|
|
3391
3540
|
|
|
3392
3541
|
if preserve_type:
|
|
3393
3542
|
dtype = arg_types["x"]
|
|
3394
|
-
shape = (warp.codegen.options["block_dim"],)
|
|
3543
|
+
shape = (warp._src.codegen.options["block_dim"],)
|
|
3395
3544
|
|
|
3396
3545
|
return tile(dtype=dtype, shape=shape)
|
|
3397
3546
|
|
|
@@ -3399,18 +3548,18 @@ def tile_value_func(arg_types, arg_values):
|
|
|
3399
3548
|
if type_is_vector(arg_types["x"]):
|
|
3400
3549
|
dtype = arg_types["x"]._wp_scalar_type_
|
|
3401
3550
|
length = arg_types["x"]._shape_[0]
|
|
3402
|
-
shape = (length, warp.codegen.options["block_dim"])
|
|
3551
|
+
shape = (length, warp._src.codegen.options["block_dim"])
|
|
3403
3552
|
elif type_is_quaternion(arg_types["x"]):
|
|
3404
3553
|
dtype = arg_types["x"]._wp_scalar_type_
|
|
3405
|
-
shape = (4, warp.codegen.options["block_dim"])
|
|
3554
|
+
shape = (4, warp._src.codegen.options["block_dim"])
|
|
3406
3555
|
elif type_is_matrix(arg_types["x"]):
|
|
3407
3556
|
dtype = arg_types["x"]._wp_scalar_type_
|
|
3408
3557
|
rows = arg_types["x"]._shape_[0]
|
|
3409
3558
|
cols = arg_types["x"]._shape_[1]
|
|
3410
|
-
shape = (rows, cols, warp.codegen.options["block_dim"])
|
|
3559
|
+
shape = (rows, cols, warp._src.codegen.options["block_dim"])
|
|
3411
3560
|
else:
|
|
3412
3561
|
dtype = arg_types["x"]
|
|
3413
|
-
shape = (warp.codegen.options["block_dim"],)
|
|
3562
|
+
shape = (warp._src.codegen.options["block_dim"],)
|
|
3414
3563
|
|
|
3415
3564
|
return tile(dtype=dtype, shape=shape)
|
|
3416
3565
|
|
|
@@ -3500,17 +3649,17 @@ def untile_value_func(arg_types, arg_values):
|
|
|
3500
3649
|
if not is_tile(t):
|
|
3501
3650
|
raise TypeError(f"untile() argument must be a tile, got {t!r}")
|
|
3502
3651
|
|
|
3503
|
-
if t.shape[-1] != warp.codegen.options["block_dim"]:
|
|
3652
|
+
if t.shape[-1] != warp._src.codegen.options["block_dim"]:
|
|
3504
3653
|
raise ValueError(
|
|
3505
|
-
f"untile() argument last dimension {t.shape[-1]} does not match the expected block width {warp.codegen.options['block_dim']}"
|
|
3654
|
+
f"untile() argument last dimension {t.shape[-1]} does not match the expected block width {warp._src.codegen.options['block_dim']}"
|
|
3506
3655
|
)
|
|
3507
3656
|
|
|
3508
3657
|
if len(t.shape) == 1:
|
|
3509
3658
|
return t.dtype
|
|
3510
3659
|
elif len(t.shape) == 2:
|
|
3511
|
-
return warp.types.vector(t.shape[0], t.dtype)
|
|
3660
|
+
return warp._src.types.vector(t.shape[0], t.dtype)
|
|
3512
3661
|
elif len(t.shape) == 3:
|
|
3513
|
-
return warp.types.matrix((t.shape[0], t.shape[1]), t.dtype)
|
|
3662
|
+
return warp._src.types.matrix((t.shape[0], t.shape[1]), t.dtype)
|
|
3514
3663
|
else:
|
|
3515
3664
|
raise ValueError(f"untile() argument must have a positive size in dimension 0, but got {t.shape[0]}")
|
|
3516
3665
|
|
|
@@ -3572,7 +3721,36 @@ def tile_extract_value_func(arg_types, arg_values):
|
|
|
3572
3721
|
# force the input tile to shared memory
|
|
3573
3722
|
arg_types["a"].storage = "shared"
|
|
3574
3723
|
|
|
3575
|
-
|
|
3724
|
+
# count the number of indices (all parameters except the tile "a")
|
|
3725
|
+
num_indices = len(arg_types) - 1
|
|
3726
|
+
tile_dtype = arg_types["a"].dtype
|
|
3727
|
+
tile_shape = arg_types["a"].shape
|
|
3728
|
+
|
|
3729
|
+
if type_is_vector(tile_dtype):
|
|
3730
|
+
if num_indices == len(tile_shape):
|
|
3731
|
+
return tile_dtype
|
|
3732
|
+
elif num_indices == len(tile_shape) + 1:
|
|
3733
|
+
return tile_dtype._wp_scalar_type_
|
|
3734
|
+
else:
|
|
3735
|
+
raise IndexError(
|
|
3736
|
+
f"tile_extract: incorrect number of indices ({num_indices}) for tile shape {tuple(tile_shape)}"
|
|
3737
|
+
)
|
|
3738
|
+
elif type_is_matrix(tile_dtype):
|
|
3739
|
+
if num_indices == len(tile_shape):
|
|
3740
|
+
return tile_dtype
|
|
3741
|
+
elif num_indices == len(tile_shape) + 2:
|
|
3742
|
+
return tile_dtype._wp_scalar_type_
|
|
3743
|
+
else:
|
|
3744
|
+
raise IndexError(
|
|
3745
|
+
f"tile_extract: incorrect number of indices ({num_indices}) for matrix tile shape {tuple(tile_shape)}"
|
|
3746
|
+
)
|
|
3747
|
+
else:
|
|
3748
|
+
# scalar element: index count must exactly match tile rank
|
|
3749
|
+
if num_indices == len(tile_shape):
|
|
3750
|
+
return tile_dtype
|
|
3751
|
+
raise IndexError(
|
|
3752
|
+
f"tile_extract: incorrect number of indices ({num_indices}) for tile shape {tuple(tile_shape)}"
|
|
3753
|
+
)
|
|
3576
3754
|
|
|
3577
3755
|
|
|
3578
3756
|
add_builtin(
|
|
@@ -3596,7 +3774,7 @@ add_builtin(
|
|
|
3596
3774
|
|
|
3597
3775
|
add_builtin(
|
|
3598
3776
|
"tile_extract",
|
|
3599
|
-
input_types={"a": tile(dtype=Any, shape=Tuple[int,
|
|
3777
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int},
|
|
3600
3778
|
value_func=tile_extract_value_func,
|
|
3601
3779
|
variadic=False,
|
|
3602
3780
|
doc="""Extract a single element from the tile.
|
|
@@ -3607,7 +3785,7 @@ add_builtin(
|
|
|
3607
3785
|
|
|
3608
3786
|
:param a: Tile to extract the element from
|
|
3609
3787
|
:param i: Coordinate of element on first dimension
|
|
3610
|
-
:param j: Coordinate of element on the second dimension
|
|
3788
|
+
:param j: Coordinate of element on the second dimension, or vector index
|
|
3611
3789
|
:returns: The value of the element at the specified tile location with the same data type as the input tile""",
|
|
3612
3790
|
group="Tile Primitives",
|
|
3613
3791
|
hidden=True,
|
|
@@ -3616,7 +3794,57 @@ add_builtin(
|
|
|
3616
3794
|
|
|
3617
3795
|
add_builtin(
|
|
3618
3796
|
"tile_extract",
|
|
3619
|
-
input_types={"a": tile(dtype=Any, shape=Tuple[int,
|
|
3797
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int},
|
|
3798
|
+
value_func=tile_extract_value_func,
|
|
3799
|
+
variadic=False,
|
|
3800
|
+
doc="""Extract a single element from the tile.
|
|
3801
|
+
|
|
3802
|
+
This function will extract an element from the tile and broadcast its value to all threads in the block.
|
|
3803
|
+
|
|
3804
|
+
Note that this may incur additional synchronization if the source tile is a register tile.
|
|
3805
|
+
|
|
3806
|
+
:param a: Tile to extract the element from
|
|
3807
|
+
:param i: Coordinate of element on first dimension
|
|
3808
|
+
:param j: Coordinate of element on the second dimension, or first matrix index
|
|
3809
|
+
:param k: Coordinate of element on the third dimension, or vector index, or second matrix index
|
|
3810
|
+
:returns: The value of the element at the specified tile location with the same data type as the input tile""",
|
|
3811
|
+
group="Tile Primitives",
|
|
3812
|
+
hidden=True,
|
|
3813
|
+
export=False,
|
|
3814
|
+
)
|
|
3815
|
+
|
|
3816
|
+
add_builtin(
|
|
3817
|
+
"tile_extract",
|
|
3818
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int},
|
|
3819
|
+
value_func=tile_extract_value_func,
|
|
3820
|
+
variadic=False,
|
|
3821
|
+
doc="""Extract a single element from the tile.
|
|
3822
|
+
|
|
3823
|
+
This function will extract an element from the tile and broadcast its value to all threads in the block.
|
|
3824
|
+
|
|
3825
|
+
Note that this may incur additional synchronization if the source tile is a register tile.
|
|
3826
|
+
|
|
3827
|
+
:param a: Tile to extract the element from
|
|
3828
|
+
:param i: Coordinate of element on first dimension
|
|
3829
|
+
:param j: Coordinate of element on the second dimension
|
|
3830
|
+
:param k: Coordinate of element on the third dimension, or first matrix index
|
|
3831
|
+
:param l: Coordinate of element on the fourth dimension, or vector index, or second matrix index
|
|
3832
|
+
:returns: The value of the element at the specified tile location, with the same data type as the input tile""",
|
|
3833
|
+
group="Tile Primitives",
|
|
3834
|
+
hidden=True,
|
|
3835
|
+
export=False,
|
|
3836
|
+
)
|
|
3837
|
+
|
|
3838
|
+
add_builtin(
|
|
3839
|
+
"tile_extract",
|
|
3840
|
+
input_types={
|
|
3841
|
+
"a": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
3842
|
+
"i": int,
|
|
3843
|
+
"j": int,
|
|
3844
|
+
"k": int,
|
|
3845
|
+
"l": int,
|
|
3846
|
+
"m": int,
|
|
3847
|
+
},
|
|
3620
3848
|
value_func=tile_extract_value_func,
|
|
3621
3849
|
variadic=False,
|
|
3622
3850
|
doc="""Extract a single element from the tile.
|
|
@@ -3629,7 +3857,9 @@ add_builtin(
|
|
|
3629
3857
|
:param i: Coordinate of element on first dimension
|
|
3630
3858
|
:param j: Coordinate of element on the second dimension
|
|
3631
3859
|
:param k: Coordinate of element on the third dimension
|
|
3632
|
-
:
|
|
3860
|
+
:param l: Coordinate of element on the fourth dimension, or first matrix index
|
|
3861
|
+
:param m: Vector index, or second matrix index
|
|
3862
|
+
:returns: The value of the element at the specified tile location, with the same data type as the input tile""",
|
|
3633
3863
|
group="Tile Primitives",
|
|
3634
3864
|
hidden=True,
|
|
3635
3865
|
export=False,
|
|
@@ -3637,7 +3867,15 @@ add_builtin(
|
|
|
3637
3867
|
|
|
3638
3868
|
add_builtin(
|
|
3639
3869
|
"tile_extract",
|
|
3640
|
-
input_types={
|
|
3870
|
+
input_types={
|
|
3871
|
+
"a": tile(dtype=Any, shape=Tuple[int, int, int, int]),
|
|
3872
|
+
"i": int,
|
|
3873
|
+
"j": int,
|
|
3874
|
+
"k": int,
|
|
3875
|
+
"l": int,
|
|
3876
|
+
"m": int,
|
|
3877
|
+
"n": int,
|
|
3878
|
+
},
|
|
3641
3879
|
value_func=tile_extract_value_func,
|
|
3642
3880
|
variadic=False,
|
|
3643
3881
|
doc="""Extract a single element from the tile.
|
|
@@ -3651,6 +3889,8 @@ add_builtin(
|
|
|
3651
3889
|
:param j: Coordinate of element on the second dimension
|
|
3652
3890
|
:param k: Coordinate of element on the third dimension
|
|
3653
3891
|
:param l: Coordinate of element on the fourth dimension
|
|
3892
|
+
:param m: Vector index, or first matrix index
|
|
3893
|
+
:param n: Second matrix index
|
|
3654
3894
|
:returns: The value of the element at the specified tile location, with the same data type as the input tile""",
|
|
3655
3895
|
group="Tile Primitives",
|
|
3656
3896
|
hidden=True,
|
|
@@ -3737,49 +3977,160 @@ add_builtin(
|
|
|
3737
3977
|
export=False,
|
|
3738
3978
|
)
|
|
3739
3979
|
|
|
3740
|
-
|
|
3741
|
-
def tile_transpose_value_func(arg_types, arg_values):
|
|
3742
|
-
# return generic type (for doc builds)
|
|
3743
|
-
if arg_types is None:
|
|
3744
|
-
return tile(dtype=Any, shape=Tuple[int, int])
|
|
3745
|
-
|
|
3746
|
-
if len(arg_types) != 1:
|
|
3747
|
-
raise TypeError(f"tile_transpose() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
3748
|
-
|
|
3749
|
-
t = arg_types["a"]
|
|
3750
|
-
|
|
3751
|
-
if not is_tile(t):
|
|
3752
|
-
raise TypeError(f"tile_transpose() argument must be a tile, got {t!r}")
|
|
3753
|
-
|
|
3754
|
-
layout = None
|
|
3755
|
-
|
|
3756
|
-
# flip layout
|
|
3757
|
-
if t.layout == "rowmajor":
|
|
3758
|
-
layout = "colmajor"
|
|
3759
|
-
elif t.layout == "colmajor":
|
|
3760
|
-
layout = "rowmajor"
|
|
3761
|
-
|
|
3762
|
-
# force the input tile to shared memory
|
|
3763
|
-
t.storage = "shared"
|
|
3764
|
-
|
|
3765
|
-
return tile(
|
|
3766
|
-
dtype=t.dtype,
|
|
3767
|
-
shape=t.shape[::-1],
|
|
3768
|
-
storage=t.storage,
|
|
3769
|
-
strides=t.strides[::-1],
|
|
3770
|
-
layout=layout,
|
|
3771
|
-
owner=False,
|
|
3772
|
-
)
|
|
3773
|
-
|
|
3774
|
-
|
|
3775
3980
|
add_builtin(
|
|
3776
|
-
"
|
|
3777
|
-
input_types={"a": tile(dtype=Any, shape=Tuple[int,
|
|
3778
|
-
value_func=
|
|
3779
|
-
|
|
3780
|
-
|
|
3781
|
-
|
|
3782
|
-
|
|
3981
|
+
"tile_bit_and_inplace",
|
|
3982
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
|
|
3983
|
+
value_func=tile_inplace_value_func,
|
|
3984
|
+
group="Tile Primitives",
|
|
3985
|
+
hidden=True,
|
|
3986
|
+
export=False,
|
|
3987
|
+
is_differentiable=False,
|
|
3988
|
+
)
|
|
3989
|
+
add_builtin(
|
|
3990
|
+
"tile_bit_and_inplace",
|
|
3991
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
|
|
3992
|
+
value_func=tile_inplace_value_func,
|
|
3993
|
+
group="Tile Primitives",
|
|
3994
|
+
hidden=True,
|
|
3995
|
+
export=False,
|
|
3996
|
+
is_differentiable=False,
|
|
3997
|
+
)
|
|
3998
|
+
add_builtin(
|
|
3999
|
+
"tile_bit_and_inplace",
|
|
4000
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
|
|
4001
|
+
value_func=tile_inplace_value_func,
|
|
4002
|
+
group="Tile Primitives",
|
|
4003
|
+
hidden=True,
|
|
4004
|
+
export=False,
|
|
4005
|
+
is_differentiable=False,
|
|
4006
|
+
)
|
|
4007
|
+
add_builtin(
|
|
4008
|
+
"tile_bit_and_inplace",
|
|
4009
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
|
|
4010
|
+
value_func=tile_inplace_value_func,
|
|
4011
|
+
group="Tile Primitives",
|
|
4012
|
+
hidden=True,
|
|
4013
|
+
export=False,
|
|
4014
|
+
is_differentiable=False,
|
|
4015
|
+
)
|
|
4016
|
+
|
|
4017
|
+
add_builtin(
|
|
4018
|
+
"tile_bit_or_inplace",
|
|
4019
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
|
|
4020
|
+
value_func=tile_inplace_value_func,
|
|
4021
|
+
group="Tile Primitives",
|
|
4022
|
+
hidden=True,
|
|
4023
|
+
export=False,
|
|
4024
|
+
is_differentiable=False,
|
|
4025
|
+
)
|
|
4026
|
+
add_builtin(
|
|
4027
|
+
"tile_bit_or_inplace",
|
|
4028
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
|
|
4029
|
+
value_func=tile_inplace_value_func,
|
|
4030
|
+
group="Tile Primitives",
|
|
4031
|
+
hidden=True,
|
|
4032
|
+
export=False,
|
|
4033
|
+
is_differentiable=False,
|
|
4034
|
+
)
|
|
4035
|
+
add_builtin(
|
|
4036
|
+
"tile_bit_or_inplace",
|
|
4037
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
|
|
4038
|
+
value_func=tile_inplace_value_func,
|
|
4039
|
+
group="Tile Primitives",
|
|
4040
|
+
hidden=True,
|
|
4041
|
+
export=False,
|
|
4042
|
+
is_differentiable=False,
|
|
4043
|
+
)
|
|
4044
|
+
add_builtin(
|
|
4045
|
+
"tile_bit_or_inplace",
|
|
4046
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
|
|
4047
|
+
value_func=tile_inplace_value_func,
|
|
4048
|
+
group="Tile Primitives",
|
|
4049
|
+
hidden=True,
|
|
4050
|
+
export=False,
|
|
4051
|
+
is_differentiable=False,
|
|
4052
|
+
)
|
|
4053
|
+
|
|
4054
|
+
add_builtin(
|
|
4055
|
+
"tile_bit_xor_inplace",
|
|
4056
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
|
|
4057
|
+
value_func=tile_inplace_value_func,
|
|
4058
|
+
group="Tile Primitives",
|
|
4059
|
+
hidden=True,
|
|
4060
|
+
export=False,
|
|
4061
|
+
is_differentiable=False,
|
|
4062
|
+
)
|
|
4063
|
+
add_builtin(
|
|
4064
|
+
"tile_bit_xor_inplace",
|
|
4065
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
|
|
4066
|
+
value_func=tile_inplace_value_func,
|
|
4067
|
+
group="Tile Primitives",
|
|
4068
|
+
hidden=True,
|
|
4069
|
+
export=False,
|
|
4070
|
+
is_differentiable=False,
|
|
4071
|
+
)
|
|
4072
|
+
add_builtin(
|
|
4073
|
+
"tile_bit_xor_inplace",
|
|
4074
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
|
|
4075
|
+
value_func=tile_inplace_value_func,
|
|
4076
|
+
group="Tile Primitives",
|
|
4077
|
+
hidden=True,
|
|
4078
|
+
export=False,
|
|
4079
|
+
is_differentiable=False,
|
|
4080
|
+
)
|
|
4081
|
+
add_builtin(
|
|
4082
|
+
"tile_bit_xor_inplace",
|
|
4083
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
|
|
4084
|
+
value_func=tile_inplace_value_func,
|
|
4085
|
+
group="Tile Primitives",
|
|
4086
|
+
hidden=True,
|
|
4087
|
+
export=False,
|
|
4088
|
+
is_differentiable=False,
|
|
4089
|
+
)
|
|
4090
|
+
|
|
4091
|
+
|
|
4092
|
+
def tile_transpose_value_func(arg_types, arg_values):
|
|
4093
|
+
# return generic type (for doc builds)
|
|
4094
|
+
if arg_types is None:
|
|
4095
|
+
return tile(dtype=Any, shape=Tuple[int, int])
|
|
4096
|
+
|
|
4097
|
+
if len(arg_types) != 1:
|
|
4098
|
+
raise TypeError(f"tile_transpose() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
4099
|
+
|
|
4100
|
+
t = arg_types["a"]
|
|
4101
|
+
|
|
4102
|
+
if not is_tile(t):
|
|
4103
|
+
raise TypeError(f"tile_transpose() argument must be a tile, got {t!r}")
|
|
4104
|
+
|
|
4105
|
+
layout = None
|
|
4106
|
+
|
|
4107
|
+
# flip layout
|
|
4108
|
+
if t.layout == "rowmajor":
|
|
4109
|
+
layout = "colmajor"
|
|
4110
|
+
elif t.layout == "colmajor":
|
|
4111
|
+
layout = "rowmajor"
|
|
4112
|
+
|
|
4113
|
+
# force the input tile to shared memory
|
|
4114
|
+
t.storage = "shared"
|
|
4115
|
+
|
|
4116
|
+
return tile(
|
|
4117
|
+
dtype=t.dtype,
|
|
4118
|
+
shape=t.shape[::-1],
|
|
4119
|
+
storage=t.storage,
|
|
4120
|
+
strides=t.strides[::-1],
|
|
4121
|
+
layout=layout,
|
|
4122
|
+
owner=False,
|
|
4123
|
+
)
|
|
4124
|
+
|
|
4125
|
+
|
|
4126
|
+
add_builtin(
|
|
4127
|
+
"tile_transpose",
|
|
4128
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, int])},
|
|
4129
|
+
value_func=tile_transpose_value_func,
|
|
4130
|
+
variadic=True,
|
|
4131
|
+
doc="""Transpose a tile.
|
|
4132
|
+
|
|
4133
|
+
For shared memory tiles, this operation will alias the input tile.
|
|
3783
4134
|
Register tiles will first be transferred to shared memory before transposition.
|
|
3784
4135
|
|
|
3785
4136
|
:param a: Tile to transpose with ``shape=(M,N)``
|
|
@@ -3910,6 +4261,80 @@ add_builtin(
|
|
|
3910
4261
|
)
|
|
3911
4262
|
|
|
3912
4263
|
|
|
4264
|
+
def tile_sum_axis_value_func(arg_types, arg_values):
|
|
4265
|
+
if arg_types is None:
|
|
4266
|
+
return tile(dtype=Scalar, shape=Tuple[int, ...])
|
|
4267
|
+
|
|
4268
|
+
a = arg_types["a"]
|
|
4269
|
+
|
|
4270
|
+
if not is_tile(a):
|
|
4271
|
+
raise TypeError(f"tile_sum() 'a' argument must be a tile, got {a!r}")
|
|
4272
|
+
|
|
4273
|
+
# force input tile to shared
|
|
4274
|
+
a.storage = "shared"
|
|
4275
|
+
|
|
4276
|
+
axis = arg_values["axis"]
|
|
4277
|
+
shape = a.shape
|
|
4278
|
+
|
|
4279
|
+
if axis < 0 or axis >= len(shape):
|
|
4280
|
+
raise ValueError(f"tile_sum() axis {axis} is out of bounds for tile with {len(shape)} dimensions")
|
|
4281
|
+
|
|
4282
|
+
# shape is identical less the axis reduction is along
|
|
4283
|
+
if len(shape) > 1:
|
|
4284
|
+
new_shape = shape[:axis] + shape[axis + 1 :]
|
|
4285
|
+
else:
|
|
4286
|
+
new_shape = (1,)
|
|
4287
|
+
|
|
4288
|
+
return tile(dtype=a.dtype, shape=new_shape)
|
|
4289
|
+
|
|
4290
|
+
|
|
4291
|
+
def tile_sum_axis_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
4292
|
+
tile = arg_values["a"]
|
|
4293
|
+
axis_var = arg_values["axis"]
|
|
4294
|
+
if not hasattr(axis_var, "constant") or axis_var.constant is None:
|
|
4295
|
+
raise ValueError("tile_sum() axis must be a compile-time constant")
|
|
4296
|
+
axis = axis_var.constant
|
|
4297
|
+
|
|
4298
|
+
return ((tile,), (axis,))
|
|
4299
|
+
|
|
4300
|
+
|
|
4301
|
+
add_builtin(
|
|
4302
|
+
"tile_sum",
|
|
4303
|
+
input_types={"a": tile(dtype=Scalar, shape=Tuple[int, ...]), "axis": int},
|
|
4304
|
+
value_func=tile_sum_axis_value_func,
|
|
4305
|
+
dispatch_func=tile_sum_axis_dispatch_func,
|
|
4306
|
+
doc="""Cooperatively compute the sum of the tile elements across an axis of the tile using all threads in the block.
|
|
4307
|
+
|
|
4308
|
+
:param a: The input tile. Must reside in shared memory.
|
|
4309
|
+
:param axis: The tile axis to compute the sum across. Must be a compile-time constant.
|
|
4310
|
+
:returns: A tile with the same shape as the input tile less the axis dimension and the same data type as the input tile.
|
|
4311
|
+
|
|
4312
|
+
Example:
|
|
4313
|
+
|
|
4314
|
+
.. code-block:: python
|
|
4315
|
+
|
|
4316
|
+
@wp.kernel
|
|
4317
|
+
def compute():
|
|
4318
|
+
|
|
4319
|
+
t = wp.tile_ones(dtype=float, shape=(8, 8))
|
|
4320
|
+
s = wp.tile_sum(t, axis=0)
|
|
4321
|
+
|
|
4322
|
+
print(s)
|
|
4323
|
+
|
|
4324
|
+
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
|
|
4325
|
+
|
|
4326
|
+
Prints:
|
|
4327
|
+
|
|
4328
|
+
.. code-block:: text
|
|
4329
|
+
|
|
4330
|
+
[8 8 8 8 8 8 8 8] = tile(shape=(8), storage=register)
|
|
4331
|
+
|
|
4332
|
+
""",
|
|
4333
|
+
group="Tile Primitives",
|
|
4334
|
+
export=False,
|
|
4335
|
+
)
|
|
4336
|
+
|
|
4337
|
+
|
|
3913
4338
|
def tile_sort_value_func(arg_types, arg_values):
|
|
3914
4339
|
# return generic type (for doc builds)
|
|
3915
4340
|
if arg_types is None:
|
|
@@ -3986,6 +4411,7 @@ add_builtin(
|
|
|
3986
4411
|
""",
|
|
3987
4412
|
group="Tile Primitives",
|
|
3988
4413
|
export=False,
|
|
4414
|
+
is_differentiable=False,
|
|
3989
4415
|
)
|
|
3990
4416
|
|
|
3991
4417
|
|
|
@@ -4039,6 +4465,7 @@ add_builtin(
|
|
|
4039
4465
|
""",
|
|
4040
4466
|
group="Tile Primitives",
|
|
4041
4467
|
export=False,
|
|
4468
|
+
is_differentiable=False,
|
|
4042
4469
|
)
|
|
4043
4470
|
|
|
4044
4471
|
|
|
@@ -4092,6 +4519,7 @@ add_builtin(
|
|
|
4092
4519
|
""",
|
|
4093
4520
|
group="Tile Primitives",
|
|
4094
4521
|
export=False,
|
|
4522
|
+
is_differentiable=False,
|
|
4095
4523
|
)
|
|
4096
4524
|
|
|
4097
4525
|
|
|
@@ -4144,6 +4572,7 @@ add_builtin(
|
|
|
4144
4572
|
""",
|
|
4145
4573
|
group="Tile Primitives",
|
|
4146
4574
|
export=False,
|
|
4575
|
+
is_differentiable=False,
|
|
4147
4576
|
)
|
|
4148
4577
|
|
|
4149
4578
|
|
|
@@ -4196,10 +4625,10 @@ add_builtin(
|
|
|
4196
4625
|
""",
|
|
4197
4626
|
group="Tile Primitives",
|
|
4198
4627
|
export=False,
|
|
4628
|
+
is_differentiable=False,
|
|
4199
4629
|
)
|
|
4200
4630
|
|
|
4201
4631
|
|
|
4202
|
-
# does type propagation for load()
|
|
4203
4632
|
def tile_reduce_value_func(arg_types, arg_values):
|
|
4204
4633
|
if arg_types is None:
|
|
4205
4634
|
return tile(dtype=Scalar, shape=(1,))
|
|
@@ -4253,6 +4682,88 @@ add_builtin(
|
|
|
4253
4682
|
""",
|
|
4254
4683
|
group="Tile Primitives",
|
|
4255
4684
|
export=False,
|
|
4685
|
+
is_differentiable=False,
|
|
4686
|
+
)
|
|
4687
|
+
|
|
4688
|
+
|
|
4689
|
+
def tile_reduce_axis_value_func(arg_types, arg_values):
|
|
4690
|
+
if arg_types is None:
|
|
4691
|
+
return tile(dtype=Scalar, shape=Tuple[int, ...])
|
|
4692
|
+
|
|
4693
|
+
a = arg_types["a"]
|
|
4694
|
+
|
|
4695
|
+
if not is_tile(a):
|
|
4696
|
+
raise TypeError(f"tile_reduce() 'a' argument must be a tile, got {a!r}")
|
|
4697
|
+
|
|
4698
|
+
# force input tile to shared memory
|
|
4699
|
+
a.storage = "shared"
|
|
4700
|
+
|
|
4701
|
+
axis = arg_values["axis"]
|
|
4702
|
+
shape = a.shape
|
|
4703
|
+
|
|
4704
|
+
if axis < 0 or axis >= len(shape):
|
|
4705
|
+
raise ValueError(f"tile_reduce() axis {axis} is out of bounds for tile with {len(shape)} dimensions")
|
|
4706
|
+
|
|
4707
|
+
# shape is identical less the axis reduction is along
|
|
4708
|
+
if len(shape) > 1:
|
|
4709
|
+
new_shape = shape[:axis] + shape[axis + 1 :]
|
|
4710
|
+
else:
|
|
4711
|
+
new_shape = (1,)
|
|
4712
|
+
|
|
4713
|
+
return tile(dtype=a.dtype, shape=new_shape)
|
|
4714
|
+
|
|
4715
|
+
|
|
4716
|
+
add_builtin(
|
|
4717
|
+
"tile_reduce",
|
|
4718
|
+
input_types={"op": Callable, "a": tile(dtype=Scalar, shape=Tuple[int, ...]), "axis": int},
|
|
4719
|
+
value_func=tile_reduce_axis_value_func,
|
|
4720
|
+
native_func="tile_reduce_axis",
|
|
4721
|
+
doc="""Apply a custom reduction operator across a tile axis.
|
|
4722
|
+
|
|
4723
|
+
This function cooperatively performs a reduction using the provided operator across an axis of the tile.
|
|
4724
|
+
|
|
4725
|
+
:param op: A callable function that accepts two arguments and returns one argument, may be a user function or builtin
|
|
4726
|
+
:param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's data type. Must reside in shared memory.
|
|
4727
|
+
:param axis: The tile axis to perform the reduction across. Must be a compile-time constant.
|
|
4728
|
+
:returns: A tile with the same shape as the input tile less the axis dimension and the same data type as the input tile.
|
|
4729
|
+
|
|
4730
|
+
Example:
|
|
4731
|
+
|
|
4732
|
+
.. code-block:: python
|
|
4733
|
+
|
|
4734
|
+
TILE_M = wp.constant(4)
|
|
4735
|
+
TILE_N = wp.constant(2)
|
|
4736
|
+
|
|
4737
|
+
@wp.kernel
|
|
4738
|
+
def compute(x: wp.array2d(dtype=float), y: wp.array(dtype=float)):
|
|
4739
|
+
|
|
4740
|
+
a = wp.tile_load(x, shape=(TILE_M, TILE_N))
|
|
4741
|
+
b = wp.tile_reduce(wp.add, a, axis=1)
|
|
4742
|
+
wp.tile_store(y, b)
|
|
4743
|
+
|
|
4744
|
+
arr = np.arange(TILE_M * TILE_N).reshape(TILE_M, TILE_N)
|
|
4745
|
+
|
|
4746
|
+
x = wp.array(arr, dtype=float)
|
|
4747
|
+
y = wp.zeros(TILE_M, dtype=float)
|
|
4748
|
+
|
|
4749
|
+
wp.launch_tiled(compute, dim=[1], inputs=[x], outputs=[y], block_dim=32)
|
|
4750
|
+
|
|
4751
|
+
print(x.numpy())
|
|
4752
|
+
print(y.numpy())
|
|
4753
|
+
|
|
4754
|
+
Prints:
|
|
4755
|
+
|
|
4756
|
+
.. code-block:: text
|
|
4757
|
+
|
|
4758
|
+
[[0. 1.]
|
|
4759
|
+
[2. 3.]
|
|
4760
|
+
[4. 5.]
|
|
4761
|
+
[6. 7.]]
|
|
4762
|
+
[ 1. 5. 9. 13.]
|
|
4763
|
+
""",
|
|
4764
|
+
group="Tile Primitives",
|
|
4765
|
+
export=False,
|
|
4766
|
+
is_differentiable=False,
|
|
4256
4767
|
)
|
|
4257
4768
|
|
|
4258
4769
|
|
|
@@ -4316,6 +4827,7 @@ add_builtin(
|
|
|
4316
4827
|
""",
|
|
4317
4828
|
group="Tile Primitives",
|
|
4318
4829
|
export=False,
|
|
4830
|
+
is_differentiable=False,
|
|
4319
4831
|
)
|
|
4320
4832
|
|
|
4321
4833
|
|
|
@@ -4379,6 +4891,7 @@ add_builtin(
|
|
|
4379
4891
|
""",
|
|
4380
4892
|
group="Tile Primitives",
|
|
4381
4893
|
export=False,
|
|
4894
|
+
is_differentiable=False,
|
|
4382
4895
|
)
|
|
4383
4896
|
|
|
4384
4897
|
|
|
@@ -4632,6 +5145,7 @@ add_builtin(
|
|
|
4632
5145
|
doc="WIP",
|
|
4633
5146
|
group="Utility",
|
|
4634
5147
|
hidden=True,
|
|
5148
|
+
is_differentiable=False,
|
|
4635
5149
|
)
|
|
4636
5150
|
|
|
4637
5151
|
add_builtin(
|
|
@@ -4647,6 +5161,7 @@ add_builtin(
|
|
|
4647
5161
|
doc="WIP",
|
|
4648
5162
|
group="Utility",
|
|
4649
5163
|
hidden=True,
|
|
5164
|
+
is_differentiable=False,
|
|
4650
5165
|
)
|
|
4651
5166
|
|
|
4652
5167
|
add_builtin(
|
|
@@ -4656,6 +5171,7 @@ add_builtin(
|
|
|
4656
5171
|
doc="WIP",
|
|
4657
5172
|
group="Utility",
|
|
4658
5173
|
hidden=True,
|
|
5174
|
+
is_differentiable=False,
|
|
4659
5175
|
)
|
|
4660
5176
|
|
|
4661
5177
|
add_builtin(
|
|
@@ -4707,6 +5223,7 @@ add_builtin(
|
|
|
4707
5223
|
:param low: The lower bound of the bounding box in BVH space
|
|
4708
5224
|
:param high: The upper bound of the bounding box in BVH space""",
|
|
4709
5225
|
export=False,
|
|
5226
|
+
is_differentiable=False,
|
|
4710
5227
|
)
|
|
4711
5228
|
|
|
4712
5229
|
add_builtin(
|
|
@@ -4722,6 +5239,7 @@ add_builtin(
|
|
|
4722
5239
|
:param start: The start of the ray in BVH space
|
|
4723
5240
|
:param dir: The direction of the ray in BVH space""",
|
|
4724
5241
|
export=False,
|
|
5242
|
+
is_differentiable=False,
|
|
4725
5243
|
)
|
|
4726
5244
|
|
|
4727
5245
|
add_builtin(
|
|
@@ -4732,6 +5250,7 @@ add_builtin(
|
|
|
4732
5250
|
doc="""Move to the next bound returned by the query.
|
|
4733
5251
|
The index of the current bound is stored in ``index``, returns ``False`` if there are no more overlapping bound.""",
|
|
4734
5252
|
export=False,
|
|
5253
|
+
is_differentiable=False,
|
|
4735
5254
|
)
|
|
4736
5255
|
|
|
4737
5256
|
add_builtin(
|
|
@@ -5066,12 +5585,13 @@ add_builtin(
|
|
|
5066
5585
|
group="Geometry",
|
|
5067
5586
|
doc="""Construct an axis-aligned bounding box query against a :class:`Mesh`.
|
|
5068
5587
|
|
|
5069
|
-
This query can be used to iterate over all triangles inside a volume.
|
|
5588
|
+
This query can be used to iterate over all bounding boxes of the triangles inside a volume.
|
|
5070
5589
|
|
|
5071
5590
|
:param id: The mesh identifier
|
|
5072
5591
|
:param low: The lower bound of the bounding box in mesh space
|
|
5073
5592
|
:param high: The upper bound of the bounding box in mesh space""",
|
|
5074
5593
|
export=False,
|
|
5594
|
+
is_differentiable=False,
|
|
5075
5595
|
)
|
|
5076
5596
|
|
|
5077
5597
|
add_builtin(
|
|
@@ -5079,10 +5599,11 @@ add_builtin(
|
|
|
5079
5599
|
input_types={"query": MeshQueryAABB, "index": int},
|
|
5080
5600
|
value_type=builtins.bool,
|
|
5081
5601
|
group="Geometry",
|
|
5082
|
-
doc="""Move to the next triangle
|
|
5602
|
+
doc="""Move to the next triangle whose bounding box overlaps the query bounding box.
|
|
5083
5603
|
|
|
5084
5604
|
The index of the current face is stored in ``index``, returns ``False`` if there are no more overlapping triangles.""",
|
|
5085
5605
|
export=False,
|
|
5606
|
+
is_differentiable=False,
|
|
5086
5607
|
)
|
|
5087
5608
|
|
|
5088
5609
|
add_builtin(
|
|
@@ -5112,6 +5633,7 @@ add_builtin(
|
|
|
5112
5633
|
|
|
5113
5634
|
This query can be used to iterate over all neighboring point within a fixed radius from the query point.""",
|
|
5114
5635
|
export=False,
|
|
5636
|
+
is_differentiable=False,
|
|
5115
5637
|
)
|
|
5116
5638
|
|
|
5117
5639
|
add_builtin(
|
|
@@ -5123,6 +5645,7 @@ add_builtin(
|
|
|
5123
5645
|
|
|
5124
5646
|
The index of the current neighbor is stored in ``index``, returns ``False`` if there are no more neighbors.""",
|
|
5125
5647
|
export=False,
|
|
5648
|
+
is_differentiable=False,
|
|
5126
5649
|
)
|
|
5127
5650
|
|
|
5128
5651
|
add_builtin(
|
|
@@ -5136,6 +5659,7 @@ add_builtin(
|
|
|
5136
5659
|
|
|
5137
5660
|
Returns -1 if the :class:`HashGrid` has not been reserved.""",
|
|
5138
5661
|
export=False,
|
|
5662
|
+
is_differentiable=False,
|
|
5139
5663
|
)
|
|
5140
5664
|
|
|
5141
5665
|
add_builtin(
|
|
@@ -5145,15 +5669,34 @@ add_builtin(
|
|
|
5145
5669
|
group="Geometry",
|
|
5146
5670
|
doc="""Tests for intersection between two triangles (v0, v1, v2) and (u0, u1, u2) using Moller's method.
|
|
5147
5671
|
|
|
5672
|
+
This function works with single precision, may return incorrect results in some case.
|
|
5673
|
+
|
|
5674
|
+
Returns > 0 if triangles intersect.""",
|
|
5675
|
+
export=False,
|
|
5676
|
+
is_differentiable=False,
|
|
5677
|
+
)
|
|
5678
|
+
|
|
5679
|
+
|
|
5680
|
+
add_builtin(
|
|
5681
|
+
"intersect_tri_tri",
|
|
5682
|
+
input_types={"v0": vec3d, "v1": vec3d, "v2": vec3d, "u0": vec3d, "u1": vec3d, "u2": vec3d},
|
|
5683
|
+
value_type=int,
|
|
5684
|
+
group="Geometry",
|
|
5685
|
+
doc="""Tests for intersection between two triangles (v0, v1, v2) and (u0, u1, u2) using Moller's method.
|
|
5686
|
+
|
|
5687
|
+
This function works with double precision, results are more accurate than the single precision version.
|
|
5688
|
+
|
|
5148
5689
|
Returns > 0 if triangles intersect.""",
|
|
5149
5690
|
export=False,
|
|
5691
|
+
is_differentiable=False,
|
|
5150
5692
|
)
|
|
5151
5693
|
|
|
5694
|
+
|
|
5152
5695
|
add_builtin(
|
|
5153
5696
|
"mesh_get",
|
|
5154
5697
|
input_types={"id": uint64},
|
|
5155
5698
|
value_type=Mesh,
|
|
5156
|
-
|
|
5699
|
+
is_differentiable=False,
|
|
5157
5700
|
group="Geometry",
|
|
5158
5701
|
doc="""Retrieves the mesh given its index.""",
|
|
5159
5702
|
export=False,
|
|
@@ -5166,6 +5709,7 @@ add_builtin(
|
|
|
5166
5709
|
group="Geometry",
|
|
5167
5710
|
doc="""Evaluates the face normal the mesh given a face index.""",
|
|
5168
5711
|
export=False,
|
|
5712
|
+
is_differentiable=False,
|
|
5169
5713
|
)
|
|
5170
5714
|
|
|
5171
5715
|
add_builtin(
|
|
@@ -5175,6 +5719,7 @@ add_builtin(
|
|
|
5175
5719
|
group="Geometry",
|
|
5176
5720
|
doc="""Returns the point of the mesh given a index.""",
|
|
5177
5721
|
export=False,
|
|
5722
|
+
is_differentiable=False,
|
|
5178
5723
|
)
|
|
5179
5724
|
|
|
5180
5725
|
add_builtin(
|
|
@@ -5184,6 +5729,7 @@ add_builtin(
|
|
|
5184
5729
|
group="Geometry",
|
|
5185
5730
|
doc="""Returns the velocity of the mesh given a index.""",
|
|
5186
5731
|
export=False,
|
|
5732
|
+
is_differentiable=False,
|
|
5187
5733
|
)
|
|
5188
5734
|
|
|
5189
5735
|
add_builtin(
|
|
@@ -5193,6 +5739,7 @@ add_builtin(
|
|
|
5193
5739
|
group="Geometry",
|
|
5194
5740
|
doc="""Returns the point-index of the mesh given a face-vertex index.""",
|
|
5195
5741
|
export=False,
|
|
5742
|
+
is_differentiable=False,
|
|
5196
5743
|
)
|
|
5197
5744
|
|
|
5198
5745
|
|
|
@@ -5233,12 +5780,32 @@ add_builtin(
|
|
|
5233
5780
|
# ---------------------------------
|
|
5234
5781
|
# Iterators
|
|
5235
5782
|
|
|
5236
|
-
add_builtin("iter_next", input_types={"range": range_t}, value_type=int, group="Utility", export=False, hidden=True)
|
|
5237
5783
|
add_builtin(
|
|
5238
|
-
"iter_next",
|
|
5784
|
+
"iter_next",
|
|
5785
|
+
input_types={"range": range_t},
|
|
5786
|
+
value_type=int,
|
|
5787
|
+
group="Utility",
|
|
5788
|
+
export=False,
|
|
5789
|
+
hidden=True,
|
|
5790
|
+
is_differentiable=False,
|
|
5791
|
+
)
|
|
5792
|
+
add_builtin(
|
|
5793
|
+
"iter_next",
|
|
5794
|
+
input_types={"query": HashGridQuery},
|
|
5795
|
+
value_type=int,
|
|
5796
|
+
group="Utility",
|
|
5797
|
+
export=False,
|
|
5798
|
+
hidden=True,
|
|
5799
|
+
is_differentiable=False,
|
|
5239
5800
|
)
|
|
5240
5801
|
add_builtin(
|
|
5241
|
-
"iter_next",
|
|
5802
|
+
"iter_next",
|
|
5803
|
+
input_types={"query": MeshQueryAABB},
|
|
5804
|
+
value_type=int,
|
|
5805
|
+
group="Utility",
|
|
5806
|
+
export=False,
|
|
5807
|
+
hidden=True,
|
|
5808
|
+
is_differentiable=False,
|
|
5242
5809
|
)
|
|
5243
5810
|
|
|
5244
5811
|
add_builtin(
|
|
@@ -5249,6 +5816,7 @@ add_builtin(
|
|
|
5249
5816
|
group="Utility",
|
|
5250
5817
|
doc="""Returns the range in reversed order.""",
|
|
5251
5818
|
export=False,
|
|
5819
|
+
is_differentiable=False,
|
|
5252
5820
|
)
|
|
5253
5821
|
|
|
5254
5822
|
# ---------------------------------
|
|
@@ -5268,8 +5836,8 @@ _volume_supported_value_types = {
|
|
|
5268
5836
|
|
|
5269
5837
|
|
|
5270
5838
|
def _is_volume_type_supported(dtype):
|
|
5271
|
-
for
|
|
5272
|
-
if types_equal(
|
|
5839
|
+
for value_type in _volume_supported_value_types:
|
|
5840
|
+
if types_equal(value_type, dtype):
|
|
5273
5841
|
return True
|
|
5274
5842
|
return False
|
|
5275
5843
|
|
|
@@ -5397,6 +5965,7 @@ add_builtin(
|
|
|
5397
5965
|
doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k`` for a volume of type type `dtype`.
|
|
5398
5966
|
|
|
5399
5967
|
If the voxel at this index does not exist, this function returns the background value.""",
|
|
5968
|
+
is_differentiable=False,
|
|
5400
5969
|
)
|
|
5401
5970
|
|
|
5402
5971
|
|
|
@@ -5417,6 +5986,7 @@ add_builtin(
|
|
|
5417
5986
|
export=False,
|
|
5418
5987
|
group="Volumes",
|
|
5419
5988
|
doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
|
|
5989
|
+
is_differentiable=False,
|
|
5420
5990
|
)
|
|
5421
5991
|
|
|
5422
5992
|
add_builtin(
|
|
@@ -5447,6 +6017,7 @@ add_builtin(
|
|
|
5447
6017
|
doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k``.
|
|
5448
6018
|
|
|
5449
6019
|
If the voxel at this index does not exist, this function returns the background value""",
|
|
6020
|
+
is_differentiable=False,
|
|
5450
6021
|
)
|
|
5451
6022
|
|
|
5452
6023
|
add_builtin(
|
|
@@ -5455,6 +6026,7 @@ add_builtin(
|
|
|
5455
6026
|
group="Volumes",
|
|
5456
6027
|
doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
|
|
5457
6028
|
export=False,
|
|
6029
|
+
is_differentiable=False,
|
|
5458
6030
|
)
|
|
5459
6031
|
|
|
5460
6032
|
add_builtin(
|
|
@@ -5475,6 +6047,7 @@ add_builtin(
|
|
|
5475
6047
|
doc="""Returns the vector value of voxel with coordinates ``i``, ``j``, ``k``.
|
|
5476
6048
|
|
|
5477
6049
|
If the voxel at this index does not exist, this function returns the background value.""",
|
|
6050
|
+
is_differentiable=False,
|
|
5478
6051
|
)
|
|
5479
6052
|
|
|
5480
6053
|
add_builtin(
|
|
@@ -5483,6 +6056,7 @@ add_builtin(
|
|
|
5483
6056
|
group="Volumes",
|
|
5484
6057
|
doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
|
|
5485
6058
|
export=False,
|
|
6059
|
+
is_differentiable=False,
|
|
5486
6060
|
)
|
|
5487
6061
|
|
|
5488
6062
|
add_builtin(
|
|
@@ -5501,6 +6075,7 @@ add_builtin(
|
|
|
5501
6075
|
doc="""Returns the :class:`int32` value of voxel with coordinates ``i``, ``j``, ``k``.
|
|
5502
6076
|
|
|
5503
6077
|
If the voxel at this index does not exist, this function returns the background value.""",
|
|
6078
|
+
is_differentiable=False,
|
|
5504
6079
|
)
|
|
5505
6080
|
|
|
5506
6081
|
add_builtin(
|
|
@@ -5509,6 +6084,7 @@ add_builtin(
|
|
|
5509
6084
|
group="Volumes",
|
|
5510
6085
|
doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
|
|
5511
6086
|
export=False,
|
|
6087
|
+
is_differentiable=False,
|
|
5512
6088
|
)
|
|
5513
6089
|
|
|
5514
6090
|
|
|
@@ -5590,6 +6166,7 @@ add_builtin(
|
|
|
5590
6166
|
If the voxel at this index does not exist, this function returns -1.
|
|
5591
6167
|
This function is available for both index grids and classical volumes.
|
|
5592
6168
|
""",
|
|
6169
|
+
is_differentiable=False,
|
|
5593
6170
|
)
|
|
5594
6171
|
|
|
5595
6172
|
add_builtin(
|
|
@@ -5631,6 +6208,7 @@ add_builtin(
|
|
|
5631
6208
|
value_type=uint32,
|
|
5632
6209
|
group="Random",
|
|
5633
6210
|
doc="Initialize a new random number generator given a user-defined seed. Returns a 32-bit integer representing the RNG state.",
|
|
6211
|
+
is_differentiable=False,
|
|
5634
6212
|
)
|
|
5635
6213
|
|
|
5636
6214
|
add_builtin(
|
|
@@ -5642,6 +6220,7 @@ add_builtin(
|
|
|
5642
6220
|
|
|
5643
6221
|
This alternative constructor can be useful in parallel programs, where a kernel as a whole should share a seed,
|
|
5644
6222
|
but each thread should generate uncorrelated values. In this case usage should be ``r = rand_init(seed, tid)``""",
|
|
6223
|
+
is_differentiable=False,
|
|
5645
6224
|
)
|
|
5646
6225
|
|
|
5647
6226
|
add_builtin(
|
|
@@ -5650,6 +6229,7 @@ add_builtin(
|
|
|
5650
6229
|
value_type=int,
|
|
5651
6230
|
group="Random",
|
|
5652
6231
|
doc="Return a random integer in the range [-2^31, 2^31).",
|
|
6232
|
+
is_differentiable=False,
|
|
5653
6233
|
)
|
|
5654
6234
|
add_builtin(
|
|
5655
6235
|
"randi",
|
|
@@ -5657,6 +6237,7 @@ add_builtin(
|
|
|
5657
6237
|
value_type=int,
|
|
5658
6238
|
group="Random",
|
|
5659
6239
|
doc="Return a random integer between [low, high).",
|
|
6240
|
+
is_differentiable=False,
|
|
5660
6241
|
)
|
|
5661
6242
|
add_builtin(
|
|
5662
6243
|
"randu",
|
|
@@ -5664,6 +6245,7 @@ add_builtin(
|
|
|
5664
6245
|
value_type=uint32,
|
|
5665
6246
|
group="Random",
|
|
5666
6247
|
doc="Return a random unsigned integer in the range [0, 2^32).",
|
|
6248
|
+
is_differentiable=False,
|
|
5667
6249
|
)
|
|
5668
6250
|
add_builtin(
|
|
5669
6251
|
"randu",
|
|
@@ -5671,6 +6253,7 @@ add_builtin(
|
|
|
5671
6253
|
value_type=uint32,
|
|
5672
6254
|
group="Random",
|
|
5673
6255
|
doc="Return a random unsigned integer between [low, high).",
|
|
6256
|
+
is_differentiable=False,
|
|
5674
6257
|
)
|
|
5675
6258
|
add_builtin(
|
|
5676
6259
|
"randf",
|
|
@@ -5678,6 +6261,7 @@ add_builtin(
|
|
|
5678
6261
|
value_type=float,
|
|
5679
6262
|
group="Random",
|
|
5680
6263
|
doc="Return a random float between [0.0, 1.0).",
|
|
6264
|
+
is_differentiable=False,
|
|
5681
6265
|
)
|
|
5682
6266
|
add_builtin(
|
|
5683
6267
|
"randf",
|
|
@@ -5685,6 +6269,7 @@ add_builtin(
|
|
|
5685
6269
|
value_type=float,
|
|
5686
6270
|
group="Random",
|
|
5687
6271
|
doc="Return a random float between [low, high).",
|
|
6272
|
+
is_differentiable=False,
|
|
5688
6273
|
)
|
|
5689
6274
|
add_builtin(
|
|
5690
6275
|
"randn",
|
|
@@ -5692,6 +6277,7 @@ add_builtin(
|
|
|
5692
6277
|
value_type=float,
|
|
5693
6278
|
group="Random",
|
|
5694
6279
|
doc="Sample a normal (Gaussian) distribution of mean 0 and variance 1. ",
|
|
6280
|
+
is_differentiable=False,
|
|
5695
6281
|
)
|
|
5696
6282
|
|
|
5697
6283
|
add_builtin(
|
|
@@ -5700,6 +6286,7 @@ add_builtin(
|
|
|
5700
6286
|
value_type=int,
|
|
5701
6287
|
group="Random",
|
|
5702
6288
|
doc="Inverse-transform sample a cumulative distribution function.",
|
|
6289
|
+
is_differentiable=False,
|
|
5703
6290
|
)
|
|
5704
6291
|
add_builtin(
|
|
5705
6292
|
"sample_triangle",
|
|
@@ -5707,6 +6294,7 @@ add_builtin(
|
|
|
5707
6294
|
value_type=vec2,
|
|
5708
6295
|
group="Random",
|
|
5709
6296
|
doc="Uniformly sample a triangle. Returns sample barycentric coordinates.",
|
|
6297
|
+
is_differentiable=False,
|
|
5710
6298
|
)
|
|
5711
6299
|
add_builtin(
|
|
5712
6300
|
"sample_unit_ring",
|
|
@@ -5714,6 +6302,7 @@ add_builtin(
|
|
|
5714
6302
|
value_type=vec2,
|
|
5715
6303
|
group="Random",
|
|
5716
6304
|
doc="Uniformly sample a ring in the xy plane.",
|
|
6305
|
+
is_differentiable=False,
|
|
5717
6306
|
)
|
|
5718
6307
|
add_builtin(
|
|
5719
6308
|
"sample_unit_disk",
|
|
@@ -5721,6 +6310,7 @@ add_builtin(
|
|
|
5721
6310
|
value_type=vec2,
|
|
5722
6311
|
group="Random",
|
|
5723
6312
|
doc="Uniformly sample a disk in the xy plane.",
|
|
6313
|
+
is_differentiable=False,
|
|
5724
6314
|
)
|
|
5725
6315
|
add_builtin(
|
|
5726
6316
|
"sample_unit_sphere_surface",
|
|
@@ -5728,6 +6318,7 @@ add_builtin(
|
|
|
5728
6318
|
value_type=vec3,
|
|
5729
6319
|
group="Random",
|
|
5730
6320
|
doc="Uniformly sample a unit sphere surface.",
|
|
6321
|
+
is_differentiable=False,
|
|
5731
6322
|
)
|
|
5732
6323
|
add_builtin(
|
|
5733
6324
|
"sample_unit_sphere",
|
|
@@ -5735,6 +6326,7 @@ add_builtin(
|
|
|
5735
6326
|
value_type=vec3,
|
|
5736
6327
|
group="Random",
|
|
5737
6328
|
doc="Uniformly sample a unit sphere.",
|
|
6329
|
+
is_differentiable=False,
|
|
5738
6330
|
)
|
|
5739
6331
|
add_builtin(
|
|
5740
6332
|
"sample_unit_hemisphere_surface",
|
|
@@ -5742,6 +6334,7 @@ add_builtin(
|
|
|
5742
6334
|
value_type=vec3,
|
|
5743
6335
|
group="Random",
|
|
5744
6336
|
doc="Uniformly sample a unit hemisphere surface.",
|
|
6337
|
+
is_differentiable=False,
|
|
5745
6338
|
)
|
|
5746
6339
|
add_builtin(
|
|
5747
6340
|
"sample_unit_hemisphere",
|
|
@@ -5749,6 +6342,7 @@ add_builtin(
|
|
|
5749
6342
|
value_type=vec3,
|
|
5750
6343
|
group="Random",
|
|
5751
6344
|
doc="Uniformly sample a unit hemisphere.",
|
|
6345
|
+
is_differentiable=False,
|
|
5752
6346
|
)
|
|
5753
6347
|
add_builtin(
|
|
5754
6348
|
"sample_unit_square",
|
|
@@ -5756,6 +6350,7 @@ add_builtin(
|
|
|
5756
6350
|
value_type=vec2,
|
|
5757
6351
|
group="Random",
|
|
5758
6352
|
doc="Uniformly sample a unit square.",
|
|
6353
|
+
is_differentiable=False,
|
|
5759
6354
|
)
|
|
5760
6355
|
add_builtin(
|
|
5761
6356
|
"sample_unit_cube",
|
|
@@ -5763,6 +6358,7 @@ add_builtin(
|
|
|
5763
6358
|
value_type=vec3,
|
|
5764
6359
|
group="Random",
|
|
5765
6360
|
doc="Uniformly sample a unit cube.",
|
|
6361
|
+
is_differentiable=False,
|
|
5766
6362
|
)
|
|
5767
6363
|
|
|
5768
6364
|
add_builtin(
|
|
@@ -5774,6 +6370,7 @@ add_builtin(
|
|
|
5774
6370
|
|
|
5775
6371
|
:param state: RNG state
|
|
5776
6372
|
:param lam: The expected value of the distribution""",
|
|
6373
|
+
is_differentiable=False,
|
|
5777
6374
|
)
|
|
5778
6375
|
|
|
5779
6376
|
add_builtin(
|
|
@@ -5841,7 +6438,7 @@ add_builtin(
|
|
|
5841
6438
|
value_type=vec2,
|
|
5842
6439
|
group="Random",
|
|
5843
6440
|
doc="Divergence-free vector field based on the gradient of a Perlin noise function.",
|
|
5844
|
-
|
|
6441
|
+
is_differentiable=False,
|
|
5845
6442
|
)
|
|
5846
6443
|
add_builtin(
|
|
5847
6444
|
"curlnoise",
|
|
@@ -5850,7 +6447,7 @@ add_builtin(
|
|
|
5850
6447
|
value_type=vec3,
|
|
5851
6448
|
group="Random",
|
|
5852
6449
|
doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
|
|
5853
|
-
|
|
6450
|
+
is_differentiable=False,
|
|
5854
6451
|
)
|
|
5855
6452
|
add_builtin(
|
|
5856
6453
|
"curlnoise",
|
|
@@ -5859,7 +6456,7 @@ add_builtin(
|
|
|
5859
6456
|
value_type=vec3,
|
|
5860
6457
|
group="Random",
|
|
5861
6458
|
doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
|
|
5862
|
-
|
|
6459
|
+
is_differentiable=False,
|
|
5863
6460
|
)
|
|
5864
6461
|
|
|
5865
6462
|
|
|
@@ -5891,10 +6488,17 @@ add_builtin(
|
|
|
5891
6488
|
dispatch_func=printf_dispatch_func,
|
|
5892
6489
|
group="Utility",
|
|
5893
6490
|
doc="Allows printing formatted strings using C-style format specifiers.",
|
|
6491
|
+
is_differentiable=False,
|
|
6492
|
+
)
|
|
6493
|
+
|
|
6494
|
+
add_builtin(
|
|
6495
|
+
"print",
|
|
6496
|
+
input_types={"value": Any},
|
|
6497
|
+
doc="Print variable to stdout",
|
|
6498
|
+
export=False,
|
|
6499
|
+
group="Utility",
|
|
5894
6500
|
)
|
|
5895
6501
|
|
|
5896
|
-
add_builtin("print", input_types={"value": Any}, doc="Print variable to stdout", export=False, group="Utility")
|
|
5897
|
-
|
|
5898
6502
|
add_builtin(
|
|
5899
6503
|
"breakpoint",
|
|
5900
6504
|
input_types={},
|
|
@@ -5903,6 +6507,7 @@ add_builtin(
|
|
|
5903
6507
|
group="Utility",
|
|
5904
6508
|
namespace="",
|
|
5905
6509
|
native_func="__debugbreak",
|
|
6510
|
+
is_differentiable=False,
|
|
5906
6511
|
)
|
|
5907
6512
|
|
|
5908
6513
|
# helpers
|
|
@@ -5920,6 +6525,7 @@ add_builtin(
|
|
|
5920
6525
|
This function may not be called from user-defined Warp functions.""",
|
|
5921
6526
|
namespace="",
|
|
5922
6527
|
native_func="builtin_tid1d",
|
|
6528
|
+
is_differentiable=False,
|
|
5923
6529
|
)
|
|
5924
6530
|
|
|
5925
6531
|
add_builtin(
|
|
@@ -5930,6 +6536,7 @@ add_builtin(
|
|
|
5930
6536
|
doc="Returns the number of threads in the current block.",
|
|
5931
6537
|
namespace="",
|
|
5932
6538
|
native_func="builtin_block_dim",
|
|
6539
|
+
is_differentiable=False,
|
|
5933
6540
|
)
|
|
5934
6541
|
|
|
5935
6542
|
add_builtin(
|
|
@@ -5944,6 +6551,7 @@ add_builtin(
|
|
|
5944
6551
|
This function may not be called from user-defined Warp functions.""",
|
|
5945
6552
|
namespace="",
|
|
5946
6553
|
native_func="builtin_tid2d",
|
|
6554
|
+
is_differentiable=False,
|
|
5947
6555
|
)
|
|
5948
6556
|
|
|
5949
6557
|
add_builtin(
|
|
@@ -5958,6 +6566,7 @@ add_builtin(
|
|
|
5958
6566
|
This function may not be called from user-defined Warp functions.""",
|
|
5959
6567
|
namespace="",
|
|
5960
6568
|
native_func="builtin_tid3d",
|
|
6569
|
+
is_differentiable=False,
|
|
5961
6570
|
)
|
|
5962
6571
|
|
|
5963
6572
|
add_builtin(
|
|
@@ -5972,17 +6581,37 @@ add_builtin(
|
|
|
5972
6581
|
This function may not be called from user-defined Warp functions.""",
|
|
5973
6582
|
namespace="",
|
|
5974
6583
|
native_func="builtin_tid4d",
|
|
6584
|
+
is_differentiable=False,
|
|
5975
6585
|
)
|
|
5976
6586
|
|
|
5977
6587
|
|
|
6588
|
+
def copy_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
6589
|
+
a = arg_types["a"]
|
|
6590
|
+
|
|
6591
|
+
# if the input is a shared tile, we force a copy
|
|
6592
|
+
if is_tile(a) and a.storage == "shared":
|
|
6593
|
+
return tile(
|
|
6594
|
+
dtype=a.dtype,
|
|
6595
|
+
shape=a.shape,
|
|
6596
|
+
storage=a.storage,
|
|
6597
|
+
strides=a.strides,
|
|
6598
|
+
layout=a.layout,
|
|
6599
|
+
owner=True,
|
|
6600
|
+
)
|
|
6601
|
+
|
|
6602
|
+
return a
|
|
6603
|
+
|
|
6604
|
+
|
|
5978
6605
|
add_builtin(
|
|
5979
6606
|
"copy",
|
|
5980
6607
|
input_types={"a": Any},
|
|
5981
|
-
value_func=
|
|
6608
|
+
value_func=copy_value_func,
|
|
5982
6609
|
hidden=True,
|
|
5983
6610
|
export=False,
|
|
5984
6611
|
group="Utility",
|
|
5985
6612
|
)
|
|
6613
|
+
|
|
6614
|
+
|
|
5986
6615
|
add_builtin(
|
|
5987
6616
|
"assign",
|
|
5988
6617
|
input_types={"dest": Any, "src": Any},
|
|
@@ -5992,61 +6621,88 @@ add_builtin(
|
|
|
5992
6621
|
)
|
|
5993
6622
|
|
|
5994
6623
|
|
|
5995
|
-
def
|
|
5996
|
-
|
|
5997
|
-
|
|
5998
|
-
"version. Use wp.where(cond, value_if_true, value_if_false) instead.",
|
|
5999
|
-
category=DeprecationWarning,
|
|
6000
|
-
)
|
|
6001
|
-
|
|
6002
|
-
func_args = tuple(args.values())
|
|
6003
|
-
template_args = ()
|
|
6624
|
+
def select_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
6625
|
+
if arg_types is None:
|
|
6626
|
+
return Any
|
|
6004
6627
|
|
|
6005
|
-
|
|
6628
|
+
raise RuntimeError("wp.select() has been removed. Use wp.where(cond, value_if_true, value_if_false) instead.")
|
|
6006
6629
|
|
|
6007
6630
|
|
|
6008
6631
|
add_builtin(
|
|
6009
6632
|
"select",
|
|
6010
6633
|
input_types={"cond": builtins.bool, "value_if_false": Any, "value_if_true": Any},
|
|
6011
|
-
value_func=
|
|
6012
|
-
dispatch_func=select_dispatch_func,
|
|
6634
|
+
value_func=select_value_func,
|
|
6013
6635
|
doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
6014
6636
|
|
|
6015
|
-
..
|
|
6637
|
+
.. versionremoved:: 1.10
|
|
6016
6638
|
Use :func:`where` instead, which has the more intuitive argument order:
|
|
6017
|
-
``where(cond, value_if_true, value_if_false)``.
|
|
6639
|
+
``where(cond, value_if_true, value_if_false)``.
|
|
6640
|
+
|
|
6641
|
+
.. deprecated:: 1.7""",
|
|
6018
6642
|
group="Utility",
|
|
6019
6643
|
)
|
|
6020
6644
|
for t in int_types:
|
|
6021
6645
|
add_builtin(
|
|
6022
6646
|
"select",
|
|
6023
6647
|
input_types={"cond": t, "value_if_false": Any, "value_if_true": Any},
|
|
6024
|
-
value_func=
|
|
6025
|
-
dispatch_func=select_dispatch_func,
|
|
6648
|
+
value_func=select_value_func,
|
|
6026
6649
|
doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
6027
6650
|
|
|
6028
|
-
..
|
|
6651
|
+
.. versionremoved:: 1.10
|
|
6029
6652
|
Use :func:`where` instead, which has the more intuitive argument order:
|
|
6030
|
-
``where(cond, value_if_true, value_if_false)``.
|
|
6653
|
+
``where(cond, value_if_true, value_if_false)``.
|
|
6654
|
+
|
|
6655
|
+
.. deprecated:: 1.7""",
|
|
6031
6656
|
group="Utility",
|
|
6032
6657
|
)
|
|
6033
6658
|
add_builtin(
|
|
6034
6659
|
"select",
|
|
6035
6660
|
input_types={"arr": array(dtype=Any), "value_if_false": Any, "value_if_true": Any},
|
|
6036
|
-
value_func=
|
|
6037
|
-
dispatch_func=select_dispatch_func,
|
|
6661
|
+
value_func=select_value_func,
|
|
6038
6662
|
doc="""Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
6039
6663
|
|
|
6040
|
-
..
|
|
6664
|
+
.. versionremoved:: 1.10
|
|
6041
6665
|
Use :func:`where` instead, which has the more intuitive argument order:
|
|
6042
|
-
``where(arr, value_if_true, value_if_false)``.
|
|
6666
|
+
``where(arr, value_if_true, value_if_false)``.
|
|
6667
|
+
|
|
6668
|
+
.. deprecated:: 1.7""",
|
|
6043
6669
|
group="Utility",
|
|
6044
6670
|
)
|
|
6045
6671
|
|
|
6672
|
+
|
|
6673
|
+
def where_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
6674
|
+
if arg_types is None:
|
|
6675
|
+
return Any
|
|
6676
|
+
|
|
6677
|
+
v_true = arg_types["value_if_true"]
|
|
6678
|
+
v_false = arg_types["value_if_false"]
|
|
6679
|
+
|
|
6680
|
+
if not types_equal(v_true, v_false):
|
|
6681
|
+
raise RuntimeError(f"where() true value type ({v_true}) must be of the same type as the false type ({v_false})")
|
|
6682
|
+
|
|
6683
|
+
if is_tile(v_false):
|
|
6684
|
+
if v_true.storage == "register":
|
|
6685
|
+
return v_true
|
|
6686
|
+
if v_false.storage == "register":
|
|
6687
|
+
return v_false
|
|
6688
|
+
|
|
6689
|
+
# both v_true and v_false are shared
|
|
6690
|
+
return tile(
|
|
6691
|
+
dtype=v_true.dtype,
|
|
6692
|
+
shape=v_true.shape,
|
|
6693
|
+
storage=v_true.storage,
|
|
6694
|
+
strides=v_true.strides,
|
|
6695
|
+
layout=v_true.layout,
|
|
6696
|
+
owner=True,
|
|
6697
|
+
)
|
|
6698
|
+
|
|
6699
|
+
return v_true
|
|
6700
|
+
|
|
6701
|
+
|
|
6046
6702
|
add_builtin(
|
|
6047
6703
|
"where",
|
|
6048
6704
|
input_types={"cond": builtins.bool, "value_if_true": Any, "value_if_false": Any},
|
|
6049
|
-
value_func=
|
|
6705
|
+
value_func=where_value_func,
|
|
6050
6706
|
doc="Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``.",
|
|
6051
6707
|
group="Utility",
|
|
6052
6708
|
)
|
|
@@ -6054,14 +6710,14 @@ for t in int_types:
|
|
|
6054
6710
|
add_builtin(
|
|
6055
6711
|
"where",
|
|
6056
6712
|
input_types={"cond": t, "value_if_true": Any, "value_if_false": Any},
|
|
6057
|
-
value_func=
|
|
6713
|
+
value_func=where_value_func,
|
|
6058
6714
|
doc="Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``.",
|
|
6059
6715
|
group="Utility",
|
|
6060
6716
|
)
|
|
6061
6717
|
add_builtin(
|
|
6062
6718
|
"where",
|
|
6063
6719
|
input_types={"arr": array(dtype=Any), "value_if_true": Any, "value_if_false": Any},
|
|
6064
|
-
value_func=
|
|
6720
|
+
value_func=where_value_func,
|
|
6065
6721
|
doc="Select between two arguments, if ``arr`` is not null then return ``value_if_true``, otherwise return ``value_if_false``.",
|
|
6066
6722
|
group="Utility",
|
|
6067
6723
|
)
|
|
@@ -6099,7 +6755,7 @@ add_builtin(
|
|
|
6099
6755
|
group="Utility",
|
|
6100
6756
|
hidden=True,
|
|
6101
6757
|
export=False,
|
|
6102
|
-
|
|
6758
|
+
is_differentiable=False,
|
|
6103
6759
|
)
|
|
6104
6760
|
|
|
6105
6761
|
|
|
@@ -6140,7 +6796,7 @@ add_builtin(
|
|
|
6140
6796
|
native_func="fixedarray_t",
|
|
6141
6797
|
group="Utility",
|
|
6142
6798
|
export=False,
|
|
6143
|
-
|
|
6799
|
+
is_differentiable=False,
|
|
6144
6800
|
hidden=True, # Unhide once we can document both a built-in and a Python scope function sharing the same name.
|
|
6145
6801
|
)
|
|
6146
6802
|
|
|
@@ -6183,14 +6839,13 @@ for array_type in array_types:
|
|
|
6183
6839
|
# does argument checking and type propagation for view()
|
|
6184
6840
|
def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
6185
6841
|
arr_type = arg_types["arr"]
|
|
6186
|
-
idx_types = tuple(arg_types[x] for x in "
|
|
6842
|
+
idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
|
|
6187
6843
|
|
|
6188
6844
|
if not is_array(arr_type):
|
|
6189
6845
|
raise RuntimeError("view() first argument must be an array")
|
|
6190
6846
|
|
|
6191
6847
|
idx_count = len(idx_types)
|
|
6192
|
-
|
|
6193
|
-
if idx_count >= arr_type.ndim:
|
|
6848
|
+
if idx_count > arr_type.ndim:
|
|
6194
6849
|
raise RuntimeError(
|
|
6195
6850
|
f"Trying to create an array view with {idx_count} indices, "
|
|
6196
6851
|
f"but the array only has {arr_type.ndim} dimension(s). "
|
|
@@ -6198,14 +6853,35 @@ def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]
|
|
|
6198
6853
|
f"the expected number of dimensions, e.g.: def func(param: wp.array3d(dtype=float): ..."
|
|
6199
6854
|
)
|
|
6200
6855
|
|
|
6201
|
-
|
|
6202
|
-
|
|
6203
|
-
|
|
6204
|
-
|
|
6856
|
+
has_slice = any(is_slice(x) for x in idx_types)
|
|
6857
|
+
if has_slice:
|
|
6858
|
+
# check index types
|
|
6859
|
+
for t in idx_types:
|
|
6860
|
+
if not (type_is_int(t) or is_slice(t)):
|
|
6861
|
+
raise RuntimeError(
|
|
6862
|
+
f"view() index arguments must be of integer or slice types, got index of type {type_repr(t)}"
|
|
6863
|
+
)
|
|
6864
|
+
|
|
6865
|
+
# Each integer index collapses one dimension.
|
|
6866
|
+
int_count = sum(x.step == 0 for x in idx_types)
|
|
6867
|
+
ndim = arr_type.ndim - int_count
|
|
6868
|
+
assert ndim > 0
|
|
6869
|
+
else:
|
|
6870
|
+
if idx_count == arr_type.ndim:
|
|
6871
|
+
raise RuntimeError("Expected to call `address()` instead of `view()`")
|
|
6872
|
+
|
|
6873
|
+
# check index types
|
|
6874
|
+
for t in idx_types:
|
|
6875
|
+
if not type_is_int(t):
|
|
6876
|
+
raise RuntimeError(
|
|
6877
|
+
f"view() index arguments must be of integer or slice types, got index of type {type_repr(t)}"
|
|
6878
|
+
)
|
|
6879
|
+
|
|
6880
|
+
# create an array view with leading dimensions removed
|
|
6881
|
+
ndim = arr_type.ndim - idx_count
|
|
6882
|
+
assert ndim > 0
|
|
6205
6883
|
|
|
6206
|
-
# create an array view with leading dimensions removed
|
|
6207
6884
|
dtype = arr_type.dtype
|
|
6208
|
-
ndim = arr_type.ndim - idx_count
|
|
6209
6885
|
if isinstance(arr_type, (fabricarray, indexedfabricarray)):
|
|
6210
6886
|
# fabric array of arrays: return array attribute as a regular array
|
|
6211
6887
|
return array(dtype=dtype, ndim=ndim)
|
|
@@ -6216,8 +6892,18 @@ def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]
|
|
|
6216
6892
|
for array_type in array_types:
|
|
6217
6893
|
add_builtin(
|
|
6218
6894
|
"view",
|
|
6219
|
-
input_types={
|
|
6220
|
-
|
|
6895
|
+
input_types={
|
|
6896
|
+
"arr": array_type(dtype=Any),
|
|
6897
|
+
"i": Any,
|
|
6898
|
+
"j": Any,
|
|
6899
|
+
"k": Any,
|
|
6900
|
+
"l": Any,
|
|
6901
|
+
},
|
|
6902
|
+
defaults={
|
|
6903
|
+
"j": None,
|
|
6904
|
+
"k": None,
|
|
6905
|
+
"l": None,
|
|
6906
|
+
},
|
|
6221
6907
|
constraint=sametypes,
|
|
6222
6908
|
hidden=True,
|
|
6223
6909
|
value_func=view_value_func,
|
|
@@ -6321,6 +7007,7 @@ add_builtin(
|
|
|
6321
7007
|
hidden=True,
|
|
6322
7008
|
skip_replay=True,
|
|
6323
7009
|
group="Utility",
|
|
7010
|
+
is_differentiable=False,
|
|
6324
7011
|
)
|
|
6325
7012
|
|
|
6326
7013
|
|
|
@@ -6337,6 +7024,7 @@ add_builtin(
|
|
|
6337
7024
|
dispatch_func=load_dispatch_func,
|
|
6338
7025
|
hidden=True,
|
|
6339
7026
|
group="Utility",
|
|
7027
|
+
is_differentiable=False,
|
|
6340
7028
|
)
|
|
6341
7029
|
|
|
6342
7030
|
|
|
@@ -6412,6 +7100,13 @@ def create_atomic_op_value_func(op: str):
|
|
|
6412
7100
|
f"atomic_{op}() operations only work on arrays with [u]int32, [u]int64, float32, or float64 "
|
|
6413
7101
|
f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
|
|
6414
7102
|
)
|
|
7103
|
+
elif op in ("and", "or", "xor"):
|
|
7104
|
+
supported_atomic_types = (warp.int32, warp.int64, warp.uint32, warp.uint64)
|
|
7105
|
+
if not any(types_equal(scalar_type, x, match_generic=True) for x in supported_atomic_types):
|
|
7106
|
+
raise RuntimeError(
|
|
7107
|
+
f"atomic_{op}() operations only work on arrays with [u]int32 or [u]int64 "
|
|
7108
|
+
f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
|
|
7109
|
+
)
|
|
6415
7110
|
else:
|
|
6416
7111
|
raise NotImplementedError
|
|
6417
7112
|
|
|
@@ -6445,7 +7140,8 @@ for array_type in array_types:
|
|
|
6445
7140
|
value_func=create_atomic_op_value_func("add"),
|
|
6446
7141
|
dispatch_func=atomic_op_dispatch_func,
|
|
6447
7142
|
doc="""Atomically adds ``value`` onto ``arr[i]`` and returns the original value of ``arr[i]``.
|
|
6448
|
-
|
|
7143
|
+
|
|
7144
|
+
This function is automatically invoked when using the syntax ``arr[i] += value``.""",
|
|
6449
7145
|
group="Utility",
|
|
6450
7146
|
skip_replay=True,
|
|
6451
7147
|
)
|
|
@@ -6457,7 +7153,8 @@ for array_type in array_types:
|
|
|
6457
7153
|
value_func=create_atomic_op_value_func("add"),
|
|
6458
7154
|
dispatch_func=atomic_op_dispatch_func,
|
|
6459
7155
|
doc="""Atomically adds ``value`` onto ``arr[i,j]`` and returns the original value of ``arr[i,j]``.
|
|
6460
|
-
|
|
7156
|
+
|
|
7157
|
+
This function is automatically invoked when using the syntax ``arr[i,j] += value``.""",
|
|
6461
7158
|
group="Utility",
|
|
6462
7159
|
skip_replay=True,
|
|
6463
7160
|
)
|
|
@@ -6469,7 +7166,8 @@ for array_type in array_types:
|
|
|
6469
7166
|
value_func=create_atomic_op_value_func("add"),
|
|
6470
7167
|
dispatch_func=atomic_op_dispatch_func,
|
|
6471
7168
|
doc="""Atomically adds ``value`` onto ``arr[i,j,k]`` and returns the original value of ``arr[i,j,k]``.
|
|
6472
|
-
|
|
7169
|
+
|
|
7170
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k] += value``.""",
|
|
6473
7171
|
group="Utility",
|
|
6474
7172
|
skip_replay=True,
|
|
6475
7173
|
)
|
|
@@ -6481,7 +7179,8 @@ for array_type in array_types:
|
|
|
6481
7179
|
value_func=create_atomic_op_value_func("add"),
|
|
6482
7180
|
dispatch_func=atomic_op_dispatch_func,
|
|
6483
7181
|
doc="""Atomically adds ``value`` onto ``arr[i,j,k,l]`` and returns the original value of ``arr[i,j,k,l]``.
|
|
6484
|
-
|
|
7182
|
+
|
|
7183
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k,l] += value``.""",
|
|
6485
7184
|
group="Utility",
|
|
6486
7185
|
skip_replay=True,
|
|
6487
7186
|
)
|
|
@@ -6494,7 +7193,8 @@ for array_type in array_types:
|
|
|
6494
7193
|
value_func=create_atomic_op_value_func("sub"),
|
|
6495
7194
|
dispatch_func=atomic_op_dispatch_func,
|
|
6496
7195
|
doc="""Atomically subtracts ``value`` onto ``arr[i]`` and returns the original value of ``arr[i]``.
|
|
6497
|
-
|
|
7196
|
+
|
|
7197
|
+
This function is automatically invoked when using the syntax ``arr[i] -= value``.""",
|
|
6498
7198
|
group="Utility",
|
|
6499
7199
|
skip_replay=True,
|
|
6500
7200
|
)
|
|
@@ -6506,7 +7206,8 @@ for array_type in array_types:
|
|
|
6506
7206
|
value_func=create_atomic_op_value_func("sub"),
|
|
6507
7207
|
dispatch_func=atomic_op_dispatch_func,
|
|
6508
7208
|
doc="""Atomically subtracts ``value`` onto ``arr[i,j]`` and returns the original value of ``arr[i,j]``.
|
|
6509
|
-
|
|
7209
|
+
|
|
7210
|
+
This function is automatically invoked when using the syntax ``arr[i,j] -= value``.""",
|
|
6510
7211
|
group="Utility",
|
|
6511
7212
|
skip_replay=True,
|
|
6512
7213
|
)
|
|
@@ -6518,7 +7219,8 @@ for array_type in array_types:
|
|
|
6518
7219
|
value_func=create_atomic_op_value_func("sub"),
|
|
6519
7220
|
dispatch_func=atomic_op_dispatch_func,
|
|
6520
7221
|
doc="""Atomically subtracts ``value`` onto ``arr[i,j,k]`` and returns the original value of ``arr[i,j,k]``.
|
|
6521
|
-
|
|
7222
|
+
|
|
7223
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k] -= value``.""",
|
|
6522
7224
|
group="Utility",
|
|
6523
7225
|
skip_replay=True,
|
|
6524
7226
|
)
|
|
@@ -6530,7 +7232,8 @@ for array_type in array_types:
|
|
|
6530
7232
|
value_func=create_atomic_op_value_func("sub"),
|
|
6531
7233
|
dispatch_func=atomic_op_dispatch_func,
|
|
6532
7234
|
doc="""Atomically subtracts ``value`` onto ``arr[i,j,k,l]`` and returns the original value of ``arr[i,j,k,l]``.
|
|
6533
|
-
|
|
7235
|
+
|
|
7236
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k,l] -= value``.""",
|
|
6534
7237
|
group="Utility",
|
|
6535
7238
|
skip_replay=True,
|
|
6536
7239
|
)
|
|
@@ -6653,6 +7356,7 @@ for array_type in array_types:
|
|
|
6653
7356
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6654
7357
|
group="Utility",
|
|
6655
7358
|
skip_replay=True,
|
|
7359
|
+
is_differentiable=False,
|
|
6656
7360
|
)
|
|
6657
7361
|
add_builtin(
|
|
6658
7362
|
"atomic_cas",
|
|
@@ -6666,6 +7370,7 @@ for array_type in array_types:
|
|
|
6666
7370
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6667
7371
|
group="Utility",
|
|
6668
7372
|
skip_replay=True,
|
|
7373
|
+
is_differentiable=False,
|
|
6669
7374
|
)
|
|
6670
7375
|
add_builtin(
|
|
6671
7376
|
"atomic_cas",
|
|
@@ -6679,6 +7384,7 @@ for array_type in array_types:
|
|
|
6679
7384
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6680
7385
|
group="Utility",
|
|
6681
7386
|
skip_replay=True,
|
|
7387
|
+
is_differentiable=False,
|
|
6682
7388
|
)
|
|
6683
7389
|
add_builtin(
|
|
6684
7390
|
"atomic_cas",
|
|
@@ -6700,6 +7406,7 @@ for array_type in array_types:
|
|
|
6700
7406
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6701
7407
|
group="Utility",
|
|
6702
7408
|
skip_replay=True,
|
|
7409
|
+
is_differentiable=False,
|
|
6703
7410
|
)
|
|
6704
7411
|
|
|
6705
7412
|
add_builtin(
|
|
@@ -6714,6 +7421,7 @@ for array_type in array_types:
|
|
|
6714
7421
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6715
7422
|
group="Utility",
|
|
6716
7423
|
skip_replay=True,
|
|
7424
|
+
is_differentiable=False,
|
|
6717
7425
|
)
|
|
6718
7426
|
add_builtin(
|
|
6719
7427
|
"atomic_exch",
|
|
@@ -6727,6 +7435,7 @@ for array_type in array_types:
|
|
|
6727
7435
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6728
7436
|
group="Utility",
|
|
6729
7437
|
skip_replay=True,
|
|
7438
|
+
is_differentiable=False,
|
|
6730
7439
|
)
|
|
6731
7440
|
add_builtin(
|
|
6732
7441
|
"atomic_exch",
|
|
@@ -6740,6 +7449,7 @@ for array_type in array_types:
|
|
|
6740
7449
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6741
7450
|
group="Utility",
|
|
6742
7451
|
skip_replay=True,
|
|
7452
|
+
is_differentiable=False,
|
|
6743
7453
|
)
|
|
6744
7454
|
add_builtin(
|
|
6745
7455
|
"atomic_exch",
|
|
@@ -6755,6 +7465,177 @@ for array_type in array_types:
|
|
|
6755
7465
|
skip_replay=True,
|
|
6756
7466
|
)
|
|
6757
7467
|
|
|
7468
|
+
add_builtin(
|
|
7469
|
+
"atomic_and",
|
|
7470
|
+
hidden=hidden,
|
|
7471
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
7472
|
+
constraint=atomic_op_constraint,
|
|
7473
|
+
value_func=create_atomic_op_value_func("and"),
|
|
7474
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7475
|
+
doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i]``, atomically update the array, and return the old value.
|
|
7476
|
+
|
|
7477
|
+
This function is automatically invoked when using the syntax ``arr[i] &= value``.""",
|
|
7478
|
+
group="Utility",
|
|
7479
|
+
skip_replay=True,
|
|
7480
|
+
is_differentiable=False,
|
|
7481
|
+
)
|
|
7482
|
+
add_builtin(
|
|
7483
|
+
"atomic_and",
|
|
7484
|
+
hidden=hidden,
|
|
7485
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
7486
|
+
constraint=atomic_op_constraint,
|
|
7487
|
+
value_func=create_atomic_op_value_func("and"),
|
|
7488
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7489
|
+
doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
|
|
7490
|
+
|
|
7491
|
+
This function is automatically invoked when using the syntax ``arr[i,j] &= value``.""",
|
|
7492
|
+
group="Utility",
|
|
7493
|
+
skip_replay=True,
|
|
7494
|
+
is_differentiable=False,
|
|
7495
|
+
)
|
|
7496
|
+
add_builtin(
|
|
7497
|
+
"atomic_and",
|
|
7498
|
+
hidden=hidden,
|
|
7499
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
7500
|
+
constraint=atomic_op_constraint,
|
|
7501
|
+
value_func=create_atomic_op_value_func("and"),
|
|
7502
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7503
|
+
doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
|
|
7504
|
+
|
|
7505
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k] &= value``.""",
|
|
7506
|
+
group="Utility",
|
|
7507
|
+
skip_replay=True,
|
|
7508
|
+
is_differentiable=False,
|
|
7509
|
+
)
|
|
7510
|
+
add_builtin(
|
|
7511
|
+
"atomic_and",
|
|
7512
|
+
hidden=hidden,
|
|
7513
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
7514
|
+
constraint=atomic_op_constraint,
|
|
7515
|
+
value_func=create_atomic_op_value_func("and"),
|
|
7516
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7517
|
+
doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
|
|
7518
|
+
|
|
7519
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k,l] &= value``.""",
|
|
7520
|
+
group="Utility",
|
|
7521
|
+
skip_replay=True,
|
|
7522
|
+
is_differentiable=False,
|
|
7523
|
+
)
|
|
7524
|
+
|
|
7525
|
+
add_builtin(
|
|
7526
|
+
"atomic_or",
|
|
7527
|
+
hidden=hidden,
|
|
7528
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
7529
|
+
constraint=atomic_op_constraint,
|
|
7530
|
+
value_func=create_atomic_op_value_func("or"),
|
|
7531
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7532
|
+
doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i]``, atomically update the array, and return the old value.
|
|
7533
|
+
|
|
7534
|
+
This function is automatically invoked when using the syntax ``arr[i] |= value``.""",
|
|
7535
|
+
group="Utility",
|
|
7536
|
+
skip_replay=True,
|
|
7537
|
+
is_differentiable=False,
|
|
7538
|
+
)
|
|
7539
|
+
add_builtin(
|
|
7540
|
+
"atomic_or",
|
|
7541
|
+
hidden=hidden,
|
|
7542
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
7543
|
+
constraint=atomic_op_constraint,
|
|
7544
|
+
value_func=create_atomic_op_value_func("or"),
|
|
7545
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7546
|
+
doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
|
|
7547
|
+
|
|
7548
|
+
This function is automatically invoked when using the syntax ``arr[i,j] |= value``.""",
|
|
7549
|
+
group="Utility",
|
|
7550
|
+
skip_replay=True,
|
|
7551
|
+
is_differentiable=False,
|
|
7552
|
+
)
|
|
7553
|
+
add_builtin(
|
|
7554
|
+
"atomic_or",
|
|
7555
|
+
hidden=hidden,
|
|
7556
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
7557
|
+
constraint=atomic_op_constraint,
|
|
7558
|
+
value_func=create_atomic_op_value_func("or"),
|
|
7559
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7560
|
+
doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
|
|
7561
|
+
|
|
7562
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k] |= value``.""",
|
|
7563
|
+
group="Utility",
|
|
7564
|
+
skip_replay=True,
|
|
7565
|
+
is_differentiable=False,
|
|
7566
|
+
)
|
|
7567
|
+
add_builtin(
|
|
7568
|
+
"atomic_or",
|
|
7569
|
+
hidden=hidden,
|
|
7570
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
7571
|
+
constraint=atomic_op_constraint,
|
|
7572
|
+
value_func=create_atomic_op_value_func("or"),
|
|
7573
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7574
|
+
doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
|
|
7575
|
+
|
|
7576
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k,l] |= value``.""",
|
|
7577
|
+
group="Utility",
|
|
7578
|
+
skip_replay=True,
|
|
7579
|
+
is_differentiable=False,
|
|
7580
|
+
)
|
|
7581
|
+
|
|
7582
|
+
add_builtin(
|
|
7583
|
+
"atomic_xor",
|
|
7584
|
+
hidden=hidden,
|
|
7585
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
7586
|
+
constraint=atomic_op_constraint,
|
|
7587
|
+
value_func=create_atomic_op_value_func("xor"),
|
|
7588
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7589
|
+
doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i]``, atomically update the array, and return the old value.
|
|
7590
|
+
|
|
7591
|
+
This function is automatically invoked when using the syntax ``arr[i] ^= value``.""",
|
|
7592
|
+
group="Utility",
|
|
7593
|
+
skip_replay=True,
|
|
7594
|
+
is_differentiable=False,
|
|
7595
|
+
)
|
|
7596
|
+
add_builtin(
|
|
7597
|
+
"atomic_xor",
|
|
7598
|
+
hidden=hidden,
|
|
7599
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
7600
|
+
constraint=atomic_op_constraint,
|
|
7601
|
+
value_func=create_atomic_op_value_func("xor"),
|
|
7602
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7603
|
+
doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
|
|
7604
|
+
|
|
7605
|
+
This function is automatically invoked when using the syntax ``arr[i,j] ^= value``.""",
|
|
7606
|
+
group="Utility",
|
|
7607
|
+
skip_replay=True,
|
|
7608
|
+
is_differentiable=False,
|
|
7609
|
+
)
|
|
7610
|
+
add_builtin(
|
|
7611
|
+
"atomic_xor",
|
|
7612
|
+
hidden=hidden,
|
|
7613
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
7614
|
+
constraint=atomic_op_constraint,
|
|
7615
|
+
value_func=create_atomic_op_value_func("xor"),
|
|
7616
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7617
|
+
doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
|
|
7618
|
+
|
|
7619
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k] ^= value``.""",
|
|
7620
|
+
group="Utility",
|
|
7621
|
+
skip_replay=True,
|
|
7622
|
+
is_differentiable=False,
|
|
7623
|
+
)
|
|
7624
|
+
add_builtin(
|
|
7625
|
+
"atomic_xor",
|
|
7626
|
+
hidden=hidden,
|
|
7627
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
7628
|
+
constraint=atomic_op_constraint,
|
|
7629
|
+
value_func=create_atomic_op_value_func("xor"),
|
|
7630
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7631
|
+
doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
|
|
7632
|
+
|
|
7633
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k,l] ^= value``.""",
|
|
7634
|
+
group="Utility",
|
|
7635
|
+
skip_replay=True,
|
|
7636
|
+
is_differentiable=False,
|
|
7637
|
+
)
|
|
7638
|
+
|
|
6758
7639
|
|
|
6759
7640
|
# used to index into builtin types, i.e.: y = vec3[1]
|
|
6760
7641
|
def vector_extract_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
@@ -6903,6 +7784,7 @@ add_builtin(
|
|
|
6903
7784
|
hidden=True,
|
|
6904
7785
|
group="Utility",
|
|
6905
7786
|
skip_replay=True,
|
|
7787
|
+
is_differentiable=False,
|
|
6906
7788
|
)
|
|
6907
7789
|
# implements &quaternion[index]
|
|
6908
7790
|
add_builtin(
|
|
@@ -6913,6 +7795,7 @@ add_builtin(
|
|
|
6913
7795
|
hidden=True,
|
|
6914
7796
|
group="Utility",
|
|
6915
7797
|
skip_replay=True,
|
|
7798
|
+
is_differentiable=False,
|
|
6916
7799
|
)
|
|
6917
7800
|
# implements &transformation[index]
|
|
6918
7801
|
add_builtin(
|
|
@@ -6923,6 +7806,7 @@ add_builtin(
|
|
|
6923
7806
|
hidden=True,
|
|
6924
7807
|
group="Utility",
|
|
6925
7808
|
skip_replay=True,
|
|
7809
|
+
is_differentiable=False,
|
|
6926
7810
|
)
|
|
6927
7811
|
# implements &(*vector)[index]
|
|
6928
7812
|
add_builtin(
|
|
@@ -6933,6 +7817,7 @@ add_builtin(
|
|
|
6933
7817
|
hidden=True,
|
|
6934
7818
|
group="Utility",
|
|
6935
7819
|
skip_replay=True,
|
|
7820
|
+
is_differentiable=False,
|
|
6936
7821
|
)
|
|
6937
7822
|
# implements &(*matrix)[i, j]
|
|
6938
7823
|
add_builtin(
|
|
@@ -6943,6 +7828,7 @@ add_builtin(
|
|
|
6943
7828
|
hidden=True,
|
|
6944
7829
|
group="Utility",
|
|
6945
7830
|
skip_replay=True,
|
|
7831
|
+
is_differentiable=False,
|
|
6946
7832
|
)
|
|
6947
7833
|
# implements &(*quaternion)[index]
|
|
6948
7834
|
add_builtin(
|
|
@@ -6953,6 +7839,7 @@ add_builtin(
|
|
|
6953
7839
|
hidden=True,
|
|
6954
7840
|
group="Utility",
|
|
6955
7841
|
skip_replay=True,
|
|
7842
|
+
is_differentiable=False,
|
|
6956
7843
|
)
|
|
6957
7844
|
# implements &(*transformation)[index]
|
|
6958
7845
|
add_builtin(
|
|
@@ -6963,6 +7850,7 @@ add_builtin(
|
|
|
6963
7850
|
hidden=True,
|
|
6964
7851
|
group="Utility",
|
|
6965
7852
|
skip_replay=True,
|
|
7853
|
+
is_differentiable=False,
|
|
6966
7854
|
)
|
|
6967
7855
|
|
|
6968
7856
|
|
|
@@ -7158,6 +8046,43 @@ add_builtin(
|
|
|
7158
8046
|
)
|
|
7159
8047
|
|
|
7160
8048
|
|
|
8049
|
+
# implements vector[idx] &= scalar
|
|
8050
|
+
add_builtin(
|
|
8051
|
+
"bit_and_inplace",
|
|
8052
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
|
|
8053
|
+
value_type=None,
|
|
8054
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
8055
|
+
hidden=True,
|
|
8056
|
+
export=False,
|
|
8057
|
+
group="Utility",
|
|
8058
|
+
is_differentiable=False,
|
|
8059
|
+
)
|
|
8060
|
+
|
|
8061
|
+
# implements vector[idx] |= scalar
|
|
8062
|
+
add_builtin(
|
|
8063
|
+
"bit_or_inplace",
|
|
8064
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
|
|
8065
|
+
value_type=None,
|
|
8066
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
8067
|
+
hidden=True,
|
|
8068
|
+
export=False,
|
|
8069
|
+
group="Utility",
|
|
8070
|
+
is_differentiable=False,
|
|
8071
|
+
)
|
|
8072
|
+
|
|
8073
|
+
# implements vector[idx] ^= scalar
|
|
8074
|
+
add_builtin(
|
|
8075
|
+
"bit_xor_inplace",
|
|
8076
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
|
|
8077
|
+
value_type=None,
|
|
8078
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
8079
|
+
hidden=True,
|
|
8080
|
+
export=False,
|
|
8081
|
+
group="Utility",
|
|
8082
|
+
is_differentiable=False,
|
|
8083
|
+
)
|
|
8084
|
+
|
|
8085
|
+
|
|
7161
8086
|
def matrix_index_row_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
7162
8087
|
mat_type = arg_types["a"]
|
|
7163
8088
|
row_type = mat_type._wp_row_type_
|
|
@@ -7173,6 +8098,7 @@ add_builtin(
|
|
|
7173
8098
|
hidden=True,
|
|
7174
8099
|
group="Utility",
|
|
7175
8100
|
skip_replay=True,
|
|
8101
|
+
is_differentiable=False,
|
|
7176
8102
|
)
|
|
7177
8103
|
|
|
7178
8104
|
|
|
@@ -7191,6 +8117,7 @@ add_builtin(
|
|
|
7191
8117
|
hidden=True,
|
|
7192
8118
|
group="Utility",
|
|
7193
8119
|
skip_replay=True,
|
|
8120
|
+
is_differentiable=False,
|
|
7194
8121
|
)
|
|
7195
8122
|
|
|
7196
8123
|
|
|
@@ -7390,6 +8317,78 @@ add_builtin(
|
|
|
7390
8317
|
)
|
|
7391
8318
|
|
|
7392
8319
|
|
|
8320
|
+
# implements matrix[i] &= value
|
|
8321
|
+
add_builtin(
|
|
8322
|
+
"bit_and_inplace",
|
|
8323
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
|
|
8324
|
+
value_type=None,
|
|
8325
|
+
hidden=True,
|
|
8326
|
+
export=False,
|
|
8327
|
+
group="Utility",
|
|
8328
|
+
is_differentiable=False,
|
|
8329
|
+
)
|
|
8330
|
+
|
|
8331
|
+
|
|
8332
|
+
# implements matrix[i,j] &= value
|
|
8333
|
+
add_builtin(
|
|
8334
|
+
"bit_and_inplace",
|
|
8335
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
|
|
8336
|
+
value_type=None,
|
|
8337
|
+
hidden=True,
|
|
8338
|
+
export=False,
|
|
8339
|
+
group="Utility",
|
|
8340
|
+
is_differentiable=False,
|
|
8341
|
+
)
|
|
8342
|
+
|
|
8343
|
+
|
|
8344
|
+
# implements matrix[i] |= value
|
|
8345
|
+
add_builtin(
|
|
8346
|
+
"bit_or_inplace",
|
|
8347
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
|
|
8348
|
+
value_type=None,
|
|
8349
|
+
hidden=True,
|
|
8350
|
+
export=False,
|
|
8351
|
+
group="Utility",
|
|
8352
|
+
is_differentiable=False,
|
|
8353
|
+
)
|
|
8354
|
+
|
|
8355
|
+
|
|
8356
|
+
# implements matrix[i,j] |= value
|
|
8357
|
+
add_builtin(
|
|
8358
|
+
"bit_or_inplace",
|
|
8359
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
|
|
8360
|
+
value_type=None,
|
|
8361
|
+
hidden=True,
|
|
8362
|
+
export=False,
|
|
8363
|
+
group="Utility",
|
|
8364
|
+
is_differentiable=False,
|
|
8365
|
+
)
|
|
8366
|
+
|
|
8367
|
+
|
|
8368
|
+
# implements matrix[i] ^= value
|
|
8369
|
+
add_builtin(
|
|
8370
|
+
"bit_xor_inplace",
|
|
8371
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
|
|
8372
|
+
value_type=None,
|
|
8373
|
+
hidden=True,
|
|
8374
|
+
export=False,
|
|
8375
|
+
group="Utility",
|
|
8376
|
+
is_differentiable=False,
|
|
8377
|
+
)
|
|
8378
|
+
|
|
8379
|
+
|
|
8380
|
+
# implements matrix[i,j] ^= value
|
|
8381
|
+
add_builtin(
|
|
8382
|
+
"bit_xor_inplace",
|
|
8383
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
|
|
8384
|
+
value_type=None,
|
|
8385
|
+
hidden=True,
|
|
8386
|
+
export=False,
|
|
8387
|
+
group="Utility",
|
|
8388
|
+
is_differentiable=False,
|
|
8389
|
+
)
|
|
8390
|
+
|
|
8391
|
+
|
|
7393
8392
|
for t in scalar_types + vector_types + (bool,):
|
|
7394
8393
|
if "vec" in t.__name__ or "mat" in t.__name__:
|
|
7395
8394
|
continue
|
|
@@ -7401,6 +8400,7 @@ for t in scalar_types + vector_types + (bool,):
|
|
|
7401
8400
|
doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
|
|
7402
8401
|
group="Utility",
|
|
7403
8402
|
hidden=True,
|
|
8403
|
+
is_differentiable=False,
|
|
7404
8404
|
)
|
|
7405
8405
|
|
|
7406
8406
|
add_builtin(
|
|
@@ -7411,6 +8411,7 @@ for t in scalar_types + vector_types + (bool,):
|
|
|
7411
8411
|
group="Utility",
|
|
7412
8412
|
hidden=True,
|
|
7413
8413
|
export=False,
|
|
8414
|
+
is_differentiable=False,
|
|
7414
8415
|
)
|
|
7415
8416
|
|
|
7416
8417
|
|
|
@@ -7429,6 +8430,7 @@ add_builtin(
|
|
|
7429
8430
|
doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
|
|
7430
8431
|
group="Utility",
|
|
7431
8432
|
hidden=True,
|
|
8433
|
+
is_differentiable=False,
|
|
7432
8434
|
)
|
|
7433
8435
|
add_builtin(
|
|
7434
8436
|
"expect_neq",
|
|
@@ -7439,6 +8441,7 @@ add_builtin(
|
|
|
7439
8441
|
group="Utility",
|
|
7440
8442
|
hidden=True,
|
|
7441
8443
|
export=False,
|
|
8444
|
+
is_differentiable=False,
|
|
7442
8445
|
)
|
|
7443
8446
|
|
|
7444
8447
|
add_builtin(
|
|
@@ -7449,6 +8452,7 @@ add_builtin(
|
|
|
7449
8452
|
doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
|
|
7450
8453
|
group="Utility",
|
|
7451
8454
|
hidden=True,
|
|
8455
|
+
is_differentiable=False,
|
|
7452
8456
|
)
|
|
7453
8457
|
add_builtin(
|
|
7454
8458
|
"expect_neq",
|
|
@@ -7459,6 +8463,7 @@ add_builtin(
|
|
|
7459
8463
|
group="Utility",
|
|
7460
8464
|
hidden=True,
|
|
7461
8465
|
export=False,
|
|
8466
|
+
is_differentiable=False,
|
|
7462
8467
|
)
|
|
7463
8468
|
|
|
7464
8469
|
add_builtin(
|
|
@@ -7549,6 +8554,7 @@ add_builtin(
|
|
|
7549
8554
|
value_type=None,
|
|
7550
8555
|
doc="Prints an error to stdout if ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
7551
8556
|
group="Utility",
|
|
8557
|
+
is_differentiable=False,
|
|
7552
8558
|
)
|
|
7553
8559
|
add_builtin(
|
|
7554
8560
|
"expect_near",
|
|
@@ -7558,6 +8564,7 @@ add_builtin(
|
|
|
7558
8564
|
value_type=None,
|
|
7559
8565
|
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
7560
8566
|
group="Utility",
|
|
8567
|
+
is_differentiable=False,
|
|
7561
8568
|
)
|
|
7562
8569
|
add_builtin(
|
|
7563
8570
|
"expect_near",
|
|
@@ -7567,6 +8574,7 @@ add_builtin(
|
|
|
7567
8574
|
value_type=None,
|
|
7568
8575
|
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
7569
8576
|
group="Utility",
|
|
8577
|
+
is_differentiable=False,
|
|
7570
8578
|
)
|
|
7571
8579
|
add_builtin(
|
|
7572
8580
|
"expect_near",
|
|
@@ -7580,6 +8588,7 @@ add_builtin(
|
|
|
7580
8588
|
value_type=None,
|
|
7581
8589
|
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
7582
8590
|
group="Utility",
|
|
8591
|
+
is_differentiable=False,
|
|
7583
8592
|
)
|
|
7584
8593
|
|
|
7585
8594
|
# ---------------------------------
|
|
@@ -7590,6 +8599,7 @@ add_builtin(
|
|
|
7590
8599
|
input_types={"arr": array(dtype=Scalar), "value": Scalar},
|
|
7591
8600
|
value_type=int,
|
|
7592
8601
|
doc="Search a sorted array ``arr`` for the closest element greater than or equal to ``value``.",
|
|
8602
|
+
is_differentiable=False,
|
|
7593
8603
|
)
|
|
7594
8604
|
|
|
7595
8605
|
add_builtin(
|
|
@@ -7597,6 +8607,7 @@ add_builtin(
|
|
|
7597
8607
|
input_types={"arr": array(dtype=Scalar), "arr_begin": int, "arr_end": int, "value": Scalar},
|
|
7598
8608
|
value_type=int,
|
|
7599
8609
|
doc="Search a sorted array ``arr`` in the range [arr_begin, arr_end) for the closest element greater than or equal to ``value``.",
|
|
8610
|
+
is_differentiable=False,
|
|
7600
8611
|
)
|
|
7601
8612
|
|
|
7602
8613
|
# ---------------------------------
|
|
@@ -7672,12 +8683,157 @@ add_builtin(
|
|
|
7672
8683
|
)
|
|
7673
8684
|
|
|
7674
8685
|
# bitwise operators
|
|
7675
|
-
add_builtin(
|
|
7676
|
-
|
|
7677
|
-
|
|
7678
|
-
|
|
7679
|
-
|
|
7680
|
-
|
|
8686
|
+
add_builtin(
|
|
8687
|
+
"bit_and",
|
|
8688
|
+
input_types={"a": Int, "b": Int},
|
|
8689
|
+
value_func=sametypes_create_value_func(Int),
|
|
8690
|
+
group="Operators",
|
|
8691
|
+
is_differentiable=False,
|
|
8692
|
+
)
|
|
8693
|
+
add_builtin(
|
|
8694
|
+
"bit_and",
|
|
8695
|
+
input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
|
|
8696
|
+
constraint=sametypes,
|
|
8697
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
|
|
8698
|
+
doc="",
|
|
8699
|
+
group="Operators",
|
|
8700
|
+
is_differentiable=False,
|
|
8701
|
+
)
|
|
8702
|
+
add_builtin(
|
|
8703
|
+
"bit_and",
|
|
8704
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
|
|
8705
|
+
constraint=sametypes,
|
|
8706
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
|
|
8707
|
+
doc="",
|
|
8708
|
+
group="Operators",
|
|
8709
|
+
is_differentiable=False,
|
|
8710
|
+
)
|
|
8711
|
+
|
|
8712
|
+
add_builtin(
|
|
8713
|
+
"bit_or",
|
|
8714
|
+
input_types={"a": Int, "b": Int},
|
|
8715
|
+
value_func=sametypes_create_value_func(Int),
|
|
8716
|
+
group="Operators",
|
|
8717
|
+
is_differentiable=False,
|
|
8718
|
+
)
|
|
8719
|
+
add_builtin(
|
|
8720
|
+
"bit_or",
|
|
8721
|
+
input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
|
|
8722
|
+
constraint=sametypes,
|
|
8723
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
|
|
8724
|
+
doc="",
|
|
8725
|
+
group="Operators",
|
|
8726
|
+
is_differentiable=False,
|
|
8727
|
+
)
|
|
8728
|
+
add_builtin(
|
|
8729
|
+
"bit_or",
|
|
8730
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
|
|
8731
|
+
constraint=sametypes,
|
|
8732
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
|
|
8733
|
+
doc="",
|
|
8734
|
+
group="Operators",
|
|
8735
|
+
is_differentiable=False,
|
|
8736
|
+
)
|
|
8737
|
+
|
|
8738
|
+
add_builtin(
|
|
8739
|
+
"bit_xor",
|
|
8740
|
+
input_types={"a": Int, "b": Int},
|
|
8741
|
+
value_func=sametypes_create_value_func(Int),
|
|
8742
|
+
group="Operators",
|
|
8743
|
+
is_differentiable=False,
|
|
8744
|
+
)
|
|
8745
|
+
add_builtin(
|
|
8746
|
+
"bit_xor",
|
|
8747
|
+
input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
|
|
8748
|
+
constraint=sametypes,
|
|
8749
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
|
|
8750
|
+
doc="",
|
|
8751
|
+
group="Operators",
|
|
8752
|
+
is_differentiable=False,
|
|
8753
|
+
)
|
|
8754
|
+
add_builtin(
|
|
8755
|
+
"bit_xor",
|
|
8756
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
|
|
8757
|
+
constraint=sametypes,
|
|
8758
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
|
|
8759
|
+
doc="",
|
|
8760
|
+
group="Operators",
|
|
8761
|
+
is_differentiable=False,
|
|
8762
|
+
)
|
|
8763
|
+
|
|
8764
|
+
add_builtin(
|
|
8765
|
+
"lshift",
|
|
8766
|
+
input_types={"a": Int, "b": Int},
|
|
8767
|
+
value_func=sametypes_create_value_func(Int),
|
|
8768
|
+
group="Operators",
|
|
8769
|
+
is_differentiable=False,
|
|
8770
|
+
)
|
|
8771
|
+
add_builtin(
|
|
8772
|
+
"lshift",
|
|
8773
|
+
input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
|
|
8774
|
+
constraint=sametypes,
|
|
8775
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
|
|
8776
|
+
doc="",
|
|
8777
|
+
group="Operators",
|
|
8778
|
+
is_differentiable=False,
|
|
8779
|
+
)
|
|
8780
|
+
add_builtin(
|
|
8781
|
+
"lshift",
|
|
8782
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
|
|
8783
|
+
constraint=sametypes,
|
|
8784
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
|
|
8785
|
+
doc="",
|
|
8786
|
+
group="Operators",
|
|
8787
|
+
is_differentiable=False,
|
|
8788
|
+
)
|
|
8789
|
+
|
|
8790
|
+
add_builtin(
|
|
8791
|
+
"rshift",
|
|
8792
|
+
input_types={"a": Int, "b": Int},
|
|
8793
|
+
value_func=sametypes_create_value_func(Int),
|
|
8794
|
+
group="Operators",
|
|
8795
|
+
is_differentiable=False,
|
|
8796
|
+
)
|
|
8797
|
+
add_builtin(
|
|
8798
|
+
"rshift",
|
|
8799
|
+
input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
|
|
8800
|
+
constraint=sametypes,
|
|
8801
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
|
|
8802
|
+
doc="",
|
|
8803
|
+
group="Operators",
|
|
8804
|
+
is_differentiable=False,
|
|
8805
|
+
)
|
|
8806
|
+
add_builtin(
|
|
8807
|
+
"rshift",
|
|
8808
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
|
|
8809
|
+
constraint=sametypes,
|
|
8810
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
|
|
8811
|
+
doc="",
|
|
8812
|
+
group="Operators",
|
|
8813
|
+
is_differentiable=False,
|
|
8814
|
+
)
|
|
8815
|
+
|
|
8816
|
+
add_builtin(
|
|
8817
|
+
"invert",
|
|
8818
|
+
input_types={"a": Int},
|
|
8819
|
+
value_func=sametypes_create_value_func(Int),
|
|
8820
|
+
group="Operators",
|
|
8821
|
+
is_differentiable=False,
|
|
8822
|
+
)
|
|
8823
|
+
add_builtin(
|
|
8824
|
+
"invert",
|
|
8825
|
+
input_types={"a": vector(length=Any, dtype=Int)},
|
|
8826
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
|
|
8827
|
+
group="Operators",
|
|
8828
|
+
is_differentiable=False,
|
|
8829
|
+
)
|
|
8830
|
+
add_builtin(
|
|
8831
|
+
"invert",
|
|
8832
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Int)},
|
|
8833
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
|
|
8834
|
+
group="Operators",
|
|
8835
|
+
is_differentiable=False,
|
|
8836
|
+
)
|
|
7681
8837
|
|
|
7682
8838
|
|
|
7683
8839
|
add_builtin(
|
|
@@ -7878,6 +9034,7 @@ add_builtin(
|
|
|
7878
9034
|
value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
|
|
7879
9035
|
doc="Modulo operation using truncated division.",
|
|
7880
9036
|
group="Operators",
|
|
9037
|
+
is_differentiable=False,
|
|
7881
9038
|
)
|
|
7882
9039
|
|
|
7883
9040
|
add_builtin(
|
|
@@ -7937,6 +9094,7 @@ add_builtin(
|
|
|
7937
9094
|
value_func=sametypes_create_value_func(Scalar),
|
|
7938
9095
|
doc="",
|
|
7939
9096
|
group="Operators",
|
|
9097
|
+
is_differentiable=False,
|
|
7940
9098
|
)
|
|
7941
9099
|
|
|
7942
9100
|
add_builtin("pos", input_types={"x": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators")
|
|
@@ -7984,12 +9142,28 @@ add_builtin(
|
|
|
7984
9142
|
group="Operators",
|
|
7985
9143
|
)
|
|
7986
9144
|
|
|
7987
|
-
add_builtin(
|
|
9145
|
+
add_builtin(
|
|
9146
|
+
"unot",
|
|
9147
|
+
input_types={"a": builtins.bool},
|
|
9148
|
+
value_type=builtins.bool,
|
|
9149
|
+
doc="",
|
|
9150
|
+
group="Operators",
|
|
9151
|
+
is_differentiable=False,
|
|
9152
|
+
)
|
|
7988
9153
|
for t in int_types:
|
|
7989
|
-
add_builtin(
|
|
9154
|
+
add_builtin(
|
|
9155
|
+
"unot", input_types={"a": t}, value_type=builtins.bool, doc="", group="Operators", is_differentiable=False
|
|
9156
|
+
)
|
|
7990
9157
|
|
|
7991
9158
|
|
|
7992
|
-
add_builtin(
|
|
9159
|
+
add_builtin(
|
|
9160
|
+
"unot",
|
|
9161
|
+
input_types={"a": array(dtype=Any)},
|
|
9162
|
+
value_type=builtins.bool,
|
|
9163
|
+
doc="",
|
|
9164
|
+
group="Operators",
|
|
9165
|
+
is_differentiable=False,
|
|
9166
|
+
)
|
|
7993
9167
|
|
|
7994
9168
|
|
|
7995
9169
|
# Tile operators
|
|
@@ -8061,6 +9235,45 @@ add_builtin(
|
|
|
8061
9235
|
export=False,
|
|
8062
9236
|
)
|
|
8063
9237
|
|
|
9238
|
+
add_builtin(
|
|
9239
|
+
"bit_and",
|
|
9240
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
9241
|
+
value_func=tile_binary_map_value_func,
|
|
9242
|
+
# dispatch_func=tile_map_dispatch_func,
|
|
9243
|
+
# variadic=True,
|
|
9244
|
+
native_func="tile_bit_and",
|
|
9245
|
+
doc="Bitwise AND each element of two tiles together",
|
|
9246
|
+
group="Tile Primitives",
|
|
9247
|
+
export=False,
|
|
9248
|
+
is_differentiable=False,
|
|
9249
|
+
)
|
|
9250
|
+
|
|
9251
|
+
add_builtin(
|
|
9252
|
+
"bit_or",
|
|
9253
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
9254
|
+
value_func=tile_binary_map_value_func,
|
|
9255
|
+
# dispatch_func=tile_map_dispatch_func,
|
|
9256
|
+
# variadic=True,
|
|
9257
|
+
native_func="tile_bit_or",
|
|
9258
|
+
doc="Bitwise OR each element of two tiles together",
|
|
9259
|
+
group="Tile Primitives",
|
|
9260
|
+
export=False,
|
|
9261
|
+
is_differentiable=False,
|
|
9262
|
+
)
|
|
9263
|
+
|
|
9264
|
+
add_builtin(
|
|
9265
|
+
"bit_xor",
|
|
9266
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
9267
|
+
value_func=tile_binary_map_value_func,
|
|
9268
|
+
# dispatch_func=tile_map_dispatch_func,
|
|
9269
|
+
# variadic=True,
|
|
9270
|
+
native_func="tile_bit_xor",
|
|
9271
|
+
doc="Bitwise XOR each element of two tiles together",
|
|
9272
|
+
group="Tile Primitives",
|
|
9273
|
+
export=False,
|
|
9274
|
+
is_differentiable=False,
|
|
9275
|
+
)
|
|
9276
|
+
|
|
8064
9277
|
|
|
8065
9278
|
add_builtin(
|
|
8066
9279
|
"mul",
|
|
@@ -8122,6 +9335,45 @@ add_builtin(
|
|
|
8122
9335
|
)
|
|
8123
9336
|
|
|
8124
9337
|
|
|
9338
|
+
add_builtin(
|
|
9339
|
+
"bit_and_inplace",
|
|
9340
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
9341
|
+
value_type=None,
|
|
9342
|
+
dispatch_func=tile_inplace_dispatch_func,
|
|
9343
|
+
export=False,
|
|
9344
|
+
hidden=True,
|
|
9345
|
+
native_func="tile_bit_and_inplace",
|
|
9346
|
+
group="Operators",
|
|
9347
|
+
is_differentiable=False,
|
|
9348
|
+
)
|
|
9349
|
+
|
|
9350
|
+
|
|
9351
|
+
add_builtin(
|
|
9352
|
+
"bit_or_inplace",
|
|
9353
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
9354
|
+
value_type=None,
|
|
9355
|
+
dispatch_func=tile_inplace_dispatch_func,
|
|
9356
|
+
export=False,
|
|
9357
|
+
hidden=True,
|
|
9358
|
+
native_func="tile_bit_or_inplace",
|
|
9359
|
+
group="Operators",
|
|
9360
|
+
is_differentiable=False,
|
|
9361
|
+
)
|
|
9362
|
+
|
|
9363
|
+
|
|
9364
|
+
add_builtin(
|
|
9365
|
+
"bit_xor_inplace",
|
|
9366
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
9367
|
+
value_type=None,
|
|
9368
|
+
dispatch_func=tile_inplace_dispatch_func,
|
|
9369
|
+
export=False,
|
|
9370
|
+
hidden=True,
|
|
9371
|
+
native_func="tile_bit_xor_inplace",
|
|
9372
|
+
group="Operators",
|
|
9373
|
+
is_differentiable=False,
|
|
9374
|
+
)
|
|
9375
|
+
|
|
9376
|
+
|
|
8125
9377
|
def tile_diag_add_value_func(arg_types, arg_values):
|
|
8126
9378
|
if arg_types is None:
|
|
8127
9379
|
return tile(dtype=Any, shape=Tuple[int, int])
|
|
@@ -8163,7 +9415,7 @@ def tile_diag_add_lto_dispatch_func(
|
|
|
8163
9415
|
return_values: List[Var],
|
|
8164
9416
|
arg_values: Mapping[str, Var],
|
|
8165
9417
|
options: Mapping[str, Any],
|
|
8166
|
-
builder: warp.context.ModuleBuilder,
|
|
9418
|
+
builder: warp._src.context.ModuleBuilder,
|
|
8167
9419
|
):
|
|
8168
9420
|
a = arg_values["a"]
|
|
8169
9421
|
d = arg_values["d"]
|
|
@@ -8183,6 +9435,7 @@ add_builtin(
|
|
|
8183
9435
|
doc="Add a square matrix and a diagonal matrix 'd' represented as a 1D tile",
|
|
8184
9436
|
group="Tile Primitives",
|
|
8185
9437
|
export=False,
|
|
9438
|
+
is_differentiable=False,
|
|
8186
9439
|
)
|
|
8187
9440
|
|
|
8188
9441
|
|
|
@@ -8239,7 +9492,7 @@ def tile_matmul_lto_dispatch_func(
|
|
|
8239
9492
|
return_values: List[Var],
|
|
8240
9493
|
arg_values: Mapping[str, Var],
|
|
8241
9494
|
options: Mapping[str, Any],
|
|
8242
|
-
builder: warp.context.ModuleBuilder,
|
|
9495
|
+
builder: warp._src.context.ModuleBuilder,
|
|
8243
9496
|
):
|
|
8244
9497
|
a = arg_values["a"]
|
|
8245
9498
|
b = arg_values["b"]
|
|
@@ -8277,7 +9530,7 @@ def tile_matmul_lto_dispatch_func(
|
|
|
8277
9530
|
num_threads = options["block_dim"]
|
|
8278
9531
|
arch = options["output_arch"]
|
|
8279
9532
|
|
|
8280
|
-
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
9533
|
+
if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
|
|
8281
9534
|
# CPU/no-MathDx dispatch
|
|
8282
9535
|
return ((0, 0, 0, a, b, out), template_args, [], 0)
|
|
8283
9536
|
else:
|
|
@@ -8290,7 +9543,7 @@ def tile_matmul_lto_dispatch_func(
|
|
|
8290
9543
|
|
|
8291
9544
|
# generate the LTOs
|
|
8292
9545
|
# C += A * B
|
|
8293
|
-
(fun_forward, lto_forward) = warp.build.build_lto_dot(
|
|
9546
|
+
(fun_forward, lto_forward) = warp._src.build.build_lto_dot(
|
|
8294
9547
|
M,
|
|
8295
9548
|
N,
|
|
8296
9549
|
K,
|
|
@@ -8306,7 +9559,7 @@ def tile_matmul_lto_dispatch_func(
|
|
|
8306
9559
|
)
|
|
8307
9560
|
if warp.config.enable_backward:
|
|
8308
9561
|
# adjA += adjC * B^T - Transpose ~= flipped layout
|
|
8309
|
-
(fun_backward_A, lto_backward_A) = warp.build.build_lto_dot(
|
|
9562
|
+
(fun_backward_A, lto_backward_A) = warp._src.build.build_lto_dot(
|
|
8310
9563
|
M,
|
|
8311
9564
|
K,
|
|
8312
9565
|
N,
|
|
@@ -8321,7 +9574,7 @@ def tile_matmul_lto_dispatch_func(
|
|
|
8321
9574
|
builder,
|
|
8322
9575
|
)
|
|
8323
9576
|
# adjB += A^T * adjC - Transpose ~= flipped layout
|
|
8324
|
-
(fun_backward_B, lto_backward_B) = warp.build.build_lto_dot(
|
|
9577
|
+
(fun_backward_B, lto_backward_B) = warp._src.build.build_lto_dot(
|
|
8325
9578
|
K,
|
|
8326
9579
|
N,
|
|
8327
9580
|
M,
|
|
@@ -8438,7 +9691,7 @@ def tile_fft_generic_lto_dispatch_func(
|
|
|
8438
9691
|
return_values: List[Var],
|
|
8439
9692
|
arg_values: Mapping[str, Var],
|
|
8440
9693
|
options: Mapping[str, Any],
|
|
8441
|
-
builder: warp.context.ModuleBuilder,
|
|
9694
|
+
builder: warp._src.context.ModuleBuilder,
|
|
8442
9695
|
direction: str | None = None,
|
|
8443
9696
|
):
|
|
8444
9697
|
inout = arg_values["inout"]
|
|
@@ -8467,12 +9720,12 @@ def tile_fft_generic_lto_dispatch_func(
|
|
|
8467
9720
|
arch = options["output_arch"]
|
|
8468
9721
|
ept = size // num_threads
|
|
8469
9722
|
|
|
8470
|
-
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
9723
|
+
if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
|
|
8471
9724
|
# CPU/no-MathDx dispatch
|
|
8472
9725
|
return ([], [], [], 0)
|
|
8473
9726
|
else:
|
|
8474
9727
|
# generate the LTO
|
|
8475
|
-
lto_symbol, lto_code_data, shared_memory_bytes = warp.build.build_lto_fft(
|
|
9728
|
+
lto_symbol, lto_code_data, shared_memory_bytes = warp._src.build.build_lto_fft(
|
|
8476
9729
|
arch, size, ept, direction, dir, precision, builder
|
|
8477
9730
|
)
|
|
8478
9731
|
|
|
@@ -8510,6 +9763,7 @@ add_builtin(
|
|
|
8510
9763
|
group="Tile Primitives",
|
|
8511
9764
|
export=False,
|
|
8512
9765
|
namespace="",
|
|
9766
|
+
is_differentiable=False,
|
|
8513
9767
|
)
|
|
8514
9768
|
|
|
8515
9769
|
add_builtin(
|
|
@@ -8531,6 +9785,7 @@ add_builtin(
|
|
|
8531
9785
|
group="Tile Primitives",
|
|
8532
9786
|
export=False,
|
|
8533
9787
|
namespace="",
|
|
9788
|
+
is_differentiable=False,
|
|
8534
9789
|
)
|
|
8535
9790
|
|
|
8536
9791
|
|
|
@@ -8575,7 +9830,7 @@ def tile_cholesky_generic_lto_dispatch_func(
|
|
|
8575
9830
|
return_values: List[Var],
|
|
8576
9831
|
arg_values: Mapping[str, Var],
|
|
8577
9832
|
options: Mapping[str, Any],
|
|
8578
|
-
builder: warp.context.ModuleBuilder,
|
|
9833
|
+
builder: warp._src.context.ModuleBuilder,
|
|
8579
9834
|
):
|
|
8580
9835
|
a = arg_values["A"]
|
|
8581
9836
|
# force source tile to shared memory
|
|
@@ -8595,7 +9850,7 @@ def tile_cholesky_generic_lto_dispatch_func(
|
|
|
8595
9850
|
|
|
8596
9851
|
arch = options["output_arch"]
|
|
8597
9852
|
|
|
8598
|
-
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
9853
|
+
if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
|
|
8599
9854
|
# CPU/no-MathDx dispatch
|
|
8600
9855
|
return ((0, a, out), [], [], 0)
|
|
8601
9856
|
else:
|
|
@@ -8610,7 +9865,7 @@ def tile_cholesky_generic_lto_dispatch_func(
|
|
|
8610
9865
|
req_smem_bytes = a.type.size * type_size_in_bytes(a.type.dtype)
|
|
8611
9866
|
|
|
8612
9867
|
# generate the LTO
|
|
8613
|
-
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
9868
|
+
lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
|
|
8614
9869
|
M,
|
|
8615
9870
|
N,
|
|
8616
9871
|
1,
|
|
@@ -8655,6 +9910,7 @@ add_builtin(
|
|
|
8655
9910
|
group="Tile Primitives",
|
|
8656
9911
|
export=False,
|
|
8657
9912
|
namespace="",
|
|
9913
|
+
is_differentiable=False,
|
|
8658
9914
|
)
|
|
8659
9915
|
|
|
8660
9916
|
|
|
@@ -8698,7 +9954,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
|
8698
9954
|
return_values: List[Var],
|
|
8699
9955
|
arg_values: Mapping[str, Var],
|
|
8700
9956
|
options: Mapping[str, Any],
|
|
8701
|
-
builder: warp.context.ModuleBuilder,
|
|
9957
|
+
builder: warp._src.context.ModuleBuilder,
|
|
8702
9958
|
):
|
|
8703
9959
|
L = arg_values["L"]
|
|
8704
9960
|
y = arg_values["y"]
|
|
@@ -8727,7 +9983,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
|
8727
9983
|
|
|
8728
9984
|
arch = options["output_arch"]
|
|
8729
9985
|
|
|
8730
|
-
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
9986
|
+
if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
|
|
8731
9987
|
# CPU/no-MathDx dispatch
|
|
8732
9988
|
return ((0, L, y, x), [], [], 0)
|
|
8733
9989
|
else:
|
|
@@ -8743,7 +9999,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
|
8743
9999
|
req_smem_bytes = (x.type.size + y.type.size + L.type.size) * type_size_in_bytes(L.type.dtype)
|
|
8744
10000
|
|
|
8745
10001
|
# generate the LTO
|
|
8746
|
-
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
10002
|
+
lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
|
|
8747
10003
|
M,
|
|
8748
10004
|
N,
|
|
8749
10005
|
NRHS,
|
|
@@ -8785,6 +10041,7 @@ add_builtin(
|
|
|
8785
10041
|
group="Tile Primitives",
|
|
8786
10042
|
export=False,
|
|
8787
10043
|
namespace="",
|
|
10044
|
+
is_differentiable=False,
|
|
8788
10045
|
)
|
|
8789
10046
|
|
|
8790
10047
|
|
|
@@ -8794,7 +10051,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
|
|
|
8794
10051
|
return_values: List[Var],
|
|
8795
10052
|
arg_values: Mapping[str, Var],
|
|
8796
10053
|
options: Mapping[str, Any],
|
|
8797
|
-
builder: warp.context.ModuleBuilder,
|
|
10054
|
+
builder: warp._src.context.ModuleBuilder,
|
|
8798
10055
|
):
|
|
8799
10056
|
L = arg_values["L"]
|
|
8800
10057
|
y = arg_values["y"]
|
|
@@ -8823,7 +10080,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
|
|
|
8823
10080
|
|
|
8824
10081
|
arch = options["output_arch"]
|
|
8825
10082
|
|
|
8826
|
-
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
10083
|
+
if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
|
|
8827
10084
|
# CPU/no-MathDx dispatch
|
|
8828
10085
|
return ((0, L, y, z), [], [], 0)
|
|
8829
10086
|
else:
|
|
@@ -8839,7 +10096,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
|
|
|
8839
10096
|
req_smem_bytes = (z.type.size + y.type.size + L.type.size) * type_size_in_bytes(L.type.dtype)
|
|
8840
10097
|
|
|
8841
10098
|
# generate the LTO
|
|
8842
|
-
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
10099
|
+
lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
|
|
8843
10100
|
M,
|
|
8844
10101
|
N,
|
|
8845
10102
|
NRHS,
|
|
@@ -8917,6 +10174,7 @@ add_builtin(
|
|
|
8917
10174
|
group="Tile Primitives",
|
|
8918
10175
|
export=False,
|
|
8919
10176
|
namespace="",
|
|
10177
|
+
is_differentiable=False,
|
|
8920
10178
|
)
|
|
8921
10179
|
|
|
8922
10180
|
|
|
@@ -8926,7 +10184,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
|
|
|
8926
10184
|
return_values: List[Var],
|
|
8927
10185
|
arg_values: Mapping[str, Var],
|
|
8928
10186
|
options: Mapping[str, Any],
|
|
8929
|
-
builder: warp.context.ModuleBuilder,
|
|
10187
|
+
builder: warp._src.context.ModuleBuilder,
|
|
8930
10188
|
):
|
|
8931
10189
|
U = arg_values["U"]
|
|
8932
10190
|
z = arg_values["z"]
|
|
@@ -8955,7 +10213,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
|
|
|
8955
10213
|
|
|
8956
10214
|
arch = options["output_arch"]
|
|
8957
10215
|
|
|
8958
|
-
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
10216
|
+
if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
|
|
8959
10217
|
# CPU/no-MathDx dispatch
|
|
8960
10218
|
return ((0, U, z, x), [], [], 0)
|
|
8961
10219
|
else:
|
|
@@ -8971,7 +10229,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
|
|
|
8971
10229
|
req_smem_bytes = (x.type.size + z.type.size + U.type.size) * type_size_in_bytes(U.type.dtype)
|
|
8972
10230
|
|
|
8973
10231
|
# generate the LTO
|
|
8974
|
-
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
10232
|
+
lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
|
|
8975
10233
|
M,
|
|
8976
10234
|
N,
|
|
8977
10235
|
NRHS,
|
|
@@ -9049,6 +10307,7 @@ add_builtin(
|
|
|
9049
10307
|
group="Tile Primitives",
|
|
9050
10308
|
export=False,
|
|
9051
10309
|
namespace="",
|
|
10310
|
+
is_differentiable=False,
|
|
9052
10311
|
)
|
|
9053
10312
|
|
|
9054
10313
|
|
|
@@ -9068,6 +10327,7 @@ add_builtin(
|
|
|
9068
10327
|
The return type of the expression must be either a Warp function, a string, or a type that is supported inside Warp kernels and functions
|
|
9069
10328
|
(excluding Warp arrays since they cannot be created in a Warp kernel at the moment).""",
|
|
9070
10329
|
group="Code Generation",
|
|
10330
|
+
is_differentiable=False,
|
|
9071
10331
|
)
|
|
9072
10332
|
|
|
9073
10333
|
|
|
@@ -9092,6 +10352,7 @@ add_builtin(
|
|
|
9092
10352
|
doc="Return the number of elements in a vector.",
|
|
9093
10353
|
group="Utility",
|
|
9094
10354
|
export=False,
|
|
10355
|
+
is_differentiable=False,
|
|
9095
10356
|
)
|
|
9096
10357
|
|
|
9097
10358
|
add_builtin(
|
|
@@ -9101,6 +10362,7 @@ add_builtin(
|
|
|
9101
10362
|
doc="Return the number of elements in a quaternion.",
|
|
9102
10363
|
group="Utility",
|
|
9103
10364
|
export=False,
|
|
10365
|
+
is_differentiable=False,
|
|
9104
10366
|
)
|
|
9105
10367
|
|
|
9106
10368
|
add_builtin(
|
|
@@ -9110,6 +10372,7 @@ add_builtin(
|
|
|
9110
10372
|
doc="Return the number of rows in a matrix.",
|
|
9111
10373
|
group="Utility",
|
|
9112
10374
|
export=False,
|
|
10375
|
+
is_differentiable=False,
|
|
9113
10376
|
)
|
|
9114
10377
|
|
|
9115
10378
|
add_builtin(
|
|
@@ -9119,6 +10382,7 @@ add_builtin(
|
|
|
9119
10382
|
doc="Return the number of elements in a transformation.",
|
|
9120
10383
|
group="Utility",
|
|
9121
10384
|
export=False,
|
|
10385
|
+
is_differentiable=False,
|
|
9122
10386
|
)
|
|
9123
10387
|
|
|
9124
10388
|
add_builtin(
|
|
@@ -9128,6 +10392,7 @@ add_builtin(
|
|
|
9128
10392
|
doc="Return the size of the first dimension in an array.",
|
|
9129
10393
|
group="Utility",
|
|
9130
10394
|
export=False,
|
|
10395
|
+
is_differentiable=False,
|
|
9131
10396
|
)
|
|
9132
10397
|
|
|
9133
10398
|
add_builtin(
|
|
@@ -9137,6 +10402,62 @@ add_builtin(
|
|
|
9137
10402
|
doc="Return the number of rows in a tile.",
|
|
9138
10403
|
group="Utility",
|
|
9139
10404
|
export=False,
|
|
10405
|
+
is_differentiable=False,
|
|
10406
|
+
)
|
|
10407
|
+
|
|
10408
|
+
|
|
10409
|
+
def cast_value_func(arg_types, arg_values):
|
|
10410
|
+
# Return generic type for doc builds.
|
|
10411
|
+
if arg_types is None:
|
|
10412
|
+
return Any
|
|
10413
|
+
|
|
10414
|
+
return arg_values["dtype"]
|
|
10415
|
+
|
|
10416
|
+
|
|
10417
|
+
def cast_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
10418
|
+
func_args = (args["a"],)
|
|
10419
|
+
template_args = (args["dtype"],)
|
|
10420
|
+
return (func_args, template_args)
|
|
10421
|
+
|
|
10422
|
+
|
|
10423
|
+
add_builtin(
|
|
10424
|
+
"cast",
|
|
10425
|
+
input_types={"a": Any, "dtype": Any},
|
|
10426
|
+
value_func=cast_value_func,
|
|
10427
|
+
dispatch_func=cast_dispatch_func,
|
|
10428
|
+
group="Utility",
|
|
10429
|
+
export=False,
|
|
10430
|
+
is_differentiable=False,
|
|
10431
|
+
doc="""Reinterpret a value as a different type while preserving its bit pattern.
|
|
10432
|
+
|
|
10433
|
+
:param a: The value to cast
|
|
10434
|
+
:param dtype: The target type
|
|
10435
|
+
|
|
10436
|
+
Example:
|
|
10437
|
+
|
|
10438
|
+
.. code-block:: python
|
|
10439
|
+
|
|
10440
|
+
@wp.struct
|
|
10441
|
+
class MyStruct:
|
|
10442
|
+
f: wp.float16
|
|
10443
|
+
i: wp.int16
|
|
10444
|
+
|
|
10445
|
+
|
|
10446
|
+
@wp.kernel
|
|
10447
|
+
def compute():
|
|
10448
|
+
x = wp.int32(0x40000000)
|
|
10449
|
+
x_casted = wp.cast(x, wp.float32)
|
|
10450
|
+
wp.expect_eq(x_casted, 2.0) # 0x40000000
|
|
10451
|
+
|
|
10452
|
+
s = MyStruct()
|
|
10453
|
+
s.f = wp.float16(2.0) # 0x4000
|
|
10454
|
+
s.i = wp.int16(4096) # 0x1000
|
|
10455
|
+
s_casted = wp.cast(s, wp.int32)
|
|
10456
|
+
wp.expect_eq(s_casted, 0x10004000)
|
|
10457
|
+
|
|
10458
|
+
|
|
10459
|
+
wp.launch(compute, dim=1)
|
|
10460
|
+
""",
|
|
9140
10461
|
)
|
|
9141
10462
|
|
|
9142
10463
|
|
|
@@ -9163,7 +10484,7 @@ add_builtin(
|
|
|
9163
10484
|
doc="Construct a tuple from a list of values",
|
|
9164
10485
|
group="Utility",
|
|
9165
10486
|
hidden=True,
|
|
9166
|
-
|
|
10487
|
+
is_differentiable=False,
|
|
9167
10488
|
export=False,
|
|
9168
10489
|
)
|
|
9169
10490
|
|
|
@@ -9200,7 +10521,7 @@ add_builtin(
|
|
|
9200
10521
|
dispatch_func=tuple_extract_dispatch_func,
|
|
9201
10522
|
group="Utility",
|
|
9202
10523
|
hidden=True,
|
|
9203
|
-
|
|
10524
|
+
is_differentiable=False,
|
|
9204
10525
|
)
|
|
9205
10526
|
|
|
9206
10527
|
|
|
@@ -9211,6 +10532,7 @@ add_builtin(
|
|
|
9211
10532
|
doc="Return the number of elements in a tuple.",
|
|
9212
10533
|
group="Utility",
|
|
9213
10534
|
export=False,
|
|
10535
|
+
is_differentiable=False,
|
|
9214
10536
|
)
|
|
9215
10537
|
|
|
9216
10538
|
# ---------------------------------
|
|
@@ -9229,5 +10551,5 @@ add_builtin(
|
|
|
9229
10551
|
export=False,
|
|
9230
10552
|
group="Utility",
|
|
9231
10553
|
hidden=True,
|
|
9232
|
-
|
|
10554
|
+
is_differentiable=False,
|
|
9233
10555
|
)
|