llama_cpp 0.14.7 → 0.15.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -40,7 +40,7 @@
40
40
  #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
41
41
 
42
42
  #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
43
- #define LLAMA_SESSION_VERSION 5
43
+ #define LLAMA_SESSION_VERSION 6
44
44
 
45
45
  #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
46
46
  #define LLAMA_STATE_SEQ_VERSION 1
@@ -69,6 +69,23 @@ extern "C" {
69
69
  LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
70
70
  };
71
71
 
72
+ // pre-tokenization types
73
+ enum llama_vocab_pre_type {
74
+ LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
75
+ LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
76
+ LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
77
+ LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
78
+ LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
79
+ LLAMA_VOCAB_PRE_TYPE_MPT = 5,
80
+ LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
81
+ LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
82
+ LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
83
+ LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
84
+ LLAMA_VOCAB_PRE_TYPE_QWEN2 = 10,
85
+ LLAMA_VOCAB_PRE_TYPE_OLMO = 11,
86
+ LLAMA_VOCAB_PRE_TYPE_DBRX = 12,
87
+ };
88
+
72
89
  // note: these values should be synchronized with ggml_rope
73
90
  // TODO: maybe move this enum to ggml.h (ggml_rope_type)
74
91
  enum llama_rope_type {
@@ -122,6 +139,7 @@ extern "C" {
122
139
  LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors
123
140
  LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
124
141
  LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
142
+ LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
125
143
 
126
144
  LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
127
145
  };
@@ -159,7 +177,7 @@ extern "C" {
159
177
  bool sorted;
160
178
  } llama_token_data_array;
161
179
 
162
- typedef bool (*llama_progress_callback)(float progress, void *ctx);
180
+ typedef bool (*llama_progress_callback)(float progress, void * user_data);
163
181
 
164
182
  // Input data for llama_decode
165
183
  // A llama_batch object can contain input about one or many sequences
@@ -195,15 +213,19 @@ extern "C" {
195
213
  LLAMA_KV_OVERRIDE_TYPE_INT,
196
214
  LLAMA_KV_OVERRIDE_TYPE_FLOAT,
197
215
  LLAMA_KV_OVERRIDE_TYPE_BOOL,
216
+ LLAMA_KV_OVERRIDE_TYPE_STR,
198
217
  };
199
218
 
200
219
  struct llama_model_kv_override {
201
- char key[128];
202
220
  enum llama_model_kv_override_type tag;
221
+
222
+ char key[128];
223
+
203
224
  union {
204
- int64_t int_value;
205
- double float_value;
206
- bool bool_value;
225
+ int64_t val_i64;
226
+ double val_f64;
227
+ bool val_bool;
228
+ char val_str[128];
207
229
  };
208
230
  };
209
231
 
@@ -232,9 +254,10 @@ extern "C" {
232
254
  const struct llama_model_kv_override * kv_overrides;
233
255
 
234
256
  // Keep the booleans together to avoid misalignment during copy-by-value.
235
- bool vocab_only; // only load the vocabulary, no weights
236
- bool use_mmap; // use mmap if possible
237
- bool use_mlock; // force system to keep model in RAM
257
+ bool vocab_only; // only load the vocabulary, no weights
258
+ bool use_mmap; // use mmap if possible
259
+ bool use_mlock; // force system to keep model in RAM
260
+ bool check_tensors; // validate model tensor data
238
261
  };
239
262
 
240
263
  struct llama_context_params {
@@ -270,6 +293,7 @@ extern "C" {
270
293
  bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
271
294
  bool embeddings; // if true, extract embeddings (together with logits)
272
295
  bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
296
+ bool flash_attn; // whether to use flash attention
273
297
 
274
298
  // Abort callback
275
299
  // if it returns true, execution of llama_decode() will be aborted
@@ -525,7 +549,7 @@ extern "C" {
525
549
  // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
526
550
  LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
527
551
 
528
- // Clear the KV cache
552
+ // Clear the KV cache - both cell info is erased and KV data is zeroed
529
553
  LLAMA_API void llama_kv_cache_clear(
530
554
  struct llama_context * ctx);
531
555
 
@@ -1,6 +1,3 @@
1
- // -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
2
- // vi: set et ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi
3
- //
4
1
  // Copyright 2024 Mozilla Foundation
5
2
  //
6
3
  // Permission is hereby granted, free of charge, to any person obtaining
@@ -50,7 +47,6 @@
50
47
  #pragma GCC diagnostic ignored "-Wignored-attributes"
51
48
 
52
49
  #include "sgemm.h"
53
- #include <algorithm>
54
50
  #include "ggml-impl.h"
55
51
  #include "ggml-quants.h"
56
52
 
@@ -243,23 +239,23 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
243
239
  template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
244
240
  class tinyBLAS {
245
241
  public:
246
- tinyBLAS(int k,
247
- const TA *A, int lda,
248
- const TB *B, int ldb,
249
- TC *C, int ldc,
242
+ tinyBLAS(int64_t k,
243
+ const TA *A, int64_t lda,
244
+ const TB *B, int64_t ldb,
245
+ TC *C, int64_t ldc,
250
246
  int ith, int nth)
251
247
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
252
248
  }
253
249
 
254
- void matmul(int m, int n, int task) {
250
+ void matmul(int64_t m, int64_t n, int task) {
255
251
  if (task == GGML_TASK_TYPE_COMPUTE)
256
252
  mnpack(0, m, 0, n);
257
253
  }
258
254
 
259
255
  private:
260
- NOINLINE void mnpack(int m0, int m, int n0, int n) {
261
- int mc, nc, mp, np;
262
- switch ((std::min(m - m0, 5) << 4) | std::min(n - n0, 5)) {
256
+ NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
257
+ int64_t mc, nc, mp, np;
258
+ switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
263
259
  #if VECTOR_REGISTERS == 32
264
260
  case 0x55:
265
261
  mc = 5;
@@ -409,27 +405,27 @@ class tinyBLAS {
409
405
  }
410
406
 
411
407
  template <int RM, int RN>
412
- NOINLINE void gemm(int m0, int m, int n0, int n) {
413
- int ytiles = (m - m0) / RM;
414
- int xtiles = (n - n0) / RN;
415
- int tiles = xtiles * ytiles;
416
- int duty = (tiles + nth - 1) / nth;
417
- int start = duty * ith;
418
- int end = start + duty;
408
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
409
+ int64_t ytiles = (m - m0) / RM;
410
+ int64_t xtiles = (n - n0) / RN;
411
+ int64_t tiles = xtiles * ytiles;
412
+ int64_t duty = (tiles + nth - 1) / nth;
413
+ int64_t start = duty * ith;
414
+ int64_t end = start + duty;
419
415
  if (end > tiles)
420
416
  end = tiles;
421
- for (int job = start; job < end; ++job) {
422
- int ii = m0 + job / xtiles * RM;
423
- int jj = n0 + job % xtiles * RN;
417
+ for (int64_t job = start; job < end; ++job) {
418
+ int64_t ii = m0 + job / xtiles * RM;
419
+ int64_t jj = n0 + job % xtiles * RN;
424
420
  D Cv[RN][RM] = {};
425
- for (int l = 0; l < k; l += KN)
426
- for (int j = 0; j < RN; ++j)
427
- for (int i = 0; i < RM; ++i)
421
+ for (int64_t l = 0; l < k; l += KN)
422
+ for (int64_t j = 0; j < RN; ++j)
423
+ for (int64_t i = 0; i < RM; ++i)
428
424
  Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
429
425
  load<V>(B + ldb * (jj + j) + l),
430
426
  Cv[j][i]);
431
- for (int j = 0; j < RN; ++j)
432
- for (int i = 0; i < RM; ++i)
427
+ for (int64_t j = 0; j < RN; ++j)
428
+ for (int64_t i = 0; i < RM; ++i)
433
429
  C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
434
430
  }
435
431
  }
@@ -437,10 +433,10 @@ class tinyBLAS {
437
433
  const TA *const A;
438
434
  const TB *const B;
439
435
  TC *const C;
440
- const int k;
441
- const int lda;
442
- const int ldb;
443
- const int ldc;
436
+ const int64_t k;
437
+ const int64_t lda;
438
+ const int64_t ldb;
439
+ const int64_t ldc;
444
440
  const int ith;
445
441
  const int nth;
446
442
  };
@@ -452,23 +448,23 @@ class tinyBLAS {
452
448
  template <typename TA>
453
449
  class tinyBLAS_Q0_ARM {
454
450
  public:
455
- tinyBLAS_Q0_ARM(int k,
456
- const TA *A, int lda,
457
- const block_q8_0 *B, int ldb,
458
- float *C, int ldc,
451
+ tinyBLAS_Q0_ARM(int64_t k,
452
+ const TA *A, int64_t lda,
453
+ const block_q8_0 *B, int64_t ldb,
454
+ float *C, int64_t ldc,
459
455
  int ith, int nth)
460
456
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
461
457
  }
462
458
 
463
- void matmul(int m, int n, int task) {
459
+ void matmul(int64_t m, int64_t n, int task) {
464
460
  if (task == GGML_TASK_TYPE_COMPUTE)
465
461
  mnpack(0, m, 0, n);
466
462
  }
467
463
 
468
464
  private:
469
- NOINLINE void mnpack(int m0, int m, int n0, int n) {
470
- int mc, nc, mp, np;
471
- switch ((std::min(m - m0, 3) << 4) | std::min(n - n0, 3)) {
465
+ NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
466
+ int64_t mc, nc, mp, np;
467
+ switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
472
468
  case 0x33:
473
469
  mc = 3;
474
470
  nc = 3;
@@ -524,22 +520,22 @@ class tinyBLAS_Q0_ARM {
524
520
  }
525
521
 
526
522
  template <int RM, int RN>
527
- NOINLINE void gemm(int m0, int m, int n0, int n) {
528
- int ytiles = (m - m0) / RM;
529
- int xtiles = (n - n0) / RN;
530
- int tiles = xtiles * ytiles;
531
- int duty = (tiles + nth - 1) / nth;
532
- int start = duty * ith;
533
- int end = start + duty;
523
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
524
+ int64_t ytiles = (m - m0) / RM;
525
+ int64_t xtiles = (n - n0) / RN;
526
+ int64_t tiles = xtiles * ytiles;
527
+ int64_t duty = (tiles + nth - 1) / nth;
528
+ int64_t start = duty * ith;
529
+ int64_t end = start + duty;
534
530
  if (end > tiles)
535
531
  end = tiles;
536
- for (int job = start; job < end; ++job) {
537
- int ii = m0 + job / xtiles * RM;
538
- int jj = n0 + job % xtiles * RN;
532
+ for (int64_t job = start; job < end; ++job) {
533
+ int64_t ii = m0 + job / xtiles * RM;
534
+ int64_t jj = n0 + job % xtiles * RN;
539
535
  float32x4_t Cv[RN][RM] = {};
540
- for (int l = 0; l < k; ++l)
541
- for (int j = 0; j < RN; ++j)
542
- for (int i = 0; i < RM; ++i)
536
+ for (int64_t l = 0; l < k; ++l)
537
+ for (int64_t j = 0; j < RN; ++j)
538
+ for (int64_t i = 0; i < RM; ++i)
543
539
  Cv[j][i] = vmlaq_n_f32(Cv[j][i],
544
540
  vcvtq_f32_s32(vdotq_s32(
545
541
  vdotq_s32(vdupq_n_s32(0),
@@ -549,8 +545,8 @@ class tinyBLAS_Q0_ARM {
549
545
  load_hi(B + ldb * (jj + j) + l))),
550
546
  unhalf(A[lda * (ii + i) + l].d) *
551
547
  unhalf(B[ldb * (jj + j) + l].d));
552
- for (int j = 0; j < RN; ++j)
553
- for (int i = 0; i < RM; ++i)
548
+ for (int64_t j = 0; j < RN; ++j)
549
+ for (int64_t i = 0; i < RM; ++i)
554
550
  C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
555
551
  }
556
552
  }
@@ -577,36 +573,36 @@ class tinyBLAS_Q0_ARM {
577
573
  const TA *const A;
578
574
  const block_q8_0 *const B;
579
575
  float *const C;
580
- const int k;
581
- const int lda;
582
- const int ldb;
583
- const int ldc;
576
+ const int64_t k;
577
+ const int64_t lda;
578
+ const int64_t ldb;
579
+ const int64_t ldc;
584
580
  const int ith;
585
581
  const int nth;
586
582
  };
587
583
  #endif // __ARM_FEATURE_DOTPROD
588
584
 
589
- #if defined(__AVX2__) || defined(__AVX512F__)
585
+ #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
590
586
  template <typename TA, typename TB, typename TC>
591
- class tinyBLAS_Q0_AVX2 {
587
+ class tinyBLAS_Q0_AVX {
592
588
  public:
593
- tinyBLAS_Q0_AVX2(int k,
594
- const TA *A, int lda,
595
- const TB *B, int ldb,
596
- TC *C, int ldc,
597
- int ith, int nth)
589
+ tinyBLAS_Q0_AVX(int64_t k,
590
+ const TA *A, int64_t lda,
591
+ const TB *B, int64_t ldb,
592
+ TC *C, int64_t ldc,
593
+ int ith, int nth)
598
594
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
599
595
  }
600
596
 
601
- void matmul(int m, int n, int task) {
597
+ void matmul(int64_t m, int64_t n, int task) {
602
598
  if (task == GGML_TASK_TYPE_COMPUTE)
603
599
  mnpack(0, m, 0, n);
604
600
  }
605
601
 
606
602
  private:
607
- void mnpack(int m0, int m, int n0, int n) {
608
- int mc, nc, mp, np;
609
- switch ((std::min(m - m0, 4) << 4) | std::min(n - n0, 4)) {
603
+ void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
604
+ int64_t mc, nc, mp, np;
605
+ switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
610
606
  #if VECTOR_REGISTERS == 32
611
607
  case 0x44:
612
608
  mc = 4;
@@ -714,31 +710,51 @@ class tinyBLAS_Q0_AVX2 {
714
710
  }
715
711
 
716
712
  template <int RM, int RN>
717
- NOINLINE void gemm(int m0, int m, int n0, int n) {
718
- int ytiles = (m - m0) / RM;
719
- int xtiles = (n - n0) / RN;
720
- int tiles = xtiles * ytiles;
721
- int duty = (tiles + nth - 1) / nth;
722
- int start = duty * ith;
723
- int end = start + duty;
713
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
714
+ int64_t ytiles = (m - m0) / RM;
715
+ int64_t xtiles = (n - n0) / RN;
716
+ int64_t tiles = xtiles * ytiles;
717
+ int64_t duty = (tiles + nth - 1) / nth;
718
+ int64_t start = duty * ith;
719
+ int64_t end = start + duty;
724
720
  if (end > tiles)
725
721
  end = tiles;
726
- for (int job = start; job < end; ++job) {
727
- int ii = m0 + job / xtiles * RM;
728
- int jj = n0 + job % xtiles * RN;
722
+ for (int64_t job = start; job < end; ++job) {
723
+ int64_t ii = m0 + job / xtiles * RM;
724
+ int64_t jj = n0 + job % xtiles * RN;
729
725
  __m256 Cv[RN][RM] = {};
730
- for (int l = 0; l < k; ++l)
731
- for (int j = 0; j < RN; ++j)
732
- for (int i = 0; i < RM; ++i)
726
+ for (int64_t l = 0; l < k; ++l)
727
+ for (int64_t j = 0; j < RN; ++j)
728
+ for (int64_t i = 0; i < RM; ++i) {
729
+ #if defined(__AVX2__)
730
+ __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
731
+ load(A + lda * (ii + i) + l)),
732
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
733
+ load(A + lda * (ii + i) + l)));
734
+ #else
735
+ __m128i ali0 = load0(A + lda * (ii + i) + l);
736
+ __m128i ali1 = load1(A + lda * (ii + i) + l);
737
+ __m128i blj0 = load0(B + ldb * (jj + j) + l);
738
+ __m128i blj1 = load1(B + ldb * (jj + j) + l);
739
+
740
+ __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
741
+ __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
742
+ __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
743
+ __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
744
+
745
+ // updot
746
+ const __m128i oneFill = _mm_set1_epi16(1);
747
+ __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
748
+ __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
749
+ __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
750
+ #endif
733
751
  Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
734
752
  unhalf(B[ldb * (jj + j) + l].d)),
735
- updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
736
- load(A + lda * (ii + i) + l)),
737
- _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
738
- load(A + lda * (ii + i) + l))),
739
- Cv[j][i]);
740
- for (int j = 0; j < RN; ++j)
741
- for (int i = 0; i < RM; ++i)
753
+ udTmp,
754
+ Cv[j][i]);
755
+ }
756
+ for (int64_t j = 0; j < RN; ++j)
757
+ for (int64_t i = 0; i < RM; ++i)
742
758
  C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
743
759
  }
744
760
  }
@@ -747,10 +763,28 @@ class tinyBLAS_Q0_AVX2 {
747
763
  return _mm256_loadu_si256((const __m256i *)b->qs);
748
764
  }
749
765
 
766
+ inline __m128i load0(const block_q8_0 *b) {
767
+ return _mm_loadu_si128((const __m128i *)b->qs);
768
+ }
769
+
770
+ inline __m128i load1(const block_q8_0 *b) {
771
+ return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
772
+ }
773
+
750
774
  inline __m256i load(const block_q4_0 *b) {
751
775
  return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
752
776
  }
753
777
 
778
+ inline __m128i load0(const block_q4_0 *b) {
779
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
780
+ return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
781
+ }
782
+
783
+ inline __m128i load1(const block_q4_0 *b) {
784
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
785
+ return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
786
+ }
787
+
754
788
  inline __m256 updot(__m256i u, __m256i s) {
755
789
  __m256i res;
756
790
  #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
@@ -771,14 +805,14 @@ class tinyBLAS_Q0_AVX2 {
771
805
  const TA *const A;
772
806
  const TB *const B;
773
807
  TC *const C;
774
- const int k;
775
- const int lda;
776
- const int ldb;
777
- const int ldc;
808
+ const int64_t k;
809
+ const int64_t lda;
810
+ const int64_t ldb;
811
+ const int64_t ldc;
778
812
  const int ith;
779
813
  const int nth;
780
814
  };
781
- #endif // __AVX2__
815
+ #endif // __AVX__
782
816
 
783
817
  } // namespace
784
818
 
@@ -813,8 +847,8 @@ class tinyBLAS_Q0_AVX2 {
813
847
  * @param Ctype is GGML data type of `C`
814
848
  * @return true if this function was able to service the matmul request
815
849
  */
816
- bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, int ldb, void *C,
817
- int ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
850
+ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
851
+ int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
818
852
 
819
853
  assert(m >= 0);
820
854
  assert(n >= 0);
@@ -824,9 +858,6 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B,
824
858
  assert(ldc >= m);
825
859
  assert(nth > 0);
826
860
  assert(ith < nth);
827
- assert(1ll * lda * m <= 0x7fffffff);
828
- assert(1ll * ldb * n <= 0x7fffffff);
829
- assert(1ll * ldc * n <= 0x7fffffff);
830
861
 
831
862
  if (Ctype != GGML_TYPE_F32)
832
863
  return false;
@@ -932,8 +963,8 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B,
932
963
  case GGML_TYPE_Q8_0: {
933
964
  if (Btype != GGML_TYPE_Q8_0)
934
965
  return false;
935
- #if defined(__AVX2__) || defined(__AVX512F__)
936
- tinyBLAS_Q0_AVX2<block_q8_0, block_q8_0, float> tb{
966
+ #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
967
+ tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
937
968
  k, (const block_q8_0 *)A, lda,
938
969
  (const block_q8_0 *)B, ldb,
939
970
  (float *)C, ldc,
@@ -956,8 +987,8 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B,
956
987
  case GGML_TYPE_Q4_0: {
957
988
  if (Btype != GGML_TYPE_Q8_0)
958
989
  return false;
959
- #if defined(__AVX2__) || defined(__AVX512F__)
960
- tinyBLAS_Q0_AVX2<block_q4_0, block_q8_0, float> tb{
990
+ #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
991
+ tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
961
992
  k, (const block_q4_0 *)A, lda,
962
993
  (const block_q8_0 *)B, ldb,
963
994
  (float *)C, ldc,
@@ -1,11 +1,13 @@
1
1
  #pragma once
2
+ #include <stdint.h>
2
3
  #include <stdbool.h>
3
4
  #ifdef __cplusplus
4
5
  extern "C" {
5
6
  #endif
6
7
 
7
- bool llamafile_sgemm(int, int, int, const void *, int, const void *, int,
8
- void *, int, int, int, int, int, int, int);
8
+ bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t,
9
+ const void *, int64_t, void *, int64_t, int, int,
10
+ int, int, int, int);
9
11
 
10
12
  #ifdef __cplusplus
11
13
  }