whispercpp 1.3.0 → 1.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +5 -0
  3. data/LICENSE +1 -1
  4. data/README.md +165 -434
  5. data/Rakefile +60 -11
  6. data/ext/.gitignore +13 -0
  7. data/ext/cpu.mk +9 -0
  8. data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
  9. data/ext/extconf.rb +185 -16
  10. data/ext/ggml/include/ggml-alloc.h +76 -0
  11. data/ext/ggml/include/ggml-backend.h +352 -0
  12. data/ext/ggml/include/ggml-blas.h +25 -0
  13. data/ext/ggml/include/ggml-cann.h +123 -0
  14. data/ext/ggml/include/ggml-cpp.h +38 -0
  15. data/ext/ggml/include/ggml-cpu.h +135 -0
  16. data/ext/ggml/include/ggml-cuda.h +47 -0
  17. data/ext/ggml/include/ggml-kompute.h +50 -0
  18. data/ext/ggml/include/ggml-metal.h +66 -0
  19. data/ext/ggml/include/ggml-opencl.h +26 -0
  20. data/ext/ggml/include/ggml-opt.h +216 -0
  21. data/ext/ggml/include/ggml-rpc.h +28 -0
  22. data/ext/ggml/include/ggml-sycl.h +49 -0
  23. data/ext/ggml/include/ggml-vulkan.h +31 -0
  24. data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
  25. data/ext/ggml/src/ggml-alloc.c +1037 -0
  26. data/ext/ggml/src/ggml-amx/common.h +94 -0
  27. data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  28. data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
  29. data/ext/ggml/src/ggml-amx/mmq.h +17 -0
  30. data/ext/ggml/src/ggml-backend-impl.h +256 -0
  31. data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
  32. data/ext/ggml/src/ggml-backend.cpp +1999 -0
  33. data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
  34. data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
  35. data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
  36. data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
  37. data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  38. data/ext/ggml/src/ggml-cann/common.h +286 -0
  39. data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
  40. data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
  41. data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
  42. data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
  43. data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
  44. data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
  45. data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  46. data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
  47. data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
  48. data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
  49. data/ext/ggml/src/ggml-common.h +1853 -0
  50. data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  51. data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
  52. data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
  53. data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  54. data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  55. data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  56. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
  57. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  58. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  59. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  60. data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
  61. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  62. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  63. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  64. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  65. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  66. data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
  67. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
  68. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  69. data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  70. data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
  71. data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  72. data/ext/ggml/src/ggml-impl.h +556 -0
  73. data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
  74. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  75. data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
  76. data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
  77. data/ext/ggml/src/ggml-opt.cpp +854 -0
  78. data/ext/ggml/src/ggml-quants.c +5238 -0
  79. data/ext/ggml/src/ggml-quants.h +100 -0
  80. data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
  81. data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
  82. data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
  83. data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
  84. data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
  85. data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  86. data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  87. data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
  88. data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
  89. data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  90. data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
  91. data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
  92. data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
  93. data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
  94. data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
  95. data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
  96. data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  97. data/ext/ggml/src/ggml-threading.cpp +12 -0
  98. data/ext/ggml/src/ggml-threading.h +14 -0
  99. data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
  100. data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  101. data/ext/ggml/src/ggml.c +7694 -0
  102. data/ext/{whisper.h → include/whisper.h} +23 -22
  103. data/ext/metal-embed.mk +17 -0
  104. data/ext/metal.mk +6 -0
  105. data/ext/ruby_whisper.cpp +1492 -9
  106. data/ext/ruby_whisper.h +10 -0
  107. data/ext/scripts/get-flags.mk +38 -0
  108. data/ext/src/coreml/whisper-decoder-impl.h +146 -0
  109. data/ext/src/coreml/whisper-decoder-impl.m +201 -0
  110. data/ext/src/coreml/whisper-encoder-impl.h +142 -0
  111. data/ext/src/coreml/whisper-encoder-impl.m +197 -0
  112. data/ext/src/coreml/whisper-encoder.h +26 -0
  113. data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
  114. data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
  115. data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
  116. data/extsources.rb +6 -0
  117. data/lib/whisper/model/uri.rb +157 -0
  118. data/lib/whisper.rb +2 -0
  119. data/tests/helper.rb +7 -0
  120. data/tests/jfk_reader/.gitignore +5 -0
  121. data/tests/jfk_reader/extconf.rb +3 -0
  122. data/tests/jfk_reader/jfk_reader.c +68 -0
  123. data/tests/test_callback.rb +160 -0
  124. data/tests/test_error.rb +20 -0
  125. data/tests/test_model.rb +71 -0
  126. data/tests/test_package.rb +31 -0
  127. data/tests/test_params.rb +160 -0
  128. data/tests/test_segment.rb +83 -0
  129. data/tests/test_whisper.rb +211 -123
  130. data/whispercpp.gemspec +36 -0
  131. metadata +137 -11
  132. data/ext/ggml.c +0 -21755
