node-llama-cpp 2.5.1 → 2.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. package/README.md +20 -301
  2. package/dist/chatWrappers/{ChatMLPromptWrapper.d.ts → ChatMLChatPromptWrapper.d.ts} +1 -1
  3. package/dist/chatWrappers/{ChatMLPromptWrapper.js → ChatMLChatPromptWrapper.js} +2 -2
  4. package/dist/chatWrappers/ChatMLChatPromptWrapper.js.map +1 -0
  5. package/dist/chatWrappers/createChatWrapperByBos.js +2 -2
  6. package/dist/chatWrappers/createChatWrapperByBos.js.map +1 -1
  7. package/dist/cli/commands/BuildCommand.js +3 -1
  8. package/dist/cli/commands/BuildCommand.js.map +1 -1
  9. package/dist/cli/commands/ChatCommand.d.ts +8 -1
  10. package/dist/cli/commands/ChatCommand.js +88 -21
  11. package/dist/cli/commands/ChatCommand.js.map +1 -1
  12. package/dist/cli/commands/DownloadCommand.d.ts +3 -2
  13. package/dist/cli/commands/DownloadCommand.js +19 -38
  14. package/dist/cli/commands/DownloadCommand.js.map +1 -1
  15. package/dist/config.d.ts +5 -0
  16. package/dist/config.js +7 -0
  17. package/dist/config.js.map +1 -1
  18. package/dist/index.d.ts +5 -4
  19. package/dist/index.js +3 -2
  20. package/dist/index.js.map +1 -1
  21. package/dist/llamaEvaluator/LlamaBins.d.ts +3 -3
  22. package/dist/llamaEvaluator/LlamaBins.js +2 -2
  23. package/dist/llamaEvaluator/LlamaBins.js.map +1 -1
  24. package/dist/llamaEvaluator/LlamaChatSession.d.ts +79 -2
  25. package/dist/llamaEvaluator/LlamaChatSession.js +52 -8
  26. package/dist/llamaEvaluator/LlamaChatSession.js.map +1 -1
  27. package/dist/llamaEvaluator/LlamaContext.d.ts +60 -3
  28. package/dist/llamaEvaluator/LlamaContext.js +36 -4
  29. package/dist/llamaEvaluator/LlamaContext.js.map +1 -1
  30. package/dist/llamaEvaluator/LlamaGrammar.d.ts +16 -3
  31. package/dist/llamaEvaluator/LlamaGrammar.js +23 -4
  32. package/dist/llamaEvaluator/LlamaGrammar.js.map +1 -1
  33. package/dist/llamaEvaluator/LlamaGrammarEvaluationState.d.ts +14 -0
  34. package/dist/llamaEvaluator/LlamaGrammarEvaluationState.js +16 -0
  35. package/dist/llamaEvaluator/LlamaGrammarEvaluationState.js.map +1 -0
  36. package/dist/llamaEvaluator/LlamaModel.d.ts +46 -14
  37. package/dist/llamaEvaluator/LlamaModel.js +23 -16
  38. package/dist/llamaEvaluator/LlamaModel.js.map +1 -1
  39. package/dist/state.d.ts +2 -0
  40. package/dist/state.js +8 -0
  41. package/dist/state.js.map +1 -0
  42. package/dist/utils/cloneLlamaCppRepo.d.ts +1 -0
  43. package/dist/utils/cloneLlamaCppRepo.js +59 -0
  44. package/dist/utils/cloneLlamaCppRepo.js.map +1 -0
  45. package/dist/utils/compileLLamaCpp.js +23 -5
  46. package/dist/utils/compileLLamaCpp.js.map +1 -1
  47. package/dist/utils/getBin.d.ts +21 -13
  48. package/dist/utils/gitReleaseBundles.d.ts +2 -0
  49. package/dist/utils/gitReleaseBundles.js +64 -0
  50. package/dist/utils/gitReleaseBundles.js.map +1 -0
  51. package/llama/addon.cpp +184 -110
  52. package/llama/binariesGithubRelease.json +1 -1
  53. package/llama/gitRelease.bundle +0 -0
  54. package/llama/toolchains/darwin.host-x64.target-arm64.cmake +8 -0
  55. package/llama/toolchains/linux.host-arm64.target-x64.cmake +5 -0
  56. package/llama/toolchains/linux.host-x64.target-arm64.cmake +5 -0
  57. package/llama/toolchains/linux.host-x64.target-arm71.cmake +5 -0
  58. package/llamaBins/linux-arm64/llama-addon.node +0 -0
  59. package/llamaBins/linux-armv7l/llama-addon.node +0 -0
  60. package/llamaBins/linux-x64/llama-addon.node +0 -0
  61. package/llamaBins/mac-arm64/ggml-metal.metal +258 -85
  62. package/llamaBins/mac-arm64/llama-addon.node +0 -0
  63. package/llamaBins/mac-x64/ggml-metal.metal +258 -85
  64. package/llamaBins/mac-x64/llama-addon.node +0 -0
  65. package/llamaBins/win-x64/llama-addon.node +0 -0
  66. package/package.json +10 -4
  67. package/dist/chatWrappers/ChatMLPromptWrapper.js.map +0 -1
  68. package/llamaBins/linux-ppc64le/llama-addon.node +0 -0
