whispercpp 1.3.0 → 1.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.gitignore +5 -0
- data/LICENSE +1 -1
- data/README.md +165 -434
- data/Rakefile +60 -11
- data/ext/.gitignore +13 -0
- data/ext/cpu.mk +9 -0
- data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
- data/ext/extconf.rb +185 -16
- data/ext/ggml/include/ggml-alloc.h +76 -0
- data/ext/ggml/include/ggml-backend.h +352 -0
- data/ext/ggml/include/ggml-blas.h +25 -0
- data/ext/ggml/include/ggml-cann.h +123 -0
- data/ext/ggml/include/ggml-cpp.h +38 -0
- data/ext/ggml/include/ggml-cpu.h +135 -0
- data/ext/ggml/include/ggml-cuda.h +47 -0
- data/ext/ggml/include/ggml-kompute.h +50 -0
- data/ext/ggml/include/ggml-metal.h +66 -0
- data/ext/ggml/include/ggml-opencl.h +26 -0
- data/ext/ggml/include/ggml-opt.h +216 -0
- data/ext/ggml/include/ggml-rpc.h +28 -0
- data/ext/ggml/include/ggml-sycl.h +49 -0
- data/ext/ggml/include/ggml-vulkan.h +31 -0
- data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
- data/ext/ggml/src/ggml-alloc.c +1037 -0
- data/ext/ggml/src/ggml-amx/common.h +94 -0
- data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
- data/ext/ggml/src/ggml-amx/mmq.h +17 -0
- data/ext/ggml/src/ggml-backend-impl.h +256 -0
- data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
- data/ext/ggml/src/ggml-backend.cpp +1999 -0
- data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
- data/ext/ggml/src/ggml-cann/common.h +286 -0
- data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
- data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
- data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
- data/ext/ggml/src/ggml-common.h +1853 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
- data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
- data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
- data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- data/ext/ggml/src/ggml-impl.h +556 -0
- data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
- data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
- data/ext/ggml/src/ggml-opt.cpp +854 -0
- data/ext/ggml/src/ggml-quants.c +5238 -0
- data/ext/ggml/src/ggml-quants.h +100 -0
- data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
- data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
- data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
- data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
- data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
- data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
- data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
- data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
- data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
- data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
- data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
- data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
- data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
- data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
- data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
- data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
- data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
- data/ext/ggml/src/ggml-threading.cpp +12 -0
- data/ext/ggml/src/ggml-threading.h +14 -0
- data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
- data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
- data/ext/ggml/src/ggml.c +7694 -0
- data/ext/{whisper.h → include/whisper.h} +23 -22
- data/ext/metal-embed.mk +17 -0
- data/ext/metal.mk +6 -0
- data/ext/ruby_whisper.cpp +1492 -9
- data/ext/ruby_whisper.h +10 -0
- data/ext/scripts/get-flags.mk +38 -0
- data/ext/src/coreml/whisper-decoder-impl.h +146 -0
- data/ext/src/coreml/whisper-decoder-impl.m +201 -0
- data/ext/src/coreml/whisper-encoder-impl.h +142 -0
- data/ext/src/coreml/whisper-encoder-impl.m +197 -0
- data/ext/src/coreml/whisper-encoder.h +26 -0
- data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
- data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
- data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
- data/extsources.rb +6 -0
- data/lib/whisper/model/uri.rb +157 -0
- data/lib/whisper.rb +2 -0
- data/tests/helper.rb +7 -0
- data/tests/jfk_reader/.gitignore +5 -0
- data/tests/jfk_reader/extconf.rb +3 -0
- data/tests/jfk_reader/jfk_reader.c +68 -0
- data/tests/test_callback.rb +160 -0
- data/tests/test_error.rb +20 -0
- data/tests/test_model.rb +71 -0
- data/tests/test_package.rb +31 -0
- data/tests/test_params.rb +160 -0
- data/tests/test_segment.rb +83 -0
- data/tests/test_whisper.rb +211 -123
- data/whispercpp.gemspec +36 -0
- metadata +137 -11
- data/ext/ggml.c +0 -21755
@@ -0,0 +1,276 @@
|
|
1
|
+
#include "rope.hpp"
|
2
|
+
|
3
|
+
struct rope_corr_dims {
|
4
|
+
float v[2];
|
5
|
+
};
|
6
|
+
|
7
|
+
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
8
|
+
const float y = (i0 / 2 - low) / sycl::max(0.001f, high - low);
|
9
|
+
return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
|
10
|
+
}
|
11
|
+
|
12
|
+
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
13
|
+
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
14
|
+
static void rope_yarn(
|
15
|
+
float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
|
16
|
+
float * cos_theta, float * sin_theta) {
|
17
|
+
// Get n-d rotational scaling corrected for extrapolation
|
18
|
+
float theta_interp = freq_scale * theta_extrap;
|
19
|
+
float theta = theta_interp;
|
20
|
+
if (ext_factor != 0.0f) {
|
21
|
+
float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
|
22
|
+
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
23
|
+
|
24
|
+
// Get n-d magnitude scaling corrected for interpolation
|
25
|
+
mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale);
|
26
|
+
}
|
27
|
+
*cos_theta = sycl::cos(theta) * mscale;
|
28
|
+
*sin_theta = sycl::sin(theta) * mscale;
|
29
|
+
}
|
30
|
+
|
31
|
+
template<typename T, bool has_ff>
|
32
|
+
static void rope_norm(
|
33
|
+
const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
|
34
|
+
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
|
35
|
+
const sycl::nd_item<3> &item_ct1) {
|
36
|
+
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
37
|
+
item_ct1.get_local_id(1));
|
38
|
+
|
39
|
+
if (i0 >= ne0) {
|
40
|
+
return;
|
41
|
+
}
|
42
|
+
|
43
|
+
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
44
|
+
item_ct1.get_local_id(2);
|
45
|
+
|
46
|
+
if (i0 >= n_dims) {
|
47
|
+
const int i = row*ne0 + i0;
|
48
|
+
|
49
|
+
dst[i + 0] = x[i + 0];
|
50
|
+
dst[i + 1] = x[i + 1];
|
51
|
+
|
52
|
+
return;
|
53
|
+
}
|
54
|
+
|
55
|
+
const int i = row*ne0 + i0;
|
56
|
+
const int i2 = row/p_delta_rows;
|
57
|
+
|
58
|
+
const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f);
|
59
|
+
|
60
|
+
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
61
|
+
|
62
|
+
float cos_theta;
|
63
|
+
float sin_theta;
|
64
|
+
|
65
|
+
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
66
|
+
|
67
|
+
const float x0 = x[i + 0];
|
68
|
+
const float x1 = x[i + 1];
|
69
|
+
|
70
|
+
dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
71
|
+
dst[i + 1] = x0*sin_theta + x1*cos_theta;
|
72
|
+
}
|
73
|
+
|
74
|
+
template<typename T, bool has_ff>
|
75
|
+
static void rope_neox(
|
76
|
+
const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
|
77
|
+
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
|
78
|
+
const sycl::nd_item<3> &item_ct1) {
|
79
|
+
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
80
|
+
item_ct1.get_local_id(1));
|
81
|
+
|
82
|
+
if (i0 >= ne0) {
|
83
|
+
return;
|
84
|
+
}
|
85
|
+
|
86
|
+
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
87
|
+
item_ct1.get_local_id(2);
|
88
|
+
|
89
|
+
if (i0 >= n_dims) {
|
90
|
+
const int i = row*ne0 + i0;
|
91
|
+
|
92
|
+
dst[i + 0] = x[i + 0];
|
93
|
+
dst[i + 1] = x[i + 1];
|
94
|
+
|
95
|
+
return;
|
96
|
+
}
|
97
|
+
|
98
|
+
const int i = row*ne0 + i0/2;
|
99
|
+
const int i2 = row/p_delta_rows;
|
100
|
+
|
101
|
+
const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f);
|
102
|
+
|
103
|
+
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
104
|
+
|
105
|
+
float cos_theta;
|
106
|
+
float sin_theta;
|
107
|
+
|
108
|
+
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
109
|
+
|
110
|
+
const float x0 = x[i + 0];
|
111
|
+
const float x1 = x[i + n_dims/2];
|
112
|
+
|
113
|
+
dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
114
|
+
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
115
|
+
}
|
116
|
+
|
117
|
+
template <typename T>
|
118
|
+
static void rope_norm_sycl(
|
119
|
+
const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
|
120
|
+
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
|
121
|
+
GGML_ASSERT(ne0 % 2 == 0);
|
122
|
+
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
123
|
+
const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
|
124
|
+
const sycl::range<3> block_nums(1, num_blocks_x, nr);
|
125
|
+
|
126
|
+
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
127
|
+
|
128
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
129
|
+
{sycl::aspect::fp16});
|
130
|
+
|
131
|
+
if (freq_factors == nullptr) {
|
132
|
+
/*
|
133
|
+
DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
|
134
|
+
the limit. To get the device limit, query
|
135
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
136
|
+
*/
|
137
|
+
stream->parallel_for(
|
138
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
139
|
+
[=](sycl::nd_item<3> item_ct1) {
|
140
|
+
rope_norm<T, false>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
|
141
|
+
ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
|
142
|
+
item_ct1);
|
143
|
+
});
|
144
|
+
} else {
|
145
|
+
/*
|
146
|
+
DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
|
147
|
+
the limit. To get the device limit, query
|
148
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
149
|
+
*/
|
150
|
+
stream->parallel_for(
|
151
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
152
|
+
[=](sycl::nd_item<3> item_ct1) {
|
153
|
+
rope_norm<T, true>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
|
154
|
+
ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
|
155
|
+
item_ct1);
|
156
|
+
});
|
157
|
+
}
|
158
|
+
}
|
159
|
+
|
160
|
+
template <typename T>
|
161
|
+
static void rope_neox_sycl(
|
162
|
+
const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
|
163
|
+
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
|
164
|
+
GGML_ASSERT(ne0 % 2 == 0);
|
165
|
+
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
166
|
+
const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
|
167
|
+
const sycl::range<3> block_nums(1, num_blocks_x, nr);
|
168
|
+
|
169
|
+
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
170
|
+
|
171
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
172
|
+
{sycl::aspect::fp16});
|
173
|
+
|
174
|
+
if (freq_factors == nullptr) {
|
175
|
+
stream->parallel_for(
|
176
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
177
|
+
[=](sycl::nd_item<3> item_ct1) {
|
178
|
+
rope_neox<T, false>(x, dst, ne0, n_dims, pos, freq_scale,
|
179
|
+
p_delta_rows, ext_factor, attn_factor,
|
180
|
+
corr_dims, theta_scale, freq_factors,
|
181
|
+
item_ct1);
|
182
|
+
});
|
183
|
+
} else {
|
184
|
+
stream->parallel_for(
|
185
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
186
|
+
[=](sycl::nd_item<3> item_ct1) {
|
187
|
+
rope_neox<T, true>(x, dst, ne0, n_dims, pos, freq_scale,
|
188
|
+
p_delta_rows, ext_factor, attn_factor,
|
189
|
+
corr_dims, theta_scale, freq_factors,
|
190
|
+
item_ct1);
|
191
|
+
});
|
192
|
+
}
|
193
|
+
}
|
194
|
+
|
195
|
+
void ggml_sycl_op_rope(
|
196
|
+
ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
|
197
|
+
const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) {
|
198
|
+
const ggml_tensor * src2 = dst->src[2];
|
199
|
+
|
200
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
201
|
+
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
202
|
+
GGML_ASSERT(src0->type == dst->type);
|
203
|
+
|
204
|
+
const int64_t ne00 = src0->ne[0];
|
205
|
+
const int64_t ne01 = src0->ne[1];
|
206
|
+
const int64_t nr = ggml_nrows(src0);
|
207
|
+
|
208
|
+
//const int n_past = ((int32_t *) dst->op_params)[0];
|
209
|
+
const int n_dims = ((int32_t *) dst->op_params)[1];
|
210
|
+
const int mode = ((int32_t *) dst->op_params)[2];
|
211
|
+
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
212
|
+
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
213
|
+
|
214
|
+
// RoPE alteration for extended context
|
215
|
+
float freq_base;
|
216
|
+
float freq_scale;
|
217
|
+
float ext_factor;
|
218
|
+
float attn_factor;
|
219
|
+
float beta_fast;
|
220
|
+
float beta_slow;
|
221
|
+
|
222
|
+
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
223
|
+
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
224
|
+
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
225
|
+
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
226
|
+
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
227
|
+
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
228
|
+
|
229
|
+
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
230
|
+
|
231
|
+
const int32_t * pos = (const int32_t *) src1_dd;
|
232
|
+
|
233
|
+
const float * freq_factors = nullptr;
|
234
|
+
if (src2 != nullptr) {
|
235
|
+
freq_factors = (const float *) src2->data;
|
236
|
+
}
|
237
|
+
|
238
|
+
rope_corr_dims corr_dims;
|
239
|
+
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
|
240
|
+
|
241
|
+
// compute
|
242
|
+
if (is_neox) {
|
243
|
+
if (src0->type == GGML_TYPE_F32) {
|
244
|
+
rope_neox_sycl(
|
245
|
+
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
246
|
+
attn_factor, corr_dims, freq_factors, main_stream
|
247
|
+
);
|
248
|
+
} else if (src0->type == GGML_TYPE_F16) {
|
249
|
+
rope_neox_sycl(
|
250
|
+
(const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
251
|
+
attn_factor, corr_dims, freq_factors, main_stream
|
252
|
+
);
|
253
|
+
} else {
|
254
|
+
GGML_ABORT("fatal error");
|
255
|
+
}
|
256
|
+
} else {
|
257
|
+
if (src0->type == GGML_TYPE_F32) {
|
258
|
+
rope_norm_sycl(
|
259
|
+
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
260
|
+
attn_factor, corr_dims, freq_factors, main_stream
|
261
|
+
);
|
262
|
+
} else if (src0->type == GGML_TYPE_F16) {
|
263
|
+
rope_norm_sycl(
|
264
|
+
(const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
265
|
+
attn_factor, corr_dims, freq_factors, main_stream
|
266
|
+
);
|
267
|
+
} else {
|
268
|
+
GGML_ABORT("fatal error");
|
269
|
+
}
|
270
|
+
}
|
271
|
+
|
272
|
+
GGML_UNUSED(src1);
|
273
|
+
GGML_UNUSED(dst);
|
274
|
+
GGML_UNUSED(src1_dd);
|
275
|
+
GGML_UNUSED(ctx);
|
276
|
+
}
|
@@ -0,0 +1,251 @@
|
|
1
|
+
#include "norm.hpp"
|
2
|
+
|
3
|
+
template <bool vals_smem, int ncols_template, int block_size_template>
|
4
|
+
static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
|
5
|
+
const int nrows_y, const float scale, const float max_bias, const float m0,
|
6
|
+
const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
|
7
|
+
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
8
|
+
|
9
|
+
const int tid = item_ct1.get_local_id(2);
|
10
|
+
const int rowx = item_ct1.get_group(2);
|
11
|
+
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
|
12
|
+
|
13
|
+
const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
|
14
|
+
|
15
|
+
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
16
|
+
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
17
|
+
const int nthreads = block_size;
|
18
|
+
const int nwarps = nthreads / WARP_SIZE;
|
19
|
+
size_t nreduce = nwarps / WARP_SIZE;
|
20
|
+
float slope = 1.0f;
|
21
|
+
|
22
|
+
// ALiBi
|
23
|
+
if (max_bias > 0.0f) {
|
24
|
+
const uint32_t h = rowx/nrows_y; // head index
|
25
|
+
|
26
|
+
const float base = h < n_head_log2 ? m0 : m1;
|
27
|
+
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
28
|
+
|
29
|
+
slope = sycl::pow(base, float(exp));
|
30
|
+
}
|
31
|
+
|
32
|
+
float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
|
33
|
+
float max_val = -INFINITY;
|
34
|
+
|
35
|
+
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
36
|
+
const int col = col0 + tid;
|
37
|
+
|
38
|
+
if (ncols_template == 0 && col >= ncols) {
|
39
|
+
break;
|
40
|
+
}
|
41
|
+
|
42
|
+
const int ix = rowx*ncols + col;
|
43
|
+
const int iy = rowy*ncols + col;
|
44
|
+
|
45
|
+
const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
|
46
|
+
|
47
|
+
vals[col] = val;
|
48
|
+
max_val = sycl::max(max_val, val);
|
49
|
+
}
|
50
|
+
|
51
|
+
// find the max value in the block
|
52
|
+
max_val = warp_reduce_max(max_val, item_ct1);
|
53
|
+
if (block_size > WARP_SIZE) {
|
54
|
+
if (warp_id == 0) {
|
55
|
+
buf[lane_id] = -INFINITY;
|
56
|
+
for (size_t i = 1; i < nreduce; i += 1) {
|
57
|
+
buf[lane_id + i * WARP_SIZE] = -INFINITY;
|
58
|
+
}
|
59
|
+
}
|
60
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
61
|
+
|
62
|
+
if (lane_id == 0) {
|
63
|
+
buf[warp_id] = max_val;
|
64
|
+
}
|
65
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
66
|
+
max_val = buf[lane_id];
|
67
|
+
for (size_t i = 1; i < nreduce; i += 1) {
|
68
|
+
max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]);
|
69
|
+
}
|
70
|
+
max_val = warp_reduce_max(max_val, item_ct1);
|
71
|
+
}
|
72
|
+
|
73
|
+
float tmp = 0.f;
|
74
|
+
#pragma unroll
|
75
|
+
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
76
|
+
const int col = col0 + tid;
|
77
|
+
if (ncols_template == 0 && col >= ncols) {
|
78
|
+
break;
|
79
|
+
}
|
80
|
+
|
81
|
+
const float val = sycl::native::exp(vals[col] - max_val);
|
82
|
+
tmp += val;
|
83
|
+
vals[col] = val;
|
84
|
+
}
|
85
|
+
|
86
|
+
// find the sum of exps in the block
|
87
|
+
tmp = warp_reduce_sum(tmp, item_ct1);
|
88
|
+
if (block_size > WARP_SIZE) {
|
89
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
90
|
+
if (warp_id == 0) {
|
91
|
+
buf[lane_id] = 0.f;
|
92
|
+
for (size_t i = 1; i < nreduce; i += 1) {
|
93
|
+
buf[lane_id + i * WARP_SIZE] = 0.f;
|
94
|
+
}
|
95
|
+
}
|
96
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
97
|
+
|
98
|
+
if (lane_id == 0) {
|
99
|
+
buf[warp_id] = tmp;
|
100
|
+
}
|
101
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
102
|
+
|
103
|
+
tmp = buf[lane_id];
|
104
|
+
for (size_t i = 1; i < nreduce; i += 1) {
|
105
|
+
tmp += buf[lane_id + i * WARP_SIZE];
|
106
|
+
}
|
107
|
+
tmp = warp_reduce_sum(tmp, item_ct1);
|
108
|
+
}
|
109
|
+
|
110
|
+
const float inv_sum = 1.f / tmp;
|
111
|
+
|
112
|
+
#pragma unroll
|
113
|
+
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
114
|
+
const int col = col0 + tid;
|
115
|
+
|
116
|
+
if (ncols_template == 0 && col >= ncols) {
|
117
|
+
return;
|
118
|
+
}
|
119
|
+
|
120
|
+
const int idst = rowx*ncols + col;
|
121
|
+
dst[idst] = vals[col] * inv_sum;
|
122
|
+
}
|
123
|
+
}
|
124
|
+
|
125
|
+
template <bool vals_smem, int ncols_template, int block_size_template>
|
126
|
+
static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
|
127
|
+
const int nrows_y, const float scale, const float max_bias, const float m0,
|
128
|
+
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
|
129
|
+
const size_t n_local_scratch, queue_ptr stream) {
|
130
|
+
stream->submit([&](sycl::handler &cgh) {
|
131
|
+
sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
|
132
|
+
|
133
|
+
cgh.parallel_for(
|
134
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
135
|
+
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
136
|
+
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
|
137
|
+
nrows_y, scale, max_bias, m0,
|
138
|
+
m1, n_head_log2, item_ct1,
|
139
|
+
get_pointer(local_buf_acc));
|
140
|
+
});
|
141
|
+
});
|
142
|
+
}
|
143
|
+
|
144
|
+
static void soft_max_f32_sycl(const float * x, const float * mask,
|
145
|
+
float * dst, const int ncols_x, const int nrows_x,
|
146
|
+
const int nrows_y, const float scale, const float max_bias,
|
147
|
+
queue_ptr stream, int device) {
|
148
|
+
int nth = WARP_SIZE;
|
149
|
+
int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
|
150
|
+
while (nth < ncols_x && nth < max_block_size) nth *= 2;
|
151
|
+
if (nth>max_block_size) nth = max_block_size;
|
152
|
+
|
153
|
+
const sycl::range<3> block_dims(1, 1, nth);
|
154
|
+
const sycl::range<3> block_nums(1, 1, nrows_x);
|
155
|
+
const size_t n_val_tmp = nth / WARP_SIZE;
|
156
|
+
const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + n_val_tmp);
|
157
|
+
|
158
|
+
const uint32_t n_head_kv = nrows_x/nrows_y;
|
159
|
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
160
|
+
|
161
|
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
162
|
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
163
|
+
|
164
|
+
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
|
165
|
+
if (n_local_scratch*sizeof(float) < local_mem_size) {
|
166
|
+
if (ncols_x > max_block_size) {
|
167
|
+
soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
|
168
|
+
max_bias, m0, m1, n_head_log2, block_nums,
|
169
|
+
block_dims, n_local_scratch, stream);
|
170
|
+
return;
|
171
|
+
}
|
172
|
+
switch (ncols_x) {
|
173
|
+
case 32:
|
174
|
+
soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
|
175
|
+
max_bias, m0, m1, n_head_log2, block_nums,
|
176
|
+
block_dims, n_local_scratch, stream);
|
177
|
+
break;
|
178
|
+
case 64:
|
179
|
+
soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
|
180
|
+
max_bias, m0, m1, n_head_log2, block_nums,
|
181
|
+
block_dims, n_local_scratch, stream);
|
182
|
+
break;
|
183
|
+
case 128:
|
184
|
+
soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
|
185
|
+
max_bias, m0, m1, n_head_log2, block_nums,
|
186
|
+
block_dims, n_local_scratch, stream);
|
187
|
+
break;
|
188
|
+
case 256:
|
189
|
+
soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
|
190
|
+
max_bias, m0, m1, n_head_log2, block_nums,
|
191
|
+
block_dims, n_local_scratch, stream);
|
192
|
+
break;
|
193
|
+
case 512:
|
194
|
+
soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
|
195
|
+
max_bias, m0, m1, n_head_log2, block_nums,
|
196
|
+
block_dims, n_local_scratch, stream);
|
197
|
+
break;
|
198
|
+
case 1024:
|
199
|
+
soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
|
200
|
+
max_bias, m0, m1, n_head_log2, block_nums,
|
201
|
+
block_dims, n_local_scratch, stream);
|
202
|
+
break;
|
203
|
+
case 2048:
|
204
|
+
soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
|
205
|
+
max_bias, m0, m1, n_head_log2, block_nums,
|
206
|
+
block_dims, n_local_scratch, stream);
|
207
|
+
break;
|
208
|
+
case 4096:
|
209
|
+
soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
|
210
|
+
max_bias, m0, m1, n_head_log2, block_nums,
|
211
|
+
block_dims, n_local_scratch, stream);
|
212
|
+
break;
|
213
|
+
default:
|
214
|
+
soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
|
215
|
+
max_bias, m0, m1, n_head_log2, block_nums,
|
216
|
+
block_dims, n_local_scratch, stream);
|
217
|
+
break;
|
218
|
+
}
|
219
|
+
} else {
|
220
|
+
soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
|
221
|
+
max_bias, m0, m1, n_head_log2, block_nums,
|
222
|
+
block_dims, WARP_SIZE, stream);
|
223
|
+
}
|
224
|
+
}
|
225
|
+
|
226
|
+
void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
227
|
+
const ggml_tensor *src1, ggml_tensor *dst,
|
228
|
+
const float *src0_dd, const float *src1_dd,
|
229
|
+
float *dst_dd,
|
230
|
+
const queue_ptr &main_stream) {
|
231
|
+
|
232
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
233
|
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
234
|
+
|
235
|
+
#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
|
236
|
+
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
237
|
+
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
238
|
+
|
239
|
+
const int64_t ne00 = src0->ne[0];
|
240
|
+
const int64_t nrows_x = ggml_nrows(src0);
|
241
|
+
const int64_t nrows_y = src0->ne[1];
|
242
|
+
|
243
|
+
float scale = 1.0f;
|
244
|
+
float max_bias = 0.0f;
|
245
|
+
|
246
|
+
memcpy(&scale, dst->op_params + 0, sizeof(float));
|
247
|
+
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
|
248
|
+
|
249
|
+
soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
|
250
|
+
nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
|
251
|
+
}
|
@@ -0,0 +1,72 @@
|
|
1
|
+
//
|
2
|
+
// MIT license
|
3
|
+
// Copyright (C) 2024 Intel Corporation
|
4
|
+
// SPDX-License-Identifier: MIT
|
5
|
+
//
|
6
|
+
|
7
|
+
//
|
8
|
+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
9
|
+
// See https://llvm.org/LICENSE.txt for license information.
|
10
|
+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
11
|
+
//
|
12
|
+
|
13
|
+
#include "tsembd.hpp"
|
14
|
+
|
15
|
+
static void timestep_embedding_f32(
|
16
|
+
const float * timesteps, float * dst, const int nb1,
|
17
|
+
const int dim, const int max_period, const sycl::nd_item<3> &item_ct1) {
|
18
|
+
// item_ct1.get_group(1)(blockIDx.y): idx of timesteps->ne[0]
|
19
|
+
// item_ct1.get_group(2) (blockIDx.x): idx of ((dim + 1) / 2) / BLOCK_SIZE
|
20
|
+
int i = item_ct1.get_group(1);
|
21
|
+
int j = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
22
|
+
float * embed_data = (float *)((char *)dst + i*nb1);
|
23
|
+
|
24
|
+
if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
|
25
|
+
embed_data[dim] = 0.f;
|
26
|
+
}
|
27
|
+
|
28
|
+
int half = dim / 2;
|
29
|
+
if (j >= half) {
|
30
|
+
return;
|
31
|
+
}
|
32
|
+
|
33
|
+
float timestep = timesteps[i];
|
34
|
+
float freq = (float)sycl::native::exp(-(sycl::log((float)max_period)) * j / half);
|
35
|
+
float arg = timestep * freq;
|
36
|
+
embed_data[j] = sycl::cos(arg);
|
37
|
+
embed_data[j + half] = sycl::sin(arg);
|
38
|
+
}
|
39
|
+
|
40
|
+
static void timestep_embedding_f32_sycl(
|
41
|
+
const float * x, float * dst, const int ne00, const int nb1,
|
42
|
+
const int dim, const int max_period, const queue_ptr& stream) {
|
43
|
+
// As the kernel returns when thread.idx is larger than dim/2, the half_ceil does not need to pad
|
44
|
+
int half_ceil = dim / 2;
|
45
|
+
int num_blocks = (half_ceil + SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE;
|
46
|
+
sycl::range<3> block_dims(1, 1, SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE);
|
47
|
+
sycl::range<3> gridDim(1, ne00, num_blocks);
|
48
|
+
stream->parallel_for(
|
49
|
+
sycl::nd_range<3>(
|
50
|
+
gridDim * block_dims, block_dims),
|
51
|
+
[=](sycl::nd_item<3> item_ct1) {
|
52
|
+
timestep_embedding_f32(
|
53
|
+
x, dst, nb1, dim, max_period, item_ct1
|
54
|
+
);
|
55
|
+
});
|
56
|
+
}
|
57
|
+
|
58
|
+
void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
59
|
+
const ggml_tensor *src1, ggml_tensor * dst) {
|
60
|
+
const float * src0_d = (const float *)src0->data;
|
61
|
+
float * dst_d = (float *)dst->data;
|
62
|
+
dpct::queue_ptr stream = ctx.stream();
|
63
|
+
|
64
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
65
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
66
|
+
|
67
|
+
const int dim = dst->op_params[0];
|
68
|
+
const int max_period = dst->op_params[1];
|
69
|
+
|
70
|
+
timestep_embedding_f32_sycl(src0_d, dst_d, src0->ne[0], dst->nb[1], dim, max_period, stream);
|
71
|
+
GGML_UNUSED(src1);
|
72
|
+
}
|