llama_cpp 0.0.5 → 0.0.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +22 -0
- data/ext/llama_cpp/extconf.rb +24 -1
- data/ext/llama_cpp/llama_cpp.cpp +72 -0
- data/ext/llama_cpp/src/ggml-cuda.h +44 -0
- data/ext/llama_cpp/src/ggml-opencl.c +216 -0
- data/ext/llama_cpp/src/ggml-opencl.h +24 -0
- data/ext/llama_cpp/src/ggml.c +2324 -969
- data/ext/llama_cpp/src/ggml.h +656 -619
- data/ext/llama_cpp/src/llama.cpp +269 -42
- data/ext/llama_cpp/src/llama.h +22 -14
- data/ext/llama_cpp/src/llama_util.h +15 -3
- data/lib/llama_cpp/client.rb +151 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +16 -8
- data/sig/llama_cpp.rbs +26 -2
- metadata +6 -2
data/ext/llama_cpp/src/ggml.c
CHANGED
@@ -19,6 +19,7 @@
|
|
19
19
|
#include <inttypes.h>
|
20
20
|
#include <stdio.h>
|
21
21
|
#include <float.h>
|
22
|
+
#include <limits.h>
|
22
23
|
|
23
24
|
// if C99 - static_assert is noop
|
24
25
|
// ref: https://stackoverflow.com/a/53923785/4039976
|
@@ -142,10 +143,14 @@ inline static void* ggml_aligned_malloc(size_t size) {
|
|
142
143
|
} \
|
143
144
|
} while (0)
|
144
145
|
|
145
|
-
#
|
146
|
+
#if defined(GGML_USE_ACCELERATE)
|
146
147
|
#include <Accelerate/Accelerate.h>
|
147
|
-
#elif GGML_USE_OPENBLAS
|
148
|
+
#elif defined(GGML_USE_OPENBLAS)
|
148
149
|
#include <cblas.h>
|
150
|
+
#elif defined(GGML_USE_CUBLAS)
|
151
|
+
#include "ggml-cuda.h"
|
152
|
+
#elif defined(GGML_USE_CLBLAST)
|
153
|
+
#include "ggml-opencl.h"
|
149
154
|
#endif
|
150
155
|
|
151
156
|
#undef MIN
|
@@ -325,6 +330,20 @@ static ggml_fp16_t table_exp_f16[1 << 16];
|
|
325
330
|
// precomputed f32 table for f16 (256 KB)
|
326
331
|
static float table_f32_f16[1 << 16];
|
327
332
|
|
333
|
+
#if defined(__ARM_NEON)
|
334
|
+
#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
|
335
|
+
#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
|
336
|
+
#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
|
337
|
+
#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
|
338
|
+
#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
|
339
|
+
#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
|
340
|
+
#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
|
341
|
+
#define B8(c,s ) B7(c,s, c), B7(c,s, s)
|
342
|
+
|
343
|
+
// precomputed tables for expanding 8bits to 8 bytes (shl 4)
|
344
|
+
static const uint64_t table_b2b_u[1 << 8] = { B8(00, 10) };
|
345
|
+
#endif
|
346
|
+
|
328
347
|
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
|
329
348
|
// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
|
330
349
|
// This is also true for POWER9.
|
@@ -427,12 +446,69 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
|
|
427
446
|
// quantization
|
428
447
|
//
|
429
448
|
|
430
|
-
|
431
|
-
//
|
449
|
+
#if __AVX__ || __AVX2__ || __AVX512F__
|
450
|
+
// Unpack 16 4-bit fields into 16 bytes
|
451
|
+
// The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
|
452
|
+
static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
|
453
|
+
{
|
454
|
+
// Load 8 bytes from memory
|
455
|
+
__m128i tmp = _mm_loadl_epi64( ( const __m128i* )rsi );
|
456
|
+
|
457
|
+
// Expand bytes into uint16_t values
|
458
|
+
__m128i bytes = _mm_cvtepu8_epi16( tmp );
|
459
|
+
|
460
|
+
// Unpack values into individual bytes
|
461
|
+
const __m128i lowMask = _mm_set1_epi8( 0xF );
|
462
|
+
__m128i high = _mm_andnot_si128( lowMask, bytes );
|
463
|
+
__m128i low = _mm_and_si128( lowMask, bytes );
|
464
|
+
high = _mm_slli_epi16( high, 4 );
|
465
|
+
bytes = _mm_or_si128( low, high );
|
466
|
+
return bytes;
|
467
|
+
}
|
468
|
+
|
469
|
+
// horizontally add 8 floats
|
470
|
+
static inline float hsum_float_8(const __m256 x) {
|
471
|
+
__m128 res = _mm256_extractf128_ps(x, 1);
|
472
|
+
res = _mm_add_ps(res, _mm256_castps256_ps128(x));
|
473
|
+
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
|
474
|
+
res = _mm_add_ss(res, _mm_movehdup_ps(res));
|
475
|
+
return _mm_cvtss_f32(res);
|
476
|
+
}
|
477
|
+
|
478
|
+
// horizontally add 8 int32_t
|
479
|
+
static inline int hsum_i32_8(const __m256i a) {
|
480
|
+
const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
|
481
|
+
const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
|
482
|
+
const __m128i sum64 = _mm_add_epi32(hi64, sum128);
|
483
|
+
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
|
484
|
+
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
|
485
|
+
}
|
486
|
+
|
487
|
+
// horizontally add 4 int32_t
|
488
|
+
static inline int hsum_i32_4(const __m128i a) {
|
489
|
+
const __m128i hi64 = _mm_unpackhi_epi64(a, a);
|
490
|
+
const __m128i sum64 = _mm_add_epi32(hi64, a);
|
491
|
+
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
|
492
|
+
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
|
493
|
+
}
|
494
|
+
|
432
495
|
#if __AVX2__ || __AVX512F__
|
496
|
+
// spread 32 bits to 32 bytes { 0x00, 0xFF }
|
497
|
+
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
|
498
|
+
uint32_t x32;
|
499
|
+
memcpy(&x32, x, sizeof(uint32_t));
|
500
|
+
const __m256i shuf_mask = _mm256_set_epi64x(
|
501
|
+
0x0303030303030303, 0x0202020202020202,
|
502
|
+
0x0101010101010101, 0x0000000000000000);
|
503
|
+
__m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
|
504
|
+
const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
|
505
|
+
bytes = _mm256_or_si256(bytes, bit_mask);
|
506
|
+
return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
|
507
|
+
}
|
508
|
+
|
433
509
|
// Unpack 32 4-bit fields into 32 bytes
|
434
510
|
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
435
|
-
static inline __m256i
|
511
|
+
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
|
436
512
|
{
|
437
513
|
// Load 16 bytes from memory
|
438
514
|
__m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
|
@@ -449,9 +525,38 @@ static inline __m256i bytesFromNibbles( const uint8_t* rsi )
|
|
449
525
|
return bytes;
|
450
526
|
}
|
451
527
|
|
528
|
+
// add int16_t pairwise and return as float vector
|
529
|
+
static inline __m256 sum_i16_pairs_float(const __m256i x) {
|
530
|
+
const __m256i ones = _mm256_set1_epi16(1);
|
531
|
+
const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
|
532
|
+
return _mm256_cvtepi32_ps(summed_pairs);
|
533
|
+
}
|
534
|
+
|
535
|
+
// multiply int8_t, add results pairwise twice and return as float vector
|
536
|
+
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
|
537
|
+
// Get absolute values of x vectors
|
538
|
+
const __m256i ax = _mm256_sign_epi8(x, x);
|
539
|
+
// Sign the values of the y vectors
|
540
|
+
const __m256i sy = _mm256_sign_epi8(y, x);
|
541
|
+
#if __AVXVNNI__
|
542
|
+
const __m256i zero = _mm256_setzero_si256();
|
543
|
+
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
|
544
|
+
return _mm256_cvtepi32_ps(summed_pairs);
|
545
|
+
#else
|
546
|
+
// Perform multiplication and create 16-bit values
|
547
|
+
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
548
|
+
return sum_i16_pairs_float(dot);
|
549
|
+
#endif
|
550
|
+
}
|
551
|
+
|
452
552
|
static inline __m128i packNibbles( __m256i bytes )
|
453
553
|
{
|
454
554
|
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
|
555
|
+
#if __AVX512F__
|
556
|
+
const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000
|
557
|
+
bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
|
558
|
+
return _mm256_cvtepi16_epi8(bytes); // abcd_efgh
|
559
|
+
#else
|
455
560
|
const __m256i lowByte = _mm256_set1_epi16( 0xFF );
|
456
561
|
__m256i high = _mm256_andnot_si256( lowByte, bytes );
|
457
562
|
__m256i low = _mm256_and_si256( lowByte, bytes );
|
@@ -462,25 +567,9 @@ static inline __m128i packNibbles( __m256i bytes )
|
|
462
567
|
__m128i r0 = _mm256_castsi256_si128( bytes );
|
463
568
|
__m128i r1 = _mm256_extracti128_si256( bytes, 1 );
|
464
569
|
return _mm_packus_epi16( r0, r1 );
|
570
|
+
#endif
|
465
571
|
}
|
466
|
-
#
|
467
|
-
static inline __m128i bytesFromNibbles( const uint8_t* rsi )
|
468
|
-
{
|
469
|
-
// Load 8 bytes from memory
|
470
|
-
__m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
|
471
|
-
|
472
|
-
// Expand bytes into uint16_t values
|
473
|
-
__m128i bytes = _mm_cvtepu8_epi16( tmp );
|
474
|
-
|
475
|
-
// Unpack values into individual bytes
|
476
|
-
const __m128i lowMask = _mm_set1_epi8( 0xF );
|
477
|
-
__m128i high = _mm_andnot_si128( lowMask, bytes );
|
478
|
-
__m128i low = _mm_and_si128( lowMask, bytes );
|
479
|
-
high = _mm_slli_epi16( high, 4 );
|
480
|
-
bytes = _mm_or_si128( low, high );
|
481
|
-
return bytes;
|
482
|
-
}
|
483
|
-
|
572
|
+
#else
|
484
573
|
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
485
574
|
{
|
486
575
|
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
|
@@ -497,6 +586,7 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
|
497
586
|
return _mm_packus_epi16( bytes1, bytes2);
|
498
587
|
}
|
499
588
|
#endif
|
589
|
+
#endif // __AVX__ || __AVX2__ || __AVX512F__
|
500
590
|
|
501
591
|
#if __ARM_NEON
|
502
592
|
|
@@ -514,6 +604,18 @@ inline static uint16_t vaddvq_u8(uint8x16_t v) {
|
|
514
604
|
(uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
|
515
605
|
}
|
516
606
|
|
607
|
+
inline static int16_t vaddvq_s8(int8x16_t v) {
|
608
|
+
return
|
609
|
+
(int16_t)vgetq_lane_s8(v, 0) + (int16_t)vgetq_lane_s8(v, 1) +
|
610
|
+
(int16_t)vgetq_lane_s8(v, 2) + (int16_t)vgetq_lane_s8(v, 3) +
|
611
|
+
(int16_t)vgetq_lane_s8(v, 4) + (int16_t)vgetq_lane_s8(v, 5) +
|
612
|
+
(int16_t)vgetq_lane_s8(v, 6) + (int16_t)vgetq_lane_s8(v, 7) +
|
613
|
+
(int16_t)vgetq_lane_s8(v, 8) + (int16_t)vgetq_lane_s8(v, 9) +
|
614
|
+
(int16_t)vgetq_lane_s8(v, 10) + (int16_t)vgetq_lane_s8(v, 11) +
|
615
|
+
(int16_t)vgetq_lane_s8(v, 12) + (int16_t)vgetq_lane_s8(v, 13) +
|
616
|
+
(int16_t)vgetq_lane_s8(v, 14) + (int16_t)vgetq_lane_s8(v, 15);
|
617
|
+
}
|
618
|
+
|
517
619
|
inline static int32_t vaddvq_s16(int16x8_t v) {
|
518
620
|
return
|
519
621
|
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
|
@@ -583,7 +685,39 @@ typedef struct {
|
|
583
685
|
float m; // min
|
584
686
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
585
687
|
} block_q4_1;
|
586
|
-
static_assert(sizeof(block_q4_1) == sizeof(float)
|
688
|
+
static_assert(sizeof(block_q4_1) == 2 * sizeof(float) + QK4_1 / 2, "wrong q4_1 block size/padding");
|
689
|
+
|
690
|
+
#define QK4_2 16
|
691
|
+
typedef struct {
|
692
|
+
ggml_fp16_t d; // delta
|
693
|
+
uint8_t qs[QK4_2 / 2]; // nibbles / quants
|
694
|
+
} block_q4_2;
|
695
|
+
static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
|
696
|
+
|
697
|
+
#define QK4_3 16
|
698
|
+
typedef struct {
|
699
|
+
ggml_fp16_t d; // delta
|
700
|
+
ggml_fp16_t m; // min
|
701
|
+
uint8_t qs[QK4_3 / 2]; // nibbles / quants
|
702
|
+
} block_q4_3;
|
703
|
+
static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
|
704
|
+
|
705
|
+
#define QK5_0 32
|
706
|
+
typedef struct {
|
707
|
+
ggml_fp16_t d; // delta
|
708
|
+
uint8_t qh[4]; // 5-th bit of quants
|
709
|
+
uint8_t qs[QK5_0 / 2]; // nibbles / quants
|
710
|
+
} block_q5_0;
|
711
|
+
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
|
712
|
+
|
713
|
+
#define QK5_1 32
|
714
|
+
typedef struct {
|
715
|
+
ggml_fp16_t d; // delta
|
716
|
+
ggml_fp16_t m; // min
|
717
|
+
uint8_t qh[4]; // 5-th bit of quants
|
718
|
+
uint8_t qs[QK5_1 / 2]; // nibbles / quants
|
719
|
+
} block_q5_1;
|
720
|
+
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
|
587
721
|
|
588
722
|
#define QK8_0 32
|
589
723
|
typedef struct {
|
@@ -592,6 +726,14 @@ typedef struct {
|
|
592
726
|
} block_q8_0;
|
593
727
|
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
|
594
728
|
|
729
|
+
#define QK8_1 32
|
730
|
+
typedef struct {
|
731
|
+
float d; // delta
|
732
|
+
float s0; // d * sum(qs[i]) low
|
733
|
+
float s1; // d * sum(qs[i]) high
|
734
|
+
int8_t qs[QK8_1]; // quants
|
735
|
+
} block_q8_1;
|
736
|
+
static_assert(sizeof(block_q8_1) == 3*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
|
595
737
|
|
596
738
|
// reference implementation for deterministic creation of model files
|
597
739
|
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
|
@@ -602,13 +744,17 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
|
|
602
744
|
|
603
745
|
for (int i = 0; i < nb; i++) {
|
604
746
|
float amax = 0.0f; // absolute max
|
747
|
+
float max = 0.0f;
|
605
748
|
|
606
749
|
for (int l = 0; l < QK4_0; l++) {
|
607
750
|
const float v = x[i*QK4_0 + l];
|
608
|
-
|
751
|
+
if (amax < fabsf(v)) {
|
752
|
+
amax = fabsf(v);
|
753
|
+
max = v;
|
754
|
+
}
|
609
755
|
}
|
610
756
|
|
611
|
-
const float d =
|
757
|
+
const float d = max / -8;
|
612
758
|
const float id = d ? 1.0f/d : 0.0f;
|
613
759
|
|
614
760
|
y[i].d = d;
|
@@ -617,8 +763,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
|
|
617
763
|
const float v0 = x[i*QK4_0 + l + 0]*id;
|
618
764
|
const float v1 = x[i*QK4_0 + l + 1]*id;
|
619
765
|
|
620
|
-
const uint8_t vi0 = (int8_t)roundf(v0) + 8;
|
621
|
-
const uint8_t vi1 = (int8_t)roundf(v1) + 8;
|
766
|
+
const uint8_t vi0 = MIN(15, (int8_t)roundf(v0) + 8);
|
767
|
+
const uint8_t vi1 = MIN(15, (int8_t)roundf(v1) + 8);
|
622
768
|
|
623
769
|
assert(vi0 < 16);
|
624
770
|
assert(vi1 < 16);
|
@@ -638,28 +784,42 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
638
784
|
|
639
785
|
#if defined(__POWER9_VECTOR__)
|
640
786
|
const vector float v85 = vec_splats(8.5f);
|
787
|
+
const vector signed int v15 = vec_splats(15);
|
641
788
|
for (int i = 0; i < nb; i++) {
|
642
|
-
float
|
789
|
+
float max = 0.0f;
|
790
|
+
float min = 0.0f;
|
643
791
|
|
644
792
|
vector float srcv [8];
|
645
|
-
vector float
|
646
|
-
vector float
|
793
|
+
vector float maxv[8];
|
794
|
+
vector float minv[8];
|
647
795
|
|
648
796
|
for (int l = 0; l < 8; l++) srcv[l] = *(vector float *)(x + i*32 + 4*l);
|
649
|
-
for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
|
650
|
-
|
651
|
-
for (int l = 0; l < 4; l++)
|
652
|
-
//for (int l = 0; l < 2; l++)
|
653
|
-
|
654
|
-
|
655
|
-
//for (int l = 0; l < 1; l++)
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
797
|
+
//for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
|
798
|
+
|
799
|
+
for (int l = 0; l < 4; l++) maxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]);
|
800
|
+
//for (int l = 0; l < 2; l++) maxv[4*l] = vec_max(maxv[4*l], maxv[4*l+2]);
|
801
|
+
maxv[0] = vec_max(maxv[0], maxv[2]);
|
802
|
+
maxv[4] = vec_max(maxv[4], maxv[6]);
|
803
|
+
//for (int l = 0; l < 1; l++) maxv[8*l] = vec_max(maxv[8*l], maxv[8*l+4]);
|
804
|
+
maxv[0] = vec_max(maxv[0], maxv[4]);
|
805
|
+
|
806
|
+
for (int l = 0; l < 4; l++) minv[2*l] = vec_min(asrcv[2*l], asrcv[2*l+1]);
|
807
|
+
//for (int l = 0; l < 2; l++) minv[4*l] = vec_min(minv[4*l], minv[4*l+2]);
|
808
|
+
minv[0] = vec_min(minv[0], minv[2]);
|
809
|
+
minv[4] = vec_min(minv[4], minv[6]);
|
810
|
+
//for (int l = 0; l < 1; l++) minv[8*l] = vec_min(minv[8*l], minv[8*l+4]);
|
811
|
+
minv[0] = vec_min(minv[0], minv[4]);
|
812
|
+
|
813
|
+
|
814
|
+
max = MAX(
|
815
|
+
MAX(vec_extract(maxv[0], 0), vec_extract(maxv[0], 1)),
|
816
|
+
MAX(vec_extract(maxv[0], 2), vec_extract(maxv[0], 3)));
|
817
|
+
min = MIN(
|
818
|
+
MIN(vec_extract(minv[0], 0), vec_extract(minv[0], 1)),
|
819
|
+
MIN(vec_extract(minv[0], 2), vec_extract(minv[0], 3)));
|
820
|
+
|
821
|
+
const float magnitude = max >= fabsf(min) ? max : min;
|
822
|
+
const float d = magnitude / -8;
|
663
823
|
const float id = d ? 1.0/d : 0.0;
|
664
824
|
|
665
825
|
y[i].d = d;
|
@@ -669,27 +829,33 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
669
829
|
for (int l = 0; l < 8; l++) {
|
670
830
|
const vector float vf = vec_madd(srcv[l], vid, v85);
|
671
831
|
const vector signed int vi = vec_signed(vf);
|
832
|
+
const vector signed int vc = vec_min(vi, v15);
|
672
833
|
|
673
|
-
pb[2*l + 0] = vec_extract(
|
674
|
-
pb[2*l + 1] = vec_extract(
|
834
|
+
pb[2*l + 0] = vec_extract(vc, 0) | (vec_extract(vc, 1) << 4);
|
835
|
+
pb[2*l + 1] = vec_extract(vc, 2) | (vec_extract(vc, 3) << 4);
|
675
836
|
}
|
676
837
|
}
|
677
838
|
#elif __ARM_NEON
|
678
839
|
for (int i = 0; i < nb; i++) {
|
679
840
|
float32x4_t srcv [8];
|
680
|
-
float32x4_t
|
681
|
-
float32x4_t
|
841
|
+
float32x4_t maxv[8];
|
842
|
+
float32x4_t minv[8];
|
682
843
|
|
683
844
|
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
|
684
|
-
for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
|
685
845
|
|
686
|
-
for (int l = 0; l < 4; l++)
|
687
|
-
for (int l = 0; l < 2; l++)
|
688
|
-
for (int l = 0; l < 1; l++)
|
846
|
+
for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l+1]);
|
847
|
+
for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l+2]);
|
848
|
+
for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l+4]);
|
689
849
|
|
690
|
-
|
850
|
+
for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l+1]);
|
851
|
+
for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l+2]);
|
852
|
+
for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l+4]);
|
853
|
+
|
854
|
+
const float max = vmaxvq_f32(maxv[0]);
|
855
|
+
const float min = vminvq_f32(minv[0]);
|
691
856
|
|
692
|
-
const float
|
857
|
+
const float magnitude = max >= fabsf(min) ? max : min;
|
858
|
+
const float d = magnitude / -8;
|
693
859
|
const float id = d ? 1.0f/d : 0.0f;
|
694
860
|
|
695
861
|
y[i].d = d;
|
@@ -698,9 +864,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
698
864
|
const float32x4_t v = vmulq_n_f32(srcv[l], id);
|
699
865
|
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
|
700
866
|
const int32x4_t vi = vcvtq_s32_f32(vf);
|
867
|
+
const int32x4_t vc = vminq_s32(vi, vdupq_n_s32(15));
|
701
868
|
|
702
|
-
y[i].qs[2*l + 0] = vgetq_lane_s32(
|
703
|
-
y[i].qs[2*l + 1] = vgetq_lane_s32(
|
869
|
+
y[i].qs[2*l + 0] = vgetq_lane_s32(vc, 0) | (vgetq_lane_s32(vc, 1) << 4);
|
870
|
+
y[i].qs[2*l + 1] = vgetq_lane_s32(vc, 2) | (vgetq_lane_s32(vc, 3) << 4);
|
704
871
|
}
|
705
872
|
}
|
706
873
|
#elif defined(__AVX2__)
|
@@ -712,22 +879,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
712
879
|
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
713
880
|
x += 32;
|
714
881
|
|
715
|
-
// Compute max
|
716
|
-
|
717
|
-
__m256
|
718
|
-
|
719
|
-
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
|
720
|
-
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
|
882
|
+
// Compute max for the block
|
883
|
+
__m256 max = _mm256_max_ps( v0, v1 );
|
884
|
+
__m256 maxTmp = _mm256_max_ps( v2, v3 );
|
885
|
+
max = _mm256_max_ps( max, maxTmp );
|
721
886
|
|
722
|
-
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps(
|
887
|
+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
|
723
888
|
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
724
889
|
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
725
890
|
const float maxScalar = _mm_cvtss_f32( max4 );
|
726
891
|
|
892
|
+
// Compute min for the block
|
893
|
+
__m256 min = _mm256_min_ps( v0, v1 );
|
894
|
+
__m256 minTmp = _mm256_min_ps( v2, v3 );
|
895
|
+
min = _mm256_min_ps( min, minTmp );
|
896
|
+
|
897
|
+
__m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
|
898
|
+
min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
|
899
|
+
min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
|
900
|
+
const float minScalar = _mm_cvtss_f32( min4 );
|
901
|
+
|
727
902
|
// Quantize these floats
|
728
|
-
const float
|
903
|
+
const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
|
904
|
+
const float d = magnitude / -8.0f;
|
729
905
|
y[i].d = d;
|
730
|
-
const float id = (
|
906
|
+
const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
|
731
907
|
const __m256 mul = _mm256_set1_ps( id );
|
732
908
|
|
733
909
|
// Apply the multiplier
|
@@ -760,9 +936,11 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
760
936
|
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
|
761
937
|
i0 = _mm256_permutevar8x32_epi32( i0, perm );
|
762
938
|
|
763
|
-
// Apply offset to translate the range from [ -
|
939
|
+
// Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
|
764
940
|
const __m256i off = _mm256_set1_epi8( 8 );
|
765
941
|
i0 = _mm256_add_epi8( i0, off );
|
942
|
+
const __m256i maxNibble = _mm256_set1_epi8( 15 );
|
943
|
+
i0 = _mm256_min_epi8( i0, maxNibble );
|
766
944
|
|
767
945
|
// Compress the vector into 4 bit/value, and store
|
768
946
|
__m128i res = packNibbles( i0 );
|
@@ -777,22 +955,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
777
955
|
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
778
956
|
x += 32;
|
779
957
|
|
780
|
-
// Compute max
|
781
|
-
|
782
|
-
__m256
|
783
|
-
|
784
|
-
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
|
785
|
-
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
|
958
|
+
// Compute max for the block
|
959
|
+
__m256 max = _mm256_max_ps( v0, v1 );
|
960
|
+
__m256 maxTmp = _mm256_max_ps( v2, v3 );
|
961
|
+
max = _mm256_max_ps( max, maxTmp );
|
786
962
|
|
787
|
-
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps(
|
963
|
+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
|
788
964
|
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
789
965
|
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
790
966
|
const float maxScalar = _mm_cvtss_f32( max4 );
|
791
967
|
|
968
|
+
// Compute min for the block
|
969
|
+
__m256 min = _mm256_min_ps( v0, v1 );
|
970
|
+
__m256 minTmp = _mm256_min_ps( v2, v3 );
|
971
|
+
min = _mm256_min_ps( min, minTmp );
|
972
|
+
|
973
|
+
__m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
|
974
|
+
min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
|
975
|
+
min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
|
976
|
+
const float minScalar = _mm_cvtss_f32( min4 );
|
977
|
+
|
792
978
|
// Quantize these floats
|
793
|
-
const float
|
979
|
+
const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
|
980
|
+
const float d = magnitude / -8.0f;
|
794
981
|
y[i].d = d;
|
795
|
-
const float id = (
|
982
|
+
const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
|
796
983
|
const __m256 mul = _mm256_set1_ps( id );
|
797
984
|
|
798
985
|
// Apply the multiplier
|
@@ -833,10 +1020,13 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
833
1020
|
ni0 = _mm_packs_epi16( ni0, ni2 );
|
834
1021
|
ni4 = _mm_packs_epi16( ni4, ni6 );
|
835
1022
|
|
836
|
-
// Apply offset to translate the range from [ -
|
837
|
-
const __m128i off = _mm_set1_epi8( 8);
|
1023
|
+
// Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
|
1024
|
+
const __m128i off = _mm_set1_epi8( 8 );
|
838
1025
|
ni0 = _mm_add_epi8( ni0, off );
|
839
1026
|
ni4 = _mm_add_epi8( ni4, off );
|
1027
|
+
const __m128i maxNibble = _mm_set1_epi8( 15 );
|
1028
|
+
ni0 = _mm_min_epi8( ni0, maxNibble );
|
1029
|
+
ni4 = _mm_min_epi8( ni4, maxNibble );
|
840
1030
|
|
841
1031
|
// Compress the vector into 4 bit/value, and store
|
842
1032
|
__m128i res = packNibbles( ni0, ni4 );
|
@@ -844,24 +1034,32 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
844
1034
|
}
|
845
1035
|
#elif defined(__wasm_simd128__)
|
846
1036
|
for (int i = 0; i < nb; i++) {
|
847
|
-
float
|
1037
|
+
float max = 0.0f;
|
1038
|
+
float min = 0.0f;
|
848
1039
|
|
849
1040
|
v128_t srcv [8];
|
850
|
-
v128_t
|
851
|
-
v128_t
|
1041
|
+
v128_t maxv[8];
|
1042
|
+
v128_t minv[8];
|
852
1043
|
|
853
1044
|
for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l);
|
854
|
-
for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]);
|
855
1045
|
|
856
|
-
for (int l = 0; l < 4; l++)
|
857
|
-
for (int l = 0; l < 2; l++)
|
858
|
-
for (int l = 0; l < 1; l++)
|
1046
|
+
for (int l = 0; l < 4; l++) maxv[2*l] = wasm_f32x4_max(srcv[2*l], srcv[2*l+1]);
|
1047
|
+
for (int l = 0; l < 2; l++) maxv[4*l] = wasm_f32x4_max(maxv[4*l], maxv[4*l+2]);
|
1048
|
+
for (int l = 0; l < 1; l++) maxv[8*l] = wasm_f32x4_max(maxv[8*l], maxv[8*l+4]);
|
1049
|
+
|
1050
|
+
for (int l = 0; l < 4; l++) minv[2*l] = wasm_f32x4_min(srcv[2*l], srcv[2*l+1]);
|
1051
|
+
for (int l = 0; l < 2; l++) minv[4*l] = wasm_f32x4_min(minv[4*l], minv[4*l+2]);
|
1052
|
+
for (int l = 0; l < 1; l++) minv[8*l] = wasm_f32x4_min(minv[8*l], minv[8*l+4]);
|
859
1053
|
|
860
|
-
|
861
|
-
MAX(wasm_f32x4_extract_lane(
|
862
|
-
MAX(wasm_f32x4_extract_lane(
|
1054
|
+
max = MAX(
|
1055
|
+
MAX(wasm_f32x4_extract_lane(maxv[0], 0), wasm_f32x4_extract_lane(maxv[0], 1)),
|
1056
|
+
MAX(wasm_f32x4_extract_lane(maxv[0], 2), wasm_f32x4_extract_lane(maxv[0], 3)));
|
1057
|
+
min = MIN(
|
1058
|
+
MIN(wasm_f32x4_extract_lane(minv[0], 0), wasm_f32x4_extract_lane(minv[0], 1)),
|
1059
|
+
MIN(wasm_f32x4_extract_lane(minv[0], 2), wasm_f32x4_extract_lane(minv[0], 3)));
|
863
1060
|
|
864
|
-
const float
|
1061
|
+
const float magnitude = max >= fabsf(min) ? max : min;
|
1062
|
+
const float d = magnitude / -8;
|
865
1063
|
const float id = d ? 1.0/d : 0.0;
|
866
1064
|
|
867
1065
|
y[i].d = d;
|
@@ -870,9 +1068,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
870
1068
|
const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
|
871
1069
|
const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
|
872
1070
|
const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
|
1071
|
+
const v128_t vc = wasm_i32x4_min_u(vi, wasm_i32x4_splat(15));
|
873
1072
|
|
874
|
-
y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(
|
875
|
-
y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(
|
1073
|
+
y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vc, 0) | (wasm_i32x4_extract_lane(vc, 1) << 4);
|
1074
|
+
y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vc, 2) | (wasm_i32x4_extract_lane(vc, 3) << 4);
|
876
1075
|
}
|
877
1076
|
}
|
878
1077
|
#else
|
@@ -1045,6 +1244,193 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
|
|
1045
1244
|
#endif
|
1046
1245
|
}
|
1047
1246
|
|
1247
|
+
// reference implementation for deterministic creation of model files
|
1248
|
+
static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * restrict y, int k) {
|
1249
|
+
assert(k % QK4_2 == 0);
|
1250
|
+
|
1251
|
+
const int nb = k / QK4_2;
|
1252
|
+
|
1253
|
+
for (int i = 0; i < nb; i++) {
|
1254
|
+
float amax = 0.0f; // absolute max
|
1255
|
+
float max = 0.0f;
|
1256
|
+
|
1257
|
+
for (int l = 0; l < QK4_2; l++) {
|
1258
|
+
const float v = x[i*QK4_2 + l];
|
1259
|
+
if (amax < fabsf(v)) {
|
1260
|
+
amax = fabsf(v);
|
1261
|
+
max = v;
|
1262
|
+
}
|
1263
|
+
}
|
1264
|
+
|
1265
|
+
const float d = max / -8;
|
1266
|
+
|
1267
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1268
|
+
|
1269
|
+
y[i].d = GGML_FP32_TO_FP16(d);
|
1270
|
+
|
1271
|
+
for (int l = 0; l < QK4_2; l += 2) {
|
1272
|
+
const float v0 = x[i*QK4_2 + l + 0]*id;
|
1273
|
+
const float v1 = x[i*QK4_2 + l + 1]*id;
|
1274
|
+
|
1275
|
+
const uint8_t vi0 = MIN(15, (uint8_t)(v0 + 8.5f));
|
1276
|
+
const uint8_t vi1 = MIN(15, (uint8_t)(v1 + 8.5f));
|
1277
|
+
|
1278
|
+
assert(vi0 < 16);
|
1279
|
+
assert(vi1 < 16);
|
1280
|
+
|
1281
|
+
y[i].qs[l/2] = vi0 | (vi1 << 4);
|
1282
|
+
}
|
1283
|
+
}
|
1284
|
+
}
|
1285
|
+
|
1286
|
+
static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
|
1287
|
+
assert(k % QK4_2 == 0);
|
1288
|
+
|
1289
|
+
block_q4_2 * restrict y = vy;
|
1290
|
+
|
1291
|
+
quantize_row_q4_2_reference(x, y, k);
|
1292
|
+
}
|
1293
|
+
|
1294
|
+
static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) {
|
1295
|
+
assert(k % QK4_3 == 0);
|
1296
|
+
const int nb = k / QK4_3;
|
1297
|
+
|
1298
|
+
for (int i = 0; i < nb; i++) {
|
1299
|
+
float min = FLT_MAX;
|
1300
|
+
float max = -FLT_MAX;
|
1301
|
+
|
1302
|
+
for (int l = 0; l < QK4_3; l++) {
|
1303
|
+
const float v = x[i*QK4_3 + l];
|
1304
|
+
if (v < min) min = v;
|
1305
|
+
if (v > max) max = v;
|
1306
|
+
}
|
1307
|
+
|
1308
|
+
const float d = (max - min) / ((1 << 4) - 1);
|
1309
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1310
|
+
|
1311
|
+
y[i].d = GGML_FP32_TO_FP16(d);
|
1312
|
+
y[i].m = GGML_FP32_TO_FP16(min);
|
1313
|
+
|
1314
|
+
for (int l = 0; l < QK4_3; l += 2) {
|
1315
|
+
const float v0 = (x[i*QK4_3 + l + 0] - min)*id;
|
1316
|
+
const float v1 = (x[i*QK4_3 + l + 1] - min)*id;
|
1317
|
+
|
1318
|
+
const uint8_t vi0 = (int) (v0 + 0.5f);
|
1319
|
+
const uint8_t vi1 = (int) (v1 + 0.5f);
|
1320
|
+
|
1321
|
+
assert(vi0 < 16);
|
1322
|
+
assert(vi1 < 16);
|
1323
|
+
|
1324
|
+
y[i].qs[l/2] = vi0 | (vi1 << 4);
|
1325
|
+
}
|
1326
|
+
}
|
1327
|
+
}
|
1328
|
+
|
1329
|
+
static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int k) {
|
1330
|
+
assert(k % QK4_3 == 0);
|
1331
|
+
|
1332
|
+
block_q4_3 * restrict y = vy;
|
1333
|
+
|
1334
|
+
quantize_row_q4_3_reference(x, y, k);
|
1335
|
+
}
|
1336
|
+
|
1337
|
+
static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
|
1338
|
+
assert(k % QK5_0 == 0);
|
1339
|
+
const int nb = k / QK5_0;
|
1340
|
+
|
1341
|
+
for (int i = 0; i < nb; i++) {
|
1342
|
+
float amax = 0.0f; // absolute max
|
1343
|
+
float max = 0.0f;
|
1344
|
+
|
1345
|
+
for (int l = 0; l < QK5_0; l++) {
|
1346
|
+
const float v = x[i*QK5_0 + l];
|
1347
|
+
if (amax < fabsf(v)) {
|
1348
|
+
amax = fabsf(v);
|
1349
|
+
max = v;
|
1350
|
+
}
|
1351
|
+
}
|
1352
|
+
|
1353
|
+
const float d = max / -16;
|
1354
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1355
|
+
|
1356
|
+
y[i].d = GGML_FP32_TO_FP16(d);
|
1357
|
+
|
1358
|
+
uint32_t qh = 0;
|
1359
|
+
|
1360
|
+
for (int l = 0; l < QK5_0; l += 2) {
|
1361
|
+
const float v0 = x[i*QK5_0 + l + 0]*id;
|
1362
|
+
const float v1 = x[i*QK5_0 + l + 1]*id;
|
1363
|
+
|
1364
|
+
const uint32_t vi0 = MIN(31, (int) (v0 + 16.5f));
|
1365
|
+
const uint32_t vi1 = MIN(31, (int) (v1 + 16.5f));
|
1366
|
+
|
1367
|
+
y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
|
1368
|
+
|
1369
|
+
// get the 5-th bit and store it in qh at the right position
|
1370
|
+
qh |= ((vi0 & 0x10) >> 4) << (l + 0);
|
1371
|
+
qh |= ((vi1 & 0x10) >> 4) << (l + 1);
|
1372
|
+
}
|
1373
|
+
|
1374
|
+
memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
|
1375
|
+
}
|
1376
|
+
}
|
1377
|
+
|
1378
|
+
static void quantize_row_q5_0(const float * restrict x, void * restrict vy, int k) {
|
1379
|
+
assert(k % QK5_0 == 0);
|
1380
|
+
|
1381
|
+
block_q5_0 * restrict y = vy;
|
1382
|
+
|
1383
|
+
quantize_row_q5_0_reference(x, y, k);
|
1384
|
+
}
|
1385
|
+
|
1386
|
+
static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
|
1387
|
+
assert(k % QK5_1 == 0);
|
1388
|
+
const int nb = k / QK5_1;
|
1389
|
+
|
1390
|
+
for (int i = 0; i < nb; i++) {
|
1391
|
+
float min = FLT_MAX;
|
1392
|
+
float max = -FLT_MAX;
|
1393
|
+
|
1394
|
+
for (int l = 0; l < QK5_1; l++) {
|
1395
|
+
const float v = x[i*QK5_1 + l];
|
1396
|
+
if (v < min) min = v;
|
1397
|
+
if (v > max) max = v;
|
1398
|
+
}
|
1399
|
+
|
1400
|
+
const float d = (max - min) / ((1 << 5) - 1);
|
1401
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1402
|
+
|
1403
|
+
y[i].d = GGML_FP32_TO_FP16(d);
|
1404
|
+
y[i].m = GGML_FP32_TO_FP16(min);
|
1405
|
+
|
1406
|
+
uint32_t qh = 0;
|
1407
|
+
|
1408
|
+
for (int l = 0; l < QK5_1; l += 2) {
|
1409
|
+
const float v0 = (x[i*QK5_1 + l + 0] - min)*id;
|
1410
|
+
const float v1 = (x[i*QK5_1 + l + 1] - min)*id;
|
1411
|
+
|
1412
|
+
const uint32_t vi0 = (int) (v0 + 0.5f);
|
1413
|
+
const uint32_t vi1 = (int) (v1 + 0.5f);
|
1414
|
+
|
1415
|
+
y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
|
1416
|
+
|
1417
|
+
// get the 5-th bit and store it in qh at the right position
|
1418
|
+
qh |= ((vi0 & 0x10) >> 4) << (l + 0);
|
1419
|
+
qh |= ((vi1 & 0x10) >> 4) << (l + 1);
|
1420
|
+
}
|
1421
|
+
|
1422
|
+
memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
|
1423
|
+
}
|
1424
|
+
}
|
1425
|
+
|
1426
|
+
static void quantize_row_q5_1(const float * restrict x, void * restrict vy, int k) {
|
1427
|
+
assert(k % QK5_1 == 0);
|
1428
|
+
|
1429
|
+
block_q5_1 * restrict y = vy;
|
1430
|
+
|
1431
|
+
quantize_row_q5_1_reference(x, y, k);
|
1432
|
+
}
|
1433
|
+
|
1048
1434
|
// reference implementation for deterministic creation of model files
|
1049
1435
|
static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
|
1050
1436
|
assert(k % QK8_0 == 0);
|
@@ -1064,18 +1450,64 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
|
|
1064
1450
|
y[i].d = d;
|
1065
1451
|
|
1066
1452
|
for (int l = 0; l < QK8_0; ++l) {
|
1067
|
-
const float
|
1068
|
-
|
1453
|
+
const float v0 = x[i*QK8_0 + l]*id;
|
1454
|
+
|
1455
|
+
y[i].qs[l] = roundf(v0);
|
1069
1456
|
}
|
1070
1457
|
}
|
1071
1458
|
}
|
1072
1459
|
|
1073
1460
|
static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
|
1074
1461
|
assert(k % QK8_0 == 0);
|
1075
|
-
const int nb = k / QK8_0;
|
1076
1462
|
|
1077
1463
|
block_q8_0 * restrict y = vy;
|
1078
1464
|
|
1465
|
+
quantize_row_q8_0_reference(x, y, k);
|
1466
|
+
}
|
1467
|
+
|
1468
|
+
// reference implementation for deterministic creation of model files
|
1469
|
+
static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
|
1470
|
+
assert(k % QK8_1 == 0);
|
1471
|
+
const int nb = k / QK8_1;
|
1472
|
+
|
1473
|
+
for (int i = 0; i < nb; i++) {
|
1474
|
+
float amax = 0.0f; // absolute max
|
1475
|
+
|
1476
|
+
for (int l = 0; l < QK8_1; l++) {
|
1477
|
+
const float v = x[i*QK8_1 + l];
|
1478
|
+
amax = MAX(amax, fabsf(v));
|
1479
|
+
}
|
1480
|
+
|
1481
|
+
const float d = amax / ((1 << 7) - 1);
|
1482
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1483
|
+
|
1484
|
+
y[i].d = d;
|
1485
|
+
|
1486
|
+
int sum0 = 0;
|
1487
|
+
int sum1 = 0;
|
1488
|
+
|
1489
|
+
for (int l = 0; l < QK8_1/2; ++l) {
|
1490
|
+
const float v0 = x[i*QK8_1 + l]*id;
|
1491
|
+
const float v1 = x[i*QK8_1 + QK8_1/2 + l]*id;
|
1492
|
+
|
1493
|
+
y[i].qs[ l] = roundf(v0);
|
1494
|
+
y[i].qs[QK8_1/2 + l] = roundf(v1);
|
1495
|
+
|
1496
|
+
sum0 += y[i].qs[ l];
|
1497
|
+
sum1 += y[i].qs[QK8_1/2 + l];
|
1498
|
+
}
|
1499
|
+
|
1500
|
+
y[i].s0 = d * sum0;
|
1501
|
+
y[i].s1 = d * sum1;
|
1502
|
+
}
|
1503
|
+
}
|
1504
|
+
|
1505
|
+
static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
|
1506
|
+
assert(k % QK8_1 == 0);
|
1507
|
+
const int nb = k / QK8_1;
|
1508
|
+
|
1509
|
+
block_q8_1 * restrict y = vy;
|
1510
|
+
|
1079
1511
|
#if defined(__ARM_NEON)
|
1080
1512
|
for (int i = 0; i < nb; i++) {
|
1081
1513
|
float32x4_t srcv [8];
|
@@ -1096,7 +1528,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|
1096
1528
|
|
1097
1529
|
y[i].d = d;
|
1098
1530
|
|
1099
|
-
|
1531
|
+
int32x4_t accv0 = vdupq_n_s32(0);
|
1532
|
+
int32x4_t accv1 = vdupq_n_s32(0);
|
1533
|
+
|
1534
|
+
// low half
|
1535
|
+
for (int l = 0; l < 4; l++) {
|
1100
1536
|
const float32x4_t v = vmulq_n_f32(srcv[l], id);
|
1101
1537
|
const int32x4_t vi = vcvtnq_s32_f32(v);
|
1102
1538
|
|
@@ -1104,19 +1540,40 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|
1104
1540
|
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
|
1105
1541
|
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
|
1106
1542
|
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
|
1543
|
+
|
1544
|
+
accv0 = vaddq_s32(accv0, vi);
|
1107
1545
|
}
|
1108
|
-
}
|
1109
|
-
#elif defined(__AVX2__) || defined(__AVX__)
|
1110
|
-
for (int i = 0; i < nb; i++) {
|
1111
|
-
// Load elements into 4 AVX vectors
|
1112
|
-
__m256 v0 = _mm256_loadu_ps( x );
|
1113
|
-
__m256 v1 = _mm256_loadu_ps( x + 8 );
|
1114
|
-
__m256 v2 = _mm256_loadu_ps( x + 16 );
|
1115
|
-
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
1116
|
-
x += 32;
|
1117
1546
|
|
1118
|
-
//
|
1119
|
-
|
1547
|
+
// high half
|
1548
|
+
for (int l = 4; l < 8; l++) {
|
1549
|
+
const float32x4_t v = vmulq_n_f32(srcv[l], id);
|
1550
|
+
const int32x4_t vi = vcvtnq_s32_f32(v);
|
1551
|
+
|
1552
|
+
y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
|
1553
|
+
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
|
1554
|
+
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
|
1555
|
+
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
|
1556
|
+
|
1557
|
+
accv1 = vaddq_s32(accv1, vi);
|
1558
|
+
}
|
1559
|
+
|
1560
|
+
const int32_t sum0 = vaddvq_s32(accv0);
|
1561
|
+
const int32_t sum1 = vaddvq_s32(accv1);
|
1562
|
+
|
1563
|
+
y[i].s0 = d * sum0;
|
1564
|
+
y[i].s1 = d * sum1;
|
1565
|
+
}
|
1566
|
+
#elif defined(__AVX2__) || defined(__AVX__)
|
1567
|
+
for (int i = 0; i < nb; i++) {
|
1568
|
+
// Load elements into 4 AVX vectors
|
1569
|
+
__m256 v0 = _mm256_loadu_ps( x );
|
1570
|
+
__m256 v1 = _mm256_loadu_ps( x + 8 );
|
1571
|
+
__m256 v2 = _mm256_loadu_ps( x + 16 );
|
1572
|
+
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
1573
|
+
x += 32;
|
1574
|
+
|
1575
|
+
// Compute max(abs(e)) for the block
|
1576
|
+
const __m256 signBit = _mm256_set1_ps( -0.0f );
|
1120
1577
|
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
|
1121
1578
|
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
|
1122
1579
|
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
|
@@ -1152,6 +1609,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|
1152
1609
|
__m256i i3 = _mm256_cvtps_epi32( v3 );
|
1153
1610
|
|
1154
1611
|
#if defined(__AVX2__)
|
1612
|
+
// Compute the sum of the quants and set y[i].s
|
1613
|
+
//y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
|
1614
|
+
y[i].s0 = d * hsum_i32_8(_mm256_add_epi32(i0, i1));
|
1615
|
+
y[i].s1 = d * hsum_i32_8(_mm256_add_epi32(i2, i3));
|
1616
|
+
|
1155
1617
|
// Convert int32 to int16
|
1156
1618
|
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
|
1157
1619
|
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
|
@@ -1177,6 +1639,12 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|
1177
1639
|
__m128i ni6 = _mm256_castsi256_si128( i3 );
|
1178
1640
|
__m128i ni7 = _mm256_extractf128_si256( i3, 1);
|
1179
1641
|
|
1642
|
+
// Compute the sum of the quants and set y[i].s
|
1643
|
+
const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
|
1644
|
+
const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
|
1645
|
+
y[i].s0 = d * hsum_i32_4(s0);
|
1646
|
+
y[i].s1 = d * hsum_i32_4(s1);
|
1647
|
+
|
1180
1648
|
// Convert int32 to int16
|
1181
1649
|
ni0 = _mm_packs_epi32( ni0, ni1 );
|
1182
1650
|
ni2 = _mm_packs_epi32( ni2, ni3 );
|
@@ -1192,7 +1660,7 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|
1192
1660
|
}
|
1193
1661
|
#else
|
1194
1662
|
// scalar
|
1195
|
-
|
1663
|
+
quantize_row_q8_1_reference(x, y, k);
|
1196
1664
|
#endif
|
1197
1665
|
}
|
1198
1666
|
|
@@ -1211,7 +1679,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1211
1679
|
|
1212
1680
|
for (int l = 0; l < QK4_0; l += 32) {
|
1213
1681
|
// Load 32x4-bit integers into 32x8-bit integers
|
1214
|
-
__m256i vx8 =
|
1682
|
+
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
|
1215
1683
|
|
1216
1684
|
// Subtract 8 from the integers
|
1217
1685
|
vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
|
@@ -1246,7 +1714,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1246
1714
|
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
1247
1715
|
|
1248
1716
|
// Expand 4-bit qs to 8-bit bytes
|
1249
|
-
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(
|
1717
|
+
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
|
1250
1718
|
const uint8x8_t v1 = vshr_n_u8(v8, 4);
|
1251
1719
|
|
1252
1720
|
// Convert to signed 8-bit integers
|
@@ -1296,7 +1764,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1296
1764
|
for (int l = 0; l < QK4_0; l += 2) {
|
1297
1765
|
const uint8_t vi = pp[l/2];
|
1298
1766
|
|
1299
|
-
const int8_t vi0 = vi &
|
1767
|
+
const int8_t vi0 = vi & 0x0F;
|
1300
1768
|
const int8_t vi1 = vi >> 4;
|
1301
1769
|
|
1302
1770
|
const float v0 = (vi0 - 8)*d;
|
@@ -1329,7 +1797,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1329
1797
|
|
1330
1798
|
for (int l = 0; l < QK4_1; l += 32) {
|
1331
1799
|
// Load 32x4-bit integers into 32x8-bit integers
|
1332
|
-
__m256i vx8 =
|
1800
|
+
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
|
1333
1801
|
|
1334
1802
|
// Convert to 16-bit int
|
1335
1803
|
const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
|
@@ -1362,7 +1830,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1362
1830
|
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
1363
1831
|
|
1364
1832
|
// Expand 4-bit qs to 8-bit bytes
|
1365
|
-
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(
|
1833
|
+
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
|
1366
1834
|
const uint8x8_t v1 = vshr_n_u8(v8, 4);
|
1367
1835
|
|
1368
1836
|
// Interleave and combine
|
@@ -1404,7 +1872,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1404
1872
|
for (int l = 0; l < QK4_1; l += 2) {
|
1405
1873
|
const uint8_t vi = pp[l/2];
|
1406
1874
|
|
1407
|
-
const int8_t vi0 = vi &
|
1875
|
+
const int8_t vi0 = vi & 0x0F;
|
1408
1876
|
const int8_t vi1 = vi >> 4;
|
1409
1877
|
|
1410
1878
|
const float v0 = vi0*d + m;
|
@@ -1420,8 +1888,162 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1420
1888
|
#endif
|
1421
1889
|
}
|
1422
1890
|
|
1423
|
-
static void
|
1891
|
+
static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, int k) {
|
1892
|
+
assert(k % QK4_2 == 0);
|
1893
|
+
const int nb = k / QK4_2;
|
1894
|
+
|
1895
|
+
const block_q4_2 * restrict x = vx;
|
1896
|
+
|
1897
|
+
for (int i = 0; i < nb; i++) {
|
1898
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
1899
|
+
|
1900
|
+
const uint8_t * restrict pp = x[i].qs;
|
1901
|
+
|
1902
|
+
for (int l = 0; l < QK4_2; l += 2) {
|
1903
|
+
const uint8_t vi = pp[l/2];
|
1904
|
+
|
1905
|
+
const int8_t vi0 = vi & 0x0F;
|
1906
|
+
const int8_t vi1 = vi >> 4;
|
1907
|
+
|
1908
|
+
const float v0 = (vi0 - 8)*d;
|
1909
|
+
const float v1 = (vi1 - 8)*d;
|
1910
|
+
|
1911
|
+
y[i*QK4_2 + l + 0] = v0;
|
1912
|
+
y[i*QK4_2 + l + 1] = v1;
|
1913
|
+
|
1914
|
+
assert(!isnan(y[i*QK4_2 + l + 0]));
|
1915
|
+
assert(!isnan(y[i*QK4_2 + l + 1]));
|
1916
|
+
}
|
1917
|
+
}
|
1918
|
+
}
|
1919
|
+
|
1920
|
+
static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, int k) {
|
1921
|
+
assert(k % QK4_3 == 0);
|
1922
|
+
const int nb = k / QK4_3;
|
1923
|
+
|
1924
|
+
const block_q4_3 * restrict x = vx;
|
1925
|
+
|
1926
|
+
for (int i = 0; i < nb; i++) {
|
1927
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
1928
|
+
const float m = GGML_FP16_TO_FP32(x[i].m);
|
1929
|
+
|
1930
|
+
const uint8_t * restrict pp = x[i].qs;
|
1931
|
+
|
1932
|
+
for (int l = 0; l < QK4_3; l += 2) {
|
1933
|
+
const uint8_t vi = pp[l/2];
|
1934
|
+
|
1935
|
+
const int8_t vi0 = vi & 0x0F;
|
1936
|
+
const int8_t vi1 = vi >> 4;
|
1937
|
+
|
1938
|
+
const float v0 = vi0*d + m;
|
1939
|
+
const float v1 = vi1*d + m;
|
1940
|
+
|
1941
|
+
y[i*QK4_3 + l + 0] = v0;
|
1942
|
+
y[i*QK4_3 + l + 1] = v1;
|
1943
|
+
|
1944
|
+
assert(!isnan(y[i*QK4_3 + l + 0]));
|
1945
|
+
assert(!isnan(y[i*QK4_3 + l + 1]));
|
1946
|
+
}
|
1947
|
+
}
|
1948
|
+
}
|
1949
|
+
|
1950
|
+
static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, int k) {
|
1951
|
+
assert(k % QK5_0 == 0);
|
1952
|
+
const int nb = k / QK5_0;
|
1953
|
+
|
1954
|
+
const block_q5_0 * restrict x = vx;
|
1955
|
+
|
1956
|
+
for (int i = 0; i < nb; i++) {
|
1957
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
1958
|
+
|
1959
|
+
const uint8_t * restrict pp = x[i].qs;
|
1960
|
+
|
1961
|
+
uint32_t qh;
|
1962
|
+
memcpy(&qh, x[i].qh, sizeof(qh));
|
1963
|
+
|
1964
|
+
for (int l = 0; l < QK5_0; l += 2) {
|
1965
|
+
const uint8_t vi = pp[l/2];
|
1966
|
+
|
1967
|
+
// extract the 5-th bit from qh
|
1968
|
+
const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
|
1969
|
+
const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
|
1970
|
+
|
1971
|
+
const int8_t vi0 = (vi & 0x0F) | vh0;
|
1972
|
+
const int8_t vi1 = (vi >> 4) | vh1;
|
1973
|
+
|
1974
|
+
const float v0 = (vi0 - 16)*d;
|
1975
|
+
const float v1 = (vi1 - 16)*d;
|
1976
|
+
|
1977
|
+
y[i*QK5_0 + l + 0] = v0;
|
1978
|
+
y[i*QK5_0 + l + 1] = v1;
|
1979
|
+
|
1980
|
+
assert(!isnan(y[i*QK5_0 + l + 0]));
|
1981
|
+
assert(!isnan(y[i*QK5_0 + l + 1]));
|
1982
|
+
}
|
1983
|
+
}
|
1984
|
+
}
|
1985
|
+
|
1986
|
+
static void dequantize_row_q5_1(const void * restrict vx, float * restrict y, int k) {
|
1987
|
+
assert(k % QK5_1 == 0);
|
1988
|
+
const int nb = k / QK5_1;
|
1989
|
+
|
1990
|
+
const block_q5_1 * restrict x = vx;
|
1991
|
+
|
1992
|
+
for (int i = 0; i < nb; i++) {
|
1993
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
1994
|
+
const float m = GGML_FP16_TO_FP32(x[i].m);
|
1995
|
+
|
1996
|
+
const uint8_t * restrict pp = x[i].qs;
|
1997
|
+
|
1998
|
+
uint32_t qh;
|
1999
|
+
memcpy(&qh, x[i].qh, sizeof(qh));
|
2000
|
+
|
2001
|
+
for (int l = 0; l < QK5_1; l += 2) {
|
2002
|
+
const uint8_t vi = pp[l/2];
|
2003
|
+
|
2004
|
+
// extract the 5-th bit from qh
|
2005
|
+
const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
|
2006
|
+
const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
|
2007
|
+
|
2008
|
+
const uint8_t vi0 = (vi & 0x0F) | vh0;
|
2009
|
+
const uint8_t vi1 = (vi >> 4) | vh1;
|
2010
|
+
|
2011
|
+
const float v0 = vi0*d + m;
|
2012
|
+
const float v1 = vi1*d + m;
|
2013
|
+
|
2014
|
+
y[i*QK5_1 + l + 0] = v0;
|
2015
|
+
y[i*QK5_1 + l + 1] = v1;
|
2016
|
+
|
2017
|
+
assert(!isnan(y[i*QK5_1 + l + 0]));
|
2018
|
+
assert(!isnan(y[i*QK5_1 + l + 1]));
|
2019
|
+
}
|
2020
|
+
}
|
2021
|
+
}
|
2022
|
+
|
2023
|
+
static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, int k) {
|
2024
|
+
assert(k % QK8_0 == 0);
|
2025
|
+
const int nb = k / QK8_0;
|
2026
|
+
|
2027
|
+
const block_q8_0 * restrict x = vx;
|
2028
|
+
|
2029
|
+
for (int i = 0; i < nb; i++) {
|
2030
|
+
const float d = x[i].d;
|
2031
|
+
|
2032
|
+
const int8_t * restrict pp = x[i].qs;
|
2033
|
+
|
2034
|
+
for (int l = 0; l < QK8_0; ++l) {
|
2035
|
+
y[i*QK8_0 + l] = pp[l]*d;
|
2036
|
+
}
|
2037
|
+
}
|
2038
|
+
}
|
2039
|
+
|
1424
2040
|
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
2041
|
+
static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
2042
|
+
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
2043
|
+
static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
2044
|
+
static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
2045
|
+
static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
2046
|
+
static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1425
2047
|
|
1426
2048
|
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
1427
2049
|
[GGML_TYPE_Q4_0] = {
|
@@ -1430,15 +2052,64 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
|
1430
2052
|
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
|
1431
2053
|
.quantize_row_q_dot = quantize_row_q8_0,
|
1432
2054
|
.vec_dot_q = ggml_vec_dot_q4_0_q8_0,
|
2055
|
+
.vec_dot_type = GGML_TYPE_Q8_0,
|
1433
2056
|
},
|
1434
2057
|
[GGML_TYPE_Q4_1] = {
|
1435
2058
|
.dequantize_row_q = dequantize_row_q4_1,
|
1436
2059
|
.quantize_row_q = quantize_row_q4_1,
|
1437
2060
|
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
|
1438
|
-
.quantize_row_q_dot =
|
1439
|
-
.vec_dot_q =
|
2061
|
+
.quantize_row_q_dot = quantize_row_q8_1,
|
2062
|
+
.vec_dot_q = ggml_vec_dot_q4_1_q8_1,
|
2063
|
+
.vec_dot_type = GGML_TYPE_Q8_1,
|
2064
|
+
},
|
2065
|
+
[GGML_TYPE_Q4_2] = {
|
2066
|
+
.dequantize_row_q = dequantize_row_q4_2,
|
2067
|
+
.quantize_row_q = quantize_row_q4_2,
|
2068
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
|
2069
|
+
.quantize_row_q_dot = quantize_row_q8_0,
|
2070
|
+
.vec_dot_q = ggml_vec_dot_q4_2_q8_0,
|
2071
|
+
.vec_dot_type = GGML_TYPE_Q8_0,
|
2072
|
+
},
|
2073
|
+
[GGML_TYPE_Q4_3] = {
|
2074
|
+
.dequantize_row_q = dequantize_row_q4_3,
|
2075
|
+
.quantize_row_q = quantize_row_q4_3,
|
2076
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference,
|
2077
|
+
.quantize_row_q_dot = quantize_row_q8_1,
|
2078
|
+
.vec_dot_q = ggml_vec_dot_q4_3_q8_1,
|
2079
|
+
.vec_dot_type = GGML_TYPE_Q8_1,
|
2080
|
+
},
|
2081
|
+
[GGML_TYPE_Q5_0] = {
|
2082
|
+
.dequantize_row_q = dequantize_row_q5_0,
|
2083
|
+
.quantize_row_q = quantize_row_q5_0,
|
2084
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_0_reference,
|
2085
|
+
.quantize_row_q_dot = quantize_row_q8_0,
|
2086
|
+
.vec_dot_q = ggml_vec_dot_q5_0_q8_0,
|
2087
|
+
.vec_dot_type = GGML_TYPE_Q8_0,
|
2088
|
+
},
|
2089
|
+
[GGML_TYPE_Q5_1] = {
|
2090
|
+
.dequantize_row_q = dequantize_row_q5_1,
|
2091
|
+
.quantize_row_q = quantize_row_q5_1,
|
2092
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_1_reference,
|
2093
|
+
.quantize_row_q_dot = quantize_row_q8_1,
|
2094
|
+
.vec_dot_q = ggml_vec_dot_q5_1_q8_1,
|
2095
|
+
.vec_dot_type = GGML_TYPE_Q8_1,
|
2096
|
+
},
|
2097
|
+
[GGML_TYPE_Q8_0] = {
|
2098
|
+
.dequantize_row_q = dequantize_row_q8_0,
|
2099
|
+
.quantize_row_q = quantize_row_q8_0,
|
2100
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
|
2101
|
+
.quantize_row_q_dot = quantize_row_q8_0,
|
2102
|
+
.vec_dot_q = ggml_vec_dot_q8_0_q8_0,
|
2103
|
+
.vec_dot_type = GGML_TYPE_Q8_0,
|
2104
|
+
},
|
2105
|
+
[GGML_TYPE_Q8_1] = {
|
2106
|
+
.dequantize_row_q = NULL, // TODO
|
2107
|
+
.quantize_row_q = quantize_row_q8_1,
|
2108
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_1_reference,
|
2109
|
+
.quantize_row_q_dot = quantize_row_q8_1,
|
2110
|
+
.vec_dot_q = NULL, // TODO
|
2111
|
+
.vec_dot_type = GGML_TYPE_Q8_1,
|
1440
2112
|
},
|
1441
|
-
// TODO: GGML_TYPE_Q8_0
|
1442
2113
|
};
|
1443
2114
|
|
1444
2115
|
// For internal test use
|
@@ -2004,191 +2675,6 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
|
|
2004
2675
|
*s = sumf;
|
2005
2676
|
}
|
2006
2677
|
|
2007
|
-
#if __AVX512F__ && QK4_0 == 32
|
2008
|
-
static inline __m512i bytes_from_q4_0_twoblocks_avx512( const __m512i blocks ) {
|
2009
|
-
// The 64 bytes of `blocks` contain two consecutive Q4_0 blocks loaded from memory:
|
2010
|
-
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2011
|
-
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
|
2012
|
-
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2013
|
-
// | :. =_ () [] <> () Zz Yy|
|
2014
|
-
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2015
|
-
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
|
2016
|
-
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2017
|
-
// |Xx Ww Vv Uu Tt Ss Rr Qq Pp Oo Nn Mm Ll Kk Jj Ii Hh Gg Ff Ee Dd Cc Bb Aa |
|
2018
|
-
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2019
|
-
//
|
2020
|
-
// Bytes 04..19 (block #0) and 24..39 (block #1) both contain 32 nibbles (4-bit unsigned integers).
|
2021
|
-
// We have exactly 64 nibbles, so we want to place each nibble into a separate byte.
|
2022
|
-
// Bytes 00..03 and 20..23 contain scales, which are irrelevant to this function.
|
2023
|
-
// Bytes 40..63 are masked when loading the data, so they are zeroed out.
|
2024
|
-
#ifdef __AVX512VBMI__
|
2025
|
-
const __m512i byte_perm = _mm512_set_epi8(
|
2026
|
-
39, 38, 39, 38, 37, 36, 37, 36, 35, 34, 35, 34, 33, 32, 33, 32,
|
2027
|
-
31, 30, 31, 30, 29, 28, 29, 28, 27, 26, 27, 26, 25, 24, 25, 24,
|
2028
|
-
19, 18, 19, 18, 17, 16, 17, 16, 15, 14, 15, 14, 13, 12, 13, 12,
|
2029
|
-
11, 10, 11, 10, 9, 8, 9, 8, 7, 6, 7, 6, 5, 4, 5, 4
|
2030
|
-
);
|
2031
|
-
const __m512i permuted = _mm512_permutexvar_epi8( byte_perm, blocks );
|
2032
|
-
// After applying VPERMB, `permuted` looks like this:
|
2033
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2034
|
-
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
|
2035
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2036
|
-
// |:. =_ :. =_ () [] () [] <> () <> () Zz Yy Zz Yy Xx Ww Xx Ww Vv Uu Vv Uu Tt Ss Tt Ss Rr Qq Rr Qq|
|
2037
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2038
|
-
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
|
2039
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2040
|
-
// |Pp Oo Pp Oo Nn Mm Nn Mm Ll Kk Ll Kk Jj Ii Jj Ii Hh Gg Hh Gg Ff Ee Ff Ee Dd Cc Dd Cc Bb Aa Bb Aa|
|
2041
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2042
|
-
#else
|
2043
|
-
const __m512i word_perm = _mm512_set_epi16(
|
2044
|
-
19, 19, 18, 18, 17, 17, 16, 16, 15, 15, 14, 14, 13, 13, 12, 12,
|
2045
|
-
9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2
|
2046
|
-
);
|
2047
|
-
const __m512i permuted = _mm512_permutexvar_epi16( word_perm, blocks );
|
2048
|
-
// This is the fallback path for CPUs that don't support VPERMB. Since we permute 16-bit groups only,
|
2049
|
-
// VPERMB can be replaced with VPERMW. We could always use VPERMW, but at least on Tiger Lake and
|
2050
|
-
// Ice Lake VPERMW followed by a right shift is quite noticeably slower than VPERMB.
|
2051
|
-
#endif
|
2052
|
-
|
2053
|
-
// Shift every odd-numbered 16-bit group to the right by 4 bits.
|
2054
|
-
const __mmask32 shift_mask = 0xaaaaaaaa;
|
2055
|
-
const __m512i shifted = _mm512_mask_srai_epi16( permuted, shift_mask, permuted, 4 );
|
2056
|
-
// After applying VPSRAW, `shifted` looks like this (the "empty" nibbles are filled with zeroes):
|
2057
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2058
|
-
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32
|
2059
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2060
|
-
// | : .= :. =_ ( )[ () [] < >( <> () Z zY Zz Yy X xW Xx Ww V vU Vv Uu T tS Tt Ss R rQ Rr Qq
|
2061
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2062
|
-
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
|
2063
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2064
|
-
// | P pO Pp Oo N nM Nn Mm L lK Ll Kk J jI Jj Ii H hG Hh Gg F fE Ff Ee D dC Dd Cc B bA Bb Aa|
|
2065
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2066
|
-
|
2067
|
-
// Now we just need to zero out the higher nibble in each byte, and we're done.
|
2068
|
-
const __m512i low_nibble_mask = _mm512_set1_epi8( 0xf );
|
2069
|
-
return _mm512_and_si512( low_nibble_mask, shifted );
|
2070
|
-
// The final result looks like this:
|
2071
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2072
|
-
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
|
2073
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2074
|
-
// | : = . _ ( [ ) ] < ( > ) Z Y z y X W x w V U v u T S t s R Q r q|
|
2075
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2076
|
-
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
|
2077
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2078
|
-
// | P O p o N M n m L K l k J I j i H G h g F E f e D C d c B A b a|
|
2079
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2080
|
-
}
|
2081
|
-
|
2082
|
-
static inline __m512 dot_q4_0_twoblocks_avx512(
|
2083
|
-
__m512 acc,
|
2084
|
-
const block_q4_0 * restrict x,
|
2085
|
-
const block_q4_0 * restrict y,
|
2086
|
-
int i
|
2087
|
-
) {
|
2088
|
-
// A pair of Q4_0 blocks spans 40 bytes, while an AVX-512 register has 64. The remaining 24 bytes
|
2089
|
-
// can potentially be unaddressable, so we make sure to mask them out before the load, even though
|
2090
|
-
// we don't use them at all. This might hurt the performance slightly, since the compiler is forced
|
2091
|
-
// to use e.g. `VMOVDQU64 REG, MASK, [ADDR] + VPERMB ..., REG` instead of just `VPERMB ..., [ADDR]`.
|
2092
|
-
const __mmask8 load_mask = 0x1f;
|
2093
|
-
const __m512i blocks_0 = _mm512_maskz_loadu_epi64( load_mask, &x[i] );
|
2094
|
-
const __m512i blocks_1 = _mm512_maskz_loadu_epi64( load_mask, &y[i] );
|
2095
|
-
|
2096
|
-
// We want to multiply the scales, so we interpret both registers as 16 32-bit floats:
|
2097
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2098
|
-
// | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
|
2099
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2100
|
-
// blocks_0_float
|
2101
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2102
|
-
// | | | | | | | xx | xx | xx | xx | B | xx | xx | xx | xx | A |
|
2103
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2104
|
-
// blocks_1_float
|
2105
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2106
|
-
// | | | | | | | xx | xx | xx | xx | D | xx | xx | xx | xx | C |
|
2107
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2108
|
-
const __m512 blocks_0_float = _mm512_castsi512_ps( blocks_0 );
|
2109
|
-
const __m512 blocks_1_float = _mm512_castsi512_ps( blocks_1 );
|
2110
|
-
// We absolutely shouldn't touch the floats marked with `xx`: they contain some
|
2111
|
-
// random data, which might very well underflow. At least on Intel, this leads
|
2112
|
-
// to a huge penalty that can't be ignored (easily 100x or more) unless you
|
2113
|
-
// compile your code with something like `-ffast-math` to enable FTZ/DAZ flags.
|
2114
|
-
// (and ggml can't assume that you do)...
|
2115
|
-
const __mmask16 scale_mul_mask = 0x21;
|
2116
|
-
#ifdef __clang__
|
2117
|
-
// ...however, clang decides to optimize the multiplication mask away:
|
2118
|
-
// https://godbolt.org/z/P8PqdsfvW
|
2119
|
-
// gcc and MSVC do the sane thing. This horrible workaround forces clang to emit the mask.
|
2120
|
-
__m512i scales;
|
2121
|
-
__asm__(
|
2122
|
-
"vmulps %1, %2, %0%{%3%}"
|
2123
|
-
: "=v" ( scales )
|
2124
|
-
: "vm" ( blocks_0_float ), "v" ( blocks_1_float ), "Yk" ( scale_mul_mask )
|
2125
|
-
);
|
2126
|
-
#else
|
2127
|
-
const __m512 scales = _mm512_maskz_mul_ps( scale_mul_mask, blocks_0_float, blocks_1_float );
|
2128
|
-
#endif
|
2129
|
-
const __m512i scale_perm = _mm512_set_epi32(
|
2130
|
-
5, 5, 5, 5, 5, 5, 5, 5,
|
2131
|
-
0, 0, 0, 0, 0, 0, 0, 0
|
2132
|
-
);
|
2133
|
-
const __m512 permuted_scales = _mm512_permutexvar_ps( scale_perm, scales );
|
2134
|
-
// After VMULPS and VPERMPS, `permuted_scales` looks like this:
|
2135
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2136
|
-
// | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
|
2137
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2138
|
-
// | B*D| B*D| B*D| B*D| B*D| B*D| B*D| B*D| A*C| A*C| A*C| A*C| A*C| A*C| A*C| A*C|
|
2139
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2140
|
-
|
2141
|
-
const __m512i bytes_0 = bytes_from_q4_0_twoblocks_avx512( blocks_0 );
|
2142
|
-
const __m512i bytes_1 = bytes_from_q4_0_twoblocks_avx512( blocks_1 );
|
2143
|
-
|
2144
|
-
// Now we want to compute dot products of 4-element byte vectors and store them in
|
2145
|
-
// 32-bit integers. That is (only one 4-element vector is shown for clarity):
|
2146
|
-
// +----+----+----+----+
|
2147
|
-
// ... | 03 | 02 | 01 | 00 |
|
2148
|
-
// +----+----+----+----+
|
2149
|
-
// bytes_0
|
2150
|
-
// +----+----+----+----+
|
2151
|
-
// ... | D | C | B | A |
|
2152
|
-
// +----+----+----+----+
|
2153
|
-
// bytes_1
|
2154
|
-
// +----+----+----+----+
|
2155
|
-
// ... | H | G | F | E |
|
2156
|
-
// +----+----+----+----+
|
2157
|
-
// final_res_int
|
2158
|
-
// +----+----+----+----+
|
2159
|
-
// ... | A*E+B*F+C*G+D*H |
|
2160
|
-
// +----+----+----+----+
|
2161
|
-
const __m512i plus_8 = _mm512_set1_epi8( 8 );
|
2162
|
-
const __m512i bytes_1_minus_8 = _mm512_sub_epi8( bytes_1, plus_8 );
|
2163
|
-
|
2164
|
-
#ifdef __AVX512VNNI__
|
2165
|
-
// We have VPDPBUSDS in AVX512-VNNI, which does exactly what we want, but with a catch:
|
2166
|
-
// the *left* operand is supposed to be unsigned, while Q4_0 quantization subtracts 8
|
2167
|
-
// from each nibble, so they can be negative. So, instead of `(bytes_0 - 8) * (bytes_1 - 8)`,
|
2168
|
-
// we compute `bytes_0 * (bytes_1 - 8) + bytes_1 * (-8) + 64`. VPDPBUSDS uses an accumulator,
|
2169
|
-
// which means we only need 2 instructions.
|
2170
|
-
const __m512i dot_init = _mm512_set1_epi32( 4 * 64 );
|
2171
|
-
const __m512i minus_8 = _mm512_set1_epi8( -8 );
|
2172
|
-
const __m512i prod_0 = _mm512_dpbusds_epi32( dot_init, bytes_1, minus_8 );
|
2173
|
-
const __m512i final_res_int = _mm512_dpbusds_epi32( prod_0, bytes_0, bytes_1_minus_8 );
|
2174
|
-
#else
|
2175
|
-
// As a fallback, we have VPMADDUBSW in AVX512-BW, which uses 16-bit products instead of 32-bit ones.
|
2176
|
-
// It has the same catch as VPDPBUSDS: the left operand should be unsigned.
|
2177
|
-
// This is essentially the AVX-512 version of the AVX-2 trick used by GH user Const-me
|
2178
|
-
// ref: https://gist.github.com/Const-me/4d30e1fc767ab314596e16e90f53b6f4#file-matmultest-cpp-L119
|
2179
|
-
const __m512i one = _mm512_set1_epi16( 1 );
|
2180
|
-
const __m512i prod_0 = _mm512_maddubs_epi16( bytes_0, bytes_1_minus_8 );
|
2181
|
-
const __m512i prod_1 = _mm512_maddubs_epi16( plus_8, bytes_1_minus_8 );
|
2182
|
-
const __m512i diff = _mm512_sub_epi16( prod_0, prod_1 );
|
2183
|
-
const __m512i final_res_int = _mm512_madd_epi16( diff, one );
|
2184
|
-
#endif
|
2185
|
-
|
2186
|
-
// Finally, we multiply the permuted scales and the 32-bit dot products, then accumulate.
|
2187
|
-
const __m512 final_res_float = _mm512_cvtepi32_ps( final_res_int );
|
2188
|
-
return _mm512_fmadd_ps( permuted_scales, final_res_float, acc );
|
2189
|
-
}
|
2190
|
-
#endif
|
2191
|
-
|
2192
2678
|
inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
|
2193
2679
|
ggml_float sumf = 0.0;
|
2194
2680
|
|
@@ -2225,67 +2711,62 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
|
|
2225
2711
|
*s = sumf;
|
2226
2712
|
}
|
2227
2713
|
|
2228
|
-
static void
|
2229
|
-
const int nb = n /
|
2714
|
+
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2715
|
+
const int nb = n / QK8_0;
|
2230
2716
|
|
2231
|
-
assert(n %
|
2717
|
+
assert(n % QK8_0 == 0);
|
2232
2718
|
assert(nb % 2 == 0);
|
2233
2719
|
|
2234
2720
|
const block_q4_0 * restrict x = vx;
|
2235
|
-
const
|
2236
|
-
|
2237
|
-
float sumf = 0.0;
|
2721
|
+
const block_q8_0 * restrict y = vy;
|
2238
2722
|
|
2239
2723
|
#if defined(__ARM_NEON)
|
2240
|
-
|
2241
|
-
|
2724
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
2725
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
2242
2726
|
|
2243
2727
|
for (int i = 0; i < nb; i += 2) {
|
2244
2728
|
const block_q4_0 * restrict x0 = &x[i + 0];
|
2245
|
-
const block_q4_0 * restrict y0 = &y[i + 0];
|
2246
2729
|
const block_q4_0 * restrict x1 = &x[i + 1];
|
2247
|
-
const
|
2730
|
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
2731
|
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
2248
2732
|
|
2249
|
-
const uint8x16_t m4b
|
2250
|
-
const int8x16_t s8b
|
2733
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
2734
|
+
const int8x16_t s8b = vdupq_n_s8(0x8);
|
2251
2735
|
|
2252
2736
|
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
2253
|
-
const uint8x16_t v1_0 = vld1q_u8(y0->qs);
|
2254
2737
|
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
2255
|
-
const uint8x16_t v1_1 = vld1q_u8(y1->qs);
|
2256
2738
|
|
2257
2739
|
// 4-bit -> 8-bit
|
2258
|
-
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
|
2259
|
-
const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
|
2740
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
2260
2741
|
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
2261
|
-
const int8x16_t
|
2262
|
-
|
2263
|
-
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
|
2264
|
-
const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
|
2742
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
2265
2743
|
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
2266
|
-
const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
|
2267
2744
|
|
2268
2745
|
// sub 8
|
2269
2746
|
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
2270
|
-
const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
|
2271
2747
|
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
2272
|
-
const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
|
2273
|
-
|
2274
2748
|
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
2275
|
-
const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
|
2276
2749
|
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
2277
|
-
|
2750
|
+
|
2751
|
+
// load y
|
2752
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2753
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2754
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2755
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2756
|
+
|
2757
|
+
// interleave
|
2758
|
+
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
|
2759
|
+
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
|
2760
|
+
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
|
2761
|
+
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
|
2278
2762
|
|
2279
2763
|
#if defined(__ARM_FEATURE_DOTPROD)
|
2280
2764
|
// dot product into int32x4_t
|
2281
|
-
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
|
2282
|
-
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
|
2765
|
+
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
|
2766
|
+
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
|
2283
2767
|
|
2284
|
-
|
2285
|
-
|
2286
|
-
|
2287
|
-
sum0 += x0->d*y0->d*vaddvq_s32(p_0);
|
2288
|
-
sum1 += x1->d*y1->d*vaddvq_s32(p_1);
|
2768
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
|
2769
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
|
2289
2770
|
#else
|
2290
2771
|
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
|
2291
2772
|
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
|
@@ -2297,125 +2778,41 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
2297
2778
|
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
|
2298
2779
|
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
|
2299
2780
|
|
2300
|
-
const
|
2301
|
-
const
|
2302
|
-
|
2303
|
-
const
|
2304
|
-
const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
|
2305
|
-
|
2306
|
-
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
|
2307
|
-
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
|
2781
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2782
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2783
|
+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
2784
|
+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
2308
2785
|
|
2309
|
-
|
2310
|
-
|
2786
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
|
2787
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
|
2311
2788
|
#endif
|
2312
2789
|
}
|
2313
2790
|
|
2314
|
-
|
2315
|
-
#elif defined(
|
2791
|
+
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
2792
|
+
#elif defined(__AVX2__)
|
2316
2793
|
// Initialize accumulator with zeros
|
2317
|
-
|
2318
|
-
__m512 acc1 = _mm512_setzero_ps();
|
2319
|
-
|
2320
|
-
const int superblock_size = 16;
|
2321
|
-
|
2322
|
-
const int superblock_count = nb / superblock_size;
|
2794
|
+
__m256 acc = _mm256_setzero_ps();
|
2323
2795
|
|
2324
|
-
|
2325
|
-
|
2796
|
+
// Main loop
|
2797
|
+
for (int i = 0; i < nb; ++i) {
|
2798
|
+
/* Compute combined scale for the block */
|
2799
|
+
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
|
2326
2800
|
|
2327
|
-
|
2328
|
-
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+2 );
|
2329
|
-
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+4 );
|
2330
|
-
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+6 );
|
2331
|
-
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+8 );
|
2332
|
-
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+10 );
|
2333
|
-
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+12 );
|
2334
|
-
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+14 );
|
2335
|
-
}
|
2801
|
+
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
2336
2802
|
|
2337
|
-
|
2338
|
-
|
2339
|
-
|
2340
|
-
}
|
2803
|
+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2804
|
+
const __m256i off = _mm256_set1_epi8( 8 );
|
2805
|
+
bx = _mm256_sub_epi8( bx, off );
|
2341
2806
|
|
2342
|
-
|
2343
|
-
sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
|
2344
|
-
#elif defined(__AVX2__)
|
2345
|
-
// Initialize accumulator with zeros
|
2346
|
-
__m256 acc = _mm256_setzero_ps();
|
2807
|
+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
2347
2808
|
|
2348
|
-
|
2349
|
-
const __m256i lowMask = _mm256_set1_epi8( 0xF );
|
2350
|
-
const __m256i offset_8 = _mm256_set1_epi16( 8 );
|
2809
|
+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
2351
2810
|
|
2352
|
-
|
2353
|
-
|
2354
|
-
|
2811
|
+
/* Multiply q with scale and accumulate */
|
2812
|
+
acc = _mm256_fmadd_ps( d, q, acc );
|
2813
|
+
}
|
2355
2814
|
|
2356
|
-
|
2357
|
-
for (int i = 0; i < nb; i+=UNROLL_COUNT) {
|
2358
|
-
// This loop will be unrolled by the compiler
|
2359
|
-
for (int u=0;u<UNROLL_COUNT;u++) {
|
2360
|
-
/* Compute combined scale for the block */
|
2361
|
-
const __m256 scale = _mm256_mul_ps(
|
2362
|
-
_mm256_broadcast_ss( &x[i+u].d ),
|
2363
|
-
_mm256_broadcast_ss( &y[i+u].d ) );
|
2364
|
-
|
2365
|
-
/* get input from x
|
2366
|
-
Input: 32 Nibbles (16 bytes) at *x[i+u]
|
2367
|
-
Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
|
2368
|
-
|
2369
|
-
/* Load 16 bytes from memory */
|
2370
|
-
const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
|
2371
|
-
/* Expand bytes into uint16_t values */
|
2372
|
-
const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
|
2373
|
-
/* Unpack values into individual bytes */
|
2374
|
-
__m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
|
2375
|
-
const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
|
2376
|
-
__m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
|
2377
|
-
/* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
|
2378
|
-
x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
|
2379
|
-
x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
|
2380
|
-
|
2381
|
-
/* get input from y
|
2382
|
-
Input: 32 Nibbles (16 bytes) at *y[i+u]
|
2383
|
-
Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
|
2384
|
-
|
2385
|
-
/* Load 16 bytes from memory */
|
2386
|
-
const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
|
2387
|
-
/* Expand bytes into uint16_t values */
|
2388
|
-
const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
|
2389
|
-
/* Unpack values into individual bytes */
|
2390
|
-
const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
|
2391
|
-
__m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
|
2392
|
-
__m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
|
2393
|
-
/* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
|
2394
|
-
y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
|
2395
|
-
y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
|
2396
|
-
|
2397
|
-
/* Compute products of int16_t integers, add pairwise, store as int32_t */
|
2398
|
-
__m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
|
2399
|
-
__m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
|
2400
|
-
|
2401
|
-
/* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
|
2402
|
-
__m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
|
2403
|
-
|
2404
|
-
/* Convert to vectore of 8 int32_t to 8 floats */
|
2405
|
-
__m256 q = _mm256_cvtepi32_ps( xy_q );
|
2406
|
-
|
2407
|
-
/* Multiply q with scale and accumulate */
|
2408
|
-
acc = _mm256_fmadd_ps( scale, q, acc );
|
2409
|
-
}
|
2410
|
-
}
|
2411
|
-
|
2412
|
-
// Return horizontal sum of the acc vector
|
2413
|
-
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
2414
|
-
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2415
|
-
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2416
|
-
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2417
|
-
|
2418
|
-
sumf = _mm_cvtss_f32( res );
|
2815
|
+
*s = hsum_float_8(acc);
|
2419
2816
|
#elif defined(__AVX__)
|
2420
2817
|
// Initialize accumulator with zeros
|
2421
2818
|
__m256 acc = _mm256_setzero_ps();
|
@@ -2428,13 +2825,12 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
2428
2825
|
__m128i i32[2];
|
2429
2826
|
for (int j = 0; j < 2; ++j) {
|
2430
2827
|
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
|
2431
|
-
__m128i bx =
|
2432
|
-
__m128i by =
|
2828
|
+
__m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j);
|
2829
|
+
__m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
|
2433
2830
|
|
2434
2831
|
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2435
2832
|
const __m128i off = _mm_set1_epi8( 8 );
|
2436
2833
|
bx = _mm_sub_epi8( bx, off );
|
2437
|
-
by = _mm_sub_epi8( by, off );
|
2438
2834
|
|
2439
2835
|
// Get absolute values of x vectors
|
2440
2836
|
const __m128i ax = _mm_sign_epi8(bx, bx);
|
@@ -2445,516 +2841,833 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
2445
2841
|
// Perform multiplication and create 16-bit values
|
2446
2842
|
const __m128i dot = _mm_maddubs_epi16(ax, sy);
|
2447
2843
|
|
2448
|
-
const __m128i ones = _mm_set1_epi16(1);
|
2449
|
-
i32[j] = _mm_madd_epi16(ones, dot);
|
2450
|
-
}
|
2844
|
+
const __m128i ones = _mm_set1_epi16(1);
|
2845
|
+
i32[j] = _mm_madd_epi16(ones, dot);
|
2846
|
+
}
|
2847
|
+
|
2848
|
+
// Convert int32_t to float
|
2849
|
+
__m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
|
2850
|
+
// Apply the scale, and accumulate
|
2851
|
+
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
|
2852
|
+
}
|
2853
|
+
|
2854
|
+
*s = hsum_float_8(acc);
|
2855
|
+
#else
|
2856
|
+
// scalar
|
2857
|
+
float sumf = 0.0;
|
2858
|
+
for (int i = 0; i < nb; i++) {
|
2859
|
+
const float d0 = x[i].d;
|
2860
|
+
const float d1 = y[i].d;
|
2861
|
+
|
2862
|
+
const uint8_t * restrict p0 = x[i].qs;
|
2863
|
+
const int8_t * restrict p1 = y[i].qs;
|
2864
|
+
|
2865
|
+
int sumi = 0;
|
2866
|
+
for (int j = 0; j < QK8_0/2; j++) {
|
2867
|
+
const uint8_t v0 = p0[j];
|
2868
|
+
|
2869
|
+
const int i0 = (int8_t) (v0 & 0x0F) - 8;
|
2870
|
+
const int i1 = (int8_t) (v0 >> 4) - 8;
|
2871
|
+
|
2872
|
+
const int i2 = p1[2*j + 0];
|
2873
|
+
const int i3 = p1[2*j + 1];
|
2874
|
+
|
2875
|
+
sumi += i0*i2 + i1*i3;
|
2876
|
+
}
|
2877
|
+
sumf += d0*d1*sumi;
|
2878
|
+
}
|
2879
|
+
*s = sumf;
|
2880
|
+
#endif
|
2881
|
+
}
|
2882
|
+
|
2883
|
+
static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2884
|
+
const int nb = n / QK8_1;
|
2885
|
+
|
2886
|
+
assert(n % QK8_1 == 0);
|
2887
|
+
assert(nb % 2 == 0);
|
2888
|
+
|
2889
|
+
const block_q4_1 * restrict x = vx;
|
2890
|
+
const block_q8_1 * restrict y = vy;
|
2891
|
+
|
2892
|
+
// TODO: add AVX / WASM SIMD / etc
|
2893
|
+
#if defined(__ARM_NEON)
|
2894
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
2895
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
2896
|
+
|
2897
|
+
float summs = 0;
|
2898
|
+
|
2899
|
+
for (int i = 0; i < nb; i += 2) {
|
2900
|
+
const block_q4_1 * restrict x0 = &x[i + 0];
|
2901
|
+
const block_q4_1 * restrict x1 = &x[i + 1];
|
2902
|
+
const block_q8_1 * restrict y0 = &y[i + 0];
|
2903
|
+
const block_q8_1 * restrict y1 = &y[i + 1];
|
2904
|
+
|
2905
|
+
summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1);
|
2906
|
+
|
2907
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
2908
|
+
|
2909
|
+
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
2910
|
+
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
2911
|
+
|
2912
|
+
// 4-bit -> 8-bit
|
2913
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
2914
|
+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
2915
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
2916
|
+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
2917
|
+
|
2918
|
+
// interleave
|
2919
|
+
const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
|
2920
|
+
const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
|
2921
|
+
const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
|
2922
|
+
const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
|
2923
|
+
|
2924
|
+
// load y
|
2925
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2926
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2927
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2928
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2929
|
+
|
2930
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
2931
|
+
// dot product into int32x4_t
|
2932
|
+
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h);
|
2933
|
+
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h);
|
2934
|
+
|
2935
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
|
2936
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
|
2937
|
+
#else
|
2938
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
|
2939
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
|
2940
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
|
2941
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
|
2942
|
+
|
2943
|
+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
|
2944
|
+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
|
2945
|
+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
|
2946
|
+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
|
2947
|
+
|
2948
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2949
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2950
|
+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
2951
|
+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
2952
|
+
|
2953
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
|
2954
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
|
2955
|
+
#endif
|
2956
|
+
}
|
2957
|
+
|
2958
|
+
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
|
2959
|
+
#elif defined(__AVX2__)
|
2960
|
+
// Initialize accumulator with zeros
|
2961
|
+
__m256 acc = _mm256_setzero_ps();
|
2962
|
+
|
2963
|
+
float summs = 0;
|
2964
|
+
|
2965
|
+
// Main loop
|
2966
|
+
for (int i = 0; i < nb; ++i) {
|
2967
|
+
const float * d0 = &x[i].d;
|
2968
|
+
const float * d1 = &y[i].d;
|
2969
|
+
|
2970
|
+
summs += x[i].m * (y[i].s0 + y[i].s1);
|
2971
|
+
|
2972
|
+
const __m256 d0v = _mm256_broadcast_ss( d0 );
|
2973
|
+
const __m256 d1v = _mm256_broadcast_ss( d1 );
|
2974
|
+
|
2975
|
+
// Compute combined scales
|
2976
|
+
const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
|
2977
|
+
|
2978
|
+
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
|
2979
|
+
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
|
2980
|
+
const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
|
2981
|
+
|
2982
|
+
const __m256 xy = mul_sum_i8_pairs_float(bx, by);
|
2983
|
+
|
2984
|
+
// Accumulate d0*d1*x*y
|
2985
|
+
acc = _mm256_fmadd_ps( d0d1, xy, acc );
|
2986
|
+
}
|
2987
|
+
|
2988
|
+
*s = hsum_float_8(acc) + summs;
|
2989
|
+
#else
|
2990
|
+
// scalar
|
2991
|
+
float sumf = 0.0;
|
2992
|
+
for (int i = 0; i < nb; i++) {
|
2993
|
+
const float d0 = x[i].d;
|
2994
|
+
const float m0 = x[i].m;
|
2995
|
+
const float d1 = y[i].d;
|
2996
|
+
|
2997
|
+
const uint8_t * restrict p0 = x[i].qs;
|
2998
|
+
const int8_t * restrict p1 = y[i].qs;
|
2999
|
+
|
3000
|
+
// TODO: this is very slow ..
|
3001
|
+
for (int j = 0; j < QK8_1/2; j++) {
|
3002
|
+
const uint8_t v0 = p0[j];
|
3003
|
+
|
3004
|
+
const float f0 = d0*(v0 & 0x0F) + m0;
|
3005
|
+
const float f1 = d0*(v0 >> 4) + m0;
|
3006
|
+
|
3007
|
+
const float f2 = d1*p1[2*j + 0];
|
3008
|
+
const float f3 = d1*p1[2*j + 1];
|
3009
|
+
|
3010
|
+
sumf += f0*f2 + f1*f3;
|
3011
|
+
}
|
3012
|
+
}
|
3013
|
+
*s = sumf;
|
3014
|
+
#endif
|
3015
|
+
}
|
3016
|
+
|
3017
|
+
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
3018
|
+
const int nb = n / QK8_0;
|
3019
|
+
|
3020
|
+
assert(n % QK8_0 == 0);
|
3021
|
+
assert(nb % 2 == 0);
|
3022
|
+
assert(QK8_0 == 2*QK4_2);
|
3023
|
+
|
3024
|
+
const block_q4_2 * restrict x = vx;
|
3025
|
+
const block_q8_0 * restrict y = vy;
|
3026
|
+
|
3027
|
+
#if defined(__ARM_NEON)
|
3028
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
3029
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
3030
|
+
|
3031
|
+
for (int i = 0; i < nb; i += 2) {
|
3032
|
+
const block_q4_2 * restrict x0_0 = &x[2*(i + 0) + 0];
|
3033
|
+
const block_q4_2 * restrict x0_1 = &x[2*(i + 0) + 1];
|
3034
|
+
const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0];
|
3035
|
+
const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1];
|
3036
|
+
|
3037
|
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
3038
|
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
3039
|
+
|
3040
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
3041
|
+
const int8x16_t s8b = vdupq_n_s8(0x8);
|
3042
|
+
|
3043
|
+
const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
|
3044
|
+
const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
|
3045
|
+
|
3046
|
+
// 4-bit -> 8-bit
|
3047
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
3048
|
+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
3049
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
3050
|
+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
3051
|
+
|
3052
|
+
// sub 8
|
3053
|
+
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
3054
|
+
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
3055
|
+
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
3056
|
+
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
3057
|
+
|
3058
|
+
// interleave
|
3059
|
+
const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
|
3060
|
+
const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
|
3061
|
+
const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
|
3062
|
+
const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
|
3063
|
+
|
3064
|
+
// load y
|
3065
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
3066
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
3067
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
3068
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
3069
|
+
|
3070
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
3071
|
+
sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
|
3072
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
|
3073
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
|
3074
|
+
|
3075
|
+
sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
|
3076
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
|
3077
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
|
3078
|
+
#else
|
3079
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
|
3080
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
|
3081
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
|
3082
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
|
3083
|
+
|
3084
|
+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
|
3085
|
+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
|
3086
|
+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
|
3087
|
+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
|
3088
|
+
|
3089
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
3090
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
3091
|
+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
3092
|
+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
3093
|
+
|
3094
|
+
sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
|
3095
|
+
vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
|
3096
|
+
vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
|
3097
|
+
|
3098
|
+
sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
|
3099
|
+
vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
|
3100
|
+
vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
|
3101
|
+
#endif
|
3102
|
+
}
|
3103
|
+
|
3104
|
+
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
3105
|
+
#elif defined(__AVX2__)
|
3106
|
+
// Initialize accumulator with zeros
|
3107
|
+
__m256 acc = _mm256_setzero_ps();
|
3108
|
+
|
3109
|
+
// Main loop
|
3110
|
+
for (int i = 0; i < nb; i++) {
|
3111
|
+
/* Compute combined scale for the block */
|
3112
|
+
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
|
3113
|
+
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
|
3114
|
+
const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
|
3115
|
+
|
3116
|
+
__m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
|
3117
|
+
__m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
|
3118
|
+
__m256i bx = _mm256_set_m128i(bx1, bx0);
|
3119
|
+
|
3120
|
+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
3121
|
+
const __m256i off = _mm256_set1_epi8(8);
|
3122
|
+
bx = _mm256_sub_epi8(bx, off);
|
3123
|
+
|
3124
|
+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
3125
|
+
|
3126
|
+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
3127
|
+
|
3128
|
+
/* Multiply q with scale and accumulate */
|
3129
|
+
acc = _mm256_fmadd_ps(d, q, acc);
|
3130
|
+
}
|
3131
|
+
|
3132
|
+
*s = hsum_float_8(acc);
|
3133
|
+
#else
|
3134
|
+
// scalar
|
3135
|
+
float sumf = 0.0;
|
3136
|
+
for (int i = 0; i < nb; i++) {
|
3137
|
+
const uint8_t * restrict x0 = x[2*i + 0].qs;
|
3138
|
+
const uint8_t * restrict x1 = x[2*i + 1].qs;
|
3139
|
+
const int8_t * restrict y0 = y[i].qs;
|
3140
|
+
|
3141
|
+
const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
|
3142
|
+
const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
|
3143
|
+
|
3144
|
+
int sumi_0 = 0;
|
3145
|
+
int sumi_1 = 0;
|
3146
|
+
|
3147
|
+
for (int j = 0; j < QK8_0/4; j++) {
|
3148
|
+
const uint8_t v0 = x0[j];
|
3149
|
+
const uint8_t v1 = x1[j];
|
3150
|
+
|
3151
|
+
const int i0_0 = (int8_t) (v0 & 0x0F) - 8;
|
3152
|
+
const int i1_0 = (int8_t) (v0 >> 4) - 8;
|
3153
|
+
|
3154
|
+
const int i0_1 = (int8_t) (v1 & 0x0F) - 8;
|
3155
|
+
const int i1_1 = (int8_t) (v1 >> 4) - 8;
|
3156
|
+
|
3157
|
+
const int i2_0 = y0[2*j + 0];
|
3158
|
+
const int i3_0 = y0[2*j + 1];
|
3159
|
+
|
3160
|
+
const int i2_1 = y0[2*(j + QK8_0/4) + 0];
|
3161
|
+
const int i3_1 = y0[2*(j + QK8_0/4) + 1];
|
3162
|
+
|
3163
|
+
sumi_0 += i0_0*i2_0 + i1_0*i3_0;
|
3164
|
+
sumi_1 += i0_1*i2_1 + i1_1*i3_1;
|
3165
|
+
}
|
3166
|
+
|
3167
|
+
sumf += (d0 * y[i].d) * sumi_0;
|
3168
|
+
sumf += (d1 * y[i].d) * sumi_1;
|
3169
|
+
}
|
3170
|
+
*s = sumf;
|
3171
|
+
#endif
|
3172
|
+
}
|
3173
|
+
|
3174
|
+
static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
3175
|
+
const int nb = n / QK8_1;
|
3176
|
+
|
3177
|
+
assert(n % QK8_1 == 0);
|
3178
|
+
assert(nb % 2 == 0);
|
3179
|
+
assert(QK8_1 == 2*QK4_3);
|
3180
|
+
|
3181
|
+
const block_q4_3 * restrict x = vx;
|
3182
|
+
const block_q8_1 * restrict y = vy;
|
3183
|
+
|
3184
|
+
#if defined(__ARM_NEON)
|
3185
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
3186
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
3187
|
+
|
3188
|
+
float summs0 = 0.0f;
|
3189
|
+
float summs1 = 0.0f;
|
3190
|
+
|
3191
|
+
for (int i = 0; i < nb; ++i) {
|
3192
|
+
const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
|
3193
|
+
const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
|
3194
|
+
|
3195
|
+
const block_q8_1 * restrict y0 = &y[i + 0];
|
3196
|
+
|
3197
|
+
summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0;
|
3198
|
+
summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1;
|
3199
|
+
|
3200
|
+
const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
|
3201
|
+
|
3202
|
+
// 4-bit -> 8-bit
|
3203
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0x0F)));
|
3204
|
+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
3205
|
+
|
3206
|
+
// interleave
|
3207
|
+
const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
|
3208
|
+
const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
|
3209
|
+
|
3210
|
+
// load y
|
3211
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
3212
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
3213
|
+
|
3214
|
+
const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
|
3215
|
+
const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
|
3216
|
+
|
3217
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
3218
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
|
3219
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
|
3220
|
+
#else
|
3221
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
|
3222
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
|
3223
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
|
3224
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
|
3225
|
+
|
3226
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
3227
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
3228
|
+
|
3229
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
|
3230
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph0), x0_1d*y0->d);
|
3231
|
+
#endif
|
3232
|
+
}
|
3233
|
+
|
3234
|
+
*s = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1;
|
3235
|
+
#elif defined(__AVX2__)
|
3236
|
+
// Initialize accumulator with zeros
|
3237
|
+
__m256 acc = _mm256_setzero_ps();
|
3238
|
+
float summs = 0.0f;
|
3239
|
+
|
3240
|
+
// Main loop
|
3241
|
+
for (int i = 0; i < nb; i++) {
|
3242
|
+
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
|
3243
|
+
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
|
3244
|
+
const __m256 dx = _mm256_set_m128(d1, d0);
|
3245
|
+
|
3246
|
+
summs += GGML_FP16_TO_FP32(x[2*i + 0].m) * y[i].s0
|
3247
|
+
+ GGML_FP16_TO_FP32(x[2*i + 1].m) * y[i].s1;
|
3248
|
+
|
3249
|
+
const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
|
3250
|
+
const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
|
3251
|
+
const __m256i bx = _mm256_set_m128i(bx1, bx0);
|
3252
|
+
|
3253
|
+
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
|
3254
|
+
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
3255
|
+
|
3256
|
+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
3257
|
+
|
3258
|
+
acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
|
3259
|
+
}
|
3260
|
+
|
3261
|
+
*s = hsum_float_8(acc) + summs;
|
3262
|
+
#else
|
3263
|
+
// scalar
|
3264
|
+
float sumf = 0.0;
|
3265
|
+
for (int i = 0; i < nb; i++) {
|
3266
|
+
const uint8_t * restrict x0 = x[2*i + 0].qs;
|
3267
|
+
const uint8_t * restrict x1 = x[2*i + 1].qs;
|
3268
|
+
const int8_t * restrict y0 = y[i].qs;
|
3269
|
+
|
3270
|
+
const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
|
3271
|
+
const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
|
3272
|
+
const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
|
3273
|
+
const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
|
3274
|
+
|
3275
|
+
int sxy_0 = 0;
|
3276
|
+
int sxy_1 = 0;
|
3277
|
+
|
3278
|
+
for (int j = 0; j < QK8_1/4; j++) {
|
3279
|
+
const uint8_t v0 = x0[j];
|
3280
|
+
const uint8_t v1 = x1[j];
|
3281
|
+
|
3282
|
+
const int x0_0 = v0 & 0x0F;
|
3283
|
+
const int x1_0 = v0 >> 4;
|
3284
|
+
|
3285
|
+
const int x0_1 = v1 & 0x0F;
|
3286
|
+
const int x1_1 = v1 >> 4;
|
3287
|
+
|
3288
|
+
const int y0_0 = y0[2*j + 0];
|
3289
|
+
const int y1_0 = y0[2*j + 1];
|
3290
|
+
|
3291
|
+
const int y0_1 = y0[2*(j + QK8_1/4) + 0];
|
3292
|
+
const int y1_1 = y0[2*(j + QK8_1/4) + 1];
|
3293
|
+
|
3294
|
+
sxy_0 += x0_0*y0_0 + x1_0*y1_0;
|
3295
|
+
sxy_1 += x0_1*y0_1 + x1_1*y1_1;
|
3296
|
+
}
|
3297
|
+
|
3298
|
+
sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1;
|
3299
|
+
}
|
3300
|
+
*s = sumf;
|
3301
|
+
#endif
|
3302
|
+
}
|
3303
|
+
|
3304
|
+
static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
3305
|
+
const int nb = n / QK8_0;
|
3306
|
+
|
3307
|
+
assert(n % QK8_0 == 0);
|
3308
|
+
assert(nb % 2 == 0);
|
3309
|
+
assert(QK8_0 == QK5_0);
|
3310
|
+
|
3311
|
+
const block_q5_0 * restrict x = vx;
|
3312
|
+
const block_q8_0 * restrict y = vy;
|
2451
3313
|
|
2452
|
-
|
2453
|
-
|
2454
|
-
// Apply the scale, and accumulate
|
2455
|
-
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
|
2456
|
-
}
|
3314
|
+
#if defined(__ARM_NEON)
|
3315
|
+
float32x4_t sumv = vdupq_n_f32(0.0f);
|
2457
3316
|
|
2458
|
-
|
2459
|
-
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
2460
|
-
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2461
|
-
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2462
|
-
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
3317
|
+
uint64_t tmp[4];
|
2463
3318
|
|
2464
|
-
|
2465
|
-
|
2466
|
-
|
2467
|
-
float sum0 = 0.0f;
|
2468
|
-
float sum1 = 0.0f;
|
3319
|
+
for (int i = 0; i < nb; ++i) {
|
3320
|
+
const block_q5_0 * restrict x0 = &x[i];
|
3321
|
+
const block_q8_0 * restrict y0 = &y[i];
|
2469
3322
|
|
2470
|
-
|
2471
|
-
const
|
2472
|
-
const block_q4_0 * restrict y0 = &y[i + 0];
|
2473
|
-
const block_q4_0 * restrict x1 = &x[i + 1];
|
2474
|
-
const block_q4_0 * restrict y1 = &y[i + 1];
|
3323
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
3324
|
+
const int8x16_t s16b = vdupq_n_s8(0x10);
|
2475
3325
|
|
2476
|
-
|
2477
|
-
|
3326
|
+
// extract the 5th bit
|
3327
|
+
uint32_t qh;
|
3328
|
+
memcpy(&qh, x0->qh, sizeof(qh));
|
2478
3329
|
|
2479
|
-
|
2480
|
-
|
2481
|
-
|
2482
|
-
|
3330
|
+
tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
|
3331
|
+
tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
|
3332
|
+
tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
|
3333
|
+
tmp[3] = table_b2b_u[(qh >> 24) ];
|
2483
3334
|
|
2484
|
-
|
2485
|
-
const
|
2486
|
-
const v128_t v1_0l = wasm_v128_and(v1_0, m4b);
|
3335
|
+
const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
|
3336
|
+
const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
|
2487
3337
|
|
2488
|
-
const
|
2489
|
-
const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4);
|
3338
|
+
const uint8x16_t v0 = vld1q_u8(x0->qs);
|
2490
3339
|
|
2491
|
-
|
2492
|
-
const
|
3340
|
+
// 4-bit -> 8-bit
|
3341
|
+
const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, m4b));
|
3342
|
+
const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
|
2493
3343
|
|
2494
|
-
|
2495
|
-
const
|
3344
|
+
// interleave
|
3345
|
+
const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
|
3346
|
+
const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
|
2496
3347
|
|
2497
|
-
// sub
|
2498
|
-
const
|
2499
|
-
const
|
3348
|
+
// add high bit and sub 16
|
3349
|
+
const int8x16_t v0lf = vsubq_s8(vorrq_s8(v0lz, qhl), s16b);
|
3350
|
+
const int8x16_t v0hf = vsubq_s8(vorrq_s8(v0hz, qhh), s16b);
|
2500
3351
|
|
2501
|
-
|
2502
|
-
const
|
3352
|
+
// load y
|
3353
|
+
const int8x16_t v1l = vld1q_s8(y0->qs);
|
3354
|
+
const int8x16_t v1h = vld1q_s8(y0->qs + 16);
|
2503
3355
|
|
2504
|
-
const
|
2505
|
-
const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b);
|
3356
|
+
const float x0d = GGML_FP16_TO_FP32(x0->d);
|
2506
3357
|
|
2507
|
-
|
2508
|
-
|
3358
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
3359
|
+
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
|
3360
|
+
vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
|
3361
|
+
vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
|
3362
|
+
#else
|
3363
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
|
3364
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
|
3365
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
|
3366
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
|
2509
3367
|
|
2510
|
-
|
2511
|
-
const
|
2512
|
-
const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls));
|
3368
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
3369
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2513
3370
|
|
2514
|
-
|
2515
|
-
|
3371
|
+
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
|
3372
|
+
#endif
|
3373
|
+
}
|
2516
3374
|
|
2517
|
-
|
2518
|
-
|
3375
|
+
*s = vaddvq_f32(sumv);
|
3376
|
+
#elif defined(__AVX2__)
|
3377
|
+
// Initialize accumulator with zeros
|
3378
|
+
__m256 acc = _mm256_setzero_ps();
|
2519
3379
|
|
2520
|
-
|
2521
|
-
|
3380
|
+
// Main loop
|
3381
|
+
for (int i = 0; i < nb; i++) {
|
3382
|
+
/* Compute combined scale for the block */
|
3383
|
+
const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
|
2522
3384
|
|
2523
|
-
|
2524
|
-
|
3385
|
+
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
3386
|
+
__m256i bxhi = bytes_from_bits_32(x[i].qh);
|
3387
|
+
bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
|
3388
|
+
bx = _mm256_or_si256(bx, bxhi);
|
2525
3389
|
|
2526
|
-
|
2527
|
-
const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h);
|
3390
|
+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
2528
3391
|
|
2529
|
-
const
|
2530
|
-
const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1);
|
3392
|
+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
2531
3393
|
|
2532
|
-
|
2533
|
-
|
2534
|
-
wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) +
|
2535
|
-
wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) +
|
2536
|
-
wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7));
|
2537
|
-
sum1 += x1->d * y1->d * (
|
2538
|
-
wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) +
|
2539
|
-
wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) +
|
2540
|
-
wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) +
|
2541
|
-
wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7));
|
3394
|
+
/* Multiply q with scale and accumulate */
|
3395
|
+
acc = _mm256_fmadd_ps(d, q, acc);
|
2542
3396
|
}
|
2543
3397
|
|
2544
|
-
|
3398
|
+
*s = hsum_float_8(acc);
|
2545
3399
|
#else
|
2546
3400
|
// scalar
|
3401
|
+
float sumf = 0.0;
|
2547
3402
|
for (int i = 0; i < nb; i++) {
|
2548
|
-
const
|
2549
|
-
const
|
3403
|
+
const uint8_t * restrict x0 = x[i].qs;
|
3404
|
+
const int8_t * restrict y0 = y[i].qs;
|
2550
3405
|
|
2551
|
-
|
2552
|
-
|
3406
|
+
uint32_t qh;
|
3407
|
+
memcpy(&qh, x[i].qh, sizeof(qh));
|
2553
3408
|
|
2554
|
-
|
2555
|
-
|
2556
|
-
|
2557
|
-
const uint8_t v1 = p1[j];
|
3409
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
3410
|
+
|
3411
|
+
int sxy = 0;
|
2558
3412
|
|
2559
|
-
|
2560
|
-
const
|
3413
|
+
for (int j = 0; j < QK8_0/2; j++) {
|
3414
|
+
const uint8_t v0 = x0[j];
|
2561
3415
|
|
2562
|
-
const int
|
2563
|
-
const int
|
3416
|
+
const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4;
|
3417
|
+
const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4;
|
2564
3418
|
|
2565
|
-
|
3419
|
+
const int x0_0 = ((v0 & 0x0F) | x0_0h) - 16;
|
3420
|
+
const int x1_0 = ((v0 >> 4) | x1_0h) - 16;
|
3421
|
+
|
3422
|
+
const int y0_0 = y0[2*j + 0];
|
3423
|
+
const int y1_0 = y0[2*j + 1];
|
3424
|
+
|
3425
|
+
sxy += x0_0*y0_0 + x1_0*y1_0;
|
2566
3426
|
}
|
2567
|
-
sumf += d0 * d1 * sumi;
|
2568
|
-
}
|
2569
|
-
#endif
|
2570
3427
|
|
3428
|
+
sumf += (d*sxy)*y[i].d;
|
3429
|
+
}
|
2571
3430
|
*s = sumf;
|
3431
|
+
#endif
|
2572
3432
|
}
|
2573
3433
|
|
2574
|
-
static void
|
2575
|
-
const int nb = n /
|
3434
|
+
static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
3435
|
+
const int nb = n / QK8_1;
|
2576
3436
|
|
2577
|
-
|
2578
|
-
|
2579
|
-
|
2580
|
-
float sumf = 0.0;
|
3437
|
+
assert(n % QK8_1 == 0);
|
3438
|
+
assert(nb % 2 == 0);
|
3439
|
+
assert(QK8_1 == QK5_1);
|
2581
3440
|
|
2582
|
-
|
2583
|
-
|
2584
|
-
__m256 acc = _mm256_setzero_ps();
|
2585
|
-
// Accumulator for constant offsets
|
2586
|
-
float acc_offset = 0.0f;
|
3441
|
+
const block_q5_1 * restrict x = vx;
|
3442
|
+
const block_q8_1 * restrict y = vy;
|
2587
3443
|
|
2588
|
-
|
2589
|
-
|
2590
|
-
const float * d0 = &x[i].d;
|
2591
|
-
const float * d1 = &y[i].d;
|
3444
|
+
#if defined(__ARM_NEON)
|
3445
|
+
float32x4_t sumv = vdupq_n_f32(0.0f);
|
2592
3446
|
|
2593
|
-
|
2594
|
-
const float * m1 = &y[i].m;
|
3447
|
+
float summs = 0.0f;
|
2595
3448
|
|
2596
|
-
|
2597
|
-
const __m256 d1v = _mm256_broadcast_ss( d1 );
|
2598
|
-
const __m256 m0v = _mm256_broadcast_ss( m0 );
|
2599
|
-
const __m256 m1v = _mm256_broadcast_ss( m1 );
|
3449
|
+
uint64_t tmp[4];
|
2600
3450
|
|
2601
|
-
|
2602
|
-
const
|
3451
|
+
for (int i = 0; i < nb; ++i) {
|
3452
|
+
const block_q5_1 * restrict x0 = &x[i];
|
3453
|
+
const block_q8_1 * restrict y0 = &y[i];
|
2603
3454
|
|
2604
|
-
|
2605
|
-
const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
|
2606
|
-
const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
|
2607
|
-
const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ );
|
3455
|
+
summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
|
2608
3456
|
|
2609
|
-
//
|
2610
|
-
|
2611
|
-
|
2612
|
-
|
2613
|
-
// Now we have a vector with bytes in [ 0 .. 15 ] interval.
|
2614
|
-
|
2615
|
-
// Sign-extend first 16 signed bytes into int16_t
|
2616
|
-
__m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
|
2617
|
-
__m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
|
2618
|
-
// Compute products of int16_t integers, add pairwise
|
2619
|
-
__m256i i32 = _mm256_madd_epi16( x16, y16 );
|
2620
|
-
|
2621
|
-
// Sign-extend last 16 signed bytes into int16_t vectors
|
2622
|
-
__m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
|
2623
|
-
__m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
|
2624
|
-
// Accumulate products of int16_t integers
|
2625
|
-
i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) );
|
2626
|
-
|
2627
|
-
// compute sums of unsigned bytes in bx, by in blocks of 8.
|
2628
|
-
// This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
|
2629
|
-
// which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
|
2630
|
-
// so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
|
2631
|
-
__m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() );
|
2632
|
-
__m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() );
|
2633
|
-
__m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
|
2634
|
-
__m256 sums = _mm256_cvtepi32_ps( sumsi );
|
3457
|
+
// extract the 5th bit
|
3458
|
+
uint32_t qh;
|
3459
|
+
memcpy(&qh, x0->qh, sizeof(qh));
|
2635
3460
|
|
2636
|
-
|
2637
|
-
|
2638
|
-
|
2639
|
-
|
2640
|
-
acc = _mm256_fmadd_ps( scale_01, p, acc );
|
2641
|
-
acc = _mm256_fmadd_ps( cross_scales, sums, acc );
|
2642
|
-
// acc_offset += m0*m1 (for each entry in the block)
|
2643
|
-
acc_offset += (*m0)*(*m1);
|
2644
|
-
}
|
3461
|
+
tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
|
3462
|
+
tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
|
3463
|
+
tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
|
3464
|
+
tmp[3] = table_b2b_u[(qh >> 24) ];
|
2645
3465
|
|
2646
|
-
|
2647
|
-
|
2648
|
-
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2649
|
-
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2650
|
-
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
3466
|
+
const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
|
3467
|
+
const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
|
2651
3468
|
|
2652
|
-
|
2653
|
-
#elif defined(__ARM_NEON)
|
2654
|
-
float sum00 = 0.0f;
|
2655
|
-
float sum01 = 0.0f;
|
2656
|
-
float sum10 = 0.0f;
|
2657
|
-
float sum11 = 0.0f;
|
3469
|
+
const uint8x16_t v0 = vld1q_u8(x0->qs);
|
2658
3470
|
|
2659
|
-
|
2660
|
-
const
|
2661
|
-
const
|
2662
|
-
const block_q4_1 * restrict x1 = &x[i + 1];
|
2663
|
-
const block_q4_1 * restrict y1 = &y[i + 1];
|
3471
|
+
// 4-bit -> 8-bit
|
3472
|
+
const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, vdupq_n_u8(0x0F)));
|
3473
|
+
const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
|
2664
3474
|
|
2665
|
-
|
3475
|
+
// interleave
|
3476
|
+
const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
|
3477
|
+
const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
|
2666
3478
|
|
2667
|
-
|
2668
|
-
const
|
2669
|
-
const
|
2670
|
-
const uint8x16_t v1_1 = vld1q_u8(y1->qs);
|
3479
|
+
// add
|
3480
|
+
const int8x16_t v0lf = vorrq_s8(v0lz, qhl);
|
3481
|
+
const int8x16_t v0hf = vorrq_s8(v0hz, qhh);
|
2671
3482
|
|
2672
|
-
//
|
2673
|
-
const
|
2674
|
-
const
|
2675
|
-
const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
|
2676
|
-
const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
|
3483
|
+
// load y
|
3484
|
+
const int8x16_t v1l = vld1q_s8(y0->qs);
|
3485
|
+
const int8x16_t v1h = vld1q_s8(y0->qs + 16);
|
2677
3486
|
|
2678
|
-
const
|
2679
|
-
const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
|
2680
|
-
const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
|
2681
|
-
const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
|
3487
|
+
const float x0d = GGML_FP16_TO_FP32(x0->d);
|
2682
3488
|
|
2683
|
-
|
2684
|
-
|
2685
|
-
|
3489
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
3490
|
+
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
|
3491
|
+
vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
|
3492
|
+
vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
|
3493
|
+
#else
|
3494
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
|
3495
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
|
3496
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
|
3497
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
|
2686
3498
|
|
2687
|
-
|
2688
|
-
|
2689
|
-
sum10 += x1->m*y1->d*((uint16_t)vaddvq_u8(v1_1l) + (uint16_t)vaddvq_u8(v1_1h));
|
3499
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
3500
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2690
3501
|
|
2691
|
-
|
2692
|
-
|
2693
|
-
|
2694
|
-
uint32x4_t p_1 = vdotq_u32(vdupq_n_u32(0), v0_1l, v1_1l);
|
3502
|
+
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
|
3503
|
+
#endif
|
3504
|
+
}
|
2695
3505
|
|
2696
|
-
|
2697
|
-
|
3506
|
+
*s = vaddvq_f32(sumv) + summs;
|
3507
|
+
#elif defined(__AVX2__)
|
3508
|
+
// Initialize accumulator with zeros
|
3509
|
+
__m256 acc = _mm256_setzero_ps();
|
3510
|
+
float summs = 0.0f;
|
2698
3511
|
|
2699
|
-
|
2700
|
-
|
2701
|
-
|
2702
|
-
const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
|
2703
|
-
const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
|
2704
|
-
const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
|
2705
|
-
const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
|
3512
|
+
// Main loop
|
3513
|
+
for (int i = 0; i < nb; i++) {
|
3514
|
+
const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
|
2706
3515
|
|
2707
|
-
|
2708
|
-
const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
|
2709
|
-
const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
|
2710
|
-
const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
|
3516
|
+
summs += GGML_FP16_TO_FP32(x[i].m) * (y[i].s0 + y[i].s1);
|
2711
3517
|
|
2712
|
-
|
2713
|
-
|
3518
|
+
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
3519
|
+
__m256i bxhi = bytes_from_bits_32(x[i].qh);
|
3520
|
+
bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
|
3521
|
+
bx = _mm256_or_si256(bx, bxhi);
|
2714
3522
|
|
2715
|
-
const
|
2716
|
-
const
|
3523
|
+
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
|
3524
|
+
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
2717
3525
|
|
2718
|
-
const
|
2719
|
-
const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
|
3526
|
+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
2720
3527
|
|
2721
|
-
|
2722
|
-
sum11 += x1->d*y1->d*vaddvq_u16(p_1);
|
2723
|
-
#endif
|
3528
|
+
acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
|
2724
3529
|
}
|
2725
3530
|
|
2726
|
-
|
3531
|
+
*s = hsum_float_8(acc) + summs;
|
2727
3532
|
#else
|
2728
|
-
|
3533
|
+
float sumf = 0.0;
|
3534
|
+
|
2729
3535
|
for (int i = 0; i < nb; i++) {
|
2730
|
-
const
|
2731
|
-
const
|
3536
|
+
const uint8_t * restrict x0 = x[i].qs;
|
3537
|
+
const int8_t * restrict y0 = y[i].qs;
|
2732
3538
|
|
2733
|
-
|
2734
|
-
|
3539
|
+
uint32_t qh;
|
3540
|
+
memcpy(&qh, x[i].qh, sizeof(qh));
|
2735
3541
|
|
2736
|
-
const
|
2737
|
-
const
|
3542
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
3543
|
+
const float m = GGML_FP16_TO_FP32(x[i].m);
|
2738
3544
|
|
2739
|
-
|
2740
|
-
const uint8_t v0 = p0[j];
|
2741
|
-
const uint8_t v1 = p1[j];
|
3545
|
+
int sxy = 0;
|
2742
3546
|
|
2743
|
-
|
2744
|
-
const
|
3547
|
+
for (int j = 0; j < QK8_1/2; j++) {
|
3548
|
+
const uint8_t v0 = x0[j];
|
2745
3549
|
|
2746
|
-
const
|
2747
|
-
const
|
3550
|
+
const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4;
|
3551
|
+
const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4;
|
2748
3552
|
|
2749
|
-
|
3553
|
+
const int x0_0 = (v0 & 0x0F) | x0_0h;
|
3554
|
+
const int x1_0 = (v0 >> 4) | x1_0h;
|
3555
|
+
|
3556
|
+
const int y0_0 = y0[2*j + 0];
|
3557
|
+
const int y1_0 = y0[2*j + 1];
|
3558
|
+
|
3559
|
+
sxy += x0_0*y0_0 + x1_0*y1_0;
|
2750
3560
|
}
|
3561
|
+
|
3562
|
+
sumf += (d*sxy)*y[i].d + m*(y[i].s0 + y[i].s1);
|
2751
3563
|
}
|
2752
|
-
#endif
|
2753
3564
|
|
2754
3565
|
*s = sumf;
|
3566
|
+
#endif
|
2755
3567
|
}
|
2756
3568
|
|
2757
|
-
static void
|
3569
|
+
static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2758
3570
|
const int nb = n / QK8_0;
|
2759
3571
|
|
2760
3572
|
assert(n % QK8_0 == 0);
|
2761
3573
|
assert(nb % 2 == 0);
|
3574
|
+
assert(QK8_0 == QK8_0);
|
2762
3575
|
|
2763
|
-
const
|
3576
|
+
const block_q8_0 * restrict x = vx;
|
2764
3577
|
const block_q8_0 * restrict y = vy;
|
2765
3578
|
|
2766
|
-
float sumf = 0.0;
|
2767
|
-
|
2768
3579
|
#if defined(__ARM_NEON)
|
2769
|
-
|
2770
|
-
|
3580
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
3581
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
2771
3582
|
|
2772
3583
|
for (int i = 0; i < nb; i += 2) {
|
2773
|
-
const
|
2774
|
-
const
|
3584
|
+
const block_q8_0 * restrict x0 = &x[i + 0];
|
3585
|
+
const block_q8_0 * restrict x1 = &x[i + 1];
|
2775
3586
|
const block_q8_0 * restrict y0 = &y[i + 0];
|
2776
3587
|
const block_q8_0 * restrict y1 = &y[i + 1];
|
2777
3588
|
|
2778
|
-
const
|
2779
|
-
const int8x16_t
|
2780
|
-
|
2781
|
-
const
|
2782
|
-
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
2783
|
-
|
2784
|
-
// 4-bit -> 8-bit
|
2785
|
-
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
2786
|
-
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
2787
|
-
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
2788
|
-
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
2789
|
-
|
2790
|
-
// sub 8
|
2791
|
-
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
2792
|
-
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
2793
|
-
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
2794
|
-
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
3589
|
+
const int8x16_t x0_0 = vld1q_s8(x0->qs);
|
3590
|
+
const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
|
3591
|
+
const int8x16_t x1_0 = vld1q_s8(x1->qs);
|
3592
|
+
const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
|
2795
3593
|
|
2796
3594
|
// load y
|
2797
|
-
const int8x16_t
|
2798
|
-
const int8x16_t
|
2799
|
-
const int8x16_t
|
2800
|
-
const int8x16_t
|
2801
|
-
|
2802
|
-
// interleave
|
2803
|
-
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
|
2804
|
-
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
|
2805
|
-
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
|
2806
|
-
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
|
3595
|
+
const int8x16_t y0_0 = vld1q_s8(y0->qs);
|
3596
|
+
const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
|
3597
|
+
const int8x16_t y1_0 = vld1q_s8(y1->qs);
|
3598
|
+
const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
|
2807
3599
|
|
2808
3600
|
#if defined(__ARM_FEATURE_DOTPROD)
|
2809
|
-
|
2810
|
-
|
2811
|
-
|
3601
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
3602
|
+
vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
|
3603
|
+
vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), x0->d*y0->d);
|
2812
3604
|
|
2813
|
-
|
2814
|
-
|
3605
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
3606
|
+
vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
|
3607
|
+
vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), x1->d*y1->d);
|
2815
3608
|
|
2816
|
-
sum0 += x0->d*y0->d*vaddvq_s32(p_0);
|
2817
|
-
sum1 += x1->d*y1->d*vaddvq_s32(p_1);
|
2818
3609
|
#else
|
2819
|
-
const int16x8_t
|
2820
|
-
const int16x8_t
|
2821
|
-
const int16x8_t
|
2822
|
-
const int16x8_t
|
2823
|
-
|
2824
|
-
const int16x8_t
|
2825
|
-
const int16x8_t
|
2826
|
-
const int16x8_t
|
2827
|
-
const int16x8_t
|
2828
|
-
|
2829
|
-
const
|
2830
|
-
const
|
2831
|
-
|
2832
|
-
const
|
2833
|
-
|
2834
|
-
|
2835
|
-
|
2836
|
-
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
|
2837
|
-
|
2838
|
-
sum0 += x0->d*y0->d*vaddvq_s16(p_0);
|
2839
|
-
sum1 += x1->d*y1->d*vaddvq_s16(p_1);
|
3610
|
+
const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
|
3611
|
+
const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
|
3612
|
+
const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
|
3613
|
+
const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
3614
|
+
|
3615
|
+
const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
|
3616
|
+
const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
|
3617
|
+
const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
|
3618
|
+
const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
|
3619
|
+
|
3620
|
+
const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
3621
|
+
const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
3622
|
+
const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
|
3623
|
+
const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
|
3624
|
+
|
3625
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0->d*y0->d);
|
3626
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), x1->d*y1->d);
|
2840
3627
|
#endif
|
2841
3628
|
}
|
2842
3629
|
|
2843
|
-
|
3630
|
+
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
2844
3631
|
#elif defined(__AVX2__)
|
2845
3632
|
// Initialize accumulator with zeros
|
2846
3633
|
__m256 acc = _mm256_setzero_ps();
|
2847
3634
|
|
2848
3635
|
// Main loop
|
2849
3636
|
for (int i = 0; i < nb; ++i) {
|
2850
|
-
|
3637
|
+
// Compute combined scale for the block
|
2851
3638
|
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
|
2852
|
-
|
2853
|
-
__m256i bx = bytesFromNibbles(x[i].qs);
|
2854
|
-
|
2855
|
-
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2856
|
-
const __m256i off = _mm256_set1_epi8( 8 );
|
2857
|
-
bx = _mm256_sub_epi8( bx, off );
|
2858
|
-
|
3639
|
+
__m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
|
2859
3640
|
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
2860
3641
|
|
2861
|
-
|
2862
|
-
const __m256i ax = _mm256_sign_epi8(bx, bx);
|
2863
|
-
|
2864
|
-
// Sign the values of the y vectors
|
2865
|
-
const __m256i sy = _mm256_sign_epi8(by, bx);
|
3642
|
+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
2866
3643
|
|
2867
|
-
//
|
2868
|
-
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
2869
|
-
|
2870
|
-
const __m256i ones = _mm256_set1_epi16(1);
|
2871
|
-
__m256i xy_q = _mm256_madd_epi16(ones, dot);
|
2872
|
-
|
2873
|
-
/* Convert to vectore of 8 int32_t to 8 floats */
|
2874
|
-
__m256 q = _mm256_cvtepi32_ps( xy_q );
|
2875
|
-
|
2876
|
-
/* Multiply q with scale and accumulate */
|
3644
|
+
// Multiply q with scale and accumulate
|
2877
3645
|
acc = _mm256_fmadd_ps( d, q, acc );
|
2878
3646
|
}
|
2879
3647
|
|
2880
|
-
|
2881
|
-
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
2882
|
-
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2883
|
-
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2884
|
-
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2885
|
-
|
2886
|
-
sumf = _mm_cvtss_f32( res );
|
2887
|
-
#elif defined(__AVX__)
|
2888
|
-
// Initialize accumulator with zeros
|
2889
|
-
__m256 acc = _mm256_setzero_ps();
|
2890
|
-
|
2891
|
-
// Main loop
|
2892
|
-
for (int i = 0; i < nb; ++i) {
|
2893
|
-
// Compute combined scale for the block
|
2894
|
-
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
|
2895
|
-
|
2896
|
-
__m128i i32[2];
|
2897
|
-
for (int j = 0; j < 2; ++j) {
|
2898
|
-
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
|
2899
|
-
__m128i bx = bytesFromNibbles( x[i].qs + 8*j );
|
2900
|
-
__m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
|
2901
|
-
|
2902
|
-
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2903
|
-
const __m128i off = _mm_set1_epi8( 8 );
|
2904
|
-
bx = _mm_sub_epi8( bx, off );
|
2905
|
-
|
2906
|
-
// Get absolute values of x vectors
|
2907
|
-
const __m128i ax = _mm_sign_epi8(bx, bx);
|
2908
|
-
|
2909
|
-
// Sign the values of the y vectors
|
2910
|
-
const __m128i sy = _mm_sign_epi8(by, bx);
|
2911
|
-
|
2912
|
-
// Perform multiplication and create 16-bit values
|
2913
|
-
const __m128i dot = _mm_maddubs_epi16(ax, sy);
|
2914
|
-
|
2915
|
-
const __m128i ones = _mm_set1_epi16(1);
|
2916
|
-
i32[j] = _mm_madd_epi16(ones, dot);
|
2917
|
-
}
|
2918
|
-
|
2919
|
-
// Convert int32_t to float
|
2920
|
-
__m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
|
2921
|
-
// Apply the scale, and accumulate
|
2922
|
-
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
|
2923
|
-
}
|
2924
|
-
|
2925
|
-
// Return horizontal sum of the acc vector
|
2926
|
-
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
2927
|
-
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2928
|
-
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2929
|
-
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2930
|
-
|
2931
|
-
sumf = _mm_cvtss_f32( res );
|
3648
|
+
*s = hsum_float_8(acc);
|
2932
3649
|
#else
|
2933
3650
|
// scalar
|
2934
|
-
|
2935
|
-
const float d0 = x[i].d;
|
2936
|
-
const float d1 = y[i].d;
|
3651
|
+
float sumf = 0.0;
|
2937
3652
|
|
2938
|
-
|
2939
|
-
const
|
3653
|
+
for (int i = 0; i < nb; i++) {
|
3654
|
+
const int8_t * restrict x0 = x[i].qs;
|
3655
|
+
const int8_t * restrict y0 = y[i].qs;
|
2940
3656
|
|
2941
3657
|
int sumi = 0;
|
2942
|
-
for (int j = 0; j < QK8_0/2; j++) {
|
2943
|
-
const uint8_t v0 = p0[j];
|
2944
3658
|
|
2945
|
-
|
2946
|
-
const int
|
3659
|
+
for (int j = 0; j < QK8_0; j++) {
|
3660
|
+
const int v0 = x0[j];
|
3661
|
+
const int v1 = y0[j];
|
2947
3662
|
|
2948
|
-
|
2949
|
-
const int i3 = p1[2*j + 1];
|
2950
|
-
|
2951
|
-
sumi += i0*i2 + i1*i3;
|
3663
|
+
sumi += v0*v1;
|
2952
3664
|
}
|
2953
|
-
|
3665
|
+
|
3666
|
+
sumf += (x[i].d*y[i].d)*sumi;
|
2954
3667
|
}
|
2955
|
-
#endif
|
2956
3668
|
|
2957
3669
|
*s = sumf;
|
3670
|
+
#endif
|
2958
3671
|
}
|
2959
3672
|
|
2960
3673
|
// compute GGML_VEC_DOT_UNROLL dot products at once
|
@@ -3153,6 +3866,14 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
|
|
3153
3866
|
#endif
|
3154
3867
|
}
|
3155
3868
|
|
3869
|
+
inline static void ggml_vec_sum_ggf(const int n, ggml_float * s, const float * x) {
|
3870
|
+
ggml_float sum = 0.0;
|
3871
|
+
for (int i = 0; i < n; ++i) {
|
3872
|
+
sum += (ggml_float)x[i];
|
3873
|
+
}
|
3874
|
+
*s = sum;
|
3875
|
+
}
|
3876
|
+
|
3156
3877
|
inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
|
3157
3878
|
#ifndef GGML_USE_ACCELERATE
|
3158
3879
|
float max = -INFINITY;
|
@@ -3203,24 +3924,34 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
|
|
3203
3924
|
[GGML_TYPE_F16] = 1,
|
3204
3925
|
[GGML_TYPE_Q4_0] = QK4_0,
|
3205
3926
|
[GGML_TYPE_Q4_1] = QK4_1,
|
3927
|
+
[GGML_TYPE_Q4_2] = QK4_2,
|
3928
|
+
[GGML_TYPE_Q4_3] = QK4_3,
|
3929
|
+
[GGML_TYPE_Q5_0] = QK5_0,
|
3930
|
+
[GGML_TYPE_Q5_1] = QK5_1,
|
3206
3931
|
[GGML_TYPE_Q8_0] = QK8_0,
|
3932
|
+
[GGML_TYPE_Q8_1] = QK8_1,
|
3207
3933
|
[GGML_TYPE_I8] = 1,
|
3208
3934
|
[GGML_TYPE_I16] = 1,
|
3209
3935
|
[GGML_TYPE_I32] = 1,
|
3210
3936
|
};
|
3211
|
-
static_assert(GGML_TYPE_COUNT ==
|
3937
|
+
static_assert(GGML_TYPE_COUNT == 13, "GGML_BLCK_SIZE is outdated");
|
3212
3938
|
|
3213
3939
|
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
|
3214
3940
|
[GGML_TYPE_F32] = sizeof(float),
|
3215
3941
|
[GGML_TYPE_F16] = sizeof(ggml_fp16_t),
|
3216
3942
|
[GGML_TYPE_Q4_0] = sizeof(block_q4_0),
|
3217
3943
|
[GGML_TYPE_Q4_1] = sizeof(block_q4_1),
|
3944
|
+
[GGML_TYPE_Q4_2] = sizeof(block_q4_2),
|
3945
|
+
[GGML_TYPE_Q4_3] = sizeof(block_q4_3),
|
3946
|
+
[GGML_TYPE_Q5_0] = sizeof(block_q5_0),
|
3947
|
+
[GGML_TYPE_Q5_1] = sizeof(block_q5_1),
|
3218
3948
|
[GGML_TYPE_Q8_0] = sizeof(block_q8_0),
|
3949
|
+
[GGML_TYPE_Q8_1] = sizeof(block_q8_1),
|
3219
3950
|
[GGML_TYPE_I8] = sizeof(int8_t),
|
3220
3951
|
[GGML_TYPE_I16] = sizeof(int16_t),
|
3221
3952
|
[GGML_TYPE_I32] = sizeof(int32_t),
|
3222
3953
|
};
|
3223
|
-
static_assert(GGML_TYPE_COUNT ==
|
3954
|
+
static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_SIZE is outdated");
|
3224
3955
|
|
3225
3956
|
|
3226
3957
|
static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
|
@@ -3228,12 +3959,34 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
|
|
3228
3959
|
[GGML_TYPE_F16] = "f16",
|
3229
3960
|
[GGML_TYPE_Q4_0] = "q4_0",
|
3230
3961
|
[GGML_TYPE_Q4_1] = "q4_1",
|
3962
|
+
[GGML_TYPE_Q4_2] = "q4_2",
|
3963
|
+
[GGML_TYPE_Q4_3] = "q4_3",
|
3964
|
+
[GGML_TYPE_Q5_0] = "q5_0",
|
3965
|
+
[GGML_TYPE_Q5_1] = "q5_1",
|
3231
3966
|
[GGML_TYPE_Q8_0] = "q8_0",
|
3967
|
+
[GGML_TYPE_Q8_1] = "q8_1",
|
3232
3968
|
[GGML_TYPE_I8] = "i8",
|
3233
3969
|
[GGML_TYPE_I16] = "i16",
|
3234
3970
|
[GGML_TYPE_I32] = "i32",
|
3235
3971
|
};
|
3236
|
-
static_assert(GGML_TYPE_COUNT ==
|
3972
|
+
static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_NAME is outdated");
|
3973
|
+
|
3974
|
+
static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
|
3975
|
+
[GGML_TYPE_F32] = false,
|
3976
|
+
[GGML_TYPE_F16] = false,
|
3977
|
+
[GGML_TYPE_Q4_0] = true,
|
3978
|
+
[GGML_TYPE_Q4_1] = true,
|
3979
|
+
[GGML_TYPE_Q4_2] = true,
|
3980
|
+
[GGML_TYPE_Q4_3] = true,
|
3981
|
+
[GGML_TYPE_Q5_0] = true,
|
3982
|
+
[GGML_TYPE_Q5_1] = true,
|
3983
|
+
[GGML_TYPE_Q8_0] = true,
|
3984
|
+
[GGML_TYPE_Q8_1] = true,
|
3985
|
+
[GGML_TYPE_I8] = false,
|
3986
|
+
[GGML_TYPE_I16] = false,
|
3987
|
+
[GGML_TYPE_I32] = false,
|
3988
|
+
};
|
3989
|
+
static_assert(GGML_TYPE_COUNT == 13, "GGML_IS_QUANTIZED is outdated");
|
3237
3990
|
|
3238
3991
|
static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
3239
3992
|
"NONE",
|
@@ -3495,6 +4248,10 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
|
|
3495
4248
|
(t0->ne[3] == t1->ne[3]);
|
3496
4249
|
}
|
3497
4250
|
|
4251
|
+
bool ggml_is_quantized(enum ggml_type type) {
|
4252
|
+
return GGML_IS_QUANTIZED[type];
|
4253
|
+
}
|
4254
|
+
|
3498
4255
|
static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
|
3499
4256
|
return tensor->nb[0] > tensor->nb[1];
|
3500
4257
|
}
|
@@ -3605,6 +4362,13 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|
3605
4362
|
GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
|
3606
4363
|
}
|
3607
4364
|
|
4365
|
+
// initialize cuBLAS
|
4366
|
+
#if defined(GGML_USE_CUBLAS)
|
4367
|
+
ggml_init_cublas();
|
4368
|
+
#elif defined(GGML_USE_CLBLAST)
|
4369
|
+
ggml_cl_init();
|
4370
|
+
#endif
|
4371
|
+
|
3608
4372
|
is_first_call = false;
|
3609
4373
|
}
|
3610
4374
|
|
@@ -5535,7 +6299,6 @@ static void ggml_compute_forward_dup_f16(
|
|
5535
6299
|
const struct ggml_compute_params * params,
|
5536
6300
|
const struct ggml_tensor * src0,
|
5537
6301
|
struct ggml_tensor * dst) {
|
5538
|
-
GGML_ASSERT(params->ith == 0);
|
5539
6302
|
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
5540
6303
|
|
5541
6304
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
@@ -5547,6 +6310,11 @@ static void ggml_compute_forward_dup_f16(
|
|
5547
6310
|
const int64_t ne02 = src0->ne[2];
|
5548
6311
|
const int64_t ne03 = src0->ne[3];
|
5549
6312
|
|
6313
|
+
const int64_t ne0 = dst->ne[0];
|
6314
|
+
const int64_t ne1 = dst->ne[1];
|
6315
|
+
const int64_t ne2 = dst->ne[2];
|
6316
|
+
const int64_t ne3 = dst->ne[3];
|
6317
|
+
|
5550
6318
|
const size_t nb00 = src0->nb[0];
|
5551
6319
|
const size_t nb01 = src0->nb[1];
|
5552
6320
|
const size_t nb02 = src0->nb[2];
|
@@ -5557,19 +6325,40 @@ static void ggml_compute_forward_dup_f16(
|
|
5557
6325
|
const size_t nb2 = dst->nb[2];
|
5558
6326
|
const size_t nb3 = dst->nb[3];
|
5559
6327
|
|
6328
|
+
const int ith = params->ith; // thread index
|
6329
|
+
const int nth = params->nth; // number of threads
|
6330
|
+
|
5560
6331
|
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
|
5561
|
-
|
6332
|
+
// parallelize by elements
|
6333
|
+
const int ne = ggml_nelements(dst);
|
6334
|
+
const int dr = (ne + nth - 1) / nth;
|
6335
|
+
const int ie0 = dr * ith;
|
6336
|
+
const int ie1 = MIN(ie0 + dr, ne);
|
6337
|
+
|
6338
|
+
memcpy(
|
6339
|
+
((char *) dst->data + ie0*nb0),
|
6340
|
+
((char *) src0->data + ie0*nb00),
|
6341
|
+
(ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
|
6342
|
+
|
5562
6343
|
return;
|
5563
6344
|
}
|
5564
6345
|
|
6346
|
+
// parallelize by rows
|
6347
|
+
const int nr = ne01;
|
6348
|
+
// number of rows per thread
|
6349
|
+
const int dr = (nr + nth - 1) / nth;
|
6350
|
+
// row range for this thread
|
6351
|
+
const int ir0 = dr * ith;
|
6352
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
6353
|
+
|
5565
6354
|
if (src0->type == dst->type &&
|
5566
|
-
|
5567
|
-
|
6355
|
+
ne00 == ne0 &&
|
6356
|
+
nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
|
5568
6357
|
// copy by rows
|
5569
6358
|
const size_t rs = ne00*nb00;
|
5570
6359
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5571
6360
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5572
|
-
for (int64_t i01 =
|
6361
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5573
6362
|
memcpy(
|
5574
6363
|
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
5575
6364
|
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
@@ -5583,21 +6372,21 @@ static void ggml_compute_forward_dup_f16(
|
|
5583
6372
|
// TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
|
5584
6373
|
|
5585
6374
|
if (ggml_is_contiguous(dst)) {
|
5586
|
-
if (
|
6375
|
+
if (nb00 == sizeof(ggml_fp16_t)) {
|
5587
6376
|
if (dst->type == GGML_TYPE_F16) {
|
5588
6377
|
size_t id = 0;
|
5589
|
-
const size_t rs = ne00*nb00;
|
6378
|
+
const size_t rs = ne00 * nb00;
|
6379
|
+
char * dst_ptr = (char *) dst->data;
|
5590
6380
|
|
5591
6381
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5592
6382
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5593
|
-
|
6383
|
+
id += rs * ir0;
|
6384
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5594
6385
|
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
5595
|
-
|
5596
|
-
|
5597
|
-
memcpy(dst_ptr, src0_ptr, rs);
|
5598
|
-
|
5599
|
-
id++;
|
6386
|
+
memcpy(dst_ptr + id, src0_ptr, rs);
|
6387
|
+
id += rs;
|
5600
6388
|
}
|
6389
|
+
id += rs * (ne01 - ir1);
|
5601
6390
|
}
|
5602
6391
|
}
|
5603
6392
|
} else if (dst->type == GGML_TYPE_F32) {
|
@@ -5606,34 +6395,39 @@ static void ggml_compute_forward_dup_f16(
|
|
5606
6395
|
|
5607
6396
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5608
6397
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5609
|
-
|
6398
|
+
id += ne00 * ir0;
|
6399
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
6400
|
+
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
5610
6401
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5611
|
-
|
5612
|
-
|
5613
|
-
dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
|
6402
|
+
dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
|
5614
6403
|
id++;
|
5615
6404
|
}
|
5616
6405
|
}
|
6406
|
+
id += ne00 * (ne01 - ir1);
|
5617
6407
|
}
|
5618
6408
|
}
|
5619
|
-
} else if (dst->type
|
6409
|
+
} else if (ggml_is_quantized(dst->type)) {
|
5620
6410
|
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
|
6411
|
+
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
6412
|
+
|
5621
6413
|
size_t id = 0;
|
5622
|
-
|
5623
|
-
|
5624
|
-
float * src0_f32 = (float *) params->wdata;
|
6414
|
+
size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
|
6415
|
+
char * dst_ptr = (char *) dst->data;
|
5625
6416
|
|
5626
6417
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5627
6418
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5628
|
-
|
6419
|
+
id += rs * ir0;
|
6420
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5629
6421
|
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
5630
|
-
|
6422
|
+
|
5631
6423
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5632
6424
|
src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
|
5633
6425
|
}
|
6426
|
+
|
5634
6427
|
quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
5635
|
-
id +=
|
6428
|
+
id += rs;
|
5636
6429
|
}
|
6430
|
+
id += rs * (ne01 - ir1);
|
5637
6431
|
}
|
5638
6432
|
}
|
5639
6433
|
} else {
|
@@ -5648,7 +6442,8 @@ static void ggml_compute_forward_dup_f16(
|
|
5648
6442
|
|
5649
6443
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5650
6444
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5651
|
-
|
6445
|
+
id += ne00 * ir0;
|
6446
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5652
6447
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5653
6448
|
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5654
6449
|
|
@@ -5656,6 +6451,7 @@ static void ggml_compute_forward_dup_f16(
|
|
5656
6451
|
id++;
|
5657
6452
|
}
|
5658
6453
|
}
|
6454
|
+
id += ne00 * (ne01 - ir1);
|
5659
6455
|
}
|
5660
6456
|
}
|
5661
6457
|
} else if (dst->type == GGML_TYPE_F16) {
|
@@ -5664,7 +6460,8 @@ static void ggml_compute_forward_dup_f16(
|
|
5664
6460
|
|
5665
6461
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5666
6462
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5667
|
-
|
6463
|
+
id += ne00 * ir0;
|
6464
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5668
6465
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5669
6466
|
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5670
6467
|
|
@@ -5672,6 +6469,7 @@ static void ggml_compute_forward_dup_f16(
|
|
5672
6469
|
id++;
|
5673
6470
|
}
|
5674
6471
|
}
|
6472
|
+
id += ne00 * (ne01 - ir1);
|
5675
6473
|
}
|
5676
6474
|
}
|
5677
6475
|
} else {
|
@@ -5690,7 +6488,20 @@ static void ggml_compute_forward_dup_f16(
|
|
5690
6488
|
if (dst->type == GGML_TYPE_F16) {
|
5691
6489
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5692
6490
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5693
|
-
|
6491
|
+
i10 += ne00 * ir0;
|
6492
|
+
while (i10 >= ne0) {
|
6493
|
+
i10 -= ne0;
|
6494
|
+
if (++i11 == ne1) {
|
6495
|
+
i11 = 0;
|
6496
|
+
if (++i12 == ne2) {
|
6497
|
+
i12 = 0;
|
6498
|
+
if (++i13 == ne3) {
|
6499
|
+
i13 = 0;
|
6500
|
+
}
|
6501
|
+
}
|
6502
|
+
}
|
6503
|
+
}
|
6504
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5694
6505
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
5695
6506
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5696
6507
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
@@ -5711,25 +6522,51 @@ static void ggml_compute_forward_dup_f16(
|
|
5711
6522
|
}
|
5712
6523
|
}
|
5713
6524
|
}
|
6525
|
+
i10 += ne00 * (ne01 - ir1);
|
6526
|
+
while (i10 >= ne0) {
|
6527
|
+
i10 -= ne0;
|
6528
|
+
if (++i11 == ne1) {
|
6529
|
+
i11 = 0;
|
6530
|
+
if (++i12 == ne2) {
|
6531
|
+
i12 = 0;
|
6532
|
+
if (++i13 == ne3) {
|
6533
|
+
i13 = 0;
|
6534
|
+
}
|
6535
|
+
}
|
6536
|
+
}
|
6537
|
+
}
|
5714
6538
|
}
|
5715
6539
|
}
|
5716
6540
|
} else if (dst->type == GGML_TYPE_F32) {
|
5717
6541
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5718
6542
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5719
|
-
|
6543
|
+
i10 += ne00 * ir0;
|
6544
|
+
while (i10 >= ne0) {
|
6545
|
+
i10 -= ne0;
|
6546
|
+
if (++i11 == ne1) {
|
6547
|
+
i11 = 0;
|
6548
|
+
if (++i12 == ne2) {
|
6549
|
+
i12 = 0;
|
6550
|
+
if (++i13 == ne3) {
|
6551
|
+
i13 = 0;
|
6552
|
+
}
|
6553
|
+
}
|
6554
|
+
}
|
6555
|
+
}
|
6556
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5720
6557
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
5721
6558
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5722
6559
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
5723
6560
|
|
5724
6561
|
*(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
|
5725
6562
|
|
5726
|
-
if (++i10 ==
|
6563
|
+
if (++i10 == ne0) {
|
5727
6564
|
i10 = 0;
|
5728
|
-
if (++i11 ==
|
6565
|
+
if (++i11 == ne1) {
|
5729
6566
|
i11 = 0;
|
5730
|
-
if (++i12 ==
|
6567
|
+
if (++i12 == ne2) {
|
5731
6568
|
i12 = 0;
|
5732
|
-
if (++i13 ==
|
6569
|
+
if (++i13 == ne3) {
|
5733
6570
|
i13 = 0;
|
5734
6571
|
}
|
5735
6572
|
}
|
@@ -5737,6 +6574,19 @@ static void ggml_compute_forward_dup_f16(
|
|
5737
6574
|
}
|
5738
6575
|
}
|
5739
6576
|
}
|
6577
|
+
i10 += ne00 * (ne01 - ir1);
|
6578
|
+
while (i10 >= ne0) {
|
6579
|
+
i10 -= ne0;
|
6580
|
+
if (++i11 == ne1) {
|
6581
|
+
i11 = 0;
|
6582
|
+
if (++i12 == ne2) {
|
6583
|
+
i12 = 0;
|
6584
|
+
if (++i13 == ne3) {
|
6585
|
+
i13 = 0;
|
6586
|
+
}
|
6587
|
+
}
|
6588
|
+
}
|
6589
|
+
}
|
5740
6590
|
}
|
5741
6591
|
}
|
5742
6592
|
} else {
|
@@ -5748,7 +6598,6 @@ static void ggml_compute_forward_dup_f32(
|
|
5748
6598
|
const struct ggml_compute_params * params,
|
5749
6599
|
const struct ggml_tensor * src0,
|
5750
6600
|
struct ggml_tensor * dst) {
|
5751
|
-
GGML_ASSERT(params->ith == 0);
|
5752
6601
|
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
5753
6602
|
|
5754
6603
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
@@ -5760,6 +6609,11 @@ static void ggml_compute_forward_dup_f32(
|
|
5760
6609
|
const int64_t ne02 = src0->ne[2];
|
5761
6610
|
const int64_t ne03 = src0->ne[3];
|
5762
6611
|
|
6612
|
+
const int64_t ne0 = dst->ne[0];
|
6613
|
+
const int64_t ne1 = dst->ne[1];
|
6614
|
+
const int64_t ne2 = dst->ne[2];
|
6615
|
+
const int64_t ne3 = dst->ne[3];
|
6616
|
+
|
5763
6617
|
const size_t nb00 = src0->nb[0];
|
5764
6618
|
const size_t nb01 = src0->nb[1];
|
5765
6619
|
const size_t nb02 = src0->nb[2];
|
@@ -5770,19 +6624,40 @@ static void ggml_compute_forward_dup_f32(
|
|
5770
6624
|
const size_t nb2 = dst->nb[2];
|
5771
6625
|
const size_t nb3 = dst->nb[3];
|
5772
6626
|
|
6627
|
+
const int ith = params->ith; // thread index
|
6628
|
+
const int nth = params->nth; // number of threads
|
6629
|
+
|
5773
6630
|
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
|
5774
|
-
|
6631
|
+
// parallelize by elements
|
6632
|
+
const int ne = ggml_nelements(dst);
|
6633
|
+
const int dr = (ne + nth - 1) / nth;
|
6634
|
+
const int ie0 = dr * ith;
|
6635
|
+
const int ie1 = MIN(ie0 + dr, ne);
|
6636
|
+
|
6637
|
+
memcpy(
|
6638
|
+
((char *) dst->data + ie0*nb0),
|
6639
|
+
((char *) src0->data + ie0*nb00),
|
6640
|
+
(ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
|
6641
|
+
|
5775
6642
|
return;
|
5776
6643
|
}
|
5777
6644
|
|
6645
|
+
// parallelize by rows
|
6646
|
+
const int nr = ne01;
|
6647
|
+
// number of rows per thread
|
6648
|
+
const int dr = (nr + nth - 1) / nth;
|
6649
|
+
// row range for this thread
|
6650
|
+
const int ir0 = dr * ith;
|
6651
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
6652
|
+
|
5778
6653
|
if (src0->type == dst->type &&
|
5779
|
-
|
5780
|
-
|
6654
|
+
ne00 == ne0 &&
|
6655
|
+
nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
|
5781
6656
|
// copy by rows
|
5782
6657
|
const size_t rs = ne00*nb00;
|
5783
6658
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5784
6659
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5785
|
-
for (int64_t i01 =
|
6660
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5786
6661
|
memcpy(
|
5787
6662
|
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
5788
6663
|
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
@@ -5795,21 +6670,21 @@ static void ggml_compute_forward_dup_f32(
|
|
5795
6670
|
|
5796
6671
|
if (ggml_is_contiguous(dst)) {
|
5797
6672
|
// TODO: simplify
|
5798
|
-
if (
|
6673
|
+
if (nb00 == sizeof(float)) {
|
5799
6674
|
if (dst->type == GGML_TYPE_F32) {
|
5800
6675
|
size_t id = 0;
|
5801
|
-
const size_t rs = ne00*nb00;
|
6676
|
+
const size_t rs = ne00 * nb00;
|
6677
|
+
char * dst_ptr = (char *) dst->data;
|
5802
6678
|
|
5803
6679
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5804
6680
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5805
|
-
|
6681
|
+
id += rs * ir0;
|
6682
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5806
6683
|
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
5807
|
-
|
5808
|
-
|
5809
|
-
memcpy(dst_ptr, src0_ptr, rs);
|
5810
|
-
|
5811
|
-
id++;
|
6684
|
+
memcpy(dst_ptr + id, src0_ptr, rs);
|
6685
|
+
id += rs;
|
5812
6686
|
}
|
6687
|
+
id += rs * (ne01 - ir1);
|
5813
6688
|
}
|
5814
6689
|
}
|
5815
6690
|
} else if (dst->type == GGML_TYPE_F16) {
|
@@ -5818,7 +6693,8 @@ static void ggml_compute_forward_dup_f32(
|
|
5818
6693
|
|
5819
6694
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5820
6695
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5821
|
-
|
6696
|
+
id += ne00 * ir0;
|
6697
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5822
6698
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5823
6699
|
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5824
6700
|
|
@@ -5826,21 +6702,25 @@ static void ggml_compute_forward_dup_f32(
|
|
5826
6702
|
id++;
|
5827
6703
|
}
|
5828
6704
|
}
|
6705
|
+
id += ne00 * (ne01 - ir1);
|
5829
6706
|
}
|
5830
6707
|
}
|
5831
|
-
} else if (dst->type
|
6708
|
+
} else if (ggml_is_quantized(dst->type)) {
|
5832
6709
|
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
|
6710
|
+
|
5833
6711
|
size_t id = 0;
|
5834
|
-
|
5835
|
-
|
6712
|
+
size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
|
6713
|
+
char * dst_ptr = (char *) dst->data;
|
5836
6714
|
|
5837
6715
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5838
6716
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5839
|
-
|
6717
|
+
id += rs * ir0;
|
6718
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5840
6719
|
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
5841
6720
|
quantize_row_q(src0_ptr, dst_ptr + id, ne00);
|
5842
|
-
id +=
|
6721
|
+
id += rs;
|
5843
6722
|
}
|
6723
|
+
id += rs * (ne01 - ir1);
|
5844
6724
|
}
|
5845
6725
|
}
|
5846
6726
|
} else {
|
@@ -5855,7 +6735,8 @@ static void ggml_compute_forward_dup_f32(
|
|
5855
6735
|
|
5856
6736
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5857
6737
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5858
|
-
|
6738
|
+
id += ne00 * ir0;
|
6739
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5859
6740
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5860
6741
|
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5861
6742
|
|
@@ -5863,6 +6744,7 @@ static void ggml_compute_forward_dup_f32(
|
|
5863
6744
|
id++;
|
5864
6745
|
}
|
5865
6746
|
}
|
6747
|
+
id += ne00 * (ne01 - ir1);
|
5866
6748
|
}
|
5867
6749
|
}
|
5868
6750
|
} else if (dst->type == GGML_TYPE_F16) {
|
@@ -5871,7 +6753,8 @@ static void ggml_compute_forward_dup_f32(
|
|
5871
6753
|
|
5872
6754
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5873
6755
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5874
|
-
|
6756
|
+
id += ne00 * ir0;
|
6757
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5875
6758
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5876
6759
|
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5877
6760
|
|
@@ -5879,6 +6762,7 @@ static void ggml_compute_forward_dup_f32(
|
|
5879
6762
|
id++;
|
5880
6763
|
}
|
5881
6764
|
}
|
6765
|
+
id += ne00 * (ne01 - ir1);
|
5882
6766
|
}
|
5883
6767
|
}
|
5884
6768
|
} else {
|
@@ -5890,6 +6774,7 @@ static void ggml_compute_forward_dup_f32(
|
|
5890
6774
|
}
|
5891
6775
|
|
5892
6776
|
// dst counters
|
6777
|
+
|
5893
6778
|
int64_t i10 = 0;
|
5894
6779
|
int64_t i11 = 0;
|
5895
6780
|
int64_t i12 = 0;
|
@@ -5898,20 +6783,33 @@ static void ggml_compute_forward_dup_f32(
|
|
5898
6783
|
if (dst->type == GGML_TYPE_F32) {
|
5899
6784
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5900
6785
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5901
|
-
|
6786
|
+
i10 += ne00 * ir0;
|
6787
|
+
while (i10 >= ne0) {
|
6788
|
+
i10 -= ne0;
|
6789
|
+
if (++i11 == ne1) {
|
6790
|
+
i11 = 0;
|
6791
|
+
if (++i12 == ne2) {
|
6792
|
+
i12 = 0;
|
6793
|
+
if (++i13 == ne3) {
|
6794
|
+
i13 = 0;
|
6795
|
+
}
|
6796
|
+
}
|
6797
|
+
}
|
6798
|
+
}
|
6799
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5902
6800
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
5903
6801
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5904
6802
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
5905
6803
|
|
5906
6804
|
memcpy(dst_ptr, src0_ptr, sizeof(float));
|
5907
6805
|
|
5908
|
-
if (++i10 ==
|
6806
|
+
if (++i10 == ne0) {
|
5909
6807
|
i10 = 0;
|
5910
|
-
if (++i11 ==
|
6808
|
+
if (++i11 == ne1) {
|
5911
6809
|
i11 = 0;
|
5912
|
-
if (++i12 ==
|
6810
|
+
if (++i12 == ne2) {
|
5913
6811
|
i12 = 0;
|
5914
|
-
if (++i13 ==
|
6812
|
+
if (++i13 == ne3) {
|
5915
6813
|
i13 = 0;
|
5916
6814
|
}
|
5917
6815
|
}
|
@@ -5919,25 +6817,51 @@ static void ggml_compute_forward_dup_f32(
|
|
5919
6817
|
}
|
5920
6818
|
}
|
5921
6819
|
}
|
6820
|
+
i10 += ne00 * (ne01 - ir1);
|
6821
|
+
while (i10 >= ne0) {
|
6822
|
+
i10 -= ne0;
|
6823
|
+
if (++i11 == ne1) {
|
6824
|
+
i11 = 0;
|
6825
|
+
if (++i12 == ne2) {
|
6826
|
+
i12 = 0;
|
6827
|
+
if (++i13 == ne3) {
|
6828
|
+
i13 = 0;
|
6829
|
+
}
|
6830
|
+
}
|
6831
|
+
}
|
6832
|
+
}
|
5922
6833
|
}
|
5923
6834
|
}
|
5924
6835
|
} else if (dst->type == GGML_TYPE_F16) {
|
5925
6836
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5926
6837
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5927
|
-
|
6838
|
+
i10 += ne00 * ir0;
|
6839
|
+
while (i10 >= ne0) {
|
6840
|
+
i10 -= ne0;
|
6841
|
+
if (++i11 == ne1) {
|
6842
|
+
i11 = 0;
|
6843
|
+
if (++i12 == ne2) {
|
6844
|
+
i12 = 0;
|
6845
|
+
if (++i13 == ne3) {
|
6846
|
+
i13 = 0;
|
6847
|
+
}
|
6848
|
+
}
|
6849
|
+
}
|
6850
|
+
}
|
6851
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5928
6852
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
5929
6853
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5930
6854
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
5931
6855
|
|
5932
6856
|
*(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
|
5933
6857
|
|
5934
|
-
if (++i10 ==
|
6858
|
+
if (++i10 == ne0) {
|
5935
6859
|
i10 = 0;
|
5936
|
-
if (++i11 ==
|
6860
|
+
if (++i11 == ne1) {
|
5937
6861
|
i11 = 0;
|
5938
|
-
if (++i12 ==
|
6862
|
+
if (++i12 == ne2) {
|
5939
6863
|
i12 = 0;
|
5940
|
-
if (++i13 ==
|
6864
|
+
if (++i13 == ne3) {
|
5941
6865
|
i13 = 0;
|
5942
6866
|
}
|
5943
6867
|
}
|
@@ -5945,6 +6869,19 @@ static void ggml_compute_forward_dup_f32(
|
|
5945
6869
|
}
|
5946
6870
|
}
|
5947
6871
|
}
|
6872
|
+
i10 += ne00 * (ne01 - ir1);
|
6873
|
+
while (i10 >= ne0) {
|
6874
|
+
i10 -= ne0;
|
6875
|
+
if (++i11 == ne1) {
|
6876
|
+
i11 = 0;
|
6877
|
+
if (++i12 == ne2) {
|
6878
|
+
i12 = 0;
|
6879
|
+
if (++i13 == ne3) {
|
6880
|
+
i13 = 0;
|
6881
|
+
}
|
6882
|
+
}
|
6883
|
+
}
|
6884
|
+
}
|
5948
6885
|
}
|
5949
6886
|
}
|
5950
6887
|
} else {
|
@@ -6191,7 +7128,7 @@ static void ggml_compute_forward_add_q_f32(
|
|
6191
7128
|
GGML_ASSERT(nb1 <= nb2);
|
6192
7129
|
GGML_ASSERT(nb2 <= nb3);
|
6193
7130
|
|
6194
|
-
GGML_ASSERT(src0->type
|
7131
|
+
GGML_ASSERT(ggml_is_quantized(src0->type));
|
6195
7132
|
GGML_ASSERT(dst->type == src0->type);
|
6196
7133
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
6197
7134
|
|
@@ -6205,7 +7142,7 @@ static void ggml_compute_forward_add_q_f32(
|
|
6205
7142
|
const int ir0 = dr*ith;
|
6206
7143
|
const int ir1 = MIN(ir0 + dr, nr);
|
6207
7144
|
|
6208
|
-
float * wdata = (float*) params->wdata + ne00 * ith;
|
7145
|
+
float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
6209
7146
|
|
6210
7147
|
for (int ir = ir0; ir < ir1; ++ir) {
|
6211
7148
|
// src0 indices
|
@@ -6261,6 +7198,11 @@ static void ggml_compute_forward_add(
|
|
6261
7198
|
} break;
|
6262
7199
|
case GGML_TYPE_Q4_0:
|
6263
7200
|
case GGML_TYPE_Q4_1:
|
7201
|
+
case GGML_TYPE_Q4_2:
|
7202
|
+
case GGML_TYPE_Q4_3:
|
7203
|
+
case GGML_TYPE_Q5_0:
|
7204
|
+
case GGML_TYPE_Q5_1:
|
7205
|
+
case GGML_TYPE_Q8_0:
|
6264
7206
|
{
|
6265
7207
|
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
6266
7208
|
} break;
|
@@ -6518,15 +7460,20 @@ static void ggml_compute_forward_sum_f32(
|
|
6518
7460
|
const size_t nb02 = src0->nb[2];
|
6519
7461
|
const size_t nb03 = src0->nb[3];
|
6520
7462
|
|
7463
|
+
ggml_float sum = 0;
|
7464
|
+
ggml_float row_sum = 0;
|
7465
|
+
|
6521
7466
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
6522
7467
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
6523
7468
|
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
6524
|
-
|
6525
|
-
|
7469
|
+
ggml_vec_sum_ggf(ne00,
|
7470
|
+
&row_sum,
|
6526
7471
|
(float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
|
7472
|
+
sum += row_sum;
|
6527
7473
|
}
|
6528
7474
|
}
|
6529
7475
|
}
|
7476
|
+
((float *) dst->data)[0] = sum;
|
6530
7477
|
}
|
6531
7478
|
|
6532
7479
|
static void ggml_compute_forward_sum(
|
@@ -7161,7 +8108,7 @@ static void ggml_compute_forward_rms_norm(
|
|
7161
8108
|
|
7162
8109
|
// ggml_compute_forward_mul_mat
|
7163
8110
|
|
7164
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
8111
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
7165
8112
|
// helper function to determine if it is better to use BLAS or not
|
7166
8113
|
// for large matrices, BLAS is faster
|
7167
8114
|
static bool ggml_compute_forward_mul_mat_use_blas(
|
@@ -7186,6 +8133,7 @@ static bool ggml_compute_forward_mul_mat_use_blas(
|
|
7186
8133
|
|
7187
8134
|
return false;
|
7188
8135
|
}
|
8136
|
+
|
7189
8137
|
#endif
|
7190
8138
|
|
7191
8139
|
static void ggml_compute_forward_mul_mat_f32(
|
@@ -7201,7 +8149,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
7201
8149
|
const int64_t ne02 = src0->ne[2];
|
7202
8150
|
const int64_t ne03 = src0->ne[3];
|
7203
8151
|
|
7204
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
8152
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
7205
8153
|
const int64_t ne10 = src1->ne[0];
|
7206
8154
|
#endif
|
7207
8155
|
const int64_t ne11 = src1->ne[1];
|
@@ -7258,7 +8206,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
7258
8206
|
// nb01 >= nb00 - src0 is not transposed
|
7259
8207
|
// compute by src0 rows
|
7260
8208
|
|
7261
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
8209
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
7262
8210
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
7263
8211
|
if (params->ith != 0) {
|
7264
8212
|
return;
|
@@ -7272,6 +8220,19 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
7272
8220
|
return;
|
7273
8221
|
}
|
7274
8222
|
|
8223
|
+
#if defined(GGML_USE_CUBLAS)
|
8224
|
+
const float alpha = 1.0f;
|
8225
|
+
const float beta = 0.0f;
|
8226
|
+
const int x_ne = ne01 * ne10;
|
8227
|
+
const int y_ne = ne11 * ne10;
|
8228
|
+
const int d_ne = ne11 * ne01;
|
8229
|
+
|
8230
|
+
size_t x_size, y_size, d_size;
|
8231
|
+
float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
|
8232
|
+
float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
|
8233
|
+
float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
|
8234
|
+
#endif
|
8235
|
+
|
7275
8236
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
7276
8237
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
7277
8238
|
const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
|
@@ -7279,15 +8240,44 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
7279
8240
|
|
7280
8241
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
7281
8242
|
|
8243
|
+
#if defined(GGML_USE_CUBLAS)
|
8244
|
+
// copy data to device
|
8245
|
+
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
|
8246
|
+
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
|
8247
|
+
|
8248
|
+
// compute
|
8249
|
+
CUBLAS_CHECK(
|
8250
|
+
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
8251
|
+
ne01, ne11, ne10,
|
8252
|
+
&alpha, d_X, ne00,
|
8253
|
+
d_Y, ne10,
|
8254
|
+
&beta, d_D, ne01));
|
8255
|
+
|
8256
|
+
// copy data to host
|
8257
|
+
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
|
8258
|
+
#elif defined(GGML_USE_CLBLAST)
|
7282
8259
|
// zT = y * xT
|
8260
|
+
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
|
8261
|
+
ne11, ne01, ne10,
|
8262
|
+
1.0f, y, ne10,
|
8263
|
+
x, ne10,
|
8264
|
+
0.0f, d, ne01,
|
8265
|
+
GGML_TYPE_F32);
|
8266
|
+
#else
|
7283
8267
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
7284
8268
|
ne11, ne01, ne10,
|
7285
8269
|
1.0f, y, ne10,
|
7286
8270
|
x, ne00,
|
7287
8271
|
0.0f, d, ne01);
|
8272
|
+
#endif
|
7288
8273
|
}
|
7289
8274
|
}
|
7290
|
-
|
8275
|
+
#if defined(GGML_USE_CUBLAS)
|
8276
|
+
CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
|
8277
|
+
ggml_cuda_pool_free(d_X, x_size);
|
8278
|
+
ggml_cuda_pool_free(d_Y, y_size);
|
8279
|
+
ggml_cuda_pool_free(d_D, d_size);
|
8280
|
+
#endif
|
7291
8281
|
//printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
7292
8282
|
|
7293
8283
|
return;
|
@@ -7417,7 +8407,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
7417
8407
|
// nb01 >= nb00 - src0 is not transposed
|
7418
8408
|
// compute by src0 rows
|
7419
8409
|
|
7420
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
8410
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
7421
8411
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
7422
8412
|
GGML_ASSERT(nb10 == sizeof(float));
|
7423
8413
|
|
@@ -7433,10 +8423,35 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
7433
8423
|
return;
|
7434
8424
|
}
|
7435
8425
|
|
7436
|
-
|
8426
|
+
#if defined(GGML_USE_CUBLAS)
|
8427
|
+
ggml_fp16_t * const wdata = params->wdata;
|
8428
|
+
|
8429
|
+
const float alpha = 1.0f;
|
8430
|
+
const float beta = 0.0f;
|
8431
|
+
const int x_ne = ne01 * ne10;
|
8432
|
+
const int y_ne = ne11 * ne10;
|
8433
|
+
const int d_ne = ne11 * ne01;
|
7437
8434
|
|
8435
|
+
size_t x_size, y_size, d_size;
|
8436
|
+
float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
|
8437
|
+
float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
|
8438
|
+
float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
|
8439
|
+
#else
|
8440
|
+
float * const wdata = params->wdata;
|
8441
|
+
#endif
|
7438
8442
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
7439
8443
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
8444
|
+
#if defined(GGML_USE_CUBLAS)
|
8445
|
+
// with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
|
8446
|
+
{
|
8447
|
+
size_t id = 0;
|
8448
|
+
for (int64_t i01 = 0; i01 < ne11; ++i01) {
|
8449
|
+
for (int64_t i00 = 0; i00 < ne10; ++i00) {
|
8450
|
+
wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
|
8451
|
+
}
|
8452
|
+
}
|
8453
|
+
}
|
8454
|
+
#else
|
7440
8455
|
{
|
7441
8456
|
size_t id = 0;
|
7442
8457
|
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
@@ -7445,7 +8460,44 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
7445
8460
|
}
|
7446
8461
|
}
|
7447
8462
|
}
|
8463
|
+
#endif
|
8464
|
+
|
8465
|
+
#if defined(GGML_USE_CUBLAS)
|
8466
|
+
const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03);
|
8467
|
+
const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
|
8468
|
+
|
8469
|
+
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
8470
|
+
|
8471
|
+
// copy data to device
|
8472
|
+
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
|
8473
|
+
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
|
8474
|
+
|
8475
|
+
// compute
|
8476
|
+
CUBLAS_CHECK(
|
8477
|
+
cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
8478
|
+
ne01, ne11, ne10,
|
8479
|
+
&alpha, d_X, CUDA_R_16F, ne00,
|
8480
|
+
d_Y, CUDA_R_16F, ne10,
|
8481
|
+
&beta, d_D, CUDA_R_32F, ne01,
|
8482
|
+
CUBLAS_COMPUTE_32F,
|
8483
|
+
CUBLAS_GEMM_DEFAULT));
|
8484
|
+
|
8485
|
+
// copy data to host
|
8486
|
+
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
|
8487
|
+
#elif defined(GGML_USE_CLBLAST)
|
8488
|
+
const float * x = wdata;
|
8489
|
+
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
7448
8490
|
|
8491
|
+
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
8492
|
+
|
8493
|
+
// zT = y * xT
|
8494
|
+
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
|
8495
|
+
ne11, ne01, ne10,
|
8496
|
+
1.0f, y, ne10,
|
8497
|
+
x, ne10,
|
8498
|
+
0.0f, d, ne01,
|
8499
|
+
GGML_TYPE_F32);
|
8500
|
+
#else
|
7449
8501
|
const float * x = wdata;
|
7450
8502
|
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
7451
8503
|
|
@@ -7457,9 +8509,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
7457
8509
|
1.0f, y, ne10,
|
7458
8510
|
x, ne00,
|
7459
8511
|
0.0f, d, ne01);
|
8512
|
+
#endif
|
7460
8513
|
}
|
7461
8514
|
}
|
7462
8515
|
|
8516
|
+
#if defined(GGML_USE_CUBLAS)
|
8517
|
+
CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
|
8518
|
+
ggml_cuda_pool_free(d_X, x_size);
|
8519
|
+
ggml_cuda_pool_free(d_Y, y_size);
|
8520
|
+
ggml_cuda_pool_free(d_D, d_size);
|
8521
|
+
#endif
|
7463
8522
|
/*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
|
7464
8523
|
|
7465
8524
|
return;
|
@@ -7592,6 +8651,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7592
8651
|
const enum ggml_type type = src0->type;
|
7593
8652
|
quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
|
7594
8653
|
vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
|
8654
|
+
enum ggml_type const vec_dot_type = quantize_fns[type].vec_dot_type;
|
7595
8655
|
|
7596
8656
|
// we don't support permuted src0 or src1
|
7597
8657
|
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
|
@@ -7611,7 +8671,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7611
8671
|
// nb01 >= nb00 - src0 is not transposed
|
7612
8672
|
// compute by src0 rows
|
7613
8673
|
|
7614
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
8674
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
7615
8675
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
7616
8676
|
if (params->ith != 0) {
|
7617
8677
|
return;
|
@@ -7625,11 +8685,66 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7625
8685
|
return;
|
7626
8686
|
}
|
7627
8687
|
|
8688
|
+
#if defined(GGML_USE_CUBLAS)
|
8689
|
+
const float alpha = 1.0f;
|
8690
|
+
const float beta = 0.0f;
|
8691
|
+
const int x_ne = ne01 * ne10;
|
8692
|
+
const int y_ne = ne11 * ne10;
|
8693
|
+
const int d_ne = ne11 * ne01;
|
8694
|
+
|
8695
|
+
size_t x_size, y_size, d_size, q_size;
|
8696
|
+
float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
|
8697
|
+
float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
|
8698
|
+
float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
|
8699
|
+
float *d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
|
8700
|
+
|
8701
|
+
void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
|
8702
|
+
if (type == GGML_TYPE_Q4_0) {
|
8703
|
+
dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
|
8704
|
+
}
|
8705
|
+
else if (type == GGML_TYPE_Q4_1) {
|
8706
|
+
dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
|
8707
|
+
}
|
8708
|
+
else if (type == GGML_TYPE_Q4_2) {
|
8709
|
+
dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
|
8710
|
+
}
|
8711
|
+
else if (type == GGML_TYPE_Q4_3) {
|
8712
|
+
dequantize_row_q_cuda = dequantize_row_q4_3_cuda;
|
8713
|
+
}
|
8714
|
+
else if (type == GGML_TYPE_Q5_0) {
|
8715
|
+
dequantize_row_q_cuda = dequantize_row_q5_0_cuda;
|
8716
|
+
}
|
8717
|
+
else if (type == GGML_TYPE_Q5_1) {
|
8718
|
+
dequantize_row_q_cuda = dequantize_row_q5_1_cuda;
|
8719
|
+
}
|
8720
|
+
else if (type == GGML_TYPE_Q8_0) {
|
8721
|
+
dequantize_row_q_cuda = dequantize_row_q8_0_cuda;
|
8722
|
+
}
|
8723
|
+
else {
|
8724
|
+
GGML_ASSERT(false);
|
8725
|
+
}
|
8726
|
+
#elif !defined(GGML_USE_CLBLAST)
|
7628
8727
|
float * const wdata = params->wdata;
|
7629
8728
|
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
|
8729
|
+
#endif
|
7630
8730
|
|
7631
8731
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
7632
8732
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
8733
|
+
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
8734
|
+
|
8735
|
+
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
8736
|
+
|
8737
|
+
#if defined(GGML_USE_CUBLAS)
|
8738
|
+
// copy and dequantize on device
|
8739
|
+
CUDA_CHECK(
|
8740
|
+
cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
|
8741
|
+
GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, g_cudaStream));
|
8742
|
+
|
8743
|
+
dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream);
|
8744
|
+
CUDA_CHECK(cudaGetLastError());
|
8745
|
+
#elif defined(GGML_USE_CLBLAST)
|
8746
|
+
const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
|
8747
|
+
#else
|
7633
8748
|
{
|
7634
8749
|
size_t id = 0;
|
7635
8750
|
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
@@ -7637,21 +8752,49 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7637
8752
|
id += ne00;
|
7638
8753
|
}
|
7639
8754
|
}
|
7640
|
-
|
7641
8755
|
const float * x = wdata;
|
7642
|
-
|
8756
|
+
#endif
|
7643
8757
|
|
7644
|
-
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
7645
8758
|
|
8759
|
+
#if defined(GGML_USE_CUBLAS)
|
8760
|
+
// copy data to device
|
8761
|
+
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
|
8762
|
+
|
8763
|
+
// compute
|
8764
|
+
CUBLAS_CHECK(
|
8765
|
+
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
8766
|
+
ne01, ne11, ne10,
|
8767
|
+
&alpha, d_X, ne00,
|
8768
|
+
d_Y, ne10,
|
8769
|
+
&beta, d_D, ne01));
|
8770
|
+
|
8771
|
+
// copy data to host
|
8772
|
+
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
|
8773
|
+
#elif defined(GGML_USE_CLBLAST)
|
7646
8774
|
// zT = y * xT
|
8775
|
+
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
|
8776
|
+
ne11, ne01, ne10,
|
8777
|
+
1.0f, y, ne10,
|
8778
|
+
x, ne10,
|
8779
|
+
0.0f, d, ne01,
|
8780
|
+
type);
|
8781
|
+
#else
|
7647
8782
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
7648
8783
|
ne11, ne01, ne10,
|
7649
8784
|
1.0f, y, ne10,
|
7650
8785
|
x, ne00,
|
7651
8786
|
0.0f, d, ne01);
|
8787
|
+
#endif
|
7652
8788
|
}
|
7653
8789
|
}
|
7654
8790
|
|
8791
|
+
#if defined(GGML_USE_CUBLAS)
|
8792
|
+
CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
|
8793
|
+
ggml_cuda_pool_free(d_X, x_size);
|
8794
|
+
ggml_cuda_pool_free(d_Y, y_size);
|
8795
|
+
ggml_cuda_pool_free(d_D, d_size);
|
8796
|
+
ggml_cuda_pool_free(d_Q, q_size);
|
8797
|
+
#endif
|
7655
8798
|
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
7656
8799
|
|
7657
8800
|
return;
|
@@ -7660,7 +8803,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7660
8803
|
|
7661
8804
|
if (params->type == GGML_TASK_INIT) {
|
7662
8805
|
char * wdata = params->wdata;
|
7663
|
-
const size_t row_size = ne10*GGML_TYPE_SIZE[
|
8806
|
+
const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
|
7664
8807
|
|
7665
8808
|
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
7666
8809
|
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
@@ -7691,7 +8834,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7691
8834
|
const int ir1 = MIN(ir0 + dr, nr);
|
7692
8835
|
|
7693
8836
|
void * wdata = params->wdata;
|
7694
|
-
const size_t row_size = ne00*GGML_TYPE_SIZE[
|
8837
|
+
const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
|
7695
8838
|
|
7696
8839
|
for (int ir = ir0; ir < ir1; ++ir) {
|
7697
8840
|
// src0 indices
|
@@ -7739,7 +8882,12 @@ static void ggml_compute_forward_mul_mat(
|
|
7739
8882
|
switch (src0->type) {
|
7740
8883
|
case GGML_TYPE_Q4_0:
|
7741
8884
|
case GGML_TYPE_Q4_1:
|
8885
|
+
case GGML_TYPE_Q4_2:
|
8886
|
+
case GGML_TYPE_Q4_3:
|
8887
|
+
case GGML_TYPE_Q5_0:
|
8888
|
+
case GGML_TYPE_Q5_1:
|
7742
8889
|
case GGML_TYPE_Q8_0:
|
8890
|
+
case GGML_TYPE_Q8_1:
|
7743
8891
|
{
|
7744
8892
|
ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
|
7745
8893
|
} break;
|
@@ -7756,34 +8904,6 @@ static void ggml_compute_forward_mul_mat(
|
|
7756
8904
|
GGML_ASSERT(false);
|
7757
8905
|
} break;
|
7758
8906
|
}
|
7759
|
-
|
7760
|
-
#if 0
|
7761
|
-
if (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_Q4_1) {
|
7762
|
-
static int first = 8;
|
7763
|
-
printf("src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
|
7764
|
-
printf("src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
|
7765
|
-
printf("dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
|
7766
|
-
if (first) {
|
7767
|
-
--first;
|
7768
|
-
} else {
|
7769
|
-
for (int k = 0; k < dst->ne[1]; ++k) {
|
7770
|
-
for (int j = 0; j < dst->ne[0]/16; ++j) {
|
7771
|
-
for (int i = 0; i < 16; ++i) {
|
7772
|
-
printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
|
7773
|
-
}
|
7774
|
-
printf("\n");
|
7775
|
-
}
|
7776
|
-
printf("\n");
|
7777
|
-
}
|
7778
|
-
printf("\n");
|
7779
|
-
exit(0);
|
7780
|
-
}
|
7781
|
-
} else {
|
7782
|
-
printf("aaaa src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
|
7783
|
-
printf("aaaa src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
|
7784
|
-
printf("aaaa dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
|
7785
|
-
}
|
7786
|
-
#endif
|
7787
8907
|
}
|
7788
8908
|
|
7789
8909
|
// ggml_compute_forward_scale
|
@@ -7994,7 +9114,12 @@ static void ggml_compute_forward_get_rows(
|
|
7994
9114
|
switch (src0->type) {
|
7995
9115
|
case GGML_TYPE_Q4_0:
|
7996
9116
|
case GGML_TYPE_Q4_1:
|
9117
|
+
case GGML_TYPE_Q4_2:
|
9118
|
+
case GGML_TYPE_Q4_3:
|
9119
|
+
case GGML_TYPE_Q5_0:
|
9120
|
+
case GGML_TYPE_Q5_1:
|
7997
9121
|
case GGML_TYPE_Q8_0:
|
9122
|
+
case GGML_TYPE_Q8_1:
|
7998
9123
|
{
|
7999
9124
|
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
8000
9125
|
} break;
|
@@ -8132,6 +9257,7 @@ static void ggml_compute_forward_soft_max_f32(
|
|
8132
9257
|
|
8133
9258
|
uint16_t scvt;
|
8134
9259
|
for (int i = 0; i < nc; i++) {
|
9260
|
+
//printf("p[%3d] = %8.4f\n", i, p[i]);
|
8135
9261
|
if (p[i] == -INFINITY) {
|
8136
9262
|
p[i] = 0.0f;
|
8137
9263
|
} else {
|
@@ -8224,9 +9350,11 @@ static void ggml_compute_forward_rope_f32(
|
|
8224
9350
|
|
8225
9351
|
const float theta_scale = powf(10000.0, -2.0f/n_dims);
|
8226
9352
|
|
9353
|
+
const bool is_neox = mode & 2;
|
9354
|
+
|
8227
9355
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
8228
|
-
for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
8229
|
-
const int p = (mode == 0 ? n_past + i2 : i2);
|
9356
|
+
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
9357
|
+
const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
8230
9358
|
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
8231
9359
|
if (ir++ < ir0) continue;
|
8232
9360
|
if (ir > ir1) break;
|
@@ -8239,14 +9367,25 @@ static void ggml_compute_forward_rope_f32(
|
|
8239
9367
|
|
8240
9368
|
theta *= theta_scale;
|
8241
9369
|
|
8242
|
-
|
8243
|
-
|
9370
|
+
if (!is_neox) {
|
9371
|
+
const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
9372
|
+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
9373
|
+
|
9374
|
+
const float x0 = src[0];
|
9375
|
+
const float x1 = src[1];
|
9376
|
+
|
9377
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
9378
|
+
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
9379
|
+
} else {
|
9380
|
+
const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
9381
|
+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
8244
9382
|
|
8245
|
-
|
8246
|
-
|
9383
|
+
const float x0 = src[0];
|
9384
|
+
const float x1 = src[n_dims/2];
|
8247
9385
|
|
8248
|
-
|
8249
|
-
|
9386
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
9387
|
+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
9388
|
+
}
|
8250
9389
|
}
|
8251
9390
|
}
|
8252
9391
|
}
|
@@ -8301,9 +9440,11 @@ static void ggml_compute_forward_rope_f16(
|
|
8301
9440
|
|
8302
9441
|
const float theta_scale = powf(10000.0, -2.0f/n_dims);
|
8303
9442
|
|
9443
|
+
const bool is_neox = mode & 2;
|
9444
|
+
|
8304
9445
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
8305
|
-
for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
8306
|
-
const int p = (mode == 0 ? n_past + i2 : i2);
|
9446
|
+
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
9447
|
+
const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
8307
9448
|
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
8308
9449
|
if (ir++ < ir0) continue;
|
8309
9450
|
if (ir > ir1) break;
|
@@ -8316,14 +9457,25 @@ static void ggml_compute_forward_rope_f16(
|
|
8316
9457
|
|
8317
9458
|
theta *= theta_scale;
|
8318
9459
|
|
8319
|
-
|
8320
|
-
|
9460
|
+
if (!is_neox) {
|
9461
|
+
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
9462
|
+
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
9463
|
+
|
9464
|
+
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
9465
|
+
const float x1 = GGML_FP16_TO_FP32(src[1]);
|
9466
|
+
|
9467
|
+
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
9468
|
+
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
9469
|
+
} else {
|
9470
|
+
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
9471
|
+
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
8321
9472
|
|
8322
|
-
|
8323
|
-
|
9473
|
+
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
9474
|
+
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
|
8324
9475
|
|
8325
|
-
|
8326
|
-
|
9476
|
+
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
9477
|
+
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
9478
|
+
}
|
8327
9479
|
}
|
8328
9480
|
}
|
8329
9481
|
}
|
@@ -10402,11 +11554,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
10402
11554
|
case GGML_OP_CPY:
|
10403
11555
|
case GGML_OP_DUP:
|
10404
11556
|
{
|
10405
|
-
node->n_tasks =
|
11557
|
+
node->n_tasks = n_threads;
|
10406
11558
|
|
10407
11559
|
size_t cur = 0;
|
10408
|
-
if (node->type
|
10409
|
-
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0];
|
11560
|
+
if (ggml_is_quantized(node->type)) {
|
11561
|
+
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_threads;
|
10410
11562
|
}
|
10411
11563
|
|
10412
11564
|
work_size = MAX(work_size, cur);
|
@@ -10417,7 +11569,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
10417
11569
|
|
10418
11570
|
size_t cur = 0;
|
10419
11571
|
|
10420
|
-
if (node->src0->type
|
11572
|
+
if (ggml_is_quantized(node->src0->type)) {
|
10421
11573
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
|
10422
11574
|
}
|
10423
11575
|
|
@@ -10466,7 +11618,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
10466
11618
|
size_t cur = 0;
|
10467
11619
|
|
10468
11620
|
if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
|
10469
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
11621
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
10470
11622
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
10471
11623
|
node->n_tasks = 1; // TODO: this actually is doing nothing
|
10472
11624
|
// the threads are still spinning
|
@@ -10482,15 +11634,16 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
10482
11634
|
#endif
|
10483
11635
|
} else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
|
10484
11636
|
cur = 0;
|
10485
|
-
} else if (
|
10486
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
11637
|
+
} else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
|
11638
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
10487
11639
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
10488
11640
|
node->n_tasks = 1;
|
10489
11641
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
|
10490
11642
|
} else
|
10491
11643
|
#endif
|
10492
11644
|
{
|
10493
|
-
|
11645
|
+
const enum ggml_type type_q = quantize_fns[node->src0->type].vec_dot_type;
|
11646
|
+
cur = GGML_TYPE_SIZE[type_q]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[type_q];
|
10494
11647
|
}
|
10495
11648
|
} else {
|
10496
11649
|
GGML_ASSERT(false);
|
@@ -10818,9 +11971,9 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
|
|
10818
11971
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
10819
11972
|
struct ggml_tensor * node = cgraph->nodes[i];
|
10820
11973
|
|
10821
|
-
perf_total_per_op_us[node->op] += node->perf_time_us;
|
11974
|
+
perf_total_per_op_us[node->op] += MAX(1, node->perf_time_us);
|
10822
11975
|
|
10823
|
-
GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 ", %" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
|
11976
|
+
GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
|
10824
11977
|
i,
|
10825
11978
|
node->ne[0], node->ne[1], node->ne[2],
|
10826
11979
|
GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
|
@@ -10834,13 +11987,17 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
|
|
10834
11987
|
for (int i = 0; i < cgraph->n_leafs; i++) {
|
10835
11988
|
struct ggml_tensor * node = cgraph->leafs[i];
|
10836
11989
|
|
10837
|
-
GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 "] %8s\n",
|
11990
|
+
GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
|
10838
11991
|
i,
|
10839
11992
|
node->ne[0], node->ne[1],
|
10840
11993
|
GGML_OP_LABEL[node->op]);
|
10841
11994
|
}
|
10842
11995
|
|
10843
11996
|
for (int i = 0; i < GGML_OP_COUNT; i++) {
|
11997
|
+
if (perf_total_per_op_us[i] == 0) {
|
11998
|
+
continue;
|
11999
|
+
}
|
12000
|
+
|
10844
12001
|
GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_LABEL[i], (double) perf_total_per_op_us[i] / 1000.0);
|
10845
12002
|
}
|
10846
12003
|
|
@@ -11674,7 +12831,7 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
|
|
11674
12831
|
|
11675
12832
|
for (int i = 0; i < nb; i++) {
|
11676
12833
|
for (int l = 0; l < QK4_0; l += 2) {
|
11677
|
-
const uint8_t vi0 = y[i].qs[l/2] &
|
12834
|
+
const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
|
11678
12835
|
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
11679
12836
|
|
11680
12837
|
hist[vi0]++;
|
@@ -11697,7 +12854,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
|
|
11697
12854
|
|
11698
12855
|
for (int i = 0; i < nb; i++) {
|
11699
12856
|
for (int l = 0; l < QK4_1; l += 2) {
|
11700
|
-
const uint8_t vi0 = y[i].qs[l/2] &
|
12857
|
+
const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
|
11701
12858
|
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
11702
12859
|
|
11703
12860
|
hist[vi0]++;
|
@@ -11709,6 +12866,184 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
|
|
11709
12866
|
return (n/QK4_1*sizeof(block_q4_1));
|
11710
12867
|
}
|
11711
12868
|
|
12869
|
+
size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist) {
|
12870
|
+
assert(k % QK4_2 == 0);
|
12871
|
+
const int nb = k / QK4_2;
|
12872
|
+
|
12873
|
+
for (int j = 0; j < n; j += k) {
|
12874
|
+
block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
|
12875
|
+
|
12876
|
+
quantize_row_q4_2_reference(src + j, y, k);
|
12877
|
+
|
12878
|
+
for (int i = 0; i < nb; i++) {
|
12879
|
+
for (int l = 0; l < QK4_2; l += 2) {
|
12880
|
+
const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
|
12881
|
+
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
12882
|
+
|
12883
|
+
hist[vi0]++;
|
12884
|
+
hist[vi1]++;
|
12885
|
+
}
|
12886
|
+
}
|
12887
|
+
}
|
12888
|
+
|
12889
|
+
return (n/QK4_2*sizeof(block_q4_2));
|
12890
|
+
}
|
12891
|
+
|
12892
|
+
size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist) {
|
12893
|
+
assert(k % QK4_3 == 0);
|
12894
|
+
const int nb = k / QK4_3;
|
12895
|
+
|
12896
|
+
for (int j = 0; j < n; j += k) {
|
12897
|
+
block_q4_3 * restrict y = (block_q4_3 *)dst + j/QK4_3;
|
12898
|
+
|
12899
|
+
quantize_row_q4_3_reference(src + j, y, k);
|
12900
|
+
|
12901
|
+
for (int i = 0; i < nb; i++) {
|
12902
|
+
for (int l = 0; l < QK4_3; l += 2) {
|
12903
|
+
const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
|
12904
|
+
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
12905
|
+
|
12906
|
+
hist[vi0]++;
|
12907
|
+
hist[vi1]++;
|
12908
|
+
}
|
12909
|
+
}
|
12910
|
+
}
|
12911
|
+
|
12912
|
+
return (n/QK4_3*sizeof(block_q4_3));
|
12913
|
+
}
|
12914
|
+
|
12915
|
+
size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) {
|
12916
|
+
assert(k % QK5_0 == 0);
|
12917
|
+
const int nb = k / QK5_0;
|
12918
|
+
|
12919
|
+
for (int j = 0; j < n; j += k) {
|
12920
|
+
block_q5_0 * restrict y = (block_q5_0 *)dst + j/QK5_0;
|
12921
|
+
|
12922
|
+
quantize_row_q5_0_reference(src + j, y, k);
|
12923
|
+
|
12924
|
+
for (int i = 0; i < nb; i++) {
|
12925
|
+
uint32_t qh;
|
12926
|
+
memcpy(&qh, &y[i].qh, sizeof(qh));
|
12927
|
+
|
12928
|
+
for (int l = 0; l < QK5_0; l += 2) {
|
12929
|
+
const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
|
12930
|
+
const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
|
12931
|
+
|
12932
|
+
// cast to 16 bins
|
12933
|
+
const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
|
12934
|
+
const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2;
|
12935
|
+
|
12936
|
+
hist[vi0]++;
|
12937
|
+
hist[vi1]++;
|
12938
|
+
}
|
12939
|
+
}
|
12940
|
+
}
|
12941
|
+
|
12942
|
+
return (n/QK5_0*sizeof(block_q5_0));
|
12943
|
+
}
|
12944
|
+
|
12945
|
+
size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) {
|
12946
|
+
assert(k % QK5_1 == 0);
|
12947
|
+
const int nb = k / QK5_1;
|
12948
|
+
|
12949
|
+
for (int j = 0; j < n; j += k) {
|
12950
|
+
block_q5_1 * restrict y = (block_q5_1 *)dst + j/QK5_1;
|
12951
|
+
|
12952
|
+
quantize_row_q5_1_reference(src + j, y, k);
|
12953
|
+
|
12954
|
+
for (int i = 0; i < nb; i++) {
|
12955
|
+
uint32_t qh;
|
12956
|
+
memcpy(&qh, &y[i].qh, sizeof(qh));
|
12957
|
+
|
12958
|
+
for (int l = 0; l < QK5_1; l += 2) {
|
12959
|
+
const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
|
12960
|
+
const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
|
12961
|
+
|
12962
|
+
// cast to 16 bins
|
12963
|
+
const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
|
12964
|
+
const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2;
|
12965
|
+
|
12966
|
+
hist[vi0]++;
|
12967
|
+
hist[vi1]++;
|
12968
|
+
}
|
12969
|
+
}
|
12970
|
+
}
|
12971
|
+
|
12972
|
+
return (n/QK5_1*sizeof(block_q5_1));
|
12973
|
+
}
|
12974
|
+
|
12975
|
+
size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) {
|
12976
|
+
assert(k % QK8_0 == 0);
|
12977
|
+
const int nb = k / QK8_0;
|
12978
|
+
|
12979
|
+
for (int j = 0; j < n; j += k) {
|
12980
|
+
block_q8_0 * restrict y = (block_q8_0 *)dst + j/QK8_0;
|
12981
|
+
|
12982
|
+
quantize_row_q8_0_reference(src + j, y, k);
|
12983
|
+
|
12984
|
+
for (int i = 0; i < nb; i++) {
|
12985
|
+
for (int l = 0; l < QK8_0; ++l) {
|
12986
|
+
const int8_t vi = y[i].qs[l];
|
12987
|
+
|
12988
|
+
hist[vi/16 + 8]++;
|
12989
|
+
}
|
12990
|
+
}
|
12991
|
+
}
|
12992
|
+
|
12993
|
+
return (n/QK8_0*sizeof(block_q8_0));
|
12994
|
+
}
|
12995
|
+
|
12996
|
+
size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
|
12997
|
+
size_t result = 0;
|
12998
|
+
switch (type) {
|
12999
|
+
case GGML_TYPE_Q4_0:
|
13000
|
+
{
|
13001
|
+
GGML_ASSERT(start % QK4_0 == 0);
|
13002
|
+
block_q4_0 * block = (block_q4_0*)dst + start / QK4_0;
|
13003
|
+
result = ggml_quantize_q4_0(src + start, block, n, n, hist);
|
13004
|
+
} break;
|
13005
|
+
case GGML_TYPE_Q4_1:
|
13006
|
+
{
|
13007
|
+
GGML_ASSERT(start % QK4_1 == 0);
|
13008
|
+
block_q4_1 * block = (block_q4_1*)dst + start / QK4_1;
|
13009
|
+
result = ggml_quantize_q4_1(src + start, block, n, n, hist);
|
13010
|
+
} break;
|
13011
|
+
case GGML_TYPE_Q4_2:
|
13012
|
+
{
|
13013
|
+
GGML_ASSERT(start % QK4_2 == 0);
|
13014
|
+
block_q4_2 * block = (block_q4_2*)dst + start / QK4_2;
|
13015
|
+
result = ggml_quantize_q4_2(src + start, block, n, n, hist);
|
13016
|
+
} break;
|
13017
|
+
case GGML_TYPE_Q4_3:
|
13018
|
+
{
|
13019
|
+
GGML_ASSERT(start % QK4_3 == 0);
|
13020
|
+
block_q4_3 * block = (block_q4_3*)dst + start / QK4_3;
|
13021
|
+
result = ggml_quantize_q4_3(src + start, block, n, n, hist);
|
13022
|
+
} break;
|
13023
|
+
case GGML_TYPE_Q5_0:
|
13024
|
+
{
|
13025
|
+
GGML_ASSERT(start % QK5_0 == 0);
|
13026
|
+
block_q5_0 * block = (block_q5_0*)dst + start / QK5_0;
|
13027
|
+
result = ggml_quantize_q5_0(src + start, block, n, n, hist);
|
13028
|
+
} break;
|
13029
|
+
case GGML_TYPE_Q5_1:
|
13030
|
+
{
|
13031
|
+
GGML_ASSERT(start % QK5_1 == 0);
|
13032
|
+
block_q5_1 * block = (block_q5_1*)dst + start / QK5_1;
|
13033
|
+
result = ggml_quantize_q5_1(src + start, block, n, n, hist);
|
13034
|
+
} break;
|
13035
|
+
case GGML_TYPE_Q8_0:
|
13036
|
+
{
|
13037
|
+
GGML_ASSERT(start % QK8_0 == 0);
|
13038
|
+
block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
|
13039
|
+
result = ggml_quantize_q8_0(src + start, block, n, n, hist);
|
13040
|
+
} break;
|
13041
|
+
default:
|
13042
|
+
assert(false);
|
13043
|
+
}
|
13044
|
+
return result;
|
13045
|
+
}
|
13046
|
+
|
11712
13047
|
////////////////////////////////////////////////////////////////////////////////
|
11713
13048
|
|
11714
13049
|
int ggml_cpu_has_avx(void) {
|
@@ -11800,13 +13135,33 @@ int ggml_cpu_has_wasm_simd(void) {
|
|
11800
13135
|
}
|
11801
13136
|
|
11802
13137
|
int ggml_cpu_has_blas(void) {
|
11803
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
13138
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
13139
|
+
return 1;
|
13140
|
+
#else
|
13141
|
+
return 0;
|
13142
|
+
#endif
|
13143
|
+
}
|
13144
|
+
|
13145
|
+
int ggml_cpu_has_cublas(void) {
|
13146
|
+
#if defined(GGML_USE_CUBLAS)
|
11804
13147
|
return 1;
|
11805
13148
|
#else
|
11806
13149
|
return 0;
|
11807
13150
|
#endif
|
11808
13151
|
}
|
11809
13152
|
|
13153
|
+
int ggml_cpu_has_clblast(void) {
|
13154
|
+
#if defined(GGML_USE_CLBLAST)
|
13155
|
+
return 1;
|
13156
|
+
#else
|
13157
|
+
return 0;
|
13158
|
+
#endif
|
13159
|
+
}
|
13160
|
+
|
13161
|
+
int ggml_cpu_has_gpublas(void) {
|
13162
|
+
return ggml_cpu_has_cublas() || ggml_cpu_has_clblast();
|
13163
|
+
}
|
13164
|
+
|
11810
13165
|
int ggml_cpu_has_sse3(void) {
|
11811
13166
|
#if defined(__SSE3__)
|
11812
13167
|
return 1;
|