whispercpp 1.2.0.2 → 1.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (135) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +5 -0
  3. data/LICENSE +1 -1
  4. data/README.md +165 -434
  5. data/Rakefile +46 -86
  6. data/ext/.gitignore +13 -0
  7. data/ext/cpu.mk +9 -0
  8. data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
  9. data/ext/extconf.rb +185 -7
  10. data/ext/ggml/include/ggml-alloc.h +76 -0
  11. data/ext/ggml/include/ggml-backend.h +352 -0
  12. data/ext/ggml/include/ggml-blas.h +25 -0
  13. data/ext/ggml/include/ggml-cann.h +123 -0
  14. data/ext/ggml/include/ggml-cpp.h +38 -0
  15. data/ext/ggml/include/ggml-cpu.h +135 -0
  16. data/ext/ggml/include/ggml-cuda.h +47 -0
  17. data/ext/ggml/include/ggml-kompute.h +50 -0
  18. data/ext/ggml/include/ggml-metal.h +66 -0
  19. data/ext/ggml/include/ggml-opencl.h +26 -0
  20. data/ext/ggml/include/ggml-opt.h +216 -0
  21. data/ext/ggml/include/ggml-rpc.h +28 -0
  22. data/ext/ggml/include/ggml-sycl.h +49 -0
  23. data/ext/ggml/include/ggml-vulkan.h +31 -0
  24. data/ext/ggml/include/ggml.h +2285 -0
  25. data/ext/ggml/src/ggml-alloc.c +1037 -0
  26. data/ext/ggml/src/ggml-amx/common.h +94 -0
  27. data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  28. data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
  29. data/ext/ggml/src/ggml-amx/mmq.h +17 -0
  30. data/ext/ggml/src/ggml-backend-impl.h +256 -0
  31. data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
  32. data/ext/ggml/src/ggml-backend.cpp +1999 -0
  33. data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
  34. data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
  35. data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
  36. data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
  37. data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  38. data/ext/ggml/src/ggml-cann/common.h +286 -0
  39. data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
  40. data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
  41. data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
  42. data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
  43. data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
  44. data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
  45. data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  46. data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
  47. data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
  48. data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
  49. data/ext/ggml/src/ggml-common.h +1853 -0
  50. data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  51. data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
  52. data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
  53. data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  54. data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  55. data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  56. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
  57. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  58. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  59. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  60. data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
  61. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  62. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  63. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  64. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  65. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  66. data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
  67. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
  68. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  69. data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  70. data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
  71. data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  72. data/ext/ggml/src/ggml-impl.h +556 -0
  73. data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
  74. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  75. data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
  76. data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
  77. data/ext/ggml/src/ggml-opt.cpp +854 -0
  78. data/ext/ggml/src/ggml-quants.c +5238 -0
  79. data/ext/ggml/src/ggml-quants.h +100 -0
  80. data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
  81. data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
  82. data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
  83. data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
  84. data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
  85. data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  86. data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  87. data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
  88. data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
  89. data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  90. data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
  91. data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
  92. data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
  93. data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
  94. data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
  95. data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
  96. data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  97. data/ext/ggml/src/ggml-threading.cpp +12 -0
  98. data/ext/ggml/src/ggml-threading.h +14 -0
  99. data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
  100. data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  101. data/ext/ggml/src/ggml.c +7694 -0
  102. data/ext/include/whisper.h +672 -0
  103. data/ext/metal-embed.mk +17 -0
  104. data/ext/metal.mk +6 -0
  105. data/ext/ruby_whisper.cpp +1608 -159
  106. data/ext/ruby_whisper.h +10 -0
  107. data/ext/scripts/get-flags.mk +38 -0
  108. data/ext/src/coreml/whisper-decoder-impl.h +146 -0
  109. data/ext/src/coreml/whisper-decoder-impl.m +201 -0
  110. data/ext/src/coreml/whisper-encoder-impl.h +142 -0
  111. data/ext/src/coreml/whisper-encoder-impl.m +197 -0
  112. data/ext/src/coreml/whisper-encoder.h +26 -0
  113. data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
  114. data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
  115. data/ext/src/whisper.cpp +7393 -0
  116. data/extsources.rb +6 -0
  117. data/lib/whisper/model/uri.rb +157 -0
  118. data/lib/whisper.rb +2 -0
  119. data/tests/helper.rb +7 -0
  120. data/tests/jfk_reader/.gitignore +5 -0
  121. data/tests/jfk_reader/extconf.rb +3 -0
  122. data/tests/jfk_reader/jfk_reader.c +68 -0
  123. data/tests/test_callback.rb +160 -0
  124. data/tests/test_error.rb +20 -0
  125. data/tests/test_model.rb +71 -0
  126. data/tests/test_package.rb +31 -0
  127. data/tests/test_params.rb +160 -0
  128. data/tests/test_segment.rb +83 -0
  129. data/tests/test_whisper.rb +211 -123
  130. data/whispercpp.gemspec +36 -0
  131. metadata +137 -11
  132. data/ext/ggml.c +0 -8616
  133. data/ext/ggml.h +0 -748
  134. data/ext/whisper.cpp +0 -4829
  135. data/ext/whisper.h +0 -402
