llama_cpp 0.0.5 → 0.0.7
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +22 -0
- data/ext/llama_cpp/extconf.rb +24 -1
- data/ext/llama_cpp/llama_cpp.cpp +72 -0
- data/ext/llama_cpp/src/ggml-cuda.h +44 -0
- data/ext/llama_cpp/src/ggml-opencl.c +216 -0
- data/ext/llama_cpp/src/ggml-opencl.h +24 -0
- data/ext/llama_cpp/src/ggml.c +2324 -969
- data/ext/llama_cpp/src/ggml.h +656 -619
- data/ext/llama_cpp/src/llama.cpp +269 -42
- data/ext/llama_cpp/src/llama.h +22 -14
- data/ext/llama_cpp/src/llama_util.h +15 -3
- data/lib/llama_cpp/client.rb +151 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +16 -8
- data/sig/llama_cpp.rbs +26 -2
- metadata +6 -2
data/ext/llama_cpp/src/ggml.c
CHANGED
@@ -19,6 +19,7 @@
|
|
19
19
|
#include <inttypes.h>
|
20
20
|
#include <stdio.h>
|
21
21
|
#include <float.h>
|
22
|
+
#include <limits.h>
|
22
23
|
|
23
24
|
// if C99 - static_assert is noop
|
24
25
|
// ref: https://stackoverflow.com/a/53923785/4039976
|
@@ -142,10 +143,14 @@ inline static void* ggml_aligned_malloc(size_t size) {
|
|
142
143
|
} \
|
143
144
|
} while (0)
|
144
145
|
|
145
|
-
#
|
146
|
+
#if defined(GGML_USE_ACCELERATE)
|
146
147
|
#include <Accelerate/Accelerate.h>
|
147
|
-
#elif GGML_USE_OPENBLAS
|
148
|
+
#elif defined(GGML_USE_OPENBLAS)
|
148
149
|
#include <cblas.h>
|
150
|
+
#elif defined(GGML_USE_CUBLAS)
|
151
|
+
#include "ggml-cuda.h"
|
152
|
+
#elif defined(GGML_USE_CLBLAST)
|
153
|
+
#include "ggml-opencl.h"
|
149
154
|
#endif
|
150
155
|
|
151
156
|
#undef MIN
|
@@ -325,6 +330,20 @@ static ggml_fp16_t table_exp_f16[1 << 16];
|
|
325
330
|
// precomputed f32 table for f16 (256 KB)
|
326
331
|
static float table_f32_f16[1 << 16];
|
327
332
|
|
333
|
+
#if defined(__ARM_NEON)
|
334
|
+
#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
|
335
|
+
#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
|
336
|
+
#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
|
337
|
+
#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
|
338
|
+
#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
|
339
|
+
#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
|
340
|
+
#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
|
341
|
+
#define B8(c,s ) B7(c,s, c), B7(c,s, s)
|
342
|
+
|
343
|
+
// precomputed tables for expanding 8bits to 8 bytes (shl 4)
|
344
|
+
static const uint64_t table_b2b_u[1 << 8] = { B8(00, 10) };
|
345
|
+
#endif
|
346
|
+
|
328
347
|
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
|
329
348
|
// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
|
330
349
|
// This is also true for POWER9.
|
@@ -427,12 +446,69 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
|
|
427
446
|
// quantization
|
428
447
|
//
|
429
448
|
|
430
|
-
|
431
|
-
//
|
449
|
+
#if __AVX__ || __AVX2__ || __AVX512F__
|
450
|
+
// Unpack 16 4-bit fields into 16 bytes
|
451
|
+
// The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
|
452
|
+
static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
|
453
|
+
{
|
454
|
+
// Load 8 bytes from memory
|
455
|
+
__m128i tmp = _mm_loadl_epi64( ( const __m128i* )rsi );
|
456
|
+
|
457
|
+
// Expand bytes into uint16_t values
|
458
|
+
__m128i bytes = _mm_cvtepu8_epi16( tmp );
|
459
|
+
|
460
|
+
// Unpack values into individual bytes
|
461
|
+
const __m128i lowMask = _mm_set1_epi8( 0xF );
|
462
|
+
__m128i high = _mm_andnot_si128( lowMask, bytes );
|
463
|
+
__m128i low = _mm_and_si128( lowMask, bytes );
|
464
|
+
high = _mm_slli_epi16( high, 4 );
|
465
|
+
bytes = _mm_or_si128( low, high );
|
466
|
+
return bytes;
|
467
|
+
}
|
468
|
+
|
469
|
+
// horizontally add 8 floats
|
470
|
+
static inline float hsum_float_8(const __m256 x) {
|
471
|
+
__m128 res = _mm256_extractf128_ps(x, 1);
|
472
|
+
res = _mm_add_ps(res, _mm256_castps256_ps128(x));
|
473
|
+
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
|
474
|
+
res = _mm_add_ss(res, _mm_movehdup_ps(res));
|
475
|
+
return _mm_cvtss_f32(res);
|
476
|
+
}
|
477
|
+
|
478
|
+
// horizontally add 8 int32_t
|
479
|
+
static inline int hsum_i32_8(const __m256i a) {
|
480
|
+
const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
|
481
|
+
const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
|
482
|
+
const __m128i sum64 = _mm_add_epi32(hi64, sum128);
|
483
|
+
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
|
484
|
+
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
|
485
|
+
}
|
486
|
+
|
487
|
+
// horizontally add 4 int32_t
|
488
|
+
static inline int hsum_i32_4(const __m128i a) {
|
489
|
+
const __m128i hi64 = _mm_unpackhi_epi64(a, a);
|
490
|
+
const __m128i sum64 = _mm_add_epi32(hi64, a);
|
491
|
+
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
|
492
|
+
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
|
493
|
+
}
|
494
|
+
|
432
495
|
#if __AVX2__ || __AVX512F__
|
496
|
+
// spread 32 bits to 32 bytes { 0x00, 0xFF }
|
497
|
+
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
|
498
|
+
uint32_t x32;
|
499
|
+
memcpy(&x32, x, sizeof(uint32_t));
|
500
|
+
const __m256i shuf_mask = _mm256_set_epi64x(
|
501
|
+
0x0303030303030303, 0x0202020202020202,
|
502
|
+
0x0101010101010101, 0x0000000000000000);
|
503
|
+
__m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
|
504
|
+
const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
|
505
|
+
bytes = _mm256_or_si256(bytes, bit_mask);
|
506
|
+
return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
|
507
|
+
}
|
508
|
+
|
433
509
|
// Unpack 32 4-bit fields into 32 bytes
|
434
510
|
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
435
|
-
static inline __m256i
|
511
|
+
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
|
436
512
|
{
|
437
513
|
// Load 16 bytes from memory
|
438
514
|
__m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
|
@@ -449,9 +525,38 @@ static inline __m256i bytesFromNibbles( const uint8_t* rsi )
|
|
449
525
|
return bytes;
|
450
526
|
}
|
451
527
|
|
528
|
+
// add int16_t pairwise and return as float vector
|
529
|
+
static inline __m256 sum_i16_pairs_float(const __m256i x) {
|
530
|
+
const __m256i ones = _mm256_set1_epi16(1);
|
531
|
+
const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
|
532
|
+
return _mm256_cvtepi32_ps(summed_pairs);
|
533
|
+
}
|
534
|
+
|
535
|
+
// multiply int8_t, add results pairwise twice and return as float vector
|
536
|
+
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
|
537
|
+
// Get absolute values of x vectors
|
538
|
+
const __m256i ax = _mm256_sign_epi8(x, x);
|
539
|
+
// Sign the values of the y vectors
|
540
|
+
const __m256i sy = _mm256_sign_epi8(y, x);
|
541
|
+
#if __AVXVNNI__
|
542
|
+
const __m256i zero = _mm256_setzero_si256();
|
543
|
+
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
|
544
|
+
return _mm256_cvtepi32_ps(summed_pairs);
|
545
|
+
#else
|
546
|
+
// Perform multiplication and create 16-bit values
|
547
|
+
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
548
|
+
return sum_i16_pairs_float(dot);
|
549
|
+
#endif
|
550
|
+
}
|
551
|
+
|
452
552
|
static inline __m128i packNibbles( __m256i bytes )
|
453
553
|
{
|
454
554
|
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
|
555
|
+
#if __AVX512F__
|
556
|
+
const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000
|
557
|
+
bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
|
558
|
+
return _mm256_cvtepi16_epi8(bytes); // abcd_efgh
|
559
|
+
#else
|
455
560
|
const __m256i lowByte = _mm256_set1_epi16( 0xFF );
|
456
561
|
__m256i high = _mm256_andnot_si256( lowByte, bytes );
|
457
562
|
__m256i low = _mm256_and_si256( lowByte, bytes );
|
@@ -462,25 +567,9 @@ static inline __m128i packNibbles( __m256i bytes )
|
|
462
567
|
__m128i r0 = _mm256_castsi256_si128( bytes );
|
463
568
|
__m128i r1 = _mm256_extracti128_si256( bytes, 1 );
|
464
569
|
return _mm_packus_epi16( r0, r1 );
|
570
|
+
#endif
|
465
571
|
}
|
466
|
-
#
|
467
|
-
static inline __m128i bytesFromNibbles( const uint8_t* rsi )
|
468
|
-
{
|
469
|
-
// Load 8 bytes from memory
|
470
|
-
__m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
|
471
|
-
|
472
|
-
// Expand bytes into uint16_t values
|
473
|
-
__m128i bytes = _mm_cvtepu8_epi16( tmp );
|
474
|
-
|
475
|
-
// Unpack values into individual bytes
|
476
|
-
const __m128i lowMask = _mm_set1_epi8( 0xF );
|
477
|
-
__m128i high = _mm_andnot_si128( lowMask, bytes );
|
478
|
-
__m128i low = _mm_and_si128( lowMask, bytes );
|
479
|
-
high = _mm_slli_epi16( high, 4 );
|
480
|
-
bytes = _mm_or_si128( low, high );
|
481
|
-
return bytes;
|
482
|
-
}
|
483
|
-
|
572
|
+
#else
|
484
573
|
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
485
574
|
{
|
486
575
|
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
|
@@ -497,6 +586,7 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
|
497
586
|
return _mm_packus_epi16( bytes1, bytes2);
|
498
587
|
}
|
499
588
|
#endif
|
589
|
+
#endif // __AVX__ || __AVX2__ || __AVX512F__
|
500
590
|
|
501
591
|
#if __ARM_NEON
|
502
592
|
|
@@ -514,6 +604,18 @@ inline static uint16_t vaddvq_u8(uint8x16_t v) {
|
|
514
604
|
(uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
|
515
605
|
}
|
516
606
|
|
607
|
+
inline static int16_t vaddvq_s8(int8x16_t v) {
|
608
|
+
return
|
609
|
+
(int16_t)vgetq_lane_s8(v, 0) + (int16_t)vgetq_lane_s8(v, 1) +
|
610
|
+
(int16_t)vgetq_lane_s8(v, 2) + (int16_t)vgetq_lane_s8(v, 3) +
|
611
|
+
(int16_t)vgetq_lane_s8(v, 4) + (int16_t)vgetq_lane_s8(v, 5) +
|
612
|
+
(int16_t)vgetq_lane_s8(v, 6) + (int16_t)vgetq_lane_s8(v, 7) +
|
613
|
+
(int16_t)vgetq_lane_s8(v, 8) + (int16_t)vgetq_lane_s8(v, 9) +
|
614
|
+
(int16_t)vgetq_lane_s8(v, 10) + (int16_t)vgetq_lane_s8(v, 11) +
|
615
|
+
(int16_t)vgetq_lane_s8(v, 12) + (int16_t)vgetq_lane_s8(v, 13) +
|
616
|
+
(int16_t)vgetq_lane_s8(v, 14) + (int16_t)vgetq_lane_s8(v, 15);
|
617
|
+
}
|
618
|
+
|
517
619
|
inline static int32_t vaddvq_s16(int16x8_t v) {
|
518
620
|
return
|
519
621
|
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
|
@@ -583,7 +685,39 @@ typedef struct {
|
|
583
685
|
float m; // min
|
584
686
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
585
687
|
} block_q4_1;
|
586
|
-
static_assert(sizeof(block_q4_1) == sizeof(float)
|
688
|
+
static_assert(sizeof(block_q4_1) == 2 * sizeof(float) + QK4_1 / 2, "wrong q4_1 block size/padding");
|
689
|
+
|
690
|
+
#define QK4_2 16
|
691
|
+
typedef struct {
|
692
|
+
ggml_fp16_t d; // delta
|
693
|
+
uint8_t qs[QK4_2 / 2]; // nibbles / quants
|
694
|
+
} block_q4_2;
|
695
|
+
static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
|
696
|
+
|
697
|
+
#define QK4_3 16
|
698
|
+
typedef struct {
|
699
|
+
ggml_fp16_t d; // delta
|
700
|
+
ggml_fp16_t m; // min
|
701
|
+
uint8_t qs[QK4_3 / 2]; // nibbles / quants
|
702
|
+
} block_q4_3;
|
703
|
+
static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
|
704
|
+
|
705
|
+
#define QK5_0 32
|
706
|
+
typedef struct {
|
707
|
+
ggml_fp16_t d; // delta
|
708
|
+
uint8_t qh[4]; // 5-th bit of quants
|
709
|
+
uint8_t qs[QK5_0 / 2]; // nibbles / quants
|
710
|
+
} block_q5_0;
|
711
|
+
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
|
712
|
+
|
713
|
+
#define QK5_1 32
|
714
|
+
typedef struct {
|
715
|
+
ggml_fp16_t d; // delta
|
716
|
+
ggml_fp16_t m; // min
|
717
|
+
uint8_t qh[4]; // 5-th bit of quants
|
718
|
+
uint8_t qs[QK5_1 / 2]; // nibbles / quants
|
719
|
+
} block_q5_1;
|
720
|
+
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
|
587
721
|
|
588
722
|
#define QK8_0 32
|
589
723
|
typedef struct {
|
@@ -592,6 +726,14 @@ typedef struct {
|
|
592
726
|
} block_q8_0;
|
593
727
|
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
|
594
728
|
|
729
|
+
#define QK8_1 32
|
730
|
+
typedef struct {
|
731
|
+
float d; // delta
|
732
|
+
float s0; // d * sum(qs[i]) low
|
733
|
+
float s1; // d * sum(qs[i]) high
|
734
|
+
int8_t qs[QK8_1]; // quants
|
735
|
+
} block_q8_1;
|
736
|
+
static_assert(sizeof(block_q8_1) == 3*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
|
595
737
|
|
596
738
|
// reference implementation for deterministic creation of model files
|
597
739
|
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
|
@@ -602,13 +744,17 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
|
|
602
744
|
|
603
745
|
for (int i = 0; i < nb; i++) {
|
604
746
|
float amax = 0.0f; // absolute max
|
747
|
+
float max = 0.0f;
|
605
748
|
|
606
749
|
for (int l = 0; l < QK4_0; l++) {
|
607
750
|
const float v = x[i*QK4_0 + l];
|
608
|
-
|
751
|
+
if (amax < fabsf(v)) {
|
752
|
+
amax = fabsf(v);
|
753
|
+
max = v;
|
754
|
+
}
|
609
755
|
}
|
610
756
|
|
611
|
-
const float d =
|
757
|
+
const float d = max / -8;
|
612
758
|
const float id = d ? 1.0f/d : 0.0f;
|
613
759
|
|
614
760
|
y[i].d = d;
|
@@ -617,8 +763,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
|
|
617
763
|
const float v0 = x[i*QK4_0 + l + 0]*id;
|
618
764
|
const float v1 = x[i*QK4_0 + l + 1]*id;
|
619
765
|
|
620
|
-
const uint8_t vi0 = (int8_t)roundf(v0) + 8;
|
621
|
-
const uint8_t vi1 = (int8_t)roundf(v1) + 8;
|
766
|
+
const uint8_t vi0 = MIN(15, (int8_t)roundf(v0) + 8);
|
767
|
+
const uint8_t vi1 = MIN(15, (int8_t)roundf(v1) + 8);
|
622
768
|
|
623
769
|
assert(vi0 < 16);
|
624
770
|
assert(vi1 < 16);
|
@@ -638,28 +784,42 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
638
784
|
|
639
785
|
#if defined(__POWER9_VECTOR__)
|
640
786
|
const vector float v85 = vec_splats(8.5f);
|
787
|
+
const vector signed int v15 = vec_splats(15);
|
641
788
|
for (int i = 0; i < nb; i++) {
|
642
|
-
float
|
789
|
+
float max = 0.0f;
|
790
|
+
float min = 0.0f;
|
643
791
|
|
644
792
|
vector float srcv [8];
|
645
|
-
vector float
|
646
|
-
vector float
|
793
|
+
vector float maxv[8];
|
794
|
+
vector float minv[8];
|
647
795
|
|
648
796
|
for (int l = 0; l < 8; l++) srcv[l] = *(vector float *)(x + i*32 + 4*l);
|
649
|
-
for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
|
650
|
-
|
651
|
-
for (int l = 0; l < 4; l++)
|
652
|
-
//for (int l = 0; l < 2; l++)
|
653
|
-
|
654
|
-
|
655
|
-
//for (int l = 0; l < 1; l++)
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
797
|
+
//for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
|
798
|
+
|
799
|
+
for (int l = 0; l < 4; l++) maxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]);
|
800
|
+
//for (int l = 0; l < 2; l++) maxv[4*l] = vec_max(maxv[4*l], maxv[4*l+2]);
|
801
|
+
maxv[0] = vec_max(maxv[0], maxv[2]);
|
802
|
+
maxv[4] = vec_max(maxv[4], maxv[6]);
|
803
|
+
//for (int l = 0; l < 1; l++) maxv[8*l] = vec_max(maxv[8*l], maxv[8*l+4]);
|
804
|
+
maxv[0] = vec_max(maxv[0], maxv[4]);
|
805
|
+
|
806
|
+
for (int l = 0; l < 4; l++) minv[2*l] = vec_min(asrcv[2*l], asrcv[2*l+1]);
|
807
|
+
//for (int l = 0; l < 2; l++) minv[4*l] = vec_min(minv[4*l], minv[4*l+2]);
|
808
|
+
minv[0] = vec_min(minv[0], minv[2]);
|
809
|
+
minv[4] = vec_min(minv[4], minv[6]);
|
810
|
+
//for (int l = 0; l < 1; l++) minv[8*l] = vec_min(minv[8*l], minv[8*l+4]);
|
811
|
+
minv[0] = vec_min(minv[0], minv[4]);
|
812
|
+
|
813
|
+
|
814
|
+
max = MAX(
|
815
|
+
MAX(vec_extract(maxv[0], 0), vec_extract(maxv[0], 1)),
|
816
|
+
MAX(vec_extract(maxv[0], 2), vec_extract(maxv[0], 3)));
|
817
|
+
min = MIN(
|
818
|
+
MIN(vec_extract(minv[0], 0), vec_extract(minv[0], 1)),
|
819
|
+
MIN(vec_extract(minv[0], 2), vec_extract(minv[0], 3)));
|
820
|
+
|
821
|
+
const float magnitude = max >= fabsf(min) ? max : min;
|
822
|
+
const float d = magnitude / -8;
|
663
823
|
const float id = d ? 1.0/d : 0.0;
|
664
824
|
|
665
825
|
y[i].d = d;
|
@@ -669,27 +829,33 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
669
829
|
for (int l = 0; l < 8; l++) {
|
670
830
|
const vector float vf = vec_madd(srcv[l], vid, v85);
|
671
831
|
const vector signed int vi = vec_signed(vf);
|
832
|
+
const vector signed int vc = vec_min(vi, v15);
|
672
833
|
|
673
|
-
pb[2*l + 0] = vec_extract(
|
674
|
-
pb[2*l + 1] = vec_extract(
|
834
|
+
pb[2*l + 0] = vec_extract(vc, 0) | (vec_extract(vc, 1) << 4);
|
835
|
+
pb[2*l + 1] = vec_extract(vc, 2) | (vec_extract(vc, 3) << 4);
|
675
836
|
}
|
676
837
|
}
|
677
838
|
#elif __ARM_NEON
|
678
839
|
for (int i = 0; i < nb; i++) {
|
679
840
|
float32x4_t srcv [8];
|
680
|
-
float32x4_t
|
681
|
-
float32x4_t
|
841
|
+
float32x4_t maxv[8];
|
842
|
+
float32x4_t minv[8];
|
682
843
|
|
683
844
|
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
|
684
|
-
for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
|
685
845
|
|
686
|
-
for (int l = 0; l < 4; l++)
|
687
|
-
for (int l = 0; l < 2; l++)
|
688
|
-
for (int l = 0; l < 1; l++)
|
846
|
+
for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l+1]);
|
847
|
+
for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l+2]);
|
848
|
+
for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l+4]);
|
689
849
|
|
690
|
-
|
850
|
+
for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l+1]);
|
851
|
+
for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l+2]);
|
852
|
+
for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l+4]);
|
853
|
+
|
854
|
+
const float max = vmaxvq_f32(maxv[0]);
|
855
|
+
const float min = vminvq_f32(minv[0]);
|
691
856
|
|
692
|
-
const float
|
857
|
+
const float magnitude = max >= fabsf(min) ? max : min;
|
858
|
+
const float d = magnitude / -8;
|
693
859
|
const float id = d ? 1.0f/d : 0.0f;
|
694
860
|
|
695
861
|
y[i].d = d;
|
@@ -698,9 +864,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
698
864
|
const float32x4_t v = vmulq_n_f32(srcv[l], id);
|
699
865
|
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
|
700
866
|
const int32x4_t vi = vcvtq_s32_f32(vf);
|
867
|
+
const int32x4_t vc = vminq_s32(vi, vdupq_n_s32(15));
|
701
868
|
|
702
|
-
y[i].qs[2*l + 0] = vgetq_lane_s32(
|
703
|
-
y[i].qs[2*l + 1] = vgetq_lane_s32(
|
869
|
+
y[i].qs[2*l + 0] = vgetq_lane_s32(vc, 0) | (vgetq_lane_s32(vc, 1) << 4);
|
870
|
+
y[i].qs[2*l + 1] = vgetq_lane_s32(vc, 2) | (vgetq_lane_s32(vc, 3) << 4);
|
704
871
|
}
|
705
872
|
}
|
706
873
|
#elif defined(__AVX2__)
|
@@ -712,22 +879,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
712
879
|
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
713
880
|
x += 32;
|
714
881
|
|
715
|
-
// Compute max
|
716
|
-
|
717
|
-
__m256
|
718
|
-
|
719
|
-
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
|
720
|
-
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
|
882
|
+
// Compute max for the block
|
883
|
+
__m256 max = _mm256_max_ps( v0, v1 );
|
884
|
+
__m256 maxTmp = _mm256_max_ps( v2, v3 );
|
885
|
+
max = _mm256_max_ps( max, maxTmp );
|
721
886
|
|
722
|
-
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps(
|
887
|
+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
|
723
888
|
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
724
889
|
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
725
890
|
const float maxScalar = _mm_cvtss_f32( max4 );
|
726
891
|
|
892
|
+
// Compute min for the block
|
893
|
+
__m256 min = _mm256_min_ps( v0, v1 );
|
894
|
+
__m256 minTmp = _mm256_min_ps( v2, v3 );
|
895
|
+
min = _mm256_min_ps( min, minTmp );
|
896
|
+
|
897
|
+
__m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
|
898
|
+
min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
|
899
|
+
min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
|
900
|
+
const float minScalar = _mm_cvtss_f32( min4 );
|
901
|
+
|
727
902
|
// Quantize these floats
|
728
|
-
const float
|
903
|
+
const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
|
904
|
+
const float d = magnitude / -8.0f;
|
729
905
|
y[i].d = d;
|
730
|
-
const float id = (
|
906
|
+
const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
|
731
907
|
const __m256 mul = _mm256_set1_ps( id );
|
732
908
|
|
733
909
|
// Apply the multiplier
|
@@ -760,9 +936,11 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
760
936
|
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
|
761
937
|
i0 = _mm256_permutevar8x32_epi32( i0, perm );
|
762
938
|
|
763
|
-
// Apply offset to translate the range from [ -
|
939
|
+
// Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
|
764
940
|
const __m256i off = _mm256_set1_epi8( 8 );
|
765
941
|
i0 = _mm256_add_epi8( i0, off );
|
942
|
+
const __m256i maxNibble = _mm256_set1_epi8( 15 );
|
943
|
+
i0 = _mm256_min_epi8( i0, maxNibble );
|
766
944
|
|
767
945
|
// Compress the vector into 4 bit/value, and store
|
768
946
|
__m128i res = packNibbles( i0 );
|
@@ -777,22 +955,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
777
955
|
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
778
956
|
x += 32;
|
779
957
|
|
780
|
-
// Compute max
|
781
|
-
|
782
|
-
__m256
|
783
|
-
|
784
|
-
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
|
785
|
-
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
|
958
|
+
// Compute max for the block
|
959
|
+
__m256 max = _mm256_max_ps( v0, v1 );
|
960
|
+
__m256 maxTmp = _mm256_max_ps( v2, v3 );
|
961
|
+
max = _mm256_max_ps( max, maxTmp );
|
786
962
|
|
787
|
-
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps(
|
963
|
+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
|
788
964
|
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
789
965
|
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
790
966
|
const float maxScalar = _mm_cvtss_f32( max4 );
|
791
967
|
|
968
|
+
// Compute min for the block
|
969
|
+
__m256 min = _mm256_min_ps( v0, v1 );
|
970
|
+
__m256 minTmp = _mm256_min_ps( v2, v3 );
|
971
|
+
min = _mm256_min_ps( min, minTmp );
|
972
|
+
|
973
|
+
__m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
|
974
|
+
min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
|
975
|
+
min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
|
976
|
+
const float minScalar = _mm_cvtss_f32( min4 );
|
977
|
+
|
792
978
|
// Quantize these floats
|
793
|
-
const float
|
979
|
+
const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
|
980
|
+
const float d = magnitude / -8.0f;
|
794
981
|
y[i].d = d;
|
795
|
-
const float id = (
|
982
|
+
const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
|
796
983
|
const __m256 mul = _mm256_set1_ps( id );
|
797
984
|
|
798
985
|
// Apply the multiplier
|
@@ -833,10 +1020,13 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
833
1020
|
ni0 = _mm_packs_epi16( ni0, ni2 );
|
834
1021
|
ni4 = _mm_packs_epi16( ni4, ni6 );
|
835
1022
|
|
836
|
-
// Apply offset to translate the range from [ -
|
837
|
-
const __m128i off = _mm_set1_epi8( 8);
|
1023
|
+
// Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
|
1024
|
+
const __m128i off = _mm_set1_epi8( 8 );
|
838
1025
|
ni0 = _mm_add_epi8( ni0, off );
|
839
1026
|
ni4 = _mm_add_epi8( ni4, off );
|
1027
|
+
const __m128i maxNibble = _mm_set1_epi8( 15 );
|
1028
|
+
ni0 = _mm_min_epi8( ni0, maxNibble );
|
1029
|
+
ni4 = _mm_min_epi8( ni4, maxNibble );
|
840
1030
|
|
841
1031
|
// Compress the vector into 4 bit/value, and store
|
842
1032
|
__m128i res = packNibbles( ni0, ni4 );
|
@@ -844,24 +1034,32 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
844
1034
|
}
|
845
1035
|
#elif defined(__wasm_simd128__)
|
846
1036
|
for (int i = 0; i < nb; i++) {
|
847
|
-
float
|
1037
|
+
float max = 0.0f;
|
1038
|
+
float min = 0.0f;
|
848
1039
|
|
849
1040
|
v128_t srcv [8];
|
850
|
-
v128_t
|
851
|
-
v128_t
|
1041
|
+
v128_t maxv[8];
|
1042
|
+
v128_t minv[8];
|
852
1043
|
|
853
1044
|
for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l);
|
854
|
-
for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]);
|
855
1045
|
|
856
|
-
for (int l = 0; l < 4; l++)
|
857
|
-
for (int l = 0; l < 2; l++)
|
858
|
-
for (int l = 0; l < 1; l++)
|
1046
|
+
for (int l = 0; l < 4; l++) maxv[2*l] = wasm_f32x4_max(srcv[2*l], srcv[2*l+1]);
|
1047
|
+
for (int l = 0; l < 2; l++) maxv[4*l] = wasm_f32x4_max(maxv[4*l], maxv[4*l+2]);
|
1048
|
+
for (int l = 0; l < 1; l++) maxv[8*l] = wasm_f32x4_max(maxv[8*l], maxv[8*l+4]);
|
1049
|
+
|
1050
|
+
for (int l = 0; l < 4; l++) minv[2*l] = wasm_f32x4_min(srcv[2*l], srcv[2*l+1]);
|
1051
|
+
for (int l = 0; l < 2; l++) minv[4*l] = wasm_f32x4_min(minv[4*l], minv[4*l+2]);
|
1052
|
+
for (int l = 0; l < 1; l++) minv[8*l] = wasm_f32x4_min(minv[8*l], minv[8*l+4]);
|
859
1053
|
|
860
|
-
|
861
|
-
MAX(wasm_f32x4_extract_lane(
|
862
|
-
MAX(wasm_f32x4_extract_lane(
|
1054
|
+
max = MAX(
|
1055
|
+
MAX(wasm_f32x4_extract_lane(maxv[0], 0), wasm_f32x4_extract_lane(maxv[0], 1)),
|
1056
|
+
MAX(wasm_f32x4_extract_lane(maxv[0], 2), wasm_f32x4_extract_lane(maxv[0], 3)));
|
1057
|
+
min = MIN(
|
1058
|
+
MIN(wasm_f32x4_extract_lane(minv[0], 0), wasm_f32x4_extract_lane(minv[0], 1)),
|
1059
|
+
MIN(wasm_f32x4_extract_lane(minv[0], 2), wasm_f32x4_extract_lane(minv[0], 3)));
|
863
1060
|
|
864
|
-
const float
|
1061
|
+
const float magnitude = max >= fabsf(min) ? max : min;
|
1062
|
+
const float d = magnitude / -8;
|
865
1063
|
const float id = d ? 1.0/d : 0.0;
|
866
1064
|
|
867
1065
|
y[i].d = d;
|
@@ -870,9 +1068,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
870
1068
|
const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
|
871
1069
|
const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
|
872
1070
|
const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
|
1071
|
+
const v128_t vc = wasm_i32x4_min_u(vi, wasm_i32x4_splat(15));
|
873
1072
|
|
874
|
-
y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(
|
875
|
-
y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(
|
1073
|
+
y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vc, 0) | (wasm_i32x4_extract_lane(vc, 1) << 4);
|
1074
|
+
y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vc, 2) | (wasm_i32x4_extract_lane(vc, 3) << 4);
|
876
1075
|
}
|
877
1076
|
}
|
878
1077
|
#else
|
@@ -1045,6 +1244,193 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
|
|
1045
1244
|
#endif
|
1046
1245
|
}
|
1047
1246
|
|
1247
|
+
// reference implementation for deterministic creation of model files
|
1248
|
+
static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * restrict y, int k) {
|
1249
|
+
assert(k % QK4_2 == 0);
|
1250
|
+
|
1251
|
+
const int nb = k / QK4_2;
|
1252
|
+
|
1253
|
+
for (int i = 0; i < nb; i++) {
|
1254
|
+
float amax = 0.0f; // absolute max
|
1255
|
+
float max = 0.0f;
|
1256
|
+
|
1257
|
+
for (int l = 0; l < QK4_2; l++) {
|
1258
|
+
const float v = x[i*QK4_2 + l];
|
1259
|
+
if (amax < fabsf(v)) {
|
1260
|
+
amax = fabsf(v);
|
1261
|
+
max = v;
|
1262
|
+
}
|
1263
|
+
}
|
1264
|
+
|
1265
|
+
const float d = max / -8;
|
1266
|
+
|
1267
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1268
|
+
|
1269
|
+
y[i].d = GGML_FP32_TO_FP16(d);
|
1270
|
+
|
1271
|
+
for (int l = 0; l < QK4_2; l += 2) {
|
1272
|
+
const float v0 = x[i*QK4_2 + l + 0]*id;
|
1273
|
+
const float v1 = x[i*QK4_2 + l + 1]*id;
|
1274
|
+
|
1275
|
+
const uint8_t vi0 = MIN(15, (uint8_t)(v0 + 8.5f));
|
1276
|
+
const uint8_t vi1 = MIN(15, (uint8_t)(v1 + 8.5f));
|
1277
|
+
|
1278
|
+
assert(vi0 < 16);
|
1279
|
+
assert(vi1 < 16);
|
1280
|
+
|
1281
|
+
y[i].qs[l/2] = vi0 | (vi1 << 4);
|
1282
|
+
}
|
1283
|
+
}
|
1284
|
+
}
|
1285
|
+
|
1286
|
+
static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
|
1287
|
+
assert(k % QK4_2 == 0);
|
1288
|
+
|
1289
|
+
block_q4_2 * restrict y = vy;
|
1290
|
+
|
1291
|
+
quantize_row_q4_2_reference(x, y, k);
|
1292
|
+
}
|
1293
|
+
|
1294
|
+
static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) {
|
1295
|
+
assert(k % QK4_3 == 0);
|
1296
|
+
const int nb = k / QK4_3;
|
1297
|
+
|
1298
|
+
for (int i = 0; i < nb; i++) {
|
1299
|
+
float min = FLT_MAX;
|
1300
|
+
float max = -FLT_MAX;
|
1301
|
+
|
1302
|
+
for (int l = 0; l < QK4_3; l++) {
|
1303
|
+
const float v = x[i*QK4_3 + l];
|
1304
|
+
if (v < min) min = v;
|
1305
|
+
if (v > max) max = v;
|
1306
|
+
}
|
1307
|
+
|
1308
|
+
const float d = (max - min) / ((1 << 4) - 1);
|
1309
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1310
|
+
|
1311
|
+
y[i].d = GGML_FP32_TO_FP16(d);
|
1312
|
+
y[i].m = GGML_FP32_TO_FP16(min);
|
1313
|
+
|
1314
|
+
for (int l = 0; l < QK4_3; l += 2) {
|
1315
|
+
const float v0 = (x[i*QK4_3 + l + 0] - min)*id;
|
1316
|
+
const float v1 = (x[i*QK4_3 + l + 1] - min)*id;
|
1317
|
+
|
1318
|
+
const uint8_t vi0 = (int) (v0 + 0.5f);
|
1319
|
+
const uint8_t vi1 = (int) (v1 + 0.5f);
|
1320
|
+
|
1321
|
+
assert(vi0 < 16);
|
1322
|
+
assert(vi1 < 16);
|
1323
|
+
|
1324
|
+
y[i].qs[l/2] = vi0 | (vi1 << 4);
|
1325
|
+
}
|
1326
|
+
}
|
1327
|
+
}
|
1328
|
+
|
1329
|
+
static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int k) {
|
1330
|
+
assert(k % QK4_3 == 0);
|
1331
|
+
|
1332
|
+
block_q4_3 * restrict y = vy;
|
1333
|
+
|
1334
|
+
quantize_row_q4_3_reference(x, y, k);
|
1335
|
+
}
|
1336
|
+
|
1337
|
+
static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
|
1338
|
+
assert(k % QK5_0 == 0);
|
1339
|
+
const int nb = k / QK5_0;
|
1340
|
+
|
1341
|
+
for (int i = 0; i < nb; i++) {
|
1342
|
+
float amax = 0.0f; // absolute max
|
1343
|
+
float max = 0.0f;
|
1344
|
+
|
1345
|
+
for (int l = 0; l < QK5_0; l++) {
|
1346
|
+
const float v = x[i*QK5_0 + l];
|
1347
|
+
if (amax < fabsf(v)) {
|
1348
|
+
amax = fabsf(v);
|
1349
|
+
max = v;
|
1350
|
+
}
|
1351
|
+
}
|
1352
|
+
|
1353
|
+
const float d = max / -16;
|
1354
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1355
|
+
|
1356
|
+
y[i].d = GGML_FP32_TO_FP16(d);
|
1357
|
+
|
1358
|
+
uint32_t qh = 0;
|
1359
|
+
|
1360
|
+
for (int l = 0; l < QK5_0; l += 2) {
|
1361
|
+
const float v0 = x[i*QK5_0 + l + 0]*id;
|
1362
|
+
const float v1 = x[i*QK5_0 + l + 1]*id;
|
1363
|
+
|
1364
|
+
const uint32_t vi0 = MIN(31, (int) (v0 + 16.5f));
|
1365
|
+
const uint32_t vi1 = MIN(31, (int) (v1 + 16.5f));
|
1366
|
+
|
1367
|
+
y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
|
1368
|
+
|
1369
|
+
// get the 5-th bit and store it in qh at the right position
|
1370
|
+
qh |= ((vi0 & 0x10) >> 4) << (l + 0);
|
1371
|
+
qh |= ((vi1 & 0x10) >> 4) << (l + 1);
|
1372
|
+
}
|
1373
|
+
|
1374
|
+
memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
|
1375
|
+
}
|
1376
|
+
}
|
1377
|
+
|
1378
|
+
static void quantize_row_q5_0(const float * restrict x, void * restrict vy, int k) {
|
1379
|
+
assert(k % QK5_0 == 0);
|
1380
|
+
|
1381
|
+
block_q5_0 * restrict y = vy;
|
1382
|
+
|
1383
|
+
quantize_row_q5_0_reference(x, y, k);
|
1384
|
+
}
|
1385
|
+
|
1386
|
+
static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
|
1387
|
+
assert(k % QK5_1 == 0);
|
1388
|
+
const int nb = k / QK5_1;
|
1389
|
+
|
1390
|
+
for (int i = 0; i < nb; i++) {
|
1391
|
+
float min = FLT_MAX;
|
1392
|
+
float max = -FLT_MAX;
|
1393
|
+
|
1394
|
+
for (int l = 0; l < QK5_1; l++) {
|
1395
|
+
const float v = x[i*QK5_1 + l];
|
1396
|
+
if (v < min) min = v;
|
1397
|
+
if (v > max) max = v;
|
1398
|
+
}
|
1399
|
+
|
1400
|
+
const float d = (max - min) / ((1 << 5) - 1);
|
1401
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1402
|
+
|
1403
|
+
y[i].d = GGML_FP32_TO_FP16(d);
|
1404
|
+
y[i].m = GGML_FP32_TO_FP16(min);
|
1405
|
+
|
1406
|
+
uint32_t qh = 0;
|
1407
|
+
|
1408
|
+
for (int l = 0; l < QK5_1; l += 2) {
|
1409
|
+
const float v0 = (x[i*QK5_1 + l + 0] - min)*id;
|
1410
|
+
const float v1 = (x[i*QK5_1 + l + 1] - min)*id;
|
1411
|
+
|
1412
|
+
const uint32_t vi0 = (int) (v0 + 0.5f);
|
1413
|
+
const uint32_t vi1 = (int) (v1 + 0.5f);
|
1414
|
+
|
1415
|
+
y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
|
1416
|
+
|
1417
|
+
// get the 5-th bit and store it in qh at the right position
|
1418
|
+
qh |= ((vi0 & 0x10) >> 4) << (l + 0);
|
1419
|
+
qh |= ((vi1 & 0x10) >> 4) << (l + 1);
|
1420
|
+
}
|
1421
|
+
|
1422
|
+
memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
|
1423
|
+
}
|
1424
|
+
}
|
1425
|
+
|
1426
|
+
static void quantize_row_q5_1(const float * restrict x, void * restrict vy, int k) {
|
1427
|
+
assert(k % QK5_1 == 0);
|
1428
|
+
|
1429
|
+
block_q5_1 * restrict y = vy;
|
1430
|
+
|
1431
|
+
quantize_row_q5_1_reference(x, y, k);
|
1432
|
+
}
|
1433
|
+
|
1048
1434
|
// reference implementation for deterministic creation of model files
|
1049
1435
|
static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
|
1050
1436
|
assert(k % QK8_0 == 0);
|
@@ -1064,18 +1450,64 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
|
|
1064
1450
|
y[i].d = d;
|
1065
1451
|
|
1066
1452
|
for (int l = 0; l < QK8_0; ++l) {
|
1067
|
-
const float
|
1068
|
-
|
1453
|
+
const float v0 = x[i*QK8_0 + l]*id;
|
1454
|
+
|
1455
|
+
y[i].qs[l] = roundf(v0);
|
1069
1456
|
}
|
1070
1457
|
}
|
1071
1458
|
}
|
1072
1459
|
|
1073
1460
|
static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
|
1074
1461
|
assert(k % QK8_0 == 0);
|
1075
|
-
const int nb = k / QK8_0;
|
1076
1462
|
|
1077
1463
|
block_q8_0 * restrict y = vy;
|
1078
1464
|
|
1465
|
+
quantize_row_q8_0_reference(x, y, k);
|
1466
|
+
}
|
1467
|
+
|
1468
|
+
// reference implementation for deterministic creation of model files
|
1469
|
+
static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
|
1470
|
+
assert(k % QK8_1 == 0);
|
1471
|
+
const int nb = k / QK8_1;
|
1472
|
+
|
1473
|
+
for (int i = 0; i < nb; i++) {
|
1474
|
+
float amax = 0.0f; // absolute max
|
1475
|
+
|
1476
|
+
for (int l = 0; l < QK8_1; l++) {
|
1477
|
+
const float v = x[i*QK8_1 + l];
|
1478
|
+
amax = MAX(amax, fabsf(v));
|
1479
|
+
}
|
1480
|
+
|
1481
|
+
const float d = amax / ((1 << 7) - 1);
|
1482
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1483
|
+
|
1484
|
+
y[i].d = d;
|
1485
|
+
|
1486
|
+
int sum0 = 0;
|
1487
|
+
int sum1 = 0;
|
1488
|
+
|
1489
|
+
for (int l = 0; l < QK8_1/2; ++l) {
|
1490
|
+
const float v0 = x[i*QK8_1 + l]*id;
|
1491
|
+
const float v1 = x[i*QK8_1 + QK8_1/2 + l]*id;
|
1492
|
+
|
1493
|
+
y[i].qs[ l] = roundf(v0);
|
1494
|
+
y[i].qs[QK8_1/2 + l] = roundf(v1);
|
1495
|
+
|
1496
|
+
sum0 += y[i].qs[ l];
|
1497
|
+
sum1 += y[i].qs[QK8_1/2 + l];
|
1498
|
+
}
|
1499
|
+
|
1500
|
+
y[i].s0 = d * sum0;
|
1501
|
+
y[i].s1 = d * sum1;
|
1502
|
+
}
|
1503
|
+
}
|
1504
|
+
|
1505
|
+
static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
|
1506
|
+
assert(k % QK8_1 == 0);
|
1507
|
+
const int nb = k / QK8_1;
|
1508
|
+
|
1509
|
+
block_q8_1 * restrict y = vy;
|
1510
|
+
|
1079
1511
|
#if defined(__ARM_NEON)
|
1080
1512
|
for (int i = 0; i < nb; i++) {
|
1081
1513
|
float32x4_t srcv [8];
|
@@ -1096,7 +1528,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|
1096
1528
|
|
1097
1529
|
y[i].d = d;
|
1098
1530
|
|
1099
|
-
|
1531
|
+
int32x4_t accv0 = vdupq_n_s32(0);
|
1532
|
+
int32x4_t accv1 = vdupq_n_s32(0);
|
1533
|
+
|
1534
|
+
// low half
|
1535
|
+
for (int l = 0; l < 4; l++) {
|
1100
1536
|
const float32x4_t v = vmulq_n_f32(srcv[l], id);
|
1101
1537
|
const int32x4_t vi = vcvtnq_s32_f32(v);
|
1102
1538
|
|
@@ -1104,19 +1540,40 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|
1104
1540
|
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
|
1105
1541
|
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
|
1106
1542
|
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
|
1543
|
+
|
1544
|
+
accv0 = vaddq_s32(accv0, vi);
|
1107
1545
|
}
|
1108
|
-
}
|
1109
|
-
#elif defined(__AVX2__) || defined(__AVX__)
|
1110
|
-
for (int i = 0; i < nb; i++) {
|
1111
|
-
// Load elements into 4 AVX vectors
|
1112
|
-
__m256 v0 = _mm256_loadu_ps( x );
|
1113
|
-
__m256 v1 = _mm256_loadu_ps( x + 8 );
|
1114
|
-
__m256 v2 = _mm256_loadu_ps( x + 16 );
|
1115
|
-
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
1116
|
-
x += 32;
|
1117
1546
|
|
1118
|
-
//
|
1119
|
-
|
1547
|
+
// high half
|
1548
|
+
for (int l = 4; l < 8; l++) {
|
1549
|
+
const float32x4_t v = vmulq_n_f32(srcv[l], id);
|
1550
|
+
const int32x4_t vi = vcvtnq_s32_f32(v);
|
1551
|
+
|
1552
|
+
y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
|
1553
|
+
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
|
1554
|
+
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
|
1555
|
+
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
|
1556
|
+
|
1557
|
+
accv1 = vaddq_s32(accv1, vi);
|
1558
|
+
}
|
1559
|
+
|
1560
|
+
const int32_t sum0 = vaddvq_s32(accv0);
|
1561
|
+
const int32_t sum1 = vaddvq_s32(accv1);
|
1562
|
+
|
1563
|
+
y[i].s0 = d * sum0;
|
1564
|
+
y[i].s1 = d * sum1;
|
1565
|
+
}
|
1566
|
+
#elif defined(__AVX2__) || defined(__AVX__)
|
1567
|
+
for (int i = 0; i < nb; i++) {
|
1568
|
+
// Load elements into 4 AVX vectors
|
1569
|
+
__m256 v0 = _mm256_loadu_ps( x );
|
1570
|
+
__m256 v1 = _mm256_loadu_ps( x + 8 );
|
1571
|
+
__m256 v2 = _mm256_loadu_ps( x + 16 );
|
1572
|
+
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
1573
|
+
x += 32;
|
1574
|
+
|
1575
|
+
// Compute max(abs(e)) for the block
|
1576
|
+
const __m256 signBit = _mm256_set1_ps( -0.0f );
|
1120
1577
|
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
|
1121
1578
|
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
|
1122
1579
|
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
|
@@ -1152,6 +1609,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|
1152
1609
|
__m256i i3 = _mm256_cvtps_epi32( v3 );
|
1153
1610
|
|
1154
1611
|
#if defined(__AVX2__)
|
1612
|
+
// Compute the sum of the quants and set y[i].s
|
1613
|
+
//y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
|
1614
|
+
y[i].s0 = d * hsum_i32_8(_mm256_add_epi32(i0, i1));
|
1615
|
+
y[i].s1 = d * hsum_i32_8(_mm256_add_epi32(i2, i3));
|
1616
|
+
|
1155
1617
|
// Convert int32 to int16
|
1156
1618
|
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
|
1157
1619
|
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
|
@@ -1177,6 +1639,12 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|
1177
1639
|
__m128i ni6 = _mm256_castsi256_si128( i3 );
|
1178
1640
|
__m128i ni7 = _mm256_extractf128_si256( i3, 1);
|
1179
1641
|
|
1642
|
+
// Compute the sum of the quants and set y[i].s
|
1643
|
+
const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
|
1644
|
+
const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
|
1645
|
+
y[i].s0 = d * hsum_i32_4(s0);
|
1646
|
+
y[i].s1 = d * hsum_i32_4(s1);
|
1647
|
+
|
1180
1648
|
// Convert int32 to int16
|
1181
1649
|
ni0 = _mm_packs_epi32( ni0, ni1 );
|
1182
1650
|
ni2 = _mm_packs_epi32( ni2, ni3 );
|
@@ -1192,7 +1660,7 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|
1192
1660
|
}
|
1193
1661
|
#else
|
1194
1662
|
// scalar
|
1195
|
-
|
1663
|
+
quantize_row_q8_1_reference(x, y, k);
|
1196
1664
|
#endif
|
1197
1665
|
}
|
1198
1666
|
|
@@ -1211,7 +1679,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1211
1679
|
|
1212
1680
|
for (int l = 0; l < QK4_0; l += 32) {
|
1213
1681
|
// Load 32x4-bit integers into 32x8-bit integers
|
1214
|
-
__m256i vx8 =
|
1682
|
+
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
|
1215
1683
|
|
1216
1684
|
// Subtract 8 from the integers
|
1217
1685
|
vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
|
@@ -1246,7 +1714,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1246
1714
|
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
1247
1715
|
|
1248
1716
|
// Expand 4-bit qs to 8-bit bytes
|
1249
|
-
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(
|
1717
|
+
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
|
1250
1718
|
const uint8x8_t v1 = vshr_n_u8(v8, 4);
|
1251
1719
|
|
1252
1720
|
// Convert to signed 8-bit integers
|
@@ -1296,7 +1764,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1296
1764
|
for (int l = 0; l < QK4_0; l += 2) {
|
1297
1765
|
const uint8_t vi = pp[l/2];
|
1298
1766
|
|
1299
|
-
const int8_t vi0 = vi &
|
1767
|
+
const int8_t vi0 = vi & 0x0F;
|
1300
1768
|
const int8_t vi1 = vi >> 4;
|
1301
1769
|
|
1302
1770
|
const float v0 = (vi0 - 8)*d;
|
@@ -1329,7 +1797,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1329
1797
|
|
1330
1798
|
for (int l = 0; l < QK4_1; l += 32) {
|
1331
1799
|
// Load 32x4-bit integers into 32x8-bit integers
|
1332
|
-
__m256i vx8 =
|
1800
|
+
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
|
1333
1801
|
|
1334
1802
|
// Convert to 16-bit int
|
1335
1803
|
const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
|
@@ -1362,7 +1830,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1362
1830
|
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
1363
1831
|
|
1364
1832
|
// Expand 4-bit qs to 8-bit bytes
|
1365
|
-
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(
|
1833
|
+
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
|
1366
1834
|
const uint8x8_t v1 = vshr_n_u8(v8, 4);
|
1367
1835
|
|
1368
1836
|
// Interleave and combine
|
@@ -1404,7 +1872,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1404
1872
|
for (int l = 0; l < QK4_1; l += 2) {
|
1405
1873
|
const uint8_t vi = pp[l/2];
|
1406
1874
|
|
1407
|
-
const int8_t vi0 = vi &
|
1875
|
+
const int8_t vi0 = vi & 0x0F;
|
1408
1876
|
const int8_t vi1 = vi >> 4;
|
1409
1877
|
|
1410
1878
|
const float v0 = vi0*d + m;
|
@@ -1420,8 +1888,162 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1420
1888
|
#endif
|
1421
1889
|
}
|
1422
1890
|
|
1423
|
-
static void
|
1891
|
+
static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, int k) {
|
1892
|
+
assert(k % QK4_2 == 0);
|
1893
|
+
const int nb = k / QK4_2;
|
1894
|
+
|
1895
|
+
const block_q4_2 * restrict x = vx;
|
1896
|
+
|
1897
|
+
for (int i = 0; i < nb; i++) {
|
1898
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
1899
|
+
|
1900
|
+
const uint8_t * restrict pp = x[i].qs;
|
1901
|
+
|
1902
|
+
for (int l = 0; l < QK4_2; l += 2) {
|
1903
|
+
const uint8_t vi = pp[l/2];
|
1904
|
+
|
1905
|
+
const int8_t vi0 = vi & 0x0F;
|
1906
|
+
const int8_t vi1 = vi >> 4;
|
1907
|
+
|
1908
|
+
const float v0 = (vi0 - 8)*d;
|
1909
|
+
const float v1 = (vi1 - 8)*d;
|
1910
|
+
|
1911
|
+
y[i*QK4_2 + l + 0] = v0;
|
1912
|
+
y[i*QK4_2 + l + 1] = v1;
|
1913
|
+
|
1914
|
+
assert(!isnan(y[i*QK4_2 + l + 0]));
|
1915
|
+
assert(!isnan(y[i*QK4_2 + l + 1]));
|
1916
|
+
}
|
1917
|
+
}
|
1918
|
+
}
|
1919
|
+
|
1920
|
+
static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, int k) {
|
1921
|
+
assert(k % QK4_3 == 0);
|
1922
|
+
const int nb = k / QK4_3;
|
1923
|
+
|
1924
|
+
const block_q4_3 * restrict x = vx;
|
1925
|
+
|
1926
|
+
for (int i = 0; i < nb; i++) {
|
1927
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
1928
|
+
const float m = GGML_FP16_TO_FP32(x[i].m);
|
1929
|
+
|
1930
|
+
const uint8_t * restrict pp = x[i].qs;
|
1931
|
+
|
1932
|
+
for (int l = 0; l < QK4_3; l += 2) {
|
1933
|
+
const uint8_t vi = pp[l/2];
|
1934
|
+
|
1935
|
+
const int8_t vi0 = vi & 0x0F;
|
1936
|
+
const int8_t vi1 = vi >> 4;
|
1937
|
+
|
1938
|
+
const float v0 = vi0*d + m;
|
1939
|
+
const float v1 = vi1*d + m;
|
1940
|
+
|
1941
|
+
y[i*QK4_3 + l + 0] = v0;
|
1942
|
+
y[i*QK4_3 + l + 1] = v1;
|
1943
|
+
|
1944
|
+
assert(!isnan(y[i*QK4_3 + l + 0]));
|
1945
|
+
assert(!isnan(y[i*QK4_3 + l + 1]));
|
1946
|
+
}
|
1947
|
+
}
|
1948
|
+
}
|
1949
|
+
|
1950
|
+
static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, int k) {
|
1951
|
+
assert(k % QK5_0 == 0);
|
1952
|
+
const int nb = k / QK5_0;
|
1953
|
+
|
1954
|
+
const block_q5_0 * restrict x = vx;
|
1955
|
+
|
1956
|
+
for (int i = 0; i < nb; i++) {
|
1957
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
1958
|
+
|
1959
|
+
const uint8_t * restrict pp = x[i].qs;
|
1960
|
+
|
1961
|
+
uint32_t qh;
|
1962
|
+
memcpy(&qh, x[i].qh, sizeof(qh));
|
1963
|
+
|
1964
|
+
for (int l = 0; l < QK5_0; l += 2) {
|
1965
|
+
const uint8_t vi = pp[l/2];
|
1966
|
+
|
1967
|
+
// extract the 5-th bit from qh
|
1968
|
+
const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
|
1969
|
+
const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
|
1970
|
+
|
1971
|
+
const int8_t vi0 = (vi & 0x0F) | vh0;
|
1972
|
+
const int8_t vi1 = (vi >> 4) | vh1;
|
1973
|
+
|
1974
|
+
const float v0 = (vi0 - 16)*d;
|
1975
|
+
const float v1 = (vi1 - 16)*d;
|
1976
|
+
|
1977
|
+
y[i*QK5_0 + l + 0] = v0;
|
1978
|
+
y[i*QK5_0 + l + 1] = v1;
|
1979
|
+
|
1980
|
+
assert(!isnan(y[i*QK5_0 + l + 0]));
|
1981
|
+
assert(!isnan(y[i*QK5_0 + l + 1]));
|
1982
|
+
}
|
1983
|
+
}
|
1984
|
+
}
|
1985
|
+
|
1986
|
+
static void dequantize_row_q5_1(const void * restrict vx, float * restrict y, int k) {
|
1987
|
+
assert(k % QK5_1 == 0);
|
1988
|
+
const int nb = k / QK5_1;
|
1989
|
+
|
1990
|
+
const block_q5_1 * restrict x = vx;
|
1991
|
+
|
1992
|
+
for (int i = 0; i < nb; i++) {
|
1993
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
1994
|
+
const float m = GGML_FP16_TO_FP32(x[i].m);
|
1995
|
+
|
1996
|
+
const uint8_t * restrict pp = x[i].qs;
|
1997
|
+
|
1998
|
+
uint32_t qh;
|
1999
|
+
memcpy(&qh, x[i].qh, sizeof(qh));
|
2000
|
+
|
2001
|
+
for (int l = 0; l < QK5_1; l += 2) {
|
2002
|
+
const uint8_t vi = pp[l/2];
|
2003
|
+
|
2004
|
+
// extract the 5-th bit from qh
|
2005
|
+
const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
|
2006
|
+
const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
|
2007
|
+
|
2008
|
+
const uint8_t vi0 = (vi & 0x0F) | vh0;
|
2009
|
+
const uint8_t vi1 = (vi >> 4) | vh1;
|
2010
|
+
|
2011
|
+
const float v0 = vi0*d + m;
|
2012
|
+
const float v1 = vi1*d + m;
|
2013
|
+
|
2014
|
+
y[i*QK5_1 + l + 0] = v0;
|
2015
|
+
y[i*QK5_1 + l + 1] = v1;
|
2016
|
+
|
2017
|
+
assert(!isnan(y[i*QK5_1 + l + 0]));
|
2018
|
+
assert(!isnan(y[i*QK5_1 + l + 1]));
|
2019
|
+
}
|
2020
|
+
}
|
2021
|
+
}
|
2022
|
+
|
2023
|
+
static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, int k) {
|
2024
|
+
assert(k % QK8_0 == 0);
|
2025
|
+
const int nb = k / QK8_0;
|
2026
|
+
|
2027
|
+
const block_q8_0 * restrict x = vx;
|
2028
|
+
|
2029
|
+
for (int i = 0; i < nb; i++) {
|
2030
|
+
const float d = x[i].d;
|
2031
|
+
|
2032
|
+
const int8_t * restrict pp = x[i].qs;
|
2033
|
+
|
2034
|
+
for (int l = 0; l < QK8_0; ++l) {
|
2035
|
+
y[i*QK8_0 + l] = pp[l]*d;
|
2036
|
+
}
|
2037
|
+
}
|
2038
|
+
}
|
2039
|
+
|
1424
2040
|
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
2041
|
+
static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
2042
|
+
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
2043
|
+
static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
2044
|
+
static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
2045
|
+
static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
2046
|
+
static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1425
2047
|
|
1426
2048
|
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
1427
2049
|
[GGML_TYPE_Q4_0] = {
|
@@ -1430,15 +2052,64 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
|
1430
2052
|
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
|
1431
2053
|
.quantize_row_q_dot = quantize_row_q8_0,
|
1432
2054
|
.vec_dot_q = ggml_vec_dot_q4_0_q8_0,
|
2055
|
+
.vec_dot_type = GGML_TYPE_Q8_0,
|
1433
2056
|
},
|
1434
2057
|
[GGML_TYPE_Q4_1] = {
|
1435
2058
|
.dequantize_row_q = dequantize_row_q4_1,
|
1436
2059
|
.quantize_row_q = quantize_row_q4_1,
|
1437
2060
|
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
|
1438
|
-
.quantize_row_q_dot =
|
1439
|
-
.vec_dot_q =
|
2061
|
+
.quantize_row_q_dot = quantize_row_q8_1,
|
2062
|
+
.vec_dot_q = ggml_vec_dot_q4_1_q8_1,
|
2063
|
+
.vec_dot_type = GGML_TYPE_Q8_1,
|
2064
|
+
},
|
2065
|
+
[GGML_TYPE_Q4_2] = {
|
2066
|
+
.dequantize_row_q = dequantize_row_q4_2,
|
2067
|
+
.quantize_row_q = quantize_row_q4_2,
|
2068
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
|
2069
|
+
.quantize_row_q_dot = quantize_row_q8_0,
|
2070
|
+
.vec_dot_q = ggml_vec_dot_q4_2_q8_0,
|
2071
|
+
.vec_dot_type = GGML_TYPE_Q8_0,
|
2072
|
+
},
|
2073
|
+
[GGML_TYPE_Q4_3] = {
|
2074
|
+
.dequantize_row_q = dequantize_row_q4_3,
|
2075
|
+
.quantize_row_q = quantize_row_q4_3,
|
2076
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference,
|
2077
|
+
.quantize_row_q_dot = quantize_row_q8_1,
|
2078
|
+
.vec_dot_q = ggml_vec_dot_q4_3_q8_1,
|
2079
|
+
.vec_dot_type = GGML_TYPE_Q8_1,
|
2080
|
+
},
|
2081
|
+
[GGML_TYPE_Q5_0] = {
|
2082
|
+
.dequantize_row_q = dequantize_row_q5_0,
|
2083
|
+
.quantize_row_q = quantize_row_q5_0,
|
2084
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_0_reference,
|
2085
|
+
.quantize_row_q_dot = quantize_row_q8_0,
|
2086
|
+
.vec_dot_q = ggml_vec_dot_q5_0_q8_0,
|
2087
|
+
.vec_dot_type = GGML_TYPE_Q8_0,
|
2088
|
+
},
|
2089
|
+
[GGML_TYPE_Q5_1] = {
|
2090
|
+
.dequantize_row_q = dequantize_row_q5_1,
|
2091
|
+
.quantize_row_q = quantize_row_q5_1,
|
2092
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_1_reference,
|
2093
|
+
.quantize_row_q_dot = quantize_row_q8_1,
|
2094
|
+
.vec_dot_q = ggml_vec_dot_q5_1_q8_1,
|
2095
|
+
.vec_dot_type = GGML_TYPE_Q8_1,
|
2096
|
+
},
|
2097
|
+
[GGML_TYPE_Q8_0] = {
|
2098
|
+
.dequantize_row_q = dequantize_row_q8_0,
|
2099
|
+
.quantize_row_q = quantize_row_q8_0,
|
2100
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
|
2101
|
+
.quantize_row_q_dot = quantize_row_q8_0,
|
2102
|
+
.vec_dot_q = ggml_vec_dot_q8_0_q8_0,
|
2103
|
+
.vec_dot_type = GGML_TYPE_Q8_0,
|
2104
|
+
},
|
2105
|
+
[GGML_TYPE_Q8_1] = {
|
2106
|
+
.dequantize_row_q = NULL, // TODO
|
2107
|
+
.quantize_row_q = quantize_row_q8_1,
|
2108
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_1_reference,
|
2109
|
+
.quantize_row_q_dot = quantize_row_q8_1,
|
2110
|
+
.vec_dot_q = NULL, // TODO
|
2111
|
+
.vec_dot_type = GGML_TYPE_Q8_1,
|
1440
2112
|
},
|
1441
|
-
// TODO: GGML_TYPE_Q8_0
|
1442
2113
|
};
|
1443
2114
|
|
1444
2115
|
// For internal test use
|
@@ -2004,191 +2675,6 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
|
|
2004
2675
|
*s = sumf;
|
2005
2676
|
}
|
2006
2677
|
|
2007
|
-
#if __AVX512F__ && QK4_0 == 32
|
2008
|
-
static inline __m512i bytes_from_q4_0_twoblocks_avx512( const __m512i blocks ) {
|
2009
|
-
// The 64 bytes of `blocks` contain two consecutive Q4_0 blocks loaded from memory:
|
2010
|
-
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2011
|
-
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
|
2012
|
-
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2013
|
-
// | :. =_ () [] <> () Zz Yy|
|
2014
|
-
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2015
|
-
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
|
2016
|
-
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2017
|
-
// |Xx Ww Vv Uu Tt Ss Rr Qq Pp Oo Nn Mm Ll Kk Jj Ii Hh Gg Ff Ee Dd Cc Bb Aa |
|
2018
|
-
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2019
|
-
//
|
2020
|
-
// Bytes 04..19 (block #0) and 24..39 (block #1) both contain 32 nibbles (4-bit unsigned integers).
|
2021
|
-
// We have exactly 64 nibbles, so we want to place each nibble into a separate byte.
|
2022
|
-
// Bytes 00..03 and 20..23 contain scales, which are irrelevant to this function.
|
2023
|
-
// Bytes 40..63 are masked when loading the data, so they are zeroed out.
|
2024
|
-
#ifdef __AVX512VBMI__
|
2025
|
-
const __m512i byte_perm = _mm512_set_epi8(
|
2026
|
-
39, 38, 39, 38, 37, 36, 37, 36, 35, 34, 35, 34, 33, 32, 33, 32,
|
2027
|
-
31, 30, 31, 30, 29, 28, 29, 28, 27, 26, 27, 26, 25, 24, 25, 24,
|
2028
|
-
19, 18, 19, 18, 17, 16, 17, 16, 15, 14, 15, 14, 13, 12, 13, 12,
|
2029
|
-
11, 10, 11, 10, 9, 8, 9, 8, 7, 6, 7, 6, 5, 4, 5, 4
|
2030
|
-
);
|
2031
|
-
const __m512i permuted = _mm512_permutexvar_epi8( byte_perm, blocks );
|
2032
|
-
// After applying VPERMB, `permuted` looks like this:
|
2033
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2034
|
-
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
|
2035
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2036
|
-
// |:. =_ :. =_ () [] () [] <> () <> () Zz Yy Zz Yy Xx Ww Xx Ww Vv Uu Vv Uu Tt Ss Tt Ss Rr Qq Rr Qq|
|
2037
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2038
|
-
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
|
2039
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2040
|
-
// |Pp Oo Pp Oo Nn Mm Nn Mm Ll Kk Ll Kk Jj Ii Jj Ii Hh Gg Hh Gg Ff Ee Ff Ee Dd Cc Dd Cc Bb Aa Bb Aa|
|
2041
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2042
|
-
#else
|
2043
|
-
const __m512i word_perm = _mm512_set_epi16(
|
2044
|
-
19, 19, 18, 18, 17, 17, 16, 16, 15, 15, 14, 14, 13, 13, 12, 12,
|
2045
|
-
9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2
|
2046
|
-
);
|
2047
|
-
const __m512i permuted = _mm512_permutexvar_epi16( word_perm, blocks );
|
2048
|
-
// This is the fallback path for CPUs that don't support VPERMB. Since we permute 16-bit groups only,
|
2049
|
-
// VPERMB can be replaced with VPERMW. We could always use VPERMW, but at least on Tiger Lake and
|
2050
|
-
// Ice Lake VPERMW followed by a right shift is quite noticeably slower than VPERMB.
|
2051
|
-
#endif
|
2052
|
-
|
2053
|
-
// Shift every odd-numbered 16-bit group to the right by 4 bits.
|
2054
|
-
const __mmask32 shift_mask = 0xaaaaaaaa;
|
2055
|
-
const __m512i shifted = _mm512_mask_srai_epi16( permuted, shift_mask, permuted, 4 );
|
2056
|
-
// After applying VPSRAW, `shifted` looks like this (the "empty" nibbles are filled with zeroes):
|
2057
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2058
|
-
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32
|
2059
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2060
|
-
// | : .= :. =_ ( )[ () [] < >( <> () Z zY Zz Yy X xW Xx Ww V vU Vv Uu T tS Tt Ss R rQ Rr Qq
|
2061
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2062
|
-
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
|
2063
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2064
|
-
// | P pO Pp Oo N nM Nn Mm L lK Ll Kk J jI Jj Ii H hG Hh Gg F fE Ff Ee D dC Dd Cc B bA Bb Aa|
|
2065
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2066
|
-
|
2067
|
-
// Now we just need to zero out the higher nibble in each byte, and we're done.
|
2068
|
-
const __m512i low_nibble_mask = _mm512_set1_epi8( 0xf );
|
2069
|
-
return _mm512_and_si512( low_nibble_mask, shifted );
|
2070
|
-
// The final result looks like this:
|
2071
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2072
|
-
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
|
2073
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2074
|
-
// | : = . _ ( [ ) ] < ( > ) Z Y z y X W x w V U v u T S t s R Q r q|
|
2075
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2076
|
-
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
|
2077
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2078
|
-
// | P O p o N M n m L K l k J I j i H G h g F E f e D C d c B A b a|
|
2079
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2080
|
-
}
|
2081
|
-
|
2082
|
-
static inline __m512 dot_q4_0_twoblocks_avx512(
|
2083
|
-
__m512 acc,
|
2084
|
-
const block_q4_0 * restrict x,
|
2085
|
-
const block_q4_0 * restrict y,
|
2086
|
-
int i
|
2087
|
-
) {
|
2088
|
-
// A pair of Q4_0 blocks spans 40 bytes, while an AVX-512 register has 64. The remaining 24 bytes
|
2089
|
-
// can potentially be unaddressable, so we make sure to mask them out before the load, even though
|
2090
|
-
// we don't use them at all. This might hurt the performance slightly, since the compiler is forced
|
2091
|
-
// to use e.g. `VMOVDQU64 REG, MASK, [ADDR] + VPERMB ..., REG` instead of just `VPERMB ..., [ADDR]`.
|
2092
|
-
const __mmask8 load_mask = 0x1f;
|
2093
|
-
const __m512i blocks_0 = _mm512_maskz_loadu_epi64( load_mask, &x[i] );
|
2094
|
-
const __m512i blocks_1 = _mm512_maskz_loadu_epi64( load_mask, &y[i] );
|
2095
|
-
|
2096
|
-
// We want to multiply the scales, so we interpret both registers as 16 32-bit floats:
|
2097
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2098
|
-
// | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
|
2099
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2100
|
-
// blocks_0_float
|
2101
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2102
|
-
// | | | | | | | xx | xx | xx | xx | B | xx | xx | xx | xx | A |
|
2103
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2104
|
-
// blocks_1_float
|
2105
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2106
|
-
// | | | | | | | xx | xx | xx | xx | D | xx | xx | xx | xx | C |
|
2107
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2108
|
-
const __m512 blocks_0_float = _mm512_castsi512_ps( blocks_0 );
|
2109
|
-
const __m512 blocks_1_float = _mm512_castsi512_ps( blocks_1 );
|
2110
|
-
// We absolutely shouldn't touch the floats marked with `xx`: they contain some
|
2111
|
-
// random data, which might very well underflow. At least on Intel, this leads
|
2112
|
-
// to a huge penalty that can't be ignored (easily 100x or more) unless you
|
2113
|
-
// compile your code with something like `-ffast-math` to enable FTZ/DAZ flags.
|
2114
|
-
// (and ggml can't assume that you do)...
|
2115
|
-
const __mmask16 scale_mul_mask = 0x21;
|
2116
|
-
#ifdef __clang__
|
2117
|
-
// ...however, clang decides to optimize the multiplication mask away:
|
2118
|
-
// https://godbolt.org/z/P8PqdsfvW
|
2119
|
-
// gcc and MSVC do the sane thing. This horrible workaround forces clang to emit the mask.
|
2120
|
-
__m512i scales;
|
2121
|
-
__asm__(
|
2122
|
-
"vmulps %1, %2, %0%{%3%}"
|
2123
|
-
: "=v" ( scales )
|
2124
|
-
: "vm" ( blocks_0_float ), "v" ( blocks_1_float ), "Yk" ( scale_mul_mask )
|
2125
|
-
);
|
2126
|
-
#else
|
2127
|
-
const __m512 scales = _mm512_maskz_mul_ps( scale_mul_mask, blocks_0_float, blocks_1_float );
|
2128
|
-
#endif
|
2129
|
-
const __m512i scale_perm = _mm512_set_epi32(
|
2130
|
-
5, 5, 5, 5, 5, 5, 5, 5,
|
2131
|
-
0, 0, 0, 0, 0, 0, 0, 0
|
2132
|
-
);
|
2133
|
-
const __m512 permuted_scales = _mm512_permutexvar_ps( scale_perm, scales );
|
2134
|
-
// After VMULPS and VPERMPS, `permuted_scales` looks like this:
|
2135
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2136
|
-
// | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
|
2137
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2138
|
-
// | B*D| B*D| B*D| B*D| B*D| B*D| B*D| B*D| A*C| A*C| A*C| A*C| A*C| A*C| A*C| A*C|
|
2139
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2140
|
-
|
2141
|
-
const __m512i bytes_0 = bytes_from_q4_0_twoblocks_avx512( blocks_0 );
|
2142
|
-
const __m512i bytes_1 = bytes_from_q4_0_twoblocks_avx512( blocks_1 );
|
2143
|
-
|
2144
|
-
// Now we want to compute dot products of 4-element byte vectors and store them in
|
2145
|
-
// 32-bit integers. That is (only one 4-element vector is shown for clarity):
|
2146
|
-
// +----+----+----+----+
|
2147
|
-
// ... | 03 | 02 | 01 | 00 |
|
2148
|
-
// +----+----+----+----+
|
2149
|
-
// bytes_0
|
2150
|
-
// +----+----+----+----+
|
2151
|
-
// ... | D | C | B | A |
|
2152
|
-
// +----+----+----+----+
|
2153
|
-
// bytes_1
|
2154
|
-
// +----+----+----+----+
|
2155
|
-
// ... | H | G | F | E |
|
2156
|
-
// +----+----+----+----+
|
2157
|
-
// final_res_int
|
2158
|
-
// +----+----+----+----+
|
2159
|
-
// ... | A*E+B*F+C*G+D*H |
|
2160
|
-
// +----+----+----+----+
|
2161
|
-
const __m512i plus_8 = _mm512_set1_epi8( 8 );
|
2162
|
-
const __m512i bytes_1_minus_8 = _mm512_sub_epi8( bytes_1, plus_8 );
|
2163
|
-
|
2164
|
-
#ifdef __AVX512VNNI__
|
2165
|
-
// We have VPDPBUSDS in AVX512-VNNI, which does exactly what we want, but with a catch:
|
2166
|
-
// the *left* operand is supposed to be unsigned, while Q4_0 quantization subtracts 8
|
2167
|
-
// from each nibble, so they can be negative. So, instead of `(bytes_0 - 8) * (bytes_1 - 8)`,
|
2168
|
-
// we compute `bytes_0 * (bytes_1 - 8) + bytes_1 * (-8) + 64`. VPDPBUSDS uses an accumulator,
|
2169
|
-
// which means we only need 2 instructions.
|
2170
|
-
const __m512i dot_init = _mm512_set1_epi32( 4 * 64 );
|
2171
|
-
const __m512i minus_8 = _mm512_set1_epi8( -8 );
|
2172
|
-
const __m512i prod_0 = _mm512_dpbusds_epi32( dot_init, bytes_1, minus_8 );
|
2173
|
-
const __m512i final_res_int = _mm512_dpbusds_epi32( prod_0, bytes_0, bytes_1_minus_8 );
|
2174
|
-
#else
|
2175
|
-
// As a fallback, we have VPMADDUBSW in AVX512-BW, which uses 16-bit products instead of 32-bit ones.
|
2176
|
-
// It has the same catch as VPDPBUSDS: the left operand should be unsigned.
|
2177
|
-
// This is essentially the AVX-512 version of the AVX-2 trick used by GH user Const-me
|
2178
|
-
// ref: https://gist.github.com/Const-me/4d30e1fc767ab314596e16e90f53b6f4#file-matmultest-cpp-L119
|
2179
|
-
const __m512i one = _mm512_set1_epi16( 1 );
|
2180
|
-
const __m512i prod_0 = _mm512_maddubs_epi16( bytes_0, bytes_1_minus_8 );
|
2181
|
-
const __m512i prod_1 = _mm512_maddubs_epi16( plus_8, bytes_1_minus_8 );
|
2182
|
-
const __m512i diff = _mm512_sub_epi16( prod_0, prod_1 );
|
2183
|
-
const __m512i final_res_int = _mm512_madd_epi16( diff, one );
|
2184
|
-
#endif
|
2185
|
-
|
2186
|
-
// Finally, we multiply the permuted scales and the 32-bit dot products, then accumulate.
|
2187
|
-
const __m512 final_res_float = _mm512_cvtepi32_ps( final_res_int );
|
2188
|
-
return _mm512_fmadd_ps( permuted_scales, final_res_float, acc );
|
2189
|
-
}
|
2190
|
-
#endif
|
2191
|
-
|
2192
2678
|
inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
|
2193
2679
|
ggml_float sumf = 0.0;
|
2194
2680
|
|
@@ -2225,67 +2711,62 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
|
|
2225
2711
|
*s = sumf;
|
2226
2712
|
}
|
2227
2713
|
|
2228
|
-
static void
|
2229
|
-
const int nb = n /
|
2714
|
+
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2715
|
+
const int nb = n / QK8_0;
|
2230
2716
|
|
2231
|
-
assert(n %
|
2717
|
+
assert(n % QK8_0 == 0);
|
2232
2718
|
assert(nb % 2 == 0);
|
2233
2719
|
|
2234
2720
|
const block_q4_0 * restrict x = vx;
|
2235
|
-
const
|
2236
|
-
|
2237
|
-
float sumf = 0.0;
|
2721
|
+
const block_q8_0 * restrict y = vy;
|
2238
2722
|
|
2239
2723
|
#if defined(__ARM_NEON)
|
2240
|
-
|
2241
|
-
|
2724
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
2725
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
2242
2726
|
|
2243
2727
|
for (int i = 0; i < nb; i += 2) {
|
2244
2728
|
const block_q4_0 * restrict x0 = &x[i + 0];
|
2245
|
-
const block_q4_0 * restrict y0 = &y[i + 0];
|
2246
2729
|
const block_q4_0 * restrict x1 = &x[i + 1];
|
2247
|
-
const
|
2730
|
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
2731
|
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
2248
2732
|
|
2249
|
-
const uint8x16_t m4b
|
2250
|
-
const int8x16_t s8b
|
2733
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
2734
|
+
const int8x16_t s8b = vdupq_n_s8(0x8);
|
2251
2735
|
|
2252
2736
|
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
2253
|
-
const uint8x16_t v1_0 = vld1q_u8(y0->qs);
|
2254
2737
|
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
2255
|
-
const uint8x16_t v1_1 = vld1q_u8(y1->qs);
|
2256
2738
|
|
2257
2739
|
// 4-bit -> 8-bit
|
2258
|
-
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
|
2259
|
-
const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
|
2740
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
2260
2741
|
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
2261
|
-
const int8x16_t
|
2262
|
-
|
2263
|
-
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
|
2264
|
-
const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
|
2742
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
2265
2743
|
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
2266
|
-
const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
|
2267
2744
|
|
2268
2745
|
// sub 8
|
2269
2746
|
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
2270
|
-
const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
|
2271
2747
|
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
2272
|
-
const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
|
2273
|
-
|
2274
2748
|
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
2275
|
-
const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
|
2276
2749
|
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
2277
|
-
|
2750
|
+
|
2751
|
+
// load y
|
2752
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2753
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2754
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2755
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2756
|
+
|
2757
|
+
// interleave
|
2758
|
+
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
|
2759
|
+
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
|
2760
|
+
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
|
2761
|
+
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
|
2278
2762
|
|
2279
2763
|
#if defined(__ARM_FEATURE_DOTPROD)
|
2280
2764
|
// dot product into int32x4_t
|
2281
|
-
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
|
2282
|
-
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
|
2765
|
+
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
|
2766
|
+
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
|
2283
2767
|
|
2284
|
-
|
2285
|
-
|
2286
|
-
|
2287
|
-
sum0 += x0->d*y0->d*vaddvq_s32(p_0);
|
2288
|
-
sum1 += x1->d*y1->d*vaddvq_s32(p_1);
|
2768
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
|
2769
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
|
2289
2770
|
#else
|
2290
2771
|
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
|
2291
2772
|
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
|
@@ -2297,125 +2778,41 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
2297
2778
|
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
|
2298
2779
|
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
|
2299
2780
|
|
2300
|
-
const
|
2301
|
-
const
|
2302
|
-
|
2303
|
-
const
|
2304
|
-
const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
|
2305
|
-
|
2306
|
-
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
|
2307
|
-
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
|
2781
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2782
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2783
|
+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
2784
|
+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
2308
2785
|
|
2309
|
-
|
2310
|
-
|
2786
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
|
2787
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
|
2311
2788
|
#endif
|
2312
2789
|
}
|
2313
2790
|
|
2314
|
-
|
2315
|
-
#elif defined(
|
2791
|
+
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
2792
|
+
#elif defined(__AVX2__)
|
2316
2793
|
// Initialize accumulator with zeros
|
2317
|
-
|
2318
|
-
__m512 acc1 = _mm512_setzero_ps();
|
2319
|
-
|
2320
|
-
const int superblock_size = 16;
|
2321
|
-
|
2322
|
-
const int superblock_count = nb / superblock_size;
|
2794
|
+
__m256 acc = _mm256_setzero_ps();
|
2323
2795
|
|
2324
|
-
|
2325
|
-
|
2796
|
+
// Main loop
|
2797
|
+
for (int i = 0; i < nb; ++i) {
|
2798
|
+
/* Compute combined scale for the block */
|
2799
|
+
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
|
2326
2800
|
|
2327
|
-
|
2328
|
-
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+2 );
|
2329
|
-
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+4 );
|
2330
|
-
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+6 );
|
2331
|
-
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+8 );
|
2332
|
-
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+10 );
|
2333
|
-
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+12 );
|
2334
|
-
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+14 );
|
2335
|
-
}
|
2801
|
+
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
2336
2802
|
|
2337
|
-
|
2338
|
-
|
2339
|
-
|
2340
|
-
}
|
2803
|
+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2804
|
+
const __m256i off = _mm256_set1_epi8( 8 );
|
2805
|
+
bx = _mm256_sub_epi8( bx, off );
|
2341
2806
|
|
2342
|
-
|
2343
|
-
sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
|
2344
|
-
#elif defined(__AVX2__)
|
2345
|
-
// Initialize accumulator with zeros
|
2346
|
-
__m256 acc = _mm256_setzero_ps();
|
2807
|
+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
2347
2808
|
|
2348
|
-
|
2349
|
-
const __m256i lowMask = _mm256_set1_epi8( 0xF );
|
2350
|
-
const __m256i offset_8 = _mm256_set1_epi16( 8 );
|
2809
|
+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
2351
2810
|
|
2352
|
-
|
2353
|
-
|
2354
|
-
|
2811
|
+
/* Multiply q with scale and accumulate */
|
2812
|
+
acc = _mm256_fmadd_ps( d, q, acc );
|
2813
|
+
}
|
2355
2814
|
|
2356
|
-
|
2357
|
-
for (int i = 0; i < nb; i+=UNROLL_COUNT) {
|
2358
|
-
// This loop will be unrolled by the compiler
|
2359
|
-
for (int u=0;u<UNROLL_COUNT;u++) {
|
2360
|
-
/* Compute combined scale for the block */
|
2361
|
-
const __m256 scale = _mm256_mul_ps(
|
2362
|
-
_mm256_broadcast_ss( &x[i+u].d ),
|
2363
|
-
_mm256_broadcast_ss( &y[i+u].d ) );
|
2364
|
-
|
2365
|
-
/* get input from x
|
2366
|
-
Input: 32 Nibbles (16 bytes) at *x[i+u]
|
2367
|
-
Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
|
2368
|
-
|
2369
|
-
/* Load 16 bytes from memory */
|
2370
|
-
const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
|
2371
|
-
/* Expand bytes into uint16_t values */
|
2372
|
-
const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
|
2373
|
-
/* Unpack values into individual bytes */
|
2374
|
-
__m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
|
2375
|
-
const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
|
2376
|
-
__m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
|
2377
|
-
/* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
|
2378
|
-
x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
|
2379
|
-
x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
|
2380
|
-
|
2381
|
-
/* get input from y
|
2382
|
-
Input: 32 Nibbles (16 bytes) at *y[i+u]
|
2383
|
-
Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
|
2384
|
-
|
2385
|
-
/* Load 16 bytes from memory */
|
2386
|
-
const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
|
2387
|
-
/* Expand bytes into uint16_t values */
|
2388
|
-
const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
|
2389
|
-
/* Unpack values into individual bytes */
|
2390
|
-
const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
|
2391
|
-
__m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
|
2392
|
-
__m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
|
2393
|
-
/* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
|
2394
|
-
y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
|
2395
|
-
y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
|
2396
|
-
|
2397
|
-
/* Compute products of int16_t integers, add pairwise, store as int32_t */
|
2398
|
-
__m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
|
2399
|
-
__m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
|
2400
|
-
|
2401
|
-
/* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
|
2402
|
-
__m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
|
2403
|
-
|
2404
|
-
/* Convert to vectore of 8 int32_t to 8 floats */
|
2405
|
-
__m256 q = _mm256_cvtepi32_ps( xy_q );
|
2406
|
-
|
2407
|
-
/* Multiply q with scale and accumulate */
|
2408
|
-
acc = _mm256_fmadd_ps( scale, q, acc );
|
2409
|
-
}
|
2410
|
-
}
|
2411
|
-
|
2412
|
-
// Return horizontal sum of the acc vector
|
2413
|
-
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
2414
|
-
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2415
|
-
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2416
|
-
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2417
|
-
|
2418
|
-
sumf = _mm_cvtss_f32( res );
|
2815
|
+
*s = hsum_float_8(acc);
|
2419
2816
|
#elif defined(__AVX__)
|
2420
2817
|
// Initialize accumulator with zeros
|
2421
2818
|
__m256 acc = _mm256_setzero_ps();
|
@@ -2428,13 +2825,12 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
2428
2825
|
__m128i i32[2];
|
2429
2826
|
for (int j = 0; j < 2; ++j) {
|
2430
2827
|
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
|
2431
|
-
__m128i bx =
|
2432
|
-
__m128i by =
|
2828
|
+
__m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j);
|
2829
|
+
__m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
|
2433
2830
|
|
2434
2831
|
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2435
2832
|
const __m128i off = _mm_set1_epi8( 8 );
|
2436
2833
|
bx = _mm_sub_epi8( bx, off );
|
2437
|
-
by = _mm_sub_epi8( by, off );
|
2438
2834
|
|
2439
2835
|
// Get absolute values of x vectors
|
2440
2836
|
const __m128i ax = _mm_sign_epi8(bx, bx);
|
@@ -2445,516 +2841,833 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
2445
2841
|
// Perform multiplication and create 16-bit values
|
2446
2842
|
const __m128i dot = _mm_maddubs_epi16(ax, sy);
|
2447
2843
|
|
2448
|
-
const __m128i ones = _mm_set1_epi16(1);
|
2449
|
-
i32[j] = _mm_madd_epi16(ones, dot);
|
2450
|
-
}
|
2844
|
+
const __m128i ones = _mm_set1_epi16(1);
|
2845
|
+
i32[j] = _mm_madd_epi16(ones, dot);
|
2846
|
+
}
|
2847
|
+
|
2848
|
+
// Convert int32_t to float
|
2849
|
+
__m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
|
2850
|
+
// Apply the scale, and accumulate
|
2851
|
+
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
|
2852
|
+
}
|
2853
|
+
|
2854
|
+
*s = hsum_float_8(acc);
|
2855
|
+
#else
|
2856
|
+
// scalar
|
2857
|
+
float sumf = 0.0;
|
2858
|
+
for (int i = 0; i < nb; i++) {
|
2859
|
+
const float d0 = x[i].d;
|
2860
|
+
const float d1 = y[i].d;
|
2861
|
+
|
2862
|
+
const uint8_t * restrict p0 = x[i].qs;
|
2863
|
+
const int8_t * restrict p1 = y[i].qs;
|
2864
|
+
|
2865
|
+
int sumi = 0;
|
2866
|
+
for (int j = 0; j < QK8_0/2; j++) {
|
2867
|
+
const uint8_t v0 = p0[j];
|
2868
|
+
|
2869
|
+
const int i0 = (int8_t) (v0 & 0x0F) - 8;
|
2870
|
+
const int i1 = (int8_t) (v0 >> 4) - 8;
|
2871
|
+
|
2872
|
+
const int i2 = p1[2*j + 0];
|
2873
|
+
const int i3 = p1[2*j + 1];
|
2874
|
+
|
2875
|
+
sumi += i0*i2 + i1*i3;
|
2876
|
+
}
|
2877
|
+
sumf += d0*d1*sumi;
|
2878
|
+
}
|
2879
|
+
*s = sumf;
|
2880
|
+
#endif
|
2881
|
+
}
|
2882
|
+
|
2883
|
+
static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2884
|
+
const int nb = n / QK8_1;
|
2885
|
+
|
2886
|
+
assert(n % QK8_1 == 0);
|
2887
|
+
assert(nb % 2 == 0);
|
2888
|
+
|
2889
|
+
const block_q4_1 * restrict x = vx;
|
2890
|
+
const block_q8_1 * restrict y = vy;
|
2891
|
+
|
2892
|
+
// TODO: add AVX / WASM SIMD / etc
|
2893
|
+
#if defined(__ARM_NEON)
|
2894
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
2895
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
2896
|
+
|
2897
|
+
float summs = 0;
|
2898
|
+
|
2899
|
+
for (int i = 0; i < nb; i += 2) {
|
2900
|
+
const block_q4_1 * restrict x0 = &x[i + 0];
|
2901
|
+
const block_q4_1 * restrict x1 = &x[i + 1];
|
2902
|
+
const block_q8_1 * restrict y0 = &y[i + 0];
|
2903
|
+
const block_q8_1 * restrict y1 = &y[i + 1];
|
2904
|
+
|
2905
|
+
summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1);
|
2906
|
+
|
2907
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
2908
|
+
|
2909
|
+
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
2910
|
+
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
2911
|
+
|
2912
|
+
// 4-bit -> 8-bit
|
2913
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
2914
|
+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
2915
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
2916
|
+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
2917
|
+
|
2918
|
+
// interleave
|
2919
|
+
const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
|
2920
|
+
const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
|
2921
|
+
const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
|
2922
|
+
const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
|
2923
|
+
|
2924
|
+
// load y
|
2925
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2926
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2927
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2928
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2929
|
+
|
2930
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
2931
|
+
// dot product into int32x4_t
|
2932
|
+
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h);
|
2933
|
+
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h);
|
2934
|
+
|
2935
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
|
2936
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
|
2937
|
+
#else
|
2938
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
|
2939
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
|
2940
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
|
2941
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
|
2942
|
+
|
2943
|
+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
|
2944
|
+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
|
2945
|
+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
|
2946
|
+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
|
2947
|
+
|
2948
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2949
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2950
|
+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
2951
|
+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
2952
|
+
|
2953
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
|
2954
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
|
2955
|
+
#endif
|
2956
|
+
}
|
2957
|
+
|
2958
|
+
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
|
2959
|
+
#elif defined(__AVX2__)
|
2960
|
+
// Initialize accumulator with zeros
|
2961
|
+
__m256 acc = _mm256_setzero_ps();
|
2962
|
+
|
2963
|
+
float summs = 0;
|
2964
|
+
|
2965
|
+
// Main loop
|
2966
|
+
for (int i = 0; i < nb; ++i) {
|
2967
|
+
const float * d0 = &x[i].d;
|
2968
|
+
const float * d1 = &y[i].d;
|
2969
|
+
|
2970
|
+
summs += x[i].m * (y[i].s0 + y[i].s1);
|
2971
|
+
|
2972
|
+
const __m256 d0v = _mm256_broadcast_ss( d0 );
|
2973
|
+
const __m256 d1v = _mm256_broadcast_ss( d1 );
|
2974
|
+
|
2975
|
+
// Compute combined scales
|
2976
|
+
const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
|
2977
|
+
|
2978
|
+
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
|
2979
|
+
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
|
2980
|
+
const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
|
2981
|
+
|
2982
|
+
const __m256 xy = mul_sum_i8_pairs_float(bx, by);
|
2983
|
+
|
2984
|
+
// Accumulate d0*d1*x*y
|
2985
|
+
acc = _mm256_fmadd_ps( d0d1, xy, acc );
|
2986
|
+
}
|
2987
|
+
|
2988
|
+
*s = hsum_float_8(acc) + summs;
|
2989
|
+
#else
|
2990
|
+
// scalar
|
2991
|
+
float sumf = 0.0;
|
2992
|
+
for (int i = 0; i < nb; i++) {
|
2993
|
+
const float d0 = x[i].d;
|
2994
|
+
const float m0 = x[i].m;
|
2995
|
+
const float d1 = y[i].d;
|
2996
|
+
|
2997
|
+
const uint8_t * restrict p0 = x[i].qs;
|
2998
|
+
const int8_t * restrict p1 = y[i].qs;
|
2999
|
+
|
3000
|
+
// TODO: this is very slow ..
|
3001
|
+
for (int j = 0; j < QK8_1/2; j++) {
|
3002
|
+
const uint8_t v0 = p0[j];
|
3003
|
+
|
3004
|
+
const float f0 = d0*(v0 & 0x0F) + m0;
|
3005
|
+
const float f1 = d0*(v0 >> 4) + m0;
|
3006
|
+
|
3007
|
+
const float f2 = d1*p1[2*j + 0];
|
3008
|
+
const float f3 = d1*p1[2*j + 1];
|
3009
|
+
|
3010
|
+
sumf += f0*f2 + f1*f3;
|
3011
|
+
}
|
3012
|
+
}
|
3013
|
+
*s = sumf;
|
3014
|
+
#endif
|
3015
|
+
}
|
3016
|
+
|
3017
|
+
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
3018
|
+
const int nb = n / QK8_0;
|
3019
|
+
|
3020
|
+
assert(n % QK8_0 == 0);
|
3021
|
+
assert(nb % 2 == 0);
|
3022
|
+
assert(QK8_0 == 2*QK4_2);
|
3023
|
+
|
3024
|
+
const block_q4_2 * restrict x = vx;
|
3025
|
+
const block_q8_0 * restrict y = vy;
|
3026
|
+
|
3027
|
+
#if defined(__ARM_NEON)
|
3028
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
3029
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
3030
|
+
|
3031
|
+
for (int i = 0; i < nb; i += 2) {
|
3032
|
+
const block_q4_2 * restrict x0_0 = &x[2*(i + 0) + 0];
|
3033
|
+
const block_q4_2 * restrict x0_1 = &x[2*(i + 0) + 1];
|
3034
|
+
const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0];
|
3035
|
+
const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1];
|
3036
|
+
|
3037
|
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
3038
|
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
3039
|
+
|
3040
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
3041
|
+
const int8x16_t s8b = vdupq_n_s8(0x8);
|
3042
|
+
|
3043
|
+
const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
|
3044
|
+
const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
|
3045
|
+
|
3046
|
+
// 4-bit -> 8-bit
|
3047
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
3048
|
+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
3049
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
3050
|
+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
3051
|
+
|
3052
|
+
// sub 8
|
3053
|
+
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
3054
|
+
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
3055
|
+
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
3056
|
+
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
3057
|
+
|
3058
|
+
// interleave
|
3059
|
+
const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
|
3060
|
+
const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
|
3061
|
+
const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
|
3062
|
+
const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
|
3063
|
+
|
3064
|
+
// load y
|
3065
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
3066
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
3067
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
3068
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
3069
|
+
|
3070
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
3071
|
+
sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
|
3072
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
|
3073
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
|
3074
|
+
|
3075
|
+
sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
|
3076
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
|
3077
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
|
3078
|
+
#else
|
3079
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
|
3080
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
|
3081
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
|
3082
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
|
3083
|
+
|
3084
|
+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
|
3085
|
+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
|
3086
|
+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
|
3087
|
+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
|
3088
|
+
|
3089
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
3090
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
3091
|
+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
3092
|
+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
3093
|
+
|
3094
|
+
sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
|
3095
|
+
vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
|
3096
|
+
vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
|
3097
|
+
|
3098
|
+
sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
|
3099
|
+
vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
|
3100
|
+
vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
|
3101
|
+
#endif
|
3102
|
+
}
|
3103
|
+
|
3104
|
+
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
3105
|
+
#elif defined(__AVX2__)
|
3106
|
+
// Initialize accumulator with zeros
|
3107
|
+
__m256 acc = _mm256_setzero_ps();
|
3108
|
+
|
3109
|
+
// Main loop
|
3110
|
+
for (int i = 0; i < nb; i++) {
|
3111
|
+
/* Compute combined scale for the block */
|
3112
|
+
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
|
3113
|
+
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
|
3114
|
+
const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
|
3115
|
+
|
3116
|
+
__m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
|
3117
|
+
__m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
|
3118
|
+
__m256i bx = _mm256_set_m128i(bx1, bx0);
|
3119
|
+
|
3120
|
+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
3121
|
+
const __m256i off = _mm256_set1_epi8(8);
|
3122
|
+
bx = _mm256_sub_epi8(bx, off);
|
3123
|
+
|
3124
|
+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
3125
|
+
|
3126
|
+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
3127
|
+
|
3128
|
+
/* Multiply q with scale and accumulate */
|
3129
|
+
acc = _mm256_fmadd_ps(d, q, acc);
|
3130
|
+
}
|
3131
|
+
|
3132
|
+
*s = hsum_float_8(acc);
|
3133
|
+
#else
|
3134
|
+
// scalar
|
3135
|
+
float sumf = 0.0;
|
3136
|
+
for (int i = 0; i < nb; i++) {
|
3137
|
+
const uint8_t * restrict x0 = x[2*i + 0].qs;
|
3138
|
+
const uint8_t * restrict x1 = x[2*i + 1].qs;
|
3139
|
+
const int8_t * restrict y0 = y[i].qs;
|
3140
|
+
|
3141
|
+
const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
|
3142
|
+
const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
|
3143
|
+
|
3144
|
+
int sumi_0 = 0;
|
3145
|
+
int sumi_1 = 0;
|
3146
|
+
|
3147
|
+
for (int j = 0; j < QK8_0/4; j++) {
|
3148
|
+
const uint8_t v0 = x0[j];
|
3149
|
+
const uint8_t v1 = x1[j];
|
3150
|
+
|
3151
|
+
const int i0_0 = (int8_t) (v0 & 0x0F) - 8;
|
3152
|
+
const int i1_0 = (int8_t) (v0 >> 4) - 8;
|
3153
|
+
|
3154
|
+
const int i0_1 = (int8_t) (v1 & 0x0F) - 8;
|
3155
|
+
const int i1_1 = (int8_t) (v1 >> 4) - 8;
|
3156
|
+
|
3157
|
+
const int i2_0 = y0[2*j + 0];
|
3158
|
+
const int i3_0 = y0[2*j + 1];
|
3159
|
+
|
3160
|
+
const int i2_1 = y0[2*(j + QK8_0/4) + 0];
|
3161
|
+
const int i3_1 = y0[2*(j + QK8_0/4) + 1];
|
3162
|
+
|
3163
|
+
sumi_0 += i0_0*i2_0 + i1_0*i3_0;
|
3164
|
+
sumi_1 += i0_1*i2_1 + i1_1*i3_1;
|
3165
|
+
}
|
3166
|
+
|
3167
|
+
sumf += (d0 * y[i].d) * sumi_0;
|
3168
|
+
sumf += (d1 * y[i].d) * sumi_1;
|
3169
|
+
}
|
3170
|
+
*s = sumf;
|
3171
|
+
#endif
|
3172
|
+
}
|
3173
|
+
|
3174
|
+
static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
3175
|
+
const int nb = n / QK8_1;
|
3176
|
+
|
3177
|
+
assert(n % QK8_1 == 0);
|
3178
|
+
assert(nb % 2 == 0);
|
3179
|
+
assert(QK8_1 == 2*QK4_3);
|
3180
|
+
|
3181
|
+
const block_q4_3 * restrict x = vx;
|
3182
|
+
const block_q8_1 * restrict y = vy;
|
3183
|
+
|
3184
|
+
#if defined(__ARM_NEON)
|
3185
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
3186
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
3187
|
+
|
3188
|
+
float summs0 = 0.0f;
|
3189
|
+
float summs1 = 0.0f;
|
3190
|
+
|
3191
|
+
for (int i = 0; i < nb; ++i) {
|
3192
|
+
const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
|
3193
|
+
const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
|
3194
|
+
|
3195
|
+
const block_q8_1 * restrict y0 = &y[i + 0];
|
3196
|
+
|
3197
|
+
summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0;
|
3198
|
+
summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1;
|
3199
|
+
|
3200
|
+
const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
|
3201
|
+
|
3202
|
+
// 4-bit -> 8-bit
|
3203
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0x0F)));
|
3204
|
+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
3205
|
+
|
3206
|
+
// interleave
|
3207
|
+
const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
|
3208
|
+
const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
|
3209
|
+
|
3210
|
+
// load y
|
3211
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
3212
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
3213
|
+
|
3214
|
+
const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
|
3215
|
+
const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
|
3216
|
+
|
3217
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
3218
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
|
3219
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
|
3220
|
+
#else
|
3221
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
|
3222
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
|
3223
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
|
3224
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
|
3225
|
+
|
3226
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
3227
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
3228
|
+
|
3229
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
|
3230
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph0), x0_1d*y0->d);
|
3231
|
+
#endif
|
3232
|
+
}
|
3233
|
+
|
3234
|
+
*s = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1;
|
3235
|
+
#elif defined(__AVX2__)
|
3236
|
+
// Initialize accumulator with zeros
|
3237
|
+
__m256 acc = _mm256_setzero_ps();
|
3238
|
+
float summs = 0.0f;
|
3239
|
+
|
3240
|
+
// Main loop
|
3241
|
+
for (int i = 0; i < nb; i++) {
|
3242
|
+
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
|
3243
|
+
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
|
3244
|
+
const __m256 dx = _mm256_set_m128(d1, d0);
|
3245
|
+
|
3246
|
+
summs += GGML_FP16_TO_FP32(x[2*i + 0].m) * y[i].s0
|
3247
|
+
+ GGML_FP16_TO_FP32(x[2*i + 1].m) * y[i].s1;
|
3248
|
+
|
3249
|
+
const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
|
3250
|
+
const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
|
3251
|
+
const __m256i bx = _mm256_set_m128i(bx1, bx0);
|
3252
|
+
|
3253
|
+
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
|
3254
|
+
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
3255
|
+
|
3256
|
+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
3257
|
+
|
3258
|
+
acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
|
3259
|
+
}
|
3260
|
+
|
3261
|
+
*s = hsum_float_8(acc) + summs;
|
3262
|
+
#else
|
3263
|
+
// scalar
|
3264
|
+
float sumf = 0.0;
|
3265
|
+
for (int i = 0; i < nb; i++) {
|
3266
|
+
const uint8_t * restrict x0 = x[2*i + 0].qs;
|
3267
|
+
const uint8_t * restrict x1 = x[2*i + 1].qs;
|
3268
|
+
const int8_t * restrict y0 = y[i].qs;
|
3269
|
+
|
3270
|
+
const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
|
3271
|
+
const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
|
3272
|
+
const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
|
3273
|
+
const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
|
3274
|
+
|
3275
|
+
int sxy_0 = 0;
|
3276
|
+
int sxy_1 = 0;
|
3277
|
+
|
3278
|
+
for (int j = 0; j < QK8_1/4; j++) {
|
3279
|
+
const uint8_t v0 = x0[j];
|
3280
|
+
const uint8_t v1 = x1[j];
|
3281
|
+
|
3282
|
+
const int x0_0 = v0 & 0x0F;
|
3283
|
+
const int x1_0 = v0 >> 4;
|
3284
|
+
|
3285
|
+
const int x0_1 = v1 & 0x0F;
|
3286
|
+
const int x1_1 = v1 >> 4;
|
3287
|
+
|
3288
|
+
const int y0_0 = y0[2*j + 0];
|
3289
|
+
const int y1_0 = y0[2*j + 1];
|
3290
|
+
|
3291
|
+
const int y0_1 = y0[2*(j + QK8_1/4) + 0];
|
3292
|
+
const int y1_1 = y0[2*(j + QK8_1/4) + 1];
|
3293
|
+
|
3294
|
+
sxy_0 += x0_0*y0_0 + x1_0*y1_0;
|
3295
|
+
sxy_1 += x0_1*y0_1 + x1_1*y1_1;
|
3296
|
+
}
|
3297
|
+
|
3298
|
+
sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1;
|
3299
|
+
}
|
3300
|
+
*s = sumf;
|
3301
|
+
#endif
|
3302
|
+
}
|
3303
|
+
|
3304
|
+
static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
3305
|
+
const int nb = n / QK8_0;
|
3306
|
+
|
3307
|
+
assert(n % QK8_0 == 0);
|
3308
|
+
assert(nb % 2 == 0);
|
3309
|
+
assert(QK8_0 == QK5_0);
|
3310
|
+
|
3311
|
+
const block_q5_0 * restrict x = vx;
|
3312
|
+
const block_q8_0 * restrict y = vy;
|
2451
3313
|
|
2452
|
-
|
2453
|
-
|
2454
|
-
// Apply the scale, and accumulate
|
2455
|
-
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
|
2456
|
-
}
|
3314
|
+
#if defined(__ARM_NEON)
|
3315
|
+
float32x4_t sumv = vdupq_n_f32(0.0f);
|
2457
3316
|
|
2458
|
-
|
2459
|
-
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
2460
|
-
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2461
|
-
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2462
|
-
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
3317
|
+
uint64_t tmp[4];
|
2463
3318
|
|
2464
|
-
|
2465
|
-
|
2466
|
-
|
2467
|
-
float sum0 = 0.0f;
|
2468
|
-
float sum1 = 0.0f;
|
3319
|
+
for (int i = 0; i < nb; ++i) {
|
3320
|
+
const block_q5_0 * restrict x0 = &x[i];
|
3321
|
+
const block_q8_0 * restrict y0 = &y[i];
|
2469
3322
|
|
2470
|
-
|
2471
|
-
const
|
2472
|
-
const block_q4_0 * restrict y0 = &y[i + 0];
|
2473
|
-
const block_q4_0 * restrict x1 = &x[i + 1];
|
2474
|
-
const block_q4_0 * restrict y1 = &y[i + 1];
|
3323
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
3324
|
+
const int8x16_t s16b = vdupq_n_s8(0x10);
|
2475
3325
|
|
2476
|
-
|
2477
|
-
|
3326
|
+
// extract the 5th bit
|
3327
|
+
uint32_t qh;
|
3328
|
+
memcpy(&qh, x0->qh, sizeof(qh));
|
2478
3329
|
|
2479
|
-
|
2480
|
-
|
2481
|
-
|
2482
|
-
|
3330
|
+
tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
|
3331
|
+
tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
|
3332
|
+
tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
|
3333
|
+
tmp[3] = table_b2b_u[(qh >> 24) ];
|
2483
3334
|
|
2484
|
-
|
2485
|
-
const
|
2486
|
-
const v128_t v1_0l = wasm_v128_and(v1_0, m4b);
|
3335
|
+
const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
|
3336
|
+
const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
|
2487
3337
|
|
2488
|
-
const
|
2489
|
-
const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4);
|
3338
|
+
const uint8x16_t v0 = vld1q_u8(x0->qs);
|
2490
3339
|
|
2491
|
-
|
2492
|
-
const
|
3340
|
+
// 4-bit -> 8-bit
|
3341
|
+
const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, m4b));
|
3342
|
+
const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
|
2493
3343
|
|
2494
|
-
|
2495
|
-
const
|
3344
|
+
// interleave
|
3345
|
+
const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
|
3346
|
+
const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
|
2496
3347
|
|
2497
|
-
// sub
|
2498
|
-
const
|
2499
|
-
const
|
3348
|
+
// add high bit and sub 16
|
3349
|
+
const int8x16_t v0lf = vsubq_s8(vorrq_s8(v0lz, qhl), s16b);
|
3350
|
+
const int8x16_t v0hf = vsubq_s8(vorrq_s8(v0hz, qhh), s16b);
|
2500
3351
|
|
2501
|
-
|
2502
|
-
const
|
3352
|
+
// load y
|
3353
|
+
const int8x16_t v1l = vld1q_s8(y0->qs);
|
3354
|
+
const int8x16_t v1h = vld1q_s8(y0->qs + 16);
|
2503
3355
|
|
2504
|
-
const
|
2505
|
-
const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b);
|
3356
|
+
const float x0d = GGML_FP16_TO_FP32(x0->d);
|
2506
3357
|
|
2507
|
-
|
2508
|
-
|
3358
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
3359
|
+
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
|
3360
|
+
vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
|
3361
|
+
vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
|
3362
|
+
#else
|
3363
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
|
3364
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
|
3365
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
|
3366
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
|
2509
3367
|
|
2510
|
-
|
2511
|
-
const
|
2512
|
-
const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls));
|
3368
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
3369
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2513
3370
|
|
2514
|
-
|
2515
|
-
|
3371
|
+
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
|
3372
|
+
#endif
|
3373
|
+
}
|
2516
3374
|
|
2517
|
-
|
2518
|
-
|
3375
|
+
*s = vaddvq_f32(sumv);
|
3376
|
+
#elif defined(__AVX2__)
|
3377
|
+
// Initialize accumulator with zeros
|
3378
|
+
__m256 acc = _mm256_setzero_ps();
|
2519
3379
|
|
2520
|
-
|
2521
|
-
|
3380
|
+
// Main loop
|
3381
|
+
for (int i = 0; i < nb; i++) {
|
3382
|
+
/* Compute combined scale for the block */
|
3383
|
+
const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
|
2522
3384
|
|
2523
|
-
|
2524
|
-
|
3385
|
+
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
3386
|
+
__m256i bxhi = bytes_from_bits_32(x[i].qh);
|
3387
|
+
bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
|
3388
|
+
bx = _mm256_or_si256(bx, bxhi);
|
2525
3389
|
|
2526
|
-
|
2527
|
-
const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h);
|
3390
|
+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
2528
3391
|
|
2529
|
-
const
|
2530
|
-
const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1);
|
3392
|
+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
2531
3393
|
|
2532
|
-
|
2533
|
-
|
2534
|
-
wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) +
|
2535
|
-
wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) +
|
2536
|
-
wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7));
|
2537
|
-
sum1 += x1->d * y1->d * (
|
2538
|
-
wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) +
|
2539
|
-
wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) +
|
2540
|
-
wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) +
|
2541
|
-
wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7));
|
3394
|
+
/* Multiply q with scale and accumulate */
|
3395
|
+
acc = _mm256_fmadd_ps(d, q, acc);
|
2542
3396
|
}
|
2543
3397
|
|
2544
|
-
|
3398
|
+
*s = hsum_float_8(acc);
|
2545
3399
|
#else
|
2546
3400
|
// scalar
|
3401
|
+
float sumf = 0.0;
|
2547
3402
|
for (int i = 0; i < nb; i++) {
|
2548
|
-
const
|
2549
|
-
const
|
3403
|
+
const uint8_t * restrict x0 = x[i].qs;
|
3404
|
+
const int8_t * restrict y0 = y[i].qs;
|
2550
3405
|
|
2551
|
-
|
2552
|
-
|
3406
|
+
uint32_t qh;
|
3407
|
+
memcpy(&qh, x[i].qh, sizeof(qh));
|
2553
3408
|
|
2554
|
-
|
2555
|
-
|
2556
|
-
|
2557
|
-
const uint8_t v1 = p1[j];
|
3409
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
3410
|
+
|
3411
|
+
int sxy = 0;
|
2558
3412
|
|
2559
|
-
|
2560
|
-
const
|
3413
|
+
for (int j = 0; j < QK8_0/2; j++) {
|
3414
|
+
const uint8_t v0 = x0[j];
|
2561
3415
|
|
2562
|
-
const int
|
2563
|
-
const int
|
3416
|
+
const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4;
|
3417
|
+
const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4;
|
2564
3418
|
|
2565
|
-
|
3419
|
+
const int x0_0 = ((v0 & 0x0F) | x0_0h) - 16;
|
3420
|
+
const int x1_0 = ((v0 >> 4) | x1_0h) - 16;
|
3421
|
+
|
3422
|
+
const int y0_0 = y0[2*j + 0];
|
3423
|
+
const int y1_0 = y0[2*j + 1];
|
3424
|
+
|
3425
|
+
sxy += x0_0*y0_0 + x1_0*y1_0;
|
2566
3426
|
}
|
2567
|
-
sumf += d0 * d1 * sumi;
|
2568
|
-
}
|
2569
|
-
#endif
|
2570
3427
|
|
3428
|
+
sumf += (d*sxy)*y[i].d;
|
3429
|
+
}
|
2571
3430
|
*s = sumf;
|
3431
|
+
#endif
|
2572
3432
|
}
|
2573
3433
|
|
2574
|
-
static void
|
2575
|
-
const int nb = n /
|
3434
|
+
static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
3435
|
+
const int nb = n / QK8_1;
|
2576
3436
|
|
2577
|
-
|
2578
|
-
|
2579
|
-
|
2580
|
-
float sumf = 0.0;
|
3437
|
+
assert(n % QK8_1 == 0);
|
3438
|
+
assert(nb % 2 == 0);
|
3439
|
+
assert(QK8_1 == QK5_1);
|
2581
3440
|
|
2582
|
-
|
2583
|
-
|
2584
|
-
__m256 acc = _mm256_setzero_ps();
|
2585
|
-
// Accumulator for constant offsets
|
2586
|
-
float acc_offset = 0.0f;
|
3441
|
+
const block_q5_1 * restrict x = vx;
|
3442
|
+
const block_q8_1 * restrict y = vy;
|
2587
3443
|
|
2588
|
-
|
2589
|
-
|
2590
|
-
const float * d0 = &x[i].d;
|
2591
|
-
const float * d1 = &y[i].d;
|
3444
|
+
#if defined(__ARM_NEON)
|
3445
|
+
float32x4_t sumv = vdupq_n_f32(0.0f);
|
2592
3446
|
|
2593
|
-
|
2594
|
-
const float * m1 = &y[i].m;
|
3447
|
+
float summs = 0.0f;
|
2595
3448
|
|
2596
|
-
|
2597
|
-
const __m256 d1v = _mm256_broadcast_ss( d1 );
|
2598
|
-
const __m256 m0v = _mm256_broadcast_ss( m0 );
|
2599
|
-
const __m256 m1v = _mm256_broadcast_ss( m1 );
|
3449
|
+
uint64_t tmp[4];
|
2600
3450
|
|
2601
|
-
|
2602
|
-
const
|
3451
|
+
for (int i = 0; i < nb; ++i) {
|
3452
|
+
const block_q5_1 * restrict x0 = &x[i];
|
3453
|
+
const block_q8_1 * restrict y0 = &y[i];
|
2603
3454
|
|
2604
|
-
|
2605
|
-
const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
|
2606
|
-
const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
|
2607
|
-
const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ );
|
3455
|
+
summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
|
2608
3456
|
|
2609
|
-
//
|
2610
|
-
|
2611
|
-
|
2612
|
-
|
2613
|
-
// Now we have a vector with bytes in [ 0 .. 15 ] interval.
|
2614
|
-
|
2615
|
-
// Sign-extend first 16 signed bytes into int16_t
|
2616
|
-
__m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
|
2617
|
-
__m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
|
2618
|
-
// Compute products of int16_t integers, add pairwise
|
2619
|
-
__m256i i32 = _mm256_madd_epi16( x16, y16 );
|
2620
|
-
|
2621
|
-
// Sign-extend last 16 signed bytes into int16_t vectors
|
2622
|
-
__m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
|
2623
|
-
__m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
|
2624
|
-
// Accumulate products of int16_t integers
|
2625
|
-
i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) );
|
2626
|
-
|
2627
|
-
// compute sums of unsigned bytes in bx, by in blocks of 8.
|
2628
|
-
// This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
|
2629
|
-
// which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
|
2630
|
-
// so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
|
2631
|
-
__m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() );
|
2632
|
-
__m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() );
|
2633
|
-
__m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
|
2634
|
-
__m256 sums = _mm256_cvtepi32_ps( sumsi );
|
3457
|
+
// extract the 5th bit
|
3458
|
+
uint32_t qh;
|
3459
|
+
memcpy(&qh, x0->qh, sizeof(qh));
|
2635
3460
|
|
2636
|
-
|
2637
|
-
|
2638
|
-
|
2639
|
-
|
2640
|
-
acc = _mm256_fmadd_ps( scale_01, p, acc );
|
2641
|
-
acc = _mm256_fmadd_ps( cross_scales, sums, acc );
|
2642
|
-
// acc_offset += m0*m1 (for each entry in the block)
|
2643
|
-
acc_offset += (*m0)*(*m1);
|
2644
|
-
}
|
3461
|
+
tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
|
3462
|
+
tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
|
3463
|
+
tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
|
3464
|
+
tmp[3] = table_b2b_u[(qh >> 24) ];
|
2645
3465
|
|
2646
|
-
|
2647
|
-
|
2648
|
-
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2649
|
-
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2650
|
-
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
3466
|
+
const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
|
3467
|
+
const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
|
2651
3468
|
|
2652
|
-
|
2653
|
-
#elif defined(__ARM_NEON)
|
2654
|
-
float sum00 = 0.0f;
|
2655
|
-
float sum01 = 0.0f;
|
2656
|
-
float sum10 = 0.0f;
|
2657
|
-
float sum11 = 0.0f;
|
3469
|
+
const uint8x16_t v0 = vld1q_u8(x0->qs);
|
2658
3470
|
|
2659
|
-
|
2660
|
-
const
|
2661
|
-
const
|
2662
|
-
const block_q4_1 * restrict x1 = &x[i + 1];
|
2663
|
-
const block_q4_1 * restrict y1 = &y[i + 1];
|
3471
|
+
// 4-bit -> 8-bit
|
3472
|
+
const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, vdupq_n_u8(0x0F)));
|
3473
|
+
const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
|
2664
3474
|
|
2665
|
-
|
3475
|
+
// interleave
|
3476
|
+
const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
|
3477
|
+
const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
|
2666
3478
|
|
2667
|
-
|
2668
|
-
const
|
2669
|
-
const
|
2670
|
-
const uint8x16_t v1_1 = vld1q_u8(y1->qs);
|
3479
|
+
// add
|
3480
|
+
const int8x16_t v0lf = vorrq_s8(v0lz, qhl);
|
3481
|
+
const int8x16_t v0hf = vorrq_s8(v0hz, qhh);
|
2671
3482
|
|
2672
|
-
//
|
2673
|
-
const
|
2674
|
-
const
|
2675
|
-
const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
|
2676
|
-
const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
|
3483
|
+
// load y
|
3484
|
+
const int8x16_t v1l = vld1q_s8(y0->qs);
|
3485
|
+
const int8x16_t v1h = vld1q_s8(y0->qs + 16);
|
2677
3486
|
|
2678
|
-
const
|
2679
|
-
const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
|
2680
|
-
const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
|
2681
|
-
const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
|
3487
|
+
const float x0d = GGML_FP16_TO_FP32(x0->d);
|
2682
3488
|
|
2683
|
-
|
2684
|
-
|
2685
|
-
|
3489
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
3490
|
+
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
|
3491
|
+
vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
|
3492
|
+
vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
|
3493
|
+
#else
|
3494
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
|
3495
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
|
3496
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
|
3497
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
|
2686
3498
|
|
2687
|
-
|
2688
|
-
|
2689
|
-
sum10 += x1->m*y1->d*((uint16_t)vaddvq_u8(v1_1l) + (uint16_t)vaddvq_u8(v1_1h));
|
3499
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
3500
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2690
3501
|
|
2691
|
-
|
2692
|
-
|
2693
|
-
|
2694
|
-
uint32x4_t p_1 = vdotq_u32(vdupq_n_u32(0), v0_1l, v1_1l);
|
3502
|
+
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
|
3503
|
+
#endif
|
3504
|
+
}
|
2695
3505
|
|
2696
|
-
|
2697
|
-
|
3506
|
+
*s = vaddvq_f32(sumv) + summs;
|
3507
|
+
#elif defined(__AVX2__)
|
3508
|
+
// Initialize accumulator with zeros
|
3509
|
+
__m256 acc = _mm256_setzero_ps();
|
3510
|
+
float summs = 0.0f;
|
2698
3511
|
|
2699
|
-
|
2700
|
-
|
2701
|
-
|
2702
|
-
const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
|
2703
|
-
const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
|
2704
|
-
const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
|
2705
|
-
const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
|
3512
|
+
// Main loop
|
3513
|
+
for (int i = 0; i < nb; i++) {
|
3514
|
+
const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
|
2706
3515
|
|
2707
|
-
|
2708
|
-
const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
|
2709
|
-
const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
|
2710
|
-
const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
|
3516
|
+
summs += GGML_FP16_TO_FP32(x[i].m) * (y[i].s0 + y[i].s1);
|
2711
3517
|
|
2712
|
-
|
2713
|
-
|
3518
|
+
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
3519
|
+
__m256i bxhi = bytes_from_bits_32(x[i].qh);
|
3520
|
+
bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
|
3521
|
+
bx = _mm256_or_si256(bx, bxhi);
|
2714
3522
|
|
2715
|
-
const
|
2716
|
-
const
|
3523
|
+
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
|
3524
|
+
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
2717
3525
|
|
2718
|
-
const
|
2719
|
-
const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
|
3526
|
+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
2720
3527
|
|
2721
|
-
|
2722
|
-
sum11 += x1->d*y1->d*vaddvq_u16(p_1);
|
2723
|
-
#endif
|
3528
|
+
acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
|
2724
3529
|
}
|
2725
3530
|
|
2726
|
-
|
3531
|
+
*s = hsum_float_8(acc) + summs;
|
2727
3532
|
#else
|
2728
|
-
|
3533
|
+
float sumf = 0.0;
|
3534
|
+
|
2729
3535
|
for (int i = 0; i < nb; i++) {
|
2730
|
-
const
|
2731
|
-
const
|
3536
|
+
const uint8_t * restrict x0 = x[i].qs;
|
3537
|
+
const int8_t * restrict y0 = y[i].qs;
|
2732
3538
|
|
2733
|
-
|
2734
|
-
|
3539
|
+
uint32_t qh;
|
3540
|
+
memcpy(&qh, x[i].qh, sizeof(qh));
|
2735
3541
|
|
2736
|
-
const
|
2737
|
-
const
|
3542
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
3543
|
+
const float m = GGML_FP16_TO_FP32(x[i].m);
|
2738
3544
|
|
2739
|
-
|
2740
|
-
const uint8_t v0 = p0[j];
|
2741
|
-
const uint8_t v1 = p1[j];
|
3545
|
+
int sxy = 0;
|
2742
3546
|
|
2743
|
-
|
2744
|
-
const
|
3547
|
+
for (int j = 0; j < QK8_1/2; j++) {
|
3548
|
+
const uint8_t v0 = x0[j];
|
2745
3549
|
|
2746
|
-
const
|
2747
|
-
const
|
3550
|
+
const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4;
|
3551
|
+
const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4;
|
2748
3552
|
|
2749
|
-
|
3553
|
+
const int x0_0 = (v0 & 0x0F) | x0_0h;
|
3554
|
+
const int x1_0 = (v0 >> 4) | x1_0h;
|
3555
|
+
|
3556
|
+
const int y0_0 = y0[2*j + 0];
|
3557
|
+
const int y1_0 = y0[2*j + 1];
|
3558
|
+
|
3559
|
+
sxy += x0_0*y0_0 + x1_0*y1_0;
|
2750
3560
|
}
|
3561
|
+
|
3562
|
+
sumf += (d*sxy)*y[i].d + m*(y[i].s0 + y[i].s1);
|
2751
3563
|
}
|
2752
|
-
#endif
|
2753
3564
|
|
2754
3565
|
*s = sumf;
|
3566
|
+
#endif
|
2755
3567
|
}
|
2756
3568
|
|
2757
|
-
static void
|
3569
|
+
static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2758
3570
|
const int nb = n / QK8_0;
|
2759
3571
|
|
2760
3572
|
assert(n % QK8_0 == 0);
|
2761
3573
|
assert(nb % 2 == 0);
|
3574
|
+
assert(QK8_0 == QK8_0);
|
2762
3575
|
|
2763
|
-
const
|
3576
|
+
const block_q8_0 * restrict x = vx;
|
2764
3577
|
const block_q8_0 * restrict y = vy;
|
2765
3578
|
|
2766
|
-
float sumf = 0.0;
|
2767
|
-
|
2768
3579
|
#if defined(__ARM_NEON)
|
2769
|
-
|
2770
|
-
|
3580
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
3581
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
2771
3582
|
|
2772
3583
|
for (int i = 0; i < nb; i += 2) {
|
2773
|
-
const
|
2774
|
-
const
|
3584
|
+
const block_q8_0 * restrict x0 = &x[i + 0];
|
3585
|
+
const block_q8_0 * restrict x1 = &x[i + 1];
|
2775
3586
|
const block_q8_0 * restrict y0 = &y[i + 0];
|
2776
3587
|
const block_q8_0 * restrict y1 = &y[i + 1];
|
2777
3588
|
|
2778
|
-
const
|
2779
|
-
const int8x16_t
|
2780
|
-
|
2781
|
-
const
|
2782
|
-
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
2783
|
-
|
2784
|
-
// 4-bit -> 8-bit
|
2785
|
-
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
2786
|
-
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
2787
|
-
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
2788
|
-
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
2789
|
-
|
2790
|
-
// sub 8
|
2791
|
-
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
2792
|
-
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
2793
|
-
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
2794
|
-
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
3589
|
+
const int8x16_t x0_0 = vld1q_s8(x0->qs);
|
3590
|
+
const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
|
3591
|
+
const int8x16_t x1_0 = vld1q_s8(x1->qs);
|
3592
|
+
const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
|
2795
3593
|
|
2796
3594
|
// load y
|
2797
|
-
const int8x16_t
|
2798
|
-
const int8x16_t
|
2799
|
-
const int8x16_t
|
2800
|
-
const int8x16_t
|
2801
|
-
|
2802
|
-
// interleave
|
2803
|
-
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
|
2804
|
-
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
|
2805
|
-
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
|
2806
|
-
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
|
3595
|
+
const int8x16_t y0_0 = vld1q_s8(y0->qs);
|
3596
|
+
const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
|
3597
|
+
const int8x16_t y1_0 = vld1q_s8(y1->qs);
|
3598
|
+
const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
|
2807
3599
|
|
2808
3600
|
#if defined(__ARM_FEATURE_DOTPROD)
|
2809
|
-
|
2810
|
-
|
2811
|
-
|
3601
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
3602
|
+
vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
|
3603
|
+
vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), x0->d*y0->d);
|
2812
3604
|
|
2813
|
-
|
2814
|
-
|
3605
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
3606
|
+
vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
|
3607
|
+
vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), x1->d*y1->d);
|
2815
3608
|
|
2816
|
-
sum0 += x0->d*y0->d*vaddvq_s32(p_0);
|
2817
|
-
sum1 += x1->d*y1->d*vaddvq_s32(p_1);
|
2818
3609
|
#else
|
2819
|
-
const int16x8_t
|
2820
|
-
const int16x8_t
|
2821
|
-
const int16x8_t
|
2822
|
-
const int16x8_t
|
2823
|
-
|
2824
|
-
const int16x8_t
|
2825
|
-
const int16x8_t
|
2826
|
-
const int16x8_t
|
2827
|
-
const int16x8_t
|
2828
|
-
|
2829
|
-
const
|
2830
|
-
const
|
2831
|
-
|
2832
|
-
const
|
2833
|
-
|
2834
|
-
|
2835
|
-
|
2836
|
-
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
|
2837
|
-
|
2838
|
-
sum0 += x0->d*y0->d*vaddvq_s16(p_0);
|
2839
|
-
sum1 += x1->d*y1->d*vaddvq_s16(p_1);
|
3610
|
+
const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
|
3611
|
+
const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
|
3612
|
+
const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
|
3613
|
+
const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
3614
|
+
|
3615
|
+
const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
|
3616
|
+
const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
|
3617
|
+
const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
|
3618
|
+
const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
|
3619
|
+
|
3620
|
+
const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
3621
|
+
const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
3622
|
+
const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
|
3623
|
+
const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
|
3624
|
+
|
3625
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0->d*y0->d);
|
3626
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), x1->d*y1->d);
|
2840
3627
|
#endif
|
2841
3628
|
}
|
2842
3629
|
|
2843
|
-
|
3630
|
+
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
2844
3631
|
#elif defined(__AVX2__)
|
2845
3632
|
// Initialize accumulator with zeros
|
2846
3633
|
__m256 acc = _mm256_setzero_ps();
|
2847
3634
|
|
2848
3635
|
// Main loop
|
2849
3636
|
for (int i = 0; i < nb; ++i) {
|
2850
|
-
|
3637
|
+
// Compute combined scale for the block
|
2851
3638
|
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
|
2852
|
-
|
2853
|
-
__m256i bx = bytesFromNibbles(x[i].qs);
|
2854
|
-
|
2855
|
-
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2856
|
-
const __m256i off = _mm256_set1_epi8( 8 );
|
2857
|
-
bx = _mm256_sub_epi8( bx, off );
|
2858
|
-
|
3639
|
+
__m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
|
2859
3640
|
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
2860
3641
|
|
2861
|
-
|
2862
|
-
const __m256i ax = _mm256_sign_epi8(bx, bx);
|
2863
|
-
|
2864
|
-
// Sign the values of the y vectors
|
2865
|
-
const __m256i sy = _mm256_sign_epi8(by, bx);
|
3642
|
+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
2866
3643
|
|
2867
|
-
//
|
2868
|
-
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
2869
|
-
|
2870
|
-
const __m256i ones = _mm256_set1_epi16(1);
|
2871
|
-
__m256i xy_q = _mm256_madd_epi16(ones, dot);
|
2872
|
-
|
2873
|
-
/* Convert to vectore of 8 int32_t to 8 floats */
|
2874
|
-
__m256 q = _mm256_cvtepi32_ps( xy_q );
|
2875
|
-
|
2876
|
-
/* Multiply q with scale and accumulate */
|
3644
|
+
// Multiply q with scale and accumulate
|
2877
3645
|
acc = _mm256_fmadd_ps( d, q, acc );
|
2878
3646
|
}
|
2879
3647
|
|
2880
|
-
|
2881
|
-
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
2882
|
-
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2883
|
-
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2884
|
-
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2885
|
-
|
2886
|
-
sumf = _mm_cvtss_f32( res );
|
2887
|
-
#elif defined(__AVX__)
|
2888
|
-
// Initialize accumulator with zeros
|
2889
|
-
__m256 acc = _mm256_setzero_ps();
|
2890
|
-
|
2891
|
-
// Main loop
|
2892
|
-
for (int i = 0; i < nb; ++i) {
|
2893
|
-
// Compute combined scale for the block
|
2894
|
-
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
|
2895
|
-
|
2896
|
-
__m128i i32[2];
|
2897
|
-
for (int j = 0; j < 2; ++j) {
|
2898
|
-
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
|
2899
|
-
__m128i bx = bytesFromNibbles( x[i].qs + 8*j );
|
2900
|
-
__m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
|
2901
|
-
|
2902
|
-
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2903
|
-
const __m128i off = _mm_set1_epi8( 8 );
|
2904
|
-
bx = _mm_sub_epi8( bx, off );
|
2905
|
-
|
2906
|
-
// Get absolute values of x vectors
|
2907
|
-
const __m128i ax = _mm_sign_epi8(bx, bx);
|
2908
|
-
|
2909
|
-
// Sign the values of the y vectors
|
2910
|
-
const __m128i sy = _mm_sign_epi8(by, bx);
|
2911
|
-
|
2912
|
-
// Perform multiplication and create 16-bit values
|
2913
|
-
const __m128i dot = _mm_maddubs_epi16(ax, sy);
|
2914
|
-
|
2915
|
-
const __m128i ones = _mm_set1_epi16(1);
|
2916
|
-
i32[j] = _mm_madd_epi16(ones, dot);
|
2917
|
-
}
|
2918
|
-
|
2919
|
-
// Convert int32_t to float
|
2920
|
-
__m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
|
2921
|
-
// Apply the scale, and accumulate
|
2922
|
-
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
|
2923
|
-
}
|
2924
|
-
|
2925
|
-
// Return horizontal sum of the acc vector
|
2926
|
-
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
2927
|
-
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2928
|
-
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2929
|
-
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2930
|
-
|
2931
|
-
sumf = _mm_cvtss_f32( res );
|
3648
|
+
*s = hsum_float_8(acc);
|
2932
3649
|
#else
|
2933
3650
|
// scalar
|
2934
|
-
|
2935
|
-
const float d0 = x[i].d;
|
2936
|
-
const float d1 = y[i].d;
|
3651
|
+
float sumf = 0.0;
|
2937
3652
|
|
2938
|
-
|
2939
|
-
const
|
3653
|
+
for (int i = 0; i < nb; i++) {
|
3654
|
+
const int8_t * restrict x0 = x[i].qs;
|
3655
|
+
const int8_t * restrict y0 = y[i].qs;
|
2940
3656
|
|
2941
3657
|
int sumi = 0;
|
2942
|
-
for (int j = 0; j < QK8_0/2; j++) {
|
2943
|
-
const uint8_t v0 = p0[j];
|
2944
3658
|
|
2945
|
-
|
2946
|
-
const int
|
3659
|
+
for (int j = 0; j < QK8_0; j++) {
|
3660
|
+
const int v0 = x0[j];
|
3661
|
+
const int v1 = y0[j];
|
2947
3662
|
|
2948
|
-
|
2949
|
-
const int i3 = p1[2*j + 1];
|
2950
|
-
|
2951
|
-
sumi += i0*i2 + i1*i3;
|
3663
|
+
sumi += v0*v1;
|
2952
3664
|
}
|
2953
|
-
|
3665
|
+
|
3666
|
+
sumf += (x[i].d*y[i].d)*sumi;
|
2954
3667
|
}
|
2955
|
-
#endif
|
2956
3668
|
|
2957
3669
|
*s = sumf;
|
3670
|
+
#endif
|
2958
3671
|
}
|
2959
3672
|
|
2960
3673
|
// compute GGML_VEC_DOT_UNROLL dot products at once
|
@@ -3153,6 +3866,14 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
|
|
3153
3866
|
#endif
|
3154
3867
|
}
|
3155
3868
|
|
3869
|
+
inline static void ggml_vec_sum_ggf(const int n, ggml_float * s, const float * x) {
|
3870
|
+
ggml_float sum = 0.0;
|
3871
|
+
for (int i = 0; i < n; ++i) {
|
3872
|
+
sum += (ggml_float)x[i];
|
3873
|
+
}
|
3874
|
+
*s = sum;
|
3875
|
+
}
|
3876
|
+
|
3156
3877
|
inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
|
3157
3878
|
#ifndef GGML_USE_ACCELERATE
|
3158
3879
|
float max = -INFINITY;
|
@@ -3203,24 +3924,34 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
|
|
3203
3924
|
[GGML_TYPE_F16] = 1,
|
3204
3925
|
[GGML_TYPE_Q4_0] = QK4_0,
|
3205
3926
|
[GGML_TYPE_Q4_1] = QK4_1,
|
3927
|
+
[GGML_TYPE_Q4_2] = QK4_2,
|
3928
|
+
[GGML_TYPE_Q4_3] = QK4_3,
|
3929
|
+
[GGML_TYPE_Q5_0] = QK5_0,
|
3930
|
+
[GGML_TYPE_Q5_1] = QK5_1,
|
3206
3931
|
[GGML_TYPE_Q8_0] = QK8_0,
|
3932
|
+
[GGML_TYPE_Q8_1] = QK8_1,
|
3207
3933
|
[GGML_TYPE_I8] = 1,
|
3208
3934
|
[GGML_TYPE_I16] = 1,
|
3209
3935
|
[GGML_TYPE_I32] = 1,
|
3210
3936
|
};
|
3211
|
-
static_assert(GGML_TYPE_COUNT ==
|
3937
|
+
static_assert(GGML_TYPE_COUNT == 13, "GGML_BLCK_SIZE is outdated");
|
3212
3938
|
|
3213
3939
|
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
|
3214
3940
|
[GGML_TYPE_F32] = sizeof(float),
|
3215
3941
|
[GGML_TYPE_F16] = sizeof(ggml_fp16_t),
|
3216
3942
|
[GGML_TYPE_Q4_0] = sizeof(block_q4_0),
|
3217
3943
|
[GGML_TYPE_Q4_1] = sizeof(block_q4_1),
|
3944
|
+
[GGML_TYPE_Q4_2] = sizeof(block_q4_2),
|
3945
|
+
[GGML_TYPE_Q4_3] = sizeof(block_q4_3),
|
3946
|
+
[GGML_TYPE_Q5_0] = sizeof(block_q5_0),
|
3947
|
+
[GGML_TYPE_Q5_1] = sizeof(block_q5_1),
|
3218
3948
|
[GGML_TYPE_Q8_0] = sizeof(block_q8_0),
|
3949
|
+
[GGML_TYPE_Q8_1] = sizeof(block_q8_1),
|
3219
3950
|
[GGML_TYPE_I8] = sizeof(int8_t),
|
3220
3951
|
[GGML_TYPE_I16] = sizeof(int16_t),
|
3221
3952
|
[GGML_TYPE_I32] = sizeof(int32_t),
|
3222
3953
|
};
|
3223
|
-
static_assert(GGML_TYPE_COUNT ==
|
3954
|
+
static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_SIZE is outdated");
|
3224
3955
|
|
3225
3956
|
|
3226
3957
|
static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
|
@@ -3228,12 +3959,34 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
|
|
3228
3959
|
[GGML_TYPE_F16] = "f16",
|
3229
3960
|
[GGML_TYPE_Q4_0] = "q4_0",
|
3230
3961
|
[GGML_TYPE_Q4_1] = "q4_1",
|
3962
|
+
[GGML_TYPE_Q4_2] = "q4_2",
|
3963
|
+
[GGML_TYPE_Q4_3] = "q4_3",
|
3964
|
+
[GGML_TYPE_Q5_0] = "q5_0",
|
3965
|
+
[GGML_TYPE_Q5_1] = "q5_1",
|
3231
3966
|
[GGML_TYPE_Q8_0] = "q8_0",
|
3967
|
+
[GGML_TYPE_Q8_1] = "q8_1",
|
3232
3968
|
[GGML_TYPE_I8] = "i8",
|
3233
3969
|
[GGML_TYPE_I16] = "i16",
|
3234
3970
|
[GGML_TYPE_I32] = "i32",
|
3235
3971
|
};
|
3236
|
-
static_assert(GGML_TYPE_COUNT ==
|
3972
|
+
static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_NAME is outdated");
|
3973
|
+
|
3974
|
+
static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
|
3975
|
+
[GGML_TYPE_F32] = false,
|
3976
|
+
[GGML_TYPE_F16] = false,
|
3977
|
+
[GGML_TYPE_Q4_0] = true,
|
3978
|
+
[GGML_TYPE_Q4_1] = true,
|
3979
|
+
[GGML_TYPE_Q4_2] = true,
|
3980
|
+
[GGML_TYPE_Q4_3] = true,
|
3981
|
+
[GGML_TYPE_Q5_0] = true,
|
3982
|
+
[GGML_TYPE_Q5_1] = true,
|
3983
|
+
[GGML_TYPE_Q8_0] = true,
|
3984
|
+
[GGML_TYPE_Q8_1] = true,
|
3985
|
+
[GGML_TYPE_I8] = false,
|
3986
|
+
[GGML_TYPE_I16] = false,
|
3987
|
+
[GGML_TYPE_I32] = false,
|
3988
|
+
};
|
3989
|
+
static_assert(GGML_TYPE_COUNT == 13, "GGML_IS_QUANTIZED is outdated");
|
3237
3990
|
|
3238
3991
|
static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
3239
3992
|
"NONE",
|
@@ -3495,6 +4248,10 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
|
|
3495
4248
|
(t0->ne[3] == t1->ne[3]);
|
3496
4249
|
}
|
3497
4250
|
|
4251
|
+
bool ggml_is_quantized(enum ggml_type type) {
|
4252
|
+
return GGML_IS_QUANTIZED[type];
|
4253
|
+
}
|
4254
|
+
|
3498
4255
|
static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
|
3499
4256
|
return tensor->nb[0] > tensor->nb[1];
|
3500
4257
|
}
|
@@ -3605,6 +4362,13 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|
3605
4362
|
GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
|
3606
4363
|
}
|
3607
4364
|
|
4365
|
+
// initialize cuBLAS
|
4366
|
+
#if defined(GGML_USE_CUBLAS)
|
4367
|
+
ggml_init_cublas();
|
4368
|
+
#elif defined(GGML_USE_CLBLAST)
|
4369
|
+
ggml_cl_init();
|
4370
|
+
#endif
|
4371
|
+
|
3608
4372
|
is_first_call = false;
|
3609
4373
|
}
|
3610
4374
|
|
@@ -5535,7 +6299,6 @@ static void ggml_compute_forward_dup_f16(
|
|
5535
6299
|
const struct ggml_compute_params * params,
|
5536
6300
|
const struct ggml_tensor * src0,
|
5537
6301
|
struct ggml_tensor * dst) {
|
5538
|
-
GGML_ASSERT(params->ith == 0);
|
5539
6302
|
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
5540
6303
|
|
5541
6304
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
@@ -5547,6 +6310,11 @@ static void ggml_compute_forward_dup_f16(
|
|
5547
6310
|
const int64_t ne02 = src0->ne[2];
|
5548
6311
|
const int64_t ne03 = src0->ne[3];
|
5549
6312
|
|
6313
|
+
const int64_t ne0 = dst->ne[0];
|
6314
|
+
const int64_t ne1 = dst->ne[1];
|
6315
|
+
const int64_t ne2 = dst->ne[2];
|
6316
|
+
const int64_t ne3 = dst->ne[3];
|
6317
|
+
|
5550
6318
|
const size_t nb00 = src0->nb[0];
|
5551
6319
|
const size_t nb01 = src0->nb[1];
|
5552
6320
|
const size_t nb02 = src0->nb[2];
|
@@ -5557,19 +6325,40 @@ static void ggml_compute_forward_dup_f16(
|
|
5557
6325
|
const size_t nb2 = dst->nb[2];
|
5558
6326
|
const size_t nb3 = dst->nb[3];
|
5559
6327
|
|
6328
|
+
const int ith = params->ith; // thread index
|
6329
|
+
const int nth = params->nth; // number of threads
|
6330
|
+
|
5560
6331
|
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
|
5561
|
-
|
6332
|
+
// parallelize by elements
|
6333
|
+
const int ne = ggml_nelements(dst);
|
6334
|
+
const int dr = (ne + nth - 1) / nth;
|
6335
|
+
const int ie0 = dr * ith;
|
6336
|
+
const int ie1 = MIN(ie0 + dr, ne);
|
6337
|
+
|
6338
|
+
memcpy(
|
6339
|
+
((char *) dst->data + ie0*nb0),
|
6340
|
+
((char *) src0->data + ie0*nb00),
|
6341
|
+
(ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
|
6342
|
+
|
5562
6343
|
return;
|
5563
6344
|
}
|
5564
6345
|
|
6346
|
+
// parallelize by rows
|
6347
|
+
const int nr = ne01;
|
6348
|
+
// number of rows per thread
|
6349
|
+
const int dr = (nr + nth - 1) / nth;
|
6350
|
+
// row range for this thread
|
6351
|
+
const int ir0 = dr * ith;
|
6352
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
6353
|
+
|
5565
6354
|
if (src0->type == dst->type &&
|
5566
|
-
|
5567
|
-
|
6355
|
+
ne00 == ne0 &&
|
6356
|
+
nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
|
5568
6357
|
// copy by rows
|
5569
6358
|
const size_t rs = ne00*nb00;
|
5570
6359
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5571
6360
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5572
|
-
for (int64_t i01 =
|
6361
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5573
6362
|
memcpy(
|
5574
6363
|
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
5575
6364
|
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
@@ -5583,21 +6372,21 @@ static void ggml_compute_forward_dup_f16(
|
|
5583
6372
|
// TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
|
5584
6373
|
|
5585
6374
|
if (ggml_is_contiguous(dst)) {
|
5586
|
-
if (
|
6375
|
+
if (nb00 == sizeof(ggml_fp16_t)) {
|
5587
6376
|
if (dst->type == GGML_TYPE_F16) {
|
5588
6377
|
size_t id = 0;
|
5589
|
-
const size_t rs = ne00*nb00;
|
6378
|
+
const size_t rs = ne00 * nb00;
|
6379
|
+
char * dst_ptr = (char *) dst->data;
|
5590
6380
|
|
5591
6381
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5592
6382
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5593
|
-
|
6383
|
+
id += rs * ir0;
|
6384
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5594
6385
|
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
5595
|
-
|
5596
|
-
|
5597
|
-
memcpy(dst_ptr, src0_ptr, rs);
|
5598
|
-
|
5599
|
-
id++;
|
6386
|
+
memcpy(dst_ptr + id, src0_ptr, rs);
|
6387
|
+
id += rs;
|
5600
6388
|
}
|
6389
|
+
id += rs * (ne01 - ir1);
|
5601
6390
|
}
|
5602
6391
|
}
|
5603
6392
|
} else if (dst->type == GGML_TYPE_F32) {
|
@@ -5606,34 +6395,39 @@ static void ggml_compute_forward_dup_f16(
|
|
5606
6395
|
|
5607
6396
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5608
6397
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5609
|
-
|
6398
|
+
id += ne00 * ir0;
|
6399
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
6400
|
+
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
5610
6401
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5611
|
-
|
5612
|
-
|
5613
|
-
dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
|
6402
|
+
dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
|
5614
6403
|
id++;
|
5615
6404
|
}
|
5616
6405
|
}
|
6406
|
+
id += ne00 * (ne01 - ir1);
|
5617
6407
|
}
|
5618
6408
|
}
|
5619
|
-
} else if (dst->type
|
6409
|
+
} else if (ggml_is_quantized(dst->type)) {
|
5620
6410
|
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
|
6411
|
+
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
6412
|
+
|
5621
6413
|
size_t id = 0;
|
5622
|
-
|
5623
|
-
|
5624
|
-
float * src0_f32 = (float *) params->wdata;
|
6414
|
+
size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
|
6415
|
+
char * dst_ptr = (char *) dst->data;
|
5625
6416
|
|
5626
6417
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5627
6418
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5628
|
-
|
6419
|
+
id += rs * ir0;
|
6420
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5629
6421
|
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
5630
|
-
|
6422
|
+
|
5631
6423
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5632
6424
|
src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
|
5633
6425
|
}
|
6426
|
+
|
5634
6427
|
quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
5635
|
-
id +=
|
6428
|
+
id += rs;
|
5636
6429
|
}
|
6430
|
+
id += rs * (ne01 - ir1);
|
5637
6431
|
}
|
5638
6432
|
}
|
5639
6433
|
} else {
|
@@ -5648,7 +6442,8 @@ static void ggml_compute_forward_dup_f16(
|
|
5648
6442
|
|
5649
6443
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5650
6444
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5651
|
-
|
6445
|
+
id += ne00 * ir0;
|
6446
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5652
6447
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5653
6448
|
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5654
6449
|
|
@@ -5656,6 +6451,7 @@ static void ggml_compute_forward_dup_f16(
|
|
5656
6451
|
id++;
|
5657
6452
|
}
|
5658
6453
|
}
|
6454
|
+
id += ne00 * (ne01 - ir1);
|
5659
6455
|
}
|
5660
6456
|
}
|
5661
6457
|
} else if (dst->type == GGML_TYPE_F16) {
|
@@ -5664,7 +6460,8 @@ static void ggml_compute_forward_dup_f16(
|
|
5664
6460
|
|
5665
6461
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5666
6462
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5667
|
-
|
6463
|
+
id += ne00 * ir0;
|
6464
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5668
6465
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5669
6466
|
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5670
6467
|
|
@@ -5672,6 +6469,7 @@ static void ggml_compute_forward_dup_f16(
|
|
5672
6469
|
id++;
|
5673
6470
|
}
|
5674
6471
|
}
|
6472
|
+
id += ne00 * (ne01 - ir1);
|
5675
6473
|
}
|
5676
6474
|
}
|
5677
6475
|
} else {
|
@@ -5690,7 +6488,20 @@ static void ggml_compute_forward_dup_f16(
|
|
5690
6488
|
if (dst->type == GGML_TYPE_F16) {
|
5691
6489
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5692
6490
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5693
|
-
|
6491
|
+
i10 += ne00 * ir0;
|
6492
|
+
while (i10 >= ne0) {
|
6493
|
+
i10 -= ne0;
|
6494
|
+
if (++i11 == ne1) {
|
6495
|
+
i11 = 0;
|
6496
|
+
if (++i12 == ne2) {
|
6497
|
+
i12 = 0;
|
6498
|
+
if (++i13 == ne3) {
|
6499
|
+
i13 = 0;
|
6500
|
+
}
|
6501
|
+
}
|
6502
|
+
}
|
6503
|
+
}
|
6504
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5694
6505
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
5695
6506
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5696
6507
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
@@ -5711,25 +6522,51 @@ static void ggml_compute_forward_dup_f16(
|
|
5711
6522
|
}
|
5712
6523
|
}
|
5713
6524
|
}
|
6525
|
+
i10 += ne00 * (ne01 - ir1);
|
6526
|
+
while (i10 >= ne0) {
|
6527
|
+
i10 -= ne0;
|
6528
|
+
if (++i11 == ne1) {
|
6529
|
+
i11 = 0;
|
6530
|
+
if (++i12 == ne2) {
|
6531
|
+
i12 = 0;
|
6532
|
+
if (++i13 == ne3) {
|
6533
|
+
i13 = 0;
|
6534
|
+
}
|
6535
|
+
}
|
6536
|
+
}
|
6537
|
+
}
|
5714
6538
|
}
|
5715
6539
|
}
|
5716
6540
|
} else if (dst->type == GGML_TYPE_F32) {
|
5717
6541
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5718
6542
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5719
|
-
|
6543
|
+
i10 += ne00 * ir0;
|
6544
|
+
while (i10 >= ne0) {
|
6545
|
+
i10 -= ne0;
|
6546
|
+
if (++i11 == ne1) {
|
6547
|
+
i11 = 0;
|
6548
|
+
if (++i12 == ne2) {
|
6549
|
+
i12 = 0;
|
6550
|
+
if (++i13 == ne3) {
|
6551
|
+
i13 = 0;
|
6552
|
+
}
|
6553
|
+
}
|
6554
|
+
}
|
6555
|
+
}
|
6556
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5720
6557
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
5721
6558
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5722
6559
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
5723
6560
|
|
5724
6561
|
*(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
|
5725
6562
|
|
5726
|
-
if (++i10 ==
|
6563
|
+
if (++i10 == ne0) {
|
5727
6564
|
i10 = 0;
|
5728
|
-
if (++i11 ==
|
6565
|
+
if (++i11 == ne1) {
|
5729
6566
|
i11 = 0;
|
5730
|
-
if (++i12 ==
|
6567
|
+
if (++i12 == ne2) {
|
5731
6568
|
i12 = 0;
|
5732
|
-
if (++i13 ==
|
6569
|
+
if (++i13 == ne3) {
|
5733
6570
|
i13 = 0;
|
5734
6571
|
}
|
5735
6572
|
}
|
@@ -5737,6 +6574,19 @@ static void ggml_compute_forward_dup_f16(
|
|
5737
6574
|
}
|
5738
6575
|
}
|
5739
6576
|
}
|
6577
|
+
i10 += ne00 * (ne01 - ir1);
|
6578
|
+
while (i10 >= ne0) {
|
6579
|
+
i10 -= ne0;
|
6580
|
+
if (++i11 == ne1) {
|
6581
|
+
i11 = 0;
|
6582
|
+
if (++i12 == ne2) {
|
6583
|
+
i12 = 0;
|
6584
|
+
if (++i13 == ne3) {
|
6585
|
+
i13 = 0;
|
6586
|
+
}
|
6587
|
+
}
|
6588
|
+
}
|
6589
|
+
}
|
5740
6590
|
}
|
5741
6591
|
}
|
5742
6592
|
} else {
|
@@ -5748,7 +6598,6 @@ static void ggml_compute_forward_dup_f32(
|
|
5748
6598
|
const struct ggml_compute_params * params,
|
5749
6599
|
const struct ggml_tensor * src0,
|
5750
6600
|
struct ggml_tensor * dst) {
|
5751
|
-
GGML_ASSERT(params->ith == 0);
|
5752
6601
|
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
5753
6602
|
|
5754
6603
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
@@ -5760,6 +6609,11 @@ static void ggml_compute_forward_dup_f32(
|
|
5760
6609
|
const int64_t ne02 = src0->ne[2];
|
5761
6610
|
const int64_t ne03 = src0->ne[3];
|
5762
6611
|
|
6612
|
+
const int64_t ne0 = dst->ne[0];
|
6613
|
+
const int64_t ne1 = dst->ne[1];
|
6614
|
+
const int64_t ne2 = dst->ne[2];
|
6615
|
+
const int64_t ne3 = dst->ne[3];
|
6616
|
+
|
5763
6617
|
const size_t nb00 = src0->nb[0];
|
5764
6618
|
const size_t nb01 = src0->nb[1];
|
5765
6619
|
const size_t nb02 = src0->nb[2];
|
@@ -5770,19 +6624,40 @@ static void ggml_compute_forward_dup_f32(
|
|
5770
6624
|
const size_t nb2 = dst->nb[2];
|
5771
6625
|
const size_t nb3 = dst->nb[3];
|
5772
6626
|
|
6627
|
+
const int ith = params->ith; // thread index
|
6628
|
+
const int nth = params->nth; // number of threads
|
6629
|
+
|
5773
6630
|
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
|
5774
|
-
|
6631
|
+
// parallelize by elements
|
6632
|
+
const int ne = ggml_nelements(dst);
|
6633
|
+
const int dr = (ne + nth - 1) / nth;
|
6634
|
+
const int ie0 = dr * ith;
|
6635
|
+
const int ie1 = MIN(ie0 + dr, ne);
|
6636
|
+
|
6637
|
+
memcpy(
|
6638
|
+
((char *) dst->data + ie0*nb0),
|
6639
|
+
((char *) src0->data + ie0*nb00),
|
6640
|
+
(ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
|
6641
|
+
|
5775
6642
|
return;
|
5776
6643
|
}
|
5777
6644
|
|
6645
|
+
// parallelize by rows
|
6646
|
+
const int nr = ne01;
|
6647
|
+
// number of rows per thread
|
6648
|
+
const int dr = (nr + nth - 1) / nth;
|
6649
|
+
// row range for this thread
|
6650
|
+
const int ir0 = dr * ith;
|
6651
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
6652
|
+
|
5778
6653
|
if (src0->type == dst->type &&
|
5779
|
-
|
5780
|
-
|
6654
|
+
ne00 == ne0 &&
|
6655
|
+
nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
|
5781
6656
|
// copy by rows
|
5782
6657
|
const size_t rs = ne00*nb00;
|
5783
6658
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5784
6659
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5785
|
-
for (int64_t i01 =
|
6660
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5786
6661
|
memcpy(
|
5787
6662
|
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
5788
6663
|
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
@@ -5795,21 +6670,21 @@ static void ggml_compute_forward_dup_f32(
|
|
5795
6670
|
|
5796
6671
|
if (ggml_is_contiguous(dst)) {
|
5797
6672
|
// TODO: simplify
|
5798
|
-
if (
|
6673
|
+
if (nb00 == sizeof(float)) {
|
5799
6674
|
if (dst->type == GGML_TYPE_F32) {
|
5800
6675
|
size_t id = 0;
|
5801
|
-
const size_t rs = ne00*nb00;
|
6676
|
+
const size_t rs = ne00 * nb00;
|
6677
|
+
char * dst_ptr = (char *) dst->data;
|
5802
6678
|
|
5803
6679
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5804
6680
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5805
|
-
|
6681
|
+
id += rs * ir0;
|
6682
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5806
6683
|
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
5807
|
-
|
5808
|
-
|
5809
|
-
memcpy(dst_ptr, src0_ptr, rs);
|
5810
|
-
|
5811
|
-
id++;
|
6684
|
+
memcpy(dst_ptr + id, src0_ptr, rs);
|
6685
|
+
id += rs;
|
5812
6686
|
}
|
6687
|
+
id += rs * (ne01 - ir1);
|
5813
6688
|
}
|
5814
6689
|
}
|
5815
6690
|
} else if (dst->type == GGML_TYPE_F16) {
|
@@ -5818,7 +6693,8 @@ static void ggml_compute_forward_dup_f32(
|
|
5818
6693
|
|
5819
6694
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5820
6695
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5821
|
-
|
6696
|
+
id += ne00 * ir0;
|
6697
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5822
6698
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5823
6699
|
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5824
6700
|
|
@@ -5826,21 +6702,25 @@ static void ggml_compute_forward_dup_f32(
|
|
5826
6702
|
id++;
|
5827
6703
|
}
|
5828
6704
|
}
|
6705
|
+
id += ne00 * (ne01 - ir1);
|
5829
6706
|
}
|
5830
6707
|
}
|
5831
|
-
} else if (dst->type
|
6708
|
+
} else if (ggml_is_quantized(dst->type)) {
|
5832
6709
|
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
|
6710
|
+
|
5833
6711
|
size_t id = 0;
|
5834
|
-
|
5835
|
-
|
6712
|
+
size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
|
6713
|
+
char * dst_ptr = (char *) dst->data;
|
5836
6714
|
|
5837
6715
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5838
6716
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5839
|
-
|
6717
|
+
id += rs * ir0;
|
6718
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5840
6719
|
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
5841
6720
|
quantize_row_q(src0_ptr, dst_ptr + id, ne00);
|
5842
|
-
id +=
|
6721
|
+
id += rs;
|
5843
6722
|
}
|
6723
|
+
id += rs * (ne01 - ir1);
|
5844
6724
|
}
|
5845
6725
|
}
|
5846
6726
|
} else {
|
@@ -5855,7 +6735,8 @@ static void ggml_compute_forward_dup_f32(
|
|
5855
6735
|
|
5856
6736
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5857
6737
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5858
|
-
|
6738
|
+
id += ne00 * ir0;
|
6739
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5859
6740
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5860
6741
|
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5861
6742
|
|
@@ -5863,6 +6744,7 @@ static void ggml_compute_forward_dup_f32(
|
|
5863
6744
|
id++;
|
5864
6745
|
}
|
5865
6746
|
}
|
6747
|
+
id += ne00 * (ne01 - ir1);
|
5866
6748
|
}
|
5867
6749
|
}
|
5868
6750
|
} else if (dst->type == GGML_TYPE_F16) {
|
@@ -5871,7 +6753,8 @@ static void ggml_compute_forward_dup_f32(
|
|
5871
6753
|
|
5872
6754
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5873
6755
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5874
|
-
|
6756
|
+
id += ne00 * ir0;
|
6757
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5875
6758
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5876
6759
|
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5877
6760
|
|
@@ -5879,6 +6762,7 @@ static void ggml_compute_forward_dup_f32(
|
|
5879
6762
|
id++;
|
5880
6763
|
}
|
5881
6764
|
}
|
6765
|
+
id += ne00 * (ne01 - ir1);
|
5882
6766
|
}
|
5883
6767
|
}
|
5884
6768
|
} else {
|
@@ -5890,6 +6774,7 @@ static void ggml_compute_forward_dup_f32(
|
|
5890
6774
|
}
|
5891
6775
|
|
5892
6776
|
// dst counters
|
6777
|
+
|
5893
6778
|
int64_t i10 = 0;
|
5894
6779
|
int64_t i11 = 0;
|
5895
6780
|
int64_t i12 = 0;
|
@@ -5898,20 +6783,33 @@ static void ggml_compute_forward_dup_f32(
|
|
5898
6783
|
if (dst->type == GGML_TYPE_F32) {
|
5899
6784
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5900
6785
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5901
|
-
|
6786
|
+
i10 += ne00 * ir0;
|
6787
|
+
while (i10 >= ne0) {
|
6788
|
+
i10 -= ne0;
|
6789
|
+
if (++i11 == ne1) {
|
6790
|
+
i11 = 0;
|
6791
|
+
if (++i12 == ne2) {
|
6792
|
+
i12 = 0;
|
6793
|
+
if (++i13 == ne3) {
|
6794
|
+
i13 = 0;
|
6795
|
+
}
|
6796
|
+
}
|
6797
|
+
}
|
6798
|
+
}
|
6799
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5902
6800
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
5903
6801
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5904
6802
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
5905
6803
|
|
5906
6804
|
memcpy(dst_ptr, src0_ptr, sizeof(float));
|
5907
6805
|
|
5908
|
-
if (++i10 ==
|
6806
|
+
if (++i10 == ne0) {
|
5909
6807
|
i10 = 0;
|
5910
|
-
if (++i11 ==
|
6808
|
+
if (++i11 == ne1) {
|
5911
6809
|
i11 = 0;
|
5912
|
-
if (++i12 ==
|
6810
|
+
if (++i12 == ne2) {
|
5913
6811
|
i12 = 0;
|
5914
|
-
if (++i13 ==
|
6812
|
+
if (++i13 == ne3) {
|
5915
6813
|
i13 = 0;
|
5916
6814
|
}
|
5917
6815
|
}
|
@@ -5919,25 +6817,51 @@ static void ggml_compute_forward_dup_f32(
|
|
5919
6817
|
}
|
5920
6818
|
}
|
5921
6819
|
}
|
6820
|
+
i10 += ne00 * (ne01 - ir1);
|
6821
|
+
while (i10 >= ne0) {
|
6822
|
+
i10 -= ne0;
|
6823
|
+
if (++i11 == ne1) {
|
6824
|
+
i11 = 0;
|
6825
|
+
if (++i12 == ne2) {
|
6826
|
+
i12 = 0;
|
6827
|
+
if (++i13 == ne3) {
|
6828
|
+
i13 = 0;
|
6829
|
+
}
|
6830
|
+
}
|
6831
|
+
}
|
6832
|
+
}
|
5922
6833
|
}
|
5923
6834
|
}
|
5924
6835
|
} else if (dst->type == GGML_TYPE_F16) {
|
5925
6836
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5926
6837
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5927
|
-
|
6838
|
+
i10 += ne00 * ir0;
|
6839
|
+
while (i10 >= ne0) {
|
6840
|
+
i10 -= ne0;
|
6841
|
+
if (++i11 == ne1) {
|
6842
|
+
i11 = 0;
|
6843
|
+
if (++i12 == ne2) {
|
6844
|
+
i12 = 0;
|
6845
|
+
if (++i13 == ne3) {
|
6846
|
+
i13 = 0;
|
6847
|
+
}
|
6848
|
+
}
|
6849
|
+
}
|
6850
|
+
}
|
6851
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5928
6852
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
5929
6853
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5930
6854
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
5931
6855
|
|
5932
6856
|
*(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
|
5933
6857
|
|
5934
|
-
if (++i10 ==
|
6858
|
+
if (++i10 == ne0) {
|
5935
6859
|
i10 = 0;
|
5936
|
-
if (++i11 ==
|
6860
|
+
if (++i11 == ne1) {
|
5937
6861
|
i11 = 0;
|
5938
|
-
if (++i12 ==
|
6862
|
+
if (++i12 == ne2) {
|
5939
6863
|
i12 = 0;
|
5940
|
-
if (++i13 ==
|
6864
|
+
if (++i13 == ne3) {
|
5941
6865
|
i13 = 0;
|
5942
6866
|
}
|
5943
6867
|
}
|
@@ -5945,6 +6869,19 @@ static void ggml_compute_forward_dup_f32(
|
|
5945
6869
|
}
|
5946
6870
|
}
|
5947
6871
|
}
|
6872
|
+
i10 += ne00 * (ne01 - ir1);
|
6873
|
+
while (i10 >= ne0) {
|
6874
|
+
i10 -= ne0;
|
6875
|
+
if (++i11 == ne1) {
|
6876
|
+
i11 = 0;
|
6877
|
+
if (++i12 == ne2) {
|
6878
|
+
i12 = 0;
|
6879
|
+
if (++i13 == ne3) {
|
6880
|
+
i13 = 0;
|
6881
|
+
}
|
6882
|
+
}
|
6883
|
+
}
|
6884
|
+
}
|
5948
6885
|
}
|
5949
6886
|
}
|
5950
6887
|
} else {
|
@@ -6191,7 +7128,7 @@ static void ggml_compute_forward_add_q_f32(
|
|
6191
7128
|
GGML_ASSERT(nb1 <= nb2);
|
6192
7129
|
GGML_ASSERT(nb2 <= nb3);
|
6193
7130
|
|
6194
|
-
GGML_ASSERT(src0->type
|
7131
|
+
GGML_ASSERT(ggml_is_quantized(src0->type));
|
6195
7132
|
GGML_ASSERT(dst->type == src0->type);
|
6196
7133
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
6197
7134
|
|
@@ -6205,7 +7142,7 @@ static void ggml_compute_forward_add_q_f32(
|
|
6205
7142
|
const int ir0 = dr*ith;
|
6206
7143
|
const int ir1 = MIN(ir0 + dr, nr);
|
6207
7144
|
|
6208
|
-
float * wdata = (float*) params->wdata + ne00 * ith;
|
7145
|
+
float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
6209
7146
|
|
6210
7147
|
for (int ir = ir0; ir < ir1; ++ir) {
|
6211
7148
|
// src0 indices
|
@@ -6261,6 +7198,11 @@ static void ggml_compute_forward_add(
|
|
6261
7198
|
} break;
|
6262
7199
|
case GGML_TYPE_Q4_0:
|
6263
7200
|
case GGML_TYPE_Q4_1:
|
7201
|
+
case GGML_TYPE_Q4_2:
|
7202
|
+
case GGML_TYPE_Q4_3:
|
7203
|
+
case GGML_TYPE_Q5_0:
|
7204
|
+
case GGML_TYPE_Q5_1:
|
7205
|
+
case GGML_TYPE_Q8_0:
|
6264
7206
|
{
|
6265
7207
|
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
6266
7208
|
} break;
|
@@ -6518,15 +7460,20 @@ static void ggml_compute_forward_sum_f32(
|
|
6518
7460
|
const size_t nb02 = src0->nb[2];
|
6519
7461
|
const size_t nb03 = src0->nb[3];
|
6520
7462
|
|
7463
|
+
ggml_float sum = 0;
|
7464
|
+
ggml_float row_sum = 0;
|
7465
|
+
|
6521
7466
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
6522
7467
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
6523
7468
|
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
6524
|
-
|
6525
|
-
|
7469
|
+
ggml_vec_sum_ggf(ne00,
|
7470
|
+
&row_sum,
|
6526
7471
|
(float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
|
7472
|
+
sum += row_sum;
|
6527
7473
|
}
|
6528
7474
|
}
|
6529
7475
|
}
|
7476
|
+
((float *) dst->data)[0] = sum;
|
6530
7477
|
}
|
6531
7478
|
|
6532
7479
|
static void ggml_compute_forward_sum(
|
@@ -7161,7 +8108,7 @@ static void ggml_compute_forward_rms_norm(
|
|
7161
8108
|
|
7162
8109
|
// ggml_compute_forward_mul_mat
|
7163
8110
|
|
7164
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
8111
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
7165
8112
|
// helper function to determine if it is better to use BLAS or not
|
7166
8113
|
// for large matrices, BLAS is faster
|
7167
8114
|
static bool ggml_compute_forward_mul_mat_use_blas(
|
@@ -7186,6 +8133,7 @@ static bool ggml_compute_forward_mul_mat_use_blas(
|
|
7186
8133
|
|
7187
8134
|
return false;
|
7188
8135
|
}
|
8136
|
+
|
7189
8137
|
#endif
|
7190
8138
|
|
7191
8139
|
static void ggml_compute_forward_mul_mat_f32(
|
@@ -7201,7 +8149,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
7201
8149
|
const int64_t ne02 = src0->ne[2];
|
7202
8150
|
const int64_t ne03 = src0->ne[3];
|
7203
8151
|
|
7204
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
8152
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
7205
8153
|
const int64_t ne10 = src1->ne[0];
|
7206
8154
|
#endif
|
7207
8155
|
const int64_t ne11 = src1->ne[1];
|
@@ -7258,7 +8206,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
7258
8206
|
// nb01 >= nb00 - src0 is not transposed
|
7259
8207
|
// compute by src0 rows
|
7260
8208
|
|
7261
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
8209
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
7262
8210
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
7263
8211
|
if (params->ith != 0) {
|
7264
8212
|
return;
|
@@ -7272,6 +8220,19 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
7272
8220
|
return;
|
7273
8221
|
}
|
7274
8222
|
|
8223
|
+
#if defined(GGML_USE_CUBLAS)
|
8224
|
+
const float alpha = 1.0f;
|
8225
|
+
const float beta = 0.0f;
|
8226
|
+
const int x_ne = ne01 * ne10;
|
8227
|
+
const int y_ne = ne11 * ne10;
|
8228
|
+
const int d_ne = ne11 * ne01;
|
8229
|
+
|
8230
|
+
size_t x_size, y_size, d_size;
|
8231
|
+
float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
|
8232
|
+
float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
|
8233
|
+
float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
|
8234
|
+
#endif
|
8235
|
+
|
7275
8236
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
7276
8237
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
7277
8238
|
const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
|
@@ -7279,15 +8240,44 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
7279
8240
|
|
7280
8241
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
7281
8242
|
|
8243
|
+
#if defined(GGML_USE_CUBLAS)
|
8244
|
+
// copy data to device
|
8245
|
+
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
|
8246
|
+
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
|
8247
|
+
|
8248
|
+
// compute
|
8249
|
+
CUBLAS_CHECK(
|
8250
|
+
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
8251
|
+
ne01, ne11, ne10,
|
8252
|
+
&alpha, d_X, ne00,
|
8253
|
+
d_Y, ne10,
|
8254
|
+
&beta, d_D, ne01));
|
8255
|
+
|
8256
|
+
// copy data to host
|
8257
|
+
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
|
8258
|
+
#elif defined(GGML_USE_CLBLAST)
|
7282
8259
|
// zT = y * xT
|
8260
|
+
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
|
8261
|
+
ne11, ne01, ne10,
|
8262
|
+
1.0f, y, ne10,
|
8263
|
+
x, ne10,
|
8264
|
+
0.0f, d, ne01,
|
8265
|
+
GGML_TYPE_F32);
|
8266
|
+
#else
|
7283
8267
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
7284
8268
|
ne11, ne01, ne10,
|
7285
8269
|
1.0f, y, ne10,
|
7286
8270
|
x, ne00,
|
7287
8271
|
0.0f, d, ne01);
|
8272
|
+
#endif
|
7288
8273
|
}
|
7289
8274
|
}
|
7290
|
-
|
8275
|
+
#if defined(GGML_USE_CUBLAS)
|
8276
|
+
CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
|
8277
|
+
ggml_cuda_pool_free(d_X, x_size);
|
8278
|
+
ggml_cuda_pool_free(d_Y, y_size);
|
8279
|
+
ggml_cuda_pool_free(d_D, d_size);
|
8280
|
+
#endif
|
7291
8281
|
//printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
7292
8282
|
|
7293
8283
|
return;
|
@@ -7417,7 +8407,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
7417
8407
|
// nb01 >= nb00 - src0 is not transposed
|
7418
8408
|
// compute by src0 rows
|
7419
8409
|
|
7420
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
8410
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
7421
8411
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
7422
8412
|
GGML_ASSERT(nb10 == sizeof(float));
|
7423
8413
|
|
@@ -7433,10 +8423,35 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
7433
8423
|
return;
|
7434
8424
|
}
|
7435
8425
|
|
7436
|
-
|
8426
|
+
#if defined(GGML_USE_CUBLAS)
|
8427
|
+
ggml_fp16_t * const wdata = params->wdata;
|
8428
|
+
|
8429
|
+
const float alpha = 1.0f;
|
8430
|
+
const float beta = 0.0f;
|
8431
|
+
const int x_ne = ne01 * ne10;
|
8432
|
+
const int y_ne = ne11 * ne10;
|
8433
|
+
const int d_ne = ne11 * ne01;
|
7437
8434
|
|
8435
|
+
size_t x_size, y_size, d_size;
|
8436
|
+
float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
|
8437
|
+
float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
|
8438
|
+
float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
|
8439
|
+
#else
|
8440
|
+
float * const wdata = params->wdata;
|
8441
|
+
#endif
|
7438
8442
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
7439
8443
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
8444
|
+
#if defined(GGML_USE_CUBLAS)
|
8445
|
+
// with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
|
8446
|
+
{
|
8447
|
+
size_t id = 0;
|
8448
|
+
for (int64_t i01 = 0; i01 < ne11; ++i01) {
|
8449
|
+
for (int64_t i00 = 0; i00 < ne10; ++i00) {
|
8450
|
+
wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
|
8451
|
+
}
|
8452
|
+
}
|
8453
|
+
}
|
8454
|
+
#else
|
7440
8455
|
{
|
7441
8456
|
size_t id = 0;
|
7442
8457
|
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
@@ -7445,7 +8460,44 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
7445
8460
|
}
|
7446
8461
|
}
|
7447
8462
|
}
|
8463
|
+
#endif
|
8464
|
+
|
8465
|
+
#if defined(GGML_USE_CUBLAS)
|
8466
|
+
const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03);
|
8467
|
+
const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
|
8468
|
+
|
8469
|
+
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
8470
|
+
|
8471
|
+
// copy data to device
|
8472
|
+
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
|
8473
|
+
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
|
8474
|
+
|
8475
|
+
// compute
|
8476
|
+
CUBLAS_CHECK(
|
8477
|
+
cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
8478
|
+
ne01, ne11, ne10,
|
8479
|
+
&alpha, d_X, CUDA_R_16F, ne00,
|
8480
|
+
d_Y, CUDA_R_16F, ne10,
|
8481
|
+
&beta, d_D, CUDA_R_32F, ne01,
|
8482
|
+
CUBLAS_COMPUTE_32F,
|
8483
|
+
CUBLAS_GEMM_DEFAULT));
|
8484
|
+
|
8485
|
+
// copy data to host
|
8486
|
+
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
|
8487
|
+
#elif defined(GGML_USE_CLBLAST)
|
8488
|
+
const float * x = wdata;
|
8489
|
+
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
7448
8490
|
|
8491
|
+
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
8492
|
+
|
8493
|
+
// zT = y * xT
|
8494
|
+
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
|
8495
|
+
ne11, ne01, ne10,
|
8496
|
+
1.0f, y, ne10,
|
8497
|
+
x, ne10,
|
8498
|
+
0.0f, d, ne01,
|
8499
|
+
GGML_TYPE_F32);
|
8500
|
+
#else
|
7449
8501
|
const float * x = wdata;
|
7450
8502
|
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
7451
8503
|
|
@@ -7457,9 +8509,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
7457
8509
|
1.0f, y, ne10,
|
7458
8510
|
x, ne00,
|
7459
8511
|
0.0f, d, ne01);
|
8512
|
+
#endif
|
7460
8513
|
}
|
7461
8514
|
}
|
7462
8515
|
|
8516
|
+
#if defined(GGML_USE_CUBLAS)
|
8517
|
+
CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
|
8518
|
+
ggml_cuda_pool_free(d_X, x_size);
|
8519
|
+
ggml_cuda_pool_free(d_Y, y_size);
|
8520
|
+
ggml_cuda_pool_free(d_D, d_size);
|
8521
|
+
#endif
|
7463
8522
|
/*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
|
7464
8523
|
|
7465
8524
|
return;
|
@@ -7592,6 +8651,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7592
8651
|
const enum ggml_type type = src0->type;
|
7593
8652
|
quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
|
7594
8653
|
vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
|
8654
|
+
enum ggml_type const vec_dot_type = quantize_fns[type].vec_dot_type;
|
7595
8655
|
|
7596
8656
|
// we don't support permuted src0 or src1
|
7597
8657
|
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
|
@@ -7611,7 +8671,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7611
8671
|
// nb01 >= nb00 - src0 is not transposed
|
7612
8672
|
// compute by src0 rows
|
7613
8673
|
|
7614
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
8674
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
7615
8675
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
7616
8676
|
if (params->ith != 0) {
|
7617
8677
|
return;
|
@@ -7625,11 +8685,66 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7625
8685
|
return;
|
7626
8686
|
}
|
7627
8687
|
|
8688
|
+
#if defined(GGML_USE_CUBLAS)
|
8689
|
+
const float alpha = 1.0f;
|
8690
|
+
const float beta = 0.0f;
|
8691
|
+
const int x_ne = ne01 * ne10;
|
8692
|
+
const int y_ne = ne11 * ne10;
|
8693
|
+
const int d_ne = ne11 * ne01;
|
8694
|
+
|
8695
|
+
size_t x_size, y_size, d_size, q_size;
|
8696
|
+
float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
|
8697
|
+
float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
|
8698
|
+
float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
|
8699
|
+
float *d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
|
8700
|
+
|
8701
|
+
void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
|
8702
|
+
if (type == GGML_TYPE_Q4_0) {
|
8703
|
+
dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
|
8704
|
+
}
|
8705
|
+
else if (type == GGML_TYPE_Q4_1) {
|
8706
|
+
dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
|
8707
|
+
}
|
8708
|
+
else if (type == GGML_TYPE_Q4_2) {
|
8709
|
+
dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
|
8710
|
+
}
|
8711
|
+
else if (type == GGML_TYPE_Q4_3) {
|
8712
|
+
dequantize_row_q_cuda = dequantize_row_q4_3_cuda;
|
8713
|
+
}
|
8714
|
+
else if (type == GGML_TYPE_Q5_0) {
|
8715
|
+
dequantize_row_q_cuda = dequantize_row_q5_0_cuda;
|
8716
|
+
}
|
8717
|
+
else if (type == GGML_TYPE_Q5_1) {
|
8718
|
+
dequantize_row_q_cuda = dequantize_row_q5_1_cuda;
|
8719
|
+
}
|
8720
|
+
else if (type == GGML_TYPE_Q8_0) {
|
8721
|
+
dequantize_row_q_cuda = dequantize_row_q8_0_cuda;
|
8722
|
+
}
|
8723
|
+
else {
|
8724
|
+
GGML_ASSERT(false);
|
8725
|
+
}
|
8726
|
+
#elif !defined(GGML_USE_CLBLAST)
|
7628
8727
|
float * const wdata = params->wdata;
|
7629
8728
|
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
|
8729
|
+
#endif
|
7630
8730
|
|
7631
8731
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
7632
8732
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
8733
|
+
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
8734
|
+
|
8735
|
+
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
8736
|
+
|
8737
|
+
#if defined(GGML_USE_CUBLAS)
|
8738
|
+
// copy and dequantize on device
|
8739
|
+
CUDA_CHECK(
|
8740
|
+
cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
|
8741
|
+
GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, g_cudaStream));
|
8742
|
+
|
8743
|
+
dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream);
|
8744
|
+
CUDA_CHECK(cudaGetLastError());
|
8745
|
+
#elif defined(GGML_USE_CLBLAST)
|
8746
|
+
const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
|
8747
|
+
#else
|
7633
8748
|
{
|
7634
8749
|
size_t id = 0;
|
7635
8750
|
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
@@ -7637,21 +8752,49 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7637
8752
|
id += ne00;
|
7638
8753
|
}
|
7639
8754
|
}
|
7640
|
-
|
7641
8755
|
const float * x = wdata;
|
7642
|
-
|
8756
|
+
#endif
|
7643
8757
|
|
7644
|
-
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
7645
8758
|
|
8759
|
+
#if defined(GGML_USE_CUBLAS)
|
8760
|
+
// copy data to device
|
8761
|
+
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
|
8762
|
+
|
8763
|
+
// compute
|
8764
|
+
CUBLAS_CHECK(
|
8765
|
+
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
8766
|
+
ne01, ne11, ne10,
|
8767
|
+
&alpha, d_X, ne00,
|
8768
|
+
d_Y, ne10,
|
8769
|
+
&beta, d_D, ne01));
|
8770
|
+
|
8771
|
+
// copy data to host
|
8772
|
+
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
|
8773
|
+
#elif defined(GGML_USE_CLBLAST)
|
7646
8774
|
// zT = y * xT
|
8775
|
+
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
|
8776
|
+
ne11, ne01, ne10,
|
8777
|
+
1.0f, y, ne10,
|
8778
|
+
x, ne10,
|
8779
|
+
0.0f, d, ne01,
|
8780
|
+
type);
|
8781
|
+
#else
|
7647
8782
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
7648
8783
|
ne11, ne01, ne10,
|
7649
8784
|
1.0f, y, ne10,
|
7650
8785
|
x, ne00,
|
7651
8786
|
0.0f, d, ne01);
|
8787
|
+
#endif
|
7652
8788
|
}
|
7653
8789
|
}
|
7654
8790
|
|
8791
|
+
#if defined(GGML_USE_CUBLAS)
|
8792
|
+
CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
|
8793
|
+
ggml_cuda_pool_free(d_X, x_size);
|
8794
|
+
ggml_cuda_pool_free(d_Y, y_size);
|
8795
|
+
ggml_cuda_pool_free(d_D, d_size);
|
8796
|
+
ggml_cuda_pool_free(d_Q, q_size);
|
8797
|
+
#endif
|
7655
8798
|
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
7656
8799
|
|
7657
8800
|
return;
|
@@ -7660,7 +8803,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7660
8803
|
|
7661
8804
|
if (params->type == GGML_TASK_INIT) {
|
7662
8805
|
char * wdata = params->wdata;
|
7663
|
-
const size_t row_size = ne10*GGML_TYPE_SIZE[
|
8806
|
+
const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
|
7664
8807
|
|
7665
8808
|
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
7666
8809
|
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
@@ -7691,7 +8834,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7691
8834
|
const int ir1 = MIN(ir0 + dr, nr);
|
7692
8835
|
|
7693
8836
|
void * wdata = params->wdata;
|
7694
|
-
const size_t row_size = ne00*GGML_TYPE_SIZE[
|
8837
|
+
const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
|
7695
8838
|
|
7696
8839
|
for (int ir = ir0; ir < ir1; ++ir) {
|
7697
8840
|
// src0 indices
|
@@ -7739,7 +8882,12 @@ static void ggml_compute_forward_mul_mat(
|
|
7739
8882
|
switch (src0->type) {
|
7740
8883
|
case GGML_TYPE_Q4_0:
|
7741
8884
|
case GGML_TYPE_Q4_1:
|
8885
|
+
case GGML_TYPE_Q4_2:
|
8886
|
+
case GGML_TYPE_Q4_3:
|
8887
|
+
case GGML_TYPE_Q5_0:
|
8888
|
+
case GGML_TYPE_Q5_1:
|
7742
8889
|
case GGML_TYPE_Q8_0:
|
8890
|
+
case GGML_TYPE_Q8_1:
|
7743
8891
|
{
|
7744
8892
|
ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
|
7745
8893
|
} break;
|
@@ -7756,34 +8904,6 @@ static void ggml_compute_forward_mul_mat(
|
|
7756
8904
|
GGML_ASSERT(false);
|
7757
8905
|
} break;
|
7758
8906
|
}
|
7759
|
-
|
7760
|
-
#if 0
|
7761
|
-
if (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_Q4_1) {
|
7762
|
-
static int first = 8;
|
7763
|
-
printf("src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
|
7764
|
-
printf("src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
|
7765
|
-
printf("dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
|
7766
|
-
if (first) {
|
7767
|
-
--first;
|
7768
|
-
} else {
|
7769
|
-
for (int k = 0; k < dst->ne[1]; ++k) {
|
7770
|
-
for (int j = 0; j < dst->ne[0]/16; ++j) {
|
7771
|
-
for (int i = 0; i < 16; ++i) {
|
7772
|
-
printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
|
7773
|
-
}
|
7774
|
-
printf("\n");
|
7775
|
-
}
|
7776
|
-
printf("\n");
|
7777
|
-
}
|
7778
|
-
printf("\n");
|
7779
|
-
exit(0);
|
7780
|
-
}
|
7781
|
-
} else {
|
7782
|
-
printf("aaaa src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
|
7783
|
-
printf("aaaa src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
|
7784
|
-
printf("aaaa dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
|
7785
|
-
}
|
7786
|
-
#endif
|
7787
8907
|
}
|
7788
8908
|
|
7789
8909
|
// ggml_compute_forward_scale
|
@@ -7994,7 +9114,12 @@ static void ggml_compute_forward_get_rows(
|
|
7994
9114
|
switch (src0->type) {
|
7995
9115
|
case GGML_TYPE_Q4_0:
|
7996
9116
|
case GGML_TYPE_Q4_1:
|
9117
|
+
case GGML_TYPE_Q4_2:
|
9118
|
+
case GGML_TYPE_Q4_3:
|
9119
|
+
case GGML_TYPE_Q5_0:
|
9120
|
+
case GGML_TYPE_Q5_1:
|
7997
9121
|
case GGML_TYPE_Q8_0:
|
9122
|
+
case GGML_TYPE_Q8_1:
|
7998
9123
|
{
|
7999
9124
|
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
8000
9125
|
} break;
|
@@ -8132,6 +9257,7 @@ static void ggml_compute_forward_soft_max_f32(
|
|
8132
9257
|
|
8133
9258
|
uint16_t scvt;
|
8134
9259
|
for (int i = 0; i < nc; i++) {
|
9260
|
+
//printf("p[%3d] = %8.4f\n", i, p[i]);
|
8135
9261
|
if (p[i] == -INFINITY) {
|
8136
9262
|
p[i] = 0.0f;
|
8137
9263
|
} else {
|
@@ -8224,9 +9350,11 @@ static void ggml_compute_forward_rope_f32(
|
|
8224
9350
|
|
8225
9351
|
const float theta_scale = powf(10000.0, -2.0f/n_dims);
|
8226
9352
|
|
9353
|
+
const bool is_neox = mode & 2;
|
9354
|
+
|
8227
9355
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
8228
|
-
for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
8229
|
-
const int p = (mode == 0 ? n_past + i2 : i2);
|
9356
|
+
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
9357
|
+
const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
8230
9358
|
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
8231
9359
|
if (ir++ < ir0) continue;
|
8232
9360
|
if (ir > ir1) break;
|
@@ -8239,14 +9367,25 @@ static void ggml_compute_forward_rope_f32(
|
|
8239
9367
|
|
8240
9368
|
theta *= theta_scale;
|
8241
9369
|
|
8242
|
-
|
8243
|
-
|
9370
|
+
if (!is_neox) {
|
9371
|
+
const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
9372
|
+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
9373
|
+
|
9374
|
+
const float x0 = src[0];
|
9375
|
+
const float x1 = src[1];
|
9376
|
+
|
9377
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
9378
|
+
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
9379
|
+
} else {
|
9380
|
+
const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
9381
|
+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
8244
9382
|
|
8245
|
-
|
8246
|
-
|
9383
|
+
const float x0 = src[0];
|
9384
|
+
const float x1 = src[n_dims/2];
|
8247
9385
|
|
8248
|
-
|
8249
|
-
|
9386
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
9387
|
+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
9388
|
+
}
|
8250
9389
|
}
|
8251
9390
|
}
|
8252
9391
|
}
|
@@ -8301,9 +9440,11 @@ static void ggml_compute_forward_rope_f16(
|
|
8301
9440
|
|
8302
9441
|
const float theta_scale = powf(10000.0, -2.0f/n_dims);
|
8303
9442
|
|
9443
|
+
const bool is_neox = mode & 2;
|
9444
|
+
|
8304
9445
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
8305
|
-
for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
8306
|
-
const int p = (mode == 0 ? n_past + i2 : i2);
|
9446
|
+
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
9447
|
+
const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
8307
9448
|
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
8308
9449
|
if (ir++ < ir0) continue;
|
8309
9450
|
if (ir > ir1) break;
|
@@ -8316,14 +9457,25 @@ static void ggml_compute_forward_rope_f16(
|
|
8316
9457
|
|
8317
9458
|
theta *= theta_scale;
|
8318
9459
|
|
8319
|
-
|
8320
|
-
|
9460
|
+
if (!is_neox) {
|
9461
|
+
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
9462
|
+
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
9463
|
+
|
9464
|
+
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
9465
|
+
const float x1 = GGML_FP16_TO_FP32(src[1]);
|
9466
|
+
|
9467
|
+
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
9468
|
+
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
9469
|
+
} else {
|
9470
|
+
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
9471
|
+
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
8321
9472
|
|
8322
|
-
|
8323
|
-
|
9473
|
+
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
9474
|
+
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
|
8324
9475
|
|
8325
|
-
|
8326
|
-
|
9476
|
+
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
9477
|
+
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
9478
|
+
}
|
8327
9479
|
}
|
8328
9480
|
}
|
8329
9481
|
}
|
@@ -10402,11 +11554,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
10402
11554
|
case GGML_OP_CPY:
|
10403
11555
|
case GGML_OP_DUP:
|
10404
11556
|
{
|
10405
|
-
node->n_tasks =
|
11557
|
+
node->n_tasks = n_threads;
|
10406
11558
|
|
10407
11559
|
size_t cur = 0;
|
10408
|
-
if (node->type
|
10409
|
-
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0];
|
11560
|
+
if (ggml_is_quantized(node->type)) {
|
11561
|
+
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_threads;
|
10410
11562
|
}
|
10411
11563
|
|
10412
11564
|
work_size = MAX(work_size, cur);
|
@@ -10417,7 +11569,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
10417
11569
|
|
10418
11570
|
size_t cur = 0;
|
10419
11571
|
|
10420
|
-
if (node->src0->type
|
11572
|
+
if (ggml_is_quantized(node->src0->type)) {
|
10421
11573
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
|
10422
11574
|
}
|
10423
11575
|
|
@@ -10466,7 +11618,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
10466
11618
|
size_t cur = 0;
|
10467
11619
|
|
10468
11620
|
if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
|
10469
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
11621
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
10470
11622
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
10471
11623
|
node->n_tasks = 1; // TODO: this actually is doing nothing
|
10472
11624
|
// the threads are still spinning
|
@@ -10482,15 +11634,16 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
10482
11634
|
#endif
|
10483
11635
|
} else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
|
10484
11636
|
cur = 0;
|
10485
|
-
} else if (
|
10486
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
11637
|
+
} else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
|
11638
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
10487
11639
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
10488
11640
|
node->n_tasks = 1;
|
10489
11641
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
|
10490
11642
|
} else
|
10491
11643
|
#endif
|
10492
11644
|
{
|
10493
|
-
|
11645
|
+
const enum ggml_type type_q = quantize_fns[node->src0->type].vec_dot_type;
|
11646
|
+
cur = GGML_TYPE_SIZE[type_q]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[type_q];
|
10494
11647
|
}
|
10495
11648
|
} else {
|
10496
11649
|
GGML_ASSERT(false);
|
@@ -10818,9 +11971,9 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
|
|
10818
11971
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
10819
11972
|
struct ggml_tensor * node = cgraph->nodes[i];
|
10820
11973
|
|
10821
|
-
perf_total_per_op_us[node->op] += node->perf_time_us;
|
11974
|
+
perf_total_per_op_us[node->op] += MAX(1, node->perf_time_us);
|
10822
11975
|
|
10823
|
-
GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 ", %" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
|
11976
|
+
GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
|
10824
11977
|
i,
|
10825
11978
|
node->ne[0], node->ne[1], node->ne[2],
|
10826
11979
|
GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
|
@@ -10834,13 +11987,17 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
|
|
10834
11987
|
for (int i = 0; i < cgraph->n_leafs; i++) {
|
10835
11988
|
struct ggml_tensor * node = cgraph->leafs[i];
|
10836
11989
|
|
10837
|
-
GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 "] %8s\n",
|
11990
|
+
GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
|
10838
11991
|
i,
|
10839
11992
|
node->ne[0], node->ne[1],
|
10840
11993
|
GGML_OP_LABEL[node->op]);
|
10841
11994
|
}
|
10842
11995
|
|
10843
11996
|
for (int i = 0; i < GGML_OP_COUNT; i++) {
|
11997
|
+
if (perf_total_per_op_us[i] == 0) {
|
11998
|
+
continue;
|
11999
|
+
}
|
12000
|
+
|
10844
12001
|
GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_LABEL[i], (double) perf_total_per_op_us[i] / 1000.0);
|
10845
12002
|
}
|
10846
12003
|
|
@@ -11674,7 +12831,7 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
|
|
11674
12831
|
|
11675
12832
|
for (int i = 0; i < nb; i++) {
|
11676
12833
|
for (int l = 0; l < QK4_0; l += 2) {
|
11677
|
-
const uint8_t vi0 = y[i].qs[l/2] &
|
12834
|
+
const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
|
11678
12835
|
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
11679
12836
|
|
11680
12837
|
hist[vi0]++;
|
@@ -11697,7 +12854,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
|
|
11697
12854
|
|
11698
12855
|
for (int i = 0; i < nb; i++) {
|
11699
12856
|
for (int l = 0; l < QK4_1; l += 2) {
|
11700
|
-
const uint8_t vi0 = y[i].qs[l/2] &
|
12857
|
+
const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
|
11701
12858
|
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
11702
12859
|
|
11703
12860
|
hist[vi0]++;
|
@@ -11709,6 +12866,184 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
|
|
11709
12866
|
return (n/QK4_1*sizeof(block_q4_1));
|
11710
12867
|
}
|
11711
12868
|
|
12869
|
+
size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist) {
|
12870
|
+
assert(k % QK4_2 == 0);
|
12871
|
+
const int nb = k / QK4_2;
|
12872
|
+
|
12873
|
+
for (int j = 0; j < n; j += k) {
|
12874
|
+
block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
|
12875
|
+
|
12876
|
+
quantize_row_q4_2_reference(src + j, y, k);
|
12877
|
+
|
12878
|
+
for (int i = 0; i < nb; i++) {
|
12879
|
+
for (int l = 0; l < QK4_2; l += 2) {
|
12880
|
+
const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
|
12881
|
+
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
12882
|
+
|
12883
|
+
hist[vi0]++;
|
12884
|
+
hist[vi1]++;
|
12885
|
+
}
|
12886
|
+
}
|
12887
|
+
}
|
12888
|
+
|
12889
|
+
return (n/QK4_2*sizeof(block_q4_2));
|
12890
|
+
}
|
12891
|
+
|
12892
|
+
size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist) {
|
12893
|
+
assert(k % QK4_3 == 0);
|
12894
|
+
const int nb = k / QK4_3;
|
12895
|
+
|
12896
|
+
for (int j = 0; j < n; j += k) {
|
12897
|
+
block_q4_3 * restrict y = (block_q4_3 *)dst + j/QK4_3;
|
12898
|
+
|
12899
|
+
quantize_row_q4_3_reference(src + j, y, k);
|
12900
|
+
|
12901
|
+
for (int i = 0; i < nb; i++) {
|
12902
|
+
for (int l = 0; l < QK4_3; l += 2) {
|
12903
|
+
const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
|
12904
|
+
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
12905
|
+
|
12906
|
+
hist[vi0]++;
|
12907
|
+
hist[vi1]++;
|
12908
|
+
}
|
12909
|
+
}
|
12910
|
+
}
|
12911
|
+
|
12912
|
+
return (n/QK4_3*sizeof(block_q4_3));
|
12913
|
+
}
|
12914
|
+
|
12915
|
+
size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) {
|
12916
|
+
assert(k % QK5_0 == 0);
|
12917
|
+
const int nb = k / QK5_0;
|
12918
|
+
|
12919
|
+
for (int j = 0; j < n; j += k) {
|
12920
|
+
block_q5_0 * restrict y = (block_q5_0 *)dst + j/QK5_0;
|
12921
|
+
|
12922
|
+
quantize_row_q5_0_reference(src + j, y, k);
|
12923
|
+
|
12924
|
+
for (int i = 0; i < nb; i++) {
|
12925
|
+
uint32_t qh;
|
12926
|
+
memcpy(&qh, &y[i].qh, sizeof(qh));
|
12927
|
+
|
12928
|
+
for (int l = 0; l < QK5_0; l += 2) {
|
12929
|
+
const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
|
12930
|
+
const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
|
12931
|
+
|
12932
|
+
// cast to 16 bins
|
12933
|
+
const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
|
12934
|
+
const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2;
|
12935
|
+
|
12936
|
+
hist[vi0]++;
|
12937
|
+
hist[vi1]++;
|
12938
|
+
}
|
12939
|
+
}
|
12940
|
+
}
|
12941
|
+
|
12942
|
+
return (n/QK5_0*sizeof(block_q5_0));
|
12943
|
+
}
|
12944
|
+
|
12945
|
+
size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) {
|
12946
|
+
assert(k % QK5_1 == 0);
|
12947
|
+
const int nb = k / QK5_1;
|
12948
|
+
|
12949
|
+
for (int j = 0; j < n; j += k) {
|
12950
|
+
block_q5_1 * restrict y = (block_q5_1 *)dst + j/QK5_1;
|
12951
|
+
|
12952
|
+
quantize_row_q5_1_reference(src + j, y, k);
|
12953
|
+
|
12954
|
+
for (int i = 0; i < nb; i++) {
|
12955
|
+
uint32_t qh;
|
12956
|
+
memcpy(&qh, &y[i].qh, sizeof(qh));
|
12957
|
+
|
12958
|
+
for (int l = 0; l < QK5_1; l += 2) {
|
12959
|
+
const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
|
12960
|
+
const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
|
12961
|
+
|
12962
|
+
// cast to 16 bins
|
12963
|
+
const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
|
12964
|
+
const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2;
|
12965
|
+
|
12966
|
+
hist[vi0]++;
|
12967
|
+
hist[vi1]++;
|
12968
|
+
}
|
12969
|
+
}
|
12970
|
+
}
|
12971
|
+
|
12972
|
+
return (n/QK5_1*sizeof(block_q5_1));
|
12973
|
+
}
|
12974
|
+
|
12975
|
+
size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) {
|
12976
|
+
assert(k % QK8_0 == 0);
|
12977
|
+
const int nb = k / QK8_0;
|
12978
|
+
|
12979
|
+
for (int j = 0; j < n; j += k) {
|
12980
|
+
block_q8_0 * restrict y = (block_q8_0 *)dst + j/QK8_0;
|
12981
|
+
|
12982
|
+
quantize_row_q8_0_reference(src + j, y, k);
|
12983
|
+
|
12984
|
+
for (int i = 0; i < nb; i++) {
|
12985
|
+
for (int l = 0; l < QK8_0; ++l) {
|
12986
|
+
const int8_t vi = y[i].qs[l];
|
12987
|
+
|
12988
|
+
hist[vi/16 + 8]++;
|
12989
|
+
}
|
12990
|
+
}
|
12991
|
+
}
|
12992
|
+
|
12993
|
+
return (n/QK8_0*sizeof(block_q8_0));
|
12994
|
+
}
|
12995
|
+
|
12996
|
+
size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
|
12997
|
+
size_t result = 0;
|
12998
|
+
switch (type) {
|
12999
|
+
case GGML_TYPE_Q4_0:
|
13000
|
+
{
|
13001
|
+
GGML_ASSERT(start % QK4_0 == 0);
|
13002
|
+
block_q4_0 * block = (block_q4_0*)dst + start / QK4_0;
|
13003
|
+
result = ggml_quantize_q4_0(src + start, block, n, n, hist);
|
13004
|
+
} break;
|
13005
|
+
case GGML_TYPE_Q4_1:
|
13006
|
+
{
|
13007
|
+
GGML_ASSERT(start % QK4_1 == 0);
|
13008
|
+
block_q4_1 * block = (block_q4_1*)dst + start / QK4_1;
|
13009
|
+
result = ggml_quantize_q4_1(src + start, block, n, n, hist);
|
13010
|
+
} break;
|
13011
|
+
case GGML_TYPE_Q4_2:
|
13012
|
+
{
|
13013
|
+
GGML_ASSERT(start % QK4_2 == 0);
|
13014
|
+
block_q4_2 * block = (block_q4_2*)dst + start / QK4_2;
|
13015
|
+
result = ggml_quantize_q4_2(src + start, block, n, n, hist);
|
13016
|
+
} break;
|
13017
|
+
case GGML_TYPE_Q4_3:
|
13018
|
+
{
|
13019
|
+
GGML_ASSERT(start % QK4_3 == 0);
|
13020
|
+
block_q4_3 * block = (block_q4_3*)dst + start / QK4_3;
|
13021
|
+
result = ggml_quantize_q4_3(src + start, block, n, n, hist);
|
13022
|
+
} break;
|
13023
|
+
case GGML_TYPE_Q5_0:
|
13024
|
+
{
|
13025
|
+
GGML_ASSERT(start % QK5_0 == 0);
|
13026
|
+
block_q5_0 * block = (block_q5_0*)dst + start / QK5_0;
|
13027
|
+
result = ggml_quantize_q5_0(src + start, block, n, n, hist);
|
13028
|
+
} break;
|
13029
|
+
case GGML_TYPE_Q5_1:
|
13030
|
+
{
|
13031
|
+
GGML_ASSERT(start % QK5_1 == 0);
|
13032
|
+
block_q5_1 * block = (block_q5_1*)dst + start / QK5_1;
|
13033
|
+
result = ggml_quantize_q5_1(src + start, block, n, n, hist);
|
13034
|
+
} break;
|
13035
|
+
case GGML_TYPE_Q8_0:
|
13036
|
+
{
|
13037
|
+
GGML_ASSERT(start % QK8_0 == 0);
|
13038
|
+
block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
|
13039
|
+
result = ggml_quantize_q8_0(src + start, block, n, n, hist);
|
13040
|
+
} break;
|
13041
|
+
default:
|
13042
|
+
assert(false);
|
13043
|
+
}
|
13044
|
+
return result;
|
13045
|
+
}
|
13046
|
+
|
11712
13047
|
////////////////////////////////////////////////////////////////////////////////
|
11713
13048
|
|
11714
13049
|
int ggml_cpu_has_avx(void) {
|
@@ -11800,13 +13135,33 @@ int ggml_cpu_has_wasm_simd(void) {
|
|
11800
13135
|
}
|
11801
13136
|
|
11802
13137
|
int ggml_cpu_has_blas(void) {
|
11803
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
13138
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
13139
|
+
return 1;
|
13140
|
+
#else
|
13141
|
+
return 0;
|
13142
|
+
#endif
|
13143
|
+
}
|
13144
|
+
|
13145
|
+
int ggml_cpu_has_cublas(void) {
|
13146
|
+
#if defined(GGML_USE_CUBLAS)
|
11804
13147
|
return 1;
|
11805
13148
|
#else
|
11806
13149
|
return 0;
|
11807
13150
|
#endif
|
11808
13151
|
}
|
11809
13152
|
|
13153
|
+
int ggml_cpu_has_clblast(void) {
|
13154
|
+
#if defined(GGML_USE_CLBLAST)
|
13155
|
+
return 1;
|
13156
|
+
#else
|
13157
|
+
return 0;
|
13158
|
+
#endif
|
13159
|
+
}
|
13160
|
+
|
13161
|
+
int ggml_cpu_has_gpublas(void) {
|
13162
|
+
return ggml_cpu_has_cublas() || ggml_cpu_has_clblast();
|
13163
|
+
}
|
13164
|
+
|
11810
13165
|
int ggml_cpu_has_sse3(void) {
|
11811
13166
|
#if defined(__SSE3__)
|
11812
13167
|
return 1;
|