whisper.rn 0.5.0 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/gradle.properties +1 -1
  3. package/cpp/ggml-alloc.c +264 -126
  4. package/cpp/ggml-backend-impl.h +4 -1
  5. package/cpp/ggml-backend-reg.cpp +13 -5
  6. package/cpp/ggml-backend.cpp +207 -17
  7. package/cpp/ggml-backend.h +17 -1
  8. package/cpp/ggml-cpu/amx/amx.cpp +4 -2
  9. package/cpp/ggml-cpu/arch/x86/repack.cpp +2 -2
  10. package/cpp/ggml-cpu/arch-fallback.h +0 -4
  11. package/cpp/ggml-cpu/common.h +14 -0
  12. package/cpp/ggml-cpu/ggml-cpu-impl.h +13 -6
  13. package/cpp/ggml-cpu/ggml-cpu.c +48 -41
  14. package/cpp/ggml-cpu/ggml-cpu.cpp +14 -4
  15. package/cpp/ggml-cpu/ops.cpp +518 -767
  16. package/cpp/ggml-cpu/ops.h +2 -0
  17. package/cpp/ggml-cpu/simd-mappings.h +88 -59
  18. package/cpp/ggml-cpu/vec.cpp +161 -20
  19. package/cpp/ggml-cpu/vec.h +400 -51
  20. package/cpp/ggml-cpu.h +1 -1
  21. package/cpp/ggml-impl.h +43 -10
  22. package/cpp/ggml-metal/ggml-metal-common.cpp +446 -0
  23. package/cpp/ggml-metal/ggml-metal-common.h +52 -0
  24. package/cpp/ggml-metal/ggml-metal-context.h +33 -0
  25. package/cpp/ggml-metal/ggml-metal-context.m +600 -0
  26. package/cpp/ggml-metal/ggml-metal-device.cpp +1376 -0
  27. package/cpp/ggml-metal/ggml-metal-device.h +226 -0
  28. package/cpp/ggml-metal/ggml-metal-device.m +1312 -0
  29. package/cpp/ggml-metal/ggml-metal-impl.h +722 -0
  30. package/cpp/ggml-metal/ggml-metal-ops.cpp +3158 -0
  31. package/cpp/ggml-metal/ggml-metal-ops.h +82 -0
  32. package/cpp/ggml-metal/ggml-metal.cpp +718 -0
  33. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  34. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  35. package/cpp/ggml-metal-impl.h +40 -40
  36. package/cpp/ggml-metal.h +1 -6
  37. package/cpp/ggml-quants.c +1 -0
  38. package/cpp/ggml.c +175 -13
  39. package/cpp/ggml.h +84 -5
  40. package/cpp/jsi/RNWhisperJSI.cpp +2 -0
  41. package/cpp/jsi/ThreadPool.h +3 -3
  42. package/cpp/whisper.cpp +85 -70
  43. package/cpp/whisper.h +1 -0
  44. package/ios/CMakeLists.txt +6 -1
  45. package/ios/RNWhisperVadContext.mm +14 -13
  46. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  47. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  48. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  49. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +43 -10
  50. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  51. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  52. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +84 -5
  53. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  54. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  55. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  56. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  57. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  58. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  59. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  60. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +43 -10
  61. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  62. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  63. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +84 -5
  64. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  65. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  66. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  67. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  68. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  69. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  70. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  71. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  72. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +43 -10
  73. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  74. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  75. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +84 -5
  76. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  77. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  78. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  79. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  80. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  81. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  82. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  83. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +43 -10
  84. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  85. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  86. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +84 -5
  87. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  88. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  89. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  90. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  91. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  92. package/lib/commonjs/version.json +1 -1
  93. package/lib/module/version.json +1 -1
  94. package/package.json +1 -1
  95. package/src/version.json +1 -1
  96. package/whisper-rn.podspec +8 -9
  97. package/cpp/ggml-metal.m +0 -6779
  98. package/cpp/ggml-whisper-sim.metallib +0 -0
  99. package/cpp/ggml-whisper.metallib +0 -0
@@ -41,13 +41,15 @@ static void wsp_ggml_compute_forward_dup_same_cont(
41
41
  }
42
42
  }
43
43
 
