whisper.rn 0.5.0 → 0.5.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/gradle.properties +1 -1
  3. package/android/src/main/jni.cpp +12 -3
  4. package/cpp/ggml-alloc.c +292 -130
  5. package/cpp/ggml-backend-impl.h +4 -4
  6. package/cpp/ggml-backend-reg.cpp +13 -5
  7. package/cpp/ggml-backend.cpp +207 -17
  8. package/cpp/ggml-backend.h +19 -1
  9. package/cpp/ggml-cpu/amx/amx.cpp +5 -2
  10. package/cpp/ggml-cpu/arch/x86/repack.cpp +2 -2
  11. package/cpp/ggml-cpu/arch-fallback.h +0 -4
  12. package/cpp/ggml-cpu/common.h +14 -0
  13. package/cpp/ggml-cpu/ggml-cpu-impl.h +14 -7
  14. package/cpp/ggml-cpu/ggml-cpu.c +65 -44
  15. package/cpp/ggml-cpu/ggml-cpu.cpp +14 -4
  16. package/cpp/ggml-cpu/ops.cpp +542 -775
  17. package/cpp/ggml-cpu/ops.h +2 -0
  18. package/cpp/ggml-cpu/simd-mappings.h +88 -59
  19. package/cpp/ggml-cpu/unary-ops.cpp +135 -0
  20. package/cpp/ggml-cpu/unary-ops.h +5 -0
  21. package/cpp/ggml-cpu/vec.cpp +227 -20
  22. package/cpp/ggml-cpu/vec.h +407 -56
  23. package/cpp/ggml-cpu.h +1 -1
  24. package/cpp/ggml-impl.h +94 -12
  25. package/cpp/ggml-metal/ggml-metal-common.cpp +446 -0
  26. package/cpp/ggml-metal/ggml-metal-common.h +52 -0
  27. package/cpp/ggml-metal/ggml-metal-context.h +33 -0
  28. package/cpp/ggml-metal/ggml-metal-context.m +600 -0
  29. package/cpp/ggml-metal/ggml-metal-device.cpp +1565 -0
  30. package/cpp/ggml-metal/ggml-metal-device.h +244 -0
  31. package/cpp/ggml-metal/ggml-metal-device.m +1325 -0
  32. package/cpp/ggml-metal/ggml-metal-impl.h +802 -0
  33. package/cpp/ggml-metal/ggml-metal-ops.cpp +3583 -0
  34. package/cpp/ggml-metal/ggml-metal-ops.h +88 -0
  35. package/cpp/ggml-metal/ggml-metal.cpp +718 -0
  36. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  37. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  38. package/cpp/ggml-metal-impl.h +40 -40
  39. package/cpp/ggml-metal.h +1 -6
  40. package/cpp/ggml-quants.c +1 -0
  41. package/cpp/ggml.c +341 -15
  42. package/cpp/ggml.h +150 -5
  43. package/cpp/jsi/RNWhisperJSI.cpp +9 -2
  44. package/cpp/jsi/ThreadPool.h +3 -3
  45. package/cpp/rn-whisper.h +1 -0
  46. package/cpp/whisper.cpp +89 -72
  47. package/cpp/whisper.h +1 -0
  48. package/ios/CMakeLists.txt +6 -1
  49. package/ios/RNWhisperContext.mm +3 -1
  50. package/ios/RNWhisperVadContext.mm +14 -13
  51. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
  52. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +19 -1
  53. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  54. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +94 -12
  55. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  56. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  57. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +150 -5
  58. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  59. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  60. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  61. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  62. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  63. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
  64. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +19 -1
  65. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  66. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +94 -12
  67. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  68. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  69. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +150 -5
  70. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  71. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  72. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  73. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  74. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  75. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  76. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
  77. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +19 -1
  78. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  79. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +94 -12
  80. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  81. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  82. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +150 -5
  83. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  84. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  85. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  86. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  87. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  88. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -4
  89. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +19 -1
  90. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  91. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +94 -12
  92. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  93. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  94. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +150 -5
  95. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  96. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  97. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  98. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  99. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  100. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  101. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  102. package/lib/commonjs/version.json +1 -1
  103. package/lib/module/NativeRNWhisper.js.map +1 -1
  104. package/lib/module/version.json +1 -1
  105. package/lib/typescript/NativeRNWhisper.d.ts +2 -0
  106. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  107. package/package.json +1 -1
  108. package/src/NativeRNWhisper.ts +2 -0
  109. package/src/version.json +1 -1
  110. package/whisper-rn.podspec +8 -9
  111. package/cpp/ggml-metal.m +0 -6779
  112. package/cpp/ggml-whisper-sim.metallib +0 -0
  113. package/cpp/ggml-whisper.metallib +0 -0
@@ -41,13 +41,15 @@ static void wsp_ggml_compute_forward_dup_same_cont(
41
41
  }
42
42
  }
43
43
 