@@ -0,0 +1,547 @@
1
+ #include "convert.hpp"
2
+ #include "dequantize.hpp"
3
+ #include "presets.hpp"
4
+
5
+ template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
6
+ static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
7
+ const sycl::nd_item<3> &item_ct1) {
8
+ const int64_t i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
9
+ item_ct1.get_local_id(2));
10
+
11
+ if (i >= k) {
12
+ return;
13
+ }
14
+
15
+ const int64_t ib = i/qk; // block index
16
+ const int64_t iqs = (i%qk)/qr; // quant index
17
+ const int64_t iybs = i - i%qk; // y block start index
18
+ const int64_t y_offset = qr == 1 ? 1 : qk/2;
19
+
20
+ // dequantize
21
+ dfloat2 v;
22
+ dequantize_kernel(vx, ib, iqs, v);
23
+
24
+ y[iybs + iqs + 0] = v.x();
25
+ y[iybs + iqs + y_offset] = v.y();
26
+ }
27
+
28
+ template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
29
+ static void dequantize_block_sycl(const void *__restrict__ vx,
30
+ dst_t *__restrict__ y, const int64_t k,
31
+ dpct::queue_ptr stream) {
32
+ const int64_t num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE);
33
+ {
34
+ dpct::has_capability_or_fail(stream->get_device(),
35
+ {sycl::aspect::fp16});
36
+ stream->parallel_for(
37
+ sycl::nd_range<3>(
38
+ sycl::range<3>(1, 1, num_blocks) *
39
+ sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
40
+ sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
41
+ [=](sycl::nd_item<3> item_ct1) {
42
+ dequantize_block<qk, qr, dequantize_kernel>(vx, y, k, item_ct1);
43
+ });
44
+ }
45
+ }
46
+
47
+ template <typename dst_t>
48
+ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k,
49
+ dpct::queue_ptr stream) {
50
+ const int64_t nb = k / QK_K;
51
+ #if QK_K == 256
52
+ {
53
+ dpct::has_capability_or_fail(stream->get_device(),
54
+ {sycl::aspect::fp16});
55
+
56
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
57
+ sycl::range<3>(1, 1, 64),
58
+ sycl::range<3>(1, 1, 64)),
59
+ [=](sycl::nd_item<3> item_ct1) {
60
+ dequantize_block_q2_K(vx, y, item_ct1);
61
+ });
62
+ }
63
+ #else
64
+ {
65
+ dpct::has_capability_or_fail(stream->get_device(),
66
+ {sycl::aspect::fp16});
67
+
68
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
69
+ sycl::range<3>(1, 1, 32),
70
+ sycl::range<3>(1, 1, 32)),
71
+ [=](sycl::nd_item<3> item_ct1) {
72
+ dequantize_block_q2_K(vx, y, item_ct1);
73
+ });
74
+ }
75
+
76
+ #endif
77
+ }
78
+
79
+ template <typename dst_t>
80
+ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k,
81
+ dpct::queue_ptr stream) {
82
+ const int64_t nb = k / QK_K;
83
+ #if QK_K == 256
84
+ {
85
+ dpct::has_capability_or_fail(stream->get_device(),
86
+ {sycl::aspect::fp16});
87
+
88
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
89
+ sycl::range<3>(1, 1, 64),
90
+ sycl::range<3>(1, 1, 64)),
91
+ [=](sycl::nd_item<3> item_ct1) {
92
+ dequantize_block_q3_K(vx, y, item_ct1);
93
+ });
94
+ }
95
+ #else
96
+ {
97
+ dpct::has_capability_or_fail(stream->get_device(),
98
+ {sycl::aspect::fp16});
99
+
100
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
101
+ sycl::range<3>(1, 1, 32),
102
+ sycl::range<3>(1, 1, 32)),
103
+ [=](sycl::nd_item<3> item_ct1) {
104
+ dequantize_block_q3_K(vx, y, item_ct1);
105
+ });
106
+ }
107
+ #endif
108
+ }
109
+
110
+ template <typename dst_t>
111
+ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
112
+ dpct::queue_ptr stream) {
113
+ const int64_t nb32 = k / 32;
114
+ const int64_t nb = (k + 255) / 256;
115
+ {
116
+ dpct::has_capability_or_fail(stream->get_device(),
117
+ {sycl::aspect::fp16});
118
+
119
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
120
+ sycl::range<3>(1, 1, 32),
121
+ sycl::range<3>(1, 1, 32)),
122
+ [=](sycl::nd_item<3> item_ct1) {
123
+ dequantize_block_q4_0(vx, y, nb32, item_ct1);
124
+ });
125
+ }
126
+ }
127
+
128
+ template <typename dst_t>
129
+ static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
130
+ dpct::queue_ptr stream) {
131
+ const int64_t nb32 = k / 32;
132
+ const int64_t nb = (k + 255) / 256;
133
+ {
134
+ dpct::has_capability_or_fail(stream->get_device(),
135
+ {sycl::aspect::fp16});
136
+
137
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
138
+ sycl::range<3>(1, 1, 32),
139
+ sycl::range<3>(1, 1, 32)),
140
+ [=](sycl::nd_item<3> item_ct1) {
141
+ dequantize_block_q4_1(vx, y, nb32, item_ct1);
142
+ });
143
+ }
144
+ }
145
+
146
+
147
+ template <typename dst_t>
148
+ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
149
+ dpct::queue_ptr stream) {
150
+ const int64_t nb = k / QK_K;
151
+ {
152
+ dpct::has_capability_or_fail(stream->get_device(),
153
+ {sycl::aspect::fp16});
154
+
155
+ stream->submit([&](sycl::handler &cgh) {
156
+ sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
157
+ cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
158
+ sycl::range<3>(1, 1, 32),
159
+ sycl::range<3>(1, 1, 32)),
160
+ [=](sycl::nd_item<3> item_ct1) {
161
+ dequantize_block_q4_K(vx, y, get_pointer(scale_local_acc), item_ct1);
162
+ });
163
+ });
164
+ }
165
+ }
166
+
167
+ template <typename dst_t>
168
+ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
169
+ dpct::queue_ptr stream) {
170
+ const int64_t nb = k / QK_K;
171
+ #if QK_K == 256
172
+ {
173
+ dpct::has_capability_or_fail(stream->get_device(),
174
+ {sycl::aspect::fp16});
175
+
176
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
177
+ sycl::range<3>(1, 1, 64),
178
+ sycl::range<3>(1, 1, 64)),
179
+ [=](sycl::nd_item<3> item_ct1) {
180
+ dequantize_block_q5_K(vx, y, item_ct1);
181
+ });
182
+ }
183
+ #else
184
+ {
185
+ dpct::has_capability_or_fail(stream->get_device(),
186
+ {sycl::aspect::fp16});
187
+
188
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
189
+ sycl::range<3>(1, 1, 32),
190
+ sycl::range<3>(1, 1, 32)),
191
+ [=](sycl::nd_item<3> item_ct1) {
192
+ dequantize_block_q5_K(vx, y, item_ct1);
193
+ });
194
+ }
195
+
196
+ #endif
197
+ }
198
+
199
+ template <typename dst_t>
200
+ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
201
+ dpct::queue_ptr stream) {
202
+ const int64_t nb = k / QK_K;
203
+ #if QK_K == 256
204
+ {
205
+ dpct::has_capability_or_fail(stream->get_device(),
206
+ {sycl::aspect::fp16});
207
+
208
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
209
+ sycl::range<3>(1, 1, 64),
210
+ sycl::range<3>(1, 1, 64)),
211
+ [=](sycl::nd_item<3> item_ct1) {
212
+ dequantize_block_q6_K(vx, y, item_ct1);
213
+ });
214
+ }
215
+ #else
216
+ {
217
+ dpct::has_capability_or_fail(stream->get_device(),
218
+ {sycl::aspect::fp16});
219
+
220
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
221
+ sycl::range<3>(1, 1, 32),
222
+ sycl::range<3>(1, 1, 32)),
223
+ [=](sycl::nd_item<3> item_ct1) {
224
+ dequantize_block_q6_K(vx, y, item_ct1);
225
+ });
226
+ }
227
+
228
+ #endif
229
+ }
230
+
231
+ template <typename dst_t>
232
+ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
233
+ dpct::queue_ptr stream) {
234
+ const int64_t nb = k / QK_K;
235
+ {
236
+ dpct::has_capability_or_fail(stream->get_device(),
237
+ {sycl::aspect::fp16});
238
+
239
+ stream->submit([&](sycl::handler &cgh) {
240
+ cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
241
+ sycl::range<3>(1, 1, 32),
242
+ sycl::range<3>(1, 1, 32)),
243
+ [=](sycl::nd_item<3> item_ct1) {
244
+ dequantize_block_iq1_s(
245
+ vx, y, item_ct1, iq1s_grid_gpu
246
+ );
247
+ });
248
+ });
249
+ }
250
+ }
251
+
252
+ template <typename dst_t>
253
+ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int64_t k,
254
+ dpct::queue_ptr stream) {
255
+ const int64_t nb = k / QK_K;
256
+ {
257
+ dpct::has_capability_or_fail(stream->get_device(),
258
+ {sycl::aspect::fp16});
259
+
260
+ stream->submit([&](sycl::handler &cgh) {
261
+ cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
262
+ sycl::range<3>(1, 1, 32),
263
+ sycl::range<3>(1, 1, 32)),
264
+ [=](sycl::nd_item<3> item_ct1) {
265
+ dequantize_block_iq1_m(
266
+ vx, y, item_ct1, iq1s_grid_gpu
267
+ );
268
+ });
269
+ });
270
+ }
271
+ }
272
+
273
+ template <typename dst_t>
274
+ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int64_t k,
275
+ dpct::queue_ptr stream) {
276
+ const int64_t nb = k / QK_K;
277
+ {
278
+ dpct::has_capability_or_fail(stream->get_device(),
279
+ {sycl::aspect::fp16});
280
+
281
+ stream->submit([&](sycl::handler &cgh) {
282
+ cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
283
+ sycl::range<3>(1, 1, 32),
284
+ sycl::range<3>(1, 1, 32)),
285
+ [=](sycl::nd_item<3> item_ct1) {
286
+ dequantize_block_iq2_xxs(
287
+ vx, y, item_ct1, iq2xxs_grid,
288
+ ksigns_iq2xs, kmask_iq2xs);
289
+ });
290
+ });
291
+ }
292
+ }
293
+
294
+ template <typename dst_t>
295
+ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int64_t k,
296
+ dpct::queue_ptr stream) {
297
+ const int64_t nb = k / QK_K;
298
+ {
299
+ dpct::has_capability_or_fail(stream->get_device(),
300
+ {sycl::aspect::fp16});
301
+
302
+ stream->submit([&](sycl::handler &cgh) {
303
+ cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
304
+ sycl::range<3>(1, 1, 32),
305
+ sycl::range<3>(1, 1, 32)),
306
+ [=](sycl::nd_item<3> item_ct1) {
307
+ dequantize_block_iq2_xs(
308
+ vx, y, item_ct1, iq2xs_grid,
309
+ ksigns_iq2xs, kmask_iq2xs);
310
+ });
311
+ });
312
+ }
313
+ }
314
+
315
+ template <typename dst_t>
316
+ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int64_t k,
317
+ dpct::queue_ptr stream) {
318
+ const int64_t nb = k / QK_K;
319
+ {
320
+ dpct::has_capability_or_fail(stream->get_device(),
321
+ {sycl::aspect::fp16});
322
+
323
+ stream->submit([&](sycl::handler &cgh) {
324
+ cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
325
+ sycl::range<3>(1, 1, 32),
326
+ sycl::range<3>(1, 1, 32)),
327
+ [=](sycl::nd_item<3> item_ct1) {
328
+ dequantize_block_iq2_s(vx, y, item_ct1);
329
+ });
330
+ });
331
+ }
332
+ }
333
+
334
+
335
+ template <typename dst_t>
336
+ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int64_t k,
337
+ dpct::queue_ptr stream) {
338
+ const int64_t nb = k / QK_K;
339
+ {
340
+ dpct::has_capability_or_fail(stream->get_device(),
341
+ {sycl::aspect::fp16});
342
+
343
+ stream->submit([&](sycl::handler &cgh) {
344
+ cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
345
+ sycl::range<3>(1, 1, 32),
346
+ sycl::range<3>(1, 1, 32)),
347
+ [=](sycl::nd_item<3> item_ct1) {
348
+ dequantize_block_iq3_xxs(
349
+ vx, y, item_ct1, iq3xxs_grid,
350
+ ksigns_iq2xs, kmask_iq2xs);
351
+ });
352
+ });
353
+ }
354
+ }
355
+
356
+ template <typename dst_t>
357
+ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int64_t k,
358
+ dpct::queue_ptr stream) {
359
+ const int64_t nb = k / QK_K;
360
+ {
361
+ dpct::has_capability_or_fail(stream->get_device(),
362
+ {sycl::aspect::fp16});
363
+
364
+ stream->submit([&](sycl::handler &cgh) {
365
+ cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
366
+ sycl::range<3>(1, 1, 32),
367
+ sycl::range<3>(1, 1, 32)),
368
+ [=](sycl::nd_item<3> item_ct1) {
369
+ dequantize_block_iq3_s(
370
+ vx, y, item_ct1, kmask_iq2xs, iq3s_grid);
371
+ });
372
+ });
373
+ }
374
+ }
375
+
376
+ template <typename dst_t>
377
+ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int64_t k,
378
+ dpct::queue_ptr stream) {
379
+ const int64_t nb = (k + QK_K - 1) / QK_K;
380
+ #if QK_K == 64
381
+ dequantize_row_iq4_nl_sycl(vx, y, k, stream);
382
+ #else
383
+ {
384
+ dpct::has_capability_or_fail(stream->get_device(),
385
+ {sycl::aspect::fp16});
386
+
387
+ stream->submit([&](sycl::handler &cgh) {
388
+ cgh.parallel_for(
389
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
390
+ sycl::range<3>(1, 1, 32),
391
+ sycl::range<3>(1, 1, 32)),
392
+ [=](sycl::nd_item<3> item_ct1) {
393
+ dequantize_block_iq4_xs(vx, y, item_ct1);
394
+ });
395
+ });
396
+ }
397
+ #endif
398
+ }
399
+
400
+ template <typename dst_t>
401
+ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k,
402
+ dpct::queue_ptr stream) {
403
+ const int64_t nb = (k + QK_K - 1) / QK_K;
404
+ {
405
+ dpct::has_capability_or_fail(stream->get_device(),
406
+ {sycl::aspect::fp16});
407
+
408
+ stream->submit([&](sycl::handler &cgh) {
409
+ cgh.parallel_for(
410
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
411
+ sycl::range<3>(1, 1, 32),
412
+ sycl::range<3>(1, 1, 32)),
413
+ [=](sycl::nd_item<3> item_ct1) {
414
+ dequantize_block_iq4_nl(vx, y, item_ct1);
415
+ });
416
+ });
417
+ }
418
+ }
419
+
420
+ template <typename src_t, typename dst_t>
421
+ static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
422
+ const sycl::nd_item<3> &item_ct1) {
423
+ const int64_t work_group_size = item_ct1.get_local_range(2);
424
+ const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
425
+
426
+ // make each work-item deal with more elements since sycl global range can not exceed max int
427
+ const src_t * x = (const src_t *) vx;
428
+ for (int64_t i = global_id; i < k; i += work_group_size * item_ct1.get_group_range(2)) {
429
+ y[i] = x[i];
430
+ }
431
+ }
432
+
433
+ template <typename src_t, typename dst_t>
434
+ static void convert_unary_sycl(const void *__restrict__ vx,
435
+ dst_t *__restrict__ y, const int64_t k,
436
+ dpct::queue_ptr stream) {
437
+ const int64_t num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
438
+
439
+ // decrease global range when it exceeds the max int
440
+ int64_t local_size = downsample_sycl_global_range(num_blocks, SYCL_DEQUANTIZE_BLOCK_SIZE);
441
+ sycl::range<3> block_nums(1, 1, num_blocks);
442
+ sycl::range<3> local_range(1, 1, local_size);
443
+ {
444
+ dpct::has_capability_or_fail(stream->get_device(),
445
+ {sycl::aspect::fp16});
446
+
447
+ stream->parallel_for(
448
+ sycl::nd_range<3>(block_nums * local_range, local_range),
449
+ [=](sycl::nd_item<3> item_ct1) {
450
+ convert_unary<src_t>(vx, y, k, item_ct1);
451
+ });
452
+ }
453
+ }
454
+
455
+ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) {
456
+ switch (type) {
457
+ case GGML_TYPE_Q4_0:
458
+ return dequantize_block_sycl<QK4_0, QR4_0, dequantize_q4_0>;
459
+ case GGML_TYPE_Q4_1:
460
+ return dequantize_block_sycl<QK4_1, QR4_1, dequantize_q4_1>;
461
+ case GGML_TYPE_Q5_0:
462
+ return dequantize_block_sycl<QK5_0, QR5_0, dequantize_q5_0>;
463
+ case GGML_TYPE_Q5_1:
464
+ return dequantize_block_sycl<QK5_1, QR5_1, dequantize_q5_1>;
465
+ case GGML_TYPE_Q8_0:
466
+ return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>;
467
+ case GGML_TYPE_Q2_K:
468
+ return dequantize_row_q2_K_sycl;
469
+ case GGML_TYPE_Q3_K:
470
+ return dequantize_row_q3_K_sycl;
471
+ case GGML_TYPE_Q4_K:
472
+ return dequantize_row_q4_K_sycl;
473
+ case GGML_TYPE_Q5_K:
474
+ return dequantize_row_q5_K_sycl;
475
+ case GGML_TYPE_Q6_K:
476
+ return dequantize_row_q6_K_sycl;
477
+ case GGML_TYPE_IQ1_S:
478
+ return dequantize_row_iq1_s_sycl;
479
+ case GGML_TYPE_IQ1_M:
480
+ return dequantize_row_iq1_m_sycl;
481
+ case GGML_TYPE_IQ2_XXS:
482
+ return dequantize_row_iq2_xxs_sycl;
483
+ case GGML_TYPE_IQ2_XS:
484
+ return dequantize_row_iq2_xs_sycl;
485
+ case GGML_TYPE_IQ2_S:
486
+ return dequantize_row_iq2_s_sycl;
487
+ case GGML_TYPE_IQ3_XXS:
488
+ return dequantize_row_iq3_xxs_sycl;
489
+ case GGML_TYPE_IQ3_S:
490
+ return dequantize_row_iq3_s_sycl;
491
+ case GGML_TYPE_IQ4_XS:
492
+ return dequantize_row_iq4_xs_sycl;
493
+ case GGML_TYPE_IQ4_NL:
494
+ return dequantize_row_iq4_nl_sycl;
495
+ case GGML_TYPE_F32:
496
+ return convert_unary_sycl<float>;
497
+ default:
498
+ return nullptr;
499
+ }
500
+ }
501
+
502
+ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) {
503
+ switch (type) {
504
+ case GGML_TYPE_Q4_0:
505
+ return dequantize_row_q4_0_sycl;
506
+ case GGML_TYPE_Q4_1:
507
+ return dequantize_row_q4_1_sycl;
508
+ case GGML_TYPE_Q5_0:
509
+ return dequantize_block_sycl<QK5_0, QR5_0, dequantize_q5_0>;
510
+ case GGML_TYPE_Q5_1:
511
+ return dequantize_block_sycl<QK5_1, QR5_1, dequantize_q5_1>;
512
+ case GGML_TYPE_Q8_0:
513
+ return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>;
514
+ case GGML_TYPE_Q2_K:
515
+ return dequantize_row_q2_K_sycl;
516
+ case GGML_TYPE_Q3_K:
517
+ return dequantize_row_q3_K_sycl;
518
+ case GGML_TYPE_Q4_K:
519
+ return dequantize_row_q4_K_sycl;
520
+ case GGML_TYPE_Q5_K:
521
+ return dequantize_row_q5_K_sycl;
522
+ case GGML_TYPE_Q6_K:
523
+ return dequantize_row_q6_K_sycl;
524
+ case GGML_TYPE_IQ1_S:
525
+ return dequantize_row_iq1_s_sycl;
526
+ case GGML_TYPE_IQ1_M:
527
+ return dequantize_row_iq1_m_sycl;
528
+ case GGML_TYPE_IQ2_XXS:
529
+ return dequantize_row_iq2_xxs_sycl;
530
+ case GGML_TYPE_IQ2_XS:
531
+ return dequantize_row_iq2_xs_sycl;
532
+ case GGML_TYPE_IQ2_S:
533
+ return dequantize_row_iq2_s_sycl;
534
+ case GGML_TYPE_IQ3_XXS:
535
+ return dequantize_row_iq3_xxs_sycl;
536
+ case GGML_TYPE_IQ3_S:
537
+ return dequantize_row_iq3_s_sycl;
538
+ case GGML_TYPE_IQ4_XS:
539
+ return dequantize_row_iq4_xs_sycl;
540
+ case GGML_TYPE_IQ4_NL:
541
+ return dequantize_row_iq4_nl_sycl;
542
+ case GGML_TYPE_F16:
543
+ return convert_unary_sycl<sycl::half>;
544
+ default:
545
+ return nullptr;
546
+ }
547
+ }