44
- static void wsp_ggml_compute_forward_dup_f16(
44
+ template<typename src_t, typename dst_t>
45
+ static void wsp_ggml_compute_forward_dup_flt(
45
46
  const wsp_ggml_compute_params * params,
46
47
  wsp_ggml_tensor * dst) {
47
48
 
48
49
  const wsp_ggml_tensor * src0 = dst->src[0];
49
50
 
50
51
  WSP_GGML_ASSERT(wsp_ggml_nelements(dst) == wsp_ggml_nelements(src0));
52
+ WSP_GGML_ASSERT(!wsp_ggml_is_quantized(src0->type) && !wsp_ggml_is_quantized(dst->type));
51
53
 
52
54
  WSP_GGML_TENSOR_UNARY_OP_LOCALS
53
55
 
@@ -62,6 +64,7 @@ static void wsp_ggml_compute_forward_dup_f16(
62
64
  const int ir0 = dr * ith;
63
65
  const int ir1 = MIN(ir0 + dr, nr);
64
66
 
67
+ // case: type & row size equal
65
68
  if (src0->type == dst->type &&
66
69
  ne00 == ne0 &&
67
70
  nb00 == wsp_ggml_type_size(src0->type) && nb0 == wsp_ggml_type_size(dst->type)) {
@@ -80,11 +83,11 @@ static void wsp_ggml_compute_forward_dup_f16(
80
83
  return;
81
84
  }
82
85
 
83
- // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
84
-
86
+ // case: dst tensor is contiguous
85
87
  if (wsp_ggml_is_contiguous(dst)) {
86
- if (nb00 == sizeof(wsp_ggml_fp16_t)) {
87
- if (dst->type == WSP_GGML_TYPE_F16) {
88
+ if (nb00 == sizeof(src_t)) {
89
+ if constexpr (std::is_same_v<dst_t, src_t>) {
90
+ // same type
88
91
  size_t id = 0;
89
92
  const size_t rs = ne00 * nb00;
90
93
  char * dst_ptr = (char *) dst->data;
@@ -100,750 +103,58 @@ static void wsp_ggml_compute_forward_dup_f16(
100
103
  id += rs * (ne01 - ir1);
101
104
  }
102
105
  }
103
- } else if (dst->type == WSP_GGML_TYPE_F32) {
104
- size_t id = 0;
105
- float * dst_ptr = (float *) dst->data;
106
-
107
- for (int i03 = 0; i03 < ne03; i03++) {
108
- for (int i02 = 0; i02 < ne02; i02++) {
109
- id += ne00 * ir0;
110
- for (int i01 = ir0; i01 < ir1; i01++) {
111
- const wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
112
- for (int i00 = 0; i00 < ne00; i00++) {
113
- dst_ptr[id] = WSP_GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
114
- id++;
115
- }
116
- }
117
- id += ne00 * (ne01 - ir1);
118
- }
119
- }
120
- } else if (wsp_ggml_get_type_traits_cpu(dst->type)->from_float) {
121
- wsp_ggml_from_float_t const wsp_quantize_row_q = wsp_ggml_get_type_traits_cpu(dst->type)->from_float;
122
- float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
123
-
124
- size_t id = 0;
125
- size_t rs = nb0 * (ne00 / wsp_ggml_blck_size(dst->type));
126
- char * dst_ptr = (char *) dst->data;
127
-
128
- for (int i03 = 0; i03 < ne03; i03++) {
129
- for (int i02 = 0; i02 < ne02; i02++) {
130
- id += rs * ir0;
131
- for (int i01 = ir0; i01 < ir1; i01++) {
132
- const wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
133
-
134
- for (int i00 = 0; i00 < ne00; i00++) {
135
- src0_f32[i00] = WSP_GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
136
- }
137
-
138
- wsp_quantize_row_q(src0_f32, dst_ptr + id, ne00);
139
- id += rs;
140
- }
141
- id += rs * (ne01 - ir1);
142
- }
143
- }
144
- } else {
145
- WSP_GGML_ABORT("fatal error"); // TODO: implement
146
- }
147
- } else {
148
- //printf("%s: this is not optimal - fix me\n", __func__);
149
-
150
- if (dst->type == WSP_GGML_TYPE_F32) {
151
- size_t id = 0;
152
- float * dst_ptr = (float *) dst->data;
153
-
154
- for (int i03 = 0; i03 < ne03; i03++) {
155
- for (int i02 = 0; i02 < ne02; i02++) {
156
- id += ne00 * ir0;
157
- for (int i01 = ir0; i01 < ir1; i01++) {
158
- for (int i00 = 0; i00 < ne00; i00++) {
159
- const wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
160
-
161
- dst_ptr[id] = WSP_GGML_CPU_FP16_TO_FP32(*src0_ptr);
162
- id++;
163
- }
164
- }
165
- id += ne00 * (ne01 - ir1);
166
- }
167
- }
168
- } else if (dst->type == WSP_GGML_TYPE_F16) {
169
- size_t id = 0;
170
- wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) dst->data;
171
-
172
- for (int i03 = 0; i03 < ne03; i03++) {
173
- for (int i02 = 0; i02 < ne02; i02++) {
174
- id += ne00 * ir0;
175
- for (int i01 = ir0; i01 < ir1; i01++) {
176
- for (int i00 = 0; i00 < ne00; i00++) {
177
- const wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
178
-
179
- dst_ptr[id] = *src0_ptr;
180
- id++;
181
- }
182
- }
183
- id += ne00 * (ne01 - ir1);
184
- }
185
- }
186
106
  } else {
187
- WSP_GGML_ABORT("fatal error"); // TODO: implement
188
- }
189
- }
190
- return;
191
- }
192
-
193
- // dst counters
194
- int64_t i10 = 0;
195
- int64_t i11 = 0;
196
- int64_t i12 = 0;
197
- int64_t i13 = 0;
198
-
199
- if (dst->type == WSP_GGML_TYPE_F16) {
200
- for (int64_t i03 = 0; i03 < ne03; i03++) {
201
- for (int64_t i02 = 0; i02 < ne02; i02++) {
202
- i10 += ne00 * ir0;
203
- while (i10 >= ne0) {
204
- i10 -= ne0;
205
- if (++i11 == ne1) {
206
- i11 = 0;
207
- if (++i12 == ne2) {
208
- i12 = 0;
209
- if (++i13 == ne3) {
210
- i13 = 0;
211
- }
212
- }
213
- }
214
- }
215
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
216
- for (int64_t i00 = 0; i00 < ne00; i00++) {
217
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
218
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
219
-
220
- memcpy(dst_ptr, src0_ptr, sizeof(wsp_ggml_fp16_t));
221
-
222
- if (++i10 == ne00) {
223
- i10 = 0;
224
- if (++i11 == ne01) {
225
- i11 = 0;
226
- if (++i12 == ne02) {
227
- i12 = 0;
228
- if (++i13 == ne03) {
229
- i13 = 0;
230
- }
231
- }
232
- }
233
- }
234
- }
235
- }
236
- i10 += ne00 * (ne01 - ir1);
237
- while (i10 >= ne0) {
238
- i10 -= ne0;
239
- if (++i11 == ne1) {
240
- i11 = 0;
241
- if (++i12 == ne2) {
242
- i12 = 0;
243
- if (++i13 == ne3) {
244
- i13 = 0;
245
- }
246
- }
247
- }
248
- }
249
- }
250
- }
251
- } else if (dst->type == WSP_GGML_TYPE_F32) {
252
- for (int64_t i03 = 0; i03 < ne03; i03++) {
253
- for (int64_t i02 = 0; i02 < ne02; i02++) {
254
- i10 += ne00 * ir0;
255
- while (i10 >= ne0) {
256
- i10 -= ne0;
257
- if (++i11 == ne1) {
258
- i11 = 0;
259
- if (++i12 == ne2) {
260
- i12 = 0;
261
- if (++i13 == ne3) {
262
- i13 = 0;
263
- }
264
- }
265
- }
266
- }
267
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
268
- for (int64_t i00 = 0; i00 < ne00; i00++) {
269
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
270
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
271
-
272
- *(float *) dst_ptr = WSP_GGML_CPU_FP16_TO_FP32(*(const wsp_ggml_fp16_t *) src0_ptr);
273
-
274
- if (++i10 == ne0) {
275
- i10 = 0;
276
- if (++i11 == ne1) {
277
- i11 = 0;
278
- if (++i12 == ne2) {
279
- i12 = 0;
280
- if (++i13 == ne3) {
281
- i13 = 0;
282
- }
283
- }
284
- }
285
- }
286
- }
287
- }
288
- i10 += ne00 * (ne01 - ir1);
289
- while (i10 >= ne0) {
290
- i10 -= ne0;
291
- if (++i11 == ne1) {
292
- i11 = 0;
293
- if (++i12 == ne2) {
294
- i12 = 0;
295
- if (++i13 == ne3) {
296
- i13 = 0;
297
- }
298
- }
299
- }
300
- }
301
- }
302
- }
303
- } else {
304
- WSP_GGML_ABORT("fatal error"); // TODO: implement
305
- }
306
- }
307
-
308
- static void wsp_ggml_compute_forward_dup_bf16(
309
- const wsp_ggml_compute_params * params,
310
- wsp_ggml_tensor * dst) {
311
-
312
- const wsp_ggml_tensor * src0 = dst->src[0];
313
-
314
- WSP_GGML_ASSERT(wsp_ggml_nelements(dst) == wsp_ggml_nelements(src0));
315
-
316
- WSP_GGML_TENSOR_UNARY_OP_LOCALS
317
-
318
- const int ith = params->ith; // thread index
319
- const int nth = params->nth; // number of threads
320
-
321
- // parallelize by rows
322
- const int nr = ne01;
323
- // number of rows per thread
324
- const int dr = (nr + nth - 1) / nth;
325
- // row range for this thread
326
- const int ir0 = dr * ith;
327
- const int ir1 = MIN(ir0 + dr, nr);
328
-
329
- if (src0->type == dst->type &&
330
- ne00 == ne0 &&
331
- nb00 == wsp_ggml_type_size(src0->type) && nb0 == wsp_ggml_type_size(dst->type)) {
332
- // copy by rows
333
- const size_t rs = ne00*nb00;
334
- for (int64_t i03 = 0; i03 < ne03; i03++) {
335
- for (int64_t i02 = 0; i02 < ne02; i02++) {
336
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
337
- memcpy(
338
- ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
339
- ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
340
- rs);
341
- }
342
- }
343
- }
344
- return;
345
- }
346
-
347
- // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
348
-
349
- if (wsp_ggml_is_contiguous(dst)) {
350
- if (nb00 == sizeof(wsp_ggml_bf16_t)) {
351
- if (dst->type == WSP_GGML_TYPE_BF16) {
352
- size_t id = 0;
353
- const size_t rs = ne00 * nb00;
354
- char * dst_ptr = (char *) dst->data;
355
-
356
- for (int i03 = 0; i03 < ne03; i03++) {
357
- for (int i02 = 0; i02 < ne02; i02++) {
358
- id += rs * ir0;
359
- for (int i01 = ir0; i01 < ir1; i01++) {
360
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
361
- memcpy(dst_ptr + id, src0_ptr, rs);
362
- id += rs;
363
- }
364
- id += rs * (ne01 - ir1);
365
- }
366
- }
367
- } else if (dst->type == WSP_GGML_TYPE_F16) {
368
- size_t id = 0;
369
- wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) dst->data;
370
-
371
- for (int i03 = 0; i03 < ne03; i03++) {
372
- for (int i02 = 0; i02 < ne02; i02++) {
373
- id += ne00 * ir0;
374
- for (int i01 = ir0; i01 < ir1; i01++) {
375
- const wsp_ggml_bf16_t * src0_ptr = (wsp_ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
376
- for (int i00 = 0; i00 < ne00; i00++) {
377
- dst_ptr[id] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_BF16_TO_FP32(src0_ptr[i00]));
378
- id++;
379
- }
380
- }
381
- id += ne00 * (ne01 - ir1);
382
- }
383
- }
384
- } else if (dst->type == WSP_GGML_TYPE_F32) {
107
+ // casting between non-quantized types
385
108
  size_t id = 0;
386
- float * dst_ptr = (float *) dst->data;
109
+ dst_t * dst_ptr = (dst_t *) dst->data;
387
110
 
388
111
  for (int i03 = 0; i03 < ne03; i03++) {
389
112
  for (int i02 = 0; i02 < ne02; i02++) {
390
- id += ne00 * ir0;
391
- for (int i01 = ir0; i01 < ir1; i01++) {
392
- const wsp_ggml_bf16_t * src0_ptr = (wsp_ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
393
- for (int i00 = 0; i00 < ne00; i00++) {
394
- dst_ptr[id] = WSP_GGML_BF16_TO_FP32(src0_ptr[i00]);
395
- id++;
396
- }
397
- }
398
- id += ne00 * (ne01 - ir1);
399
- }
400
- }
401
- } else if (wsp_ggml_get_type_traits_cpu(dst->type)->from_float) {
402
- wsp_ggml_from_float_t const wsp_quantize_row_q = wsp_ggml_get_type_traits_cpu(dst->type)->from_float;
403
- float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
404
-
405
- size_t id = 0;
406
- size_t rs = nb0 * (ne00 / wsp_ggml_blck_size(dst->type));
407
- char * dst_ptr = (char *) dst->data;
408
-
409
- for (int i03 = 0; i03 < ne03; i03++) {
410
- for (int i02 = 0; i02 < ne02; i02++) {
411
- id += rs * ir0;
412
- for (int i01 = ir0; i01 < ir1; i01++) {
413
- const wsp_ggml_bf16_t * src0_ptr = (wsp_ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
414
-
415
- for (int i00 = 0; i00 < ne00; i00++) {
416
- src0_f32[i00] = WSP_GGML_BF16_TO_FP32(src0_ptr[i00]);
417
- }
418
-
419
- wsp_quantize_row_q(src0_f32, dst_ptr + id, ne00);
420
- id += rs;
421
- }
422
- id += rs * (ne01 - ir1);
423
- }
424
- }
425
- } else {
426
- WSP_GGML_ABORT("fatal error"); // TODO: implement
427
- }
428
- } else {
429
- //printf("%s: this is not optimal - fix me\n", __func__);
430
-
431
- if (dst->type == WSP_GGML_TYPE_F32) {
432
- size_t id = 0;
433
- float * dst_ptr = (float *) dst->data;
434
-
435
- for (int i03 = 0; i03 < ne03; i03++) {
436
- for (int i02 = 0; i02 < ne02; i02++) {
437
- id += ne00 * ir0;
438
- for (int i01 = ir0; i01 < ir1; i01++) {
439
- for (int i00 = 0; i00 < ne00; i00++) {
440
- const wsp_ggml_bf16_t * src0_ptr = (wsp_ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
441
-
442
- dst_ptr[id] = WSP_GGML_BF16_TO_FP32(*src0_ptr);
443
- id++;
444
- }
445
- }
446
- id += ne00 * (ne01 - ir1);
447
- }
448
- }
449
- } else if (dst->type == WSP_GGML_TYPE_BF16) {
450
- size_t id = 0;
451
- wsp_ggml_bf16_t * dst_ptr = (wsp_ggml_bf16_t *) dst->data;
452
-
453
- for (int i03 = 0; i03 < ne03; i03++) {
454
- for (int i02 = 0; i02 < ne02; i02++) {
455
- id += ne00 * ir0;
456
- for (int i01 = ir0; i01 < ir1; i01++) {
457
- for (int i00 = 0; i00 < ne00; i00++) {
458
- const wsp_ggml_bf16_t * src0_ptr = (wsp_ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
459
-
460
- dst_ptr[id] = *src0_ptr;
461
- id++;
462
- }
463
- }
464
- id += ne00 * (ne01 - ir1);
465
- }
466
- }
467
- } else if (dst->type == WSP_GGML_TYPE_F16) {
468
- size_t id = 0;
469
- wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) dst->data;
470
-
471
- for (int i03 = 0; i03 < ne03; i03++) {
472
- for (int i02 = 0; i02 < ne02; i02++) {
473
- id += ne00 * ir0;
474
- for (int i01 = ir0; i01 < ir1; i01++) {
475
- for (int i00 = 0; i00 < ne00; i00++) {
476
- const wsp_ggml_bf16_t * src0_ptr = (wsp_ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
477
-
478
- dst_ptr[id] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_BF16_TO_FP32(*src0_ptr));
479
- id++;
480
- }
481
- }
482
- id += ne00 * (ne01 - ir1);
483
- }
484
- }
485
- } else {
486
- WSP_GGML_ABORT("fatal error"); // TODO: implement
487
- }
488
- }
489
- return;
490
- }
491
-
492
- // dst counters
493
- int64_t i10 = 0;
494
- int64_t i11 = 0;
495
- int64_t i12 = 0;
496
- int64_t i13 = 0;
497
-
498
- if (dst->type == WSP_GGML_TYPE_BF16) {
499
- for (int64_t i03 = 0; i03 < ne03; i03++) {
500
- for (int64_t i02 = 0; i02 < ne02; i02++) {
501
- i10 += ne00 * ir0;
502
- while (i10 >= ne0) {
503
- i10 -= ne0;
504
- if (++i11 == ne1) {
505
- i11 = 0;
506
- if (++i12 == ne2) {
507
- i12 = 0;
508
- if (++i13 == ne3) {
509
- i13 = 0;
510
- }
511
- }
512
- }
513
- }
514
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
515
- for (int64_t i00 = 0; i00 < ne00; i00++) {
516
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
517
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
518
-
519
- memcpy(dst_ptr, src0_ptr, sizeof(wsp_ggml_bf16_t));
520
-
521
- if (++i10 == ne00) {
522
- i10 = 0;
523
- if (++i11 == ne01) {
524
- i11 = 0;
525
- if (++i12 == ne02) {
526
- i12 = 0;
527
- if (++i13 == ne03) {
528
- i13 = 0;
529
- }
530
- }
531
- }
532
- }
533
- }
534
- }
535
- i10 += ne00 * (ne01 - ir1);
536
- while (i10 >= ne0) {
537
- i10 -= ne0;
538
- if (++i11 == ne1) {
539
- i11 = 0;
540
- if (++i12 == ne2) {
541
- i12 = 0;
542
- if (++i13 == ne3) {
543
- i13 = 0;
544
- }
545
- }
546
- }
547
- }
548
- }
549
- }
550
- } else if (dst->type == WSP_GGML_TYPE_F16) {
551
- for (int64_t i03 = 0; i03 < ne03; i03++) {
552
- for (int64_t i02 = 0; i02 < ne02; i02++) {
553
- i10 += ne00 * ir0;
554
- while (i10 >= ne0) {
555
- i10 -= ne0;
556
- if (++i11 == ne1) {
557
- i11 = 0;
558
- if (++i12 == ne2) {
559
- i12 = 0;
560
- if (++i13 == ne3) {
561
- i13 = 0;
562
- }
563
- }
564
- }
565
- }
566
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
567
- for (int64_t i00 = 0; i00 < ne00; i00++) {
568
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
569
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
570
-
571
- *(wsp_ggml_fp16_t *) dst_ptr = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_BF16_TO_FP32(*(const wsp_ggml_bf16_t *) src0_ptr));
572
-
573
- if (++i10 == ne0) {
574
- i10 = 0;
575
- if (++i11 == ne1) {
576
- i11 = 0;
577
- if (++i12 == ne2) {
578
- i12 = 0;
579
- if (++i13 == ne3) {
580
- i13 = 0;
581
- }
582
- }
583
- }
584
- }
585
- }
586
- }
587
- i10 += ne00 * (ne01 - ir1);
588
- while (i10 >= ne0) {
589
- i10 -= ne0;
590
- if (++i11 == ne1) {
591
- i11 = 0;
592
- if (++i12 == ne2) {
593
- i12 = 0;
594
- if (++i13 == ne3) {
595
- i13 = 0;
596
- }
597
- }
598
- }
599
- }
600
- }
601
- }
602
- } else if (dst->type == WSP_GGML_TYPE_F32) {
603
- for (int64_t i03 = 0; i03 < ne03; i03++) {
604
- for (int64_t i02 = 0; i02 < ne02; i02++) {
605
- i10 += ne00 * ir0;
606
- while (i10 >= ne0) {
607
- i10 -= ne0;
608
- if (++i11 == ne1) {
609
- i11 = 0;
610
- if (++i12 == ne2) {
611
- i12 = 0;
612
- if (++i13 == ne3) {
613
- i13 = 0;
614
- }
615
- }
616
- }
617
- }
618
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
619
- for (int64_t i00 = 0; i00 < ne00; i00++) {
620
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
621
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
622
-
623
- *(float *) dst_ptr = WSP_GGML_BF16_TO_FP32(*(const wsp_ggml_bf16_t *) src0_ptr);
624
-
625
- if (++i10 == ne0) {
626
- i10 = 0;
627
- if (++i11 == ne1) {
628
- i11 = 0;
629
- if (++i12 == ne2) {
630
- i12 = 0;
631
- if (++i13 == ne3) {
632
- i13 = 0;
633
- }
634
- }
635
- }
636
- }
637
- }
638
- }
639
- i10 += ne00 * (ne01 - ir1);
640
- while (i10 >= ne0) {
641
- i10 -= ne0;
642
- if (++i11 == ne1) {
643
- i11 = 0;
644
- if (++i12 == ne2) {
645
- i12 = 0;
646
- if (++i13 == ne3) {
647
- i13 = 0;
648
- }
649
- }
650
- }
651
- }
652
- }
653
- }
654
- } else {
655
- WSP_GGML_ABORT("fatal error"); // TODO: implement
656
- }
657
- }
658
-
659
- static void wsp_ggml_compute_forward_dup_f32(
660
- const wsp_ggml_compute_params * params,
661
- wsp_ggml_tensor * dst) {
662
-
663
- const wsp_ggml_tensor * src0 = dst->src[0];
664
-
665
- WSP_GGML_ASSERT(wsp_ggml_nelements(dst) == wsp_ggml_nelements(src0));
666
-
667
- WSP_GGML_TENSOR_UNARY_OP_LOCALS
668
-
669
- const int ith = params->ith; // thread index
670
- const int nth = params->nth; // number of threads
671
-
672
- // parallelize by rows
673
- const int nr = ne01;
674
- // number of rows per thread
675
- const int dr = (nr + nth - 1) / nth;
676
- // row range for this thread
677
- const int ir0 = dr * ith;
678
- const int ir1 = MIN(ir0 + dr, nr);
679
-
680
- if (src0->type == dst->type &&
681
- ne00 == ne0 &&
682
- nb00 == wsp_ggml_type_size(src0->type) && nb0 == wsp_ggml_type_size(dst->type)) {
683
- // copy by rows
684
- const size_t rs = ne00*nb00;
685
- for (int64_t i03 = 0; i03 < ne03; i03++) {
686
- for (int64_t i02 = 0; i02 < ne02; i02++) {
687
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
688
- memcpy(
689
- ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
690
- ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
691
- rs);
692
- }
693
- }
694
- }
695
- return;
696
- }
697
-
698
- if (wsp_ggml_is_contiguous(dst)) {
699
- // TODO: simplify
700
- if (nb00 == sizeof(float)) {
701
- if (wsp_ggml_get_type_traits_cpu(dst->type)->from_float) {
702
- wsp_ggml_from_float_t const from_float = wsp_ggml_get_type_traits_cpu(dst->type)->from_float;
703
-
704
- size_t id = 0;
705
- size_t rs = nb0 * (ne00 / wsp_ggml_blck_size(dst->type));
706
- char * dst_ptr = (char *) dst->data;
707
-
708
- for (int i03 = 0; i03 < ne03; i03++) {
709
- for (int i02 = 0; i02 < ne02; i02++) {
710
- id += rs * ir0;
711
- for (int i01 = ir0; i01 < ir1; i01++) {
712
- const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
713
- from_float(src0_ptr, dst_ptr + id, ne00);
714
- id += rs;
715
- }
716
- id += rs * (ne01 - ir1);
717
- }
718
- }
719
- } else {
720
- WSP_GGML_ABORT("fatal error"); // TODO: implement
721
- }
722
- } else {
723
- //printf("%s: this is not optimal - fix me\n", __func__);
724
-
725
- if (dst->type == WSP_GGML_TYPE_F32) {
726
- size_t id = 0;
727
- float * dst_ptr = (float *) dst->data;
728
-
729
- for (int i03 = 0; i03 < ne03; i03++) {
730
- for (int i02 = 0; i02 < ne02; i02++) {
731
- id += ne00 * ir0;
732
- for (int i01 = ir0; i01 < ir1; i01++) {
733
- for (int i00 = 0; i00 < ne00; i00++) {
734
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
735
-
736
- dst_ptr[id] = *src0_ptr;
737
- id++;
738
- }
739
- }
740
- id += ne00 * (ne01 - ir1);
741
- }
742
- }
743
- } else if (dst->type == WSP_GGML_TYPE_F16) {
744
- size_t id = 0;
745
- wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) dst->data;
746
-
747
- for (int i03 = 0; i03 < ne03; i03++) {
748
- for (int i02 = 0; i02 < ne02; i02++) {
749
- id += ne00 * ir0;
750
- for (int i01 = ir0; i01 < ir1; i01++) {
751
- for (int i00 = 0; i00 < ne00; i00++) {
752
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
753
-
754
- dst_ptr[id] = WSP_GGML_CPU_FP32_TO_FP16(*src0_ptr);
755
- id++;
756
- }
757
- }
758
- id += ne00 * (ne01 - ir1);
759
- }
760
- }
761
- } else if (dst->type == WSP_GGML_TYPE_BF16) {
762
- size_t id = 0;
763
- wsp_ggml_bf16_t * dst_ptr = (wsp_ggml_bf16_t *) dst->data;
764
-
765
- for (int i03 = 0; i03 < ne03; i03++) {
766
- for (int i02 = 0; i02 < ne02; i02++) {
767
- id += ne00 * ir0;
768
- for (int i01 = ir0; i01 < ir1; i01++) {
769
- for (int i00 = 0; i00 < ne00; i00++) {
770
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
771
-
772
- dst_ptr[id] = WSP_GGML_FP32_TO_BF16(*src0_ptr);
773
- id++;
774
- }
775
- }
776
- id += ne00 * (ne01 - ir1);
777
- }
778
- }
779
- } else {
780
- WSP_GGML_ABORT("fatal error"); // TODO: implement
781
- }
782
- }
783
-
784
- return;
785
- }
786
-
787
- // dst counters
788
-
789
- int64_t i10 = 0;
790
- int64_t i11 = 0;
791
- int64_t i12 = 0;
792
- int64_t i13 = 0;
793
-
794
- if (dst->type == WSP_GGML_TYPE_F32) {
795
- for (int64_t i03 = 0; i03 < ne03; i03++) {
796
- for (int64_t i02 = 0; i02 < ne02; i02++) {
797
- i10 += ne00 * ir0;
798
- while (i10 >= ne0) {
799
- i10 -= ne0;
800
- if (++i11 == ne1) {
801
- i11 = 0;
802
- if (++i12 == ne2) {
803
- i12 = 0;
804
- if (++i13 == ne3) {
805
- i13 = 0;
806
- }
807
- }
808
- }
809
- }
810
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
811
- for (int64_t i00 = 0; i00 < ne00; i00++) {
812
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
813
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
814
-
815
- memcpy(dst_ptr, src0_ptr, sizeof(float));
816
-
817
- if (++i10 == ne0) {
818
- i10 = 0;
819
- if (++i11 == ne1) {
820
- i11 = 0;
821
- if (++i12 == ne2) {
822
- i12 = 0;
823
- if (++i13 == ne3) {
824
- i13 = 0;
825
- }
826
- }
113
+ id += ne00 * ir0;
114
+ for (int i01 = ir0; i01 < ir1; i01++) {
115
+ const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
116
+ for (int i00 = 0; i00 < ne00; i00++) {
117
+ float tmp = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
118
+ dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
119
+ id++;
827
120
  }
828
121
  }
122
+ id += ne00 * (ne01 - ir1);
829
123
  }
830
124
  }
831
- i10 += ne00 * (ne01 - ir1);
832
- while (i10 >= ne0) {
833
- i10 -= ne0;
834
- if (++i11 == ne1) {
835
- i11 = 0;
836
- if (++i12 == ne2) {
837
- i12 = 0;
838
- if (++i13 == ne3) {
839
- i13 = 0;
840
- }
125
+ }
126
+ } else {
127
+ //printf("%s: this is not optimal - fix me\n", __func__);
128
+
129
+ size_t id = 0;
130
+ dst_t * dst_ptr = (dst_t *) dst->data;
131
+
132
+ for (int i03 = 0; i03 < ne03; i03++) {
133
+ for (int i02 = 0; i02 < ne02; i02++) {
134
+ id += ne00 * ir0;
135
+ for (int i01 = ir0; i01 < ir1; i01++) {
136
+ for (int i00 = 0; i00 < ne00; i00++) {
137
+ const src_t * src0_ptr = (src_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
138
+
139
+ float tmp = type_conversion_table<src_t>::to_f32(*src0_ptr);
140
+ dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
141
+ id++;
841
142
  }
842
143
  }
144
+ id += ne00 * (ne01 - ir1);
843
145
  }
844
146
  }
845
147
  }
846
- } else if (dst->type == WSP_GGML_TYPE_F16) {
148
+ return;
149
+ }
150
+
151
+ // dst counters
152
+ int64_t i10 = 0;
153
+ int64_t i11 = 0;
154
+ int64_t i12 = 0;
155
+ int64_t i13 = 0;
156
+
157
+ if constexpr (std::is_same_v<dst_t, src_t>) {
847
158
  for (int64_t i03 = 0; i03 < ne03; i03++) {
848
159
  for (int64_t i02 = 0; i02 < ne02; i02++) {
849
160
  i10 += ne00 * ir0;
@@ -864,15 +175,15 @@ static void wsp_ggml_compute_forward_dup_f32(
864
175
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
865
176
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
866
177
 
867
- *(wsp_ggml_fp16_t *) dst_ptr = WSP_GGML_CPU_FP32_TO_FP16(*(const float *) src0_ptr);
178
+ memcpy(dst_ptr, src0_ptr, sizeof(dst_t));
868
179
 
869
- if (++i10 == ne0) {
180
+ if (++i10 == ne00) {
870
181
  i10 = 0;
871
- if (++i11 == ne1) {
182
+ if (++i11 == ne01) {
872
183
  i11 = 0;
873
- if (++i12 == ne2) {
184
+ if (++i12 == ne02) {
874
185
  i12 = 0;
875
- if (++i13 == ne3) {
186
+ if (++i13 == ne03) {
876
187
  i13 = 0;
877
188
  }
878
189
  }
@@ -895,7 +206,8 @@ static void wsp_ggml_compute_forward_dup_f32(
895
206
  }
896
207
  }
897
208
  }
898
- } else if (dst->type == WSP_GGML_TYPE_BF16) {
209
+
210
+ } else {
899
211
  for (int64_t i03 = 0; i03 < ne03; i03++) {
900
212
  for (int64_t i02 = 0; i02 < ne02; i02++) {
901
213
  i10 += ne00 * ir0;
@@ -916,7 +228,8 @@ static void wsp_ggml_compute_forward_dup_f32(
916
228
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
917
229
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
918
230
 
919
- *(wsp_ggml_bf16_t *) dst_ptr = WSP_GGML_FP32_TO_BF16(*(const float *) src0_ptr);
231
+ float tmp = type_conversion_table<src_t>::to_f32(*(const src_t *) src0_ptr);
232
+ *(dst_t *) dst_ptr = type_conversion_table<dst_t>::from_f32(tmp);
920
233
 
921
234
  if (++i10 == ne0) {
922
235
  i10 = 0;
@@ -947,8 +260,63 @@ static void wsp_ggml_compute_forward_dup_f32(
947
260
  }
948
261
  }
949
262
  }
263
+ }
264
+ }
265
+
266
+
267
+ template<typename src_t>
268
+ static void wsp_ggml_compute_forward_dup_to_q(
269
+ const wsp_ggml_compute_params * params,
270
+ wsp_ggml_tensor * dst) {
271
+
272
+ const wsp_ggml_tensor * src0 = dst->src[0];
273
+
274
+ WSP_GGML_ASSERT(wsp_ggml_nelements(dst) == wsp_ggml_nelements(src0));
275
+ WSP_GGML_ASSERT(!wsp_ggml_is_quantized(src0->type));
276
+
277
+ WSP_GGML_TENSOR_UNARY_OP_LOCALS
278
+
279
+ const int ith = params->ith; // thread index
280
+ const int nth = params->nth; // number of threads
281
+
282
+ // parallelize by rows
283
+ const int nr = ne01;
284
+ // number of rows per thread
285
+ const int dr = (nr + nth - 1) / nth;
286
+ // row range for this thread
287
+ const int ir0 = dr * ith;
288
+ const int ir1 = MIN(ir0 + dr, nr);
289
+
290
+ if (wsp_ggml_is_contiguous(dst) &&
291
+ nb00 == sizeof(src_t) &&
292
+ wsp_ggml_get_type_traits_cpu(dst->type)->from_float) {
293
+ // casting non-quantized types --> intermediate f32 --> quantized
294
+ wsp_ggml_from_float_t const wsp_quantize_row_q = wsp_ggml_get_type_traits_cpu(dst->type)->from_float;
295
+ float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
296
+
297
+ size_t id = 0;
298
+ size_t rs = nb0 * (ne00 / wsp_ggml_blck_size(dst->type));
299
+ char * dst_ptr = (char *) dst->data;
300
+
301
+ for (int i03 = 0; i03 < ne03; i03++) {
302
+ for (int i02 = 0; i02 < ne02; i02++) {
303
+ id += rs * ir0;
304
+ for (int i01 = ir0; i01 < ir1; i01++) {
305
+ const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
306
+
307
+ for (int i00 = 0; i00 < ne00; i00++) {
308
+ src0_f32[i00] = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
309
+ }
310
+
311
+ wsp_quantize_row_q(src0_f32, dst_ptr + id, ne00);
312
+ id += rs;
313
+ }
314
+ id += rs * (ne01 - ir1);
315
+ }
316
+ }
950
317
  } else {
951
- WSP_GGML_ABORT("fatal error"); // TODO: implement
318
+ // printf("%s %s\n", wsp_ggml_type_name(src0->type), wsp_ggml_type_name(dst->type));
319
+ WSP_GGML_ABORT("not implemented");
952
320
  }
953
321
  }
954
322
 
@@ -1102,7 +470,7 @@ static void wsp_ggml_compute_forward_dup_bytes(
1102
470
  }
1103
471
  }
1104
472
 
1105
- static void wsp_ggml_compute_forward_dup_q(
473
+ static void wsp_ggml_compute_forward_dup_from_q(
1106
474
  const wsp_ggml_compute_params * params,
1107
475
  wsp_ggml_tensor * dst) {
1108
476
 
@@ -1167,20 +535,35 @@ void wsp_ggml_compute_forward_dup(
1167
535
  switch (src0->type) {
1168
536
  case WSP_GGML_TYPE_F16:
1169
537
  {
1170
- wsp_ggml_compute_forward_dup_f16(params, dst);
538
+ /**/ if (dst->type == WSP_GGML_TYPE_F16) wsp_ggml_compute_forward_dup_flt<wsp_ggml_fp16_t, wsp_ggml_fp16_t>(params, dst);
539
+ else if (dst->type == WSP_GGML_TYPE_BF16) wsp_ggml_compute_forward_dup_flt<wsp_ggml_fp16_t, wsp_ggml_bf16_t>(params, dst);
540
+ else if (dst->type == WSP_GGML_TYPE_F32) wsp_ggml_compute_forward_dup_flt<wsp_ggml_fp16_t, float >(params, dst);
541
+ else wsp_ggml_compute_forward_dup_to_q<wsp_ggml_fp16_t>(params, dst);
1171
542
  } break;
1172
543
  case WSP_GGML_TYPE_BF16:
1173
544
  {
1174
- wsp_ggml_compute_forward_dup_bf16(params, dst);
545
+ /**/ if (dst->type == WSP_GGML_TYPE_F16) wsp_ggml_compute_forward_dup_flt<wsp_ggml_bf16_t, wsp_ggml_fp16_t>(params, dst);
546
+ else if (dst->type == WSP_GGML_TYPE_BF16) wsp_ggml_compute_forward_dup_flt<wsp_ggml_bf16_t, wsp_ggml_bf16_t>(params, dst);
547
+ else if (dst->type == WSP_GGML_TYPE_F32) wsp_ggml_compute_forward_dup_flt<wsp_ggml_bf16_t, float >(params, dst);
548
+ else wsp_ggml_compute_forward_dup_to_q<wsp_ggml_bf16_t>(params, dst);
1175
549
  } break;
1176
550
  case WSP_GGML_TYPE_F32:
1177
551
  {
1178
- wsp_ggml_compute_forward_dup_f32(params, dst);
552
+ /**/ if (dst->type == WSP_GGML_TYPE_F16) wsp_ggml_compute_forward_dup_flt<float, wsp_ggml_fp16_t>(params, dst);
553
+ else if (dst->type == WSP_GGML_TYPE_BF16) wsp_ggml_compute_forward_dup_flt<float, wsp_ggml_bf16_t>(params, dst);
554
+ else if (dst->type == WSP_GGML_TYPE_F32) wsp_ggml_compute_forward_dup_flt<float, float >(params, dst);
555
+ else if (dst->type == WSP_GGML_TYPE_I32) wsp_ggml_compute_forward_dup_flt<float, int32_t >(params, dst);
556
+ else wsp_ggml_compute_forward_dup_to_q<float>(params, dst);
557
+ } break;
558
+ case WSP_GGML_TYPE_I32:
559
+ {
560
+ if (dst->type == WSP_GGML_TYPE_F32) wsp_ggml_compute_forward_dup_flt<int32_t, float>(params, dst);
561
+ else WSP_GGML_ABORT("not implemented");
1179
562
  } break;
1180
563
  default:
1181
564
  {
1182
565
  if (wsp_ggml_is_quantized(src0->type) && dst->type == WSP_GGML_TYPE_F32) {
1183
- wsp_ggml_compute_forward_dup_q(params, dst);
566
+ wsp_ggml_compute_forward_dup_from_q(params, dst);
1184
567
  break;
1185
568
  }
1186
569
  WSP_GGML_ABORT("fatal error");
@@ -5356,6 +4739,7 @@ void wsp_ggml_compute_forward_get_rows(
5356
4739
  //}
5357
4740
  }
5358
4741
 
4742
+ template<typename idx_t>
5359
4743
  static void wsp_ggml_compute_forward_set_rows_f32(
5360
4744
  const wsp_ggml_compute_params * params,
5361
4745
  wsp_ggml_tensor * dst) {
@@ -5394,7 +4778,7 @@ static void wsp_ggml_compute_forward_set_rows_f32(
5394
4778
  const int64_t i11 = i02%ne11;
5395
4779
  const int64_t i10 = i;
5396
4780
 
5397
- const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4781
+ const int64_t i1 = *(idx_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
5398
4782
 
5399
4783
  WSP_GGML_ASSERT(i1 >= 0 && i1 < ne1);
5400
4784
 
@@ -5411,11 +4795,18 @@ void wsp_ggml_compute_forward_set_rows(
5411
4795
  wsp_ggml_tensor * dst) {
5412
4796
 
5413
4797
  const wsp_ggml_tensor * src0 = dst->src[0];
4798
+ const wsp_ggml_tensor * src1 = dst->src[1];
5414
4799
 
5415
4800
  switch (src0->type) {
5416
4801
  case WSP_GGML_TYPE_F32:
5417
4802
  {
5418
- wsp_ggml_compute_forward_set_rows_f32(params, dst);
4803
+ if (src1->type == WSP_GGML_TYPE_I64) {
4804
+ wsp_ggml_compute_forward_set_rows_f32<int64_t>(params, dst);
4805
+ } else if (src1->type == WSP_GGML_TYPE_I32) {
4806
+ wsp_ggml_compute_forward_set_rows_f32<int32_t>(params, dst);
4807
+ } else {
4808
+ WSP_GGML_ABORT("src1->type = %d (%s) not supported", src1->type, wsp_ggml_type_name(src1->type));
4809
+ }
5419
4810
  } break;
5420
4811
  default:
5421
4812
  {
@@ -7027,6 +6418,209 @@ void wsp_ggml_compute_forward_im2col_back_f32(
7027
6418
  }
7028
6419
  }
7029
6420
 
6421
+
6422
+ // wsp_ggml_compute_forward_im2col_3d_f16
6423
+ // src0: kernel [OC*IC, KD, KH, KW]
6424
+ // src1: image [N*IC, ID, IH, IW]
6425
+ // dst: result [N*OD, OH, OW, IC * KD * KH * KW]
6426
+ static void wsp_ggml_compute_forward_im2col_3d_f16(
6427
+ const wsp_ggml_compute_params * params,
6428
+ wsp_ggml_tensor * dst) {
6429
+
6430
+ const wsp_ggml_tensor * src0 = dst->src[0];
6431
+ const wsp_ggml_tensor * src1 = dst->src[1];
6432
+
6433
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
6434
+ WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
6435
+ WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16);
6436
+
6437
+ WSP_GGML_TENSOR_BINARY_OP_LOCALS;
6438
+
6439
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6440
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6441
+ const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
6442
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
6443
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
6444
+ const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
6445
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
6446
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
6447
+ const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
6448
+ const int32_t IC = ((const int32_t *)(dst->op_params))[9];
6449
+
6450
+
6451
+ const int ith = params->ith;
6452
+ const int nth = params->nth;
6453
+
6454
+ const int64_t N = ne13 / IC;
6455
+ const int64_t ID = ne12;
6456
+ const int64_t IH = ne11;
6457
+ const int64_t IW = ne10;
6458
+
6459
+ const int64_t OC = ne03 / IC;
6460
+ WSP_GGML_UNUSED(OC);
6461
+ const int64_t KD = ne02;
6462
+ const int64_t KH = ne01;
6463
+ const int64_t KW = ne00;
6464
+
6465
+ const int64_t OD = ne3 / N;
6466
+ const int64_t OH = ne2;
6467
+ const int64_t OW = ne1;
6468
+ const int64_t OH_OW = OH*OW;
6469
+ const int64_t KD_KH_KW = KD*KH*KW;
6470
+ const int64_t KH_KW = KH*KW;
6471
+ const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
6472
+
6473
+ WSP_GGML_ASSERT(nb10 == sizeof(float));
6474
+
6475
+ // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
6476
+ {
6477
+ wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) dst->data;
6478
+
6479
+ for (int64_t in = 0; in < N; in++) {
6480
+ for (int64_t iod = 0; iod < OD; iod++) {
6481
+ for (int64_t ioh = 0; ioh < OH; ioh++) {
6482
+ for (int64_t iow = 0; iow < OW; iow++) {
6483
+ for (int64_t iic = ith; iic < IC; iic += nth) {
6484
+
6485
+ // micro kernel
6486
+ wsp_ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
6487
+ const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
6488
+
6489
+ for (int64_t ikd = 0; ikd < KD; ikd++) {
6490
+ for (int64_t ikh = 0; ikh < KH; ikh++) {
6491
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
6492
+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
6493
+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
6494
+ const int64_t iid = iod*s2 + ikd*d2 - p2;
6495
+
6496
+ if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
6497
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
6498
+ } else {
6499
+ const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
6500
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = WSP_GGML_CPU_FP32_TO_FP16(*s);
6501
+ }
6502
+ }
6503
+ }
6504
+ }
6505
+ }
6506
+ }
6507
+ }
6508
+ }
6509
+ }
6510
+ }
6511
+ }
6512
+
6513
+ // wsp_ggml_compute_forward_im2col_3d_f32
6514
+ // src0: kernel [OC*IC, KD, KH, KW]
6515
+ // src1: image [N*IC, ID, IH, IW]
6516
+ // dst: result [N*OD, OH, OW, IC * KD * KH * KW]
6517
+ static void wsp_ggml_compute_forward_im2col_3d_f32(
6518
+ const wsp_ggml_compute_params * params,
6519
+ wsp_ggml_tensor * dst) {
6520
+
6521
+ const wsp_ggml_tensor * src0 = dst->src[0];
6522
+ const wsp_ggml_tensor * src1 = dst->src[1];
6523
+
6524
+ WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
6525
+ WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
6526
+
6527
+ WSP_GGML_TENSOR_BINARY_OP_LOCALS;
6528
+
6529
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6530
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6531
+ const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
6532
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
6533
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
6534
+ const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
6535
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
6536
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
6537
+ const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
6538
+ const int32_t IC = ((const int32_t *)(dst->op_params))[9];
6539
+
6540
+
6541
+ const int ith = params->ith;
6542
+ const int nth = params->nth;
6543
+
6544
+ const int64_t N = ne13 / IC;
6545
+ const int64_t ID = ne12;
6546
+ const int64_t IH = ne11;
6547
+ const int64_t IW = ne10;
6548
+
6549
+ const int64_t OC = ne03 / IC;
6550
+ WSP_GGML_UNUSED(OC);
6551
+ const int64_t KD = ne02;
6552
+ const int64_t KH = ne01;
6553
+ const int64_t KW = ne00;
6554
+
6555
+ const int64_t OD = ne3 / N;
6556
+ const int64_t OH = ne2;
6557
+ const int64_t OW = ne1;
6558
+
6559
+ const int64_t OH_OW = OH*OW;
6560
+ const int64_t KD_KH_KW = KD*KH*KW;
6561
+ const int64_t KH_KW = KH*KW;
6562
+ const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
6563
+
6564
+ WSP_GGML_ASSERT(nb10 == sizeof(float));
6565
+
6566
+ // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
6567
+ {
6568
+ float * const wdata = (float *) dst->data;
6569
+
6570
+ for (int64_t in = 0; in < N; in++) {
6571
+ for (int64_t iod = 0; iod < OD; iod++) {
6572
+ for (int64_t ioh = 0; ioh < OH; ioh++) {
6573
+ for (int64_t iow = 0; iow < OW; iow++) {
6574
+ for (int64_t iic = ith; iic < IC; iic += nth) {
6575
+
6576
+ // micro kernel
6577
+ float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
6578
+ const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
6579
+
6580
+ for (int64_t ikd = 0; ikd < KD; ikd++) {
6581
+ for (int64_t ikh = 0; ikh < KH; ikh++) {
6582
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
6583
+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
6584
+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
6585
+ const int64_t iid = iod*s2 + ikd*d2 - p2;
6586
+
6587
+ if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
6588
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
6589
+ } else {
6590
+ const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
6591
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s;
6592
+ }
6593
+ }
6594
+ }
6595
+ }
6596
+ }
6597
+ }
6598
+ }
6599
+ }
6600
+ }
6601
+ }
6602
+ }
6603
+
6604
+
6605
+ void wsp_ggml_compute_forward_im2col_3d(
6606
+ const wsp_ggml_compute_params * params,
6607
+ wsp_ggml_tensor * dst) {
6608
+ switch (dst->type) {
6609
+ case WSP_GGML_TYPE_F16:
6610
+ {
6611
+ wsp_ggml_compute_forward_im2col_3d_f16(params, dst);
6612
+ } break;
6613
+ case WSP_GGML_TYPE_F32:
6614
+ {
6615
+ wsp_ggml_compute_forward_im2col_3d_f32(params, dst);
6616
+ } break;
6617
+ default:
6618
+ {
6619
+ WSP_GGML_ABORT("fatal error");
6620
+ }
6621
+ }
6622
+ }
6623
+
7030
6624
  static void wsp_ggml_call_mul_mat(wsp_ggml_type type, const wsp_ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
7031
6625
  void * a, void * b, float * c) {
7032
6626
  const wsp_ggml_type_traits * traits = wsp_ggml_get_type_traits(type);
@@ -7207,6 +6801,148 @@ void wsp_ggml_compute_forward_conv_2d(
7207
6801
  wsp_ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
7208
6802
  }
7209
6803
 
6804
+ // wsp_ggml_compute_forward_conv_3d
6805
+
6806
+ static void wsp_ggml_compute_forward_conv_3d_impl(const wsp_ggml_compute_params * params,
6807
+ const wsp_ggml_tensor * kernel,
6808
+ const wsp_ggml_tensor * src,
6809
+ wsp_ggml_tensor * dst,
6810
+ wsp_ggml_type kernel_type) {
6811
+
6812
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(kernel));
6813
+ WSP_GGML_ASSERT(kernel_type == WSP_GGML_TYPE_F16 || kernel_type == WSP_GGML_TYPE_F32);
6814
+ WSP_GGML_ASSERT(kernel->type == kernel_type);
6815
+
6816
+ const wsp_ggml_type_traits * traits = wsp_ggml_get_type_traits(kernel_type);
6817
+
6818
+ const int32_t s0 = dst->op_params[0];
6819
+ const int32_t s1 = dst->op_params[1];
6820
+ const int32_t s2 = dst->op_params[2];
6821
+ const int32_t p0 = dst->op_params[3];
6822
+ const int32_t p1 = dst->op_params[4];
6823
+ const int32_t p2 = dst->op_params[5];
6824
+ const int32_t d0 = dst->op_params[6];
6825
+ const int32_t d1 = dst->op_params[7];
6826
+ const int32_t d2 = dst->op_params[8];
6827
+ const int32_t c = dst->op_params[9];
6828
+ const int32_t n = dst->op_params[10];
6829
+ const int32_t oc = dst->op_params[11];
6830
+
6831
+ const int64_t src_w = src->ne[0];
6832
+ const int64_t src_h = src->ne[1];
6833
+ const int64_t src_d = src->ne[2];
6834
+ const int64_t knl_w = kernel->ne[0];
6835
+ const int64_t knl_h = kernel->ne[1];
6836
+ const int64_t knl_d = kernel->ne[2];
6837
+ const int64_t dst_w = dst->ne[0];
6838
+ const int64_t dst_h = dst->ne[1];
6839
+ const int64_t dst_d = dst->ne[2];
6840
+
6841
+ const float * src_data = (float *) src->data;
6842
+ void * knl_data = kernel->data;
6843
+ float * dst_data = (float *) dst->data;
6844
+
6845
+ const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
6846
+ const int64_t knl_n_total = knl_n_per_channel * c;
6847
+ const int64_t patch_total = n * dst_w * dst_h * dst_d;
6848
+
6849
+ const int64_t space_per_patch = knl_n_total * traits->type_size + oc * sizeof(float);
6850
+ const int64_t batch_size = params->wsize / space_per_patch;
6851
+ const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
6852
+ const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
6853
+
6854
+ WSP_GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
6855
+
6856
+ void * tmp = params->wdata;
6857
+
6858
+ for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
6859
+ const int64_t patch_start_batch = batch_i * patches_per_batch;
6860
+ const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, patch_total);
6861
+ const int64_t patch_n_in_batch = patch_end_batch - patch_start_batch;
6862
+
6863
+ const int64_t patch_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
6864
+ const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
6865
+ const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
6866
+
6867
+ for (int64_t p = patch_start; p < patch_end; ++p) {
6868
+ const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
6869
+ const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
6870
+ const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
6871
+ const int64_t dst_z = p_in_batch / (dst_w * dst_h);
6872
+ const int64_t dst_y = p_in_depth / dst_w;
6873
+ const int64_t dst_x = p_in_depth % dst_w;
6874
+
6875
+ char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
6876
+
6877
+ for (int64_t ic = 0; ic < c; ++ic) {
6878
+ for (int64_t kz = 0; kz < knl_d; ++kz) {
6879
+ for (int64_t ky = 0; ky < knl_h; ++ky) {
6880
+ for (int64_t kx = 0; kx < knl_w; ++kx) {
6881
+ const int64_t sz = dst_z * s2 + kz * d2 - p2;
6882
+ const int64_t sy = dst_y * s1 + ky * d1 - p1;
6883
+ const int64_t sx = dst_x * s0 + kx * d0 - p0;
6884
+
6885
+ int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
6886
+
6887
+ float src_val;
6888
+ if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6889
+ src_val = 0.0f;
6890
+ } else {
6891
+ const int64_t cn_idx = batch_idx * c + ic;
6892
+ const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
6893
+ src_val = *src_ptr;
6894
+ }
6895
+
6896
+ char * element_ptr = dst_row + dst_idx * traits->type_size;
6897
+ if (kernel_type == WSP_GGML_TYPE_F32) {
6898
+ *(float *)element_ptr = src_val;
6899
+ } else if (kernel_type == WSP_GGML_TYPE_F16) {
6900
+ *(wsp_ggml_fp16_t *)element_ptr = WSP_GGML_CPU_FP32_TO_FP16(src_val);
6901
+ }
6902
+ }
6903
+ }
6904
+ }
6905
+ }
6906
+ }
6907
+
6908
+ wsp_ggml_barrier(params->threadpool);
6909
+
6910
+ float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
6911
+ wsp_ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
6912
+
6913
+ wsp_ggml_barrier(params->threadpool);
6914
+
6915
+ const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
6916
+ const int64_t permute_start = params->ith * permute_per_thread;
6917
+ const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);
6918
+
6919
+ for (int64_t i = permute_start; i < permute_end; ++i) {
6920
+ const int64_t p = patch_start_batch + i;
6921
+ const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
6922
+ const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
6923
+ const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
6924
+ const int64_t dst_z = p_in_batch / (dst_w * dst_h);
6925
+ const int64_t dst_y = p_in_depth / dst_w;
6926
+ const int64_t dst_x = p_in_depth % dst_w;
6927
+
6928
+ for (int64_t ioc = 0; ioc < oc; ++ioc) {
6929
+ const float value = gemm_output[i * oc + ioc];
6930
+ const int64_t ocn_idx = batch_idx * oc + ioc;
6931
+ float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
6932
+ *dst_ptr = value;
6933
+ }
6934
+ }
6935
+ }
6936
+ }
6937
+
6938
+ void wsp_ggml_compute_forward_conv_3d(
6939
+ const wsp_ggml_compute_params * params,
6940
+ wsp_ggml_tensor * dst) {
6941
+ const wsp_ggml_tensor * src0 = dst->src[0];
6942
+ const wsp_ggml_tensor * src1 = dst->src[1];
6943
+ wsp_ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
6944
+ }
6945
+
7210
6946
  // wsp_ggml_compute_forward_conv_transpose_2d
7211
6947
 
7212
6948
  void wsp_ggml_compute_forward_conv_transpose_2d(
@@ -7872,6 +7608,15 @@ static void wsp_ggml_compute_forward_pad_f32(
7872
7608
  WSP_GGML_TENSOR_UNARY_OP_LOCALS
7873
7609
 
7874
7610
  float * dst_ptr = (float *) dst->data;
7611
+ const int32_t lp0 = wsp_ggml_get_op_params_i32(dst, 0);
7612
+ const int32_t rp0 = wsp_ggml_get_op_params_i32(dst, 1);
7613
+ const int32_t lp1 = wsp_ggml_get_op_params_i32(dst, 2);
7614
+ const int32_t rp1 = wsp_ggml_get_op_params_i32(dst, 3);
7615
+ const int32_t lp2 = wsp_ggml_get_op_params_i32(dst, 4);
7616
+ const int32_t rp2 = wsp_ggml_get_op_params_i32(dst, 5);
7617
+ const int32_t lp3 = wsp_ggml_get_op_params_i32(dst, 6);
7618
+ const int32_t rp3 = wsp_ggml_get_op_params_i32(dst, 7);
7619
+
7875
7620
 
7876
7621
  // TODO: optimize
7877
7622
 
@@ -7880,10 +7625,12 @@ static void wsp_ggml_compute_forward_pad_f32(
7880
7625
  for (int64_t i0 = 0; i0 < ne0; ++i0) {
7881
7626
  for (int64_t i3 = 0; i3 < ne3; ++i3) {
7882
7627
  const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7883
-
7884
- const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
7885
-
7886
- if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
7628
+ if ((i0 >= lp0 && i0 < ne0 - rp0) \
7629
+ && (i1 >= lp1 && i1 < ne1 - rp1) \
7630
+ && (i2 >= lp2 && i2 < ne2 - rp2) \
7631
+ && (i3 >= lp3 && i3 < ne3 - rp3)) {
7632
+ const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
7633
+ const float * src_ptr = (const float *)((char *) src0->data + src_idx);
7887
7634
  dst_ptr[dst_idx] = *src_ptr;
7888
7635
  } else {
7889
7636
  dst_ptr[dst_idx] = 0;
@@ -8082,7 +7829,7 @@ static void wsp_ggml_compute_forward_timestep_embedding_f32(
8082
7829
  embed_data[j + half] = sinf(arg);
8083
7830
  }
8084
7831
  if (dim % 2 != 0 && ith == 0) {
8085
- embed_data[dim] = 0.f;
7832
+ embed_data[2 * half] = 0.f;
8086
7833
  }
8087
7834
  }
8088
7835
  }
@@ -8861,8 +8608,7 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
8861
8608
  WSP_GGML_ASSERT(src4->nb[0] == sizeof(float));
8862
8609
  WSP_GGML_ASSERT(src5->nb[0] == sizeof(float));
8863
8610
  WSP_GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
8864
- // allows optimizing the modulo since n_group should be a power of 2
8865
- WSP_GGML_ASSERT((ng & -ng) == ng);
8611
+ WSP_GGML_ASSERT(nh % ng == 0);
8866
8612
 
8867
8613
  // heads per thread
8868
8614
  const int dh = (nh + nth - 1)/nth;
@@ -8893,6 +8639,7 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
8893
8639
  // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8894
8640
  const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8895
8641
  const float dA = expf(dt_soft_plus * A[h]);
8642
+ const int g = h / (nh / ng); // repeat_interleave
8896
8643
 
8897
8644
  // dim
8898
8645
  for (int i1 = 0; i1 < nr; ++i1) {
@@ -8915,8 +8662,8 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
8915
8662
  // TODO: maybe unroll more?
8916
8663
  for (int j = 0; j < 1; j++) {
8917
8664
  WSP_GGML_F32_VEC t0 = WSP_GGML_F32_VEC_LOAD(s0 + i + j*wsp_ggml_f32_epr + ii*nc);
8918
- WSP_GGML_F32_VEC t1 = WSP_GGML_F32_VEC_LOAD(B + i + j*wsp_ggml_f32_epr + (h & (ng - 1))*nc);
8919
- WSP_GGML_F32_VEC t2 = WSP_GGML_F32_VEC_LOAD(C + i + j*wsp_ggml_f32_epr + (h & (ng - 1))*nc);
8665
+ WSP_GGML_F32_VEC t1 = WSP_GGML_F32_VEC_LOAD(B + i + j*wsp_ggml_f32_epr + g*nc);
8666
+ WSP_GGML_F32_VEC t2 = WSP_GGML_F32_VEC_LOAD(C + i + j*wsp_ggml_f32_epr + g*nc);
8920
8667
 
8921
8668
  t0 = WSP_GGML_F32_VEC_MUL(t0, adA);
8922
8669
  t1 = WSP_GGML_F32_VEC_MUL(t1, axdt);
@@ -8930,6 +8677,9 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
8930
8677
  }
8931
8678
 
8932
8679
  sumf = WSP_GGML_F32xt_REDUCE_ONE(sum);
8680
+ #elif defined(__riscv_v_intrinsic)
8681
+ // todo: RVV implementation
8682
+ const int np = 0;
8933
8683
  #else
8934
8684
  const int np = (nc & ~(WSP_GGML_F32_STEP - 1));
8935
8685
 
@@ -8945,8 +8695,8 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
8945
8695
  for (int i = 0; i < np; i += WSP_GGML_F32_STEP) {
8946
8696
  for (int j = 0; j < WSP_GGML_F32_ARR; j++) {
8947
8697
  ax[j] = WSP_GGML_F32_VEC_LOAD(s0 + i + j*WSP_GGML_F32_EPR + ii*nc);
8948
- ay[j] = WSP_GGML_F32_VEC_LOAD(B + i + j*WSP_GGML_F32_EPR + (h & (ng - 1))*nc);
8949
- az[j] = WSP_GGML_F32_VEC_LOAD(C + i + j*WSP_GGML_F32_EPR + (h & (ng - 1))*nc);
8698
+ ay[j] = WSP_GGML_F32_VEC_LOAD(B + i + j*WSP_GGML_F32_EPR + g*nc);
8699
+ az[j] = WSP_GGML_F32_VEC_LOAD(C + i + j*WSP_GGML_F32_EPR + g*nc);
8950
8700
 
8951
8701
  ax[j] = WSP_GGML_F32_VEC_MUL(ax[j], adA);
8952
8702
  ay[j] = WSP_GGML_F32_VEC_MUL(ay[j], axdt);
@@ -8968,7 +8718,7 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
8968
8718
  // d_state
8969
8719
  for (int i0 = np; i0 < nc; ++i0) {
8970
8720
  const int i = i0 + ii*nc;
8971
- const int ig = i0 + (h & (ng - 1))*nc;
8721
+ const int ig = i0 + g*nc;
8972
8722
  // state = prev_state * dA + dB * x
8973
8723
  const float state = (s0[i] * dA) + (B[ig] * x_dt);
8974
8724
  // y = rowwise_dotprod(state, C)
@@ -8985,6 +8735,7 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
8985
8735
  for (int h = ih0; h < ih1; ++h) {
8986
8736
  // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8987
8737
  const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8738
+ const int g = h / (nh / ng); // repeat_interleave
8988
8739
 
8989
8740
  // dim
8990
8741
  for (int i1 = 0; i1 < nr; ++i1) {
@@ -8999,8 +8750,8 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
8999
8750
  // TODO: what happens when (d_state % svcntw()) != 0?
9000
8751
  for (int64_t k = 0; k < nc; k += svcntw()) {
9001
8752
  svfloat32_t vA = WSP_GGML_F32_VEC_LOAD(&A[h*nc + k]);
9002
- svfloat32_t vB = WSP_GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
9003
- svfloat32_t vC = WSP_GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
8753
+ svfloat32_t vB = WSP_GGML_F32_VEC_LOAD(&B[k + g*nc]);
8754
+ svfloat32_t vC = WSP_GGML_F32_VEC_LOAD(&C[k + g*nc]);
9004
8755
  svfloat32_t vs0 = WSP_GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
9005
8756
 
9006
8757
  svfloat32_t t1 = WSP_GGML_F32_VEC_MUL(vdt_soft_plus, vA);
@@ -9020,7 +8771,7 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
9020
8771
  // d_state
9021
8772
  for (int i0 = 0; i0 < nc; ++i0) {
9022
8773
  const int i = i0 + ii*nc;
9023
- const int ig = i0 + (h & (ng - 1))*nc;
8774
+ const int ig = i0 + g*nc;
9024
8775
  // state = prev_state * dA + dB * x
9025
8776
  const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
9026
8777
  // y = rowwise_dotprod(state, C)
@@ -9881,8 +9632,8 @@ static void wsp_ggml_compute_forward_rwkv_wkv7_f32(
9881
9632
  int64_t h_stride_2d = head_size * head_size;
9882
9633
 
9883
9634
  #if defined(WSP_GGML_SIMD)
9884
- #if defined(__ARM_FEATURE_SVE)
9885
- // scalar Route to scalar implementation //TODO: Write SVE code
9635
+ #if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
9636
+ // scalar Route to scalar implementation //TODO: Write SVE code and RVV code
9886
9637
  for (int64_t t = 0; t < T; t++) {
9887
9638
  int64_t t_offset = t * t_stride;
9888
9639
  int64_t state_offset = head_size * C * (t / (T / n_seqs));