whispercpp 1.3.0 → 1.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (132) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +5 -0
  3. data/LICENSE +1 -1
  4. data/README.md +165 -434
  5. data/Rakefile +60 -11
  6. data/ext/.gitignore +13 -0
  7. data/ext/cpu.mk +9 -0
  8. data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
  9. data/ext/extconf.rb +185 -16
  10. data/ext/ggml/include/ggml-alloc.h +76 -0
  11. data/ext/ggml/include/ggml-backend.h +352 -0
  12. data/ext/ggml/include/ggml-blas.h +25 -0
  13. data/ext/ggml/include/ggml-cann.h +123 -0
  14. data/ext/ggml/include/ggml-cpp.h +38 -0
  15. data/ext/ggml/include/ggml-cpu.h +135 -0
  16. data/ext/ggml/include/ggml-cuda.h +47 -0
  17. data/ext/ggml/include/ggml-kompute.h +50 -0
  18. data/ext/ggml/include/ggml-metal.h +66 -0
  19. data/ext/ggml/include/ggml-opencl.h +26 -0
  20. data/ext/ggml/include/ggml-opt.h +216 -0
  21. data/ext/ggml/include/ggml-rpc.h +28 -0
  22. data/ext/ggml/include/ggml-sycl.h +49 -0
  23. data/ext/ggml/include/ggml-vulkan.h +31 -0
  24. data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
  25. data/ext/ggml/src/ggml-alloc.c +1037 -0
  26. data/ext/ggml/src/ggml-amx/common.h +94 -0
  27. data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  28. data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
  29. data/ext/ggml/src/ggml-amx/mmq.h +17 -0
  30. data/ext/ggml/src/ggml-backend-impl.h +256 -0
  31. data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
  32. data/ext/ggml/src/ggml-backend.cpp +1999 -0
  33. data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
  34. data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
  35. data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
  36. data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
  37. data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  38. data/ext/ggml/src/ggml-cann/common.h +286 -0
  39. data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
  40. data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
  41. data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
  42. data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
  43. data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
  44. data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
  45. data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  46. data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
  47. data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
  48. data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
  49. data/ext/ggml/src/ggml-common.h +1853 -0
  50. data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  51. data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
  52. data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
  53. data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  54. data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  55. data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  56. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
  57. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  58. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  59. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  60. data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
  61. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  62. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  63. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  64. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  65. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  66. data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
  67. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
  68. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  69. data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  70. data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
  71. data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  72. data/ext/ggml/src/ggml-impl.h +556 -0
  73. data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
  74. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  75. data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
  76. data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
  77. data/ext/ggml/src/ggml-opt.cpp +854 -0
  78. data/ext/ggml/src/ggml-quants.c +5238 -0
  79. data/ext/ggml/src/ggml-quants.h +100 -0
  80. data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
  81. data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
  82. data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
  83. data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
  84. data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
  85. data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  86. data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  87. data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
  88. data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
  89. data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  90. data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
  91. data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
  92. data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
  93. data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
  94. data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
  95. data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
  96. data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  97. data/ext/ggml/src/ggml-threading.cpp +12 -0
  98. data/ext/ggml/src/ggml-threading.h +14 -0
  99. data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
  100. data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  101. data/ext/ggml/src/ggml.c +7694 -0
  102. data/ext/{whisper.h → include/whisper.h} +23 -22
  103. data/ext/metal-embed.mk +17 -0
  104. data/ext/metal.mk +6 -0
  105. data/ext/ruby_whisper.cpp +1492 -9
  106. data/ext/ruby_whisper.h +10 -0
  107. data/ext/scripts/get-flags.mk +38 -0
  108. data/ext/src/coreml/whisper-decoder-impl.h +146 -0
  109. data/ext/src/coreml/whisper-decoder-impl.m +201 -0
  110. data/ext/src/coreml/whisper-encoder-impl.h +142 -0
  111. data/ext/src/coreml/whisper-encoder-impl.m +197 -0
  112. data/ext/src/coreml/whisper-encoder.h +26 -0
  113. data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
  114. data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
  115. data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
  116. data/extsources.rb +6 -0
  117. data/lib/whisper/model/uri.rb +157 -0
  118. data/lib/whisper.rb +2 -0
  119. data/tests/helper.rb +7 -0
  120. data/tests/jfk_reader/.gitignore +5 -0
  121. data/tests/jfk_reader/extconf.rb +3 -0
  122. data/tests/jfk_reader/jfk_reader.c +68 -0
  123. data/tests/test_callback.rb +160 -0
  124. data/tests/test_error.rb +20 -0
  125. data/tests/test_model.rb +71 -0
  126. data/tests/test_package.rb +31 -0
  127. data/tests/test_params.rb +160 -0
  128. data/tests/test_segment.rb +83 -0
  129. data/tests/test_whisper.rb +211 -123
  130. data/whispercpp.gemspec +36 -0
  131. metadata +137 -11
  132. data/ext/ggml.c +0 -21755
