whisper.rn 0.4.2 → 0.5.0-rc.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. package/README.md +1 -3
  2. package/android/build.gradle +70 -11
  3. package/android/src/main/CMakeLists.txt +28 -1
  4. package/android/src/main/java/com/rnwhisper/JSCallInvokerResolver.java +40 -0
  5. package/android/src/main/java/com/rnwhisper/RNWhisper.java +80 -27
  6. package/android/src/main/java/com/rnwhisper/WhisperContext.java +21 -9
  7. package/android/src/main/java/com/rnwhisper/WhisperVadContext.java +1 -1
  8. package/android/src/main/jni.cpp +79 -2
  9. package/android/src/main/jniLibs/arm64-v8a/librnwhisper.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnwhisper_v8fp16_va_2.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper.so +0 -0
  12. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper_vfpv4.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnwhisper.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/librnwhisper_x86_64.so +0 -0
  15. package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +5 -0
  16. package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +5 -0
  17. package/cpp/ggml-backend.cpp +36 -18
  18. package/cpp/ggml-backend.h +1 -1
  19. package/cpp/ggml-cpu/amx/mmq.cpp +10 -9
  20. package/cpp/ggml-cpu/arch/arm/quants.c +109 -108
  21. package/cpp/ggml-cpu/arch/arm/repack.cpp +13 -12
  22. package/cpp/ggml-cpu/arch/x86/quants.c +83 -82
  23. package/cpp/ggml-cpu/arch/x86/repack.cpp +20 -19
  24. package/cpp/ggml-cpu/common.h +3 -2
  25. package/cpp/ggml-cpu/ggml-cpu-impl.h +9 -3
  26. package/cpp/ggml-cpu/ggml-cpu.c +95 -17
  27. package/cpp/ggml-cpu/ggml-cpu.cpp +4 -0
  28. package/cpp/ggml-cpu/ops.cpp +775 -74
  29. package/cpp/ggml-cpu/ops.h +7 -0
  30. package/cpp/ggml-cpu/quants.c +25 -24
  31. package/cpp/ggml-cpu/repack.cpp +15 -14
  32. package/cpp/ggml-cpu/simd-mappings.h +211 -33
  33. package/cpp/ggml-cpu/vec.cpp +26 -2
  34. package/cpp/ggml-cpu/vec.h +99 -45
  35. package/cpp/ggml-cpu.h +2 -0
  36. package/cpp/ggml-impl.h +125 -183
  37. package/cpp/ggml-metal-impl.h +27 -0
  38. package/cpp/ggml-metal.m +298 -41
  39. package/cpp/ggml-quants.c +6 -6
  40. package/cpp/ggml-whisper-sim.metallib +0 -0
  41. package/cpp/ggml-whisper.metallib +0 -0
  42. package/cpp/ggml.c +269 -40
  43. package/cpp/ggml.h +122 -2
  44. package/cpp/gguf.cpp +5 -1
  45. package/cpp/jsi/RNWhisperJSI.cpp +681 -0
  46. package/cpp/jsi/RNWhisperJSI.h +44 -0
  47. package/cpp/jsi/ThreadPool.h +100 -0
  48. package/cpp/whisper.cpp +4 -0
  49. package/cpp/whisper.h +2 -0
  50. package/ios/RNWhisper.h +3 -0
  51. package/ios/RNWhisper.mm +66 -31
  52. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +1 -1
  53. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
  54. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +125 -183
  55. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
  56. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +122 -2
  57. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +2 -0
  58. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  59. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  60. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +1 -1
  61. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
  62. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +125 -183
  63. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
  64. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +122 -2
  65. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +2 -0
  66. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  67. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  68. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +1 -1
  69. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
  70. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +125 -183
  71. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
  72. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +122 -2
  73. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +2 -0
  74. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  75. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  76. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +1 -1
  77. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
  78. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +125 -183
  79. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
  80. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +122 -2
  81. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +2 -0
  82. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  83. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  84. package/jest/mock.js +1 -0
  85. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  86. package/lib/commonjs/index.js +83 -2
  87. package/lib/commonjs/index.js.map +1 -1
  88. package/lib/module/NativeRNWhisper.js.map +1 -1
  89. package/lib/module/index.js +83 -2
  90. package/lib/module/index.js.map +1 -1
  91. package/lib/typescript/NativeRNWhisper.d.ts +4 -0
  92. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  93. package/lib/typescript/index.d.ts +18 -6
  94. package/lib/typescript/index.d.ts.map +1 -1
  95. package/package.json +2 -3
  96. package/src/NativeRNWhisper.ts +2 -0
  97. package/src/index.ts +162 -33
  98. package/whisper-rn.podspec +6 -3
@@ -3,6 +3,7 @@
3
3
  #include "ggml-cpu.h"
4
4
  #include "ggml-impl.h"
5
5
  #include "binary-ops.h"
6
+ #include "ggml.h"
6
7
  #include "unary-ops.h"
7
8
  #include "vec.h"
8
9
 
