whispercpp 1.2.0.2 → 1.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.gitignore +5 -0
- data/LICENSE +1 -1
- data/README.md +165 -434
- data/Rakefile +46 -86
- data/ext/.gitignore +13 -0
- data/ext/cpu.mk +9 -0
- data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
- data/ext/extconf.rb +185 -7
- data/ext/ggml/include/ggml-alloc.h +76 -0
- data/ext/ggml/include/ggml-backend.h +352 -0
- data/ext/ggml/include/ggml-blas.h +25 -0
- data/ext/ggml/include/ggml-cann.h +123 -0
- data/ext/ggml/include/ggml-cpp.h +38 -0
- data/ext/ggml/include/ggml-cpu.h +135 -0
- data/ext/ggml/include/ggml-cuda.h +47 -0
- data/ext/ggml/include/ggml-kompute.h +50 -0
- data/ext/ggml/include/ggml-metal.h +66 -0
- data/ext/ggml/include/ggml-opencl.h +26 -0
- data/ext/ggml/include/ggml-opt.h +216 -0
- data/ext/ggml/include/ggml-rpc.h +28 -0
- data/ext/ggml/include/ggml-sycl.h +49 -0
- data/ext/ggml/include/ggml-vulkan.h +31 -0
- data/ext/ggml/include/ggml.h +2285 -0
- data/ext/ggml/src/ggml-alloc.c +1037 -0
- data/ext/ggml/src/ggml-amx/common.h +94 -0
- data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
- data/ext/ggml/src/ggml-amx/mmq.h +17 -0
- data/ext/ggml/src/ggml-backend-impl.h +256 -0
- data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
- data/ext/ggml/src/ggml-backend.cpp +1999 -0
- data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
- data/ext/ggml/src/ggml-cann/common.h +286 -0
- data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
- data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
- data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
- data/ext/ggml/src/ggml-common.h +1853 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
- data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
- data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
- data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- data/ext/ggml/src/ggml-impl.h +556 -0
- data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
- data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
- data/ext/ggml/src/ggml-opt.cpp +854 -0
- data/ext/ggml/src/ggml-quants.c +5238 -0
- data/ext/ggml/src/ggml-quants.h +100 -0
- data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
- data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
- data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
- data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
- data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
- data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
- data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
- data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
- data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
- data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
- data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
- data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
- data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
- data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
- data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
- data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
- data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
- data/ext/ggml/src/ggml-threading.cpp +12 -0
- data/ext/ggml/src/ggml-threading.h +14 -0
- data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
- data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
- data/ext/ggml/src/ggml.c +7694 -0
- data/ext/include/whisper.h +672 -0
- data/ext/metal-embed.mk +17 -0
- data/ext/metal.mk +6 -0
- data/ext/ruby_whisper.cpp +1608 -159
- data/ext/ruby_whisper.h +10 -0
- data/ext/scripts/get-flags.mk +38 -0
- data/ext/src/coreml/whisper-decoder-impl.h +146 -0
- data/ext/src/coreml/whisper-decoder-impl.m +201 -0
- data/ext/src/coreml/whisper-encoder-impl.h +142 -0
- data/ext/src/coreml/whisper-encoder-impl.m +197 -0
- data/ext/src/coreml/whisper-encoder.h +26 -0
- data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
- data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
- data/ext/src/whisper.cpp +7393 -0
- data/extsources.rb +6 -0
- data/lib/whisper/model/uri.rb +157 -0
- data/lib/whisper.rb +2 -0
- data/tests/helper.rb +7 -0
- data/tests/jfk_reader/.gitignore +5 -0
- data/tests/jfk_reader/extconf.rb +3 -0
- data/tests/jfk_reader/jfk_reader.c +68 -0
- data/tests/test_callback.rb +160 -0
- data/tests/test_error.rb +20 -0
- data/tests/test_model.rb +71 -0
- data/tests/test_package.rb +31 -0
- data/tests/test_params.rb +160 -0
- data/tests/test_segment.rb +83 -0
- data/tests/test_whisper.rb +211 -123
- data/whispercpp.gemspec +36 -0
- metadata +137 -11
- data/ext/ggml.c +0 -8616
- data/ext/ggml.h +0 -748
- data/ext/whisper.cpp +0 -4829
- data/ext/whisper.h +0 -402
@@ -0,0 +1,378 @@
|
|
1
|
+
#include "norm.hpp"
|
2
|
+
|
3
|
+
static void norm_f32(const float* x, float* dst, const int ncols, const float eps,
|
4
|
+
const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
|
5
|
+
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
6
|
+
item_ct1.get_local_id(1);
|
7
|
+
const int tid = item_ct1.get_local_id(2);
|
8
|
+
|
9
|
+
const int nthreads = item_ct1.get_local_range(2);
|
10
|
+
const int nwarps = nthreads / WARP_SIZE;
|
11
|
+
sycl::float2 mean_var = sycl::float2(0.f, 0.f);
|
12
|
+
|
13
|
+
for (int col = tid; col < ncols; col += block_size) {
|
14
|
+
const float xi = x[row * ncols + col];
|
15
|
+
mean_var.x() += xi;
|
16
|
+
mean_var.y() += xi * xi;
|
17
|
+
}
|
18
|
+
|
19
|
+
// sum up partial sums
|
20
|
+
mean_var = warp_reduce_sum(mean_var, item_ct1);
|
21
|
+
if (block_size > WARP_SIZE) {
|
22
|
+
|
23
|
+
int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
24
|
+
int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
25
|
+
if (lane_id == 0) {
|
26
|
+
s_sum[warp_id] = mean_var;
|
27
|
+
}
|
28
|
+
/*
|
29
|
+
DPCT1118:0: SYCL group functions and algorithms must be encountered in
|
30
|
+
converged control flow. You may need to adjust the code.
|
31
|
+
*/
|
32
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
33
|
+
mean_var = 0.f;
|
34
|
+
size_t nreduce = nwarps / WARP_SIZE;
|
35
|
+
for (size_t i = 0; i < nreduce; i += 1)
|
36
|
+
{
|
37
|
+
mean_var += s_sum[lane_id + i * WARP_SIZE];
|
38
|
+
}
|
39
|
+
mean_var = warp_reduce_sum(mean_var, item_ct1);
|
40
|
+
}
|
41
|
+
|
42
|
+
const float mean = mean_var.x() / ncols;
|
43
|
+
const float var = mean_var.y() / ncols - mean * mean;
|
44
|
+
const float inv_std = sycl::rsqrt(var + eps);
|
45
|
+
|
46
|
+
for (int col = tid; col < ncols; col += block_size) {
|
47
|
+
dst[row * ncols + col] = (x[row * ncols + col] - mean) * inv_std;
|
48
|
+
}
|
49
|
+
}
|
50
|
+
|
51
|
+
static void group_norm_f32(const float* x, float* dst, const int group_size, const int ne_elements, const float eps,
|
52
|
+
const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
|
53
|
+
int start = item_ct1.get_group(2) * group_size;
|
54
|
+
int end = start + group_size;
|
55
|
+
const int nthreads = item_ct1.get_local_range(2);
|
56
|
+
const int nwarps = nthreads / WARP_SIZE;
|
57
|
+
start += item_ct1.get_local_id(2);
|
58
|
+
size_t nreduce = nwarps / WARP_SIZE;
|
59
|
+
|
60
|
+
if (end >= ne_elements) {
|
61
|
+
end = ne_elements;
|
62
|
+
}
|
63
|
+
|
64
|
+
float tmp = 0.0f; // partial sum for thread in warp
|
65
|
+
|
66
|
+
for (int j = start; j < end; j += block_size) {
|
67
|
+
tmp += x[j];
|
68
|
+
}
|
69
|
+
|
70
|
+
tmp = warp_reduce_sum(tmp, item_ct1);
|
71
|
+
if (block_size > WARP_SIZE) {
|
72
|
+
|
73
|
+
int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
74
|
+
int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
75
|
+
if (lane_id == 0) {
|
76
|
+
s_sum[warp_id] = tmp;
|
77
|
+
}
|
78
|
+
/*
|
79
|
+
DPCT1118:1: SYCL group functions and algorithms must be encountered in
|
80
|
+
converged control flow. You may need to adjust the code.
|
81
|
+
*/
|
82
|
+
/*
|
83
|
+
DPCT1065:54: Consider replacing sycl::nd_item::barrier() with
|
84
|
+
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
|
85
|
+
better performance if there is no access to global memory.
|
86
|
+
*/
|
87
|
+
item_ct1.barrier();
|
88
|
+
tmp = 0.f;
|
89
|
+
for (size_t i = 0; i < nreduce; i += 1)
|
90
|
+
{
|
91
|
+
tmp += s_sum[lane_id + i * WARP_SIZE];
|
92
|
+
}
|
93
|
+
tmp = warp_reduce_sum(tmp, item_ct1);
|
94
|
+
}
|
95
|
+
|
96
|
+
float mean = tmp / group_size;
|
97
|
+
tmp = 0.0f;
|
98
|
+
|
99
|
+
for (int j = start; j < end; j += block_size) {
|
100
|
+
float xi = x[j] - mean;
|
101
|
+
dst[j] = xi;
|
102
|
+
tmp += xi * xi;
|
103
|
+
}
|
104
|
+
|
105
|
+
tmp = warp_reduce_sum(tmp, item_ct1);
|
106
|
+
if (block_size > WARP_SIZE) {
|
107
|
+
|
108
|
+
int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
109
|
+
int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
110
|
+
if (lane_id == 0) {
|
111
|
+
s_sum[warp_id] = tmp;
|
112
|
+
}
|
113
|
+
/*
|
114
|
+
DPCT1118:2: SYCL group functions and algorithms must be encountered in
|
115
|
+
converged control flow. You may need to adjust the code.
|
116
|
+
*/
|
117
|
+
/*
|
118
|
+
DPCT1065:55: Consider replacing sycl::nd_item::barrier() with
|
119
|
+
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
|
120
|
+
better performance if there is no access to global memory.
|
121
|
+
*/
|
122
|
+
item_ct1.barrier();
|
123
|
+
tmp = 0.f;
|
124
|
+
for (size_t i = 0; i < nreduce; i += 1)
|
125
|
+
{
|
126
|
+
tmp += s_sum[lane_id + i * WARP_SIZE];
|
127
|
+
}
|
128
|
+
tmp = warp_reduce_sum(tmp, item_ct1);
|
129
|
+
}
|
130
|
+
|
131
|
+
float variance = tmp / group_size;
|
132
|
+
float scale = sycl::rsqrt(variance + eps);
|
133
|
+
for (int j = start; j < end; j += block_size) {
|
134
|
+
dst[j] *= scale;
|
135
|
+
}
|
136
|
+
}
|
137
|
+
|
138
|
+
static void rms_norm_f32(const float* x, float* dst, const int ncols, const float eps,
|
139
|
+
const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
|
140
|
+
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
141
|
+
item_ct1.get_local_id(1);
|
142
|
+
const int tid = item_ct1.get_local_id(2);
|
143
|
+
const int nthreads = item_ct1.get_local_range(2);
|
144
|
+
const int nwarps = nthreads / WARP_SIZE;
|
145
|
+
float tmp = 0.0f; // partial sum for thread in warp
|
146
|
+
|
147
|
+
for (int col = tid; col < ncols; col += block_size) {
|
148
|
+
const float xi = x[row * ncols + col];
|
149
|
+
tmp += xi * xi;
|
150
|
+
}
|
151
|
+
|
152
|
+
// sum up partial sums
|
153
|
+
tmp = warp_reduce_sum(tmp, item_ct1);
|
154
|
+
if (block_size > WARP_SIZE) {
|
155
|
+
|
156
|
+
int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
157
|
+
int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
158
|
+
if (lane_id == 0) {
|
159
|
+
s_sum[warp_id] = tmp;
|
160
|
+
}
|
161
|
+
/*
|
162
|
+
DPCT1118:3: SYCL group functions and algorithms must be encountered in
|
163
|
+
converged control flow. You may need to adjust the code.
|
164
|
+
*/
|
165
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
166
|
+
size_t nreduce = nwarps / WARP_SIZE;
|
167
|
+
tmp = 0.f;
|
168
|
+
for (size_t i = 0; i < nreduce; i += 1)
|
169
|
+
{
|
170
|
+
tmp += s_sum[lane_id + i * WARP_SIZE];
|
171
|
+
}
|
172
|
+
tmp = warp_reduce_sum(tmp, item_ct1);
|
173
|
+
}
|
174
|
+
|
175
|
+
const float mean = tmp / ncols;
|
176
|
+
const float scale = sycl::rsqrt(mean + eps);
|
177
|
+
|
178
|
+
for (int col = tid; col < ncols; col += block_size) {
|
179
|
+
dst[row * ncols + col] = scale * x[row * ncols + col];
|
180
|
+
}
|
181
|
+
}
|
182
|
+
|
183
|
+
static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
184
|
+
const int nrows, const float eps,
|
185
|
+
queue_ptr stream, int device) {
|
186
|
+
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
187
|
+
if (ncols < 1024) {
|
188
|
+
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
189
|
+
stream->submit([&](sycl::handler& cgh) {
|
190
|
+
cgh.parallel_for(
|
191
|
+
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
192
|
+
block_dims),
|
193
|
+
[=](sycl::nd_item<3> item_ct1)
|
194
|
+
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
195
|
+
norm_f32(x, dst, ncols, eps, item_ct1,
|
196
|
+
nullptr, WARP_SIZE);
|
197
|
+
});
|
198
|
+
});
|
199
|
+
}
|
200
|
+
else {
|
201
|
+
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
202
|
+
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
203
|
+
const sycl::range<3> block_dims(1, 1, work_group_size);
|
204
|
+
/*
|
205
|
+
DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
|
206
|
+
the limit. To get the device limit, query
|
207
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
208
|
+
*/
|
209
|
+
stream->submit([&](sycl::handler& cgh) {
|
210
|
+
sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
|
211
|
+
sycl::range<1>(work_group_size / WARP_SIZE), cgh);
|
212
|
+
|
213
|
+
cgh.parallel_for(
|
214
|
+
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
215
|
+
block_dims),
|
216
|
+
[=](sycl::nd_item<3> item_ct1)
|
217
|
+
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
218
|
+
norm_f32(x, dst, ncols, eps, item_ct1,
|
219
|
+
get_pointer(s_sum_acc_ct1), work_group_size);
|
220
|
+
});
|
221
|
+
});
|
222
|
+
}
|
223
|
+
}
|
224
|
+
|
225
|
+
static void group_norm_f32_sycl(const float* x, float* dst,
|
226
|
+
const int num_groups, const float eps, const int group_size,
|
227
|
+
const int ne_elements, queue_ptr stream, int device) {
|
228
|
+
if (group_size < 1024) {
|
229
|
+
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
230
|
+
stream->submit([&](sycl::handler& cgh) {
|
231
|
+
const float eps_ct4 = eps;
|
232
|
+
cgh.parallel_for(
|
233
|
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
|
234
|
+
block_dims),
|
235
|
+
[=](sycl::nd_item<3> item_ct1)
|
236
|
+
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
237
|
+
group_norm_f32(
|
238
|
+
x, dst, group_size, ne_elements, eps_ct4, item_ct1,
|
239
|
+
nullptr, WARP_SIZE);
|
240
|
+
});
|
241
|
+
});
|
242
|
+
}
|
243
|
+
else {
|
244
|
+
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
245
|
+
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
246
|
+
const sycl::range<3> block_dims(1, 1, work_group_size);
|
247
|
+
/*
|
248
|
+
DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
|
249
|
+
the limit. To get the device limit, query
|
250
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
251
|
+
*/
|
252
|
+
|
253
|
+
stream->submit([&](sycl::handler& cgh) {
|
254
|
+
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
255
|
+
cgh);
|
256
|
+
|
257
|
+
const float eps_ct4 = eps;
|
258
|
+
|
259
|
+
cgh.parallel_for(
|
260
|
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
|
261
|
+
block_dims),
|
262
|
+
[=](sycl::nd_item<3> item_ct1)
|
263
|
+
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
264
|
+
group_norm_f32(x, dst, group_size, ne_elements,
|
265
|
+
eps_ct4, item_ct1,
|
266
|
+
get_pointer(s_sum_acc_ct1), work_group_size);
|
267
|
+
});
|
268
|
+
});
|
269
|
+
}
|
270
|
+
}
|
271
|
+
|
272
|
+
static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
273
|
+
const int nrows, const float eps,
|
274
|
+
queue_ptr stream, int device) {
|
275
|
+
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
276
|
+
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
277
|
+
if (ncols < 1024) {
|
278
|
+
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
279
|
+
stream->submit([&](sycl::handler& cgh) {
|
280
|
+
cgh.parallel_for(
|
281
|
+
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
282
|
+
block_dims),
|
283
|
+
[=](sycl::nd_item<3> item_ct1)
|
284
|
+
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
285
|
+
rms_norm_f32(x, dst, ncols, eps, item_ct1,
|
286
|
+
nullptr, WARP_SIZE);
|
287
|
+
});
|
288
|
+
});
|
289
|
+
}
|
290
|
+
else {
|
291
|
+
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
292
|
+
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
293
|
+
const sycl::range<3> block_dims(1, 1, work_group_size);
|
294
|
+
/*
|
295
|
+
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
|
296
|
+
the limit. To get the device limit, query
|
297
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
298
|
+
*/
|
299
|
+
stream->submit([&](sycl::handler& cgh) {
|
300
|
+
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
301
|
+
cgh);
|
302
|
+
cgh.parallel_for(
|
303
|
+
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
304
|
+
block_dims),
|
305
|
+
[=](sycl::nd_item<3> item_ct1)
|
306
|
+
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
307
|
+
rms_norm_f32(x, dst, ncols, eps, item_ct1,
|
308
|
+
get_pointer(s_sum_acc_ct1), work_group_size);
|
309
|
+
});
|
310
|
+
});
|
311
|
+
}
|
312
|
+
}
|
313
|
+
|
314
|
+
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
|
315
|
+
ggml_tensor* dst, const float* src0_dd,
|
316
|
+
const float* src1_dd, float* dst_dd,
|
317
|
+
const queue_ptr& main_stream) {
|
318
|
+
|
319
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
320
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
321
|
+
|
322
|
+
const int64_t ne00 = src0->ne[0];
|
323
|
+
const int64_t nrows = ggml_nrows(src0);
|
324
|
+
|
325
|
+
float eps;
|
326
|
+
memcpy(&eps, dst->op_params, sizeof(float));
|
327
|
+
|
328
|
+
norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
|
329
|
+
|
330
|
+
(void)src1;
|
331
|
+
(void)dst;
|
332
|
+
(void)src1_dd;
|
333
|
+
}
|
334
|
+
|
335
|
+
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
336
|
+
const ggml_tensor* src1, ggml_tensor* dst,
|
337
|
+
const float* src0_dd, const float* src1_dd,
|
338
|
+
float* dst_dd,
|
339
|
+
const queue_ptr& main_stream) {
|
340
|
+
|
341
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
342
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
343
|
+
|
344
|
+
int num_groups = dst->op_params[0];
|
345
|
+
|
346
|
+
float eps;
|
347
|
+
memcpy(&eps, dst->op_params + 1, sizeof(float));
|
348
|
+
|
349
|
+
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
|
350
|
+
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
|
351
|
+
|
352
|
+
(void)src1;
|
353
|
+
(void)dst;
|
354
|
+
(void)src1_dd;
|
355
|
+
GGML_UNUSED(ctx);
|
356
|
+
}
|
357
|
+
|
358
|
+
void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
359
|
+
const ggml_tensor* src1, ggml_tensor* dst,
|
360
|
+
const float* src0_dd, const float* src1_dd,
|
361
|
+
float* dst_dd,
|
362
|
+
const queue_ptr& main_stream) {
|
363
|
+
|
364
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
365
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
366
|
+
|
367
|
+
const int64_t ne00 = src0->ne[0];
|
368
|
+
const int64_t nrows = ggml_nrows(src0);
|
369
|
+
|
370
|
+
float eps;
|
371
|
+
memcpy(&eps, dst->op_params, sizeof(float));
|
372
|
+
|
373
|
+
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
|
374
|
+
|
375
|
+
(void)src1;
|
376
|
+
(void)dst;
|
377
|
+
(void)src1_dd;
|
378
|
+
}
|
@@ -0,0 +1,56 @@
|
|
1
|
+
#include <sycl/sycl.hpp>
|
2
|
+
#include <oneapi/mkl.hpp>
|
3
|
+
#include "outprod.hpp"
|
4
|
+
|
5
|
+
|
6
|
+
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
7
|
+
const ggml_tensor* src1, ggml_tensor* dst) {
|
8
|
+
|
9
|
+
|
10
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
11
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
12
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
13
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
14
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
15
|
+
|
16
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
17
|
+
|
18
|
+
// Get SYCL queue
|
19
|
+
dpct::queue_ptr stream = ctx.stream();
|
20
|
+
|
21
|
+
// Dimension checks
|
22
|
+
GGML_ASSERT(ne01 == ne11); // Inner dimensions must match
|
23
|
+
GGML_ASSERT(ne0 == ne00); // Output rows match src0 rows
|
24
|
+
GGML_ASSERT(ne1 == ne10); // Output cols match src1 cols
|
25
|
+
|
26
|
+
// Get data pointers
|
27
|
+
const float* src0_d = (const float*)src0->data;
|
28
|
+
const float* src1_d = (const float*)src1->data;
|
29
|
+
float* dst_d = (float*)dst->data;
|
30
|
+
|
31
|
+
// GEMM parameters
|
32
|
+
const float alpha = 1.0f;
|
33
|
+
const float beta = 0.0f;
|
34
|
+
|
35
|
+
// Handle transposition of src1
|
36
|
+
const bool src1_T = ggml_is_transposed(src1);
|
37
|
+
const oneapi::mkl::transpose src1_op =
|
38
|
+
src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;
|
39
|
+
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
|
40
|
+
|
41
|
+
try {
|
42
|
+
// Perform matrix multiplication using oneMKL GEMM
|
43
|
+
#ifdef GGML_SYCL_NVIDIA
|
44
|
+
oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream },
|
45
|
+
oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d,
|
46
|
+
ne00, src1_d, ldb, beta, dst_d, ne0);
|
47
|
+
#else
|
48
|
+
oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha,
|
49
|
+
src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
|
50
|
+
#endif
|
51
|
+
}
|
52
|
+
catch (sycl::exception const& exc) {
|
53
|
+
std::cerr << exc.what() << std::endl;
|
54
|
+
GGML_ASSERT(false);
|
55
|
+
}
|
56
|
+
}
|