warp-lang 1.9.1__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +301 -287
- warp/__init__.pyi +794 -305
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1075 -0
- warp/_src/build.py +618 -0
- warp/_src/build_dll.py +640 -0
- warp/{builtins.py → _src/builtins.py} +1382 -377
- warp/_src/codegen.py +4359 -0
- warp/{config.py → _src/config.py} +178 -169
- warp/_src/constants.py +57 -0
- warp/_src/context.py +8294 -0
- warp/_src/dlpack.py +462 -0
- warp/_src/fabric.py +355 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +508 -0
- warp/_src/fem/cache.py +687 -0
- warp/_src/fem/dirichlet.py +188 -0
- warp/{fem → _src/fem}/domain.py +40 -30
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +701 -0
- warp/{fem → _src/fem}/field/nodal_field.py +30 -15
- warp/{fem → _src/fem}/field/restriction.py +1 -1
- warp/{fem → _src/fem}/field/virtual.py +53 -27
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
- warp/_src/fem/geometry/closest_point.py +97 -0
- warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
- warp/{fem → _src/fem}/geometry/element.py +32 -10
- warp/{fem → _src/fem}/geometry/geometry.py +48 -20
- warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
- warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
- warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
- warp/{fem → _src/fem}/geometry/partition.py +121 -63
- warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
- warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
- warp/{fem → _src/fem}/integrate.py +164 -158
- warp/_src/fem/linalg.py +383 -0
- warp/_src/fem/operator.py +396 -0
- warp/_src/fem/polynomial.py +229 -0
- warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
- warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
- warp/_src/fem/space/__init__.py +248 -0
- warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
- warp/_src/fem/space/basis_space.py +679 -0
- warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
- warp/{fem → _src/fem}/space/function_space.py +14 -13
- warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
- warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
- warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
- warp/{fem → _src/fem}/space/partition.py +117 -60
- warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/restriction.py +66 -33
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
- warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
- warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
- warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
- warp/_src/fem/space/topology.py +459 -0
- warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
- warp/_src/fem/types.py +112 -0
- warp/_src/fem/utils.py +486 -0
- warp/_src/jax.py +186 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +387 -0
- warp/_src/jax_experimental/ffi.py +1284 -0
- warp/_src/jax_experimental/xla_ffi.py +656 -0
- warp/_src/marching_cubes.py +708 -0
- warp/_src/math.py +414 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +163 -0
- warp/_src/optim/linear.py +1606 -0
- warp/_src/optim/sgd.py +112 -0
- warp/_src/paddle.py +406 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +289 -0
- warp/_src/render/render_opengl.py +3636 -0
- warp/_src/render/render_usd.py +937 -0
- warp/_src/render/utils.py +160 -0
- warp/_src/sparse.py +2716 -0
- warp/_src/tape.py +1206 -0
- warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
- warp/_src/torch.py +391 -0
- warp/_src/types.py +5870 -0
- warp/_src/utils.py +1693 -0
- warp/autograd.py +12 -1054
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +8 -588
- warp/build_dll.py +6 -721
- warp/codegen.py +6 -4251
- warp/constants.py +6 -39
- warp/context.py +12 -8062
- warp/dlpack.py +6 -444
- warp/examples/distributed/example_jacobi_mpi.py +4 -5
- warp/examples/fem/example_adaptive_grid.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +8 -8
- warp/examples/fem/example_diffusion.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_mixed_elasticity.py +2 -2
- warp/examples/fem/example_navier_stokes.py +1 -1
- warp/examples/fem/example_nonconforming_contact.py +7 -7
- warp/examples/fem/example_stokes.py +1 -1
- warp/examples/fem/example_stokes_transfer.py +1 -1
- warp/examples/fem/utils.py +2 -2
- warp/examples/interop/example_jax_callable.py +1 -1
- warp/examples/interop/example_jax_ffi_callback.py +1 -1
- warp/examples/interop/example_jax_kernel.py +1 -1
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/fabric.py +6 -337
- warp/fem/__init__.py +159 -97
- warp/fem/adaptivity.py +7 -489
- warp/fem/cache.py +9 -648
- warp/fem/dirichlet.py +6 -184
- warp/fem/field/__init__.py +8 -109
- warp/fem/field/field.py +7 -652
- warp/fem/geometry/__init__.py +7 -18
- warp/fem/geometry/closest_point.py +11 -77
- warp/fem/linalg.py +18 -366
- warp/fem/operator.py +11 -369
- warp/fem/polynomial.py +9 -209
- warp/fem/space/__init__.py +5 -211
- warp/fem/space/basis_space.py +6 -662
- warp/fem/space/shape/__init__.py +41 -118
- warp/fem/space/topology.py +6 -437
- warp/fem/types.py +6 -81
- warp/fem/utils.py +11 -444
- warp/jax.py +8 -165
- warp/jax_experimental/__init__.py +14 -1
- warp/jax_experimental/custom_call.py +8 -365
- warp/jax_experimental/ffi.py +17 -873
- warp/jax_experimental/xla_ffi.py +5 -605
- warp/marching_cubes.py +5 -689
- warp/math.py +16 -393
- warp/native/array.h +385 -37
- warp/native/builtin.h +314 -37
- warp/native/bvh.cpp +43 -9
- warp/native/bvh.cu +62 -27
- warp/native/bvh.h +310 -309
- warp/native/clang/clang.cpp +102 -97
- warp/native/coloring.cpp +0 -1
- warp/native/crt.h +208 -0
- warp/native/exports.h +156 -0
- warp/native/hashgrid.cu +2 -0
- warp/native/intersect.h +24 -1
- warp/native/intersect_tri.h +44 -35
- warp/native/mat.h +1456 -276
- warp/native/mesh.cpp +4 -4
- warp/native/mesh.cu +4 -2
- warp/native/mesh.h +176 -61
- warp/native/quat.h +0 -52
- warp/native/scan.cu +2 -0
- warp/native/sparse.cu +7 -3
- warp/native/spatial.h +12 -0
- warp/native/tile.h +681 -89
- warp/native/tile_radix_sort.h +1 -1
- warp/native/tile_reduce.h +394 -46
- warp/native/tile_scan.h +4 -4
- warp/native/vec.h +469 -0
- warp/native/version.h +23 -0
- warp/native/volume.cpp +1 -1
- warp/native/volume.cu +1 -0
- warp/native/volume.h +1 -1
- warp/native/volume_builder.cu +2 -0
- warp/native/warp.cpp +57 -29
- warp/native/warp.cu +253 -171
- warp/native/warp.h +11 -8
- warp/optim/__init__.py +6 -3
- warp/optim/adam.py +6 -145
- warp/optim/linear.py +14 -1585
- warp/optim/sgd.py +6 -94
- warp/paddle.py +6 -388
- warp/render/__init__.py +8 -4
- warp/render/imgui_manager.py +7 -267
- warp/render/render_opengl.py +6 -3618
- warp/render/render_usd.py +6 -919
- warp/render/utils.py +6 -142
- warp/sparse.py +37 -2563
- warp/tape.py +6 -1188
- warp/tests/__main__.py +1 -1
- warp/tests/cuda/test_async.py +4 -4
- warp/tests/cuda/test_conditional_captures.py +1 -1
- warp/tests/cuda/test_multigpu.py +1 -1
- warp/tests/cuda/test_streams.py +58 -1
- warp/tests/geometry/test_bvh.py +157 -22
- warp/tests/geometry/test_marching_cubes.py +0 -1
- warp/tests/geometry/test_mesh.py +5 -3
- warp/tests/geometry/test_mesh_query_aabb.py +5 -12
- warp/tests/geometry/test_mesh_query_point.py +5 -2
- warp/tests/geometry/test_mesh_query_ray.py +15 -3
- warp/tests/geometry/test_volume_write.py +5 -5
- warp/tests/interop/test_dlpack.py +14 -14
- warp/tests/interop/test_jax.py +772 -49
- warp/tests/interop/test_paddle.py +1 -1
- warp/tests/test_adam.py +0 -1
- warp/tests/test_arithmetic.py +9 -9
- warp/tests/test_array.py +527 -100
- warp/tests/test_array_reduce.py +3 -3
- warp/tests/test_atomic.py +12 -8
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +4 -4
- warp/tests/test_bool.py +2 -2
- warp/tests/test_builtins_resolution.py +5 -571
- warp/tests/test_codegen.py +33 -14
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_context.py +6 -6
- warp/tests/test_copy.py +242 -161
- warp/tests/test_ctypes.py +3 -3
- warp/tests/test_devices.py +24 -2
- warp/tests/test_examples.py +16 -84
- warp/tests/test_fabricarray.py +35 -35
- warp/tests/test_fast_math.py +0 -2
- warp/tests/test_fem.py +56 -10
- warp/tests/test_fixedarray.py +3 -3
- warp/tests/test_func.py +8 -5
- warp/tests/test_generics.py +1 -1
- warp/tests/test_indexedarray.py +24 -24
- warp/tests/test_intersect.py +39 -9
- warp/tests/test_large.py +1 -1
- warp/tests/test_lerp.py +3 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_map.py +35 -4
- warp/tests/test_mat.py +52 -62
- warp/tests/test_mat_constructors.py +4 -5
- warp/tests/test_mat_lite.py +1 -1
- warp/tests/test_mat_scalar_ops.py +121 -121
- warp/tests/test_math.py +34 -0
- warp/tests/test_module_aot.py +4 -4
- warp/tests/test_modules_lite.py +28 -2
- warp/tests/test_print.py +11 -11
- warp/tests/test_quat.py +93 -58
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +38 -10
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +126 -15
- warp/tests/test_spatial.py +105 -87
- warp/tests/test_special_values.py +6 -6
- warp/tests/test_static.py +7 -7
- warp/tests/test_struct.py +13 -2
- warp/tests/test_triangle_closest_point.py +48 -1
- warp/tests/test_types.py +27 -15
- warp/tests/test_utils.py +52 -52
- warp/tests/test_vec.py +29 -29
- warp/tests/test_vec_constructors.py +5 -5
- warp/tests/test_vec_scalar_ops.py +97 -97
- warp/tests/test_version.py +75 -0
- warp/tests/tile/test_tile.py +178 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +7 -4
- warp/tests/tile/test_tile_load.py +26 -2
- warp/tests/tile/test_tile_mathdx.py +3 -3
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +2 -4
- warp/tests/tile/test_tile_reduce.py +214 -13
- warp/tests/unittest_suites.py +6 -14
- warp/tests/unittest_utils.py +10 -9
- warp/tests/walkthrough_debug.py +3 -1
- warp/torch.py +6 -373
- warp/types.py +29 -5764
- warp/utils.py +10 -1659
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
- warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp/examples/assets/cartpole.urdf +0 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/nv_ant.xml +0 -92
- warp/examples/assets/nv_humanoid.xml +0 -183
- warp/examples/assets/quadruped.urdf +0 -268
- warp/examples/optim/example_bounce.py +0 -266
- warp/examples/optim/example_cloth_throw.py +0 -228
- warp/examples/optim/example_drone.py +0 -870
- warp/examples/optim/example_inverse_kinematics.py +0 -182
- warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
- warp/examples/optim/example_softbody_properties.py +0 -400
- warp/examples/optim/example_spring_cage.py +0 -245
- warp/examples/optim/example_trajectory.py +0 -227
- warp/examples/sim/example_cartpole.py +0 -143
- warp/examples/sim/example_cloth.py +0 -225
- warp/examples/sim/example_cloth_self_contact.py +0 -316
- warp/examples/sim/example_granular.py +0 -130
- warp/examples/sim/example_granular_collision_sdf.py +0 -202
- warp/examples/sim/example_jacobian_ik.py +0 -244
- warp/examples/sim/example_particle_chain.py +0 -124
- warp/examples/sim/example_quadruped.py +0 -203
- warp/examples/sim/example_rigid_chain.py +0 -203
- warp/examples/sim/example_rigid_contact.py +0 -195
- warp/examples/sim/example_rigid_force.py +0 -133
- warp/examples/sim/example_rigid_gyroscopic.py +0 -115
- warp/examples/sim/example_rigid_soft_contact.py +0 -140
- warp/examples/sim/example_soft_body.py +0 -196
- warp/examples/tile/example_tile_walker.py +0 -327
- warp/sim/__init__.py +0 -74
- warp/sim/articulation.py +0 -793
- warp/sim/collide.py +0 -2570
- warp/sim/graph_coloring.py +0 -307
- warp/sim/import_mjcf.py +0 -791
- warp/sim/import_snu.py +0 -227
- warp/sim/import_urdf.py +0 -579
- warp/sim/import_usd.py +0 -898
- warp/sim/inertia.py +0 -357
- warp/sim/integrator.py +0 -245
- warp/sim/integrator_euler.py +0 -2000
- warp/sim/integrator_featherstone.py +0 -2101
- warp/sim/integrator_vbd.py +0 -2487
- warp/sim/integrator_xpbd.py +0 -3295
- warp/sim/model.py +0 -4821
- warp/sim/particles.py +0 -121
- warp/sim/render.py +0 -431
- warp/sim/utils.py +0 -431
- warp/tests/sim/disabled_kinematics.py +0 -244
- warp/tests/sim/test_cloth.py +0 -863
- warp/tests/sim/test_collision.py +0 -743
- warp/tests/sim/test_coloring.py +0 -347
- warp/tests/sim/test_inertia.py +0 -161
- warp/tests/sim/test_model.py +0 -226
- warp/tests/sim/test_sim_grad.py +0 -287
- warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
- warp/tests/sim/test_sim_kinematics.py +0 -98
- warp/thirdparty/__init__.py +0 -0
- warp_lang-1.9.1.dist-info/RECORD +0 -456
- /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
- /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/native/tile.h
CHANGED
|
@@ -43,18 +43,10 @@
|
|
|
43
43
|
};
|
|
44
44
|
#endif
|
|
45
45
|
|
|
46
|
-
|
|
47
|
-
#ifndef WP_TILE_BLOCK_DIM
|
|
48
|
-
#define WP_TILE_BLOCK_DIM 256
|
|
49
|
-
#endif
|
|
50
|
-
|
|
51
|
-
#if !defined(__CUDA_ARCH__)
|
|
52
|
-
#define WP_TILE_SHARED static
|
|
53
|
-
#define WP_TILE_SYNC void
|
|
54
|
-
|
|
55
|
-
#else
|
|
56
|
-
#define WP_TILE_SHARED __shared__
|
|
46
|
+
#if defined(__CUDA_ARCH__)
|
|
57
47
|
#define WP_TILE_SYNC __syncthreads
|
|
48
|
+
#else
|
|
49
|
+
#define WP_TILE_SYNC void
|
|
58
50
|
#endif
|
|
59
51
|
|
|
60
52
|
#if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__)
|
|
@@ -140,7 +132,6 @@
|
|
|
140
132
|
[ ] LayerNorm
|
|
141
133
|
[ ] SoftMax
|
|
142
134
|
[ ] GEMM
|
|
143
|
-
[ ] warp.sim (CRBA)
|
|
144
135
|
[ ] Batched MLP
|
|
145
136
|
[ ] Layer norm
|
|
146
137
|
[ ] FNO + Burgers equation
|
|
@@ -149,7 +140,6 @@
|
|
|
149
140
|
[ ] MeshCNN (Modulus, Oliver)
|
|
150
141
|
[ ] BioNemo (Ali)
|
|
151
142
|
[ ] Skinning (David/Or/Vismay)
|
|
152
|
-
[ ] warp.sim (VBD)
|
|
153
143
|
[ ] Error checking
|
|
154
144
|
[ ] Ensure functions passed to tile_map() are compatible with tile type
|
|
155
145
|
[ ] Ensure that args passed to tile ops are compatible
|
|
@@ -213,6 +203,12 @@ struct is_same<T, T> {
|
|
|
213
203
|
static constexpr bool value = true;
|
|
214
204
|
};
|
|
215
205
|
|
|
206
|
+
// Helper for dependent static_assert failures
|
|
207
|
+
template <typename T>
|
|
208
|
+
struct always_false {
|
|
209
|
+
static constexpr bool value = false;
|
|
210
|
+
};
|
|
211
|
+
|
|
216
212
|
|
|
217
213
|
template <int N>
|
|
218
214
|
struct tile_coord_t
|
|
@@ -338,6 +334,113 @@ template <int... V>
|
|
|
338
334
|
using tile_stride_t = tile_tuple_t<V...>;
|
|
339
335
|
|
|
340
336
|
|
|
337
|
+
// helper to remove a dimension from a shape (used for axis reductions)
|
|
338
|
+
template<int Axis, typename Shape>
|
|
339
|
+
struct tile_shape_remove_dim {
|
|
340
|
+
static_assert(Axis >= 0 && Axis < Shape::N, "Axis out of bounds for tile_shape_remove_dim");
|
|
341
|
+
};
|
|
342
|
+
|
|
343
|
+
// 1D -> scalar
|
|
344
|
+
template<int D0>
|
|
345
|
+
struct tile_shape_remove_dim<0, tile_shape_t<D0>> {
|
|
346
|
+
using type = tile_shape_t<1>;
|
|
347
|
+
};
|
|
348
|
+
|
|
349
|
+
// 2D -> 1D
|
|
350
|
+
template<int D0, int D1>
|
|
351
|
+
struct tile_shape_remove_dim<0, tile_shape_t<D0, D1>> {
|
|
352
|
+
using type = tile_shape_t<D1>;
|
|
353
|
+
};
|
|
354
|
+
|
|
355
|
+
template<int D0, int D1>
|
|
356
|
+
struct tile_shape_remove_dim<1, tile_shape_t<D0, D1>> {
|
|
357
|
+
using type = tile_shape_t<D0>;
|
|
358
|
+
};
|
|
359
|
+
|
|
360
|
+
// 3D -> 2D
|
|
361
|
+
template<int D0, int D1, int D2>
|
|
362
|
+
struct tile_shape_remove_dim<0, tile_shape_t<D0, D1, D2>> {
|
|
363
|
+
using type = tile_shape_t<D1, D2>;
|
|
364
|
+
};
|
|
365
|
+
|
|
366
|
+
template<int D0, int D1, int D2>
|
|
367
|
+
struct tile_shape_remove_dim<1, tile_shape_t<D0, D1, D2>> {
|
|
368
|
+
using type = tile_shape_t<D0, D2>;
|
|
369
|
+
};
|
|
370
|
+
|
|
371
|
+
template<int D0, int D1, int D2>
|
|
372
|
+
struct tile_shape_remove_dim<2, tile_shape_t<D0, D1, D2>> {
|
|
373
|
+
using type = tile_shape_t<D0, D1>;
|
|
374
|
+
};
|
|
375
|
+
|
|
376
|
+
// 4D -> 3D
|
|
377
|
+
template<int D0, int D1, int D2, int D3>
|
|
378
|
+
struct tile_shape_remove_dim<0, tile_shape_t<D0, D1, D2, D3>> {
|
|
379
|
+
using type = tile_shape_t<D1, D2, D3>;
|
|
380
|
+
};
|
|
381
|
+
|
|
382
|
+
template<int D0, int D1, int D2, int D3>
|
|
383
|
+
struct tile_shape_remove_dim<1, tile_shape_t<D0, D1, D2, D3>> {
|
|
384
|
+
using type = tile_shape_t<D0, D2, D3>;
|
|
385
|
+
};
|
|
386
|
+
|
|
387
|
+
template<int D0, int D1, int D2, int D3>
|
|
388
|
+
struct tile_shape_remove_dim<2, tile_shape_t<D0, D1, D2, D3>> {
|
|
389
|
+
using type = tile_shape_t<D0, D1, D3>;
|
|
390
|
+
};
|
|
391
|
+
|
|
392
|
+
template<int D0, int D1, int D2, int D3>
|
|
393
|
+
struct tile_shape_remove_dim<3, tile_shape_t<D0, D1, D2, D3>> {
|
|
394
|
+
using type = tile_shape_t<D0, D1, D2>;
|
|
395
|
+
};
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
// helper to insert an axis value into a coordinate (inverse of removing dimension)
|
|
399
|
+
// used for mapping output coordinates back to input coordinates during axis reduction
|
|
400
|
+
template<int Axis, int N>
|
|
401
|
+
CUDA_CALLABLE constexpr auto tile_coord_insert_axis(const tile_coord_t<N>& coord, int axis_val)
|
|
402
|
+
{
|
|
403
|
+
static_assert(Axis >= 0 && Axis <= N, "Axis out of bounds for tile_coord_insert_axis");
|
|
404
|
+
|
|
405
|
+
if constexpr (N == 0)
|
|
406
|
+
{
|
|
407
|
+
// Scalar -> 1D
|
|
408
|
+
static_assert(Axis == 0, "Invalid axis for scalar coordinate");
|
|
409
|
+
return tile_coord(axis_val);
|
|
410
|
+
}
|
|
411
|
+
else if constexpr (N == 1)
|
|
412
|
+
{
|
|
413
|
+
// 1D -> 2D
|
|
414
|
+
if constexpr (Axis == 0)
|
|
415
|
+
return tile_coord(axis_val, coord[0]);
|
|
416
|
+
else
|
|
417
|
+
return tile_coord(coord[0], axis_val);
|
|
418
|
+
}
|
|
419
|
+
else if constexpr (N == 2)
|
|
420
|
+
{
|
|
421
|
+
// 2D -> 3D
|
|
422
|
+
if constexpr (Axis == 0)
|
|
423
|
+
return tile_coord(axis_val, coord[0], coord[1]);
|
|
424
|
+
else if constexpr (Axis == 1)
|
|
425
|
+
return tile_coord(coord[0], axis_val, coord[1]);
|
|
426
|
+
else
|
|
427
|
+
return tile_coord(coord[0], coord[1], axis_val);
|
|
428
|
+
}
|
|
429
|
+
else // N == 3
|
|
430
|
+
{
|
|
431
|
+
// 3D -> 4D
|
|
432
|
+
if constexpr (Axis == 0)
|
|
433
|
+
return tile_coord(axis_val, coord[0], coord[1], coord[2]);
|
|
434
|
+
else if constexpr (Axis == 1)
|
|
435
|
+
return tile_coord(coord[0], axis_val, coord[1], coord[2]);
|
|
436
|
+
else if constexpr (Axis == 2)
|
|
437
|
+
return tile_coord(coord[0], coord[1], axis_val, coord[2]);
|
|
438
|
+
else
|
|
439
|
+
return tile_coord(coord[0], coord[1], coord[2], axis_val);
|
|
440
|
+
}
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
|
|
341
444
|
// represents a tile stored in global memory with dynamic strides
|
|
342
445
|
// used to represent the source and offset for tile loads to register/shared
|
|
343
446
|
// BoundsCheck: when true (default), validates array access bounds; when false, skips validation for performance
|
|
@@ -581,7 +684,11 @@ struct tile_register_t
|
|
|
581
684
|
const int thread = Layout::thread_from_linear(linear);
|
|
582
685
|
const int reg = Layout::register_from_linear(linear);
|
|
583
686
|
|
|
584
|
-
|
|
687
|
+
#if defined(__CUDA_ARCH__)
|
|
688
|
+
__shared__ Type scratch;
|
|
689
|
+
#else
|
|
690
|
+
Type scratch;
|
|
691
|
+
#endif
|
|
585
692
|
|
|
586
693
|
// ensure any previously scheduled threads have finished reading from scratch
|
|
587
694
|
WP_TILE_SYNC();
|
|
@@ -735,43 +842,124 @@ inline CUDA_CALLABLE int tile_align(int num_bytes)
|
|
|
735
842
|
return sign * ((num_bytes_abs + alignment - 1) / alignment) * alignment;
|
|
736
843
|
}
|
|
737
844
|
|
|
738
|
-
|
|
845
|
+
#if defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
|
|
846
|
+
// On the CPU we use a fixed size block of stack memory for shared tile allocations.
|
|
847
|
+
// We store a pointer to the current allocation storage either in a reserved register
|
|
848
|
+
// (AArch64) or a static variable (x86-64).
|
|
849
|
+
#if !defined(__CUDA_ARCH__)
|
|
850
|
+
class tile_shared_storage_t;
|
|
851
|
+
#if defined(__aarch64__)
|
|
852
|
+
// x28 is is the last callee-saved register on AArch64. This allows us to call externally
|
|
853
|
+
// compiled functions without worrying about clobbering the pointer.
|
|
854
|
+
// We pass -target-feature +reserve-x28 to Clang to exclude it from register allocation.
|
|
855
|
+
register tile_shared_storage_t* shared_tile_storage asm("x28");
|
|
856
|
+
#else
|
|
857
|
+
// Ideally this would be thread_local, but LLVM's JIT doesn't support TLS yet
|
|
858
|
+
// There is also no support for something like -ffixed-r15 either
|
|
859
|
+
static tile_shared_storage_t* shared_tile_storage;
|
|
860
|
+
#endif
|
|
861
|
+
#endif
|
|
862
|
+
#endif
|
|
863
|
+
|
|
864
|
+
// This class manages a block of "shared" memory for use by tiles.
|
|
865
|
+
// On the GPU this maps to dynamic shared memory, while on the CPU we allocate
|
|
866
|
+
// a fixed size block of memory on the stack and manage allocations from it.
|
|
867
|
+
// An instance of this class gets created at the start of a kernel.
|
|
868
|
+
class tile_shared_storage_t
|
|
739
869
|
{
|
|
870
|
+
private:
|
|
871
|
+
#if !defined(__CUDA_ARCH__)
|
|
872
|
+
#define WP_MAX_CPU_SHARED 256*1024
|
|
873
|
+
#if defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
|
|
874
|
+
tile_shared_storage_t* old_value;
|
|
875
|
+
unsigned int smem_base[WP_TILE_BLOCK_DIM];
|
|
876
|
+
char dynamic_smem_base[WP_MAX_CPU_SHARED]; // on CPU allocate a fixed 256k block to use for shared allocs
|
|
877
|
+
#endif
|
|
878
|
+
#endif
|
|
879
|
+
|
|
740
880
|
// we maintain a per-thread offset into dynamic
|
|
741
881
|
// shared memory that allows us to keep track of
|
|
742
882
|
// current use across dynamic function calls
|
|
743
|
-
|
|
883
|
+
static inline CUDA_CALLABLE unsigned int* get_smem_base()
|
|
884
|
+
{
|
|
885
|
+
#if defined(__CUDA_ARCH__)
|
|
886
|
+
__shared__ unsigned int smem_base[WP_TILE_BLOCK_DIM];
|
|
887
|
+
return smem_base;
|
|
888
|
+
#elif defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
|
|
889
|
+
return shared_tile_storage->smem_base;
|
|
890
|
+
#else
|
|
891
|
+
static unsigned int smem_base[WP_TILE_BLOCK_DIM];
|
|
892
|
+
return smem_base;
|
|
893
|
+
#endif
|
|
894
|
+
}
|
|
895
|
+
|
|
896
|
+
static inline CUDA_CALLABLE char* get_dynamic_smem_base()
|
|
897
|
+
{
|
|
898
|
+
#if defined(__CUDA_ARCH__)
|
|
899
|
+
extern __shared__ char dynamic_smem_base[];
|
|
900
|
+
return dynamic_smem_base;
|
|
901
|
+
#elif defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
|
|
902
|
+
return shared_tile_storage->dynamic_smem_base;
|
|
903
|
+
#else
|
|
904
|
+
static char dynamic_smem_base[WP_MAX_CPU_SHARED];
|
|
905
|
+
return dynamic_smem_base;
|
|
906
|
+
#endif
|
|
907
|
+
}
|
|
744
908
|
|
|
745
|
-
|
|
909
|
+
public:
|
|
910
|
+
// cppcheck-suppress uninitMemberVar
|
|
911
|
+
inline CUDA_CALLABLE tile_shared_storage_t()
|
|
746
912
|
{
|
|
913
|
+
#if !defined(__CUDA_ARCH__) && defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
|
|
914
|
+
// On the CPU save a pointer to this instance in a reserved register
|
|
915
|
+
// or static variable so it can be accessed from anywhere within a kernel.
|
|
916
|
+
old_value = shared_tile_storage;
|
|
917
|
+
shared_tile_storage = this;
|
|
918
|
+
#endif
|
|
919
|
+
|
|
920
|
+
init();
|
|
921
|
+
}
|
|
922
|
+
|
|
923
|
+
inline CUDA_CALLABLE ~tile_shared_storage_t()
|
|
924
|
+
{
|
|
925
|
+
check();
|
|
926
|
+
|
|
927
|
+
#if !defined(__CUDA_ARCH__) && defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
|
|
928
|
+
shared_tile_storage = old_value;
|
|
929
|
+
#endif
|
|
930
|
+
}
|
|
931
|
+
|
|
932
|
+
static inline CUDA_CALLABLE void init()
|
|
933
|
+
{
|
|
934
|
+
unsigned int* smem_base = get_smem_base();
|
|
935
|
+
|
|
747
936
|
smem_base[WP_TILE_THREAD_IDX] = 0;
|
|
748
|
-
return nullptr;
|
|
749
937
|
}
|
|
750
|
-
|
|
938
|
+
|
|
939
|
+
static inline CUDA_CALLABLE void check()
|
|
751
940
|
{
|
|
941
|
+
unsigned int* smem_base = get_smem_base();
|
|
942
|
+
|
|
752
943
|
assert(smem_base[WP_TILE_THREAD_IDX] == 0);
|
|
753
|
-
return nullptr;
|
|
754
944
|
}
|
|
755
|
-
|
|
945
|
+
|
|
946
|
+
static inline CUDA_CALLABLE void* alloc(int num_bytes)
|
|
756
947
|
{
|
|
757
|
-
|
|
758
|
-
|
|
948
|
+
unsigned int* smem_base = get_smem_base();
|
|
949
|
+
char* dynamic_smem_base = get_dynamic_smem_base();
|
|
950
|
+
|
|
951
|
+
const unsigned int offset = smem_base[WP_TILE_THREAD_IDX];
|
|
952
|
+
|
|
759
953
|
// one entry per-thread so no need for synchronization
|
|
760
954
|
smem_base[WP_TILE_THREAD_IDX] += tile_align(num_bytes);
|
|
761
|
-
assert(smem_base[WP_TILE_THREAD_IDX] >= 0);
|
|
762
955
|
|
|
763
|
-
#
|
|
764
|
-
|
|
765
|
-
#else
|
|
766
|
-
// on CPU allocate a fixed 256k block to use for shared allocs
|
|
767
|
-
static const int max_cpu_shared = 256*1024;
|
|
768
|
-
static char dynamic_smem_base[max_cpu_shared];
|
|
769
|
-
|
|
770
|
-
assert(smem_base[WP_TILE_THREAD_IDX] <= max_cpu_shared);
|
|
956
|
+
#if !defined(__CUDA_ARCH__)
|
|
957
|
+
assert(smem_base[WP_TILE_THREAD_IDX] <= WP_MAX_CPU_SHARED);
|
|
771
958
|
#endif
|
|
959
|
+
|
|
772
960
|
return &(dynamic_smem_base[offset]);
|
|
773
961
|
}
|
|
774
|
-
}
|
|
962
|
+
};
|
|
775
963
|
|
|
776
964
|
|
|
777
965
|
template <typename Shape_, typename Stride_= typename compute_strides<Shape_>::Stride>
|
|
@@ -939,10 +1127,10 @@ struct tile_shared_t
|
|
|
939
1127
|
{
|
|
940
1128
|
// update our per-thread shared memory allocator
|
|
941
1129
|
if (data.ptr)
|
|
942
|
-
|
|
1130
|
+
tile_shared_storage_t::alloc(-Layout::Size*int(sizeof(T)));
|
|
943
1131
|
|
|
944
1132
|
if (grad.ptr)
|
|
945
|
-
|
|
1133
|
+
tile_shared_storage_t::alloc(-Layout::Size*int(sizeof(T)));
|
|
946
1134
|
}
|
|
947
1135
|
}
|
|
948
1136
|
|
|
@@ -1095,6 +1283,46 @@ struct tile_shared_t
|
|
|
1095
1283
|
adj_x -= grad(c);
|
|
1096
1284
|
}
|
|
1097
1285
|
|
|
1286
|
+
// perform AND between a scalar value and a single tile element
|
|
1287
|
+
inline CUDA_CALLABLE void bit_and_inplace(const typename Layout::Coord& c, const Type& x)
|
|
1288
|
+
{
|
|
1289
|
+
// since multiple threads may access the same element
|
|
1290
|
+
// we need to access using atomic operations
|
|
1291
|
+
wp::atomic_and(&data(c), x);
|
|
1292
|
+
|
|
1293
|
+
WP_TILE_SYNC();
|
|
1294
|
+
}
|
|
1295
|
+
|
|
1296
|
+
// backward of inplace scalar AND
|
|
1297
|
+
inline CUDA_CALLABLE void adj_bit_and_inplace(const typename Layout::Coord& c, Type& adj_x) {}
|
|
1298
|
+
|
|
1299
|
+
|
|
1300
|
+
// perform OR between a scalar value and a single tile element
|
|
1301
|
+
inline CUDA_CALLABLE void bit_or_inplace(const typename Layout::Coord& c, const Type& x)
|
|
1302
|
+
{
|
|
1303
|
+
// since multiple threads may access the same element
|
|
1304
|
+
// we need to access using atomic operations
|
|
1305
|
+
wp::atomic_or(&data(c), x);
|
|
1306
|
+
|
|
1307
|
+
WP_TILE_SYNC();
|
|
1308
|
+
}
|
|
1309
|
+
|
|
1310
|
+
// backward of inplace scalar OR
|
|
1311
|
+
inline CUDA_CALLABLE void adj_bit_or_inplace(const typename Layout::Coord& c, Type& adj_x) {}
|
|
1312
|
+
|
|
1313
|
+
// perform XOR between a scalar value and a single tile element
|
|
1314
|
+
inline CUDA_CALLABLE void bit_xor_inplace(const typename Layout::Coord& c, const Type& x)
|
|
1315
|
+
{
|
|
1316
|
+
// since multiple threads may access the same element
|
|
1317
|
+
// we need to access using atomic operations
|
|
1318
|
+
wp::atomic_xor(&data(c), x);
|
|
1319
|
+
|
|
1320
|
+
WP_TILE_SYNC();
|
|
1321
|
+
}
|
|
1322
|
+
|
|
1323
|
+
// backward of inplace scalar XOR
|
|
1324
|
+
inline CUDA_CALLABLE void adj_bit_xor_inplace(const typename Layout::Coord& c, Type& adj_x) {}
|
|
1325
|
+
|
|
1098
1326
|
// copy register tile to shared
|
|
1099
1327
|
template <typename Tile>
|
|
1100
1328
|
inline CUDA_CALLABLE void assign(const Tile& tile)
|
|
@@ -1549,7 +1777,11 @@ void tile_register_t<T, L>::print() const
|
|
|
1549
1777
|
{
|
|
1550
1778
|
// create a temporary shared tile so that
|
|
1551
1779
|
// we can print it deterministically
|
|
1552
|
-
|
|
1780
|
+
#if defined(__CUDA_ARCH__)
|
|
1781
|
+
__shared__ T smem[L::Size];
|
|
1782
|
+
#else
|
|
1783
|
+
T smem[L::Size];
|
|
1784
|
+
#endif
|
|
1553
1785
|
tile_shared_t<T, tile_layout_strided_t<typename L::Shape>, false> scratch(smem, nullptr);
|
|
1554
1786
|
|
|
1555
1787
|
scratch.assign(*this);
|
|
@@ -1609,37 +1841,6 @@ inline CUDA_CALLABLE void adj_len(const tile_register_t<T,L>& t, const AdjTile&
|
|
|
1609
1841
|
{
|
|
1610
1842
|
}
|
|
1611
1843
|
|
|
1612
|
-
// select specialization for shared tiles
|
|
1613
|
-
template <typename C, typename T, typename LRegister, typename LShared, bool Owner>
|
|
1614
|
-
inline CUDA_CALLABLE auto select(const C& cond, const tile_register_t<T, LRegister>& a, const tile_shared_t<T, LShared, Owner>& b)
|
|
1615
|
-
{
|
|
1616
|
-
// The double NOT operator !! casts to bool without compiler warnings.
|
|
1617
|
-
return (!!cond) ? b.copy_to_register() : a;
|
|
1618
|
-
}
|
|
1619
|
-
|
|
1620
|
-
template <typename C, typename T, typename LRegister, typename LShared, bool Owner>
|
|
1621
|
-
inline CUDA_CALLABLE auto select(const C& cond, const tile_shared_t<T, LShared, Owner>& a, const tile_register_t<T, LRegister>& b)
|
|
1622
|
-
{
|
|
1623
|
-
// The double NOT operator !! casts to bool without compiler warnings.
|
|
1624
|
-
return (!!cond) ? b : a.copy_to_register();
|
|
1625
|
-
}
|
|
1626
|
-
|
|
1627
|
-
template <typename C, typename T, typename L, bool Owner>
|
|
1628
|
-
inline CUDA_CALLABLE auto select(const C& cond, const tile_shared_t<T, L, Owner>& a, const tile_shared_t<T, L, Owner>& b)
|
|
1629
|
-
{
|
|
1630
|
-
// The double NOT operator !! casts to bool without compiler warnings.
|
|
1631
|
-
return (!!cond) ? tile_shared_t<T, L, false>(b.data.ptr, b.grad.ptr) : tile_shared_t<T, L, false>(a.data.ptr, a.grad.ptr);
|
|
1632
|
-
}
|
|
1633
|
-
|
|
1634
|
-
template <typename C, typename T, typename L, bool LOwner, bool ROwner>
|
|
1635
|
-
inline CUDA_CALLABLE auto select(const C& cond, const tile_shared_t<T, L, LOwner>& a, const tile_shared_t<T, L, ROwner>& b)
|
|
1636
|
-
{
|
|
1637
|
-
// The double NOT operator !! casts to bool without compiler warnings.
|
|
1638
|
-
return (!!cond) ? tile_shared_t<T, L, false>(b.data.ptr, b.grad.ptr) : tile_shared_t<T, L, false>(a.data.ptr, a.grad.ptr);
|
|
1639
|
-
}
|
|
1640
|
-
|
|
1641
|
-
// adj_select same as in builtin.h
|
|
1642
|
-
|
|
1643
1844
|
// where specialization for register/shared tiles
|
|
1644
1845
|
template <typename C, typename T, typename LRegister, typename LShared, bool Owner>
|
|
1645
1846
|
inline CUDA_CALLABLE auto where(const C& cond, const tile_register_t<T, LRegister>& a, const tile_shared_t<T, LShared, Owner>& b)
|
|
@@ -1690,7 +1891,7 @@ template <typename T, typename Shape, typename Strides, bool RequiresGrad>
|
|
|
1690
1891
|
inline CUDA_CALLABLE auto tile_alloc_empty()
|
|
1691
1892
|
{
|
|
1692
1893
|
constexpr int size = Shape::size();
|
|
1693
|
-
T* data = (T*)
|
|
1894
|
+
T* data = (T*)tile_shared_storage_t::alloc(size*sizeof(T));
|
|
1694
1895
|
T* grad = nullptr;
|
|
1695
1896
|
|
|
1696
1897
|
#if FP_CHECK
|
|
@@ -1709,7 +1910,7 @@ inline CUDA_CALLABLE auto tile_alloc_empty()
|
|
|
1709
1910
|
|
|
1710
1911
|
if (RequiresGrad)
|
|
1711
1912
|
{
|
|
1712
|
-
grad = (T*)
|
|
1913
|
+
grad = (T*)tile_shared_storage_t::alloc(size*sizeof(T));
|
|
1713
1914
|
|
|
1714
1915
|
for (int i=WP_TILE_THREAD_IDX; i < size; i+= WP_TILE_BLOCK_DIM)
|
|
1715
1916
|
grad[i] = T(0);
|
|
@@ -1887,6 +2088,14 @@ inline CUDA_CALLABLE auto tile_ones()
|
|
|
1887
2088
|
return T(1);
|
|
1888
2089
|
}
|
|
1889
2090
|
|
|
2091
|
+
// value-initialized tile
|
|
2092
|
+
template <typename T, unsigned... Shape>
|
|
2093
|
+
inline CUDA_CALLABLE auto tile_full(T x)
|
|
2094
|
+
{
|
|
2095
|
+
// tile variable assignment operator will handle initialization (since lhs could be shared/register tile)
|
|
2096
|
+
return x;
|
|
2097
|
+
}
|
|
2098
|
+
|
|
1890
2099
|
// tile with evenly spaced values
|
|
1891
2100
|
template <typename T, int Len>
|
|
1892
2101
|
inline CUDA_CALLABLE auto tile_arange(T start, T stop, T step)
|
|
@@ -2438,6 +2647,43 @@ inline CUDA_CALLABLE void adj_tile_mul(const typename Tile::Type& s, Tile& a,
|
|
|
2438
2647
|
}
|
|
2439
2648
|
|
|
2440
2649
|
|
|
2650
|
+
// tile & tile
|
|
2651
|
+
template <typename TileA, typename TileB>
|
|
2652
|
+
inline CUDA_CALLABLE auto tile_bit_and(TileA& a, TileB& b)
|
|
2653
|
+
{
|
|
2654
|
+
return tile_binary_map(bit_and, a, b, a);
|
|
2655
|
+
}
|
|
2656
|
+
|
|
2657
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
|
|
2658
|
+
inline CUDA_CALLABLE void adj_tile_bit_and(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
|
|
2659
|
+
{
|
|
2660
|
+
}
|
|
2661
|
+
|
|
2662
|
+
// tile | tile
|
|
2663
|
+
template <typename TileA, typename TileB>
|
|
2664
|
+
inline CUDA_CALLABLE auto tile_bit_or(TileA& a, TileB& b)
|
|
2665
|
+
{
|
|
2666
|
+
return tile_binary_map(bit_or, a, b, a);
|
|
2667
|
+
}
|
|
2668
|
+
|
|
2669
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
|
|
2670
|
+
inline CUDA_CALLABLE void adj_tile_bit_or(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
|
|
2671
|
+
{
|
|
2672
|
+
}
|
|
2673
|
+
|
|
2674
|
+
// tile ^ tile
|
|
2675
|
+
template <typename TileA, typename TileB>
|
|
2676
|
+
inline CUDA_CALLABLE auto tile_bit_xor(TileA& a, TileB& b)
|
|
2677
|
+
{
|
|
2678
|
+
return tile_binary_map(bit_xor, a, b, a);
|
|
2679
|
+
}
|
|
2680
|
+
|
|
2681
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
|
|
2682
|
+
inline CUDA_CALLABLE void adj_tile_bit_xor(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
|
|
2683
|
+
{
|
|
2684
|
+
}
|
|
2685
|
+
|
|
2686
|
+
|
|
2441
2687
|
template <typename TileA, typename TileB>
|
|
2442
2688
|
inline CUDA_CALLABLE void tile_add_inplace(TileA& a, TileB& b)
|
|
2443
2689
|
{
|
|
@@ -2557,24 +2803,227 @@ inline CUDA_CALLABLE void adj_tile_sub_inplace(TileA& a, TileB& b, AdjTileA& adj
|
|
|
2557
2803
|
adj_b.grad_add(adj_b_reg);
|
|
2558
2804
|
}
|
|
2559
2805
|
|
|
2806
|
+
template <typename TileA, typename TileB>
|
|
2807
|
+
inline CUDA_CALLABLE void tile_bit_and_inplace(TileA& a, TileB& b)
|
|
2808
|
+
{
|
|
2809
|
+
using ShapeA = typename TileA::Layout::Shape;
|
|
2810
|
+
using ShapeB = typename TileB::Layout::Shape;
|
|
2811
|
+
|
|
2812
|
+
// verify shapes and sizes are compatible
|
|
2813
|
+
static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace bitwise AND");
|
|
2814
|
+
static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace bitwise AND");
|
|
2815
|
+
|
|
2816
|
+
// work with register tiles for inplace operations, regardless of the storage type of the input tiles
|
|
2817
|
+
auto a_reg = a.copy_to_register();
|
|
2818
|
+
auto b_reg = b.copy_to_register();
|
|
2819
|
+
|
|
2820
|
+
using Layout = typename decltype(a_reg)::Layout;
|
|
2821
|
+
|
|
2822
|
+
WP_PRAGMA_UNROLL
|
|
2823
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
2824
|
+
{
|
|
2825
|
+
const int linear = Layout::linear_from_register(i);
|
|
2826
|
+
|
|
2827
|
+
if(!Layout::valid(linear))
|
|
2828
|
+
break;
|
|
2829
|
+
|
|
2830
|
+
a_reg.data[i] &= b_reg.data[i];
|
|
2831
|
+
}
|
|
2832
|
+
|
|
2833
|
+
a.assign(a_reg);
|
|
2834
|
+
}
|
|
2835
|
+
|
|
2836
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
|
|
2837
|
+
inline CUDA_CALLABLE void adj_tile_bit_and_inplace(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b) {}
|
|
2838
|
+
|
|
2839
|
+
template <typename TileA, typename TileB>
|
|
2840
|
+
inline CUDA_CALLABLE void tile_bit_or_inplace(TileA& a, TileB& b)
|
|
2841
|
+
{
|
|
2842
|
+
using ShapeA = typename TileA::Layout::Shape;
|
|
2843
|
+
using ShapeB = typename TileB::Layout::Shape;
|
|
2844
|
+
|
|
2845
|
+
// verify shapes and sizes are compatible
|
|
2846
|
+
static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace bitwise OR");
|
|
2847
|
+
static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace bitwise OR");
|
|
2848
|
+
|
|
2849
|
+
// work with register tiles for inplace operations, regardless of the storage type of the input tiles
|
|
2850
|
+
auto a_reg = a.copy_to_register();
|
|
2851
|
+
auto b_reg = b.copy_to_register();
|
|
2852
|
+
|
|
2853
|
+
using Layout = typename decltype(a_reg)::Layout;
|
|
2854
|
+
|
|
2855
|
+
WP_PRAGMA_UNROLL
|
|
2856
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
2857
|
+
{
|
|
2858
|
+
const int linear = Layout::linear_from_register(i);
|
|
2859
|
+
|
|
2860
|
+
if(!Layout::valid(linear))
|
|
2861
|
+
break;
|
|
2862
|
+
|
|
2863
|
+
a_reg.data[i] |= b_reg.data[i];
|
|
2864
|
+
}
|
|
2865
|
+
|
|
2866
|
+
a.assign(a_reg);
|
|
2867
|
+
}
|
|
2868
|
+
|
|
2869
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
|
|
2870
|
+
inline CUDA_CALLABLE void adj_tile_bit_or_inplace(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b) {}
|
|
2871
|
+
|
|
2872
|
+
template <typename TileA, typename TileB>
|
|
2873
|
+
inline CUDA_CALLABLE void tile_bit_xor_inplace(TileA& a, TileB& b)
|
|
2874
|
+
{
|
|
2875
|
+
using ShapeA = typename TileA::Layout::Shape;
|
|
2876
|
+
using ShapeB = typename TileB::Layout::Shape;
|
|
2877
|
+
|
|
2878
|
+
// verify shapes and sizes are compatible
|
|
2879
|
+
static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace bitwise XOR");
|
|
2880
|
+
static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace bitwise XOR");
|
|
2881
|
+
|
|
2882
|
+
// work with register tiles for inplace operations, regardless of the storage type of the input tiles
|
|
2883
|
+
auto a_reg = a.copy_to_register();
|
|
2884
|
+
auto b_reg = b.copy_to_register();
|
|
2885
|
+
|
|
2886
|
+
using Layout = typename decltype(a_reg)::Layout;
|
|
2887
|
+
|
|
2888
|
+
WP_PRAGMA_UNROLL
|
|
2889
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
2890
|
+
{
|
|
2891
|
+
const int linear = Layout::linear_from_register(i);
|
|
2892
|
+
|
|
2893
|
+
if(!Layout::valid(linear))
|
|
2894
|
+
break;
|
|
2895
|
+
|
|
2896
|
+
a_reg.data[i] ^= b_reg.data[i];
|
|
2897
|
+
}
|
|
2898
|
+
|
|
2899
|
+
a.assign(a_reg);
|
|
2900
|
+
}
|
|
2901
|
+
|
|
2902
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
|
|
2903
|
+
inline CUDA_CALLABLE void adj_tile_bit_xor_inplace(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b) {}
|
|
2904
|
+
|
|
2560
2905
|
|
|
2561
2906
|
template<typename Tile>
|
|
2562
|
-
typename Tile::Type tile_extract(Tile& t, int i) {
|
|
2907
|
+
typename Tile::Type tile_extract(Tile& t, int i) {
|
|
2908
|
+
return t.extract(tile_coord(i));
|
|
2909
|
+
}
|
|
2563
2910
|
template<typename Tile>
|
|
2564
|
-
|
|
2911
|
+
auto tile_extract(Tile& t, int i, int j) {
|
|
2912
|
+
if constexpr(is_vector<typename Tile::Type>::value) {
|
|
2913
|
+
return t.extract(tile_coord(i))[j];
|
|
2914
|
+
} else {
|
|
2915
|
+
return t.extract(tile_coord(i,j));
|
|
2916
|
+
}
|
|
2917
|
+
}
|
|
2565
2918
|
template<typename Tile>
|
|
2566
|
-
|
|
2919
|
+
auto tile_extract(Tile& t, int i, int j, int k) {
|
|
2920
|
+
if constexpr(is_vector<typename Tile::Type>::value) {
|
|
2921
|
+
return t.extract(tile_coord(i,j))[k];
|
|
2922
|
+
} else if constexpr(is_matrix<typename Tile::Type>::value) {
|
|
2923
|
+
return t.extract(tile_coord(i)).data[j][k];
|
|
2924
|
+
} else {
|
|
2925
|
+
return t.extract(tile_coord(i,j,k));
|
|
2926
|
+
}
|
|
2927
|
+
}
|
|
2567
2928
|
template<typename Tile>
|
|
2568
|
-
|
|
2929
|
+
auto tile_extract(Tile& t, int i, int j, int k, int l) {
|
|
2930
|
+
if constexpr(is_vector<typename Tile::Type>::value) {
|
|
2931
|
+
return t.extract(tile_coord(i,j,k))[l];
|
|
2932
|
+
} else if constexpr(is_matrix<typename Tile::Type>::value) {
|
|
2933
|
+
return t.extract(tile_coord(i,j)).data[k][l];
|
|
2934
|
+
} else {
|
|
2935
|
+
return t.extract(tile_coord(i,j,k,l));
|
|
2936
|
+
}
|
|
2937
|
+
}
|
|
2938
|
+
template<typename Tile>
|
|
2939
|
+
auto tile_extract(Tile& t, int i, int j, int k, int l, int m) {
|
|
2940
|
+
if constexpr(is_vector<typename Tile::Type>::value) {
|
|
2941
|
+
return t.extract(tile_coord(i,j,k,l))[m];
|
|
2942
|
+
} else if constexpr(is_matrix<typename Tile::Type>::value) {
|
|
2943
|
+
return t.extract(tile_coord(i,j,k)).data[l][m];
|
|
2944
|
+
} else {
|
|
2945
|
+
static_assert(always_false<Tile>::value,
|
|
2946
|
+
"tile_extract with 5 indices requires a tile of vectors (4D tile) or matrices (3D tile)");
|
|
2947
|
+
}
|
|
2948
|
+
}
|
|
2949
|
+
template<typename Tile>
|
|
2950
|
+
auto tile_extract(Tile& t, int i, int j, int k, int l, int m, int n) {
|
|
2951
|
+
if constexpr(is_matrix<typename Tile::Type>::value) {
|
|
2952
|
+
return t.extract(tile_coord(i,j,k,l)).data[m][n];
|
|
2953
|
+
} else {
|
|
2954
|
+
static_assert(always_false<Tile>::value,
|
|
2955
|
+
"tile_extract with 6 indices requires a tile of matrices (4D tile)");
|
|
2956
|
+
}
|
|
2957
|
+
}
|
|
2569
2958
|
|
|
2570
2959
|
template<typename Tile, typename AdjTile>
|
|
2571
|
-
void adj_tile_extract(Tile& t, int i, AdjTile& adj_t, int adj_i, typename Tile::Type adj_ret) {
|
|
2572
|
-
|
|
2573
|
-
|
|
2574
|
-
template<typename Tile, typename AdjTile>
|
|
2575
|
-
void adj_tile_extract(Tile& t, int i, int j,
|
|
2576
|
-
|
|
2577
|
-
|
|
2960
|
+
void adj_tile_extract(Tile& t, int i, AdjTile& adj_t, int adj_i, typename Tile::Type adj_ret) {
|
|
2961
|
+
adj_t.adj_extract(tile_coord(i), adj_ret);
|
|
2962
|
+
}
|
|
2963
|
+
template<typename Tile, typename AdjTile, typename AdjType>
|
|
2964
|
+
void adj_tile_extract(Tile& t, int i, int j, AdjTile& adj_t, int adj_i, int adj_j, AdjType adj_ret) {
|
|
2965
|
+
if constexpr(is_vector<typename Tile::Type>::value) {
|
|
2966
|
+
typename Tile::Type vector_adj{};
|
|
2967
|
+
vector_adj[j] = adj_ret;
|
|
2968
|
+
adj_t.adj_extract(tile_coord(i), vector_adj);
|
|
2969
|
+
} else {
|
|
2970
|
+
adj_t.adj_extract(tile_coord(i, j), adj_ret);
|
|
2971
|
+
}
|
|
2972
|
+
}
|
|
2973
|
+
template<typename Tile, typename AdjTile, typename AdjType>
|
|
2974
|
+
void adj_tile_extract(Tile& t, int i, int j, int k, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, AdjType adj_ret) {
|
|
2975
|
+
if constexpr(is_vector<typename Tile::Type>::value) {
|
|
2976
|
+
typename Tile::Type vector_adj{};
|
|
2977
|
+
vector_adj[k] = adj_ret;
|
|
2978
|
+
adj_t.adj_extract(tile_coord(i, j), vector_adj);
|
|
2979
|
+
} else if constexpr(is_matrix<typename Tile::Type>::value) {
|
|
2980
|
+
typename Tile::Type matrix_adj{};
|
|
2981
|
+
matrix_adj.data[j][k] = adj_ret;
|
|
2982
|
+
adj_t.adj_extract(tile_coord(i), matrix_adj);
|
|
2983
|
+
} else {
|
|
2984
|
+
adj_t.adj_extract(tile_coord(i, j, k), adj_ret);
|
|
2985
|
+
}
|
|
2986
|
+
}
|
|
2987
|
+
template<typename Tile, typename AdjTile, typename AdjType>
|
|
2988
|
+
void adj_tile_extract(Tile& t, int i, int j, int k, int l, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, AdjType adj_ret) {
|
|
2989
|
+
if constexpr(is_vector<typename Tile::Type>::value) {
|
|
2990
|
+
typename Tile::Type vector_adj{};
|
|
2991
|
+
vector_adj[l] = adj_ret;
|
|
2992
|
+
adj_t.adj_extract(tile_coord(i, j, k), vector_adj);
|
|
2993
|
+
} else if constexpr(is_matrix<typename Tile::Type>::value) {
|
|
2994
|
+
typename Tile::Type matrix_adj{};
|
|
2995
|
+
matrix_adj.data[k][l] = adj_ret;
|
|
2996
|
+
adj_t.adj_extract(tile_coord(i, j), matrix_adj);
|
|
2997
|
+
} else {
|
|
2998
|
+
adj_t.adj_extract(tile_coord(i, j, k, l), adj_ret);
|
|
2999
|
+
}
|
|
3000
|
+
}
|
|
3001
|
+
template<typename Tile, typename AdjTile, typename AdjType>
|
|
3002
|
+
void adj_tile_extract(Tile& t, int i, int j, int k, int l, int m, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, int adj_m, AdjType adj_ret) {
|
|
3003
|
+
if constexpr(is_vector<typename Tile::Type>::value) {
|
|
3004
|
+
typename Tile::Type vector_adj{};
|
|
3005
|
+
vector_adj[m] = adj_ret;
|
|
3006
|
+
adj_t.adj_extract(tile_coord(i, j, k, l), vector_adj);
|
|
3007
|
+
} else if constexpr(is_matrix<typename Tile::Type>::value) {
|
|
3008
|
+
typename Tile::Type matrix_adj{};
|
|
3009
|
+
matrix_adj.data[l][m] = adj_ret;
|
|
3010
|
+
adj_t.adj_extract(tile_coord(i, j, k), matrix_adj);
|
|
3011
|
+
} else {
|
|
3012
|
+
static_assert(always_false<Tile>::value,
|
|
3013
|
+
"adj_tile_extract with 5 indices requires a tile of vectors (4D tile) or matrices (3D tile)");
|
|
3014
|
+
}
|
|
3015
|
+
}
|
|
3016
|
+
template<typename Tile, typename AdjTile, typename AdjType>
|
|
3017
|
+
void adj_tile_extract(Tile& t, int i, int j, int k, int l, int m, int n, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, int adj_m, int adj_n, AdjType adj_ret) {
|
|
3018
|
+
if constexpr(is_matrix<typename Tile::Type>::value) {
|
|
3019
|
+
typename Tile::Type matrix_adj{};
|
|
3020
|
+
matrix_adj.data[m][n] = adj_ret;
|
|
3021
|
+
adj_t.adj_extract(tile_coord(i, j, k, l), matrix_adj);
|
|
3022
|
+
} else {
|
|
3023
|
+
static_assert(always_false<Tile>::value,
|
|
3024
|
+
"adj_tile_extract with 6 indices requires a tile of matrices (4D tile)");
|
|
3025
|
+
}
|
|
3026
|
+
}
|
|
2578
3027
|
|
|
2579
3028
|
|
|
2580
3029
|
template<typename Tile>
|
|
@@ -2595,6 +3044,33 @@ void tile_sub_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) {
|
|
|
2595
3044
|
template<typename Tile>
|
|
2596
3045
|
void tile_sub_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.sub_inplace(tile_coord(i,j,k,l), value); }
|
|
2597
3046
|
|
|
3047
|
+
template<typename Tile>
|
|
3048
|
+
void tile_bit_and_inplace(Tile& t, int i, typename Tile::Type value) { t.bit_and_inplace(tile_coord(i), value); }
|
|
3049
|
+
template<typename Tile>
|
|
3050
|
+
void tile_bit_and_inplace(Tile& t, int i, int j, typename Tile::Type value) { t.bit_and_inplace(tile_coord(i,j), value); }
|
|
3051
|
+
template<typename Tile>
|
|
3052
|
+
void tile_bit_and_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) { t.bit_and_inplace(tile_coord(i,j,k), value); }
|
|
3053
|
+
template<typename Tile>
|
|
3054
|
+
void tile_bit_and_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.bit_and_inplace(tile_coord(i,j,k,l), value); }
|
|
3055
|
+
|
|
3056
|
+
template<typename Tile>
|
|
3057
|
+
void tile_bit_or_inplace(Tile& t, int i, typename Tile::Type value) { t.bit_or_inplace(tile_coord(i), value); }
|
|
3058
|
+
template<typename Tile>
|
|
3059
|
+
void tile_bit_or_inplace(Tile& t, int i, int j, typename Tile::Type value) { t.bit_or_inplace(tile_coord(i,j), value); }
|
|
3060
|
+
template<typename Tile>
|
|
3061
|
+
void tile_bit_or_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) { t.bit_or_inplace(tile_coord(i,j,k), value); }
|
|
3062
|
+
template<typename Tile>
|
|
3063
|
+
void tile_bit_or_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.bit_or_inplace(tile_coord(i,j,k,l), value); }
|
|
3064
|
+
|
|
3065
|
+
template<typename Tile>
|
|
3066
|
+
void tile_bit_xor_inplace(Tile& t, int i, typename Tile::Type value) { t.bit_xor_inplace(tile_coord(i), value); }
|
|
3067
|
+
template<typename Tile>
|
|
3068
|
+
void tile_bit_xor_inplace(Tile& t, int i, int j, typename Tile::Type value) { t.bit_xor_inplace(tile_coord(i,j), value); }
|
|
3069
|
+
template<typename Tile>
|
|
3070
|
+
void tile_bit_xor_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) { t.bit_xor_inplace(tile_coord(i,j,k), value); }
|
|
3071
|
+
template<typename Tile>
|
|
3072
|
+
void tile_bit_xor_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.bit_xor_inplace(tile_coord(i,j,k,l), value); }
|
|
3073
|
+
|
|
2598
3074
|
template<typename Tile, typename AdjTile>
|
|
2599
3075
|
void adj_tile_add_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) { adj_t.adj_add_inplace(tile_coord(i), adj_value); }
|
|
2600
3076
|
template<typename Tile, typename AdjTile>
|
|
@@ -2613,6 +3089,33 @@ void adj_tile_sub_inplace(Tile& t, int i, int j, int k, typename Tile::Type valu
|
|
|
2613
3089
|
template<typename Tile, typename AdjTile>
|
|
2614
3090
|
void adj_tile_sub_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) { adj_t.adj_sub_inplace(tile_coord(i, j, k, l), adj_value); }
|
|
2615
3091
|
|
|
3092
|
+
template<typename Tile, typename AdjTile>
|
|
3093
|
+
void adj_tile_bit_and_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) {}
|
|
3094
|
+
template<typename Tile, typename AdjTile>
|
|
3095
|
+
void adj_tile_bit_and_inplace(Tile& t, int i, int j, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type& adj_value) {}
|
|
3096
|
+
template<typename Tile, typename AdjTile>
|
|
3097
|
+
void adj_tile_bit_and_inplace(Tile& t, int i, int j, int k, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type& adj_value) {}
|
|
3098
|
+
template<typename Tile, typename AdjTile>
|
|
3099
|
+
void adj_tile_bit_and_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) {}
|
|
3100
|
+
|
|
3101
|
+
template<typename Tile, typename AdjTile>
|
|
3102
|
+
void adj_tile_bit_or_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) {}
|
|
3103
|
+
template<typename Tile, typename AdjTile>
|
|
3104
|
+
void adj_tile_bit_or_inplace(Tile& t, int i, int j, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type& adj_value) {}
|
|
3105
|
+
template<typename Tile, typename AdjTile>
|
|
3106
|
+
void adj_tile_bit_or_inplace(Tile& t, int i, int j, int k, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type& adj_value) {}
|
|
3107
|
+
template<typename Tile, typename AdjTile>
|
|
3108
|
+
void adj_tile_bit_or_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) {}
|
|
3109
|
+
|
|
3110
|
+
template<typename Tile, typename AdjTile>
|
|
3111
|
+
void adj_tile_bit_xor_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) {}
|
|
3112
|
+
template<typename Tile, typename AdjTile>
|
|
3113
|
+
void adj_tile_bit_xor_inplace(Tile& t, int i, int j, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type& adj_value) {}
|
|
3114
|
+
template<typename Tile, typename AdjTile>
|
|
3115
|
+
void adj_tile_bit_xor_inplace(Tile& t, int i, int j, int k, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type& adj_value) {}
|
|
3116
|
+
template<typename Tile, typename AdjTile>
|
|
3117
|
+
void adj_tile_bit_xor_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) {}
|
|
3118
|
+
|
|
2616
3119
|
namespace partitioned_gemm
|
|
2617
3120
|
{
|
|
2618
3121
|
|
|
@@ -3000,7 +3503,7 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
|
|
|
3000
3503
|
#define tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout) \
|
|
3001
3504
|
do { \
|
|
3002
3505
|
void function_name(dtype*, char*); \
|
|
3003
|
-
char* buffer = (char*)wp::
|
|
3506
|
+
char* buffer = (char*)wp::tile_shared_storage_t::alloc(shared_memory_size); \
|
|
3004
3507
|
__align__(16) dtype data[ept]; \
|
|
3005
3508
|
for(int b = 0; b < (int)batch_size; b++) { \
|
|
3006
3509
|
dtype* inout = Xinout.data + (int)b * (int)ept; \
|
|
@@ -3009,7 +3512,7 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
|
|
|
3009
3512
|
memcpy(inout, data, sizeof(dtype) * ept); \
|
|
3010
3513
|
WP_TILE_SYNC(); \
|
|
3011
3514
|
} \
|
|
3012
|
-
wp::
|
|
3515
|
+
wp::tile_shared_storage_t::alloc(-shared_memory_size); \
|
|
3013
3516
|
} while (0)
|
|
3014
3517
|
|
|
3015
3518
|
#define tile_ifft tile_fft
|
|
@@ -3053,7 +3556,7 @@ TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
|
|
|
3053
3556
|
#else
|
|
3054
3557
|
|
|
3055
3558
|
// TODO: for batched Cholesky, need one info per batch
|
|
3056
|
-
|
|
3559
|
+
__shared__ int info[1];
|
|
3057
3560
|
|
|
3058
3561
|
if (WP_TILE_THREAD_IDX == 0) {
|
|
3059
3562
|
info[0] = 0;
|
|
@@ -3385,21 +3888,62 @@ inline CUDA_CALLABLE void assign(TileA& dest, int i, const Scalar& src)
|
|
|
3385
3888
|
template <typename TileA, typename Scalar>
|
|
3386
3889
|
inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, const Scalar& src)
|
|
3387
3890
|
{
|
|
3388
|
-
|
|
3891
|
+
if constexpr(is_vector<typename TileA::Type>::value) {
|
|
3892
|
+
dest.data(tile_coord(i))[j] = src;
|
|
3893
|
+
} else {
|
|
3894
|
+
dest.data(tile_coord(i, j)) = src;
|
|
3895
|
+
}
|
|
3389
3896
|
WP_TILE_SYNC();
|
|
3390
3897
|
}
|
|
3391
3898
|
template <typename TileA, typename Scalar>
|
|
3392
3899
|
inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, const Scalar& src)
|
|
3393
3900
|
{
|
|
3394
|
-
|
|
3901
|
+
if constexpr(is_vector<typename TileA::Type>::value) {
|
|
3902
|
+
dest.data(tile_coord(i, j))[k] = src;
|
|
3903
|
+
} else if constexpr(is_matrix<typename TileA::Type>::value) {
|
|
3904
|
+
dest.data(tile_coord(i)).data[j][k] = src;
|
|
3905
|
+
} else {
|
|
3906
|
+
dest.data(tile_coord(i, j, k)) = src;
|
|
3907
|
+
}
|
|
3395
3908
|
WP_TILE_SYNC();
|
|
3396
3909
|
}
|
|
3397
3910
|
template <typename TileA, typename Scalar>
|
|
3398
3911
|
inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, int l, const Scalar& src)
|
|
3399
3912
|
{
|
|
3400
|
-
|
|
3913
|
+
if constexpr(is_vector<typename TileA::Type>::value) {
|
|
3914
|
+
dest.data(tile_coord(i, j, k))[l] = src;
|
|
3915
|
+
} else if constexpr(is_matrix<typename TileA::Type>::value) {
|
|
3916
|
+
dest.data(tile_coord(i, j)).data[k][l] = src;
|
|
3917
|
+
} else {
|
|
3918
|
+
dest.data(tile_coord(i, j, k, l)) = src;
|
|
3919
|
+
}
|
|
3920
|
+
WP_TILE_SYNC();
|
|
3921
|
+
}
|
|
3922
|
+
template <typename TileA, typename Scalar>
|
|
3923
|
+
inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, int l, int m, const Scalar& src)
|
|
3924
|
+
{
|
|
3925
|
+
if constexpr(is_vector<typename TileA::Type>::value) {
|
|
3926
|
+
dest.data(tile_coord(i, j, k, l))[m] = src;
|
|
3927
|
+
} else if constexpr(is_matrix<typename TileA::Type>::value) {
|
|
3928
|
+
dest.data(tile_coord(i, j, k)).data[l][m] = src;
|
|
3929
|
+
} else {
|
|
3930
|
+
static_assert(always_false<TileA>::value,
|
|
3931
|
+
"assign with 5 indices requires a tile of vectors (4D tile) or matrices (3D tile)");
|
|
3932
|
+
}
|
|
3401
3933
|
WP_TILE_SYNC();
|
|
3402
3934
|
}
|
|
3935
|
+
template <typename TileA, typename Scalar>
|
|
3936
|
+
inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, int l, int m, int n, const Scalar& src)
|
|
3937
|
+
{
|
|
3938
|
+
if constexpr(is_matrix<typename TileA::Type>::value) {
|
|
3939
|
+
dest.data(tile_coord(i, j, k, l)).data[m][n] = src;
|
|
3940
|
+
} else {
|
|
3941
|
+
static_assert(always_false<TileA>::value,
|
|
3942
|
+
"assign with 6 indices requires a tile of matrices (4D tile)");
|
|
3943
|
+
}
|
|
3944
|
+
WP_TILE_SYNC();
|
|
3945
|
+
}
|
|
3946
|
+
|
|
3403
3947
|
|
|
3404
3948
|
template <typename TileA, typename AdjTileA, typename Scalar>
|
|
3405
3949
|
inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, const Scalar& src, AdjTileA& adj_dest, int adj_i, Scalar& adj_src)
|
|
@@ -3419,7 +3963,11 @@ inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, const Scalar& sr
|
|
|
3419
3963
|
return;
|
|
3420
3964
|
}
|
|
3421
3965
|
|
|
3422
|
-
|
|
3966
|
+
if constexpr(is_vector<typename TileA::Type>::value) {
|
|
3967
|
+
adj_src += dest.grad(tile_coord(i))[j];
|
|
3968
|
+
} else {
|
|
3969
|
+
adj_src += dest.grad(tile_coord(i, j));
|
|
3970
|
+
}
|
|
3423
3971
|
}
|
|
3424
3972
|
template <typename TileA, typename AdjTileA, typename Scalar>
|
|
3425
3973
|
inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, const Scalar& src, AdjTileA& adj_dest, int adj_i, int adj_j, int adj_k, Scalar& adj_src)
|
|
@@ -3429,7 +3977,13 @@ inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, const Sca
|
|
|
3429
3977
|
return;
|
|
3430
3978
|
}
|
|
3431
3979
|
|
|
3432
|
-
|
|
3980
|
+
if constexpr(is_vector<typename TileA::Type>::value) {
|
|
3981
|
+
adj_src += dest.grad(tile_coord(i, j))[k];
|
|
3982
|
+
} else if constexpr(is_matrix<typename TileA::Type>::value) {
|
|
3983
|
+
adj_src += dest.grad(tile_coord(i)).data[j][k];
|
|
3984
|
+
} else {
|
|
3985
|
+
adj_src += dest.grad(tile_coord(i, j, k));
|
|
3986
|
+
}
|
|
3433
3987
|
}
|
|
3434
3988
|
template <typename TileA, typename AdjTileA, typename Scalar>
|
|
3435
3989
|
inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, int l, const Scalar& src, AdjTileA& adj_dest, int adj_i, int adj_j, int adj_k, int adj_l, Scalar& adj_src)
|
|
@@ -3439,7 +3993,45 @@ inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, int l, co
|
|
|
3439
3993
|
return;
|
|
3440
3994
|
}
|
|
3441
3995
|
|
|
3442
|
-
|
|
3996
|
+
if constexpr(is_vector<typename TileA::Type>::value) {
|
|
3997
|
+
adj_src += dest.grad(tile_coord(i, j, k))[l];
|
|
3998
|
+
} else if constexpr(is_matrix<typename TileA::Type>::value) {
|
|
3999
|
+
adj_src += dest.grad(tile_coord(i, j)).data[k][l];
|
|
4000
|
+
} else {
|
|
4001
|
+
adj_src += dest.grad(tile_coord(i, j, k, l));
|
|
4002
|
+
}
|
|
4003
|
+
}
|
|
4004
|
+
template <typename TileA, typename AdjTileA, typename Scalar>
|
|
4005
|
+
inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, int l, int m, const Scalar& src, AdjTileA& adj_dest, int adj_i, int adj_j, int adj_k, int adj_l, int adj_m, Scalar& adj_src)
|
|
4006
|
+
{
|
|
4007
|
+
if (dest.grad.ptr == nullptr)
|
|
4008
|
+
{
|
|
4009
|
+
return;
|
|
4010
|
+
}
|
|
4011
|
+
|
|
4012
|
+
if constexpr(is_vector<typename TileA::Type>::value) {
|
|
4013
|
+
adj_src += dest.grad(tile_coord(i, j, k, l))[m];
|
|
4014
|
+
} else if constexpr(is_matrix<typename TileA::Type>::value) {
|
|
4015
|
+
adj_src += dest.grad(tile_coord(i, j, k)).data[l][m];
|
|
4016
|
+
} else {
|
|
4017
|
+
static_assert(always_false<TileA>::value,
|
|
4018
|
+
"adj_assign with 5 indices requires a tile of vectors (4D tile) or matrices (3D tile)");
|
|
4019
|
+
}
|
|
4020
|
+
}
|
|
4021
|
+
template <typename TileA, typename AdjTileA, typename Scalar>
|
|
4022
|
+
inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, int l, int m, int n, const Scalar& src, AdjTileA& adj_dest, int adj_i, int adj_j, int adj_k, int adj_l, int adj_m, int adj_n, Scalar& adj_src)
|
|
4023
|
+
{
|
|
4024
|
+
if (dest.grad.ptr == nullptr)
|
|
4025
|
+
{
|
|
4026
|
+
return;
|
|
4027
|
+
}
|
|
4028
|
+
|
|
4029
|
+
if constexpr(is_matrix<typename TileA::Type>::value) {
|
|
4030
|
+
adj_src += dest.grad(tile_coord(i, j, k, l)).data[m][n];
|
|
4031
|
+
} else {
|
|
4032
|
+
static_assert(always_false<TileA>::value,
|
|
4033
|
+
"adj_assign with 6 indices requires a tile of matrices (4D tile)");
|
|
4034
|
+
}
|
|
3443
4035
|
}
|
|
3444
4036
|
|
|
3445
4037
|
template <typename TileA, typename TileB, typename Coord>
|