warp-lang 1.5.0__py3-none-manylinux2014_x86_64.whl → 1.6.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (132) hide show
  1. warp/__init__.py +5 -0
  2. warp/autograd.py +414 -191
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +40 -12
  6. warp/build_dll.py +13 -6
  7. warp/builtins.py +1124 -497
  8. warp/codegen.py +261 -136
  9. warp/config.py +1 -1
  10. warp/context.py +357 -119
  11. warp/examples/assets/square_cloth.usd +0 -0
  12. warp/examples/benchmarks/benchmark_gemm.py +27 -18
  13. warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
  14. warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
  15. warp/examples/core/example_torch.py +18 -34
  16. warp/examples/fem/example_apic_fluid.py +1 -0
  17. warp/examples/fem/example_mixed_elasticity.py +1 -1
  18. warp/examples/optim/example_bounce.py +1 -1
  19. warp/examples/optim/example_cloth_throw.py +1 -1
  20. warp/examples/optim/example_diffray.py +4 -15
  21. warp/examples/optim/example_drone.py +1 -1
  22. warp/examples/optim/example_softbody_properties.py +392 -0
  23. warp/examples/optim/example_trajectory.py +1 -3
  24. warp/examples/optim/example_walker.py +5 -0
  25. warp/examples/sim/example_cartpole.py +0 -2
  26. warp/examples/sim/example_cloth.py +3 -1
  27. warp/examples/sim/example_cloth_self_contact.py +260 -0
  28. warp/examples/sim/example_granular_collision_sdf.py +4 -5
  29. warp/examples/sim/example_jacobian_ik.py +0 -2
  30. warp/examples/sim/example_quadruped.py +5 -2
  31. warp/examples/tile/example_tile_cholesky.py +79 -0
  32. warp/examples/tile/example_tile_convolution.py +2 -2
  33. warp/examples/tile/example_tile_fft.py +2 -2
  34. warp/examples/tile/example_tile_filtering.py +3 -3
  35. warp/examples/tile/example_tile_matmul.py +4 -4
  36. warp/examples/tile/example_tile_mlp.py +12 -12
  37. warp/examples/tile/example_tile_nbody.py +180 -0
  38. warp/examples/tile/example_tile_walker.py +319 -0
  39. warp/fem/geometry/geometry.py +0 -2
  40. warp/math.py +147 -0
  41. warp/native/array.h +12 -0
  42. warp/native/builtin.h +0 -1
  43. warp/native/bvh.cpp +149 -70
  44. warp/native/bvh.cu +287 -68
  45. warp/native/bvh.h +195 -85
  46. warp/native/clang/clang.cpp +5 -1
  47. warp/native/coloring.cpp +5 -1
  48. warp/native/cuda_util.cpp +91 -53
  49. warp/native/cuda_util.h +5 -0
  50. warp/native/exports.h +40 -40
  51. warp/native/intersect.h +17 -0
  52. warp/native/mat.h +41 -0
  53. warp/native/mathdx.cpp +19 -0
  54. warp/native/mesh.cpp +25 -8
  55. warp/native/mesh.cu +153 -101
  56. warp/native/mesh.h +482 -403
  57. warp/native/quat.h +40 -0
  58. warp/native/solid_angle.h +7 -0
  59. warp/native/sort.cpp +85 -0
  60. warp/native/sort.cu +34 -0
  61. warp/native/sort.h +3 -1
  62. warp/native/spatial.h +11 -0
  63. warp/native/tile.h +1187 -669
  64. warp/native/tile_reduce.h +8 -6
  65. warp/native/vec.h +41 -0
  66. warp/native/warp.cpp +8 -1
  67. warp/native/warp.cu +263 -40
  68. warp/native/warp.h +19 -5
  69. warp/optim/linear.py +22 -4
  70. warp/render/render_opengl.py +130 -64
  71. warp/sim/__init__.py +6 -1
  72. warp/sim/collide.py +270 -26
  73. warp/sim/import_urdf.py +8 -8
  74. warp/sim/integrator_euler.py +25 -7
  75. warp/sim/integrator_featherstone.py +154 -35
  76. warp/sim/integrator_vbd.py +842 -40
  77. warp/sim/model.py +134 -72
  78. warp/sparse.py +1 -1
  79. warp/stubs.py +265 -132
  80. warp/tape.py +28 -30
  81. warp/tests/aux_test_module_unload.py +15 -0
  82. warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
  83. warp/tests/test_array.py +74 -0
  84. warp/tests/test_assert.py +242 -0
  85. warp/tests/test_codegen.py +14 -61
  86. warp/tests/test_collision.py +2 -2
  87. warp/tests/test_coloring.py +12 -2
  88. warp/tests/test_examples.py +12 -1
  89. warp/tests/test_func.py +21 -4
  90. warp/tests/test_grad_debug.py +87 -2
  91. warp/tests/test_hash_grid.py +1 -1
  92. warp/tests/test_ipc.py +116 -0
  93. warp/tests/test_lerp.py +13 -87
  94. warp/tests/test_mat.py +138 -167
  95. warp/tests/test_math.py +47 -1
  96. warp/tests/test_matmul.py +17 -16
  97. warp/tests/test_matmul_lite.py +10 -15
  98. warp/tests/test_mesh.py +84 -60
  99. warp/tests/test_mesh_query_aabb.py +165 -0
  100. warp/tests/test_mesh_query_point.py +328 -286
  101. warp/tests/test_mesh_query_ray.py +134 -121
  102. warp/tests/test_mlp.py +2 -2
  103. warp/tests/test_operators.py +43 -0
  104. warp/tests/test_overwrite.py +47 -2
  105. warp/tests/test_quat.py +77 -0
  106. warp/tests/test_reload.py +29 -0
  107. warp/tests/test_sim_grad_bounce_linear.py +204 -0
  108. warp/tests/test_smoothstep.py +17 -83
  109. warp/tests/test_static.py +19 -3
  110. warp/tests/test_tape.py +25 -0
  111. warp/tests/test_tile.py +178 -191
  112. warp/tests/test_tile_load.py +356 -0
  113. warp/tests/test_tile_mathdx.py +61 -8
  114. warp/tests/test_tile_mlp.py +17 -17
  115. warp/tests/test_tile_reduce.py +24 -18
  116. warp/tests/test_tile_shared_memory.py +66 -17
  117. warp/tests/test_tile_view.py +165 -0
  118. warp/tests/test_torch.py +35 -0
  119. warp/tests/test_utils.py +36 -24
  120. warp/tests/test_vec.py +110 -0
  121. warp/tests/unittest_suites.py +29 -4
  122. warp/tests/unittest_utils.py +30 -13
  123. warp/thirdparty/unittest_parallel.py +2 -2
  124. warp/types.py +411 -101
  125. warp/utils.py +10 -7
  126. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/METADATA +92 -69
  127. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/RECORD +130 -119
  128. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/WHEEL +1 -1
  129. warp/examples/benchmarks/benchmark_tile.py +0 -179
  130. warp/native/tile_gemm.h +0 -341
  131. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/LICENSE.md +0 -0
  132. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/top_level.txt +0 -0