@@ -13,8 +13,8 @@ typedef struct {
13
13
 
14
14
  #define QK4_1 32
15
15
  typedef struct {
16
- half d; // delta
17
- half m; // min
16
+ half d; // delta
17
+ half m; // min
18
18
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
19
  } block_q4_1;
20
20
 
@@ -24,12 +24,59 @@ typedef struct {
24
24
  int8_t qs[QK8_0]; // quants
25
25
  } block_q8_0;
26
26
 
27
+ // general-purpose kernel for addition of two tensors
28
+ // pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
29
+ // cons: not very efficient
27
30
  kernel void kernel_add(
28
- device const float4 * src0,
29
- device const float4 * src1,
30
- device float4 * dst,
31
- uint tpig[[thread_position_in_grid]]) {
32
- dst[tpig] = src0[tpig] + src1[tpig];
31
+ device const char * src0,
32
+ device const char * src1,
33
+ device char * dst,
34
+ constant int64_t & ne00,
35
+ constant int64_t & ne01,
36
+ constant int64_t & ne02,
37
+ constant int64_t & ne03,
38
+ constant int64_t & nb00,
39
+ constant int64_t & nb01,
40
+ constant int64_t & nb02,
41
+ constant int64_t & nb03,
42
+ constant int64_t & ne10,
43
+ constant int64_t & ne11,
44
+ constant int64_t & ne12,
45
+ constant int64_t & ne13,
46
+ constant int64_t & nb10,
47
+ constant int64_t & nb11,
48
+ constant int64_t & nb12,
49
+ constant int64_t & nb13,
50
+ constant int64_t & ne0,
51
+ constant int64_t & ne1,
52
+ constant int64_t & ne2,
53
+ constant int64_t & ne3,
54
+ constant int64_t & nb0,
55
+ constant int64_t & nb1,
56
+ constant int64_t & nb2,
57
+ constant int64_t & nb3,
58
+ uint3 tgpig[[threadgroup_position_in_grid]],
59
+ uint3 tpitg[[thread_position_in_threadgroup]],
60
+ uint3 ntg[[threads_per_threadgroup]]) {
61
+ const int64_t i03 = tgpig.z;
62
+ const int64_t i02 = tgpig.y;
63
+ const int64_t i01 = tgpig.x;
64
+
65
+ const int64_t i13 = i03 % ne13;
66
+ const int64_t i12 = i02 % ne12;
67
+ const int64_t i11 = i01 % ne11;
68
+
69
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
70
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
71
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
72
+
73
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
74
+ ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
75
+
76
+ src0_ptr += ntg.x*nb00;
77
+ src1_ptr += ntg.x*nb10;
78
+ dst_ptr += ntg.x*nb0;
79
+ }
33
80
  }
