whisper.rn 0.5.1 → 0.5.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. package/android/src/main/jni.cpp +12 -3
  2. package/cpp/ggml-alloc.c +49 -18
  3. package/cpp/ggml-backend-impl.h +0 -3
  4. package/cpp/ggml-backend-reg.cpp +8 -0
  5. package/cpp/ggml-backend.cpp +0 -2
  6. package/cpp/ggml-backend.h +2 -0
  7. package/cpp/ggml-cpu/amx/amx.cpp +1 -0
  8. package/cpp/ggml-cpu/arch/arm/quants.c +428 -26
  9. package/cpp/ggml-cpu/ggml-cpu-impl.h +4 -2
  10. package/cpp/ggml-cpu/ggml-cpu.c +67 -24
  11. package/cpp/ggml-cpu/ops.cpp +489 -364
  12. package/cpp/ggml-cpu/ops.h +4 -4
  13. package/cpp/ggml-cpu/repack.cpp +143 -29
  14. package/cpp/ggml-cpu/simd-mappings.h +25 -25
  15. package/cpp/ggml-cpu/unary-ops.cpp +151 -0
  16. package/cpp/ggml-cpu/unary-ops.h +7 -0
  17. package/cpp/ggml-cpu/vec.cpp +83 -0
  18. package/cpp/ggml-cpu/vec.h +20 -8
  19. package/cpp/ggml-impl.h +67 -2
  20. package/cpp/ggml-metal/ggml-metal-common.cpp +2 -2
  21. package/cpp/ggml-metal/ggml-metal-context.m +5 -6
  22. package/cpp/ggml-metal/ggml-metal-device.cpp +300 -14
  23. package/cpp/ggml-metal/ggml-metal-device.h +26 -1
  24. package/cpp/ggml-metal/ggml-metal-device.m +243 -28
  25. package/cpp/ggml-metal/ggml-metal-impl.h +177 -9
  26. package/cpp/ggml-metal/ggml-metal-ops.cpp +843 -157
  27. package/cpp/ggml-metal/ggml-metal-ops.h +8 -0
  28. package/cpp/ggml-metal/ggml-metal.cpp +8 -3
  29. package/cpp/ggml-metal/ggml-metal.metal +12436 -0
  30. package/cpp/ggml.c +317 -4
  31. package/cpp/ggml.h +139 -0
  32. package/cpp/jsi/RNWhisperJSI.cpp +7 -2
  33. package/cpp/rn-whisper.h +1 -0
  34. package/cpp/whisper.cpp +8 -2
  35. package/ios/RNWhisperContext.mm +3 -1
  36. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  37. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  38. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  39. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +139 -0
  40. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  41. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  42. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
  43. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  44. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  45. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  46. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  47. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +139 -0
  48. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  49. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  50. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  51. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
  52. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  53. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  54. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  55. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  56. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +139 -0
  57. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  58. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  59. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
  60. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  61. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  62. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  63. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  64. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +139 -0
  65. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  66. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  67. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  68. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
  69. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  70. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  71. package/lib/commonjs/version.json +1 -1
  72. package/lib/module/NativeRNWhisper.js.map +1 -1
  73. package/lib/module/version.json +1 -1
  74. package/lib/typescript/NativeRNWhisper.d.ts +2 -0
  75. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  76. package/package.json +1 -1
  77. package/src/NativeRNWhisper.ts +2 -0
  78. package/src/version.json +1 -1
  79. package/whisper-rn.podspec +1 -1
  80. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  81. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  82. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  83. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  84. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  85. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