@@ -1,179 +0,0 @@
1
- # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
-
8
- import numpy as np
9
- import torch
10
-
11
- import warp as wp
12
-
13
- wp.init()
14
- wp.set_module_options({"enable_backward": False, "fast_math": True})
15
- wp.set_device("cuda:0")
16
-
17
- wp.build.clear_kernel_cache()
18
-
19
-
20
- @wp.kernel
21
- def gemm(A: wp.array2d(dtype=float), B: wp.array2d(dtype=float), C: wp.array2d(dtype=float)):
22
- # output index
23
- i, j = wp.tid()
24
-
25
- sum = float(0.0)
26
-
27
- for k in range(0, A.shape[1]):
28
- sum += A[i, k] * B[k, j]
29
-
30
- C[i, j] = sum
31
-
32
-
33
- TILE_M = wp.constant(64)
34
- TILE_N = wp.constant(64)
35
- TILE_K = wp.constant(8)
36
-
37
-
38
- @wp.kernel
39
- def gemm_tiled(A: wp.array2d(dtype=float), B: wp.array2d(dtype=float), C: wp.array2d(dtype=float)):
40
- # output tile index
41
- i, j = wp.tid()
42
-
43
- sum = wp.tile_zeros(m=TILE_M, n=TILE_N, dtype=wp.float32)
44
-
45
- _M = A.shape[0]
46
- _N = B.shape[1]
47
- K = A.shape[1]
48
-
49
- count = int(K / 8) # TODO: code-gen bug if you use a constant before passing it to a kwd arg (in this case TILE_K)
50
-
51
- for k in range(count):
52
- a = wp.tile_load(A, i, k, m=TILE_M, n=TILE_K)
53
- b = wp.tile_load(B, k, j, m=TILE_K, n=TILE_N)
54
-
55
- # sum += a*b
56
- wp.tile_matmul(a, b, sum)
57
-
58
- wp.tile_store(C, i, j, sum)
59
-
60
-
61
- def benchmark_numpy(A, B, C):
62
- timers = {}
63
- iters = 10
64
-
65
- # warm up
66
- for _i in range(10):
67
- _C = A @ B
68
-
69
- with wp.ScopedTimer("NumPy", dict=timers):
70
- for _i in range(iters):
71
- _C = A @ B
72
-
73
- return min(timers["NumPy"])
74
-
75
-
76
- def benchmark_warp_simt(A, B, C):
77
- timers = {}
78
- iters = 10
79
-
80
- A_wp = wp.array(A)
81
- B_wp = wp.array(B)
82
- C_wp = wp.array(C)
83
-
84
- # warm up
85
- for _i in range(10):
86
- wp.launch(gemm, dim=(M, N), inputs=[A_wp, B_wp, C_wp])
87
-
88
- with wp.ScopedTimer("Warp (SIMT)", dict=timers, print=False, synchronize=True):
89
- for _i in range(iters):
90
- wp.launch(gemm, dim=(M, N), inputs=[A_wp, B_wp, C_wp])
91
-
92
- return min(timers["Warp (SIMT)"])
93
-
94
-
95
- def benchmark_warp_tiled(A, B, C):
96
- timers = {}
97
- iters = 10
98
-
99
- # must match with the tile_matmul() partition size
100
- SUB_TILE_M = 4
101
- SUB_TILE_N = 4
102
-
103
- num_threads = int(TILE_M / SUB_TILE_M) * int(TILE_N / SUB_TILE_N)
104
- A_wp = wp.array(A)
105
- B_wp = wp.array(B)
106
- C_wp = wp.array(C)
107
-
108
- # warm up
109
- wp.capture_begin()
110
-
111
- for _i in range(iters):
112
- wp.launch(gemm_tiled, dim=(int(M / TILE_M), int(N / TILE_N)), inputs=[A_wp, B_wp, C_wp], tile_size=num_threads)
113
-
114
- graph = wp.capture_end()
115
-
116
- with wp.ScopedTimer("Warp (Tiled)", dict=timers, print=False, synchronize=True):
117
- # for i in range(iters):
118
- # wp.launch(gemm_tiled, dim=(int(M/TILE_M), int(N/TILE_N)), inputs=[A_wp, B_wp, C_wp], tile_size=num_threads)
119
- wp.capture_launch(graph)
120
-
121
- return min(timers["Warp (Tiled)"])
122
-
123
-
124
- def benchmark_torch(A, B, C):
125
- A_tc = torch.from_numpy(A).to("cuda:0")
126
- B_tc = torch.from_numpy(B).to("cuda:0")
127
- C_tc = torch.from_numpy(C).to("cuda:0")
128
-
129
- # warm-up
130
- for _i in range(10):
131
- torch.matmul(A_tc, B_tc, out=C_tc)
132
-
133
- timers = {}
134
- iters = 10
135
-
136
- torch.cuda.synchronize()
137
-
138
- with wp.ScopedTimer("Torch", dict=timers, print=False):
139
- for _i in range(iters):
140
- torch.matmul(A_tc, B_tc) # , out=C_tc)
141
-
142
- torch.cuda.synchronize()
143
-
144
- return min(timers["Torch"])
145
-
146
-
147
- results_torch = []
148
- results_warp_simt = []
149
- results_warp_tiled = []
150
-
151
- print("{:>8s} {:>8s} {:>8s} {:>8s} {:>8s} {:>8s}".format("M", "N", "K", "Torch", "Warp (SIMT)", "Warp (Tiled)"))
152
- print("--------------------------------------------------------")
153
-
154
- for i in range(2, 33):
155
- # for i in range(8,9):
156
-
157
- M = i * 128
158
- N = M
159
- K = N
160
-
161
- # M = TILE_M*21
162
- # K = TILE_K*7
163
- # N = TILE_M*12
164
-
165
- rng = np.random.default_rng(42)
166
-
167
- A = rng.random((M, K), dtype=np.float32)
168
- B = rng.random((K, N), dtype=np.float32)
169
- C = np.zeros((M, N), dtype=np.float32)
170
-
171
- results_torch.append(benchmark_torch(A, B, C))
172
- results_warp_simt.append(0.0) # benchmark_warp_simt(A, B, C))
173
- results_warp_tiled.append(benchmark_warp_tiled(A, B, C))
174
-
175
- print(
176
- "{:>8d} {:>8d} {:>8d} {:>8f} {:>8f} {:>8f}".format(
177
- M, N, K, results_torch[-1], results_warp_simt[-1], results_warp_tiled[-1]
178
- )
179
- )
warp/native/tile_gemm.h DELETED
@@ -1,341 +0,0 @@
1
- /** Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
- * NVIDIA CORPORATION and its licensors retain all intellectual property
3
- * and proprietary rights in and to this software, related documentation
4
- * and any modifications thereto. Any use, reproduction, disclosure or
5
- * distribution of this software and related documentation without an express
6
- * license agreement from NVIDIA CORPORATION is strictly prohibited.
7
- */
8
-
9
- #pragma once
10
-
11
- #include "builtin.h"
12
-
13
- #define USE_CUTE 0
14
-
15
- #if USE_CUTE
16
- #include "cutlass/include/cute/tensor.hpp"
17
- #include "cutlass/include/cute/algorithm/cooperative_gemm.hpp"
18
- #endif // USE_CUTE
19
-
20
- namespace wp
21
- {
22
-
23
- /*
24
- // 2D tile zero
25
- template <typename T, int M, int N, int Index>
26
- inline CUDA_CALLABLE array_t<T> tile_zeros()
27
- {
28
- const int length = M*N;
29
-
30
- WP_TILE_SHARED __align__(16) T data[length];
31
-
32
- WP_PRAGMA_UNROLL
33
- for (int t=threadIdx.x; t < length; t += blockDim.x)
34
- {
35
- data[t] = T(0.0);
36
- }
37
-
38
- return array_t<T>(data, M, N, nullptr);
39
- }
40
-
41
- // 2D tile load
42
- template <typename T, int M, int N, int Index>
43
- inline CUDA_CALLABLE array_t<T> tile_load(const array_t<T>& src, int i, int j)
44
- {
45
- const int length = M*N;
46
-
47
- WP_TILE_SHARED __align__(16) T data[length];
48
-
49
- //---------------
50
- // naive-synchronous load
51
- //
52
- // WP_PRAGMA_UNROLL
53
- // for (int t=threadIdx.x; t < length; t += blockDim.x)
54
- // {
55
- // data[t] = index(src, i*M + t/N, j*N + t%N);
56
- // }
57
-
58
- //---------------
59
- // async 128 bit loads (assumes row-major i.e.: stride 1 on y axis and 4-element alignment on dimension)
60
- const int s = 4;
61
-
62
- WP_PRAGMA_UNROLL
63
- for (int t=threadIdx.x*s; t < length; t += blockDim.x*s)
64
- {
65
- __pipeline_memcpy_async(&data[t],
66
- &index(src, i*M + t/N, j*N + t%N),
67
- sizeof(T)*s);
68
- }
69
-
70
- __pipeline_commit();
71
-
72
-
73
- return array_t<T>(data, M, N, nullptr);
74
- }
75
-
76
- // 2D tile store
77
- template <typename T>
78
- inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int i, int j, const array_t<T>& src)
79
- {
80
- const int M = src.shape[0];
81
- const int N = src.shape[1];
82
-
83
- const int length = M*N;
84
-
85
- // cooperatively store the tile, using a block-stride iterator
86
- WP_PRAGMA_UNROLL
87
- for (int t=threadIdx.x; t < length; t += blockDim.x)
88
- {
89
- index(dest, i*M + t/N, j*N + t%N) = src.data[t];
90
- }
91
- }
92
- */
93
-
94
- template <typename T>
95
- inline CUDA_CALLABLE const T& index(const T* __restrict__ p, int i, int j, int stride)
96
- {
97
- return p[i*stride + j];
98
- }
99
-
100
- template <typename T>
101
- inline CUDA_CALLABLE T& index(T* __restrict__ p, int i, int j, int stride)
102
- {
103
- return p[i*stride + j];
104
- }
105
-
106
- template <unsigned M, unsigned N, typename T>
107
- struct partition_t
108
- {
109
- inline partition_t(array_t<T> A)
110
- {
111
- data = A;
112
-
113
- // todo: do ceil div for non-multiples of M,N
114
- shape[0] = A.shape[0]/M;
115
- shape[1] = A.shape[1]/N;
116
- }
117
-
118
- // underlying data
119
- array_t<T> data;
120
-
121
- // partition dimensions
122
- int shape[2];
123
- };
124
-
125
- template <unsigned M, unsigned N, typename T>
126
- inline int partition_size(const partition_t<M, N, T>& tile)
127
- {
128
- return tile.shape[0]*tile.shape[1];
129
- }
130
-
131
- // returns the x, y coordinates of a tile given a linear index
132
- template <unsigned M, unsigned N, typename T>
133
- inline void partition_coord(const partition_t<M, N, T>& tile, const int t, int& i, int& j)
134
- {
135
- i = t/tile.shape[1];
136
- j = t%tile.shape[1];
137
- }
138
-
139
- template <unsigned M, unsigned N, typename T>
140
- inline mat_t<M, N, T> partition_load(const partition_t<M, N, T>& tile, int i, int j)
141
- {
142
- mat_t<M, N, T> out;
143
-
144
- const int tile_i = i*M;
145
- const int tile_j = j*N;
146
-
147
- WP_PRAGMA_UNROLL
148
- for (int i=0; i < M; ++i)
149
- {
150
- WP_PRAGMA_UNROLL
151
- for (int j=0; j < N; ++j)
152
- {
153
- out.data[i][j] = index(tile.data, tile_i + i, tile_j + j);
154
- }
155
- }
156
-
157
- return out;
158
- }
159
-
160
- template <unsigned M, unsigned N, typename T>
161
- inline void partition_store(const partition_t<M, N, T>& tile, int i, int j, const mat_t<M, N, T>& value)
162
- {
163
- mat_t<M, N, T> out;
164
-
165
- const int tile_i = M*i;
166
- const int tile_j = N*j;
167
-
168
- WP_PRAGMA_UNROLL
169
- for (int i=0; i < M; ++i)
170
- {
171
- WP_PRAGMA_UNROLL
172
- for (int j=0; j < N; ++j)
173
- {
174
- index(tile.data, tile_i + i, tile_j + j) = value.data[i][j];
175
- }
176
- }
177
- }
178
-
179
-
180
- #if !USE_CUTE
181
-
182
- template <typename T>
183
- inline CUDA_CALLABLE void gemm(const array_t<T>& A, const array_t<T>& B, const array_t<T>& out)
184
- {
185
- const int TILE_M = 4;
186
- const int TILE_N = 4;
187
- const int TILE_K = 4;
188
-
189
- partition_t A_tile = partition_t<TILE_M, TILE_K, T>(A);
190
- partition_t B_tile = partition_t<TILE_K, TILE_N, T>(B);
191
- partition_t C_tile = partition_t<TILE_M, TILE_N, T>(out);
192
-
193
- const int length = partition_size(C_tile);
194
-
195
- __pipeline_wait_prior(0);
196
-
197
- WP_TILE_SYNC();
198
-
199
- for (int t=threadIdx.x; t < length; t += blockDim.x)
200
- {
201
- int i, j;
202
- partition_coord(C_tile, t, i, j);
203
-
204
- // accumulator
205
- mat_t<TILE_M, TILE_N, T> sum = partition_load(C_tile, i, j);
206
-
207
- WP_PRAGMA_UNROLL
208
- for (int k=0; k < A_tile.shape[1]; k++)
209
- {
210
- const mat_t<TILE_M, TILE_K, T> a = partition_load(A_tile, i, k);
211
- const mat_t<TILE_K, TILE_N, T> b = partition_load(B_tile, k, j);
212
-
213
- sum += mul(a, b);
214
- }
215
-
216
- partition_store(C_tile, i, j, sum);
217
- }
218
-
219
- WP_TILE_SYNC();
220
- }
221
-
222
-
223
- // 2D gemm accumulate out += A*B
224
- template <typename TileA, typename TileB, typename TileC>
225
- inline CUDA_CALLABLE void tile_matmul_scalar(const TileA& A,
226
- const TileB& B,
227
- TileC& out)
228
- {
229
- const int length = tile_size(out);
230
-
231
- WP_TILE_SYNC();
232
-
233
- using T = typename TileA::Type;
234
-
235
- WP_PRAGMA_UNROLL
236
- for (int t=threadIdx.x; t < length; t += WP_TILE_BLOCK_DIM)
237
- {
238
- // compute output index
239
- const int i = t/out.N;
240
- const int j = t%out.N;
241
-
242
- T sum(0.0);
243
-
244
- WP_PRAGMA_UNROLL
245
- for (int k=0; k < A.N; ++k)
246
- {
247
- T a = A(i,k);
248
- T b = B(k,j);
249
-
250
- sum += a*b; // todo: use fmaf()
251
- }
252
-
253
- out(i,j) += sum;
254
- }
255
-
256
- WP_TILE_SYNC();
257
- }
258
-
259
- #else
260
-
261
-
262
- template <typename T>
263
- inline CUDA_CALLABLE void tile_matmul(const array_t<T>& A, const array_t<T>& B, const array_t<T>& out)
264
- {
265
- using namespace cute;
266
-
267
- __pipeline_wait_prior(0);
268
-
269
- // ensure smem tile is ready
270
- WP_TILE_SYNC();
271
-
272
- // Define CTA matrix size (static)
273
- auto bM = Int<64>{};
274
- auto bN = Int<64>{};
275
- auto bK = Int<8>{};
276
-
277
- // Define the smem layouts (static)
278
- auto sA = make_layout(make_shape(bM, bK), LayoutRight{});
279
- auto sB = make_layout(make_shape(bN, bK));
280
- auto sC = make_layout(make_shape(bM, bN), LayoutRight{});
281
-
282
- Tensor s_a_tensor = make_tensor(make_smem_ptr<float>(A.data), sA);
283
- Tensor s_b_tensor = make_tensor(make_smem_ptr<float>(B.data), sB);
284
- Tensor s_c_tensor = make_tensor(make_smem_ptr<float>(out.data), sC);
285
-
286
-
287
- // TiledMMA tiled_mma = make_tiled_mma(UniversalFMA<float,float,float>{},
288
- // Layout<Shape<_16,_8,_1>>{}); // 16x8x1 UniversalFMA, assumes blockDim=128
289
-
290
-
291
- // TiledMMA tiled_mma = make_tiled_mma(UniversalFMA<float,float,float>{},
292
- // Layout<Shape<_8,_16>,Stride<_16,_1>>{}); // 8x16x1 UniversalFMA, assumes blockDim=128
293
-
294
-
295
-
296
- TiledMMA tiled_mma = make_tiled_mma(UniversalFMA<float,float,float>{},
297
- Layout<Shape<_2,_64>,Stride<_64,_1>>{}); // 8x16x1 UniversalFMA, assumes blockDim=128
298
-
299
-
300
- cooperative_gemm< AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<float>>,
301
- AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<float>>,
302
- AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<float>>
303
- >(
304
- threadIdx.x, tiled_mma,
305
- 1.0f, s_a_tensor, s_b_tensor, 1.0f, s_c_tensor,
306
- cute::identity(), cute::identity(), cute::identity(), cute::identity()
307
- );
308
-
309
- WP_TILE_SYNC();
310
-
311
- }
312
-
313
- #endif // USE_CUTE
314
-
315
-
316
- #if 0
317
-
318
- template <typename TileA, typename TileB, typename TileC>
319
- void tile_matmul(TileA& a, TileB& b, TileC& c)
320
- {
321
- static_assert(wp::is_same<typename TileA::Type, typename TileB::Type>::value, "Error, tile datatypes must match");
322
- static_assert(TileA::N == TileB::M, "Error, inner dimensions must match");
323
- static_assert(TileC::M == TileA::M, "Error, first output dimension must match");
324
- static_assert(TileC::N == TileB::N, "Error, second output dimension must match");
325
-
326
- tile_matmul_scalar(a, b, c);
327
- }
328
-
329
-
330
- template <typename TileA, typename TileB, typename TileC,
331
- typename AdjTileA, typename AdjTileB, typename AdjTileC>
332
- void adj_tile_matmul(TileA& a, TileB& b, TileC& c,
333
- AdjTileA& adj_a, AdjTileB& adj_b, AdjTileC& adj_c)
334
- {
335
- tile_matmul_scalar(adj_c, wp::tile_transpose(b), adj_a);
336
- tile_matmul_scalar(wp::tile_transpose(a), adj_c, adj_b);
337
- }
338
-
339
- #endif // 0
340
-
341
- } // namespace wp