@@ -0,0 +1,276 @@
1
+ #include "rope.hpp"
2
+
3
+ struct rope_corr_dims {
4
+ float v[2];
5
+ };
6
+
7
+ static float rope_yarn_ramp(const float low, const float high, const int i0) {
8
+ const float y = (i0 / 2 - low) / sycl::max(0.001f, high - low);
9
+ return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
10
+ }
11
+
12
+ // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
13
+ // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
14
+ static void rope_yarn(
15
+ float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
16
+ float * cos_theta, float * sin_theta) {
17
+ // Get n-d rotational scaling corrected for extrapolation
18
+ float theta_interp = freq_scale * theta_extrap;
19
+ float theta = theta_interp;
20
+ if (ext_factor != 0.0f) {
21
+ float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
22
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
23
+
24
+ // Get n-d magnitude scaling corrected for interpolation
25
+ mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale);
26
+ }
27
+ *cos_theta = sycl::cos(theta) * mscale;
28
+ *sin_theta = sycl::sin(theta) * mscale;
29
+ }
30
+
31
+ template<typename T, bool has_ff>
32
+ static void rope_norm(
33
+ const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
34
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
35
+ const sycl::nd_item<3> &item_ct1) {
36
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
37
+ item_ct1.get_local_id(1));
38
+
39
+ if (i0 >= ne0) {
40
+ return;
41
+ }
42
+
43
+ const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
44
+ item_ct1.get_local_id(2);
45
+
46
+ if (i0 >= n_dims) {
47
+ const int i = row*ne0 + i0;
48
+
49
+ dst[i + 0] = x[i + 0];
50
+ dst[i + 1] = x[i + 1];
51
+
52
+ return;
53
+ }
54
+
55
+ const int i = row*ne0 + i0;
56
+ const int i2 = row/p_delta_rows;
57
+
58
+ const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f);
59
+
60
+ const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
61
+
62
+ float cos_theta;
63
+ float sin_theta;
64
+
65
+ rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
66
+
67
+ const float x0 = x[i + 0];
68
+ const float x1 = x[i + 1];
69
+
70
+ dst[i + 0] = x0*cos_theta - x1*sin_theta;
71
+ dst[i + 1] = x0*sin_theta + x1*cos_theta;
72
+ }
73
+
74
+ template<typename T, bool has_ff>
75
+ static void rope_neox(
76
+ const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
77
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
78
+ const sycl::nd_item<3> &item_ct1) {
79
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
80
+ item_ct1.get_local_id(1));
81
+
82
+ if (i0 >= ne0) {
83
+ return;
84
+ }
85
+
86
+ const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
87
+ item_ct1.get_local_id(2);
88
+
89
+ if (i0 >= n_dims) {
90
+ const int i = row*ne0 + i0;
91
+
92
+ dst[i + 0] = x[i + 0];
93
+ dst[i + 1] = x[i + 1];
94
+
95
+ return;
96
+ }
97
+
98
+ const int i = row*ne0 + i0/2;
99
+ const int i2 = row/p_delta_rows;
100
+
101
+ const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f);
102
+
103
+ const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
104
+
105
+ float cos_theta;
106
+ float sin_theta;
107
+
108
+ rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
109
+
110
+ const float x0 = x[i + 0];
111
+ const float x1 = x[i + n_dims/2];
112
+
113
+ dst[i + 0] = x0*cos_theta - x1*sin_theta;
114
+ dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
115
+ }
116
+
117
+ template <typename T>
118
+ static void rope_norm_sycl(
119
+ const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
120
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
121
+ GGML_ASSERT(ne0 % 2 == 0);
122
+ const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
123
+ const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
124
+ const sycl::range<3> block_nums(1, num_blocks_x, nr);
125
+
126
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
127
+
128
+ dpct::has_capability_or_fail(stream->get_device(),
129
+ {sycl::aspect::fp16});
130
+
131
+ if (freq_factors == nullptr) {
132
+ /*
133
+ DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
134
+ the limit. To get the device limit, query
135
+ info::device::max_work_group_size. Adjust the work-group size if needed.
136
+ */
137
+ stream->parallel_for(
138
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
139
+ [=](sycl::nd_item<3> item_ct1) {
140
+ rope_norm<T, false>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
141
+ ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
142
+ item_ct1);
143
+ });
144
+ } else {
145
+ /*
146
+ DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
147
+ the limit. To get the device limit, query
148
+ info::device::max_work_group_size. Adjust the work-group size if needed.
149
+ */
150
+ stream->parallel_for(
151
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
152
+ [=](sycl::nd_item<3> item_ct1) {
153
+ rope_norm<T, true>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
154
+ ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
155
+ item_ct1);
156
+ });
157
+ }
158
+ }
159
+
160
+ template <typename T>
161
+ static void rope_neox_sycl(
162
+ const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
163
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
164
+ GGML_ASSERT(ne0 % 2 == 0);
165
+ const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
166
+ const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
167
+ const sycl::range<3> block_nums(1, num_blocks_x, nr);
168
+
169
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
170
+
171
+ dpct::has_capability_or_fail(stream->get_device(),
172
+ {sycl::aspect::fp16});
173
+
174
+ if (freq_factors == nullptr) {
175
+ stream->parallel_for(
176
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
177
+ [=](sycl::nd_item<3> item_ct1) {
178
+ rope_neox<T, false>(x, dst, ne0, n_dims, pos, freq_scale,
179
+ p_delta_rows, ext_factor, attn_factor,
180
+ corr_dims, theta_scale, freq_factors,
181
+ item_ct1);
182
+ });
183
+ } else {
184
+ stream->parallel_for(
185
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
186
+ [=](sycl::nd_item<3> item_ct1) {
187
+ rope_neox<T, true>(x, dst, ne0, n_dims, pos, freq_scale,
188
+ p_delta_rows, ext_factor, attn_factor,
189
+ corr_dims, theta_scale, freq_factors,
190
+ item_ct1);
191
+ });
192
+ }
193
+ }
194
+
195
+ void ggml_sycl_op_rope(
196
+ ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
197
+ const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) {
198
+ const ggml_tensor * src2 = dst->src[2];
199
+
200
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
201
+ GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
202
+ GGML_ASSERT(src0->type == dst->type);
203
+
204
+ const int64_t ne00 = src0->ne[0];
205
+ const int64_t ne01 = src0->ne[1];
206
+ const int64_t nr = ggml_nrows(src0);
207
+
208
+ //const int n_past = ((int32_t *) dst->op_params)[0];
209
+ const int n_dims = ((int32_t *) dst->op_params)[1];
210
+ const int mode = ((int32_t *) dst->op_params)[2];
211
+ //const int n_ctx = ((int32_t *) dst->op_params)[3];
212
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
213
+
214
+ // RoPE alteration for extended context
215
+ float freq_base;
216
+ float freq_scale;
217
+ float ext_factor;
218
+ float attn_factor;
219
+ float beta_fast;
220
+ float beta_slow;
221
+
222
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
223
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
224
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
225
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
226
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
227
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
228
+
229
+ const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
230
+
231
+ const int32_t * pos = (const int32_t *) src1_dd;
232
+
233
+ const float * freq_factors = nullptr;
234
+ if (src2 != nullptr) {
235
+ freq_factors = (const float *) src2->data;
236
+ }
237
+
238
+ rope_corr_dims corr_dims;
239
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
240
+
241
+ // compute
242
+ if (is_neox) {
243
+ if (src0->type == GGML_TYPE_F32) {
244
+ rope_neox_sycl(
245
+ (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
246
+ attn_factor, corr_dims, freq_factors, main_stream
247
+ );
248
+ } else if (src0->type == GGML_TYPE_F16) {
249
+ rope_neox_sycl(
250
+ (const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
251
+ attn_factor, corr_dims, freq_factors, main_stream
252
+ );
253
+ } else {
254
+ GGML_ABORT("fatal error");
255
+ }
256
+ } else {
257
+ if (src0->type == GGML_TYPE_F32) {
258
+ rope_norm_sycl(
259
+ (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
260
+ attn_factor, corr_dims, freq_factors, main_stream
261
+ );
262
+ } else if (src0->type == GGML_TYPE_F16) {
263
+ rope_norm_sycl(
264
+ (const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
265
+ attn_factor, corr_dims, freq_factors, main_stream
266
+ );
267
+ } else {
268
+ GGML_ABORT("fatal error");
269
+ }
270
+ }
271
+
272
+ GGML_UNUSED(src1);
273
+ GGML_UNUSED(dst);
274
+ GGML_UNUSED(src1_dd);
275
+ GGML_UNUSED(ctx);
276
+ }
@@ -0,0 +1,251 @@
1
+ #include "norm.hpp"
2
+
3
+ template <bool vals_smem, int ncols_template, int block_size_template>
4
+ static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
5
+ const int nrows_y, const float scale, const float max_bias, const float m0,
6
+ const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
7
+ const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
8
+
9
+ const int tid = item_ct1.get_local_id(2);
10
+ const int rowx = item_ct1.get_group(2);
11
+ const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
12
+
13
+ const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
14
+
15
+ const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
16
+ const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
17
+ const int nthreads = block_size;
18
+ const int nwarps = nthreads / WARP_SIZE;
19
+ size_t nreduce = nwarps / WARP_SIZE;
20
+ float slope = 1.0f;
21
+
22
+ // ALiBi
23
+ if (max_bias > 0.0f) {
24
+ const uint32_t h = rowx/nrows_y; // head index
25
+
26
+ const float base = h < n_head_log2 ? m0 : m1;
27
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
28
+
29
+ slope = sycl::pow(base, float(exp));
30
+ }
31
+
32
+ float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
33
+ float max_val = -INFINITY;
34
+
35
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
36
+ const int col = col0 + tid;
37
+
38
+ if (ncols_template == 0 && col >= ncols) {
39
+ break;
40
+ }
41
+
42
+ const int ix = rowx*ncols + col;
43
+ const int iy = rowy*ncols + col;
44
+
45
+ const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
46
+
47
+ vals[col] = val;
48
+ max_val = sycl::max(max_val, val);
49
+ }
50
+
51
+ // find the max value in the block
52
+ max_val = warp_reduce_max(max_val, item_ct1);
53
+ if (block_size > WARP_SIZE) {
54
+ if (warp_id == 0) {
55
+ buf[lane_id] = -INFINITY;
56
+ for (size_t i = 1; i < nreduce; i += 1) {
57
+ buf[lane_id + i * WARP_SIZE] = -INFINITY;
58
+ }
59
+ }
60
+ item_ct1.barrier(sycl::access::fence_space::local_space);
61
+
62
+ if (lane_id == 0) {
63
+ buf[warp_id] = max_val;
64
+ }
65
+ item_ct1.barrier(sycl::access::fence_space::local_space);
66
+ max_val = buf[lane_id];
67
+ for (size_t i = 1; i < nreduce; i += 1) {
68
+ max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]);
69
+ }
70
+ max_val = warp_reduce_max(max_val, item_ct1);
71
+ }
72
+
73
+ float tmp = 0.f;
74
+ #pragma unroll
75
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
76
+ const int col = col0 + tid;
77
+ if (ncols_template == 0 && col >= ncols) {
78
+ break;
79
+ }
80
+
81
+ const float val = sycl::native::exp(vals[col] - max_val);
82
+ tmp += val;
83
+ vals[col] = val;
84
+ }
85
+
86
+ // find the sum of exps in the block
87
+ tmp = warp_reduce_sum(tmp, item_ct1);
88
+ if (block_size > WARP_SIZE) {
89
+ item_ct1.barrier(sycl::access::fence_space::local_space);
90
+ if (warp_id == 0) {
91
+ buf[lane_id] = 0.f;
92
+ for (size_t i = 1; i < nreduce; i += 1) {
93
+ buf[lane_id + i * WARP_SIZE] = 0.f;
94
+ }
95
+ }
96
+ item_ct1.barrier(sycl::access::fence_space::local_space);
97
+
98
+ if (lane_id == 0) {
99
+ buf[warp_id] = tmp;
100
+ }
101
+ item_ct1.barrier(sycl::access::fence_space::local_space);
102
+
103
+ tmp = buf[lane_id];
104
+ for (size_t i = 1; i < nreduce; i += 1) {
105
+ tmp += buf[lane_id + i * WARP_SIZE];
106
+ }
107
+ tmp = warp_reduce_sum(tmp, item_ct1);
108
+ }
109
+
110
+ const float inv_sum = 1.f / tmp;
111
+
112
+ #pragma unroll
113
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
114
+ const int col = col0 + tid;
115
+
116
+ if (ncols_template == 0 && col >= ncols) {
117
+ return;
118
+ }
119
+
120
+ const int idst = rowx*ncols + col;
121
+ dst[idst] = vals[col] * inv_sum;
122
+ }
123
+ }
124
+
125
+ template <bool vals_smem, int ncols_template, int block_size_template>
126
+ static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
127
+ const int nrows_y, const float scale, const float max_bias, const float m0,
128
+ const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
129
+ const size_t n_local_scratch, queue_ptr stream) {
130
+ stream->submit([&](sycl::handler &cgh) {
131
+ sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
132
+
133
+ cgh.parallel_for(
134
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
135
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
136
+ soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
137
+ nrows_y, scale, max_bias, m0,
138
+ m1, n_head_log2, item_ct1,
139
+ get_pointer(local_buf_acc));
140
+ });
141
+ });
142
+ }
143
+
144
+ static void soft_max_f32_sycl(const float * x, const float * mask,
145
+ float * dst, const int ncols_x, const int nrows_x,
146
+ const int nrows_y, const float scale, const float max_bias,
147
+ queue_ptr stream, int device) {
148
+ int nth = WARP_SIZE;
149
+ int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
150
+ while (nth < ncols_x && nth < max_block_size) nth *= 2;
151
+ if (nth>max_block_size) nth = max_block_size;
152
+
153
+ const sycl::range<3> block_dims(1, 1, nth);
154
+ const sycl::range<3> block_nums(1, 1, nrows_x);
155
+ const size_t n_val_tmp = nth / WARP_SIZE;
156
+ const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + n_val_tmp);
157
+
158
+ const uint32_t n_head_kv = nrows_x/nrows_y;
159
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
160
+
161
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
162
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
163
+
164
+ const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
165
+ if (n_local_scratch*sizeof(float) < local_mem_size) {
166
+ if (ncols_x > max_block_size) {
167
+ soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
168
+ max_bias, m0, m1, n_head_log2, block_nums,
169
+ block_dims, n_local_scratch, stream);
170
+ return;
171
+ }
172
+ switch (ncols_x) {
173
+ case 32:
174
+ soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
175
+ max_bias, m0, m1, n_head_log2, block_nums,
176
+ block_dims, n_local_scratch, stream);
177
+ break;
178
+ case 64:
179
+ soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
180
+ max_bias, m0, m1, n_head_log2, block_nums,
181
+ block_dims, n_local_scratch, stream);
182
+ break;
183
+ case 128:
184
+ soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
185
+ max_bias, m0, m1, n_head_log2, block_nums,
186
+ block_dims, n_local_scratch, stream);
187
+ break;
188
+ case 256:
189
+ soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
190
+ max_bias, m0, m1, n_head_log2, block_nums,
191
+ block_dims, n_local_scratch, stream);
192
+ break;
193
+ case 512:
194
+ soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
195
+ max_bias, m0, m1, n_head_log2, block_nums,
196
+ block_dims, n_local_scratch, stream);
197
+ break;
198
+ case 1024:
199
+ soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
200
+ max_bias, m0, m1, n_head_log2, block_nums,
201
+ block_dims, n_local_scratch, stream);
202
+ break;
203
+ case 2048:
204
+ soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
205
+ max_bias, m0, m1, n_head_log2, block_nums,
206
+ block_dims, n_local_scratch, stream);
207
+ break;
208
+ case 4096:
209
+ soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
210
+ max_bias, m0, m1, n_head_log2, block_nums,
211
+ block_dims, n_local_scratch, stream);
212
+ break;
213
+ default:
214
+ soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
215
+ max_bias, m0, m1, n_head_log2, block_nums,
216
+ block_dims, n_local_scratch, stream);
217
+ break;
218
+ }
219
+ } else {
220
+ soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
221
+ max_bias, m0, m1, n_head_log2, block_nums,
222
+ block_dims, WARP_SIZE, stream);
223
+ }
224
+ }
225
+
226
+ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
227
+ const ggml_tensor *src1, ggml_tensor *dst,
228
+ const float *src0_dd, const float *src1_dd,
229
+ float *dst_dd,
230
+ const queue_ptr &main_stream) {
231
+
232
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
233
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
234
+
235
+ #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
236
+ #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
237
+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
238
+
239
+ const int64_t ne00 = src0->ne[0];
240
+ const int64_t nrows_x = ggml_nrows(src0);
241
+ const int64_t nrows_y = src0->ne[1];
242
+
243
+ float scale = 1.0f;
244
+ float max_bias = 0.0f;
245
+
246
+ memcpy(&scale, dst->op_params + 0, sizeof(float));
247
+ memcpy(&max_bias, dst->op_params + 1, sizeof(float));
248
+
249
+ soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
250
+ nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
251
+ }
@@ -0,0 +1,72 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #include "tsembd.hpp"
14
+
15
+ static void timestep_embedding_f32(
16
+ const float * timesteps, float * dst, const int nb1,
17
+ const int dim, const int max_period, const sycl::nd_item<3> &item_ct1) {
18
+ // item_ct1.get_group(1)(blockIDx.y): idx of timesteps->ne[0]
19
+ // item_ct1.get_group(2) (blockIDx.x): idx of ((dim + 1) / 2) / BLOCK_SIZE
20
+ int i = item_ct1.get_group(1);
21
+ int j = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2);
22
+ float * embed_data = (float *)((char *)dst + i*nb1);
23
+
24
+ if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
25
+ embed_data[dim] = 0.f;
26
+ }
27
+
28
+ int half = dim / 2;
29
+ if (j >= half) {
30
+ return;
31
+ }
32
+
33
+ float timestep = timesteps[i];
34
+ float freq = (float)sycl::native::exp(-(sycl::log((float)max_period)) * j / half);
35
+ float arg = timestep * freq;
36
+ embed_data[j] = sycl::cos(arg);
37
+ embed_data[j + half] = sycl::sin(arg);
38
+ }
39
+
40
+ static void timestep_embedding_f32_sycl(
41
+ const float * x, float * dst, const int ne00, const int nb1,
42
+ const int dim, const int max_period, const queue_ptr& stream) {
43
+ // As the kernel returns when thread.idx is larger than dim/2, the half_ceil does not need to pad
44
+ int half_ceil = dim / 2;
45
+ int num_blocks = (half_ceil + SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE;
46
+ sycl::range<3> block_dims(1, 1, SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE);
47
+ sycl::range<3> gridDim(1, ne00, num_blocks);
48
+ stream->parallel_for(
49
+ sycl::nd_range<3>(
50
+ gridDim * block_dims, block_dims),
51
+ [=](sycl::nd_item<3> item_ct1) {
52
+ timestep_embedding_f32(
53
+ x, dst, nb1, dim, max_period, item_ct1
54
+ );
55
+ });
56
+ }
57
+
58
+ void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
59
+ const ggml_tensor *src1, ggml_tensor * dst) {
60
+ const float * src0_d = (const float *)src0->data;
61
+ float * dst_d = (float *)dst->data;
62
+ dpct::queue_ptr stream = ctx.stream();
63
+
64
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
65
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
66
+
67
+ const int dim = dst->op_params[0];
68
+ const int max_period = dst->op_params[1];
69
+
70
+ timestep_embedding_f32_sycl(src0_d, dst_d, src0->ne[0], dst->nb[1], dim, max_period, stream);
71
+ GGML_UNUSED(src1);
72
+ }