34
81
 
35
82
  // assumption: src1 is a row
@@ -38,7 +85,7 @@ kernel void kernel_add_row(
38
85
  device const float4 * src0,
39
86
  device const float4 * src1,
40
87
  device float4 * dst,
41
- constant int64_t & nb,
88
+ constant int64_t & nb [[buffer(27)]],
42
89
  uint tpig[[thread_position_in_grid]]) {
43
90
  dst[tpig] = src0[tpig] + src1[tpig % nb];
44
91
  }
@@ -85,6 +132,13 @@ kernel void kernel_relu(
85
132
  dst[tpig] = max(0.0f, src0[tpig]);
86
133
  }
87
134
 
135
+ kernel void kernel_sqr(
136
+ device const float * src0,
137
+ device float * dst,
138
+ uint tpig[[thread_position_in_grid]]) {
139
+ dst[tpig] = src0[tpig] * src0[tpig];
140
+ }
141
+
88
142
  constant float GELU_COEF_A = 0.044715f;
89
143
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
90
144
 
@@ -291,10 +345,11 @@ kernel void kernel_rms_norm(
291
345
  uint sgitg[[simdgroup_index_in_threadgroup]],
292
346
  uint tiisg[[thread_index_in_simdgroup]],
293
347
  uint ntg[[threads_per_threadgroup]]) {
294
- device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
295
- device const float * x_scalar = (device const float *) x;
296
- float4 sumf=0;
297
- float all_sum=0;
348
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
349
+ device const float * x_scalar = (device const float *) x;
350
+
351
+ float4 sumf = 0;
352
+ float all_sum = 0;
298
353
 
299
354
  // parallel sum
300
355
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
@@ -307,6 +362,7 @@ kernel void kernel_rms_norm(
307
362
  }
308
363
 
309
364
  threadgroup_barrier(mem_flags::mem_threadgroup);
365
+
310
366
  // broadcast, simd group number is ntg / 32
311
367
  for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
312
368
  if (tpitg < i) {
@@ -314,7 +370,9 @@ kernel void kernel_rms_norm(
314
370
  }
315
371
  }
316
372
  if (tpitg == 0) {
317
- for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
373
+ for (int i = 4 * (ne00 / 4); i < ne00; i++) {
374
+ sum[0] += x_scalar[i];
375
+ }
318
376
  sum[0] /= ne00;
319
377
  }
320
378
 
@@ -329,7 +387,9 @@ kernel void kernel_rms_norm(
329
387
  y[i00] = x[i00] * scale;
330
388
  }
331
389
  if (tpitg == 0) {
332
- for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
390
+ for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
391
+ y_scalar[i00] = x_scalar[i00] * scale;
392
+ }
333
393
  }
334
394
  }
335
395
 
@@ -369,8 +429,8 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
369
429
  }
370
430
 
371
431
  // putting them in the kernel cause a significant performance penalty
372
- #define N_DST 4 // each SIMD group works on 4 rows
373
- #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
432
+ #define N_DST 4 // each SIMD group works on 4 rows
433
+ #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
374
434
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
375
435
  //Note: This is a template, but strictly speaking it only applies to
376
436
  // quantizations where the block size is 32. It also does not
@@ -381,18 +441,23 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
381
441
  int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