@@ -108,7 +109,7 @@ static void wsp_ggml_compute_forward_dup_f16(
108
109
  for (int i01 = ir0; i01 < ir1; i01++) {
109
110
  const wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
110
111
  for (int i00 = 0; i00 < ne00; i00++) {
111
- dst_ptr[id] = WSP_GGML_FP16_TO_FP32(src0_ptr[i00]);
112
+ dst_ptr[id] = WSP_GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
112
113
  id++;
113
114
  }
114
115
  }
@@ -130,7 +131,7 @@ static void wsp_ggml_compute_forward_dup_f16(
130
131
  const wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
131
132
 
132
133
  for (int i00 = 0; i00 < ne00; i00++) {
133
- src0_f32[i00] = WSP_GGML_FP16_TO_FP32(src0_ptr[i00]);
134
+ src0_f32[i00] = WSP_GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
134
135
  }
135
136
 
136
137
  wsp_quantize_row_q(src0_f32, dst_ptr + id, ne00);
@@ -156,7 +157,7 @@ static void wsp_ggml_compute_forward_dup_f16(
156
157
  for (int i00 = 0; i00 < ne00; i00++) {
157
158
  const wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
158
159
 
159
- dst_ptr[id] = WSP_GGML_FP16_TO_FP32(*src0_ptr);
160
+ dst_ptr[id] = WSP_GGML_CPU_FP16_TO_FP32(*src0_ptr);
160
161
  id++;
161
162
  }
162
163
  }
@@ -267,7 +268,7 @@ static void wsp_ggml_compute_forward_dup_f16(
267
268
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
268
269
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
269
270
 
270
- *(float *) dst_ptr = WSP_GGML_FP16_TO_FP32(*(const wsp_ggml_fp16_t *) src0_ptr);
271
+ *(float *) dst_ptr = WSP_GGML_CPU_FP16_TO_FP32(*(const wsp_ggml_fp16_t *) src0_ptr);
271
272
 
272
273
  if (++i10 == ne0) {
273
274
  i10 = 0;
@@ -372,7 +373,7 @@ static void wsp_ggml_compute_forward_dup_bf16(
372
373
  for (int i01 = ir0; i01 < ir1; i01++) {
373
374
  const wsp_ggml_bf16_t * src0_ptr = (wsp_ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
374
375
  for (int i00 = 0; i00 < ne00; i00++) {
375
- dst_ptr[id] = WSP_GGML_FP32_TO_FP16(WSP_GGML_BF16_TO_FP32(src0_ptr[i00]));
376
+ dst_ptr[id] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_BF16_TO_FP32(src0_ptr[i00]));
376
377
  id++;
377
378
  }
378
379
  }
@@ -473,7 +474,7 @@ static void wsp_ggml_compute_forward_dup_bf16(
473
474
  for (int i00 = 0; i00 < ne00; i00++) {
474
475
  const wsp_ggml_bf16_t * src0_ptr = (wsp_ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
475
476
 
476
- dst_ptr[id] = WSP_GGML_FP32_TO_FP16(WSP_GGML_BF16_TO_FP32(*src0_ptr));
477
+ dst_ptr[id] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_BF16_TO_FP32(*src0_ptr));
477
478
  id++;
478
479
  }
479
480
  }
@@ -566,7 +567,7 @@ static void wsp_ggml_compute_forward_dup_bf16(
566
567
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
567
568
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
568
569
 
569
- *(wsp_ggml_fp16_t *) dst_ptr = WSP_GGML_FP32_TO_FP16(WSP_GGML_BF16_TO_FP32(*(const wsp_ggml_bf16_t *) src0_ptr));
570
+ *(wsp_ggml_fp16_t *) dst_ptr = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_BF16_TO_FP32(*(const wsp_ggml_bf16_t *) src0_ptr));
570
571
 
571
572
  if (++i10 == ne0) {
572
573
  i10 = 0;
@@ -696,24 +697,8 @@ static void wsp_ggml_compute_forward_dup_f32(
696
697
  if (wsp_ggml_is_contiguous(dst)) {
697
698
  // TODO: simplify
698
699
  if (nb00 == sizeof(float)) {
699
- if (dst->type == WSP_GGML_TYPE_F32) {
700
- size_t id = 0;
701
- const size_t rs = ne00 * nb00;
702
- char * dst_ptr = (char *) dst->data;
703
-
704
- for (int i03 = 0; i03 < ne03; i03++) {
705
- for (int i02 = 0; i02 < ne02; i02++) {
706
- id += rs * ir0;
707
- for (int i01 = ir0; i01 < ir1; i01++) {
708
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
709
- memcpy(dst_ptr + id, src0_ptr, rs);
710
- id += rs;
711
- }
712
- id += rs * (ne01 - ir1);
713
- }
714
- }
715
- } else if (wsp_ggml_get_type_traits_cpu(dst->type)->from_float) {
716
- wsp_ggml_from_float_t const wsp_quantize_row_q = wsp_ggml_get_type_traits_cpu(dst->type)->from_float;
700
+ if (wsp_ggml_get_type_traits_cpu(dst->type)->from_float) {
701
+ wsp_ggml_from_float_t const from_float = wsp_ggml_get_type_traits_cpu(dst->type)->from_float;
717
702
 
718
703
  size_t id = 0;
719
704
  size_t rs = nb0 * (ne00 / wsp_ggml_blck_size(dst->type));
@@ -724,7 +709,7 @@ static void wsp_ggml_compute_forward_dup_f32(
724
709
  id += rs * ir0;
725
710
  for (int i01 = ir0; i01 < ir1; i01++) {
726
711
  const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
727
- wsp_quantize_row_q(src0_ptr, dst_ptr + id, ne00);
712
+ from_float(src0_ptr, dst_ptr + id, ne00);
728
713
  id += rs;
729
714
  }
730
715
  id += rs * (ne01 - ir1);
@@ -765,7 +750,7 @@ static void wsp_ggml_compute_forward_dup_f32(
765
750
  for (int i00 = 0; i00 < ne00; i00++) {
766
751
  const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
767
752
 
768
- dst_ptr[id] = WSP_GGML_FP32_TO_FP16(*src0_ptr);
753
+ dst_ptr[id] = WSP_GGML_CPU_FP32_TO_FP16(*src0_ptr);
769
754
  id++;
770
755
  }
771
756
  }
@@ -878,7 +863,7 @@ static void wsp_ggml_compute_forward_dup_f32(
878
863
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
879
864
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
880
865
 
881
- *(wsp_ggml_fp16_t *) dst_ptr = WSP_GGML_FP32_TO_FP16(*(const float *) src0_ptr);
866
+ *(wsp_ggml_fp16_t *) dst_ptr = WSP_GGML_CPU_FP32_TO_FP16(*(const float *) src0_ptr);
882
867
 
883
868
  if (++i10 == ne0) {
884
869
  i10 = 0;
@@ -1419,7 +1404,7 @@ static void wsp_ggml_compute_forward_add1_f16_f32(
1419
1404
  wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
1420
1405
  wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
1421
1406
  for (int i = 0; i < ne0; i++) {
1422
- dst_ptr[i] = WSP_GGML_FP32_TO_FP16(WSP_GGML_FP16_TO_FP32(src0_ptr[i]) + v);
1407
+ dst_ptr[i] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
1423
1408
  }
1424
1409
  }
1425
1410
  }
@@ -1435,7 +1420,7 @@ static void wsp_ggml_compute_forward_add1_f16_f16(
1435
1420
  WSP_GGML_ASSERT(wsp_ggml_is_scalar(src1));
1436
1421
 
1437
1422
  // scalar to add
1438
- const float v = WSP_GGML_FP16_TO_FP32(*(wsp_ggml_fp16_t *) src1->data);
1423
+ const float v = WSP_GGML_CPU_FP16_TO_FP32(*(wsp_ggml_fp16_t *) src1->data);
1439
1424
 
1440
1425
  const int ith = params->ith;
1441
1426
  const int nth = params->nth;
@@ -1467,7 +1452,7 @@ static void wsp_ggml_compute_forward_add1_f16_f16(
1467
1452
  wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
1468
1453
  wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
1469
1454
  for (int i = 0; i < ne0; i++) {
1470
- dst_ptr[i] = WSP_GGML_FP32_TO_FP16(WSP_GGML_FP16_TO_FP32(src0_ptr[i]) + v);
1455
+ dst_ptr[i] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
1471
1456
  }
1472
1457
  }
1473
1458
  }
@@ -1889,7 +1874,7 @@ static void wsp_ggml_compute_forward_sum_f16(
1889
1874
  }
1890
1875
  }
1891
1876
  }
1892
- ((wsp_ggml_fp16_t *) dst->data)[0] = WSP_GGML_FP32_TO_FP16(sum);
1877
+ ((wsp_ggml_fp16_t *) dst->data)[0] = WSP_GGML_CPU_FP32_TO_FP16(sum);
1893
1878
  }
1894
1879
 
1895
1880
  static void wsp_ggml_compute_forward_sum_bf16(
@@ -2300,6 +2285,12 @@ void wsp_ggml_compute_forward_repeat(
2300
2285
  {
2301
2286
  wsp_ggml_compute_forward_repeat_f32(params, dst);
2302
2287
  } break;
2288
+ // TODO: templateify the implemenation and support for I64
2289
+ // ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
2290
+ //case WSP_GGML_TYPE_I64:
2291
+ // {
2292
+ // wsp_ggml_compute_forward_repeat_i64(params, dst);
2293
+ // } break;
2303
2294
  default:
2304
2295
  {
2305
2296
  WSP_GGML_ABORT("fatal error");
@@ -2660,7 +2651,7 @@ static void wsp_ggml_compute_forward_gelu_f16(
2660
2651
  #ifndef NDEBUG
2661
2652
  for (int k = 0; k < nc; k++) {
2662
2653
  const wsp_ggml_fp16_t x = ((wsp_ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2663
- const float v = WSP_GGML_FP16_TO_FP32(x);
2654
+ const float v = WSP_GGML_CPU_FP16_TO_FP32(x);
2664
2655
  WSP_GGML_UNUSED(v);
2665
2656
  assert(!isnan(v));
2666
2657
  assert(!isinf(v));
@@ -2763,7 +2754,7 @@ static void wsp_ggml_compute_forward_gelu_erf_f16(
2763
2754
  #ifndef NDEBUG
2764
2755
  for (int k = 0; k < nc; k++) {
2765
2756
  const wsp_ggml_fp16_t x = ((wsp_ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2766
- const float v = WSP_GGML_FP16_TO_FP32(x);
2757
+ const float v = WSP_GGML_CPU_FP16_TO_FP32(x);
2767
2758
  WSP_GGML_UNUSED(v);
2768
2759
  assert(!isnan(v));
2769
2760
  assert(!isinf(v));
@@ -2866,7 +2857,7 @@ static void wsp_ggml_compute_forward_gelu_quick_f16(
2866
2857
  #ifndef NDEBUG
2867
2858
  for (int k = 0; k < nc; k++) {
2868
2859
  const wsp_ggml_fp16_t x = ((wsp_ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2869
- const float v = WSP_GGML_FP16_TO_FP32(x);
2860
+ const float v = WSP_GGML_CPU_FP16_TO_FP32(x);
2870
2861
  WSP_GGML_UNUSED(v);
2871
2862
  assert(!isnan(v));
2872
2863
  assert(!isinf(v));
@@ -2969,7 +2960,7 @@ static void wsp_ggml_compute_forward_silu_f16(
2969
2960
  #ifndef NDEBUG
2970
2961
  for (int k = 0; k < nc; k++) {
2971
2962
  const wsp_ggml_fp16_t x = ((wsp_ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k];
2972
- const float v = WSP_GGML_FP16_TO_FP32(x);
2963
+ const float v = WSP_GGML_CPU_FP16_TO_FP32(x);
2973
2964
  WSP_GGML_UNUSED(v);
2974
2965
  assert(!isnan(v));
2975
2966
  assert(!isinf(v));
@@ -3163,7 +3154,7 @@ static void wsp_ggml_compute_forward_silu_back_f16(
3163
3154
  #ifndef NDEBUG
3164
3155
  for (int k = 0; k < nc; k++) {
3165
3156
  const float x = ((wsp_ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3166
- const float v = WSP_GGML_FP16_TO_FP32(x);
3157
+ const float v = WSP_GGML_CPU_FP16_TO_FP32(x);
3167
3158
  WSP_GGML_UNUSED(v);
3168
3159
  assert(!isnan(v));
3169
3160
  assert(!isinf(v));
@@ -3194,6 +3185,435 @@ void wsp_ggml_compute_forward_silu_back(
3194
3185
  }
3195
3186
  }
3196
3187
 
3188
+ // wsp_ggml_compute_forward_reglu
3189
+
3190
+ static void wsp_ggml_compute_forward_reglu_f32(
3191
+ const wsp_ggml_compute_params * params,
3192
+ wsp_ggml_tensor * dst) {
3193
+
3194
+ const wsp_ggml_tensor * src0 = dst->src[0];
3195
+ const wsp_ggml_tensor * src1 = dst->src[1];
3196
+ char * src0_d = (char *) src0->data;
3197
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3198
+ const size_t src0_o = src0->nb[1];
3199
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3200
+
3201
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0));
3202
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(dst));
3203
+
3204
+ if (src1) {
3205
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src1));
3206
+ WSP_GGML_ASSERT(src0->type == src1->type);
3207
+ }
3208
+
3209
+ const int ith = params->ith;
3210
+ const int nth = params->nth;
3211
+
3212
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3213
+ const int nr = wsp_ggml_nrows(src0);
3214
+
3215
+ WSP_GGML_ASSERT(dst->ne[0] == nc);
3216
+ WSP_GGML_ASSERT(wsp_ggml_nrows(dst) == nr);
3217
+
3218
+ const int32_t swapped = wsp_ggml_get_op_params_i32(dst, 1);
3219
+
3220
+ // rows per thread
3221
+ const int dr = (nr + nth - 1)/nth;
3222
+
3223
+ // row range for this thread
3224
+ const int ir0 = dr*ith;
3225
+ const int ir1 = MIN(ir0 + dr, nr);
3226
+
3227
+ for (int i1 = ir0; i1 < ir1; i1++) {
3228
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3229
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3230
+
3231
+ if (!src1) {
3232
+ src0_p += swapped ? nc : 0;
3233
+ src1_p += swapped ? 0 : nc;
3234
+ }
3235
+
3236
+ wsp_ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3237
+
3238
+ #ifndef NDEBUG
3239
+ for (int k = 0; k < nc; k++) {
3240
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3241
+ WSP_GGML_UNUSED(x);
3242
+ assert(!isnan(x));
3243
+ assert(!isinf(x));
3244
+ }
3245
+ #endif
3246
+ }
3247
+ }
3248
+
3249
+ static void wsp_ggml_compute_forward_reglu_f16(
3250
+ const wsp_ggml_compute_params * params,
3251
+ wsp_ggml_tensor * dst) {
3252
+
3253
+ const wsp_ggml_tensor * src0 = dst->src[0];
3254
+ const wsp_ggml_tensor * src1 = dst->src[1];
3255
+ char * src0_d = (char *) src0->data;
3256
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3257
+ const size_t src0_o = src0->nb[1];
3258
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3259
+
3260
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0));
3261
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(dst));
3262
+
3263
+ if (src1) {
3264
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src1));
3265
+ WSP_GGML_ASSERT(src0->type == src1->type);
3266
+ }
3267
+
3268
+ const int ith = params->ith;
3269
+ const int nth = params->nth;
3270
+
3271
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3272
+ const int nr = wsp_ggml_nrows(src0);
3273
+
3274
+ WSP_GGML_ASSERT(dst->ne[0] == nc);
3275
+ WSP_GGML_ASSERT(wsp_ggml_nrows(dst) == nr);
3276
+
3277
+ const int32_t swapped = wsp_ggml_get_op_params_i32(dst, 1);
3278
+
3279
+ // rows per thread
3280
+ const int dr = (nr + nth - 1)/nth;
3281
+
3282
+ // row range for this thread
3283
+ const int ir0 = dr*ith;
3284
+ const int ir1 = MIN(ir0 + dr, nr);
3285
+
3286
+ for (int i1 = ir0; i1 < ir1; i1++) {
3287
+ wsp_ggml_fp16_t * src0_p = (wsp_ggml_fp16_t *) (src0_d + i1*src0_o);
3288
+ wsp_ggml_fp16_t * src1_p = (wsp_ggml_fp16_t *) (src1_d + i1*src1_o);
3289
+
3290
+ if (!src1) {
3291
+ src0_p += swapped ? nc : 0;
3292
+ src1_p += swapped ? 0 : nc;
3293
+ }
3294
+
3295
+ wsp_ggml_vec_reglu_f16(nc, (wsp_ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3296
+
3297
+ #ifndef NDEBUG
3298
+ for (int k = 0; k < nc; k++) {
3299
+ const wsp_ggml_fp16_t x = ((wsp_ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3300
+ const float v = WSP_GGML_FP16_TO_FP32(x);
3301
+ WSP_GGML_UNUSED(v);
3302
+ assert(!isnan(v));
3303
+ assert(!isinf(v));
3304
+ }
3305
+ #endif
3306
+ }
3307
+ }
3308
+
3309
+ static void wsp_ggml_compute_forward_reglu(
3310
+ const wsp_ggml_compute_params * params,
3311
+ wsp_ggml_tensor * dst) {
3312
+
3313
+ const wsp_ggml_tensor * src0 = dst->src[0];
3314
+
3315
+ switch (src0->type) {
3316
+ case WSP_GGML_TYPE_F32:
3317
+ {
3318
+ wsp_ggml_compute_forward_reglu_f32(params, dst);
3319
+ } break;
3320
+ case WSP_GGML_TYPE_F16:
3321
+ {
3322
+ wsp_ggml_compute_forward_reglu_f16(params, dst);
3323
+ } break;
3324
+ default:
3325
+ {
3326
+ WSP_GGML_ABORT("fatal error");
3327
+ }
3328
+ }
3329
+ }
3330
+
3331
+ // wsp_ggml_compute_forward_geglu
3332
+
3333
+ static void wsp_ggml_compute_forward_geglu_f32(
3334
+ const wsp_ggml_compute_params * params,
3335
+ wsp_ggml_tensor * dst) {
3336
+
3337
+ const wsp_ggml_tensor * src0 = dst->src[0];
3338
+ const wsp_ggml_tensor * src1 = dst->src[1];
3339
+ char * src0_d = (char *) src0->data;
3340
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3341
+ const size_t src0_o = src0->nb[1];
3342
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3343
+
3344
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0));
3345
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(dst));
3346
+
3347
+ if (src1) {
3348
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src1));
3349
+ WSP_GGML_ASSERT(src0->type == src1->type);
3350
+ }
3351
+
3352
+ const int ith = params->ith;
3353
+ const int nth = params->nth;
3354
+
3355
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3356
+ const int nr = wsp_ggml_nrows(src0);
3357
+
3358
+ WSP_GGML_ASSERT(dst->ne[0] == nc);
3359
+ WSP_GGML_ASSERT(wsp_ggml_nrows(dst) == nr);
3360
+
3361
+ const int32_t swapped = wsp_ggml_get_op_params_i32(dst, 1);
3362
+
3363
+ // rows per thread
3364
+ const int dr = (nr + nth - 1)/nth;
3365
+
3366
+ // row range for this thread
3367
+ const int ir0 = dr*ith;
3368
+ const int ir1 = MIN(ir0 + dr, nr);
3369
+
3370
+ for (int i1 = ir0; i1 < ir1; i1++) {
3371
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3372
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3373
+
3374
+ if (!src1) {
3375
+ src0_p += swapped ? nc : 0;
3376
+ src1_p += swapped ? 0 : nc;
3377
+ }
3378
+
3379
+ wsp_ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3380
+
3381
+ #ifndef NDEBUG
3382
+ for (int k = 0; k < nc; k++) {
3383
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3384
+ WSP_GGML_UNUSED(x);
3385
+ assert(!isnan(x));
3386
+ assert(!isinf(x));
3387
+ }
3388
+ #endif
3389
+ }
3390
+ }
3391
+
3392
+ static void wsp_ggml_compute_forward_geglu_f16(
3393
+ const wsp_ggml_compute_params * params,
3394
+ wsp_ggml_tensor * dst) {
3395
+
3396
+ const wsp_ggml_tensor * src0 = dst->src[0];
3397
+ const wsp_ggml_tensor * src1 = dst->src[1];
3398
+ char * src0_d = (char *) src0->data;
3399
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3400
+ const size_t src0_o = src0->nb[1];
3401
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3402
+
3403
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0));
3404
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(dst));
3405
+
3406
+ if (src1) {
3407
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src1));
3408
+ WSP_GGML_ASSERT(src0->type == src1->type);
3409
+ }
3410
+
3411
+ const int ith = params->ith;
3412
+ const int nth = params->nth;
3413
+
3414
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3415
+ const int nr = wsp_ggml_nrows(src0);
3416
+
3417
+ WSP_GGML_ASSERT(dst->ne[0] == nc);
3418
+ WSP_GGML_ASSERT(wsp_ggml_nrows(dst) == nr);
3419
+
3420
+ const int32_t swapped = wsp_ggml_get_op_params_i32(dst, 1);
3421
+
3422
+ // rows per thread
3423
+ const int dr = (nr + nth - 1)/nth;
3424
+
3425
+ // row range for this thread
3426
+ const int ir0 = dr*ith;
3427
+ const int ir1 = MIN(ir0 + dr, nr);
3428
+
3429
+ for (int i1 = ir0; i1 < ir1; i1++) {
3430
+ wsp_ggml_fp16_t * src0_p = (wsp_ggml_fp16_t *) (src0_d + i1*src0_o);
3431
+ wsp_ggml_fp16_t * src1_p = (wsp_ggml_fp16_t *) (src1_d + i1*src1_o);
3432
+
3433
+ if (!src1) {
3434
+ src0_p += swapped ? nc : 0;
3435
+ src1_p += swapped ? 0 : nc;
3436
+ }
3437
+
3438
+ wsp_ggml_vec_geglu_f16(nc, (wsp_ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3439
+
3440
+ #ifndef NDEBUG
3441
+ for (int k = 0; k < nc; k++) {
3442
+ const wsp_ggml_fp16_t x = ((wsp_ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3443
+ const float v = WSP_GGML_FP16_TO_FP32(x);
3444
+ WSP_GGML_UNUSED(v);
3445
+ assert(!isnan(v));
3446
+ assert(!isinf(v));
3447
+ }
3448
+ #endif
3449
+ }
3450
+ }
3451
+
3452
+ static void wsp_ggml_compute_forward_geglu(
3453
+ const wsp_ggml_compute_params * params,
3454
+ wsp_ggml_tensor * dst) {
3455
+
3456
+ const wsp_ggml_tensor * src0 = dst->src[0];
3457
+
3458
+ switch (src0->type) {
3459
+ case WSP_GGML_TYPE_F32:
3460
+ {
3461
+ wsp_ggml_compute_forward_geglu_f32(params, dst);
3462
+ } break;
3463
+ case WSP_GGML_TYPE_F16:
3464
+ {
3465
+ wsp_ggml_compute_forward_geglu_f16(params, dst);
3466
+ } break;
3467
+ default:
3468
+ {
3469
+ WSP_GGML_ABORT("fatal error");
3470
+ }
3471
+ }
3472
+ }
3473
+
3474
+ // wsp_ggml_compute_forward_swiglu
3475
+
3476
+ static void wsp_ggml_compute_forward_swiglu_f32(
3477
+ const wsp_ggml_compute_params * params,
3478
+ wsp_ggml_tensor * dst) {
3479
+
3480
+ const wsp_ggml_tensor * src0 = dst->src[0];
3481
+ const wsp_ggml_tensor * src1 = dst->src[1];
3482
+ char * src0_d = (char *) src0->data;
3483
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3484
+ const size_t src0_o = src0->nb[1];
3485
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3486
+
3487
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0));
3488
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(dst));
3489
+
3490
+ if (src1) {
3491
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src1));
3492
+ WSP_GGML_ASSERT(src0->type == src1->type);
3493
+ }
3494
+
3495
+ const int ith = params->ith;
3496
+ const int nth = params->nth;
3497
+
3498
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3499
+ const int nr = wsp_ggml_nrows(src0);
3500
+
3501
+ WSP_GGML_ASSERT(dst->ne[0] == nc);
3502
+ WSP_GGML_ASSERT(wsp_ggml_nrows(dst) == nr);
3503
+
3504
+ const int32_t swapped = wsp_ggml_get_op_params_i32(dst, 1);
3505
+
3506
+ // rows per thread
3507
+ const int dr = (nr + nth - 1)/nth;
3508
+
3509
+ // row range for this thread
3510
+ const int ir0 = dr*ith;
3511
+ const int ir1 = MIN(ir0 + dr, nr);
3512
+
3513
+ for (int i1 = ir0; i1 < ir1; i1++) {
3514
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3515
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3516
+
3517
+ if (!src1) {
3518
+ src0_p += swapped ? nc : 0;
3519
+ src1_p += swapped ? 0 : nc;
3520
+ }
3521
+
3522
+ wsp_ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3523
+
3524
+ #ifndef NDEBUG
3525
+ for (int k = 0; k < nc; k++) {
3526
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3527
+ WSP_GGML_UNUSED(x);
3528
+ assert(!isnan(x));
3529
+ assert(!isinf(x));
3530
+ }
3531
+ #endif
3532
+ }
3533
+ }
3534
+
3535
+ static void wsp_ggml_compute_forward_swiglu_f16(
3536
+ const wsp_ggml_compute_params * params,
3537
+ wsp_ggml_tensor * dst) {
3538
+
3539
+ const wsp_ggml_tensor * src0 = dst->src[0];
3540
+ const wsp_ggml_tensor * src1 = dst->src[1];
3541
+ char * src0_d = (char *) src0->data;
3542
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3543
+ const size_t src0_o = src0->nb[1];
3544
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3545
+
3546
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0));
3547
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(dst));
3548
+
3549
+ if (src1) {
3550
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src1));
3551
+ WSP_GGML_ASSERT(src0->type == src1->type);
3552
+ }
3553
+
3554
+ const int ith = params->ith;
3555
+ const int nth = params->nth;
3556
+
3557
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3558
+ const int nr = wsp_ggml_nrows(src0);
3559
+
3560
+ WSP_GGML_ASSERT(dst->ne[0] == nc);
3561
+ WSP_GGML_ASSERT(wsp_ggml_nrows(dst) == nr);
3562
+
3563
+ const int32_t swapped = wsp_ggml_get_op_params_i32(dst, 1);
3564
+
3565
+ // rows per thread
3566
+ const int dr = (nr + nth - 1)/nth;
3567
+
3568
+ // row range for this thread
3569
+ const int ir0 = dr*ith;
3570
+ const int ir1 = MIN(ir0 + dr, nr);
3571
+
3572
+ for (int i1 = ir0; i1 < ir1; i1++) {
3573
+ wsp_ggml_fp16_t * src0_p = (wsp_ggml_fp16_t *) (src0_d + i1*src0_o);
3574
+ wsp_ggml_fp16_t * src1_p = (wsp_ggml_fp16_t *) (src1_d + i1*src1_o);
3575
+
3576
+ if (!src1) {
3577
+ src0_p += swapped ? nc : 0;
3578
+ src1_p += swapped ? 0 : nc;
3579
+ }
3580
+
3581
+ wsp_ggml_vec_swiglu_f16(nc, (wsp_ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3582
+
3583
+ #ifndef NDEBUG
3584
+ for (int k = 0; k < nc; k++) {
3585
+ const wsp_ggml_fp16_t x = ((wsp_ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3586
+ const float v = WSP_GGML_FP16_TO_FP32(x);
3587
+ WSP_GGML_UNUSED(v);
3588
+ assert(!isnan(v));
3589
+ assert(!isinf(v));
3590
+ }
3591
+ #endif
3592
+ }
3593
+ }
3594
+
3595
+ static void wsp_ggml_compute_forward_swiglu(
3596
+ const wsp_ggml_compute_params * params,
3597
+ wsp_ggml_tensor * dst) {
3598
+
3599
+ const wsp_ggml_tensor * src0 = dst->src[0];
3600
+
3601
+ switch (src0->type) {
3602
+ case WSP_GGML_TYPE_F32:
3603
+ {
3604
+ wsp_ggml_compute_forward_swiglu_f32(params, dst);
3605
+ } break;
3606
+ case WSP_GGML_TYPE_F16:
3607
+ {
3608
+ wsp_ggml_compute_forward_swiglu_f16(params, dst);
3609
+ } break;
3610
+ default:
3611
+ {
3612
+ WSP_GGML_ABORT("fatal error");
3613
+ }
3614
+ }
3615
+ }
3616
+
3197
3617
  // wsp_ggml_compute_forward_norm
3198
3618
 
3199
3619
  static void wsp_ggml_compute_forward_norm_f32(
@@ -4470,6 +4890,74 @@ void wsp_ggml_compute_forward_get_rows(
4470
4890
  //}
4471
4891
  }
4472
4892
 
4893
+ static void wsp_ggml_compute_forward_set_rows_f32(
4894
+ const wsp_ggml_compute_params * params,
4895
+ wsp_ggml_tensor * dst) {
4896
+
4897
+ const wsp_ggml_tensor * src0 = dst->src[0];
4898
+ const wsp_ggml_tensor * src1 = dst->src[1];
4899
+
4900
+ WSP_GGML_TENSOR_BINARY_OP_LOCALS
4901
+
4902
+ const int64_t nc = ne00;
4903
+ const int64_t nr = ne01;
4904
+
4905
+ assert(ne0 == nc);
4906
+ assert(ne2 == ne02);
4907
+ assert(ne3 == ne03);
4908
+ assert(src0->type == WSP_GGML_TYPE_F32);
4909
+ assert(ne02 % ne11 == 0);
4910
+ assert(ne03 % ne12 == 0);
4911
+
4912
+ const int ith = params->ith;
4913
+ const int nth = params->nth;
4914
+
4915
+ // rows per thread
4916
+ const int64_t dr = (nr + nth - 1)/nth;
4917
+
4918
+ // row range for this thread
4919
+ const int64_t ir0 = dr*ith;
4920
+ const int64_t ir1 = std::min(ir0 + dr, nr);
4921
+
4922
+ wsp_ggml_from_float_t const from_float = wsp_ggml_get_type_traits_cpu(dst->type)->from_float;
4923
+
4924
+ for (int64_t i03 = 0; i03 < ne03; ++i03) {
4925
+ for (int64_t i02 = 0; i02 < ne02; ++i02) {
4926
+ for (int64_t i = ir0; i < ir1; ++i) {
4927
+ const int64_t i12 = i03%ne12;
4928
+ const int64_t i11 = i02%ne11;
4929
+ const int64_t i10 = i;
4930
+
4931
+ const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4932
+
4933
+ WSP_GGML_ASSERT(i1 >= 0 && i1 < ne1);
4934
+
4935
+ from_float(
4936
+ (const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
4937
+ ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc);
4938
+ }
4939
+ }
4940
+ }
4941
+ }
4942
+
4943
+ void wsp_ggml_compute_forward_set_rows(
4944
+ const wsp_ggml_compute_params * params,
4945
+ wsp_ggml_tensor * dst) {
4946
+
4947
+ const wsp_ggml_tensor * src0 = dst->src[0];
4948
+
4949
+ switch (src0->type) {
4950
+ case WSP_GGML_TYPE_F32:
4951
+ {
4952
+ wsp_ggml_compute_forward_set_rows_f32(params, dst);
4953
+ } break;
4954
+ default:
4955
+ {
4956
+ WSP_GGML_ABORT("src0->type = %d (%s) not supported", src0->type, wsp_ggml_type_name(src0->type));
4957
+ }
4958
+ }
4959
+ }
4960
+
4473
4961
  // wsp_ggml_compute_forward_get_rows_back
4474
4962
 
4475
4963
  static void wsp_ggml_compute_forward_get_rows_back_f32_f16(
@@ -4500,7 +4988,7 @@ static void wsp_ggml_compute_forward_get_rows_back_f32_f16(
4500
4988
 
4501
4989
  for (int j = 0; j < nc; ++j) {
4502
4990
  wsp_ggml_fp16_t v = ((wsp_ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
4503
- ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += WSP_GGML_FP16_TO_FP32(v);
4991
+ ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += WSP_GGML_CPU_FP16_TO_FP32(v);
4504
4992
  }
4505
4993
  }
4506
4994
  }
@@ -4792,7 +5280,7 @@ static void wsp_ggml_compute_forward_soft_max_f32(
4792
5280
  if (mp_f32) {
4793
5281
  if (use_f16) {
4794
5282
  for (int i = 0; i < nc; ++i) {
4795
- wp[i] += slope*WSP_GGML_FP16_TO_FP32(mp_f16[i]);
5283
+ wp[i] += slope*WSP_GGML_CPU_FP16_TO_FP32(mp_f16[i]);
4796
5284
  }
4797
5285
  } else {
4798
5286
  for (int i = 0; i < nc; ++i) {
@@ -5018,8 +5506,8 @@ static void wsp_ggml_compute_forward_clamp_f16(
5018
5506
  wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + j*nb01);
5019
5507
 
5020
5508
  for (int i = 0; i < nc; i++) {
5021
- float v = WSP_GGML_FP16_TO_FP32(src0_ptr[i]);
5022
- dst_ptr[i] = WSP_GGML_FP32_TO_FP16(MAX(MIN(v, max), min));
5509
+ float v = WSP_GGML_CPU_FP16_TO_FP32(src0_ptr[i]);
5510
+ dst_ptr[i] = WSP_GGML_CPU_FP32_TO_FP16(MAX(MIN(v, max), min));
5023
5511
  }
5024
5512
  }
5025
5513
  }
@@ -5476,11 +5964,11 @@ static void wsp_ggml_compute_forward_rope_f16(
5476
5964
  const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5477
5965
  wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5478
5966
 
5479
- const float x0 = WSP_GGML_FP16_TO_FP32(src[0]);
5480
- const float x1 = WSP_GGML_FP16_TO_FP32(src[n_dims]);
5967
+ const float x0 = WSP_GGML_CPU_FP16_TO_FP32(src[0]);
5968
+ const float x1 = WSP_GGML_CPU_FP16_TO_FP32(src[n_dims]);
5481
5969
 
5482
- dst_data[0] = WSP_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5483
- dst_data[n_dims] = WSP_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5970
+ dst_data[0] = WSP_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5971
+ dst_data[n_dims] = WSP_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5484
5972
  }
5485
5973
  } else {
5486
5974
  for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
@@ -5492,11 +5980,11 @@ static void wsp_ggml_compute_forward_rope_f16(
5492
5980
  const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5493
5981
  wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5494
5982
 
5495
- const float x0 = WSP_GGML_FP16_TO_FP32(src[0]);
5496
- const float x1 = WSP_GGML_FP16_TO_FP32(src[n_dims/2]);
5983
+ const float x0 = WSP_GGML_CPU_FP16_TO_FP32(src[0]);
5984
+ const float x1 = WSP_GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
5497
5985
 
5498
- dst_data[0] = WSP_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5499
- dst_data[n_dims/2] = WSP_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5986
+ dst_data[0] = WSP_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5987
+ dst_data[n_dims/2] = WSP_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5500
5988
  }
5501
5989
  }
5502
5990
  } else {
@@ -5507,11 +5995,11 @@ static void wsp_ggml_compute_forward_rope_f16(
5507
5995
  const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5508
5996
  wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5509
5997
 
5510
- const float x0 = WSP_GGML_FP16_TO_FP32(src[0]);
5511
- const float x1 = WSP_GGML_FP16_TO_FP32(src[1]);
5998
+ const float x0 = WSP_GGML_CPU_FP16_TO_FP32(src[0]);
5999
+ const float x1 = WSP_GGML_CPU_FP16_TO_FP32(src[1]);
5512
6000
 
5513
- dst_data[0] = WSP_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5514
- dst_data[1] = WSP_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
6001
+ dst_data[0] = WSP_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
6002
+ dst_data[1] = WSP_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5515
6003
  }
5516
6004
  }
5517
6005
 
@@ -5525,11 +6013,11 @@ static void wsp_ggml_compute_forward_rope_f16(
5525
6013
  const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5526
6014
  wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5527
6015
 
5528
- const float x0 = WSP_GGML_FP16_TO_FP32(src[0]);
5529
- const float x1 = WSP_GGML_FP16_TO_FP32(src[n_dims]);
6016
+ const float x0 = WSP_GGML_CPU_FP16_TO_FP32(src[0]);
6017
+ const float x1 = WSP_GGML_CPU_FP16_TO_FP32(src[n_dims]);
5530
6018
 
5531
- dst_data[0] = WSP_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5532
- dst_data[n_dims] = WSP_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
6019
+ dst_data[0] = WSP_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
6020
+ dst_data[n_dims] = WSP_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5533
6021
  }
5534
6022
  } else {
5535
6023
  for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
@@ -5640,7 +6128,7 @@ static void wsp_ggml_compute_forward_conv_transpose_1d_f16_f32(
5640
6128
  for (int64_t i11 = 0; i11 < ne11; i11++) {
5641
6129
  const float * const src = (float *)((char *) src1->data + i11*nb11);
5642
6130
  for (int64_t i10 = 0; i10 < ne10; i10++) {
5643
- dst_data[i10*ne11 + i11] = WSP_GGML_FP32_TO_FP16(src[i10]);
6131
+ dst_data[i10*ne11 + i11] = WSP_GGML_CPU_FP32_TO_FP16(src[i10]);
5644
6132
  }
5645
6133
  }
5646
6134
  }
@@ -5933,7 +6421,7 @@ static void wsp_ggml_compute_forward_im2col_f16(
5933
6421
  if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
5934
6422
  dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
5935
6423
  } else {
5936
- dst_data[iic*(KH*KW) + ikh*KW + ikw] = WSP_GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
6424
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = WSP_GGML_CPU_FP32_TO_FP16(src_data[iih*IW + iiw]);
5937
6425
  }
5938
6426
  }
5939
6427
  }
@@ -6058,6 +6546,186 @@ void wsp_ggml_compute_forward_im2col_back_f32(
6058
6546
  }
6059
6547
  }
6060
6548
 
6549
+ static void wsp_ggml_call_mul_mat(wsp_ggml_type type, const wsp_ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
6550
+ void * a, void * b, float * c) {
6551
+ const wsp_ggml_type_traits * traits = wsp_ggml_get_type_traits(type);
6552
+ struct wsp_ggml_tensor src1 = {};
6553
+ src1.type = type;
6554
+ src1.ne[0] = k;
6555
+ src1.ne[1] = m;
6556
+ src1.ne[2] = 1;
6557
+ src1.ne[3] = 1;
6558
+ src1.nb[0] = traits->type_size;
6559
+ src1.nb[1] = k * traits->type_size;
6560
+ src1.nb[2] = src1.nb[1];
6561
+ src1.nb[3] = src1.nb[2];
6562
+ src1.data = a;
6563
+
6564
+ struct wsp_ggml_tensor src0 = {};
6565
+ src0.type = type;
6566
+ src0.ne[0] = k;
6567
+ src0.ne[1] = n;
6568
+ src0.ne[2] = 1;
6569
+ src0.ne[3] = 1;
6570
+ src0.nb[0] = traits->type_size;
6571
+ src0.nb[1] = k * traits->type_size;
6572
+ src0.nb[2] = src0.nb[1];
6573
+ src0.nb[3] = src0.nb[2];
6574
+ src0.data = b;
6575
+
6576
+ struct wsp_ggml_tensor dst = {};
6577
+ dst.ne[0] = n;
6578
+ dst.ne[1] = m;
6579
+ dst.ne[2] = 1;
6580
+ dst.ne[3] = 1;
6581
+ dst.nb[0] = sizeof(float);
6582
+ dst.nb[1] = n * sizeof(float);
6583
+ dst.nb[2] = dst.nb[1];
6584
+ dst.nb[3] = dst.nb[2];
6585
+ dst.data = c;
6586
+ dst.src[0] = &src0;
6587
+ dst.src[1] = &src1;
6588
+
6589
+ wsp_ggml_compute_forward_mul_mat(params, &dst);
6590
+ }
6591
+
6592
+ // wsp_ggml_compute_forward_conv_2d
6593
+
6594
+ static void wsp_ggml_compute_forward_conv_2d_impl(const wsp_ggml_compute_params * params,
6595
+ const wsp_ggml_tensor * kernel, // [KW, KH, IC, OC]
6596
+ const wsp_ggml_tensor * src, // [W, H, C, N]
6597
+ wsp_ggml_tensor * dst, // [OW, OH, OC, N]
6598
+ wsp_ggml_type kernel_type) {
6599
+
6600
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(kernel));
6601
+ WSP_GGML_ASSERT(kernel_type == WSP_GGML_TYPE_F16 || kernel_type == WSP_GGML_TYPE_F32);
6602
+ WSP_GGML_ASSERT(kernel->type == kernel_type);
6603
+
6604
+ const wsp_ggml_type_traits * traits = wsp_ggml_get_type_traits(kernel_type);
6605
+
6606
+ const int32_t stride_x = dst->op_params[0];
6607
+ const int32_t stride_y = dst->op_params[1];
6608
+ const int32_t pad_x = dst->op_params[2];
6609
+ const int32_t pad_y = dst->op_params[3];
6610
+ const int32_t dilation_x = dst->op_params[4];
6611
+ const int32_t dilation_y = dst->op_params[5];
6612
+
6613
+ const int64_t c_in = src->ne[2];
6614
+ const int64_t c_out = kernel->ne[3];
6615
+ WSP_GGML_ASSERT(c_in == kernel->ne[2]);
6616
+
6617
+ const int64_t src_w = src->ne[0];
6618
+ const int64_t src_h = src->ne[1];
6619
+ const int64_t knl_w = kernel->ne[0];
6620
+ const int64_t knl_h = kernel->ne[1];
6621
+ const int64_t dst_w = dst->ne[0];
6622
+ const int64_t dst_h = dst->ne[1];
6623
+
6624
+ const float * src_data = (float *) src->data;
6625
+ void * knl_data = kernel->data;
6626
+ float * dst_data = (float *) dst->data;
6627
+
6628
+ const int64_t knl_n = knl_w * knl_h * c_in;
6629
+ const int64_t patch_total = dst->ne[3] * dst_w * dst_h;
6630
+
6631
+ const int64_t space_per_patch = knl_n * traits->type_size + c_out * sizeof(float);
6632
+ const int64_t batch_size = params->wsize / space_per_patch;
6633
+ const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
6634
+ const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
6635
+
6636
+ WSP_GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
6637
+
6638
+ void * tmp = params->wdata;
6639
+
6640
+ for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
6641
+
6642
+ const int64_t patch_start_batch = batch_i * patches_per_batch;
6643
+ const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch,
6644
+ patch_total);
6645
+ const int64_t patch_n = patch_end_batch - patch_start_batch;
6646
+
6647
+ const int64_t patch_per_thread = (patch_n + params->nth - 1) / params->nth;
6648
+ const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
6649
+ const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
6650
+
6651
+ //im2col for a patch
6652
+ for (int64_t p = patch_start; p < patch_end; ++p) {
6653
+ const int64_t batch_n = p / (dst_w * dst_h);
6654
+ const int64_t src_x = (p / dst_w) % dst_h;
6655
+ const int64_t src_y = p % dst_w;
6656
+
6657
+ const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb[3]);
6658
+ char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;
6659
+
6660
+ for (int64_t ic = 0; ic < c_in; ++ic) {
6661
+ for (int64_t ky = 0; ky < knl_h; ++ky) {
6662
+ for (int64_t kx = 0; kx < knl_w; ++kx) {
6663
+ const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;
6664
+ const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;
6665
+
6666
+ int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
6667
+
6668
+ float src_val;
6669
+ if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6670
+ src_val = 0.0f;
6671
+ } else {
6672
+ const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
6673
+ src_val = *src_ptr;
6674
+ }
6675
+
6676
+ char * element_ptr = dst_row + dst_idx * traits->type_size;
6677
+ if (kernel_type == WSP_GGML_TYPE_F32) {
6678
+ *(float *) element_ptr = src_val;
6679
+ } else if (kernel_type == WSP_GGML_TYPE_F16) {
6680
+ *(wsp_ggml_fp16_t *) element_ptr = WSP_GGML_CPU_FP32_TO_FP16(src_val);
6681
+ }
6682
+ }
6683
+ }
6684
+ }
6685
+ } // patches handled by this thread
6686
+
6687
+ wsp_ggml_barrier(params->threadpool);
6688
+
6689
+ float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size);
6690
+
6691
+ WSP_GGML_ASSERT(gemm_output + patch_n * c_out <= (float*)tmp + params->wsize);
6692
+
6693
+ // GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
6694
+ wsp_ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
6695
+
6696
+ wsp_ggml_barrier(params->threadpool);
6697
+
6698
+
6699
+ //permute back [OC, N, OH, OW] to [N, OC, OH, OW]
6700
+ const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth;
6701
+ const int64_t permute_start = params->ith * permute_per_thread;
6702
+ const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n);
6703
+
6704
+ for (int64_t i = permute_start; i < permute_end; ++i) {
6705
+ const int64_t p = patch_start_batch + i;
6706
+ const int64_t batch_n = p / (dst_w * dst_h);
6707
+ const int64_t dst_y = (p / dst_w) % dst_h;
6708
+ const int64_t dst_x = p % dst_w;
6709
+
6710
+ for (int64_t oc = 0; oc < c_out; ++oc) {
6711
+ const float value = gemm_output[i * c_out + oc];
6712
+ float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);
6713
+ *dst_ptr = value;
6714
+ }
6715
+ }
6716
+ }
6717
+ }
6718
+
6719
+ void wsp_ggml_compute_forward_conv_2d(
6720
+ const wsp_ggml_compute_params * params,
6721
+ wsp_ggml_tensor * dst) {
6722
+
6723
+ const wsp_ggml_tensor * src0 = dst->src[0];
6724
+ const wsp_ggml_tensor * src1 = dst->src[1];
6725
+
6726
+ wsp_ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
6727
+ }
6728
+
6061
6729
  // wsp_ggml_compute_forward_conv_transpose_2d
6062
6730
 
6063
6731
  void wsp_ggml_compute_forward_conv_transpose_2d(
@@ -6109,7 +6777,7 @@ void wsp_ggml_compute_forward_conv_transpose_2d(
6109
6777
  const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
6110
6778
  wsp_ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
6111
6779
  for (int i10 = 0; i10 < ne10; i10++) {
6112
- dst_data[i10*ne12 + i12] = WSP_GGML_FP32_TO_FP16(src[i10]);
6780
+ dst_data[i10*ne12 + i12] = WSP_GGML_CPU_FP32_TO_FP16(src[i10]);
6113
6781
  }
6114
6782
  }
6115
6783
  }
@@ -6358,7 +7026,7 @@ static void wsp_ggml_compute_forward_pool_1d_sk_p0(
6358
7026
  case WSP_GGML_OP_POOL_COUNT: WSP_GGML_ABORT("fatal error");
6359
7027
  }
6360
7028
  for (int ki = 0; ki < k; ++ki) {
6361
- const float srow_j = (src->type == WSP_GGML_TYPE_F32) ? ((const float*)srow)[j] : WSP_GGML_FP16_TO_FP32(((const wsp_ggml_fp16_t*)srow)[j]);
7029
+ const float srow_j = (src->type == WSP_GGML_TYPE_F32) ? ((const float*)srow)[j] : WSP_GGML_CPU_FP16_TO_FP32(((const wsp_ggml_fp16_t*)srow)[j]);
6362
7030
  switch (op) {
6363
7031
  case WSP_GGML_OP_POOL_AVG: drow[i] += srow_j; break;
6364
7032
  case WSP_GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break;
@@ -6450,7 +7118,7 @@ void wsp_ggml_compute_forward_pool_2d(
6450
7118
  for (int kx = 0; kx < k0; ++kx) {
6451
7119
  int j = ix + kx;
6452
7120
  if (j < 0 || j >= src->ne[0]) continue;
6453
- const float srow_j = (src->type == WSP_GGML_TYPE_F32) ? ((const float*)srow)[j] : WSP_GGML_FP16_TO_FP32(((const wsp_ggml_fp16_t*)srow)[j]);
7121
+ const float srow_j = (src->type == WSP_GGML_TYPE_F32) ? ((const float*)srow)[j] : WSP_GGML_CPU_FP16_TO_FP32(((const wsp_ggml_fp16_t*)srow)[j]);
6454
7122
  switch (op) {
6455
7123
  case WSP_GGML_OP_POOL_AVG: *out += srow_j; break;
6456
7124
  case WSP_GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break;
@@ -6538,7 +7206,7 @@ void wsp_ggml_compute_forward_pool_2d_back(
6538
7206
  }
6539
7207
 
6540
7208
  const float val = dst->type == WSP_GGML_TYPE_F32 ?
6541
- ((const float *) drowf)[j] : WSP_GGML_FP16_TO_FP32(((const wsp_ggml_fp16_t *) drowf)[j]);
7209
+ ((const float *) drowf)[j] : WSP_GGML_CPU_FP16_TO_FP32(((const wsp_ggml_fp16_t *) drowf)[j]);
6542
7210
  if (val <= maxval) {
6543
7211
  continue;
6544
7212
  }
@@ -6558,7 +7226,7 @@ void wsp_ggml_compute_forward_pool_2d_back(
6558
7226
  if (dst->type == WSP_GGML_TYPE_F32) {
6559
7227
  ((float *) drow)[j] += grad0;
6560
7228
  } else {
6561
- ((wsp_ggml_fp16_t *) drow)[j] = WSP_GGML_FP32_TO_FP16(grad0 + WSP_GGML_FP16_TO_FP32(((const wsp_ggml_fp16_t *) drow)[j]));
7229
+ ((wsp_ggml_fp16_t *) drow)[j] = WSP_GGML_CPU_FP32_TO_FP16(grad0 + WSP_GGML_CPU_FP16_TO_FP32(((const wsp_ggml_fp16_t *) drow)[j]));
6562
7230
  }
6563
7231
  } else if (op == WSP_GGML_OP_POOL_AVG) {
6564
7232
  const float grad = grad0 / ka;
@@ -6577,7 +7245,7 @@ void wsp_ggml_compute_forward_pool_2d_back(
6577
7245
  if (dst->type == WSP_GGML_TYPE_F32) {
6578
7246
  ((float *) drow)[j] += grad;
6579
7247
  } else {
6580
- ((wsp_ggml_fp16_t *) drow)[j] += WSP_GGML_FP32_TO_FP16(grad);
7248
+ ((wsp_ggml_fp16_t *) drow)[j] += WSP_GGML_CPU_FP32_TO_FP16(grad);
6581
7249
  }
6582
7250
  }
6583
7251
  }
@@ -6608,12 +7276,13 @@ static void wsp_ggml_compute_forward_upscale_f32(
6608
7276
 
6609
7277
  WSP_GGML_TENSOR_UNARY_OP_LOCALS
6610
7278
 
6611
- const float sf0 = (float)ne0/src0->ne[0];
6612
- const float sf1 = (float)ne1/src0->ne[1];
6613
- const float sf2 = (float)ne2/src0->ne[2];
6614
- const float sf3 = (float)ne3/src0->ne[3];
7279
+ float sf0 = (float)ne0/src0->ne[0];
7280
+ float sf1 = (float)ne1/src0->ne[1];
7281
+ float sf2 = (float)ne2/src0->ne[2];
7282
+ float sf3 = (float)ne3/src0->ne[3];
6615
7283
 
6616
- const wsp_ggml_scale_mode mode = (wsp_ggml_scale_mode) wsp_ggml_get_op_params_i32(dst, 0);
7284
+ const int32_t mode_flags = wsp_ggml_get_op_params_i32(dst, 0);
7285
+ const wsp_ggml_scale_mode mode = (wsp_ggml_scale_mode) (mode_flags & 0xFF);
6617
7286
 
6618
7287
  if (mode == WSP_GGML_SCALE_MODE_NEAREST) {
6619
7288
  for (int64_t i3 = 0; i3 < ne3; i3++) {
@@ -6634,8 +7303,12 @@ static void wsp_ggml_compute_forward_upscale_f32(
6634
7303
  }
6635
7304
  }
6636
7305
  } else if (mode == WSP_GGML_SCALE_MODE_BILINEAR) {
6637
- // setting a pixel offset of 0 would replicate the behavior of pytorch interpolate with align_corners=True
6638
- const float pixel_offset = 0.5f;
7306
+ float pixel_offset = 0.5f;
7307
+ if (mode_flags & WSP_GGML_SCALE_FLAG_ALIGN_CORNERS) {
7308
+ pixel_offset = 0.0f;
7309
+ sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1);
7310
+ sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1);
7311
+ }
6639
7312
 
6640
7313
  for (int64_t i3 = 0; i3 < ne3; i3++) {
6641
7314
  const int64_t i03 = i3 / sf3;
@@ -7142,7 +7815,7 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
7142
7815
  // loop over n_kv and n_head_kv
7143
7816
  // ref: https://arxiv.org/pdf/2112.05682.pdf
7144
7817
  for (int64_t ic = 0; ic < nek1; ++ic) {
7145
- const float mv = mp ? slope*WSP_GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
7818
+ const float mv = mp ? slope*WSP_GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
7146
7819
  if (mv == -INFINITY) {
7147
7820
  continue;
7148
7821
  }
@@ -7210,7 +7883,7 @@ static void wsp_ggml_compute_forward_flash_attn_ext_f16(
7210
7883
 
7211
7884
  if (v->type == WSP_GGML_TYPE_F16) {
7212
7885
  for (int64_t d = 0; d < DV; ++d) {
7213
- VKQ32[d] = WSP_GGML_FP16_TO_FP32(VKQ16[d]);
7886
+ VKQ32[d] = WSP_GGML_CPU_FP16_TO_FP32(VKQ16[d]);
7214
7887
  }
7215
7888
  }
7216
7889
 
@@ -7994,6 +8667,34 @@ void wsp_ggml_compute_forward_unary(
7994
8667
  }
7995
8668
  }
7996
8669
 
8670
+ //wsp_ggml_compute_forward_glu
8671
+
8672
+ void wsp_ggml_compute_forward_glu(
8673
+ const wsp_ggml_compute_params * params,
8674
+ wsp_ggml_tensor * dst) {
8675
+
8676
+ const wsp_ggml_glu_op op = wsp_ggml_get_glu_op(dst);
8677
+
8678
+ switch (op) {
8679
+ case WSP_GGML_GLU_OP_REGLU:
8680
+ {
8681
+ wsp_ggml_compute_forward_reglu(params, dst);
8682
+ } break;
8683
+ case WSP_GGML_GLU_OP_GEGLU:
8684
+ {
8685
+ wsp_ggml_compute_forward_geglu(params, dst);
8686
+ } break;
8687
+ case WSP_GGML_GLU_OP_SWIGLU:
8688
+ {
8689
+ wsp_ggml_compute_forward_swiglu(params, dst);
8690
+ } break;
8691
+ default:
8692
+ {
8693
+ WSP_GGML_ABORT("fatal error");
8694
+ }
8695
+ }
8696
+ }
8697
+
7997
8698
  // wsp_ggml_compute_forward_get_rel_pos
7998
8699
 
7999
8700
  static void wsp_ggml_compute_forward_get_rel_pos_f16(