44
- static void wsp_ggml_compute_forward_dup_f16(
44
+ template<typename src_t, typename dst_t>
45
+ static void wsp_ggml_compute_forward_dup_flt(
45
46
  const wsp_ggml_compute_params * params,
46
47
  wsp_ggml_tensor * dst) {
47
48
 
48
49
  const wsp_ggml_tensor * src0 = dst->src[0];
49
50
 
50
51
  WSP_GGML_ASSERT(wsp_ggml_nelements(dst) == wsp_ggml_nelements(src0));
52
+ WSP_GGML_ASSERT(!wsp_ggml_is_quantized(src0->type) && !wsp_ggml_is_quantized(dst->type));
51
53
 
52
54
  WSP_GGML_TENSOR_UNARY_OP_LOCALS
53
55
 
@@ -62,6 +64,7 @@ static void wsp_ggml_compute_forward_dup_f16(
62
64
  const int ir0 = dr * ith;
63
65
  const int ir1 = MIN(ir0 + dr, nr);
64
66
 
67
+ // case: type & row size equal
65
68
  if (src0->type == dst->type &&
66
69
  ne00 == ne0 &&
67
70
  nb00 == wsp_ggml_type_size(src0->type) && nb0 == wsp_ggml_type_size(dst->type)) {
@@ -80,11 +83,11 @@ static void wsp_ggml_compute_forward_dup_f16(
80
83
  return;
81
84
  }
82
85
 
83
- // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
84
-
86
+ // case: dst tensor is contiguous
85
87
  if (wsp_ggml_is_contiguous(dst)) {
86
- if (nb00 == sizeof(wsp_ggml_fp16_t)) {
87
- if (dst->type == WSP_GGML_TYPE_F16) {
88
+ if (nb00 == sizeof(src_t)) {
89
+ if constexpr (std::is_same_v<dst_t, src_t>) {
90
+ // same type
88
91
  size_t id = 0;
89
92
  const size_t rs = ne00 * nb00;
90
93
  char * dst_ptr = (char *) dst->data;
@@ -100,750 +103,58 @@ static void wsp_ggml_compute_forward_dup_f16(
100
103
  id += rs * (ne01 - ir1);
101
104
  }
102
105
  }
103
- } else if (dst->type == WSP_GGML_TYPE_F32) {
106
+ } else {
107
+ // casting between non-quantized types
104
108
  size_t id = 0;
105
- float * dst_ptr = (float *) dst->data;
109
+ dst_t * dst_ptr = (dst_t *) dst->data;
106
110
 
107
111
  for (int i03 = 0; i03 < ne03; i03++) {
108
112
  for (int i02 = 0; i02 < ne02; i02++) {
109
113
  id += ne00 * ir0;
110
114
  for (int i01 = ir0; i01 < ir1; i01++) {
111
- const wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
115
+ const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
112
116
  for (int i00 = 0; i00 < ne00; i00++) {
113
- dst_ptr[id] = WSP_GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
117
+ float tmp = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
118
+ dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
114
119
  id++;
115
120
  }
116
121
  }
117
122
  id += ne00 * (ne01 - ir1);
118
123
  }
119
124
  }
120
- } else if (wsp_ggml_get_type_traits_cpu(dst->type)->from_float) {
121
- wsp_ggml_from_float_t const wsp_quantize_row_q = wsp_ggml_get_type_traits_cpu(dst->type)->from_float;
122
- float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
123
-
124
- size_t id = 0;
125
- size_t rs = nb0 * (ne00 / wsp_ggml_blck_size(dst->type));
126
- char * dst_ptr = (char *) dst->data;
127
-
128
- for (int i03 = 0; i03 < ne03; i03++) {
129
- for (int i02 = 0; i02 < ne02; i02++) {
130
- id += rs * ir0;
131
- for (int i01 = ir0; i01 < ir1; i01++) {
132
- const wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
133
-
134
- for (int i00 = 0; i00 < ne00; i00++) {
135
- src0_f32[i00] = WSP_GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
136
- }
137
-
138
- wsp_quantize_row_q(src0_f32, dst_ptr + id, ne00);
139
- id += rs;
140
- }
141
- id += rs * (ne01 - ir1);
142
- }
143
- }
144
- } else {
145
- WSP_GGML_ABORT("fatal error"); // TODO: implement
146
125
  }
147
126
  } else {
148
127
  //printf("%s: this is not optimal - fix me\n", __func__);
149
128
 
150
- if (dst->type == WSP_GGML_TYPE_F32) {
151
- size_t id = 0;
152
- float * dst_ptr = (float *) dst->data;
153
-
154
- for (int i03 = 0; i03 < ne03; i03++) {
155
- for (int i02 = 0; i02 < ne02; i02++) {
156
- id += ne00 * ir0;
157
- for (int i01 = ir0; i01 < ir1; i01++) {
158
- for (int i00 = 0; i00 < ne00; i00++) {
159
- const wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
160
-
161
- dst_ptr[id] = WSP_GGML_CPU_FP16_TO_FP32(*src0_ptr);
162
- id++;
163
- }
164
- }
165
- id += ne00 * (ne01 - ir1);
166
- }
167
- }
168
- } else if (dst->type == WSP_GGML_TYPE_F16) {
169
- size_t id = 0;
170
- wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) dst->data;
171
-
172
- for (int i03 = 0; i03 < ne03; i03++) {
173
- for (int i02 = 0; i02 < ne02; i02++) {
174
- id += ne00 * ir0;
175
- for (int i01 = ir0; i01 < ir1; i01++) {
176
- for (int i00 = 0; i00 < ne00; i00++) {
177
- const wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
178
-
179
- dst_ptr[id] = *src0_ptr;
180
- id++;
181
- }
182
- }
183
- id += ne00 * (ne01 - ir1);
184
- }
185
- }
186
- } else {
187
- WSP_GGML_ABORT("fatal error"); // TODO: implement
188
- }
189
- }
190
- return;
191
- }
192
-
193
- // dst counters
194
- int64_t i10 = 0;
195
- int64_t i11 = 0;
196
- int64_t i12 = 0;
197
- int64_t i13 = 0;
198
-
199
- if (dst->type == WSP_GGML_TYPE_F16) {
200
- for (int64_t i03 = 0; i03 < ne03; i03++) {
201
- for (int64_t i02 = 0; i02 < ne02; i02++) {
202
- i10 += ne00 * ir0;
203
- while (i10 >= ne0) {
204
- i10 -= ne0;
205
- if (++i11 == ne1) {
206
- i11 = 0;
207
- if (++i12 == ne2) {
208
- i12 = 0;
209
- if (++i13 == ne3) {
210
- i13 = 0;
211
- }
212
- }
213
- }
214
- }
215
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
216
- for (int64_t i00 = 0; i00 < ne00; i00++) {
217
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
218
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
219
-
220
- memcpy(dst_ptr, src0_ptr, sizeof(wsp_ggml_fp16_t));
221
-
222
- if (++i10 == ne00) {
223
- i10 = 0;
224
- if (++i11 == ne01) {
225
- i11 = 0;
226
- if (++i12 == ne02) {
227
- i12 = 0;
228
- if (++i13 == ne03) {
229
- i13 = 0;
230
- }
231
- }
232
- }
233
- }
234
- }
235
- }
236
- i10 += ne00 * (ne01 - ir1);
237
- while (i10 >= ne0) {
238
- i10 -= ne0;
239
- if (++i11 == ne1) {
240
- i11 = 0;
241
- if (++i12 == ne2) {
242
- i12 = 0;
243
- if (++i13 == ne3) {
244
- i13 = 0;
245
- }
246
- }
247
- }
248
- }
249
- }
250
- }
251
- } else if (dst->type == WSP_GGML_TYPE_F32) {
252
- for (int64_t i03 = 0; i03 < ne03; i03++) {
253
- for (int64_t i02 = 0; i02 < ne02; i02++) {
254
- i10 += ne00 * ir0;
255
- while (i10 >= ne0) {
256
- i10 -= ne0;
257
- if (++i11 == ne1) {
258
- i11 = 0;
259
- if (++i12 == ne2) {
260
- i12 = 0;
261
- if (++i13 == ne3) {
262
- i13 = 0;
263
- }
264
- }
265
- }
266
- }
267
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
268
- for (int64_t i00 = 0; i00 < ne00; i00++) {
269
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
270
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
271
-
272
- *(float *) dst_ptr = WSP_GGML_CPU_FP16_TO_FP32(*(const wsp_ggml_fp16_t *) src0_ptr);
273
-
274
- if (++i10 == ne0) {
275
- i10 = 0;
276
- if (++i11 == ne1) {
277
- i11 = 0;
278
- if (++i12 == ne2) {
279
- i12 = 0;
280
- if (++i13 == ne3) {
281
- i13 = 0;
282
- }
283
- }
284
- }
285
- }
286
- }
287
- }
288
- i10 += ne00 * (ne01 - ir1);
289
- while (i10 >= ne0) {
290
- i10 -= ne0;
291
- if (++i11 == ne1) {
292
- i11 = 0;
293
- if (++i12 == ne2) {
294
- i12 = 0;
295
- if (++i13 == ne3) {
296
- i13 = 0;
297
- }
298
- }
299
- }
300
- }
301
- }
302
- }
303
- } else {
304
- WSP_GGML_ABORT("fatal error"); // TODO: implement
305
- }
306
- }
307
-
308
- static void wsp_ggml_compute_forward_dup_bf16(
309
- const wsp_ggml_compute_params * params,
310
- wsp_ggml_tensor * dst) {
311
-
312
- const wsp_ggml_tensor * src0 = dst->src[0];
313
-
314
- WSP_GGML_ASSERT(wsp_ggml_nelements(dst) == wsp_ggml_nelements(src0));
315
-
316
- WSP_GGML_TENSOR_UNARY_OP_LOCALS
317
-
318
- const int ith = params->ith; // thread index
319
- const int nth = params->nth; // number of threads
320
-
321
- // parallelize by rows
322
- const int nr = ne01;
323
- // number of rows per thread
324
- const int dr = (nr + nth - 1) / nth;
325
- // row range for this thread
326
- const int ir0 = dr * ith;
327
- const int ir1 = MIN(ir0 + dr, nr);
328
-
329
- if (src0->type == dst->type &&
330
- ne00 == ne0 &&
331
- nb00 == wsp_ggml_type_size(src0->type) && nb0 == wsp_ggml_type_size(dst->type)) {
332
- // copy by rows
333
- const size_t rs = ne00*nb00;
334
- for (int64_t i03 = 0; i03 < ne03; i03++) {
335
- for (int64_t i02 = 0; i02 < ne02; i02++) {
336
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
337
- memcpy(
338
- ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
339
- ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
340
- rs);
341
- }
342
- }
343
- }
344
- return;
345
- }
346
-
347
- // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
348
-
349
- if (wsp_ggml_is_contiguous(dst)) {
350
- if (nb00 == sizeof(wsp_ggml_bf16_t)) {
351
- if (dst->type == WSP_GGML_TYPE_BF16) {
352
- size_t id = 0;
353
- const size_t rs = ne00 * nb00;
354
- char * dst_ptr = (char *) dst->data;
355
-
356
- for (int i03 = 0; i03 < ne03; i03++) {
357
- for (int i02 = 0; i02 < ne02; i02++) {
358
- id += rs * ir0;
359
- for (int i01 = ir0; i01 < ir1; i01++) {
360
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
361
- memcpy(dst_ptr + id, src0_ptr, rs);
362
- id += rs;
363
- }
364
- id += rs * (ne01 - ir1);
365
- }
366
- }
367
- } else if (dst->type == WSP_GGML_TYPE_F16) {
368
- size_t id = 0;
369
- wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) dst->data;
370
-
371
- for (int i03 = 0; i03 < ne03; i03++) {
372
- for (int i02 = 0; i02 < ne02; i02++) {
373
- id += ne00 * ir0;
374
- for (int i01 = ir0; i01 < ir1; i01++) {
375
- const wsp_ggml_bf16_t * src0_ptr = (wsp_ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
376
- for (int i00 = 0; i00 < ne00; i00++) {
377
- dst_ptr[id] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_BF16_TO_FP32(src0_ptr[i00]));
378
- id++;
379
- }
380
- }
381
- id += ne00 * (ne01 - ir1);
382
- }
383
- }
384
- } else if (dst->type == WSP_GGML_TYPE_F32) {
385
- size_t id = 0;
386
- float * dst_ptr = (float *) dst->data;
387
-
388
- for (int i03 = 0; i03 < ne03; i03++) {
389
- for (int i02 = 0; i02 < ne02; i02++) {
390
- id += ne00 * ir0;
391
- for (int i01 = ir0; i01 < ir1; i01++) {
392
- const wsp_ggml_bf16_t * src0_ptr = (wsp_ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
393
- for (int i00 = 0; i00 < ne00; i00++) {
394
- dst_ptr[id] = WSP_GGML_BF16_TO_FP32(src0_ptr[i00]);
395
- id++;
396
- }
397
- }
398
- id += ne00 * (ne01 - ir1);
399
- }
400
- }
401
- } else if (wsp_ggml_get_type_traits_cpu(dst->type)->from_float) {
402
- wsp_ggml_from_float_t const wsp_quantize_row_q = wsp_ggml_get_type_traits_cpu(dst->type)->from_float;
403
- float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
404
-
405
- size_t id = 0;
406
- size_t rs = nb0 * (ne00 / wsp_ggml_blck_size(dst->type));
407
- char * dst_ptr = (char *) dst->data;
408
-
409
- for (int i03 = 0; i03 < ne03; i03++) {
410
- for (int i02 = 0; i02 < ne02; i02++) {
411
- id += rs * ir0;
412
- for (int i01 = ir0; i01 < ir1; i01++) {
413
- const wsp_ggml_bf16_t * src0_ptr = (wsp_ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
414
-
415
- for (int i00 = 0; i00 < ne00; i00++) {
416
- src0_f32[i00] = WSP_GGML_BF16_TO_FP32(src0_ptr[i00]);
417
- }
129
+ size_t id = 0;
130
+ dst_t * dst_ptr = (dst_t *) dst->data;
418
131
 
419
- wsp_quantize_row_q(src0_f32, dst_ptr + id, ne00);
420
- id += rs;
421
- }
422
- id += rs * (ne01 - ir1);
423
- }
424
- }
425
- } else {
426
- WSP_GGML_ABORT("fatal error"); // TODO: implement
427
- }
428
- } else {
429
- //printf("%s: this is not optimal - fix me\n", __func__);
430
-
431
- if (dst->type == WSP_GGML_TYPE_F32) {
432
- size_t id = 0;
433
- float * dst_ptr = (float *) dst->data;
434
-
435
- for (int i03 = 0; i03 < ne03; i03++) {
436
- for (int i02 = 0; i02 < ne02; i02++) {
437
- id += ne00 * ir0;
438
- for (int i01 = ir0; i01 < ir1; i01++) {
439
- for (int i00 = 0; i00 < ne00; i00++) {
440
- const wsp_ggml_bf16_t * src0_ptr = (wsp_ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
441
-
442
- dst_ptr[id] = WSP_GGML_BF16_TO_FP32(*src0_ptr);
443
- id++;
444
- }
445
- }
446
- id += ne00 * (ne01 - ir1);
447
- }
448
- }
449
- } else if (dst->type == WSP_GGML_TYPE_BF16) {
450
- size_t id = 0;
451
- wsp_ggml_bf16_t * dst_ptr = (wsp_ggml_bf16_t *) dst->data;
452
-
453
- for (int i03 = 0; i03 < ne03; i03++) {
454
- for (int i02 = 0; i02 < ne02; i02++) {
455
- id += ne00 * ir0;
456
- for (int i01 = ir0; i01 < ir1; i01++) {
457
- for (int i00 = 0; i00 < ne00; i00++) {
458
- const wsp_ggml_bf16_t * src0_ptr = (wsp_ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
459
-
460
- dst_ptr[id] = *src0_ptr;
461
- id++;
462
- }
463
- }
464
- id += ne00 * (ne01 - ir1);
465
- }
466
- }
467
- } else if (dst->type == WSP_GGML_TYPE_F16) {
468
- size_t id = 0;
469
- wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) dst->data;
470
-
471
- for (int i03 = 0; i03 < ne03; i03++) {
472
- for (int i02 = 0; i02 < ne02; i02++) {
473
- id += ne00 * ir0;
474
- for (int i01 = ir0; i01 < ir1; i01++) {
475
- for (int i00 = 0; i00 < ne00; i00++) {
476
- const wsp_ggml_bf16_t * src0_ptr = (wsp_ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
477
-
478
- dst_ptr[id] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_BF16_TO_FP32(*src0_ptr));
479
- id++;
480
- }
481
- }
482
- id += ne00 * (ne01 - ir1);
483
- }
484
- }
485
- } else {
486
- WSP_GGML_ABORT("fatal error"); // TODO: implement
487
- }
488
- }
489
- return;
490
- }
491
-
492
- // dst counters
493
- int64_t i10 = 0;
494
- int64_t i11 = 0;
495
- int64_t i12 = 0;
496
- int64_t i13 = 0;
497
-
498
- if (dst->type == WSP_GGML_TYPE_BF16) {
499
- for (int64_t i03 = 0; i03 < ne03; i03++) {
500
- for (int64_t i02 = 0; i02 < ne02; i02++) {
501
- i10 += ne00 * ir0;
502
- while (i10 >= ne0) {
503
- i10 -= ne0;
504
- if (++i11 == ne1) {
505
- i11 = 0;
506
- if (++i12 == ne2) {
507
- i12 = 0;
508
- if (++i13 == ne3) {
509
- i13 = 0;
510
- }
511
- }
512
- }
513
- }
514
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
515
- for (int64_t i00 = 0; i00 < ne00; i00++) {
516
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
517
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
518
-
519
- memcpy(dst_ptr, src0_ptr, sizeof(wsp_ggml_bf16_t));
520
-
521
- if (++i10 == ne00) {
522
- i10 = 0;
523
- if (++i11 == ne01) {
524
- i11 = 0;
525
- if (++i12 == ne02) {
526
- i12 = 0;
527
- if (++i13 == ne03) {
528
- i13 = 0;
529
- }
530
- }
531
- }
532
- }
533
- }
534
- }
535
- i10 += ne00 * (ne01 - ir1);
536
- while (i10 >= ne0) {
537
- i10 -= ne0;
538
- if (++i11 == ne1) {
539
- i11 = 0;
540
- if (++i12 == ne2) {
541
- i12 = 0;
542
- if (++i13 == ne3) {
543
- i13 = 0;
544
- }
545
- }
546
- }
547
- }
548
- }
549
- }
550
- } else if (dst->type == WSP_GGML_TYPE_F16) {
551
- for (int64_t i03 = 0; i03 < ne03; i03++) {
552
- for (int64_t i02 = 0; i02 < ne02; i02++) {
553
- i10 += ne00 * ir0;
554
- while (i10 >= ne0) {
555
- i10 -= ne0;
556
- if (++i11 == ne1) {
557
- i11 = 0;
558
- if (++i12 == ne2) {
559
- i12 = 0;
560
- if (++i13 == ne3) {
561
- i13 = 0;
562
- }
563
- }
564
- }
565
- }
566
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
567
- for (int64_t i00 = 0; i00 < ne00; i00++) {
568
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
569
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
570
-
571
- *(wsp_ggml_fp16_t *) dst_ptr = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_BF16_TO_FP32(*(const wsp_ggml_bf16_t *) src0_ptr));
572
-
573
- if (++i10 == ne0) {
574
- i10 = 0;
575
- if (++i11 == ne1) {
576
- i11 = 0;
577
- if (++i12 == ne2) {
578
- i12 = 0;
579
- if (++i13 == ne3) {
580
- i13 = 0;
581
- }
582
- }
583
- }
584
- }
585
- }
586
- }
587
- i10 += ne00 * (ne01 - ir1);
588
- while (i10 >= ne0) {
589
- i10 -= ne0;
590
- if (++i11 == ne1) {
591
- i11 = 0;
592
- if (++i12 == ne2) {
593
- i12 = 0;
594
- if (++i13 == ne3) {
595
- i13 = 0;
596
- }
597
- }
598
- }
599
- }
600
- }
601
- }
602
- } else if (dst->type == WSP_GGML_TYPE_F32) {
603
- for (int64_t i03 = 0; i03 < ne03; i03++) {
604
- for (int64_t i02 = 0; i02 < ne02; i02++) {
605
- i10 += ne00 * ir0;
606
- while (i10 >= ne0) {
607
- i10 -= ne0;
608
- if (++i11 == ne1) {
609
- i11 = 0;
610
- if (++i12 == ne2) {
611
- i12 = 0;
612
- if (++i13 == ne3) {
613
- i13 = 0;
614
- }
615
- }
616
- }
617
- }
618
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
619
- for (int64_t i00 = 0; i00 < ne00; i00++) {
620
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
621
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
622
-
623
- *(float *) dst_ptr = WSP_GGML_BF16_TO_FP32(*(const wsp_ggml_bf16_t *) src0_ptr);
624
-
625
- if (++i10 == ne0) {
626
- i10 = 0;
627
- if (++i11 == ne1) {
628
- i11 = 0;
629
- if (++i12 == ne2) {
630
- i12 = 0;
631
- if (++i13 == ne3) {
632
- i13 = 0;
633
- }
634
- }
635
- }
636
- }
637
- }
638
- }
639
- i10 += ne00 * (ne01 - ir1);
640
- while (i10 >= ne0) {
641
- i10 -= ne0;
642
- if (++i11 == ne1) {
643
- i11 = 0;
644
- if (++i12 == ne2) {
645
- i12 = 0;
646
- if (++i13 == ne3) {
647
- i13 = 0;
648
- }
649
- }
650
- }
651
- }
652
- }
653
- }
654
- } else {
655
- WSP_GGML_ABORT("fatal error"); // TODO: implement
656
- }
657
- }
658
-
659
- static void wsp_ggml_compute_forward_dup_f32(
660
- const wsp_ggml_compute_params * params,
661
- wsp_ggml_tensor * dst) {
662
-
663
- const wsp_ggml_tensor * src0 = dst->src[0];
664
-
665
- WSP_GGML_ASSERT(wsp_ggml_nelements(dst) == wsp_ggml_nelements(src0));
666
-
667
- WSP_GGML_TENSOR_UNARY_OP_LOCALS
668
-
669
- const int ith = params->ith; // thread index
670
- const int nth = params->nth; // number of threads
671
-
672
- // parallelize by rows
673
- const int nr = ne01;
674
- // number of rows per thread
675
- const int dr = (nr + nth - 1) / nth;
676
- // row range for this thread
677
- const int ir0 = dr * ith;
678
- const int ir1 = MIN(ir0 + dr, nr);
679
-
680
- if (src0->type == dst->type &&
681
- ne00 == ne0 &&
682
- nb00 == wsp_ggml_type_size(src0->type) && nb0 == wsp_ggml_type_size(dst->type)) {
683
- // copy by rows
684
- const size_t rs = ne00*nb00;
685
- for (int64_t i03 = 0; i03 < ne03; i03++) {
686
- for (int64_t i02 = 0; i02 < ne02; i02++) {
687
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
688
- memcpy(
689
- ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
690
- ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
691
- rs);
692
- }
693
- }
694
- }
695
- return;
696
- }
697
-
698
- if (wsp_ggml_is_contiguous(dst)) {
699
- // TODO: simplify
700
- if (nb00 == sizeof(float)) {
701
- if (wsp_ggml_get_type_traits_cpu(dst->type)->from_float) {
702
- wsp_ggml_from_float_t const from_float = wsp_ggml_get_type_traits_cpu(dst->type)->from_float;
703
-
704
- size_t id = 0;
705
- size_t rs = nb0 * (ne00 / wsp_ggml_blck_size(dst->type));
706
- char * dst_ptr = (char *) dst->data;
707
-
708
- for (int i03 = 0; i03 < ne03; i03++) {
709
- for (int i02 = 0; i02 < ne02; i02++) {
710
- id += rs * ir0;
711
- for (int i01 = ir0; i01 < ir1; i01++) {
712
- const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
713
- from_float(src0_ptr, dst_ptr + id, ne00);
714
- id += rs;
715
- }
716
- id += rs * (ne01 - ir1);
717
- }
718
- }
719
- } else {
720
- WSP_GGML_ABORT("fatal error"); // TODO: implement
721
- }
722
- } else {
723
- //printf("%s: this is not optimal - fix me\n", __func__);
724
-
725
- if (dst->type == WSP_GGML_TYPE_F32) {
726
- size_t id = 0;
727
- float * dst_ptr = (float *) dst->data;
728
-
729
- for (int i03 = 0; i03 < ne03; i03++) {
730
- for (int i02 = 0; i02 < ne02; i02++) {
731
- id += ne00 * ir0;
732
- for (int i01 = ir0; i01 < ir1; i01++) {
733
- for (int i00 = 0; i00 < ne00; i00++) {
734
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
735
-
736
- dst_ptr[id] = *src0_ptr;
737
- id++;
738
- }
739
- }
740
- id += ne00 * (ne01 - ir1);
741
- }
742
- }
743
- } else if (dst->type == WSP_GGML_TYPE_F16) {
744
- size_t id = 0;
745
- wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) dst->data;
746
-
747
- for (int i03 = 0; i03 < ne03; i03++) {
748
- for (int i02 = 0; i02 < ne02; i02++) {
749
- id += ne00 * ir0;
750
- for (int i01 = ir0; i01 < ir1; i01++) {
751
- for (int i00 = 0; i00 < ne00; i00++) {
752
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
753
-
754
- dst_ptr[id] = WSP_GGML_CPU_FP32_TO_FP16(*src0_ptr);
755
- id++;
756
- }
757
- }
758
- id += ne00 * (ne01 - ir1);
759
- }
760
- }
761
- } else if (dst->type == WSP_GGML_TYPE_BF16) {
762
- size_t id = 0;
763
- wsp_ggml_bf16_t * dst_ptr = (wsp_ggml_bf16_t *) dst->data;
764
-
765
- for (int i03 = 0; i03 < ne03; i03++) {
766
- for (int i02 = 0; i02 < ne02; i02++) {
767
- id += ne00 * ir0;
768
- for (int i01 = ir0; i01 < ir1; i01++) {
769
- for (int i00 = 0; i00 < ne00; i00++) {
770
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
771
-
772
- dst_ptr[id] = WSP_GGML_FP32_TO_BF16(*src0_ptr);
773
- id++;
774
- }
775
- }
776
- id += ne00 * (ne01 - ir1);
777
- }
778
- }
779
- } else {
780
- WSP_GGML_ABORT("fatal error"); // TODO: implement
781
- }
782
- }
783
-
784
- return;
785
- }
786
-
787
- // dst counters
788
-
789
- int64_t i10 = 0;
790
- int64_t i11 = 0;
791
- int64_t i12 = 0;
792
- int64_t i13 = 0;
793
-
794
- if (dst->type == WSP_GGML_TYPE_F32) {
795
- for (int64_t i03 = 0; i03 < ne03; i03++) {
796
- for (int64_t i02 = 0; i02 < ne02; i02++) {
797
- i10 += ne00 * ir0;
798
- while (i10 >= ne0) {
799
- i10 -= ne0;
800
- if (++i11 == ne1) {
801
- i11 = 0;
802
- if (++i12 == ne2) {
803
- i12 = 0;
804
- if (++i13 == ne3) {
805
- i13 = 0;
806
- }
807
- }
808
- }
809
- }
810
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
811
- for (int64_t i00 = 0; i00 < ne00; i00++) {
812
- const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
813
- char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
814
-
815
- memcpy(dst_ptr, src0_ptr, sizeof(float));
816
-
817
- if (++i10 == ne0) {
818
- i10 = 0;
819
- if (++i11 == ne1) {
820
- i11 = 0;
821
- if (++i12 == ne2) {
822
- i12 = 0;
823
- if (++i13 == ne3) {
824
- i13 = 0;
825
- }
826
- }
827
- }
828
- }
829
- }
830
- }
831
- i10 += ne00 * (ne01 - ir1);
832
- while (i10 >= ne0) {
833
- i10 -= ne0;
834
- if (++i11 == ne1) {
835
- i11 = 0;
836
- if (++i12 == ne2) {
837
- i12 = 0;
838
- if (++i13 == ne3) {
839
- i13 = 0;
840
- }
132
+ for (int i03 = 0; i03 < ne03; i03++) {
133
+ for (int i02 = 0; i02 < ne02; i02++) {
134
+ id += ne00 * ir0;
135
+ for (int i01 = ir0; i01 < ir1; i01++) {
136
+ for (int i00 = 0; i00 < ne00; i00++) {
137
+ const src_t * src0_ptr = (src_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
138
+
139
+ float tmp = type_conversion_table<src_t>::to_f32(*src0_ptr);
140
+ dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
141
+ id++;
841
142
  }
842
143
  }
144
+ id += ne00 * (ne01 - ir1);
843
145
  }
844
146
  }
845
147
  }
846
- } else if (dst->type == WSP_GGML_TYPE_F16) {
148
+ return;
149
+ }
150
+
151
+ // dst counters
152
+ int64_t i10 = 0;
153
+ int64_t i11 = 0;
154
+ int64_t i12 = 0;
155
+ int64_t i13 = 0;
156
+
157
+ if constexpr (std::is_same_v<dst_t, src_t>) {
847
158
  for (int64_t i03 = 0; i03 < ne03; i03++) {
848
159
  for (int64_t i02 = 0; i02 < ne02; i02++) {
849
160
  i10 += ne00 * ir0;
@@ -864,15 +175,15 @@ static void wsp_ggml_compute_forward_dup_f32(
864
175
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
865
176
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
866
177
 
867
- *(wsp_ggml_fp16_t *) dst_ptr = WSP_GGML_CPU_FP32_TO_FP16(*(const float *) src0_ptr);
178
+ memcpy(dst_ptr, src0_ptr, sizeof(dst_t));
868
179
 
869
- if (++i10 == ne0) {
180
+ if (++i10 == ne00) {
870
181
  i10 = 0;
871
- if (++i11 == ne1) {
182
+ if (++i11 == ne01) {
872
183
  i11 = 0;
873
- if (++i12 == ne2) {
184
+ if (++i12 == ne02) {
874
185
  i12 = 0;
875
- if (++i13 == ne3) {
186
+ if (++i13 == ne03) {
876
187
  i13 = 0;
877
188
  }
878
189
  }
@@ -895,7 +206,8 @@ static void wsp_ggml_compute_forward_dup_f32(
895
206
  }
896
207
  }
897
208
  }
898
- } else if (dst->type == WSP_GGML_TYPE_BF16) {
209
+
210
+ } else {
899
211
  for (int64_t i03 = 0; i03 < ne03; i03++) {
900
212
  for (int64_t i02 = 0; i02 < ne02; i02++) {
901
213
  i10 += ne00 * ir0;
@@ -916,7 +228,8 @@ static void wsp_ggml_compute_forward_dup_f32(
916
228
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
917
229
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
918
230
 
919
- *(wsp_ggml_bf16_t *) dst_ptr = WSP_GGML_FP32_TO_BF16(*(const float *) src0_ptr);
231
+ float tmp = type_conversion_table<src_t>::to_f32(*(const src_t *) src0_ptr);
232
+ *(dst_t *) dst_ptr = type_conversion_table<dst_t>::from_f32(tmp);
920
233
 
921
234
  if (++i10 == ne0) {
922
235
  i10 = 0;
@@ -947,8 +260,63 @@ static void wsp_ggml_compute_forward_dup_f32(
947
260
  }
948
261
  }
949
262
  }
263
+ }
264
+ }
265
+
266
+
267
+ template<typename src_t>
268
+ static void wsp_ggml_compute_forward_dup_to_q(
269
+ const wsp_ggml_compute_params * params,
270
+ wsp_ggml_tensor * dst) {
271
+
272
+ const wsp_ggml_tensor * src0 = dst->src[0];
273
+
274
+ WSP_GGML_ASSERT(wsp_ggml_nelements(dst) == wsp_ggml_nelements(src0));
275
+ WSP_GGML_ASSERT(!wsp_ggml_is_quantized(src0->type));
276
+
277
+ WSP_GGML_TENSOR_UNARY_OP_LOCALS
278
+
279
+ const int ith = params->ith; // thread index
280
+ const int nth = params->nth; // number of threads
281
+
282
+ // parallelize by rows
283
+ const int nr = ne01;
284
+ // number of rows per thread
285
+ const int dr = (nr + nth - 1) / nth;
286
+ // row range for this thread
287
+ const int ir0 = dr * ith;
288
+ const int ir1 = MIN(ir0 + dr, nr);
289
+
290
+ if (wsp_ggml_is_contiguous(dst) &&
291
+ nb00 == sizeof(src_t) &&
292
+ wsp_ggml_get_type_traits_cpu(dst->type)->from_float) {
293
+ // casting non-quantized types --> intermediate f32 --> quantized
294
+ wsp_ggml_from_float_t const wsp_quantize_row_q = wsp_ggml_get_type_traits_cpu(dst->type)->from_float;
295
+ float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
296
+
297
+ size_t id = 0;
298
+ size_t rs = nb0 * (ne00 / wsp_ggml_blck_size(dst->type));
299
+ char * dst_ptr = (char *) dst->data;
300
+
301
+ for (int i03 = 0; i03 < ne03; i03++) {
302
+ for (int i02 = 0; i02 < ne02; i02++) {
303
+ id += rs * ir0;
304
+ for (int i01 = ir0; i01 < ir1; i01++) {
305
+ const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
306
+
307
+ for (int i00 = 0; i00 < ne00; i00++) {
308
+ src0_f32[i00] = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
309
+ }
310
+
311
+ wsp_quantize_row_q(src0_f32, dst_ptr + id, ne00);
312
+ id += rs;
313
+ }
314
+ id += rs * (ne01 - ir1);
315
+ }
316
+ }
950
317
  } else {
951
- WSP_GGML_ABORT("fatal error"); // TODO: implement
318
+ // printf("%s %s\n", wsp_ggml_type_name(src0->type), wsp_ggml_type_name(dst->type));
319
+ WSP_GGML_ABORT("not implemented");
952
320
  }
953
321
  }
954
322
 
@@ -1102,7 +470,7 @@ static void wsp_ggml_compute_forward_dup_bytes(
1102
470
  }
1103
471
  }
1104
472
 
1105
- static void wsp_ggml_compute_forward_dup_q(
473
+ static void wsp_ggml_compute_forward_dup_from_q(
1106
474
  const wsp_ggml_compute_params * params,
1107
475
  wsp_ggml_tensor * dst) {
1108
476
 
@@ -1167,20 +535,35 @@ void wsp_ggml_compute_forward_dup(
1167
535
  switch (src0->type) {
1168
536
  case WSP_GGML_TYPE_F16:
1169
537
  {
1170
- wsp_ggml_compute_forward_dup_f16(params, dst);
538
+ /**/ if (dst->type == WSP_GGML_TYPE_F16) wsp_ggml_compute_forward_dup_flt<wsp_ggml_fp16_t, wsp_ggml_fp16_t>(params, dst);
539
+ else if (dst->type == WSP_GGML_TYPE_BF16) wsp_ggml_compute_forward_dup_flt<wsp_ggml_fp16_t, wsp_ggml_bf16_t>(params, dst);
540
+ else if (dst->type == WSP_GGML_TYPE_F32) wsp_ggml_compute_forward_dup_flt<wsp_ggml_fp16_t, float >(params, dst);
541
+ else wsp_ggml_compute_forward_dup_to_q<wsp_ggml_fp16_t>(params, dst);
1171
542
  } break;
1172
543
  case WSP_GGML_TYPE_BF16:
1173
544
  {
1174
- wsp_ggml_compute_forward_dup_bf16(params, dst);
545
+ /**/ if (dst->type == WSP_GGML_TYPE_F16) wsp_ggml_compute_forward_dup_flt<wsp_ggml_bf16_t, wsp_ggml_fp16_t>(params, dst);
546
+ else if (dst->type == WSP_GGML_TYPE_BF16) wsp_ggml_compute_forward_dup_flt<wsp_ggml_bf16_t, wsp_ggml_bf16_t>(params, dst);
547
+ else if (dst->type == WSP_GGML_TYPE_F32) wsp_ggml_compute_forward_dup_flt<wsp_ggml_bf16_t, float >(params, dst);
548
+ else wsp_ggml_compute_forward_dup_to_q<wsp_ggml_bf16_t>(params, dst);
1175
549
  } break;
1176
550
  case WSP_GGML_TYPE_F32:
1177
551
  {
1178
- wsp_ggml_compute_forward_dup_f32(params, dst);
552
+ /**/ if (dst->type == WSP_GGML_TYPE_F16) wsp_ggml_compute_forward_dup_flt<float, wsp_ggml_fp16_t>(params, dst);
553
+ else if (dst->type == WSP_GGML_TYPE_BF16) wsp_ggml_compute_forward_dup_flt<float, wsp_ggml_bf16_t>(params, dst);
554
+ else if (dst->type == WSP_GGML_TYPE_F32) wsp_ggml_compute_forward_dup_flt<float, float >(params, dst);
555
+ else if (dst->type == WSP_GGML_TYPE_I32) wsp_ggml_compute_forward_dup_flt<float, int32_t >(params, dst);
556
+ else wsp_ggml_compute_forward_dup_to_q<float>(params, dst);
557
+ } break;
558
+ case WSP_GGML_TYPE_I32:
559
+ {
560
+ if (dst->type == WSP_GGML_TYPE_F32) wsp_ggml_compute_forward_dup_flt<int32_t, float>(params, dst);
561
+ else WSP_GGML_ABORT("not implemented");
1179
562
  } break;
1180
563
  default:
1181
564
  {
1182
565
  if (wsp_ggml_is_quantized(src0->type) && dst->type == WSP_GGML_TYPE_F32) {
1183
- wsp_ggml_compute_forward_dup_q(params, dst);
566
+ wsp_ggml_compute_forward_dup_from_q(params, dst);
1184
567
  break;
1185
568
  }
1186
569
  WSP_GGML_ABORT("fatal error");
@@ -4084,31 +3467,27 @@ static void wsp_ggml_compute_forward_norm_f32(
4084
3467
 
4085
3468
  WSP_GGML_ASSERT(eps >= 0.0f);
4086
3469
 
4087
- // TODO: optimize
4088
3470
  for (int64_t i03 = 0; i03 < ne03; i03++) {
4089
3471
  for (int64_t i02 = 0; i02 < ne02; i02++) {
4090
3472
  for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
4091
3473
  const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
4092
3474
 
4093
- wsp_ggml_float sum = 0.0;
4094
- for (int64_t i00 = 0; i00 < ne00; i00++) {
4095
- sum += (wsp_ggml_float)x[i00];
4096
- }
4097
-
3475
+ float sum = 0.0;
3476
+ wsp_ggml_vec_sum_f32(ne00, &sum, x);
4098
3477
  float mean = sum/ne00;
4099
3478
 
4100
3479
  float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3480
+ float variance = 0;
4101
3481
 
4102
- wsp_ggml_float sum2 = 0.0;
4103
- for (int64_t i00 = 0; i00 < ne00; i00++) {
4104
- float v = x[i00] - mean;
4105
- y[i00] = v;
4106
- sum2 += (wsp_ggml_float)(v*v);
4107
- }
3482
+ #ifdef WSP_GGML_USE_ACCELERATE
3483
+ mean = -mean;
3484
+ vDSP_vsadd(x, 1, &mean, y, 1, ne00);
3485
+ vDSP_measqv(y, 1, &variance, ne00);
3486
+ #else
3487
+ variance = wsp_ggml_vec_cvar_f32(ne00, y, x, mean);
3488
+ #endif //WSP_GGML_USE_ACCELERATE
4108
3489
 
4109
- float variance = sum2/ne00;
4110
3490
  const float scale = 1.0f/sqrtf(variance + eps);
4111
-
4112
3491
  wsp_ggml_vec_scale_f32(ne00, y, scale);
4113
3492
  }
4114
3493
  }
@@ -5356,6 +4735,7 @@ void wsp_ggml_compute_forward_get_rows(
5356
4735
  //}
5357
4736
  }
5358
4737
 
4738
+ template<typename idx_t>
5359
4739
  static void wsp_ggml_compute_forward_set_rows_f32(
5360
4740
  const wsp_ggml_compute_params * params,
5361
4741
  wsp_ggml_tensor * dst) {
@@ -5394,7 +4774,7 @@ static void wsp_ggml_compute_forward_set_rows_f32(
5394
4774
  const int64_t i11 = i02%ne11;
5395
4775
  const int64_t i10 = i;
5396
4776
 
5397
- const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4777
+ const int64_t i1 = *(idx_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
5398
4778
 
5399
4779
  WSP_GGML_ASSERT(i1 >= 0 && i1 < ne1);
5400
4780
 
@@ -5411,11 +4791,18 @@ void wsp_ggml_compute_forward_set_rows(
5411
4791
  wsp_ggml_tensor * dst) {
5412
4792
 
5413
4793
  const wsp_ggml_tensor * src0 = dst->src[0];
4794
+ const wsp_ggml_tensor * src1 = dst->src[1];
5414
4795
 
5415
4796
  switch (src0->type) {
5416
4797
  case WSP_GGML_TYPE_F32:
5417
4798
  {
5418
- wsp_ggml_compute_forward_set_rows_f32(params, dst);
4799
+ if (src1->type == WSP_GGML_TYPE_I64) {
4800
+ wsp_ggml_compute_forward_set_rows_f32<int64_t>(params, dst);
4801
+ } else if (src1->type == WSP_GGML_TYPE_I32) {
4802
+ wsp_ggml_compute_forward_set_rows_f32<int32_t>(params, dst);
4803
+ } else {
4804
+ WSP_GGML_ABORT("src1->type = %d (%s) not supported", src1->type, wsp_ggml_type_name(src1->type));
4805
+ }
5419
4806
  } break;
5420
4807
  default:
5421
4808
  {
@@ -7027,6 +6414,209 @@ void wsp_ggml_compute_forward_im2col_back_f32(
7027
6414
  }
7028
6415
  }
7029
6416
 
6417
+
6418
+ // wsp_ggml_compute_forward_im2col_3d_f16
6419
+ // src0: kernel [OC*IC, KD, KH, KW]
6420
+ // src1: image [N*IC, ID, IH, IW]
6421
+ // dst: result [N*OD, OH, OW, IC * KD * KH * KW]
6422
+ static void wsp_ggml_compute_forward_im2col_3d_f16(
6423
+ const wsp_ggml_compute_params * params,
6424
+ wsp_ggml_tensor * dst) {
6425
+
6426
+ const wsp_ggml_tensor * src0 = dst->src[0];
6427
+ const wsp_ggml_tensor * src1 = dst->src[1];
6428
+
6429
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
6430
+ WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
6431
+ WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16);
6432
+
6433
+ WSP_GGML_TENSOR_BINARY_OP_LOCALS;
6434
+
6435
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6436
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6437
+ const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
6438
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
6439
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
6440
+ const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
6441
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
6442
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
6443
+ const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
6444
+ const int32_t IC = ((const int32_t *)(dst->op_params))[9];
6445
+
6446
+
6447
+ const int ith = params->ith;
6448
+ const int nth = params->nth;
6449
+
6450
+ const int64_t N = ne13 / IC;
6451
+ const int64_t ID = ne12;
6452
+ const int64_t IH = ne11;
6453
+ const int64_t IW = ne10;
6454
+
6455
+ const int64_t OC = ne03 / IC;
6456
+ WSP_GGML_UNUSED(OC);
6457
+ const int64_t KD = ne02;
6458
+ const int64_t KH = ne01;
6459
+ const int64_t KW = ne00;
6460
+
6461
+ const int64_t OD = ne3 / N;
6462
+ const int64_t OH = ne2;
6463
+ const int64_t OW = ne1;
6464
+ const int64_t OH_OW = OH*OW;
6465
+ const int64_t KD_KH_KW = KD*KH*KW;
6466
+ const int64_t KH_KW = KH*KW;
6467
+ const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
6468
+
6469
+ WSP_GGML_ASSERT(nb10 == sizeof(float));
6470
+
6471
+ // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
6472
+ {
6473
+ wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) dst->data;
6474
+
6475
+ for (int64_t in = 0; in < N; in++) {
6476
+ for (int64_t iod = 0; iod < OD; iod++) {
6477
+ for (int64_t ioh = 0; ioh < OH; ioh++) {
6478
+ for (int64_t iow = 0; iow < OW; iow++) {
6479
+ for (int64_t iic = ith; iic < IC; iic += nth) {
6480
+
6481
+ // micro kernel
6482
+ wsp_ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
6483
+ const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
6484
+
6485
+ for (int64_t ikd = 0; ikd < KD; ikd++) {
6486
+ for (int64_t ikh = 0; ikh < KH; ikh++) {
6487
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
6488
+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
6489
+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
6490
+ const int64_t iid = iod*s2 + ikd*d2 - p2;
6491
+
6492
+ if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
6493
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
6494
+ } else {
6495
+ const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
6496
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = WSP_GGML_CPU_FP32_TO_FP16(*s);
6497
+ }
6498
+ }
6499
+ }
6500
+ }
6501
+ }
6502
+ }
6503
+ }
6504
+ }
6505
+ }
6506
+ }
6507
+ }
6508
+
6509
+ // wsp_ggml_compute_forward_im2col_3d_f32
6510
+ // src0: kernel [OC*IC, KD, KH, KW]
6511
+ // src1: image [N*IC, ID, IH, IW]
6512
+ // dst: result [N*OD, OH, OW, IC * KD * KH * KW]
6513
+ static void wsp_ggml_compute_forward_im2col_3d_f32(
6514
+ const wsp_ggml_compute_params * params,
6515
+ wsp_ggml_tensor * dst) {
6516
+
6517
+ const wsp_ggml_tensor * src0 = dst->src[0];
6518
+ const wsp_ggml_tensor * src1 = dst->src[1];
6519
+
6520
+ WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
6521
+ WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
6522
+
6523
+ WSP_GGML_TENSOR_BINARY_OP_LOCALS;
6524
+
6525
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6526
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6527
+ const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
6528
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
6529
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
6530
+ const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
6531
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
6532
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
6533
+ const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
6534
+ const int32_t IC = ((const int32_t *)(dst->op_params))[9];
6535
+
6536
+
6537
+ const int ith = params->ith;
6538
+ const int nth = params->nth;
6539
+
6540
+ const int64_t N = ne13 / IC;
6541
+ const int64_t ID = ne12;
6542
+ const int64_t IH = ne11;
6543
+ const int64_t IW = ne10;
6544
+
6545
+ const int64_t OC = ne03 / IC;
6546
+ WSP_GGML_UNUSED(OC);
6547
+ const int64_t KD = ne02;
6548
+ const int64_t KH = ne01;
6549
+ const int64_t KW = ne00;
6550
+
6551
+ const int64_t OD = ne3 / N;
6552
+ const int64_t OH = ne2;
6553
+ const int64_t OW = ne1;
6554
+
6555
+ const int64_t OH_OW = OH*OW;
6556
+ const int64_t KD_KH_KW = KD*KH*KW;
6557
+ const int64_t KH_KW = KH*KW;
6558
+ const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
6559
+
6560
+ WSP_GGML_ASSERT(nb10 == sizeof(float));
6561
+
6562
+ // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
6563
+ {
6564
+ float * const wdata = (float *) dst->data;
6565
+
6566
+ for (int64_t in = 0; in < N; in++) {
6567
+ for (int64_t iod = 0; iod < OD; iod++) {
6568
+ for (int64_t ioh = 0; ioh < OH; ioh++) {
6569
+ for (int64_t iow = 0; iow < OW; iow++) {
6570
+ for (int64_t iic = ith; iic < IC; iic += nth) {
6571
+
6572
+ // micro kernel
6573
+ float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
6574
+ const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
6575
+
6576
+ for (int64_t ikd = 0; ikd < KD; ikd++) {
6577
+ for (int64_t ikh = 0; ikh < KH; ikh++) {
6578
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
6579
+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
6580
+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
6581
+ const int64_t iid = iod*s2 + ikd*d2 - p2;
6582
+
6583
+ if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
6584
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
6585
+ } else {
6586
+ const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
6587
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s;
6588
+ }
6589
+ }
6590
+ }
6591
+ }
6592
+ }
6593
+ }
6594
+ }
6595
+ }
6596
+ }
6597
+ }
6598
+ }
6599
+
6600
+
6601
+ void wsp_ggml_compute_forward_im2col_3d(
6602
+ const wsp_ggml_compute_params * params,
6603
+ wsp_ggml_tensor * dst) {
6604
+ switch (dst->type) {
6605
+ case WSP_GGML_TYPE_F16:
6606
+ {
6607
+ wsp_ggml_compute_forward_im2col_3d_f16(params, dst);
6608
+ } break;
6609
+ case WSP_GGML_TYPE_F32:
6610
+ {
6611
+ wsp_ggml_compute_forward_im2col_3d_f32(params, dst);
6612
+ } break;
6613
+ default:
6614
+ {
6615
+ WSP_GGML_ABORT("fatal error");
6616
+ }
6617
+ }
6618
+ }
6619
+
7030
6620
  static void wsp_ggml_call_mul_mat(wsp_ggml_type type, const wsp_ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
7031
6621
  void * a, void * b, float * c) {
7032
6622
  const wsp_ggml_type_traits * traits = wsp_ggml_get_type_traits(type);
@@ -7207,6 +6797,148 @@ void wsp_ggml_compute_forward_conv_2d(
7207
6797
  wsp_ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
7208
6798
  }
7209
6799
 
6800
+ // wsp_ggml_compute_forward_conv_3d
6801
+
6802
+ static void wsp_ggml_compute_forward_conv_3d_impl(const wsp_ggml_compute_params * params,
6803
+ const wsp_ggml_tensor * kernel,
6804
+ const wsp_ggml_tensor * src,
6805
+ wsp_ggml_tensor * dst,
6806
+ wsp_ggml_type kernel_type) {
6807
+
6808
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(kernel));
6809
+ WSP_GGML_ASSERT(kernel_type == WSP_GGML_TYPE_F16 || kernel_type == WSP_GGML_TYPE_F32);
6810
+ WSP_GGML_ASSERT(kernel->type == kernel_type);
6811
+
6812
+ const wsp_ggml_type_traits * traits = wsp_ggml_get_type_traits(kernel_type);
6813
+
6814
+ const int32_t s0 = dst->op_params[0];
6815
+ const int32_t s1 = dst->op_params[1];
6816
+ const int32_t s2 = dst->op_params[2];
6817
+ const int32_t p0 = dst->op_params[3];
6818
+ const int32_t p1 = dst->op_params[4];
6819
+ const int32_t p2 = dst->op_params[5];
6820
+ const int32_t d0 = dst->op_params[6];
6821
+ const int32_t d1 = dst->op_params[7];
6822
+ const int32_t d2 = dst->op_params[8];
6823
+ const int32_t c = dst->op_params[9];
6824
+ const int32_t n = dst->op_params[10];
6825
+ const int32_t oc = dst->op_params[11];
6826
+
6827
+ const int64_t src_w = src->ne[0];
6828
+ const int64_t src_h = src->ne[1];
6829
+ const int64_t src_d = src->ne[2];
6830
+ const int64_t knl_w = kernel->ne[0];
6831
+ const int64_t knl_h = kernel->ne[1];
6832
+ const int64_t knl_d = kernel->ne[2];
6833
+ const int64_t dst_w = dst->ne[0];
6834
+ const int64_t dst_h = dst->ne[1];
6835
+ const int64_t dst_d = dst->ne[2];
6836
+
6837
+ const float * src_data = (float *) src->data;
6838
+ void * knl_data = kernel->data;
6839
+ float * dst_data = (float *) dst->data;
6840
+
6841
+ const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
6842
+ const int64_t knl_n_total = knl_n_per_channel * c;
6843
+ const int64_t patch_total = n * dst_w * dst_h * dst_d;
6844
+
6845
+ const int64_t space_per_patch = knl_n_total * traits->type_size + oc * sizeof(float);
6846
+ const int64_t batch_size = params->wsize / space_per_patch;
6847
+ const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
6848
+ const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
6849
+
6850
+ WSP_GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
6851
+
6852
+ void * tmp = params->wdata;
6853
+
6854
+ for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
6855
+ const int64_t patch_start_batch = batch_i * patches_per_batch;
6856
+ const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, patch_total);
6857
+ const int64_t patch_n_in_batch = patch_end_batch - patch_start_batch;
6858
+
6859
+ const int64_t patch_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
6860
+ const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
6861
+ const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
6862
+
6863
+ for (int64_t p = patch_start; p < patch_end; ++p) {
6864
+ const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
6865
+ const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
6866
+ const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
6867
+ const int64_t dst_z = p_in_batch / (dst_w * dst_h);
6868
+ const int64_t dst_y = p_in_depth / dst_w;
6869
+ const int64_t dst_x = p_in_depth % dst_w;
6870
+
6871
+ char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
6872
+
6873
+ for (int64_t ic = 0; ic < c; ++ic) {
6874
+ for (int64_t kz = 0; kz < knl_d; ++kz) {
6875
+ for (int64_t ky = 0; ky < knl_h; ++ky) {
6876
+ for (int64_t kx = 0; kx < knl_w; ++kx) {
6877
+ const int64_t sz = dst_z * s2 + kz * d2 - p2;
6878
+ const int64_t sy = dst_y * s1 + ky * d1 - p1;
6879
+ const int64_t sx = dst_x * s0 + kx * d0 - p0;
6880
+
6881
+ int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
6882
+
6883
+ float src_val;
6884
+ if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6885
+ src_val = 0.0f;
6886
+ } else {
6887
+ const int64_t cn_idx = batch_idx * c + ic;
6888
+ const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
6889
+ src_val = *src_ptr;
6890
+ }
6891
+
6892
+ char * element_ptr = dst_row + dst_idx * traits->type_size;
6893
+ if (kernel_type == WSP_GGML_TYPE_F32) {
6894
+ *(float *)element_ptr = src_val;
6895
+ } else if (kernel_type == WSP_GGML_TYPE_F16) {
6896
+ *(wsp_ggml_fp16_t *)element_ptr = WSP_GGML_CPU_FP32_TO_FP16(src_val);
6897
+ }
6898
+ }
6899
+ }
6900
+ }
6901
+ }
6902
+ }
6903
+
6904
+ wsp_ggml_barrier(params->threadpool);
6905
+
6906
+ float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
6907
+ wsp_ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
6908
+
6909
+ wsp_ggml_barrier(params->threadpool);
6910
+
6911
+ const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
6912
+ const int64_t permute_start = params->ith * permute_per_thread;
6913
+ const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);
6914
+
6915
+ for (int64_t i = permute_start; i < permute_end; ++i) {
6916
+ const int64_t p = patch_start_batch + i;
6917
+ const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
6918
+ const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
6919
+ const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
6920
+ const int64_t dst_z = p_in_batch / (dst_w * dst_h);
6921
+ const int64_t dst_y = p_in_depth / dst_w;
6922
+ const int64_t dst_x = p_in_depth % dst_w;
6923
+
6924
+ for (int64_t ioc = 0; ioc < oc; ++ioc) {
6925
+ const float value = gemm_output[i * oc + ioc];
6926
+ const int64_t ocn_idx = batch_idx * oc + ioc;
6927
+ float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
6928
+ *dst_ptr = value;
6929
+ }
6930
+ }
6931
+ }
6932
+ }
6933
+
6934
+ void wsp_ggml_compute_forward_conv_3d(
6935
+ const wsp_ggml_compute_params * params,
6936
+ wsp_ggml_tensor * dst) {
6937
+ const wsp_ggml_tensor * src0 = dst->src[0];
6938
+ const wsp_ggml_tensor * src1 = dst->src[1];
6939
+ wsp_ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
6940
+ }
6941
+
7210
6942
  // wsp_ggml_compute_forward_conv_transpose_2d
7211
6943
 
7212
6944
  void wsp_ggml_compute_forward_conv_transpose_2d(
@@ -7872,6 +7604,15 @@ static void wsp_ggml_compute_forward_pad_f32(
7872
7604
  WSP_GGML_TENSOR_UNARY_OP_LOCALS
7873
7605
 
7874
7606
  float * dst_ptr = (float *) dst->data;
7607
+ const int32_t lp0 = wsp_ggml_get_op_params_i32(dst, 0);
7608
+ const int32_t rp0 = wsp_ggml_get_op_params_i32(dst, 1);
7609
+ const int32_t lp1 = wsp_ggml_get_op_params_i32(dst, 2);
7610
+ const int32_t rp1 = wsp_ggml_get_op_params_i32(dst, 3);
7611
+ const int32_t lp2 = wsp_ggml_get_op_params_i32(dst, 4);
7612
+ const int32_t rp2 = wsp_ggml_get_op_params_i32(dst, 5);
7613
+ const int32_t lp3 = wsp_ggml_get_op_params_i32(dst, 6);
7614
+ const int32_t rp3 = wsp_ggml_get_op_params_i32(dst, 7);
7615
+
7875
7616
 
7876
7617
  // TODO: optimize
7877
7618
 
@@ -7880,10 +7621,12 @@ static void wsp_ggml_compute_forward_pad_f32(
7880
7621
  for (int64_t i0 = 0; i0 < ne0; ++i0) {
7881
7622
  for (int64_t i3 = 0; i3 < ne3; ++i3) {
7882
7623
  const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7883
-
7884
- const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
7885
-
7886
- if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
7624
+ if ((i0 >= lp0 && i0 < ne0 - rp0) \
7625
+ && (i1 >= lp1 && i1 < ne1 - rp1) \
7626
+ && (i2 >= lp2 && i2 < ne2 - rp2) \
7627
+ && (i3 >= lp3 && i3 < ne3 - rp3)) {
7628
+ const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
7629
+ const float * src_ptr = (const float *)((char *) src0->data + src_idx);
7887
7630
  dst_ptr[dst_idx] = *src_ptr;
7888
7631
  } else {
7889
7632
  dst_ptr[dst_idx] = 0;
@@ -8082,7 +7825,7 @@ static void wsp_ggml_compute_forward_timestep_embedding_f32(
8082
7825
  embed_data[j + half] = sinf(arg);
8083
7826
  }
8084
7827
  if (dim % 2 != 0 && ith == 0) {
8085
- embed_data[dim] = 0.f;
7828
+ embed_data[2 * half] = 0.f;
8086
7829
  }
8087
7830
  }
8088
7831
  }
@@ -8388,7 +8131,7 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
8388
8131
  }
8389
8132
 
8390
8133
  // V /= S
8391
- const float S_inv = 1.0f/S;
8134
+ const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
8392
8135
  wsp_ggml_vec_scale_f32(DV, VKQ32, S_inv);
8393
8136
 
8394
8137
  // dst indices
@@ -8861,8 +8604,7 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
8861
8604
  WSP_GGML_ASSERT(src4->nb[0] == sizeof(float));
8862
8605
  WSP_GGML_ASSERT(src5->nb[0] == sizeof(float));
8863
8606
  WSP_GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
8864
- // allows optimizing the modulo since n_group should be a power of 2
8865
- WSP_GGML_ASSERT((ng & -ng) == ng);
8607
+ WSP_GGML_ASSERT(nh % ng == 0);
8866
8608
 
8867
8609
  // heads per thread
8868
8610
  const int dh = (nh + nth - 1)/nth;
@@ -8891,8 +8633,9 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
8891
8633
  // n_head
8892
8634
  for (int h = ih0; h < ih1; ++h) {
8893
8635
  // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8894
- const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8636
+ const float dt_soft_plus = wsp_ggml_softplus(dt[h]);
8895
8637
  const float dA = expf(dt_soft_plus * A[h]);
8638
+ const int g = h / (nh / ng); // repeat_interleave
8896
8639
 
8897
8640
  // dim
8898
8641
  for (int i1 = 0; i1 < nr; ++i1) {
@@ -8915,8 +8658,8 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
8915
8658
  // TODO: maybe unroll more?
8916
8659
  for (int j = 0; j < 1; j++) {
8917
8660
  WSP_GGML_F32_VEC t0 = WSP_GGML_F32_VEC_LOAD(s0 + i + j*wsp_ggml_f32_epr + ii*nc);
8918
- WSP_GGML_F32_VEC t1 = WSP_GGML_F32_VEC_LOAD(B + i + j*wsp_ggml_f32_epr + (h & (ng - 1))*nc);
8919
- WSP_GGML_F32_VEC t2 = WSP_GGML_F32_VEC_LOAD(C + i + j*wsp_ggml_f32_epr + (h & (ng - 1))*nc);
8661
+ WSP_GGML_F32_VEC t1 = WSP_GGML_F32_VEC_LOAD(B + i + j*wsp_ggml_f32_epr + g*nc);
8662
+ WSP_GGML_F32_VEC t2 = WSP_GGML_F32_VEC_LOAD(C + i + j*wsp_ggml_f32_epr + g*nc);
8920
8663
 
8921
8664
  t0 = WSP_GGML_F32_VEC_MUL(t0, adA);
8922
8665
  t1 = WSP_GGML_F32_VEC_MUL(t1, axdt);
@@ -8930,6 +8673,9 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
8930
8673
  }
8931
8674
 
8932
8675
  sumf = WSP_GGML_F32xt_REDUCE_ONE(sum);
8676
+ #elif defined(__riscv_v_intrinsic)
8677
+ // todo: RVV implementation
8678
+ const int np = 0;
8933
8679
  #else
8934
8680
  const int np = (nc & ~(WSP_GGML_F32_STEP - 1));
8935
8681
 
@@ -8945,8 +8691,8 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
8945
8691
  for (int i = 0; i < np; i += WSP_GGML_F32_STEP) {
8946
8692
  for (int j = 0; j < WSP_GGML_F32_ARR; j++) {
8947
8693
  ax[j] = WSP_GGML_F32_VEC_LOAD(s0 + i + j*WSP_GGML_F32_EPR + ii*nc);
8948
- ay[j] = WSP_GGML_F32_VEC_LOAD(B + i + j*WSP_GGML_F32_EPR + (h & (ng - 1))*nc);
8949
- az[j] = WSP_GGML_F32_VEC_LOAD(C + i + j*WSP_GGML_F32_EPR + (h & (ng - 1))*nc);
8694
+ ay[j] = WSP_GGML_F32_VEC_LOAD(B + i + j*WSP_GGML_F32_EPR + g*nc);
8695
+ az[j] = WSP_GGML_F32_VEC_LOAD(C + i + j*WSP_GGML_F32_EPR + g*nc);
8950
8696
 
8951
8697
  ax[j] = WSP_GGML_F32_VEC_MUL(ax[j], adA);
8952
8698
  ay[j] = WSP_GGML_F32_VEC_MUL(ay[j], axdt);
@@ -8968,7 +8714,7 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
8968
8714
  // d_state
8969
8715
  for (int i0 = np; i0 < nc; ++i0) {
8970
8716
  const int i = i0 + ii*nc;
8971
- const int ig = i0 + (h & (ng - 1))*nc;
8717
+ const int ig = i0 + g*nc;
8972
8718
  // state = prev_state * dA + dB * x
8973
8719
  const float state = (s0[i] * dA) + (B[ig] * x_dt);
8974
8720
  // y = rowwise_dotprod(state, C)
@@ -8984,7 +8730,8 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
8984
8730
  // n_head
8985
8731
  for (int h = ih0; h < ih1; ++h) {
8986
8732
  // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8987
- const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8733
+ const float dt_soft_plus = wsp_ggml_softplus(dt[h]);
8734
+ const int g = h / (nh / ng); // repeat_interleave
8988
8735
 
8989
8736
  // dim
8990
8737
  for (int i1 = 0; i1 < nr; ++i1) {
@@ -8999,8 +8746,8 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
8999
8746
  // TODO: what happens when (d_state % svcntw()) != 0?
9000
8747
  for (int64_t k = 0; k < nc; k += svcntw()) {
9001
8748
  svfloat32_t vA = WSP_GGML_F32_VEC_LOAD(&A[h*nc + k]);
9002
- svfloat32_t vB = WSP_GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
9003
- svfloat32_t vC = WSP_GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
8749
+ svfloat32_t vB = WSP_GGML_F32_VEC_LOAD(&B[k + g*nc]);
8750
+ svfloat32_t vC = WSP_GGML_F32_VEC_LOAD(&C[k + g*nc]);
9004
8751
  svfloat32_t vs0 = WSP_GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
9005
8752
 
9006
8753
  svfloat32_t t1 = WSP_GGML_F32_VEC_MUL(vdt_soft_plus, vA);
@@ -9020,7 +8767,7 @@ static void wsp_ggml_compute_forward_ssm_scan_f32(
9020
8767
  // d_state
9021
8768
  for (int i0 = 0; i0 < nc; ++i0) {
9022
8769
  const int i = i0 + ii*nc;
9023
- const int ig = i0 + (h & (ng - 1))*nc;
8770
+ const int ig = i0 + g*nc;
9024
8771
  // state = prev_state * dA + dB * x
9025
8772
  const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
9026
8773
  // y = rowwise_dotprod(state, C)
@@ -9246,6 +8993,26 @@ void wsp_ggml_compute_forward_unary(
9246
8993
  {
9247
8994
  wsp_ggml_compute_forward_exp(params, dst);
9248
8995
  } break;
8996
+ case WSP_GGML_UNARY_OP_FLOOR:
8997
+ {
8998
+ wsp_ggml_compute_forward_floor(params, dst);
8999
+ } break;
9000
+ case WSP_GGML_UNARY_OP_CEIL:
9001
+ {
9002
+ wsp_ggml_compute_forward_ceil(params, dst);
9003
+ } break;
9004
+ case WSP_GGML_UNARY_OP_ROUND:
9005
+ {
9006
+ wsp_ggml_compute_forward_round(params, dst);
9007
+ } break;
9008
+ case WSP_GGML_UNARY_OP_TRUNC:
9009
+ {
9010
+ wsp_ggml_compute_forward_trunc(params, dst);
9011
+ } break;
9012
+ case WSP_GGML_UNARY_OP_XIELU:
9013
+ {
9014
+ wsp_ggml_compute_forward_xielu(params, dst);
9015
+ } break;
9249
9016
  default:
9250
9017
  {
9251
9018
  WSP_GGML_ABORT("fatal error");
@@ -9881,8 +9648,8 @@ static void wsp_ggml_compute_forward_rwkv_wkv7_f32(
9881
9648
  int64_t h_stride_2d = head_size * head_size;
9882
9649
 
9883
9650
  #if defined(WSP_GGML_SIMD)
9884
- #if defined(__ARM_FEATURE_SVE)
9885
- // scalar Route to scalar implementation //TODO: Write SVE code
9651
+ #if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
9652
+ // scalar Route to scalar implementation //TODO: Write SVE code and RVV code
9886
9653
  for (int64_t t = 0; t < T; t++) {
9887
9654
  int64_t t_offset = t * t_stride;
9888
9655
  int64_t state_offset = head_size * C * (t / (T / n_seqs));