llama_cpp 0.14.7 → 0.15.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -40,7 +40,7 @@
40
40
  #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
41
41
 
42
42
  #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
43
- #define LLAMA_SESSION_VERSION 5
43
+ #define LLAMA_SESSION_VERSION 6
44
44
 
45
45
  #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
46
46
  #define LLAMA_STATE_SEQ_VERSION 1
@@ -69,6 +69,23 @@ extern "C" {
69
69
  LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
70
70
  };
71
71
 
72
+ // pre-tokenization types
73
+ enum llama_vocab_pre_type {
74
+ LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
75
+ LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
76
+ LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
77
+ LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
78
+ LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
79
+ LLAMA_VOCAB_PRE_TYPE_MPT = 5,
80
+ LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
81
+ LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
82
+ LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
83
+ LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
84
+ LLAMA_VOCAB_PRE_TYPE_QWEN2 = 10,
85
+ LLAMA_VOCAB_PRE_TYPE_OLMO = 11,
86
+ LLAMA_VOCAB_PRE_TYPE_DBRX = 12,
87
+ };
88
+
72
89
  // note: these values should be synchronized with ggml_rope
73
90
  // TODO: maybe move this enum to ggml.h (ggml_rope_type)
74
91
  enum llama_rope_type {
@@ -122,6 +139,7 @@ extern "C" {
122
139
  LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors
123
140
  LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
124
141
  LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
142
+ LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
125
143
 
126
144
  LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
127
145
  };
@@ -159,7 +177,7 @@ extern "C" {
159
177
  bool sorted;
160
178
  } llama_token_data_array;
161
179
 
162
- typedef bool (*llama_progress_callback)(float progress, void *ctx);
180
+ typedef bool (*llama_progress_callback)(float progress, void * user_data);
163
181
 
164
182
  // Input data for llama_decode
165
183
  // A llama_batch object can contain input about one or many sequences
@@ -195,15 +213,19 @@ extern "C" {
195
213
  LLAMA_KV_OVERRIDE_TYPE_INT,
196
214
  LLAMA_KV_OVERRIDE_TYPE_FLOAT,
197
215
  LLAMA_KV_OVERRIDE_TYPE_BOOL,
216
+ LLAMA_KV_OVERRIDE_TYPE_STR,
198
217
  };
199
218
 
200
219
  struct llama_model_kv_override {
201
- char key[128];
202
220
  enum llama_model_kv_override_type tag;
221
+
222
+ char key[128];
223
+
203
224
  union {
204
- int64_t int_value;
205
- double float_value;
206
- bool bool_value;
225
+ int64_t val_i64;
226
+ double val_f64;
227
+ bool val_bool;
228
+ char val_str[128];
207
229
  };
208
230
  };
209
231
 
@@ -232,9 +254,10 @@ extern "C" {
232
254
  const struct llama_model_kv_override * kv_overrides;
233
255
 
234
256
  // Keep the booleans together to avoid misalignment during copy-by-value.
235
- bool vocab_only; // only load the vocabulary, no weights
236
- bool use_mmap; // use mmap if possible
237
- bool use_mlock; // force system to keep model in RAM
257
+ bool vocab_only; // only load the vocabulary, no weights
258
+ bool use_mmap; // use mmap if possible
259
+ bool use_mlock; // force system to keep model in RAM
260
+ bool check_tensors; // validate model tensor data
238
261
  };
239
262
 
240
263
  struct llama_context_params {
@@ -270,6 +293,7 @@ extern "C" {
270
293
  bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
271
294
  bool embeddings; // if true, extract embeddings (together with logits)
272
295
  bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
296
+ bool flash_attn; // whether to use flash attention
273
297
 
274
298
  // Abort callback
275
299
  // if it returns true, execution of llama_decode() will be aborted
@@ -525,7 +549,7 @@ extern "C" {
525
549
  // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
526
550
  LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
527
551
 
528
- // Clear the KV cache
552
+ // Clear the KV cache - both cell info is erased and KV data is zeroed
529
553
  LLAMA_API void llama_kv_cache_clear(
530
554
  struct llama_context * ctx);
531
555
 
@@ -1,6 +1,3 @@
1
- // -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
2
- // vi: set et ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi
3
- //
4
1
  // Copyright 2024 Mozilla Foundation
5
2
  //
6
3
  // Permission is hereby granted, free of charge, to any person obtaining
@@ -50,7 +47,6 @@
50
47
  #pragma GCC diagnostic ignored "-Wignored-attributes"
51
48
 
52
49
  #include "sgemm.h"
53
- #include <algorithm>
54
50
  #include "ggml-impl.h"
55
51
  #include "ggml-quants.h"
56
52
 
@@ -243,23 +239,23 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
243
239
  template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
244
240
  class tinyBLAS {
245
241
  public:
246
- tinyBLAS(int k,
247
- const TA *A, int lda,
248
- const TB *B, int ldb,
249
- TC *C, int ldc,
242
+ tinyBLAS(int64_t k,
243
+ const TA *A, int64_t lda,
244
+ const TB *B, int64_t ldb,
245
+ TC *C, int64_t ldc,
250
246
  int ith, int nth)
251
247
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
252
248
  }
253
249
 
254
- void matmul(int m, int n, int task) {
250
+ void matmul(int64_t m, int64_t n, int task) {
255
251
  if (task == GGML_TASK_TYPE_COMPUTE)
256
252
  mnpack(0, m, 0, n);
257
253
  }
258
254
 
259
255
  private:
260
- NOINLINE void mnpack(int m0, int m, int n0, int n) {
261
- int mc, nc, mp, np;
262
- switch ((std::min(m - m0, 5) << 4) | std::min(n - n0, 5)) {
256
+ NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
257
+ int64_t mc, nc, mp, np;
258
+ switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
263
259
  #if VECTOR_REGISTERS == 32
264
260
  case 0x55:
265
261
  mc = 5;
@@ -409,27 +405,27 @@ class tinyBLAS {
409
405
  }
410
406
 
411
407
  template <int RM, int RN>
412
- NOINLINE void gemm(int m0, int m, int n0, int n) {
413
- int ytiles = (m - m0) / RM;
414
- int xtiles = (n - n0) / RN;
415
- int tiles = xtiles * ytiles;
416
- int duty = (tiles + nth - 1) / nth;
417
- int start = duty * ith;
418
- int end = start + duty;
408
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
409
+ int64_t ytiles = (m - m0) / RM;
410
+ int64_t xtiles = (n - n0) / RN;
411
+ int64_t tiles = xtiles * ytiles;
412
+ int64_t duty = (tiles + nth - 1) / nth;
413
+ int64_t start = duty * ith;
414
+ int64_t end = start + duty;
419
415
  if (end > tiles)
420
416
  end = tiles;
421
- for (int job = start; job < end; ++job) {
422
- int ii = m0 + job / xtiles * RM;
423
- int jj = n0 + job % xtiles * RN;
417
+ for (int64_t job = start; job < end; ++job) {
418
+ int64_t ii = m0 + job / xtiles * RM;
419
+ int64_t jj = n0 + job % xtiles * RN;
424
420
  D Cv[RN][RM] = {};
425
- for (int l = 0; l < k; l += KN)
426
- for (int j = 0; j < RN; ++j)
427
- for (int i = 0; i < RM; ++i)
421
+ for (int64_t l = 0; l < k; l += KN)
422
+ for (int64_t j = 0; j < RN; ++j)
423
+ for (int64_t i = 0; i < RM; ++i)
428
424
  Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
429
425
  load<V>(B + ldb * (jj + j) + l),
430
426
  Cv[j][i]);
431
- for (int j = 0; j < RN; ++j)
432
- for (int i = 0; i < RM; ++i)
427
+ for (int64_t j = 0; j < RN; ++j)
428
+ for (int64_t i = 0; i < RM; ++i)
433
429
  C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
434
430
  }
435
431
  }
@@ -437,10 +433,10 @@ class tinyBLAS {
437
433
  const TA *const A;
438
434
  const TB *const B;
439
435
  TC *const C;
440
- const int k;
441
- const int lda;
442
- const int ldb;
443
- const int ldc;
436
+ const int64_t k;
437
+ const int64_t lda;
438
+ const int64_t ldb;
439
+ const int64_t ldc;
444
440
  const int ith;
445
441
  const int nth;
446
442
  };
@@ -452,23 +448,23 @@ class tinyBLAS {
452
448
  template <typename TA>
453
449
  class tinyBLAS_Q0_ARM {
454
450
  public:
455
- tinyBLAS_Q0_ARM(int k,
456
- const TA *A, int lda,
457
- const block_q8_0 *B, int ldb,
458
- float *C, int ldc,
451
+ tinyBLAS_Q0_ARM(int64_t k,
452
+ const TA *A, int64_t lda,
453
+ const block_q8_0 *B, int64_t ldb,
454
+ float *C, int64_t ldc,
459
455
  int ith, int nth)
460
456
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
461
457
  }
462
458
 
463
- void matmul(int m, int n, int task) {
459
+ void matmul(int64_t m, int64_t n, int task) {
464
460
  if (task == GGML_TASK_TYPE_COMPUTE)
465
461
  mnpack(0, m, 0, n);
466
462
  }
467
463
 
468
464
  private:
469
- NOINLINE void mnpack(int m0, int m, int n0, int n) {
470
- int mc, nc, mp, np;
471
- switch ((std::min(m - m0, 3) << 4) | std::min(n - n0, 3)) {
465
+ NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
466
+ int64_t mc, nc, mp, np;
467
+ switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
472
468
  case 0x33:
473
469
  mc = 3;
474
470
  nc = 3;
@@ -524,22 +520,22 @@ class tinyBLAS_Q0_ARM {
524
520
  }
525
521
 
526
522
  template <int RM, int RN>
527
- NOINLINE void gemm(int m0, int m, int n0, int n) {
528
- int ytiles = (m - m0) / RM;
529
- int xtiles = (n - n0) / RN;
530
- int tiles = xtiles * ytiles;
531
- int duty = (tiles + nth - 1) / nth;
532
- int start = duty * ith;
533
- int end = start + duty;
523
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
524
+ int64_t ytiles = (m - m0) / RM;
525
+ int64_t xtiles = (n - n0) / RN;
526
+ int64_t tiles = xtiles * ytiles;
527
+ int64_t duty = (tiles + nth - 1) / nth;
528
+ int64_t start = duty * ith;
529
+ int64_t end = start + duty;
534
530
  if (end > tiles)
535
531
  end = tiles;
536
- for (int job = start; job < end; ++job) {
537
- int ii = m0 + job / xtiles * RM;
538
- int jj = n0 + job % xtiles * RN;
532
+ for (int64_t job = start; job < end; ++job) {
533
+ int64_t ii = m0 + job / xtiles * RM;
534
+ int64_t jj = n0 + job % xtiles * RN;
539
535
  float32x4_t Cv[RN][RM] = {};
540
- for (int l = 0; l < k; ++l)
541
- for (int j = 0; j < RN; ++j)
542
- for (int i = 0; i < RM; ++i)
536
+ for (int64_t l = 0; l < k; ++l)
537
+ for (int64_t j = 0; j < RN; ++j)
538
+ for (int64_t i = 0; i < RM; ++i)
543
539
  Cv[j][i] = vmlaq_n_f32(Cv[j][i],
544
540
  vcvtq_f32_s32(vdotq_s32(
545
541
  vdotq_s32(vdupq_n_s32(0),
@@ -549,8 +545,8 @@ class tinyBLAS_Q0_ARM {
549
545
  load_hi(B + ldb * (jj + j) + l))),
550
546
  unhalf(A[lda * (ii + i) + l].d) *
551
547
  unhalf(B[ldb * (jj + j) + l].d));
552
- for (int j = 0; j < RN; ++j)
553
- for (int i = 0; i < RM; ++i)
548
+ for (int64_t j = 0; j < RN; ++j)
549
+ for (int64_t i = 0; i < RM; ++i)
554
550
  C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
555
551
  }
556
552
  }
@@ -577,36 +573,36 @@ class tinyBLAS_Q0_ARM {
577
573
  const TA *const A;
578
574
  const block_q8_0 *const B;
579
575
  float *const C;
580
- const int k;
581
- const int lda;
582
- const int ldb;
583
- const int ldc;
576
+ const int64_t k;
577
+ const int64_t lda;
578
+ const int64_t ldb;
579
+ const int64_t ldc;
584
580
  const int ith;
585
581
  const int nth;
586
582
  };
587
583
  #endif // __ARM_FEATURE_DOTPROD
588
584
 
589
- #if defined(__AVX2__) || defined(__AVX512F__)
585
+ #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
590
586
  template <typename TA, typename TB, typename TC>
591
- class tinyBLAS_Q0_AVX2 {
587
+ class tinyBLAS_Q0_AVX {
592
588
  public:
593
- tinyBLAS_Q0_AVX2(int k,
594
- const TA *A, int lda,
595
- const TB *B, int ldb,
596
- TC *C, int ldc,
597
- int ith, int nth)
589
+ tinyBLAS_Q0_AVX(int64_t k,
590
+ const TA *A, int64_t lda,
591
+ const TB *B, int64_t ldb,
592
+ TC *C, int64_t ldc,
593
+ int ith, int nth)
598
594
  : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
599
595
  }
600
596
 
601
- void matmul(int m, int n, int task) {
597
+ void matmul(int64_t m, int64_t n, int task) {
602
598
  if (task == GGML_TASK_TYPE_COMPUTE)
603
599
  mnpack(0, m, 0, n);
604
600
  }
605
601
 
606
602
  private:
607
- void mnpack(int m0, int m, int n0, int n) {
608
- int mc, nc, mp, np;
609
- switch ((std::min(m - m0, 4) << 4) | std::min(n - n0, 4)) {
603
+ void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
604
+ int64_t mc, nc, mp, np;
605
+ switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
610
606
  #if VECTOR_REGISTERS == 32
611
607
  case 0x44:
612
608
  mc = 4;
@@ -714,31 +710,51 @@ class tinyBLAS_Q0_AVX2 {
714
710
  }
715
711
 
716
712
  template <int RM, int RN>
717
- NOINLINE void gemm(int m0, int m, int n0, int n) {
718
- int ytiles = (m - m0) / RM;
719
- int xtiles = (n - n0) / RN;
720
- int tiles = xtiles * ytiles;
721
- int duty = (tiles + nth - 1) / nth;
722
- int start = duty * ith;
723
- int end = start + duty;
713
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
714
+ int64_t ytiles = (m - m0) / RM;
715
+ int64_t xtiles = (n - n0) / RN;
716
+ int64_t tiles = xtiles * ytiles;
717
+ int64_t duty = (tiles + nth - 1) / nth;
718
+ int64_t start = duty * ith;
719
+ int64_t end = start + duty;
724
720
  if (end > tiles)
725
721
  end = tiles;
726
- for (int job = start; job < end; ++job) {
727
- int ii = m0 + job / xtiles * RM;
728
- int jj = n0 + job % xtiles * RN;
722
+ for (int64_t job = start; job < end; ++job) {
723
+ int64_t ii = m0 + job / xtiles * RM;
724
+ int64_t jj = n0 + job % xtiles * RN;
729
725
  __m256 Cv[RN][RM] = {};
730
- for (int l = 0; l < k; ++l)
731
- for (int j = 0; j < RN; ++j)
732
- for (int i = 0; i < RM; ++i)
726
+ for (int64_t l = 0; l < k; ++l)
727
+ for (int64_t j = 0; j < RN; ++j)
728
+ for (int64_t i = 0; i < RM; ++i) {
729
+ #if defined(__AVX2__)
730
+ __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
731
+ load(A + lda * (ii + i) + l)),
732
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
733
+ load(A + lda * (ii + i) + l)));
734
+ #else
735
+ __m128i ali0 = load0(A + lda * (ii + i) + l);
736
+ __m128i ali1 = load1(A + lda * (ii + i) + l);
737
+ __m128i blj0 = load0(B + ldb * (jj + j) + l);
738
+ __m128i blj1 = load1(B + ldb * (jj + j) + l);
739
+
740
+ __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
741
+ __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
742
+ __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
743
+ __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
744
+
745
+ // updot
746
+ const __m128i oneFill = _mm_set1_epi16(1);
747
+ __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
748
+ __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
749
+ __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
750
+ #endif
733
751
  Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
734
752
  unhalf(B[ldb * (jj + j) + l].d)),
735
- updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
736
- load(A + lda * (ii + i) + l)),
737
- _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
738
- load(A + lda * (ii + i) + l))),
739
- Cv[j][i]);
740
- for (int j = 0; j < RN; ++j)
741
- for (int i = 0; i < RM; ++i)
753
+ udTmp,
754
+ Cv[j][i]);
755
+ }
756
+ for (int64_t j = 0; j < RN; ++j)
757
+ for (int64_t i = 0; i < RM; ++i)
742
758
  C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
743
759
  }
744
760
  }
@@ -747,10 +763,28 @@ class tinyBLAS_Q0_AVX2 {
747
763
  return _mm256_loadu_si256((const __m256i *)b->qs);
748
764
  }
749
765
 
766
+ inline __m128i load0(const block_q8_0 *b) {
767
+ return _mm_loadu_si128((const __m128i *)b->qs);
768
+ }
769
+
770
+ inline __m128i load1(const block_q8_0 *b) {
771
+ return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
772
+ }
773
+
750
774
  inline __m256i load(const block_q4_0 *b) {
751
775
  return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
752
776
  }
753
777
 
778
+ inline __m128i load0(const block_q4_0 *b) {
779
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
780
+ return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
781
+ }
782
+
783
+ inline __m128i load1(const block_q4_0 *b) {
784
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
785
+ return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
786
+ }
787
+
754
788
  inline __m256 updot(__m256i u, __m256i s) {
755
789
  __m256i res;
756
790
  #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
@@ -771,14 +805,14 @@ class tinyBLAS_Q0_AVX2 {
771
805
  const TA *const A;
772
806
  const TB *const B;
773
807
  TC *const C;
774
- const int k;
775
- const int lda;
776
- const int ldb;
777
- const int ldc;
808
+ const int64_t k;
809
+ const int64_t lda;
810
+ const int64_t ldb;
811
+ const int64_t ldc;
778
812
  const int ith;
779
813
  const int nth;
780
814
  };
781
- #endif // __AVX2__
815
+ #endif // __AVX__
782
816
 
783
817
  } // namespace
784
818
 
@@ -813,8 +847,8 @@ class tinyBLAS_Q0_AVX2 {
813
847
  * @param Ctype is GGML data type of `C`
814
848
  * @return true if this function was able to service the matmul request
815
849
  */
816
- bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, int ldb, void *C,
817
- int ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
850
+ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
851
+ int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
818
852
 
819
853
  assert(m >= 0);
820
854
  assert(n >= 0);
@@ -824,9 +858,6 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B,
824
858
  assert(ldc >= m);
825
859
  assert(nth > 0);
826
860
  assert(ith < nth);
827
- assert(1ll * lda * m <= 0x7fffffff);
828
- assert(1ll * ldb * n <= 0x7fffffff);
829
- assert(1ll * ldc * n <= 0x7fffffff);
830
861
 
831
862
  if (Ctype != GGML_TYPE_F32)
832
863
  return false;
@@ -932,8 +963,8 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B,
932
963
  case GGML_TYPE_Q8_0: {
933
964
  if (Btype != GGML_TYPE_Q8_0)
934
965
  return false;
935
- #if defined(__AVX2__) || defined(__AVX512F__)
936
- tinyBLAS_Q0_AVX2<block_q8_0, block_q8_0, float> tb{
966
+ #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
967
+ tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
937
968
  k, (const block_q8_0 *)A, lda,
938
969
  (const block_q8_0 *)B, ldb,
939
970
  (float *)C, ldc,
@@ -956,8 +987,8 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B,
956
987
  case GGML_TYPE_Q4_0: {
957
988
  if (Btype != GGML_TYPE_Q8_0)
958
989
  return false;
959
- #if defined(__AVX2__) || defined(__AVX512F__)
960
- tinyBLAS_Q0_AVX2<block_q4_0, block_q8_0, float> tb{
990
+ #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
991
+ tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
961
992
  k, (const block_q4_0 *)A, lda,
962
993
  (const block_q8_0 *)B, ldb,
963
994
  (float *)C, ldc,
@@ -1,11 +1,13 @@
1
1
  #pragma once
2
+ #include <stdint.h>
2
3
  #include <stdbool.h>
3
4
  #ifdef __cplusplus
4
5
  extern "C" {
5
6
  #endif
6
7
 
7
- bool llamafile_sgemm(int, int, int, const void *, int, const void *, int,
8
- void *, int, int, int, int, int, int, int);
8
+ bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t,
9
+ const void *, int64_t, void *, int64_t, int, int,
10
+ int, int, int, int);
9
11
 
10
12
  #ifdef __cplusplus
11
13
  }