warp-lang 1.9.1__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +301 -287
- warp/__init__.pyi +794 -305
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1075 -0
- warp/_src/build.py +618 -0
- warp/_src/build_dll.py +640 -0
- warp/{builtins.py → _src/builtins.py} +1382 -377
- warp/_src/codegen.py +4359 -0
- warp/{config.py → _src/config.py} +178 -169
- warp/_src/constants.py +57 -0
- warp/_src/context.py +8294 -0
- warp/_src/dlpack.py +462 -0
- warp/_src/fabric.py +355 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +508 -0
- warp/_src/fem/cache.py +687 -0
- warp/_src/fem/dirichlet.py +188 -0
- warp/{fem → _src/fem}/domain.py +40 -30
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +701 -0
- warp/{fem → _src/fem}/field/nodal_field.py +30 -15
- warp/{fem → _src/fem}/field/restriction.py +1 -1
- warp/{fem → _src/fem}/field/virtual.py +53 -27
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
- warp/_src/fem/geometry/closest_point.py +97 -0
- warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
- warp/{fem → _src/fem}/geometry/element.py +32 -10
- warp/{fem → _src/fem}/geometry/geometry.py +48 -20
- warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
- warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
- warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
- warp/{fem → _src/fem}/geometry/partition.py +121 -63
- warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
- warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
- warp/{fem → _src/fem}/integrate.py +164 -158
- warp/_src/fem/linalg.py +383 -0
- warp/_src/fem/operator.py +396 -0
- warp/_src/fem/polynomial.py +229 -0
- warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
- warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
- warp/_src/fem/space/__init__.py +248 -0
- warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
- warp/_src/fem/space/basis_space.py +679 -0
- warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
- warp/{fem → _src/fem}/space/function_space.py +14 -13
- warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
- warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
- warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
- warp/{fem → _src/fem}/space/partition.py +117 -60
- warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/restriction.py +66 -33
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
- warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
- warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
- warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
- warp/_src/fem/space/topology.py +459 -0
- warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
- warp/_src/fem/types.py +112 -0
- warp/_src/fem/utils.py +486 -0
- warp/_src/jax.py +186 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +387 -0
- warp/_src/jax_experimental/ffi.py +1284 -0
- warp/_src/jax_experimental/xla_ffi.py +656 -0
- warp/_src/marching_cubes.py +708 -0
- warp/_src/math.py +414 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +163 -0
- warp/_src/optim/linear.py +1606 -0
- warp/_src/optim/sgd.py +112 -0
- warp/_src/paddle.py +406 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +289 -0
- warp/_src/render/render_opengl.py +3636 -0
- warp/_src/render/render_usd.py +937 -0
- warp/_src/render/utils.py +160 -0
- warp/_src/sparse.py +2716 -0
- warp/_src/tape.py +1206 -0
- warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
- warp/_src/torch.py +391 -0
- warp/_src/types.py +5870 -0
- warp/_src/utils.py +1693 -0
- warp/autograd.py +12 -1054
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +8 -588
- warp/build_dll.py +6 -721
- warp/codegen.py +6 -4251
- warp/constants.py +6 -39
- warp/context.py +12 -8062
- warp/dlpack.py +6 -444
- warp/examples/distributed/example_jacobi_mpi.py +4 -5
- warp/examples/fem/example_adaptive_grid.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +8 -8
- warp/examples/fem/example_diffusion.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_mixed_elasticity.py +2 -2
- warp/examples/fem/example_navier_stokes.py +1 -1
- warp/examples/fem/example_nonconforming_contact.py +7 -7
- warp/examples/fem/example_stokes.py +1 -1
- warp/examples/fem/example_stokes_transfer.py +1 -1
- warp/examples/fem/utils.py +2 -2
- warp/examples/interop/example_jax_callable.py +1 -1
- warp/examples/interop/example_jax_ffi_callback.py +1 -1
- warp/examples/interop/example_jax_kernel.py +1 -1
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/fabric.py +6 -337
- warp/fem/__init__.py +159 -97
- warp/fem/adaptivity.py +7 -489
- warp/fem/cache.py +9 -648
- warp/fem/dirichlet.py +6 -184
- warp/fem/field/__init__.py +8 -109
- warp/fem/field/field.py +7 -652
- warp/fem/geometry/__init__.py +7 -18
- warp/fem/geometry/closest_point.py +11 -77
- warp/fem/linalg.py +18 -366
- warp/fem/operator.py +11 -369
- warp/fem/polynomial.py +9 -209
- warp/fem/space/__init__.py +5 -211
- warp/fem/space/basis_space.py +6 -662
- warp/fem/space/shape/__init__.py +41 -118
- warp/fem/space/topology.py +6 -437
- warp/fem/types.py +6 -81
- warp/fem/utils.py +11 -444
- warp/jax.py +8 -165
- warp/jax_experimental/__init__.py +14 -1
- warp/jax_experimental/custom_call.py +8 -365
- warp/jax_experimental/ffi.py +17 -873
- warp/jax_experimental/xla_ffi.py +5 -605
- warp/marching_cubes.py +5 -689
- warp/math.py +16 -393
- warp/native/array.h +385 -37
- warp/native/builtin.h +314 -37
- warp/native/bvh.cpp +43 -9
- warp/native/bvh.cu +62 -27
- warp/native/bvh.h +310 -309
- warp/native/clang/clang.cpp +102 -97
- warp/native/coloring.cpp +0 -1
- warp/native/crt.h +208 -0
- warp/native/exports.h +156 -0
- warp/native/hashgrid.cu +2 -0
- warp/native/intersect.h +24 -1
- warp/native/intersect_tri.h +44 -35
- warp/native/mat.h +1456 -276
- warp/native/mesh.cpp +4 -4
- warp/native/mesh.cu +4 -2
- warp/native/mesh.h +176 -61
- warp/native/quat.h +0 -52
- warp/native/scan.cu +2 -0
- warp/native/sparse.cu +7 -3
- warp/native/spatial.h +12 -0
- warp/native/tile.h +681 -89
- warp/native/tile_radix_sort.h +1 -1
- warp/native/tile_reduce.h +394 -46
- warp/native/tile_scan.h +4 -4
- warp/native/vec.h +469 -0
- warp/native/version.h +23 -0
- warp/native/volume.cpp +1 -1
- warp/native/volume.cu +1 -0
- warp/native/volume.h +1 -1
- warp/native/volume_builder.cu +2 -0
- warp/native/warp.cpp +57 -29
- warp/native/warp.cu +253 -171
- warp/native/warp.h +11 -8
- warp/optim/__init__.py +6 -3
- warp/optim/adam.py +6 -145
- warp/optim/linear.py +14 -1585
- warp/optim/sgd.py +6 -94
- warp/paddle.py +6 -388
- warp/render/__init__.py +8 -4
- warp/render/imgui_manager.py +7 -267
- warp/render/render_opengl.py +6 -3618
- warp/render/render_usd.py +6 -919
- warp/render/utils.py +6 -142
- warp/sparse.py +37 -2563
- warp/tape.py +6 -1188
- warp/tests/__main__.py +1 -1
- warp/tests/cuda/test_async.py +4 -4
- warp/tests/cuda/test_conditional_captures.py +1 -1
- warp/tests/cuda/test_multigpu.py +1 -1
- warp/tests/cuda/test_streams.py +58 -1
- warp/tests/geometry/test_bvh.py +157 -22
- warp/tests/geometry/test_marching_cubes.py +0 -1
- warp/tests/geometry/test_mesh.py +5 -3
- warp/tests/geometry/test_mesh_query_aabb.py +5 -12
- warp/tests/geometry/test_mesh_query_point.py +5 -2
- warp/tests/geometry/test_mesh_query_ray.py +15 -3
- warp/tests/geometry/test_volume_write.py +5 -5
- warp/tests/interop/test_dlpack.py +14 -14
- warp/tests/interop/test_jax.py +772 -49
- warp/tests/interop/test_paddle.py +1 -1
- warp/tests/test_adam.py +0 -1
- warp/tests/test_arithmetic.py +9 -9
- warp/tests/test_array.py +527 -100
- warp/tests/test_array_reduce.py +3 -3
- warp/tests/test_atomic.py +12 -8
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +4 -4
- warp/tests/test_bool.py +2 -2
- warp/tests/test_builtins_resolution.py +5 -571
- warp/tests/test_codegen.py +33 -14
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_context.py +6 -6
- warp/tests/test_copy.py +242 -161
- warp/tests/test_ctypes.py +3 -3
- warp/tests/test_devices.py +24 -2
- warp/tests/test_examples.py +16 -84
- warp/tests/test_fabricarray.py +35 -35
- warp/tests/test_fast_math.py +0 -2
- warp/tests/test_fem.py +56 -10
- warp/tests/test_fixedarray.py +3 -3
- warp/tests/test_func.py +8 -5
- warp/tests/test_generics.py +1 -1
- warp/tests/test_indexedarray.py +24 -24
- warp/tests/test_intersect.py +39 -9
- warp/tests/test_large.py +1 -1
- warp/tests/test_lerp.py +3 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_map.py +35 -4
- warp/tests/test_mat.py +52 -62
- warp/tests/test_mat_constructors.py +4 -5
- warp/tests/test_mat_lite.py +1 -1
- warp/tests/test_mat_scalar_ops.py +121 -121
- warp/tests/test_math.py +34 -0
- warp/tests/test_module_aot.py +4 -4
- warp/tests/test_modules_lite.py +28 -2
- warp/tests/test_print.py +11 -11
- warp/tests/test_quat.py +93 -58
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +38 -10
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +126 -15
- warp/tests/test_spatial.py +105 -87
- warp/tests/test_special_values.py +6 -6
- warp/tests/test_static.py +7 -7
- warp/tests/test_struct.py +13 -2
- warp/tests/test_triangle_closest_point.py +48 -1
- warp/tests/test_types.py +27 -15
- warp/tests/test_utils.py +52 -52
- warp/tests/test_vec.py +29 -29
- warp/tests/test_vec_constructors.py +5 -5
- warp/tests/test_vec_scalar_ops.py +97 -97
- warp/tests/test_version.py +75 -0
- warp/tests/tile/test_tile.py +178 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +7 -4
- warp/tests/tile/test_tile_load.py +26 -2
- warp/tests/tile/test_tile_mathdx.py +3 -3
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +2 -4
- warp/tests/tile/test_tile_reduce.py +214 -13
- warp/tests/unittest_suites.py +6 -14
- warp/tests/unittest_utils.py +10 -9
- warp/tests/walkthrough_debug.py +3 -1
- warp/torch.py +6 -373
- warp/types.py +29 -5764
- warp/utils.py +10 -1659
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
- warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp/examples/assets/cartpole.urdf +0 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/nv_ant.xml +0 -92
- warp/examples/assets/nv_humanoid.xml +0 -183
- warp/examples/assets/quadruped.urdf +0 -268
- warp/examples/optim/example_bounce.py +0 -266
- warp/examples/optim/example_cloth_throw.py +0 -228
- warp/examples/optim/example_drone.py +0 -870
- warp/examples/optim/example_inverse_kinematics.py +0 -182
- warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
- warp/examples/optim/example_softbody_properties.py +0 -400
- warp/examples/optim/example_spring_cage.py +0 -245
- warp/examples/optim/example_trajectory.py +0 -227
- warp/examples/sim/example_cartpole.py +0 -143
- warp/examples/sim/example_cloth.py +0 -225
- warp/examples/sim/example_cloth_self_contact.py +0 -316
- warp/examples/sim/example_granular.py +0 -130
- warp/examples/sim/example_granular_collision_sdf.py +0 -202
- warp/examples/sim/example_jacobian_ik.py +0 -244
- warp/examples/sim/example_particle_chain.py +0 -124
- warp/examples/sim/example_quadruped.py +0 -203
- warp/examples/sim/example_rigid_chain.py +0 -203
- warp/examples/sim/example_rigid_contact.py +0 -195
- warp/examples/sim/example_rigid_force.py +0 -133
- warp/examples/sim/example_rigid_gyroscopic.py +0 -115
- warp/examples/sim/example_rigid_soft_contact.py +0 -140
- warp/examples/sim/example_soft_body.py +0 -196
- warp/examples/tile/example_tile_walker.py +0 -327
- warp/sim/__init__.py +0 -74
- warp/sim/articulation.py +0 -793
- warp/sim/collide.py +0 -2570
- warp/sim/graph_coloring.py +0 -307
- warp/sim/import_mjcf.py +0 -791
- warp/sim/import_snu.py +0 -227
- warp/sim/import_urdf.py +0 -579
- warp/sim/import_usd.py +0 -898
- warp/sim/inertia.py +0 -357
- warp/sim/integrator.py +0 -245
- warp/sim/integrator_euler.py +0 -2000
- warp/sim/integrator_featherstone.py +0 -2101
- warp/sim/integrator_vbd.py +0 -2487
- warp/sim/integrator_xpbd.py +0 -3295
- warp/sim/model.py +0 -4821
- warp/sim/particles.py +0 -121
- warp/sim/render.py +0 -431
- warp/sim/utils.py +0 -431
- warp/tests/sim/disabled_kinematics.py +0 -244
- warp/tests/sim/test_cloth.py +0 -863
- warp/tests/sim/test_collision.py +0 -743
- warp/tests/sim/test_coloring.py +0 -347
- warp/tests/sim/test_inertia.py +0 -161
- warp/tests/sim/test_model.py +0 -226
- warp/tests/sim/test_sim_grad.py +0 -287
- warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
- warp/tests/sim/test_sim_kinematics.py +0 -98
- warp/thirdparty/__init__.py +0 -0
- warp_lang-1.9.1.dist-info/RECORD +0 -456
- /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
- /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -20,11 +20,11 @@ import functools
|
|
|
20
20
|
import math
|
|
21
21
|
from typing import Any, Callable, Mapping, Sequence
|
|
22
22
|
|
|
23
|
-
import warp.build
|
|
24
|
-
import warp.context
|
|
25
|
-
import warp.utils
|
|
26
|
-
from warp.codegen import Reference, Var, get_arg_value, strip_reference
|
|
27
|
-
from warp.types import *
|
|
23
|
+
import warp._src.build
|
|
24
|
+
import warp._src.context
|
|
25
|
+
import warp._src.utils
|
|
26
|
+
from warp._src.codegen import Reference, Var, get_arg_value, strip_reference
|
|
27
|
+
from warp._src.types import *
|
|
28
28
|
|
|
29
29
|
from .context import add_builtin
|
|
30
30
|
|
|
@@ -61,11 +61,11 @@ def sametypes_create_value_func(default: TypeVar):
|
|
|
61
61
|
|
|
62
62
|
def extract_tuple(arg, as_constant=False):
|
|
63
63
|
if isinstance(arg, Var):
|
|
64
|
-
if isinstance(arg.type, warp.types.tuple_t):
|
|
64
|
+
if isinstance(arg.type, warp._src.types.tuple_t):
|
|
65
65
|
out = arg.type.values
|
|
66
66
|
else:
|
|
67
67
|
out = (arg,)
|
|
68
|
-
elif isinstance(arg, warp.types.tuple_t):
|
|
68
|
+
elif isinstance(arg, warp._src.types.tuple_t):
|
|
69
69
|
out = arg.values
|
|
70
70
|
elif not isinstance(arg, Sequence):
|
|
71
71
|
out = (arg,)
|
|
@@ -82,7 +82,7 @@ def static_len_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str
|
|
|
82
82
|
if arg_types is None:
|
|
83
83
|
return int
|
|
84
84
|
|
|
85
|
-
length = warp.types.type_length(arg_types["a"])
|
|
85
|
+
length = warp._src.types.type_length(arg_types["a"])
|
|
86
86
|
return Var(None, type=int, constant=length)
|
|
87
87
|
|
|
88
88
|
|
|
@@ -126,7 +126,7 @@ add_builtin(
|
|
|
126
126
|
value_func=sametypes_create_value_func(Scalar),
|
|
127
127
|
doc="Return -1 if ``x`` < 0, return 1 otherwise.",
|
|
128
128
|
group="Scalar Math",
|
|
129
|
-
|
|
129
|
+
is_differentiable=False,
|
|
130
130
|
)
|
|
131
131
|
|
|
132
132
|
add_builtin(
|
|
@@ -135,7 +135,7 @@ add_builtin(
|
|
|
135
135
|
value_func=sametypes_create_value_func(Scalar),
|
|
136
136
|
doc="Return 1.0 if ``x`` < 0.0, return 0.0 otherwise.",
|
|
137
137
|
group="Scalar Math",
|
|
138
|
-
|
|
138
|
+
is_differentiable=False,
|
|
139
139
|
)
|
|
140
140
|
add_builtin(
|
|
141
141
|
"nonzero",
|
|
@@ -143,7 +143,7 @@ add_builtin(
|
|
|
143
143
|
value_func=sametypes_create_value_func(Scalar),
|
|
144
144
|
doc="Return 1.0 if ``x`` is not equal to zero, return 0.0 otherwise.",
|
|
145
145
|
group="Scalar Math",
|
|
146
|
-
|
|
146
|
+
is_differentiable=False,
|
|
147
147
|
)
|
|
148
148
|
|
|
149
149
|
add_builtin(
|
|
@@ -285,7 +285,36 @@ add_builtin(
|
|
|
285
285
|
group="Scalar Math",
|
|
286
286
|
require_original_output_arg=True,
|
|
287
287
|
)
|
|
288
|
-
|
|
288
|
+
add_builtin(
|
|
289
|
+
"erf",
|
|
290
|
+
input_types={"x": Float},
|
|
291
|
+
value_func=sametypes_create_value_func(Float),
|
|
292
|
+
doc="Return the error function of ``x``.",
|
|
293
|
+
group="Scalar Math",
|
|
294
|
+
)
|
|
295
|
+
add_builtin(
|
|
296
|
+
"erfc",
|
|
297
|
+
input_types={"x": Float},
|
|
298
|
+
value_func=sametypes_create_value_func(Float),
|
|
299
|
+
doc="Return the complementary error function of ``x``.",
|
|
300
|
+
group="Scalar Math",
|
|
301
|
+
)
|
|
302
|
+
add_builtin(
|
|
303
|
+
"erfinv",
|
|
304
|
+
input_types={"x": Float},
|
|
305
|
+
value_func=sametypes_create_value_func(Float),
|
|
306
|
+
doc="Return the inverse error function of ``x``.",
|
|
307
|
+
group="Scalar Math",
|
|
308
|
+
require_original_output_arg=True,
|
|
309
|
+
)
|
|
310
|
+
add_builtin(
|
|
311
|
+
"erfcinv",
|
|
312
|
+
input_types={"x": Float},
|
|
313
|
+
value_func=sametypes_create_value_func(Float),
|
|
314
|
+
doc="Return the inverse complementary error function of ``x``.",
|
|
315
|
+
group="Scalar Math",
|
|
316
|
+
require_original_output_arg=True,
|
|
317
|
+
)
|
|
289
318
|
add_builtin(
|
|
290
319
|
"round",
|
|
291
320
|
input_types={"x": Float},
|
|
@@ -295,7 +324,7 @@ add_builtin(
|
|
|
295
324
|
|
|
296
325
|
This is the most intuitive form of rounding in the colloquial sense, but can be slower than other options like :func:`warp.rint()`.
|
|
297
326
|
Differs from :func:`numpy.round()`, which behaves the same way as :func:`numpy.rint()`.""",
|
|
298
|
-
|
|
327
|
+
is_differentiable=False,
|
|
299
328
|
)
|
|
300
329
|
|
|
301
330
|
add_builtin(
|
|
@@ -306,7 +335,7 @@ add_builtin(
|
|
|
306
335
|
doc="""Return the nearest integer value to ``x``, rounding halfway cases to nearest even integer.
|
|
307
336
|
|
|
308
337
|
It is generally faster than :func:`warp.round()`. Equivalent to :func:`numpy.rint()`.""",
|
|
309
|
-
|
|
338
|
+
is_differentiable=False,
|
|
310
339
|
)
|
|
311
340
|
|
|
312
341
|
add_builtin(
|
|
@@ -319,7 +348,7 @@ add_builtin(
|
|
|
319
348
|
In other words, it discards the fractional part of ``x``.
|
|
320
349
|
It is similar to casting ``float(int(a))``, but preserves the negative sign when ``x`` is in the range [-0.0, -1.0).
|
|
321
350
|
Equivalent to :func:`numpy.trunc()` and :func:`numpy.fix()`.""",
|
|
322
|
-
|
|
351
|
+
is_differentiable=False,
|
|
323
352
|
)
|
|
324
353
|
|
|
325
354
|
add_builtin(
|
|
@@ -328,7 +357,7 @@ add_builtin(
|
|
|
328
357
|
value_func=sametypes_create_value_func(Float),
|
|
329
358
|
group="Scalar Math",
|
|
330
359
|
doc="""Return the largest integer that is less than or equal to ``x``.""",
|
|
331
|
-
|
|
360
|
+
is_differentiable=False,
|
|
332
361
|
)
|
|
333
362
|
|
|
334
363
|
add_builtin(
|
|
@@ -337,7 +366,7 @@ add_builtin(
|
|
|
337
366
|
value_func=sametypes_create_value_func(Float),
|
|
338
367
|
group="Scalar Math",
|
|
339
368
|
doc="""Return the smallest integer that is greater than or equal to ``x``.""",
|
|
340
|
-
|
|
369
|
+
is_differentiable=False,
|
|
341
370
|
)
|
|
342
371
|
|
|
343
372
|
add_builtin(
|
|
@@ -348,7 +377,7 @@ add_builtin(
|
|
|
348
377
|
doc="""Retrieve the fractional part of ``x``.
|
|
349
378
|
|
|
350
379
|
In other words, it discards the integer part of ``x`` and is equivalent to ``x - trunc(x)``.""",
|
|
351
|
-
|
|
380
|
+
is_differentiable=False,
|
|
352
381
|
)
|
|
353
382
|
|
|
354
383
|
add_builtin(
|
|
@@ -357,7 +386,7 @@ add_builtin(
|
|
|
357
386
|
value_type=builtins.bool,
|
|
358
387
|
group="Scalar Math",
|
|
359
388
|
doc="""Return ``True`` if ``a`` is a finite number, otherwise return ``False``.""",
|
|
360
|
-
|
|
389
|
+
is_differentiable=False,
|
|
361
390
|
)
|
|
362
391
|
add_builtin(
|
|
363
392
|
"isfinite",
|
|
@@ -365,7 +394,7 @@ add_builtin(
|
|
|
365
394
|
value_type=builtins.bool,
|
|
366
395
|
group="Vector Math",
|
|
367
396
|
doc="Return ``True`` if all elements of the vector ``a`` are finite, otherwise return ``False``.",
|
|
368
|
-
|
|
397
|
+
is_differentiable=False,
|
|
369
398
|
)
|
|
370
399
|
add_builtin(
|
|
371
400
|
"isfinite",
|
|
@@ -373,7 +402,7 @@ add_builtin(
|
|
|
373
402
|
value_type=builtins.bool,
|
|
374
403
|
group="Vector Math",
|
|
375
404
|
doc="Return ``True`` if all elements of the quaternion ``a`` are finite, otherwise return ``False``.",
|
|
376
|
-
|
|
405
|
+
is_differentiable=False,
|
|
377
406
|
)
|
|
378
407
|
add_builtin(
|
|
379
408
|
"isfinite",
|
|
@@ -381,7 +410,7 @@ add_builtin(
|
|
|
381
410
|
value_type=builtins.bool,
|
|
382
411
|
group="Vector Math",
|
|
383
412
|
doc="Return ``True`` if all elements of the matrix ``a`` are finite, otherwise return ``False``.",
|
|
384
|
-
|
|
413
|
+
is_differentiable=False,
|
|
385
414
|
)
|
|
386
415
|
|
|
387
416
|
add_builtin(
|
|
@@ -390,7 +419,7 @@ add_builtin(
|
|
|
390
419
|
value_type=builtins.bool,
|
|
391
420
|
doc="Return ``True`` if ``a`` is NaN, otherwise return ``False``.",
|
|
392
421
|
group="Scalar Math",
|
|
393
|
-
|
|
422
|
+
is_differentiable=False,
|
|
394
423
|
)
|
|
395
424
|
add_builtin(
|
|
396
425
|
"isnan",
|
|
@@ -398,7 +427,7 @@ add_builtin(
|
|
|
398
427
|
value_type=builtins.bool,
|
|
399
428
|
group="Vector Math",
|
|
400
429
|
doc="Return ``True`` if any element of the vector ``a`` is NaN, otherwise return ``False``.",
|
|
401
|
-
|
|
430
|
+
is_differentiable=False,
|
|
402
431
|
)
|
|
403
432
|
add_builtin(
|
|
404
433
|
"isnan",
|
|
@@ -406,7 +435,7 @@ add_builtin(
|
|
|
406
435
|
value_type=builtins.bool,
|
|
407
436
|
group="Vector Math",
|
|
408
437
|
doc="Return ``True`` if any element of the quaternion ``a`` is NaN, otherwise return ``False``.",
|
|
409
|
-
|
|
438
|
+
is_differentiable=False,
|
|
410
439
|
)
|
|
411
440
|
add_builtin(
|
|
412
441
|
"isnan",
|
|
@@ -414,7 +443,7 @@ add_builtin(
|
|
|
414
443
|
value_type=builtins.bool,
|
|
415
444
|
group="Vector Math",
|
|
416
445
|
doc="Return ``True`` if any element of the matrix ``a`` is NaN, otherwise return ``False``.",
|
|
417
|
-
|
|
446
|
+
is_differentiable=False,
|
|
418
447
|
)
|
|
419
448
|
|
|
420
449
|
add_builtin(
|
|
@@ -423,7 +452,7 @@ add_builtin(
|
|
|
423
452
|
value_type=builtins.bool,
|
|
424
453
|
group="Scalar Math",
|
|
425
454
|
doc="""Return ``True`` if ``a`` is positive or negative infinity, otherwise return ``False``.""",
|
|
426
|
-
|
|
455
|
+
is_differentiable=False,
|
|
427
456
|
)
|
|
428
457
|
add_builtin(
|
|
429
458
|
"isinf",
|
|
@@ -431,7 +460,7 @@ add_builtin(
|
|
|
431
460
|
value_type=builtins.bool,
|
|
432
461
|
group="Vector Math",
|
|
433
462
|
doc="Return ``True`` if any element of the vector ``a`` is positive or negative infinity, otherwise return ``False``.",
|
|
434
|
-
|
|
463
|
+
is_differentiable=False,
|
|
435
464
|
)
|
|
436
465
|
add_builtin(
|
|
437
466
|
"isinf",
|
|
@@ -439,7 +468,7 @@ add_builtin(
|
|
|
439
468
|
value_type=builtins.bool,
|
|
440
469
|
group="Vector Math",
|
|
441
470
|
doc="Return ``True`` if any element of the quaternion ``a`` is positive or negative infinity, otherwise return ``False``.",
|
|
442
|
-
|
|
471
|
+
is_differentiable=False,
|
|
443
472
|
)
|
|
444
473
|
add_builtin(
|
|
445
474
|
"isinf",
|
|
@@ -447,7 +476,7 @@ add_builtin(
|
|
|
447
476
|
value_type=builtins.bool,
|
|
448
477
|
group="Vector Math",
|
|
449
478
|
doc="Return ``True`` if any element of the matrix ``a`` is positive or negative infinity, otherwise return ``False``.",
|
|
450
|
-
|
|
479
|
+
is_differentiable=False,
|
|
451
480
|
)
|
|
452
481
|
|
|
453
482
|
|
|
@@ -555,7 +584,7 @@ add_builtin(
|
|
|
555
584
|
value_func=lambda arg_types, arg_values: warp.uint32,
|
|
556
585
|
doc="Return the index of the minimum element of a vector ``a``.",
|
|
557
586
|
group="Vector Math",
|
|
558
|
-
|
|
587
|
+
is_differentiable=False,
|
|
559
588
|
)
|
|
560
589
|
add_builtin(
|
|
561
590
|
"argmax",
|
|
@@ -563,7 +592,7 @@ add_builtin(
|
|
|
563
592
|
value_func=lambda arg_types, arg_values: warp.uint32,
|
|
564
593
|
doc="Return the index of the maximum element of a vector ``a``.",
|
|
565
594
|
group="Vector Math",
|
|
566
|
-
|
|
595
|
+
is_differentiable=False,
|
|
567
596
|
)
|
|
568
597
|
|
|
569
598
|
add_builtin(
|
|
@@ -888,7 +917,7 @@ def vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
888
917
|
|
|
889
918
|
if dtype is None:
|
|
890
919
|
dtype = value_type
|
|
891
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
920
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
892
921
|
raise RuntimeError(
|
|
893
922
|
f"the value used to fill this vector is expected to be of the type `{dtype.__name__}`"
|
|
894
923
|
)
|
|
@@ -909,7 +938,7 @@ def vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
909
938
|
|
|
910
939
|
if dtype is None:
|
|
911
940
|
dtype = value_type
|
|
912
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
941
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
913
942
|
raise RuntimeError(
|
|
914
943
|
f"all values used to initialize this vector are expected to be of the type `{dtype.__name__}`"
|
|
915
944
|
)
|
|
@@ -992,7 +1021,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
992
1021
|
|
|
993
1022
|
if dtype is None:
|
|
994
1023
|
dtype = value_type
|
|
995
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1024
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
996
1025
|
raise RuntimeError(
|
|
997
1026
|
f"the value used to fill this matrix is expected to be of the type `{dtype.__name__}`"
|
|
998
1027
|
)
|
|
@@ -1002,7 +1031,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
1002
1031
|
raise RuntimeError("the `shape` argument must be specified when initializing a matrix by value")
|
|
1003
1032
|
|
|
1004
1033
|
if all(type_is_vector(x) for x in variadic_arg_types):
|
|
1005
|
-
warp.utils.warn(
|
|
1034
|
+
warp._src.utils.warn(
|
|
1006
1035
|
"the built-in `wp.matrix()` won't support taking column vectors as input "
|
|
1007
1036
|
"in the future. Use `wp.matrix_from_rows()` or `wp.matrix_from_cols()` instead.",
|
|
1008
1037
|
DeprecationWarning,
|
|
@@ -1031,7 +1060,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
1031
1060
|
|
|
1032
1061
|
if dtype is None:
|
|
1033
1062
|
dtype = value_type
|
|
1034
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1063
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
1035
1064
|
raise RuntimeError(
|
|
1036
1065
|
f"all values used to initialize this matrix are expected to be of the type `{dtype.__name__}`"
|
|
1037
1066
|
)
|
|
@@ -1203,49 +1232,18 @@ add_builtin(
|
|
|
1203
1232
|
doc="Create an identity matrix with shape=(n,n) with the type given by ``dtype``.",
|
|
1204
1233
|
group="Vector Math",
|
|
1205
1234
|
export=False,
|
|
1206
|
-
|
|
1235
|
+
is_differentiable=False,
|
|
1207
1236
|
)
|
|
1208
1237
|
|
|
1209
1238
|
|
|
1210
1239
|
def matrix_transform_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1211
|
-
warp.utils.warn(
|
|
1212
|
-
"the built-in `wp.matrix()` function to construct a 4x4 matrix from a 3D position, quaternion, "
|
|
1213
|
-
"and 3D scale vector will be deprecated in favor of `wp.transform_compose()`.",
|
|
1214
|
-
DeprecationWarning,
|
|
1215
|
-
)
|
|
1216
1240
|
if arg_types is None:
|
|
1217
1241
|
return matrix(shape=(4, 4), dtype=Float)
|
|
1218
1242
|
|
|
1219
|
-
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
value_type = scalar_infer_type(value_arg_types)
|
|
1224
|
-
except RuntimeError:
|
|
1225
|
-
raise RuntimeError(
|
|
1226
|
-
"all values given when constructing a transformation matrix must have the same type"
|
|
1227
|
-
) from None
|
|
1228
|
-
|
|
1229
|
-
if dtype is None:
|
|
1230
|
-
dtype = value_type
|
|
1231
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1232
|
-
raise RuntimeError(
|
|
1233
|
-
f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
|
|
1234
|
-
)
|
|
1235
|
-
|
|
1236
|
-
return matrix(shape=(4, 4), dtype=dtype)
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
def matrix_transform_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
1240
|
-
# We're in the codegen stage where we emit the code calling the built-in.
|
|
1241
|
-
# Further validate the given argument values if needed and map them
|
|
1242
|
-
# to the underlying C++ function's runtime and template params.
|
|
1243
|
-
|
|
1244
|
-
dtype = return_type._wp_scalar_type_
|
|
1245
|
-
|
|
1246
|
-
func_args = tuple(v for k, v in args.items() if k != "dtype")
|
|
1247
|
-
template_args = (4, 4, dtype)
|
|
1248
|
-
return (func_args, template_args)
|
|
1243
|
+
raise RuntimeError(
|
|
1244
|
+
"the built-in `wp.matrix()` to construct a 4x4 matrix from a 3D position, quaternion, "
|
|
1245
|
+
"and 3D scale vector has been removed in favor of `wp.transform_compose()`."
|
|
1246
|
+
)
|
|
1249
1247
|
|
|
1250
1248
|
|
|
1251
1249
|
add_builtin(
|
|
@@ -1259,13 +1257,14 @@ add_builtin(
|
|
|
1259
1257
|
defaults={"dtype": None},
|
|
1260
1258
|
value_func=matrix_transform_value_func,
|
|
1261
1259
|
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
1262
|
-
dispatch_func=matrix_transform_dispatch_func,
|
|
1263
1260
|
native_func="mat_t",
|
|
1264
1261
|
doc="""Construct a 4x4 transformation matrix that applies the transformations as
|
|
1265
1262
|
Translation(pos)*Rotation(rot)*Scaling(scale) when applied to column vectors, i.e.: y = (TRS)*x
|
|
1266
1263
|
|
|
1267
|
-
..
|
|
1268
|
-
This function has been
|
|
1264
|
+
.. versionremoved:: 1.10
|
|
1265
|
+
This function has been removed in favor of :func:`warp.math.transform_compose()`.
|
|
1266
|
+
|
|
1267
|
+
.. deprecated:: 1.8""",
|
|
1269
1268
|
group="Vector Math",
|
|
1270
1269
|
export=False,
|
|
1271
1270
|
)
|
|
@@ -1460,7 +1459,7 @@ def quaternion_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str
|
|
|
1460
1459
|
|
|
1461
1460
|
if dtype is None:
|
|
1462
1461
|
dtype = value_type
|
|
1463
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1462
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
1464
1463
|
raise RuntimeError(
|
|
1465
1464
|
f"all values used to initialize this quaternion are expected to be of the type `{dtype.__name__}`"
|
|
1466
1465
|
)
|
|
@@ -1568,7 +1567,7 @@ add_builtin(
|
|
|
1568
1567
|
group="Quaternion Math",
|
|
1569
1568
|
doc="Construct an identity quaternion with zero imaginary part and real part of 1.0",
|
|
1570
1569
|
export=True,
|
|
1571
|
-
|
|
1570
|
+
is_differentiable=False,
|
|
1572
1571
|
)
|
|
1573
1572
|
|
|
1574
1573
|
add_builtin(
|
|
@@ -1697,7 +1696,7 @@ def transformation_value_func(arg_types: Mapping[str, type], arg_values: Mapping
|
|
|
1697
1696
|
value_type = strip_reference(variadic_arg_types[0])
|
|
1698
1697
|
if dtype is None:
|
|
1699
1698
|
dtype = value_type
|
|
1700
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1699
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
1701
1700
|
raise RuntimeError(
|
|
1702
1701
|
f"the value used to fill this transform is expected to be of the type `{dtype.__name__}`"
|
|
1703
1702
|
)
|
|
@@ -1710,7 +1709,7 @@ def transformation_value_func(arg_types: Mapping[str, type], arg_values: Mapping
|
|
|
1710
1709
|
|
|
1711
1710
|
if dtype is None:
|
|
1712
1711
|
dtype = value_type
|
|
1713
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1712
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
1714
1713
|
raise RuntimeError(
|
|
1715
1714
|
f"all values used to initialize this transform are expected to be of the type `{dtype.__name__}`"
|
|
1716
1715
|
)
|
|
@@ -1735,7 +1734,7 @@ def transformation_pq_value_func(arg_types: Mapping[str, type], arg_values: Mapp
|
|
|
1735
1734
|
dtype = arg_values.get("dtype", None)
|
|
1736
1735
|
if dtype is None:
|
|
1737
1736
|
dtype = value_type
|
|
1738
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1737
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
1739
1738
|
raise RuntimeError(
|
|
1740
1739
|
f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
|
|
1741
1740
|
)
|
|
@@ -1750,9 +1749,19 @@ def transformation_dispatch_func(input_types: Mapping[str, type], return_type: A
|
|
|
1750
1749
|
|
|
1751
1750
|
dtype = return_type._wp_scalar_type_
|
|
1752
1751
|
|
|
1753
|
-
variadic_args =
|
|
1752
|
+
variadic_args = args.get("args", ())
|
|
1753
|
+
variadic_arg_count = len(variadic_args)
|
|
1754
|
+
|
|
1755
|
+
if variadic_arg_count == 7:
|
|
1756
|
+
func_args = variadic_args
|
|
1757
|
+
else:
|
|
1758
|
+
func_args = tuple(v for k, v in args.items() if k != "dtype")
|
|
1759
|
+
if "p" in args and "q" not in args:
|
|
1760
|
+
quat_ident = warp._src.codegen.Var(
|
|
1761
|
+
label=None, type=quaternion(dtype=dtype), constant=quaternion(dtype=dtype)(0, 0, 0, 1)
|
|
1762
|
+
)
|
|
1763
|
+
func_args += (quat_ident,)
|
|
1754
1764
|
|
|
1755
|
-
func_args = variadic_args
|
|
1756
1765
|
template_args = (dtype,)
|
|
1757
1766
|
return (func_args, template_args)
|
|
1758
1767
|
|
|
@@ -1760,7 +1769,7 @@ def transformation_dispatch_func(input_types: Mapping[str, type], return_type: A
|
|
|
1760
1769
|
add_builtin(
|
|
1761
1770
|
"transformation",
|
|
1762
1771
|
input_types={"p": vector(length=3, dtype=Float), "q": quaternion(dtype=Float), "dtype": Float},
|
|
1763
|
-
defaults={"dtype": None},
|
|
1772
|
+
defaults={"q": None, "dtype": None},
|
|
1764
1773
|
value_func=transformation_pq_value_func,
|
|
1765
1774
|
export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
|
|
1766
1775
|
dispatch_func=transformation_dispatch_func,
|
|
@@ -1784,7 +1793,6 @@ add_builtin(
|
|
|
1784
1793
|
doc="Construct a spatial transform vector of given dtype.",
|
|
1785
1794
|
group="Spatial Math",
|
|
1786
1795
|
export=False,
|
|
1787
|
-
missing_grad=True,
|
|
1788
1796
|
)
|
|
1789
1797
|
|
|
1790
1798
|
|
|
@@ -1819,7 +1827,7 @@ add_builtin(
|
|
|
1819
1827
|
group="Transformations",
|
|
1820
1828
|
doc="Construct an identity transform with zero translation and identity rotation.",
|
|
1821
1829
|
export=True,
|
|
1822
|
-
|
|
1830
|
+
is_differentiable=False,
|
|
1823
1831
|
)
|
|
1824
1832
|
|
|
1825
1833
|
add_builtin(
|
|
@@ -1953,7 +1961,7 @@ def spatial_vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping
|
|
|
1953
1961
|
|
|
1954
1962
|
if dtype is None:
|
|
1955
1963
|
dtype = value_type
|
|
1956
|
-
elif not warp.types.scalars_equal(value_type, dtype):
|
|
1964
|
+
elif not warp._src.types.scalars_equal(value_type, dtype):
|
|
1957
1965
|
raise RuntimeError(
|
|
1958
1966
|
f"all values used to initialize this spatial vector are expected to be of the type `{dtype.__name__}`"
|
|
1959
1967
|
)
|
|
@@ -2147,7 +2155,7 @@ add_builtin(
|
|
|
2147
2155
|
value_func=tile_zeros_value_func,
|
|
2148
2156
|
dispatch_func=tile_zeros_dispatch_func,
|
|
2149
2157
|
variadic=False,
|
|
2150
|
-
|
|
2158
|
+
is_differentiable=False,
|
|
2151
2159
|
doc="""Allocate a tile of zero-initialized items.
|
|
2152
2160
|
|
|
2153
2161
|
:param shape: Shape of the output tile
|
|
@@ -2167,7 +2175,7 @@ add_builtin(
|
|
|
2167
2175
|
value_func=tile_zeros_value_func,
|
|
2168
2176
|
dispatch_func=tile_zeros_dispatch_func,
|
|
2169
2177
|
variadic=False,
|
|
2170
|
-
|
|
2178
|
+
is_differentiable=False,
|
|
2171
2179
|
hidden=True,
|
|
2172
2180
|
group="Tile Primitives",
|
|
2173
2181
|
export=False,
|
|
@@ -2219,7 +2227,7 @@ add_builtin(
|
|
|
2219
2227
|
defaults={"storage": "register"},
|
|
2220
2228
|
value_func=tile_ones_value_func,
|
|
2221
2229
|
dispatch_func=tile_ones_dispatch_func,
|
|
2222
|
-
|
|
2230
|
+
is_differentiable=False,
|
|
2223
2231
|
doc="""Allocate a tile of one-initialized items.
|
|
2224
2232
|
|
|
2225
2233
|
:param shape: Shape of the output tile
|
|
@@ -2238,7 +2246,86 @@ add_builtin(
|
|
|
2238
2246
|
defaults={"storage": "register"},
|
|
2239
2247
|
value_func=tile_ones_value_func,
|
|
2240
2248
|
dispatch_func=tile_ones_dispatch_func,
|
|
2241
|
-
|
|
2249
|
+
is_differentiable=False,
|
|
2250
|
+
hidden=True,
|
|
2251
|
+
group="Tile Primitives",
|
|
2252
|
+
export=False,
|
|
2253
|
+
)
|
|
2254
|
+
|
|
2255
|
+
|
|
2256
|
+
def tile_full_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
2257
|
+
# return generic type (for doc builds)
|
|
2258
|
+
if arg_types is None:
|
|
2259
|
+
return tile(dtype=Any, shape=Tuple[int, ...])
|
|
2260
|
+
|
|
2261
|
+
shape = extract_tuple(arg_values["shape"], as_constant=True)
|
|
2262
|
+
|
|
2263
|
+
if None in shape:
|
|
2264
|
+
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
2265
|
+
|
|
2266
|
+
if "value" not in arg_values:
|
|
2267
|
+
raise TypeError("tile_full() missing required keyword argument 'value'")
|
|
2268
|
+
|
|
2269
|
+
if "dtype" not in arg_values:
|
|
2270
|
+
raise TypeError("tile_full() missing required keyword argument 'dtype'")
|
|
2271
|
+
|
|
2272
|
+
if "storage" not in arg_values:
|
|
2273
|
+
raise TypeError("tile_full() missing required keyword argument 'storage'")
|
|
2274
|
+
|
|
2275
|
+
if arg_values["storage"] not in {"shared", "register"}:
|
|
2276
|
+
raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
|
|
2277
|
+
|
|
2278
|
+
dtype = arg_values["dtype"]
|
|
2279
|
+
|
|
2280
|
+
return tile(dtype=dtype, shape=shape, storage=arg_values["storage"])
|
|
2281
|
+
|
|
2282
|
+
|
|
2283
|
+
def tile_full_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
2284
|
+
shape = extract_tuple(arg_values["shape"], as_constant=True)
|
|
2285
|
+
|
|
2286
|
+
if None in shape:
|
|
2287
|
+
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
2288
|
+
|
|
2289
|
+
dtype = arg_values["dtype"]
|
|
2290
|
+
value = arg_values["value"]
|
|
2291
|
+
|
|
2292
|
+
func_args = [value]
|
|
2293
|
+
|
|
2294
|
+
template_args = []
|
|
2295
|
+
template_args.append(dtype)
|
|
2296
|
+
template_args.extend(shape)
|
|
2297
|
+
|
|
2298
|
+
return (func_args, template_args)
|
|
2299
|
+
|
|
2300
|
+
|
|
2301
|
+
add_builtin(
|
|
2302
|
+
"tile_full",
|
|
2303
|
+
input_types={"shape": Tuple[int, ...], "value": Any, "dtype": Any, "storage": str},
|
|
2304
|
+
defaults={"storage": "register"},
|
|
2305
|
+
value_func=tile_full_value_func,
|
|
2306
|
+
dispatch_func=tile_full_dispatch_func,
|
|
2307
|
+
is_differentiable=False,
|
|
2308
|
+
doc="""Allocate a tile filled with the specified value.
|
|
2309
|
+
|
|
2310
|
+
:param shape: Shape of the output tile
|
|
2311
|
+
:param value: Value to fill the tile with
|
|
2312
|
+
:param dtype: Data type of output tile's elements
|
|
2313
|
+
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
2314
|
+
(default) or ``"shared"`` for shared memory.
|
|
2315
|
+
:returns: A tile filled with the specified value""",
|
|
2316
|
+
group="Tile Primitives",
|
|
2317
|
+
export=False,
|
|
2318
|
+
)
|
|
2319
|
+
|
|
2320
|
+
|
|
2321
|
+
# overload for scalar shape
|
|
2322
|
+
add_builtin(
|
|
2323
|
+
"tile_full",
|
|
2324
|
+
input_types={"shape": int, "value": Any, "dtype": Any, "storage": str},
|
|
2325
|
+
defaults={"storage": "register"},
|
|
2326
|
+
value_func=tile_full_value_func,
|
|
2327
|
+
dispatch_func=tile_full_dispatch_func,
|
|
2328
|
+
is_differentiable=False,
|
|
2242
2329
|
hidden=True,
|
|
2243
2330
|
group="Tile Primitives",
|
|
2244
2331
|
export=False,
|
|
@@ -2300,13 +2387,13 @@ def tile_arange_dispatch_func(arg_types: Mapping[str, type], return_type: Any, a
|
|
|
2300
2387
|
args = arg_values["args"]
|
|
2301
2388
|
|
|
2302
2389
|
if len(args) == 1:
|
|
2303
|
-
start = warp.codegen.Var(label=None, type=return_type.dtype, constant=0)
|
|
2390
|
+
start = warp._src.codegen.Var(label=None, type=return_type.dtype, constant=0)
|
|
2304
2391
|
stop = args[0]
|
|
2305
|
-
step = warp.codegen.Var(label=None, type=return_type.dtype, constant=1)
|
|
2392
|
+
step = warp._src.codegen.Var(label=None, type=return_type.dtype, constant=1)
|
|
2306
2393
|
elif len(args) == 2:
|
|
2307
2394
|
start = args[0]
|
|
2308
2395
|
stop = args[1]
|
|
2309
|
-
step = warp.codegen.Var(label=None, type=return_type.dtype, constant=1)
|
|
2396
|
+
step = warp._src.codegen.Var(label=None, type=return_type.dtype, constant=1)
|
|
2310
2397
|
elif len(args) == 3:
|
|
2311
2398
|
start = args[0]
|
|
2312
2399
|
stop = args[1]
|
|
@@ -2329,7 +2416,7 @@ add_builtin(
|
|
|
2329
2416
|
value_func=tile_arange_value_func,
|
|
2330
2417
|
dispatch_func=tile_arange_dispatch_func,
|
|
2331
2418
|
variadic=True,
|
|
2332
|
-
|
|
2419
|
+
is_differentiable=False,
|
|
2333
2420
|
doc="""Generate a tile of linearly spaced elements.
|
|
2334
2421
|
|
|
2335
2422
|
:param args: Variable-length positional arguments, interpreted as:
|
|
@@ -3124,7 +3211,7 @@ add_builtin(
|
|
|
3124
3211
|
:param shape: Shape of the returned slice
|
|
3125
3212
|
:returns: A tile with dimensions given by the specified shape or the remaining source tile dimensions""",
|
|
3126
3213
|
group="Tile Primitives",
|
|
3127
|
-
|
|
3214
|
+
is_differentiable=False,
|
|
3128
3215
|
export=False,
|
|
3129
3216
|
)
|
|
3130
3217
|
|
|
@@ -3371,7 +3458,32 @@ add_builtin(
|
|
|
3371
3458
|
|
|
3372
3459
|
add_builtin(
|
|
3373
3460
|
"assign",
|
|
3374
|
-
input_types={"dst": tile(dtype=Any, shape=Tuple[int,
|
|
3461
|
+
input_types={"dst": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "src": Any},
|
|
3462
|
+
value_func=tile_assign_value_func,
|
|
3463
|
+
group="Tile Primitives",
|
|
3464
|
+
export=False,
|
|
3465
|
+
hidden=True,
|
|
3466
|
+
)
|
|
3467
|
+
|
|
3468
|
+
add_builtin(
|
|
3469
|
+
"assign",
|
|
3470
|
+
input_types={"dst": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "src": Any},
|
|
3471
|
+
value_func=tile_assign_value_func,
|
|
3472
|
+
group="Tile Primitives",
|
|
3473
|
+
export=False,
|
|
3474
|
+
hidden=True,
|
|
3475
|
+
)
|
|
3476
|
+
|
|
3477
|
+
add_builtin(
|
|
3478
|
+
"assign",
|
|
3479
|
+
input_types={
|
|
3480
|
+
"dst": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
3481
|
+
"i": int,
|
|
3482
|
+
"j": int,
|
|
3483
|
+
"k": int,
|
|
3484
|
+
"l": int,
|
|
3485
|
+
"src": Any,
|
|
3486
|
+
},
|
|
3375
3487
|
value_func=tile_assign_value_func,
|
|
3376
3488
|
group="Tile Primitives",
|
|
3377
3489
|
export=False,
|
|
@@ -3380,7 +3492,15 @@ add_builtin(
|
|
|
3380
3492
|
|
|
3381
3493
|
add_builtin(
|
|
3382
3494
|
"assign",
|
|
3383
|
-
input_types={
|
|
3495
|
+
input_types={
|
|
3496
|
+
"dst": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
3497
|
+
"i": int,
|
|
3498
|
+
"j": int,
|
|
3499
|
+
"k": int,
|
|
3500
|
+
"l": int,
|
|
3501
|
+
"m": int,
|
|
3502
|
+
"src": Any,
|
|
3503
|
+
},
|
|
3384
3504
|
value_func=tile_assign_value_func,
|
|
3385
3505
|
group="Tile Primitives",
|
|
3386
3506
|
export=False,
|
|
@@ -3395,6 +3515,8 @@ add_builtin(
|
|
|
3395
3515
|
"j": int,
|
|
3396
3516
|
"k": int,
|
|
3397
3517
|
"l": int,
|
|
3518
|
+
"m": int,
|
|
3519
|
+
"n": int,
|
|
3398
3520
|
"src": Any,
|
|
3399
3521
|
},
|
|
3400
3522
|
value_func=tile_assign_value_func,
|
|
@@ -3416,7 +3538,7 @@ def tile_value_func(arg_types, arg_values):
|
|
|
3416
3538
|
|
|
3417
3539
|
if preserve_type:
|
|
3418
3540
|
dtype = arg_types["x"]
|
|
3419
|
-
shape = (warp.codegen.options["block_dim"],)
|
|
3541
|
+
shape = (warp._src.codegen.options["block_dim"],)
|
|
3420
3542
|
|
|
3421
3543
|
return tile(dtype=dtype, shape=shape)
|
|
3422
3544
|
|
|
@@ -3424,18 +3546,18 @@ def tile_value_func(arg_types, arg_values):
|
|
|
3424
3546
|
if type_is_vector(arg_types["x"]):
|
|
3425
3547
|
dtype = arg_types["x"]._wp_scalar_type_
|
|
3426
3548
|
length = arg_types["x"]._shape_[0]
|
|
3427
|
-
shape = (length, warp.codegen.options["block_dim"])
|
|
3549
|
+
shape = (length, warp._src.codegen.options["block_dim"])
|
|
3428
3550
|
elif type_is_quaternion(arg_types["x"]):
|
|
3429
3551
|
dtype = arg_types["x"]._wp_scalar_type_
|
|
3430
|
-
shape = (4, warp.codegen.options["block_dim"])
|
|
3552
|
+
shape = (4, warp._src.codegen.options["block_dim"])
|
|
3431
3553
|
elif type_is_matrix(arg_types["x"]):
|
|
3432
3554
|
dtype = arg_types["x"]._wp_scalar_type_
|
|
3433
3555
|
rows = arg_types["x"]._shape_[0]
|
|
3434
3556
|
cols = arg_types["x"]._shape_[1]
|
|
3435
|
-
shape = (rows, cols, warp.codegen.options["block_dim"])
|
|
3557
|
+
shape = (rows, cols, warp._src.codegen.options["block_dim"])
|
|
3436
3558
|
else:
|
|
3437
3559
|
dtype = arg_types["x"]
|
|
3438
|
-
shape = (warp.codegen.options["block_dim"],)
|
|
3560
|
+
shape = (warp._src.codegen.options["block_dim"],)
|
|
3439
3561
|
|
|
3440
3562
|
return tile(dtype=dtype, shape=shape)
|
|
3441
3563
|
|
|
@@ -3525,17 +3647,17 @@ def untile_value_func(arg_types, arg_values):
|
|
|
3525
3647
|
if not is_tile(t):
|
|
3526
3648
|
raise TypeError(f"untile() argument must be a tile, got {t!r}")
|
|
3527
3649
|
|
|
3528
|
-
if t.shape[-1] != warp.codegen.options["block_dim"]:
|
|
3650
|
+
if t.shape[-1] != warp._src.codegen.options["block_dim"]:
|
|
3529
3651
|
raise ValueError(
|
|
3530
|
-
f"untile() argument last dimension {t.shape[-1]} does not match the expected block width {warp.codegen.options['block_dim']}"
|
|
3652
|
+
f"untile() argument last dimension {t.shape[-1]} does not match the expected block width {warp._src.codegen.options['block_dim']}"
|
|
3531
3653
|
)
|
|
3532
3654
|
|
|
3533
3655
|
if len(t.shape) == 1:
|
|
3534
3656
|
return t.dtype
|
|
3535
3657
|
elif len(t.shape) == 2:
|
|
3536
|
-
return warp.types.vector(t.shape[0], t.dtype)
|
|
3658
|
+
return warp._src.types.vector(t.shape[0], t.dtype)
|
|
3537
3659
|
elif len(t.shape) == 3:
|
|
3538
|
-
return warp.types.matrix((t.shape[0], t.shape[1]), t.dtype)
|
|
3660
|
+
return warp._src.types.matrix((t.shape[0], t.shape[1]), t.dtype)
|
|
3539
3661
|
else:
|
|
3540
3662
|
raise ValueError(f"untile() argument must have a positive size in dimension 0, but got {t.shape[0]}")
|
|
3541
3663
|
|
|
@@ -3597,7 +3719,36 @@ def tile_extract_value_func(arg_types, arg_values):
|
|
|
3597
3719
|
# force the input tile to shared memory
|
|
3598
3720
|
arg_types["a"].storage = "shared"
|
|
3599
3721
|
|
|
3600
|
-
|
|
3722
|
+
# count the number of indices (all parameters except the tile "a")
|
|
3723
|
+
num_indices = len(arg_types) - 1
|
|
3724
|
+
tile_dtype = arg_types["a"].dtype
|
|
3725
|
+
tile_shape = arg_types["a"].shape
|
|
3726
|
+
|
|
3727
|
+
if type_is_vector(tile_dtype):
|
|
3728
|
+
if num_indices == len(tile_shape):
|
|
3729
|
+
return tile_dtype
|
|
3730
|
+
elif num_indices == len(tile_shape) + 1:
|
|
3731
|
+
return tile_dtype._wp_scalar_type_
|
|
3732
|
+
else:
|
|
3733
|
+
raise IndexError(
|
|
3734
|
+
f"tile_extract: incorrect number of indices ({num_indices}) for tile shape {tuple(tile_shape)}"
|
|
3735
|
+
)
|
|
3736
|
+
elif type_is_matrix(tile_dtype):
|
|
3737
|
+
if num_indices == len(tile_shape):
|
|
3738
|
+
return tile_dtype
|
|
3739
|
+
elif num_indices == len(tile_shape) + 2:
|
|
3740
|
+
return tile_dtype._wp_scalar_type_
|
|
3741
|
+
else:
|
|
3742
|
+
raise IndexError(
|
|
3743
|
+
f"tile_extract: incorrect number of indices ({num_indices}) for matrix tile shape {tuple(tile_shape)}"
|
|
3744
|
+
)
|
|
3745
|
+
else:
|
|
3746
|
+
# scalar element: index count must exactly match tile rank
|
|
3747
|
+
if num_indices == len(tile_shape):
|
|
3748
|
+
return tile_dtype
|
|
3749
|
+
raise IndexError(
|
|
3750
|
+
f"tile_extract: incorrect number of indices ({num_indices}) for tile shape {tuple(tile_shape)}"
|
|
3751
|
+
)
|
|
3601
3752
|
|
|
3602
3753
|
|
|
3603
3754
|
add_builtin(
|
|
@@ -3621,7 +3772,7 @@ add_builtin(
|
|
|
3621
3772
|
|
|
3622
3773
|
add_builtin(
|
|
3623
3774
|
"tile_extract",
|
|
3624
|
-
input_types={"a": tile(dtype=Any, shape=Tuple[int,
|
|
3775
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int},
|
|
3625
3776
|
value_func=tile_extract_value_func,
|
|
3626
3777
|
variadic=False,
|
|
3627
3778
|
doc="""Extract a single element from the tile.
|
|
@@ -3632,7 +3783,28 @@ add_builtin(
|
|
|
3632
3783
|
|
|
3633
3784
|
:param a: Tile to extract the element from
|
|
3634
3785
|
:param i: Coordinate of element on first dimension
|
|
3635
|
-
:param j: Coordinate of element on the second dimension
|
|
3786
|
+
:param j: Coordinate of element on the second dimension, or vector index
|
|
3787
|
+
:returns: The value of the element at the specified tile location with the same data type as the input tile""",
|
|
3788
|
+
group="Tile Primitives",
|
|
3789
|
+
hidden=True,
|
|
3790
|
+
export=False,
|
|
3791
|
+
)
|
|
3792
|
+
|
|
3793
|
+
add_builtin(
|
|
3794
|
+
"tile_extract",
|
|
3795
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int},
|
|
3796
|
+
value_func=tile_extract_value_func,
|
|
3797
|
+
variadic=False,
|
|
3798
|
+
doc="""Extract a single element from the tile.
|
|
3799
|
+
|
|
3800
|
+
This function will extract an element from the tile and broadcast its value to all threads in the block.
|
|
3801
|
+
|
|
3802
|
+
Note that this may incur additional synchronization if the source tile is a register tile.
|
|
3803
|
+
|
|
3804
|
+
:param a: Tile to extract the element from
|
|
3805
|
+
:param i: Coordinate of element on first dimension
|
|
3806
|
+
:param j: Coordinate of element on the second dimension, or first matrix index
|
|
3807
|
+
:param k: Coordinate of element on the third dimension, or vector index, or second matrix index
|
|
3636
3808
|
:returns: The value of the element at the specified tile location with the same data type as the input tile""",
|
|
3637
3809
|
group="Tile Primitives",
|
|
3638
3810
|
hidden=True,
|
|
@@ -3641,7 +3813,36 @@ add_builtin(
|
|
|
3641
3813
|
|
|
3642
3814
|
add_builtin(
|
|
3643
3815
|
"tile_extract",
|
|
3644
|
-
input_types={"a": tile(dtype=Any, shape=Tuple[int,
|
|
3816
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int},
|
|
3817
|
+
value_func=tile_extract_value_func,
|
|
3818
|
+
variadic=False,
|
|
3819
|
+
doc="""Extract a single element from the tile.
|
|
3820
|
+
|
|
3821
|
+
This function will extract an element from the tile and broadcast its value to all threads in the block.
|
|
3822
|
+
|
|
3823
|
+
Note that this may incur additional synchronization if the source tile is a register tile.
|
|
3824
|
+
|
|
3825
|
+
:param a: Tile to extract the element from
|
|
3826
|
+
:param i: Coordinate of element on first dimension
|
|
3827
|
+
:param j: Coordinate of element on the second dimension
|
|
3828
|
+
:param k: Coordinate of element on the third dimension, or first matrix index
|
|
3829
|
+
:param l: Coordinate of element on the fourth dimension, or vector index, or second matrix index
|
|
3830
|
+
:returns: The value of the element at the specified tile location, with the same data type as the input tile""",
|
|
3831
|
+
group="Tile Primitives",
|
|
3832
|
+
hidden=True,
|
|
3833
|
+
export=False,
|
|
3834
|
+
)
|
|
3835
|
+
|
|
3836
|
+
add_builtin(
|
|
3837
|
+
"tile_extract",
|
|
3838
|
+
input_types={
|
|
3839
|
+
"a": tile(dtype=Any, shape=Tuple[int, ...]),
|
|
3840
|
+
"i": int,
|
|
3841
|
+
"j": int,
|
|
3842
|
+
"k": int,
|
|
3843
|
+
"l": int,
|
|
3844
|
+
"m": int,
|
|
3845
|
+
},
|
|
3645
3846
|
value_func=tile_extract_value_func,
|
|
3646
3847
|
variadic=False,
|
|
3647
3848
|
doc="""Extract a single element from the tile.
|
|
@@ -3654,7 +3855,9 @@ add_builtin(
|
|
|
3654
3855
|
:param i: Coordinate of element on first dimension
|
|
3655
3856
|
:param j: Coordinate of element on the second dimension
|
|
3656
3857
|
:param k: Coordinate of element on the third dimension
|
|
3657
|
-
:
|
|
3858
|
+
:param l: Coordinate of element on the fourth dimension, or first matrix index
|
|
3859
|
+
:param m: Vector index, or second matrix index
|
|
3860
|
+
:returns: The value of the element at the specified tile location, with the same data type as the input tile""",
|
|
3658
3861
|
group="Tile Primitives",
|
|
3659
3862
|
hidden=True,
|
|
3660
3863
|
export=False,
|
|
@@ -3662,7 +3865,15 @@ add_builtin(
|
|
|
3662
3865
|
|
|
3663
3866
|
add_builtin(
|
|
3664
3867
|
"tile_extract",
|
|
3665
|
-
input_types={
|
|
3868
|
+
input_types={
|
|
3869
|
+
"a": tile(dtype=Any, shape=Tuple[int, int, int, int]),
|
|
3870
|
+
"i": int,
|
|
3871
|
+
"j": int,
|
|
3872
|
+
"k": int,
|
|
3873
|
+
"l": int,
|
|
3874
|
+
"m": int,
|
|
3875
|
+
"n": int,
|
|
3876
|
+
},
|
|
3666
3877
|
value_func=tile_extract_value_func,
|
|
3667
3878
|
variadic=False,
|
|
3668
3879
|
doc="""Extract a single element from the tile.
|
|
@@ -3676,6 +3887,8 @@ add_builtin(
|
|
|
3676
3887
|
:param j: Coordinate of element on the second dimension
|
|
3677
3888
|
:param k: Coordinate of element on the third dimension
|
|
3678
3889
|
:param l: Coordinate of element on the fourth dimension
|
|
3890
|
+
:param m: Vector index, or first matrix index
|
|
3891
|
+
:param n: Second matrix index
|
|
3679
3892
|
:returns: The value of the element at the specified tile location, with the same data type as the input tile""",
|
|
3680
3893
|
group="Tile Primitives",
|
|
3681
3894
|
hidden=True,
|
|
@@ -3762,50 +3975,161 @@ add_builtin(
|
|
|
3762
3975
|
export=False,
|
|
3763
3976
|
)
|
|
3764
3977
|
|
|
3765
|
-
|
|
3766
|
-
def tile_transpose_value_func(arg_types, arg_values):
|
|
3767
|
-
# return generic type (for doc builds)
|
|
3768
|
-
if arg_types is None:
|
|
3769
|
-
return tile(dtype=Any, shape=Tuple[int, int])
|
|
3770
|
-
|
|
3771
|
-
if len(arg_types) != 1:
|
|
3772
|
-
raise TypeError(f"tile_transpose() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
3773
|
-
|
|
3774
|
-
t = arg_types["a"]
|
|
3775
|
-
|
|
3776
|
-
if not is_tile(t):
|
|
3777
|
-
raise TypeError(f"tile_transpose() argument must be a tile, got {t!r}")
|
|
3778
|
-
|
|
3779
|
-
layout = None
|
|
3780
|
-
|
|
3781
|
-
# flip layout
|
|
3782
|
-
if t.layout == "rowmajor":
|
|
3783
|
-
layout = "colmajor"
|
|
3784
|
-
elif t.layout == "colmajor":
|
|
3785
|
-
layout = "rowmajor"
|
|
3786
|
-
|
|
3787
|
-
# force the input tile to shared memory
|
|
3788
|
-
t.storage = "shared"
|
|
3789
|
-
|
|
3790
|
-
return tile(
|
|
3791
|
-
dtype=t.dtype,
|
|
3792
|
-
shape=t.shape[::-1],
|
|
3793
|
-
storage=t.storage,
|
|
3794
|
-
strides=t.strides[::-1],
|
|
3795
|
-
layout=layout,
|
|
3796
|
-
owner=False,
|
|
3797
|
-
)
|
|
3798
|
-
|
|
3799
|
-
|
|
3800
3978
|
add_builtin(
|
|
3801
|
-
"
|
|
3802
|
-
input_types={"a": tile(dtype=Any, shape=Tuple[int,
|
|
3803
|
-
value_func=
|
|
3804
|
-
|
|
3805
|
-
|
|
3806
|
-
|
|
3807
|
-
|
|
3808
|
-
|
|
3979
|
+
"tile_bit_and_inplace",
|
|
3980
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
|
|
3981
|
+
value_func=tile_inplace_value_func,
|
|
3982
|
+
group="Tile Primitives",
|
|
3983
|
+
hidden=True,
|
|
3984
|
+
export=False,
|
|
3985
|
+
is_differentiable=False,
|
|
3986
|
+
)
|
|
3987
|
+
add_builtin(
|
|
3988
|
+
"tile_bit_and_inplace",
|
|
3989
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
|
|
3990
|
+
value_func=tile_inplace_value_func,
|
|
3991
|
+
group="Tile Primitives",
|
|
3992
|
+
hidden=True,
|
|
3993
|
+
export=False,
|
|
3994
|
+
is_differentiable=False,
|
|
3995
|
+
)
|
|
3996
|
+
add_builtin(
|
|
3997
|
+
"tile_bit_and_inplace",
|
|
3998
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
|
|
3999
|
+
value_func=tile_inplace_value_func,
|
|
4000
|
+
group="Tile Primitives",
|
|
4001
|
+
hidden=True,
|
|
4002
|
+
export=False,
|
|
4003
|
+
is_differentiable=False,
|
|
4004
|
+
)
|
|
4005
|
+
add_builtin(
|
|
4006
|
+
"tile_bit_and_inplace",
|
|
4007
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
|
|
4008
|
+
value_func=tile_inplace_value_func,
|
|
4009
|
+
group="Tile Primitives",
|
|
4010
|
+
hidden=True,
|
|
4011
|
+
export=False,
|
|
4012
|
+
is_differentiable=False,
|
|
4013
|
+
)
|
|
4014
|
+
|
|
4015
|
+
add_builtin(
|
|
4016
|
+
"tile_bit_or_inplace",
|
|
4017
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
|
|
4018
|
+
value_func=tile_inplace_value_func,
|
|
4019
|
+
group="Tile Primitives",
|
|
4020
|
+
hidden=True,
|
|
4021
|
+
export=False,
|
|
4022
|
+
is_differentiable=False,
|
|
4023
|
+
)
|
|
4024
|
+
add_builtin(
|
|
4025
|
+
"tile_bit_or_inplace",
|
|
4026
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
|
|
4027
|
+
value_func=tile_inplace_value_func,
|
|
4028
|
+
group="Tile Primitives",
|
|
4029
|
+
hidden=True,
|
|
4030
|
+
export=False,
|
|
4031
|
+
is_differentiable=False,
|
|
4032
|
+
)
|
|
4033
|
+
add_builtin(
|
|
4034
|
+
"tile_bit_or_inplace",
|
|
4035
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
|
|
4036
|
+
value_func=tile_inplace_value_func,
|
|
4037
|
+
group="Tile Primitives",
|
|
4038
|
+
hidden=True,
|
|
4039
|
+
export=False,
|
|
4040
|
+
is_differentiable=False,
|
|
4041
|
+
)
|
|
4042
|
+
add_builtin(
|
|
4043
|
+
"tile_bit_or_inplace",
|
|
4044
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
|
|
4045
|
+
value_func=tile_inplace_value_func,
|
|
4046
|
+
group="Tile Primitives",
|
|
4047
|
+
hidden=True,
|
|
4048
|
+
export=False,
|
|
4049
|
+
is_differentiable=False,
|
|
4050
|
+
)
|
|
4051
|
+
|
|
4052
|
+
add_builtin(
|
|
4053
|
+
"tile_bit_xor_inplace",
|
|
4054
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "value": Any},
|
|
4055
|
+
value_func=tile_inplace_value_func,
|
|
4056
|
+
group="Tile Primitives",
|
|
4057
|
+
hidden=True,
|
|
4058
|
+
export=False,
|
|
4059
|
+
is_differentiable=False,
|
|
4060
|
+
)
|
|
4061
|
+
add_builtin(
|
|
4062
|
+
"tile_bit_xor_inplace",
|
|
4063
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "value": Any},
|
|
4064
|
+
value_func=tile_inplace_value_func,
|
|
4065
|
+
group="Tile Primitives",
|
|
4066
|
+
hidden=True,
|
|
4067
|
+
export=False,
|
|
4068
|
+
is_differentiable=False,
|
|
4069
|
+
)
|
|
4070
|
+
add_builtin(
|
|
4071
|
+
"tile_bit_xor_inplace",
|
|
4072
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "value": Any},
|
|
4073
|
+
value_func=tile_inplace_value_func,
|
|
4074
|
+
group="Tile Primitives",
|
|
4075
|
+
hidden=True,
|
|
4076
|
+
export=False,
|
|
4077
|
+
is_differentiable=False,
|
|
4078
|
+
)
|
|
4079
|
+
add_builtin(
|
|
4080
|
+
"tile_bit_xor_inplace",
|
|
4081
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "i": int, "j": int, "k": int, "l": int, "value": Any},
|
|
4082
|
+
value_func=tile_inplace_value_func,
|
|
4083
|
+
group="Tile Primitives",
|
|
4084
|
+
hidden=True,
|
|
4085
|
+
export=False,
|
|
4086
|
+
is_differentiable=False,
|
|
4087
|
+
)
|
|
4088
|
+
|
|
4089
|
+
|
|
4090
|
+
def tile_transpose_value_func(arg_types, arg_values):
|
|
4091
|
+
# return generic type (for doc builds)
|
|
4092
|
+
if arg_types is None:
|
|
4093
|
+
return tile(dtype=Any, shape=Tuple[int, int])
|
|
4094
|
+
|
|
4095
|
+
if len(arg_types) != 1:
|
|
4096
|
+
raise TypeError(f"tile_transpose() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
4097
|
+
|
|
4098
|
+
t = arg_types["a"]
|
|
4099
|
+
|
|
4100
|
+
if not is_tile(t):
|
|
4101
|
+
raise TypeError(f"tile_transpose() argument must be a tile, got {t!r}")
|
|
4102
|
+
|
|
4103
|
+
layout = None
|
|
4104
|
+
|
|
4105
|
+
# flip layout
|
|
4106
|
+
if t.layout == "rowmajor":
|
|
4107
|
+
layout = "colmajor"
|
|
4108
|
+
elif t.layout == "colmajor":
|
|
4109
|
+
layout = "rowmajor"
|
|
4110
|
+
|
|
4111
|
+
# force the input tile to shared memory
|
|
4112
|
+
t.storage = "shared"
|
|
4113
|
+
|
|
4114
|
+
return tile(
|
|
4115
|
+
dtype=t.dtype,
|
|
4116
|
+
shape=t.shape[::-1],
|
|
4117
|
+
storage=t.storage,
|
|
4118
|
+
strides=t.strides[::-1],
|
|
4119
|
+
layout=layout,
|
|
4120
|
+
owner=False,
|
|
4121
|
+
)
|
|
4122
|
+
|
|
4123
|
+
|
|
4124
|
+
add_builtin(
|
|
4125
|
+
"tile_transpose",
|
|
4126
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, int])},
|
|
4127
|
+
value_func=tile_transpose_value_func,
|
|
4128
|
+
variadic=True,
|
|
4129
|
+
doc="""Transpose a tile.
|
|
4130
|
+
|
|
4131
|
+
For shared memory tiles, this operation will alias the input tile.
|
|
4132
|
+
Register tiles will first be transferred to shared memory before transposition.
|
|
3809
4133
|
|
|
3810
4134
|
:param a: Tile to transpose with ``shape=(M,N)``
|
|
3811
4135
|
:returns: Tile with ``shape=(N,M)``""",
|
|
@@ -3935,6 +4259,80 @@ add_builtin(
|
|
|
3935
4259
|
)
|
|
3936
4260
|
|
|
3937
4261
|
|
|
4262
|
+
def tile_sum_axis_value_func(arg_types, arg_values):
|
|
4263
|
+
if arg_types is None:
|
|
4264
|
+
return tile(dtype=Scalar, shape=Tuple[int, ...])
|
|
4265
|
+
|
|
4266
|
+
a = arg_types["a"]
|
|
4267
|
+
|
|
4268
|
+
if not is_tile(a):
|
|
4269
|
+
raise TypeError(f"tile_sum() 'a' argument must be a tile, got {a!r}")
|
|
4270
|
+
|
|
4271
|
+
# force input tile to shared
|
|
4272
|
+
a.storage = "shared"
|
|
4273
|
+
|
|
4274
|
+
axis = arg_values["axis"]
|
|
4275
|
+
shape = a.shape
|
|
4276
|
+
|
|
4277
|
+
if axis < 0 or axis >= len(shape):
|
|
4278
|
+
raise ValueError(f"tile_sum() axis {axis} is out of bounds for tile with {len(shape)} dimensions")
|
|
4279
|
+
|
|
4280
|
+
# shape is identical less the axis reduction is along
|
|
4281
|
+
if len(shape) > 1:
|
|
4282
|
+
new_shape = shape[:axis] + shape[axis + 1 :]
|
|
4283
|
+
else:
|
|
4284
|
+
new_shape = (1,)
|
|
4285
|
+
|
|
4286
|
+
return tile(dtype=a.dtype, shape=new_shape)
|
|
4287
|
+
|
|
4288
|
+
|
|
4289
|
+
def tile_sum_axis_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
4290
|
+
tile = arg_values["a"]
|
|
4291
|
+
axis_var = arg_values["axis"]
|
|
4292
|
+
if not hasattr(axis_var, "constant") or axis_var.constant is None:
|
|
4293
|
+
raise ValueError("tile_sum() axis must be a compile-time constant")
|
|
4294
|
+
axis = axis_var.constant
|
|
4295
|
+
|
|
4296
|
+
return ((tile,), (axis,))
|
|
4297
|
+
|
|
4298
|
+
|
|
4299
|
+
add_builtin(
|
|
4300
|
+
"tile_sum",
|
|
4301
|
+
input_types={"a": tile(dtype=Scalar, shape=Tuple[int, ...]), "axis": int},
|
|
4302
|
+
value_func=tile_sum_axis_value_func,
|
|
4303
|
+
dispatch_func=tile_sum_axis_dispatch_func,
|
|
4304
|
+
doc="""Cooperatively compute the sum of the tile elements across an axis of the tile using all threads in the block.
|
|
4305
|
+
|
|
4306
|
+
:param a: The input tile. Must reside in shared memory.
|
|
4307
|
+
:param axis: The tile axis to compute the sum across. Must be a compile-time constant.
|
|
4308
|
+
:returns: A tile with the same shape as the input tile less the axis dimension and the same data type as the input tile.
|
|
4309
|
+
|
|
4310
|
+
Example:
|
|
4311
|
+
|
|
4312
|
+
.. code-block:: python
|
|
4313
|
+
|
|
4314
|
+
@wp.kernel
|
|
4315
|
+
def compute():
|
|
4316
|
+
|
|
4317
|
+
t = wp.tile_ones(dtype=float, shape=(8, 8))
|
|
4318
|
+
s = wp.tile_sum(t, axis=0)
|
|
4319
|
+
|
|
4320
|
+
print(s)
|
|
4321
|
+
|
|
4322
|
+
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
|
|
4323
|
+
|
|
4324
|
+
Prints:
|
|
4325
|
+
|
|
4326
|
+
.. code-block:: text
|
|
4327
|
+
|
|
4328
|
+
[8 8 8 8 8 8 8 8] = tile(shape=(8), storage=register)
|
|
4329
|
+
|
|
4330
|
+
""",
|
|
4331
|
+
group="Tile Primitives",
|
|
4332
|
+
export=False,
|
|
4333
|
+
)
|
|
4334
|
+
|
|
4335
|
+
|
|
3938
4336
|
def tile_sort_value_func(arg_types, arg_values):
|
|
3939
4337
|
# return generic type (for doc builds)
|
|
3940
4338
|
if arg_types is None:
|
|
@@ -4011,7 +4409,7 @@ add_builtin(
|
|
|
4011
4409
|
""",
|
|
4012
4410
|
group="Tile Primitives",
|
|
4013
4411
|
export=False,
|
|
4014
|
-
|
|
4412
|
+
is_differentiable=False,
|
|
4015
4413
|
)
|
|
4016
4414
|
|
|
4017
4415
|
|
|
@@ -4065,7 +4463,7 @@ add_builtin(
|
|
|
4065
4463
|
""",
|
|
4066
4464
|
group="Tile Primitives",
|
|
4067
4465
|
export=False,
|
|
4068
|
-
|
|
4466
|
+
is_differentiable=False,
|
|
4069
4467
|
)
|
|
4070
4468
|
|
|
4071
4469
|
|
|
@@ -4119,7 +4517,7 @@ add_builtin(
|
|
|
4119
4517
|
""",
|
|
4120
4518
|
group="Tile Primitives",
|
|
4121
4519
|
export=False,
|
|
4122
|
-
|
|
4520
|
+
is_differentiable=False,
|
|
4123
4521
|
)
|
|
4124
4522
|
|
|
4125
4523
|
|
|
@@ -4172,7 +4570,7 @@ add_builtin(
|
|
|
4172
4570
|
""",
|
|
4173
4571
|
group="Tile Primitives",
|
|
4174
4572
|
export=False,
|
|
4175
|
-
|
|
4573
|
+
is_differentiable=False,
|
|
4176
4574
|
)
|
|
4177
4575
|
|
|
4178
4576
|
|
|
@@ -4225,11 +4623,10 @@ add_builtin(
|
|
|
4225
4623
|
""",
|
|
4226
4624
|
group="Tile Primitives",
|
|
4227
4625
|
export=False,
|
|
4228
|
-
|
|
4626
|
+
is_differentiable=False,
|
|
4229
4627
|
)
|
|
4230
4628
|
|
|
4231
4629
|
|
|
4232
|
-
# does type propagation for load()
|
|
4233
4630
|
def tile_reduce_value_func(arg_types, arg_values):
|
|
4234
4631
|
if arg_types is None:
|
|
4235
4632
|
return tile(dtype=Scalar, shape=(1,))
|
|
@@ -4283,7 +4680,88 @@ add_builtin(
|
|
|
4283
4680
|
""",
|
|
4284
4681
|
group="Tile Primitives",
|
|
4285
4682
|
export=False,
|
|
4286
|
-
|
|
4683
|
+
is_differentiable=False,
|
|
4684
|
+
)
|
|
4685
|
+
|
|
4686
|
+
|
|
4687
|
+
def tile_reduce_axis_value_func(arg_types, arg_values):
|
|
4688
|
+
if arg_types is None:
|
|
4689
|
+
return tile(dtype=Scalar, shape=Tuple[int, ...])
|
|
4690
|
+
|
|
4691
|
+
a = arg_types["a"]
|
|
4692
|
+
|
|
4693
|
+
if not is_tile(a):
|
|
4694
|
+
raise TypeError(f"tile_reduce() 'a' argument must be a tile, got {a!r}")
|
|
4695
|
+
|
|
4696
|
+
# force input tile to shared memory
|
|
4697
|
+
a.storage = "shared"
|
|
4698
|
+
|
|
4699
|
+
axis = arg_values["axis"]
|
|
4700
|
+
shape = a.shape
|
|
4701
|
+
|
|
4702
|
+
if axis < 0 or axis >= len(shape):
|
|
4703
|
+
raise ValueError(f"tile_reduce() axis {axis} is out of bounds for tile with {len(shape)} dimensions")
|
|
4704
|
+
|
|
4705
|
+
# shape is identical less the axis reduction is along
|
|
4706
|
+
if len(shape) > 1:
|
|
4707
|
+
new_shape = shape[:axis] + shape[axis + 1 :]
|
|
4708
|
+
else:
|
|
4709
|
+
new_shape = (1,)
|
|
4710
|
+
|
|
4711
|
+
return tile(dtype=a.dtype, shape=new_shape)
|
|
4712
|
+
|
|
4713
|
+
|
|
4714
|
+
add_builtin(
|
|
4715
|
+
"tile_reduce",
|
|
4716
|
+
input_types={"op": Callable, "a": tile(dtype=Scalar, shape=Tuple[int, ...]), "axis": int},
|
|
4717
|
+
value_func=tile_reduce_axis_value_func,
|
|
4718
|
+
native_func="tile_reduce_axis",
|
|
4719
|
+
doc="""Apply a custom reduction operator across a tile axis.
|
|
4720
|
+
|
|
4721
|
+
This function cooperatively performs a reduction using the provided operator across an axis of the tile.
|
|
4722
|
+
|
|
4723
|
+
:param op: A callable function that accepts two arguments and returns one argument, may be a user function or builtin
|
|
4724
|
+
:param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's data type. Must reside in shared memory.
|
|
4725
|
+
:param axis: The tile axis to perform the reduction across. Must be a compile-time constant.
|
|
4726
|
+
:returns: A tile with the same shape as the input tile less the axis dimension and the same data type as the input tile.
|
|
4727
|
+
|
|
4728
|
+
Example:
|
|
4729
|
+
|
|
4730
|
+
.. code-block:: python
|
|
4731
|
+
|
|
4732
|
+
TILE_M = wp.constant(4)
|
|
4733
|
+
TILE_N = wp.constant(2)
|
|
4734
|
+
|
|
4735
|
+
@wp.kernel
|
|
4736
|
+
def compute(x: wp.array2d(dtype=float), y: wp.array(dtype=float)):
|
|
4737
|
+
|
|
4738
|
+
a = wp.tile_load(x, shape=(TILE_M, TILE_N))
|
|
4739
|
+
b = wp.tile_reduce(wp.add, a, axis=1)
|
|
4740
|
+
wp.tile_store(y, b)
|
|
4741
|
+
|
|
4742
|
+
arr = np.arange(TILE_M * TILE_N).reshape(TILE_M, TILE_N)
|
|
4743
|
+
|
|
4744
|
+
x = wp.array(arr, dtype=float)
|
|
4745
|
+
y = wp.zeros(TILE_M, dtype=float)
|
|
4746
|
+
|
|
4747
|
+
wp.launch_tiled(compute, dim=[1], inputs=[x], outputs=[y], block_dim=32)
|
|
4748
|
+
|
|
4749
|
+
print(x.numpy())
|
|
4750
|
+
print(y.numpy())
|
|
4751
|
+
|
|
4752
|
+
Prints:
|
|
4753
|
+
|
|
4754
|
+
.. code-block:: text
|
|
4755
|
+
|
|
4756
|
+
[[0. 1.]
|
|
4757
|
+
[2. 3.]
|
|
4758
|
+
[4. 5.]
|
|
4759
|
+
[6. 7.]]
|
|
4760
|
+
[ 1. 5. 9. 13.]
|
|
4761
|
+
""",
|
|
4762
|
+
group="Tile Primitives",
|
|
4763
|
+
export=False,
|
|
4764
|
+
is_differentiable=False,
|
|
4287
4765
|
)
|
|
4288
4766
|
|
|
4289
4767
|
|
|
@@ -4347,7 +4825,7 @@ add_builtin(
|
|
|
4347
4825
|
""",
|
|
4348
4826
|
group="Tile Primitives",
|
|
4349
4827
|
export=False,
|
|
4350
|
-
|
|
4828
|
+
is_differentiable=False,
|
|
4351
4829
|
)
|
|
4352
4830
|
|
|
4353
4831
|
|
|
@@ -4411,7 +4889,7 @@ add_builtin(
|
|
|
4411
4889
|
""",
|
|
4412
4890
|
group="Tile Primitives",
|
|
4413
4891
|
export=False,
|
|
4414
|
-
|
|
4892
|
+
is_differentiable=False,
|
|
4415
4893
|
)
|
|
4416
4894
|
|
|
4417
4895
|
|
|
@@ -4665,7 +5143,7 @@ add_builtin(
|
|
|
4665
5143
|
doc="WIP",
|
|
4666
5144
|
group="Utility",
|
|
4667
5145
|
hidden=True,
|
|
4668
|
-
|
|
5146
|
+
is_differentiable=False,
|
|
4669
5147
|
)
|
|
4670
5148
|
|
|
4671
5149
|
add_builtin(
|
|
@@ -4681,7 +5159,7 @@ add_builtin(
|
|
|
4681
5159
|
doc="WIP",
|
|
4682
5160
|
group="Utility",
|
|
4683
5161
|
hidden=True,
|
|
4684
|
-
|
|
5162
|
+
is_differentiable=False,
|
|
4685
5163
|
)
|
|
4686
5164
|
|
|
4687
5165
|
add_builtin(
|
|
@@ -4691,7 +5169,7 @@ add_builtin(
|
|
|
4691
5169
|
doc="WIP",
|
|
4692
5170
|
group="Utility",
|
|
4693
5171
|
hidden=True,
|
|
4694
|
-
|
|
5172
|
+
is_differentiable=False,
|
|
4695
5173
|
)
|
|
4696
5174
|
|
|
4697
5175
|
add_builtin(
|
|
@@ -4743,7 +5221,7 @@ add_builtin(
|
|
|
4743
5221
|
:param low: The lower bound of the bounding box in BVH space
|
|
4744
5222
|
:param high: The upper bound of the bounding box in BVH space""",
|
|
4745
5223
|
export=False,
|
|
4746
|
-
|
|
5224
|
+
is_differentiable=False,
|
|
4747
5225
|
)
|
|
4748
5226
|
|
|
4749
5227
|
add_builtin(
|
|
@@ -4759,7 +5237,7 @@ add_builtin(
|
|
|
4759
5237
|
:param start: The start of the ray in BVH space
|
|
4760
5238
|
:param dir: The direction of the ray in BVH space""",
|
|
4761
5239
|
export=False,
|
|
4762
|
-
|
|
5240
|
+
is_differentiable=False,
|
|
4763
5241
|
)
|
|
4764
5242
|
|
|
4765
5243
|
add_builtin(
|
|
@@ -4770,7 +5248,7 @@ add_builtin(
|
|
|
4770
5248
|
doc="""Move to the next bound returned by the query.
|
|
4771
5249
|
The index of the current bound is stored in ``index``, returns ``False`` if there are no more overlapping bound.""",
|
|
4772
5250
|
export=False,
|
|
4773
|
-
|
|
5251
|
+
is_differentiable=False,
|
|
4774
5252
|
)
|
|
4775
5253
|
|
|
4776
5254
|
add_builtin(
|
|
@@ -5111,7 +5589,7 @@ add_builtin(
|
|
|
5111
5589
|
:param low: The lower bound of the bounding box in mesh space
|
|
5112
5590
|
:param high: The upper bound of the bounding box in mesh space""",
|
|
5113
5591
|
export=False,
|
|
5114
|
-
|
|
5592
|
+
is_differentiable=False,
|
|
5115
5593
|
)
|
|
5116
5594
|
|
|
5117
5595
|
add_builtin(
|
|
@@ -5123,7 +5601,7 @@ add_builtin(
|
|
|
5123
5601
|
|
|
5124
5602
|
The index of the current face is stored in ``index``, returns ``False`` if there are no more overlapping triangles.""",
|
|
5125
5603
|
export=False,
|
|
5126
|
-
|
|
5604
|
+
is_differentiable=False,
|
|
5127
5605
|
)
|
|
5128
5606
|
|
|
5129
5607
|
add_builtin(
|
|
@@ -5153,7 +5631,7 @@ add_builtin(
|
|
|
5153
5631
|
|
|
5154
5632
|
This query can be used to iterate over all neighboring point within a fixed radius from the query point.""",
|
|
5155
5633
|
export=False,
|
|
5156
|
-
|
|
5634
|
+
is_differentiable=False,
|
|
5157
5635
|
)
|
|
5158
5636
|
|
|
5159
5637
|
add_builtin(
|
|
@@ -5165,7 +5643,7 @@ add_builtin(
|
|
|
5165
5643
|
|
|
5166
5644
|
The index of the current neighbor is stored in ``index``, returns ``False`` if there are no more neighbors.""",
|
|
5167
5645
|
export=False,
|
|
5168
|
-
|
|
5646
|
+
is_differentiable=False,
|
|
5169
5647
|
)
|
|
5170
5648
|
|
|
5171
5649
|
add_builtin(
|
|
@@ -5179,7 +5657,7 @@ add_builtin(
|
|
|
5179
5657
|
|
|
5180
5658
|
Returns -1 if the :class:`HashGrid` has not been reserved.""",
|
|
5181
5659
|
export=False,
|
|
5182
|
-
|
|
5660
|
+
is_differentiable=False,
|
|
5183
5661
|
)
|
|
5184
5662
|
|
|
5185
5663
|
add_builtin(
|
|
@@ -5189,16 +5667,34 @@ add_builtin(
|
|
|
5189
5667
|
group="Geometry",
|
|
5190
5668
|
doc="""Tests for intersection between two triangles (v0, v1, v2) and (u0, u1, u2) using Moller's method.
|
|
5191
5669
|
|
|
5670
|
+
This function works with single precision, may return incorrect results in some case.
|
|
5671
|
+
|
|
5672
|
+
Returns > 0 if triangles intersect.""",
|
|
5673
|
+
export=False,
|
|
5674
|
+
is_differentiable=False,
|
|
5675
|
+
)
|
|
5676
|
+
|
|
5677
|
+
|
|
5678
|
+
add_builtin(
|
|
5679
|
+
"intersect_tri_tri",
|
|
5680
|
+
input_types={"v0": vec3d, "v1": vec3d, "v2": vec3d, "u0": vec3d, "u1": vec3d, "u2": vec3d},
|
|
5681
|
+
value_type=int,
|
|
5682
|
+
group="Geometry",
|
|
5683
|
+
doc="""Tests for intersection between two triangles (v0, v1, v2) and (u0, u1, u2) using Moller's method.
|
|
5684
|
+
|
|
5685
|
+
This function works with double precision, results are more accurate than the single precision version.
|
|
5686
|
+
|
|
5192
5687
|
Returns > 0 if triangles intersect.""",
|
|
5193
5688
|
export=False,
|
|
5194
|
-
|
|
5689
|
+
is_differentiable=False,
|
|
5195
5690
|
)
|
|
5196
5691
|
|
|
5692
|
+
|
|
5197
5693
|
add_builtin(
|
|
5198
5694
|
"mesh_get",
|
|
5199
5695
|
input_types={"id": uint64},
|
|
5200
5696
|
value_type=Mesh,
|
|
5201
|
-
|
|
5697
|
+
is_differentiable=False,
|
|
5202
5698
|
group="Geometry",
|
|
5203
5699
|
doc="""Retrieves the mesh given its index.""",
|
|
5204
5700
|
export=False,
|
|
@@ -5211,7 +5707,7 @@ add_builtin(
|
|
|
5211
5707
|
group="Geometry",
|
|
5212
5708
|
doc="""Evaluates the face normal the mesh given a face index.""",
|
|
5213
5709
|
export=False,
|
|
5214
|
-
|
|
5710
|
+
is_differentiable=False,
|
|
5215
5711
|
)
|
|
5216
5712
|
|
|
5217
5713
|
add_builtin(
|
|
@@ -5221,7 +5717,7 @@ add_builtin(
|
|
|
5221
5717
|
group="Geometry",
|
|
5222
5718
|
doc="""Returns the point of the mesh given a index.""",
|
|
5223
5719
|
export=False,
|
|
5224
|
-
|
|
5720
|
+
is_differentiable=False,
|
|
5225
5721
|
)
|
|
5226
5722
|
|
|
5227
5723
|
add_builtin(
|
|
@@ -5231,7 +5727,7 @@ add_builtin(
|
|
|
5231
5727
|
group="Geometry",
|
|
5232
5728
|
doc="""Returns the velocity of the mesh given a index.""",
|
|
5233
5729
|
export=False,
|
|
5234
|
-
|
|
5730
|
+
is_differentiable=False,
|
|
5235
5731
|
)
|
|
5236
5732
|
|
|
5237
5733
|
add_builtin(
|
|
@@ -5241,7 +5737,7 @@ add_builtin(
|
|
|
5241
5737
|
group="Geometry",
|
|
5242
5738
|
doc="""Returns the point-index of the mesh given a face-vertex index.""",
|
|
5243
5739
|
export=False,
|
|
5244
|
-
|
|
5740
|
+
is_differentiable=False,
|
|
5245
5741
|
)
|
|
5246
5742
|
|
|
5247
5743
|
|
|
@@ -5289,7 +5785,7 @@ add_builtin(
|
|
|
5289
5785
|
group="Utility",
|
|
5290
5786
|
export=False,
|
|
5291
5787
|
hidden=True,
|
|
5292
|
-
|
|
5788
|
+
is_differentiable=False,
|
|
5293
5789
|
)
|
|
5294
5790
|
add_builtin(
|
|
5295
5791
|
"iter_next",
|
|
@@ -5298,7 +5794,7 @@ add_builtin(
|
|
|
5298
5794
|
group="Utility",
|
|
5299
5795
|
export=False,
|
|
5300
5796
|
hidden=True,
|
|
5301
|
-
|
|
5797
|
+
is_differentiable=False,
|
|
5302
5798
|
)
|
|
5303
5799
|
add_builtin(
|
|
5304
5800
|
"iter_next",
|
|
@@ -5307,7 +5803,7 @@ add_builtin(
|
|
|
5307
5803
|
group="Utility",
|
|
5308
5804
|
export=False,
|
|
5309
5805
|
hidden=True,
|
|
5310
|
-
|
|
5806
|
+
is_differentiable=False,
|
|
5311
5807
|
)
|
|
5312
5808
|
|
|
5313
5809
|
add_builtin(
|
|
@@ -5318,7 +5814,7 @@ add_builtin(
|
|
|
5318
5814
|
group="Utility",
|
|
5319
5815
|
doc="""Returns the range in reversed order.""",
|
|
5320
5816
|
export=False,
|
|
5321
|
-
|
|
5817
|
+
is_differentiable=False,
|
|
5322
5818
|
)
|
|
5323
5819
|
|
|
5324
5820
|
# ---------------------------------
|
|
@@ -5338,8 +5834,8 @@ _volume_supported_value_types = {
|
|
|
5338
5834
|
|
|
5339
5835
|
|
|
5340
5836
|
def _is_volume_type_supported(dtype):
|
|
5341
|
-
for
|
|
5342
|
-
if types_equal(
|
|
5837
|
+
for value_type in _volume_supported_value_types:
|
|
5838
|
+
if types_equal(value_type, dtype):
|
|
5343
5839
|
return True
|
|
5344
5840
|
return False
|
|
5345
5841
|
|
|
@@ -5467,7 +5963,7 @@ add_builtin(
|
|
|
5467
5963
|
doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k`` for a volume of type type `dtype`.
|
|
5468
5964
|
|
|
5469
5965
|
If the voxel at this index does not exist, this function returns the background value.""",
|
|
5470
|
-
|
|
5966
|
+
is_differentiable=False,
|
|
5471
5967
|
)
|
|
5472
5968
|
|
|
5473
5969
|
|
|
@@ -5488,7 +5984,7 @@ add_builtin(
|
|
|
5488
5984
|
export=False,
|
|
5489
5985
|
group="Volumes",
|
|
5490
5986
|
doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
|
|
5491
|
-
|
|
5987
|
+
is_differentiable=False,
|
|
5492
5988
|
)
|
|
5493
5989
|
|
|
5494
5990
|
add_builtin(
|
|
@@ -5519,7 +6015,7 @@ add_builtin(
|
|
|
5519
6015
|
doc="""Returns the value of voxel with coordinates ``i``, ``j``, ``k``.
|
|
5520
6016
|
|
|
5521
6017
|
If the voxel at this index does not exist, this function returns the background value""",
|
|
5522
|
-
|
|
6018
|
+
is_differentiable=False,
|
|
5523
6019
|
)
|
|
5524
6020
|
|
|
5525
6021
|
add_builtin(
|
|
@@ -5528,7 +6024,7 @@ add_builtin(
|
|
|
5528
6024
|
group="Volumes",
|
|
5529
6025
|
doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
|
|
5530
6026
|
export=False,
|
|
5531
|
-
|
|
6027
|
+
is_differentiable=False,
|
|
5532
6028
|
)
|
|
5533
6029
|
|
|
5534
6030
|
add_builtin(
|
|
@@ -5549,7 +6045,7 @@ add_builtin(
|
|
|
5549
6045
|
doc="""Returns the vector value of voxel with coordinates ``i``, ``j``, ``k``.
|
|
5550
6046
|
|
|
5551
6047
|
If the voxel at this index does not exist, this function returns the background value.""",
|
|
5552
|
-
|
|
6048
|
+
is_differentiable=False,
|
|
5553
6049
|
)
|
|
5554
6050
|
|
|
5555
6051
|
add_builtin(
|
|
@@ -5558,7 +6054,7 @@ add_builtin(
|
|
|
5558
6054
|
group="Volumes",
|
|
5559
6055
|
doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
|
|
5560
6056
|
export=False,
|
|
5561
|
-
|
|
6057
|
+
is_differentiable=False,
|
|
5562
6058
|
)
|
|
5563
6059
|
|
|
5564
6060
|
add_builtin(
|
|
@@ -5577,7 +6073,7 @@ add_builtin(
|
|
|
5577
6073
|
doc="""Returns the :class:`int32` value of voxel with coordinates ``i``, ``j``, ``k``.
|
|
5578
6074
|
|
|
5579
6075
|
If the voxel at this index does not exist, this function returns the background value.""",
|
|
5580
|
-
|
|
6076
|
+
is_differentiable=False,
|
|
5581
6077
|
)
|
|
5582
6078
|
|
|
5583
6079
|
add_builtin(
|
|
@@ -5586,7 +6082,7 @@ add_builtin(
|
|
|
5586
6082
|
group="Volumes",
|
|
5587
6083
|
doc="""Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``.""",
|
|
5588
6084
|
export=False,
|
|
5589
|
-
|
|
6085
|
+
is_differentiable=False,
|
|
5590
6086
|
)
|
|
5591
6087
|
|
|
5592
6088
|
|
|
@@ -5668,7 +6164,7 @@ add_builtin(
|
|
|
5668
6164
|
If the voxel at this index does not exist, this function returns -1.
|
|
5669
6165
|
This function is available for both index grids and classical volumes.
|
|
5670
6166
|
""",
|
|
5671
|
-
|
|
6167
|
+
is_differentiable=False,
|
|
5672
6168
|
)
|
|
5673
6169
|
|
|
5674
6170
|
add_builtin(
|
|
@@ -5710,7 +6206,7 @@ add_builtin(
|
|
|
5710
6206
|
value_type=uint32,
|
|
5711
6207
|
group="Random",
|
|
5712
6208
|
doc="Initialize a new random number generator given a user-defined seed. Returns a 32-bit integer representing the RNG state.",
|
|
5713
|
-
|
|
6209
|
+
is_differentiable=False,
|
|
5714
6210
|
)
|
|
5715
6211
|
|
|
5716
6212
|
add_builtin(
|
|
@@ -5722,7 +6218,7 @@ add_builtin(
|
|
|
5722
6218
|
|
|
5723
6219
|
This alternative constructor can be useful in parallel programs, where a kernel as a whole should share a seed,
|
|
5724
6220
|
but each thread should generate uncorrelated values. In this case usage should be ``r = rand_init(seed, tid)``""",
|
|
5725
|
-
|
|
6221
|
+
is_differentiable=False,
|
|
5726
6222
|
)
|
|
5727
6223
|
|
|
5728
6224
|
add_builtin(
|
|
@@ -5731,7 +6227,7 @@ add_builtin(
|
|
|
5731
6227
|
value_type=int,
|
|
5732
6228
|
group="Random",
|
|
5733
6229
|
doc="Return a random integer in the range [-2^31, 2^31).",
|
|
5734
|
-
|
|
6230
|
+
is_differentiable=False,
|
|
5735
6231
|
)
|
|
5736
6232
|
add_builtin(
|
|
5737
6233
|
"randi",
|
|
@@ -5739,7 +6235,7 @@ add_builtin(
|
|
|
5739
6235
|
value_type=int,
|
|
5740
6236
|
group="Random",
|
|
5741
6237
|
doc="Return a random integer between [low, high).",
|
|
5742
|
-
|
|
6238
|
+
is_differentiable=False,
|
|
5743
6239
|
)
|
|
5744
6240
|
add_builtin(
|
|
5745
6241
|
"randu",
|
|
@@ -5747,7 +6243,7 @@ add_builtin(
|
|
|
5747
6243
|
value_type=uint32,
|
|
5748
6244
|
group="Random",
|
|
5749
6245
|
doc="Return a random unsigned integer in the range [0, 2^32).",
|
|
5750
|
-
|
|
6246
|
+
is_differentiable=False,
|
|
5751
6247
|
)
|
|
5752
6248
|
add_builtin(
|
|
5753
6249
|
"randu",
|
|
@@ -5755,7 +6251,7 @@ add_builtin(
|
|
|
5755
6251
|
value_type=uint32,
|
|
5756
6252
|
group="Random",
|
|
5757
6253
|
doc="Return a random unsigned integer between [low, high).",
|
|
5758
|
-
|
|
6254
|
+
is_differentiable=False,
|
|
5759
6255
|
)
|
|
5760
6256
|
add_builtin(
|
|
5761
6257
|
"randf",
|
|
@@ -5763,7 +6259,7 @@ add_builtin(
|
|
|
5763
6259
|
value_type=float,
|
|
5764
6260
|
group="Random",
|
|
5765
6261
|
doc="Return a random float between [0.0, 1.0).",
|
|
5766
|
-
|
|
6262
|
+
is_differentiable=False,
|
|
5767
6263
|
)
|
|
5768
6264
|
add_builtin(
|
|
5769
6265
|
"randf",
|
|
@@ -5771,7 +6267,7 @@ add_builtin(
|
|
|
5771
6267
|
value_type=float,
|
|
5772
6268
|
group="Random",
|
|
5773
6269
|
doc="Return a random float between [low, high).",
|
|
5774
|
-
|
|
6270
|
+
is_differentiable=False,
|
|
5775
6271
|
)
|
|
5776
6272
|
add_builtin(
|
|
5777
6273
|
"randn",
|
|
@@ -5779,7 +6275,7 @@ add_builtin(
|
|
|
5779
6275
|
value_type=float,
|
|
5780
6276
|
group="Random",
|
|
5781
6277
|
doc="Sample a normal (Gaussian) distribution of mean 0 and variance 1. ",
|
|
5782
|
-
|
|
6278
|
+
is_differentiable=False,
|
|
5783
6279
|
)
|
|
5784
6280
|
|
|
5785
6281
|
add_builtin(
|
|
@@ -5788,7 +6284,7 @@ add_builtin(
|
|
|
5788
6284
|
value_type=int,
|
|
5789
6285
|
group="Random",
|
|
5790
6286
|
doc="Inverse-transform sample a cumulative distribution function.",
|
|
5791
|
-
|
|
6287
|
+
is_differentiable=False,
|
|
5792
6288
|
)
|
|
5793
6289
|
add_builtin(
|
|
5794
6290
|
"sample_triangle",
|
|
@@ -5796,7 +6292,7 @@ add_builtin(
|
|
|
5796
6292
|
value_type=vec2,
|
|
5797
6293
|
group="Random",
|
|
5798
6294
|
doc="Uniformly sample a triangle. Returns sample barycentric coordinates.",
|
|
5799
|
-
|
|
6295
|
+
is_differentiable=False,
|
|
5800
6296
|
)
|
|
5801
6297
|
add_builtin(
|
|
5802
6298
|
"sample_unit_ring",
|
|
@@ -5804,7 +6300,7 @@ add_builtin(
|
|
|
5804
6300
|
value_type=vec2,
|
|
5805
6301
|
group="Random",
|
|
5806
6302
|
doc="Uniformly sample a ring in the xy plane.",
|
|
5807
|
-
|
|
6303
|
+
is_differentiable=False,
|
|
5808
6304
|
)
|
|
5809
6305
|
add_builtin(
|
|
5810
6306
|
"sample_unit_disk",
|
|
@@ -5812,7 +6308,7 @@ add_builtin(
|
|
|
5812
6308
|
value_type=vec2,
|
|
5813
6309
|
group="Random",
|
|
5814
6310
|
doc="Uniformly sample a disk in the xy plane.",
|
|
5815
|
-
|
|
6311
|
+
is_differentiable=False,
|
|
5816
6312
|
)
|
|
5817
6313
|
add_builtin(
|
|
5818
6314
|
"sample_unit_sphere_surface",
|
|
@@ -5820,7 +6316,7 @@ add_builtin(
|
|
|
5820
6316
|
value_type=vec3,
|
|
5821
6317
|
group="Random",
|
|
5822
6318
|
doc="Uniformly sample a unit sphere surface.",
|
|
5823
|
-
|
|
6319
|
+
is_differentiable=False,
|
|
5824
6320
|
)
|
|
5825
6321
|
add_builtin(
|
|
5826
6322
|
"sample_unit_sphere",
|
|
@@ -5828,7 +6324,7 @@ add_builtin(
|
|
|
5828
6324
|
value_type=vec3,
|
|
5829
6325
|
group="Random",
|
|
5830
6326
|
doc="Uniformly sample a unit sphere.",
|
|
5831
|
-
|
|
6327
|
+
is_differentiable=False,
|
|
5832
6328
|
)
|
|
5833
6329
|
add_builtin(
|
|
5834
6330
|
"sample_unit_hemisphere_surface",
|
|
@@ -5836,7 +6332,7 @@ add_builtin(
|
|
|
5836
6332
|
value_type=vec3,
|
|
5837
6333
|
group="Random",
|
|
5838
6334
|
doc="Uniformly sample a unit hemisphere surface.",
|
|
5839
|
-
|
|
6335
|
+
is_differentiable=False,
|
|
5840
6336
|
)
|
|
5841
6337
|
add_builtin(
|
|
5842
6338
|
"sample_unit_hemisphere",
|
|
@@ -5844,7 +6340,7 @@ add_builtin(
|
|
|
5844
6340
|
value_type=vec3,
|
|
5845
6341
|
group="Random",
|
|
5846
6342
|
doc="Uniformly sample a unit hemisphere.",
|
|
5847
|
-
|
|
6343
|
+
is_differentiable=False,
|
|
5848
6344
|
)
|
|
5849
6345
|
add_builtin(
|
|
5850
6346
|
"sample_unit_square",
|
|
@@ -5852,7 +6348,7 @@ add_builtin(
|
|
|
5852
6348
|
value_type=vec2,
|
|
5853
6349
|
group="Random",
|
|
5854
6350
|
doc="Uniformly sample a unit square.",
|
|
5855
|
-
|
|
6351
|
+
is_differentiable=False,
|
|
5856
6352
|
)
|
|
5857
6353
|
add_builtin(
|
|
5858
6354
|
"sample_unit_cube",
|
|
@@ -5860,7 +6356,7 @@ add_builtin(
|
|
|
5860
6356
|
value_type=vec3,
|
|
5861
6357
|
group="Random",
|
|
5862
6358
|
doc="Uniformly sample a unit cube.",
|
|
5863
|
-
|
|
6359
|
+
is_differentiable=False,
|
|
5864
6360
|
)
|
|
5865
6361
|
|
|
5866
6362
|
add_builtin(
|
|
@@ -5872,7 +6368,7 @@ add_builtin(
|
|
|
5872
6368
|
|
|
5873
6369
|
:param state: RNG state
|
|
5874
6370
|
:param lam: The expected value of the distribution""",
|
|
5875
|
-
|
|
6371
|
+
is_differentiable=False,
|
|
5876
6372
|
)
|
|
5877
6373
|
|
|
5878
6374
|
add_builtin(
|
|
@@ -5940,7 +6436,7 @@ add_builtin(
|
|
|
5940
6436
|
value_type=vec2,
|
|
5941
6437
|
group="Random",
|
|
5942
6438
|
doc="Divergence-free vector field based on the gradient of a Perlin noise function.",
|
|
5943
|
-
|
|
6439
|
+
is_differentiable=False,
|
|
5944
6440
|
)
|
|
5945
6441
|
add_builtin(
|
|
5946
6442
|
"curlnoise",
|
|
@@ -5949,7 +6445,7 @@ add_builtin(
|
|
|
5949
6445
|
value_type=vec3,
|
|
5950
6446
|
group="Random",
|
|
5951
6447
|
doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
|
|
5952
|
-
|
|
6448
|
+
is_differentiable=False,
|
|
5953
6449
|
)
|
|
5954
6450
|
add_builtin(
|
|
5955
6451
|
"curlnoise",
|
|
@@ -5958,7 +6454,7 @@ add_builtin(
|
|
|
5958
6454
|
value_type=vec3,
|
|
5959
6455
|
group="Random",
|
|
5960
6456
|
doc="Divergence-free vector field based on the curl of three Perlin noise functions.",
|
|
5961
|
-
|
|
6457
|
+
is_differentiable=False,
|
|
5962
6458
|
)
|
|
5963
6459
|
|
|
5964
6460
|
|
|
@@ -5990,7 +6486,7 @@ add_builtin(
|
|
|
5990
6486
|
dispatch_func=printf_dispatch_func,
|
|
5991
6487
|
group="Utility",
|
|
5992
6488
|
doc="Allows printing formatted strings using C-style format specifiers.",
|
|
5993
|
-
|
|
6489
|
+
is_differentiable=False,
|
|
5994
6490
|
)
|
|
5995
6491
|
|
|
5996
6492
|
add_builtin(
|
|
@@ -6009,7 +6505,7 @@ add_builtin(
|
|
|
6009
6505
|
group="Utility",
|
|
6010
6506
|
namespace="",
|
|
6011
6507
|
native_func="__debugbreak",
|
|
6012
|
-
|
|
6508
|
+
is_differentiable=False,
|
|
6013
6509
|
)
|
|
6014
6510
|
|
|
6015
6511
|
# helpers
|
|
@@ -6027,7 +6523,7 @@ add_builtin(
|
|
|
6027
6523
|
This function may not be called from user-defined Warp functions.""",
|
|
6028
6524
|
namespace="",
|
|
6029
6525
|
native_func="builtin_tid1d",
|
|
6030
|
-
|
|
6526
|
+
is_differentiable=False,
|
|
6031
6527
|
)
|
|
6032
6528
|
|
|
6033
6529
|
add_builtin(
|
|
@@ -6038,7 +6534,7 @@ add_builtin(
|
|
|
6038
6534
|
doc="Returns the number of threads in the current block.",
|
|
6039
6535
|
namespace="",
|
|
6040
6536
|
native_func="builtin_block_dim",
|
|
6041
|
-
|
|
6537
|
+
is_differentiable=False,
|
|
6042
6538
|
)
|
|
6043
6539
|
|
|
6044
6540
|
add_builtin(
|
|
@@ -6053,7 +6549,7 @@ add_builtin(
|
|
|
6053
6549
|
This function may not be called from user-defined Warp functions.""",
|
|
6054
6550
|
namespace="",
|
|
6055
6551
|
native_func="builtin_tid2d",
|
|
6056
|
-
|
|
6552
|
+
is_differentiable=False,
|
|
6057
6553
|
)
|
|
6058
6554
|
|
|
6059
6555
|
add_builtin(
|
|
@@ -6068,7 +6564,7 @@ add_builtin(
|
|
|
6068
6564
|
This function may not be called from user-defined Warp functions.""",
|
|
6069
6565
|
namespace="",
|
|
6070
6566
|
native_func="builtin_tid3d",
|
|
6071
|
-
|
|
6567
|
+
is_differentiable=False,
|
|
6072
6568
|
)
|
|
6073
6569
|
|
|
6074
6570
|
add_builtin(
|
|
@@ -6083,7 +6579,7 @@ add_builtin(
|
|
|
6083
6579
|
This function may not be called from user-defined Warp functions.""",
|
|
6084
6580
|
namespace="",
|
|
6085
6581
|
native_func="builtin_tid4d",
|
|
6086
|
-
|
|
6582
|
+
is_differentiable=False,
|
|
6087
6583
|
)
|
|
6088
6584
|
|
|
6089
6585
|
|
|
@@ -6127,56 +6623,20 @@ def select_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
6127
6623
|
if arg_types is None:
|
|
6128
6624
|
return Any
|
|
6129
6625
|
|
|
6130
|
-
|
|
6131
|
-
v_false = arg_types["value_if_false"]
|
|
6132
|
-
|
|
6133
|
-
if not types_equal(v_true, v_false):
|
|
6134
|
-
raise RuntimeError(
|
|
6135
|
-
f"select() true value type ({v_true}) must be of the same type as the false type ({v_false})"
|
|
6136
|
-
)
|
|
6137
|
-
|
|
6138
|
-
if is_tile(v_false):
|
|
6139
|
-
if v_true.storage == "register":
|
|
6140
|
-
return v_true
|
|
6141
|
-
if v_false.storage == "register":
|
|
6142
|
-
return v_false
|
|
6143
|
-
|
|
6144
|
-
# both v_true and v_false are shared
|
|
6145
|
-
return tile(
|
|
6146
|
-
dtype=v_true.dtype,
|
|
6147
|
-
shape=v_true.shape,
|
|
6148
|
-
storage=v_true.storage,
|
|
6149
|
-
strides=v_true.strides,
|
|
6150
|
-
layout=v_true.layout,
|
|
6151
|
-
owner=True,
|
|
6152
|
-
)
|
|
6153
|
-
|
|
6154
|
-
return v_true
|
|
6155
|
-
|
|
6156
|
-
|
|
6157
|
-
def select_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
6158
|
-
warp.utils.warn(
|
|
6159
|
-
"wp.select() is deprecated and will be removed in a future\n"
|
|
6160
|
-
"version. Use wp.where(cond, value_if_true, value_if_false) instead.",
|
|
6161
|
-
category=DeprecationWarning,
|
|
6162
|
-
)
|
|
6163
|
-
|
|
6164
|
-
func_args = tuple(args.values())
|
|
6165
|
-
template_args = ()
|
|
6166
|
-
|
|
6167
|
-
return (func_args, template_args)
|
|
6626
|
+
raise RuntimeError("wp.select() has been removed. Use wp.where(cond, value_if_true, value_if_false) instead.")
|
|
6168
6627
|
|
|
6169
6628
|
|
|
6170
6629
|
add_builtin(
|
|
6171
6630
|
"select",
|
|
6172
6631
|
input_types={"cond": builtins.bool, "value_if_false": Any, "value_if_true": Any},
|
|
6173
6632
|
value_func=select_value_func,
|
|
6174
|
-
dispatch_func=select_dispatch_func,
|
|
6175
6633
|
doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
6176
6634
|
|
|
6177
|
-
..
|
|
6635
|
+
.. versionremoved:: 1.10
|
|
6178
6636
|
Use :func:`where` instead, which has the more intuitive argument order:
|
|
6179
|
-
``where(cond, value_if_true, value_if_false)``.
|
|
6637
|
+
``where(cond, value_if_true, value_if_false)``.
|
|
6638
|
+
|
|
6639
|
+
.. deprecated:: 1.7""",
|
|
6180
6640
|
group="Utility",
|
|
6181
6641
|
)
|
|
6182
6642
|
for t in int_types:
|
|
@@ -6184,24 +6644,26 @@ for t in int_types:
|
|
|
6184
6644
|
"select",
|
|
6185
6645
|
input_types={"cond": t, "value_if_false": Any, "value_if_true": Any},
|
|
6186
6646
|
value_func=select_value_func,
|
|
6187
|
-
dispatch_func=select_dispatch_func,
|
|
6188
6647
|
doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
6189
6648
|
|
|
6190
|
-
..
|
|
6649
|
+
.. versionremoved:: 1.10
|
|
6191
6650
|
Use :func:`where` instead, which has the more intuitive argument order:
|
|
6192
|
-
``where(cond, value_if_true, value_if_false)``.
|
|
6651
|
+
``where(cond, value_if_true, value_if_false)``.
|
|
6652
|
+
|
|
6653
|
+
.. deprecated:: 1.7""",
|
|
6193
6654
|
group="Utility",
|
|
6194
6655
|
)
|
|
6195
6656
|
add_builtin(
|
|
6196
6657
|
"select",
|
|
6197
6658
|
input_types={"arr": array(dtype=Any), "value_if_false": Any, "value_if_true": Any},
|
|
6198
6659
|
value_func=select_value_func,
|
|
6199
|
-
dispatch_func=select_dispatch_func,
|
|
6200
6660
|
doc="""Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
6201
6661
|
|
|
6202
|
-
..
|
|
6662
|
+
.. versionremoved:: 1.10
|
|
6203
6663
|
Use :func:`where` instead, which has the more intuitive argument order:
|
|
6204
|
-
``where(arr, value_if_true, value_if_false)``.
|
|
6664
|
+
``where(arr, value_if_true, value_if_false)``.
|
|
6665
|
+
|
|
6666
|
+
.. deprecated:: 1.7""",
|
|
6205
6667
|
group="Utility",
|
|
6206
6668
|
)
|
|
6207
6669
|
|
|
@@ -6291,7 +6753,7 @@ add_builtin(
|
|
|
6291
6753
|
group="Utility",
|
|
6292
6754
|
hidden=True,
|
|
6293
6755
|
export=False,
|
|
6294
|
-
|
|
6756
|
+
is_differentiable=False,
|
|
6295
6757
|
)
|
|
6296
6758
|
|
|
6297
6759
|
|
|
@@ -6332,7 +6794,7 @@ add_builtin(
|
|
|
6332
6794
|
native_func="fixedarray_t",
|
|
6333
6795
|
group="Utility",
|
|
6334
6796
|
export=False,
|
|
6335
|
-
|
|
6797
|
+
is_differentiable=False,
|
|
6336
6798
|
hidden=True, # Unhide once we can document both a built-in and a Python scope function sharing the same name.
|
|
6337
6799
|
)
|
|
6338
6800
|
|
|
@@ -6375,14 +6837,13 @@ for array_type in array_types:
|
|
|
6375
6837
|
# does argument checking and type propagation for view()
|
|
6376
6838
|
def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
6377
6839
|
arr_type = arg_types["arr"]
|
|
6378
|
-
idx_types = tuple(arg_types[x] for x in "
|
|
6840
|
+
idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
|
|
6379
6841
|
|
|
6380
6842
|
if not is_array(arr_type):
|
|
6381
6843
|
raise RuntimeError("view() first argument must be an array")
|
|
6382
6844
|
|
|
6383
6845
|
idx_count = len(idx_types)
|
|
6384
|
-
|
|
6385
|
-
if idx_count >= arr_type.ndim:
|
|
6846
|
+
if idx_count > arr_type.ndim:
|
|
6386
6847
|
raise RuntimeError(
|
|
6387
6848
|
f"Trying to create an array view with {idx_count} indices, "
|
|
6388
6849
|
f"but the array only has {arr_type.ndim} dimension(s). "
|
|
@@ -6390,14 +6851,35 @@ def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]
|
|
|
6390
6851
|
f"the expected number of dimensions, e.g.: def func(param: wp.array3d(dtype=float): ..."
|
|
6391
6852
|
)
|
|
6392
6853
|
|
|
6393
|
-
|
|
6394
|
-
|
|
6395
|
-
|
|
6396
|
-
|
|
6854
|
+
has_slice = any(is_slice(x) for x in idx_types)
|
|
6855
|
+
if has_slice:
|
|
6856
|
+
# check index types
|
|
6857
|
+
for t in idx_types:
|
|
6858
|
+
if not (type_is_int(t) or is_slice(t)):
|
|
6859
|
+
raise RuntimeError(
|
|
6860
|
+
f"view() index arguments must be of integer or slice types, got index of type {type_repr(t)}"
|
|
6861
|
+
)
|
|
6862
|
+
|
|
6863
|
+
# Each integer index collapses one dimension.
|
|
6864
|
+
int_count = sum(x.step == 0 for x in idx_types)
|
|
6865
|
+
ndim = arr_type.ndim - int_count
|
|
6866
|
+
assert ndim > 0
|
|
6867
|
+
else:
|
|
6868
|
+
if idx_count == arr_type.ndim:
|
|
6869
|
+
raise RuntimeError("Expected to call `address()` instead of `view()`")
|
|
6870
|
+
|
|
6871
|
+
# check index types
|
|
6872
|
+
for t in idx_types:
|
|
6873
|
+
if not type_is_int(t):
|
|
6874
|
+
raise RuntimeError(
|
|
6875
|
+
f"view() index arguments must be of integer or slice types, got index of type {type_repr(t)}"
|
|
6876
|
+
)
|
|
6877
|
+
|
|
6878
|
+
# create an array view with leading dimensions removed
|
|
6879
|
+
ndim = arr_type.ndim - idx_count
|
|
6880
|
+
assert ndim > 0
|
|
6397
6881
|
|
|
6398
|
-
# create an array view with leading dimensions removed
|
|
6399
6882
|
dtype = arr_type.dtype
|
|
6400
|
-
ndim = arr_type.ndim - idx_count
|
|
6401
6883
|
if isinstance(arr_type, (fabricarray, indexedfabricarray)):
|
|
6402
6884
|
# fabric array of arrays: return array attribute as a regular array
|
|
6403
6885
|
return array(dtype=dtype, ndim=ndim)
|
|
@@ -6408,8 +6890,18 @@ def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]
|
|
|
6408
6890
|
for array_type in array_types:
|
|
6409
6891
|
add_builtin(
|
|
6410
6892
|
"view",
|
|
6411
|
-
input_types={
|
|
6412
|
-
|
|
6893
|
+
input_types={
|
|
6894
|
+
"arr": array_type(dtype=Any),
|
|
6895
|
+
"i": Any,
|
|
6896
|
+
"j": Any,
|
|
6897
|
+
"k": Any,
|
|
6898
|
+
"l": Any,
|
|
6899
|
+
},
|
|
6900
|
+
defaults={
|
|
6901
|
+
"j": None,
|
|
6902
|
+
"k": None,
|
|
6903
|
+
"l": None,
|
|
6904
|
+
},
|
|
6413
6905
|
constraint=sametypes,
|
|
6414
6906
|
hidden=True,
|
|
6415
6907
|
value_func=view_value_func,
|
|
@@ -6513,7 +7005,7 @@ add_builtin(
|
|
|
6513
7005
|
hidden=True,
|
|
6514
7006
|
skip_replay=True,
|
|
6515
7007
|
group="Utility",
|
|
6516
|
-
|
|
7008
|
+
is_differentiable=False,
|
|
6517
7009
|
)
|
|
6518
7010
|
|
|
6519
7011
|
|
|
@@ -6530,7 +7022,7 @@ add_builtin(
|
|
|
6530
7022
|
dispatch_func=load_dispatch_func,
|
|
6531
7023
|
hidden=True,
|
|
6532
7024
|
group="Utility",
|
|
6533
|
-
|
|
7025
|
+
is_differentiable=False,
|
|
6534
7026
|
)
|
|
6535
7027
|
|
|
6536
7028
|
|
|
@@ -6606,6 +7098,13 @@ def create_atomic_op_value_func(op: str):
|
|
|
6606
7098
|
f"atomic_{op}() operations only work on arrays with [u]int32, [u]int64, float32, or float64 "
|
|
6607
7099
|
f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
|
|
6608
7100
|
)
|
|
7101
|
+
elif op in ("and", "or", "xor"):
|
|
7102
|
+
supported_atomic_types = (warp.int32, warp.int64, warp.uint32, warp.uint64)
|
|
7103
|
+
if not any(types_equal(scalar_type, x, match_generic=True) for x in supported_atomic_types):
|
|
7104
|
+
raise RuntimeError(
|
|
7105
|
+
f"atomic_{op}() operations only work on arrays with [u]int32 or [u]int64 "
|
|
7106
|
+
f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
|
|
7107
|
+
)
|
|
6609
7108
|
else:
|
|
6610
7109
|
raise NotImplementedError
|
|
6611
7110
|
|
|
@@ -6847,7 +7346,7 @@ for array_type in array_types:
|
|
|
6847
7346
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6848
7347
|
group="Utility",
|
|
6849
7348
|
skip_replay=True,
|
|
6850
|
-
|
|
7349
|
+
is_differentiable=False,
|
|
6851
7350
|
)
|
|
6852
7351
|
add_builtin(
|
|
6853
7352
|
"atomic_cas",
|
|
@@ -6861,7 +7360,7 @@ for array_type in array_types:
|
|
|
6861
7360
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6862
7361
|
group="Utility",
|
|
6863
7362
|
skip_replay=True,
|
|
6864
|
-
|
|
7363
|
+
is_differentiable=False,
|
|
6865
7364
|
)
|
|
6866
7365
|
add_builtin(
|
|
6867
7366
|
"atomic_cas",
|
|
@@ -6875,7 +7374,7 @@ for array_type in array_types:
|
|
|
6875
7374
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6876
7375
|
group="Utility",
|
|
6877
7376
|
skip_replay=True,
|
|
6878
|
-
|
|
7377
|
+
is_differentiable=False,
|
|
6879
7378
|
)
|
|
6880
7379
|
add_builtin(
|
|
6881
7380
|
"atomic_cas",
|
|
@@ -6897,7 +7396,7 @@ for array_type in array_types:
|
|
|
6897
7396
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6898
7397
|
group="Utility",
|
|
6899
7398
|
skip_replay=True,
|
|
6900
|
-
|
|
7399
|
+
is_differentiable=False,
|
|
6901
7400
|
)
|
|
6902
7401
|
|
|
6903
7402
|
add_builtin(
|
|
@@ -6912,7 +7411,7 @@ for array_type in array_types:
|
|
|
6912
7411
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6913
7412
|
group="Utility",
|
|
6914
7413
|
skip_replay=True,
|
|
6915
|
-
|
|
7414
|
+
is_differentiable=False,
|
|
6916
7415
|
)
|
|
6917
7416
|
add_builtin(
|
|
6918
7417
|
"atomic_exch",
|
|
@@ -6926,34 +7425,193 @@ for array_type in array_types:
|
|
|
6926
7425
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
6927
7426
|
group="Utility",
|
|
6928
7427
|
skip_replay=True,
|
|
6929
|
-
|
|
7428
|
+
is_differentiable=False,
|
|
7429
|
+
)
|
|
7430
|
+
add_builtin(
|
|
7431
|
+
"atomic_exch",
|
|
7432
|
+
hidden=hidden,
|
|
7433
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
7434
|
+
constraint=atomic_op_constraint,
|
|
7435
|
+
value_func=create_atomic_op_value_func("exch"),
|
|
7436
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7437
|
+
doc="""Atomically exchange ``value`` with ``arr[i,j,k]`` and return the old value.
|
|
7438
|
+
|
|
7439
|
+
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
7440
|
+
group="Utility",
|
|
7441
|
+
skip_replay=True,
|
|
7442
|
+
is_differentiable=False,
|
|
7443
|
+
)
|
|
7444
|
+
add_builtin(
|
|
7445
|
+
"atomic_exch",
|
|
7446
|
+
hidden=hidden,
|
|
7447
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
7448
|
+
constraint=atomic_op_constraint,
|
|
7449
|
+
value_func=create_atomic_op_value_func("exch"),
|
|
7450
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7451
|
+
doc="""Atomically exchange ``value`` with ``arr[i,j,k,l]`` and return the old value.
|
|
7452
|
+
|
|
7453
|
+
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
7454
|
+
group="Utility",
|
|
7455
|
+
skip_replay=True,
|
|
7456
|
+
)
|
|
7457
|
+
|
|
7458
|
+
add_builtin(
|
|
7459
|
+
"atomic_and",
|
|
7460
|
+
hidden=hidden,
|
|
7461
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
7462
|
+
constraint=atomic_op_constraint,
|
|
7463
|
+
value_func=create_atomic_op_value_func("and"),
|
|
7464
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7465
|
+
doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i]``, atomically update the array, and return the old value.
|
|
7466
|
+
This function is automatically invoked when using the syntax ``arr[i] &= value``.""",
|
|
7467
|
+
group="Utility",
|
|
7468
|
+
skip_replay=True,
|
|
7469
|
+
is_differentiable=False,
|
|
7470
|
+
)
|
|
7471
|
+
add_builtin(
|
|
7472
|
+
"atomic_and",
|
|
7473
|
+
hidden=hidden,
|
|
7474
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
7475
|
+
constraint=atomic_op_constraint,
|
|
7476
|
+
value_func=create_atomic_op_value_func("and"),
|
|
7477
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7478
|
+
doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
|
|
7479
|
+
This function is automatically invoked when using the syntax ``arr[i,j] &= value``.""",
|
|
7480
|
+
group="Utility",
|
|
7481
|
+
skip_replay=True,
|
|
7482
|
+
is_differentiable=False,
|
|
7483
|
+
)
|
|
7484
|
+
add_builtin(
|
|
7485
|
+
"atomic_and",
|
|
7486
|
+
hidden=hidden,
|
|
7487
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
7488
|
+
constraint=atomic_op_constraint,
|
|
7489
|
+
value_func=create_atomic_op_value_func("and"),
|
|
7490
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7491
|
+
doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
|
|
7492
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k] &= value``.""",
|
|
7493
|
+
group="Utility",
|
|
7494
|
+
skip_replay=True,
|
|
7495
|
+
is_differentiable=False,
|
|
7496
|
+
)
|
|
7497
|
+
add_builtin(
|
|
7498
|
+
"atomic_and",
|
|
7499
|
+
hidden=hidden,
|
|
7500
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
7501
|
+
constraint=atomic_op_constraint,
|
|
7502
|
+
value_func=create_atomic_op_value_func("and"),
|
|
7503
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7504
|
+
doc="""Atomically performs a bitwise AND between ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
|
|
7505
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k,l] &= value``.""",
|
|
7506
|
+
group="Utility",
|
|
7507
|
+
skip_replay=True,
|
|
7508
|
+
is_differentiable=False,
|
|
7509
|
+
)
|
|
7510
|
+
|
|
7511
|
+
add_builtin(
|
|
7512
|
+
"atomic_or",
|
|
7513
|
+
hidden=hidden,
|
|
7514
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
7515
|
+
constraint=atomic_op_constraint,
|
|
7516
|
+
value_func=create_atomic_op_value_func("or"),
|
|
7517
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7518
|
+
doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i]``, atomically update the array, and return the old value.
|
|
7519
|
+
This function is automatically invoked when using the syntax ``arr[i] |= value``.""",
|
|
7520
|
+
group="Utility",
|
|
7521
|
+
skip_replay=True,
|
|
7522
|
+
is_differentiable=False,
|
|
7523
|
+
)
|
|
7524
|
+
add_builtin(
|
|
7525
|
+
"atomic_or",
|
|
7526
|
+
hidden=hidden,
|
|
7527
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
7528
|
+
constraint=atomic_op_constraint,
|
|
7529
|
+
value_func=create_atomic_op_value_func("or"),
|
|
7530
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7531
|
+
doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
|
|
7532
|
+
This function is automatically invoked when using the syntax ``arr[i,j] |= value``.""",
|
|
7533
|
+
group="Utility",
|
|
7534
|
+
skip_replay=True,
|
|
7535
|
+
is_differentiable=False,
|
|
7536
|
+
)
|
|
7537
|
+
add_builtin(
|
|
7538
|
+
"atomic_or",
|
|
7539
|
+
hidden=hidden,
|
|
7540
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
7541
|
+
constraint=atomic_op_constraint,
|
|
7542
|
+
value_func=create_atomic_op_value_func("or"),
|
|
7543
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7544
|
+
doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
|
|
7545
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k] |= value``.""",
|
|
7546
|
+
group="Utility",
|
|
7547
|
+
skip_replay=True,
|
|
7548
|
+
is_differentiable=False,
|
|
7549
|
+
)
|
|
7550
|
+
add_builtin(
|
|
7551
|
+
"atomic_or",
|
|
7552
|
+
hidden=hidden,
|
|
7553
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
7554
|
+
constraint=atomic_op_constraint,
|
|
7555
|
+
value_func=create_atomic_op_value_func("or"),
|
|
7556
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7557
|
+
doc="""Atomically performs a bitwise OR between ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
|
|
7558
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k,l] |= value``.""",
|
|
7559
|
+
group="Utility",
|
|
7560
|
+
skip_replay=True,
|
|
7561
|
+
is_differentiable=False,
|
|
7562
|
+
)
|
|
7563
|
+
|
|
7564
|
+
add_builtin(
|
|
7565
|
+
"atomic_xor",
|
|
7566
|
+
hidden=hidden,
|
|
7567
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
7568
|
+
constraint=atomic_op_constraint,
|
|
7569
|
+
value_func=create_atomic_op_value_func("xor"),
|
|
7570
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7571
|
+
doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i]``, atomically update the array, and return the old value.
|
|
7572
|
+
This function is automatically invoked when using the syntax ``arr[i] ^= value``.""",
|
|
7573
|
+
group="Utility",
|
|
7574
|
+
skip_replay=True,
|
|
7575
|
+
is_differentiable=False,
|
|
7576
|
+
)
|
|
7577
|
+
add_builtin(
|
|
7578
|
+
"atomic_xor",
|
|
7579
|
+
hidden=hidden,
|
|
7580
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
7581
|
+
constraint=atomic_op_constraint,
|
|
7582
|
+
value_func=create_atomic_op_value_func("xor"),
|
|
7583
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
7584
|
+
doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
|
|
7585
|
+
This function is automatically invoked when using the syntax ``arr[i,j] ^= value``.""",
|
|
7586
|
+
group="Utility",
|
|
7587
|
+
skip_replay=True,
|
|
7588
|
+
is_differentiable=False,
|
|
6930
7589
|
)
|
|
6931
7590
|
add_builtin(
|
|
6932
|
-
"
|
|
7591
|
+
"atomic_xor",
|
|
6933
7592
|
hidden=hidden,
|
|
6934
7593
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
6935
7594
|
constraint=atomic_op_constraint,
|
|
6936
|
-
value_func=create_atomic_op_value_func("
|
|
7595
|
+
value_func=create_atomic_op_value_func("xor"),
|
|
6937
7596
|
dispatch_func=atomic_op_dispatch_func,
|
|
6938
|
-
doc="""Atomically
|
|
6939
|
-
|
|
6940
|
-
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
7597
|
+
doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
|
|
7598
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k] ^= value``.""",
|
|
6941
7599
|
group="Utility",
|
|
6942
7600
|
skip_replay=True,
|
|
6943
|
-
|
|
7601
|
+
is_differentiable=False,
|
|
6944
7602
|
)
|
|
6945
7603
|
add_builtin(
|
|
6946
|
-
"
|
|
7604
|
+
"atomic_xor",
|
|
6947
7605
|
hidden=hidden,
|
|
6948
7606
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
6949
7607
|
constraint=atomic_op_constraint,
|
|
6950
|
-
value_func=create_atomic_op_value_func("
|
|
7608
|
+
value_func=create_atomic_op_value_func("xor"),
|
|
6951
7609
|
dispatch_func=atomic_op_dispatch_func,
|
|
6952
|
-
doc="""Atomically
|
|
6953
|
-
|
|
6954
|
-
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
7610
|
+
doc="""Atomically performs a bitwise XOR between ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
|
|
7611
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k,l] ^= value``.""",
|
|
6955
7612
|
group="Utility",
|
|
6956
7613
|
skip_replay=True,
|
|
7614
|
+
is_differentiable=False,
|
|
6957
7615
|
)
|
|
6958
7616
|
|
|
6959
7617
|
|
|
@@ -7104,7 +7762,7 @@ add_builtin(
|
|
|
7104
7762
|
hidden=True,
|
|
7105
7763
|
group="Utility",
|
|
7106
7764
|
skip_replay=True,
|
|
7107
|
-
|
|
7765
|
+
is_differentiable=False,
|
|
7108
7766
|
)
|
|
7109
7767
|
# implements &quaternion[index]
|
|
7110
7768
|
add_builtin(
|
|
@@ -7115,7 +7773,7 @@ add_builtin(
|
|
|
7115
7773
|
hidden=True,
|
|
7116
7774
|
group="Utility",
|
|
7117
7775
|
skip_replay=True,
|
|
7118
|
-
|
|
7776
|
+
is_differentiable=False,
|
|
7119
7777
|
)
|
|
7120
7778
|
# implements &transformation[index]
|
|
7121
7779
|
add_builtin(
|
|
@@ -7126,7 +7784,7 @@ add_builtin(
|
|
|
7126
7784
|
hidden=True,
|
|
7127
7785
|
group="Utility",
|
|
7128
7786
|
skip_replay=True,
|
|
7129
|
-
|
|
7787
|
+
is_differentiable=False,
|
|
7130
7788
|
)
|
|
7131
7789
|
# implements &(*vector)[index]
|
|
7132
7790
|
add_builtin(
|
|
@@ -7137,7 +7795,7 @@ add_builtin(
|
|
|
7137
7795
|
hidden=True,
|
|
7138
7796
|
group="Utility",
|
|
7139
7797
|
skip_replay=True,
|
|
7140
|
-
|
|
7798
|
+
is_differentiable=False,
|
|
7141
7799
|
)
|
|
7142
7800
|
# implements &(*matrix)[i, j]
|
|
7143
7801
|
add_builtin(
|
|
@@ -7148,7 +7806,7 @@ add_builtin(
|
|
|
7148
7806
|
hidden=True,
|
|
7149
7807
|
group="Utility",
|
|
7150
7808
|
skip_replay=True,
|
|
7151
|
-
|
|
7809
|
+
is_differentiable=False,
|
|
7152
7810
|
)
|
|
7153
7811
|
# implements &(*quaternion)[index]
|
|
7154
7812
|
add_builtin(
|
|
@@ -7159,7 +7817,7 @@ add_builtin(
|
|
|
7159
7817
|
hidden=True,
|
|
7160
7818
|
group="Utility",
|
|
7161
7819
|
skip_replay=True,
|
|
7162
|
-
|
|
7820
|
+
is_differentiable=False,
|
|
7163
7821
|
)
|
|
7164
7822
|
# implements &(*transformation)[index]
|
|
7165
7823
|
add_builtin(
|
|
@@ -7170,7 +7828,7 @@ add_builtin(
|
|
|
7170
7828
|
hidden=True,
|
|
7171
7829
|
group="Utility",
|
|
7172
7830
|
skip_replay=True,
|
|
7173
|
-
|
|
7831
|
+
is_differentiable=False,
|
|
7174
7832
|
)
|
|
7175
7833
|
|
|
7176
7834
|
|
|
@@ -7366,6 +8024,43 @@ add_builtin(
|
|
|
7366
8024
|
)
|
|
7367
8025
|
|
|
7368
8026
|
|
|
8027
|
+
# implements vector[idx] &= scalar
|
|
8028
|
+
add_builtin(
|
|
8029
|
+
"bit_and_inplace",
|
|
8030
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
|
|
8031
|
+
value_type=None,
|
|
8032
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
8033
|
+
hidden=True,
|
|
8034
|
+
export=False,
|
|
8035
|
+
group="Utility",
|
|
8036
|
+
is_differentiable=False,
|
|
8037
|
+
)
|
|
8038
|
+
|
|
8039
|
+
# implements vector[idx] |= scalar
|
|
8040
|
+
add_builtin(
|
|
8041
|
+
"bit_or_inplace",
|
|
8042
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
|
|
8043
|
+
value_type=None,
|
|
8044
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
8045
|
+
hidden=True,
|
|
8046
|
+
export=False,
|
|
8047
|
+
group="Utility",
|
|
8048
|
+
is_differentiable=False,
|
|
8049
|
+
)
|
|
8050
|
+
|
|
8051
|
+
# implements vector[idx] ^= scalar
|
|
8052
|
+
add_builtin(
|
|
8053
|
+
"bit_xor_inplace",
|
|
8054
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": Any, "value": Any},
|
|
8055
|
+
value_type=None,
|
|
8056
|
+
dispatch_func=vector_assign_dispatch_func,
|
|
8057
|
+
hidden=True,
|
|
8058
|
+
export=False,
|
|
8059
|
+
group="Utility",
|
|
8060
|
+
is_differentiable=False,
|
|
8061
|
+
)
|
|
8062
|
+
|
|
8063
|
+
|
|
7369
8064
|
def matrix_index_row_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
7370
8065
|
mat_type = arg_types["a"]
|
|
7371
8066
|
row_type = mat_type._wp_row_type_
|
|
@@ -7381,7 +8076,7 @@ add_builtin(
|
|
|
7381
8076
|
hidden=True,
|
|
7382
8077
|
group="Utility",
|
|
7383
8078
|
skip_replay=True,
|
|
7384
|
-
|
|
8079
|
+
is_differentiable=False,
|
|
7385
8080
|
)
|
|
7386
8081
|
|
|
7387
8082
|
|
|
@@ -7400,7 +8095,7 @@ add_builtin(
|
|
|
7400
8095
|
hidden=True,
|
|
7401
8096
|
group="Utility",
|
|
7402
8097
|
skip_replay=True,
|
|
7403
|
-
|
|
8098
|
+
is_differentiable=False,
|
|
7404
8099
|
)
|
|
7405
8100
|
|
|
7406
8101
|
|
|
@@ -7600,6 +8295,78 @@ add_builtin(
|
|
|
7600
8295
|
)
|
|
7601
8296
|
|
|
7602
8297
|
|
|
8298
|
+
# implements matrix[i] &= value
|
|
8299
|
+
add_builtin(
|
|
8300
|
+
"bit_and_inplace",
|
|
8301
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
|
|
8302
|
+
value_type=None,
|
|
8303
|
+
hidden=True,
|
|
8304
|
+
export=False,
|
|
8305
|
+
group="Utility",
|
|
8306
|
+
is_differentiable=False,
|
|
8307
|
+
)
|
|
8308
|
+
|
|
8309
|
+
|
|
8310
|
+
# implements matrix[i,j] &= value
|
|
8311
|
+
add_builtin(
|
|
8312
|
+
"bit_and_inplace",
|
|
8313
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
|
|
8314
|
+
value_type=None,
|
|
8315
|
+
hidden=True,
|
|
8316
|
+
export=False,
|
|
8317
|
+
group="Utility",
|
|
8318
|
+
is_differentiable=False,
|
|
8319
|
+
)
|
|
8320
|
+
|
|
8321
|
+
|
|
8322
|
+
# implements matrix[i] |= value
|
|
8323
|
+
add_builtin(
|
|
8324
|
+
"bit_or_inplace",
|
|
8325
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
|
|
8326
|
+
value_type=None,
|
|
8327
|
+
hidden=True,
|
|
8328
|
+
export=False,
|
|
8329
|
+
group="Utility",
|
|
8330
|
+
is_differentiable=False,
|
|
8331
|
+
)
|
|
8332
|
+
|
|
8333
|
+
|
|
8334
|
+
# implements matrix[i,j] |= value
|
|
8335
|
+
add_builtin(
|
|
8336
|
+
"bit_or_inplace",
|
|
8337
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
|
|
8338
|
+
value_type=None,
|
|
8339
|
+
hidden=True,
|
|
8340
|
+
export=False,
|
|
8341
|
+
group="Utility",
|
|
8342
|
+
is_differentiable=False,
|
|
8343
|
+
)
|
|
8344
|
+
|
|
8345
|
+
|
|
8346
|
+
# implements matrix[i] ^= value
|
|
8347
|
+
add_builtin(
|
|
8348
|
+
"bit_xor_inplace",
|
|
8349
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "value": Any},
|
|
8350
|
+
value_type=None,
|
|
8351
|
+
hidden=True,
|
|
8352
|
+
export=False,
|
|
8353
|
+
group="Utility",
|
|
8354
|
+
is_differentiable=False,
|
|
8355
|
+
)
|
|
8356
|
+
|
|
8357
|
+
|
|
8358
|
+
# implements matrix[i,j] ^= value
|
|
8359
|
+
add_builtin(
|
|
8360
|
+
"bit_xor_inplace",
|
|
8361
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": Any, "j": Any, "value": Any},
|
|
8362
|
+
value_type=None,
|
|
8363
|
+
hidden=True,
|
|
8364
|
+
export=False,
|
|
8365
|
+
group="Utility",
|
|
8366
|
+
is_differentiable=False,
|
|
8367
|
+
)
|
|
8368
|
+
|
|
8369
|
+
|
|
7603
8370
|
for t in scalar_types + vector_types + (bool,):
|
|
7604
8371
|
if "vec" in t.__name__ or "mat" in t.__name__:
|
|
7605
8372
|
continue
|
|
@@ -7611,7 +8378,7 @@ for t in scalar_types + vector_types + (bool,):
|
|
|
7611
8378
|
doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
|
|
7612
8379
|
group="Utility",
|
|
7613
8380
|
hidden=True,
|
|
7614
|
-
|
|
8381
|
+
is_differentiable=False,
|
|
7615
8382
|
)
|
|
7616
8383
|
|
|
7617
8384
|
add_builtin(
|
|
@@ -7622,7 +8389,7 @@ for t in scalar_types + vector_types + (bool,):
|
|
|
7622
8389
|
group="Utility",
|
|
7623
8390
|
hidden=True,
|
|
7624
8391
|
export=False,
|
|
7625
|
-
|
|
8392
|
+
is_differentiable=False,
|
|
7626
8393
|
)
|
|
7627
8394
|
|
|
7628
8395
|
|
|
@@ -7641,7 +8408,7 @@ add_builtin(
|
|
|
7641
8408
|
doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
|
|
7642
8409
|
group="Utility",
|
|
7643
8410
|
hidden=True,
|
|
7644
|
-
|
|
8411
|
+
is_differentiable=False,
|
|
7645
8412
|
)
|
|
7646
8413
|
add_builtin(
|
|
7647
8414
|
"expect_neq",
|
|
@@ -7652,7 +8419,7 @@ add_builtin(
|
|
|
7652
8419
|
group="Utility",
|
|
7653
8420
|
hidden=True,
|
|
7654
8421
|
export=False,
|
|
7655
|
-
|
|
8422
|
+
is_differentiable=False,
|
|
7656
8423
|
)
|
|
7657
8424
|
|
|
7658
8425
|
add_builtin(
|
|
@@ -7663,7 +8430,7 @@ add_builtin(
|
|
|
7663
8430
|
doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
|
|
7664
8431
|
group="Utility",
|
|
7665
8432
|
hidden=True,
|
|
7666
|
-
|
|
8433
|
+
is_differentiable=False,
|
|
7667
8434
|
)
|
|
7668
8435
|
add_builtin(
|
|
7669
8436
|
"expect_neq",
|
|
@@ -7674,7 +8441,7 @@ add_builtin(
|
|
|
7674
8441
|
group="Utility",
|
|
7675
8442
|
hidden=True,
|
|
7676
8443
|
export=False,
|
|
7677
|
-
|
|
8444
|
+
is_differentiable=False,
|
|
7678
8445
|
)
|
|
7679
8446
|
|
|
7680
8447
|
add_builtin(
|
|
@@ -7765,7 +8532,7 @@ add_builtin(
|
|
|
7765
8532
|
value_type=None,
|
|
7766
8533
|
doc="Prints an error to stdout if ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
7767
8534
|
group="Utility",
|
|
7768
|
-
|
|
8535
|
+
is_differentiable=False,
|
|
7769
8536
|
)
|
|
7770
8537
|
add_builtin(
|
|
7771
8538
|
"expect_near",
|
|
@@ -7775,7 +8542,7 @@ add_builtin(
|
|
|
7775
8542
|
value_type=None,
|
|
7776
8543
|
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
7777
8544
|
group="Utility",
|
|
7778
|
-
|
|
8545
|
+
is_differentiable=False,
|
|
7779
8546
|
)
|
|
7780
8547
|
add_builtin(
|
|
7781
8548
|
"expect_near",
|
|
@@ -7785,7 +8552,7 @@ add_builtin(
|
|
|
7785
8552
|
value_type=None,
|
|
7786
8553
|
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
7787
8554
|
group="Utility",
|
|
7788
|
-
|
|
8555
|
+
is_differentiable=False,
|
|
7789
8556
|
)
|
|
7790
8557
|
add_builtin(
|
|
7791
8558
|
"expect_near",
|
|
@@ -7799,7 +8566,7 @@ add_builtin(
|
|
|
7799
8566
|
value_type=None,
|
|
7800
8567
|
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
7801
8568
|
group="Utility",
|
|
7802
|
-
|
|
8569
|
+
is_differentiable=False,
|
|
7803
8570
|
)
|
|
7804
8571
|
|
|
7805
8572
|
# ---------------------------------
|
|
@@ -7810,7 +8577,7 @@ add_builtin(
|
|
|
7810
8577
|
input_types={"arr": array(dtype=Scalar), "value": Scalar},
|
|
7811
8578
|
value_type=int,
|
|
7812
8579
|
doc="Search a sorted array ``arr`` for the closest element greater than or equal to ``value``.",
|
|
7813
|
-
|
|
8580
|
+
is_differentiable=False,
|
|
7814
8581
|
)
|
|
7815
8582
|
|
|
7816
8583
|
add_builtin(
|
|
@@ -7818,7 +8585,7 @@ add_builtin(
|
|
|
7818
8585
|
input_types={"arr": array(dtype=Scalar), "arr_begin": int, "arr_end": int, "value": Scalar},
|
|
7819
8586
|
value_type=int,
|
|
7820
8587
|
doc="Search a sorted array ``arr`` in the range [arr_begin, arr_end) for the closest element greater than or equal to ``value``.",
|
|
7821
|
-
|
|
8588
|
+
is_differentiable=False,
|
|
7822
8589
|
)
|
|
7823
8590
|
|
|
7824
8591
|
# ---------------------------------
|
|
@@ -7899,31 +8666,153 @@ add_builtin(
|
|
|
7899
8666
|
input_types={"a": Int, "b": Int},
|
|
7900
8667
|
value_func=sametypes_create_value_func(Int),
|
|
7901
8668
|
group="Operators",
|
|
7902
|
-
|
|
8669
|
+
is_differentiable=False,
|
|
8670
|
+
)
|
|
8671
|
+
add_builtin(
|
|
8672
|
+
"bit_and",
|
|
8673
|
+
input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
|
|
8674
|
+
constraint=sametypes,
|
|
8675
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
|
|
8676
|
+
doc="",
|
|
8677
|
+
group="Operators",
|
|
8678
|
+
is_differentiable=False,
|
|
8679
|
+
)
|
|
8680
|
+
add_builtin(
|
|
8681
|
+
"bit_and",
|
|
8682
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
|
|
8683
|
+
constraint=sametypes,
|
|
8684
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
|
|
8685
|
+
doc="",
|
|
8686
|
+
group="Operators",
|
|
8687
|
+
is_differentiable=False,
|
|
8688
|
+
)
|
|
8689
|
+
|
|
8690
|
+
add_builtin(
|
|
8691
|
+
"bit_or",
|
|
8692
|
+
input_types={"a": Int, "b": Int},
|
|
8693
|
+
value_func=sametypes_create_value_func(Int),
|
|
8694
|
+
group="Operators",
|
|
8695
|
+
is_differentiable=False,
|
|
7903
8696
|
)
|
|
7904
8697
|
add_builtin(
|
|
7905
8698
|
"bit_or",
|
|
8699
|
+
input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
|
|
8700
|
+
constraint=sametypes,
|
|
8701
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
|
|
8702
|
+
doc="",
|
|
8703
|
+
group="Operators",
|
|
8704
|
+
is_differentiable=False,
|
|
8705
|
+
)
|
|
8706
|
+
add_builtin(
|
|
8707
|
+
"bit_or",
|
|
8708
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
|
|
8709
|
+
constraint=sametypes,
|
|
8710
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
|
|
8711
|
+
doc="",
|
|
8712
|
+
group="Operators",
|
|
8713
|
+
is_differentiable=False,
|
|
8714
|
+
)
|
|
8715
|
+
|
|
8716
|
+
add_builtin(
|
|
8717
|
+
"bit_xor",
|
|
7906
8718
|
input_types={"a": Int, "b": Int},
|
|
7907
8719
|
value_func=sametypes_create_value_func(Int),
|
|
7908
8720
|
group="Operators",
|
|
7909
|
-
|
|
8721
|
+
is_differentiable=False,
|
|
8722
|
+
)
|
|
8723
|
+
add_builtin(
|
|
8724
|
+
"bit_xor",
|
|
8725
|
+
input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
|
|
8726
|
+
constraint=sametypes,
|
|
8727
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
|
|
8728
|
+
doc="",
|
|
8729
|
+
group="Operators",
|
|
8730
|
+
is_differentiable=False,
|
|
7910
8731
|
)
|
|
7911
8732
|
add_builtin(
|
|
7912
8733
|
"bit_xor",
|
|
8734
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
|
|
8735
|
+
constraint=sametypes,
|
|
8736
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
|
|
8737
|
+
doc="",
|
|
8738
|
+
group="Operators",
|
|
8739
|
+
is_differentiable=False,
|
|
8740
|
+
)
|
|
8741
|
+
|
|
8742
|
+
add_builtin(
|
|
8743
|
+
"lshift",
|
|
7913
8744
|
input_types={"a": Int, "b": Int},
|
|
7914
8745
|
value_func=sametypes_create_value_func(Int),
|
|
7915
8746
|
group="Operators",
|
|
7916
|
-
|
|
8747
|
+
is_differentiable=False,
|
|
8748
|
+
)
|
|
8749
|
+
add_builtin(
|
|
8750
|
+
"lshift",
|
|
8751
|
+
input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
|
|
8752
|
+
constraint=sametypes,
|
|
8753
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
|
|
8754
|
+
doc="",
|
|
8755
|
+
group="Operators",
|
|
8756
|
+
is_differentiable=False,
|
|
8757
|
+
)
|
|
8758
|
+
add_builtin(
|
|
8759
|
+
"lshift",
|
|
8760
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
|
|
8761
|
+
constraint=sametypes,
|
|
8762
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
|
|
8763
|
+
doc="",
|
|
8764
|
+
group="Operators",
|
|
8765
|
+
is_differentiable=False,
|
|
7917
8766
|
)
|
|
7918
|
-
|
|
8767
|
+
|
|
7919
8768
|
add_builtin(
|
|
7920
8769
|
"rshift",
|
|
7921
8770
|
input_types={"a": Int, "b": Int},
|
|
7922
8771
|
value_func=sametypes_create_value_func(Int),
|
|
7923
8772
|
group="Operators",
|
|
7924
|
-
|
|
8773
|
+
is_differentiable=False,
|
|
8774
|
+
)
|
|
8775
|
+
add_builtin(
|
|
8776
|
+
"rshift",
|
|
8777
|
+
input_types={"a": vector(length=Any, dtype=Int), "b": vector(length=Any, dtype=Int)},
|
|
8778
|
+
constraint=sametypes,
|
|
8779
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
|
|
8780
|
+
doc="",
|
|
8781
|
+
group="Operators",
|
|
8782
|
+
is_differentiable=False,
|
|
8783
|
+
)
|
|
8784
|
+
add_builtin(
|
|
8785
|
+
"rshift",
|
|
8786
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Int), "b": matrix(shape=(Any, Any), dtype=Int)},
|
|
8787
|
+
constraint=sametypes,
|
|
8788
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
|
|
8789
|
+
doc="",
|
|
8790
|
+
group="Operators",
|
|
8791
|
+
is_differentiable=False,
|
|
8792
|
+
)
|
|
8793
|
+
|
|
8794
|
+
add_builtin(
|
|
8795
|
+
"invert",
|
|
8796
|
+
input_types={"a": Int},
|
|
8797
|
+
value_func=sametypes_create_value_func(Int),
|
|
8798
|
+
group="Operators",
|
|
8799
|
+
is_differentiable=False,
|
|
8800
|
+
)
|
|
8801
|
+
add_builtin(
|
|
8802
|
+
"invert",
|
|
8803
|
+
input_types={"a": vector(length=Any, dtype=Int)},
|
|
8804
|
+
value_func=sametypes_create_value_func(vector(length=Any, dtype=Int)),
|
|
8805
|
+
group="Operators",
|
|
8806
|
+
is_differentiable=False,
|
|
7925
8807
|
)
|
|
7926
|
-
add_builtin(
|
|
8808
|
+
add_builtin(
|
|
8809
|
+
"invert",
|
|
8810
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Int)},
|
|
8811
|
+
value_func=sametypes_create_value_func(matrix(shape=(Any, Any), dtype=Int)),
|
|
8812
|
+
group="Operators",
|
|
8813
|
+
is_differentiable=False,
|
|
8814
|
+
)
|
|
8815
|
+
|
|
7927
8816
|
|
|
7928
8817
|
add_builtin(
|
|
7929
8818
|
"mul", input_types={"a": Scalar, "b": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators"
|
|
@@ -8123,7 +9012,7 @@ add_builtin(
|
|
|
8123
9012
|
value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
|
|
8124
9013
|
doc="Modulo operation using truncated division.",
|
|
8125
9014
|
group="Operators",
|
|
8126
|
-
|
|
9015
|
+
is_differentiable=False,
|
|
8127
9016
|
)
|
|
8128
9017
|
|
|
8129
9018
|
add_builtin(
|
|
@@ -8183,7 +9072,7 @@ add_builtin(
|
|
|
8183
9072
|
value_func=sametypes_create_value_func(Scalar),
|
|
8184
9073
|
doc="",
|
|
8185
9074
|
group="Operators",
|
|
8186
|
-
|
|
9075
|
+
is_differentiable=False,
|
|
8187
9076
|
)
|
|
8188
9077
|
|
|
8189
9078
|
add_builtin("pos", input_types={"x": Scalar}, value_func=sametypes_create_value_func(Scalar), group="Operators")
|
|
@@ -8232,14 +9121,26 @@ add_builtin(
|
|
|
8232
9121
|
)
|
|
8233
9122
|
|
|
8234
9123
|
add_builtin(
|
|
8235
|
-
"unot",
|
|
9124
|
+
"unot",
|
|
9125
|
+
input_types={"a": builtins.bool},
|
|
9126
|
+
value_type=builtins.bool,
|
|
9127
|
+
doc="",
|
|
9128
|
+
group="Operators",
|
|
9129
|
+
is_differentiable=False,
|
|
8236
9130
|
)
|
|
8237
9131
|
for t in int_types:
|
|
8238
|
-
add_builtin(
|
|
9132
|
+
add_builtin(
|
|
9133
|
+
"unot", input_types={"a": t}, value_type=builtins.bool, doc="", group="Operators", is_differentiable=False
|
|
9134
|
+
)
|
|
8239
9135
|
|
|
8240
9136
|
|
|
8241
9137
|
add_builtin(
|
|
8242
|
-
"unot",
|
|
9138
|
+
"unot",
|
|
9139
|
+
input_types={"a": array(dtype=Any)},
|
|
9140
|
+
value_type=builtins.bool,
|
|
9141
|
+
doc="",
|
|
9142
|
+
group="Operators",
|
|
9143
|
+
is_differentiable=False,
|
|
8243
9144
|
)
|
|
8244
9145
|
|
|
8245
9146
|
|
|
@@ -8312,6 +9213,45 @@ add_builtin(
|
|
|
8312
9213
|
export=False,
|
|
8313
9214
|
)
|
|
8314
9215
|
|
|
9216
|
+
add_builtin(
|
|
9217
|
+
"bit_and",
|
|
9218
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
9219
|
+
value_func=tile_binary_map_value_func,
|
|
9220
|
+
# dispatch_func=tile_map_dispatch_func,
|
|
9221
|
+
# variadic=True,
|
|
9222
|
+
native_func="tile_bit_and",
|
|
9223
|
+
doc="Bitwise AND each element of two tiles together",
|
|
9224
|
+
group="Tile Primitives",
|
|
9225
|
+
export=False,
|
|
9226
|
+
is_differentiable=False,
|
|
9227
|
+
)
|
|
9228
|
+
|
|
9229
|
+
add_builtin(
|
|
9230
|
+
"bit_or",
|
|
9231
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
9232
|
+
value_func=tile_binary_map_value_func,
|
|
9233
|
+
# dispatch_func=tile_map_dispatch_func,
|
|
9234
|
+
# variadic=True,
|
|
9235
|
+
native_func="tile_bit_or",
|
|
9236
|
+
doc="Bitwise OR each element of two tiles together",
|
|
9237
|
+
group="Tile Primitives",
|
|
9238
|
+
export=False,
|
|
9239
|
+
is_differentiable=False,
|
|
9240
|
+
)
|
|
9241
|
+
|
|
9242
|
+
add_builtin(
|
|
9243
|
+
"bit_xor",
|
|
9244
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
9245
|
+
value_func=tile_binary_map_value_func,
|
|
9246
|
+
# dispatch_func=tile_map_dispatch_func,
|
|
9247
|
+
# variadic=True,
|
|
9248
|
+
native_func="tile_bit_xor",
|
|
9249
|
+
doc="Bitwise XOR each element of two tiles together",
|
|
9250
|
+
group="Tile Primitives",
|
|
9251
|
+
export=False,
|
|
9252
|
+
is_differentiable=False,
|
|
9253
|
+
)
|
|
9254
|
+
|
|
8315
9255
|
|
|
8316
9256
|
add_builtin(
|
|
8317
9257
|
"mul",
|
|
@@ -8373,6 +9313,45 @@ add_builtin(
|
|
|
8373
9313
|
)
|
|
8374
9314
|
|
|
8375
9315
|
|
|
9316
|
+
add_builtin(
|
|
9317
|
+
"bit_and_inplace",
|
|
9318
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
9319
|
+
value_type=None,
|
|
9320
|
+
dispatch_func=tile_inplace_dispatch_func,
|
|
9321
|
+
export=False,
|
|
9322
|
+
hidden=True,
|
|
9323
|
+
native_func="tile_bit_and_inplace",
|
|
9324
|
+
group="Operators",
|
|
9325
|
+
is_differentiable=False,
|
|
9326
|
+
)
|
|
9327
|
+
|
|
9328
|
+
|
|
9329
|
+
add_builtin(
|
|
9330
|
+
"bit_or_inplace",
|
|
9331
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
9332
|
+
value_type=None,
|
|
9333
|
+
dispatch_func=tile_inplace_dispatch_func,
|
|
9334
|
+
export=False,
|
|
9335
|
+
hidden=True,
|
|
9336
|
+
native_func="tile_bit_or_inplace",
|
|
9337
|
+
group="Operators",
|
|
9338
|
+
is_differentiable=False,
|
|
9339
|
+
)
|
|
9340
|
+
|
|
9341
|
+
|
|
9342
|
+
add_builtin(
|
|
9343
|
+
"bit_xor_inplace",
|
|
9344
|
+
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...]), "b": tile(dtype=Any, shape=Tuple[int, ...])},
|
|
9345
|
+
value_type=None,
|
|
9346
|
+
dispatch_func=tile_inplace_dispatch_func,
|
|
9347
|
+
export=False,
|
|
9348
|
+
hidden=True,
|
|
9349
|
+
native_func="tile_bit_xor_inplace",
|
|
9350
|
+
group="Operators",
|
|
9351
|
+
is_differentiable=False,
|
|
9352
|
+
)
|
|
9353
|
+
|
|
9354
|
+
|
|
8376
9355
|
def tile_diag_add_value_func(arg_types, arg_values):
|
|
8377
9356
|
if arg_types is None:
|
|
8378
9357
|
return tile(dtype=Any, shape=Tuple[int, int])
|
|
@@ -8414,7 +9393,7 @@ def tile_diag_add_lto_dispatch_func(
|
|
|
8414
9393
|
return_values: List[Var],
|
|
8415
9394
|
arg_values: Mapping[str, Var],
|
|
8416
9395
|
options: Mapping[str, Any],
|
|
8417
|
-
builder: warp.context.ModuleBuilder,
|
|
9396
|
+
builder: warp._src.context.ModuleBuilder,
|
|
8418
9397
|
):
|
|
8419
9398
|
a = arg_values["a"]
|
|
8420
9399
|
d = arg_values["d"]
|
|
@@ -8434,7 +9413,7 @@ add_builtin(
|
|
|
8434
9413
|
doc="Add a square matrix and a diagonal matrix 'd' represented as a 1D tile",
|
|
8435
9414
|
group="Tile Primitives",
|
|
8436
9415
|
export=False,
|
|
8437
|
-
|
|
9416
|
+
is_differentiable=False,
|
|
8438
9417
|
)
|
|
8439
9418
|
|
|
8440
9419
|
|
|
@@ -8491,7 +9470,7 @@ def tile_matmul_lto_dispatch_func(
|
|
|
8491
9470
|
return_values: List[Var],
|
|
8492
9471
|
arg_values: Mapping[str, Var],
|
|
8493
9472
|
options: Mapping[str, Any],
|
|
8494
|
-
builder: warp.context.ModuleBuilder,
|
|
9473
|
+
builder: warp._src.context.ModuleBuilder,
|
|
8495
9474
|
):
|
|
8496
9475
|
a = arg_values["a"]
|
|
8497
9476
|
b = arg_values["b"]
|
|
@@ -8529,7 +9508,7 @@ def tile_matmul_lto_dispatch_func(
|
|
|
8529
9508
|
num_threads = options["block_dim"]
|
|
8530
9509
|
arch = options["output_arch"]
|
|
8531
9510
|
|
|
8532
|
-
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
9511
|
+
if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
|
|
8533
9512
|
# CPU/no-MathDx dispatch
|
|
8534
9513
|
return ((0, 0, 0, a, b, out), template_args, [], 0)
|
|
8535
9514
|
else:
|
|
@@ -8542,7 +9521,7 @@ def tile_matmul_lto_dispatch_func(
|
|
|
8542
9521
|
|
|
8543
9522
|
# generate the LTOs
|
|
8544
9523
|
# C += A * B
|
|
8545
|
-
(fun_forward, lto_forward) = warp.build.build_lto_dot(
|
|
9524
|
+
(fun_forward, lto_forward) = warp._src.build.build_lto_dot(
|
|
8546
9525
|
M,
|
|
8547
9526
|
N,
|
|
8548
9527
|
K,
|
|
@@ -8558,7 +9537,7 @@ def tile_matmul_lto_dispatch_func(
|
|
|
8558
9537
|
)
|
|
8559
9538
|
if warp.config.enable_backward:
|
|
8560
9539
|
# adjA += adjC * B^T - Transpose ~= flipped layout
|
|
8561
|
-
(fun_backward_A, lto_backward_A) = warp.build.build_lto_dot(
|
|
9540
|
+
(fun_backward_A, lto_backward_A) = warp._src.build.build_lto_dot(
|
|
8562
9541
|
M,
|
|
8563
9542
|
K,
|
|
8564
9543
|
N,
|
|
@@ -8573,7 +9552,7 @@ def tile_matmul_lto_dispatch_func(
|
|
|
8573
9552
|
builder,
|
|
8574
9553
|
)
|
|
8575
9554
|
# adjB += A^T * adjC - Transpose ~= flipped layout
|
|
8576
|
-
(fun_backward_B, lto_backward_B) = warp.build.build_lto_dot(
|
|
9555
|
+
(fun_backward_B, lto_backward_B) = warp._src.build.build_lto_dot(
|
|
8577
9556
|
K,
|
|
8578
9557
|
N,
|
|
8579
9558
|
M,
|
|
@@ -8690,7 +9669,7 @@ def tile_fft_generic_lto_dispatch_func(
|
|
|
8690
9669
|
return_values: List[Var],
|
|
8691
9670
|
arg_values: Mapping[str, Var],
|
|
8692
9671
|
options: Mapping[str, Any],
|
|
8693
|
-
builder: warp.context.ModuleBuilder,
|
|
9672
|
+
builder: warp._src.context.ModuleBuilder,
|
|
8694
9673
|
direction: str | None = None,
|
|
8695
9674
|
):
|
|
8696
9675
|
inout = arg_values["inout"]
|
|
@@ -8719,12 +9698,12 @@ def tile_fft_generic_lto_dispatch_func(
|
|
|
8719
9698
|
arch = options["output_arch"]
|
|
8720
9699
|
ept = size // num_threads
|
|
8721
9700
|
|
|
8722
|
-
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
9701
|
+
if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
|
|
8723
9702
|
# CPU/no-MathDx dispatch
|
|
8724
9703
|
return ([], [], [], 0)
|
|
8725
9704
|
else:
|
|
8726
9705
|
# generate the LTO
|
|
8727
|
-
lto_symbol, lto_code_data, shared_memory_bytes = warp.build.build_lto_fft(
|
|
9706
|
+
lto_symbol, lto_code_data, shared_memory_bytes = warp._src.build.build_lto_fft(
|
|
8728
9707
|
arch, size, ept, direction, dir, precision, builder
|
|
8729
9708
|
)
|
|
8730
9709
|
|
|
@@ -8762,7 +9741,7 @@ add_builtin(
|
|
|
8762
9741
|
group="Tile Primitives",
|
|
8763
9742
|
export=False,
|
|
8764
9743
|
namespace="",
|
|
8765
|
-
|
|
9744
|
+
is_differentiable=False,
|
|
8766
9745
|
)
|
|
8767
9746
|
|
|
8768
9747
|
add_builtin(
|
|
@@ -8784,7 +9763,7 @@ add_builtin(
|
|
|
8784
9763
|
group="Tile Primitives",
|
|
8785
9764
|
export=False,
|
|
8786
9765
|
namespace="",
|
|
8787
|
-
|
|
9766
|
+
is_differentiable=False,
|
|
8788
9767
|
)
|
|
8789
9768
|
|
|
8790
9769
|
|
|
@@ -8829,7 +9808,7 @@ def tile_cholesky_generic_lto_dispatch_func(
|
|
|
8829
9808
|
return_values: List[Var],
|
|
8830
9809
|
arg_values: Mapping[str, Var],
|
|
8831
9810
|
options: Mapping[str, Any],
|
|
8832
|
-
builder: warp.context.ModuleBuilder,
|
|
9811
|
+
builder: warp._src.context.ModuleBuilder,
|
|
8833
9812
|
):
|
|
8834
9813
|
a = arg_values["A"]
|
|
8835
9814
|
# force source tile to shared memory
|
|
@@ -8849,7 +9828,7 @@ def tile_cholesky_generic_lto_dispatch_func(
|
|
|
8849
9828
|
|
|
8850
9829
|
arch = options["output_arch"]
|
|
8851
9830
|
|
|
8852
|
-
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
9831
|
+
if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
|
|
8853
9832
|
# CPU/no-MathDx dispatch
|
|
8854
9833
|
return ((0, a, out), [], [], 0)
|
|
8855
9834
|
else:
|
|
@@ -8864,7 +9843,7 @@ def tile_cholesky_generic_lto_dispatch_func(
|
|
|
8864
9843
|
req_smem_bytes = a.type.size * type_size_in_bytes(a.type.dtype)
|
|
8865
9844
|
|
|
8866
9845
|
# generate the LTO
|
|
8867
|
-
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
9846
|
+
lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
|
|
8868
9847
|
M,
|
|
8869
9848
|
N,
|
|
8870
9849
|
1,
|
|
@@ -8909,7 +9888,7 @@ add_builtin(
|
|
|
8909
9888
|
group="Tile Primitives",
|
|
8910
9889
|
export=False,
|
|
8911
9890
|
namespace="",
|
|
8912
|
-
|
|
9891
|
+
is_differentiable=False,
|
|
8913
9892
|
)
|
|
8914
9893
|
|
|
8915
9894
|
|
|
@@ -8953,7 +9932,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
|
8953
9932
|
return_values: List[Var],
|
|
8954
9933
|
arg_values: Mapping[str, Var],
|
|
8955
9934
|
options: Mapping[str, Any],
|
|
8956
|
-
builder: warp.context.ModuleBuilder,
|
|
9935
|
+
builder: warp._src.context.ModuleBuilder,
|
|
8957
9936
|
):
|
|
8958
9937
|
L = arg_values["L"]
|
|
8959
9938
|
y = arg_values["y"]
|
|
@@ -8982,7 +9961,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
|
8982
9961
|
|
|
8983
9962
|
arch = options["output_arch"]
|
|
8984
9963
|
|
|
8985
|
-
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
9964
|
+
if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
|
|
8986
9965
|
# CPU/no-MathDx dispatch
|
|
8987
9966
|
return ((0, L, y, x), [], [], 0)
|
|
8988
9967
|
else:
|
|
@@ -8998,7 +9977,7 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
|
8998
9977
|
req_smem_bytes = (x.type.size + y.type.size + L.type.size) * type_size_in_bytes(L.type.dtype)
|
|
8999
9978
|
|
|
9000
9979
|
# generate the LTO
|
|
9001
|
-
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
9980
|
+
lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
|
|
9002
9981
|
M,
|
|
9003
9982
|
N,
|
|
9004
9983
|
NRHS,
|
|
@@ -9040,7 +10019,7 @@ add_builtin(
|
|
|
9040
10019
|
group="Tile Primitives",
|
|
9041
10020
|
export=False,
|
|
9042
10021
|
namespace="",
|
|
9043
|
-
|
|
10022
|
+
is_differentiable=False,
|
|
9044
10023
|
)
|
|
9045
10024
|
|
|
9046
10025
|
|
|
@@ -9050,7 +10029,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
|
|
|
9050
10029
|
return_values: List[Var],
|
|
9051
10030
|
arg_values: Mapping[str, Var],
|
|
9052
10031
|
options: Mapping[str, Any],
|
|
9053
|
-
builder: warp.context.ModuleBuilder,
|
|
10032
|
+
builder: warp._src.context.ModuleBuilder,
|
|
9054
10033
|
):
|
|
9055
10034
|
L = arg_values["L"]
|
|
9056
10035
|
y = arg_values["y"]
|
|
@@ -9079,7 +10058,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
|
|
|
9079
10058
|
|
|
9080
10059
|
arch = options["output_arch"]
|
|
9081
10060
|
|
|
9082
|
-
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
10061
|
+
if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
|
|
9083
10062
|
# CPU/no-MathDx dispatch
|
|
9084
10063
|
return ((0, L, y, z), [], [], 0)
|
|
9085
10064
|
else:
|
|
@@ -9095,7 +10074,7 @@ def tile_lower_solve_generic_lto_dispatch_func(
|
|
|
9095
10074
|
req_smem_bytes = (z.type.size + y.type.size + L.type.size) * type_size_in_bytes(L.type.dtype)
|
|
9096
10075
|
|
|
9097
10076
|
# generate the LTO
|
|
9098
|
-
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
10077
|
+
lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
|
|
9099
10078
|
M,
|
|
9100
10079
|
N,
|
|
9101
10080
|
NRHS,
|
|
@@ -9173,7 +10152,7 @@ add_builtin(
|
|
|
9173
10152
|
group="Tile Primitives",
|
|
9174
10153
|
export=False,
|
|
9175
10154
|
namespace="",
|
|
9176
|
-
|
|
10155
|
+
is_differentiable=False,
|
|
9177
10156
|
)
|
|
9178
10157
|
|
|
9179
10158
|
|
|
@@ -9183,7 +10162,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
|
|
|
9183
10162
|
return_values: List[Var],
|
|
9184
10163
|
arg_values: Mapping[str, Var],
|
|
9185
10164
|
options: Mapping[str, Any],
|
|
9186
|
-
builder: warp.context.ModuleBuilder,
|
|
10165
|
+
builder: warp._src.context.ModuleBuilder,
|
|
9187
10166
|
):
|
|
9188
10167
|
U = arg_values["U"]
|
|
9189
10168
|
z = arg_values["z"]
|
|
@@ -9212,7 +10191,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
|
|
|
9212
10191
|
|
|
9213
10192
|
arch = options["output_arch"]
|
|
9214
10193
|
|
|
9215
|
-
if arch is None or not warp.context.runtime.core.wp_is_mathdx_enabled():
|
|
10194
|
+
if arch is None or not warp._src.context.runtime.core.wp_is_mathdx_enabled():
|
|
9216
10195
|
# CPU/no-MathDx dispatch
|
|
9217
10196
|
return ((0, U, z, x), [], [], 0)
|
|
9218
10197
|
else:
|
|
@@ -9228,7 +10207,7 @@ def tile_upper_solve_generic_lto_dispatch_func(
|
|
|
9228
10207
|
req_smem_bytes = (x.type.size + z.type.size + U.type.size) * type_size_in_bytes(U.type.dtype)
|
|
9229
10208
|
|
|
9230
10209
|
# generate the LTO
|
|
9231
|
-
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
10210
|
+
lto_symbol, lto_code_data = warp._src.build.build_lto_solver(
|
|
9232
10211
|
M,
|
|
9233
10212
|
N,
|
|
9234
10213
|
NRHS,
|
|
@@ -9306,7 +10285,7 @@ add_builtin(
|
|
|
9306
10285
|
group="Tile Primitives",
|
|
9307
10286
|
export=False,
|
|
9308
10287
|
namespace="",
|
|
9309
|
-
|
|
10288
|
+
is_differentiable=False,
|
|
9310
10289
|
)
|
|
9311
10290
|
|
|
9312
10291
|
|
|
@@ -9326,7 +10305,7 @@ add_builtin(
|
|
|
9326
10305
|
The return type of the expression must be either a Warp function, a string, or a type that is supported inside Warp kernels and functions
|
|
9327
10306
|
(excluding Warp arrays since they cannot be created in a Warp kernel at the moment).""",
|
|
9328
10307
|
group="Code Generation",
|
|
9329
|
-
|
|
10308
|
+
is_differentiable=False,
|
|
9330
10309
|
)
|
|
9331
10310
|
|
|
9332
10311
|
|
|
@@ -9351,7 +10330,7 @@ add_builtin(
|
|
|
9351
10330
|
doc="Return the number of elements in a vector.",
|
|
9352
10331
|
group="Utility",
|
|
9353
10332
|
export=False,
|
|
9354
|
-
|
|
10333
|
+
is_differentiable=False,
|
|
9355
10334
|
)
|
|
9356
10335
|
|
|
9357
10336
|
add_builtin(
|
|
@@ -9361,7 +10340,7 @@ add_builtin(
|
|
|
9361
10340
|
doc="Return the number of elements in a quaternion.",
|
|
9362
10341
|
group="Utility",
|
|
9363
10342
|
export=False,
|
|
9364
|
-
|
|
10343
|
+
is_differentiable=False,
|
|
9365
10344
|
)
|
|
9366
10345
|
|
|
9367
10346
|
add_builtin(
|
|
@@ -9371,7 +10350,7 @@ add_builtin(
|
|
|
9371
10350
|
doc="Return the number of rows in a matrix.",
|
|
9372
10351
|
group="Utility",
|
|
9373
10352
|
export=False,
|
|
9374
|
-
|
|
10353
|
+
is_differentiable=False,
|
|
9375
10354
|
)
|
|
9376
10355
|
|
|
9377
10356
|
add_builtin(
|
|
@@ -9381,7 +10360,7 @@ add_builtin(
|
|
|
9381
10360
|
doc="Return the number of elements in a transformation.",
|
|
9382
10361
|
group="Utility",
|
|
9383
10362
|
export=False,
|
|
9384
|
-
|
|
10363
|
+
is_differentiable=False,
|
|
9385
10364
|
)
|
|
9386
10365
|
|
|
9387
10366
|
add_builtin(
|
|
@@ -9391,7 +10370,7 @@ add_builtin(
|
|
|
9391
10370
|
doc="Return the size of the first dimension in an array.",
|
|
9392
10371
|
group="Utility",
|
|
9393
10372
|
export=False,
|
|
9394
|
-
|
|
10373
|
+
is_differentiable=False,
|
|
9395
10374
|
)
|
|
9396
10375
|
|
|
9397
10376
|
add_builtin(
|
|
@@ -9401,7 +10380,33 @@ add_builtin(
|
|
|
9401
10380
|
doc="Return the number of rows in a tile.",
|
|
9402
10381
|
group="Utility",
|
|
9403
10382
|
export=False,
|
|
9404
|
-
|
|
10383
|
+
is_differentiable=False,
|
|
10384
|
+
)
|
|
10385
|
+
|
|
10386
|
+
|
|
10387
|
+
def cast_value_func(arg_types, arg_values):
|
|
10388
|
+
# Return generic type for doc builds.
|
|
10389
|
+
if arg_types is None:
|
|
10390
|
+
return Any
|
|
10391
|
+
|
|
10392
|
+
return arg_values["dtype"]
|
|
10393
|
+
|
|
10394
|
+
|
|
10395
|
+
def cast_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
10396
|
+
func_args = (args["a"],)
|
|
10397
|
+
template_args = (args["dtype"],)
|
|
10398
|
+
return (func_args, template_args)
|
|
10399
|
+
|
|
10400
|
+
|
|
10401
|
+
add_builtin(
|
|
10402
|
+
"cast",
|
|
10403
|
+
input_types={"a": Any, "dtype": Any},
|
|
10404
|
+
value_func=cast_value_func,
|
|
10405
|
+
dispatch_func=cast_dispatch_func,
|
|
10406
|
+
doc="Reinterpret a value as a different type while preserving its bit pattern.",
|
|
10407
|
+
group="Utility",
|
|
10408
|
+
export=False,
|
|
10409
|
+
is_differentiable=False,
|
|
9405
10410
|
)
|
|
9406
10411
|
|
|
9407
10412
|
|
|
@@ -9428,7 +10433,7 @@ add_builtin(
|
|
|
9428
10433
|
doc="Construct a tuple from a list of values",
|
|
9429
10434
|
group="Utility",
|
|
9430
10435
|
hidden=True,
|
|
9431
|
-
|
|
10436
|
+
is_differentiable=False,
|
|
9432
10437
|
export=False,
|
|
9433
10438
|
)
|
|
9434
10439
|
|
|
@@ -9465,7 +10470,7 @@ add_builtin(
|
|
|
9465
10470
|
dispatch_func=tuple_extract_dispatch_func,
|
|
9466
10471
|
group="Utility",
|
|
9467
10472
|
hidden=True,
|
|
9468
|
-
|
|
10473
|
+
is_differentiable=False,
|
|
9469
10474
|
)
|
|
9470
10475
|
|
|
9471
10476
|
|
|
@@ -9476,7 +10481,7 @@ add_builtin(
|
|
|
9476
10481
|
doc="Return the number of elements in a tuple.",
|
|
9477
10482
|
group="Utility",
|
|
9478
10483
|
export=False,
|
|
9479
|
-
|
|
10484
|
+
is_differentiable=False,
|
|
9480
10485
|
)
|
|
9481
10486
|
|
|
9482
10487
|
# ---------------------------------
|
|
@@ -9495,5 +10500,5 @@ add_builtin(
|
|
|
9495
10500
|
export=False,
|
|
9496
10501
|
group="Utility",
|
|
9497
10502
|
hidden=True,
|
|
9498
|
-
|
|
10503
|
+
is_differentiable=False,
|
|
9499
10504
|
)
|