node-llama-cpp 2.5.0 → 2.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. package/README.md +19 -301
  2. package/dist/chatWrappers/{ChatMLPromptWrapper.d.ts → ChatMLChatPromptWrapper.d.ts} +1 -1
  3. package/dist/chatWrappers/{ChatMLPromptWrapper.js → ChatMLChatPromptWrapper.js} +2 -2
  4. package/dist/chatWrappers/ChatMLChatPromptWrapper.js.map +1 -0
  5. package/dist/chatWrappers/createChatWrapperByBos.js +2 -2
  6. package/dist/chatWrappers/createChatWrapperByBos.js.map +1 -1
  7. package/dist/cli/commands/BuildCommand.js +3 -1
  8. package/dist/cli/commands/BuildCommand.js.map +1 -1
  9. package/dist/cli/commands/ChatCommand.d.ts +8 -1
  10. package/dist/cli/commands/ChatCommand.js +88 -21
  11. package/dist/cli/commands/ChatCommand.js.map +1 -1
  12. package/dist/cli/commands/DownloadCommand.d.ts +2 -2
  13. package/dist/cli/commands/DownloadCommand.js +13 -38
  14. package/dist/cli/commands/DownloadCommand.js.map +1 -1
  15. package/dist/config.d.ts +5 -0
  16. package/dist/config.js +7 -0
  17. package/dist/config.js.map +1 -1
  18. package/dist/index.d.ts +5 -4
  19. package/dist/index.js +3 -2
  20. package/dist/index.js.map +1 -1
  21. package/dist/llamaEvaluator/LlamaBins.d.ts +3 -3
  22. package/dist/llamaEvaluator/LlamaBins.js +2 -2
  23. package/dist/llamaEvaluator/LlamaBins.js.map +1 -1
  24. package/dist/llamaEvaluator/LlamaChatSession.d.ts +79 -2
  25. package/dist/llamaEvaluator/LlamaChatSession.js +52 -8
  26. package/dist/llamaEvaluator/LlamaChatSession.js.map +1 -1
  27. package/dist/llamaEvaluator/LlamaContext.d.ts +60 -3
  28. package/dist/llamaEvaluator/LlamaContext.js +36 -4
  29. package/dist/llamaEvaluator/LlamaContext.js.map +1 -1
  30. package/dist/llamaEvaluator/LlamaGrammar.d.ts +16 -3
  31. package/dist/llamaEvaluator/LlamaGrammar.js +23 -4
  32. package/dist/llamaEvaluator/LlamaGrammar.js.map +1 -1
  33. package/dist/llamaEvaluator/LlamaGrammarEvaluationState.d.ts +14 -0
  34. package/dist/llamaEvaluator/LlamaGrammarEvaluationState.js +16 -0
  35. package/dist/llamaEvaluator/LlamaGrammarEvaluationState.js.map +1 -0
  36. package/dist/llamaEvaluator/LlamaModel.d.ts +46 -14
  37. package/dist/llamaEvaluator/LlamaModel.js +23 -16
  38. package/dist/llamaEvaluator/LlamaModel.js.map +1 -1
  39. package/dist/state.d.ts +2 -0
  40. package/dist/state.js +8 -0
  41. package/dist/state.js.map +1 -0
  42. package/dist/utils/cloneLlamaCppRepo.d.ts +1 -0
  43. package/dist/utils/cloneLlamaCppRepo.js +62 -0
  44. package/dist/utils/cloneLlamaCppRepo.js.map +1 -0
  45. package/dist/utils/compileLLamaCpp.js +24 -6
  46. package/dist/utils/compileLLamaCpp.js.map +1 -1
  47. package/dist/utils/getBin.d.ts +21 -13
  48. package/dist/utils/gitReleaseBundles.d.ts +2 -0
  49. package/dist/utils/gitReleaseBundles.js +25 -0
  50. package/dist/utils/gitReleaseBundles.js.map +1 -0
  51. package/llama/addon.cpp +184 -110
  52. package/llama/binariesGithubRelease.json +1 -1
  53. package/llama/gitRelease.bundle +0 -0
  54. package/llama/toolchains/darwin.host-x64.target-arm64.cmake +8 -0
  55. package/llama/toolchains/linux.host-arm64.target-x64.cmake +5 -0
  56. package/llama/toolchains/linux.host-x64.target-arm64.cmake +5 -0
  57. package/llama/toolchains/linux.host-x64.target-arm71.cmake +5 -0
  58. package/llamaBins/linux-arm64/llama-addon.node +0 -0
  59. package/llamaBins/linux-armv7l/llama-addon.node +0 -0
  60. package/llamaBins/linux-x64/llama-addon.node +0 -0
  61. package/llamaBins/mac-arm64/ggml-metal.metal +246 -79
  62. package/llamaBins/mac-arm64/llama-addon.node +0 -0
  63. package/llamaBins/mac-x64/ggml-metal.metal +246 -79
  64. package/llamaBins/mac-x64/llama-addon.node +0 -0
  65. package/llamaBins/win-x64/llama-addon.node +0 -0
  66. package/package.json +10 -4
  67. package/dist/chatWrappers/ChatMLPromptWrapper.js.map +0 -1
  68. package/llamaBins/linux-ppc64le/llama-addon.node +0 -0