382
442
  uint3 tgpig, uint tiisg, uint sgitg) {
383
443
  const int nb = ne00/QK4_0;
444
+
384
445
  const int r0 = tgpig.x;
385
446
  const int r1 = tgpig.y;
386
447
  const int im = tgpig.z;
448
+
387
449
  const int first_row = (r0 * nsg + sgitg) * nr;
450
+
388
451
  const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
452
+
389
453
  device const block_q_type * x = (device const block_q_type *) src0 + offset0;
390
454
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
391
- float yl[16]; // src1 vector cache
392
- float sumf[nr]={0.f};
393
455
 
394
- const int ix = tiisg/2;
395
- const int il = 8*(tiisg%2);
456
+ float yl[16]; // src1 vector cache
457
+ float sumf[nr] = {0.f};
458
+
459
+ const int ix = (tiisg/2);
460
+ const int il = (tiisg%2)*8;
396
461
 
397
462
  device const float * yb = y + ix * QK4_0 + il;
398
463
 
@@ -403,6 +468,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
403
468
  sumy += yb[i] + yb[i+1];
404
469
  yl[i+0] = yb[i+ 0];
405
470
  yl[i+1] = yb[i+ 1]/256.f;
471
+
406
472
  sumy += yb[i+16] + yb[i+17];
407
473
  yl[i+8] = yb[i+16]/16.f;
408
474
  yl[i+9] = yb[i+17]/4096.f;
@@ -418,12 +484,12 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
418
484
  for (int row = 0; row < nr; ++row) {
419
485
  const float tot = simd_sum(sumf[row]);
420
486
  if (tiisg == 0 && first_row + row < ne01) {
421
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
487
+ dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
422
488
  }
423
489
  }
424
490
  }
425
491
 
