warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +301 -287
- warp/__init__.pyi +2220 -313
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1075 -0
- warp/_src/build.py +618 -0
- warp/_src/build_dll.py +640 -0
- warp/{builtins.py → _src/builtins.py} +1497 -226
- warp/_src/codegen.py +4359 -0
- warp/{config.py → _src/config.py} +178 -169
- warp/_src/constants.py +57 -0
- warp/_src/context.py +8294 -0
- warp/_src/dlpack.py +462 -0
- warp/_src/fabric.py +355 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +508 -0
- warp/_src/fem/cache.py +687 -0
- warp/_src/fem/dirichlet.py +188 -0
- warp/{fem → _src/fem}/domain.py +40 -30
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +701 -0
- warp/{fem → _src/fem}/field/nodal_field.py +30 -15
- warp/{fem → _src/fem}/field/restriction.py +1 -1
- warp/{fem → _src/fem}/field/virtual.py +53 -27
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
- warp/_src/fem/geometry/closest_point.py +97 -0
- warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
- warp/{fem → _src/fem}/geometry/element.py +32 -10
- warp/{fem → _src/fem}/geometry/geometry.py +48 -20
- warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
- warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
- warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
- warp/{fem → _src/fem}/geometry/partition.py +121 -63
- warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
- warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
- warp/{fem → _src/fem}/integrate.py +164 -158
- warp/_src/fem/linalg.py +383 -0
- warp/_src/fem/operator.py +396 -0
- warp/_src/fem/polynomial.py +229 -0
- warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
- warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
- warp/_src/fem/space/__init__.py +248 -0
- warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
- warp/_src/fem/space/basis_space.py +679 -0
- warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
- warp/{fem → _src/fem}/space/function_space.py +14 -13
- warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
- warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
- warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
- warp/{fem → _src/fem}/space/partition.py +117 -60
- warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/restriction.py +66 -33
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
- warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
- warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
- warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
- warp/_src/fem/space/topology.py +459 -0
- warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
- warp/_src/fem/types.py +112 -0
- warp/_src/fem/utils.py +486 -0
- warp/_src/jax.py +186 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +387 -0
- warp/_src/jax_experimental/ffi.py +1284 -0
- warp/_src/jax_experimental/xla_ffi.py +656 -0
- warp/_src/marching_cubes.py +708 -0
- warp/_src/math.py +414 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +163 -0
- warp/_src/optim/linear.py +1606 -0
- warp/_src/optim/sgd.py +112 -0
- warp/_src/paddle.py +406 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +289 -0
- warp/_src/render/render_opengl.py +3636 -0
- warp/_src/render/render_usd.py +937 -0
- warp/_src/render/utils.py +160 -0
- warp/_src/sparse.py +2716 -0
- warp/_src/tape.py +1206 -0
- warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
- warp/_src/torch.py +391 -0
- warp/_src/types.py +5870 -0
- warp/_src/utils.py +1693 -0
- warp/autograd.py +12 -1054
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +8 -588
- warp/build_dll.py +6 -471
- warp/codegen.py +6 -4246
- warp/constants.py +6 -39
- warp/context.py +12 -7851
- warp/dlpack.py +6 -444
- warp/examples/distributed/example_jacobi_mpi.py +4 -5
- warp/examples/fem/example_adaptive_grid.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +8 -8
- warp/examples/fem/example_diffusion.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_mixed_elasticity.py +2 -2
- warp/examples/fem/example_navier_stokes.py +1 -1
- warp/examples/fem/example_nonconforming_contact.py +7 -7
- warp/examples/fem/example_stokes.py +1 -1
- warp/examples/fem/example_stokes_transfer.py +1 -1
- warp/examples/fem/utils.py +2 -2
- warp/examples/interop/example_jax_callable.py +1 -1
- warp/examples/interop/example_jax_ffi_callback.py +1 -1
- warp/examples/interop/example_jax_kernel.py +3 -2
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/fabric.py +6 -337
- warp/fem/__init__.py +159 -97
- warp/fem/adaptivity.py +7 -489
- warp/fem/cache.py +9 -648
- warp/fem/dirichlet.py +6 -184
- warp/fem/field/__init__.py +8 -109
- warp/fem/field/field.py +7 -652
- warp/fem/geometry/__init__.py +7 -18
- warp/fem/geometry/closest_point.py +11 -77
- warp/fem/linalg.py +18 -366
- warp/fem/operator.py +11 -369
- warp/fem/polynomial.py +9 -209
- warp/fem/space/__init__.py +5 -211
- warp/fem/space/basis_space.py +6 -662
- warp/fem/space/shape/__init__.py +41 -118
- warp/fem/space/topology.py +6 -437
- warp/fem/types.py +6 -81
- warp/fem/utils.py +11 -444
- warp/jax.py +8 -165
- warp/jax_experimental/__init__.py +14 -1
- warp/jax_experimental/custom_call.py +8 -342
- warp/jax_experimental/ffi.py +17 -853
- warp/jax_experimental/xla_ffi.py +5 -596
- warp/marching_cubes.py +5 -689
- warp/math.py +16 -393
- warp/native/array.h +385 -37
- warp/native/builtin.h +316 -39
- warp/native/bvh.cpp +43 -9
- warp/native/bvh.cu +62 -27
- warp/native/bvh.h +310 -309
- warp/native/clang/clang.cpp +102 -97
- warp/native/coloring.cpp +0 -1
- warp/native/crt.h +208 -0
- warp/native/exports.h +156 -0
- warp/native/hashgrid.cu +2 -0
- warp/native/intersect.h +24 -1
- warp/native/intersect_tri.h +44 -35
- warp/native/mat.h +1456 -276
- warp/native/mesh.cpp +4 -4
- warp/native/mesh.cu +4 -2
- warp/native/mesh.h +176 -61
- warp/native/quat.h +0 -52
- warp/native/scan.cu +2 -0
- warp/native/sort.cu +22 -13
- warp/native/sort.h +2 -0
- warp/native/sparse.cu +7 -3
- warp/native/spatial.h +12 -0
- warp/native/tile.h +837 -70
- warp/native/tile_radix_sort.h +1 -1
- warp/native/tile_reduce.h +394 -46
- warp/native/tile_scan.h +4 -4
- warp/native/vec.h +469 -53
- warp/native/version.h +23 -0
- warp/native/volume.cpp +1 -1
- warp/native/volume.cu +1 -0
- warp/native/volume.h +1 -1
- warp/native/volume_builder.cu +2 -0
- warp/native/warp.cpp +60 -32
- warp/native/warp.cu +313 -201
- warp/native/warp.h +14 -11
- warp/optim/__init__.py +6 -3
- warp/optim/adam.py +6 -145
- warp/optim/linear.py +14 -1585
- warp/optim/sgd.py +6 -94
- warp/paddle.py +6 -388
- warp/render/__init__.py +8 -4
- warp/render/imgui_manager.py +7 -267
- warp/render/render_opengl.py +6 -3616
- warp/render/render_usd.py +6 -918
- warp/render/utils.py +6 -142
- warp/sparse.py +37 -2563
- warp/tape.py +6 -1188
- warp/tests/__main__.py +1 -1
- warp/tests/cuda/test_async.py +4 -4
- warp/tests/cuda/test_conditional_captures.py +1 -1
- warp/tests/cuda/test_multigpu.py +1 -1
- warp/tests/cuda/test_streams.py +58 -1
- warp/tests/geometry/test_bvh.py +157 -22
- warp/tests/geometry/test_hash_grid.py +38 -0
- warp/tests/geometry/test_marching_cubes.py +0 -1
- warp/tests/geometry/test_mesh.py +5 -3
- warp/tests/geometry/test_mesh_query_aabb.py +5 -12
- warp/tests/geometry/test_mesh_query_point.py +5 -2
- warp/tests/geometry/test_mesh_query_ray.py +15 -3
- warp/tests/geometry/test_volume_write.py +5 -5
- warp/tests/interop/test_dlpack.py +14 -14
- warp/tests/interop/test_jax.py +1382 -79
- warp/tests/interop/test_paddle.py +1 -1
- warp/tests/test_adam.py +0 -1
- warp/tests/test_arithmetic.py +9 -9
- warp/tests/test_array.py +529 -100
- warp/tests/test_array_reduce.py +3 -3
- warp/tests/test_atomic.py +12 -8
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +4 -4
- warp/tests/test_bool.py +2 -2
- warp/tests/test_builtins_resolution.py +5 -571
- warp/tests/test_codegen.py +34 -15
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_context.py +6 -6
- warp/tests/test_copy.py +242 -161
- warp/tests/test_ctypes.py +3 -3
- warp/tests/test_devices.py +24 -2
- warp/tests/test_examples.py +16 -84
- warp/tests/test_fabricarray.py +35 -35
- warp/tests/test_fast_math.py +0 -2
- warp/tests/test_fem.py +60 -14
- warp/tests/test_fixedarray.py +3 -3
- warp/tests/test_func.py +8 -5
- warp/tests/test_generics.py +1 -1
- warp/tests/test_indexedarray.py +24 -24
- warp/tests/test_intersect.py +39 -9
- warp/tests/test_large.py +1 -1
- warp/tests/test_lerp.py +3 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_map.py +49 -4
- warp/tests/test_mat.py +52 -62
- warp/tests/test_mat_constructors.py +4 -5
- warp/tests/test_mat_lite.py +1 -1
- warp/tests/test_mat_scalar_ops.py +121 -121
- warp/tests/test_math.py +34 -0
- warp/tests/test_module_aot.py +4 -4
- warp/tests/test_modules_lite.py +28 -2
- warp/tests/test_print.py +11 -11
- warp/tests/test_quat.py +93 -58
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +38 -10
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +126 -15
- warp/tests/test_spatial.py +105 -87
- warp/tests/test_special_values.py +6 -6
- warp/tests/test_static.py +7 -7
- warp/tests/test_struct.py +13 -2
- warp/tests/test_triangle_closest_point.py +48 -1
- warp/tests/test_tuple.py +96 -0
- warp/tests/test_types.py +82 -9
- warp/tests/test_utils.py +52 -52
- warp/tests/test_vec.py +29 -29
- warp/tests/test_vec_constructors.py +5 -5
- warp/tests/test_vec_scalar_ops.py +97 -97
- warp/tests/test_version.py +75 -0
- warp/tests/tile/test_tile.py +239 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +7 -4
- warp/tests/tile/test_tile_load.py +26 -2
- warp/tests/tile/test_tile_mathdx.py +3 -3
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +2 -4
- warp/tests/tile/test_tile_reduce.py +214 -13
- warp/tests/unittest_suites.py +6 -14
- warp/tests/unittest_utils.py +10 -9
- warp/tests/walkthrough_debug.py +3 -1
- warp/torch.py +6 -373
- warp/types.py +29 -5750
- warp/utils.py +10 -1659
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
- warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp/examples/assets/cartpole.urdf +0 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/nv_ant.xml +0 -92
- warp/examples/assets/nv_humanoid.xml +0 -183
- warp/examples/assets/quadruped.urdf +0 -268
- warp/examples/optim/example_bounce.py +0 -266
- warp/examples/optim/example_cloth_throw.py +0 -228
- warp/examples/optim/example_drone.py +0 -870
- warp/examples/optim/example_inverse_kinematics.py +0 -182
- warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
- warp/examples/optim/example_softbody_properties.py +0 -400
- warp/examples/optim/example_spring_cage.py +0 -245
- warp/examples/optim/example_trajectory.py +0 -227
- warp/examples/sim/example_cartpole.py +0 -143
- warp/examples/sim/example_cloth.py +0 -225
- warp/examples/sim/example_cloth_self_contact.py +0 -316
- warp/examples/sim/example_granular.py +0 -130
- warp/examples/sim/example_granular_collision_sdf.py +0 -202
- warp/examples/sim/example_jacobian_ik.py +0 -244
- warp/examples/sim/example_particle_chain.py +0 -124
- warp/examples/sim/example_quadruped.py +0 -203
- warp/examples/sim/example_rigid_chain.py +0 -203
- warp/examples/sim/example_rigid_contact.py +0 -195
- warp/examples/sim/example_rigid_force.py +0 -133
- warp/examples/sim/example_rigid_gyroscopic.py +0 -115
- warp/examples/sim/example_rigid_soft_contact.py +0 -140
- warp/examples/sim/example_soft_body.py +0 -196
- warp/examples/tile/example_tile_walker.py +0 -327
- warp/sim/__init__.py +0 -74
- warp/sim/articulation.py +0 -793
- warp/sim/collide.py +0 -2570
- warp/sim/graph_coloring.py +0 -307
- warp/sim/import_mjcf.py +0 -791
- warp/sim/import_snu.py +0 -227
- warp/sim/import_urdf.py +0 -579
- warp/sim/import_usd.py +0 -898
- warp/sim/inertia.py +0 -357
- warp/sim/integrator.py +0 -245
- warp/sim/integrator_euler.py +0 -2000
- warp/sim/integrator_featherstone.py +0 -2101
- warp/sim/integrator_vbd.py +0 -2487
- warp/sim/integrator_xpbd.py +0 -3295
- warp/sim/model.py +0 -4821
- warp/sim/particles.py +0 -121
- warp/sim/render.py +0 -431
- warp/sim/utils.py +0 -431
- warp/tests/sim/disabled_kinematics.py +0 -244
- warp/tests/sim/test_cloth.py +0 -863
- warp/tests/sim/test_collision.py +0 -743
- warp/tests/sim/test_coloring.py +0 -347
- warp/tests/sim/test_inertia.py +0 -161
- warp/tests/sim/test_model.py +0 -226
- warp/tests/sim/test_sim_grad.py +0 -287
- warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
- warp/tests/sim/test_sim_kinematics.py +0 -98
- warp/thirdparty/__init__.py +0 -0
- warp_lang-1.9.0.dist-info/RECORD +0 -456
- /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
- /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/native/tile.h
CHANGED
|
@@ -43,18 +43,10 @@
|
|
|
43
43
|
};
|
|
44
44
|
#endif
|
|
45
45
|
|
|
46
|
-
|
|
47
|
-
#ifndef WP_TILE_BLOCK_DIM
|
|
48
|
-
#define WP_TILE_BLOCK_DIM 256
|
|
49
|
-
#endif
|
|
50
|
-
|
|
51
|
-
#if !defined(__CUDA_ARCH__)
|
|
52
|
-
#define WP_TILE_SHARED static
|
|
53
|
-
#define WP_TILE_SYNC void
|
|
54
|
-
|
|
55
|
-
#else
|
|
56
|
-
#define WP_TILE_SHARED __shared__
|
|
46
|
+
#if defined(__CUDA_ARCH__)
|
|
57
47
|
#define WP_TILE_SYNC __syncthreads
|
|
48
|
+
#else
|
|
49
|
+
#define WP_TILE_SYNC void
|
|
58
50
|
#endif
|
|
59
51
|
|
|
60
52
|
#if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__)
|
|
@@ -140,7 +132,6 @@
|
|
|
140
132
|
[ ] LayerNorm
|
|
141
133
|
[ ] SoftMax
|
|
142
134
|
[ ] GEMM
|
|
143
|
-
[ ] warp.sim (CRBA)
|
|
144
135
|
[ ] Batched MLP
|
|
145
136
|
[ ] Layer norm
|
|
146
137
|
[ ] FNO + Burgers equation
|
|
@@ -149,7 +140,6 @@
|
|
|
149
140
|
[ ] MeshCNN (Modulus, Oliver)
|
|
150
141
|
[ ] BioNemo (Ali)
|
|
151
142
|
[ ] Skinning (David/Or/Vismay)
|
|
152
|
-
[ ] warp.sim (VBD)
|
|
153
143
|
[ ] Error checking
|
|
154
144
|
[ ] Ensure functions passed to tile_map() are compatible with tile type
|
|
155
145
|
[ ] Ensure that args passed to tile ops are compatible
|
|
@@ -213,6 +203,12 @@ struct is_same<T, T> {
|
|
|
213
203
|
static constexpr bool value = true;
|
|
214
204
|
};
|
|
215
205
|
|
|
206
|
+
// Helper for dependent static_assert failures
|
|
207
|
+
template <typename T>
|
|
208
|
+
struct always_false {
|
|
209
|
+
static constexpr bool value = false;
|
|
210
|
+
};
|
|
211
|
+
|
|
216
212
|
|
|
217
213
|
template <int N>
|
|
218
214
|
struct tile_coord_t
|
|
@@ -338,6 +334,113 @@ template <int... V>
|
|
|
338
334
|
using tile_stride_t = tile_tuple_t<V...>;
|
|
339
335
|
|
|
340
336
|
|
|
337
|
+
// helper to remove a dimension from a shape (used for axis reductions)
|
|
338
|
+
template<int Axis, typename Shape>
|
|
339
|
+
struct tile_shape_remove_dim {
|
|
340
|
+
static_assert(Axis >= 0 && Axis < Shape::N, "Axis out of bounds for tile_shape_remove_dim");
|
|
341
|
+
};
|
|
342
|
+
|
|
343
|
+
// 1D -> scalar
|
|
344
|
+
template<int D0>
|
|
345
|
+
struct tile_shape_remove_dim<0, tile_shape_t<D0>> {
|
|
346
|
+
using type = tile_shape_t<1>;
|
|
347
|
+
};
|
|
348
|
+
|
|
349
|
+
// 2D -> 1D
|
|
350
|
+
template<int D0, int D1>
|
|
351
|
+
struct tile_shape_remove_dim<0, tile_shape_t<D0, D1>> {
|
|
352
|
+
using type = tile_shape_t<D1>;
|
|
353
|
+
};
|
|
354
|
+
|
|
355
|
+
template<int D0, int D1>
|
|
356
|
+
struct tile_shape_remove_dim<1, tile_shape_t<D0, D1>> {
|
|
357
|
+
using type = tile_shape_t<D0>;
|
|
358
|
+
};
|
|
359
|
+
|
|
360
|
+
// 3D -> 2D
|
|
361
|
+
template<int D0, int D1, int D2>
|
|
362
|
+
struct tile_shape_remove_dim<0, tile_shape_t<D0, D1, D2>> {
|
|
363
|
+
using type = tile_shape_t<D1, D2>;
|
|
364
|
+
};
|
|
365
|
+
|
|
366
|
+
template<int D0, int D1, int D2>
|
|
367
|
+
struct tile_shape_remove_dim<1, tile_shape_t<D0, D1, D2>> {
|
|
368
|
+
using type = tile_shape_t<D0, D2>;
|
|
369
|
+
};
|
|
370
|
+
|
|
371
|
+
template<int D0, int D1, int D2>
|
|
372
|
+
struct tile_shape_remove_dim<2, tile_shape_t<D0, D1, D2>> {
|
|
373
|
+
using type = tile_shape_t<D0, D1>;
|
|
374
|
+
};
|
|
375
|
+
|
|
376
|
+
// 4D -> 3D
|
|
377
|
+
template<int D0, int D1, int D2, int D3>
|
|
378
|
+
struct tile_shape_remove_dim<0, tile_shape_t<D0, D1, D2, D3>> {
|
|
379
|
+
using type = tile_shape_t<D1, D2, D3>;
|
|
380
|
+
};
|
|
381
|
+
|
|
382
|
+
template<int D0, int D1, int D2, int D3>
|
|
383
|
+
struct tile_shape_remove_dim<1, tile_shape_t<D0, D1, D2, D3>> {
|
|
384
|
+
using type = tile_shape_t<D0, D2, D3>;
|
|
385
|
+
};
|
|
386
|
+
|
|
387
|
+
template<int D0, int D1, int D2, int D3>
|
|
388
|
+
struct tile_shape_remove_dim<2, tile_shape_t<D0, D1, D2, D3>> {
|
|
389
|
+
using type = tile_shape_t<D0, D1, D3>;
|
|
390
|
+
};
|
|
391
|
+
|
|
392
|
+
template<int D0, int D1, int D2, int D3>
|
|
393
|
+
struct tile_shape_remove_dim<3, tile_shape_t<D0, D1, D2, D3>> {
|
|
394
|
+
using type = tile_shape_t<D0, D1, D2>;
|
|
395
|
+
};
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
// helper to insert an axis value into a coordinate (inverse of removing dimension)
|
|
399
|
+
// used for mapping output coordinates back to input coordinates during axis reduction
|
|
400
|
+
template<int Axis, int N>
|
|
401
|
+
CUDA_CALLABLE constexpr auto tile_coord_insert_axis(const tile_coord_t<N>& coord, int axis_val)
|
|
402
|
+
{
|
|
403
|
+
static_assert(Axis >= 0 && Axis <= N, "Axis out of bounds for tile_coord_insert_axis");
|
|
404
|
+
|
|
405
|
+
if constexpr (N == 0)
|
|
406
|
+
{
|
|
407
|
+
// Scalar -> 1D
|
|
408
|
+
static_assert(Axis == 0, "Invalid axis for scalar coordinate");
|
|
409
|
+
return tile_coord(axis_val);
|
|
410
|
+
}
|
|
411
|
+
else if constexpr (N == 1)
|
|
412
|
+
{
|
|
413
|
+
// 1D -> 2D
|
|
414
|
+
if constexpr (Axis == 0)
|
|
415
|
+
return tile_coord(axis_val, coord[0]);
|
|
416
|
+
else
|
|
417
|
+
return tile_coord(coord[0], axis_val);
|
|
418
|
+
}
|
|
419
|
+
else if constexpr (N == 2)
|
|
420
|
+
{
|
|
421
|
+
// 2D -> 3D
|
|
422
|
+
if constexpr (Axis == 0)
|
|
423
|
+
return tile_coord(axis_val, coord[0], coord[1]);
|
|
424
|
+
else if constexpr (Axis == 1)
|
|
425
|
+
return tile_coord(coord[0], axis_val, coord[1]);
|
|
426
|
+
else
|
|
427
|
+
return tile_coord(coord[0], coord[1], axis_val);
|
|
428
|
+
}
|
|
429
|
+
else // N == 3
|
|
430
|
+
{
|
|
431
|
+
// 3D -> 4D
|
|
432
|
+
if constexpr (Axis == 0)
|
|
433
|
+
return tile_coord(axis_val, coord[0], coord[1], coord[2]);
|
|
434
|
+
else if constexpr (Axis == 1)
|
|
435
|
+
return tile_coord(coord[0], axis_val, coord[1], coord[2]);
|
|
436
|
+
else if constexpr (Axis == 2)
|
|
437
|
+
return tile_coord(coord[0], coord[1], axis_val, coord[2]);
|
|
438
|
+
else
|
|
439
|
+
return tile_coord(coord[0], coord[1], coord[2], axis_val);
|
|
440
|
+
}
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
|
|
341
444
|
// represents a tile stored in global memory with dynamic strides
|
|
342
445
|
// used to represent the source and offset for tile loads to register/shared
|
|
343
446
|
// BoundsCheck: when true (default), validates array access bounds; when false, skips validation for performance
|
|
@@ -542,7 +645,7 @@ struct tile_register_t
|
|
|
542
645
|
|
|
543
646
|
// define the += operator which is used during backward pass codegen
|
|
544
647
|
// when returning a register tile from a user defined function
|
|
545
|
-
inline CUDA_CALLABLE auto& operator += (tile_register_t<T, Layout>& rhs)
|
|
648
|
+
inline CUDA_CALLABLE auto& operator += (const tile_register_t<T, Layout>& rhs)
|
|
546
649
|
{
|
|
547
650
|
grad_add(rhs);
|
|
548
651
|
return *this;
|
|
@@ -581,7 +684,11 @@ struct tile_register_t
|
|
|
581
684
|
const int thread = Layout::thread_from_linear(linear);
|
|
582
685
|
const int reg = Layout::register_from_linear(linear);
|
|
583
686
|
|
|
584
|
-
|
|
687
|
+
#if defined(__CUDA_ARCH__)
|
|
688
|
+
__shared__ Type scratch;
|
|
689
|
+
#else
|
|
690
|
+
Type scratch;
|
|
691
|
+
#endif
|
|
585
692
|
|
|
586
693
|
// ensure any previously scheduled threads have finished reading from scratch
|
|
587
694
|
WP_TILE_SYNC();
|
|
@@ -658,7 +765,7 @@ struct tile_register_t
|
|
|
658
765
|
data[i] += tile.data[i];
|
|
659
766
|
}
|
|
660
767
|
|
|
661
|
-
CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
|
|
768
|
+
inline CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
|
|
662
769
|
{
|
|
663
770
|
apply([&](int reg, auto c) {data[reg] += global.load_grad(c);});
|
|
664
771
|
}
|
|
@@ -735,42 +842,124 @@ inline CUDA_CALLABLE int tile_align(int num_bytes)
|
|
|
735
842
|
return sign * ((num_bytes_abs + alignment - 1) / alignment) * alignment;
|
|
736
843
|
}
|
|
737
844
|
|
|
738
|
-
|
|
845
|
+
#if defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
|
|
846
|
+
// On the CPU we use a fixed size block of stack memory for shared tile allocations.
|
|
847
|
+
// We store a pointer to the current allocation storage either in a reserved register
|
|
848
|
+
// (AArch64) or a static variable (x86-64).
|
|
849
|
+
#if !defined(__CUDA_ARCH__)
|
|
850
|
+
class tile_shared_storage_t;
|
|
851
|
+
#if defined(__aarch64__)
|
|
852
|
+
// x28 is is the last callee-saved register on AArch64. This allows us to call externally
|
|
853
|
+
// compiled functions without worrying about clobbering the pointer.
|
|
854
|
+
// We pass -target-feature +reserve-x28 to Clang to exclude it from register allocation.
|
|
855
|
+
register tile_shared_storage_t* shared_tile_storage asm("x28");
|
|
856
|
+
#else
|
|
857
|
+
// Ideally this would be thread_local, but LLVM's JIT doesn't support TLS yet
|
|
858
|
+
// There is also no support for something like -ffixed-r15 either
|
|
859
|
+
static tile_shared_storage_t* shared_tile_storage;
|
|
860
|
+
#endif
|
|
861
|
+
#endif
|
|
862
|
+
#endif
|
|
863
|
+
|
|
864
|
+
// This class manages a block of "shared" memory for use by tiles.
|
|
865
|
+
// On the GPU this maps to dynamic shared memory, while on the CPU we allocate
|
|
866
|
+
// a fixed size block of memory on the stack and manage allocations from it.
|
|
867
|
+
// An instance of this class gets created at the start of a kernel.
|
|
868
|
+
class tile_shared_storage_t
|
|
739
869
|
{
|
|
870
|
+
private:
|
|
871
|
+
#if !defined(__CUDA_ARCH__)
|
|
872
|
+
#define WP_MAX_CPU_SHARED 256*1024
|
|
873
|
+
#if defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
|
|
874
|
+
tile_shared_storage_t* old_value;
|
|
875
|
+
unsigned int smem_base[WP_TILE_BLOCK_DIM];
|
|
876
|
+
char dynamic_smem_base[WP_MAX_CPU_SHARED]; // on CPU allocate a fixed 256k block to use for shared allocs
|
|
877
|
+
#endif
|
|
878
|
+
#endif
|
|
879
|
+
|
|
740
880
|
// we maintain a per-thread offset into dynamic
|
|
741
881
|
// shared memory that allows us to keep track of
|
|
742
882
|
// current use across dynamic function calls
|
|
743
|
-
|
|
883
|
+
static inline CUDA_CALLABLE unsigned int* get_smem_base()
|
|
884
|
+
{
|
|
885
|
+
#if defined(__CUDA_ARCH__)
|
|
886
|
+
__shared__ unsigned int smem_base[WP_TILE_BLOCK_DIM];
|
|
887
|
+
return smem_base;
|
|
888
|
+
#elif defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
|
|
889
|
+
return shared_tile_storage->smem_base;
|
|
890
|
+
#else
|
|
891
|
+
static unsigned int smem_base[WP_TILE_BLOCK_DIM];
|
|
892
|
+
return smem_base;
|
|
893
|
+
#endif
|
|
894
|
+
}
|
|
744
895
|
|
|
745
|
-
|
|
896
|
+
static inline CUDA_CALLABLE char* get_dynamic_smem_base()
|
|
746
897
|
{
|
|
898
|
+
#if defined(__CUDA_ARCH__)
|
|
899
|
+
extern __shared__ char dynamic_smem_base[];
|
|
900
|
+
return dynamic_smem_base;
|
|
901
|
+
#elif defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
|
|
902
|
+
return shared_tile_storage->dynamic_smem_base;
|
|
903
|
+
#else
|
|
904
|
+
static char dynamic_smem_base[WP_MAX_CPU_SHARED];
|
|
905
|
+
return dynamic_smem_base;
|
|
906
|
+
#endif
|
|
907
|
+
}
|
|
908
|
+
|
|
909
|
+
public:
|
|
910
|
+
// cppcheck-suppress uninitMemberVar
|
|
911
|
+
inline CUDA_CALLABLE tile_shared_storage_t()
|
|
912
|
+
{
|
|
913
|
+
#if !defined(__CUDA_ARCH__) && defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
|
|
914
|
+
// On the CPU save a pointer to this instance in a reserved register
|
|
915
|
+
// or static variable so it can be accessed from anywhere within a kernel.
|
|
916
|
+
old_value = shared_tile_storage;
|
|
917
|
+
shared_tile_storage = this;
|
|
918
|
+
#endif
|
|
919
|
+
|
|
920
|
+
init();
|
|
921
|
+
}
|
|
922
|
+
|
|
923
|
+
inline CUDA_CALLABLE ~tile_shared_storage_t()
|
|
924
|
+
{
|
|
925
|
+
check();
|
|
926
|
+
|
|
927
|
+
#if !defined(__CUDA_ARCH__) && defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
|
|
928
|
+
shared_tile_storage = old_value;
|
|
929
|
+
#endif
|
|
930
|
+
}
|
|
931
|
+
|
|
932
|
+
static inline CUDA_CALLABLE void init()
|
|
933
|
+
{
|
|
934
|
+
unsigned int* smem_base = get_smem_base();
|
|
935
|
+
|
|
747
936
|
smem_base[WP_TILE_THREAD_IDX] = 0;
|
|
748
|
-
return nullptr;
|
|
749
937
|
}
|
|
750
|
-
|
|
938
|
+
|
|
939
|
+
static inline CUDA_CALLABLE void check()
|
|
751
940
|
{
|
|
941
|
+
unsigned int* smem_base = get_smem_base();
|
|
942
|
+
|
|
752
943
|
assert(smem_base[WP_TILE_THREAD_IDX] == 0);
|
|
753
|
-
return nullptr;
|
|
754
944
|
}
|
|
755
|
-
|
|
945
|
+
|
|
946
|
+
static inline CUDA_CALLABLE void* alloc(int num_bytes)
|
|
756
947
|
{
|
|
757
|
-
|
|
758
|
-
|
|
948
|
+
unsigned int* smem_base = get_smem_base();
|
|
949
|
+
char* dynamic_smem_base = get_dynamic_smem_base();
|
|
950
|
+
|
|
951
|
+
const unsigned int offset = smem_base[WP_TILE_THREAD_IDX];
|
|
952
|
+
|
|
759
953
|
// one entry per-thread so no need for synchronization
|
|
760
954
|
smem_base[WP_TILE_THREAD_IDX] += tile_align(num_bytes);
|
|
761
955
|
|
|
762
|
-
#
|
|
763
|
-
|
|
764
|
-
#else
|
|
765
|
-
// on CPU allocate a fixed 256k block to use for shared allocs
|
|
766
|
-
static const int max_cpu_shared = 256*1024;
|
|
767
|
-
static char dynamic_smem_base[max_cpu_shared];
|
|
768
|
-
|
|
769
|
-
assert(smem_base[WP_TILE_THREAD_IDX] <= max_cpu_shared);
|
|
956
|
+
#if !defined(__CUDA_ARCH__)
|
|
957
|
+
assert(smem_base[WP_TILE_THREAD_IDX] <= WP_MAX_CPU_SHARED);
|
|
770
958
|
#endif
|
|
959
|
+
|
|
771
960
|
return &(dynamic_smem_base[offset]);
|
|
772
961
|
}
|
|
773
|
-
}
|
|
962
|
+
};
|
|
774
963
|
|
|
775
964
|
|
|
776
965
|
template <typename Shape_, typename Stride_= typename compute_strides<Shape_>::Stride>
|
|
@@ -905,6 +1094,28 @@ struct tile_shared_t
|
|
|
905
1094
|
{
|
|
906
1095
|
}
|
|
907
1096
|
|
|
1097
|
+
// we delete the copy constructor because in the case the shared tile is owning,
|
|
1098
|
+
// this leads to a double deallocation.
|
|
1099
|
+
// this also forces one to handle copies explicitly
|
|
1100
|
+
inline CUDA_CALLABLE tile_shared_t(const tile_shared_t& other) : data(other.data), grad(other.grad), initialized(other.initialized)
|
|
1101
|
+
{
|
|
1102
|
+
static_assert(!Owner, "Copy constructor is only supported for non-owning tiles.");
|
|
1103
|
+
}
|
|
1104
|
+
|
|
1105
|
+
// move constructor
|
|
1106
|
+
inline CUDA_CALLABLE tile_shared_t(tile_shared_t&& other) : data(other.data), grad(other.grad), initialized(other.initialized)
|
|
1107
|
+
{
|
|
1108
|
+
other.data.ptr = nullptr;
|
|
1109
|
+
other.grad.ptr = nullptr;
|
|
1110
|
+
}
|
|
1111
|
+
|
|
1112
|
+
template <typename OtherT, typename OtherLayout, bool OtherOwner>
|
|
1113
|
+
inline CUDA_CALLABLE tile_shared_t(const tile_shared_t<OtherT, OtherLayout, OtherOwner>& other) : data(other.data.ptr), grad(other.grad.ptr), initialized(other.initialized)
|
|
1114
|
+
{
|
|
1115
|
+
static_assert(!Owner, "Copy constructor is only supported for non-owning tiles.");
|
|
1116
|
+
static_assert(Layout::Size == OtherLayout::Size, "Expected Size == OtherLayout::Size");
|
|
1117
|
+
}
|
|
1118
|
+
|
|
908
1119
|
// initialize from an existing tile's memory
|
|
909
1120
|
inline CUDA_CALLABLE tile_shared_t(T* data, T* grad=nullptr, bool initialized=true) : data(data), grad(grad), initialized(initialized)
|
|
910
1121
|
{
|
|
@@ -916,10 +1127,10 @@ struct tile_shared_t
|
|
|
916
1127
|
{
|
|
917
1128
|
// update our per-thread shared memory allocator
|
|
918
1129
|
if (data.ptr)
|
|
919
|
-
|
|
1130
|
+
tile_shared_storage_t::alloc(-Layout::Size*int(sizeof(T)));
|
|
920
1131
|
|
|
921
1132
|
if (grad.ptr)
|
|
922
|
-
|
|
1133
|
+
tile_shared_storage_t::alloc(-Layout::Size*int(sizeof(T)));
|
|
923
1134
|
}
|
|
924
1135
|
}
|
|
925
1136
|
|
|
@@ -932,19 +1143,47 @@ struct tile_shared_t
|
|
|
932
1143
|
|
|
933
1144
|
// construct from another shared tile, this constructor
|
|
934
1145
|
// is invoked for reshape operations like `wp.tile_transpose()`
|
|
1146
|
+
// or `wp::copy()`
|
|
935
1147
|
template <typename OtherT, typename OtherLayout, bool OtherOwner>
|
|
936
1148
|
inline CUDA_CALLABLE auto& operator=(const tile_shared_t<OtherT, OtherLayout, OtherOwner>& rhs)
|
|
937
1149
|
{
|
|
938
1150
|
// check dimensions are compatible
|
|
939
1151
|
static_assert(Layout::Size == OtherLayout::Size, "Expected Size == OtherLayout::Size");
|
|
940
1152
|
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
1153
|
+
|
|
1154
|
+
if (Owner)
|
|
1155
|
+
{
|
|
1156
|
+
// if the tile owns the data we need to copy
|
|
1157
|
+
assign(rhs);
|
|
1158
|
+
}
|
|
1159
|
+
else
|
|
1160
|
+
{
|
|
1161
|
+
// alias tile directly
|
|
1162
|
+
data.ptr = rhs.data.ptr;
|
|
1163
|
+
grad.ptr = rhs.grad.ptr;
|
|
1164
|
+
initialized = rhs.initialized;
|
|
1165
|
+
}
|
|
945
1166
|
|
|
946
1167
|
return *this;
|
|
947
|
-
}
|
|
1168
|
+
}
|
|
1169
|
+
|
|
1170
|
+
inline CUDA_CALLABLE auto& operator=(const tile_shared_t& rhs)
|
|
1171
|
+
{
|
|
1172
|
+
if (Owner)
|
|
1173
|
+
{
|
|
1174
|
+
// if the tile owns the data we need to copy
|
|
1175
|
+
assign(rhs);
|
|
1176
|
+
}
|
|
1177
|
+
else
|
|
1178
|
+
{
|
|
1179
|
+
// alias tile directly
|
|
1180
|
+
data.ptr = rhs.data.ptr;
|
|
1181
|
+
grad.ptr = rhs.grad.ptr;
|
|
1182
|
+
initialized = rhs.initialized;
|
|
1183
|
+
}
|
|
1184
|
+
|
|
1185
|
+
return *this;
|
|
1186
|
+
}
|
|
948
1187
|
|
|
949
1188
|
// assign from a global tile (load)
|
|
950
1189
|
|
|
@@ -972,6 +1211,21 @@ struct tile_shared_t
|
|
|
972
1211
|
return *this;
|
|
973
1212
|
}
|
|
974
1213
|
|
|
1214
|
+
// define the += operator which is used during backward pass codegen
|
|
1215
|
+
// when returning a register tile from a user defined function
|
|
1216
|
+
template<typename OtherLayout>
|
|
1217
|
+
inline CUDA_CALLABLE auto& operator += (const tile_register_t<T, OtherLayout>& rhs)
|
|
1218
|
+
{
|
|
1219
|
+
grad_add(rhs);
|
|
1220
|
+
return *this;
|
|
1221
|
+
}
|
|
1222
|
+
|
|
1223
|
+
inline CUDA_CALLABLE auto& operator += (const tile_shared_t<T, Layout>& rhs)
|
|
1224
|
+
{
|
|
1225
|
+
grad_add(rhs);
|
|
1226
|
+
return *this;
|
|
1227
|
+
}
|
|
1228
|
+
|
|
975
1229
|
// in-place zero
|
|
976
1230
|
inline CUDA_CALLABLE void zero()
|
|
977
1231
|
{
|
|
@@ -1029,6 +1283,46 @@ struct tile_shared_t
|
|
|
1029
1283
|
adj_x -= grad(c);
|
|
1030
1284
|
}
|
|
1031
1285
|
|
|
1286
|
+
// perform AND between a scalar value and a single tile element
|
|
1287
|
+
inline CUDA_CALLABLE void bit_and_inplace(const typename Layout::Coord& c, const Type& x)
|
|
1288
|
+
{
|
|
1289
|
+
// since multiple threads may access the same element
|
|
1290
|
+
// we need to access using atomic operations
|
|
1291
|
+
wp::atomic_and(&data(c), x);
|
|
1292
|
+
|
|
1293
|
+
WP_TILE_SYNC();
|
|
1294
|
+
}
|
|
1295
|
+
|
|
1296
|
+
// backward of inplace scalar AND
|
|
1297
|
+
inline CUDA_CALLABLE void adj_bit_and_inplace(const typename Layout::Coord& c, Type& adj_x) {}
|
|
1298
|
+
|
|
1299
|
+
|
|
1300
|
+
// perform OR between a scalar value and a single tile element
|
|
1301
|
+
inline CUDA_CALLABLE void bit_or_inplace(const typename Layout::Coord& c, const Type& x)
|
|
1302
|
+
{
|
|
1303
|
+
// since multiple threads may access the same element
|
|
1304
|
+
// we need to access using atomic operations
|
|
1305
|
+
wp::atomic_or(&data(c), x);
|
|
1306
|
+
|
|
1307
|
+
WP_TILE_SYNC();
|
|
1308
|
+
}
|
|
1309
|
+
|
|
1310
|
+
// backward of inplace scalar OR
|
|
1311
|
+
inline CUDA_CALLABLE void adj_bit_or_inplace(const typename Layout::Coord& c, Type& adj_x) {}
|
|
1312
|
+
|
|
1313
|
+
// perform XOR between a scalar value and a single tile element
|
|
1314
|
+
inline CUDA_CALLABLE void bit_xor_inplace(const typename Layout::Coord& c, const Type& x)
|
|
1315
|
+
{
|
|
1316
|
+
// since multiple threads may access the same element
|
|
1317
|
+
// we need to access using atomic operations
|
|
1318
|
+
wp::atomic_xor(&data(c), x);
|
|
1319
|
+
|
|
1320
|
+
WP_TILE_SYNC();
|
|
1321
|
+
}
|
|
1322
|
+
|
|
1323
|
+
// backward of inplace scalar XOR
|
|
1324
|
+
inline CUDA_CALLABLE void adj_bit_xor_inplace(const typename Layout::Coord& c, Type& adj_x) {}
|
|
1325
|
+
|
|
1032
1326
|
// copy register tile to shared
|
|
1033
1327
|
template <typename Tile>
|
|
1034
1328
|
inline CUDA_CALLABLE void assign(const Tile& tile)
|
|
@@ -1053,6 +1347,27 @@ struct tile_shared_t
|
|
|
1053
1347
|
WP_TILE_SYNC();
|
|
1054
1348
|
}
|
|
1055
1349
|
|
|
1350
|
+
// shared tile deep copy
|
|
1351
|
+
template <typename OtherT, typename OtherLayout, bool OtherOwner>
|
|
1352
|
+
inline CUDA_CALLABLE void assign(const tile_shared_t<OtherT, OtherLayout, OtherOwner>& tile)
|
|
1353
|
+
{
|
|
1354
|
+
// check dimensions are compatible
|
|
1355
|
+
static_assert(Layout::Size == OtherLayout::Size, "Expected Size == OtherLayout::Size");
|
|
1356
|
+
|
|
1357
|
+
if (initialized)
|
|
1358
|
+
WP_TILE_SYNC();
|
|
1359
|
+
|
|
1360
|
+
WP_PRAGMA_UNROLL
|
|
1361
|
+
for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
|
|
1362
|
+
{
|
|
1363
|
+
auto c = Layout::coord_from_linear(i);
|
|
1364
|
+
data(c) = tile.data(c);
|
|
1365
|
+
}
|
|
1366
|
+
|
|
1367
|
+
initialized = true;
|
|
1368
|
+
WP_TILE_SYNC();
|
|
1369
|
+
}
|
|
1370
|
+
|
|
1056
1371
|
// in-place gradient zero
|
|
1057
1372
|
inline CUDA_CALLABLE void grad_zero()
|
|
1058
1373
|
{
|
|
@@ -1092,8 +1407,21 @@ struct tile_shared_t
|
|
|
1092
1407
|
WP_TILE_SYNC();
|
|
1093
1408
|
}
|
|
1094
1409
|
|
|
1410
|
+
// accumulate gradients onto this tile from another shared tile
|
|
1411
|
+
inline CUDA_CALLABLE void grad_add(const tile_shared_t<T, Layout>& tile)
|
|
1412
|
+
{
|
|
1413
|
+
WP_PRAGMA_UNROLL
|
|
1414
|
+
for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
|
|
1415
|
+
{
|
|
1416
|
+
auto c = Layout::coord_from_linear(i);
|
|
1417
|
+
grad(c) += tile.grad(c);
|
|
1418
|
+
}
|
|
1419
|
+
|
|
1420
|
+
WP_TILE_SYNC();
|
|
1421
|
+
}
|
|
1422
|
+
|
|
1095
1423
|
// accumulate gradient onto this tile from a global array
|
|
1096
|
-
CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
|
|
1424
|
+
inline CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
|
|
1097
1425
|
{
|
|
1098
1426
|
WP_PRAGMA_UNROLL
|
|
1099
1427
|
for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
|
|
@@ -1449,7 +1777,11 @@ void tile_register_t<T, L>::print() const
|
|
|
1449
1777
|
{
|
|
1450
1778
|
// create a temporary shared tile so that
|
|
1451
1779
|
// we can print it deterministically
|
|
1452
|
-
|
|
1780
|
+
#if defined(__CUDA_ARCH__)
|
|
1781
|
+
__shared__ T smem[L::Size];
|
|
1782
|
+
#else
|
|
1783
|
+
T smem[L::Size];
|
|
1784
|
+
#endif
|
|
1453
1785
|
tile_shared_t<T, tile_layout_strided_t<typename L::Shape>, false> scratch(smem, nullptr);
|
|
1454
1786
|
|
|
1455
1787
|
scratch.assign(*this);
|
|
@@ -1477,9 +1809,16 @@ void tile_register_t<T, L>::print() const
|
|
|
1477
1809
|
// print entry points
|
|
1478
1810
|
template <typename T, typename L>
|
|
1479
1811
|
inline CUDA_CALLABLE void print(const tile_register_t<T, L>& t) { t.print(); }
|
|
1812
|
+
|
|
1813
|
+
template <typename T, typename L>
|
|
1814
|
+
inline CUDA_CALLABLE void adj_print(const tile_register_t<T, L>& t, const tile_register_t<T, L>& a) { a.print(); }
|
|
1815
|
+
|
|
1480
1816
|
template <typename T, typename L, bool Owner>
|
|
1481
1817
|
inline CUDA_CALLABLE void print(const tile_shared_t<T, L, Owner>& t) { t.print(); }
|
|
1482
1818
|
|
|
1819
|
+
template <typename T, typename L, bool Owner>
|
|
1820
|
+
inline CUDA_CALLABLE void adj_print(const tile_shared_t<T, L, Owner>& t, const tile_shared_t<T, L, Owner>& a) { a.print(true); }
|
|
1821
|
+
|
|
1483
1822
|
template <typename T, typename L, bool O>
|
|
1484
1823
|
inline CUDA_CALLABLE int len(const tile_shared_t<T, L, O>& t)
|
|
1485
1824
|
{
|
|
@@ -1502,20 +1841,57 @@ inline CUDA_CALLABLE void adj_len(const tile_register_t<T,L>& t, const AdjTile&
|
|
|
1502
1841
|
{
|
|
1503
1842
|
}
|
|
1504
1843
|
|
|
1844
|
+
// where specialization for register/shared tiles
|
|
1845
|
+
template <typename C, typename T, typename LRegister, typename LShared, bool Owner>
|
|
1846
|
+
inline CUDA_CALLABLE auto where(const C& cond, const tile_register_t<T, LRegister>& a, const tile_shared_t<T, LShared, Owner>& b)
|
|
1847
|
+
{
|
|
1848
|
+
// The double NOT operator !! casts to bool without compiler warnings.
|
|
1849
|
+
return (!!cond) ? a : b.copy_to_register();
|
|
1850
|
+
}
|
|
1851
|
+
|
|
1852
|
+
template <typename C, typename T, typename LRegister, typename LShared, bool Owner>
|
|
1853
|
+
inline CUDA_CALLABLE auto where(const C& cond, const tile_shared_t<T, LShared, Owner>& a, const tile_register_t<T, LRegister>& b)
|
|
1854
|
+
{
|
|
1855
|
+
// The double NOT operator !! casts to bool without compiler warnings.
|
|
1856
|
+
return (!!cond) ? a.copy_to_register() : b;
|
|
1857
|
+
}
|
|
1505
1858
|
|
|
1506
|
-
template <typename T, typename L>
|
|
1507
|
-
inline CUDA_CALLABLE
|
|
1508
|
-
|
|
1509
|
-
|
|
1859
|
+
template <typename C, typename T, typename L, bool Owner>
|
|
1860
|
+
inline CUDA_CALLABLE auto where(const C& cond, const tile_shared_t<T, L, Owner>& a, const tile_shared_t<T, L, Owner>& b)
|
|
1861
|
+
{
|
|
1862
|
+
// The double NOT operator !! casts to bool without compiler warnings.
|
|
1863
|
+
return (!!cond) ? tile_shared_t<T, L, false>(a.data.ptr, a.grad.ptr) : tile_shared_t<T, L, false>(b.data.ptr, b.grad.ptr);
|
|
1864
|
+
}
|
|
1510
1865
|
|
|
1866
|
+
template <typename C, typename T, typename L, bool LOwner, bool ROwner>
|
|
1867
|
+
inline CUDA_CALLABLE auto where(const C& cond, const tile_shared_t<T, L, LOwner>& a, const tile_shared_t<T, L, ROwner>& b)
|
|
1868
|
+
{
|
|
1869
|
+
// The double NOT operator !! casts to bool without compiler warnings.
|
|
1870
|
+
return (!!cond) ? tile_shared_t<T, L, false>(a.data.ptr, a.grad.ptr) : tile_shared_t<T, L, false>(b.data.ptr, b.grad.ptr);
|
|
1871
|
+
}
|
|
1511
1872
|
|
|
1873
|
+
// adj_where same as in builtin.h
|
|
1874
|
+
|
|
1875
|
+
// copy specialization for shared tiles, the lvalue this gets assigned to is owning, thus, this invokes the copy assign path
|
|
1876
|
+
template <typename T, typename L, bool Owner>
|
|
1877
|
+
inline CUDA_CALLABLE auto copy(const tile_shared_t<T, L, Owner>& t)
|
|
1878
|
+
{
|
|
1879
|
+
return tile_shared_t<T, L, false>(t.data.ptr, t.grad.ptr);
|
|
1880
|
+
}
|
|
1881
|
+
|
|
1882
|
+
template <typename T, typename L, bool Owner>
|
|
1883
|
+
inline CUDA_CALLABLE void adj_copy(const tile_shared_t<T, L, Owner>& src, tile_shared_t<T, L, Owner>& adj_src, tile_shared_t<T, L, Owner>& adj_dest)
|
|
1884
|
+
{
|
|
1885
|
+
adj_src += adj_dest;
|
|
1886
|
+
adj_dest.grad_zero();
|
|
1887
|
+
}
|
|
1512
1888
|
|
|
1513
1889
|
// helpers to allocate shared tiles
|
|
1514
1890
|
template <typename T, typename Shape, typename Strides, bool RequiresGrad>
|
|
1515
1891
|
inline CUDA_CALLABLE auto tile_alloc_empty()
|
|
1516
1892
|
{
|
|
1517
1893
|
constexpr int size = Shape::size();
|
|
1518
|
-
T* data = (T*)
|
|
1894
|
+
T* data = (T*)tile_shared_storage_t::alloc(size*sizeof(T));
|
|
1519
1895
|
T* grad = nullptr;
|
|
1520
1896
|
|
|
1521
1897
|
#if FP_CHECK
|
|
@@ -1534,7 +1910,7 @@ inline CUDA_CALLABLE auto tile_alloc_empty()
|
|
|
1534
1910
|
|
|
1535
1911
|
if (RequiresGrad)
|
|
1536
1912
|
{
|
|
1537
|
-
grad = (T*)
|
|
1913
|
+
grad = (T*)tile_shared_storage_t::alloc(size*sizeof(T));
|
|
1538
1914
|
|
|
1539
1915
|
for (int i=WP_TILE_THREAD_IDX; i < size; i+= WP_TILE_BLOCK_DIM)
|
|
1540
1916
|
grad[i] = T(0);
|
|
@@ -1712,6 +2088,14 @@ inline CUDA_CALLABLE auto tile_ones()
|
|
|
1712
2088
|
return T(1);
|
|
1713
2089
|
}
|
|
1714
2090
|
|
|
2091
|
+
// value-initialized tile
|
|
2092
|
+
template <typename T, unsigned... Shape>
|
|
2093
|
+
inline CUDA_CALLABLE auto tile_full(T x)
|
|
2094
|
+
{
|
|
2095
|
+
// tile variable assignment operator will handle initialization (since lhs could be shared/register tile)
|
|
2096
|
+
return x;
|
|
2097
|
+
}
|
|
2098
|
+
|
|
1715
2099
|
// tile with evenly spaced values
|
|
1716
2100
|
template <typename T, int Len>
|
|
1717
2101
|
inline CUDA_CALLABLE auto tile_arange(T start, T stop, T step)
|
|
@@ -2263,6 +2647,43 @@ inline CUDA_CALLABLE void adj_tile_mul(const typename Tile::Type& s, Tile& a,
|
|
|
2263
2647
|
}
|
|
2264
2648
|
|
|
2265
2649
|
|
|
2650
|
+
// tile & tile
|
|
2651
|
+
template <typename TileA, typename TileB>
|
|
2652
|
+
inline CUDA_CALLABLE auto tile_bit_and(TileA& a, TileB& b)
|
|
2653
|
+
{
|
|
2654
|
+
return tile_binary_map(bit_and, a, b, a);
|
|
2655
|
+
}
|
|
2656
|
+
|
|
2657
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
|
|
2658
|
+
inline CUDA_CALLABLE void adj_tile_bit_and(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
|
|
2659
|
+
{
|
|
2660
|
+
}
|
|
2661
|
+
|
|
2662
|
+
// tile | tile
|
|
2663
|
+
template <typename TileA, typename TileB>
|
|
2664
|
+
inline CUDA_CALLABLE auto tile_bit_or(TileA& a, TileB& b)
|
|
2665
|
+
{
|
|
2666
|
+
return tile_binary_map(bit_or, a, b, a);
|
|
2667
|
+
}
|
|
2668
|
+
|
|
2669
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
|
|
2670
|
+
inline CUDA_CALLABLE void adj_tile_bit_or(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
|
|
2671
|
+
{
|
|
2672
|
+
}
|
|
2673
|
+
|
|
2674
|
+
// tile ^ tile
|
|
2675
|
+
template <typename TileA, typename TileB>
|
|
2676
|
+
inline CUDA_CALLABLE auto tile_bit_xor(TileA& a, TileB& b)
|
|
2677
|
+
{
|
|
2678
|
+
return tile_binary_map(bit_xor, a, b, a);
|
|
2679
|
+
}
|
|
2680
|
+
|
|
2681
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
|
|
2682
|
+
inline CUDA_CALLABLE void adj_tile_bit_xor(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
|
|
2683
|
+
{
|
|
2684
|
+
}
|
|
2685
|
+
|
|
2686
|
+
|
|
2266
2687
|
template <typename TileA, typename TileB>
|
|
2267
2688
|
inline CUDA_CALLABLE void tile_add_inplace(TileA& a, TileB& b)
|
|
2268
2689
|
{
|
|
@@ -2382,24 +2803,227 @@ inline CUDA_CALLABLE void adj_tile_sub_inplace(TileA& a, TileB& b, AdjTileA& adj
|
|
|
2382
2803
|
adj_b.grad_add(adj_b_reg);
|
|
2383
2804
|
}
|
|
2384
2805
|
|
|
2806
|
+
template <typename TileA, typename TileB>
|
|
2807
|
+
inline CUDA_CALLABLE void tile_bit_and_inplace(TileA& a, TileB& b)
|
|
2808
|
+
{
|
|
2809
|
+
using ShapeA = typename TileA::Layout::Shape;
|
|
2810
|
+
using ShapeB = typename TileB::Layout::Shape;
|
|
2811
|
+
|
|
2812
|
+
// verify shapes and sizes are compatible
|
|
2813
|
+
static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace bitwise AND");
|
|
2814
|
+
static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace bitwise AND");
|
|
2815
|
+
|
|
2816
|
+
// work with register tiles for inplace operations, regardless of the storage type of the input tiles
|
|
2817
|
+
auto a_reg = a.copy_to_register();
|
|
2818
|
+
auto b_reg = b.copy_to_register();
|
|
2819
|
+
|
|
2820
|
+
using Layout = typename decltype(a_reg)::Layout;
|
|
2821
|
+
|
|
2822
|
+
WP_PRAGMA_UNROLL
|
|
2823
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
2824
|
+
{
|
|
2825
|
+
const int linear = Layout::linear_from_register(i);
|
|
2826
|
+
|
|
2827
|
+
if(!Layout::valid(linear))
|
|
2828
|
+
break;
|
|
2829
|
+
|
|
2830
|
+
a_reg.data[i] &= b_reg.data[i];
|
|
2831
|
+
}
|
|
2832
|
+
|
|
2833
|
+
a.assign(a_reg);
|
|
2834
|
+
}
|
|
2835
|
+
|
|
2836
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
|
|
2837
|
+
inline CUDA_CALLABLE void adj_tile_bit_and_inplace(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b) {}
|
|
2838
|
+
|
|
2839
|
+
template <typename TileA, typename TileB>
|
|
2840
|
+
inline CUDA_CALLABLE void tile_bit_or_inplace(TileA& a, TileB& b)
|
|
2841
|
+
{
|
|
2842
|
+
using ShapeA = typename TileA::Layout::Shape;
|
|
2843
|
+
using ShapeB = typename TileB::Layout::Shape;
|
|
2844
|
+
|
|
2845
|
+
// verify shapes and sizes are compatible
|
|
2846
|
+
static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace bitwise OR");
|
|
2847
|
+
static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace bitwise OR");
|
|
2848
|
+
|
|
2849
|
+
// work with register tiles for inplace operations, regardless of the storage type of the input tiles
|
|
2850
|
+
auto a_reg = a.copy_to_register();
|
|
2851
|
+
auto b_reg = b.copy_to_register();
|
|
2852
|
+
|
|
2853
|
+
using Layout = typename decltype(a_reg)::Layout;
|
|
2854
|
+
|
|
2855
|
+
WP_PRAGMA_UNROLL
|
|
2856
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
2857
|
+
{
|
|
2858
|
+
const int linear = Layout::linear_from_register(i);
|
|
2859
|
+
|
|
2860
|
+
if(!Layout::valid(linear))
|
|
2861
|
+
break;
|
|
2862
|
+
|
|
2863
|
+
a_reg.data[i] |= b_reg.data[i];
|
|
2864
|
+
}
|
|
2865
|
+
|
|
2866
|
+
a.assign(a_reg);
|
|
2867
|
+
}
|
|
2868
|
+
|
|
2869
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
|
|
2870
|
+
inline CUDA_CALLABLE void adj_tile_bit_or_inplace(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b) {}
|
|
2871
|
+
|
|
2872
|
+
template <typename TileA, typename TileB>
|
|
2873
|
+
inline CUDA_CALLABLE void tile_bit_xor_inplace(TileA& a, TileB& b)
|
|
2874
|
+
{
|
|
2875
|
+
using ShapeA = typename TileA::Layout::Shape;
|
|
2876
|
+
using ShapeB = typename TileB::Layout::Shape;
|
|
2877
|
+
|
|
2878
|
+
// verify shapes and sizes are compatible
|
|
2879
|
+
static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace bitwise XOR");
|
|
2880
|
+
static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace bitwise XOR");
|
|
2881
|
+
|
|
2882
|
+
// work with register tiles for inplace operations, regardless of the storage type of the input tiles
|
|
2883
|
+
auto a_reg = a.copy_to_register();
|
|
2884
|
+
auto b_reg = b.copy_to_register();
|
|
2885
|
+
|
|
2886
|
+
using Layout = typename decltype(a_reg)::Layout;
|
|
2887
|
+
|
|
2888
|
+
WP_PRAGMA_UNROLL
|
|
2889
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
2890
|
+
{
|
|
2891
|
+
const int linear = Layout::linear_from_register(i);
|
|
2892
|
+
|
|
2893
|
+
if(!Layout::valid(linear))
|
|
2894
|
+
break;
|
|
2895
|
+
|
|
2896
|
+
a_reg.data[i] ^= b_reg.data[i];
|
|
2897
|
+
}
|
|
2898
|
+
|
|
2899
|
+
a.assign(a_reg);
|
|
2900
|
+
}
|
|
2901
|
+
|
|
2902
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
|
|
2903
|
+
inline CUDA_CALLABLE void adj_tile_bit_xor_inplace(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b) {}
|
|
2904
|
+
|
|
2385
2905
|
|
|
2386
2906
|
template<typename Tile>
|
|
2387
|
-
typename Tile::Type tile_extract(Tile& t, int i) {
|
|
2907
|
+
typename Tile::Type tile_extract(Tile& t, int i) {
|
|
2908
|
+
return t.extract(tile_coord(i));
|
|
2909
|
+
}
|
|
2388
2910
|
template<typename Tile>
|
|
2389
|
-
|
|
2911
|
+
auto tile_extract(Tile& t, int i, int j) {
|
|
2912
|
+
if constexpr(is_vector<typename Tile::Type>::value) {
|
|
2913
|
+
return t.extract(tile_coord(i))[j];
|
|
2914
|
+
} else {
|
|
2915
|
+
return t.extract(tile_coord(i,j));
|
|
2916
|
+
}
|
|
2917
|
+
}
|
|
2390
2918
|
template<typename Tile>
|
|
2391
|
-
|
|
2919
|
+
auto tile_extract(Tile& t, int i, int j, int k) {
|
|
2920
|
+
if constexpr(is_vector<typename Tile::Type>::value) {
|
|
2921
|
+
return t.extract(tile_coord(i,j))[k];
|
|
2922
|
+
} else if constexpr(is_matrix<typename Tile::Type>::value) {
|
|
2923
|
+
return t.extract(tile_coord(i)).data[j][k];
|
|
2924
|
+
} else {
|
|
2925
|
+
return t.extract(tile_coord(i,j,k));
|
|
2926
|
+
}
|
|
2927
|
+
}
|
|
2392
2928
|
template<typename Tile>
|
|
2393
|
-
|
|
2929
|
+
auto tile_extract(Tile& t, int i, int j, int k, int l) {
|
|
2930
|
+
if constexpr(is_vector<typename Tile::Type>::value) {
|
|
2931
|
+
return t.extract(tile_coord(i,j,k))[l];
|
|
2932
|
+
} else if constexpr(is_matrix<typename Tile::Type>::value) {
|
|
2933
|
+
return t.extract(tile_coord(i,j)).data[k][l];
|
|
2934
|
+
} else {
|
|
2935
|
+
return t.extract(tile_coord(i,j,k,l));
|
|
2936
|
+
}
|
|
2937
|
+
}
|
|
2938
|
+
template<typename Tile>
|
|
2939
|
+
auto tile_extract(Tile& t, int i, int j, int k, int l, int m) {
|
|
2940
|
+
if constexpr(is_vector<typename Tile::Type>::value) {
|
|
2941
|
+
return t.extract(tile_coord(i,j,k,l))[m];
|
|
2942
|
+
} else if constexpr(is_matrix<typename Tile::Type>::value) {
|
|
2943
|
+
return t.extract(tile_coord(i,j,k)).data[l][m];
|
|
2944
|
+
} else {
|
|
2945
|
+
static_assert(always_false<Tile>::value,
|
|
2946
|
+
"tile_extract with 5 indices requires a tile of vectors (4D tile) or matrices (3D tile)");
|
|
2947
|
+
}
|
|
2948
|
+
}
|
|
2949
|
+
template<typename Tile>
|
|
2950
|
+
auto tile_extract(Tile& t, int i, int j, int k, int l, int m, int n) {
|
|
2951
|
+
if constexpr(is_matrix<typename Tile::Type>::value) {
|
|
2952
|
+
return t.extract(tile_coord(i,j,k,l)).data[m][n];
|
|
2953
|
+
} else {
|
|
2954
|
+
static_assert(always_false<Tile>::value,
|
|
2955
|
+
"tile_extract with 6 indices requires a tile of matrices (4D tile)");
|
|
2956
|
+
}
|
|
2957
|
+
}
|
|
2394
2958
|
|
|
2395
2959
|
template<typename Tile, typename AdjTile>
|
|
2396
|
-
void adj_tile_extract(Tile& t, int i, AdjTile& adj_t, int adj_i, typename Tile::Type adj_ret) {
|
|
2397
|
-
|
|
2398
|
-
|
|
2399
|
-
template<typename Tile, typename AdjTile>
|
|
2400
|
-
void adj_tile_extract(Tile& t, int i, int j,
|
|
2401
|
-
|
|
2402
|
-
|
|
2960
|
+
void adj_tile_extract(Tile& t, int i, AdjTile& adj_t, int adj_i, typename Tile::Type adj_ret) {
|
|
2961
|
+
adj_t.adj_extract(tile_coord(i), adj_ret);
|
|
2962
|
+
}
|
|
2963
|
+
template<typename Tile, typename AdjTile, typename AdjType>
|
|
2964
|
+
void adj_tile_extract(Tile& t, int i, int j, AdjTile& adj_t, int adj_i, int adj_j, AdjType adj_ret) {
|
|
2965
|
+
if constexpr(is_vector<typename Tile::Type>::value) {
|
|
2966
|
+
typename Tile::Type vector_adj{};
|
|
2967
|
+
vector_adj[j] = adj_ret;
|
|
2968
|
+
adj_t.adj_extract(tile_coord(i), vector_adj);
|
|
2969
|
+
} else {
|
|
2970
|
+
adj_t.adj_extract(tile_coord(i, j), adj_ret);
|
|
2971
|
+
}
|
|
2972
|
+
}
|
|
2973
|
+
template<typename Tile, typename AdjTile, typename AdjType>
|
|
2974
|
+
void adj_tile_extract(Tile& t, int i, int j, int k, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, AdjType adj_ret) {
|
|
2975
|
+
if constexpr(is_vector<typename Tile::Type>::value) {
|
|
2976
|
+
typename Tile::Type vector_adj{};
|
|
2977
|
+
vector_adj[k] = adj_ret;
|
|
2978
|
+
adj_t.adj_extract(tile_coord(i, j), vector_adj);
|
|
2979
|
+
} else if constexpr(is_matrix<typename Tile::Type>::value) {
|
|
2980
|
+
typename Tile::Type matrix_adj{};
|
|
2981
|
+
matrix_adj.data[j][k] = adj_ret;
|
|
2982
|
+
adj_t.adj_extract(tile_coord(i), matrix_adj);
|
|
2983
|
+
} else {
|
|
2984
|
+
adj_t.adj_extract(tile_coord(i, j, k), adj_ret);
|
|
2985
|
+
}
|
|
2986
|
+
}
|
|
2987
|
+
template<typename Tile, typename AdjTile, typename AdjType>
|
|
2988
|
+
void adj_tile_extract(Tile& t, int i, int j, int k, int l, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, AdjType adj_ret) {
|
|
2989
|
+
if constexpr(is_vector<typename Tile::Type>::value) {
|
|
2990
|
+
typename Tile::Type vector_adj{};
|
|
2991
|
+
vector_adj[l] = adj_ret;
|
|
2992
|
+
adj_t.adj_extract(tile_coord(i, j, k), vector_adj);
|
|
2993
|
+
} else if constexpr(is_matrix<typename Tile::Type>::value) {
|
|
2994
|
+
typename Tile::Type matrix_adj{};
|
|
2995
|
+
matrix_adj.data[k][l] = adj_ret;
|
|
2996
|
+
adj_t.adj_extract(tile_coord(i, j), matrix_adj);
|
|
2997
|
+
} else {
|
|
2998
|
+
adj_t.adj_extract(tile_coord(i, j, k, l), adj_ret);
|
|
2999
|
+
}
|
|
3000
|
+
}
|
|
3001
|
+
template<typename Tile, typename AdjTile, typename AdjType>
|
|
3002
|
+
void adj_tile_extract(Tile& t, int i, int j, int k, int l, int m, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, int adj_m, AdjType adj_ret) {
|
|
3003
|
+
if constexpr(is_vector<typename Tile::Type>::value) {
|
|
3004
|
+
typename Tile::Type vector_adj{};
|
|
3005
|
+
vector_adj[m] = adj_ret;
|
|
3006
|
+
adj_t.adj_extract(tile_coord(i, j, k, l), vector_adj);
|
|
3007
|
+
} else if constexpr(is_matrix<typename Tile::Type>::value) {
|
|
3008
|
+
typename Tile::Type matrix_adj{};
|
|
3009
|
+
matrix_adj.data[l][m] = adj_ret;
|
|
3010
|
+
adj_t.adj_extract(tile_coord(i, j, k), matrix_adj);
|
|
3011
|
+
} else {
|
|
3012
|
+
static_assert(always_false<Tile>::value,
|
|
3013
|
+
"adj_tile_extract with 5 indices requires a tile of vectors (4D tile) or matrices (3D tile)");
|
|
3014
|
+
}
|
|
3015
|
+
}
|
|
3016
|
+
template<typename Tile, typename AdjTile, typename AdjType>
|
|
3017
|
+
void adj_tile_extract(Tile& t, int i, int j, int k, int l, int m, int n, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, int adj_m, int adj_n, AdjType adj_ret) {
|
|
3018
|
+
if constexpr(is_matrix<typename Tile::Type>::value) {
|
|
3019
|
+
typename Tile::Type matrix_adj{};
|
|
3020
|
+
matrix_adj.data[m][n] = adj_ret;
|
|
3021
|
+
adj_t.adj_extract(tile_coord(i, j, k, l), matrix_adj);
|
|
3022
|
+
} else {
|
|
3023
|
+
static_assert(always_false<Tile>::value,
|
|
3024
|
+
"adj_tile_extract with 6 indices requires a tile of matrices (4D tile)");
|
|
3025
|
+
}
|
|
3026
|
+
}
|
|
2403
3027
|
|
|
2404
3028
|
|
|
2405
3029
|
template<typename Tile>
|
|
@@ -2420,6 +3044,33 @@ void tile_sub_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) {
|
|
|
2420
3044
|
template<typename Tile>
|
|
2421
3045
|
void tile_sub_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.sub_inplace(tile_coord(i,j,k,l), value); }
|
|
2422
3046
|
|
|
3047
|
+
template<typename Tile>
|
|
3048
|
+
void tile_bit_and_inplace(Tile& t, int i, typename Tile::Type value) { t.bit_and_inplace(tile_coord(i), value); }
|
|
3049
|
+
template<typename Tile>
|
|
3050
|
+
void tile_bit_and_inplace(Tile& t, int i, int j, typename Tile::Type value) { t.bit_and_inplace(tile_coord(i,j), value); }
|
|
3051
|
+
template<typename Tile>
|
|
3052
|
+
void tile_bit_and_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) { t.bit_and_inplace(tile_coord(i,j,k), value); }
|
|
3053
|
+
template<typename Tile>
|
|
3054
|
+
void tile_bit_and_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.bit_and_inplace(tile_coord(i,j,k,l), value); }
|
|
3055
|
+
|
|
3056
|
+
template<typename Tile>
|
|
3057
|
+
void tile_bit_or_inplace(Tile& t, int i, typename Tile::Type value) { t.bit_or_inplace(tile_coord(i), value); }
|
|
3058
|
+
template<typename Tile>
|
|
3059
|
+
void tile_bit_or_inplace(Tile& t, int i, int j, typename Tile::Type value) { t.bit_or_inplace(tile_coord(i,j), value); }
|
|
3060
|
+
template<typename Tile>
|
|
3061
|
+
void tile_bit_or_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) { t.bit_or_inplace(tile_coord(i,j,k), value); }
|
|
3062
|
+
template<typename Tile>
|
|
3063
|
+
void tile_bit_or_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.bit_or_inplace(tile_coord(i,j,k,l), value); }
|
|
3064
|
+
|
|
3065
|
+
template<typename Tile>
|
|
3066
|
+
void tile_bit_xor_inplace(Tile& t, int i, typename Tile::Type value) { t.bit_xor_inplace(tile_coord(i), value); }
|
|
3067
|
+
template<typename Tile>
|
|
3068
|
+
void tile_bit_xor_inplace(Tile& t, int i, int j, typename Tile::Type value) { t.bit_xor_inplace(tile_coord(i,j), value); }
|
|
3069
|
+
template<typename Tile>
|
|
3070
|
+
void tile_bit_xor_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) { t.bit_xor_inplace(tile_coord(i,j,k), value); }
|
|
3071
|
+
template<typename Tile>
|
|
3072
|
+
void tile_bit_xor_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.bit_xor_inplace(tile_coord(i,j,k,l), value); }
|
|
3073
|
+
|
|
2423
3074
|
template<typename Tile, typename AdjTile>
|
|
2424
3075
|
void adj_tile_add_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) { adj_t.adj_add_inplace(tile_coord(i), adj_value); }
|
|
2425
3076
|
template<typename Tile, typename AdjTile>
|
|
@@ -2438,6 +3089,33 @@ void adj_tile_sub_inplace(Tile& t, int i, int j, int k, typename Tile::Type valu
|
|
|
2438
3089
|
template<typename Tile, typename AdjTile>
|
|
2439
3090
|
void adj_tile_sub_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) { adj_t.adj_sub_inplace(tile_coord(i, j, k, l), adj_value); }
|
|
2440
3091
|
|
|
3092
|
+
template<typename Tile, typename AdjTile>
|
|
3093
|
+
void adj_tile_bit_and_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) {}
|
|
3094
|
+
template<typename Tile, typename AdjTile>
|
|
3095
|
+
void adj_tile_bit_and_inplace(Tile& t, int i, int j, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type& adj_value) {}
|
|
3096
|
+
template<typename Tile, typename AdjTile>
|
|
3097
|
+
void adj_tile_bit_and_inplace(Tile& t, int i, int j, int k, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type& adj_value) {}
|
|
3098
|
+
template<typename Tile, typename AdjTile>
|
|
3099
|
+
void adj_tile_bit_and_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) {}
|
|
3100
|
+
|
|
3101
|
+
template<typename Tile, typename AdjTile>
|
|
3102
|
+
void adj_tile_bit_or_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) {}
|
|
3103
|
+
template<typename Tile, typename AdjTile>
|
|
3104
|
+
void adj_tile_bit_or_inplace(Tile& t, int i, int j, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type& adj_value) {}
|
|
3105
|
+
template<typename Tile, typename AdjTile>
|
|
3106
|
+
void adj_tile_bit_or_inplace(Tile& t, int i, int j, int k, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type& adj_value) {}
|
|
3107
|
+
template<typename Tile, typename AdjTile>
|
|
3108
|
+
void adj_tile_bit_or_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) {}
|
|
3109
|
+
|
|
3110
|
+
template<typename Tile, typename AdjTile>
|
|
3111
|
+
void adj_tile_bit_xor_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) {}
|
|
3112
|
+
template<typename Tile, typename AdjTile>
|
|
3113
|
+
void adj_tile_bit_xor_inplace(Tile& t, int i, int j, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type& adj_value) {}
|
|
3114
|
+
template<typename Tile, typename AdjTile>
|
|
3115
|
+
void adj_tile_bit_xor_inplace(Tile& t, int i, int j, int k, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type& adj_value) {}
|
|
3116
|
+
template<typename Tile, typename AdjTile>
|
|
3117
|
+
void adj_tile_bit_xor_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) {}
|
|
3118
|
+
|
|
2441
3119
|
namespace partitioned_gemm
|
|
2442
3120
|
{
|
|
2443
3121
|
|
|
@@ -2825,7 +3503,7 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
|
|
|
2825
3503
|
#define tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout) \
|
|
2826
3504
|
do { \
|
|
2827
3505
|
void function_name(dtype*, char*); \
|
|
2828
|
-
char* buffer = (char*)wp::
|
|
3506
|
+
char* buffer = (char*)wp::tile_shared_storage_t::alloc(shared_memory_size); \
|
|
2829
3507
|
__align__(16) dtype data[ept]; \
|
|
2830
3508
|
for(int b = 0; b < (int)batch_size; b++) { \
|
|
2831
3509
|
dtype* inout = Xinout.data + (int)b * (int)ept; \
|
|
@@ -2834,7 +3512,7 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
|
|
|
2834
3512
|
memcpy(inout, data, sizeof(dtype) * ept); \
|
|
2835
3513
|
WP_TILE_SYNC(); \
|
|
2836
3514
|
} \
|
|
2837
|
-
wp::
|
|
3515
|
+
wp::tile_shared_storage_t::alloc(-shared_memory_size); \
|
|
2838
3516
|
} while (0)
|
|
2839
3517
|
|
|
2840
3518
|
#define tile_ifft tile_fft
|
|
@@ -2878,7 +3556,7 @@ TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
|
|
|
2878
3556
|
#else
|
|
2879
3557
|
|
|
2880
3558
|
// TODO: for batched Cholesky, need one info per batch
|
|
2881
|
-
|
|
3559
|
+
__shared__ int info[1];
|
|
2882
3560
|
|
|
2883
3561
|
if (WP_TILE_THREAD_IDX == 0) {
|
|
2884
3562
|
info[0] = 0;
|
|
@@ -3048,7 +3726,7 @@ template <typename Tile, typename AdjTile>
|
|
|
3048
3726
|
inline CUDA_CALLABLE void adj_tile_transpose(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
3049
3727
|
{
|
|
3050
3728
|
auto a = tile_transpose(adj_ret);
|
|
3051
|
-
auto b = adj_t;
|
|
3729
|
+
auto& b = adj_t;
|
|
3052
3730
|
|
|
3053
3731
|
adj_t.assign(tile_add(a,b));
|
|
3054
3732
|
}
|
|
@@ -3210,22 +3888,63 @@ inline CUDA_CALLABLE void assign(TileA& dest, int i, const Scalar& src)
|
|
|
3210
3888
|
template <typename TileA, typename Scalar>
|
|
3211
3889
|
inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, const Scalar& src)
|
|
3212
3890
|
{
|
|
3213
|
-
|
|
3891
|
+
if constexpr(is_vector<typename TileA::Type>::value) {
|
|
3892
|
+
dest.data(tile_coord(i))[j] = src;
|
|
3893
|
+
} else {
|
|
3894
|
+
dest.data(tile_coord(i, j)) = src;
|
|
3895
|
+
}
|
|
3214
3896
|
WP_TILE_SYNC();
|
|
3215
3897
|
}
|
|
3216
3898
|
template <typename TileA, typename Scalar>
|
|
3217
3899
|
inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, const Scalar& src)
|
|
3218
3900
|
{
|
|
3219
|
-
|
|
3901
|
+
if constexpr(is_vector<typename TileA::Type>::value) {
|
|
3902
|
+
dest.data(tile_coord(i, j))[k] = src;
|
|
3903
|
+
} else if constexpr(is_matrix<typename TileA::Type>::value) {
|
|
3904
|
+
dest.data(tile_coord(i)).data[j][k] = src;
|
|
3905
|
+
} else {
|
|
3906
|
+
dest.data(tile_coord(i, j, k)) = src;
|
|
3907
|
+
}
|
|
3220
3908
|
WP_TILE_SYNC();
|
|
3221
3909
|
}
|
|
3222
3910
|
template <typename TileA, typename Scalar>
|
|
3223
3911
|
inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, int l, const Scalar& src)
|
|
3224
3912
|
{
|
|
3225
|
-
|
|
3913
|
+
if constexpr(is_vector<typename TileA::Type>::value) {
|
|
3914
|
+
dest.data(tile_coord(i, j, k))[l] = src;
|
|
3915
|
+
} else if constexpr(is_matrix<typename TileA::Type>::value) {
|
|
3916
|
+
dest.data(tile_coord(i, j)).data[k][l] = src;
|
|
3917
|
+
} else {
|
|
3918
|
+
dest.data(tile_coord(i, j, k, l)) = src;
|
|
3919
|
+
}
|
|
3920
|
+
WP_TILE_SYNC();
|
|
3921
|
+
}
|
|
3922
|
+
template <typename TileA, typename Scalar>
|
|
3923
|
+
inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, int l, int m, const Scalar& src)
|
|
3924
|
+
{
|
|
3925
|
+
if constexpr(is_vector<typename TileA::Type>::value) {
|
|
3926
|
+
dest.data(tile_coord(i, j, k, l))[m] = src;
|
|
3927
|
+
} else if constexpr(is_matrix<typename TileA::Type>::value) {
|
|
3928
|
+
dest.data(tile_coord(i, j, k)).data[l][m] = src;
|
|
3929
|
+
} else {
|
|
3930
|
+
static_assert(always_false<TileA>::value,
|
|
3931
|
+
"assign with 5 indices requires a tile of vectors (4D tile) or matrices (3D tile)");
|
|
3932
|
+
}
|
|
3933
|
+
WP_TILE_SYNC();
|
|
3934
|
+
}
|
|
3935
|
+
template <typename TileA, typename Scalar>
|
|
3936
|
+
inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, int l, int m, int n, const Scalar& src)
|
|
3937
|
+
{
|
|
3938
|
+
if constexpr(is_matrix<typename TileA::Type>::value) {
|
|
3939
|
+
dest.data(tile_coord(i, j, k, l)).data[m][n] = src;
|
|
3940
|
+
} else {
|
|
3941
|
+
static_assert(always_false<TileA>::value,
|
|
3942
|
+
"assign with 6 indices requires a tile of matrices (4D tile)");
|
|
3943
|
+
}
|
|
3226
3944
|
WP_TILE_SYNC();
|
|
3227
3945
|
}
|
|
3228
3946
|
|
|
3947
|
+
|
|
3229
3948
|
template <typename TileA, typename AdjTileA, typename Scalar>
|
|
3230
3949
|
inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, const Scalar& src, AdjTileA& adj_dest, int adj_i, Scalar& adj_src)
|
|
3231
3950
|
{
|
|
@@ -3244,7 +3963,11 @@ inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, const Scalar& sr
|
|
|
3244
3963
|
return;
|
|
3245
3964
|
}
|
|
3246
3965
|
|
|
3247
|
-
|
|
3966
|
+
if constexpr(is_vector<typename TileA::Type>::value) {
|
|
3967
|
+
adj_src += dest.grad(tile_coord(i))[j];
|
|
3968
|
+
} else {
|
|
3969
|
+
adj_src += dest.grad(tile_coord(i, j));
|
|
3970
|
+
}
|
|
3248
3971
|
}
|
|
3249
3972
|
template <typename TileA, typename AdjTileA, typename Scalar>
|
|
3250
3973
|
inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, const Scalar& src, AdjTileA& adj_dest, int adj_i, int adj_j, int adj_k, Scalar& adj_src)
|
|
@@ -3254,7 +3977,13 @@ inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, const Sca
|
|
|
3254
3977
|
return;
|
|
3255
3978
|
}
|
|
3256
3979
|
|
|
3257
|
-
|
|
3980
|
+
if constexpr(is_vector<typename TileA::Type>::value) {
|
|
3981
|
+
adj_src += dest.grad(tile_coord(i, j))[k];
|
|
3982
|
+
} else if constexpr(is_matrix<typename TileA::Type>::value) {
|
|
3983
|
+
adj_src += dest.grad(tile_coord(i)).data[j][k];
|
|
3984
|
+
} else {
|
|
3985
|
+
adj_src += dest.grad(tile_coord(i, j, k));
|
|
3986
|
+
}
|
|
3258
3987
|
}
|
|
3259
3988
|
template <typename TileA, typename AdjTileA, typename Scalar>
|
|
3260
3989
|
inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, int l, const Scalar& src, AdjTileA& adj_dest, int adj_i, int adj_j, int adj_k, int adj_l, Scalar& adj_src)
|
|
@@ -3264,7 +3993,45 @@ inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, int l, co
|
|
|
3264
3993
|
return;
|
|
3265
3994
|
}
|
|
3266
3995
|
|
|
3267
|
-
|
|
3996
|
+
if constexpr(is_vector<typename TileA::Type>::value) {
|
|
3997
|
+
adj_src += dest.grad(tile_coord(i, j, k))[l];
|
|
3998
|
+
} else if constexpr(is_matrix<typename TileA::Type>::value) {
|
|
3999
|
+
adj_src += dest.grad(tile_coord(i, j)).data[k][l];
|
|
4000
|
+
} else {
|
|
4001
|
+
adj_src += dest.grad(tile_coord(i, j, k, l));
|
|
4002
|
+
}
|
|
4003
|
+
}
|
|
4004
|
+
template <typename TileA, typename AdjTileA, typename Scalar>
|
|
4005
|
+
inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, int l, int m, const Scalar& src, AdjTileA& adj_dest, int adj_i, int adj_j, int adj_k, int adj_l, int adj_m, Scalar& adj_src)
|
|
4006
|
+
{
|
|
4007
|
+
if (dest.grad.ptr == nullptr)
|
|
4008
|
+
{
|
|
4009
|
+
return;
|
|
4010
|
+
}
|
|
4011
|
+
|
|
4012
|
+
if constexpr(is_vector<typename TileA::Type>::value) {
|
|
4013
|
+
adj_src += dest.grad(tile_coord(i, j, k, l))[m];
|
|
4014
|
+
} else if constexpr(is_matrix<typename TileA::Type>::value) {
|
|
4015
|
+
adj_src += dest.grad(tile_coord(i, j, k)).data[l][m];
|
|
4016
|
+
} else {
|
|
4017
|
+
static_assert(always_false<TileA>::value,
|
|
4018
|
+
"adj_assign with 5 indices requires a tile of vectors (4D tile) or matrices (3D tile)");
|
|
4019
|
+
}
|
|
4020
|
+
}
|
|
4021
|
+
template <typename TileA, typename AdjTileA, typename Scalar>
|
|
4022
|
+
inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, int l, int m, int n, const Scalar& src, AdjTileA& adj_dest, int adj_i, int adj_j, int adj_k, int adj_l, int adj_m, int adj_n, Scalar& adj_src)
|
|
4023
|
+
{
|
|
4024
|
+
if (dest.grad.ptr == nullptr)
|
|
4025
|
+
{
|
|
4026
|
+
return;
|
|
4027
|
+
}
|
|
4028
|
+
|
|
4029
|
+
if constexpr(is_matrix<typename TileA::Type>::value) {
|
|
4030
|
+
adj_src += dest.grad(tile_coord(i, j, k, l)).data[m][n];
|
|
4031
|
+
} else {
|
|
4032
|
+
static_assert(always_false<TileA>::value,
|
|
4033
|
+
"adj_assign with 6 indices requires a tile of matrices (4D tile)");
|
|
4034
|
+
}
|
|
3268
4035
|
}
|
|
3269
4036
|
|
|
3270
4037
|
template <typename TileA, typename TileB, typename Coord>
|