@@ -13,8 +13,8 @@ typedef struct {
13
13
 
14
14
  #define QK4_1 32
15
15
  typedef struct {
16
- half d; // delta
17
- half m; // min
16
+ half d; // delta
17
+ half m; // min
18
18
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
19
  } block_q4_1;
20
20
 
@@ -24,12 +24,59 @@ typedef struct {
24
24
  int8_t qs[QK8_0]; // quants
25
25
  } block_q8_0;
26
26
 
27
+ // general-purpose kernel for addition of two tensors
28
+ // pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
29
+ // cons: not very efficient
27
30
  kernel void kernel_add(
28
- device const float4 * src0,
29
- device const float4 * src1,
30
- device float4 * dst,
31
- uint tpig[[thread_position_in_grid]]) {
32
- dst[tpig] = src0[tpig] + src1[tpig];
31
+ device const char * src0,
32
+ device const char * src1,
33
+ device char * dst,
34
+ constant int64_t & ne00,
35
+ constant int64_t & ne01,
36
+ constant int64_t & ne02,
37
+ constant int64_t & ne03,
38
+ constant int64_t & nb00,
39
+ constant int64_t & nb01,
40
+ constant int64_t & nb02,
41
+ constant int64_t & nb03,
42
+ constant int64_t & ne10,
43
+ constant int64_t & ne11,
44
+ constant int64_t & ne12,
45
+ constant int64_t & ne13,
46
+ constant int64_t & nb10,
47
+ constant int64_t & nb11,
48
+ constant int64_t & nb12,
49
+ constant int64_t & nb13,
50
+ constant int64_t & ne0,
51
+ constant int64_t & ne1,
52
+ constant int64_t & ne2,
53
+ constant int64_t & ne3,
54
+ constant int64_t & nb0,
55
+ constant int64_t & nb1,
56
+ constant int64_t & nb2,
57
+ constant int64_t & nb3,
58
+ uint3 tgpig[[threadgroup_position_in_grid]],
59
+ uint3 tpitg[[thread_position_in_threadgroup]],
60
+ uint3 ntg[[threads_per_threadgroup]]) {
61
+ const int64_t i03 = tgpig.z;
62
+ const int64_t i02 = tgpig.y;
63
+ const int64_t i01 = tgpig.x;
64
+
65
+ const int64_t i13 = i03 % ne13;
66
+ const int64_t i12 = i02 % ne12;
67
+ const int64_t i11 = i01 % ne11;
68
+
69
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
70
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
71
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
72
+
73
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
74
+ ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
75
+
76
+ src0_ptr += ntg.x*nb00;
77
+ src1_ptr += ntg.x*nb10;
78
+ dst_ptr += ntg.x*nb0;
79
+ }
33
80
  }
34
81
 
35
82
  // assumption: src1 is a row
@@ -38,7 +85,7 @@ kernel void kernel_add_row(
38
85
  device const float4 * src0,
39
86
  device const float4 * src1,
40
87
  device float4 * dst,
41
- constant int64_t & nb,
88
+ constant int64_t & nb [[buffer(27)]],
42
89
  uint tpig[[thread_position_in_grid]]) {
43
90
  dst[tpig] = src0[tpig] + src1[tpig % nb];
44
91
  }
@@ -85,6 +132,13 @@ kernel void kernel_relu(
85
132
  dst[tpig] = max(0.0f, src0[tpig]);
86
133
  }
