whisper.rn 0.5.3 → 0.5.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. package/README.md +1 -1
  2. package/android/src/main/java/com/rnwhisper/WhisperContext.java +5 -0
  3. package/android/src/main/jni.cpp +13 -0
  4. package/cpp/ggml-alloc.c +78 -26
  5. package/cpp/ggml-alloc.h +9 -0
  6. package/cpp/ggml-backend-impl.h +1 -1
  7. package/cpp/ggml-backend-reg.cpp +19 -3
  8. package/cpp/ggml-backend.cpp +72 -20
  9. package/cpp/ggml-backend.h +2 -1
  10. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  11. package/cpp/ggml-cpu/arch/arm/repack.cpp +1004 -0
  12. package/cpp/ggml-cpu/arch/x86/repack.cpp +6 -6
  13. package/cpp/ggml-cpu/arch-fallback.h +50 -2
  14. package/cpp/ggml-cpu/ggml-cpu-impl.h +1 -1
  15. package/cpp/ggml-cpu/ggml-cpu.c +139 -58
  16. package/cpp/ggml-cpu/ggml-cpu.cpp +4 -0
  17. package/cpp/ggml-cpu/ops.cpp +170 -18
  18. package/cpp/ggml-cpu/ops.h +1 -0
  19. package/cpp/ggml-cpu/repack.cpp +531 -5
  20. package/cpp/ggml-cpu/repack.h +14 -0
  21. package/cpp/ggml-cpu/simd-mappings.h +16 -18
  22. package/cpp/ggml-cpu/vec.cpp +41 -1
  23. package/cpp/ggml-cpu/vec.h +241 -138
  24. package/cpp/ggml-cpu.h +1 -0
  25. package/cpp/ggml-impl.h +0 -4
  26. package/cpp/ggml-metal/ggml-metal-context.m +26 -16
  27. package/cpp/ggml-metal/ggml-metal-device.cpp +452 -371
  28. package/cpp/ggml-metal/ggml-metal-device.h +87 -65
  29. package/cpp/ggml-metal/ggml-metal-device.m +263 -104
  30. package/cpp/ggml-metal/ggml-metal-impl.h +58 -4
  31. package/cpp/ggml-metal/ggml-metal-ops.cpp +415 -98
  32. package/cpp/ggml-metal/ggml-metal-ops.h +4 -0
  33. package/cpp/ggml-metal/ggml-metal.cpp +6 -5
  34. package/cpp/ggml-metal/ggml-metal.metal +404 -34
  35. package/cpp/ggml.c +110 -31
  36. package/cpp/ggml.h +51 -12
  37. package/cpp/jsi/RNWhisperJSI.cpp +1 -0
  38. package/cpp/whisper.cpp +17 -4
  39. package/ios/CMakeLists.txt +21 -1
  40. package/ios/RNWhisperContext.mm +5 -0
  41. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
  42. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
  43. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -1
  44. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
  45. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +0 -4
  46. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +51 -12
  47. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  48. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +404 -34
  49. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  50. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
  51. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
  52. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -1
  53. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
  54. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +0 -4
  55. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +51 -12
  56. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  57. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  58. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +404 -34
  59. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  60. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
  61. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
  62. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -1
  63. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
  64. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +0 -4
  65. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +51 -12
  66. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  67. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +404 -34
  68. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  69. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
  70. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
  71. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -1
  72. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
  73. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +0 -4
  74. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +51 -12
  75. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  76. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  77. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +404 -34
  78. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  79. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  80. package/lib/commonjs/jest-mock.js +2 -0
  81. package/lib/commonjs/jest-mock.js.map +1 -1
  82. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +156 -12
  83. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  84. package/lib/commonjs/version.json +1 -1
  85. package/lib/module/NativeRNWhisper.js.map +1 -1
  86. package/lib/module/jest-mock.js +2 -0
  87. package/lib/module/jest-mock.js.map +1 -1
  88. package/lib/module/realtime-transcription/RealtimeTranscriber.js +155 -12
  89. package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  90. package/lib/module/version.json +1 -1
  91. package/lib/typescript/NativeRNWhisper.d.ts +1 -0
  92. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  93. package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts +29 -0
  94. package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
  95. package/lib/typescript/realtime-transcription/types.d.ts +7 -0
  96. package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
  97. package/package.json +1 -1
  98. package/src/NativeRNWhisper.ts +1 -0
  99. package/src/jest-mock.ts +2 -0
  100. package/src/realtime-transcription/RealtimeTranscriber.ts +179 -9
  101. package/src/realtime-transcription/types.ts +9 -0
  102. package/src/version.json +1 -1
@@ -1957,6 +1957,8 @@ WSP_GGML_TABLE_END()
1957
1957
  #define FC_MUL_MV 600
1958
1958
  #define FC_MUL_MM 700
1959
1959
  #define FC_ROPE 800
1960
+ #define FC_SSM_CONV 900
1961
+ #define FC_COUNT_EQUAL 1000
1960
1962
 
1961
1963
  // op-specific constants
1962
1964
  #define OP_FLASH_ATTN_EXT_NQPTG 8
