llama_cpp 0.0.6 → 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +20 -1
- data/ext/llama_cpp/extconf.rb +9 -0
- data/ext/llama_cpp/llama_cpp.cpp +762 -36
- data/ext/llama_cpp/src/ggml-cuda.h +11 -4
- data/ext/llama_cpp/src/ggml-opencl.c +398 -0
- data/ext/llama_cpp/src/ggml-opencl.h +24 -0
- data/ext/llama_cpp/src/ggml.c +1957 -909
- data/ext/llama_cpp/src/ggml.h +696 -627
- data/ext/llama_cpp/src/{llama_util.h → llama-util.h} +91 -12
- data/ext/llama_cpp/src/llama.cpp +755 -159
- data/ext/llama_cpp/src/llama.h +85 -34
- data/lib/llama_cpp/client.rb +174 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +43 -11
- data/sig/llama_cpp.rbs +53 -3
- metadata +6 -3
data/ext/llama_cpp/src/ggml.c
CHANGED
@@ -135,57 +135,14 @@ inline static void* ggml_aligned_malloc(size_t size) {
|
|
135
135
|
#define UNUSED(x) (void)(x)
|
136
136
|
#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
|
137
137
|
|
138
|
-
#define GGML_ASSERT(x) \
|
139
|
-
do { \
|
140
|
-
if (!(x)) { \
|
141
|
-
fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
|
142
|
-
abort(); \
|
143
|
-
} \
|
144
|
-
} while (0)
|
145
|
-
|
146
138
|
#if defined(GGML_USE_ACCELERATE)
|
147
139
|
#include <Accelerate/Accelerate.h>
|
148
140
|
#elif defined(GGML_USE_OPENBLAS)
|
149
141
|
#include <cblas.h>
|
150
142
|
#elif defined(GGML_USE_CUBLAS)
|
151
|
-
#include <cublas_v2.h>
|
152
|
-
#include <cuda_runtime.h>
|
153
143
|
#include "ggml-cuda.h"
|
154
|
-
|
155
|
-
#
|
156
|
-
do { \
|
157
|
-
cudaError_t err_ = (err); \
|
158
|
-
if (err_ != cudaSuccess) { \
|
159
|
-
printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
|
160
|
-
cudaGetErrorString(err_)); \
|
161
|
-
exit(1); \
|
162
|
-
} \
|
163
|
-
} while (0)
|
164
|
-
|
165
|
-
#define CUBLAS_CHECK(err) \
|
166
|
-
do { \
|
167
|
-
cublasStatus_t err_ = (err); \
|
168
|
-
if (err_ != CUBLAS_STATUS_SUCCESS) { \
|
169
|
-
printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
|
170
|
-
exit(1); \
|
171
|
-
} \
|
172
|
-
} while (0)
|
173
|
-
|
174
|
-
static cublasHandle_t cublasH = NULL;
|
175
|
-
static cudaStream_t cudaStream = NULL;
|
176
|
-
static void init_cublas(void) {
|
177
|
-
if (cublasH == NULL) {
|
178
|
-
// create cublas handle, bind a stream
|
179
|
-
CUBLAS_CHECK(cublasCreate(&cublasH));
|
180
|
-
|
181
|
-
CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
|
182
|
-
|
183
|
-
CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
|
184
|
-
|
185
|
-
// configure logging to stdout
|
186
|
-
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
|
187
|
-
}
|
188
|
-
}
|
144
|
+
#elif defined(GGML_USE_CLBLAST)
|
145
|
+
#include "ggml-opencl.h"
|
189
146
|
#endif
|
190
147
|
|
191
148
|
#undef MIN
|
@@ -223,9 +180,13 @@ typedef double ggml_float;
|
|
223
180
|
#undef bool
|
224
181
|
#define bool _Bool
|
225
182
|
#else
|
183
|
+
#if defined(_MSC_VER) || defined(__MINGW32__)
|
184
|
+
#include <intrin.h>
|
185
|
+
#else
|
226
186
|
#include <immintrin.h>
|
227
187
|
#endif
|
228
188
|
#endif
|
189
|
+
#endif
|
229
190
|
|
230
191
|
#ifdef __F16C__
|
231
192
|
|
@@ -365,6 +326,20 @@ static ggml_fp16_t table_exp_f16[1 << 16];
|
|
365
326
|
// precomputed f32 table for f16 (256 KB)
|
366
327
|
static float table_f32_f16[1 << 16];
|
367
328
|
|
329
|
+
#if defined(__ARM_NEON) || defined(__wasm_simd128__)
|
330
|
+
#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
|
331
|
+
#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
|
332
|
+
#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
|
333
|
+
#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
|
334
|
+
#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
|
335
|
+
#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
|
336
|
+
#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
|
337
|
+
#define B8(c,s ) B7(c,s, c), B7(c,s, s)
|
338
|
+
|
339
|
+
// precomputed tables for expanding 8bits to 8 bytes (shl 4)
|
340
|
+
static const uint64_t table_b2b_u[1 << 8] = { B8(00, 10) };
|
341
|
+
#endif
|
342
|
+
|
368
343
|
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
|
369
344
|
// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
|
370
345
|
// This is also true for POWER9.
|
@@ -391,6 +366,32 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) {
|
|
391
366
|
return GGML_FP32_TO_FP16(x);
|
392
367
|
}
|
393
368
|
|
369
|
+
void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, size_t n) {
|
370
|
+
for (size_t i = 0; i < n; i++) {
|
371
|
+
y[i] = GGML_FP16_TO_FP32(x[i]);
|
372
|
+
}
|
373
|
+
}
|
374
|
+
|
375
|
+
void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n) {
|
376
|
+
size_t i = 0;
|
377
|
+
#if defined(__F16C__)
|
378
|
+
for (; i + 7 < n; i += 8) {
|
379
|
+
__m256 x_vec = _mm256_loadu_ps(x + i);
|
380
|
+
__m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
|
381
|
+
_mm_storeu_si128((__m128i *)(y + i), y_vec);
|
382
|
+
}
|
383
|
+
for(; i + 3 < n; i += 4) {
|
384
|
+
__m128 x_vec = _mm_loadu_ps(x + i);
|
385
|
+
__m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
|
386
|
+
_mm_storel_epi64((__m128i *)(y + i), y_vec);
|
387
|
+
}
|
388
|
+
#endif
|
389
|
+
for (; i < n; i++) {
|
390
|
+
y[i] = GGML_FP32_TO_FP16(x[i]);
|
391
|
+
}
|
392
|
+
}
|
393
|
+
|
394
|
+
|
394
395
|
//
|
395
396
|
// timing
|
396
397
|
//
|
@@ -473,7 +474,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
|
|
473
474
|
static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
|
474
475
|
{
|
475
476
|
// Load 8 bytes from memory
|
476
|
-
__m128i tmp =
|
477
|
+
__m128i tmp = _mm_loadl_epi64( ( const __m128i* )rsi );
|
477
478
|
|
478
479
|
// Expand bytes into uint16_t values
|
479
480
|
__m128i bytes = _mm_cvtepu8_epi16( tmp );
|
@@ -487,7 +488,46 @@ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
|
|
487
488
|
return bytes;
|
488
489
|
}
|
489
490
|
|
491
|
+
// horizontally add 8 floats
|
492
|
+
static inline float hsum_float_8(const __m256 x) {
|
493
|
+
__m128 res = _mm256_extractf128_ps(x, 1);
|
494
|
+
res = _mm_add_ps(res, _mm256_castps256_ps128(x));
|
495
|
+
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
|
496
|
+
res = _mm_add_ss(res, _mm_movehdup_ps(res));
|
497
|
+
return _mm_cvtss_f32(res);
|
498
|
+
}
|
499
|
+
|
500
|
+
// horizontally add 8 int32_t
|
501
|
+
static inline int hsum_i32_8(const __m256i a) {
|
502
|
+
const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
|
503
|
+
const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
|
504
|
+
const __m128i sum64 = _mm_add_epi32(hi64, sum128);
|
505
|
+
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
|
506
|
+
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
|
507
|
+
}
|
508
|
+
|
509
|
+
// horizontally add 4 int32_t
|
510
|
+
static inline int hsum_i32_4(const __m128i a) {
|
511
|
+
const __m128i hi64 = _mm_unpackhi_epi64(a, a);
|
512
|
+
const __m128i sum64 = _mm_add_epi32(hi64, a);
|
513
|
+
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
|
514
|
+
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
|
515
|
+
}
|
516
|
+
|
490
517
|
#if __AVX2__ || __AVX512F__
|
518
|
+
// spread 32 bits to 32 bytes { 0x00, 0xFF }
|
519
|
+
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
|
520
|
+
uint32_t x32;
|
521
|
+
memcpy(&x32, x, sizeof(uint32_t));
|
522
|
+
const __m256i shuf_mask = _mm256_set_epi64x(
|
523
|
+
0x0303030303030303, 0x0202020202020202,
|
524
|
+
0x0101010101010101, 0x0000000000000000);
|
525
|
+
__m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
|
526
|
+
const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
|
527
|
+
bytes = _mm256_or_si256(bytes, bit_mask);
|
528
|
+
return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
|
529
|
+
}
|
530
|
+
|
491
531
|
// Unpack 32 4-bit fields into 32 bytes
|
492
532
|
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
493
533
|
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
|
@@ -507,9 +547,38 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
|
|
507
547
|
return bytes;
|
508
548
|
}
|
509
549
|
|
550
|
+
// add int16_t pairwise and return as float vector
|
551
|
+
static inline __m256 sum_i16_pairs_float(const __m256i x) {
|
552
|
+
const __m256i ones = _mm256_set1_epi16(1);
|
553
|
+
const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
|
554
|
+
return _mm256_cvtepi32_ps(summed_pairs);
|
555
|
+
}
|
556
|
+
|
557
|
+
// multiply int8_t, add results pairwise twice and return as float vector
|
558
|
+
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
|
559
|
+
// Get absolute values of x vectors
|
560
|
+
const __m256i ax = _mm256_sign_epi8(x, x);
|
561
|
+
// Sign the values of the y vectors
|
562
|
+
const __m256i sy = _mm256_sign_epi8(y, x);
|
563
|
+
#if __AVXVNNI__
|
564
|
+
const __m256i zero = _mm256_setzero_si256();
|
565
|
+
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
|
566
|
+
return _mm256_cvtepi32_ps(summed_pairs);
|
567
|
+
#else
|
568
|
+
// Perform multiplication and create 16-bit values
|
569
|
+
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
570
|
+
return sum_i16_pairs_float(dot);
|
571
|
+
#endif
|
572
|
+
}
|
573
|
+
|
510
574
|
static inline __m128i packNibbles( __m256i bytes )
|
511
575
|
{
|
512
576
|
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
|
577
|
+
#if __AVX512F__
|
578
|
+
const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000
|
579
|
+
bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
|
580
|
+
return _mm256_cvtepi16_epi8(bytes); // abcd_efgh
|
581
|
+
#else
|
513
582
|
const __m256i lowByte = _mm256_set1_epi16( 0xFF );
|
514
583
|
__m256i high = _mm256_andnot_si256( lowByte, bytes );
|
515
584
|
__m256i low = _mm256_and_si256( lowByte, bytes );
|
@@ -520,6 +589,7 @@ static inline __m128i packNibbles( __m256i bytes )
|
|
520
589
|
__m128i r0 = _mm256_castsi256_si128( bytes );
|
521
590
|
__m128i r1 = _mm256_extracti128_si256( bytes, 1 );
|
522
591
|
return _mm_packus_epi16( r0, r1 );
|
592
|
+
#endif
|
523
593
|
}
|
524
594
|
#else
|
525
595
|
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
@@ -605,19 +675,102 @@ float vmaxvq_f32(float32x4_t v) {
|
|
605
675
|
}
|
606
676
|
|
607
677
|
int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
|
608
|
-
|
678
|
+
int8x8_t res;
|
679
|
+
|
680
|
+
res[0] = a[0]; res[1] = b[0];
|
681
|
+
res[2] = a[1]; res[3] = b[1];
|
682
|
+
res[4] = a[2]; res[5] = b[2];
|
683
|
+
res[6] = a[3]; res[7] = b[3];
|
684
|
+
|
685
|
+
return res;
|
609
686
|
}
|
610
687
|
|
611
688
|
int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
|
612
|
-
|
689
|
+
int8x8_t res;
|
690
|
+
|
691
|
+
res[0] = a[4]; res[1] = b[4];
|
692
|
+
res[2] = a[5]; res[3] = b[5];
|
693
|
+
res[4] = a[6]; res[5] = b[6];
|
694
|
+
res[6] = a[7]; res[7] = b[7];
|
695
|
+
|
696
|
+
return res;
|
613
697
|
}
|
614
698
|
|
615
699
|
uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
|
616
|
-
|
700
|
+
uint8x8_t res;
|
701
|
+
|
702
|
+
res[0] = a[0]; res[1] = b[0];
|
703
|
+
res[2] = a[1]; res[3] = b[1];
|
704
|
+
res[4] = a[2]; res[5] = b[2];
|
705
|
+
res[6] = a[3]; res[7] = b[3];
|
706
|
+
|
707
|
+
return res;
|
617
708
|
}
|
618
709
|
|
619
710
|
uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
|
620
|
-
|
711
|
+
uint8x8_t res;
|
712
|
+
|
713
|
+
res[0] = a[4]; res[1] = b[4];
|
714
|
+
res[2] = a[5]; res[3] = b[5];
|
715
|
+
res[4] = a[6]; res[5] = b[6];
|
716
|
+
res[6] = a[7]; res[7] = b[7];
|
717
|
+
|
718
|
+
return res;
|
719
|
+
}
|
720
|
+
|
721
|
+
int8x16_t vzip1q_s8(int8x16_t a, int8x16_t b) {
|
722
|
+
int8x16_t res;
|
723
|
+
|
724
|
+
res[0] = a[0]; res[1] = b[0]; res[2] = a[1]; res[3] = b[1];
|
725
|
+
res[4] = a[2]; res[5] = b[2]; res[6] = a[3]; res[7] = b[3];
|
726
|
+
res[8] = a[4]; res[9] = b[4]; res[10] = a[5]; res[11] = b[5];
|
727
|
+
res[12] = a[6]; res[13] = b[6]; res[14] = a[7]; res[15] = b[7];
|
728
|
+
|
729
|
+
return res;
|
730
|
+
}
|
731
|
+
|
732
|
+
int8x16_t vzip2q_s8(int8x16_t a, int8x16_t b) {
|
733
|
+
int8x16_t res;
|
734
|
+
|
735
|
+
res[0] = a[8]; res[1] = b[8]; res[2] = a[9]; res[3] = b[9];
|
736
|
+
res[4] = a[10]; res[5] = b[10]; res[6] = a[11]; res[7] = b[11];
|
737
|
+
res[8] = a[12]; res[9] = b[12]; res[10] = a[13]; res[11] = b[13];
|
738
|
+
res[12] = a[14]; res[13] = b[14]; res[14] = a[15]; res[15] = b[15];
|
739
|
+
|
740
|
+
return res;
|
741
|
+
}
|
742
|
+
|
743
|
+
uint8x16_t vzip1q_u8(uint8x16_t a, uint8x16_t b) {
|
744
|
+
uint8x16_t res;
|
745
|
+
|
746
|
+
res[0] = a[0]; res[1] = b[0]; res[2] = a[1]; res[3] = b[1];
|
747
|
+
res[4] = a[2]; res[5] = b[2]; res[6] = a[3]; res[7] = b[3];
|
748
|
+
res[8] = a[4]; res[9] = b[4]; res[10] = a[5]; res[11] = b[5];
|
749
|
+
res[12] = a[6]; res[13] = b[6]; res[14] = a[7]; res[15] = b[7];
|
750
|
+
|
751
|
+
return res;
|
752
|
+
}
|
753
|
+
|
754
|
+
uint8x16_t vzip2q_u8(uint8x16_t a, uint8x16_t b) {
|
755
|
+
uint8x16_t res;
|
756
|
+
|
757
|
+
res[0] = a[8]; res[1] = b[8]; res[2] = a[9]; res[3] = b[9];
|
758
|
+
res[4] = a[10]; res[5] = b[10]; res[6] = a[11]; res[7] = b[11];
|
759
|
+
res[8] = a[12]; res[9] = b[12]; res[10] = a[13]; res[11] = b[13];
|
760
|
+
res[12] = a[14]; res[13] = b[14]; res[14] = a[15]; res[15] = b[15];
|
761
|
+
|
762
|
+
return res;
|
763
|
+
}
|
764
|
+
|
765
|
+
int32x4_t vcvtnq_s32_f32(float32x4_t v) {
|
766
|
+
int32x4_t res;
|
767
|
+
|
768
|
+
res[0] = roundf(vgetq_lane_f32(v, 0));
|
769
|
+
res[1] = roundf(vgetq_lane_f32(v, 1));
|
770
|
+
res[2] = roundf(vgetq_lane_f32(v, 2));
|
771
|
+
res[3] = roundf(vgetq_lane_f32(v, 3));
|
772
|
+
|
773
|
+
return res;
|
621
774
|
}
|
622
775
|
|
623
776
|
#endif
|
@@ -646,13 +799,22 @@ typedef struct {
|
|
646
799
|
} block_q4_2;
|
647
800
|
static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
|
648
801
|
|
649
|
-
#define
|
802
|
+
#define QK5_0 32
|
803
|
+
typedef struct {
|
804
|
+
ggml_fp16_t d; // delta
|
805
|
+
uint8_t qh[4]; // 5-th bit of quants
|
806
|
+
uint8_t qs[QK5_0 / 2]; // nibbles / quants
|
807
|
+
} block_q5_0;
|
808
|
+
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
|
809
|
+
|
810
|
+
#define QK5_1 32
|
650
811
|
typedef struct {
|
651
812
|
ggml_fp16_t d; // delta
|
652
813
|
ggml_fp16_t m; // min
|
653
|
-
uint8_t
|
654
|
-
|
655
|
-
|
814
|
+
uint8_t qh[4]; // 5-th bit of quants
|
815
|
+
uint8_t qs[QK5_1 / 2]; // nibbles / quants
|
816
|
+
} block_q5_1;
|
817
|
+
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
|
656
818
|
|
657
819
|
#define QK8_0 32
|
658
820
|
typedef struct {
|
@@ -661,6 +823,14 @@ typedef struct {
|
|
661
823
|
} block_q8_0;
|
662
824
|
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
|
663
825
|
|
826
|
+
#define QK8_1 32
|
827
|
+
typedef struct {
|
828
|
+
float d; // delta
|
829
|
+
float s0; // d * sum(qs[i]) low
|
830
|
+
float s1; // d * sum(qs[i]) high
|
831
|
+
int8_t qs[QK8_1]; // quants
|
832
|
+
} block_q8_1;
|
833
|
+
static_assert(sizeof(block_q8_1) == 3*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
|
664
834
|
|
665
835
|
// reference implementation for deterministic creation of model files
|
666
836
|
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
|
@@ -671,13 +841,17 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
|
|
671
841
|
|
672
842
|
for (int i = 0; i < nb; i++) {
|
673
843
|
float amax = 0.0f; // absolute max
|
844
|
+
float max = 0.0f;
|
674
845
|
|
675
846
|
for (int l = 0; l < QK4_0; l++) {
|
676
847
|
const float v = x[i*QK4_0 + l];
|
677
|
-
|
848
|
+
if (amax < fabsf(v)) {
|
849
|
+
amax = fabsf(v);
|
850
|
+
max = v;
|
851
|
+
}
|
678
852
|
}
|
679
853
|
|
680
|
-
const float d =
|
854
|
+
const float d = max / -8;
|
681
855
|
const float id = d ? 1.0f/d : 0.0f;
|
682
856
|
|
683
857
|
y[i].d = d;
|
@@ -686,8 +860,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
|
|
686
860
|
const float v0 = x[i*QK4_0 + l + 0]*id;
|
687
861
|
const float v1 = x[i*QK4_0 + l + 1]*id;
|
688
862
|
|
689
|
-
const uint8_t vi0 = (int8_t)roundf(v0) + 8;
|
690
|
-
const uint8_t vi1 = (int8_t)roundf(v1) + 8;
|
863
|
+
const uint8_t vi0 = MIN(15, (int8_t)roundf(v0) + 8);
|
864
|
+
const uint8_t vi1 = MIN(15, (int8_t)roundf(v1) + 8);
|
691
865
|
|
692
866
|
assert(vi0 < 16);
|
693
867
|
assert(vi1 < 16);
|
@@ -707,28 +881,43 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
707
881
|
|
708
882
|
#if defined(__POWER9_VECTOR__)
|
709
883
|
const vector float v85 = vec_splats(8.5f);
|
884
|
+
const vector signed int v15 = vec_splats(15);
|
710
885
|
for (int i = 0; i < nb; i++) {
|
711
|
-
float
|
886
|
+
float max = 0.0f;
|
887
|
+
float min = 0.0f;
|
712
888
|
|
889
|
+
vector float asrcv [8];
|
713
890
|
vector float srcv [8];
|
714
|
-
vector float
|
715
|
-
vector float
|
891
|
+
vector float maxv[8];
|
892
|
+
vector float minv[8];
|
716
893
|
|
717
894
|
for (int l = 0; l < 8; l++) srcv[l] = *(vector float *)(x + i*32 + 4*l);
|
718
|
-
for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
|
719
|
-
|
720
|
-
for (int l = 0; l < 4; l++)
|
721
|
-
//for (int l = 0; l < 2; l++)
|
722
|
-
|
723
|
-
|
724
|
-
//for (int l = 0; l < 1; l++)
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
895
|
+
//for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
|
896
|
+
|
897
|
+
for (int l = 0; l < 4; l++) maxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]);
|
898
|
+
//for (int l = 0; l < 2; l++) maxv[4*l] = vec_max(maxv[4*l], maxv[4*l+2]);
|
899
|
+
maxv[0] = vec_max(maxv[0], maxv[2]);
|
900
|
+
maxv[4] = vec_max(maxv[4], maxv[6]);
|
901
|
+
//for (int l = 0; l < 1; l++) maxv[8*l] = vec_max(maxv[8*l], maxv[8*l+4]);
|
902
|
+
maxv[0] = vec_max(maxv[0], maxv[4]);
|
903
|
+
|
904
|
+
for (int l = 0; l < 4; l++) minv[2*l] = vec_min(asrcv[2*l], asrcv[2*l+1]);
|
905
|
+
//for (int l = 0; l < 2; l++) minv[4*l] = vec_min(minv[4*l], minv[4*l+2]);
|
906
|
+
minv[0] = vec_min(minv[0], minv[2]);
|
907
|
+
minv[4] = vec_min(minv[4], minv[6]);
|
908
|
+
//for (int l = 0; l < 1; l++) minv[8*l] = vec_min(minv[8*l], minv[8*l+4]);
|
909
|
+
minv[0] = vec_min(minv[0], minv[4]);
|
910
|
+
|
911
|
+
|
912
|
+
max = MAX(
|
913
|
+
MAX(vec_extract(maxv[0], 0), vec_extract(maxv[0], 1)),
|
914
|
+
MAX(vec_extract(maxv[0], 2), vec_extract(maxv[0], 3)));
|
915
|
+
min = MIN(
|
916
|
+
MIN(vec_extract(minv[0], 0), vec_extract(minv[0], 1)),
|
917
|
+
MIN(vec_extract(minv[0], 2), vec_extract(minv[0], 3)));
|
918
|
+
|
919
|
+
const float magnitude = max >= fabsf(min) ? max : min;
|
920
|
+
const float d = magnitude / -8;
|
732
921
|
const float id = d ? 1.0/d : 0.0;
|
733
922
|
|
734
923
|
y[i].d = d;
|
@@ -738,27 +927,33 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
738
927
|
for (int l = 0; l < 8; l++) {
|
739
928
|
const vector float vf = vec_madd(srcv[l], vid, v85);
|
740
929
|
const vector signed int vi = vec_signed(vf);
|
930
|
+
const vector signed int vc = vec_min(vi, v15);
|
741
931
|
|
742
|
-
pb[2*l + 0] = vec_extract(
|
743
|
-
pb[2*l + 1] = vec_extract(
|
932
|
+
pb[2*l + 0] = vec_extract(vc, 0) | (vec_extract(vc, 1) << 4);
|
933
|
+
pb[2*l + 1] = vec_extract(vc, 2) | (vec_extract(vc, 3) << 4);
|
744
934
|
}
|
745
935
|
}
|
746
936
|
#elif __ARM_NEON
|
747
937
|
for (int i = 0; i < nb; i++) {
|
748
938
|
float32x4_t srcv [8];
|
749
|
-
float32x4_t
|
750
|
-
float32x4_t
|
939
|
+
float32x4_t maxv[8];
|
940
|
+
float32x4_t minv[8];
|
751
941
|
|
752
942
|
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
|
753
|
-
for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
|
754
943
|
|
755
|
-
for (int l = 0; l < 4; l++)
|
756
|
-
for (int l = 0; l < 2; l++)
|
757
|
-
for (int l = 0; l < 1; l++)
|
944
|
+
for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l+1]);
|
945
|
+
for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l+2]);
|
946
|
+
for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l+4]);
|
758
947
|
|
759
|
-
|
948
|
+
for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l+1]);
|
949
|
+
for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l+2]);
|
950
|
+
for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l+4]);
|
951
|
+
|
952
|
+
const float max = vmaxvq_f32(maxv[0]);
|
953
|
+
const float min = vminvq_f32(minv[0]);
|
760
954
|
|
761
|
-
const float
|
955
|
+
const float magnitude = max >= fabsf(min) ? max : min;
|
956
|
+
const float d = magnitude / -8;
|
762
957
|
const float id = d ? 1.0f/d : 0.0f;
|
763
958
|
|
764
959
|
y[i].d = d;
|
@@ -767,9 +962,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
767
962
|
const float32x4_t v = vmulq_n_f32(srcv[l], id);
|
768
963
|
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
|
769
964
|
const int32x4_t vi = vcvtq_s32_f32(vf);
|
965
|
+
const int32x4_t vc = vminq_s32(vi, vdupq_n_s32(15));
|
770
966
|
|
771
|
-
y[i].qs[2*l + 0] = vgetq_lane_s32(
|
772
|
-
y[i].qs[2*l + 1] = vgetq_lane_s32(
|
967
|
+
y[i].qs[2*l + 0] = vgetq_lane_s32(vc, 0) | (vgetq_lane_s32(vc, 1) << 4);
|
968
|
+
y[i].qs[2*l + 1] = vgetq_lane_s32(vc, 2) | (vgetq_lane_s32(vc, 3) << 4);
|
773
969
|
}
|
774
970
|
}
|
775
971
|
#elif defined(__AVX2__)
|
@@ -781,22 +977,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
781
977
|
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
782
978
|
x += 32;
|
783
979
|
|
784
|
-
// Compute max
|
785
|
-
|
786
|
-
__m256
|
787
|
-
|
788
|
-
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
|
789
|
-
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
|
980
|
+
// Compute max for the block
|
981
|
+
__m256 max = _mm256_max_ps( v0, v1 );
|
982
|
+
__m256 maxTmp = _mm256_max_ps( v2, v3 );
|
983
|
+
max = _mm256_max_ps( max, maxTmp );
|
790
984
|
|
791
|
-
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps(
|
985
|
+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
|
792
986
|
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
793
987
|
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
794
988
|
const float maxScalar = _mm_cvtss_f32( max4 );
|
795
989
|
|
990
|
+
// Compute min for the block
|
991
|
+
__m256 min = _mm256_min_ps( v0, v1 );
|
992
|
+
__m256 minTmp = _mm256_min_ps( v2, v3 );
|
993
|
+
min = _mm256_min_ps( min, minTmp );
|
994
|
+
|
995
|
+
__m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
|
996
|
+
min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
|
997
|
+
min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
|
998
|
+
const float minScalar = _mm_cvtss_f32( min4 );
|
999
|
+
|
796
1000
|
// Quantize these floats
|
797
|
-
const float
|
1001
|
+
const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
|
1002
|
+
const float d = magnitude / -8.0f;
|
798
1003
|
y[i].d = d;
|
799
|
-
const float id = (
|
1004
|
+
const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
|
800
1005
|
const __m256 mul = _mm256_set1_ps( id );
|
801
1006
|
|
802
1007
|
// Apply the multiplier
|
@@ -829,9 +1034,11 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
829
1034
|
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
|
830
1035
|
i0 = _mm256_permutevar8x32_epi32( i0, perm );
|
831
1036
|
|
832
|
-
// Apply offset to translate the range from [ -
|
1037
|
+
// Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
|
833
1038
|
const __m256i off = _mm256_set1_epi8( 8 );
|
834
1039
|
i0 = _mm256_add_epi8( i0, off );
|
1040
|
+
const __m256i maxNibble = _mm256_set1_epi8( 15 );
|
1041
|
+
i0 = _mm256_min_epi8( i0, maxNibble );
|
835
1042
|
|
836
1043
|
// Compress the vector into 4 bit/value, and store
|
837
1044
|
__m128i res = packNibbles( i0 );
|
@@ -846,22 +1053,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
846
1053
|
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
847
1054
|
x += 32;
|
848
1055
|
|
849
|
-
// Compute max
|
850
|
-
|
851
|
-
__m256
|
852
|
-
|
853
|
-
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
|
854
|
-
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
|
1056
|
+
// Compute max for the block
|
1057
|
+
__m256 max = _mm256_max_ps( v0, v1 );
|
1058
|
+
__m256 maxTmp = _mm256_max_ps( v2, v3 );
|
1059
|
+
max = _mm256_max_ps( max, maxTmp );
|
855
1060
|
|
856
|
-
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps(
|
1061
|
+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
|
857
1062
|
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
858
1063
|
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
859
1064
|
const float maxScalar = _mm_cvtss_f32( max4 );
|
860
1065
|
|
1066
|
+
// Compute min for the block
|
1067
|
+
__m256 min = _mm256_min_ps( v0, v1 );
|
1068
|
+
__m256 minTmp = _mm256_min_ps( v2, v3 );
|
1069
|
+
min = _mm256_min_ps( min, minTmp );
|
1070
|
+
|
1071
|
+
__m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
|
1072
|
+
min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
|
1073
|
+
min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
|
1074
|
+
const float minScalar = _mm_cvtss_f32( min4 );
|
1075
|
+
|
861
1076
|
// Quantize these floats
|
862
|
-
const float
|
1077
|
+
const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
|
1078
|
+
const float d = magnitude / -8.0f;
|
863
1079
|
y[i].d = d;
|
864
|
-
const float id = (
|
1080
|
+
const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
|
865
1081
|
const __m256 mul = _mm256_set1_ps( id );
|
866
1082
|
|
867
1083
|
// Apply the multiplier
|
@@ -902,10 +1118,13 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
902
1118
|
ni0 = _mm_packs_epi16( ni0, ni2 );
|
903
1119
|
ni4 = _mm_packs_epi16( ni4, ni6 );
|
904
1120
|
|
905
|
-
// Apply offset to translate the range from [ -
|
906
|
-
const __m128i off = _mm_set1_epi8( 8);
|
1121
|
+
// Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
|
1122
|
+
const __m128i off = _mm_set1_epi8( 8 );
|
907
1123
|
ni0 = _mm_add_epi8( ni0, off );
|
908
1124
|
ni4 = _mm_add_epi8( ni4, off );
|
1125
|
+
const __m128i maxNibble = _mm_set1_epi8( 15 );
|
1126
|
+
ni0 = _mm_min_epi8( ni0, maxNibble );
|
1127
|
+
ni4 = _mm_min_epi8( ni4, maxNibble );
|
909
1128
|
|
910
1129
|
// Compress the vector into 4 bit/value, and store
|
911
1130
|
__m128i res = packNibbles( ni0, ni4 );
|
@@ -913,24 +1132,32 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
913
1132
|
}
|
914
1133
|
#elif defined(__wasm_simd128__)
|
915
1134
|
for (int i = 0; i < nb; i++) {
|
916
|
-
float
|
1135
|
+
float max = 0.0f;
|
1136
|
+
float min = 0.0f;
|
917
1137
|
|
918
1138
|
v128_t srcv [8];
|
919
|
-
v128_t
|
920
|
-
v128_t
|
1139
|
+
v128_t maxv[8];
|
1140
|
+
v128_t minv[8];
|
921
1141
|
|
922
1142
|
for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l);
|
923
|
-
for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]);
|
924
1143
|
|
925
|
-
for (int l = 0; l < 4; l++)
|
926
|
-
for (int l = 0; l < 2; l++)
|
927
|
-
for (int l = 0; l < 1; l++)
|
1144
|
+
for (int l = 0; l < 4; l++) maxv[2*l] = wasm_f32x4_max(srcv[2*l], srcv[2*l+1]);
|
1145
|
+
for (int l = 0; l < 2; l++) maxv[4*l] = wasm_f32x4_max(maxv[4*l], maxv[4*l+2]);
|
1146
|
+
for (int l = 0; l < 1; l++) maxv[8*l] = wasm_f32x4_max(maxv[8*l], maxv[8*l+4]);
|
1147
|
+
|
1148
|
+
for (int l = 0; l < 4; l++) minv[2*l] = wasm_f32x4_min(srcv[2*l], srcv[2*l+1]);
|
1149
|
+
for (int l = 0; l < 2; l++) minv[4*l] = wasm_f32x4_min(minv[4*l], minv[4*l+2]);
|
1150
|
+
for (int l = 0; l < 1; l++) minv[8*l] = wasm_f32x4_min(minv[8*l], minv[8*l+4]);
|
928
1151
|
|
929
|
-
|
930
|
-
MAX(wasm_f32x4_extract_lane(
|
931
|
-
MAX(wasm_f32x4_extract_lane(
|
1152
|
+
max = MAX(
|
1153
|
+
MAX(wasm_f32x4_extract_lane(maxv[0], 0), wasm_f32x4_extract_lane(maxv[0], 1)),
|
1154
|
+
MAX(wasm_f32x4_extract_lane(maxv[0], 2), wasm_f32x4_extract_lane(maxv[0], 3)));
|
1155
|
+
min = MIN(
|
1156
|
+
MIN(wasm_f32x4_extract_lane(minv[0], 0), wasm_f32x4_extract_lane(minv[0], 1)),
|
1157
|
+
MIN(wasm_f32x4_extract_lane(minv[0], 2), wasm_f32x4_extract_lane(minv[0], 3)));
|
932
1158
|
|
933
|
-
const float
|
1159
|
+
const float magnitude = max >= fabsf(min) ? max : min;
|
1160
|
+
const float d = magnitude / -8;
|
934
1161
|
const float id = d ? 1.0/d : 0.0;
|
935
1162
|
|
936
1163
|
y[i].d = d;
|
@@ -939,9 +1166,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
939
1166
|
const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
|
940
1167
|
const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
|
941
1168
|
const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
|
1169
|
+
const v128_t vc = wasm_i32x4_min(vi, wasm_i32x4_splat(15));
|
942
1170
|
|
943
|
-
y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(
|
944
|
-
y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(
|
1171
|
+
y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vc, 0) | (wasm_i32x4_extract_lane(vc, 1) << 4);
|
1172
|
+
y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vc, 2) | (wasm_i32x4_extract_lane(vc, 3) << 4);
|
945
1173
|
}
|
946
1174
|
}
|
947
1175
|
#else
|
@@ -1122,13 +1350,17 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
|
|
1122
1350
|
|
1123
1351
|
for (int i = 0; i < nb; i++) {
|
1124
1352
|
float amax = 0.0f; // absolute max
|
1353
|
+
float max = 0.0f;
|
1125
1354
|
|
1126
1355
|
for (int l = 0; l < QK4_2; l++) {
|
1127
1356
|
const float v = x[i*QK4_2 + l];
|
1128
|
-
|
1357
|
+
if (amax < fabsf(v)) {
|
1358
|
+
amax = fabsf(v);
|
1359
|
+
max = v;
|
1360
|
+
}
|
1129
1361
|
}
|
1130
1362
|
|
1131
|
-
const float d =
|
1363
|
+
const float d = max / -8;
|
1132
1364
|
|
1133
1365
|
const float id = d ? 1.0f/d : 0.0f;
|
1134
1366
|
|
@@ -1138,8 +1370,8 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
|
|
1138
1370
|
const float v0 = x[i*QK4_2 + l + 0]*id;
|
1139
1371
|
const float v1 = x[i*QK4_2 + l + 1]*id;
|
1140
1372
|
|
1141
|
-
const uint8_t vi0 = (uint8_t)(v0 + 8.5f);
|
1142
|
-
const uint8_t vi1 = (uint8_t)(v1 + 8.5f);
|
1373
|
+
const uint8_t vi0 = MIN(15, (uint8_t)(v0 + 8.5f));
|
1374
|
+
const uint8_t vi1 = MIN(15, (uint8_t)(v1 + 8.5f));
|
1143
1375
|
|
1144
1376
|
assert(vi0 < 16);
|
1145
1377
|
assert(vi1 < 16);
|
@@ -1149,136 +1381,109 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
|
|
1149
1381
|
}
|
1150
1382
|
}
|
1151
1383
|
|
1152
|
-
static
|
1153
|
-
assert(
|
1154
|
-
|
1155
|
-
|
1156
|
-
|
1157
|
-
|
1158
|
-
|
1159
|
-
|
1160
|
-
|
1161
|
-
assert
|
1162
|
-
|
1163
|
-
|
1164
|
-
for (int i=0; i<
|
1165
|
-
|
1166
|
-
|
1167
|
-
|
1168
|
-
|
1169
|
-
|
1170
|
-
|
1171
|
-
|
1172
|
-
|
1173
|
-
float sumlxM = 0; int suml2M = 0;
|
1174
|
-
for (int i=0; i<n; ++i) {
|
1175
|
-
int l = nearest_int(iscale*X[i]);
|
1176
|
-
int lp = MAX(nmin, MIN(nmax, +l));
|
1177
|
-
int lm = MAX(nmin, MIN(nmax, -l));
|
1178
|
-
sumlxP += X[i]*lp; suml2P += lp*lp;
|
1179
|
-
sumlxM += X[i]*lm; suml2M += lm*lm;
|
1180
|
-
}
|
1181
|
-
float sumlxP2 = sumlxP*sumlxP;
|
1182
|
-
float sumlxM2 = sumlxM*sumlxM;
|
1183
|
-
if (sumlxP2*suml2M > sumlxM2*suml2P) {
|
1184
|
-
if (sumlxP2 > best*suml2P) {
|
1185
|
-
best = sumlxP2/suml2P; bestScale = iscale;
|
1186
|
-
}
|
1187
|
-
} else {
|
1188
|
-
if (sumlxM2 > best*suml2M) {
|
1189
|
-
best = sumlxM2/suml2M; bestScale = -iscale;
|
1384
|
+
static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
|
1385
|
+
assert(k % QK4_2 == 0);
|
1386
|
+
|
1387
|
+
block_q4_2 * restrict y = vy;
|
1388
|
+
|
1389
|
+
quantize_row_q4_2_reference(x, y, k);
|
1390
|
+
}
|
1391
|
+
|
1392
|
+
static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
|
1393
|
+
assert(k % QK5_0 == 0);
|
1394
|
+
const int nb = k / QK5_0;
|
1395
|
+
|
1396
|
+
for (int i = 0; i < nb; i++) {
|
1397
|
+
float amax = 0.0f; // absolute max
|
1398
|
+
float max = 0.0f;
|
1399
|
+
|
1400
|
+
for (int l = 0; l < QK5_0; l++) {
|
1401
|
+
const float v = x[i*QK5_0 + l];
|
1402
|
+
if (amax < fabsf(v)) {
|
1403
|
+
amax = fabsf(v);
|
1404
|
+
max = v;
|
1190
1405
|
}
|
1191
1406
|
}
|
1192
|
-
}
|
1193
|
-
float sumlx = 0; int suml2 = 0;
|
1194
|
-
for (int i=0; i<n; ++i) {
|
1195
|
-
int l = nearest_int(bestScale*X[i]);
|
1196
|
-
l = MAX(nmin, MIN(nmax, l));
|
1197
|
-
sumlx += X[i]*l; suml2 += l*l;
|
1198
|
-
L[i] = l;
|
1199
|
-
}
|
1200
|
-
float scale = sumlx/suml2;
|
1201
|
-
return scale;
|
1202
|
-
}
|
1203
1407
|
|
1204
|
-
|
1205
|
-
|
1206
|
-
static const float candidates[CANDIDATE_COUNT] = { +8.7f, +8.3f, +8.1f, +7.8f, +7.3f, +7.0f, +6.3f, +5.7f };
|
1207
|
-
assert(k % QK4_2 == 0);
|
1408
|
+
const float d = max / -16;
|
1409
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1208
1410
|
|
1209
|
-
|
1411
|
+
y[i].d = GGML_FP32_TO_FP16(d);
|
1210
1412
|
|
1211
|
-
|
1413
|
+
uint32_t qh = 0;
|
1212
1414
|
|
1213
|
-
|
1214
|
-
|
1215
|
-
|
1415
|
+
for (int l = 0; l < QK5_0; l += 2) {
|
1416
|
+
const float v0 = x[i*QK5_0 + l + 0]*id;
|
1417
|
+
const float v1 = x[i*QK5_0 + l + 1]*id;
|
1216
1418
|
|
1217
|
-
|
1218
|
-
const
|
1219
|
-
const uint8_t vi1 = (uint8_t)(L[l+1] + 8);
|
1419
|
+
const uint32_t vi0 = MIN(31, (int) (v0 + 16.5f));
|
1420
|
+
const uint32_t vi1 = MIN(31, (int) (v1 + 16.5f));
|
1220
1421
|
|
1221
|
-
|
1222
|
-
assert(vi1 < 16);
|
1422
|
+
y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
|
1223
1423
|
|
1224
|
-
|
1424
|
+
// get the 5-th bit and store it in qh at the right position
|
1425
|
+
qh |= ((vi0 & 0x10) >> 4) << (l + 0);
|
1426
|
+
qh |= ((vi1 & 0x10) >> 4) << (l + 1);
|
1225
1427
|
}
|
1226
1428
|
|
1227
|
-
|
1429
|
+
memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
|
1228
1430
|
}
|
1229
1431
|
}
|
1230
1432
|
|
1231
|
-
static void
|
1232
|
-
assert(k %
|
1433
|
+
static void quantize_row_q5_0(const float * restrict x, void * restrict vy, int k) {
|
1434
|
+
assert(k % QK5_0 == 0);
|
1233
1435
|
|
1234
|
-
|
1436
|
+
block_q5_0 * restrict y = vy;
|
1235
1437
|
|
1236
|
-
|
1237
|
-
// This produces the exact same format, just better match to the input floats ("better" as measured by RMSE)
|
1238
|
-
quantize_row_q4_2_rmse(x, y, k);
|
1438
|
+
quantize_row_q5_0_reference(x, y, k);
|
1239
1439
|
}
|
1240
1440
|
|
1241
|
-
static void
|
1242
|
-
assert(k %
|
1243
|
-
const int nb = k /
|
1441
|
+
static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
|
1442
|
+
assert(k % QK5_1 == 0);
|
1443
|
+
const int nb = k / QK5_1;
|
1244
1444
|
|
1245
1445
|
for (int i = 0; i < nb; i++) {
|
1246
1446
|
float min = FLT_MAX;
|
1247
1447
|
float max = -FLT_MAX;
|
1248
1448
|
|
1249
|
-
for (int l = 0; l <
|
1250
|
-
const float v = x[i*
|
1449
|
+
for (int l = 0; l < QK5_1; l++) {
|
1450
|
+
const float v = x[i*QK5_1 + l];
|
1251
1451
|
if (v < min) min = v;
|
1252
1452
|
if (v > max) max = v;
|
1253
1453
|
}
|
1254
1454
|
|
1255
|
-
const float d = (max - min) / ((1 <<
|
1455
|
+
const float d = (max - min) / ((1 << 5) - 1);
|
1256
1456
|
const float id = d ? 1.0f/d : 0.0f;
|
1257
1457
|
|
1258
1458
|
y[i].d = GGML_FP32_TO_FP16(d);
|
1259
1459
|
y[i].m = GGML_FP32_TO_FP16(min);
|
1260
1460
|
|
1261
|
-
|
1262
|
-
const float v0 = (x[i*QK4_3 + l + 0] - min)*id;
|
1263
|
-
const float v1 = (x[i*QK4_3 + l + 1] - min)*id;
|
1461
|
+
uint32_t qh = 0;
|
1264
1462
|
|
1265
|
-
|
1266
|
-
const
|
1463
|
+
for (int l = 0; l < QK5_1; l += 2) {
|
1464
|
+
const float v0 = (x[i*QK5_1 + l + 0] - min)*id;
|
1465
|
+
const float v1 = (x[i*QK5_1 + l + 1] - min)*id;
|
1267
1466
|
|
1268
|
-
|
1269
|
-
|
1467
|
+
const uint32_t vi0 = (int) (v0 + 0.5f);
|
1468
|
+
const uint32_t vi1 = (int) (v1 + 0.5f);
|
1270
1469
|
|
1271
|
-
y[i].qs[l/2] = vi0 | (vi1 << 4);
|
1470
|
+
y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
|
1471
|
+
|
1472
|
+
// get the 5-th bit and store it in qh at the right position
|
1473
|
+
qh |= ((vi0 & 0x10) >> 4) << (l + 0);
|
1474
|
+
qh |= ((vi1 & 0x10) >> 4) << (l + 1);
|
1272
1475
|
}
|
1476
|
+
|
1477
|
+
memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
|
1273
1478
|
}
|
1274
1479
|
}
|
1275
1480
|
|
1276
|
-
static void
|
1277
|
-
assert(k %
|
1481
|
+
static void quantize_row_q5_1(const float * restrict x, void * restrict vy, int k) {
|
1482
|
+
assert(k % QK5_1 == 0);
|
1278
1483
|
|
1279
|
-
|
1484
|
+
block_q5_1 * restrict y = vy;
|
1280
1485
|
|
1281
|
-
|
1486
|
+
quantize_row_q5_1_reference(x, y, k);
|
1282
1487
|
}
|
1283
1488
|
|
1284
1489
|
// reference implementation for deterministic creation of model files
|
@@ -1300,13 +1505,15 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
|
|
1300
1505
|
y[i].d = d;
|
1301
1506
|
|
1302
1507
|
for (int l = 0; l < QK8_0; ++l) {
|
1303
|
-
const float
|
1304
|
-
|
1508
|
+
const float v0 = x[i*QK8_0 + l]*id;
|
1509
|
+
|
1510
|
+
y[i].qs[l] = roundf(v0);
|
1305
1511
|
}
|
1306
1512
|
}
|
1307
1513
|
}
|
1308
1514
|
|
1309
1515
|
static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
|
1516
|
+
assert(QK8_0 == 32);
|
1310
1517
|
assert(k % QK8_0 == 0);
|
1311
1518
|
const int nb = k / QK8_0;
|
1312
1519
|
|
@@ -1432,95 +1639,295 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|
1432
1639
|
#endif
|
1433
1640
|
}
|
1434
1641
|
|
1435
|
-
|
1436
|
-
|
1437
|
-
|
1438
|
-
|
1439
|
-
const
|
1642
|
+
// reference implementation for deterministic creation of model files
|
1643
|
+
static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
|
1644
|
+
assert(QK8_1 == 32);
|
1645
|
+
assert(k % QK8_1 == 0);
|
1646
|
+
const int nb = k / QK8_1;
|
1440
1647
|
|
1441
|
-
#if defined(__AVX2__)
|
1442
1648
|
for (int i = 0; i < nb; i++) {
|
1443
|
-
//
|
1444
|
-
const __m256 d_v = _mm256_broadcast_ss(&x[i].d);
|
1649
|
+
float amax = 0.0f; // absolute max
|
1445
1650
|
|
1446
|
-
|
1651
|
+
for (int l = 0; l < QK8_1; l++) {
|
1652
|
+
const float v = x[i*QK8_1 + l];
|
1653
|
+
amax = MAX(amax, fabsf(v));
|
1654
|
+
}
|
1447
1655
|
|
1448
|
-
|
1449
|
-
|
1450
|
-
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
|
1656
|
+
const float d = amax / ((1 << 7) - 1);
|
1657
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1451
1658
|
|
1452
|
-
|
1453
|
-
vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
|
1659
|
+
y[i].d = d;
|
1454
1660
|
|
1455
|
-
|
1456
|
-
|
1457
|
-
const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
|
1661
|
+
int sum0 = 0;
|
1662
|
+
int sum1 = 0;
|
1458
1663
|
|
1459
|
-
|
1460
|
-
const
|
1461
|
-
|
1462
|
-
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
|
1463
|
-
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
|
1464
|
-
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
|
1465
|
-
};
|
1664
|
+
for (int l = 0; l < QK8_1/2; ++l) {
|
1665
|
+
const float v0 = x[i*QK8_1 + l]*id;
|
1666
|
+
const float v1 = x[i*QK8_1 + QK8_1/2 + l]*id;
|
1466
1667
|
|
1467
|
-
|
1468
|
-
|
1469
|
-
|
1470
|
-
|
1471
|
-
|
1668
|
+
y[i].qs[ l] = roundf(v0);
|
1669
|
+
y[i].qs[QK8_1/2 + l] = roundf(v1);
|
1670
|
+
|
1671
|
+
sum0 += y[i].qs[ l];
|
1672
|
+
sum1 += y[i].qs[QK8_1/2 + l];
|
1472
1673
|
}
|
1674
|
+
|
1675
|
+
y[i].s0 = d * sum0;
|
1676
|
+
y[i].s1 = d * sum1;
|
1473
1677
|
}
|
1474
|
-
|
1475
|
-
for (int i = 0; i < nb; i++) {
|
1476
|
-
const float32x4_t vd = vdupq_n_f32(x[i].d);
|
1678
|
+
}
|
1477
1679
|
|
1478
|
-
|
1680
|
+
static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
|
1681
|
+
assert(k % QK8_1 == 0);
|
1682
|
+
const int nb = k / QK8_1;
|
1479
1683
|
|
1480
|
-
|
1481
|
-
// Load 16x4-bit integers into 8x8-bit integers
|
1482
|
-
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
1684
|
+
block_q8_1 * restrict y = vy;
|
1483
1685
|
|
1484
|
-
|
1485
|
-
|
1486
|
-
|
1686
|
+
#if defined(__ARM_NEON)
|
1687
|
+
for (int i = 0; i < nb; i++) {
|
1688
|
+
float32x4_t srcv [8];
|
1689
|
+
float32x4_t asrcv[8];
|
1690
|
+
float32x4_t amaxv[8];
|
1487
1691
|
|
1488
|
-
|
1489
|
-
|
1490
|
-
const int8x8_t vs_1 = vreinterpret_s8_u8(v1);
|
1692
|
+
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
|
1693
|
+
for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
|
1491
1694
|
|
1492
|
-
|
1493
|
-
|
1494
|
-
|
1695
|
+
for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
|
1696
|
+
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
|
1697
|
+
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
|
1495
1698
|
|
1496
|
-
|
1497
|
-
const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1);
|
1498
|
-
const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1);
|
1699
|
+
const float amax = vmaxvq_f32(amaxv[0]);
|
1499
1700
|
|
1500
|
-
|
1701
|
+
const float d = amax / ((1 << 7) - 1);
|
1702
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1501
1703
|
|
1502
|
-
|
1503
|
-
const int16x8_t vi_0 = vmovl_s8(vget_low_s8 (vq));
|
1504
|
-
const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq));
|
1704
|
+
y[i].d = d;
|
1505
1705
|
|
1506
|
-
|
1507
|
-
|
1508
|
-
const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0)));
|
1509
|
-
const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1)));
|
1510
|
-
const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1)));
|
1706
|
+
int32x4_t accv0 = vdupq_n_s32(0);
|
1707
|
+
int32x4_t accv1 = vdupq_n_s32(0);
|
1511
1708
|
|
1512
|
-
|
1513
|
-
|
1514
|
-
const float32x4_t
|
1515
|
-
const
|
1516
|
-
const float32x4_t r3 = vmulq_f32(vf_3, vd);
|
1709
|
+
// low half
|
1710
|
+
for (int l = 0; l < 4; l++) {
|
1711
|
+
const float32x4_t v = vmulq_n_f32(srcv[l], id);
|
1712
|
+
const int32x4_t vi = vcvtnq_s32_f32(v);
|
1517
1713
|
|
1518
|
-
|
1519
|
-
|
1520
|
-
|
1521
|
-
|
1522
|
-
|
1523
|
-
|
1714
|
+
y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
|
1715
|
+
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
|
1716
|
+
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
|
1717
|
+
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
|
1718
|
+
|
1719
|
+
accv0 = vaddq_s32(accv0, vi);
|
1720
|
+
}
|
1721
|
+
|
1722
|
+
// high half
|
1723
|
+
for (int l = 4; l < 8; l++) {
|
1724
|
+
const float32x4_t v = vmulq_n_f32(srcv[l], id);
|
1725
|
+
const int32x4_t vi = vcvtnq_s32_f32(v);
|
1726
|
+
|
1727
|
+
y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
|
1728
|
+
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
|
1729
|
+
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
|
1730
|
+
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
|
1731
|
+
|
1732
|
+
accv1 = vaddq_s32(accv1, vi);
|
1733
|
+
}
|
1734
|
+
|
1735
|
+
const int32_t sum0 = vaddvq_s32(accv0);
|
1736
|
+
const int32_t sum1 = vaddvq_s32(accv1);
|
1737
|
+
|
1738
|
+
y[i].s0 = d * sum0;
|
1739
|
+
y[i].s1 = d * sum1;
|
1740
|
+
}
|
1741
|
+
#elif defined(__AVX2__) || defined(__AVX__)
|
1742
|
+
for (int i = 0; i < nb; i++) {
|
1743
|
+
// Load elements into 4 AVX vectors
|
1744
|
+
__m256 v0 = _mm256_loadu_ps( x );
|
1745
|
+
__m256 v1 = _mm256_loadu_ps( x + 8 );
|
1746
|
+
__m256 v2 = _mm256_loadu_ps( x + 16 );
|
1747
|
+
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
1748
|
+
x += 32;
|
1749
|
+
|
1750
|
+
// Compute max(abs(e)) for the block
|
1751
|
+
const __m256 signBit = _mm256_set1_ps( -0.0f );
|
1752
|
+
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
|
1753
|
+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
|
1754
|
+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
|
1755
|
+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
|
1756
|
+
|
1757
|
+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
|
1758
|
+
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
1759
|
+
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
1760
|
+
const float maxScalar = _mm_cvtss_f32( max4 );
|
1761
|
+
|
1762
|
+
// Quantize these floats
|
1763
|
+
const float d = maxScalar / 127.f;
|
1764
|
+
y[i].d = d;
|
1765
|
+
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
|
1766
|
+
const __m256 mul = _mm256_set1_ps( id );
|
1767
|
+
|
1768
|
+
// Apply the multiplier
|
1769
|
+
v0 = _mm256_mul_ps( v0, mul );
|
1770
|
+
v1 = _mm256_mul_ps( v1, mul );
|
1771
|
+
v2 = _mm256_mul_ps( v2, mul );
|
1772
|
+
v3 = _mm256_mul_ps( v3, mul );
|
1773
|
+
|
1774
|
+
// Round to nearest integer
|
1775
|
+
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
|
1776
|
+
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
|
1777
|
+
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
|
1778
|
+
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
|
1779
|
+
|
1780
|
+
// Convert floats to integers
|
1781
|
+
__m256i i0 = _mm256_cvtps_epi32( v0 );
|
1782
|
+
__m256i i1 = _mm256_cvtps_epi32( v1 );
|
1783
|
+
__m256i i2 = _mm256_cvtps_epi32( v2 );
|
1784
|
+
__m256i i3 = _mm256_cvtps_epi32( v3 );
|
1785
|
+
|
1786
|
+
#if defined(__AVX2__)
|
1787
|
+
// Compute the sum of the quants and set y[i].s
|
1788
|
+
//y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
|
1789
|
+
y[i].s0 = d * hsum_i32_8(_mm256_add_epi32(i0, i1));
|
1790
|
+
y[i].s1 = d * hsum_i32_8(_mm256_add_epi32(i2, i3));
|
1791
|
+
|
1792
|
+
// Convert int32 to int16
|
1793
|
+
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
|
1794
|
+
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
|
1795
|
+
// Convert int16 to int8
|
1796
|
+
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
|
1797
|
+
|
1798
|
+
// We got our precious signed bytes, but the order is now wrong
|
1799
|
+
// These AVX2 pack instructions process 16-byte pieces independently
|
1800
|
+
// The following instruction is fixing the order
|
1801
|
+
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
|
1802
|
+
i0 = _mm256_permutevar8x32_epi32( i0, perm );
|
1803
|
+
|
1804
|
+
_mm256_storeu_si256((__m256i *)y[i].qs, i0);
|
1805
|
+
#else
|
1806
|
+
// Since we don't have in AVX some necessary functions,
|
1807
|
+
// we split the registers in half and call AVX2 analogs from SSE
|
1808
|
+
__m128i ni0 = _mm256_castsi256_si128( i0 );
|
1809
|
+
__m128i ni1 = _mm256_extractf128_si256( i0, 1);
|
1810
|
+
__m128i ni2 = _mm256_castsi256_si128( i1 );
|
1811
|
+
__m128i ni3 = _mm256_extractf128_si256( i1, 1);
|
1812
|
+
__m128i ni4 = _mm256_castsi256_si128( i2 );
|
1813
|
+
__m128i ni5 = _mm256_extractf128_si256( i2, 1);
|
1814
|
+
__m128i ni6 = _mm256_castsi256_si128( i3 );
|
1815
|
+
__m128i ni7 = _mm256_extractf128_si256( i3, 1);
|
1816
|
+
|
1817
|
+
// Compute the sum of the quants and set y[i].s
|
1818
|
+
const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
|
1819
|
+
const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
|
1820
|
+
y[i].s0 = d * hsum_i32_4(s0);
|
1821
|
+
y[i].s1 = d * hsum_i32_4(s1);
|
1822
|
+
|
1823
|
+
// Convert int32 to int16
|
1824
|
+
ni0 = _mm_packs_epi32( ni0, ni1 );
|
1825
|
+
ni2 = _mm_packs_epi32( ni2, ni3 );
|
1826
|
+
ni4 = _mm_packs_epi32( ni4, ni5 );
|
1827
|
+
ni6 = _mm_packs_epi32( ni6, ni7 );
|
1828
|
+
// Convert int16 to int8
|
1829
|
+
ni0 = _mm_packs_epi16( ni0, ni2 );
|
1830
|
+
ni4 = _mm_packs_epi16( ni4, ni6 );
|
1831
|
+
|
1832
|
+
_mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
|
1833
|
+
_mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
|
1834
|
+
#endif
|
1835
|
+
}
|
1836
|
+
#else
|
1837
|
+
// scalar
|
1838
|
+
quantize_row_q8_1_reference(x, y, k);
|
1839
|
+
#endif
|
1840
|
+
}
|
1841
|
+
|
1842
|
+
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
|
1843
|
+
assert(k % QK4_0 == 0);
|
1844
|
+
const int nb = k / QK4_0;
|
1845
|
+
|
1846
|
+
const block_q4_0 * restrict x = vx;
|
1847
|
+
|
1848
|
+
#if defined(__AVX2__)
|
1849
|
+
for (int i = 0; i < nb; i++) {
|
1850
|
+
// scale factor
|
1851
|
+
const __m256 d_v = _mm256_broadcast_ss(&x[i].d);
|
1852
|
+
|
1853
|
+
const uint8_t * restrict pp = x[i].qs;
|
1854
|
+
|
1855
|
+
for (int l = 0; l < QK4_0; l += 32) {
|
1856
|
+
// Load 32x4-bit integers into 32x8-bit integers
|
1857
|
+
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
|
1858
|
+
|
1859
|
+
// Subtract 8 from the integers
|
1860
|
+
vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
|
1861
|
+
|
1862
|
+
// Convert to 16-bit int
|
1863
|
+
const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
|
1864
|
+
const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
|
1865
|
+
|
1866
|
+
// Convert to 32-bit int -> float 32
|
1867
|
+
const __m256 vf[4] = {
|
1868
|
+
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))),
|
1869
|
+
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
|
1870
|
+
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
|
1871
|
+
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
|
1872
|
+
};
|
1873
|
+
|
1874
|
+
// Scale and store
|
1875
|
+
for (int j = 0; j < 4; j++) {
|
1876
|
+
const __m256 result = _mm256_mul_ps(vf[j], d_v);
|
1877
|
+
_mm256_storeu_ps(y + i * QK4_0 + l + j*8, result);
|
1878
|
+
}
|
1879
|
+
}
|
1880
|
+
}
|
1881
|
+
#elif defined(__ARM_NEON)
|
1882
|
+
for (int i = 0; i < nb; i++) {
|
1883
|
+
const float32x4_t vd = vdupq_n_f32(x[i].d);
|
1884
|
+
|
1885
|
+
const uint8_t * restrict pp = x[i].qs;
|
1886
|
+
|
1887
|
+
for (int l = 0; l < QK4_0; l += 16) {
|
1888
|
+
// Load 16x4-bit integers into 8x8-bit integers
|
1889
|
+
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
1890
|
+
|
1891
|
+
// Expand 4-bit qs to 8-bit bytes
|
1892
|
+
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
|
1893
|
+
const uint8x8_t v1 = vshr_n_u8(v8, 4);
|
1894
|
+
|
1895
|
+
// Convert to signed 8-bit integers
|
1896
|
+
const int8x8_t vs_0 = vreinterpret_s8_u8(v0);
|
1897
|
+
const int8x8_t vs_1 = vreinterpret_s8_u8(v1);
|
1898
|
+
|
1899
|
+
// Subtract 8 from each byte
|
1900
|
+
const int8x8_t vb_0 = vsub_s8(vs_0, vdup_n_s8(8));
|
1901
|
+
const int8x8_t vb_1 = vsub_s8(vs_1, vdup_n_s8(8));
|
1902
|
+
|
1903
|
+
// Interleave and combine
|
1904
|
+
const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1);
|
1905
|
+
const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1);
|
1906
|
+
|
1907
|
+
const int8x16_t vq = vcombine_s8(vx_0, vx_1);
|
1908
|
+
|
1909
|
+
// convert to 2x int16x8_t
|
1910
|
+
const int16x8_t vi_0 = vmovl_s8(vget_low_s8 (vq));
|
1911
|
+
const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq));
|
1912
|
+
|
1913
|
+
// convert to 4x float32x4_t
|
1914
|
+
const float32x4_t vf_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_0)));
|
1915
|
+
const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0)));
|
1916
|
+
const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1)));
|
1917
|
+
const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1)));
|
1918
|
+
|
1919
|
+
// Multiply by d
|
1920
|
+
const float32x4_t r0 = vmulq_f32(vf_0, vd);
|
1921
|
+
const float32x4_t r1 = vmulq_f32(vf_1, vd);
|
1922
|
+
const float32x4_t r2 = vmulq_f32(vf_2, vd);
|
1923
|
+
const float32x4_t r3 = vmulq_f32(vf_3, vd);
|
1924
|
+
|
1925
|
+
// Store
|
1926
|
+
vst1q_f32(y + i*QK4_0 + l + 0, r0);
|
1927
|
+
vst1q_f32(y + i*QK4_0 + l + 4, r1);
|
1928
|
+
vst1q_f32(y + i*QK4_0 + l + 8, r2);
|
1929
|
+
vst1q_f32(y + i*QK4_0 + l + 12, r3);
|
1930
|
+
}
|
1524
1931
|
}
|
1525
1932
|
#else
|
1526
1933
|
// scalar
|
@@ -1532,7 +1939,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1532
1939
|
for (int l = 0; l < QK4_0; l += 2) {
|
1533
1940
|
const uint8_t vi = pp[l/2];
|
1534
1941
|
|
1535
|
-
const int8_t vi0 = vi &
|
1942
|
+
const int8_t vi0 = vi & 0x0F;
|
1536
1943
|
const int8_t vi1 = vi >> 4;
|
1537
1944
|
|
1538
1945
|
const float v0 = (vi0 - 8)*d;
|
@@ -1598,7 +2005,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1598
2005
|
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
1599
2006
|
|
1600
2007
|
// Expand 4-bit qs to 8-bit bytes
|
1601
|
-
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(
|
2008
|
+
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
|
1602
2009
|
const uint8x8_t v1 = vshr_n_u8(v8, 4);
|
1603
2010
|
|
1604
2011
|
// Interleave and combine
|
@@ -1640,7 +2047,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1640
2047
|
for (int l = 0; l < QK4_1; l += 2) {
|
1641
2048
|
const uint8_t vi = pp[l/2];
|
1642
2049
|
|
1643
|
-
const int8_t vi0 = vi &
|
2050
|
+
const int8_t vi0 = vi & 0x0F;
|
1644
2051
|
const int8_t vi1 = vi >> 4;
|
1645
2052
|
|
1646
2053
|
const float v0 = vi0*d + m;
|
@@ -1670,7 +2077,7 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in
|
|
1670
2077
|
for (int l = 0; l < QK4_2; l += 2) {
|
1671
2078
|
const uint8_t vi = pp[l/2];
|
1672
2079
|
|
1673
|
-
const int8_t vi0 = vi &
|
2080
|
+
const int8_t vi0 = vi & 0x0F;
|
1674
2081
|
const int8_t vi1 = vi >> 4;
|
1675
2082
|
|
1676
2083
|
const float v0 = (vi0 - 8)*d;
|
@@ -1685,11 +2092,47 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in
|
|
1685
2092
|
}
|
1686
2093
|
}
|
1687
2094
|
|
1688
|
-
static void
|
1689
|
-
assert(k %
|
1690
|
-
const int nb = k /
|
2095
|
+
static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, int k) {
|
2096
|
+
assert(k % QK5_0 == 0);
|
2097
|
+
const int nb = k / QK5_0;
|
2098
|
+
|
2099
|
+
const block_q5_0 * restrict x = vx;
|
2100
|
+
|
2101
|
+
for (int i = 0; i < nb; i++) {
|
2102
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
2103
|
+
|
2104
|
+
const uint8_t * restrict pp = x[i].qs;
|
2105
|
+
|
2106
|
+
uint32_t qh;
|
2107
|
+
memcpy(&qh, x[i].qh, sizeof(qh));
|
2108
|
+
|
2109
|
+
for (int l = 0; l < QK5_0; l += 2) {
|
2110
|
+
const uint8_t vi = pp[l/2];
|
2111
|
+
|
2112
|
+
// extract the 5-th bit from qh
|
2113
|
+
const uint8_t vh0 = ((qh & (1u << (l + 0))) >> (l + 0)) << 4;
|
2114
|
+
const uint8_t vh1 = ((qh & (1u << (l + 1))) >> (l + 1)) << 4;
|
2115
|
+
|
2116
|
+
const int8_t vi0 = (vi & 0x0F) | vh0;
|
2117
|
+
const int8_t vi1 = (vi >> 4) | vh1;
|
2118
|
+
|
2119
|
+
const float v0 = (vi0 - 16)*d;
|
2120
|
+
const float v1 = (vi1 - 16)*d;
|
2121
|
+
|
2122
|
+
y[i*QK5_0 + l + 0] = v0;
|
2123
|
+
y[i*QK5_0 + l + 1] = v1;
|
2124
|
+
|
2125
|
+
assert(!isnan(y[i*QK5_0 + l + 0]));
|
2126
|
+
assert(!isnan(y[i*QK5_0 + l + 1]));
|
2127
|
+
}
|
2128
|
+
}
|
2129
|
+
}
|
2130
|
+
|
2131
|
+
static void dequantize_row_q5_1(const void * restrict vx, float * restrict y, int k) {
|
2132
|
+
assert(k % QK5_1 == 0);
|
2133
|
+
const int nb = k / QK5_1;
|
1691
2134
|
|
1692
|
-
const
|
2135
|
+
const block_q5_1 * restrict x = vx;
|
1693
2136
|
|
1694
2137
|
for (int i = 0; i < nb; i++) {
|
1695
2138
|
const float d = GGML_FP16_TO_FP32(x[i].d);
|
@@ -1697,28 +2140,54 @@ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, in
|
|
1697
2140
|
|
1698
2141
|
const uint8_t * restrict pp = x[i].qs;
|
1699
2142
|
|
1700
|
-
|
2143
|
+
uint32_t qh;
|
2144
|
+
memcpy(&qh, x[i].qh, sizeof(qh));
|
2145
|
+
|
2146
|
+
for (int l = 0; l < QK5_1; l += 2) {
|
1701
2147
|
const uint8_t vi = pp[l/2];
|
1702
2148
|
|
1703
|
-
|
1704
|
-
const
|
2149
|
+
// extract the 5-th bit from qh
|
2150
|
+
const uint8_t vh0 = ((qh & (1u << (l + 0))) >> (l + 0)) << 4;
|
2151
|
+
const uint8_t vh1 = ((qh & (1u << (l + 1))) >> (l + 1)) << 4;
|
2152
|
+
|
2153
|
+
const uint8_t vi0 = (vi & 0x0F) | vh0;
|
2154
|
+
const uint8_t vi1 = (vi >> 4) | vh1;
|
1705
2155
|
|
1706
2156
|
const float v0 = vi0*d + m;
|
1707
2157
|
const float v1 = vi1*d + m;
|
1708
2158
|
|
1709
|
-
y[i*
|
1710
|
-
y[i*
|
2159
|
+
y[i*QK5_1 + l + 0] = v0;
|
2160
|
+
y[i*QK5_1 + l + 1] = v1;
|
2161
|
+
|
2162
|
+
assert(!isnan(y[i*QK5_1 + l + 0]));
|
2163
|
+
assert(!isnan(y[i*QK5_1 + l + 1]));
|
2164
|
+
}
|
2165
|
+
}
|
2166
|
+
}
|
2167
|
+
|
2168
|
+
static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, int k) {
|
2169
|
+
assert(k % QK8_0 == 0);
|
2170
|
+
const int nb = k / QK8_0;
|
2171
|
+
|
2172
|
+
const block_q8_0 * restrict x = vx;
|
2173
|
+
|
2174
|
+
for (int i = 0; i < nb; i++) {
|
2175
|
+
const float d = x[i].d;
|
1711
2176
|
|
1712
|
-
|
1713
|
-
|
2177
|
+
const int8_t * restrict pp = x[i].qs;
|
2178
|
+
|
2179
|
+
for (int l = 0; l < QK8_0; ++l) {
|
2180
|
+
y[i*QK8_0 + l] = pp[l]*d;
|
1714
2181
|
}
|
1715
2182
|
}
|
1716
2183
|
}
|
1717
2184
|
|
1718
2185
|
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1719
|
-
static void
|
2186
|
+
static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1720
2187
|
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1721
|
-
static void
|
2188
|
+
static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
2189
|
+
static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
2190
|
+
static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1722
2191
|
|
1723
2192
|
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
1724
2193
|
[GGML_TYPE_Q4_0] = {
|
@@ -1727,34 +2196,55 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
|
1727
2196
|
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
|
1728
2197
|
.quantize_row_q_dot = quantize_row_q8_0,
|
1729
2198
|
.vec_dot_q = ggml_vec_dot_q4_0_q8_0,
|
2199
|
+
.vec_dot_type = GGML_TYPE_Q8_0,
|
1730
2200
|
},
|
1731
2201
|
[GGML_TYPE_Q4_1] = {
|
1732
2202
|
.dequantize_row_q = dequantize_row_q4_1,
|
1733
2203
|
.quantize_row_q = quantize_row_q4_1,
|
1734
2204
|
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
|
1735
|
-
.quantize_row_q_dot =
|
1736
|
-
.vec_dot_q =
|
2205
|
+
.quantize_row_q_dot = quantize_row_q8_1,
|
2206
|
+
.vec_dot_q = ggml_vec_dot_q4_1_q8_1,
|
2207
|
+
.vec_dot_type = GGML_TYPE_Q8_1,
|
1737
2208
|
},
|
1738
2209
|
[GGML_TYPE_Q4_2] = {
|
1739
2210
|
.dequantize_row_q = dequantize_row_q4_2,
|
1740
2211
|
.quantize_row_q = quantize_row_q4_2,
|
1741
|
-
.quantize_row_q_reference = (quantize_row_q_t)
|
2212
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
|
1742
2213
|
.quantize_row_q_dot = quantize_row_q8_0,
|
1743
2214
|
.vec_dot_q = ggml_vec_dot_q4_2_q8_0,
|
2215
|
+
.vec_dot_type = GGML_TYPE_Q8_0,
|
1744
2216
|
},
|
1745
|
-
[
|
1746
|
-
.dequantize_row_q =
|
1747
|
-
.quantize_row_q =
|
1748
|
-
.quantize_row_q_reference = (quantize_row_q_t)
|
2217
|
+
[GGML_TYPE_Q5_0] = {
|
2218
|
+
.dequantize_row_q = dequantize_row_q5_0,
|
2219
|
+
.quantize_row_q = quantize_row_q5_0,
|
2220
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_0_reference,
|
1749
2221
|
.quantize_row_q_dot = quantize_row_q8_0,
|
1750
|
-
.vec_dot_q =
|
2222
|
+
.vec_dot_q = ggml_vec_dot_q5_0_q8_0,
|
2223
|
+
.vec_dot_type = GGML_TYPE_Q8_0,
|
2224
|
+
},
|
2225
|
+
[GGML_TYPE_Q5_1] = {
|
2226
|
+
.dequantize_row_q = dequantize_row_q5_1,
|
2227
|
+
.quantize_row_q = quantize_row_q5_1,
|
2228
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_1_reference,
|
2229
|
+
.quantize_row_q_dot = quantize_row_q8_1,
|
2230
|
+
.vec_dot_q = ggml_vec_dot_q5_1_q8_1,
|
2231
|
+
.vec_dot_type = GGML_TYPE_Q8_1,
|
1751
2232
|
},
|
1752
2233
|
[GGML_TYPE_Q8_0] = {
|
1753
|
-
.dequantize_row_q =
|
2234
|
+
.dequantize_row_q = dequantize_row_q8_0,
|
1754
2235
|
.quantize_row_q = quantize_row_q8_0,
|
1755
2236
|
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
|
1756
2237
|
.quantize_row_q_dot = quantize_row_q8_0,
|
2238
|
+
.vec_dot_q = ggml_vec_dot_q8_0_q8_0,
|
2239
|
+
.vec_dot_type = GGML_TYPE_Q8_0,
|
2240
|
+
},
|
2241
|
+
[GGML_TYPE_Q8_1] = {
|
2242
|
+
.dequantize_row_q = NULL, // TODO
|
2243
|
+
.quantize_row_q = quantize_row_q8_1,
|
2244
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_1_reference,
|
2245
|
+
.quantize_row_q_dot = quantize_row_q8_1,
|
1757
2246
|
.vec_dot_q = NULL, // TODO
|
2247
|
+
.vec_dot_type = GGML_TYPE_Q8_1,
|
1758
2248
|
},
|
1759
2249
|
};
|
1760
2250
|
|
@@ -2366,8 +2856,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
2366
2856
|
const block_q4_0 * restrict x = vx;
|
2367
2857
|
const block_q8_0 * restrict y = vy;
|
2368
2858
|
|
2369
|
-
float sumf = 0.0;
|
2370
|
-
|
2371
2859
|
#if defined(__ARM_NEON)
|
2372
2860
|
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
2373
2861
|
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
@@ -2378,7 +2866,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
2378
2866
|
const block_q8_0 * restrict y0 = &y[i + 0];
|
2379
2867
|
const block_q8_0 * restrict y1 = &y[i + 1];
|
2380
2868
|
|
2381
|
-
const uint8x16_t m4b = vdupq_n_u8(
|
2869
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
2382
2870
|
const int8x16_t s8b = vdupq_n_s8(0x8);
|
2383
2871
|
|
2384
2872
|
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
@@ -2396,35 +2884,35 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
2396
2884
|
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
2397
2885
|
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
2398
2886
|
|
2887
|
+
// interleave
|
2888
|
+
const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
|
2889
|
+
const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
|
2890
|
+
const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
|
2891
|
+
const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
|
2892
|
+
|
2399
2893
|
// load y
|
2400
2894
|
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2401
2895
|
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2402
2896
|
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2403
2897
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2404
2898
|
|
2405
|
-
// interleave
|
2406
|
-
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
|
2407
|
-
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
|
2408
|
-
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
|
2409
|
-
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
|
2410
|
-
|
2411
2899
|
#if defined(__ARM_FEATURE_DOTPROD)
|
2412
2900
|
// dot product into int32x4_t
|
2413
|
-
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0),
|
2414
|
-
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0),
|
2901
|
+
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h);
|
2902
|
+
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h);
|
2415
2903
|
|
2416
2904
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
|
2417
2905
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
|
2418
2906
|
#else
|
2419
|
-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (
|
2420
|
-
const int16x8_t pl0h = vmull_s8(vget_high_s8(
|
2421
|
-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (
|
2422
|
-
const int16x8_t ph0h = vmull_s8(vget_high_s8(
|
2907
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
|
2908
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
|
2909
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
|
2910
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
|
2423
2911
|
|
2424
|
-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (
|
2425
|
-
const int16x8_t pl1h = vmull_s8(vget_high_s8(
|
2426
|
-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (
|
2427
|
-
const int16x8_t ph1h = vmull_s8(vget_high_s8(
|
2912
|
+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
|
2913
|
+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
|
2914
|
+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
|
2915
|
+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
|
2428
2916
|
|
2429
2917
|
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2430
2918
|
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
@@ -2436,7 +2924,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
2436
2924
|
#endif
|
2437
2925
|
}
|
2438
2926
|
|
2439
|
-
|
2927
|
+
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
2440
2928
|
#elif defined(__AVX2__)
|
2441
2929
|
// Initialize accumulator with zeros
|
2442
2930
|
__m256 acc = _mm256_setzero_ps();
|
@@ -2454,32 +2942,13 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
2454
2942
|
|
2455
2943
|
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
2456
2944
|
|
2457
|
-
|
2458
|
-
const __m256i ax = _mm256_sign_epi8(bx, bx);
|
2459
|
-
|
2460
|
-
// Sign the values of the y vectors
|
2461
|
-
const __m256i sy = _mm256_sign_epi8(by, bx);
|
2462
|
-
|
2463
|
-
// Perform multiplication and create 16-bit values
|
2464
|
-
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
2465
|
-
|
2466
|
-
const __m256i ones = _mm256_set1_epi16(1);
|
2467
|
-
__m256i xy_q = _mm256_madd_epi16(ones, dot);
|
2468
|
-
|
2469
|
-
/* Convert to vectore of 8 int32_t to 8 floats */
|
2470
|
-
__m256 q = _mm256_cvtepi32_ps( xy_q );
|
2945
|
+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
2471
2946
|
|
2472
2947
|
/* Multiply q with scale and accumulate */
|
2473
2948
|
acc = _mm256_fmadd_ps( d, q, acc );
|
2474
2949
|
}
|
2475
2950
|
|
2476
|
-
|
2477
|
-
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
2478
|
-
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2479
|
-
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2480
|
-
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2481
|
-
|
2482
|
-
sumf = _mm_cvtss_f32( res );
|
2951
|
+
*s = hsum_float_8(acc);
|
2483
2952
|
#elif defined(__AVX__)
|
2484
2953
|
// Initialize accumulator with zeros
|
2485
2954
|
__m256 acc = _mm256_setzero_ps();
|
@@ -2518,15 +2987,10 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
2518
2987
|
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
|
2519
2988
|
}
|
2520
2989
|
|
2521
|
-
|
2522
|
-
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
2523
|
-
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2524
|
-
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2525
|
-
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2526
|
-
|
2527
|
-
sumf = _mm_cvtss_f32( res );
|
2990
|
+
*s = hsum_float_8(acc);
|
2528
2991
|
#else
|
2529
2992
|
// scalar
|
2993
|
+
float sumf = 0.0;
|
2530
2994
|
for (int i = 0; i < nb; i++) {
|
2531
2995
|
const float d0 = x[i].d;
|
2532
2996
|
const float d1 = y[i].d;
|
@@ -2538,8 +3002,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
2538
3002
|
for (int j = 0; j < QK8_0/2; j++) {
|
2539
3003
|
const uint8_t v0 = p0[j];
|
2540
3004
|
|
2541
|
-
const int i0 = (int8_t) (v0 &
|
2542
|
-
const int i1 = (int8_t) (v0 >>
|
3005
|
+
const int i0 = (int8_t) (v0 & 0x0F) - 8;
|
3006
|
+
const int i1 = (int8_t) (v0 >> 4) - 8;
|
2543
3007
|
|
2544
3008
|
const int i2 = p1[2*j + 0];
|
2545
3009
|
const int i3 = p1[2*j + 1];
|
@@ -2548,34 +3012,35 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
2548
3012
|
}
|
2549
3013
|
sumf += d0*d1*sumi;
|
2550
3014
|
}
|
2551
|
-
#endif
|
2552
|
-
|
2553
3015
|
*s = sumf;
|
3016
|
+
#endif
|
2554
3017
|
}
|
2555
3018
|
|
2556
|
-
static void
|
2557
|
-
const int nb = n /
|
3019
|
+
static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
3020
|
+
const int nb = n / QK8_1;
|
2558
3021
|
|
2559
|
-
assert(n %
|
3022
|
+
assert(n % QK8_1 == 0);
|
2560
3023
|
assert(nb % 2 == 0);
|
2561
3024
|
|
2562
3025
|
const block_q4_1 * restrict x = vx;
|
2563
|
-
const
|
2564
|
-
|
2565
|
-
float sumf = 0.0;
|
3026
|
+
const block_q8_1 * restrict y = vy;
|
2566
3027
|
|
2567
3028
|
// TODO: add AVX / WASM SIMD / etc
|
2568
3029
|
#if defined(__ARM_NEON)
|
2569
3030
|
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
2570
3031
|
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
2571
3032
|
|
3033
|
+
float summs = 0;
|
3034
|
+
|
2572
3035
|
for (int i = 0; i < nb; i += 2) {
|
2573
3036
|
const block_q4_1 * restrict x0 = &x[i + 0];
|
2574
3037
|
const block_q4_1 * restrict x1 = &x[i + 1];
|
2575
|
-
const
|
2576
|
-
const
|
3038
|
+
const block_q8_1 * restrict y0 = &y[i + 0];
|
3039
|
+
const block_q8_1 * restrict y1 = &y[i + 1];
|
3040
|
+
|
3041
|
+
summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1);
|
2577
3042
|
|
2578
|
-
const uint8x16_t m4b = vdupq_n_u8(
|
3043
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
2579
3044
|
|
2580
3045
|
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
2581
3046
|
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
@@ -2586,46 +3051,35 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
|
|
2586
3051
|
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
2587
3052
|
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
2588
3053
|
|
3054
|
+
// interleave
|
3055
|
+
const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
|
3056
|
+
const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
|
3057
|
+
const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
|
3058
|
+
const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
|
3059
|
+
|
2589
3060
|
// load y
|
2590
3061
|
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2591
3062
|
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2592
3063
|
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2593
3064
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2594
3065
|
|
2595
|
-
// interleave
|
2596
|
-
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
|
2597
|
-
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
|
2598
|
-
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
|
2599
|
-
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
|
2600
|
-
|
2601
|
-
const int16x8_t s0i = vaddq_s16(
|
2602
|
-
vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
|
2603
|
-
vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
|
2604
|
-
|
2605
|
-
const int16x8_t s1i = vaddq_s16(
|
2606
|
-
vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
|
2607
|
-
vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
|
2608
|
-
|
2609
|
-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
|
2610
|
-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
|
2611
|
-
|
2612
3066
|
#if defined(__ARM_FEATURE_DOTPROD)
|
2613
3067
|
// dot product into int32x4_t
|
2614
|
-
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0),
|
2615
|
-
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0),
|
3068
|
+
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h);
|
3069
|
+
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h);
|
2616
3070
|
|
2617
3071
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
|
2618
3072
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
|
2619
3073
|
#else
|
2620
|
-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (
|
2621
|
-
const int16x8_t pl0h = vmull_s8(vget_high_s8(
|
2622
|
-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (
|
2623
|
-
const int16x8_t ph0h = vmull_s8(vget_high_s8(
|
3074
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
|
3075
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
|
3076
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
|
3077
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
|
2624
3078
|
|
2625
|
-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (
|
2626
|
-
const int16x8_t pl1h = vmull_s8(vget_high_s8(
|
2627
|
-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (
|
2628
|
-
const int16x8_t ph1h = vmull_s8(vget_high_s8(
|
3079
|
+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
|
3080
|
+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
|
3081
|
+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
|
3082
|
+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
|
2629
3083
|
|
2630
3084
|
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2631
3085
|
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
@@ -2637,65 +3091,40 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
|
|
2637
3091
|
#endif
|
2638
3092
|
}
|
2639
3093
|
|
2640
|
-
|
3094
|
+
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
|
2641
3095
|
#elif defined(__AVX2__)
|
2642
3096
|
// Initialize accumulator with zeros
|
2643
3097
|
__m256 acc = _mm256_setzero_ps();
|
2644
3098
|
|
3099
|
+
float summs = 0;
|
3100
|
+
|
2645
3101
|
// Main loop
|
2646
3102
|
for (int i = 0; i < nb; ++i) {
|
2647
3103
|
const float * d0 = &x[i].d;
|
2648
3104
|
const float * d1 = &y[i].d;
|
2649
|
-
|
3105
|
+
|
3106
|
+
summs += x[i].m * (y[i].s0 + y[i].s1);
|
2650
3107
|
|
2651
3108
|
const __m256 d0v = _mm256_broadcast_ss( d0 );
|
2652
3109
|
const __m256 d1v = _mm256_broadcast_ss( d1 );
|
2653
|
-
const __m256 m0v = _mm256_broadcast_ss( m0 );
|
2654
3110
|
|
2655
3111
|
// Compute combined scales
|
2656
3112
|
const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
|
2657
|
-
const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
|
2658
3113
|
|
2659
3114
|
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
|
2660
3115
|
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
|
2661
3116
|
const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
|
2662
3117
|
|
2663
|
-
|
2664
|
-
const __m256i ax = _mm256_sign_epi8( bx, bx );
|
2665
|
-
|
2666
|
-
// Sign the values of the y vectors
|
2667
|
-
const __m256i sy = _mm256_sign_epi8( by, bx );
|
2668
|
-
|
2669
|
-
// Perform multiplication and create 16-bit values
|
2670
|
-
const __m256i dot = _mm256_maddubs_epi16( ax, sy );
|
2671
|
-
const __m256i ones = _mm256_set1_epi16( 1 );
|
2672
|
-
const __m256i xy_q = _mm256_madd_epi16( ones, dot );
|
2673
|
-
|
2674
|
-
// Convert to vector of 8 int32_t to 8 floats
|
2675
|
-
const __m256 xy = _mm256_cvtepi32_ps( xy_q );
|
3118
|
+
const __m256 xy = mul_sum_i8_pairs_float(bx, by);
|
2676
3119
|
|
2677
3120
|
// Accumulate d0*d1*x*y
|
2678
3121
|
acc = _mm256_fmadd_ps( d0d1, xy, acc );
|
2679
|
-
|
2680
|
-
// Compute sum of y values
|
2681
|
-
const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
|
2682
|
-
const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
|
2683
|
-
const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
|
2684
|
-
const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
|
2685
|
-
|
2686
|
-
// Accumulate d1*m0*y
|
2687
|
-
acc = _mm256_fmadd_ps( d1m0, ysum, acc );
|
2688
3122
|
}
|
2689
3123
|
|
2690
|
-
|
2691
|
-
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
2692
|
-
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2693
|
-
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2694
|
-
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2695
|
-
|
2696
|
-
sumf = _mm_cvtss_f32( res );
|
3124
|
+
*s = hsum_float_8(acc) + summs;
|
2697
3125
|
#else
|
2698
3126
|
// scalar
|
3127
|
+
float sumf = 0.0;
|
2699
3128
|
for (int i = 0; i < nb; i++) {
|
2700
3129
|
const float d0 = x[i].d;
|
2701
3130
|
const float m0 = x[i].m;
|
@@ -2705,347 +3134,685 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
|
|
2705
3134
|
const int8_t * restrict p1 = y[i].qs;
|
2706
3135
|
|
2707
3136
|
// TODO: this is very slow ..
|
2708
|
-
for (int j = 0; j <
|
3137
|
+
for (int j = 0; j < QK8_1/2; j++) {
|
2709
3138
|
const uint8_t v0 = p0[j];
|
2710
3139
|
|
2711
|
-
const float f0 = d0*(v0 &
|
2712
|
-
const float f1 = d0*(v0 >>
|
3140
|
+
const float f0 = d0*(v0 & 0x0F) + m0;
|
3141
|
+
const float f1 = d0*(v0 >> 4) + m0;
|
3142
|
+
|
3143
|
+
const float f2 = d1*p1[2*j + 0];
|
3144
|
+
const float f3 = d1*p1[2*j + 1];
|
3145
|
+
|
3146
|
+
sumf += f0*f2 + f1*f3;
|
3147
|
+
}
|
3148
|
+
}
|
3149
|
+
*s = sumf;
|
3150
|
+
#endif
|
3151
|
+
}
|
3152
|
+
|
3153
|
+
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
3154
|
+
const int nb = n / QK8_0;
|
3155
|
+
|
3156
|
+
assert(n % QK8_0 == 0);
|
3157
|
+
assert(nb % 2 == 0);
|
3158
|
+
assert(QK8_0 == 2*QK4_2);
|
3159
|
+
|
3160
|
+
const block_q4_2 * restrict x = vx;
|
3161
|
+
const block_q8_0 * restrict y = vy;
|
3162
|
+
|
3163
|
+
#if defined(__ARM_NEON)
|
3164
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
3165
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
3166
|
+
|
3167
|
+
for (int i = 0; i < nb; i += 2) {
|
3168
|
+
const block_q4_2 * restrict x0_0 = &x[2*(i + 0) + 0];
|
3169
|
+
const block_q4_2 * restrict x0_1 = &x[2*(i + 0) + 1];
|
3170
|
+
const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0];
|
3171
|
+
const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1];
|
3172
|
+
|
3173
|
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
3174
|
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
3175
|
+
|
3176
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
3177
|
+
const int8x16_t s8b = vdupq_n_s8(0x8);
|
3178
|
+
|
3179
|
+
const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
|
3180
|
+
const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
|
3181
|
+
|
3182
|
+
// 4-bit -> 8-bit
|
3183
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
3184
|
+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
3185
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
3186
|
+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
3187
|
+
|
3188
|
+
// sub 8
|
3189
|
+
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
3190
|
+
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
3191
|
+
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
3192
|
+
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
3193
|
+
|
3194
|
+
// interleave
|
3195
|
+
const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
|
3196
|
+
const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
|
3197
|
+
const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
|
3198
|
+
const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
|
3199
|
+
|
3200
|
+
// load y
|
3201
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
3202
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
3203
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
3204
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
3205
|
+
|
3206
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
3207
|
+
sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
|
3208
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
|
3209
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
|
3210
|
+
|
3211
|
+
sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
|
3212
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
|
3213
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
|
3214
|
+
#else
|
3215
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
|
3216
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
|
3217
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
|
3218
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
|
3219
|
+
|
3220
|
+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
|
3221
|
+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
|
3222
|
+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
|
3223
|
+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
|
3224
|
+
|
3225
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
3226
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
3227
|
+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
3228
|
+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
3229
|
+
|
3230
|
+
sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
|
3231
|
+
vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
|
3232
|
+
vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
|
3233
|
+
|
3234
|
+
sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
|
3235
|
+
vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
|
3236
|
+
vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
|
3237
|
+
#endif
|
3238
|
+
}
|
3239
|
+
|
3240
|
+
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
3241
|
+
#elif defined(__AVX2__)
|
3242
|
+
// Initialize accumulator with zeros
|
3243
|
+
__m256 acc = _mm256_setzero_ps();
|
3244
|
+
|
3245
|
+
// Main loop
|
3246
|
+
for (int i = 0; i < nb; i++) {
|
3247
|
+
/* Compute combined scale for the block */
|
3248
|
+
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
|
3249
|
+
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
|
3250
|
+
const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
|
3251
|
+
|
3252
|
+
__m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
|
3253
|
+
__m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
|
3254
|
+
__m256i bx = _mm256_set_m128i(bx1, bx0);
|
3255
|
+
|
3256
|
+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
3257
|
+
const __m256i off = _mm256_set1_epi8(8);
|
3258
|
+
bx = _mm256_sub_epi8(bx, off);
|
3259
|
+
|
3260
|
+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
3261
|
+
|
3262
|
+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
3263
|
+
|
3264
|
+
/* Multiply q with scale and accumulate */
|
3265
|
+
acc = _mm256_fmadd_ps(d, q, acc);
|
3266
|
+
}
|
3267
|
+
|
3268
|
+
*s = hsum_float_8(acc);
|
3269
|
+
#else
|
3270
|
+
// scalar
|
3271
|
+
float sumf = 0.0;
|
3272
|
+
for (int i = 0; i < nb; i++) {
|
3273
|
+
const uint8_t * restrict x0 = x[2*i + 0].qs;
|
3274
|
+
const uint8_t * restrict x1 = x[2*i + 1].qs;
|
3275
|
+
const int8_t * restrict y0 = y[i].qs;
|
3276
|
+
|
3277
|
+
const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
|
3278
|
+
const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
|
3279
|
+
|
3280
|
+
int sumi_0 = 0;
|
3281
|
+
int sumi_1 = 0;
|
3282
|
+
|
3283
|
+
for (int j = 0; j < QK8_0/4; j++) {
|
3284
|
+
const uint8_t v0 = x0[j];
|
3285
|
+
const uint8_t v1 = x1[j];
|
3286
|
+
|
3287
|
+
const int i0_0 = (int8_t) (v0 & 0x0F) - 8;
|
3288
|
+
const int i1_0 = (int8_t) (v0 >> 4) - 8;
|
3289
|
+
|
3290
|
+
const int i0_1 = (int8_t) (v1 & 0x0F) - 8;
|
3291
|
+
const int i1_1 = (int8_t) (v1 >> 4) - 8;
|
3292
|
+
|
3293
|
+
const int i2_0 = y0[2*j + 0];
|
3294
|
+
const int i3_0 = y0[2*j + 1];
|
3295
|
+
|
3296
|
+
const int i2_1 = y0[2*(j + QK8_0/4) + 0];
|
3297
|
+
const int i3_1 = y0[2*(j + QK8_0/4) + 1];
|
3298
|
+
|
3299
|
+
sumi_0 += i0_0*i2_0 + i1_0*i3_0;
|
3300
|
+
sumi_1 += i0_1*i2_1 + i1_1*i3_1;
|
3301
|
+
}
|
3302
|
+
|
3303
|
+
sumf += (d0 * y[i].d) * sumi_0;
|
3304
|
+
sumf += (d1 * y[i].d) * sumi_1;
|
3305
|
+
}
|
3306
|
+
*s = sumf;
|
3307
|
+
#endif
|
3308
|
+
}
|
3309
|
+
|
3310
|
+
static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
3311
|
+
const int nb = n / QK8_0;
|
3312
|
+
|
3313
|
+
assert(n % QK8_0 == 0);
|
3314
|
+
assert(nb % 2 == 0);
|
3315
|
+
assert(QK8_0 == QK5_0);
|
3316
|
+
|
3317
|
+
const block_q5_0 * restrict x = vx;
|
3318
|
+
const block_q8_0 * restrict y = vy;
|
3319
|
+
|
3320
|
+
#if defined(__ARM_NEON)
|
3321
|
+
float32x4_t sumv = vdupq_n_f32(0.0f);
|
3322
|
+
|
3323
|
+
uint64_t tmp[4];
|
3324
|
+
|
3325
|
+
for (int i = 0; i < nb; ++i) {
|
3326
|
+
const block_q5_0 * restrict x0 = &x[i];
|
3327
|
+
const block_q8_0 * restrict y0 = &y[i];
|
3328
|
+
|
3329
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
3330
|
+
const int8x16_t s16b = vdupq_n_s8(0x10);
|
3331
|
+
|
3332
|
+
// extract the 5th bit
|
3333
|
+
uint32_t qh;
|
3334
|
+
memcpy(&qh, x0->qh, sizeof(qh));
|
3335
|
+
|
3336
|
+
tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
|
3337
|
+
tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
|
3338
|
+
tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
|
3339
|
+
tmp[3] = table_b2b_u[(qh >> 24) ];
|
3340
|
+
|
3341
|
+
const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
|
3342
|
+
const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
|
3343
|
+
|
3344
|
+
const uint8x16_t v0 = vld1q_u8(x0->qs);
|
3345
|
+
|
3346
|
+
// 4-bit -> 8-bit
|
3347
|
+
const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, m4b));
|
3348
|
+
const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
|
3349
|
+
|
3350
|
+
// interleave
|
3351
|
+
const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
|
3352
|
+
const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
|
3353
|
+
|
3354
|
+
// add high bit and sub 16
|
3355
|
+
const int8x16_t v0lf = vsubq_s8(vorrq_s8(v0lz, qhl), s16b);
|
3356
|
+
const int8x16_t v0hf = vsubq_s8(vorrq_s8(v0hz, qhh), s16b);
|
3357
|
+
|
3358
|
+
// load y
|
3359
|
+
const int8x16_t v1l = vld1q_s8(y0->qs);
|
3360
|
+
const int8x16_t v1h = vld1q_s8(y0->qs + 16);
|
3361
|
+
|
3362
|
+
const float x0d = GGML_FP16_TO_FP32(x0->d);
|
3363
|
+
|
3364
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
3365
|
+
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
|
3366
|
+
vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
|
3367
|
+
vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
|
3368
|
+
#else
|
3369
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
|
3370
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
|
3371
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
|
3372
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
|
3373
|
+
|
3374
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
3375
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
3376
|
+
|
3377
|
+
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
|
3378
|
+
#endif
|
3379
|
+
}
|
3380
|
+
|
3381
|
+
*s = vaddvq_f32(sumv);
|
3382
|
+
#elif defined(__wasm_simd128__)
|
3383
|
+
v128_t sumv = wasm_f32x4_splat(0.0f);
|
3384
|
+
|
3385
|
+
uint64_t tmp[4];
|
3386
|
+
|
3387
|
+
for (int i = 0; i < nb; ++i) {
|
3388
|
+
const block_q5_0 * restrict x0 = &x[i];
|
3389
|
+
const block_q8_0 * restrict y0 = &y[i];
|
3390
|
+
|
3391
|
+
const v128_t m4b = wasm_i8x16_splat(0x0F);
|
3392
|
+
const v128_t s16b = wasm_i8x16_splat(0x10);
|
3393
|
+
|
3394
|
+
// extract the 5th bit
|
3395
|
+
uint32_t qh;
|
3396
|
+
memcpy(&qh, x0->qh, sizeof(qh));
|
3397
|
+
|
3398
|
+
tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
|
3399
|
+
tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
|
3400
|
+
tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
|
3401
|
+
tmp[3] = table_b2b_u[(qh >> 24) ];
|
3402
|
+
|
3403
|
+
const v128_t qhl = wasm_v128_load(tmp + 0);
|
3404
|
+
const v128_t qhh = wasm_v128_load(tmp + 2);
|
3405
|
+
|
3406
|
+
const v128_t v0 = wasm_v128_load(x0->qs);
|
3407
|
+
|
3408
|
+
// 4-bit -> 8-bit
|
3409
|
+
const v128_t v0l = wasm_v128_and (v0, m4b);
|
3410
|
+
const v128_t v0h = wasm_u8x16_shr(v0, 4);
|
3411
|
+
|
3412
|
+
// interleave
|
3413
|
+
const v128_t v0lz = wasm_v8x16_shuffle(v0l, v0h, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23);
|
3414
|
+
const v128_t v0hz = wasm_v8x16_shuffle(v0l, v0h, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31);
|
3415
|
+
|
3416
|
+
// add high bit and sub 16
|
3417
|
+
const v128_t v0lf = wasm_i8x16_sub(wasm_v128_or(v0lz, qhl), s16b);
|
3418
|
+
const v128_t v0hf = wasm_i8x16_sub(wasm_v128_or(v0hz, qhh), s16b);
|
3419
|
+
|
3420
|
+
// load y
|
3421
|
+
const v128_t v1l = wasm_v128_load(y0->qs);
|
3422
|
+
const v128_t v1h = wasm_v128_load(y0->qs + 16);
|
3423
|
+
|
3424
|
+
// int8x16 -> int16x8
|
3425
|
+
const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
|
3426
|
+
const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
|
3427
|
+
const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
|
3428
|
+
const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
|
3429
|
+
|
3430
|
+
const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
|
3431
|
+
const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
|
3432
|
+
const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
|
3433
|
+
const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
|
3434
|
+
|
3435
|
+
const float x0d = GGML_FP16_TO_FP32(x0->d);
|
3436
|
+
|
3437
|
+
// dot product
|
3438
|
+
sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
|
3439
|
+
wasm_i32x4_add(
|
3440
|
+
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
|
3441
|
+
wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
|
3442
|
+
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
|
3443
|
+
wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), wasm_f32x4_splat(x0d*y0->d)));
|
3444
|
+
}
|
3445
|
+
|
3446
|
+
*s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
|
3447
|
+
wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
|
3448
|
+
#elif defined(__AVX2__)
|
3449
|
+
// Initialize accumulator with zeros
|
3450
|
+
__m256 acc = _mm256_setzero_ps();
|
3451
|
+
|
3452
|
+
// Main loop
|
3453
|
+
for (int i = 0; i < nb; i++) {
|
3454
|
+
/* Compute combined scale for the block */
|
3455
|
+
const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
|
3456
|
+
|
3457
|
+
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
3458
|
+
__m256i bxhi = bytes_from_bits_32(x[i].qh);
|
3459
|
+
bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
|
3460
|
+
bx = _mm256_or_si256(bx, bxhi);
|
3461
|
+
|
3462
|
+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
3463
|
+
|
3464
|
+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
3465
|
+
|
3466
|
+
/* Multiply q with scale and accumulate */
|
3467
|
+
acc = _mm256_fmadd_ps(d, q, acc);
|
3468
|
+
}
|
3469
|
+
|
3470
|
+
*s = hsum_float_8(acc);
|
3471
|
+
#else
|
3472
|
+
// scalar
|
3473
|
+
float sumf = 0.0;
|
3474
|
+
for (int i = 0; i < nb; i++) {
|
3475
|
+
const uint8_t * restrict x0 = x[i].qs;
|
3476
|
+
const int8_t * restrict y0 = y[i].qs;
|
3477
|
+
|
3478
|
+
uint32_t qh;
|
3479
|
+
memcpy(&qh, x[i].qh, sizeof(qh));
|
3480
|
+
|
3481
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
3482
|
+
|
3483
|
+
int sxy = 0;
|
3484
|
+
|
3485
|
+
for (int j = 0; j < QK8_0/2; j++) {
|
3486
|
+
const uint8_t v0 = x0[j];
|
3487
|
+
|
3488
|
+
const int x0_0h = ((qh & (1u << (2*j + 0))) >> (2*j + 0)) << 4;
|
3489
|
+
const int x1_0h = ((qh & (1u << (2*j + 1))) >> (2*j + 1)) << 4;
|
3490
|
+
|
3491
|
+
const int x0_0 = ((v0 & 0x0F) | x0_0h) - 16;
|
3492
|
+
const int x1_0 = ((v0 >> 4) | x1_0h) - 16;
|
3493
|
+
|
3494
|
+
const int y0_0 = y0[2*j + 0];
|
3495
|
+
const int y1_0 = y0[2*j + 1];
|
3496
|
+
|
3497
|
+
sxy += x0_0*y0_0 + x1_0*y1_0;
|
3498
|
+
}
|
3499
|
+
|
3500
|
+
sumf += (d*sxy)*y[i].d;
|
3501
|
+
}
|
3502
|
+
*s = sumf;
|
3503
|
+
#endif
|
3504
|
+
}
|
3505
|
+
|
3506
|
+
static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
3507
|
+
const int nb = n / QK8_1;
|
3508
|
+
|
3509
|
+
assert(n % QK8_1 == 0);
|
3510
|
+
assert(nb % 2 == 0);
|
3511
|
+
assert(QK8_1 == QK5_1);
|
3512
|
+
|
3513
|
+
const block_q5_1 * restrict x = vx;
|
3514
|
+
const block_q8_1 * restrict y = vy;
|
3515
|
+
|
3516
|
+
#if defined(__ARM_NEON)
|
3517
|
+
float32x4_t sumv = vdupq_n_f32(0.0f);
|
3518
|
+
|
3519
|
+
float summs = 0.0f;
|
3520
|
+
|
3521
|
+
uint64_t tmp[4];
|
3522
|
+
|
3523
|
+
for (int i = 0; i < nb; ++i) {
|
3524
|
+
const block_q5_1 * restrict x0 = &x[i];
|
3525
|
+
const block_q8_1 * restrict y0 = &y[i];
|
3526
|
+
|
3527
|
+
summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
|
3528
|
+
|
3529
|
+
// extract the 5th bit
|
3530
|
+
uint32_t qh;
|
3531
|
+
memcpy(&qh, x0->qh, sizeof(qh));
|
3532
|
+
|
3533
|
+
tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
|
3534
|
+
tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
|
3535
|
+
tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
|
3536
|
+
tmp[3] = table_b2b_u[(qh >> 24) ];
|
3537
|
+
|
3538
|
+
const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
|
3539
|
+
const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
|
3540
|
+
|
3541
|
+
const uint8x16_t v0 = vld1q_u8(x0->qs);
|
3542
|
+
|
3543
|
+
// 4-bit -> 8-bit
|
3544
|
+
const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, vdupq_n_u8(0x0F)));
|
3545
|
+
const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
|
3546
|
+
|
3547
|
+
// interleave
|
3548
|
+
const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
|
3549
|
+
const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
|
3550
|
+
|
3551
|
+
// add
|
3552
|
+
const int8x16_t v0lf = vorrq_s8(v0lz, qhl);
|
3553
|
+
const int8x16_t v0hf = vorrq_s8(v0hz, qhh);
|
3554
|
+
|
3555
|
+
// load y
|
3556
|
+
const int8x16_t v1l = vld1q_s8(y0->qs);
|
3557
|
+
const int8x16_t v1h = vld1q_s8(y0->qs + 16);
|
3558
|
+
|
3559
|
+
const float x0d = GGML_FP16_TO_FP32(x0->d);
|
2713
3560
|
|
2714
|
-
|
2715
|
-
|
3561
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
3562
|
+
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
|
3563
|
+
vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
|
3564
|
+
vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
|
3565
|
+
#else
|
3566
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
|
3567
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
|
3568
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
|
3569
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
|
2716
3570
|
|
2717
|
-
|
2718
|
-
|
2719
|
-
|
3571
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
3572
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
3573
|
+
|
3574
|
+
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
|
2720
3575
|
#endif
|
3576
|
+
}
|
2721
3577
|
|
2722
|
-
*s =
|
2723
|
-
|
3578
|
+
*s = vaddvq_f32(sumv) + summs;
|
3579
|
+
#elif defined(__wasm_simd128__)
|
3580
|
+
v128_t sumv = wasm_f32x4_splat(0.0f);
|
2724
3581
|
|
2725
|
-
|
2726
|
-
const int nb = n / QK8_0;
|
3582
|
+
float summs = 0.0f;
|
2727
3583
|
|
2728
|
-
|
2729
|
-
assert(nb % 2 == 0);
|
2730
|
-
assert(QK8_0 == 2*QK4_2);
|
3584
|
+
uint64_t tmp[4];
|
2731
3585
|
|
2732
|
-
|
2733
|
-
|
3586
|
+
for (int i = 0; i < nb; ++i) {
|
3587
|
+
const block_q5_1 * restrict x0 = &x[i];
|
3588
|
+
const block_q8_1 * restrict y0 = &y[i];
|
2734
3589
|
|
2735
|
-
|
3590
|
+
summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
|
2736
3591
|
|
2737
|
-
|
2738
|
-
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
2739
|
-
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
3592
|
+
const v128_t m4b = wasm_i8x16_splat(0x0F);
|
2740
3593
|
|
2741
|
-
|
2742
|
-
|
2743
|
-
|
2744
|
-
const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0];
|
2745
|
-
const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1];
|
3594
|
+
// extract the 5th bit
|
3595
|
+
uint32_t qh;
|
3596
|
+
memcpy(&qh, x0->qh, sizeof(qh));
|
2746
3597
|
|
2747
|
-
|
2748
|
-
|
3598
|
+
tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
|
3599
|
+
tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
|
3600
|
+
tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
|
3601
|
+
tmp[3] = table_b2b_u[(qh >> 24) ];
|
2749
3602
|
|
2750
|
-
const
|
2751
|
-
const
|
3603
|
+
const v128_t qhl = wasm_v128_load(tmp + 0);
|
3604
|
+
const v128_t qhh = wasm_v128_load(tmp + 2);
|
2752
3605
|
|
2753
|
-
const
|
2754
|
-
const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
|
3606
|
+
const v128_t v0 = wasm_v128_load(x0->qs);
|
2755
3607
|
|
2756
3608
|
// 4-bit -> 8-bit
|
2757
|
-
const
|
2758
|
-
const
|
2759
|
-
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
2760
|
-
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
3609
|
+
const v128_t v0l = wasm_v128_and (v0, m4b);
|
3610
|
+
const v128_t v0h = wasm_u8x16_shr(v0, 4);
|
2761
3611
|
|
2762
|
-
|
2763
|
-
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
2764
|
-
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
2765
|
-
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
2766
|
-
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
3612
|
+
static bool x = true;
|
2767
3613
|
|
2768
3614
|
// interleave
|
2769
|
-
const
|
2770
|
-
const
|
2771
|
-
const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
|
2772
|
-
const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
|
2773
|
-
|
2774
|
-
// load y
|
2775
|
-
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2776
|
-
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2777
|
-
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2778
|
-
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
3615
|
+
const v128_t v0lz = wasm_v8x16_shuffle(v0l, v0h, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23);
|
3616
|
+
const v128_t v0hz = wasm_v8x16_shuffle(v0l, v0h, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31);
|
2779
3617
|
|
2780
|
-
|
2781
|
-
|
2782
|
-
|
2783
|
-
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
|
3618
|
+
// add high bit
|
3619
|
+
const v128_t v0lf = wasm_v128_or(v0lz, qhl);
|
3620
|
+
const v128_t v0hf = wasm_v128_or(v0hz, qhh);
|
2784
3621
|
|
2785
|
-
|
2786
|
-
|
2787
|
-
|
2788
|
-
#else
|
2789
|
-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
|
2790
|
-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
|
2791
|
-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
|
2792
|
-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
|
3622
|
+
// load y
|
3623
|
+
const v128_t v1l = wasm_v128_load(y0->qs);
|
3624
|
+
const v128_t v1h = wasm_v128_load(y0->qs + 16);
|
2793
3625
|
|
2794
|
-
|
2795
|
-
const
|
2796
|
-
const
|
2797
|
-
const
|
3626
|
+
// int8x16 -> int16x8
|
3627
|
+
const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
|
3628
|
+
const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
|
3629
|
+
const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
|
3630
|
+
const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
|
2798
3631
|
|
2799
|
-
const
|
2800
|
-
const
|
2801
|
-
const
|
2802
|
-
const
|
3632
|
+
const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
|
3633
|
+
const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
|
3634
|
+
const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
|
3635
|
+
const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
|
2803
3636
|
|
2804
|
-
|
2805
|
-
vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
|
2806
|
-
vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
|
3637
|
+
const float x0d = GGML_FP16_TO_FP32(x0->d);
|
2807
3638
|
|
2808
|
-
|
2809
|
-
|
2810
|
-
|
2811
|
-
|
3639
|
+
// dot product
|
3640
|
+
sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
|
3641
|
+
wasm_i32x4_add(
|
3642
|
+
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
|
3643
|
+
wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
|
3644
|
+
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
|
3645
|
+
wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), wasm_f32x4_splat(x0d*y0->d)));
|
2812
3646
|
}
|
2813
3647
|
|
2814
|
-
|
3648
|
+
*s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
|
3649
|
+
wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
|
2815
3650
|
#elif defined(__AVX2__)
|
2816
3651
|
// Initialize accumulator with zeros
|
2817
3652
|
__m256 acc = _mm256_setzero_ps();
|
3653
|
+
float summs = 0.0f;
|
2818
3654
|
|
2819
3655
|
// Main loop
|
2820
3656
|
for (int i = 0; i < nb; i++) {
|
2821
|
-
|
2822
|
-
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
|
2823
|
-
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
|
2824
|
-
const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
|
2825
|
-
|
2826
|
-
__m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
|
2827
|
-
__m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
|
2828
|
-
__m256i bx = _mm256_set_m128i(bx1, bx0);
|
2829
|
-
|
2830
|
-
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2831
|
-
const __m256i off = _mm256_set1_epi8(8);
|
2832
|
-
bx = _mm256_sub_epi8(bx, off);
|
3657
|
+
const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
|
2833
3658
|
|
2834
|
-
|
3659
|
+
summs += GGML_FP16_TO_FP32(x[i].m) * (y[i].s0 + y[i].s1);
|
2835
3660
|
|
2836
|
-
|
2837
|
-
|
2838
|
-
|
2839
|
-
|
2840
|
-
// Perform multiplication and create 16-bit values
|
2841
|
-
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
3661
|
+
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
3662
|
+
__m256i bxhi = bytes_from_bits_32(x[i].qh);
|
3663
|
+
bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
|
3664
|
+
bx = _mm256_or_si256(bx, bxhi);
|
2842
3665
|
|
2843
|
-
const
|
2844
|
-
__m256i
|
3666
|
+
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
|
3667
|
+
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
2845
3668
|
|
2846
|
-
|
2847
|
-
__m256 q = _mm256_cvtepi32_ps(xy_q);
|
3669
|
+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
2848
3670
|
|
2849
|
-
|
2850
|
-
acc = _mm256_fmadd_ps(d, q, acc);
|
3671
|
+
acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
|
2851
3672
|
}
|
2852
3673
|
|
2853
|
-
|
2854
|
-
__m128 res = _mm256_extractf128_ps(acc, 1);
|
2855
|
-
res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
|
2856
|
-
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
|
2857
|
-
res = _mm_add_ss(res, _mm_movehdup_ps(res));
|
2858
|
-
|
2859
|
-
sumf = _mm_cvtss_f32(res);
|
3674
|
+
*s = hsum_float_8(acc) + summs;
|
2860
3675
|
#else
|
2861
|
-
|
3676
|
+
float sumf = 0.0;
|
3677
|
+
|
2862
3678
|
for (int i = 0; i < nb; i++) {
|
2863
|
-
const uint8_t * restrict x0 = x[
|
2864
|
-
const uint8_t * restrict x1 = x[2*i + 1].qs;
|
3679
|
+
const uint8_t * restrict x0 = x[i].qs;
|
2865
3680
|
const int8_t * restrict y0 = y[i].qs;
|
2866
3681
|
|
2867
|
-
|
2868
|
-
|
3682
|
+
uint32_t qh;
|
3683
|
+
memcpy(&qh, x[i].qh, sizeof(qh));
|
2869
3684
|
|
2870
|
-
|
2871
|
-
|
3685
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
3686
|
+
const float m = GGML_FP16_TO_FP32(x[i].m);
|
2872
3687
|
|
2873
|
-
|
2874
|
-
const uint8_t v0 = x0[j];
|
2875
|
-
const uint8_t v1 = x1[j];
|
3688
|
+
int sxy = 0;
|
2876
3689
|
|
2877
|
-
|
2878
|
-
const
|
3690
|
+
for (int j = 0; j < QK8_1/2; j++) {
|
3691
|
+
const uint8_t v0 = x0[j];
|
2879
3692
|
|
2880
|
-
const int
|
2881
|
-
const int
|
3693
|
+
const int x0_0h = ((qh & (1u << (2*j + 0))) >> (2*j + 0)) << 4;
|
3694
|
+
const int x1_0h = ((qh & (1u << (2*j + 1))) >> (2*j + 1)) << 4;
|
2882
3695
|
|
2883
|
-
const int
|
2884
|
-
const int
|
3696
|
+
const int x0_0 = (v0 & 0x0F) | x0_0h;
|
3697
|
+
const int x1_0 = (v0 >> 4) | x1_0h;
|
2885
3698
|
|
2886
|
-
const int
|
2887
|
-
const int
|
3699
|
+
const int y0_0 = y0[2*j + 0];
|
3700
|
+
const int y1_0 = y0[2*j + 1];
|
2888
3701
|
|
2889
|
-
|
2890
|
-
sumi_1 += i0_1*i2_1 + i1_1*i3_1;
|
3702
|
+
sxy += x0_0*y0_0 + x1_0*y1_0;
|
2891
3703
|
}
|
2892
3704
|
|
2893
|
-
sumf += (
|
2894
|
-
sumf += (d1 * y[i].d) * sumi_1;
|
3705
|
+
sumf += (d*sxy)*y[i].d + m*(y[i].s0 + y[i].s1);
|
2895
3706
|
}
|
2896
|
-
#endif
|
2897
3707
|
|
2898
3708
|
*s = sumf;
|
3709
|
+
#endif
|
2899
3710
|
}
|
2900
3711
|
|
2901
|
-
static void
|
3712
|
+
static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2902
3713
|
const int nb = n / QK8_0;
|
2903
3714
|
|
2904
3715
|
assert(n % QK8_0 == 0);
|
2905
3716
|
assert(nb % 2 == 0);
|
2906
|
-
assert(QK8_0 ==
|
3717
|
+
assert(QK8_0 == QK8_0);
|
2907
3718
|
|
2908
|
-
const
|
3719
|
+
const block_q8_0 * restrict x = vx;
|
2909
3720
|
const block_q8_0 * restrict y = vy;
|
2910
3721
|
|
2911
|
-
float sumf = 0.0;
|
2912
|
-
|
2913
3722
|
#if defined(__ARM_NEON)
|
2914
3723
|
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
2915
3724
|
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
2916
3725
|
|
2917
3726
|
for (int i = 0; i < nb; i += 2) {
|
2918
|
-
const
|
2919
|
-
const
|
2920
|
-
const block_q4_3 * restrict x1_0 = &x[2*(i + 1) + 0];
|
2921
|
-
const block_q4_3 * restrict x1_1 = &x[2*(i + 1) + 1];
|
2922
|
-
|
3727
|
+
const block_q8_0 * restrict x0 = &x[i + 0];
|
3728
|
+
const block_q8_0 * restrict x1 = &x[i + 1];
|
2923
3729
|
const block_q8_0 * restrict y0 = &y[i + 0];
|
2924
3730
|
const block_q8_0 * restrict y1 = &y[i + 1];
|
2925
3731
|
|
2926
|
-
const
|
2927
|
-
|
2928
|
-
const
|
2929
|
-
const
|
2930
|
-
const float x1_0d = GGML_FP16_TO_FP32(x1_0->d);
|
2931
|
-
const float x1_1d = GGML_FP16_TO_FP32(x1_1->d);
|
2932
|
-
|
2933
|
-
const float x0_0m = GGML_FP16_TO_FP32(x0_0->m);
|
2934
|
-
const float x0_1m = GGML_FP16_TO_FP32(x0_1->m);
|
2935
|
-
const float x1_0m = GGML_FP16_TO_FP32(x1_0->m);
|
2936
|
-
const float x1_1m = GGML_FP16_TO_FP32(x1_1->m);
|
2937
|
-
|
2938
|
-
const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
|
2939
|
-
const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
|
2940
|
-
|
2941
|
-
// 4-bit -> 8-bit
|
2942
|
-
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
2943
|
-
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
2944
|
-
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
2945
|
-
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
2946
|
-
|
2947
|
-
// interleave
|
2948
|
-
const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
|
2949
|
-
const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
|
2950
|
-
const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
|
2951
|
-
const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
|
3732
|
+
const int8x16_t x0_0 = vld1q_s8(x0->qs);
|
3733
|
+
const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
|
3734
|
+
const int8x16_t x1_0 = vld1q_s8(x1->qs);
|
3735
|
+
const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
|
2952
3736
|
|
2953
3737
|
// load y
|
2954
|
-
const int8x16_t
|
2955
|
-
const int8x16_t
|
2956
|
-
const int8x16_t
|
2957
|
-
const int8x16_t
|
2958
|
-
|
2959
|
-
const int16x8_t sy0_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0l)), vmovl_s8(vget_high_s8(v1_0l)));
|
2960
|
-
const int16x8_t sy0_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0h)), vmovl_s8(vget_high_s8(v1_0h)));
|
2961
|
-
|
2962
|
-
const int16x8_t sy1_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1l)), vmovl_s8(vget_high_s8(v1_1l)));
|
2963
|
-
const int16x8_t sy1_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1h)), vmovl_s8(vget_high_s8(v1_1h)));
|
2964
|
-
|
2965
|
-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_0), vget_high_s16(sy0_0))), x0_0m*y0->d);
|
2966
|
-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_1), vget_high_s16(sy0_1))), x0_1m*y0->d);
|
2967
|
-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_0), vget_high_s16(sy1_0))), x1_0m*y1->d);
|
2968
|
-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_1), vget_high_s16(sy1_1))), x1_1m*y1->d);
|
3738
|
+
const int8x16_t y0_0 = vld1q_s8(y0->qs);
|
3739
|
+
const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
|
3740
|
+
const int8x16_t y1_0 = vld1q_s8(y1->qs);
|
3741
|
+
const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
|
2969
3742
|
|
2970
3743
|
#if defined(__ARM_FEATURE_DOTPROD)
|
2971
|
-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(
|
2972
|
-
|
2973
|
-
|
2974
|
-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), x1_1d*y1->d);
|
2975
|
-
#else
|
2976
|
-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
|
2977
|
-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
|
2978
|
-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
|
2979
|
-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
|
2980
|
-
|
2981
|
-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
|
2982
|
-
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
|
2983
|
-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
|
2984
|
-
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
|
3744
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
3745
|
+
vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
|
3746
|
+
vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), x0->d*y0->d);
|
2985
3747
|
|
2986
|
-
|
2987
|
-
|
2988
|
-
|
2989
|
-
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
3748
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
3749
|
+
vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
|
3750
|
+
vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), x1->d*y1->d);
|
2990
3751
|
|
2991
|
-
|
2992
|
-
|
2993
|
-
|
2994
|
-
|
3752
|
+
#else
|
3753
|
+
const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
|
3754
|
+
const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
|
3755
|
+
const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
|
3756
|
+
const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
3757
|
+
|
3758
|
+
const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
|
3759
|
+
const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
|
3760
|
+
const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
|
3761
|
+
const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
|
3762
|
+
|
3763
|
+
const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
3764
|
+
const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
3765
|
+
const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
|
3766
|
+
const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
|
3767
|
+
|
3768
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0->d*y0->d);
|
3769
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), x1->d*y1->d);
|
2995
3770
|
#endif
|
2996
3771
|
}
|
2997
3772
|
|
2998
|
-
|
2999
|
-
#
|
3000
|
-
//
|
3001
|
-
|
3002
|
-
const uint8_t * restrict x0 = x[2*i + 0].qs;
|
3003
|
-
const uint8_t * restrict x1 = x[2*i + 1].qs;
|
3004
|
-
const int8_t * restrict y0 = y[i].qs;
|
3005
|
-
|
3006
|
-
const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
|
3007
|
-
const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
|
3008
|
-
const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
|
3009
|
-
const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
|
3010
|
-
|
3011
|
-
int sy_0 = 0;
|
3012
|
-
int sy_1 = 0;
|
3773
|
+
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
3774
|
+
#elif defined(__AVX2__)
|
3775
|
+
// Initialize accumulator with zeros
|
3776
|
+
__m256 acc = _mm256_setzero_ps();
|
3013
3777
|
|
3014
|
-
|
3015
|
-
|
3778
|
+
// Main loop
|
3779
|
+
for (int i = 0; i < nb; ++i) {
|
3780
|
+
// Compute combined scale for the block
|
3781
|
+
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
|
3782
|
+
__m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
|
3783
|
+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
3016
3784
|
|
3017
|
-
|
3018
|
-
const uint8_t v0 = x0[j];
|
3019
|
-
const uint8_t v1 = x1[j];
|
3785
|
+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
3020
3786
|
|
3021
|
-
|
3022
|
-
|
3787
|
+
// Multiply q with scale and accumulate
|
3788
|
+
acc = _mm256_fmadd_ps( d, q, acc );
|
3789
|
+
}
|
3023
3790
|
|
3024
|
-
|
3025
|
-
|
3791
|
+
*s = hsum_float_8(acc);
|
3792
|
+
#else
|
3793
|
+
// scalar
|
3794
|
+
float sumf = 0.0;
|
3026
3795
|
|
3027
|
-
|
3028
|
-
|
3796
|
+
for (int i = 0; i < nb; i++) {
|
3797
|
+
const int8_t * restrict x0 = x[i].qs;
|
3798
|
+
const int8_t * restrict y0 = y[i].qs;
|
3029
3799
|
|
3030
|
-
|
3031
|
-
const int y1_1 = y0[2*(j + QK8_0/4) + 1];
|
3800
|
+
int sumi = 0;
|
3032
3801
|
|
3033
|
-
|
3034
|
-
|
3802
|
+
for (int j = 0; j < QK8_0; j++) {
|
3803
|
+
const int v0 = x0[j];
|
3804
|
+
const int v1 = y0[j];
|
3035
3805
|
|
3036
|
-
|
3037
|
-
sxy_1 += x0_1*y0_1 + x1_1*y1_1;
|
3806
|
+
sumi += v0*v1;
|
3038
3807
|
}
|
3039
3808
|
|
3040
|
-
sumf += (
|
3041
|
-
sumf += (d1*sxy_1 + m1*sy_1)*y[i].d;
|
3809
|
+
sumf += (x[i].d*y[i].d)*sumi;
|
3042
3810
|
}
|
3043
|
-
#endif
|
3044
3811
|
|
3045
3812
|
*s = sumf;
|
3813
|
+
#endif
|
3046
3814
|
}
|
3047
3815
|
|
3048
|
-
|
3049
3816
|
// compute GGML_VEC_DOT_UNROLL dot products at once
|
3050
3817
|
// xs - x row stride in bytes
|
3051
3818
|
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
|
@@ -3242,6 +4009,14 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
|
|
3242
4009
|
#endif
|
3243
4010
|
}
|
3244
4011
|
|
4012
|
+
inline static void ggml_vec_sum_ggf(const int n, ggml_float * s, const float * x) {
|
4013
|
+
ggml_float sum = 0.0;
|
4014
|
+
for (int i = 0; i < n; ++i) {
|
4015
|
+
sum += (ggml_float)x[i];
|
4016
|
+
}
|
4017
|
+
*s = sum;
|
4018
|
+
}
|
4019
|
+
|
3245
4020
|
inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
|
3246
4021
|
#ifndef GGML_USE_ACCELERATE
|
3247
4022
|
float max = -INFINITY;
|
@@ -3293,13 +4068,15 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
|
|
3293
4068
|
[GGML_TYPE_Q4_0] = QK4_0,
|
3294
4069
|
[GGML_TYPE_Q4_1] = QK4_1,
|
3295
4070
|
[GGML_TYPE_Q4_2] = QK4_2,
|
3296
|
-
[
|
4071
|
+
[GGML_TYPE_Q5_0] = QK5_0,
|
4072
|
+
[GGML_TYPE_Q5_1] = QK5_1,
|
3297
4073
|
[GGML_TYPE_Q8_0] = QK8_0,
|
4074
|
+
[GGML_TYPE_Q8_1] = QK8_1,
|
3298
4075
|
[GGML_TYPE_I8] = 1,
|
3299
4076
|
[GGML_TYPE_I16] = 1,
|
3300
4077
|
[GGML_TYPE_I32] = 1,
|
3301
4078
|
};
|
3302
|
-
static_assert(GGML_TYPE_COUNT ==
|
4079
|
+
static_assert(GGML_TYPE_COUNT == 13, "GGML_BLCK_SIZE is outdated");
|
3303
4080
|
|
3304
4081
|
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
|
3305
4082
|
[GGML_TYPE_F32] = sizeof(float),
|
@@ -3307,13 +4084,15 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
|
|
3307
4084
|
[GGML_TYPE_Q4_0] = sizeof(block_q4_0),
|
3308
4085
|
[GGML_TYPE_Q4_1] = sizeof(block_q4_1),
|
3309
4086
|
[GGML_TYPE_Q4_2] = sizeof(block_q4_2),
|
3310
|
-
[
|
4087
|
+
[GGML_TYPE_Q5_0] = sizeof(block_q5_0),
|
4088
|
+
[GGML_TYPE_Q5_1] = sizeof(block_q5_1),
|
3311
4089
|
[GGML_TYPE_Q8_0] = sizeof(block_q8_0),
|
4090
|
+
[GGML_TYPE_Q8_1] = sizeof(block_q8_1),
|
3312
4091
|
[GGML_TYPE_I8] = sizeof(int8_t),
|
3313
4092
|
[GGML_TYPE_I16] = sizeof(int16_t),
|
3314
4093
|
[GGML_TYPE_I32] = sizeof(int32_t),
|
3315
4094
|
};
|
3316
|
-
static_assert(GGML_TYPE_COUNT ==
|
4095
|
+
static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_SIZE is outdated");
|
3317
4096
|
|
3318
4097
|
|
3319
4098
|
static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
|
@@ -3322,13 +4101,15 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
|
|
3322
4101
|
[GGML_TYPE_Q4_0] = "q4_0",
|
3323
4102
|
[GGML_TYPE_Q4_1] = "q4_1",
|
3324
4103
|
[GGML_TYPE_Q4_2] = "q4_2",
|
3325
|
-
[
|
4104
|
+
[GGML_TYPE_Q5_0] = "q5_0",
|
4105
|
+
[GGML_TYPE_Q5_1] = "q5_1",
|
3326
4106
|
[GGML_TYPE_Q8_0] = "q8_0",
|
4107
|
+
[GGML_TYPE_Q8_1] = "q8_1",
|
3327
4108
|
[GGML_TYPE_I8] = "i8",
|
3328
4109
|
[GGML_TYPE_I16] = "i16",
|
3329
4110
|
[GGML_TYPE_I32] = "i32",
|
3330
4111
|
};
|
3331
|
-
static_assert(GGML_TYPE_COUNT ==
|
4112
|
+
static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_NAME is outdated");
|
3332
4113
|
|
3333
4114
|
static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
|
3334
4115
|
[GGML_TYPE_F32] = false,
|
@@ -3336,13 +4117,15 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
|
|
3336
4117
|
[GGML_TYPE_Q4_0] = true,
|
3337
4118
|
[GGML_TYPE_Q4_1] = true,
|
3338
4119
|
[GGML_TYPE_Q4_2] = true,
|
3339
|
-
[
|
4120
|
+
[GGML_TYPE_Q5_0] = true,
|
4121
|
+
[GGML_TYPE_Q5_1] = true,
|
3340
4122
|
[GGML_TYPE_Q8_0] = true,
|
4123
|
+
[GGML_TYPE_Q8_1] = true,
|
3341
4124
|
[GGML_TYPE_I8] = false,
|
3342
4125
|
[GGML_TYPE_I16] = false,
|
3343
4126
|
[GGML_TYPE_I32] = false,
|
3344
4127
|
};
|
3345
|
-
static_assert(GGML_TYPE_COUNT ==
|
4128
|
+
static_assert(GGML_TYPE_COUNT == 13, "GGML_IS_QUANTIZED is outdated");
|
3346
4129
|
|
3347
4130
|
static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
3348
4131
|
"NONE",
|
@@ -3380,6 +4163,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
|
3380
4163
|
"DIAG_MASK_INF",
|
3381
4164
|
"SOFT_MAX",
|
3382
4165
|
"ROPE",
|
4166
|
+
"ALIBI",
|
3383
4167
|
"CONV_1D_1S",
|
3384
4168
|
"CONV_1D_2S",
|
3385
4169
|
|
@@ -3390,7 +4174,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
|
3390
4174
|
"MAP_BINARY",
|
3391
4175
|
};
|
3392
4176
|
|
3393
|
-
static_assert(GGML_OP_COUNT ==
|
4177
|
+
static_assert(GGML_OP_COUNT == 39, "GGML_OP_COUNT != 39");
|
3394
4178
|
|
3395
4179
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
3396
4180
|
"none",
|
@@ -3428,6 +4212,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
3428
4212
|
"diag_mask_inf(x)",
|
3429
4213
|
"soft_max(x)",
|
3430
4214
|
"rope(x)",
|
4215
|
+
"alibi(x)",
|
3431
4216
|
"conv_1d_1s(x)",
|
3432
4217
|
"conv_1d_2s(x)",
|
3433
4218
|
|
@@ -3438,7 +4223,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
3438
4223
|
"f(x,y)",
|
3439
4224
|
};
|
3440
4225
|
|
3441
|
-
static_assert(GGML_OP_COUNT ==
|
4226
|
+
static_assert(GGML_OP_COUNT == 39, "GGML_OP_COUNT != 39");
|
3442
4227
|
|
3443
4228
|
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
3444
4229
|
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
@@ -3608,6 +4393,27 @@ bool ggml_is_quantized(enum ggml_type type) {
|
|
3608
4393
|
return GGML_IS_QUANTIZED[type];
|
3609
4394
|
}
|
3610
4395
|
|
4396
|
+
enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
4397
|
+
enum ggml_type wtype = GGML_TYPE_COUNT;
|
4398
|
+
|
4399
|
+
switch (ftype) {
|
4400
|
+
case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break;
|
4401
|
+
case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break;
|
4402
|
+
case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break;
|
4403
|
+
case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break;
|
4404
|
+
case GGML_FTYPE_MOSTLY_Q4_2: wtype = GGML_TYPE_Q4_2; break;
|
4405
|
+
case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
|
4406
|
+
case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
|
4407
|
+
case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
|
4408
|
+
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
|
4409
|
+
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
|
4410
|
+
}
|
4411
|
+
|
4412
|
+
GGML_ASSERT(wtype != GGML_TYPE_COUNT);
|
4413
|
+
|
4414
|
+
return wtype;
|
4415
|
+
}
|
4416
|
+
|
3611
4417
|
static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
|
3612
4418
|
return tensor->nb[0] > tensor->nb[1];
|
3613
4419
|
}
|
@@ -3718,10 +4524,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|
3718
4524
|
GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
|
3719
4525
|
}
|
3720
4526
|
|
3721
|
-
|
3722
|
-
|
3723
|
-
|
3724
|
-
|
4527
|
+
#if defined(GGML_USE_CUBLAS)
|
4528
|
+
ggml_init_cublas();
|
4529
|
+
#elif defined(GGML_USE_CLBLAST)
|
4530
|
+
ggml_cl_init();
|
4531
|
+
#endif
|
3725
4532
|
|
3726
4533
|
is_first_call = false;
|
3727
4534
|
}
|
@@ -3802,7 +4609,7 @@ void ggml_free(struct ggml_context * ctx) {
|
|
3802
4609
|
}
|
3803
4610
|
|
3804
4611
|
size_t ggml_used_mem(const struct ggml_context * ctx) {
|
3805
|
-
return ctx->objects_end->offs + ctx->objects_end->size;
|
4612
|
+
return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size;
|
3806
4613
|
}
|
3807
4614
|
|
3808
4615
|
size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) {
|
@@ -3915,6 +4722,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
|
|
3915
4722
|
/*.perf_cycles =*/ 0,
|
3916
4723
|
/*.perf_time_us =*/ 0,
|
3917
4724
|
/*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
|
4725
|
+
/*.name =*/ { 0 },
|
3918
4726
|
/*.pad =*/ { 0 },
|
3919
4727
|
};
|
3920
4728
|
|
@@ -4269,6 +5077,15 @@ float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
|
|
4269
5077
|
return (float *)(tensor->data);
|
4270
5078
|
}
|
4271
5079
|
|
5080
|
+
const char * ggml_get_name(const struct ggml_tensor * tensor) {
|
5081
|
+
return tensor->name;
|
5082
|
+
}
|
5083
|
+
|
5084
|
+
void ggml_set_name(struct ggml_tensor * tensor, const char * name) {
|
5085
|
+
strncpy(tensor->name, name, sizeof(tensor->name));
|
5086
|
+
tensor->name[sizeof(tensor->name) - 1] = '\0';
|
5087
|
+
}
|
5088
|
+
|
4272
5089
|
struct ggml_tensor * ggml_view_tensor(
|
4273
5090
|
struct ggml_context * ctx,
|
4274
5091
|
const struct ggml_tensor * src) {
|
@@ -5368,6 +6185,7 @@ struct ggml_tensor * ggml_diag_mask_inf(
|
|
5368
6185
|
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
5369
6186
|
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
5370
6187
|
struct ggml_tensor * b = ggml_new_i32(ctx, n_past);
|
6188
|
+
ggml_set_name(b, "n_past");
|
5371
6189
|
|
5372
6190
|
result->op = GGML_OP_DIAG_MASK_INF;
|
5373
6191
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
@@ -5393,22 +6211,55 @@ struct ggml_tensor * ggml_soft_max(
|
|
5393
6211
|
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
5394
6212
|
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
5395
6213
|
|
5396
|
-
result->op = GGML_OP_SOFT_MAX;
|
6214
|
+
result->op = GGML_OP_SOFT_MAX;
|
6215
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6216
|
+
result->src0 = a;
|
6217
|
+
result->src1 = NULL;
|
6218
|
+
|
6219
|
+
return result;
|
6220
|
+
}
|
6221
|
+
|
6222
|
+
// ggml_rope
|
6223
|
+
|
6224
|
+
struct ggml_tensor * ggml_rope(
|
6225
|
+
struct ggml_context * ctx,
|
6226
|
+
struct ggml_tensor * a,
|
6227
|
+
int n_past,
|
6228
|
+
int n_dims,
|
6229
|
+
int mode) {
|
6230
|
+
GGML_ASSERT(n_past >= 0);
|
6231
|
+
bool is_node = false;
|
6232
|
+
|
6233
|
+
if (a->grad) {
|
6234
|
+
GGML_ASSERT(false); // TODO: implement backward
|
6235
|
+
is_node = true;
|
6236
|
+
}
|
6237
|
+
|
6238
|
+
// TODO: when implement backward, fix this:
|
6239
|
+
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
6240
|
+
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
6241
|
+
|
6242
|
+
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
|
6243
|
+
((int32_t *) b->data)[0] = n_past;
|
6244
|
+
((int32_t *) b->data)[1] = n_dims;
|
6245
|
+
((int32_t *) b->data)[2] = mode;
|
6246
|
+
ggml_set_name(b, "n_past, n_dims, mode");
|
6247
|
+
|
6248
|
+
result->op = GGML_OP_ROPE;
|
5397
6249
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5398
6250
|
result->src0 = a;
|
5399
|
-
result->src1 =
|
6251
|
+
result->src1 = b;
|
5400
6252
|
|
5401
6253
|
return result;
|
5402
6254
|
}
|
5403
6255
|
|
5404
|
-
//
|
6256
|
+
// ggml_alibi
|
5405
6257
|
|
5406
|
-
struct ggml_tensor *
|
6258
|
+
struct ggml_tensor * ggml_alibi(
|
5407
6259
|
struct ggml_context * ctx,
|
5408
6260
|
struct ggml_tensor * a,
|
5409
6261
|
int n_past,
|
5410
|
-
int
|
5411
|
-
int mode) {
|
6262
|
+
int n_head) {
|
5412
6263
|
GGML_ASSERT(n_past >= 0);
|
5413
6264
|
bool is_node = false;
|
5414
6265
|
|
@@ -5421,12 +6272,11 @@ struct ggml_tensor * ggml_rope(
|
|
5421
6272
|
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
5422
6273
|
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
5423
6274
|
|
5424
|
-
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32,
|
6275
|
+
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
|
5425
6276
|
((int32_t *) b->data)[0] = n_past;
|
5426
|
-
((int32_t *) b->data)[1] =
|
5427
|
-
((int32_t *) b->data)[2] = mode;
|
6277
|
+
((int32_t *) b->data)[1] = n_head;
|
5428
6278
|
|
5429
|
-
result->op =
|
6279
|
+
result->op = GGML_OP_ALIBI;
|
5430
6280
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5431
6281
|
result->src0 = a;
|
5432
6282
|
result->src1 = b;
|
@@ -6553,7 +7403,9 @@ static void ggml_compute_forward_add(
|
|
6553
7403
|
case GGML_TYPE_Q4_0:
|
6554
7404
|
case GGML_TYPE_Q4_1:
|
6555
7405
|
case GGML_TYPE_Q4_2:
|
6556
|
-
case
|
7406
|
+
case GGML_TYPE_Q5_0:
|
7407
|
+
case GGML_TYPE_Q5_1:
|
7408
|
+
case GGML_TYPE_Q8_0:
|
6557
7409
|
{
|
6558
7410
|
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
6559
7411
|
} break;
|
@@ -6811,15 +7663,20 @@ static void ggml_compute_forward_sum_f32(
|
|
6811
7663
|
const size_t nb02 = src0->nb[2];
|
6812
7664
|
const size_t nb03 = src0->nb[3];
|
6813
7665
|
|
7666
|
+
ggml_float sum = 0;
|
7667
|
+
ggml_float row_sum = 0;
|
7668
|
+
|
6814
7669
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
6815
7670
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
6816
7671
|
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
6817
|
-
|
6818
|
-
|
7672
|
+
ggml_vec_sum_ggf(ne00,
|
7673
|
+
&row_sum,
|
6819
7674
|
(float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
|
7675
|
+
sum += row_sum;
|
6820
7676
|
}
|
6821
7677
|
}
|
6822
7678
|
}
|
7679
|
+
((float *) dst->data)[0] = sum;
|
6823
7680
|
}
|
6824
7681
|
|
6825
7682
|
static void ggml_compute_forward_sum(
|
@@ -7454,7 +8311,7 @@ static void ggml_compute_forward_rms_norm(
|
|
7454
8311
|
|
7455
8312
|
// ggml_compute_forward_mul_mat
|
7456
8313
|
|
7457
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(
|
8314
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
7458
8315
|
// helper function to determine if it is better to use BLAS or not
|
7459
8316
|
// for large matrices, BLAS is faster
|
7460
8317
|
static bool ggml_compute_forward_mul_mat_use_blas(
|
@@ -7471,7 +8328,8 @@ static bool ggml_compute_forward_mul_mat_use_blas(
|
|
7471
8328
|
|
7472
8329
|
// TODO: find the optimal values for these
|
7473
8330
|
if (ggml_is_contiguous(src0) &&
|
7474
|
-
ggml_is_contiguous(src1) &&
|
8331
|
+
ggml_is_contiguous(src1) &&
|
8332
|
+
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
|
7475
8333
|
|
7476
8334
|
/*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
|
7477
8335
|
return true;
|
@@ -7494,7 +8352,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
7494
8352
|
const int64_t ne02 = src0->ne[2];
|
7495
8353
|
const int64_t ne03 = src0->ne[3];
|
7496
8354
|
|
7497
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(
|
8355
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
7498
8356
|
const int64_t ne10 = src1->ne[0];
|
7499
8357
|
#endif
|
7500
8358
|
const int64_t ne11 = src1->ne[1];
|
@@ -7551,7 +8409,16 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
7551
8409
|
// nb01 >= nb00 - src0 is not transposed
|
7552
8410
|
// compute by src0 rows
|
7553
8411
|
|
7554
|
-
#if defined(
|
8412
|
+
#if defined(GGML_USE_CUBLAS)
|
8413
|
+
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
|
8414
|
+
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
8415
|
+
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
8416
|
+
}
|
8417
|
+
return;
|
8418
|
+
}
|
8419
|
+
#endif
|
8420
|
+
|
8421
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
7555
8422
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
7556
8423
|
if (params->ith != 0) {
|
7557
8424
|
return;
|
@@ -7565,45 +8432,21 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
7565
8432
|
return;
|
7566
8433
|
}
|
7567
8434
|
|
7568
|
-
#if defined(GGML_USE_CUBLAS)
|
7569
|
-
float *d_X = NULL;
|
7570
|
-
float *d_Y = NULL;
|
7571
|
-
float *d_D = NULL;
|
7572
|
-
const float alpha = 1.0f;
|
7573
|
-
const float beta = 0.0f;
|
7574
|
-
const int x_ne = ne01 * ne10;
|
7575
|
-
const int y_ne = ne11 * ne10;
|
7576
|
-
const int d_ne = ne11 * ne01;
|
7577
|
-
|
7578
|
-
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
|
7579
|
-
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
|
7580
|
-
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
7581
|
-
#endif
|
7582
|
-
|
7583
8435
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
7584
8436
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
7585
8437
|
const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
|
7586
8438
|
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
7587
|
-
|
7588
8439
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
7589
8440
|
|
7590
|
-
#if defined(
|
7591
|
-
// copy data to device
|
7592
|
-
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
|
7593
|
-
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
7594
|
-
|
7595
|
-
// compute
|
7596
|
-
CUBLAS_CHECK(
|
7597
|
-
cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
7598
|
-
ne01, ne11, ne10,
|
7599
|
-
&alpha, d_X, ne00,
|
7600
|
-
d_Y, ne10,
|
7601
|
-
&beta, d_D, ne01));
|
7602
|
-
|
7603
|
-
// copy data to host
|
7604
|
-
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
7605
|
-
#else
|
8441
|
+
#if defined(GGML_USE_CLBLAST)
|
7606
8442
|
// zT = y * xT
|
8443
|
+
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
|
8444
|
+
ne11, ne01, ne10,
|
8445
|
+
1.0f, y, ne10,
|
8446
|
+
x, ne10,
|
8447
|
+
0.0f, d, ne01,
|
8448
|
+
GGML_TYPE_F32);
|
8449
|
+
#else
|
7607
8450
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
7608
8451
|
ne11, ne01, ne10,
|
7609
8452
|
1.0f, y, ne10,
|
@@ -7612,12 +8455,6 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
7612
8455
|
#endif
|
7613
8456
|
}
|
7614
8457
|
}
|
7615
|
-
#if defined(GGML_USE_CUBLAS)
|
7616
|
-
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
7617
|
-
CUDA_CHECK(cudaFree(d_X));
|
7618
|
-
CUDA_CHECK(cudaFree(d_Y));
|
7619
|
-
CUDA_CHECK(cudaFree(d_D));
|
7620
|
-
#endif
|
7621
8458
|
//printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
7622
8459
|
|
7623
8460
|
return;
|
@@ -7747,7 +8584,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
7747
8584
|
// nb01 >= nb00 - src0 is not transposed
|
7748
8585
|
// compute by src0 rows
|
7749
8586
|
|
7750
|
-
#if defined(
|
8587
|
+
#if defined(GGML_USE_CUBLAS)
|
8588
|
+
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
|
8589
|
+
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
8590
|
+
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
8591
|
+
}
|
8592
|
+
return;
|
8593
|
+
}
|
8594
|
+
#endif
|
8595
|
+
|
8596
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
7751
8597
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
7752
8598
|
GGML_ASSERT(nb10 == sizeof(float));
|
7753
8599
|
|
@@ -7763,37 +8609,9 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
7763
8609
|
return;
|
7764
8610
|
}
|
7765
8611
|
|
7766
|
-
#if defined(GGML_USE_CUBLAS)
|
7767
|
-
ggml_fp16_t * const wdata = params->wdata;
|
7768
|
-
|
7769
|
-
float *d_X = NULL;
|
7770
|
-
float *d_Y = NULL;
|
7771
|
-
float *d_D = NULL;
|
7772
|
-
const float alpha = 1.0f;
|
7773
|
-
const float beta = 0.0f;
|
7774
|
-
const int x_ne = ne01 * ne10;
|
7775
|
-
const int y_ne = ne11 * ne10;
|
7776
|
-
const int d_ne = ne11 * ne01;
|
7777
|
-
|
7778
|
-
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne));
|
7779
|
-
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
|
7780
|
-
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
7781
|
-
#else
|
7782
|
-
float * const wdata = params->wdata;
|
7783
|
-
#endif
|
7784
8612
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
7785
8613
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
7786
|
-
|
7787
|
-
// with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
|
7788
|
-
{
|
7789
|
-
size_t id = 0;
|
7790
|
-
for (int64_t i01 = 0; i01 < ne11; ++i01) {
|
7791
|
-
for (int64_t i00 = 0; i00 < ne10; ++i00) {
|
7792
|
-
wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
|
7793
|
-
}
|
7794
|
-
}
|
7795
|
-
}
|
7796
|
-
#else
|
8614
|
+
float * const wdata = params->wdata;
|
7797
8615
|
{
|
7798
8616
|
size_t id = 0;
|
7799
8617
|
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
@@ -7801,31 +8619,23 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
7801
8619
|
wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
|
7802
8620
|
}
|
7803
8621
|
}
|
8622
|
+
|
8623
|
+
assert(id*sizeof(float) <= params->wsize);
|
7804
8624
|
}
|
7805
|
-
#endif
|
7806
8625
|
|
7807
|
-
#if defined(
|
7808
|
-
const
|
7809
|
-
const
|
8626
|
+
#if defined(GGML_USE_CLBLAST)
|
8627
|
+
const float * x = wdata;
|
8628
|
+
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
7810
8629
|
|
7811
8630
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
7812
8631
|
|
7813
|
-
//
|
7814
|
-
|
7815
|
-
|
7816
|
-
|
7817
|
-
|
7818
|
-
|
7819
|
-
|
7820
|
-
ne01, ne11, ne10,
|
7821
|
-
&alpha, d_X, CUDA_R_16F, ne00,
|
7822
|
-
d_Y, CUDA_R_16F, ne10,
|
7823
|
-
&beta, d_D, CUDA_R_32F, ne01,
|
7824
|
-
CUBLAS_COMPUTE_32F,
|
7825
|
-
CUBLAS_GEMM_DEFAULT));
|
7826
|
-
|
7827
|
-
// copy data to host
|
7828
|
-
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
8632
|
+
// zT = y * xT
|
8633
|
+
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
|
8634
|
+
ne11, ne01, ne10,
|
8635
|
+
1.0f, y, ne10,
|
8636
|
+
x, ne10,
|
8637
|
+
0.0f, d, ne01,
|
8638
|
+
GGML_TYPE_F32);
|
7829
8639
|
#else
|
7830
8640
|
const float * x = wdata;
|
7831
8641
|
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
@@ -7842,12 +8652,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
7842
8652
|
}
|
7843
8653
|
}
|
7844
8654
|
|
7845
|
-
#if defined(GGML_USE_CUBLAS)
|
7846
|
-
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
7847
|
-
CUDA_CHECK(cudaFree(d_X));
|
7848
|
-
CUDA_CHECK(cudaFree(d_Y));
|
7849
|
-
CUDA_CHECK(cudaFree(d_D));
|
7850
|
-
#endif
|
7851
8655
|
/*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
|
7852
8656
|
|
7853
8657
|
return;
|
@@ -7980,6 +8784,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7980
8784
|
const enum ggml_type type = src0->type;
|
7981
8785
|
quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
|
7982
8786
|
vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
|
8787
|
+
enum ggml_type const vec_dot_type = quantize_fns[type].vec_dot_type;
|
7983
8788
|
|
7984
8789
|
// we don't support permuted src0 or src1
|
7985
8790
|
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
|
@@ -7999,7 +8804,16 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7999
8804
|
// nb01 >= nb00 - src0 is not transposed
|
8000
8805
|
// compute by src0 rows
|
8001
8806
|
|
8002
|
-
#if defined(
|
8807
|
+
#if defined(GGML_USE_CUBLAS)
|
8808
|
+
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
|
8809
|
+
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
8810
|
+
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
8811
|
+
}
|
8812
|
+
return;
|
8813
|
+
}
|
8814
|
+
#endif
|
8815
|
+
|
8816
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
8003
8817
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
8004
8818
|
if (params->ith != 0) {
|
8005
8819
|
return;
|
@@ -8013,39 +8827,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
8013
8827
|
return;
|
8014
8828
|
}
|
8015
8829
|
|
8016
|
-
#if defined(GGML_USE_CUBLAS)
|
8017
|
-
float *d_X = NULL;
|
8018
|
-
float *d_Y = NULL;
|
8019
|
-
float *d_D = NULL;
|
8020
|
-
float *d_Q = NULL;
|
8021
|
-
const float alpha = 1.0f;
|
8022
|
-
const float beta = 0.0f;
|
8023
|
-
const int x_ne = ne01 * ne10;
|
8024
|
-
const int y_ne = ne11 * ne10;
|
8025
|
-
const int d_ne = ne11 * ne01;
|
8026
|
-
|
8027
|
-
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
|
8028
|
-
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
|
8029
|
-
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
8030
|
-
CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
|
8031
|
-
|
8032
|
-
void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
|
8033
|
-
if (type == GGML_TYPE_Q4_0) {
|
8034
|
-
dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
|
8035
|
-
}
|
8036
|
-
else if (type == GGML_TYPE_Q4_1) {
|
8037
|
-
dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
|
8038
|
-
}
|
8039
|
-
else if (type == GGML_TYPE_Q4_2) {
|
8040
|
-
dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
|
8041
|
-
}
|
8042
|
-
else {
|
8043
|
-
GGML_ASSERT(false);
|
8044
|
-
}
|
8045
|
-
#else
|
8046
8830
|
float * const wdata = params->wdata;
|
8047
8831
|
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
|
8048
|
-
#endif
|
8049
8832
|
|
8050
8833
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
8051
8834
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
@@ -8053,14 +8836,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
8053
8836
|
|
8054
8837
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
8055
8838
|
|
8056
|
-
#if defined(
|
8057
|
-
|
8058
|
-
CUDA_CHECK(
|
8059
|
-
cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
|
8060
|
-
GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
|
8061
|
-
|
8062
|
-
dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream);
|
8063
|
-
CUDA_CHECK(cudaGetLastError());
|
8839
|
+
#if defined(GGML_USE_CLBLAST)
|
8840
|
+
const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
|
8064
8841
|
#else
|
8065
8842
|
{
|
8066
8843
|
size_t id = 0;
|
@@ -8068,27 +8845,22 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
8068
8845
|
dequantize_row_q((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
|
8069
8846
|
id += ne00;
|
8070
8847
|
}
|
8848
|
+
|
8849
|
+
assert(id*sizeof(float) <= params->wsize);
|
8071
8850
|
}
|
8851
|
+
|
8072
8852
|
const float * x = wdata;
|
8073
8853
|
#endif
|
8074
8854
|
|
8075
|
-
|
8076
|
-
#if defined(GGML_USE_CUBLAS)
|
8077
|
-
// copy data to device
|
8078
|
-
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
8079
|
-
|
8080
|
-
// compute
|
8081
|
-
CUBLAS_CHECK(
|
8082
|
-
cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
8083
|
-
ne01, ne11, ne10,
|
8084
|
-
&alpha, d_X, ne00,
|
8085
|
-
d_Y, ne10,
|
8086
|
-
&beta, d_D, ne01));
|
8087
|
-
|
8088
|
-
// copy data to host
|
8089
|
-
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
8090
|
-
#else
|
8855
|
+
#if defined(GGML_USE_CLBLAST)
|
8091
8856
|
// zT = y * xT
|
8857
|
+
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
|
8858
|
+
ne11, ne01, ne10,
|
8859
|
+
1.0f, y, ne10,
|
8860
|
+
x, ne10,
|
8861
|
+
0.0f, d, ne01,
|
8862
|
+
type);
|
8863
|
+
#else
|
8092
8864
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
8093
8865
|
ne11, ne01, ne10,
|
8094
8866
|
1.0f, y, ne10,
|
@@ -8098,13 +8870,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
8098
8870
|
}
|
8099
8871
|
}
|
8100
8872
|
|
8101
|
-
#if defined(GGML_USE_CUBLAS)
|
8102
|
-
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
8103
|
-
CUDA_CHECK(cudaFree(d_X));
|
8104
|
-
CUDA_CHECK(cudaFree(d_Y));
|
8105
|
-
CUDA_CHECK(cudaFree(d_D));
|
8106
|
-
CUDA_CHECK(cudaFree(d_Q));
|
8107
|
-
#endif
|
8108
8873
|
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
8109
8874
|
|
8110
8875
|
return;
|
@@ -8113,7 +8878,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
8113
8878
|
|
8114
8879
|
if (params->type == GGML_TASK_INIT) {
|
8115
8880
|
char * wdata = params->wdata;
|
8116
|
-
const size_t row_size = ne10*GGML_TYPE_SIZE[
|
8881
|
+
const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
|
8117
8882
|
|
8118
8883
|
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
8119
8884
|
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
@@ -8144,7 +8909,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
8144
8909
|
const int ir1 = MIN(ir0 + dr, nr);
|
8145
8910
|
|
8146
8911
|
void * wdata = params->wdata;
|
8147
|
-
const size_t row_size = ne00*GGML_TYPE_SIZE[
|
8912
|
+
const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
|
8148
8913
|
|
8149
8914
|
for (int ir = ir0; ir < ir1; ++ir) {
|
8150
8915
|
// src0 indices
|
@@ -8193,8 +8958,10 @@ static void ggml_compute_forward_mul_mat(
|
|
8193
8958
|
case GGML_TYPE_Q4_0:
|
8194
8959
|
case GGML_TYPE_Q4_1:
|
8195
8960
|
case GGML_TYPE_Q4_2:
|
8196
|
-
case
|
8961
|
+
case GGML_TYPE_Q5_0:
|
8962
|
+
case GGML_TYPE_Q5_1:
|
8197
8963
|
case GGML_TYPE_Q8_0:
|
8964
|
+
case GGML_TYPE_Q8_1:
|
8198
8965
|
{
|
8199
8966
|
ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
|
8200
8967
|
} break;
|
@@ -8422,8 +9189,10 @@ static void ggml_compute_forward_get_rows(
|
|
8422
9189
|
case GGML_TYPE_Q4_0:
|
8423
9190
|
case GGML_TYPE_Q4_1:
|
8424
9191
|
case GGML_TYPE_Q4_2:
|
8425
|
-
case
|
9192
|
+
case GGML_TYPE_Q5_0:
|
9193
|
+
case GGML_TYPE_Q5_1:
|
8426
9194
|
case GGML_TYPE_Q8_0:
|
9195
|
+
case GGML_TYPE_Q8_1:
|
8427
9196
|
{
|
8428
9197
|
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
8429
9198
|
} break;
|
@@ -8561,6 +9330,7 @@ static void ggml_compute_forward_soft_max_f32(
|
|
8561
9330
|
|
8562
9331
|
uint16_t scvt;
|
8563
9332
|
for (int i = 0; i < nc; i++) {
|
9333
|
+
//printf("p[%3d] = %8.4f\n", i, p[i]);
|
8564
9334
|
if (p[i] == -INFINITY) {
|
8565
9335
|
p[i] = 0.0f;
|
8566
9336
|
} else {
|
@@ -8603,6 +9373,161 @@ static void ggml_compute_forward_soft_max(
|
|
8603
9373
|
}
|
8604
9374
|
}
|
8605
9375
|
|
9376
|
+
// ggml_compute_forward_alibi
|
9377
|
+
|
9378
|
+
static void ggml_compute_forward_alibi_f32(
|
9379
|
+
const struct ggml_compute_params * params,
|
9380
|
+
const struct ggml_tensor * src0,
|
9381
|
+
const struct ggml_tensor * src1,
|
9382
|
+
struct ggml_tensor * dst) {
|
9383
|
+
assert(params->ith == 0);
|
9384
|
+
assert(src1->type == GGML_TYPE_I32);
|
9385
|
+
assert(ggml_nelements(src1) == 2);
|
9386
|
+
|
9387
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
9388
|
+
return;
|
9389
|
+
}
|
9390
|
+
|
9391
|
+
const int n_past = ((int32_t *) src1->data)[0];
|
9392
|
+
const int n_head = ((int32_t *) src1->data)[1];
|
9393
|
+
|
9394
|
+
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
|
9395
|
+
const int ne1 = src0->ne[1]; // seq_len_without_past
|
9396
|
+
//const int ne2 = src0->ne[2]; // n_head -> this is k
|
9397
|
+
//const int ne3 = src0->ne[3]; // 1 -> bsz
|
9398
|
+
|
9399
|
+
const int n = ggml_nrows(src0);
|
9400
|
+
const int ne2_ne3 = n/ne1; // ne2*ne3
|
9401
|
+
|
9402
|
+
const int nb0 = src0->nb[0];
|
9403
|
+
const int nb1 = src0->nb[1];
|
9404
|
+
const int nb2 = src0->nb[2];
|
9405
|
+
//const int nb3 = src0->nb[3];
|
9406
|
+
|
9407
|
+
assert(nb0 == sizeof(float));
|
9408
|
+
assert(ne1 + n_past == ne0); (void) n_past;
|
9409
|
+
|
9410
|
+
// add alibi to src0 (KQ_scaled)
|
9411
|
+
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
9412
|
+
|
9413
|
+
const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor);
|
9414
|
+
const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor);
|
9415
|
+
|
9416
|
+
for (int i = 0; i < ne0; i++) {
|
9417
|
+
for (int j = 0; j < ne1; j++) {
|
9418
|
+
for (int k = 0; k < ne2_ne3; k++) {
|
9419
|
+
float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
9420
|
+
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
9421
|
+
|
9422
|
+
// TODO: k*nb2 or k*nb3
|
9423
|
+
|
9424
|
+
float m_k;
|
9425
|
+
|
9426
|
+
if (k < n_heads_log2_floor) {
|
9427
|
+
m_k = powf(m0, k + 1);
|
9428
|
+
} else {
|
9429
|
+
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
9430
|
+
}
|
9431
|
+
|
9432
|
+
pdst[0] = (j+1) * m_k + src[0];
|
9433
|
+
}
|
9434
|
+
}
|
9435
|
+
}
|
9436
|
+
}
|
9437
|
+
|
9438
|
+
|
9439
|
+
static void ggml_compute_forward_alibi_f16(
|
9440
|
+
const struct ggml_compute_params * params,
|
9441
|
+
const struct ggml_tensor * src0,
|
9442
|
+
const struct ggml_tensor * src1,
|
9443
|
+
struct ggml_tensor * dst) {
|
9444
|
+
assert(params->ith == 0);
|
9445
|
+
assert(src1->type == GGML_TYPE_I32);
|
9446
|
+
assert(ggml_nelements(src1) == 2);
|
9447
|
+
|
9448
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
9449
|
+
return;
|
9450
|
+
}
|
9451
|
+
|
9452
|
+
const int n_past = ((int32_t *) src1->data)[0];
|
9453
|
+
const int n_head = ((int32_t *) src1->data)[1];
|
9454
|
+
|
9455
|
+
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
|
9456
|
+
const int ne1 = src0->ne[1]; // seq_len_without_past
|
9457
|
+
//const int ne2 = src0->ne[2]; // n_head -> this is k
|
9458
|
+
//const int ne3 = src0->ne[3]; // 1 -> bsz
|
9459
|
+
|
9460
|
+
const int n = ggml_nrows(src0);
|
9461
|
+
const int ne2_ne3 = n/ne1; // ne2*ne3
|
9462
|
+
|
9463
|
+
const int nb0 = src0->nb[0];
|
9464
|
+
const int nb1 = src0->nb[1];
|
9465
|
+
const int nb2 = src0->nb[2];
|
9466
|
+
//const int nb3 = src0->nb[3];
|
9467
|
+
|
9468
|
+
assert(nb0 == sizeof(ggml_fp16_t));
|
9469
|
+
assert(ne1 + n_past == ne0); (void) n_past;
|
9470
|
+
|
9471
|
+
// add alibi to src0 (KQ_scaled)
|
9472
|
+
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
9473
|
+
|
9474
|
+
const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor);
|
9475
|
+
const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor);
|
9476
|
+
|
9477
|
+
for (int i = 0; i < ne0; i++) {
|
9478
|
+
for (int j = 0; j < ne1; j++) {
|
9479
|
+
for (int k = 0; k < ne2_ne3; k++) {
|
9480
|
+
ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
9481
|
+
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
9482
|
+
|
9483
|
+
// TODO: k*nb2 or k*nb3
|
9484
|
+
|
9485
|
+
float m_k;
|
9486
|
+
|
9487
|
+
if (k < n_heads_log2_floor) {
|
9488
|
+
m_k = powf(m0, k + 1);
|
9489
|
+
} else {
|
9490
|
+
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
9491
|
+
}
|
9492
|
+
|
9493
|
+
// we return F32
|
9494
|
+
pdst[0] = (j+1) * m_k + GGML_FP16_TO_FP32(src[0]);
|
9495
|
+
}
|
9496
|
+
}
|
9497
|
+
}
|
9498
|
+
}
|
9499
|
+
|
9500
|
+
static void ggml_compute_forward_alibi(
|
9501
|
+
const struct ggml_compute_params * params,
|
9502
|
+
const struct ggml_tensor * src0,
|
9503
|
+
const struct ggml_tensor * src1,
|
9504
|
+
struct ggml_tensor * dst) {
|
9505
|
+
switch (src0->type) {
|
9506
|
+
case GGML_TYPE_F16:
|
9507
|
+
{
|
9508
|
+
ggml_compute_forward_alibi_f16(params, src0, src1, dst);
|
9509
|
+
} break;
|
9510
|
+
case GGML_TYPE_F32:
|
9511
|
+
{
|
9512
|
+
ggml_compute_forward_alibi_f32(params, src0, src1, dst);
|
9513
|
+
} break;
|
9514
|
+
case GGML_TYPE_Q4_0:
|
9515
|
+
case GGML_TYPE_Q4_1:
|
9516
|
+
case GGML_TYPE_Q4_2:
|
9517
|
+
case GGML_TYPE_Q5_0:
|
9518
|
+
case GGML_TYPE_Q5_1:
|
9519
|
+
case GGML_TYPE_Q8_0:
|
9520
|
+
case GGML_TYPE_Q8_1:
|
9521
|
+
case GGML_TYPE_I8:
|
9522
|
+
case GGML_TYPE_I16:
|
9523
|
+
case GGML_TYPE_I32:
|
9524
|
+
case GGML_TYPE_COUNT:
|
9525
|
+
{
|
9526
|
+
GGML_ASSERT(false);
|
9527
|
+
} break;
|
9528
|
+
}
|
9529
|
+
}
|
9530
|
+
|
8606
9531
|
// ggml_compute_forward_rope
|
8607
9532
|
|
8608
9533
|
static void ggml_compute_forward_rope_f32(
|
@@ -10241,6 +11166,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
10241
11166
|
{
|
10242
11167
|
ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
|
10243
11168
|
} break;
|
11169
|
+
case GGML_OP_ALIBI:
|
11170
|
+
{
|
11171
|
+
ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor);
|
11172
|
+
} break;
|
10244
11173
|
case GGML_OP_CONV_1D_1S:
|
10245
11174
|
{
|
10246
11175
|
ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor);
|
@@ -10443,6 +11372,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
10443
11372
|
{
|
10444
11373
|
GGML_ASSERT(false); // TODO: not implemented
|
10445
11374
|
} break;
|
11375
|
+
case GGML_OP_ALIBI:
|
11376
|
+
{
|
11377
|
+
GGML_ASSERT(false); // TODO: not implemented
|
11378
|
+
} break;
|
10446
11379
|
case GGML_OP_SILU:
|
10447
11380
|
{
|
10448
11381
|
GGML_ASSERT(false); // TODO: not implemented
|
@@ -10920,15 +11853,21 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
10920
11853
|
|
10921
11854
|
size_t cur = 0;
|
10922
11855
|
|
11856
|
+
#if defined(GGML_USE_CUBLAS)
|
11857
|
+
if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
|
11858
|
+
node->n_tasks = 1; // TODO: this actually is doing nothing
|
11859
|
+
// the threads are still spinning
|
11860
|
+
cur = ggml_cuda_mul_mat_get_wsize(node->src0, node->src1, node);
|
11861
|
+
}
|
11862
|
+
else
|
11863
|
+
#endif
|
10923
11864
|
if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
|
10924
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(
|
11865
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
10925
11866
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
10926
11867
|
node->n_tasks = 1; // TODO: this actually is doing nothing
|
10927
11868
|
// the threads are still spinning
|
11869
|
+
// here we need memory just for single 2D matrix from src0
|
10928
11870
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
|
10929
|
-
//printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
|
10930
|
-
//printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
|
10931
|
-
//printf("cur = %zu\n", cur);
|
10932
11871
|
} else {
|
10933
11872
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
|
10934
11873
|
}
|
@@ -10937,15 +11876,21 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
10937
11876
|
#endif
|
10938
11877
|
} else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
|
10939
11878
|
cur = 0;
|
11879
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
11880
|
+
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
11881
|
+
node->n_tasks = 1;
|
11882
|
+
}
|
11883
|
+
#endif
|
10940
11884
|
} else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
|
10941
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(
|
11885
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
10942
11886
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
10943
11887
|
node->n_tasks = 1;
|
10944
11888
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
|
10945
11889
|
} else
|
10946
11890
|
#endif
|
10947
11891
|
{
|
10948
|
-
|
11892
|
+
const enum ggml_type type_q = quantize_fns[node->src0->type].vec_dot_type;
|
11893
|
+
cur = GGML_TYPE_SIZE[type_q]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[type_q];
|
10949
11894
|
}
|
10950
11895
|
} else {
|
10951
11896
|
GGML_ASSERT(false);
|
@@ -10975,6 +11920,10 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
10975
11920
|
{
|
10976
11921
|
node->n_tasks = n_threads;
|
10977
11922
|
} break;
|
11923
|
+
case GGML_OP_ALIBI:
|
11924
|
+
{
|
11925
|
+
node->n_tasks = 1; //TODO
|
11926
|
+
} break;
|
10978
11927
|
case GGML_OP_CONV_1D_1S:
|
10979
11928
|
case GGML_OP_CONV_1D_2S:
|
10980
11929
|
{
|
@@ -11273,9 +12222,9 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
|
|
11273
12222
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
11274
12223
|
struct ggml_tensor * node = cgraph->nodes[i];
|
11275
12224
|
|
11276
|
-
perf_total_per_op_us[node->op] += node->perf_time_us;
|
12225
|
+
perf_total_per_op_us[node->op] += MAX(1, node->perf_time_us);
|
11277
12226
|
|
11278
|
-
GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 ", %" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
|
12227
|
+
GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
|
11279
12228
|
i,
|
11280
12229
|
node->ne[0], node->ne[1], node->ne[2],
|
11281
12230
|
GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
|
@@ -11289,13 +12238,17 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
|
|
11289
12238
|
for (int i = 0; i < cgraph->n_leafs; i++) {
|
11290
12239
|
struct ggml_tensor * node = cgraph->leafs[i];
|
11291
12240
|
|
11292
|
-
GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 "] %8s\n",
|
12241
|
+
GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
|
11293
12242
|
i,
|
11294
12243
|
node->ne[0], node->ne[1],
|
11295
12244
|
GGML_OP_LABEL[node->op]);
|
11296
12245
|
}
|
11297
12246
|
|
11298
12247
|
for (int i = 0; i < GGML_OP_COUNT; i++) {
|
12248
|
+
if (perf_total_per_op_us[i] == 0) {
|
12249
|
+
continue;
|
12250
|
+
}
|
12251
|
+
|
11299
12252
|
GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_LABEL[i], (double) perf_total_per_op_us[i] / 1000.0);
|
11300
12253
|
}
|
11301
12254
|
|
@@ -11358,10 +12311,16 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
|
|
11358
12311
|
snprintf(color, sizeof(color), "white");
|
11359
12312
|
}
|
11360
12313
|
|
11361
|
-
fprintf(fp, " \"%p\" [
|
11362
|
-
style = filled; fillcolor = %s; shape = record;
|
11363
|
-
label=\"
|
11364
|
-
(void *) node, color
|
12314
|
+
fprintf(fp, " \"%p\" [ "
|
12315
|
+
"style = filled; fillcolor = %s; shape = record; "
|
12316
|
+
"label=\"",
|
12317
|
+
(void *) node, color);
|
12318
|
+
|
12319
|
+
if (strlen(node->name) > 0) {
|
12320
|
+
fprintf(fp, "%s |", node->name);
|
12321
|
+
}
|
12322
|
+
|
12323
|
+
fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | <x>%s",
|
11365
12324
|
i, node->ne[0], node->ne[1],
|
11366
12325
|
GGML_OP_SYMBOL[node->op]);
|
11367
12326
|
|
@@ -11377,18 +12336,26 @@ label=\"%d [%" PRId64 ", %" PRId64 "] | <x>%s",
|
|
11377
12336
|
|
11378
12337
|
snprintf(color, sizeof(color), "pink");
|
11379
12338
|
|
12339
|
+
fprintf(fp, " \"%p\" [ "
|
12340
|
+
"style = filled; fillcolor = %s; shape = record; "
|
12341
|
+
"label=\"<x>",
|
12342
|
+
(void *) node, color);
|
12343
|
+
|
12344
|
+
if (strlen(node->name) > 0) {
|
12345
|
+
fprintf(fp, "%s | ", node->name);
|
12346
|
+
}
|
11380
12347
|
if (ggml_nelements(node) == 1) {
|
11381
|
-
|
11382
|
-
|
11383
|
-
|
11384
|
-
|
11385
|
-
|
11386
|
-
|
11387
|
-
style = filled; fillcolor = %s; shape = record; \
|
11388
|
-
label=\"<x>CONST %d [%" PRId64 ", %" PRId64 "]\"; ]\n",
|
11389
|
-
(void *) node, color,
|
11390
|
-
i, node->ne[0], node->ne[1]);
|
12348
|
+
if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
|
12349
|
+
fprintf(fp, "%d", ggml_get_i32_1d(node, 0));
|
12350
|
+
}
|
12351
|
+
else {
|
12352
|
+
fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, 0));
|
12353
|
+
}
|
11391
12354
|
}
|
12355
|
+
else {
|
12356
|
+
fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]);
|
12357
|
+
}
|
12358
|
+
fprintf(fp, "\"; ]\n");
|
11392
12359
|
}
|
11393
12360
|
|
11394
12361
|
for (int i = 0; i < gb->n_nodes; i++) {
|
@@ -12129,7 +13096,7 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
|
|
12129
13096
|
|
12130
13097
|
for (int i = 0; i < nb; i++) {
|
12131
13098
|
for (int l = 0; l < QK4_0; l += 2) {
|
12132
|
-
const uint8_t vi0 = y[i].qs[l/2] &
|
13099
|
+
const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
|
12133
13100
|
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
12134
13101
|
|
12135
13102
|
hist[vi0]++;
|
@@ -12152,7 +13119,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
|
|
12152
13119
|
|
12153
13120
|
for (int i = 0; i < nb; i++) {
|
12154
13121
|
for (int l = 0; l < QK4_1; l += 2) {
|
12155
|
-
const uint8_t vi0 = y[i].qs[l/2] &
|
13122
|
+
const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
|
12156
13123
|
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
12157
13124
|
|
12158
13125
|
hist[vi0]++;
|
@@ -12171,12 +13138,11 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
|
|
12171
13138
|
for (int j = 0; j < n; j += k) {
|
12172
13139
|
block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
|
12173
13140
|
|
12174
|
-
|
12175
|
-
quantize_row_q4_2_rmse(src + j, y, k);
|
13141
|
+
quantize_row_q4_2_reference(src + j, y, k);
|
12176
13142
|
|
12177
13143
|
for (int i = 0; i < nb; i++) {
|
12178
13144
|
for (int l = 0; l < QK4_2; l += 2) {
|
12179
|
-
const uint8_t vi0 = y[i].qs[l/2] &
|
13145
|
+
const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
|
12180
13146
|
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
12181
13147
|
|
12182
13148
|
hist[vi0]++;
|
@@ -12188,19 +13154,56 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
|
|
12188
13154
|
return (n/QK4_2*sizeof(block_q4_2));
|
12189
13155
|
}
|
12190
13156
|
|
12191
|
-
size_t
|
12192
|
-
assert(k %
|
12193
|
-
const int nb = k /
|
13157
|
+
size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) {
|
13158
|
+
assert(k % QK5_0 == 0);
|
13159
|
+
const int nb = k / QK5_0;
|
12194
13160
|
|
12195
13161
|
for (int j = 0; j < n; j += k) {
|
12196
|
-
|
13162
|
+
block_q5_0 * restrict y = (block_q5_0 *)dst + j/QK5_0;
|
12197
13163
|
|
12198
|
-
|
13164
|
+
quantize_row_q5_0_reference(src + j, y, k);
|
12199
13165
|
|
12200
13166
|
for (int i = 0; i < nb; i++) {
|
12201
|
-
|
12202
|
-
|
12203
|
-
|
13167
|
+
uint32_t qh;
|
13168
|
+
memcpy(&qh, &y[i].qh, sizeof(qh));
|
13169
|
+
|
13170
|
+
for (int l = 0; l < QK5_0; l += 2) {
|
13171
|
+
const uint8_t vh0 = ((qh & (1u << (l + 0))) >> (l + 0)) << 4;
|
13172
|
+
const uint8_t vh1 = ((qh & (1u << (l + 1))) >> (l + 1)) << 4;
|
13173
|
+
|
13174
|
+
// cast to 16 bins
|
13175
|
+
const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
|
13176
|
+
const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2;
|
13177
|
+
|
13178
|
+
hist[vi0]++;
|
13179
|
+
hist[vi1]++;
|
13180
|
+
}
|
13181
|
+
}
|
13182
|
+
}
|
13183
|
+
|
13184
|
+
return (n/QK5_0*sizeof(block_q5_0));
|
13185
|
+
}
|
13186
|
+
|
13187
|
+
size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) {
|
13188
|
+
assert(k % QK5_1 == 0);
|
13189
|
+
const int nb = k / QK5_1;
|
13190
|
+
|
13191
|
+
for (int j = 0; j < n; j += k) {
|
13192
|
+
block_q5_1 * restrict y = (block_q5_1 *)dst + j/QK5_1;
|
13193
|
+
|
13194
|
+
quantize_row_q5_1_reference(src + j, y, k);
|
13195
|
+
|
13196
|
+
for (int i = 0; i < nb; i++) {
|
13197
|
+
uint32_t qh;
|
13198
|
+
memcpy(&qh, &y[i].qh, sizeof(qh));
|
13199
|
+
|
13200
|
+
for (int l = 0; l < QK5_1; l += 2) {
|
13201
|
+
const uint8_t vh0 = ((qh & (1u << (l + 0))) >> (l + 0)) << 4;
|
13202
|
+
const uint8_t vh1 = ((qh & (1u << (l + 1))) >> (l + 1)) << 4;
|
13203
|
+
|
13204
|
+
// cast to 16 bins
|
13205
|
+
const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
|
13206
|
+
const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2;
|
12204
13207
|
|
12205
13208
|
hist[vi0]++;
|
12206
13209
|
hist[vi1]++;
|
@@ -12208,7 +13211,28 @@ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t *
|
|
12208
13211
|
}
|
12209
13212
|
}
|
12210
13213
|
|
12211
|
-
return (n/
|
13214
|
+
return (n/QK5_1*sizeof(block_q5_1));
|
13215
|
+
}
|
13216
|
+
|
13217
|
+
size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) {
|
13218
|
+
assert(k % QK8_0 == 0);
|
13219
|
+
const int nb = k / QK8_0;
|
13220
|
+
|
13221
|
+
for (int j = 0; j < n; j += k) {
|
13222
|
+
block_q8_0 * restrict y = (block_q8_0 *)dst + j/QK8_0;
|
13223
|
+
|
13224
|
+
quantize_row_q8_0_reference(src + j, y, k);
|
13225
|
+
|
13226
|
+
for (int i = 0; i < nb; i++) {
|
13227
|
+
for (int l = 0; l < QK8_0; ++l) {
|
13228
|
+
const int8_t vi = y[i].qs[l];
|
13229
|
+
|
13230
|
+
hist[vi/16 + 8]++;
|
13231
|
+
}
|
13232
|
+
}
|
13233
|
+
}
|
13234
|
+
|
13235
|
+
return (n/QK8_0*sizeof(block_q8_0));
|
12212
13236
|
}
|
12213
13237
|
|
12214
13238
|
size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
|
@@ -12232,11 +13256,23 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
|
|
12232
13256
|
block_q4_2 * block = (block_q4_2*)dst + start / QK4_2;
|
12233
13257
|
result = ggml_quantize_q4_2(src + start, block, n, n, hist);
|
12234
13258
|
} break;
|
12235
|
-
case
|
13259
|
+
case GGML_TYPE_Q5_0:
|
13260
|
+
{
|
13261
|
+
GGML_ASSERT(start % QK5_0 == 0);
|
13262
|
+
block_q5_0 * block = (block_q5_0*)dst + start / QK5_0;
|
13263
|
+
result = ggml_quantize_q5_0(src + start, block, n, n, hist);
|
13264
|
+
} break;
|
13265
|
+
case GGML_TYPE_Q5_1:
|
13266
|
+
{
|
13267
|
+
GGML_ASSERT(start % QK5_1 == 0);
|
13268
|
+
block_q5_1 * block = (block_q5_1*)dst + start / QK5_1;
|
13269
|
+
result = ggml_quantize_q5_1(src + start, block, n, n, hist);
|
13270
|
+
} break;
|
13271
|
+
case GGML_TYPE_Q8_0:
|
12236
13272
|
{
|
12237
|
-
GGML_ASSERT(start %
|
12238
|
-
|
12239
|
-
result =
|
13273
|
+
GGML_ASSERT(start % QK8_0 == 0);
|
13274
|
+
block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
|
13275
|
+
result = ggml_quantize_q8_0(src + start, block, n, n, hist);
|
12240
13276
|
} break;
|
12241
13277
|
default:
|
12242
13278
|
assert(false);
|
@@ -12335,7 +13371,7 @@ int ggml_cpu_has_wasm_simd(void) {
|
|
12335
13371
|
}
|
12336
13372
|
|
12337
13373
|
int ggml_cpu_has_blas(void) {
|
12338
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
13374
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
12339
13375
|
return 1;
|
12340
13376
|
#else
|
12341
13377
|
return 0;
|
@@ -12350,6 +13386,18 @@ int ggml_cpu_has_cublas(void) {
|
|
12350
13386
|
#endif
|
12351
13387
|
}
|
12352
13388
|
|
13389
|
+
int ggml_cpu_has_clblast(void) {
|
13390
|
+
#if defined(GGML_USE_CLBLAST)
|
13391
|
+
return 1;
|
13392
|
+
#else
|
13393
|
+
return 0;
|
13394
|
+
#endif
|
13395
|
+
}
|
13396
|
+
|
13397
|
+
int ggml_cpu_has_gpublas(void) {
|
13398
|
+
return ggml_cpu_has_cublas() || ggml_cpu_has_clblast();
|
13399
|
+
}
|
13400
|
+
|
12353
13401
|
int ggml_cpu_has_sse3(void) {
|
12354
13402
|
#if defined(__SSE3__)
|
12355
13403
|
return 1;
|