426
- kernel void kernel_mul_mat_q4_0_f32(
492
+ kernel void kernel_mul_mv_q4_0_f32(
427
493
  device const void * src0,
428
494
  device const float * src1,
429
495
  device float * dst,
@@ -436,12 +502,12 @@ kernel void kernel_mul_mat_q4_0_f32(
436
502
  constant int64_t & ne1[[buffer(16)]],
437
503
  constant uint & gqa[[buffer(17)]],
438
504
  uint3 tgpig[[threadgroup_position_in_grid]],
439
- uint tiisg[[thread_index_in_simdgroup]],
440
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
505
+ uint tiisg[[thread_index_in_simdgroup]],
506
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
441
507
  mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
442
508
  }
443
509
 
444
- kernel void kernel_mul_mat_q4_1_f32(
510
+ kernel void kernel_mul_mv_q4_1_f32(
445
511
  device const void * src0,
446
512
  device const float * src1,
447
513
  device float * dst,
@@ -461,7 +527,7 @@ kernel void kernel_mul_mat_q4_1_f32(
461
527
 
462
528
  #define NB_Q8_0 8
463
529
 
464
- kernel void kernel_mul_mat_q8_0_f32(
530
+ kernel void kernel_mul_mv_q8_0_f32(
465
531
  device const void * src0,
466
532
  device const float * src1,
467
533
  device float * dst,
@@ -525,7 +591,7 @@ kernel void kernel_mul_mat_q8_0_f32(
525
591
 
526
592
  #define N_F32_F32 4
527
593
 
528
- kernel void kernel_mul_mat_f32_f32(
594
+ kernel void kernel_mul_mv_f32_f32(
529
595
  device const char * src0,
530
596
  device const char * src1,
531
597
  device float * dst,
@@ -596,7 +662,7 @@ kernel void kernel_mul_mat_f32_f32(
596
662
  }
597
663
  }
598
664
 
599
- kernel void kernel_mul_mat_f16_f32_1row(
665
+ kernel void kernel_mul_mv_f16_f32_1row(
600
666
  device const char * src0,
601
667
  device const char * src1,
602
668
  device float * dst,
@@ -615,7 +681,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
615
681
  constant int64_t & ne0,
616
682
  constant int64_t & ne1,
617
683
  uint3 tgpig[[threadgroup_position_in_grid]],
618
- uint tiisg[[thread_index_in_simdgroup]]) {
684
+ uint tiisg[[thread_index_in_simdgroup]]) {
619
685
 
620
686
  const int64_t r0 = tgpig.x;
621
687
  const int64_t r1 = tgpig.y;
@@ -650,7 +716,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
650
716
 
651
717
  #define N_F16_F32 4
652
718
 
653
- kernel void kernel_mul_mat_f16_f32(
719
+ kernel void kernel_mul_mv_f16_f32(
654
720
  device const char * src0,
655
721
  device const char * src1,
656
722
  device float * dst,
@@ -722,7 +788,7 @@ kernel void kernel_mul_mat_f16_f32(
722
788
  }
723
789
 
724
790
  // Assumes row size (ne00) is a multiple of 4
725
- kernel void kernel_mul_mat_f16_f32_l4(
791
+ kernel void kernel_mul_mv_f16_f32_l4(
726
792
  device const char * src0,
727
793
  device const char * src1,
728
794
  device float * dst,
@@ -783,7 +849,9 @@ kernel void kernel_alibi_f32(
783
849
  constant uint64_t & nb1,
784
850
  constant uint64_t & nb2,
785
851
  constant uint64_t & nb3,
786
- constant float & m0,
852
+ constant float & m0,
853
+ constant float & m1,
854
+ constant int & n_heads_log2_floor,
787
855
  uint3 tgpig[[threadgroup_position_in_grid]],
788
856
  uint3 tpitg[[thread_position_in_threadgroup]],
789
857
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -799,37 +867,73 @@ kernel void kernel_alibi_f32(
799
867
  const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
800
868
 
801
869
  device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
802
- float m_k = pow(m0, i2 + 1);
870
+ float m_k;
871
+ if (i2 < n_heads_log2_floor) {
872
+ m_k = pow(m0, i2 + 1);
873
+ } else {
874
+ m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
875
+ }
803
876
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
804
877
  device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
805
878
  dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
806
879
  }
807
880
  }
808
881
 
882
+ typedef void (rope_t)(
883
+ device const void * src0,
884
+ device const int32_t * src1,
885
+ device float * dst,
886
+ constant int64_t & ne00,
887
+ constant int64_t & ne01,
888
+ constant int64_t & ne02,
889
+ constant int64_t & ne03,
890
+ constant uint64_t & nb00,
891
+ constant uint64_t & nb01,
892
+ constant uint64_t & nb02,
893
+ constant uint64_t & nb03,
894
+ constant int64_t & ne0,
895
+ constant int64_t & ne1,
896
+ constant int64_t & ne2,
897
+ constant int64_t & ne3,
898
+ constant uint64_t & nb0,
899
+ constant uint64_t & nb1,
900
+ constant uint64_t & nb2,
901
+ constant uint64_t & nb3,
902
+ constant int & n_past,
903
+ constant int & n_dims,
904
+ constant int & mode,
905
+ constant float & freq_base,
906
+ constant float & freq_scale,
907
+ uint tiitg[[thread_index_in_threadgroup]],
908
+ uint3 tptg[[threads_per_threadgroup]],
909
+ uint3 tgpig[[threadgroup_position_in_grid]]);
910
+
911
+ template<typename T>
809
912
  kernel void kernel_rope(
810
- device const void * src0,
811
- device float * dst,
812
- constant int64_t & ne00,
813
- constant int64_t & ne01,
814
- constant int64_t & ne02,
815
- constant int64_t & ne03,
816
- constant uint64_t & nb00,
817
- constant uint64_t & nb01,
818
- constant uint64_t & nb02,
819
- constant uint64_t & nb03,
820
- constant int64_t & ne0,
821
- constant int64_t & ne1,
822
- constant int64_t & ne2,
823
- constant int64_t & ne3,
824
- constant uint64_t & nb0,
825
- constant uint64_t & nb1,
826
- constant uint64_t & nb2,
827
- constant uint64_t & nb3,
828
- constant int & n_past,
829
- constant int & n_dims,
830
- constant int & mode,
831
- constant float & freq_base,
832
- constant float & freq_scale,
913
+ device const void * src0,
914
+ device const int32_t * src1,
915
+ device float * dst,
916
+ constant int64_t & ne00,
917
+ constant int64_t & ne01,
918
+ constant int64_t & ne02,
919
+ constant int64_t & ne03,
920
+ constant uint64_t & nb00,
921
+ constant uint64_t & nb01,
922
+ constant uint64_t & nb02,
923
+ constant uint64_t & nb03,
924
+ constant int64_t & ne0,
925
+ constant int64_t & ne1,
926
+ constant int64_t & ne2,
927
+ constant int64_t & ne3,
928
+ constant uint64_t & nb0,
929
+ constant uint64_t & nb1,
930
+ constant uint64_t & nb2,
931
+ constant uint64_t & nb3,
932
+ constant int & n_past,
933
+ constant int & n_dims,
934
+ constant int & mode,
935
+ constant float & freq_base,
936
+ constant float & freq_scale,
833
937
  uint tiitg[[thread_index_in_threadgroup]],
834
938
  uint3 tptg[[threads_per_threadgroup]],
835
939
  uint3 tgpig[[threadgroup_position_in_grid]]) {
@@ -839,7 +943,9 @@ kernel void kernel_rope(
839
943
 
840
944
  const bool is_neox = mode & 2;
841
945
 
842
- const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
946
+ device const int32_t * pos = src1;
947
+
948
+ const int64_t p = pos[i2];
843
949
 
844
950
  const float theta_0 = freq_scale * (float)p;
845
951
  const float inv_ndims = -1.f/n_dims;
@@ -851,11 +957,11 @@ kernel void kernel_rope(
851
957
  const float cos_theta = cos(theta);
852
958
  const float sin_theta = sin(theta);
853
959
 
854
- device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
855
- device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
960
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
961
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
856
962
 
857
- const float x0 = src[0];
858
- const float x1 = src[1];
963
+ const T x0 = src[0];
964
+ const T x1 = src[1];
859
965
 
860
966
  dst_data[0] = x0*cos_theta - x1*sin_theta;
861
967
  dst_data[1] = x0*sin_theta + x1*cos_theta;
@@ -870,8 +976,8 @@ kernel void kernel_rope(
870
976
 
871
977
  const int64_t i0 = ib*n_dims + ic/2;
872
978
 
873
- device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
874
- device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
979
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
980
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
875
981
 
876
982
  const float x0 = src[0];
877
983
  const float x1 = src[n_dims/2];
@@ -883,6 +989,9 @@ kernel void kernel_rope(
883
989
  }
884
990
  }
885
991
 
992
+ template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
993
+ template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
994
+
886
995
  kernel void kernel_cpy_f16_f16(
887
996
  device const half * src0,
888
997
  device half * dst,
@@ -1008,6 +1117,62 @@ kernel void kernel_cpy_f32_f32(
1008
1117
  }
1009
1118
  }
1010
1119
 
1120
+ kernel void kernel_concat(
1121
+ device const char * src0,
1122
+ device const char * src1,
1123
+ device char * dst,
1124
+ constant int64_t & ne00,
1125
+ constant int64_t & ne01,
1126
+ constant int64_t & ne02,
1127
+ constant int64_t & ne03,
1128
+ constant uint64_t & nb00,
1129
+ constant uint64_t & nb01,
1130
+ constant uint64_t & nb02,
1131
+ constant uint64_t & nb03,
1132
+ constant int64_t & ne10,
1133
+ constant int64_t & ne11,
1134
+ constant int64_t & ne12,
1135
+ constant int64_t & ne13,
1136
+ constant uint64_t & nb10,
1137
+ constant uint64_t & nb11,
1138
+ constant uint64_t & nb12,
1139
+ constant uint64_t & nb13,
1140
+ constant int64_t & ne0,
1141
+ constant int64_t & ne1,
1142
+ constant int64_t & ne2,
1143
+ constant int64_t & ne3,
1144
+ constant uint64_t & nb0,
1145
+ constant uint64_t & nb1,
1146
+ constant uint64_t & nb2,
1147
+ constant uint64_t & nb3,
1148
+ uint3 tgpig[[threadgroup_position_in_grid]],
1149
+ uint3 tpitg[[thread_position_in_threadgroup]],
1150
+ uint3 ntg[[threads_per_threadgroup]]) {
1151
+
1152
+ const int64_t i03 = tgpig.z;
1153
+ const int64_t i02 = tgpig.y;
1154
+ const int64_t i01 = tgpig.x;
1155
+
1156
+ const int64_t i13 = i03 % ne13;
1157
+ const int64_t i12 = i02 % ne12;
1158
+ const int64_t i11 = i01 % ne11;
1159
+
1160
+ device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
1161
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
1162
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
1163
+
1164
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1165
+ if (i02 < ne02) {
1166
+ ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
1167
+ src0_ptr += ntg.x*nb00;
1168
+ } else {
1169
+ ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
1170
+ src1_ptr += ntg.x*nb10;
1171
+ }
1172
+ dst_ptr += ntg.x*nb0;
1173
+ }
1174
+ }
1175
+
1011
1176
  //============================================ k-quants ======================================================
1012
1177
 
1013
1178
  #ifndef QK_K
@@ -1100,7 +1265,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
1100
1265
 
1101
1266
  //====================================== dot products =========================
1102
1267
 
1103
- kernel void kernel_mul_mat_q2_K_f32(
1268
+ kernel void kernel_mul_mv_q2_K_f32(
1104
1269
  device const void * src0,
1105
1270
  device const float * src1,
1106
1271
  device float * dst,
@@ -1244,7 +1409,7 @@ kernel void kernel_mul_mat_q2_K_f32(
1244
1409
  }
1245
1410
 
1246
1411
  #if QK_K == 256
1247
- kernel void kernel_mul_mat_q3_K_f32(
1412
+ kernel void kernel_mul_mv_q3_K_f32(
1248
1413
  device const void * src0,
1249
1414
  device const float * src1,
1250
1415
  device float * dst,
@@ -1273,8 +1438,8 @@ kernel void kernel_mul_mat_q3_K_f32(
1273
1438
 
1274
1439
  float yl[32];
1275
1440
 
1276
- const uint16_t kmask1 = 0x3030;
1277
- const uint16_t kmask2 = 0x0f0f;
1441
+ //const uint16_t kmask1 = 0x3030;
1442
+ //const uint16_t kmask2 = 0x0f0f;
1278
1443
 
1279
1444
  const int tid = tiisg/4;
1280
1445
  const int ix = tiisg%4;
@@ -1396,7 +1561,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1396
1561
  }
1397
1562
  }
1398
1563
  #else
1399
- kernel void kernel_mul_mat_q3_K_f32(
1564
+ kernel void kernel_mul_mv_q3_K_f32(
1400
1565
  device const void * src0,
1401
1566
  device const float * src1,
1402
1567
  device float * dst,
@@ -1467,7 +1632,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1467
1632
  #endif
1468
1633
 
1469
1634
  #if QK_K == 256
1470
- kernel void kernel_mul_mat_q4_K_f32(
1635
+ kernel void kernel_mul_mv_q4_K_f32(
1471
1636
  device const void * src0,
1472
1637
  device const float * src1,
1473
1638
  device float * dst,
@@ -1573,7 +1738,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1573
1738
  }
1574
1739
  }
1575
1740
  #else
1576
- kernel void kernel_mul_mat_q4_K_f32(
1741
+ kernel void kernel_mul_mv_q4_K_f32(
1577
1742
  device const void * src0,
1578
1743
  device const float * src1,
1579
1744
  device float * dst,
@@ -1662,7 +1827,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1662
1827
  }
1663
1828
  #endif
1664
1829
 
1665
- kernel void kernel_mul_mat_q5_K_f32(
1830
+ kernel void kernel_mul_mv_q5_K_f32(
1666
1831
  device const void * src0,
1667
1832
  device const float * src1,
1668
1833
  device float * dst,
@@ -1835,7 +2000,7 @@ kernel void kernel_mul_mat_q5_K_f32(
1835
2000
 
1836
2001
  }
1837
2002
 
1838
- kernel void kernel_mul_mat_q6_K_f32(
2003
+ kernel void kernel_mul_mv_q6_K_f32(
1839
2004
  device const void * src0,
1840
2005
  device const float * src1,
1841
2006
  device float * dst,
@@ -2173,7 +2338,7 @@ kernel void kernel_get_rows(
2173
2338
  }
2174
2339
 
2175
2340
  #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
2176
- #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
2341
+ #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
2177
2342
  #define BLOCK_SIZE_K 32
2178
2343
  #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
2179
2344
  #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
@@ -2210,9 +2375,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
2210
2375
  const uint r0 = tgpig.y;
2211
2376
  const uint r1 = tgpig.x;
2212
2377
  const uint im = tgpig.z;
2378
+
2213
2379
  // if this block is of 64x32 shape or smaller
2214
2380
  short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
2215
2381
  short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
2382
+
2216
2383
  // a thread shouldn't load data outside of the matrix
2217
2384
  short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
2218
2385
  short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
@@ -2236,26 +2403,30 @@ kernel void kernel_mul_mm(device const uchar * src0,
2236
2403
  + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
2237
2404
 
2238
2405
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
2239
- //load data and store to threadgroup memory
2406
+ // load data and store to threadgroup memory
2240
2407
  half4x4 temp_a;
2241
2408
  dequantize_func(x, il, temp_a);
2242
2409
  threadgroup_barrier(mem_flags::mem_threadgroup);
2410
+
2243
2411
  #pragma unroll(16)
2244
2412
  for (int i = 0; i < 16; i++) {
2245
2413
  *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
2246
- + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
2247
- + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
2414
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
2415
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
2248
2416
  }
2249
- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
2250
- = *((device float2x4 *)y);
2417
+
2418
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
2419
+
2251
2420
  il = (il + 2 < nl) ? il + 2 : il % 2;
2252
2421
  x = (il < 2) ? x + (2+nl-1)/nl : x;
2253
2422
  y += BLOCK_SIZE_K;
2254
2423
 
2255
2424
  threadgroup_barrier(mem_flags::mem_threadgroup);
2256
- //load matrices from threadgroup memory and conduct outer products
2425
+
2426
+ // load matrices from threadgroup memory and conduct outer products
2257
2427
  threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
2258
2428
  threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
2429
+
2259
2430
  #pragma unroll(4)
2260
2431
  for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
2261
2432
  #pragma unroll(4)
@@ -2270,6 +2441,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
2270
2441
 
2271
2442
  lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
2272
2443
  lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
2444
+
2273
2445
  #pragma unroll(8)
2274
2446
  for (int i = 0; i < 8; i++){
2275
2447
  simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
@@ -2278,25 +2450,26 @@ kernel void kernel_mul_mm(device const uchar * src0,
2278
2450
  }
2279
2451
 
2280
2452
  if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
2281
- device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
2282
- + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
2453
+ device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
2454
+ + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
2283
2455
  for (int i = 0; i < 8; i++) {
2284
2456
  simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
2285
2457
  }
2286
2458
  } else {
2287
2459
  // block is smaller than 64x32, we should avoid writing data outside of the matrix
2288
2460
  threadgroup_barrier(mem_flags::mem_threadgroup);
2289
- threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
2461
+ threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
2290
2462
  + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
2291
2463
  for (int i = 0; i < 8; i++) {
2292
2464
  simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
2293
2465
  }
2294
2466
 
2295
2467
  threadgroup_barrier(mem_flags::mem_threadgroup);
2296
- device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
2297
- if (sgitg==0) {
2468
+
2469
+ device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
2470
+ if (sgitg == 0) {
2298
2471
  for (int i = 0; i < n_rows; i++) {
2299
- for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
2472
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
2300
2473
  *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
2301
2474
  }
2302
2475
  }