@@ -0,0 +1,276 @@
1
+ #include "rope.hpp"
2
+
3
+ struct rope_corr_dims {
4
+ float v[2];
5
+ };
6
+
7
+ static float rope_yarn_ramp(const float low, const float high, const int i0) {
8
+ const float y = (i0 / 2 - low) / sycl::max(0.001f, high - low);
9
+ return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
10
+ }
11
+
12
+ // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
13
+ // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
14
+ static void rope_yarn(
15
+ float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
16
+ float * cos_theta, float * sin_theta) {
17
+ // Get n-d rotational scaling corrected for extrapolation
18
+ float theta_interp = freq_scale * theta_extrap;
19
+ float theta = theta_interp;
20
+ if (ext_factor != 0.0f) {
21
+ float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
22
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
23
+
24
+ // Get n-d magnitude scaling corrected for interpolation
25
+ mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale);
26
+ }
27
+ *cos_theta = sycl::cos(theta) * mscale;
28
+ *sin_theta = sycl::sin(theta) * mscale;
29
+ }
30
+
31
+ template<typename T, bool has_ff>
32
+ static void rope_norm(
33
+ const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
34
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
35
+ const sycl::nd_item<3> &item_ct1) {
36
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
37
+ item_ct1.get_local_id(1));
38
+
39
+ if (i0 >= ne0) {
40
+ return;
41
+ }
42
+
43
+ const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
44
+ item_ct1.get_local_id(2);
45
+
46
+ if (i0 >= n_dims) {
47
+ const int i = row*ne0 + i0;
48
+
49
+ dst[i + 0] = x[i + 0];
50
+ dst[i + 1] = x[i + 1];
51
+
52
+ return;
53
+ }
54
+
55
+ const int i = row*ne0 + i0;
56
+ const int i2 = row/p_delta_rows;
57
+
58
+ const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f);
59
+
60
+ const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
61
+
62
+ float cos_theta;
63
+ float sin_theta;
64
+
65
+ rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
66
+
67
+ const float x0 = x[i + 0];
68
+ const float x1 = x[i + 1];
69
+
70
+ dst[i + 0] = x0*cos_theta - x1*sin_theta;
71
+ dst[i + 1] = x0*sin_theta + x1*cos_theta;
72
+ }
73
+
74
+ template<typename T, bool has_ff>
75
+ static void rope_neox(
76
+ const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
77
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
78
+ const sycl::nd_item<3> &item_ct1) {
79
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
80
+ item_ct1.get_local_id(1));
81
+
82
+ if (i0 >= ne0) {
83
+ return;
84
+ }
85
+
86
+ const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
87
+ item_ct1.get_local_id(2);
88
+
89
+ if (i0 >= n_dims) {
90
+ const int i = row*ne0 + i0;
91
+
92
+ dst[i + 0] = x[i + 0];
93
+ dst[i + 1] = x[i + 1];
94
+
95
+ return;
96
+ }
97
+
98
+ const int i = row*ne0 + i0/2;
99
+ const int i2 = row/p_delta_rows;
100
+
101
+ const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f);
102
+
103
+ const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
104
+
105
+ float cos_theta;
106
+ float sin_theta;
107
+
108
+ rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
109
+
110
+ const float x0 = x[i + 0];
111
+ const float x1 = x[i + n_dims/2];
112
+
113
+ dst[i + 0] = x0*cos_theta - x1*sin_theta;
114
+ dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
115
+ }
116
+
117
+ template <typename T>
118
+ static void rope_norm_sycl(
119
+ const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
120
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
121
+ GGML_ASSERT(ne0 % 2 == 0);
122
+ const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
123
+ const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
124
+ const sycl::range<3> block_nums(1, num_blocks_x, nr);
125
+
126
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
127
+
128
+ dpct::has_capability_or_fail(stream->get_device(),
129
+ {sycl::aspect::fp16});
130
+
131
+ if (freq_factors == nullptr) {
132
+ /*
133
+ DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
134
+ the limit. To get the device limit, query
135
+ info::device::max_work_group_size. Adjust the work-group size if needed.
136
+ */
137
+ stream->parallel_for(
138
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
139
+ [=](sycl::nd_item<3> item_ct1) {
140
+ rope_norm<T, false>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
141
+ ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
142
+ item_ct1);
143
+ });
144
+ } else {
145
+ /*
146
+ DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
147
+ the limit. To get the device limit, query
148
+ info::device::max_work_group_size. Adjust the work-group size if needed.
149
+ */
150
+ stream->parallel_for(
151
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
152
+ [=](sycl::nd_item<3> item_ct1) {
153
+ rope_norm<T, true>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
154
+ ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
155
+ item_ct1);
156
+ });
157
+ }
158
+ }
159
+
160
+ template <typename T>
161
+ static void rope_neox_sycl(
162
+ const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
163
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
164
+ GGML_ASSERT(ne0 % 2 == 0);
165
+ const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
166
+ const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
167
+ const sycl::range<3> block_nums(1, num_blocks_x, nr);
168
+
169
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
170
+
171
+ dpct::has_capability_or_fail(stream->get_device(),
172
+ {sycl::aspect::fp16});
173
+
174
+ if (freq_factors == nullptr) {
175
+ stream->parallel_for(
176
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
177
+ [=](sycl::nd_item<3> item_ct1) {
178
+ rope_neox<T, false>(x, dst, ne0, n_dims, pos, freq_scale,
179
+ p_delta_rows, ext_factor, attn_factor,
180
+ corr_dims, theta_scale, freq_factors,
181
+ item_ct1);
182
+ });
183
+ } else {
184
+ stream->parallel_for(
185
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
186
+ [=](sycl::nd_item<3> item_ct1) {
187
+ rope_neox<T, true>(x, dst, ne0, n_dims, pos, freq_scale,
188
+ p_delta_rows, ext_factor, attn_factor,
189
+ corr_dims, theta_scale, freq_factors,
190
+ item_ct1);
191
+ });
192
+ }
193
+ }
194
+
195
+ void ggml_sycl_op_rope(
196
+ ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
197
+ const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) {
198
+ const ggml_tensor * src2 = dst->src[2];
199
+
200
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
201
+ GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
202
+ GGML_ASSERT(src0->type == dst->type);
203
+
204
+ const int64_t ne00 = src0->ne[0];
205
+ const int64_t ne01 = src0->ne[1];
206
+ const int64_t nr = ggml_nrows(src0);
207
+
208
+ //const int n_past = ((int32_t *) dst->op_params)[0];
209
+ const int n_dims = ((int32_t *) dst->op_params)[1];
210
+ const int mode = ((int32_t *) dst->op_params)[2];
211
+ //const int n_ctx = ((int32_t *) dst->op_params)[3];
212
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
213
+
214
+ // RoPE alteration for extended context
215
+ float freq_base;
216
+ float freq_scale;
217
+ float ext_factor;
218
+ float attn_factor;
219
+ float beta_fast;
220
+ float beta_slow;
221
+
222
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
223
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
224
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
225
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
226
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
227
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
228
+
229
+ const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
230
+
231
+ const int32_t * pos = (const int32_t *) src1_dd;
232
+
233
+ const float * freq_factors = nullptr;
234
+ if (src2 != nullptr) {
235
+ freq_factors = (const float *) src2->data;
236
+ }
237
+
238
+ rope_corr_dims corr_dims;
239
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
240
+
241
+ // compute
242
+ if (is_neox) {
243
+ if (src0->type == GGML_TYPE_F32) {
244
+ rope_neox_sycl(
245
+ (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
246
+ attn_factor, corr_dims, freq_factors, main_stream
247
+ );
248
+ } else if (src0->type == GGML_TYPE_F16) {
249
+ rope_neox_sycl(
250
+ (const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
251
+ attn_factor, corr_dims, freq_factors, main_stream
252
+ );
253
+ } else {
254
+ GGML_ABORT("fatal error");
255
+ }
256
+ } else {
257
+ if (src0->type == GGML_TYPE_F32) {
258
+ rope_norm_sycl(
259
+ (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
260
+ attn_factor, corr_dims, freq_factors, main_stream
261
+ );
262
+ } else if (src0->type == GGML_TYPE_F16) {
263
+ rope_norm_sycl(
264
+ (const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
265
+ attn_factor, corr_dims, freq_factors, main_stream
266
+ );
267
+ } else {
268
+ GGML_ABORT("fatal error");
269
+ }
270
+ }
271
+
272
+ GGML_UNUSED(src1);
273
+ GGML_UNUSED(dst);
274
+ GGML_UNUSED(src1_dd);
275
+ GGML_UNUSED(ctx);
276
+ }
@@ -0,0 +1,251 @@
1
+ #include "norm.hpp"
2
+
3
+ template <bool vals_smem, int ncols_template, int block_size_template>
4
+ static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
5
+ const int nrows_y, const float scale, const float max_bias, const float m0,
6
+ const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
7
+ const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
8
+
9
+ const int tid = item_ct1.get_local_id(2);
10
+ const int rowx = item_ct1.get_group(2);
11
+ const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
12
+
13
+ const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
14
+
15
+ const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
16
+ const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
17
+ const int nthreads = block_size;
18
+ const int nwarps = nthreads / WARP_SIZE;
19
+ size_t nreduce = nwarps / WARP_SIZE;
20
+ float slope = 1.0f;
21
+
22
+ // ALiBi
23
+ if (max_bias > 0.0f) {
24
+ const uint32_t h = rowx/nrows_y; // head index
25
+
26
+ const float base = h < n_head_log2 ? m0 : m1;
27
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
28
+
29
+ slope = sycl::pow(base, float(exp));
30
+ }
31
+
32
+ float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
33
+ float max_val = -INFINITY;
34
+
35
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
36
+ const int col = col0 + tid;
37
+
38
+ if (ncols_template == 0 && col >= ncols) {
39
+ break;
40
+ }
41
+
42
+ const int ix = rowx*ncols + col;
43
+ const int iy = rowy*ncols + col;
44
+
45
+ const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
46
+
47
+ vals[col] = val;
48
+ max_val = sycl::max(max_val, val);
49
+ }
50
+
51
+ // find the max value in the block
52
+ max_val = warp_reduce_max(max_val, item_ct1);
53
+ if (block_size > WARP_SIZE) {
54
+ if (warp_id == 0) {
55
+ buf[lane_id] = -INFINITY;
56
+ for (size_t i = 1; i < nreduce; i += 1) {
57
+ buf[lane_id + i * WARP_SIZE] = -INFINITY;
58
+ }
59
+ }
60
+ item_ct1.barrier(sycl::access::fence_space::local_space);
61
+
62
+ if (lane_id == 0) {
63
+ buf[warp_id] = max_val;
64
+ }
65
+ item_ct1.barrier(sycl::access::fence_space::local_space);
66
+ max_val = buf[lane_id];
67
+ for (size_t i = 1; i < nreduce; i += 1) {
68
+ max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]);
69
+ }
70
+ max_val = warp_reduce_max(max_val, item_ct1);
71
+ }
72
+
73
+ float tmp = 0.f;
74
+ #pragma unroll
75
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
76
+ const int col = col0 + tid;
77
+ if (ncols_template == 0 && col >= ncols) {
78
+ break;
79
+ }
80
+
81
+ const float val = sycl::native::exp(vals[col] - max_val);
82
+ tmp += val;
83
+ vals[col] = val;
84
+ }
85
+
86
+ // find the sum of exps in the block
87
+ tmp = warp_reduce_sum(tmp, item_ct1);
88
+ if (block_size > WARP_SIZE) {
89
+ item_ct1.barrier(sycl::access::fence_space::local_space);
90
+ if (warp_id == 0) {
91
+ buf[lane_id] = 0.f;
92
+ for (size_t i = 1; i < nreduce; i += 1) {
93
+ buf[lane_id + i * WARP_SIZE] = 0.f;
94
+ }
95
+ }
96
+ item_ct1.barrier(sycl::access::fence_space::local_space);
97
+
98
+ if (lane_id == 0) {
99
+ buf[warp_id] = tmp;
100
+ }
101
+ item_ct1.barrier(sycl::access::fence_space::local_space);
102
+
103
+ tmp = buf[lane_id];
104
+ for (size_t i = 1; i < nreduce; i += 1) {
105
+ tmp += buf[lane_id + i * WARP_SIZE];
106
+ }
107
+ tmp = warp_reduce_sum(tmp, item_ct1);
108
+ }
109
+
110
+ const float inv_sum = 1.f / tmp;
111
+
112
+ #pragma unroll
113
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
114
+ const int col = col0 + tid;
115
+
116
+ if (ncols_template == 0 && col >= ncols) {
117
+ return;
118
+ }
119
+
120
+ const int idst = rowx*ncols + col;
121
+ dst[idst] = vals[col] * inv_sum;
122
+ }
123
+ }
124
+
125
+ template <bool vals_smem, int ncols_template, int block_size_template>
126
+ static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
127
+ const int nrows_y, const float scale, const float max_bias, const float m0,
128
+ const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
129
+ const size_t n_local_scratch, queue_ptr stream) {
130
+ stream->submit([&](sycl::handler &cgh) {
131
+ sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
132
+
133
+ cgh.parallel_for(
134
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
135
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
136
+ soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
137
+ nrows_y, scale, max_bias, m0,
138
+ m1, n_head_log2, item_ct1,
139
+ get_pointer(local_buf_acc));
140
+ });
141
+ });
142
+ }
143
+
144
+ static void soft_max_f32_sycl(const float * x, const float * mask,
145
+ float * dst, const int ncols_x, const int nrows_x,
146
+ const int nrows_y, const float scale, const float max_bias,
147
+ queue_ptr stream, int device) {
148
+ int nth = WARP_SIZE;
149
+ int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
150
+ while (nth < ncols_x && nth < max_block_size) nth *= 2;
151
+ if (nth>max_block_size) nth = max_block_size;
152
+
153
+ const sycl::range<3> block_dims(1, 1, nth);
154
+ const sycl::range<3> block_nums(1, 1, nrows_x);
155
+ const size_t n_val_tmp = nth / WARP_SIZE;
156
+ const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + n_val_tmp);
157
+
158
+ const uint32_t n_head_kv = nrows_x/nrows_y;
159
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
160
+
161
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
162
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
163
+
164
+ const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
165
+ if (n_local_scratch*sizeof(float) < local_mem_size) {
166
+ if (ncols_x > max_block_size) {
167
+ soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
168
+ max_bias, m0, m1, n_head_log2, block_nums,
169
+ block_dims, n_local_scratch, stream);
170
+ return;
171
+ }
172
+ switch (ncols_x) {
173
+ case 32:
174
+ soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
175
+ max_bias, m0, m1, n_head_log2, block_nums,
176
+ block_dims, n_local_scratch, stream);
177
+ break;
178
+ case 64:
179
+ soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
180
+ max_bias, m0, m1, n_head_log2, block_nums,
181
+ block_dims, n_local_scratch, stream);
182
+ break;
183
+ case 128:
184
+ soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
185
+ max_bias, m0, m1, n_head_log2, block_nums,
186
+ block_dims, n_local_scratch, stream);
187
+ break;
188
+ case 256:
189
+ soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
190
+ max_bias, m0, m1, n_head_log2, block_nums,
191
+ block_dims, n_local_scratch, stream);
192
+ break;
193
+ case 512:
194
+ soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
195
+ max_bias, m0, m1, n_head_log2, block_nums,
196
+ block_dims, n_local_scratch, stream);
197
+ break;
198
+ case 1024:
199
+ soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
200
+ max_bias, m0, m1, n_head_log2, block_nums,
201
+ block_dims, n_local_scratch, stream);
202
+ break;
203
+ case 2048:
204
+ soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
205
+ max_bias, m0, m1, n_head_log2, block_nums,
206
+ block_dims, n_local_scratch, stream);
207
+ break;
208
+ case 4096:
209
+ soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
210
+ max_bias, m0, m1, n_head_log2, block_nums,
211
+ block_dims, n_local_scratch, stream);
212
+ break;
213
+ default:
214
+ soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
215
+ max_bias, m0, m1, n_head_log2, block_nums,
216
+ block_dims, n_local_scratch, stream);
217
+ break;
218
+ }
219
+ } else {
220
+ soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
221
+ max_bias, m0, m1, n_head_log2, block_nums,
222
+ block_dims, WARP_SIZE, stream);
223
+ }
224
+ }
225
+
226
+ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
227
+ const ggml_tensor *src1, ggml_tensor *dst,
228
+ const float *src0_dd, const float *src1_dd,
229
+ float *dst_dd,
230
+ const queue_ptr &main_stream) {
231
+
232
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
233
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
234
+
235
+ #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
236
+ #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
237
+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
238
+
239
+ const int64_t ne00 = src0->ne[0];
240
+ const int64_t nrows_x = ggml_nrows(src0);
241
+ const int64_t nrows_y = src0->ne[1];
242
+
243
+ float scale = 1.0f;
244
+ float max_bias = 0.0f;
245
+
246
+ memcpy(&scale, dst->op_params + 0, sizeof(float));
247
+ memcpy(&max_bias, dst->op_params + 1, sizeof(float));
248
+
249
+ soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
250
+ nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
251
+ }
@@ -0,0 +1,72 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #include "tsembd.hpp"
14
+
15
+ static void timestep_embedding_f32(
16
+ const float * timesteps, float * dst, const int nb1,
17
+ const int dim, const int max_period, const sycl::nd_item<3> &item_ct1) {
18
+ // item_ct1.get_group(1)(blockIDx.y): idx of timesteps->ne[0]
19
+ // item_ct1.get_group(2) (blockIDx.x): idx of ((dim + 1) / 2) / BLOCK_SIZE
20
+ int i = item_ct1.get_group(1);
21
+ int j = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2);
22
+ float * embed_data = (float *)((char *)dst + i*nb1);
23
+
24
+ if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
25
+ embed_data[dim] = 0.f;
26
+ }
27
+
28
+ int half = dim / 2;
29
+ if (j >= half) {
30
+ return;
31
+ }
32
+
33
+ float timestep = timesteps[i];
34
+ float freq = (float)sycl::native::exp(-(sycl::log((float)max_period)) * j / half);
35
+ float arg = timestep * freq;
36
+ embed_data[j] = sycl::cos(arg);
37
+ embed_data[j + half] = sycl::sin(arg);
38
+ }
39
+
40
+ static void timestep_embedding_f32_sycl(
41
+ const float * x, float * dst, const int ne00, const int nb1,
42
+ const int dim, const int max_period, const queue_ptr& stream) {
43
+ // As the kernel returns when thread.idx is larger than dim/2, the half_ceil does not need to pad
44
+ int half_ceil = dim / 2;
45
+ int num_blocks = (half_ceil + SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE;
46
+ sycl::range<3> block_dims(1, 1, SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE);
47
+ sycl::range<3> gridDim(1, ne00, num_blocks);
48
+ stream->parallel_for(
49
+ sycl::nd_range<3>(
50
+ gridDim * block_dims, block_dims),
51
+ [=](sycl::nd_item<3> item_ct1) {
52
+ timestep_embedding_f32(
53
+ x, dst, nb1, dim, max_period, item_ct1
54
+ );
55
+ });
56
+ }
57
+
58
+ void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
59
+ const ggml_tensor *src1, ggml_tensor * dst) {
60
+ const float * src0_d = (const float *)src0->data;
61
+ float * dst_d = (float *)dst->data;
62
+ dpct::queue_ptr stream = ctx.stream();
63
+
64
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
65
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
66
+
67
+ const int dim = dst->op_params[0];
68
+ const int max_period = dst->op_params[1];
69
+
70
+ timestep_embedding_f32_sycl(src0_d, dst_d, src0->ne[0], dst->nb[1], dim, max_period, stream);
71
+ GGML_UNUSED(src1);
72
+ }