87
134
 
135
+ kernel void kernel_sqr(
136
+ device const float * src0,
137
+ device float * dst,
138
+ uint tpig[[thread_position_in_grid]]) {
139
+ dst[tpig] = src0[tpig] * src0[tpig];
140
+ }
141
+
88
142
  constant float GELU_COEF_A = 0.044715f;
89
143
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
90
144
 
@@ -369,8 +423,8 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
369
423
  }
370
424
 
371
425
  // putting them in the kernel cause a significant performance penalty
372
- #define N_DST 4 // each SIMD group works on 4 rows
373
- #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
426
+ #define N_DST 4 // each SIMD group works on 4 rows
427
+ #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
374
428
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
375
429
  //Note: This is a template, but strictly speaking it only applies to
376
430
  // quantizations where the block size is 32. It also does not
@@ -381,18 +435,23 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
381
435
  int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
382
436
  uint3 tgpig, uint tiisg, uint sgitg) {
383
437
  const int nb = ne00/QK4_0;
438
+
384
439
  const int r0 = tgpig.x;
385
440
  const int r1 = tgpig.y;
386
441
  const int im = tgpig.z;
442
+
387
443
  const int first_row = (r0 * nsg + sgitg) * nr;
444
+
388
445
  const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
446
+
389
447
  device const block_q_type * x = (device const block_q_type *) src0 + offset0;
390
448
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
391
- float yl[16]; // src1 vector cache
392
- float sumf[nr]={0.f};
393
449
 
394
- const int ix = tiisg/2;
395
- const int il = 8*(tiisg%2);
450
+ float yl[16]; // src1 vector cache
451
+ float sumf[nr] = {0.f};
452
+
453
+ const int ix = (tiisg/2);
454
+ const int il = (tiisg%2)*8;
396
455
 
397
456
  device const float * yb = y + ix * QK4_0 + il;
398
457
 
@@ -403,6 +462,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
403
462
  sumy += yb[i] + yb[i+1];
404
463
  yl[i+0] = yb[i+ 0];
405
464
  yl[i+1] = yb[i+ 1]/256.f;
465
+
406
466
  sumy += yb[i+16] + yb[i+17];
407
467
  yl[i+8] = yb[i+16]/16.f;
408
468
  yl[i+9] = yb[i+17]/4096.f;
@@ -418,12 +478,12 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
418
478
  for (int row = 0; row < nr; ++row) {
419
479
  const float tot = simd_sum(sumf[row]);
420
480
  if (tiisg == 0 && first_row + row < ne01) {
421
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
481
+ dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
422
482
  }
423
483
  }
424
484
  }
425
485
 