@@ -2062,6 +2064,10 @@ typedef struct {
2062
2064
  float bias;
2063
2065
  } wsp_ggml_metal_kargs_scale;
2064
2066
 
2067
+ typedef struct {
2068
+ float val;
2069
+ } wsp_ggml_metal_kargs_fill;
2070
+
2065
2071
  typedef struct {
2066
2072
  float min;
2067
2073
  float max;
@@ -2712,14 +2718,38 @@ typedef struct {
2712
2718
  } wsp_ggml_metal_kargs_leaky_relu;
2713
2719
 
2714
2720
  typedef struct {
2715
- int64_t ne00;
2716
- int64_t ne01;
2717
- int64_t ne02;
2718
- int64_t ne03;
2721
+ int32_t ne00;
2722
+ int32_t ne01;
2723
+ int32_t ne02;
2724
+ int32_t ne03;
2719
2725
  uint64_t nb00;
2720
2726
  uint64_t nb01;
2721
2727
  uint64_t nb02;
2722
2728
  uint64_t nb03;
2729
+ int32_t ne0;
2730
+ int32_t ne1;
2731
+ int32_t ne2;
2732
+ int32_t ne3;
2733
+ uint64_t nb0;
2734
+ uint64_t nb1;
2735
+ uint64_t nb2;
2736
+ uint64_t nb3;
2737
+ } wsp_ggml_metal_kargs_tri;
2738
+
2739
+ typedef struct {
2740
+ int32_t ne00;
2741
+ int32_t ne01;
2742
+ int32_t ne02;
2743
+ int32_t ne03;
2744
+ uint64_t nb00;
2745
+ uint64_t nb01;
2746
+ uint64_t nb02;
2747
+ uint64_t nb03;
2748
+ int32_t ne0;
2749
+ int32_t ne1;
2750
+ int32_t ne2;
2751
+ int32_t ne3;
2752
+ int32_t top_k;
2723
2753
  } wsp_ggml_metal_kargs_argsort;
2724
2754
 
2725
2755
  typedef struct {
@@ -2731,6 +2761,11 @@ typedef struct {
2731
2761
  uint64_t nb01;
2732
2762
  uint64_t nb02;
2733
2763
  uint64_t nb03;
2764
+ int32_t ne0;
2765
+ int32_t ne1;
2766
+ int32_t ne2;
2767
+ int32_t ne3;
2768
+ int32_t top_k;
2734
2769
  int32_t len;
2735
2770
  } wsp_ggml_metal_kargs_argsort_merge;
2736
2771
 
@@ -2740,6 +2775,25 @@ typedef struct {
2740
2775
  float step;
2741
2776
  } wsp_ggml_metal_kargs_arange;
2742
2777
 
2778
+ typedef struct {
2779
+ int64_t val;
2780
+ } wsp_ggml_metal_kargs_memset;
2781
+
2782
+ typedef struct {
2783
+ int32_t ne00;
2784
+ int32_t ne01;
2785
+ int32_t ne02;
2786
+ int32_t ne03;
2787
+ uint64_t nb00;
2788
+ uint64_t nb01;
2789
+ uint64_t nb02;
2790
+ uint64_t nb03;
2791
+ uint64_t nb10;
2792
+ uint64_t nb11;
2793
+ uint64_t nb12;
2794
+ uint64_t nb13;
2795
+ } wsp_ggml_metal_kargs_count_equal;
2796
+
2743
2797
  typedef struct {
2744
2798
  int32_t k0;
2745
2799
  int32_t k1;
@@ -4011,6 +4065,22 @@ kernel void kernel_scale_f32_4(
4011
4065
  dst[tpig] = src0[tpig] * args.scale + args.bias;
4012
4066
  }
4013
4067
 
4068
+ kernel void kernel_fill_f32(
4069
+ constant wsp_ggml_metal_kargs_fill & args,
4070
+ device const float * src0,
4071
+ device float * dst,
4072
+ uint tpig[[thread_position_in_grid]]) {
4073
+ dst[tpig] = args.val;
4074
+ }
4075
+
4076
+ kernel void kernel_fill_f32_4(
4077
+ constant wsp_ggml_metal_kargs_fill & args,
4078
+ device const float4 * src0,
4079
+ device float4 * dst,
4080
+ uint tpig[[thread_position_in_grid]]) {
4081
+ dst[tpig] = args.val;
4082
+ }
4083
+
4014
4084
  kernel void kernel_clamp_f32(
4015
4085
  constant wsp_ggml_metal_kargs_clamp & args,
4016
4086
  device const float * src0,
@@ -4357,6 +4427,36 @@ kernel void kernel_exp_f32_4(
4357
4427
  dst[tpig] = exp(src0[tpig]);
4358
4428
  }
4359
4429
 
4430
+ kernel void kernel_softplus_f32(
4431
+ device const float * src0,
4432
+ device float * dst,
4433
+ uint tpig[[thread_position_in_grid]]) {
4434
+ device const float & x = src0[tpig];
4435
+ dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
4436
+ }
4437
+
4438
+ kernel void kernel_softplus_f32_4(
4439
+ device const float4 * src0,
4440
+ device float4 * dst,
4441
+ uint tpig[[thread_position_in_grid]]) {
4442
+ device const float4 & x = src0[tpig];
4443
+ dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
4444
+ }
4445
+
4446
+ kernel void kernel_expm1_f32(
4447
+ device const float * src0,
4448
+ device float * dst,
4449
+ uint tpig[[thread_position_in_grid]]) {
4450
+ dst[tpig] = exp(src0[tpig]) - 1.0f;
4451
+ }
4452
+
4453
+ kernel void kernel_expm1_f32_4(
4454
+ device const float4 * src0,
4455
+ device float4 * dst,
4456
+ uint tpig[[thread_position_in_grid]]) {
4457
+ dst[tpig] = exp(src0[tpig]) - 1.0f;
4458
+ }
4459
+
4360
4460
  kernel void kernel_reglu_f32(
4361
4461
  constant wsp_ggml_metal_kargs_glu & args,
4362
4462
  device const char * src0,
@@ -4506,6 +4606,7 @@ kernel void kernel_op_sum_f32(
4506
4606
  return;
4507
4607
  }
4508
4608
 
4609
+ // TODO: become function constant
4509
4610
  const uint nsg = (ntg.x + 31) / 32;
4510
4611
 
4511
4612
  float sumf = 0;
@@ -4705,6 +4806,75 @@ typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
4705
4806
 
4706
4807
  template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
4707
4808
 
4809
+
4810
+ template<uint32_t ttype>
4811
+ bool _wsp_ggml_vec_tri_cmp(const int i, const int r);
4812
+
4813
+ template<>
4814
+ bool _wsp_ggml_vec_tri_cmp</* WSP_GGML_TRI_TYPE_LOWER */ 3>(const int i, const int r) {
4815
+ return i < r;
4816
+ }
4817
+
4818
+ template<>
4819
+ bool _wsp_ggml_vec_tri_cmp</* WSP_GGML_TRI_TYPE_LOWER_DIAG */ 2>(const int i, const int r) {
4820
+ return i <= r;
4821
+ }
4822
+
4823
+ template<>
4824
+ bool _wsp_ggml_vec_tri_cmp</* WSP_GGML_TRI_TYPE_UPPER */ 1>(const int i, const int r) {
4825
+ return i > r;
4826
+ }
4827
+
4828
+ template<>
4829
+ bool _wsp_ggml_vec_tri_cmp</* WSP_GGML_TRI_TYPE_UPPER_DIAG */ 0>(const int i, const int r) {
4830
+ return i >= r;
4831
+ }
4832
+
4833
+ template<typename T, int ttype>
4834
+ kernel void kernel_tri(
4835
+ constant wsp_ggml_metal_kargs_tri & args,
4836
+ device const char * src0,
4837
+ device const char * dst,
4838
+ uint3 tgpig[[threadgroup_position_in_grid]],
4839
+ ushort3 tpitg[[thread_position_in_threadgroup]],
4840
+ ushort3 ntg[[threads_per_threadgroup]]) {
4841
+ const int i3 = tgpig.z;
4842
+ const int i2 = tgpig.y;
4843
+ const int i1 = tgpig.x;
4844
+
4845
+ if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
4846
+ return;
4847
+ }
4848
+
4849
+ device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
4850
+ device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
4851
+
4852
+ // Each thread is a single element of the row if ne00 < max threads per
4853
+ // threadgroup, so this will loop once for each index that this thread is
4854
+ // responsible for
4855
+ for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
4856
+ // Use the comparison as a mask for branchless
4857
+ dst_row[i0] = static_cast<T>(_wsp_ggml_vec_tri_cmp<ttype>(i0, i1)) * src_row[i0];
4858
+ }
4859
+ }
4860
+
4861
+ typedef decltype(kernel_tri<float, 0>) kernel_tri_t;
4862
+
4863
+ template [[host_name("kernel_tri_f32_0")]] kernel kernel_tri_t kernel_tri<float, 0>;
4864
+ template [[host_name("kernel_tri_f32_1")]] kernel kernel_tri_t kernel_tri<float, 1>;
4865
+ template [[host_name("kernel_tri_f32_2")]] kernel kernel_tri_t kernel_tri<float, 2>;
4866
+ template [[host_name("kernel_tri_f32_3")]] kernel kernel_tri_t kernel_tri<float, 3>;
4867
+ template [[host_name("kernel_tri_f16_0")]] kernel kernel_tri_t kernel_tri<half, 0>;
4868
+ template [[host_name("kernel_tri_f16_1")]] kernel kernel_tri_t kernel_tri<half, 1>;
4869
+ template [[host_name("kernel_tri_f16_2")]] kernel kernel_tri_t kernel_tri<half, 2>;
4870
+ template [[host_name("kernel_tri_f16_3")]] kernel kernel_tri_t kernel_tri<half, 3>;
4871
+ #if defined(WSP_GGML_METAL_HAS_BF16)
4872
+ template [[host_name("kernel_tri_bf16_0")]] kernel kernel_tri_t kernel_tri<bfloat, 0>;
4873
+ template [[host_name("kernel_tri_bf16_1")]] kernel kernel_tri_t kernel_tri<bfloat, 1>;
4874
+ template [[host_name("kernel_tri_bf16_2")]] kernel kernel_tri_t kernel_tri<bfloat, 2>;
4875
+ template [[host_name("kernel_tri_bf16_3")]] kernel kernel_tri_t kernel_tri<bfloat, 3>;
4876
+ #endif
4877
+
4708
4878
  template<typename T>
4709
4879
  kernel void kernel_soft_max(
4710
4880
  constant wsp_ggml_metal_kargs_soft_max & args,
@@ -4990,7 +5160,102 @@ kernel void kernel_ssm_conv_f32_f32_4(
4990
5160
  x[0] = sumf;
4991
5161
  }
4992
5162
 
5163
+ constant short FC_ssm_conv_bs [[function_constant(FC_SSM_CONV + 0)]];
5164
+
5165
+ // Batched version: each threadgroup processes multiple tokens for better efficiency
5166
+ // Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens
5167
+ kernel void kernel_ssm_conv_f32_f32_batched(
5168
+ constant wsp_ggml_metal_kargs_ssm_conv & args,
5169
+ device const void * src0,
5170
+ device const void * src1,
5171
+ device float * dst,
5172
+ uint3 tgpig[[threadgroup_position_in_grid]],
5173
+ uint3 tpitg[[thread_position_in_threadgroup]],
5174
+ uint3 ntg[[threads_per_threadgroup]]) {
5175
+ // tgpig.x = row index (ir)
5176
+ // tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
5177
+ // tgpig.z = sequence index (i3)
5178
+ // tpitg.x = thread within batch (0..BATCH_SIZE-1)
5179
+ const short BATCH_SIZE = FC_ssm_conv_bs;
5180
+
5181
+ const int64_t ir = tgpig.x;
5182
+ const int64_t i2_base = tgpig.y * BATCH_SIZE;
5183
+ const int64_t i3 = tgpig.z;
5184
+ const int64_t i2_off = tpitg.x;
5185
+ const int64_t i2 = i2_base + i2_off;
5186
+
5187
+ const int64_t nc = args.ne10; // conv kernel size (typically 4)
5188
+ const int64_t n_t = args.ne1; // number of tokens
5189
+
5190
+ // Bounds check for partial batches at the end
5191
+ if (i2 >= n_t) {
5192
+ return;
5193
+ }
5194
+
5195
+ // Load conv weights (shared across all tokens for this row)
5196
+ device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
5197
+
5198
+ // Load source for this specific token
5199
+ device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
5200
+
5201
+ // Output location for this token
5202
+ device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
5203
+
5204
+ float sumf = 0.0f;
5205
+ for (int64_t i0 = 0; i0 < nc; ++i0) {
5206
+ sumf += s[i0] * c[i0];
5207
+ }
5208
+
5209
+ x[0] = sumf;
5210
+ }
5211
+
5212
+ kernel void kernel_ssm_conv_f32_f32_batched_4(
5213
+ constant wsp_ggml_metal_kargs_ssm_conv & args,
5214
+ device const void * src0,
5215
+ device const void * src1,
5216
+ device float * dst,
5217
+ uint3 tgpig[[threadgroup_position_in_grid]],
5218
+ uint3 tpitg[[thread_position_in_threadgroup]],
5219
+ uint3 ntg[[threads_per_threadgroup]]) {
5220
+ // tgpig.x = row index (ir)
5221
+ // tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
5222
+ // tgpig.z = sequence index (i3)
5223
+ // tpitg.x = thread within batch (0..BATCH_SIZE-1)
5224
+ const short BATCH_SIZE = FC_ssm_conv_bs;
5225
+
5226
+ const int64_t ir = tgpig.x;
5227
+ const int64_t i2_base = tgpig.y * BATCH_SIZE;
5228
+ const int64_t i3 = tgpig.z;
5229
+ const int64_t i2_off = tpitg.x;
5230
+ const int64_t i2 = i2_base + i2_off;
5231
+
5232
+ const int64_t nc = args.ne10; // conv kernel size (typically 4)
5233
+ const int64_t n_t = args.ne1; // number of tokens
5234
+
5235
+ // Bounds check for partial batches at the end
5236
+ if (i2 >= n_t) {
5237
+ return;
5238
+ }
5239
+
5240
+ // Load conv weights (shared across all tokens for this row)
5241
+ device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
5242
+
5243
+ // Load source for this specific token
5244
+ device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
5245
+
5246
+ // Output location for this token
5247
+ device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
5248
+
5249
+ float sumf = 0.0f;
5250
+ for (int64_t i0 = 0; i0 < nc/4; ++i0) {
5251
+ sumf += dot(s[i0], c[i0]);
5252
+ }
5253
+
5254
+ x[0] = sumf;
5255
+ }
5256
+
4993
5257
  // ref: ggml.c:wsp_ggml_compute_forward_ssm_scan_f32, Mamba-2 part
5258
+ // Optimized version: reduces redundant memory loads by having one thread load shared values
4994
5259
  kernel void kernel_ssm_scan_f32(
4995
5260
  constant wsp_ggml_metal_kargs_ssm_scan & args,
4996
5261
  device const void * src0,
@@ -5010,7 +5275,15 @@ kernel void kernel_ssm_scan_f32(
5010
5275
  uint3 tgpg[[threadgroups_per_grid]]) {
5011
5276
  constexpr short NW = N_SIMDWIDTH;
5012
5277
 
5013
- shared[tpitg.x] = 0.0f;
5278
+ // Shared memory layout:
5279
+ // [0..sgptg*NW-1]: partial sums for reduction (existing)
5280
+ // [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch
5281
+ // [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch
5282
+ threadgroup float * shared_sums = shared;
5283
+ threadgroup float * shared_x_dt = shared + sgptg * NW;
5284
+ threadgroup float * shared_dA = shared + sgptg * NW + sgptg;
5285
+
5286
+ shared_sums[tpitg.x] = 0.0f;
5014
5287
 
5015
5288
  const int32_t i0 = tpitg.x;
5016
5289
  const int32_t i1 = tgpig.x;
@@ -5050,32 +5323,47 @@ kernel void kernel_ssm_scan_f32(
5050
5323
  for (int i2 = 0; i2 < n_t; i2 += sgptg) {
5051
5324
  threadgroup_barrier(mem_flags::mem_threadgroup);
5052
5325
 
5053
- for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
5054
- const float dt0 = dt[0];
5326
+ // Pre-compute x_dt and dA for this batch of tokens
5327
+ // Only first sgptg threads do the loads and expensive math
5328
+ if (i0 < sgptg && i2 + i0 < n_t) {
5329
+ // ns12 and ns21 are element strides (nb12/nb10, nb21/nb20)
5330
+ device const float * x_t = x + i0 * args.ns12;
5331
+ device const float * dt_t = dt + i0 * args.ns21;
5332
+
5333
+ const float dt0 = dt_t[0];
5055
5334
  const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
5056
- const float x_dt = x[0] * dtsp;
5057
- const float dA = exp(dtsp * A0);
5335
+ shared_x_dt[i0] = x_t[0] * dtsp;
5336
+ shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies
5337
+ }
5338
+
5339
+ threadgroup_barrier(mem_flags::mem_threadgroup);
5340
+
5341
+ for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
5342
+ const float x_dt = shared_x_dt[t];
5343
+ const float dA = exp(shared_dA[t] * A0);
5058
5344
 
5059
5345
  s = (s0 * dA) + (B[i0] * x_dt);
5060
5346
 
5061
5347
  const float sumf = simd_sum(s * C[i0]);
5062
5348
 
5063
5349
  if (tiisg == 0) {
5064
- shared[t*NW + sgitg] = sumf;
5350
+ shared_sums[t*NW + sgitg] = sumf;
5065
5351
  }
5066
5352
 
5067
5353
  // recurse
5068
5354
  s0 = s;
5069
5355
 
5070
- x += args.ns12;
5071
- dt += args.ns21;
5072
5356
  B += args.ns42;
5073
5357
  C += args.ns52;
5074
5358
  }
5075
5359
 
5360
+ // Advance pointers for next batch
5361
+ x += sgptg * args.ns12;
5362
+ dt += sgptg * args.ns21;
5363
+
5076
5364
  threadgroup_barrier(mem_flags::mem_threadgroup);
5077
5365
 
5078
- const float sumf = simd_sum(shared[sgitg*NW + tiisg]);
5366
+ const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]);
5079
5367
 
5080
5368
  if (tiisg == 0 && i2 + sgitg < n_t) {
5081
5369
  y[sgitg*nh*nr] = sumf;
@@ -7432,11 +7720,12 @@ kernel void kernel_argsort_f32_i32(
7432
7720
  ushort3 ntg[[threads_per_threadgroup]]) {
7433
7721
  // bitonic sort
7434
7722
  const int col = tpitg[0];
7723
+ const int ib = tgpig[0] / args.ne01;
7435
7724
 
7436
- const int i00 = (tgpig[0]/args.ne01)*ntg.x;
7437
- const int i01 = tgpig[0]%args.ne01;
7438
- const int i02 = tgpig[1];
7439
- const int i03 = tgpig[2];
7725
+ const int i00 = ib*ntg.x;
7726
+ const int i01 = tgpig[0] % args.ne01;
7727
+ const int i02 = tgpig[1];
7728
+ const int i03 = tgpig[2];
7440
7729
 
7441
7730
  device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
7442
7731
 
@@ -7472,9 +7761,11 @@ kernel void kernel_argsort_f32_i32(
7472
7761
  }
7473
7762
  }
7474
7763
 
7764
+ const int64_t i0 = ib*args.top_k;
7765
+
7475
7766
  // copy the result to dst without the padding
7476
- if (i00 + col < args.ne00) {
7477
- dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03;
7767
+ if (i0 + col < args.ne0 && col < args.top_k) {
7768
+ dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
7478
7769
 
7479
7770
  dst[col] = shmem_i32[col];
7480
7771
  }
@@ -7509,22 +7800,22 @@ kernel void kernel_argsort_merge_f32_i32(
7509
7800
 
7510
7801
  const int start = im * (2 * args.len);
7511
7802
 
7512
- const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start)));
7513
- const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len)));
7803
+ const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
7804
+ const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
7514
7805
 
7515
7806
  const int total = len0 + len1;
7516
7807
 
7517
7808
  device const int32_t * tmp0 = tmp + start
7518
- + i01*args.ne00
7519
- + i02*args.ne00*args.ne01
7520
- + i03*args.ne00*args.ne01*args.ne02;
7809
+ + i01*args.ne0
7810
+ + i02*args.ne0*args.ne01
7811
+ + i03*args.ne0*args.ne01*args.ne02;
7521
7812
 
7522
7813
  device const int32_t * tmp1 = tmp0 + args.len;
7523
7814
 
7524
7815
  dst += start
7525
- + i01*args.ne00
7526
- + i02*args.ne00*args.ne01
7527
- + i03*args.ne00*args.ne01*args.ne02;
7816
+ + i01*args.top_k
7817
+ + i02*args.top_k*args.ne01
7818
+ + i03*args.top_k*args.ne01*args.ne02;
7528
7819
 
7529
7820
  device const float * src0_row = (device const float *)(src0
7530
7821
  + args.nb01*i01
@@ -7538,7 +7829,11 @@ kernel void kernel_argsort_merge_f32_i32(
7538
7829
  const int chunk = (total + ntg.x - 1) / ntg.x;
7539
7830
 
7540
7831
  const int k0 = tpitg.x * chunk;
7541
- const int k1 = min(k0 + chunk, total);
7832
+ const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
7833
+
7834
+ if (k0 >= args.top_k) {
7835
+ return;
7836
+ }
7542
7837
 
7543
7838
  if (k0 >= total) {
7544
7839
  return;
@@ -8512,6 +8807,7 @@ typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, wsp_dewsp_quantize_
8512
8807
 
8513
8808
  template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, wsp_dewsp_quantize_f32, float4x4, 1, wsp_dewsp_quantize_f32, 32, 32>;
8514
8809
  template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, wsp_dewsp_quantize_f32, float4x4, 1, wsp_dewsp_quantize_f32, 40, 40>;
8810
+ template [[host_name("kernel_flash_attn_ext_f32_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, wsp_dewsp_quantize_f32, float4x4, 1, wsp_dewsp_quantize_f32, 48, 48>;
8515
8811
  template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, wsp_dewsp_quantize_f32, float4x4, 1, wsp_dewsp_quantize_f32, 64, 64>;
8516
8812
  template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, wsp_dewsp_quantize_f32, float4x4, 1, wsp_dewsp_quantize_f32, 72, 72>;
8517
8813
  template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, wsp_dewsp_quantize_f32, float4x4, 1, wsp_dewsp_quantize_f32, 80, 80>;
@@ -8525,6 +8821,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_at
8525
8821
 
8526
8822
  template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, wsp_dewsp_quantize_f16, half4x4, 1, wsp_dewsp_quantize_f16, 32, 32>;
8527
8823
  template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, wsp_dewsp_quantize_f16, half4x4, 1, wsp_dewsp_quantize_f16, 40, 40>;
8824
+ template [[host_name("kernel_flash_attn_ext_f16_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, wsp_dewsp_quantize_f16, half4x4, 1, wsp_dewsp_quantize_f16, 48, 48>;
8528
8825
  template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, wsp_dewsp_quantize_f16, half4x4, 1, wsp_dewsp_quantize_f16, 64, 64>;
8529
8826
  template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, wsp_dewsp_quantize_f16, half4x4, 1, wsp_dewsp_quantize_f16, 72, 72>;
8530
8827
  template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, wsp_dewsp_quantize_f16, half4x4, 1, wsp_dewsp_quantize_f16, 80, 80>;
@@ -8539,6 +8836,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_at
8539
8836
  #if defined(WSP_GGML_METAL_HAS_BF16)
8540
8837
  template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, wsp_dewsp_quantize_bf16, bfloat4x4, 1, wsp_dewsp_quantize_bf16, 32, 32>;
8541
8838
  template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, wsp_dewsp_quantize_bf16, bfloat4x4, 1, wsp_dewsp_quantize_bf16, 40, 40>;
8839
+ template [[host_name("kernel_flash_attn_ext_bf16_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, wsp_dewsp_quantize_bf16, bfloat4x4, 1, wsp_dewsp_quantize_bf16, 48, 48>;
8542
8840
  template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, wsp_dewsp_quantize_bf16, bfloat4x4, 1, wsp_dewsp_quantize_bf16, 64, 64>;
8543
8841
  template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, wsp_dewsp_quantize_bf16, bfloat4x4, 1, wsp_dewsp_quantize_bf16, 72, 72>;
8544
8842
  template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, wsp_dewsp_quantize_bf16, bfloat4x4, 1, wsp_dewsp_quantize_bf16, 80, 80>;
@@ -8553,6 +8851,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_at
8553
8851
 
8554
8852
  template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, wsp_dewsp_quantize_q4_0, block_q4_0, 2, wsp_dewsp_quantize_q4_0, 32, 32>;
8555
8853
  template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, wsp_dewsp_quantize_q4_0, block_q4_0, 2, wsp_dewsp_quantize_q4_0, 40, 40>;
8854
+ template [[host_name("kernel_flash_attn_ext_q4_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, wsp_dewsp_quantize_q4_0, block_q4_0, 2, wsp_dewsp_quantize_q4_0, 48, 48>;
8556
8855
  template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, wsp_dewsp_quantize_q4_0, block_q4_0, 2, wsp_dewsp_quantize_q4_0, 64, 64>;
8557
8856
  template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, wsp_dewsp_quantize_q4_0, block_q4_0, 2, wsp_dewsp_quantize_q4_0, 72, 72>;
8558
8857
  template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, wsp_dewsp_quantize_q4_0, block_q4_0, 2, wsp_dewsp_quantize_q4_0, 80, 80>;
@@ -8566,6 +8865,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_at
8566
8865
 
8567
8866
  template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, wsp_dewsp_quantize_q4_1, block_q4_1, 2, wsp_dewsp_quantize_q4_1, 32, 32>;
8568
8867
  template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, wsp_dewsp_quantize_q4_1, block_q4_1, 2, wsp_dewsp_quantize_q4_1, 40, 40>;
8868
+ template [[host_name("kernel_flash_attn_ext_q4_1_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, wsp_dewsp_quantize_q4_1, block_q4_1, 2, wsp_dewsp_quantize_q4_1, 48, 48>;
8569
8869
  template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, wsp_dewsp_quantize_q4_1, block_q4_1, 2, wsp_dewsp_quantize_q4_1, 64, 64>;
8570
8870
  template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, wsp_dewsp_quantize_q4_1, block_q4_1, 2, wsp_dewsp_quantize_q4_1, 72, 72>;
8571
8871
  template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, wsp_dewsp_quantize_q4_1, block_q4_1, 2, wsp_dewsp_quantize_q4_1, 80, 80>;
@@ -8579,6 +8879,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_at
8579
8879
 
8580
8880
  template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, wsp_dewsp_quantize_q5_0, block_q5_0, 2, wsp_dewsp_quantize_q5_0, 32, 32>;
8581
8881
  template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, wsp_dewsp_quantize_q5_0, block_q5_0, 2, wsp_dewsp_quantize_q5_0, 40, 40>;
8882
+ template [[host_name("kernel_flash_attn_ext_q5_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, wsp_dewsp_quantize_q5_0, block_q5_0, 2, wsp_dewsp_quantize_q5_0, 48, 48>;
8582
8883
  template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, wsp_dewsp_quantize_q5_0, block_q5_0, 2, wsp_dewsp_quantize_q5_0, 64, 64>;
8583
8884
  template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, wsp_dewsp_quantize_q5_0, block_q5_0, 2, wsp_dewsp_quantize_q5_0, 72, 72>;
8584
8885
  template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, wsp_dewsp_quantize_q5_0, block_q5_0, 2, wsp_dewsp_quantize_q5_0, 80, 80>;
@@ -8592,6 +8893,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_at
8592
8893
 
8593
8894
  template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, wsp_dewsp_quantize_q5_1, block_q5_1, 2, wsp_dewsp_quantize_q5_1, 32, 32>;
8594
8895
  template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, wsp_dewsp_quantize_q5_1, block_q5_1, 2, wsp_dewsp_quantize_q5_1, 40, 40>;
8896
+ template [[host_name("kernel_flash_attn_ext_q5_1_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, wsp_dewsp_quantize_q5_1, block_q5_1, 2, wsp_dewsp_quantize_q5_1, 48, 48>;
8595
8897
  template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, wsp_dewsp_quantize_q5_1, block_q5_1, 2, wsp_dewsp_quantize_q5_1, 64, 64>;
8596
8898
  template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, wsp_dewsp_quantize_q5_1, block_q5_1, 2, wsp_dewsp_quantize_q5_1, 72, 72>;
8597
8899
  template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, wsp_dewsp_quantize_q5_1, block_q5_1, 2, wsp_dewsp_quantize_q5_1, 80, 80>;
@@ -8605,6 +8907,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_at
8605
8907
 
8606
8908
  template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, wsp_dewsp_quantize_q8_0, block_q8_0, 2, wsp_dewsp_quantize_q8_0, 32, 32>;
8607
8909
  template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, wsp_dewsp_quantize_q8_0, block_q8_0, 2, wsp_dewsp_quantize_q8_0, 40, 40>;
8910
+ template [[host_name("kernel_flash_attn_ext_q8_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, wsp_dewsp_quantize_q8_0, block_q8_0, 2, wsp_dewsp_quantize_q8_0, 48, 48>;
8608
8911
  template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, wsp_dewsp_quantize_q8_0, block_q8_0, 2, wsp_dewsp_quantize_q8_0, 64, 64>;
8609
8912
  template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, wsp_dewsp_quantize_q8_0, block_q8_0, 2, wsp_dewsp_quantize_q8_0, 72, 72>;
8610
8913
  template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, wsp_dewsp_quantize_q8_0, block_q8_0, 2, wsp_dewsp_quantize_q8_0, 80, 80>;
@@ -11661,6 +11964,7 @@ typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
11661
11964
  template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
11662
11965
  template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
11663
11966
  template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
11967
+ template [[host_name("kernel_mul_mm_id_map0_ne20_5" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<5>;
11664
11968
  template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
11665
11969
  template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
11666
11970
  template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
@@ -12071,9 +12375,6 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m
12071
12375
 
12072
12376
  template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, wsp_dewsp_quantize_f32, float, float4x4, half, half2x4>;
12073
12377
  template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, wsp_dewsp_quantize_f16, half, half4x4, half, half2x4>;
12074
- #if defined(WSP_GGML_METAL_HAS_BF16)
12075
- template [[host_name("kernel_mul_mm_bf16_f16")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, half, half2x4, simdgroup_half8x8, bfloat4x4, 1, wsp_dewsp_quantize_bf16, bfloat, bfloat4x4, half, half2x4>;
12076
- #endif
12077
12378
  template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, wsp_dewsp_quantize_q4_0, float, float4x4, half, half2x4>;
12078
12379
  template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, wsp_dewsp_quantize_q4_1, float, float4x4, half, half2x4>;
12079
12380
  template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, wsp_dewsp_quantize_q5_0, float, float4x4, half, half2x4>;
@@ -12129,9 +12430,6 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_m
12129
12430
 
12130
12431
  template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, wsp_dewsp_quantize_f32, float, float4x4, half, half2x4>;
12131
12432
  template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, wsp_dewsp_quantize_f16, half, half4x4, half, half2x4>;
12132
- #if defined(WSP_GGML_METAL_HAS_BF16)
12133
- template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, half, half2x4, simdgroup_half8x8, bfloat4x4, 1, wsp_dewsp_quantize_bf16, bfloat, bfloat4x4, half, half2x4>;
12134
- #endif
12135
12433
  template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, wsp_dewsp_quantize_q4_0, float, float4x4, half, half2x4>;
12136
12434
  template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, wsp_dewsp_quantize_q4_1, float, float4x4, half, half2x4>;
12137
12435
  template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, wsp_dewsp_quantize_q5_0, float, float4x4, half, half2x4>;
@@ -12434,3 +12732,75 @@ kernel void kernel_opt_step_sgd_f32(
12434
12732
 
12435
12733
  x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
12436
12734
  }
12735
+
12736
+ template<typename T>
12737
+ kernel void kernel_memset(
12738
+ constant wsp_ggml_metal_kargs_fill & args,
12739
+ device T * dst,
12740
+ uint tpig[[thread_position_in_grid]]) {
12741
+ dst[tpig] = args.val;
12742
+ }
12743
+
12744
+ typedef decltype(kernel_memset<int64_t>) kernel_memset_t;
12745
+
12746
+ template [[host_name("kernel_memset_i64")]] kernel kernel_memset_t kernel_memset<int64_t>;
12747
+
12748
+ constant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]];
12749
+
12750
+ template<typename T>
12751
+ kernel void kernel_count_equal(
12752
+ constant wsp_ggml_metal_kargs_count_equal & args,
12753
+ device const char * src0,
12754
+ device const char * src1,
12755
+ device atomic_int * dst,
12756
+ threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
12757
+ uint3 tgpig[[threadgroup_position_in_grid]],
12758
+ ushort3 tpitg[[thread_position_in_threadgroup]],
12759
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
12760
+ ushort tiisg[[thread_index_in_simdgroup]],
12761
+ ushort3 ntg[[threads_per_threadgroup]]) {
12762
+ const short NSG = FC_count_equal_nsg;
12763
+
12764
+ const int i3 = tgpig.z;
12765
+ const int i2 = tgpig.y;
12766
+ const int i1 = tgpig.x;
12767
+
12768
+ if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
12769
+ return;
12770
+ }
12771
+
12772
+ int sum = 0;
12773
+
12774
+ device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03;
12775
+ device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13;
12776
+
12777
+ for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
12778
+ const T v0 = *(device const T *)(base0 + i0*args.nb00);
12779
+ const T v1 = *(device const T *)(base1 + i0*args.nb10);
12780
+ sum += (v0 == v1);
12781
+ }
12782
+
12783
+ sum = simd_sum(sum);
12784
+
12785
+ if (tiisg == 0) {
12786
+ shmem_i32[sgitg] = sum;
12787
+ }
12788
+
12789
+ threadgroup_barrier(mem_flags::mem_threadgroup);
12790
+
12791
+ if (sgitg == 0) {
12792
+ float v = 0.0f;
12793
+ if (tpitg.x < NSG) {
12794
+ v = shmem_i32[tpitg.x];
12795
+ }
12796
+
12797
+ float total = simd_sum(v);
12798
+ if (tpitg.x == 0) {
12799
+ atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed);
12800
+ }
12801
+ }
12802
+ }
12803
+
12804
+ typedef decltype(kernel_count_equal<int32_t>) kernel_count_equal_t;
12805
+
12806
+ template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal<int32_t>;