@@ -276,6 +276,7 @@ JNIEXPORT jlong JNICALL
276
276
  Java_com_rnwhisper_WhisperContext_initContextWithAsset(
277
277
  JNIEnv *env,
278
278
  jobject thiz,
279
+ jint context_id,
279
280
  jobject asset_manager,
280
281
  jstring model_path_str
281
282
  ) {
@@ -290,6 +291,7 @@ Java_com_rnwhisper_WhisperContext_initContextWithAsset(
290
291
  const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
291
292
  context = whisper_init_from_asset(env, asset_manager, model_path_chars, cparams);
292
293
  env->ReleaseStringUTFChars(model_path_str, model_path_chars);
294
+ rnwhisper_jsi::addContext(context_id, reinterpret_cast<jlong>(context));
293
295
  return reinterpret_cast<jlong>(context);
294
296
  }
295
297
 
@@ -297,6 +299,7 @@ JNIEXPORT jlong JNICALL
297
299
  Java_com_rnwhisper_WhisperContext_initContextWithInputStream(
298
300
  JNIEnv *env,
299
301
  jobject thiz,
302
+ jint context_id,
300
303
  jobject input_stream
301
304
  ) {
302
305
  UNUSED(thiz);
@@ -308,6 +311,7 @@ Java_com_rnwhisper_WhisperContext_initContextWithInputStream(
308
311
 
309
312
  struct whisper_context *context = nullptr;
310
313
  context = whisper_init_from_input_stream(env, input_stream, cparams);
314
+ rnwhisper_jsi::addContext(context_id, reinterpret_cast<jlong>(context));
311
315
  return reinterpret_cast<jlong>(context);
312
316
  }
313
317
 
@@ -421,8 +425,9 @@ Java_com_rnwhisper_WhisperContext_fullWithNewJob(
421
425
  LOGI("About to reset timings");
422
426
  whisper_reset_timings(context);
423
427
 
424
- LOGI("About to run whisper_full");
425
- int code = whisper_full(context, params, audio_data_arr, audio_data_len);
428
+ int n_processors = readablemap::getInt(env, options, "nProcessors", 1);
429
+ LOGI("About to run whisper_full_parallel with n_processors=%d", n_processors);
430
+ int code = whisper_full_parallel(context, params, audio_data_arr, audio_data_len, n_processors);
426
431
  if (code == 0) {
427
432
  // whisper_print_timings(context);
428
433
  }
@@ -441,8 +446,11 @@ Java_com_rnwhisper_WhisperContext_createRealtimeTranscribeJob(
441
446
  jlong context_ptr,
442
447
  jobject options
443
448
  ) {
449
+ UNUSED(thiz);
450
+ UNUSED(context_ptr);
444
451
  whisper_full_params params = createFullParams(env, options);
445
452
  rnwhisper::job* job = rnwhisper::job_new(job_id, params);
453
+ job->n_processors = readablemap::getInt(env, options, "nProcessors", 1);
446
454
  rnwhisper::vad_params vad;
447
455
  vad.use_vad = readablemap::getBool(env, options, "useVad", false);
448
456
  vad.vad_ms = readablemap::getInt(env, options, "vadMs", 2000);
@@ -534,11 +542,12 @@ Java_com_rnwhisper_WhisperContext_fullWithJob(
534
542
  jint n_samples
535
543
  ) {
536
544
  UNUSED(thiz);
545
+ UNUSED(env);
537
546
  struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
538
547
 
539
548
  rnwhisper::job* job = rnwhisper::job_get(job_id);
540
549
  float* pcmf32 = job->pcm_slice_to_f32(slice_index, n_samples);
541
- int code = whisper_full(context, job->params, pcmf32, n_samples);
550
+ int code = whisper_full_parallel(context, job->params, pcmf32, n_samples, job->n_processors);
542
551
  free(pcmf32);
543
552
  if (code == 0) {
544
553
  // whisper_print_timings(context);
package/cpp/ggml-alloc.c CHANGED
@@ -226,16 +226,23 @@ static struct buffer_address wsp_ggml_dyn_tallocr_alloc(struct wsp_ggml_dyn_tall
226
226
  }
227
227
 
228
228
  if (best_fit_block == -1) {
229
- // no suitable block found, try the last block (this will grow a chunks size)
229
+ // no suitable block found, try the last block (this may grow a chunks size)
230
+ int64_t best_reuse = INT64_MIN;
230
231
  for (int c = 0; c < alloc->n_chunks; ++c) {
231
232
  struct tallocr_chunk * chunk = alloc->chunks[c];
232
233
  if (chunk->n_free_blocks > 0) {
233
234
  struct free_block * block = &chunk->free_blocks[chunk->n_free_blocks - 1];
234
235
  max_avail = MAX(max_avail, block->size);
235
- if (block->size >= size) {
236
+ int64_t reuse_factor = chunk->max_size - block->offset - size;
237
+ // reuse_factor < 0 : amount of extra memory that needs to be allocated
238
+ // reuse_factor = 0 : allocated free space exactly matches tensor size
239
+ // reuse_factor > 0 : superfluous memory that will remain unused
240
+ bool better_reuse = best_reuse < 0 && reuse_factor > best_reuse;
241
+ bool better_fit = reuse_factor >= 0 && reuse_factor < best_reuse;
242
+ if (block->size >= size && (better_reuse || better_fit)) {
236
243
  best_fit_chunk = c;
237
244
  best_fit_block = chunk->n_free_blocks - 1;
238
- break;
245
+ best_reuse = reuse_factor;
239
246
  }
240
247
  }
241
248
  }
@@ -268,7 +275,7 @@ static struct buffer_address wsp_ggml_dyn_tallocr_alloc(struct wsp_ggml_dyn_tall
268
275
  #ifdef WSP_GGML_ALLOCATOR_DEBUG
269
276
  add_allocated_tensor(alloc, addr, tensor);
270
277
  size_t cur_max = addr.offset + size;
271
- if (cur_max > alloc->max_size[addr.chunk]) {
278
+ if (cur_max > chunk->max_size) {
272
279
  // sort allocated_tensors by chunk/offset
273
280
  for (int i = 0; i < 1024; i++) {
274
281
  for (int j = i + 1; j < 1024; j++) {
@@ -392,12 +399,8 @@ static void wsp_ggml_dyn_tallocr_free(struct wsp_ggml_dyn_tallocr * alloc) {
392
399
  free(alloc);
393
400
  }
394
401
 
395
- static size_t wsp_ggml_dyn_tallocr_max_size(struct wsp_ggml_dyn_tallocr * alloc) {
396
- size_t max_size = 0;
397
- for (int i = 0; i < alloc->n_chunks; i++) {
398
- max_size += alloc->chunks[i]->max_size;
399
- }
400
- return max_size;
402
+ static size_t wsp_ggml_dyn_tallocr_max_size(struct wsp_ggml_dyn_tallocr * alloc, int chunk) {
403
+ return chunk < alloc->n_chunks ? alloc->chunks[chunk]->max_size : 0;
401
404
  }
402
405
 
403
406
 
@@ -417,10 +420,8 @@ static void wsp_ggml_vbuffer_free(struct vbuffer * buf) {
417
420
  free(buf);
418
421
  }
419
422
 
420
- static int wsp_ggml_vbuffer_n_chunks(struct vbuffer * buf) {
421
- int n = 0;
422
- while (n < WSP_GGML_VBUFFER_MAX_CHUNKS && buf->chunks[n]) n++;
423
- return n;
423
+ static size_t wsp_ggml_vbuffer_chunk_size(struct vbuffer * buf, int chunk) {
424
+ return buf->chunks[chunk] ? wsp_ggml_backend_buffer_get_size(buf->chunks[chunk]) : 0;
424
425
  }
425
426
 
426
427
  static size_t wsp_ggml_vbuffer_size(struct vbuffer * buf) {
@@ -604,6 +605,26 @@ static bool wsp_ggml_gallocr_is_allocated(wsp_ggml_gallocr_t galloc, struct wsp_
604
605
  return t->data != NULL || wsp_ggml_gallocr_hash_get(galloc, t)->allocated;
605
606
  }
606
607
 
608
+ // free the extra space at the end if the new tensor is smaller
609
+ static void wsp_ggml_gallocr_free_extra_space(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * node, struct wsp_ggml_tensor * parent) {
610
+ struct hash_node * hn = wsp_ggml_gallocr_hash_get(galloc, node);
611
+ struct hash_node * p_hn = wsp_ggml_gallocr_hash_get(galloc, parent);
612
+
613
+ size_t parent_size = wsp_ggml_backend_buft_get_alloc_size(galloc->bufts[p_hn->buffer_id], parent);
614
+ size_t node_size = wsp_ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node);
615
+
616
+ WSP_GGML_ASSERT(parent_size >= node_size);
617
+
618
+ if (parent_size > node_size) {
619
+ struct wsp_ggml_dyn_tallocr * p_alloc = galloc->buf_tallocs[p_hn->buffer_id];
620
+ struct buffer_address p_addr = p_hn->addr;
621
+ p_addr.offset += node_size;
622
+ size_t extra_size = parent_size - node_size;
623
+ AT_PRINTF("freeing extra %zu bytes from parent %s for %s\n", extra_size, parent->name, node->name);
624
+ wsp_ggml_dyn_tallocr_free_tensor(p_alloc, p_addr, extra_size, parent);
625
+ }
626
+ }
627
+
607
628
  static void wsp_ggml_gallocr_allocate_node(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * node, int buffer_id) {
608
629
  WSP_GGML_ASSERT(buffer_id >= 0);
609
630
  struct hash_node * hn = wsp_ggml_gallocr_hash_get(galloc, node);
@@ -649,6 +670,7 @@ static void wsp_ggml_gallocr_allocate_node(wsp_ggml_gallocr_t galloc, struct wsp
649
670
  hn->addr = p_hn->addr;
650
671
  p_hn->allocated = false; // avoid freeing the parent
651
672
  view_src_hn->allocated = false;
673
+ wsp_ggml_gallocr_free_extra_space(galloc, node, view_src);
652
674
  return;
653
675
  }
654
676
  } else {
@@ -656,6 +678,7 @@ static void wsp_ggml_gallocr_allocate_node(wsp_ggml_gallocr_t galloc, struct wsp
656
678
  hn->buffer_id = p_hn->buffer_id;
657
679
  hn->addr = p_hn->addr;
658
680
  p_hn->allocated = false; // avoid freeing the parent
681
+ wsp_ggml_gallocr_free_extra_space(galloc, node, parent);
659
682
  return;
660
683
  }
661
684
  }
@@ -885,12 +908,20 @@ bool wsp_ggml_gallocr_reserve_n(wsp_ggml_gallocr_t galloc, struct wsp_ggml_cgrap
885
908
  }
886
909
  }
887
910
 
888
- size_t cur_size = galloc->buffers[i] ? wsp_ggml_vbuffer_size(galloc->buffers[i]) : 0;
889
- size_t new_size = wsp_ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i]);
890
-
891
911
  // even if there are no tensors allocated in this buffer, we still need to allocate it to initialize views
892
- if (new_size > cur_size || galloc->buffers[i] == NULL) {
912
+ bool realloc = galloc->buffers[i] == NULL;
913
+ size_t new_size = 0;
914
+ for (int c = 0; c < galloc->buf_tallocs[i]->n_chunks; c++) {
915
+ size_t cur_chunk_size = galloc->buffers[i] ? wsp_ggml_vbuffer_chunk_size(galloc->buffers[i], c) : 0;
916
+ size_t new_chunk_size = wsp_ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i], c);
917
+ new_size += new_chunk_size;
918
+ if (new_chunk_size > cur_chunk_size) {
919
+ realloc = true;
920
+ }
921
+ }
922
+ if (realloc) {
893
923
  #ifndef NDEBUG
924
+ size_t cur_size = galloc->buffers[i] ? wsp_ggml_vbuffer_size(galloc->buffers[i]) : 0;
894
925
  WSP_GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, wsp_ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
895
926
  #endif
896
927
 
@@ -209,9 +209,6 @@ extern "C" {
209
209
  void * context;
210
210
  };
211
211
 
212
- // Internal backend registry API
213
- WSP_GGML_API void wsp_ggml_backend_register(wsp_ggml_backend_reg_t reg);
214
-
215
212
  // Add backend dynamic loading support to the backend
216
213
 
217
214
  // Initialize the backend
@@ -57,6 +57,10 @@
57
57
  #include "ggml-opencl.h"
58
58
  #endif
59
59
 
60
+ #ifdef WSP_GGML_USE_HEXAGON
61
+ #include "ggml-hexagon.h"
62
+ #endif
63
+
60
64
  #ifdef WSP_GGML_USE_BLAS
61
65
  #include "ggml-blas.h"
62
66
  #endif
@@ -199,6 +203,9 @@ struct wsp_ggml_backend_registry {
199
203
  #ifdef WSP_GGML_USE_OPENCL
200
204
  register_backend(wsp_ggml_backend_opencl_reg());
201
205
  #endif
206
+ #ifdef WSP_GGML_USE_HEXAGON
207
+ register_backend(wsp_ggml_backend_hexagon_reg());
208
+ #endif
202
209
  #ifdef WSP_GGML_USE_CANN
203
210
  register_backend(wsp_ggml_backend_cann_reg());
204
211
  #endif
@@ -598,6 +605,7 @@ void wsp_ggml_backend_load_all_from_path(const char * dir_path) {
598
605
  wsp_ggml_backend_load_best("sycl", silent, dir_path);
599
606
  wsp_ggml_backend_load_best("vulkan", silent, dir_path);
600
607
  wsp_ggml_backend_load_best("opencl", silent, dir_path);
608
+ wsp_ggml_backend_load_best("hexagon", silent, dir_path);
601
609
  wsp_ggml_backend_load_best("musa", silent, dir_path);
602
610
  wsp_ggml_backend_load_best("cpu", silent, dir_path);
603
611
  // check the environment variable WSP_GGML_BACKEND_PATH to load an out-of-tree backend
@@ -1698,8 +1698,6 @@ bool wsp_ggml_backend_sched_reserve(wsp_ggml_backend_sched_t sched, struct wsp_g
1698
1698
  WSP_GGML_ASSERT(sched);
1699
1699
  WSP_GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
1700
1700
 
1701
- wsp_ggml_backend_sched_reset(sched);
1702
-
1703
1701
  wsp_ggml_backend_sched_synchronize(sched);
1704
1702
 
1705
1703
  wsp_ggml_backend_sched_split_graph(sched, measure_graph);
@@ -215,6 +215,8 @@ extern "C" {
215
215
  // Backend registry
216
216
  //
217
217
 
218
+ WSP_GGML_API void wsp_ggml_backend_register(wsp_ggml_backend_reg_t reg);
219
+
218
220
  WSP_GGML_API void wsp_ggml_backend_device_register(wsp_ggml_backend_dev_t device);
219
221
 
220
222
  // Backend (reg) enumeration
@@ -149,6 +149,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
149
149
  if (op->op == WSP_GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous
150
150
  is_contiguous_2d(op->src[1]) && // src1 must be contiguous
151
151
  op->src[0]->buffer && op->src[0]->buffer->buft == wsp_ggml_backend_amx_buffer_type() &&
152
+ op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315)
152
153
  op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x
153
154
  (qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == WSP_GGML_TYPE_F16))) {
154
155
  // src1 must be host buffer