426
- kernel void kernel_mul_mat_q4_0_f32(
486
+ kernel void kernel_mul_mv_q4_0_f32(
427
487
  device const void * src0,
428
488
  device const float * src1,
429
489
  device float * dst,
@@ -436,12 +496,12 @@ kernel void kernel_mul_mat_q4_0_f32(
436
496
  constant int64_t & ne1[[buffer(16)]],
437
497
  constant uint & gqa[[buffer(17)]],
438
498
  uint3 tgpig[[threadgroup_position_in_grid]],
439
- uint tiisg[[thread_index_in_simdgroup]],
440
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
499
+ uint tiisg[[thread_index_in_simdgroup]],
500
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
441
501
  mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
442
502
  }
443
503
 
444
- kernel void kernel_mul_mat_q4_1_f32(
504
+ kernel void kernel_mul_mv_q4_1_f32(
445
505
  device const void * src0,
446
506
  device const float * src1,
447
507
  device float * dst,
@@ -461,7 +521,7 @@ kernel void kernel_mul_mat_q4_1_f32(
461
521
 
462
522
  #define NB_Q8_0 8
463
523
 
464
- kernel void kernel_mul_mat_q8_0_f32(
524
+ kernel void kernel_mul_mv_q8_0_f32(
465
525
  device const void * src0,
466
526
  device const float * src1,
467
527
  device float * dst,
@@ -525,7 +585,7 @@ kernel void kernel_mul_mat_q8_0_f32(
525
585
 
526
586
  #define N_F32_F32 4
527
587
 
528
- kernel void kernel_mul_mat_f32_f32(
588
+ kernel void kernel_mul_mv_f32_f32(
529
589
  device const char * src0,
530
590
  device const char * src1,
531
591
  device float * dst,
@@ -596,7 +656,7 @@ kernel void kernel_mul_mat_f32_f32(
596
656
  }
597
657
  }
598
658
 
599
- kernel void kernel_mul_mat_f16_f32_1row(
659
+ kernel void kernel_mul_mv_f16_f32_1row(
600
660
  device const char * src0,
601
661
  device const char * src1,
602
662
  device float * dst,
@@ -615,7 +675,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
615
675
  constant int64_t & ne0,
616
676
  constant int64_t & ne1,
617
677
  uint3 tgpig[[threadgroup_position_in_grid]],
618
- uint tiisg[[thread_index_in_simdgroup]]) {
678
+ uint tiisg[[thread_index_in_simdgroup]]) {
619
679
 
620
680
  const int64_t r0 = tgpig.x;
621
681
  const int64_t r1 = tgpig.y;
@@ -650,7 +710,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
650
710
 
651
711
  #define N_F16_F32 4
652
712
 
653
- kernel void kernel_mul_mat_f16_f32(
713
+ kernel void kernel_mul_mv_f16_f32(
654
714
  device const char * src0,
655
715
  device const char * src1,
656
716
  device float * dst,
@@ -722,7 +782,7 @@ kernel void kernel_mul_mat_f16_f32(
722
782
  }
723
783
 
724
784
  // Assumes row size (ne00) is a multiple of 4
725
- kernel void kernel_mul_mat_f16_f32_l4(
785
+ kernel void kernel_mul_mv_f16_f32_l4(
726
786
  device const char * src0,
727
787
  device const char * src1,
728
788
  device float * dst,
@@ -783,7 +843,9 @@ kernel void kernel_alibi_f32(
783
843
  constant uint64_t & nb1,
784
844
  constant uint64_t & nb2,
785
845
  constant uint64_t & nb3,
786
- constant float & m0,
846
+ constant float & m0,
847
+ constant float & m1,
848
+ constant int & n_heads_log2_floor,
787
849
  uint3 tgpig[[threadgroup_position_in_grid]],
788
850
  uint3 tpitg[[thread_position_in_threadgroup]],
789
851
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -799,37 +861,73 @@ kernel void kernel_alibi_f32(
799
861
  const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
800
862
 
801
863
  device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
802
- float m_k = pow(m0, i2 + 1);
864
+ float m_k;
865
+ if (i2 < n_heads_log2_floor) {
866
+ m_k = pow(m0, i2 + 1);
867
+ } else {
868
+ m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
869
+ }
803
870
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
804
871
  device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
805
872
  dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
806
873
  }
807
874
  }
808
875
 
876
+ typedef void (rope_t)(
877
+ device const void * src0,
878
+ device const int32_t * src1,
879
+ device float * dst,
880
+ constant int64_t & ne00,
881
+ constant int64_t & ne01,
882
+ constant int64_t & ne02,
883
+ constant int64_t & ne03,
884
+ constant uint64_t & nb00,
885
+ constant uint64_t & nb01,
886
+ constant uint64_t & nb02,
887
+ constant uint64_t & nb03,
888
+ constant int64_t & ne0,
889
+ constant int64_t & ne1,
890
+ constant int64_t & ne2,
891
+ constant int64_t & ne3,
892
+ constant uint64_t & nb0,
893
+ constant uint64_t & nb1,
894
+ constant uint64_t & nb2,
895
+ constant uint64_t & nb3,
896
+ constant int & n_past,
897
+ constant int & n_dims,
898
+ constant int & mode,
899
+ constant float & freq_base,
900
+ constant float & freq_scale,
901
+ uint tiitg[[thread_index_in_threadgroup]],
902
+ uint3 tptg[[threads_per_threadgroup]],
903
+ uint3 tgpig[[threadgroup_position_in_grid]]);
904
+
905
+ template<typename T>
809
906
  kernel void kernel_rope(
810
- device const void * src0,
811
- device float * dst,
812
- constant int64_t & ne00,
813
- constant int64_t & ne01,
814
- constant int64_t & ne02,
815
- constant int64_t & ne03,
816
- constant uint64_t & nb00,
817
- constant uint64_t & nb01,
818
- constant uint64_t & nb02,
819
- constant uint64_t & nb03,
820
- constant int64_t & ne0,
821
- constant int64_t & ne1,
822
- constant int64_t & ne2,
823
- constant int64_t & ne3,
824
- constant uint64_t & nb0,
825
- constant uint64_t & nb1,
826
- constant uint64_t & nb2,
827
- constant uint64_t & nb3,
828
- constant int & n_past,
829
- constant int & n_dims,
830
- constant int & mode,
831
- constant float & freq_base,
832
- constant float & freq_scale,
907
+ device const void * src0,
908
+ device const int32_t * src1,
909
+ device float * dst,
910
+ constant int64_t & ne00,
911
+ constant int64_t & ne01,
912
+ constant int64_t & ne02,
913
+ constant int64_t & ne03,
914
+ constant uint64_t & nb00,
915
+ constant uint64_t & nb01,
916
+ constant uint64_t & nb02,
917
+ constant uint64_t & nb03,
918
+ constant int64_t & ne0,
919
+ constant int64_t & ne1,
920
+ constant int64_t & ne2,
921
+ constant int64_t & ne3,
922
+ constant uint64_t & nb0,
923
+ constant uint64_t & nb1,
924
+ constant uint64_t & nb2,
925
+ constant uint64_t & nb3,
926
+ constant int & n_past,
927
+ constant int & n_dims,
928
+ constant int & mode,
929
+ constant float & freq_base,
930
+ constant float & freq_scale,
833
931
  uint tiitg[[thread_index_in_threadgroup]],
834
932
  uint3 tptg[[threads_per_threadgroup]],
835
933
  uint3 tgpig[[threadgroup_position_in_grid]]) {
@@ -839,7 +937,9 @@ kernel void kernel_rope(
839
937
 
840
938
  const bool is_neox = mode & 2;
841
939
 
842
- const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
940
+ device const int32_t * pos = src1;
941
+
942
+ const int64_t p = pos[i2];
843
943
 
844
944
  const float theta_0 = freq_scale * (float)p;
845
945
  const float inv_ndims = -1.f/n_dims;
@@ -851,11 +951,11 @@ kernel void kernel_rope(
851
951
  const float cos_theta = cos(theta);
852
952
  const float sin_theta = sin(theta);
853
953
 
854
- device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
855
- device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
954
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
955
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
856
956
 
857
- const float x0 = src[0];
858
- const float x1 = src[1];
957
+ const T x0 = src[0];
958
+ const T x1 = src[1];
859
959
 
860
960
  dst_data[0] = x0*cos_theta - x1*sin_theta;
861
961
  dst_data[1] = x0*sin_theta + x1*cos_theta;
@@ -870,8 +970,8 @@ kernel void kernel_rope(
870
970
 
871
971
  const int64_t i0 = ib*n_dims + ic/2;
872
972
 
873
- device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
874
- device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
973
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
974
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
875
975
 
876
976
  const float x0 = src[0];
877
977
  const float x1 = src[n_dims/2];
@@ -883,6 +983,9 @@ kernel void kernel_rope(
883
983
  }
884
984
  }
885
985
 
986
+ template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
987
+ template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
988
+
886
989
  kernel void kernel_cpy_f16_f16(
887
990
  device const half * src0,
888
991
  device half * dst,
@@ -1008,6 +1111,62 @@ kernel void kernel_cpy_f32_f32(
1008
1111
  }
1009
1112
  }
1010
1113
 
1114
+ kernel void kernel_concat(
1115
+ device const char * src0,
1116
+ device const char * src1,
1117
+ device char * dst,
1118
+ constant int64_t & ne00,
1119
+ constant int64_t & ne01,
1120
+ constant int64_t & ne02,
1121
+ constant int64_t & ne03,
1122
+ constant uint64_t & nb00,
1123
+ constant uint64_t & nb01,
1124
+ constant uint64_t & nb02,
1125
+ constant uint64_t & nb03,
1126
+ constant int64_t & ne10,
1127
+ constant int64_t & ne11,
1128
+ constant int64_t & ne12,
1129
+ constant int64_t & ne13,
1130
+ constant uint64_t & nb10,
1131
+ constant uint64_t & nb11,
1132
+ constant uint64_t & nb12,
1133
+ constant uint64_t & nb13,
1134
+ constant int64_t & ne0,
1135
+ constant int64_t & ne1,
1136
+ constant int64_t & ne2,
1137
+ constant int64_t & ne3,
1138
+ constant uint64_t & nb0,
1139
+ constant uint64_t & nb1,
1140
+ constant uint64_t & nb2,
1141
+ constant uint64_t & nb3,
1142
+ uint3 tgpig[[threadgroup_position_in_grid]],
1143
+ uint3 tpitg[[thread_position_in_threadgroup]],
1144
+ uint3 ntg[[threads_per_threadgroup]]) {
1145
+
1146
+ const int64_t i03 = tgpig.z;
1147
+ const int64_t i02 = tgpig.y;
1148
+ const int64_t i01 = tgpig.x;
1149
+
1150
+ const int64_t i13 = i03 % ne13;
1151
+ const int64_t i12 = i02 % ne12;
1152
+ const int64_t i11 = i01 % ne11;
1153
+
1154
+ device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
1155
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
1156
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
1157
+
1158
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1159
+ if (i02 < ne02) {
1160
+ ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
1161
+ src0_ptr += ntg.x*nb00;
1162
+ } else {
1163
+ ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
1164
+ src1_ptr += ntg.x*nb10;
1165
+ }
1166
+ dst_ptr += ntg.x*nb0;
1167
+ }
1168
+ }
1169
+
1011
1170
  //============================================ k-quants ======================================================
1012
1171
 
1013
1172
  #ifndef QK_K
@@ -1100,7 +1259,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
1100
1259
 
1101
1260
  //====================================== dot products =========================
1102
1261
 
1103
- kernel void kernel_mul_mat_q2_K_f32(
1262
+ kernel void kernel_mul_mv_q2_K_f32(
1104
1263
  device const void * src0,
1105
1264
  device const float * src1,
1106
1265
  device float * dst,
@@ -1244,7 +1403,7 @@ kernel void kernel_mul_mat_q2_K_f32(
1244
1403
  }
1245
1404
 
1246
1405
  #if QK_K == 256
1247
- kernel void kernel_mul_mat_q3_K_f32(
1406
+ kernel void kernel_mul_mv_q3_K_f32(
1248
1407
  device const void * src0,
1249
1408
  device const float * src1,
1250
1409
  device float * dst,
@@ -1273,8 +1432,8 @@ kernel void kernel_mul_mat_q3_K_f32(
1273
1432
 
1274
1433
  float yl[32];
1275
1434
 
1276
- const uint16_t kmask1 = 0x3030;
1277
- const uint16_t kmask2 = 0x0f0f;
1435
+ //const uint16_t kmask1 = 0x3030;
1436
+ //const uint16_t kmask2 = 0x0f0f;
1278
1437
 
1279
1438
  const int tid = tiisg/4;
1280
1439
  const int ix = tiisg%4;
@@ -1396,7 +1555,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1396
1555
  }
1397
1556
  }
1398
1557
  #else
1399
- kernel void kernel_mul_mat_q3_K_f32(
1558
+ kernel void kernel_mul_mv_q3_K_f32(
1400
1559
  device const void * src0,
1401
1560
  device const float * src1,
1402
1561
  device float * dst,
@@ -1467,7 +1626,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1467
1626
  #endif
1468
1627
 
1469
1628
  #if QK_K == 256
1470
- kernel void kernel_mul_mat_q4_K_f32(
1629
+ kernel void kernel_mul_mv_q4_K_f32(
1471
1630
  device const void * src0,
1472
1631
  device const float * src1,
1473
1632
  device float * dst,
@@ -1573,7 +1732,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1573
1732
  }
1574
1733
  }
1575
1734
  #else
1576
- kernel void kernel_mul_mat_q4_K_f32(
1735
+ kernel void kernel_mul_mv_q4_K_f32(
1577
1736
  device const void * src0,
1578
1737
  device const float * src1,
1579
1738
  device float * dst,
@@ -1662,7 +1821,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1662
1821
  }
1663
1822
  #endif
1664
1823
 
1665
- kernel void kernel_mul_mat_q5_K_f32(
1824
+ kernel void kernel_mul_mv_q5_K_f32(
1666
1825
  device const void * src0,
1667
1826
  device const float * src1,
1668
1827
  device float * dst,
@@ -1835,7 +1994,7 @@ kernel void kernel_mul_mat_q5_K_f32(
1835
1994
 
1836
1995
  }
1837
1996
 
1838
- kernel void kernel_mul_mat_q6_K_f32(
1997
+ kernel void kernel_mul_mv_q6_K_f32(
1839
1998
  device const void * src0,
1840
1999
  device const float * src1,
1841
2000
  device float * dst,
@@ -2173,7 +2332,7 @@ kernel void kernel_get_rows(
2173
2332
  }
2174
2333
 
2175
2334
  #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
2176
- #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
2335
+ #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
2177
2336
  #define BLOCK_SIZE_K 32
2178
2337
  #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
2179
2338
  #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
@@ -2210,9 +2369,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
2210
2369
  const uint r0 = tgpig.y;
2211
2370
  const uint r1 = tgpig.x;
2212
2371
  const uint im = tgpig.z;
2372
+
2213
2373
  // if this block is of 64x32 shape or smaller
2214
2374
  short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
2215
2375
  short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
2376
+
2216
2377
  // a thread shouldn't load data outside of the matrix
2217
2378
  short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
2218
2379
  short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
@@ -2236,26 +2397,30 @@ kernel void kernel_mul_mm(device const uchar * src0,
2236
2397
  + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
2237
2398
 
2238
2399
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
2239
- //load data and store to threadgroup memory
2400
+ // load data and store to threadgroup memory
2240
2401
  half4x4 temp_a;
2241
2402
  dequantize_func(x, il, temp_a);
2242
2403
  threadgroup_barrier(mem_flags::mem_threadgroup);
2404
+
2243
2405
  #pragma unroll(16)
2244
2406
  for (int i = 0; i < 16; i++) {
2245
2407
  *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
2246
- + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
2247
- + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
2408
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
2409
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
2248
2410
  }
2249
- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
2250
- = *((device float2x4 *)y);
2411
+
2412
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
2413
+
2251
2414
  il = (il + 2 < nl) ? il + 2 : il % 2;
2252
2415
  x = (il < 2) ? x + (2+nl-1)/nl : x;
2253
2416
  y += BLOCK_SIZE_K;
2254
2417
 
2255
2418
  threadgroup_barrier(mem_flags::mem_threadgroup);
2256
- //load matrices from threadgroup memory and conduct outer products
2419
+
2420
+ // load matrices from threadgroup memory and conduct outer products
2257
2421
  threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
2258
2422
  threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
2423
+
2259
2424
  #pragma unroll(4)
2260
2425
  for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
2261
2426
  #pragma unroll(4)
@@ -2270,6 +2435,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
2270
2435
 
2271
2436
  lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
2272
2437
  lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
2438
+
2273
2439
  #pragma unroll(8)
2274
2440
  for (int i = 0; i < 8; i++){
2275
2441
  simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
@@ -2278,25 +2444,26 @@ kernel void kernel_mul_mm(device const uchar * src0,
2278
2444
  }
2279
2445
 
2280
2446
  if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
2281
- device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
2282
- + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
2447
+ device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
2448
+ + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
2283
2449
  for (int i = 0; i < 8; i++) {
2284
2450
  simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
2285
2451
  }
2286
2452
  } else {
2287
2453
  // block is smaller than 64x32, we should avoid writing data outside of the matrix
2288
2454
  threadgroup_barrier(mem_flags::mem_threadgroup);
2289
- threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
2455
+ threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
2290
2456
  + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
2291
2457
  for (int i = 0; i < 8; i++) {
2292
2458
  simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
2293
2459
  }
2294
2460
 
2295
2461
  threadgroup_barrier(mem_flags::mem_threadgroup);
2296
- device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
2297
- if (sgitg==0) {
2462
+
2463
+ device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
2464
+ if (sgitg == 0) {
2298
2465
  for (int i = 0; i < n_rows; i++) {
2299
- for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
2466
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
2300
2467
  *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
2301
2468
  }
2302
2469
  }
Binary file
Binary file
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "node-llama-cpp",
3
- "version": "2.5.0",
3
+ "version": "2.6.0",
4
4
  "description": "Run AI models locally on your machine with node.js bindings for llama.cpp",
5
5
  "main": "dist/index.js",
6
6
  "type": "module",
@@ -44,7 +44,6 @@
44
44
  "prebuild": "rimraf ./dist ./tsconfig.tsbuildinfo",
45
45
  "build": "tsc --build tsconfig.json --force",
46
46
  "addPostinstallScript": "npm pkg set scripts.postinstall=\"node ./dist/cli/cli.js postinstall\"",
47
- "generate-docs": "typedoc --plugin typedoc-plugin-mdn-links",
48
47
  "prewatch": "rimraf ./dist ./tsconfig.tsbuildinfo",
49
48
  "watch": "tsc --build tsconfig.json --watch --force",
50
49
  "cmake-js-llama": "cd llama && cmake-js",
@@ -53,9 +52,13 @@
53
52
  "lint": "npm run lint:eslint",
54
53
  "lint:eslint": "eslint --ext .js --ext .ts .",
55
54
  "format": "npm run lint:eslint -- --fix",
56
- "dev:setup": "npm run build && node ./dist/cli/cli.js download",
55
+ "dev:setup": "npm run build && node ./dist/cli/cli.js download && npm run docs:generateTypedoc",
57
56
  "dev:build": "npm run build && node ./dist/cli/cli.js build",
58
57
  "clean": "rm -rf ./node_modules ./dist ./tsconfig.tsbuildinfo",
58
+ "docs:generateTypedoc": "typedoc && rimraf ./docs/api/index.md ./docs/api/exports.md",
59
+ "docs:dev": "npm run docs:generateTypedoc && vitepress dev",
60
+ "docs:build": "npm run docs:generateTypedoc && vitepress build",
61
+ "docs:preview": "npm run docs:generateTypedoc && vitepress preview",
59
62
  "postinstall": "node ./dist/cli/cli.js postinstall"
60
63
  },
61
64
  "repository": {
@@ -112,9 +115,12 @@
112
115
  "semantic-release": "^21.0.7",
113
116
  "ts-node": "^10.9.1",
114
117
  "tslib": "^2.6.1",
115
- "typedoc": "^0.24.8",
118
+ "typedoc": "^0.25.1",
119
+ "typedoc-plugin-markdown": "^4.0.0-next.22",
116
120
  "typedoc-plugin-mdn-links": "^3.1.0",
121
+ "typedoc-vitepress-theme": "^1.0.0-next.3",
117
122
  "typescript": "^5.1.6",
123
+ "vitepress": "^1.0.0-rc.20",
118
124
  "zx": "^7.2.3"
119
125
  },
120
126
  "dependencies": {
@@ -1 +0,0 @@
1
- {"version":3,"file":"ChatMLPromptWrapper.js","sourceRoot":"","sources":["../../src/chatWrappers/ChatMLPromptWrapper.ts"],"names":[],"mappings":"AAAA,OAAO,EAAC,iBAAiB,EAAC,MAAM,yBAAyB,CAAC;AAC1D,OAAO,EAAC,iBAAiB,EAAC,MAAM,+BAA+B,CAAC;AAEhE,0GAA0G;AAC1G,MAAM,OAAO,mBAAoB,SAAQ,iBAAiB;IACtC,WAAW,GAAW,QAAQ,CAAC;IAE/B,UAAU,CAAC,MAAc,EAAE,EAAC,YAAY,EAAE,WAAW,EAAE,cAAc,EAAE,oBAAoB,EAE1G;QACG,MAAM,qBAAqB,GAAG,CAAC,cAAc,IAAI,EAAE,CAAC,GAAG,CAAC,oBAAoB,IAAI,EAAE,CAAC,CAAC;QAEpF,IAAI,WAAW,KAAK,CAAC,IAAI,YAAY,IAAI,EAAE;YACvC,OAAO,CAAC,iBAAiB,CAAC,qBAAqB,EAAE,sBAAsB,CAAC,IAAI,sBAAsB,CAAC;gBAC/F,YAAY,GAAG,gCAAgC,GAAG,MAAM,GAAG,qCAAqC,CAAC;;YAErG,OAAO,CAAC,iBAAiB,CAAC,qBAAqB,EAAE,gCAAgC,CAAC,IAAI,gCAAgC,CAAC;gBACnH,MAAM,GAAG,qCAAqC,CAAC;IAC3D,CAAC;IAEe,cAAc;QAC1B,OAAO,CAAC,YAAY,CAAC,CAAC;IAC1B,CAAC;IAEe,oBAAoB;QAChC,OAAO,YAAY,CAAC;IACxB,CAAC;CACJ"}