llama_cpp 0.0.6 → 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +20 -1
- data/ext/llama_cpp/extconf.rb +9 -0
- data/ext/llama_cpp/llama_cpp.cpp +762 -36
- data/ext/llama_cpp/src/ggml-cuda.h +11 -4
- data/ext/llama_cpp/src/ggml-opencl.c +398 -0
- data/ext/llama_cpp/src/ggml-opencl.h +24 -0
- data/ext/llama_cpp/src/ggml.c +1957 -909
- data/ext/llama_cpp/src/ggml.h +696 -627
- data/ext/llama_cpp/src/{llama_util.h → llama-util.h} +91 -12
- data/ext/llama_cpp/src/llama.cpp +755 -159
- data/ext/llama_cpp/src/llama.h +85 -34
- data/lib/llama_cpp/client.rb +174 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +43 -11
- data/sig/llama_cpp.rbs +53 -3
- metadata +6 -3
data/ext/llama_cpp/src/ggml.c
CHANGED
@@ -135,57 +135,14 @@ inline static void* ggml_aligned_malloc(size_t size) {
|
|
135
135
|
#define UNUSED(x) (void)(x)
|
136
136
|
#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
|
137
137
|
|
138
|
-
#define GGML_ASSERT(x) \
|
139
|
-
do { \
|
140
|
-
if (!(x)) { \
|
141
|
-
fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
|
142
|
-
abort(); \
|
143
|
-
} \
|
144
|
-
} while (0)
|
145
|
-
|
146
138
|
#if defined(GGML_USE_ACCELERATE)
|
147
139
|
#include <Accelerate/Accelerate.h>
|
148
140
|
#elif defined(GGML_USE_OPENBLAS)
|
149
141
|
#include <cblas.h>
|
150
142
|
#elif defined(GGML_USE_CUBLAS)
|
151
|
-
#include <cublas_v2.h>
|
152
|
-
#include <cuda_runtime.h>
|
153
143
|
#include "ggml-cuda.h"
|
154
|
-
|
155
|
-
#
|
156
|
-
do { \
|
157
|
-
cudaError_t err_ = (err); \
|
158
|
-
if (err_ != cudaSuccess) { \
|
159
|
-
printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
|
160
|
-
cudaGetErrorString(err_)); \
|
161
|
-
exit(1); \
|
162
|
-
} \
|
163
|
-
} while (0)
|
164
|
-
|
165
|
-
#define CUBLAS_CHECK(err) \
|
166
|
-
do { \
|
167
|
-
cublasStatus_t err_ = (err); \
|
168
|
-
if (err_ != CUBLAS_STATUS_SUCCESS) { \
|
169
|
-
printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
|
170
|
-
exit(1); \
|
171
|
-
} \
|
172
|
-
} while (0)
|
173
|
-
|
174
|
-
static cublasHandle_t cublasH = NULL;
|
175
|
-
static cudaStream_t cudaStream = NULL;
|
176
|
-
static void init_cublas(void) {
|
177
|
-
if (cublasH == NULL) {
|
178
|
-
// create cublas handle, bind a stream
|
179
|
-
CUBLAS_CHECK(cublasCreate(&cublasH));
|
180
|
-
|
181
|
-
CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
|
182
|
-
|
183
|
-
CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
|
184
|
-
|
185
|
-
// configure logging to stdout
|
186
|
-
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
|
187
|
-
}
|
188
|
-
}
|
144
|
+
#elif defined(GGML_USE_CLBLAST)
|
145
|
+
#include "ggml-opencl.h"
|
189
146
|
#endif
|
190
147
|
|
191
148
|
#undef MIN
|
@@ -223,9 +180,13 @@ typedef double ggml_float;
|
|
223
180
|
#undef bool
|
224
181
|
#define bool _Bool
|
225
182
|
#else
|
183
|
+
#if defined(_MSC_VER) || defined(__MINGW32__)
|
184
|
+
#include <intrin.h>
|
185
|
+
#else
|
226
186
|
#include <immintrin.h>
|
227
187
|
#endif
|
228
188
|
#endif
|
189
|
+
#endif
|
229
190
|
|
230
191
|
#ifdef __F16C__
|
231
192
|
|
@@ -365,6 +326,20 @@ static ggml_fp16_t table_exp_f16[1 << 16];
|
|
365
326
|
// precomputed f32 table for f16 (256 KB)
|
366
327
|
static float table_f32_f16[1 << 16];
|
367
328
|
|
329
|
+
#if defined(__ARM_NEON) || defined(__wasm_simd128__)
|
330
|
+
#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
|
331
|
+
#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
|
332
|
+
#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
|
333
|
+
#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
|
334
|
+
#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
|
335
|
+
#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
|
336
|
+
#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
|
337
|
+
#define B8(c,s ) B7(c,s, c), B7(c,s, s)
|
338
|
+
|
339
|
+
// precomputed tables for expanding 8bits to 8 bytes (shl 4)
|
340
|
+
static const uint64_t table_b2b_u[1 << 8] = { B8(00, 10) };
|
341
|
+
#endif
|
342
|
+
|
368
343
|
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
|
369
344
|
// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
|
370
345
|
// This is also true for POWER9.
|
@@ -391,6 +366,32 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) {
|
|
391
366
|
return GGML_FP32_TO_FP16(x);
|
392
367
|
}
|
393
368
|
|
369
|
+
void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, size_t n) {
|
370
|
+
for (size_t i = 0; i < n; i++) {
|
371
|
+
y[i] = GGML_FP16_TO_FP32(x[i]);
|
372
|
+
}
|
373
|
+
}
|
374
|
+
|
375
|
+
void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n) {
|
376
|
+
size_t i = 0;
|
377
|
+
#if defined(__F16C__)
|
378
|
+
for (; i + 7 < n; i += 8) {
|
379
|
+
__m256 x_vec = _mm256_loadu_ps(x + i);
|
380
|
+
__m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
|
381
|
+
_mm_storeu_si128((__m128i *)(y + i), y_vec);
|
382
|
+
}
|
383
|
+
for(; i + 3 < n; i += 4) {
|
384
|
+
__m128 x_vec = _mm_loadu_ps(x + i);
|
385
|
+
__m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
|
386
|
+
_mm_storel_epi64((__m128i *)(y + i), y_vec);
|
387
|
+
}
|
388
|
+
#endif
|
389
|
+
for (; i < n; i++) {
|
390
|
+
y[i] = GGML_FP32_TO_FP16(x[i]);
|
391
|
+
}
|
392
|
+
}
|
393
|
+
|
394
|
+
|
394
395
|
//
|
395
396
|
// timing
|
396
397
|
//
|
@@ -473,7 +474,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
|
|
473
474
|
static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
|
474
475
|
{
|
475
476
|
// Load 8 bytes from memory
|
476
|
-
__m128i tmp =
|
477
|
+
__m128i tmp = _mm_loadl_epi64( ( const __m128i* )rsi );
|
477
478
|
|
478
479
|
// Expand bytes into uint16_t values
|
479
480
|
__m128i bytes = _mm_cvtepu8_epi16( tmp );
|
@@ -487,7 +488,46 @@ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
|
|
487
488
|
return bytes;
|
488
489
|
}
|
489
490
|
|
491
|
+
// horizontally add 8 floats
|
492
|
+
static inline float hsum_float_8(const __m256 x) {
|
493
|
+
__m128 res = _mm256_extractf128_ps(x, 1);
|
494
|
+
res = _mm_add_ps(res, _mm256_castps256_ps128(x));
|
495
|
+
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
|
496
|
+
res = _mm_add_ss(res, _mm_movehdup_ps(res));
|
497
|
+
return _mm_cvtss_f32(res);
|
498
|
+
}
|
499
|
+
|
500
|
+
// horizontally add 8 int32_t
|
501
|
+
static inline int hsum_i32_8(const __m256i a) {
|
502
|
+
const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
|
503
|
+
const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
|
504
|
+
const __m128i sum64 = _mm_add_epi32(hi64, sum128);
|
505
|
+
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
|
506
|
+
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
|
507
|
+
}
|
508
|
+
|
509
|
+
// horizontally add 4 int32_t
|
510
|
+
static inline int hsum_i32_4(const __m128i a) {
|
511
|
+
const __m128i hi64 = _mm_unpackhi_epi64(a, a);
|
512
|
+
const __m128i sum64 = _mm_add_epi32(hi64, a);
|
513
|
+
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
|
514
|
+
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
|
515
|
+
}
|
516
|
+
|
490
517
|
#if __AVX2__ || __AVX512F__
|
518
|
+
// spread 32 bits to 32 bytes { 0x00, 0xFF }
|
519
|
+
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
|
520
|
+
uint32_t x32;
|
521
|
+
memcpy(&x32, x, sizeof(uint32_t));
|
522
|
+
const __m256i shuf_mask = _mm256_set_epi64x(
|
523
|
+
0x0303030303030303, 0x0202020202020202,
|
524
|
+
0x0101010101010101, 0x0000000000000000);
|
525
|
+
__m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
|
526
|
+
const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
|
527
|
+
bytes = _mm256_or_si256(bytes, bit_mask);
|
528
|
+
return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
|
529
|
+
}
|
530
|
+
|
491
531
|
// Unpack 32 4-bit fields into 32 bytes
|
492
532
|
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
493
533
|
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
|
@@ -507,9 +547,38 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
|
|
507
547
|
return bytes;
|
508
548
|
}
|
509
549
|
|
550
|
+
// add int16_t pairwise and return as float vector
|
551
|
+
static inline __m256 sum_i16_pairs_float(const __m256i x) {
|
552
|
+
const __m256i ones = _mm256_set1_epi16(1);
|
553
|
+
const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
|
554
|
+
return _mm256_cvtepi32_ps(summed_pairs);
|
555
|
+
}
|
556
|
+
|
557
|
+
// multiply int8_t, add results pairwise twice and return as float vector
|
558
|
+
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
|
559
|
+
// Get absolute values of x vectors
|
560
|
+
const __m256i ax = _mm256_sign_epi8(x, x);
|
561
|
+
// Sign the values of the y vectors
|
562
|
+
const __m256i sy = _mm256_sign_epi8(y, x);
|
563
|
+
#if __AVXVNNI__
|
564
|
+
const __m256i zero = _mm256_setzero_si256();
|
565
|
+
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
|
566
|
+
return _mm256_cvtepi32_ps(summed_pairs);
|
567
|
+
#else
|
568
|
+
// Perform multiplication and create 16-bit values
|
569
|
+
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
570
|
+
return sum_i16_pairs_float(dot);
|
571
|
+
#endif
|
572
|
+
}
|
573
|
+
|
510
574
|
static inline __m128i packNibbles( __m256i bytes )
|
511
575
|
{
|
512
576
|
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
|
577
|
+
#if __AVX512F__
|
578
|
+
const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000
|
579
|
+
bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
|
580
|
+
return _mm256_cvtepi16_epi8(bytes); // abcd_efgh
|
581
|
+
#else
|
513
582
|
const __m256i lowByte = _mm256_set1_epi16( 0xFF );
|
514
583
|
__m256i high = _mm256_andnot_si256( lowByte, bytes );
|
515
584
|
__m256i low = _mm256_and_si256( lowByte, bytes );
|
@@ -520,6 +589,7 @@ static inline __m128i packNibbles( __m256i bytes )
|
|
520
589
|
__m128i r0 = _mm256_castsi256_si128( bytes );
|
521
590
|
__m128i r1 = _mm256_extracti128_si256( bytes, 1 );
|
522
591
|
return _mm_packus_epi16( r0, r1 );
|
592
|
+
#endif
|
523
593
|
}
|
524
594
|
#else
|
525
595
|
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
@@ -605,19 +675,102 @@ float vmaxvq_f32(float32x4_t v) {
|
|
605
675
|
}
|
606
676
|
|
607
677
|
int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
|
608
|
-
|
678
|
+
int8x8_t res;
|
679
|
+
|
680
|
+
res[0] = a[0]; res[1] = b[0];
|
681
|
+
res[2] = a[1]; res[3] = b[1];
|
682
|
+
res[4] = a[2]; res[5] = b[2];
|
683
|
+
res[6] = a[3]; res[7] = b[3];
|
684
|
+
|
685
|
+
return res;
|
609
686
|
}
|
610
687
|
|
611
688
|
int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
|
612
|
-
|
689
|
+
int8x8_t res;
|
690
|
+
|
691
|
+
res[0] = a[4]; res[1] = b[4];
|
692
|
+
res[2] = a[5]; res[3] = b[5];
|
693
|
+
res[4] = a[6]; res[5] = b[6];
|
694
|
+
res[6] = a[7]; res[7] = b[7];
|
695
|
+
|
696
|
+
return res;
|
613
697
|
}
|
614
698
|
|
615
699
|
uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
|
616
|
-
|
700
|
+
uint8x8_t res;
|
701
|
+
|
702
|
+
res[0] = a[0]; res[1] = b[0];
|
703
|
+
res[2] = a[1]; res[3] = b[1];
|
704
|
+
res[4] = a[2]; res[5] = b[2];
|
705
|
+
res[6] = a[3]; res[7] = b[3];
|
706
|
+
|
707
|
+
return res;
|
617
708
|
}
|
618
709
|
|
619
710
|
uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
|
620
|
-
|
711
|
+
uint8x8_t res;
|
712
|
+
|
713
|
+
res[0] = a[4]; res[1] = b[4];
|
714
|
+
res[2] = a[5]; res[3] = b[5];
|
715
|
+
res[4] = a[6]; res[5] = b[6];
|
716
|
+
res[6] = a[7]; res[7] = b[7];
|
717
|
+
|
718
|
+
return res;
|
719
|
+
}
|
720
|
+
|
721
|
+
int8x16_t vzip1q_s8(int8x16_t a, int8x16_t b) {
|
722
|
+
int8x16_t res;
|
723
|
+
|
724
|
+
res[0] = a[0]; res[1] = b[0]; res[2] = a[1]; res[3] = b[1];
|
725
|
+
res[4] = a[2]; res[5] = b[2]; res[6] = a[3]; res[7] = b[3];
|
726
|
+
res[8] = a[4]; res[9] = b[4]; res[10] = a[5]; res[11] = b[5];
|
727
|
+
res[12] = a[6]; res[13] = b[6]; res[14] = a[7]; res[15] = b[7];
|
728
|
+
|
729
|
+
return res;
|
730
|
+
}
|
731
|
+
|
732
|
+
int8x16_t vzip2q_s8(int8x16_t a, int8x16_t b) {
|
733
|
+
int8x16_t res;
|
734
|
+
|
735
|
+
res[0] = a[8]; res[1] = b[8]; res[2] = a[9]; res[3] = b[9];
|
736
|
+
res[4] = a[10]; res[5] = b[10]; res[6] = a[11]; res[7] = b[11];
|
737
|
+
res[8] = a[12]; res[9] = b[12]; res[10] = a[13]; res[11] = b[13];
|
738
|
+
res[12] = a[14]; res[13] = b[14]; res[14] = a[15]; res[15] = b[15];
|
739
|
+
|
740
|
+
return res;
|
741
|
+
}
|
742
|
+
|
743
|
+
uint8x16_t vzip1q_u8(uint8x16_t a, uint8x16_t b) {
|
744
|
+
uint8x16_t res;
|
745
|
+
|
746
|
+
res[0] = a[0]; res[1] = b[0]; res[2] = a[1]; res[3] = b[1];
|
747
|
+
res[4] = a[2]; res[5] = b[2]; res[6] = a[3]; res[7] = b[3];
|
748
|
+
res[8] = a[4]; res[9] = b[4]; res[10] = a[5]; res[11] = b[5];
|
749
|
+
res[12] = a[6]; res[13] = b[6]; res[14] = a[7]; res[15] = b[7];
|
750
|
+
|
751
|
+
return res;
|
752
|
+
}
|
753
|
+
|
754
|
+
uint8x16_t vzip2q_u8(uint8x16_t a, uint8x16_t b) {
|
755
|
+
uint8x16_t res;
|
756
|
+
|
757
|
+
res[0] = a[8]; res[1] = b[8]; res[2] = a[9]; res[3] = b[9];
|
758
|
+
res[4] = a[10]; res[5] = b[10]; res[6] = a[11]; res[7] = b[11];
|
759
|
+
res[8] = a[12]; res[9] = b[12]; res[10] = a[13]; res[11] = b[13];
|
760
|
+
res[12] = a[14]; res[13] = b[14]; res[14] = a[15]; res[15] = b[15];
|
761
|
+
|
762
|
+
return res;
|
763
|
+
}
|
764
|
+
|
765
|
+
int32x4_t vcvtnq_s32_f32(float32x4_t v) {
|
766
|
+
int32x4_t res;
|
767
|
+
|
768
|
+
res[0] = roundf(vgetq_lane_f32(v, 0));
|
769
|
+
res[1] = roundf(vgetq_lane_f32(v, 1));
|
770
|
+
res[2] = roundf(vgetq_lane_f32(v, 2));
|
771
|
+
res[3] = roundf(vgetq_lane_f32(v, 3));
|
772
|
+
|
773
|
+
return res;
|
621
774
|
}
|
622
775
|
|
623
776
|
#endif
|
@@ -646,13 +799,22 @@ typedef struct {
|
|
646
799
|
} block_q4_2;
|
647
800
|
static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
|
648
801
|
|
649
|
-
#define
|
802
|
+
#define QK5_0 32
|
803
|
+
typedef struct {
|
804
|
+
ggml_fp16_t d; // delta
|
805
|
+
uint8_t qh[4]; // 5-th bit of quants
|
806
|
+
uint8_t qs[QK5_0 / 2]; // nibbles / quants
|
807
|
+
} block_q5_0;
|
808
|
+
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
|
809
|
+
|
810
|
+
#define QK5_1 32
|
650
811
|
typedef struct {
|
651
812
|
ggml_fp16_t d; // delta
|
652
813
|
ggml_fp16_t m; // min
|
653
|
-
uint8_t
|
654
|
-
|
655
|
-
|
814
|
+
uint8_t qh[4]; // 5-th bit of quants
|
815
|
+
uint8_t qs[QK5_1 / 2]; // nibbles / quants
|
816
|
+
} block_q5_1;
|
817
|
+
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
|
656
818
|
|
657
819
|
#define QK8_0 32
|
658
820
|
typedef struct {
|
@@ -661,6 +823,14 @@ typedef struct {
|
|
661
823
|
} block_q8_0;
|
662
824
|
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
|
663
825
|
|
826
|
+
#define QK8_1 32
|
827
|
+
typedef struct {
|
828
|
+
float d; // delta
|
829
|
+
float s0; // d * sum(qs[i]) low
|
830
|
+
float s1; // d * sum(qs[i]) high
|
831
|
+
int8_t qs[QK8_1]; // quants
|
832
|
+
} block_q8_1;
|
833
|
+
static_assert(sizeof(block_q8_1) == 3*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
|
664
834
|
|
665
835
|
// reference implementation for deterministic creation of model files
|
666
836
|
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
|
@@ -671,13 +841,17 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
|
|
671
841
|
|
672
842
|
for (int i = 0; i < nb; i++) {
|
673
843
|
float amax = 0.0f; // absolute max
|
844
|
+
float max = 0.0f;
|
674
845
|
|
675
846
|
for (int l = 0; l < QK4_0; l++) {
|
676
847
|
const float v = x[i*QK4_0 + l];
|
677
|
-
|
848
|
+
if (amax < fabsf(v)) {
|
849
|
+
amax = fabsf(v);
|
850
|
+
max = v;
|
851
|
+
}
|
678
852
|
}
|
679
853
|
|
680
|
-
const float d =
|
854
|
+
const float d = max / -8;
|
681
855
|
const float id = d ? 1.0f/d : 0.0f;
|
682
856
|
|
683
857
|
y[i].d = d;
|
@@ -686,8 +860,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
|
|
686
860
|
const float v0 = x[i*QK4_0 + l + 0]*id;
|
687
861
|
const float v1 = x[i*QK4_0 + l + 1]*id;
|
688
862
|
|
689
|
-
const uint8_t vi0 = (int8_t)roundf(v0) + 8;
|
690
|
-
const uint8_t vi1 = (int8_t)roundf(v1) + 8;
|
863
|
+
const uint8_t vi0 = MIN(15, (int8_t)roundf(v0) + 8);
|
864
|
+
const uint8_t vi1 = MIN(15, (int8_t)roundf(v1) + 8);
|
691
865
|
|
692
866
|
assert(vi0 < 16);
|
693
867
|
assert(vi1 < 16);
|
@@ -707,28 +881,43 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
707
881
|
|
708
882
|
#if defined(__POWER9_VECTOR__)
|
709
883
|
const vector float v85 = vec_splats(8.5f);
|
884
|
+
const vector signed int v15 = vec_splats(15);
|
710
885
|
for (int i = 0; i < nb; i++) {
|
711
|
-
float
|
886
|
+
float max = 0.0f;
|
887
|
+
float min = 0.0f;
|
712
888
|
|
889
|
+
vector float asrcv [8];
|
713
890
|
vector float srcv [8];
|
714
|
-
vector float
|
715
|
-
vector float
|
891
|
+
vector float maxv[8];
|
892
|
+
vector float minv[8];
|
716
893
|
|
717
894
|
for (int l = 0; l < 8; l++) srcv[l] = *(vector float *)(x + i*32 + 4*l);
|
718
|
-
for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
|
719
|
-
|
720
|
-
for (int l = 0; l < 4; l++)
|
721
|
-
//for (int l = 0; l < 2; l++)
|
722
|
-
|
723
|
-
|
724
|
-
//for (int l = 0; l < 1; l++)
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
895
|
+
//for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
|
896
|
+
|
897
|
+
for (int l = 0; l < 4; l++) maxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]);
|
898
|
+
//for (int l = 0; l < 2; l++) maxv[4*l] = vec_max(maxv[4*l], maxv[4*l+2]);
|
899
|
+
maxv[0] = vec_max(maxv[0], maxv[2]);
|
900
|
+
maxv[4] = vec_max(maxv[4], maxv[6]);
|
901
|
+
//for (int l = 0; l < 1; l++) maxv[8*l] = vec_max(maxv[8*l], maxv[8*l+4]);
|
902
|
+
maxv[0] = vec_max(maxv[0], maxv[4]);
|
903
|
+
|
904
|
+
for (int l = 0; l < 4; l++) minv[2*l] = vec_min(asrcv[2*l], asrcv[2*l+1]);
|
905
|
+
//for (int l = 0; l < 2; l++) minv[4*l] = vec_min(minv[4*l], minv[4*l+2]);
|
906
|
+
minv[0] = vec_min(minv[0], minv[2]);
|
907
|
+
minv[4] = vec_min(minv[4], minv[6]);
|
908
|
+
//for (int l = 0; l < 1; l++) minv[8*l] = vec_min(minv[8*l], minv[8*l+4]);
|
909
|
+
minv[0] = vec_min(minv[0], minv[4]);
|
910
|
+
|
911
|
+
|
912
|
+
max = MAX(
|
913
|
+
MAX(vec_extract(maxv[0], 0), vec_extract(maxv[0], 1)),
|
914
|
+
MAX(vec_extract(maxv[0], 2), vec_extract(maxv[0], 3)));
|
915
|
+
min = MIN(
|
916
|
+
MIN(vec_extract(minv[0], 0), vec_extract(minv[0], 1)),
|
917
|
+
MIN(vec_extract(minv[0], 2), vec_extract(minv[0], 3)));
|
918
|
+
|
919
|
+
const float magnitude = max >= fabsf(min) ? max : min;
|
920
|
+
const float d = magnitude / -8;
|
732
921
|
const float id = d ? 1.0/d : 0.0;
|
733
922
|
|
734
923
|
y[i].d = d;
|
@@ -738,27 +927,33 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
738
927
|
for (int l = 0; l < 8; l++) {
|
739
928
|
const vector float vf = vec_madd(srcv[l], vid, v85);
|
740
929
|
const vector signed int vi = vec_signed(vf);
|
930
|
+
const vector signed int vc = vec_min(vi, v15);
|
741
931
|
|
742
|
-
pb[2*l + 0] = vec_extract(
|
743
|
-
pb[2*l + 1] = vec_extract(
|
932
|
+
pb[2*l + 0] = vec_extract(vc, 0) | (vec_extract(vc, 1) << 4);
|
933
|
+
pb[2*l + 1] = vec_extract(vc, 2) | (vec_extract(vc, 3) << 4);
|
744
934
|
}
|
745
935
|
}
|
746
936
|
#elif __ARM_NEON
|
747
937
|
for (int i = 0; i < nb; i++) {
|
748
938
|
float32x4_t srcv [8];
|
749
|
-
float32x4_t
|
750
|
-
float32x4_t
|
939
|
+
float32x4_t maxv[8];
|
940
|
+
float32x4_t minv[8];
|
751
941
|
|
752
942
|
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
|
753
|
-
for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
|
754
943
|
|
755
|
-
for (int l = 0; l < 4; l++)
|
756
|
-
for (int l = 0; l < 2; l++)
|
757
|
-
for (int l = 0; l < 1; l++)
|
944
|
+
for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l+1]);
|
945
|
+
for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l+2]);
|
946
|
+
for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l+4]);
|
758
947
|
|
759
|
-
|
948
|
+
for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l+1]);
|
949
|
+
for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l+2]);
|
950
|
+
for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l+4]);
|
951
|
+
|
952
|
+
const float max = vmaxvq_f32(maxv[0]);
|
953
|
+
const float min = vminvq_f32(minv[0]);
|
760
954
|
|
761
|
-
const float
|
955
|
+
const float magnitude = max >= fabsf(min) ? max : min;
|
956
|
+
const float d = magnitude / -8;
|
762
957
|
const float id = d ? 1.0f/d : 0.0f;
|
763
958
|
|
764
959
|
y[i].d = d;
|
@@ -767,9 +962,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
767
962
|
const float32x4_t v = vmulq_n_f32(srcv[l], id);
|
768
963
|
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
|
769
964
|
const int32x4_t vi = vcvtq_s32_f32(vf);
|
965
|
+
const int32x4_t vc = vminq_s32(vi, vdupq_n_s32(15));
|
770
966
|
|
771
|
-
y[i].qs[2*l + 0] = vgetq_lane_s32(
|
772
|
-
y[i].qs[2*l + 1] = vgetq_lane_s32(
|
967
|
+
y[i].qs[2*l + 0] = vgetq_lane_s32(vc, 0) | (vgetq_lane_s32(vc, 1) << 4);
|
968
|
+
y[i].qs[2*l + 1] = vgetq_lane_s32(vc, 2) | (vgetq_lane_s32(vc, 3) << 4);
|
773
969
|
}
|
774
970
|
}
|
775
971
|
#elif defined(__AVX2__)
|
@@ -781,22 +977,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
781
977
|
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
782
978
|
x += 32;
|
783
979
|
|
784
|
-
// Compute max
|
785
|
-
|
786
|
-
__m256
|
787
|
-
|
788
|
-
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
|
789
|
-
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
|
980
|
+
// Compute max for the block
|
981
|
+
__m256 max = _mm256_max_ps( v0, v1 );
|
982
|
+
__m256 maxTmp = _mm256_max_ps( v2, v3 );
|
983
|
+
max = _mm256_max_ps( max, maxTmp );
|
790
984
|
|
791
|
-
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps(
|
985
|
+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
|
792
986
|
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
793
987
|
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
794
988
|
const float maxScalar = _mm_cvtss_f32( max4 );
|
795
989
|
|
990
|
+
// Compute min for the block
|
991
|
+
__m256 min = _mm256_min_ps( v0, v1 );
|
992
|
+
__m256 minTmp = _mm256_min_ps( v2, v3 );
|
993
|
+
min = _mm256_min_ps( min, minTmp );
|
994
|
+
|
995
|
+
__m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
|
996
|
+
min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
|
997
|
+
min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
|
998
|
+
const float minScalar = _mm_cvtss_f32( min4 );
|
999
|
+
|
796
1000
|
// Quantize these floats
|
797
|
-
const float
|
1001
|
+
const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
|
1002
|
+
const float d = magnitude / -8.0f;
|
798
1003
|
y[i].d = d;
|
799
|
-
const float id = (
|
1004
|
+
const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
|
800
1005
|
const __m256 mul = _mm256_set1_ps( id );
|
801
1006
|
|
802
1007
|
// Apply the multiplier
|
@@ -829,9 +1034,11 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
829
1034
|
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
|
830
1035
|
i0 = _mm256_permutevar8x32_epi32( i0, perm );
|
831
1036
|
|
832
|
-
// Apply offset to translate the range from [ -
|
1037
|
+
// Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
|
833
1038
|
const __m256i off = _mm256_set1_epi8( 8 );
|
834
1039
|
i0 = _mm256_add_epi8( i0, off );
|
1040
|
+
const __m256i maxNibble = _mm256_set1_epi8( 15 );
|
1041
|
+
i0 = _mm256_min_epi8( i0, maxNibble );
|
835
1042
|
|
836
1043
|
// Compress the vector into 4 bit/value, and store
|
837
1044
|
__m128i res = packNibbles( i0 );
|
@@ -846,22 +1053,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
846
1053
|
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
847
1054
|
x += 32;
|
848
1055
|
|
849
|
-
// Compute max
|
850
|
-
|
851
|
-
__m256
|
852
|
-
|
853
|
-
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
|
854
|
-
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
|
1056
|
+
// Compute max for the block
|
1057
|
+
__m256 max = _mm256_max_ps( v0, v1 );
|
1058
|
+
__m256 maxTmp = _mm256_max_ps( v2, v3 );
|
1059
|
+
max = _mm256_max_ps( max, maxTmp );
|
855
1060
|
|
856
|
-
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps(
|
1061
|
+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
|
857
1062
|
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
858
1063
|
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
859
1064
|
const float maxScalar = _mm_cvtss_f32( max4 );
|
860
1065
|
|
1066
|
+
// Compute min for the block
|
1067
|
+
__m256 min = _mm256_min_ps( v0, v1 );
|
1068
|
+
__m256 minTmp = _mm256_min_ps( v2, v3 );
|
1069
|
+
min = _mm256_min_ps( min, minTmp );
|
1070
|
+
|
1071
|
+
__m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
|
1072
|
+
min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
|
1073
|
+
min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
|
1074
|
+
const float minScalar = _mm_cvtss_f32( min4 );
|
1075
|
+
|
861
1076
|
// Quantize these floats
|
862
|
-
const float
|
1077
|
+
const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
|
1078
|
+
const float d = magnitude / -8.0f;
|
863
1079
|
y[i].d = d;
|
864
|
-
const float id = (
|
1080
|
+
const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
|
865
1081
|
const __m256 mul = _mm256_set1_ps( id );
|
866
1082
|
|
867
1083
|
// Apply the multiplier
|
@@ -902,10 +1118,13 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
902
1118
|
ni0 = _mm_packs_epi16( ni0, ni2 );
|
903
1119
|
ni4 = _mm_packs_epi16( ni4, ni6 );
|
904
1120
|
|
905
|
-
// Apply offset to translate the range from [ -
|
906
|
-
const __m128i off = _mm_set1_epi8( 8);
|
1121
|
+
// Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
|
1122
|
+
const __m128i off = _mm_set1_epi8( 8 );
|
907
1123
|
ni0 = _mm_add_epi8( ni0, off );
|
908
1124
|
ni4 = _mm_add_epi8( ni4, off );
|
1125
|
+
const __m128i maxNibble = _mm_set1_epi8( 15 );
|
1126
|
+
ni0 = _mm_min_epi8( ni0, maxNibble );
|
1127
|
+
ni4 = _mm_min_epi8( ni4, maxNibble );
|
909
1128
|
|
910
1129
|
// Compress the vector into 4 bit/value, and store
|
911
1130
|
__m128i res = packNibbles( ni0, ni4 );
|
@@ -913,24 +1132,32 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
913
1132
|
}
|
914
1133
|
#elif defined(__wasm_simd128__)
|
915
1134
|
for (int i = 0; i < nb; i++) {
|
916
|
-
float
|
1135
|
+
float max = 0.0f;
|
1136
|
+
float min = 0.0f;
|
917
1137
|
|
918
1138
|
v128_t srcv [8];
|
919
|
-
v128_t
|
920
|
-
v128_t
|
1139
|
+
v128_t maxv[8];
|
1140
|
+
v128_t minv[8];
|
921
1141
|
|
922
1142
|
for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l);
|
923
|
-
for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]);
|
924
1143
|
|
925
|
-
for (int l = 0; l < 4; l++)
|
926
|
-
for (int l = 0; l < 2; l++)
|
927
|
-
for (int l = 0; l < 1; l++)
|
1144
|
+
for (int l = 0; l < 4; l++) maxv[2*l] = wasm_f32x4_max(srcv[2*l], srcv[2*l+1]);
|
1145
|
+
for (int l = 0; l < 2; l++) maxv[4*l] = wasm_f32x4_max(maxv[4*l], maxv[4*l+2]);
|
1146
|
+
for (int l = 0; l < 1; l++) maxv[8*l] = wasm_f32x4_max(maxv[8*l], maxv[8*l+4]);
|
1147
|
+
|
1148
|
+
for (int l = 0; l < 4; l++) minv[2*l] = wasm_f32x4_min(srcv[2*l], srcv[2*l+1]);
|
1149
|
+
for (int l = 0; l < 2; l++) minv[4*l] = wasm_f32x4_min(minv[4*l], minv[4*l+2]);
|
1150
|
+
for (int l = 0; l < 1; l++) minv[8*l] = wasm_f32x4_min(minv[8*l], minv[8*l+4]);
|
928
1151
|
|
929
|
-
|
930
|
-
MAX(wasm_f32x4_extract_lane(
|
931
|
-
MAX(wasm_f32x4_extract_lane(
|
1152
|
+
max = MAX(
|
1153
|
+
MAX(wasm_f32x4_extract_lane(maxv[0], 0), wasm_f32x4_extract_lane(maxv[0], 1)),
|
1154
|
+
MAX(wasm_f32x4_extract_lane(maxv[0], 2), wasm_f32x4_extract_lane(maxv[0], 3)));
|
1155
|
+
min = MIN(
|
1156
|
+
MIN(wasm_f32x4_extract_lane(minv[0], 0), wasm_f32x4_extract_lane(minv[0], 1)),
|
1157
|
+
MIN(wasm_f32x4_extract_lane(minv[0], 2), wasm_f32x4_extract_lane(minv[0], 3)));
|
932
1158
|
|
933
|
-
const float
|
1159
|
+
const float magnitude = max >= fabsf(min) ? max : min;
|
1160
|
+
const float d = magnitude / -8;
|
934
1161
|
const float id = d ? 1.0/d : 0.0;
|
935
1162
|
|
936
1163
|
y[i].d = d;
|
@@ -939,9 +1166,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
939
1166
|
const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
|
940
1167
|
const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
|
941
1168
|
const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
|
1169
|
+
const v128_t vc = wasm_i32x4_min(vi, wasm_i32x4_splat(15));
|
942
1170
|
|
943
|
-
y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(
|
944
|
-
y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(
|
1171
|
+
y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vc, 0) | (wasm_i32x4_extract_lane(vc, 1) << 4);
|
1172
|
+
y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vc, 2) | (wasm_i32x4_extract_lane(vc, 3) << 4);
|
945
1173
|
}
|
946
1174
|
}
|
947
1175
|
#else
|
@@ -1122,13 +1350,17 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
|
|
1122
1350
|
|
1123
1351
|
for (int i = 0; i < nb; i++) {
|
1124
1352
|
float amax = 0.0f; // absolute max
|
1353
|
+
float max = 0.0f;
|
1125
1354
|
|
1126
1355
|
for (int l = 0; l < QK4_2; l++) {
|
1127
1356
|
const float v = x[i*QK4_2 + l];
|
1128
|
-
|
1357
|
+
if (amax < fabsf(v)) {
|
1358
|
+
amax = fabsf(v);
|
1359
|
+
max = v;
|
1360
|
+
}
|
1129
1361
|
}
|
1130
1362
|
|
1131
|
-
const float d =
|
1363
|
+
const float d = max / -8;
|
1132
1364
|
|
1133
1365
|
const float id = d ? 1.0f/d : 0.0f;
|
1134
1366
|
|
@@ -1138,8 +1370,8 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
|
|
1138
1370
|
const float v0 = x[i*QK4_2 + l + 0]*id;
|
1139
1371
|
const float v1 = x[i*QK4_2 + l + 1]*id;
|
1140
1372
|
|
1141
|
-
const uint8_t vi0 = (uint8_t)(v0 + 8.5f);
|
1142
|
-
const uint8_t vi1 = (uint8_t)(v1 + 8.5f);
|
1373
|
+
const uint8_t vi0 = MIN(15, (uint8_t)(v0 + 8.5f));
|
1374
|
+
const uint8_t vi1 = MIN(15, (uint8_t)(v1 + 8.5f));
|
1143
1375
|
|
1144
1376
|
assert(vi0 < 16);
|
1145
1377
|
assert(vi1 < 16);
|
@@ -1149,136 +1381,109 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
|
|
1149
1381
|
}
|
1150
1382
|
}
|
1151
1383
|
|
1152
|
-
static
|
1153
|
-
assert(
|
1154
|
-
|
1155
|
-
|
1156
|
-
|
1157
|
-
|
1158
|
-
|
1159
|
-
|
1160
|
-
|
1161
|
-
assert
|
1162
|
-
|
1163
|
-
|
1164
|
-
for (int i=0; i<
|
1165
|
-
|
1166
|
-
|
1167
|
-
|
1168
|
-
|
1169
|
-
|
1170
|
-
|
1171
|
-
|
1172
|
-
|
1173
|
-
float sumlxM = 0; int suml2M = 0;
|
1174
|
-
for (int i=0; i<n; ++i) {
|
1175
|
-
int l = nearest_int(iscale*X[i]);
|
1176
|
-
int lp = MAX(nmin, MIN(nmax, +l));
|
1177
|
-
int lm = MAX(nmin, MIN(nmax, -l));
|
1178
|
-
sumlxP += X[i]*lp; suml2P += lp*lp;
|
1179
|
-
sumlxM += X[i]*lm; suml2M += lm*lm;
|
1180
|
-
}
|
1181
|
-
float sumlxP2 = sumlxP*sumlxP;
|
1182
|
-
float sumlxM2 = sumlxM*sumlxM;
|
1183
|
-
if (sumlxP2*suml2M > sumlxM2*suml2P) {
|
1184
|
-
if (sumlxP2 > best*suml2P) {
|
1185
|
-
best = sumlxP2/suml2P; bestScale = iscale;
|
1186
|
-
}
|
1187
|
-
} else {
|
1188
|
-
if (sumlxM2 > best*suml2M) {
|
1189
|
-
best = sumlxM2/suml2M; bestScale = -iscale;
|
1384
|
+
static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
|
1385
|
+
assert(k % QK4_2 == 0);
|
1386
|
+
|
1387
|
+
block_q4_2 * restrict y = vy;
|
1388
|
+
|
1389
|
+
quantize_row_q4_2_reference(x, y, k);
|
1390
|
+
}
|
1391
|
+
|
1392
|
+
static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
|
1393
|
+
assert(k % QK5_0 == 0);
|
1394
|
+
const int nb = k / QK5_0;
|
1395
|
+
|
1396
|
+
for (int i = 0; i < nb; i++) {
|
1397
|
+
float amax = 0.0f; // absolute max
|
1398
|
+
float max = 0.0f;
|
1399
|
+
|
1400
|
+
for (int l = 0; l < QK5_0; l++) {
|
1401
|
+
const float v = x[i*QK5_0 + l];
|
1402
|
+
if (amax < fabsf(v)) {
|
1403
|
+
amax = fabsf(v);
|
1404
|
+
max = v;
|
1190
1405
|
}
|
1191
1406
|
}
|
1192
|
-
}
|
1193
|
-
float sumlx = 0; int suml2 = 0;
|
1194
|
-
for (int i=0; i<n; ++i) {
|
1195
|
-
int l = nearest_int(bestScale*X[i]);
|
1196
|
-
l = MAX(nmin, MIN(nmax, l));
|
1197
|
-
sumlx += X[i]*l; suml2 += l*l;
|
1198
|
-
L[i] = l;
|
1199
|
-
}
|
1200
|
-
float scale = sumlx/suml2;
|
1201
|
-
return scale;
|
1202
|
-
}
|
1203
1407
|
|
1204
|
-
|
1205
|
-
|
1206
|
-
static const float candidates[CANDIDATE_COUNT] = { +8.7f, +8.3f, +8.1f, +7.8f, +7.3f, +7.0f, +6.3f, +5.7f };
|
1207
|
-
assert(k % QK4_2 == 0);
|
1408
|
+
const float d = max / -16;
|
1409
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1208
1410
|
|
1209
|
-
|
1411
|
+
y[i].d = GGML_FP32_TO_FP16(d);
|
1210
1412
|
|
1211
|
-
|
1413
|
+
uint32_t qh = 0;
|
1212
1414
|
|
1213
|
-
|
1214
|
-
|
1215
|
-
|
1415
|
+
for (int l = 0; l < QK5_0; l += 2) {
|
1416
|
+
const float v0 = x[i*QK5_0 + l + 0]*id;
|
1417
|
+
const float v1 = x[i*QK5_0 + l + 1]*id;
|
1216
1418
|
|
1217
|
-
|
1218
|
-
const
|
1219
|
-
const uint8_t vi1 = (uint8_t)(L[l+1] + 8);
|
1419
|
+
const uint32_t vi0 = MIN(31, (int) (v0 + 16.5f));
|
1420
|
+
const uint32_t vi1 = MIN(31, (int) (v1 + 16.5f));
|
1220
1421
|
|
1221
|
-
|
1222
|
-
assert(vi1 < 16);
|
1422
|
+
y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
|
1223
1423
|
|
1224
|
-
|
1424
|
+
// get the 5-th bit and store it in qh at the right position
|
1425
|
+
qh |= ((vi0 & 0x10) >> 4) << (l + 0);
|
1426
|
+
qh |= ((vi1 & 0x10) >> 4) << (l + 1);
|
1225
1427
|
}
|
1226
1428
|
|
1227
|
-
|
1429
|
+
memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
|
1228
1430
|
}
|
1229
1431
|
}
|
1230
1432
|
|
1231
|
-
static void
|
1232
|
-
assert(k %
|
1433
|
+
static void quantize_row_q5_0(const float * restrict x, void * restrict vy, int k) {
|
1434
|
+
assert(k % QK5_0 == 0);
|
1233
1435
|
|
1234
|
-
|
1436
|
+
block_q5_0 * restrict y = vy;
|
1235
1437
|
|
1236
|
-
|
1237
|
-
// This produces the exact same format, just better match to the input floats ("better" as measured by RMSE)
|
1238
|
-
quantize_row_q4_2_rmse(x, y, k);
|
1438
|
+
quantize_row_q5_0_reference(x, y, k);
|
1239
1439
|
}
|
1240
1440
|
|
1241
|
-
static void
|
1242
|
-
assert(k %
|
1243
|
-
const int nb = k /
|
1441
|
+
static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
|
1442
|
+
assert(k % QK5_1 == 0);
|
1443
|
+
const int nb = k / QK5_1;
|
1244
1444
|
|
1245
1445
|
for (int i = 0; i < nb; i++) {
|
1246
1446
|
float min = FLT_MAX;
|
1247
1447
|
float max = -FLT_MAX;
|
1248
1448
|
|
1249
|
-
for (int l = 0; l <
|
1250
|
-
const float v = x[i*
|
1449
|
+
for (int l = 0; l < QK5_1; l++) {
|
1450
|
+
const float v = x[i*QK5_1 + l];
|
1251
1451
|
if (v < min) min = v;
|
1252
1452
|
if (v > max) max = v;
|
1253
1453
|
}
|
1254
1454
|
|
1255
|
-
const float d = (max - min) / ((1 <<
|
1455
|
+
const float d = (max - min) / ((1 << 5) - 1);
|
1256
1456
|
const float id = d ? 1.0f/d : 0.0f;
|
1257
1457
|
|
1258
1458
|
y[i].d = GGML_FP32_TO_FP16(d);
|
1259
1459
|
y[i].m = GGML_FP32_TO_FP16(min);
|
1260
1460
|
|
1261
|
-
|
1262
|
-
const float v0 = (x[i*QK4_3 + l + 0] - min)*id;
|
1263
|
-
const float v1 = (x[i*QK4_3 + l + 1] - min)*id;
|
1461
|
+
uint32_t qh = 0;
|
1264
1462
|
|
1265
|
-
|
1266
|
-
const
|
1463
|
+
for (int l = 0; l < QK5_1; l += 2) {
|
1464
|
+
const float v0 = (x[i*QK5_1 + l + 0] - min)*id;
|
1465
|
+
const float v1 = (x[i*QK5_1 + l + 1] - min)*id;
|
1267
1466
|
|
1268
|
-
|
1269
|
-
|
1467
|
+
const uint32_t vi0 = (int) (v0 + 0.5f);
|
1468
|
+
const uint32_t vi1 = (int) (v1 + 0.5f);
|
1270
1469
|
|
1271
|
-
y[i].qs[l/2] = vi0 | (vi1 << 4);
|
1470
|
+
y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
|
1471
|
+
|
1472
|
+
// get the 5-th bit and store it in qh at the right position
|
1473
|
+
qh |= ((vi0 & 0x10) >> 4) << (l + 0);
|
1474
|
+
qh |= ((vi1 & 0x10) >> 4) << (l + 1);
|
1272
1475
|
}
|
1476
|
+
|
1477
|
+
memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
|
1273
1478
|
}
|
1274
1479
|
}
|
1275
1480
|
|
1276
|
-
static void
|
1277
|
-
assert(k %
|
1481
|
+
static void quantize_row_q5_1(const float * restrict x, void * restrict vy, int k) {
|
1482
|
+
assert(k % QK5_1 == 0);
|
1278
1483
|
|
1279
|
-
|
1484
|
+
block_q5_1 * restrict y = vy;
|
1280
1485
|
|
1281
|
-
|
1486
|
+
quantize_row_q5_1_reference(x, y, k);
|
1282
1487
|
}
|
1283
1488
|
|
1284
1489
|
// reference implementation for deterministic creation of model files
|
@@ -1300,13 +1505,15 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
|
|
1300
1505
|
y[i].d = d;
|
1301
1506
|
|
1302
1507
|
for (int l = 0; l < QK8_0; ++l) {
|
1303
|
-
const float
|
1304
|
-
|
1508
|
+
const float v0 = x[i*QK8_0 + l]*id;
|
1509
|
+
|
1510
|
+
y[i].qs[l] = roundf(v0);
|
1305
1511
|
}
|
1306
1512
|
}
|
1307
1513
|
}
|
1308
1514
|
|
1309
1515
|
static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
|
1516
|
+
assert(QK8_0 == 32);
|
1310
1517
|
assert(k % QK8_0 == 0);
|
1311
1518
|
const int nb = k / QK8_0;
|
1312
1519
|
|
@@ -1432,95 +1639,295 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|
1432
1639
|
#endif
|
1433
1640
|
}
|
1434
1641
|
|
1435
|
-
|
1436
|
-
|
1437
|
-
|
1438
|
-
|
1439
|
-
const
|
1642
|
+
// reference implementation for deterministic creation of model files
|
1643
|
+
static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
|
1644
|
+
assert(QK8_1 == 32);
|
1645
|
+
assert(k % QK8_1 == 0);
|
1646
|
+
const int nb = k / QK8_1;
|
1440
1647
|
|
1441
|
-
#if defined(__AVX2__)
|
1442
1648
|
for (int i = 0; i < nb; i++) {
|
1443
|
-
//
|
1444
|
-
const __m256 d_v = _mm256_broadcast_ss(&x[i].d);
|
1649
|
+
float amax = 0.0f; // absolute max
|
1445
1650
|
|
1446
|
-
|
1651
|
+
for (int l = 0; l < QK8_1; l++) {
|
1652
|
+
const float v = x[i*QK8_1 + l];
|
1653
|
+
amax = MAX(amax, fabsf(v));
|
1654
|
+
}
|
1447
1655
|
|
1448
|
-
|
1449
|
-
|
1450
|
-
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
|
1656
|
+
const float d = amax / ((1 << 7) - 1);
|
1657
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1451
1658
|
|
1452
|
-
|
1453
|
-
vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
|
1659
|
+
y[i].d = d;
|
1454
1660
|
|
1455
|
-
|
1456
|
-
|
1457
|
-
const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
|
1661
|
+
int sum0 = 0;
|
1662
|
+
int sum1 = 0;
|
1458
1663
|
|
1459
|
-
|
1460
|
-
const
|
1461
|
-
|
1462
|
-
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
|
1463
|
-
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
|
1464
|
-
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
|
1465
|
-
};
|
1664
|
+
for (int l = 0; l < QK8_1/2; ++l) {
|
1665
|
+
const float v0 = x[i*QK8_1 + l]*id;
|
1666
|
+
const float v1 = x[i*QK8_1 + QK8_1/2 + l]*id;
|
1466
1667
|
|
1467
|
-
|
1468
|
-
|
1469
|
-
|
1470
|
-
|
1471
|
-
|
1668
|
+
y[i].qs[ l] = roundf(v0);
|
1669
|
+
y[i].qs[QK8_1/2 + l] = roundf(v1);
|
1670
|
+
|
1671
|
+
sum0 += y[i].qs[ l];
|
1672
|
+
sum1 += y[i].qs[QK8_1/2 + l];
|
1472
1673
|
}
|
1674
|
+
|
1675
|
+
y[i].s0 = d * sum0;
|
1676
|
+
y[i].s1 = d * sum1;
|
1473
1677
|
}
|
1474
|
-
|
1475
|
-
for (int i = 0; i < nb; i++) {
|
1476
|
-
const float32x4_t vd = vdupq_n_f32(x[i].d);
|
1678
|
+
}
|
1477
1679
|
|
1478
|
-
|
1680
|
+
static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
|
1681
|
+
assert(k % QK8_1 == 0);
|
1682
|
+
const int nb = k / QK8_1;
|
1479
1683
|
|
1480
|
-
|
1481
|
-
// Load 16x4-bit integers into 8x8-bit integers
|
1482
|
-
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
1684
|
+
block_q8_1 * restrict y = vy;
|
1483
1685
|
|
1484
|
-
|
1485
|
-
|
1486
|
-
|
1686
|
+
#if defined(__ARM_NEON)
|
1687
|
+
for (int i = 0; i < nb; i++) {
|
1688
|
+
float32x4_t srcv [8];
|
1689
|
+
float32x4_t asrcv[8];
|
1690
|
+
float32x4_t amaxv[8];
|
1487
1691
|
|
1488
|
-
|
1489
|
-
|
1490
|
-
const int8x8_t vs_1 = vreinterpret_s8_u8(v1);
|
1692
|
+
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
|
1693
|
+
for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
|
1491
1694
|
|
1492
|
-
|
1493
|
-
|
1494
|
-
|
1695
|
+
for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
|
1696
|
+
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
|
1697
|
+
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
|
1495
1698
|
|
1496
|
-
|
1497
|
-
const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1);
|
1498
|
-
const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1);
|
1699
|
+
const float amax = vmaxvq_f32(amaxv[0]);
|
1499
1700
|
|
1500
|
-
|
1701
|
+
const float d = amax / ((1 << 7) - 1);
|
1702
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1501
1703
|
|
1502
|
-
|
1503
|
-
const int16x8_t vi_0 = vmovl_s8(vget_low_s8 (vq));
|
1504
|
-
const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq));
|
1704
|
+
y[i].d = d;
|
1505
1705
|
|
1506
|
-
|
1507
|
-
|
1508
|
-
const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0)));
|
1509
|
-
const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1)));
|
1510
|
-
const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1)));
|
1706
|
+
int32x4_t accv0 = vdupq_n_s32(0);
|
1707
|
+
int32x4_t accv1 = vdupq_n_s32(0);
|
1511
1708
|
|
1512
|
-
|
1513
|
-
|
1514
|
-
const float32x4_t
|
1515
|
-
const
|
1516
|
-
const float32x4_t r3 = vmulq_f32(vf_3, vd);
|
1709
|
+
// low half
|
1710
|
+
for (int l = 0; l < 4; l++) {
|
1711
|
+
const float32x4_t v = vmulq_n_f32(srcv[l], id);
|
1712
|
+
const int32x4_t vi = vcvtnq_s32_f32(v);
|
1517
1713
|
|
1518
|
-
|
1519
|
-
|
1520
|
-
|
1521
|
-
|
1522
|
-
|
1523
|
-
|
1714
|
+
y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
|
1715
|
+
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
|
1716
|
+
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
|
1717
|
+
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
|
1718
|
+
|
1719
|
+
accv0 = vaddq_s32(accv0, vi);
|
1720
|
+
}
|
1721
|
+
|
1722
|
+
// high half
|
1723
|
+
for (int l = 4; l < 8; l++) {
|
1724
|
+
const float32x4_t v = vmulq_n_f32(srcv[l], id);
|
1725
|
+
const int32x4_t vi = vcvtnq_s32_f32(v);
|
1726
|
+
|
1727
|
+
y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
|
1728
|
+
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
|
1729
|
+
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
|
1730
|
+
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
|
1731
|
+
|
1732
|
+
accv1 = vaddq_s32(accv1, vi);
|
1733
|
+
}
|
1734
|
+
|
1735
|
+
const int32_t sum0 = vaddvq_s32(accv0);
|
1736
|
+
const int32_t sum1 = vaddvq_s32(accv1);
|
1737
|
+
|
1738
|
+
y[i].s0 = d * sum0;
|
1739
|
+
y[i].s1 = d * sum1;
|
1740
|
+
}
|
1741
|
+
#elif defined(__AVX2__) || defined(__AVX__)
|
1742
|
+
for (int i = 0; i < nb; i++) {
|
1743
|
+
// Load elements into 4 AVX vectors
|
1744
|
+
__m256 v0 = _mm256_loadu_ps( x );
|
1745
|
+
__m256 v1 = _mm256_loadu_ps( x + 8 );
|
1746
|
+
__m256 v2 = _mm256_loadu_ps( x + 16 );
|
1747
|
+
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
1748
|
+
x += 32;
|
1749
|
+
|
1750
|
+
// Compute max(abs(e)) for the block
|
1751
|
+
const __m256 signBit = _mm256_set1_ps( -0.0f );
|
1752
|
+
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
|
1753
|
+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
|
1754
|
+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
|
1755
|
+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
|
1756
|
+
|
1757
|
+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
|
1758
|
+
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
1759
|
+
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
1760
|
+
const float maxScalar = _mm_cvtss_f32( max4 );
|
1761
|
+
|
1762
|
+
// Quantize these floats
|
1763
|
+
const float d = maxScalar / 127.f;
|
1764
|
+
y[i].d = d;
|
1765
|
+
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
|
1766
|
+
const __m256 mul = _mm256_set1_ps( id );
|
1767
|
+
|
1768
|
+
// Apply the multiplier
|
1769
|
+
v0 = _mm256_mul_ps( v0, mul );
|
1770
|
+
v1 = _mm256_mul_ps( v1, mul );
|
1771
|
+
v2 = _mm256_mul_ps( v2, mul );
|
1772
|
+
v3 = _mm256_mul_ps( v3, mul );
|
1773
|
+
|
1774
|
+
// Round to nearest integer
|
1775
|
+
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
|
1776
|
+
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
|
1777
|
+
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
|
1778
|
+
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
|
1779
|
+
|
1780
|
+
// Convert floats to integers
|
1781
|
+
__m256i i0 = _mm256_cvtps_epi32( v0 );
|
1782
|
+
__m256i i1 = _mm256_cvtps_epi32( v1 );
|
1783
|
+
__m256i i2 = _mm256_cvtps_epi32( v2 );
|
1784
|
+
__m256i i3 = _mm256_cvtps_epi32( v3 );
|
1785
|
+
|
1786
|
+
#if defined(__AVX2__)
|
1787
|
+
// Compute the sum of the quants and set y[i].s
|
1788
|
+
//y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
|
1789
|
+
y[i].s0 = d * hsum_i32_8(_mm256_add_epi32(i0, i1));
|
1790
|
+
y[i].s1 = d * hsum_i32_8(_mm256_add_epi32(i2, i3));
|
1791
|
+
|
1792
|
+
// Convert int32 to int16
|
1793
|
+
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
|
1794
|
+
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
|
1795
|
+
// Convert int16 to int8
|
1796
|
+
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
|
1797
|
+
|
1798
|
+
// We got our precious signed bytes, but the order is now wrong
|
1799
|
+
// These AVX2 pack instructions process 16-byte pieces independently
|
1800
|
+
// The following instruction is fixing the order
|
1801
|
+
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
|
1802
|
+
i0 = _mm256_permutevar8x32_epi32( i0, perm );
|
1803
|
+
|
1804
|
+
_mm256_storeu_si256((__m256i *)y[i].qs, i0);
|
1805
|
+
#else
|
1806
|
+
// Since we don't have in AVX some necessary functions,
|
1807
|
+
// we split the registers in half and call AVX2 analogs from SSE
|
1808
|
+
__m128i ni0 = _mm256_castsi256_si128( i0 );
|
1809
|
+
__m128i ni1 = _mm256_extractf128_si256( i0, 1);
|
1810
|
+
__m128i ni2 = _mm256_castsi256_si128( i1 );
|
1811
|
+
__m128i ni3 = _mm256_extractf128_si256( i1, 1);
|
1812
|
+
__m128i ni4 = _mm256_castsi256_si128( i2 );
|
1813
|
+
__m128i ni5 = _mm256_extractf128_si256( i2, 1);
|
1814
|
+
__m128i ni6 = _mm256_castsi256_si128( i3 );
|
1815
|
+
__m128i ni7 = _mm256_extractf128_si256( i3, 1);
|
1816
|
+
|
1817
|
+
// Compute the sum of the quants and set y[i].s
|
1818
|
+
const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
|
1819
|
+
const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
|
1820
|
+
y[i].s0 = d * hsum_i32_4(s0);
|
1821
|
+
y[i].s1 = d * hsum_i32_4(s1);
|
1822
|
+
|
1823
|
+
// Convert int32 to int16
|
1824
|
+
ni0 = _mm_packs_epi32( ni0, ni1 );
|
1825
|
+
ni2 = _mm_packs_epi32( ni2, ni3 );
|
1826
|
+
ni4 = _mm_packs_epi32( ni4, ni5 );
|
1827
|
+
ni6 = _mm_packs_epi32( ni6, ni7 );
|
1828
|
+
// Convert int16 to int8
|
1829
|
+
ni0 = _mm_packs_epi16( ni0, ni2 );
|
1830
|
+
ni4 = _mm_packs_epi16( ni4, ni6 );
|
1831
|
+
|
1832
|
+
_mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
|
1833
|
+
_mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
|
1834
|
+
#endif
|
1835
|
+
}
|
1836
|
+
#else
|
1837
|
+
// scalar
|
1838
|
+
quantize_row_q8_1_reference(x, y, k);
|
1839
|
+
#endif
|
1840
|
+
}
|
1841
|
+
|
1842
|
+
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
|
1843
|
+
assert(k % QK4_0 == 0);
|
1844
|
+
const int nb = k / QK4_0;
|
1845
|
+
|
1846
|
+
const block_q4_0 * restrict x = vx;
|
1847
|
+
|
1848
|
+
#if defined(__AVX2__)
|
1849
|
+
for (int i = 0; i < nb; i++) {
|
1850
|
+
// scale factor
|
1851
|
+
const __m256 d_v = _mm256_broadcast_ss(&x[i].d);
|
1852
|
+
|
1853
|
+
const uint8_t * restrict pp = x[i].qs;
|
1854
|
+
|
1855
|
+
for (int l = 0; l < QK4_0; l += 32) {
|
1856
|
+
// Load 32x4-bit integers into 32x8-bit integers
|
1857
|
+
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
|
1858
|
+
|
1859
|
+
// Subtract 8 from the integers
|
1860
|
+
vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
|
1861
|
+
|
1862
|
+
// Convert to 16-bit int
|
1863
|
+
const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
|
1864
|
+
const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
|
1865
|
+
|
1866
|
+
// Convert to 32-bit int -> float 32
|
1867
|
+
const __m256 vf[4] = {
|
1868
|
+
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))),
|
1869
|
+
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
|
1870
|
+
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
|
1871
|
+
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
|
1872
|
+
};
|
1873
|
+
|
1874
|
+
// Scale and store
|
1875
|
+
for (int j = 0; j < 4; j++) {
|
1876
|
+
const __m256 result = _mm256_mul_ps(vf[j], d_v);
|
1877
|
+
_mm256_storeu_ps(y + i * QK4_0 + l + j*8, result);
|
1878
|
+
}
|
1879
|
+
}
|
1880
|
+
}
|
1881
|
+
#elif defined(__ARM_NEON)
|
1882
|
+
for (int i = 0; i < nb; i++) {
|
1883
|
+
const float32x4_t vd = vdupq_n_f32(x[i].d);
|
1884
|
+
|
1885
|
+
const uint8_t * restrict pp = x[i].qs;
|
1886
|
+
|
1887
|
+
for (int l = 0; l < QK4_0; l += 16) {
|
1888
|
+
// Load 16x4-bit integers into 8x8-bit integers
|
1889
|
+
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
1890
|
+
|
1891
|
+
// Expand 4-bit qs to 8-bit bytes
|
1892
|
+
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
|
1893
|
+
const uint8x8_t v1 = vshr_n_u8(v8, 4);
|
1894
|
+
|
1895
|
+
// Convert to signed 8-bit integers
|
1896
|
+
const int8x8_t vs_0 = vreinterpret_s8_u8(v0);
|
1897
|
+
const int8x8_t vs_1 = vreinterpret_s8_u8(v1);
|
1898
|
+
|
1899
|
+
// Subtract 8 from each byte
|
1900
|
+
const int8x8_t vb_0 = vsub_s8(vs_0, vdup_n_s8(8));
|
1901
|
+
const int8x8_t vb_1 = vsub_s8(vs_1, vdup_n_s8(8));
|
1902
|
+
|
1903
|
+
// Interleave and combine
|
1904
|
+
const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1);
|
1905
|
+
const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1);
|
1906
|
+
|
1907
|
+
const int8x16_t vq = vcombine_s8(vx_0, vx_1);
|
1908
|
+
|
1909
|
+
// convert to 2x int16x8_t
|
1910
|
+
const int16x8_t vi_0 = vmovl_s8(vget_low_s8 (vq));
|
1911
|
+
const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq));
|
1912
|
+
|
1913
|
+
// convert to 4x float32x4_t
|
1914
|
+
const float32x4_t vf_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_0)));
|
1915
|
+
const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0)));
|
1916
|
+
const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1)));
|
1917
|
+
const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1)));
|
1918
|
+
|
1919
|
+
// Multiply by d
|
1920
|
+
const float32x4_t r0 = vmulq_f32(vf_0, vd);
|
1921
|
+
const float32x4_t r1 = vmulq_f32(vf_1, vd);
|
1922
|
+
const float32x4_t r2 = vmulq_f32(vf_2, vd);
|
1923
|
+
const float32x4_t r3 = vmulq_f32(vf_3, vd);
|
1924
|
+
|
1925
|
+
// Store
|
1926
|
+
vst1q_f32(y + i*QK4_0 + l + 0, r0);
|
1927
|
+
vst1q_f32(y + i*QK4_0 + l + 4, r1);
|
1928
|
+
vst1q_f32(y + i*QK4_0 + l + 8, r2);
|
1929
|
+
vst1q_f32(y + i*QK4_0 + l + 12, r3);
|
1930
|
+
}
|
1524
1931
|
}
|
1525
1932
|
#else
|
1526
1933
|
// scalar
|
@@ -1532,7 +1939,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1532
1939
|
for (int l = 0; l < QK4_0; l += 2) {
|
1533
1940
|
const uint8_t vi = pp[l/2];
|
1534
1941
|
|
1535
|
-
const int8_t vi0 = vi &
|
1942
|
+
const int8_t vi0 = vi & 0x0F;
|
1536
1943
|
const int8_t vi1 = vi >> 4;
|
1537
1944
|
|
1538
1945
|
const float v0 = (vi0 - 8)*d;
|
@@ -1598,7 +2005,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1598
2005
|
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
1599
2006
|
|
1600
2007
|
// Expand 4-bit qs to 8-bit bytes
|
1601
|
-
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(
|
2008
|
+
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
|
1602
2009
|
const uint8x8_t v1 = vshr_n_u8(v8, 4);
|
1603
2010
|
|
1604
2011
|
// Interleave and combine
|
@@ -1640,7 +2047,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1640
2047
|
for (int l = 0; l < QK4_1; l += 2) {
|
1641
2048
|
const uint8_t vi = pp[l/2];
|
1642
2049
|
|
1643
|
-
const int8_t vi0 = vi &
|
2050
|
+
const int8_t vi0 = vi & 0x0F;
|
1644
2051
|
const int8_t vi1 = vi >> 4;
|
1645
2052
|
|
1646
2053
|
const float v0 = vi0*d + m;
|
@@ -1670,7 +2077,7 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in
|
|
1670
2077
|
for (int l = 0; l < QK4_2; l += 2) {
|
1671
2078
|
const uint8_t vi = pp[l/2];
|
1672
2079
|
|
1673
|
-
const int8_t vi0 = vi &
|
2080
|
+
const int8_t vi0 = vi & 0x0F;
|
1674
2081
|
const int8_t vi1 = vi >> 4;
|
1675
2082
|
|
1676
2083
|
const float v0 = (vi0 - 8)*d;
|
@@ -1685,11 +2092,47 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in
|
|
1685
2092
|
}
|
1686
2093
|
}
|
1687
2094
|
|
1688
|
-
static void
|
1689
|
-
assert(k %
|
1690
|
-
const int nb = k /
|
2095
|
+
static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, int k) {
|
2096
|
+
assert(k % QK5_0 == 0);
|
2097
|
+
const int nb = k / QK5_0;
|
2098
|
+
|
2099
|
+
const block_q5_0 * restrict x = vx;
|
2100
|
+
|
2101
|
+
for (int i = 0; i < nb; i++) {
|
2102
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
2103
|
+
|
2104
|
+
const uint8_t * restrict pp = x[i].qs;
|
2105
|
+
|
2106
|
+
uint32_t qh;
|
2107
|
+
memcpy(&qh, x[i].qh, sizeof(qh));
|
2108
|
+
|
2109
|
+
for (int l = 0; l < QK5_0; l += 2) {
|
2110
|
+
const uint8_t vi = pp[l/2];
|
2111
|
+
|
2112
|
+
// extract the 5-th bit from qh
|
2113
|
+
const uint8_t vh0 = ((qh & (1u << (l + 0))) >> (l + 0)) << 4;
|
2114
|
+
const uint8_t vh1 = ((qh & (1u << (l + 1))) >> (l + 1)) << 4;
|
2115
|
+
|
2116
|
+
const int8_t vi0 = (vi & 0x0F) | vh0;
|
2117
|
+
const int8_t vi1 = (vi >> 4) | vh1;
|
2118
|
+
|
2119
|
+
const float v0 = (vi0 - 16)*d;
|
2120
|
+
const float v1 = (vi1 - 16)*d;
|
2121
|
+
|
2122
|
+
y[i*QK5_0 + l + 0] = v0;
|
2123
|
+
y[i*QK5_0 + l + 1] = v1;
|
2124
|
+
|
2125
|
+
assert(!isnan(y[i*QK5_0 + l + 0]));
|
2126
|
+
assert(!isnan(y[i*QK5_0 + l + 1]));
|
2127
|
+
}
|
2128
|
+
}
|
2129
|
+
}
|
2130
|
+
|
2131
|
+
static void dequantize_row_q5_1(const void * restrict vx, float * restrict y, int k) {
|
2132
|
+
assert(k % QK5_1 == 0);
|
2133
|
+
const int nb = k / QK5_1;
|
1691
2134
|
|
1692
|
-
const
|
2135
|
+
const block_q5_1 * restrict x = vx;
|
1693
2136
|
|
1694
2137
|
for (int i = 0; i < nb; i++) {
|
1695
2138
|
const float d = GGML_FP16_TO_FP32(x[i].d);
|
@@ -1697,28 +2140,54 @@ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, in
|
|
1697
2140
|
|
1698
2141
|
const uint8_t * restrict pp = x[i].qs;
|
1699
2142
|
|
1700
|
-
|
2143
|
+
uint32_t qh;
|
2144
|
+
memcpy(&qh, x[i].qh, sizeof(qh));
|
2145
|
+
|
2146
|
+
for (int l = 0; l < QK5_1; l += 2) {
|
1701
2147
|
const uint8_t vi = pp[l/2];
|
1702
2148
|
|
1703
|
-
|
1704
|
-
const
|
2149
|
+
// extract the 5-th bit from qh
|
2150
|
+
const uint8_t vh0 = ((qh & (1u << (l + 0))) >> (l + 0)) << 4;
|
2151
|
+
const uint8_t vh1 = ((qh & (1u << (l + 1))) >> (l + 1)) << 4;
|
2152
|
+
|
2153
|
+
const uint8_t vi0 = (vi & 0x0F) | vh0;
|
2154
|
+
const uint8_t vi1 = (vi >> 4) | vh1;
|
1705
2155
|
|
1706
2156
|
const float v0 = vi0*d + m;
|
1707
2157
|
const float v1 = vi1*d + m;
|
1708
2158
|
|
1709
|
-
y[i*
|
1710
|
-
y[i*
|
2159
|
+
y[i*QK5_1 + l + 0] = v0;
|
2160
|
+
y[i*QK5_1 + l + 1] = v1;
|
2161
|
+
|
2162
|
+
assert(!isnan(y[i*QK5_1 + l + 0]));
|
2163
|
+
assert(!isnan(y[i*QK5_1 + l + 1]));
|
2164
|
+
}
|
2165
|
+
}
|
2166
|
+
}
|
2167
|
+
|
2168
|
+
static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, int k) {
|
2169
|
+
assert(k % QK8_0 == 0);
|
2170
|
+
const int nb = k / QK8_0;
|
2171
|
+
|
2172
|
+
const block_q8_0 * restrict x = vx;
|
2173
|
+
|
2174
|
+
for (int i = 0; i < nb; i++) {
|
2175
|
+
const float d = x[i].d;
|
1711
2176
|
|
1712
|
-
|
1713
|
-
|
2177
|
+
const int8_t * restrict pp = x[i].qs;
|
2178
|
+
|
2179
|
+
for (int l = 0; l < QK8_0; ++l) {
|
2180
|
+
y[i*QK8_0 + l] = pp[l]*d;
|
1714
2181
|
}
|
1715
2182
|
}
|
1716
2183
|
}
|
1717
2184
|
|
1718
2185
|
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1719
|
-
static void
|
2186
|
+
static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1720
2187
|
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1721
|
-
static void
|
2188
|
+
static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
2189
|
+
static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
2190
|
+
static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1722
2191
|
|
1723
2192
|
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
1724
2193
|
[GGML_TYPE_Q4_0] = {
|
@@ -1727,34 +2196,55 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
|
1727
2196
|
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
|
1728
2197
|
.quantize_row_q_dot = quantize_row_q8_0,
|
1729
2198
|
.vec_dot_q = ggml_vec_dot_q4_0_q8_0,
|
2199
|
+
.vec_dot_type = GGML_TYPE_Q8_0,
|
1730
2200
|
},
|
1731
2201
|
[GGML_TYPE_Q4_1] = {
|
1732
2202
|
.dequantize_row_q = dequantize_row_q4_1,
|
1733
2203
|
.quantize_row_q = quantize_row_q4_1,
|
1734
2204
|
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
|
1735
|
-
.quantize_row_q_dot =
|
1736
|
-
.vec_dot_q =
|
2205
|
+
.quantize_row_q_dot = quantize_row_q8_1,
|
2206
|
+
.vec_dot_q = ggml_vec_dot_q4_1_q8_1,
|
2207
|
+
.vec_dot_type = GGML_TYPE_Q8_1,
|
1737
2208
|
},
|
1738
2209
|
[GGML_TYPE_Q4_2] = {
|
1739
2210
|
.dequantize_row_q = dequantize_row_q4_2,
|
1740
2211
|
.quantize_row_q = quantize_row_q4_2,
|
1741
|
-
.quantize_row_q_reference = (quantize_row_q_t)
|
2212
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
|
1742
2213
|
.quantize_row_q_dot = quantize_row_q8_0,
|
1743
2214
|
.vec_dot_q = ggml_vec_dot_q4_2_q8_0,
|
2215
|
+
.vec_dot_type = GGML_TYPE_Q8_0,
|
1744
2216
|
},
|
1745
|
-
[
|
1746
|
-
.dequantize_row_q =
|
1747
|
-
.quantize_row_q =
|
1748
|
-
.quantize_row_q_reference = (quantize_row_q_t)
|
2217
|
+
[GGML_TYPE_Q5_0] = {
|
2218
|
+
.dequantize_row_q = dequantize_row_q5_0,
|
2219
|
+
.quantize_row_q = quantize_row_q5_0,
|
2220
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_0_reference,
|
1749
2221
|
.quantize_row_q_dot = quantize_row_q8_0,
|
1750
|
-
.vec_dot_q =
|
2222
|
+
.vec_dot_q = ggml_vec_dot_q5_0_q8_0,
|
2223
|
+
.vec_dot_type = GGML_TYPE_Q8_0,
|
2224
|
+
},
|
2225
|
+
[GGML_TYPE_Q5_1] = {
|
2226
|
+
.dequantize_row_q = dequantize_row_q5_1,
|
2227
|
+
.quantize_row_q = quantize_row_q5_1,
|
2228
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_1_reference,
|
2229
|
+
.quantize_row_q_dot = quantize_row_q8_1,
|
2230
|
+
.vec_dot_q = ggml_vec_dot_q5_1_q8_1,
|
2231
|
+
.vec_dot_type = GGML_TYPE_Q8_1,
|
1751
2232
|
},
|
1752
2233
|
[GGML_TYPE_Q8_0] = {
|
1753
|
-
.dequantize_row_q =
|
2234
|
+
.dequantize_row_q = dequantize_row_q8_0,
|
1754
2235
|
.quantize_row_q = quantize_row_q8_0,
|
1755
2236
|
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
|
1756
2237
|
.quantize_row_q_dot = quantize_row_q8_0,
|
2238
|
+
.vec_dot_q = ggml_vec_dot_q8_0_q8_0,
|
2239
|
+
.vec_dot_type = GGML_TYPE_Q8_0,
|
2240
|
+
},
|
2241
|
+
[GGML_TYPE_Q8_1] = {
|
2242
|
+
.dequantize_row_q = NULL, // TODO
|
2243
|
+
.quantize_row_q = quantize_row_q8_1,
|
2244
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_1_reference,
|
2245
|
+
.quantize_row_q_dot = quantize_row_q8_1,
|
1757
2246
|
.vec_dot_q = NULL, // TODO
|
2247
|
+
.vec_dot_type = GGML_TYPE_Q8_1,
|
1758
2248
|
},
|
1759
2249
|
};
|
1760
2250
|
|
@@ -2366,8 +2856,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
2366
2856
|
const block_q4_0 * restrict x = vx;
|
2367
2857
|
const block_q8_0 * restrict y = vy;
|
2368
2858
|
|
2369
|
-
float sumf = 0.0;
|
2370
|
-
|
2371
2859
|
#if defined(__ARM_NEON)
|
2372
2860
|
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
2373
2861
|
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
@@ -2378,7 +2866,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
2378
2866
|
const block_q8_0 * restrict y0 = &y[i + 0];
|
2379
2867
|
const block_q8_0 * restrict y1 = &y[i + 1];
|
2380
2868
|
|
2381
|
-
const uint8x16_t m4b = vdupq_n_u8(
|
2869
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
2382
2870
|
const int8x16_t s8b = vdupq_n_s8(0x8);
|
2383
2871
|
|
2384
2872
|
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
@@ -2396,35 +2884,35 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
2396
2884
|
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
2397
2885
|
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
2398
2886
|
|
2887
|
+
// interleave
|
2888
|
+
const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
|
2889
|
+
const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
|
2890
|
+
const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
|
2891
|
+
const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
|
2892
|
+
|
2399
2893
|
// load y
|
2400
2894
|
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2401
2895
|
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2402
2896
|
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2403
2897
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2404
2898
|
|
2405
|
-
// interleave
|
2406
|
-
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
|
2407
|
-
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
|
2408
|
-
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
|
2409
|
-
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
|
2410
|
-
|
2411
2899
|
#if defined(__ARM_FEATURE_DOTPROD)
|
2412
2900
|
// dot product into int32x4_t
|
2413
|
-
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0),
|
2414
|
-
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0),
|
2901
|
+
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h);
|
2902
|
+
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h);
|
2415
2903
|
|
2416
2904
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
|
2417
2905
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
|
2418
2906
|
#else
|
2419
|
-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (
|
2420
|
-
const int16x8_t pl0h = vmull_s8(vget_high_s8(
|
2421
|
-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (
|
2422
|
-
const int16x8_t ph0h = vmull_s8(vget_high_s8(
|
2907
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
|
2908
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
|
2909
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
|
2910
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
|
2423
2911
|
|
2424
|
-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (
|
2425
|
-
const int16x8_t pl1h = vmull_s8(vget_high_s8(
|
2426
|
-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (
|
2427
|
-
const int16x8_t ph1h = vmull_s8(vget_high_s8(
|
2912
|
+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
|
2913
|
+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
|
2914
|
+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
|
2915
|
+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
|
2428
2916
|
|
2429
2917
|
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2430
2918
|
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
@@ -2436,7 +2924,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
2436
2924
|
#endif
|
2437
2925
|
}
|
2438
2926
|
|
2439
|
-
|
2927
|
+
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
2440
2928
|
#elif defined(__AVX2__)
|
2441
2929
|
// Initialize accumulator with zeros
|
2442
2930
|
__m256 acc = _mm256_setzero_ps();
|
@@ -2454,32 +2942,13 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
2454
2942
|
|
2455
2943
|
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
2456
2944
|
|
2457
|
-
|
2458
|
-
const __m256i ax = _mm256_sign_epi8(bx, bx);
|
2459
|
-
|
2460
|
-
// Sign the values of the y vectors
|
2461
|
-
const __m256i sy = _mm256_sign_epi8(by, bx);
|
2462
|
-
|
2463
|
-
// Perform multiplication and create 16-bit values
|
2464
|
-
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
2465
|
-
|
2466
|
-
const __m256i ones = _mm256_set1_epi16(1);
|
2467
|
-
__m256i xy_q = _mm256_madd_epi16(ones, dot);
|
2468
|
-
|
2469
|
-
/* Convert to vectore of 8 int32_t to 8 floats */
|
2470
|
-
__m256 q = _mm256_cvtepi32_ps( xy_q );
|
2945
|
+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
2471
2946
|
|
2472
2947
|
/* Multiply q with scale and accumulate */
|
2473
2948
|
acc = _mm256_fmadd_ps( d, q, acc );
|
2474
2949
|
}
|
2475
2950
|
|
2476
|
-
|
2477
|
-
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
2478
|
-
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2479
|
-
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2480
|
-
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2481
|
-
|
2482
|
-
sumf = _mm_cvtss_f32( res );
|
2951
|
+
*s = hsum_float_8(acc);
|
2483
2952
|
#elif defined(__AVX__)
|
2484
2953
|
// Initialize accumulator with zeros
|
2485
2954
|
__m256 acc = _mm256_setzero_ps();
|
@@ -2518,15 +2987,10 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
2518
2987
|
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
|
2519
2988
|
}
|
2520
2989
|
|
2521
|
-
|
2522
|
-
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
2523
|
-
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2524
|
-
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2525
|
-
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2526
|
-
|
2527
|
-
sumf = _mm_cvtss_f32( res );
|
2990
|
+
*s = hsum_float_8(acc);
|
2528
2991
|
#else
|
2529
2992
|
// scalar
|
2993
|
+
float sumf = 0.0;
|
2530
2994
|
for (int i = 0; i < nb; i++) {
|
2531
2995
|
const float d0 = x[i].d;
|
2532
2996
|
const float d1 = y[i].d;
|
@@ -2538,8 +3002,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
2538
3002
|
for (int j = 0; j < QK8_0/2; j++) {
|
2539
3003
|
const uint8_t v0 = p0[j];
|
2540
3004
|
|
2541
|
-
const int i0 = (int8_t) (v0 &
|
2542
|
-
const int i1 = (int8_t) (v0 >>
|
3005
|
+
const int i0 = (int8_t) (v0 & 0x0F) - 8;
|
3006
|
+
const int i1 = (int8_t) (v0 >> 4) - 8;
|
2543
3007
|
|
2544
3008
|
const int i2 = p1[2*j + 0];
|
2545
3009
|
const int i3 = p1[2*j + 1];
|
@@ -2548,34 +3012,35 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
2548
3012
|
}
|
2549
3013
|
sumf += d0*d1*sumi;
|
2550
3014
|
}
|
2551
|
-
#endif
|
2552
|
-
|
2553
3015
|
*s = sumf;
|
3016
|
+
#endif
|
2554
3017
|
}
|
2555
3018
|
|
2556
|
-
static void
|
2557
|
-
const int nb = n /
|
3019
|
+
static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
3020
|
+
const int nb = n / QK8_1;
|
2558
3021
|
|
2559
|
-
assert(n %
|
3022
|
+
assert(n % QK8_1 == 0);
|
2560
3023
|
assert(nb % 2 == 0);
|
2561
3024
|
|
2562
3025
|
const block_q4_1 * restrict x = vx;
|
2563
|
-
const
|
2564
|
-
|
2565
|
-
float sumf = 0.0;
|
3026
|
+
const block_q8_1 * restrict y = vy;
|
2566
3027
|
|
2567
3028
|
// TODO: add AVX / WASM SIMD / etc
|
2568
3029
|
#if defined(__ARM_NEON)
|
2569
3030
|
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
2570
3031
|
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
2571
3032
|
|
3033
|
+
float summs = 0;
|
3034
|
+
|
2572
3035
|
for (int i = 0; i < nb; i += 2) {
|
2573
3036
|
const block_q4_1 * restrict x0 = &x[i + 0];
|
2574
3037
|
const block_q4_1 * restrict x1 = &x[i + 1];
|
2575
|
-
const
|
2576
|
-
const
|
3038
|
+
const block_q8_1 * restrict y0 = &y[i + 0];
|
3039
|
+
const block_q8_1 * restrict y1 = &y[i + 1];
|
3040
|
+
|
3041
|
+
summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1);
|
2577
3042
|
|
2578
|
-
const uint8x16_t m4b = vdupq_n_u8(
|
3043
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
2579
3044
|
|
2580
3045
|
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
2581
3046
|
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
@@ -2586,46 +3051,35 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
|
|
2586
3051
|
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
2587
3052
|
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
2588
3053
|
|
3054
|
+
// interleave
|
3055
|
+
const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
|
3056
|
+
const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
|
3057
|
+
const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
|
3058
|
+
const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
|
3059
|
+
|
2589
3060
|
// load y
|
2590
3061
|
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2591
3062
|
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2592
3063
|
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2593
3064
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2594
3065
|
|
2595
|
-
// interleave
|
2596
|
-
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
|
2597
|
-
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
|
2598
|
-
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
|
2599
|
-
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
|
2600
|
-
|
2601
|
-
const int16x8_t s0i = vaddq_s16(
|
2602
|
-
vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
|
2603
|
-
vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
|
2604
|
-
|
2605
|
-
const int16x8_t s1i = vaddq_s16(
|
2606
|
-
vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
|
2607
|
-
vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
|
2608
|
-
|
2609
|
-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
|
2610
|
-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
|
2611
|
-
|
2612
3066
|
#if defined(__ARM_FEATURE_DOTPROD)
|
2613
3067
|
// dot product into int32x4_t
|
2614
|
-
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0),
|
2615
|
-
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0),
|
3068
|
+
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h);
|
3069
|
+
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h);
|
2616
3070
|
|
2617
3071
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
|
2618
3072
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
|
2619
3073
|
#else
|
2620
|
-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (
|
2621
|
-
const int16x8_t pl0h = vmull_s8(vget_high_s8(
|
2622
|
-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (
|
2623
|
-
const int16x8_t ph0h = vmull_s8(vget_high_s8(
|
3074
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
|
3075
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
|
3076
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
|
3077
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
|
2624
3078
|
|
2625
|
-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (
|
2626
|
-
const int16x8_t pl1h = vmull_s8(vget_high_s8(
|
2627
|
-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (
|
2628
|
-
const int16x8_t ph1h = vmull_s8(vget_high_s8(
|
3079
|
+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
|
3080
|
+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
|
3081
|
+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
|
3082
|
+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
|
2629
3083
|
|
2630
3084
|
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2631
3085
|
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
@@ -2637,65 +3091,40 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
|
|
2637
3091
|
#endif
|
2638
3092
|
}
|
2639
3093
|
|
2640
|
-
|
3094
|
+
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
|
2641
3095
|
#elif defined(__AVX2__)
|
2642
3096
|
// Initialize accumulator with zeros
|
2643
3097
|
__m256 acc = _mm256_setzero_ps();
|
2644
3098
|
|
3099
|
+
float summs = 0;
|
3100
|
+
|
2645
3101
|
// Main loop
|
2646
3102
|
for (int i = 0; i < nb; ++i) {
|
2647
3103
|
const float * d0 = &x[i].d;
|
2648
3104
|
const float * d1 = &y[i].d;
|
2649
|
-
|
3105
|
+
|
3106
|
+
summs += x[i].m * (y[i].s0 + y[i].s1);
|
2650
3107
|
|
2651
3108
|
const __m256 d0v = _mm256_broadcast_ss( d0 );
|
2652
3109
|
const __m256 d1v = _mm256_broadcast_ss( d1 );
|
2653
|
-
const __m256 m0v = _mm256_broadcast_ss( m0 );
|
2654
3110
|
|
2655
3111
|
// Compute combined scales
|
2656
3112
|
const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
|
2657
|
-
const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
|
2658
3113
|
|
2659
3114
|
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
|
2660
3115
|
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
|
2661
3116
|
const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
|
2662
3117
|
|
2663
|
-
|
2664
|
-
const __m256i ax = _mm256_sign_epi8( bx, bx );
|
2665
|
-
|
2666
|
-
// Sign the values of the y vectors
|
2667
|
-
const __m256i sy = _mm256_sign_epi8( by, bx );
|
2668
|
-
|
2669
|
-
// Perform multiplication and create 16-bit values
|
2670
|
-
const __m256i dot = _mm256_maddubs_epi16( ax, sy );
|
2671
|
-
const __m256i ones = _mm256_set1_epi16( 1 );
|
2672
|
-
const __m256i xy_q = _mm256_madd_epi16( ones, dot );
|
2673
|
-
|
2674
|
-
// Convert to vector of 8 int32_t to 8 floats
|
2675
|
-
const __m256 xy = _mm256_cvtepi32_ps( xy_q );
|
3118
|
+
const __m256 xy = mul_sum_i8_pairs_float(bx, by);
|
2676
3119
|
|
2677
3120
|
// Accumulate d0*d1*x*y
|
2678
3121
|
acc = _mm256_fmadd_ps( d0d1, xy, acc );
|
2679
|
-
|
2680
|
-
// Compute sum of y values
|
2681
|
-
const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
|
2682
|
-
const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
|
2683
|
-
const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
|
2684
|
-
const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
|
2685
|
-
|
2686
|
-
// Accumulate d1*m0*y
|
2687
|
-
acc = _mm256_fmadd_ps( d1m0, ysum, acc );
|
2688
3122
|
}
|
2689
3123
|
|
2690
|
-
|
2691
|
-
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
2692
|
-
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2693
|
-
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2694
|
-
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2695
|
-
|
2696
|
-
sumf = _mm_cvtss_f32( res );
|
3124
|
+
*s = hsum_float_8(acc) + summs;
|
2697
3125
|
#else
|
2698
3126
|
// scalar
|
3127
|
+
float sumf = 0.0;
|
2699
3128
|
for (int i = 0; i < nb; i++) {
|
2700
3129
|
const float d0 = x[i].d;
|
2701
3130
|
const float m0 = x[i].m;
|
@@ -2705,347 +3134,685 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
|
|
2705
3134
|
const int8_t * restrict p1 = y[i].qs;
|
2706
3135
|
|
2707
3136
|
// TODO: this is very slow ..
|
2708
|
-
for (int j = 0; j <
|
3137
|
+
for (int j = 0; j < QK8_1/2; j++) {
|
2709
3138
|
const uint8_t v0 = p0[j];
|
2710
3139
|
|
2711
|
-
const float f0 = d0*(v0 &
|
2712
|
-
const float f1 = d0*(v0 >>
|
3140
|
+
const float f0 = d0*(v0 & 0x0F) + m0;
|
3141
|
+
const float f1 = d0*(v0 >> 4) + m0;
|
3142
|
+
|
3143
|
+
const float f2 = d1*p1[2*j + 0];
|
3144
|
+
const float f3 = d1*p1[2*j + 1];
|
3145
|
+
|
3146
|
+
sumf += f0*f2 + f1*f3;
|
3147
|
+
}
|
3148
|
+
}
|
3149
|
+
*s = sumf;
|
3150
|
+
#endif
|
3151
|
+
}
|
3152
|
+
|
3153
|
+
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
3154
|
+
const int nb = n / QK8_0;
|
3155
|
+
|
3156
|
+
assert(n % QK8_0 == 0);
|
3157
|
+
assert(nb % 2 == 0);
|
3158
|
+
assert(QK8_0 == 2*QK4_2);
|
3159
|
+
|
3160
|
+
const block_q4_2 * restrict x = vx;
|
3161
|
+
const block_q8_0 * restrict y = vy;
|
3162
|
+
|
3163
|
+
#if defined(__ARM_NEON)
|
3164
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
3165
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
3166
|
+
|
3167
|
+
for (int i = 0; i < nb; i += 2) {
|
3168
|
+
const block_q4_2 * restrict x0_0 = &x[2*(i + 0) + 0];
|
3169
|
+
const block_q4_2 * restrict x0_1 = &x[2*(i + 0) + 1];
|
3170
|
+
const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0];
|
3171
|
+
const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1];
|
3172
|
+
|
3173
|
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
3174
|
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
3175
|
+
|
3176
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
3177
|
+
const int8x16_t s8b = vdupq_n_s8(0x8);
|
3178
|
+
|
3179
|
+
const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
|
3180
|
+
const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
|
3181
|
+
|
3182
|
+
// 4-bit -> 8-bit
|
3183
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
3184
|
+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
3185
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
3186
|
+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
3187
|
+
|
3188
|
+
// sub 8
|
3189
|
+
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
3190
|
+
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
3191
|
+
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
3192
|
+
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
3193
|
+
|
3194
|
+
// interleave
|
3195
|
+
const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
|
3196
|
+
const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
|
3197
|
+
const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
|
3198
|
+
const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
|
3199
|
+
|
3200
|
+
// load y
|
3201
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
3202
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
3203
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
3204
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
3205
|
+
|
3206
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
3207
|
+
sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
|
3208
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
|
3209
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
|
3210
|
+
|
3211
|
+
sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
|
3212
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
|
3213
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
|
3214
|
+
#else
|
3215
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
|
3216
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
|
3217
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
|
3218
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
|
3219
|
+
|
3220
|
+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
|
3221
|
+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
|
3222
|
+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
|
3223
|
+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
|
3224
|
+
|
3225
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
3226
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
3227
|
+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
3228
|
+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
3229
|
+
|
3230
|
+
sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
|
3231
|
+
vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
|
3232
|
+
vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
|
3233
|
+
|
3234
|
+
sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
|
3235
|
+
vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
|
3236
|
+
vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
|
3237
|
+
#endif
|
3238
|
+
}
|
3239
|
+
|
3240
|
+
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
3241
|
+
#elif defined(__AVX2__)
|
3242
|
+
// Initialize accumulator with zeros
|
3243
|
+
__m256 acc = _mm256_setzero_ps();
|
3244
|
+
|
3245
|
+
// Main loop
|
3246
|
+
for (int i = 0; i < nb; i++) {
|
3247
|
+
/* Compute combined scale for the block */
|
3248
|
+
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
|
3249
|
+
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
|
3250
|
+
const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
|
3251
|
+
|
3252
|
+
__m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
|
3253
|
+
__m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
|
3254
|
+
__m256i bx = _mm256_set_m128i(bx1, bx0);
|
3255
|
+
|
3256
|
+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
3257
|
+
const __m256i off = _mm256_set1_epi8(8);
|
3258
|
+
bx = _mm256_sub_epi8(bx, off);
|
3259
|
+
|
3260
|
+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
3261
|
+
|
3262
|
+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
3263
|
+
|
3264
|
+
/* Multiply q with scale and accumulate */
|
3265
|
+
acc = _mm256_fmadd_ps(d, q, acc);
|
3266
|
+
}
|
3267
|
+
|
3268
|
+
*s = hsum_float_8(acc);
|
3269
|
+
#else
|
3270
|
+
// scalar
|
3271
|
+
float sumf = 0.0;
|
3272
|
+
for (int i = 0; i < nb; i++) {
|
3273
|
+
const uint8_t * restrict x0 = x[2*i + 0].qs;
|
3274
|
+
const uint8_t * restrict x1 = x[2*i + 1].qs;
|
3275
|
+
const int8_t * restrict y0 = y[i].qs;
|
3276
|
+
|
3277
|
+
const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
|
3278
|
+
const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
|
3279
|
+
|
3280
|
+
int sumi_0 = 0;
|
3281
|
+
int sumi_1 = 0;
|
3282
|
+
|
3283
|
+
for (int j = 0; j < QK8_0/4; j++) {
|
3284
|
+
const uint8_t v0 = x0[j];
|
3285
|
+
const uint8_t v1 = x1[j];
|
3286
|
+
|
3287
|
+
const int i0_0 = (int8_t) (v0 & 0x0F) - 8;
|
3288
|
+
const int i1_0 = (int8_t) (v0 >> 4) - 8;
|
3289
|
+
|
3290
|
+
const int i0_1 = (int8_t) (v1 & 0x0F) - 8;
|
3291
|
+
const int i1_1 = (int8_t) (v1 >> 4) - 8;
|
3292
|
+
|
3293
|
+
const int i2_0 = y0[2*j + 0];
|
3294
|
+
const int i3_0 = y0[2*j + 1];
|
3295
|
+
|
3296
|
+
const int i2_1 = y0[2*(j + QK8_0/4) + 0];
|
3297
|
+
const int i3_1 = y0[2*(j + QK8_0/4) + 1];
|
3298
|
+
|
3299
|
+
sumi_0 += i0_0*i2_0 + i1_0*i3_0;
|
3300
|
+
sumi_1 += i0_1*i2_1 + i1_1*i3_1;
|
3301
|
+
}
|
3302
|
+
|
3303
|
+
sumf += (d0 * y[i].d) * sumi_0;
|
3304
|
+
sumf += (d1 * y[i].d) * sumi_1;
|
3305
|
+
}
|
3306
|
+
*s = sumf;
|
3307
|
+
#endif
|
3308
|
+
}
|
3309
|
+
|
3310
|
+
static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
3311
|
+
const int nb = n / QK8_0;
|
3312
|
+
|
3313
|
+
assert(n % QK8_0 == 0);
|
3314
|
+
assert(nb % 2 == 0);
|
3315
|
+
assert(QK8_0 == QK5_0);
|
3316
|
+
|
3317
|
+
const block_q5_0 * restrict x = vx;
|
3318
|
+
const block_q8_0 * restrict y = vy;
|
3319
|
+
|
3320
|
+
#if defined(__ARM_NEON)
|
3321
|
+
float32x4_t sumv = vdupq_n_f32(0.0f);
|
3322
|
+
|
3323
|
+
uint64_t tmp[4];
|
3324
|
+
|
3325
|
+
for (int i = 0; i < nb; ++i) {
|
3326
|
+
const block_q5_0 * restrict x0 = &x[i];
|
3327
|
+
const block_q8_0 * restrict y0 = &y[i];
|
3328
|
+
|
3329
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
3330
|
+
const int8x16_t s16b = vdupq_n_s8(0x10);
|
3331
|
+
|
3332
|
+
// extract the 5th bit
|
3333
|
+
uint32_t qh;
|
3334
|
+
memcpy(&qh, x0->qh, sizeof(qh));
|
3335
|
+
|
3336
|
+
tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
|
3337
|
+
tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
|
3338
|
+
tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
|
3339
|
+
tmp[3] = table_b2b_u[(qh >> 24) ];
|
3340
|
+
|
3341
|
+
const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
|
3342
|
+
const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
|
3343
|
+
|
3344
|
+
const uint8x16_t v0 = vld1q_u8(x0->qs);
|
3345
|
+
|
3346
|
+
// 4-bit -> 8-bit
|
3347
|
+
const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, m4b));
|
3348
|
+
const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
|
3349
|
+
|
3350
|
+
// interleave
|
3351
|
+
const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
|
3352
|
+
const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
|
3353
|
+
|
3354
|
+
// add high bit and sub 16
|
3355
|
+
const int8x16_t v0lf = vsubq_s8(vorrq_s8(v0lz, qhl), s16b);
|
3356
|
+
const int8x16_t v0hf = vsubq_s8(vorrq_s8(v0hz, qhh), s16b);
|
3357
|
+
|
3358
|
+
// load y
|
3359
|
+
const int8x16_t v1l = vld1q_s8(y0->qs);
|
3360
|
+
const int8x16_t v1h = vld1q_s8(y0->qs + 16);
|
3361
|
+
|
3362
|
+
const float x0d = GGML_FP16_TO_FP32(x0->d);
|
3363
|
+
|
3364
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
3365
|
+
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
|
3366
|
+
vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
|
3367
|
+
vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
|
3368
|
+
#else
|
3369
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
|
3370
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
|
3371
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
|
3372
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
|
3373
|
+
|
3374
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
3375
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
3376
|
+
|
3377
|
+
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
|
3378
|
+
#endif
|
3379
|
+
}
|
3380
|
+
|
3381
|
+
*s = vaddvq_f32(sumv);
|
3382
|
+
#elif defined(__wasm_simd128__)
|
3383
|
+
v128_t sumv = wasm_f32x4_splat(0.0f);
|
3384
|
+
|
3385
|
+
uint64_t tmp[4];
|
3386
|
+
|
3387
|
+
for (int i = 0; i < nb; ++i) {
|
3388
|
+
const block_q5_0 * restrict x0 = &x[i];
|
3389
|
+
const block_q8_0 * restrict y0 = &y[i];
|
3390
|
+
|
3391
|
+
const v128_t m4b = wasm_i8x16_splat(0x0F);
|
3392
|
+
const v128_t s16b = wasm_i8x16_splat(0x10);
|
3393
|
+
|
3394
|
+
// extract the 5th bit
|
3395
|
+
uint32_t qh;
|
3396
|
+
memcpy(&qh, x0->qh, sizeof(qh));
|
3397
|
+
|
3398
|
+
tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
|
3399
|
+
tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
|
3400
|
+
tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
|
3401
|
+
tmp[3] = table_b2b_u[(qh >> 24) ];
|
3402
|
+
|
3403
|
+
const v128_t qhl = wasm_v128_load(tmp + 0);
|
3404
|
+
const v128_t qhh = wasm_v128_load(tmp + 2);
|
3405
|
+
|
3406
|
+
const v128_t v0 = wasm_v128_load(x0->qs);
|
3407
|
+
|
3408
|
+
// 4-bit -> 8-bit
|
3409
|
+
const v128_t v0l = wasm_v128_and (v0, m4b);
|
3410
|
+
const v128_t v0h = wasm_u8x16_shr(v0, 4);
|
3411
|
+
|
3412
|
+
// interleave
|
3413
|
+
const v128_t v0lz = wasm_v8x16_shuffle(v0l, v0h, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23);
|
3414
|
+
const v128_t v0hz = wasm_v8x16_shuffle(v0l, v0h, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31);
|
3415
|
+
|
3416
|
+
// add high bit and sub 16
|
3417
|
+
const v128_t v0lf = wasm_i8x16_sub(wasm_v128_or(v0lz, qhl), s16b);
|
3418
|
+
const v128_t v0hf = wasm_i8x16_sub(wasm_v128_or(v0hz, qhh), s16b);
|
3419
|
+
|
3420
|
+
// load y
|
3421
|
+
const v128_t v1l = wasm_v128_load(y0->qs);
|
3422
|
+
const v128_t v1h = wasm_v128_load(y0->qs + 16);
|
3423
|
+
|
3424
|
+
// int8x16 -> int16x8
|
3425
|
+
const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
|
3426
|
+
const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
|
3427
|
+
const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
|
3428
|
+
const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
|
3429
|
+
|
3430
|
+
const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
|
3431
|
+
const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
|
3432
|
+
const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
|
3433
|
+
const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
|
3434
|
+
|
3435
|
+
const float x0d = GGML_FP16_TO_FP32(x0->d);
|
3436
|
+
|
3437
|
+
// dot product
|
3438
|
+
sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
|
3439
|
+
wasm_i32x4_add(
|
3440
|
+
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
|
3441
|
+
wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
|
3442
|
+
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
|
3443
|
+
wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), wasm_f32x4_splat(x0d*y0->d)));
|
3444
|
+
}
|
3445
|
+
|
3446
|
+
*s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
|
3447
|
+
wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
|
3448
|
+
#elif defined(__AVX2__)
|
3449
|
+
// Initialize accumulator with zeros
|
3450
|
+
__m256 acc = _mm256_setzero_ps();
|
3451
|
+
|
3452
|
+
// Main loop
|
3453
|
+
for (int i = 0; i < nb; i++) {
|
3454
|
+
/* Compute combined scale for the block */
|
3455
|
+
const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
|
3456
|
+
|
3457
|
+
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
3458
|
+
__m256i bxhi = bytes_from_bits_32(x[i].qh);
|
3459
|
+
bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
|
3460
|
+
bx = _mm256_or_si256(bx, bxhi);
|
3461
|
+
|
3462
|
+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
3463
|
+
|
3464
|
+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
3465
|
+
|
3466
|
+
/* Multiply q with scale and accumulate */
|
3467
|
+
acc = _mm256_fmadd_ps(d, q, acc);
|
3468
|
+
}
|
3469
|
+
|
3470
|
+
*s = hsum_float_8(acc);
|
3471
|
+
#else
|
3472
|
+
// scalar
|
3473
|
+
float sumf = 0.0;
|
3474
|
+
for (int i = 0; i < nb; i++) {
|
3475
|
+
const uint8_t * restrict x0 = x[i].qs;
|
3476
|
+
const int8_t * restrict y0 = y[i].qs;
|
3477
|
+
|
3478
|
+
uint32_t qh;
|
3479
|
+
memcpy(&qh, x[i].qh, sizeof(qh));
|
3480
|
+
|
3481
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
3482
|
+
|
3483
|
+
int sxy = 0;
|
3484
|
+
|
3485
|
+
for (int j = 0; j < QK8_0/2; j++) {
|
3486
|
+
const uint8_t v0 = x0[j];
|
3487
|
+
|
3488
|
+
const int x0_0h = ((qh & (1u << (2*j + 0))) >> (2*j + 0)) << 4;
|
3489
|
+
const int x1_0h = ((qh & (1u << (2*j + 1))) >> (2*j + 1)) << 4;
|
3490
|
+
|
3491
|
+
const int x0_0 = ((v0 & 0x0F) | x0_0h) - 16;
|
3492
|
+
const int x1_0 = ((v0 >> 4) | x1_0h) - 16;
|
3493
|
+
|
3494
|
+
const int y0_0 = y0[2*j + 0];
|
3495
|
+
const int y1_0 = y0[2*j + 1];
|
3496
|
+
|
3497
|
+
sxy += x0_0*y0_0 + x1_0*y1_0;
|
3498
|
+
}
|
3499
|
+
|
3500
|
+
sumf += (d*sxy)*y[i].d;
|
3501
|
+
}
|
3502
|
+
*s = sumf;
|
3503
|
+
#endif
|
3504
|
+
}
|
3505
|
+
|
3506
|
+
static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
3507
|
+
const int nb = n / QK8_1;
|
3508
|
+
|
3509
|
+
assert(n % QK8_1 == 0);
|
3510
|
+
assert(nb % 2 == 0);
|
3511
|
+
assert(QK8_1 == QK5_1);
|
3512
|
+
|
3513
|
+
const block_q5_1 * restrict x = vx;
|
3514
|
+
const block_q8_1 * restrict y = vy;
|
3515
|
+
|
3516
|
+
#if defined(__ARM_NEON)
|
3517
|
+
float32x4_t sumv = vdupq_n_f32(0.0f);
|
3518
|
+
|
3519
|
+
float summs = 0.0f;
|
3520
|
+
|
3521
|
+
uint64_t tmp[4];
|
3522
|
+
|
3523
|
+
for (int i = 0; i < nb; ++i) {
|
3524
|
+
const block_q5_1 * restrict x0 = &x[i];
|
3525
|
+
const block_q8_1 * restrict y0 = &y[i];
|
3526
|
+
|
3527
|
+
summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
|
3528
|
+
|
3529
|
+
// extract the 5th bit
|
3530
|
+
uint32_t qh;
|
3531
|
+
memcpy(&qh, x0->qh, sizeof(qh));
|
3532
|
+
|
3533
|
+
tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
|
3534
|
+
tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
|
3535
|
+
tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
|
3536
|
+
tmp[3] = table_b2b_u[(qh >> 24) ];
|
3537
|
+
|
3538
|
+
const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
|
3539
|
+
const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
|
3540
|
+
|
3541
|
+
const uint8x16_t v0 = vld1q_u8(x0->qs);
|
3542
|
+
|
3543
|
+
// 4-bit -> 8-bit
|
3544
|
+
const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, vdupq_n_u8(0x0F)));
|
3545
|
+
const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
|
3546
|
+
|
3547
|
+
// interleave
|
3548
|
+
const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
|
3549
|
+
const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
|
3550
|
+
|
3551
|
+
// add
|
3552
|
+
const int8x16_t v0lf = vorrq_s8(v0lz, qhl);
|
3553
|
+
const int8x16_t v0hf = vorrq_s8(v0hz, qhh);
|
3554
|
+
|
3555
|
+
// load y
|
3556
|
+
const int8x16_t v1l = vld1q_s8(y0->qs);
|
3557
|
+
const int8x16_t v1h = vld1q_s8(y0->qs + 16);
|
3558
|
+
|
3559
|
+
const float x0d = GGML_FP16_TO_FP32(x0->d);
|
2713
3560
|
|
2714
|
-
|
2715
|
-
|
3561
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
3562
|
+
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
|
3563
|
+
vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
|
3564
|
+
vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
|
3565
|
+
#else
|
3566
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
|
3567
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
|
3568
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
|
3569
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
|
2716
3570
|
|
2717
|
-
|
2718
|
-
|
2719
|
-
|
3571
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
3572
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
3573
|
+
|
3574
|
+
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
|
2720
3575
|
#endif
|
3576
|
+
}
|
2721
3577
|
|
2722
|
-
*s =
|
2723
|
-
|
3578
|
+
*s = vaddvq_f32(sumv) + summs;
|
3579
|
+
#elif defined(__wasm_simd128__)
|
3580
|
+
v128_t sumv = wasm_f32x4_splat(0.0f);
|
2724
3581
|
|
2725
|
-
|
2726
|
-
const int nb = n / QK8_0;
|
3582
|
+
float summs = 0.0f;
|
2727
3583
|
|
2728
|
-
|
2729
|
-
assert(nb % 2 == 0);
|
2730
|
-
assert(QK8_0 == 2*QK4_2);
|
3584
|
+
uint64_t tmp[4];
|
2731
3585
|
|
2732
|
-
|
2733
|
-
|
3586
|
+
for (int i = 0; i < nb; ++i) {
|
3587
|
+
const block_q5_1 * restrict x0 = &x[i];
|
3588
|
+
const block_q8_1 * restrict y0 = &y[i];
|
2734
3589
|
|
2735
|
-
|
3590
|
+
summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
|
2736
3591
|
|
2737
|
-
|
2738
|
-
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
2739
|
-
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
3592
|
+
const v128_t m4b = wasm_i8x16_splat(0x0F);
|
2740
3593
|
|
2741
|
-
|
2742
|
-
|
2743
|
-
|
2744
|
-
const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0];
|
2745
|
-
const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1];
|
3594
|
+
// extract the 5th bit
|
3595
|
+
uint32_t qh;
|
3596
|
+
memcpy(&qh, x0->qh, sizeof(qh));
|
2746
3597
|
|
2747
|
-
|
2748
|
-
|
3598
|
+
tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
|
3599
|
+
tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
|
3600
|
+
tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
|
3601
|
+
tmp[3] = table_b2b_u[(qh >> 24) ];
|
2749
3602
|
|
2750
|
-
const
|
2751
|
-
const
|
3603
|
+
const v128_t qhl = wasm_v128_load(tmp + 0);
|
3604
|
+
const v128_t qhh = wasm_v128_load(tmp + 2);
|
2752
3605
|
|
2753
|
-
const
|
2754
|
-
const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
|
3606
|
+
const v128_t v0 = wasm_v128_load(x0->qs);
|
2755
3607
|
|
2756
3608
|
// 4-bit -> 8-bit
|
2757
|
-
const
|
2758
|
-
const
|
2759
|
-
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
2760
|
-
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
3609
|
+
const v128_t v0l = wasm_v128_and (v0, m4b);
|
3610
|
+
const v128_t v0h = wasm_u8x16_shr(v0, 4);
|
2761
3611
|
|
2762
|
-
|
2763
|
-
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
2764
|
-
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
2765
|
-
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
2766
|
-
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
3612
|
+
static bool x = true;
|
2767
3613
|
|
2768
3614
|
// interleave
|
2769
|
-
const
|
2770
|
-
const
|
2771
|
-
const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
|
2772
|
-
const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
|
2773
|
-
|
2774
|
-
// load y
|
2775
|
-
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2776
|
-
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2777
|
-
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2778
|
-
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
3615
|
+
const v128_t v0lz = wasm_v8x16_shuffle(v0l, v0h, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23);
|
3616
|
+
const v128_t v0hz = wasm_v8x16_shuffle(v0l, v0h, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31);
|
2779
3617
|
|
2780
|
-
|
2781
|
-
|
2782
|
-
|
2783
|
-
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
|
3618
|
+
// add high bit
|
3619
|
+
const v128_t v0lf = wasm_v128_or(v0lz, qhl);
|
3620
|
+
const v128_t v0hf = wasm_v128_or(v0hz, qhh);
|
2784
3621
|
|
2785
|
-
|
2786
|
-
|
2787
|
-
|
2788
|
-
#else
|
2789
|
-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
|
2790
|
-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
|
2791
|
-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
|
2792
|
-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
|
3622
|
+
// load y
|
3623
|
+
const v128_t v1l = wasm_v128_load(y0->qs);
|
3624
|
+
const v128_t v1h = wasm_v128_load(y0->qs + 16);
|
2793
3625
|
|
2794
|
-
|
2795
|
-
const
|
2796
|
-
const
|
2797
|
-
const
|
3626
|
+
// int8x16 -> int16x8
|
3627
|
+
const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
|
3628
|
+
const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
|
3629
|
+
const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
|
3630
|
+
const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
|
2798
3631
|
|
2799
|
-
const
|
2800
|
-
const
|
2801
|
-
const
|
2802
|
-
const
|
3632
|
+
const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
|
3633
|
+
const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
|
3634
|
+
const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
|
3635
|
+
const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
|
2803
3636
|
|
2804
|
-
|
2805
|
-
vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
|
2806
|
-
vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
|
3637
|
+
const float x0d = GGML_FP16_TO_FP32(x0->d);
|
2807
3638
|
|
2808
|
-
|
2809
|
-
|
2810
|
-
|
2811
|
-
|
3639
|
+
// dot product
|
3640
|
+
sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
|
3641
|
+
wasm_i32x4_add(
|
3642
|
+
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
|
3643
|
+
wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
|
3644
|
+
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
|
3645
|
+
wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), wasm_f32x4_splat(x0d*y0->d)));
|
2812
3646
|
}
|
2813
3647
|
|
2814
|
-
|
3648
|
+
*s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
|
3649
|
+
wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
|
2815
3650
|
#elif defined(__AVX2__)
|
2816
3651
|
// Initialize accumulator with zeros
|
2817
3652
|
__m256 acc = _mm256_setzero_ps();
|
3653
|
+
float summs = 0.0f;
|
2818
3654
|
|
2819
3655
|
// Main loop
|
2820
3656
|
for (int i = 0; i < nb; i++) {
|
2821
|
-
|
2822
|
-
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
|
2823
|
-
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
|
2824
|
-
const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
|
2825
|
-
|
2826
|
-
__m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
|
2827
|
-
__m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
|
2828
|
-
__m256i bx = _mm256_set_m128i(bx1, bx0);
|
2829
|
-
|
2830
|
-
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2831
|
-
const __m256i off = _mm256_set1_epi8(8);
|
2832
|
-
bx = _mm256_sub_epi8(bx, off);
|
3657
|
+
const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
|
2833
3658
|
|
2834
|
-
|
3659
|
+
summs += GGML_FP16_TO_FP32(x[i].m) * (y[i].s0 + y[i].s1);
|
2835
3660
|
|
2836
|
-
|
2837
|
-
|
2838
|
-
|
2839
|
-
|
2840
|
-
// Perform multiplication and create 16-bit values
|
2841
|
-
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
3661
|
+
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
3662
|
+
__m256i bxhi = bytes_from_bits_32(x[i].qh);
|
3663
|
+
bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
|
3664
|
+
bx = _mm256_or_si256(bx, bxhi);
|
2842
3665
|
|
2843
|
-
const
|
2844
|
-
__m256i
|
3666
|
+
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
|
3667
|
+
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
2845
3668
|
|
2846
|
-
|
2847
|
-
__m256 q = _mm256_cvtepi32_ps(xy_q);
|
3669
|
+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
2848
3670
|
|
2849
|
-
|
2850
|
-
acc = _mm256_fmadd_ps(d, q, acc);
|
3671
|
+
acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
|
2851
3672
|
}
|
2852
3673
|
|
2853
|
-
|
2854
|
-
__m128 res = _mm256_extractf128_ps(acc, 1);
|
2855
|
-
res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
|
2856
|
-
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
|
2857
|
-
res = _mm_add_ss(res, _mm_movehdup_ps(res));
|
2858
|
-
|
2859
|
-
sumf = _mm_cvtss_f32(res);
|
3674
|
+
*s = hsum_float_8(acc) + summs;
|
2860
3675
|
#else
|
2861
|
-
|
3676
|
+
float sumf = 0.0;
|
3677
|
+
|
2862
3678
|
for (int i = 0; i < nb; i++) {
|
2863
|
-
const uint8_t * restrict x0 = x[
|
2864
|
-
const uint8_t * restrict x1 = x[2*i + 1].qs;
|
3679
|
+
const uint8_t * restrict x0 = x[i].qs;
|
2865
3680
|
const int8_t * restrict y0 = y[i].qs;
|
2866
3681
|
|
2867
|
-
|
2868
|
-
|
3682
|
+
uint32_t qh;
|
3683
|
+
memcpy(&qh, x[i].qh, sizeof(qh));
|
2869
3684
|
|
2870
|
-
|
2871
|
-
|
3685
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
3686
|
+
const float m = GGML_FP16_TO_FP32(x[i].m);
|
2872
3687
|
|
2873
|
-
|
2874
|
-
const uint8_t v0 = x0[j];
|
2875
|
-
const uint8_t v1 = x1[j];
|
3688
|
+
int sxy = 0;
|
2876
3689
|
|
2877
|
-
|
2878
|
-
const
|
3690
|
+
for (int j = 0; j < QK8_1/2; j++) {
|
3691
|
+
const uint8_t v0 = x0[j];
|
2879
3692
|
|
2880
|
-
const int
|
2881
|
-
const int
|
3693
|
+
const int x0_0h = ((qh & (1u << (2*j + 0))) >> (2*j + 0)) << 4;
|
3694
|
+
const int x1_0h = ((qh & (1u << (2*j + 1))) >> (2*j + 1)) << 4;
|
2882
3695
|
|
2883
|
-
const int
|
2884
|
-
const int
|
3696
|
+
const int x0_0 = (v0 & 0x0F) | x0_0h;
|
3697
|
+
const int x1_0 = (v0 >> 4) | x1_0h;
|
2885
3698
|
|
2886
|
-
const int
|
2887
|
-
const int
|
3699
|
+
const int y0_0 = y0[2*j + 0];
|
3700
|
+
const int y1_0 = y0[2*j + 1];
|
2888
3701
|
|
2889
|
-
|
2890
|
-
sumi_1 += i0_1*i2_1 + i1_1*i3_1;
|
3702
|
+
sxy += x0_0*y0_0 + x1_0*y1_0;
|
2891
3703
|
}
|
2892
3704
|
|
2893
|
-
sumf += (
|
2894
|
-
sumf += (d1 * y[i].d) * sumi_1;
|
3705
|
+
sumf += (d*sxy)*y[i].d + m*(y[i].s0 + y[i].s1);
|
2895
3706
|
}
|
2896
|
-
#endif
|
2897
3707
|
|
2898
3708
|
*s = sumf;
|
3709
|
+
#endif
|
2899
3710
|
}
|
2900
3711
|
|
2901
|
-
static void
|
3712
|
+
static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2902
3713
|
const int nb = n / QK8_0;
|
2903
3714
|
|
2904
3715
|
assert(n % QK8_0 == 0);
|
2905
3716
|
assert(nb % 2 == 0);
|
2906
|
-
assert(QK8_0 ==
|
3717
|
+
assert(QK8_0 == QK8_0);
|
2907
3718
|
|
2908
|
-
const
|
3719
|
+
const block_q8_0 * restrict x = vx;
|
2909
3720
|
const block_q8_0 * restrict y = vy;
|
2910
3721
|
|
2911
|
-
float sumf = 0.0;
|
2912
|
-
|
2913
3722
|
#if defined(__ARM_NEON)
|
2914
3723
|
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
2915
3724
|
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
2916
3725
|
|
2917
3726
|
for (int i = 0; i < nb; i += 2) {
|
2918
|
-
const
|
2919
|
-
const
|
2920
|
-
const block_q4_3 * restrict x1_0 = &x[2*(i + 1) + 0];
|
2921
|
-
const block_q4_3 * restrict x1_1 = &x[2*(i + 1) + 1];
|
2922
|
-
|
3727
|
+
const block_q8_0 * restrict x0 = &x[i + 0];
|
3728
|
+
const block_q8_0 * restrict x1 = &x[i + 1];
|
2923
3729
|
const block_q8_0 * restrict y0 = &y[i + 0];
|
2924
3730
|
const block_q8_0 * restrict y1 = &y[i + 1];
|
2925
3731
|
|
2926
|
-
const
|
2927
|
-
|
2928
|
-
const
|
2929
|
-
const
|
2930
|
-
const float x1_0d = GGML_FP16_TO_FP32(x1_0->d);
|
2931
|
-
const float x1_1d = GGML_FP16_TO_FP32(x1_1->d);
|
2932
|
-
|
2933
|
-
const float x0_0m = GGML_FP16_TO_FP32(x0_0->m);
|
2934
|
-
const float x0_1m = GGML_FP16_TO_FP32(x0_1->m);
|
2935
|
-
const float x1_0m = GGML_FP16_TO_FP32(x1_0->m);
|
2936
|
-
const float x1_1m = GGML_FP16_TO_FP32(x1_1->m);
|
2937
|
-
|
2938
|
-
const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
|
2939
|
-
const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
|
2940
|
-
|
2941
|
-
// 4-bit -> 8-bit
|
2942
|
-
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
2943
|
-
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
2944
|
-
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
2945
|
-
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
2946
|
-
|
2947
|
-
// interleave
|
2948
|
-
const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
|
2949
|
-
const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
|
2950
|
-
const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
|
2951
|
-
const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
|
3732
|
+
const int8x16_t x0_0 = vld1q_s8(x0->qs);
|
3733
|
+
const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
|
3734
|
+
const int8x16_t x1_0 = vld1q_s8(x1->qs);
|
3735
|
+
const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
|
2952
3736
|
|
2953
3737
|
// load y
|
2954
|
-
const int8x16_t
|
2955
|
-
const int8x16_t
|
2956
|
-
const int8x16_t
|
2957
|
-
const int8x16_t
|
2958
|
-
|
2959
|
-
const int16x8_t sy0_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0l)), vmovl_s8(vget_high_s8(v1_0l)));
|
2960
|
-
const int16x8_t sy0_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0h)), vmovl_s8(vget_high_s8(v1_0h)));
|
2961
|
-
|
2962
|
-
const int16x8_t sy1_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1l)), vmovl_s8(vget_high_s8(v1_1l)));
|
2963
|
-
const int16x8_t sy1_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1h)), vmovl_s8(vget_high_s8(v1_1h)));
|
2964
|
-
|
2965
|
-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_0), vget_high_s16(sy0_0))), x0_0m*y0->d);
|
2966
|
-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_1), vget_high_s16(sy0_1))), x0_1m*y0->d);
|
2967
|
-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_0), vget_high_s16(sy1_0))), x1_0m*y1->d);
|
2968
|
-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_1), vget_high_s16(sy1_1))), x1_1m*y1->d);
|
3738
|
+
const int8x16_t y0_0 = vld1q_s8(y0->qs);
|
3739
|
+
const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
|
3740
|
+
const int8x16_t y1_0 = vld1q_s8(y1->qs);
|
3741
|
+
const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
|
2969
3742
|
|
2970
3743
|
#if defined(__ARM_FEATURE_DOTPROD)
|
2971
|
-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(
|
2972
|
-
|
2973
|
-
|
2974
|
-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), x1_1d*y1->d);
|
2975
|
-
#else
|
2976
|
-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
|
2977
|
-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
|
2978
|
-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
|
2979
|
-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
|
2980
|
-
|
2981
|
-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
|
2982
|
-
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
|
2983
|
-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
|
2984
|
-
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
|
3744
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
3745
|
+
vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
|
3746
|
+
vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), x0->d*y0->d);
|
2985
3747
|
|
2986
|
-
|
2987
|
-
|
2988
|
-
|
2989
|
-
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
3748
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
3749
|
+
vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
|
3750
|
+
vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), x1->d*y1->d);
|
2990
3751
|
|
2991
|
-
|
2992
|
-
|
2993
|
-
|
2994
|
-
|
3752
|
+
#else
|
3753
|
+
const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
|
3754
|
+
const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
|
3755
|
+
const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
|
3756
|
+
const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
3757
|
+
|
3758
|
+
const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
|
3759
|
+
const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
|
3760
|
+
const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
|
3761
|
+
const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
|
3762
|
+
|
3763
|
+
const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
3764
|
+
const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
3765
|
+
const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
|
3766
|
+
const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
|
3767
|
+
|
3768
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0->d*y0->d);
|
3769
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), x1->d*y1->d);
|
2995
3770
|
#endif
|
2996
3771
|
}
|
2997
3772
|
|
2998
|
-
|
2999
|
-
#
|
3000
|
-
//
|
3001
|
-
|
3002
|
-
const uint8_t * restrict x0 = x[2*i + 0].qs;
|
3003
|
-
const uint8_t * restrict x1 = x[2*i + 1].qs;
|
3004
|
-
const int8_t * restrict y0 = y[i].qs;
|
3005
|
-
|
3006
|
-
const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
|
3007
|
-
const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
|
3008
|
-
const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
|
3009
|
-
const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
|
3010
|
-
|
3011
|
-
int sy_0 = 0;
|
3012
|
-
int sy_1 = 0;
|
3773
|
+
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
3774
|
+
#elif defined(__AVX2__)
|
3775
|
+
// Initialize accumulator with zeros
|
3776
|
+
__m256 acc = _mm256_setzero_ps();
|
3013
3777
|
|
3014
|
-
|
3015
|
-
|
3778
|
+
// Main loop
|
3779
|
+
for (int i = 0; i < nb; ++i) {
|
3780
|
+
// Compute combined scale for the block
|
3781
|
+
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
|
3782
|
+
__m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
|
3783
|
+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
3016
3784
|
|
3017
|
-
|
3018
|
-
const uint8_t v0 = x0[j];
|
3019
|
-
const uint8_t v1 = x1[j];
|
3785
|
+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
3020
3786
|
|
3021
|
-
|
3022
|
-
|
3787
|
+
// Multiply q with scale and accumulate
|
3788
|
+
acc = _mm256_fmadd_ps( d, q, acc );
|
3789
|
+
}
|
3023
3790
|
|
3024
|
-
|
3025
|
-
|
3791
|
+
*s = hsum_float_8(acc);
|
3792
|
+
#else
|
3793
|
+
// scalar
|
3794
|
+
float sumf = 0.0;
|
3026
3795
|
|
3027
|
-
|
3028
|
-
|
3796
|
+
for (int i = 0; i < nb; i++) {
|
3797
|
+
const int8_t * restrict x0 = x[i].qs;
|
3798
|
+
const int8_t * restrict y0 = y[i].qs;
|
3029
3799
|
|
3030
|
-
|
3031
|
-
const int y1_1 = y0[2*(j + QK8_0/4) + 1];
|
3800
|
+
int sumi = 0;
|
3032
3801
|
|
3033
|
-
|
3034
|
-
|
3802
|
+
for (int j = 0; j < QK8_0; j++) {
|
3803
|
+
const int v0 = x0[j];
|
3804
|
+
const int v1 = y0[j];
|
3035
3805
|
|
3036
|
-
|
3037
|
-
sxy_1 += x0_1*y0_1 + x1_1*y1_1;
|
3806
|
+
sumi += v0*v1;
|
3038
3807
|
}
|
3039
3808
|
|
3040
|
-
sumf += (
|
3041
|
-
sumf += (d1*sxy_1 + m1*sy_1)*y[i].d;
|
3809
|
+
sumf += (x[i].d*y[i].d)*sumi;
|
3042
3810
|
}
|
3043
|
-
#endif
|
3044
3811
|
|
3045
3812
|
*s = sumf;
|
3813
|
+
#endif
|
3046
3814
|
}
|
3047
3815
|
|
3048
|
-
|
3049
3816
|
// compute GGML_VEC_DOT_UNROLL dot products at once
|
3050
3817
|
// xs - x row stride in bytes
|
3051
3818
|
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
|
@@ -3242,6 +4009,14 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
|
|
3242
4009
|
#endif
|
3243
4010
|
}
|
3244
4011
|
|
4012
|
+
inline static void ggml_vec_sum_ggf(const int n, ggml_float * s, const float * x) {
|
4013
|
+
ggml_float sum = 0.0;
|
4014
|
+
for (int i = 0; i < n; ++i) {
|
4015
|
+
sum += (ggml_float)x[i];
|
4016
|
+
}
|
4017
|
+
*s = sum;
|
4018
|
+
}
|
4019
|
+
|
3245
4020
|
inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
|
3246
4021
|
#ifndef GGML_USE_ACCELERATE
|
3247
4022
|
float max = -INFINITY;
|
@@ -3293,13 +4068,15 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
|
|
3293
4068
|
[GGML_TYPE_Q4_0] = QK4_0,
|
3294
4069
|
[GGML_TYPE_Q4_1] = QK4_1,
|
3295
4070
|
[GGML_TYPE_Q4_2] = QK4_2,
|
3296
|
-
[
|
4071
|
+
[GGML_TYPE_Q5_0] = QK5_0,
|
4072
|
+
[GGML_TYPE_Q5_1] = QK5_1,
|
3297
4073
|
[GGML_TYPE_Q8_0] = QK8_0,
|
4074
|
+
[GGML_TYPE_Q8_1] = QK8_1,
|
3298
4075
|
[GGML_TYPE_I8] = 1,
|
3299
4076
|
[GGML_TYPE_I16] = 1,
|
3300
4077
|
[GGML_TYPE_I32] = 1,
|
3301
4078
|
};
|
3302
|
-
static_assert(GGML_TYPE_COUNT ==
|
4079
|
+
static_assert(GGML_TYPE_COUNT == 13, "GGML_BLCK_SIZE is outdated");
|
3303
4080
|
|
3304
4081
|
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
|
3305
4082
|
[GGML_TYPE_F32] = sizeof(float),
|
@@ -3307,13 +4084,15 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
|
|
3307
4084
|
[GGML_TYPE_Q4_0] = sizeof(block_q4_0),
|
3308
4085
|
[GGML_TYPE_Q4_1] = sizeof(block_q4_1),
|
3309
4086
|
[GGML_TYPE_Q4_2] = sizeof(block_q4_2),
|
3310
|
-
[
|
4087
|
+
[GGML_TYPE_Q5_0] = sizeof(block_q5_0),
|
4088
|
+
[GGML_TYPE_Q5_1] = sizeof(block_q5_1),
|
3311
4089
|
[GGML_TYPE_Q8_0] = sizeof(block_q8_0),
|
4090
|
+
[GGML_TYPE_Q8_1] = sizeof(block_q8_1),
|
3312
4091
|
[GGML_TYPE_I8] = sizeof(int8_t),
|
3313
4092
|
[GGML_TYPE_I16] = sizeof(int16_t),
|
3314
4093
|
[GGML_TYPE_I32] = sizeof(int32_t),
|
3315
4094
|
};
|
3316
|
-
static_assert(GGML_TYPE_COUNT ==
|
4095
|
+
static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_SIZE is outdated");
|
3317
4096
|
|
3318
4097
|
|
3319
4098
|
static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
|
@@ -3322,13 +4101,15 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
|
|
3322
4101
|
[GGML_TYPE_Q4_0] = "q4_0",
|
3323
4102
|
[GGML_TYPE_Q4_1] = "q4_1",
|
3324
4103
|
[GGML_TYPE_Q4_2] = "q4_2",
|
3325
|
-
[
|
4104
|
+
[GGML_TYPE_Q5_0] = "q5_0",
|
4105
|
+
[GGML_TYPE_Q5_1] = "q5_1",
|
3326
4106
|
[GGML_TYPE_Q8_0] = "q8_0",
|
4107
|
+
[GGML_TYPE_Q8_1] = "q8_1",
|
3327
4108
|
[GGML_TYPE_I8] = "i8",
|
3328
4109
|
[GGML_TYPE_I16] = "i16",
|
3329
4110
|
[GGML_TYPE_I32] = "i32",
|
3330
4111
|
};
|
3331
|
-
static_assert(GGML_TYPE_COUNT ==
|
4112
|
+
static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_NAME is outdated");
|
3332
4113
|
|
3333
4114
|
static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
|
3334
4115
|
[GGML_TYPE_F32] = false,
|
@@ -3336,13 +4117,15 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
|
|
3336
4117
|
[GGML_TYPE_Q4_0] = true,
|
3337
4118
|
[GGML_TYPE_Q4_1] = true,
|
3338
4119
|
[GGML_TYPE_Q4_2] = true,
|
3339
|
-
[
|
4120
|
+
[GGML_TYPE_Q5_0] = true,
|
4121
|
+
[GGML_TYPE_Q5_1] = true,
|
3340
4122
|
[GGML_TYPE_Q8_0] = true,
|
4123
|
+
[GGML_TYPE_Q8_1] = true,
|
3341
4124
|
[GGML_TYPE_I8] = false,
|
3342
4125
|
[GGML_TYPE_I16] = false,
|
3343
4126
|
[GGML_TYPE_I32] = false,
|
3344
4127
|
};
|
3345
|
-
static_assert(GGML_TYPE_COUNT ==
|
4128
|
+
static_assert(GGML_TYPE_COUNT == 13, "GGML_IS_QUANTIZED is outdated");
|
3346
4129
|
|
3347
4130
|
static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
3348
4131
|
"NONE",
|
@@ -3380,6 +4163,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
|
3380
4163
|
"DIAG_MASK_INF",
|
3381
4164
|
"SOFT_MAX",
|
3382
4165
|
"ROPE",
|
4166
|
+
"ALIBI",
|
3383
4167
|
"CONV_1D_1S",
|
3384
4168
|
"CONV_1D_2S",
|
3385
4169
|
|
@@ -3390,7 +4174,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
|
3390
4174
|
"MAP_BINARY",
|
3391
4175
|
};
|
3392
4176
|
|
3393
|
-
static_assert(GGML_OP_COUNT ==
|
4177
|
+
static_assert(GGML_OP_COUNT == 39, "GGML_OP_COUNT != 39");
|
3394
4178
|
|
3395
4179
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
3396
4180
|
"none",
|
@@ -3428,6 +4212,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
3428
4212
|
"diag_mask_inf(x)",
|
3429
4213
|
"soft_max(x)",
|
3430
4214
|
"rope(x)",
|
4215
|
+
"alibi(x)",
|
3431
4216
|
"conv_1d_1s(x)",
|
3432
4217
|
"conv_1d_2s(x)",
|
3433
4218
|
|
@@ -3438,7 +4223,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
3438
4223
|
"f(x,y)",
|
3439
4224
|
};
|
3440
4225
|
|
3441
|
-
static_assert(GGML_OP_COUNT ==
|
4226
|
+
static_assert(GGML_OP_COUNT == 39, "GGML_OP_COUNT != 39");
|
3442
4227
|
|
3443
4228
|
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
3444
4229
|
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
@@ -3608,6 +4393,27 @@ bool ggml_is_quantized(enum ggml_type type) {
|
|
3608
4393
|
return GGML_IS_QUANTIZED[type];
|
3609
4394
|
}
|
3610
4395
|
|
4396
|
+
enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
4397
|
+
enum ggml_type wtype = GGML_TYPE_COUNT;
|
4398
|
+
|
4399
|
+
switch (ftype) {
|
4400
|
+
case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break;
|
4401
|
+
case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break;
|
4402
|
+
case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break;
|
4403
|
+
case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break;
|
4404
|
+
case GGML_FTYPE_MOSTLY_Q4_2: wtype = GGML_TYPE_Q4_2; break;
|
4405
|
+
case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
|
4406
|
+
case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
|
4407
|
+
case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
|
4408
|
+
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
|
4409
|
+
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
|
4410
|
+
}
|
4411
|
+
|
4412
|
+
GGML_ASSERT(wtype != GGML_TYPE_COUNT);
|
4413
|
+
|
4414
|
+
return wtype;
|
4415
|
+
}
|
4416
|
+
|
3611
4417
|
static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
|
3612
4418
|
return tensor->nb[0] > tensor->nb[1];
|
3613
4419
|
}
|
@@ -3718,10 +4524,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|
3718
4524
|
GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
|
3719
4525
|
}
|
3720
4526
|
|
3721
|
-
|
3722
|
-
|
3723
|
-
|
3724
|
-
|
4527
|
+
#if defined(GGML_USE_CUBLAS)
|
4528
|
+
ggml_init_cublas();
|
4529
|
+
#elif defined(GGML_USE_CLBLAST)
|
4530
|
+
ggml_cl_init();
|
4531
|
+
#endif
|
3725
4532
|
|
3726
4533
|
is_first_call = false;
|
3727
4534
|
}
|
@@ -3802,7 +4609,7 @@ void ggml_free(struct ggml_context * ctx) {
|
|
3802
4609
|
}
|
3803
4610
|
|
3804
4611
|
size_t ggml_used_mem(const struct ggml_context * ctx) {
|
3805
|
-
return ctx->objects_end->offs + ctx->objects_end->size;
|
4612
|
+
return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size;
|
3806
4613
|
}
|
3807
4614
|
|
3808
4615
|
size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) {
|
@@ -3915,6 +4722,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
|
|
3915
4722
|
/*.perf_cycles =*/ 0,
|
3916
4723
|
/*.perf_time_us =*/ 0,
|
3917
4724
|
/*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
|
4725
|
+
/*.name =*/ { 0 },
|
3918
4726
|
/*.pad =*/ { 0 },
|
3919
4727
|
};
|
3920
4728
|
|
@@ -4269,6 +5077,15 @@ float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
|
|
4269
5077
|
return (float *)(tensor->data);
|
4270
5078
|
}
|
4271
5079
|
|
5080
|
+
const char * ggml_get_name(const struct ggml_tensor * tensor) {
|
5081
|
+
return tensor->name;
|
5082
|
+
}
|
5083
|
+
|
5084
|
+
void ggml_set_name(struct ggml_tensor * tensor, const char * name) {
|
5085
|
+
strncpy(tensor->name, name, sizeof(tensor->name));
|
5086
|
+
tensor->name[sizeof(tensor->name) - 1] = '\0';
|
5087
|
+
}
|
5088
|
+
|
4272
5089
|
struct ggml_tensor * ggml_view_tensor(
|
4273
5090
|
struct ggml_context * ctx,
|
4274
5091
|
const struct ggml_tensor * src) {
|
@@ -5368,6 +6185,7 @@ struct ggml_tensor * ggml_diag_mask_inf(
|
|
5368
6185
|
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
5369
6186
|
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
5370
6187
|
struct ggml_tensor * b = ggml_new_i32(ctx, n_past);
|
6188
|
+
ggml_set_name(b, "n_past");
|
5371
6189
|
|
5372
6190
|
result->op = GGML_OP_DIAG_MASK_INF;
|
5373
6191
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
@@ -5393,22 +6211,55 @@ struct ggml_tensor * ggml_soft_max(
|
|
5393
6211
|
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
5394
6212
|
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
5395
6213
|
|
5396
|
-
result->op = GGML_OP_SOFT_MAX;
|
6214
|
+
result->op = GGML_OP_SOFT_MAX;
|
6215
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6216
|
+
result->src0 = a;
|
6217
|
+
result->src1 = NULL;
|
6218
|
+
|
6219
|
+
return result;
|
6220
|
+
}
|
6221
|
+
|
6222
|
+
// ggml_rope
|
6223
|
+
|
6224
|
+
struct ggml_tensor * ggml_rope(
|
6225
|
+
struct ggml_context * ctx,
|
6226
|
+
struct ggml_tensor * a,
|
6227
|
+
int n_past,
|
6228
|
+
int n_dims,
|
6229
|
+
int mode) {
|
6230
|
+
GGML_ASSERT(n_past >= 0);
|
6231
|
+
bool is_node = false;
|
6232
|
+
|
6233
|
+
if (a->grad) {
|
6234
|
+
GGML_ASSERT(false); // TODO: implement backward
|
6235
|
+
is_node = true;
|
6236
|
+
}
|
6237
|
+
|
6238
|
+
// TODO: when implement backward, fix this:
|
6239
|
+
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
6240
|
+
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
6241
|
+
|
6242
|
+
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
|
6243
|
+
((int32_t *) b->data)[0] = n_past;
|
6244
|
+
((int32_t *) b->data)[1] = n_dims;
|
6245
|
+
((int32_t *) b->data)[2] = mode;
|
6246
|
+
ggml_set_name(b, "n_past, n_dims, mode");
|
6247
|
+
|
6248
|
+
result->op = GGML_OP_ROPE;
|
5397
6249
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5398
6250
|
result->src0 = a;
|
5399
|
-
result->src1 =
|
6251
|
+
result->src1 = b;
|
5400
6252
|
|
5401
6253
|
return result;
|
5402
6254
|
}
|
5403
6255
|
|
5404
|
-
//
|
6256
|
+
// ggml_alibi
|
5405
6257
|
|
5406
|
-
struct ggml_tensor *
|
6258
|
+
struct ggml_tensor * ggml_alibi(
|
5407
6259
|
struct ggml_context * ctx,
|
5408
6260
|
struct ggml_tensor * a,
|
5409
6261
|
int n_past,
|
5410
|
-
int
|
5411
|
-
int mode) {
|
6262
|
+
int n_head) {
|
5412
6263
|
GGML_ASSERT(n_past >= 0);
|
5413
6264
|
bool is_node = false;
|
5414
6265
|
|
@@ -5421,12 +6272,11 @@ struct ggml_tensor * ggml_rope(
|
|
5421
6272
|
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
5422
6273
|
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
5423
6274
|
|
5424
|
-
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32,
|
6275
|
+
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
|
5425
6276
|
((int32_t *) b->data)[0] = n_past;
|
5426
|
-
((int32_t *) b->data)[1] =
|
5427
|
-
((int32_t *) b->data)[2] = mode;
|
6277
|
+
((int32_t *) b->data)[1] = n_head;
|
5428
6278
|
|
5429
|
-
result->op =
|
6279
|
+
result->op = GGML_OP_ALIBI;
|
5430
6280
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5431
6281
|
result->src0 = a;
|
5432
6282
|
result->src1 = b;
|
@@ -6553,7 +7403,9 @@ static void ggml_compute_forward_add(
|
|
6553
7403
|
case GGML_TYPE_Q4_0:
|
6554
7404
|
case GGML_TYPE_Q4_1:
|
6555
7405
|
case GGML_TYPE_Q4_2:
|
6556
|
-
case
|
7406
|
+
case GGML_TYPE_Q5_0:
|
7407
|
+
case GGML_TYPE_Q5_1:
|
7408
|
+
case GGML_TYPE_Q8_0:
|
6557
7409
|
{
|
6558
7410
|
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
6559
7411
|
} break;
|
@@ -6811,15 +7663,20 @@ static void ggml_compute_forward_sum_f32(
|
|
6811
7663
|
const size_t nb02 = src0->nb[2];
|
6812
7664
|
const size_t nb03 = src0->nb[3];
|
6813
7665
|
|
7666
|
+
ggml_float sum = 0;
|
7667
|
+
ggml_float row_sum = 0;
|
7668
|
+
|
6814
7669
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
6815
7670
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
6816
7671
|
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
6817
|
-
|
6818
|
-
|
7672
|
+
ggml_vec_sum_ggf(ne00,
|
7673
|
+
&row_sum,
|
6819
7674
|
(float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
|
7675
|
+
sum += row_sum;
|
6820
7676
|
}
|
6821
7677
|
}
|
6822
7678
|
}
|
7679
|
+
((float *) dst->data)[0] = sum;
|
6823
7680
|
}
|
6824
7681
|
|
6825
7682
|
static void ggml_compute_forward_sum(
|
@@ -7454,7 +8311,7 @@ static void ggml_compute_forward_rms_norm(
|
|
7454
8311
|
|
7455
8312
|
// ggml_compute_forward_mul_mat
|
7456
8313
|
|
7457
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(
|
8314
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
7458
8315
|
// helper function to determine if it is better to use BLAS or not
|
7459
8316
|
// for large matrices, BLAS is faster
|
7460
8317
|
static bool ggml_compute_forward_mul_mat_use_blas(
|
@@ -7471,7 +8328,8 @@ static bool ggml_compute_forward_mul_mat_use_blas(
|
|
7471
8328
|
|
7472
8329
|
// TODO: find the optimal values for these
|
7473
8330
|
if (ggml_is_contiguous(src0) &&
|
7474
|
-
ggml_is_contiguous(src1) &&
|
8331
|
+
ggml_is_contiguous(src1) &&
|
8332
|
+
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
|
7475
8333
|
|
7476
8334
|
/*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
|
7477
8335
|
return true;
|
@@ -7494,7 +8352,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
7494
8352
|
const int64_t ne02 = src0->ne[2];
|
7495
8353
|
const int64_t ne03 = src0->ne[3];
|
7496
8354
|
|
7497
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(
|
8355
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
7498
8356
|
const int64_t ne10 = src1->ne[0];
|
7499
8357
|
#endif
|
7500
8358
|
const int64_t ne11 = src1->ne[1];
|
@@ -7551,7 +8409,16 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
7551
8409
|
// nb01 >= nb00 - src0 is not transposed
|
7552
8410
|
// compute by src0 rows
|
7553
8411
|
|
7554
|
-
#if defined(
|
8412
|
+
#if defined(GGML_USE_CUBLAS)
|
8413
|
+
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
|
8414
|
+
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
8415
|
+
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
8416
|
+
}
|
8417
|
+
return;
|
8418
|
+
}
|
8419
|
+
#endif
|
8420
|
+
|
8421
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
7555
8422
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
7556
8423
|
if (params->ith != 0) {
|
7557
8424
|
return;
|
@@ -7565,45 +8432,21 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
7565
8432
|
return;
|
7566
8433
|
}
|
7567
8434
|
|
7568
|
-
#if defined(GGML_USE_CUBLAS)
|
7569
|
-
float *d_X = NULL;
|
7570
|
-
float *d_Y = NULL;
|
7571
|
-
float *d_D = NULL;
|
7572
|
-
const float alpha = 1.0f;
|
7573
|
-
const float beta = 0.0f;
|
7574
|
-
const int x_ne = ne01 * ne10;
|
7575
|
-
const int y_ne = ne11 * ne10;
|
7576
|
-
const int d_ne = ne11 * ne01;
|
7577
|
-
|
7578
|
-
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
|
7579
|
-
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
|
7580
|
-
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
7581
|
-
#endif
|
7582
|
-
|
7583
8435
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
7584
8436
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
7585
8437
|
const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
|
7586
8438
|
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
7587
|
-
|
7588
8439
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
7589
8440
|
|
7590
|
-
#if defined(
|
7591
|
-
// copy data to device
|
7592
|
-
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
|
7593
|
-
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
7594
|
-
|
7595
|
-
// compute
|
7596
|
-
CUBLAS_CHECK(
|
7597
|
-
cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
7598
|
-
ne01, ne11, ne10,
|
7599
|
-
&alpha, d_X, ne00,
|
7600
|
-
d_Y, ne10,
|
7601
|
-
&beta, d_D, ne01));
|
7602
|
-
|
7603
|
-
// copy data to host
|
7604
|
-
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
7605
|
-
#else
|
8441
|
+
#if defined(GGML_USE_CLBLAST)
|
7606
8442
|
// zT = y * xT
|
8443
|
+
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
|
8444
|
+
ne11, ne01, ne10,
|
8445
|
+
1.0f, y, ne10,
|
8446
|
+
x, ne10,
|
8447
|
+
0.0f, d, ne01,
|
8448
|
+
GGML_TYPE_F32);
|
8449
|
+
#else
|
7607
8450
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
7608
8451
|
ne11, ne01, ne10,
|
7609
8452
|
1.0f, y, ne10,
|
@@ -7612,12 +8455,6 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
7612
8455
|
#endif
|
7613
8456
|
}
|
7614
8457
|
}
|
7615
|
-
#if defined(GGML_USE_CUBLAS)
|
7616
|
-
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
7617
|
-
CUDA_CHECK(cudaFree(d_X));
|
7618
|
-
CUDA_CHECK(cudaFree(d_Y));
|
7619
|
-
CUDA_CHECK(cudaFree(d_D));
|
7620
|
-
#endif
|
7621
8458
|
//printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
7622
8459
|
|
7623
8460
|
return;
|
@@ -7747,7 +8584,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
7747
8584
|
// nb01 >= nb00 - src0 is not transposed
|
7748
8585
|
// compute by src0 rows
|
7749
8586
|
|
7750
|
-
#if defined(
|
8587
|
+
#if defined(GGML_USE_CUBLAS)
|
8588
|
+
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
|
8589
|
+
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
8590
|
+
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
8591
|
+
}
|
8592
|
+
return;
|
8593
|
+
}
|
8594
|
+
#endif
|
8595
|
+
|
8596
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
7751
8597
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
7752
8598
|
GGML_ASSERT(nb10 == sizeof(float));
|
7753
8599
|
|
@@ -7763,37 +8609,9 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
7763
8609
|
return;
|
7764
8610
|
}
|
7765
8611
|
|
7766
|
-
#if defined(GGML_USE_CUBLAS)
|
7767
|
-
ggml_fp16_t * const wdata = params->wdata;
|
7768
|
-
|
7769
|
-
float *d_X = NULL;
|
7770
|
-
float *d_Y = NULL;
|
7771
|
-
float *d_D = NULL;
|
7772
|
-
const float alpha = 1.0f;
|
7773
|
-
const float beta = 0.0f;
|
7774
|
-
const int x_ne = ne01 * ne10;
|
7775
|
-
const int y_ne = ne11 * ne10;
|
7776
|
-
const int d_ne = ne11 * ne01;
|
7777
|
-
|
7778
|
-
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne));
|
7779
|
-
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
|
7780
|
-
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
7781
|
-
#else
|
7782
|
-
float * const wdata = params->wdata;
|
7783
|
-
#endif
|
7784
8612
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
7785
8613
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
7786
|
-
|
7787
|
-
// with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
|
7788
|
-
{
|
7789
|
-
size_t id = 0;
|
7790
|
-
for (int64_t i01 = 0; i01 < ne11; ++i01) {
|
7791
|
-
for (int64_t i00 = 0; i00 < ne10; ++i00) {
|
7792
|
-
wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
|
7793
|
-
}
|
7794
|
-
}
|
7795
|
-
}
|
7796
|
-
#else
|
8614
|
+
float * const wdata = params->wdata;
|
7797
8615
|
{
|
7798
8616
|
size_t id = 0;
|
7799
8617
|
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
@@ -7801,31 +8619,23 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
7801
8619
|
wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
|
7802
8620
|
}
|
7803
8621
|
}
|
8622
|
+
|
8623
|
+
assert(id*sizeof(float) <= params->wsize);
|
7804
8624
|
}
|
7805
|
-
#endif
|
7806
8625
|
|
7807
|
-
#if defined(
|
7808
|
-
const
|
7809
|
-
const
|
8626
|
+
#if defined(GGML_USE_CLBLAST)
|
8627
|
+
const float * x = wdata;
|
8628
|
+
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
7810
8629
|
|
7811
8630
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
7812
8631
|
|
7813
|
-
//
|
7814
|
-
|
7815
|
-
|
7816
|
-
|
7817
|
-
|
7818
|
-
|
7819
|
-
|
7820
|
-
ne01, ne11, ne10,
|
7821
|
-
&alpha, d_X, CUDA_R_16F, ne00,
|
7822
|
-
d_Y, CUDA_R_16F, ne10,
|
7823
|
-
&beta, d_D, CUDA_R_32F, ne01,
|
7824
|
-
CUBLAS_COMPUTE_32F,
|
7825
|
-
CUBLAS_GEMM_DEFAULT));
|
7826
|
-
|
7827
|
-
// copy data to host
|
7828
|
-
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
8632
|
+
// zT = y * xT
|
8633
|
+
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
|
8634
|
+
ne11, ne01, ne10,
|
8635
|
+
1.0f, y, ne10,
|
8636
|
+
x, ne10,
|
8637
|
+
0.0f, d, ne01,
|
8638
|
+
GGML_TYPE_F32);
|
7829
8639
|
#else
|
7830
8640
|
const float * x = wdata;
|
7831
8641
|
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
@@ -7842,12 +8652,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
7842
8652
|
}
|
7843
8653
|
}
|
7844
8654
|
|
7845
|
-
#if defined(GGML_USE_CUBLAS)
|
7846
|
-
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
7847
|
-
CUDA_CHECK(cudaFree(d_X));
|
7848
|
-
CUDA_CHECK(cudaFree(d_Y));
|
7849
|
-
CUDA_CHECK(cudaFree(d_D));
|
7850
|
-
#endif
|
7851
8655
|
/*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
|
7852
8656
|
|
7853
8657
|
return;
|
@@ -7980,6 +8784,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7980
8784
|
const enum ggml_type type = src0->type;
|
7981
8785
|
quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
|
7982
8786
|
vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
|
8787
|
+
enum ggml_type const vec_dot_type = quantize_fns[type].vec_dot_type;
|
7983
8788
|
|
7984
8789
|
// we don't support permuted src0 or src1
|
7985
8790
|
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
|
@@ -7999,7 +8804,16 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7999
8804
|
// nb01 >= nb00 - src0 is not transposed
|
8000
8805
|
// compute by src0 rows
|
8001
8806
|
|
8002
|
-
#if defined(
|
8807
|
+
#if defined(GGML_USE_CUBLAS)
|
8808
|
+
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
|
8809
|
+
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
8810
|
+
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
8811
|
+
}
|
8812
|
+
return;
|
8813
|
+
}
|
8814
|
+
#endif
|
8815
|
+
|
8816
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
8003
8817
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
8004
8818
|
if (params->ith != 0) {
|
8005
8819
|
return;
|
@@ -8013,39 +8827,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
8013
8827
|
return;
|
8014
8828
|
}
|
8015
8829
|
|
8016
|
-
#if defined(GGML_USE_CUBLAS)
|
8017
|
-
float *d_X = NULL;
|
8018
|
-
float *d_Y = NULL;
|
8019
|
-
float *d_D = NULL;
|
8020
|
-
float *d_Q = NULL;
|
8021
|
-
const float alpha = 1.0f;
|
8022
|
-
const float beta = 0.0f;
|
8023
|
-
const int x_ne = ne01 * ne10;
|
8024
|
-
const int y_ne = ne11 * ne10;
|
8025
|
-
const int d_ne = ne11 * ne01;
|
8026
|
-
|
8027
|
-
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
|
8028
|
-
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
|
8029
|
-
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
8030
|
-
CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
|
8031
|
-
|
8032
|
-
void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
|
8033
|
-
if (type == GGML_TYPE_Q4_0) {
|
8034
|
-
dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
|
8035
|
-
}
|
8036
|
-
else if (type == GGML_TYPE_Q4_1) {
|
8037
|
-
dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
|
8038
|
-
}
|
8039
|
-
else if (type == GGML_TYPE_Q4_2) {
|
8040
|
-
dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
|
8041
|
-
}
|
8042
|
-
else {
|
8043
|
-
GGML_ASSERT(false);
|
8044
|
-
}
|
8045
|
-
#else
|
8046
8830
|
float * const wdata = params->wdata;
|
8047
8831
|
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
|
8048
|
-
#endif
|
8049
8832
|
|
8050
8833
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
8051
8834
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
@@ -8053,14 +8836,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
8053
8836
|
|
8054
8837
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
8055
8838
|
|
8056
|
-
#if defined(
|
8057
|
-
|
8058
|
-
CUDA_CHECK(
|
8059
|
-
cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
|
8060
|
-
GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
|
8061
|
-
|
8062
|
-
dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream);
|
8063
|
-
CUDA_CHECK(cudaGetLastError());
|
8839
|
+
#if defined(GGML_USE_CLBLAST)
|
8840
|
+
const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
|
8064
8841
|
#else
|
8065
8842
|
{
|
8066
8843
|
size_t id = 0;
|
@@ -8068,27 +8845,22 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
8068
8845
|
dequantize_row_q((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
|
8069
8846
|
id += ne00;
|
8070
8847
|
}
|
8848
|
+
|
8849
|
+
assert(id*sizeof(float) <= params->wsize);
|
8071
8850
|
}
|
8851
|
+
|
8072
8852
|
const float * x = wdata;
|
8073
8853
|
#endif
|
8074
8854
|
|
8075
|
-
|
8076
|
-
#if defined(GGML_USE_CUBLAS)
|
8077
|
-
// copy data to device
|
8078
|
-
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
8079
|
-
|
8080
|
-
// compute
|
8081
|
-
CUBLAS_CHECK(
|
8082
|
-
cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
8083
|
-
ne01, ne11, ne10,
|
8084
|
-
&alpha, d_X, ne00,
|
8085
|
-
d_Y, ne10,
|
8086
|
-
&beta, d_D, ne01));
|
8087
|
-
|
8088
|
-
// copy data to host
|
8089
|
-
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
8090
|
-
#else
|
8855
|
+
#if defined(GGML_USE_CLBLAST)
|
8091
8856
|
// zT = y * xT
|
8857
|
+
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
|
8858
|
+
ne11, ne01, ne10,
|
8859
|
+
1.0f, y, ne10,
|
8860
|
+
x, ne10,
|
8861
|
+
0.0f, d, ne01,
|
8862
|
+
type);
|
8863
|
+
#else
|
8092
8864
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
8093
8865
|
ne11, ne01, ne10,
|
8094
8866
|
1.0f, y, ne10,
|
@@ -8098,13 +8870,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
8098
8870
|
}
|
8099
8871
|
}
|
8100
8872
|
|
8101
|
-
#if defined(GGML_USE_CUBLAS)
|
8102
|
-
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
8103
|
-
CUDA_CHECK(cudaFree(d_X));
|
8104
|
-
CUDA_CHECK(cudaFree(d_Y));
|
8105
|
-
CUDA_CHECK(cudaFree(d_D));
|
8106
|
-
CUDA_CHECK(cudaFree(d_Q));
|
8107
|
-
#endif
|
8108
8873
|
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
8109
8874
|
|
8110
8875
|
return;
|
@@ -8113,7 +8878,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
8113
8878
|
|
8114
8879
|
if (params->type == GGML_TASK_INIT) {
|
8115
8880
|
char * wdata = params->wdata;
|
8116
|
-
const size_t row_size = ne10*GGML_TYPE_SIZE[
|
8881
|
+
const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
|
8117
8882
|
|
8118
8883
|
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
8119
8884
|
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
@@ -8144,7 +8909,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
8144
8909
|
const int ir1 = MIN(ir0 + dr, nr);
|
8145
8910
|
|
8146
8911
|
void * wdata = params->wdata;
|
8147
|
-
const size_t row_size = ne00*GGML_TYPE_SIZE[
|
8912
|
+
const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
|
8148
8913
|
|
8149
8914
|
for (int ir = ir0; ir < ir1; ++ir) {
|
8150
8915
|
// src0 indices
|
@@ -8193,8 +8958,10 @@ static void ggml_compute_forward_mul_mat(
|
|
8193
8958
|
case GGML_TYPE_Q4_0:
|
8194
8959
|
case GGML_TYPE_Q4_1:
|
8195
8960
|
case GGML_TYPE_Q4_2:
|
8196
|
-
case
|
8961
|
+
case GGML_TYPE_Q5_0:
|
8962
|
+
case GGML_TYPE_Q5_1:
|
8197
8963
|
case GGML_TYPE_Q8_0:
|
8964
|
+
case GGML_TYPE_Q8_1:
|
8198
8965
|
{
|
8199
8966
|
ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
|
8200
8967
|
} break;
|
@@ -8422,8 +9189,10 @@ static void ggml_compute_forward_get_rows(
|
|
8422
9189
|
case GGML_TYPE_Q4_0:
|
8423
9190
|
case GGML_TYPE_Q4_1:
|
8424
9191
|
case GGML_TYPE_Q4_2:
|
8425
|
-
case
|
9192
|
+
case GGML_TYPE_Q5_0:
|
9193
|
+
case GGML_TYPE_Q5_1:
|
8426
9194
|
case GGML_TYPE_Q8_0:
|
9195
|
+
case GGML_TYPE_Q8_1:
|
8427
9196
|
{
|
8428
9197
|
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
8429
9198
|
} break;
|
@@ -8561,6 +9330,7 @@ static void ggml_compute_forward_soft_max_f32(
|
|
8561
9330
|
|
8562
9331
|
uint16_t scvt;
|
8563
9332
|
for (int i = 0; i < nc; i++) {
|
9333
|
+
//printf("p[%3d] = %8.4f\n", i, p[i]);
|
8564
9334
|
if (p[i] == -INFINITY) {
|
8565
9335
|
p[i] = 0.0f;
|
8566
9336
|
} else {
|
@@ -8603,6 +9373,161 @@ static void ggml_compute_forward_soft_max(
|
|
8603
9373
|
}
|
8604
9374
|
}
|
8605
9375
|
|
9376
|
+
// ggml_compute_forward_alibi
|
9377
|
+
|
9378
|
+
static void ggml_compute_forward_alibi_f32(
|
9379
|
+
const struct ggml_compute_params * params,
|
9380
|
+
const struct ggml_tensor * src0,
|
9381
|
+
const struct ggml_tensor * src1,
|
9382
|
+
struct ggml_tensor * dst) {
|
9383
|
+
assert(params->ith == 0);
|
9384
|
+
assert(src1->type == GGML_TYPE_I32);
|
9385
|
+
assert(ggml_nelements(src1) == 2);
|
9386
|
+
|
9387
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
9388
|
+
return;
|
9389
|
+
}
|
9390
|
+
|
9391
|
+
const int n_past = ((int32_t *) src1->data)[0];
|
9392
|
+
const int n_head = ((int32_t *) src1->data)[1];
|
9393
|
+
|
9394
|
+
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
|
9395
|
+
const int ne1 = src0->ne[1]; // seq_len_without_past
|
9396
|
+
//const int ne2 = src0->ne[2]; // n_head -> this is k
|
9397
|
+
//const int ne3 = src0->ne[3]; // 1 -> bsz
|
9398
|
+
|
9399
|
+
const int n = ggml_nrows(src0);
|
9400
|
+
const int ne2_ne3 = n/ne1; // ne2*ne3
|
9401
|
+
|
9402
|
+
const int nb0 = src0->nb[0];
|
9403
|
+
const int nb1 = src0->nb[1];
|
9404
|
+
const int nb2 = src0->nb[2];
|
9405
|
+
//const int nb3 = src0->nb[3];
|
9406
|
+
|
9407
|
+
assert(nb0 == sizeof(float));
|
9408
|
+
assert(ne1 + n_past == ne0); (void) n_past;
|
9409
|
+
|
9410
|
+
// add alibi to src0 (KQ_scaled)
|
9411
|
+
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
9412
|
+
|
9413
|
+
const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor);
|
9414
|
+
const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor);
|
9415
|
+
|
9416
|
+
for (int i = 0; i < ne0; i++) {
|
9417
|
+
for (int j = 0; j < ne1; j++) {
|
9418
|
+
for (int k = 0; k < ne2_ne3; k++) {
|
9419
|
+
float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
9420
|
+
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
9421
|
+
|
9422
|
+
// TODO: k*nb2 or k*nb3
|
9423
|
+
|
9424
|
+
float m_k;
|
9425
|
+
|
9426
|
+
if (k < n_heads_log2_floor) {
|
9427
|
+
m_k = powf(m0, k + 1);
|
9428
|
+
} else {
|
9429
|
+
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
9430
|
+
}
|
9431
|
+
|
9432
|
+
pdst[0] = (j+1) * m_k + src[0];
|
9433
|
+
}
|
9434
|
+
}
|
9435
|
+
}
|
9436
|
+
}
|
9437
|
+
|
9438
|
+
|
9439
|
+
static void ggml_compute_forward_alibi_f16(
|
9440
|
+
const struct ggml_compute_params * params,
|
9441
|
+
const struct ggml_tensor * src0,
|
9442
|
+
const struct ggml_tensor * src1,
|
9443
|
+
struct ggml_tensor * dst) {
|
9444
|
+
assert(params->ith == 0);
|
9445
|
+
assert(src1->type == GGML_TYPE_I32);
|
9446
|
+
assert(ggml_nelements(src1) == 2);
|
9447
|
+
|
9448
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
9449
|
+
return;
|
9450
|
+
}
|
9451
|
+
|
9452
|
+
const int n_past = ((int32_t *) src1->data)[0];
|
9453
|
+
const int n_head = ((int32_t *) src1->data)[1];
|
9454
|
+
|
9455
|
+
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
|
9456
|
+
const int ne1 = src0->ne[1]; // seq_len_without_past
|
9457
|
+
//const int ne2 = src0->ne[2]; // n_head -> this is k
|
9458
|
+
//const int ne3 = src0->ne[3]; // 1 -> bsz
|
9459
|
+
|
9460
|
+
const int n = ggml_nrows(src0);
|
9461
|
+
const int ne2_ne3 = n/ne1; // ne2*ne3
|
9462
|
+
|
9463
|
+
const int nb0 = src0->nb[0];
|
9464
|
+
const int nb1 = src0->nb[1];
|
9465
|
+
const int nb2 = src0->nb[2];
|
9466
|
+
//const int nb3 = src0->nb[3];
|
9467
|
+
|
9468
|
+
assert(nb0 == sizeof(ggml_fp16_t));
|
9469
|
+
assert(ne1 + n_past == ne0); (void) n_past;
|
9470
|
+
|
9471
|
+
// add alibi to src0 (KQ_scaled)
|
9472
|
+
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
9473
|
+
|
9474
|
+
const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor);
|
9475
|
+
const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor);
|
9476
|
+
|
9477
|
+
for (int i = 0; i < ne0; i++) {
|
9478
|
+
for (int j = 0; j < ne1; j++) {
|
9479
|
+
for (int k = 0; k < ne2_ne3; k++) {
|
9480
|
+
ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
9481
|
+
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
9482
|
+
|
9483
|
+
// TODO: k*nb2 or k*nb3
|
9484
|
+
|
9485
|
+
float m_k;
|
9486
|
+
|
9487
|
+
if (k < n_heads_log2_floor) {
|
9488
|
+
m_k = powf(m0, k + 1);
|
9489
|
+
} else {
|
9490
|
+
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
9491
|
+
}
|
9492
|
+
|
9493
|
+
// we return F32
|
9494
|
+
pdst[0] = (j+1) * m_k + GGML_FP16_TO_FP32(src[0]);
|
9495
|
+
}
|
9496
|
+
}
|
9497
|
+
}
|
9498
|
+
}
|
9499
|
+
|
9500
|
+
static void ggml_compute_forward_alibi(
|
9501
|
+
const struct ggml_compute_params * params,
|
9502
|
+
const struct ggml_tensor * src0,
|
9503
|
+
const struct ggml_tensor * src1,
|
9504
|
+
struct ggml_tensor * dst) {
|
9505
|
+
switch (src0->type) {
|
9506
|
+
case GGML_TYPE_F16:
|
9507
|
+
{
|
9508
|
+
ggml_compute_forward_alibi_f16(params, src0, src1, dst);
|
9509
|
+
} break;
|
9510
|
+
case GGML_TYPE_F32:
|
9511
|
+
{
|
9512
|
+
ggml_compute_forward_alibi_f32(params, src0, src1, dst);
|
9513
|
+
} break;
|
9514
|
+
case GGML_TYPE_Q4_0:
|
9515
|
+
case GGML_TYPE_Q4_1:
|
9516
|
+
case GGML_TYPE_Q4_2:
|
9517
|
+
case GGML_TYPE_Q5_0:
|
9518
|
+
case GGML_TYPE_Q5_1:
|
9519
|
+
case GGML_TYPE_Q8_0:
|
9520
|
+
case GGML_TYPE_Q8_1:
|
9521
|
+
case GGML_TYPE_I8:
|
9522
|
+
case GGML_TYPE_I16:
|
9523
|
+
case GGML_TYPE_I32:
|
9524
|
+
case GGML_TYPE_COUNT:
|
9525
|
+
{
|
9526
|
+
GGML_ASSERT(false);
|
9527
|
+
} break;
|
9528
|
+
}
|
9529
|
+
}
|
9530
|
+
|
8606
9531
|
// ggml_compute_forward_rope
|
8607
9532
|
|
8608
9533
|
static void ggml_compute_forward_rope_f32(
|
@@ -10241,6 +11166,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
10241
11166
|
{
|
10242
11167
|
ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
|
10243
11168
|
} break;
|
11169
|
+
case GGML_OP_ALIBI:
|
11170
|
+
{
|
11171
|
+
ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor);
|
11172
|
+
} break;
|
10244
11173
|
case GGML_OP_CONV_1D_1S:
|
10245
11174
|
{
|
10246
11175
|
ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor);
|
@@ -10443,6 +11372,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
10443
11372
|
{
|
10444
11373
|
GGML_ASSERT(false); // TODO: not implemented
|
10445
11374
|
} break;
|
11375
|
+
case GGML_OP_ALIBI:
|
11376
|
+
{
|
11377
|
+
GGML_ASSERT(false); // TODO: not implemented
|
11378
|
+
} break;
|
10446
11379
|
case GGML_OP_SILU:
|
10447
11380
|
{
|
10448
11381
|
GGML_ASSERT(false); // TODO: not implemented
|
@@ -10920,15 +11853,21 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
10920
11853
|
|
10921
11854
|
size_t cur = 0;
|
10922
11855
|
|
11856
|
+
#if defined(GGML_USE_CUBLAS)
|
11857
|
+
if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
|
11858
|
+
node->n_tasks = 1; // TODO: this actually is doing nothing
|
11859
|
+
// the threads are still spinning
|
11860
|
+
cur = ggml_cuda_mul_mat_get_wsize(node->src0, node->src1, node);
|
11861
|
+
}
|
11862
|
+
else
|
11863
|
+
#endif
|
10923
11864
|
if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
|
10924
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(
|
11865
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
10925
11866
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
10926
11867
|
node->n_tasks = 1; // TODO: this actually is doing nothing
|
10927
11868
|
// the threads are still spinning
|
11869
|
+
// here we need memory just for single 2D matrix from src0
|
10928
11870
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
|
10929
|
-
//printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
|
10930
|
-
//printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
|
10931
|
-
//printf("cur = %zu\n", cur);
|
10932
11871
|
} else {
|
10933
11872
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
|
10934
11873
|
}
|
@@ -10937,15 +11876,21 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
10937
11876
|
#endif
|
10938
11877
|
} else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
|
10939
11878
|
cur = 0;
|
11879
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
11880
|
+
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
11881
|
+
node->n_tasks = 1;
|
11882
|
+
}
|
11883
|
+
#endif
|
10940
11884
|
} else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
|
10941
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(
|
11885
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
10942
11886
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
10943
11887
|
node->n_tasks = 1;
|
10944
11888
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
|
10945
11889
|
} else
|
10946
11890
|
#endif
|
10947
11891
|
{
|
10948
|
-
|
11892
|
+
const enum ggml_type type_q = quantize_fns[node->src0->type].vec_dot_type;
|
11893
|
+
cur = GGML_TYPE_SIZE[type_q]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[type_q];
|
10949
11894
|
}
|
10950
11895
|
} else {
|
10951
11896
|
GGML_ASSERT(false);
|
@@ -10975,6 +11920,10 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
10975
11920
|
{
|
10976
11921
|
node->n_tasks = n_threads;
|
10977
11922
|
} break;
|
11923
|
+
case GGML_OP_ALIBI:
|
11924
|
+
{
|
11925
|
+
node->n_tasks = 1; //TODO
|
11926
|
+
} break;
|
10978
11927
|
case GGML_OP_CONV_1D_1S:
|
10979
11928
|
case GGML_OP_CONV_1D_2S:
|
10980
11929
|
{
|
@@ -11273,9 +12222,9 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
|
|
11273
12222
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
11274
12223
|
struct ggml_tensor * node = cgraph->nodes[i];
|
11275
12224
|
|
11276
|
-
perf_total_per_op_us[node->op] += node->perf_time_us;
|
12225
|
+
perf_total_per_op_us[node->op] += MAX(1, node->perf_time_us);
|
11277
12226
|
|
11278
|
-
GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 ", %" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
|
12227
|
+
GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
|
11279
12228
|
i,
|
11280
12229
|
node->ne[0], node->ne[1], node->ne[2],
|
11281
12230
|
GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
|
@@ -11289,13 +12238,17 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
|
|
11289
12238
|
for (int i = 0; i < cgraph->n_leafs; i++) {
|
11290
12239
|
struct ggml_tensor * node = cgraph->leafs[i];
|
11291
12240
|
|
11292
|
-
GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 "] %8s\n",
|
12241
|
+
GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
|
11293
12242
|
i,
|
11294
12243
|
node->ne[0], node->ne[1],
|
11295
12244
|
GGML_OP_LABEL[node->op]);
|
11296
12245
|
}
|
11297
12246
|
|
11298
12247
|
for (int i = 0; i < GGML_OP_COUNT; i++) {
|
12248
|
+
if (perf_total_per_op_us[i] == 0) {
|
12249
|
+
continue;
|
12250
|
+
}
|
12251
|
+
|
11299
12252
|
GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_LABEL[i], (double) perf_total_per_op_us[i] / 1000.0);
|
11300
12253
|
}
|
11301
12254
|
|
@@ -11358,10 +12311,16 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
|
|
11358
12311
|
snprintf(color, sizeof(color), "white");
|
11359
12312
|
}
|
11360
12313
|
|
11361
|
-
fprintf(fp, " \"%p\" [
|
11362
|
-
style = filled; fillcolor = %s; shape = record;
|
11363
|
-
label=\"
|
11364
|
-
(void *) node, color
|
12314
|
+
fprintf(fp, " \"%p\" [ "
|
12315
|
+
"style = filled; fillcolor = %s; shape = record; "
|
12316
|
+
"label=\"",
|
12317
|
+
(void *) node, color);
|
12318
|
+
|
12319
|
+
if (strlen(node->name) > 0) {
|
12320
|
+
fprintf(fp, "%s |", node->name);
|
12321
|
+
}
|
12322
|
+
|
12323
|
+
fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | <x>%s",
|
11365
12324
|
i, node->ne[0], node->ne[1],
|
11366
12325
|
GGML_OP_SYMBOL[node->op]);
|
11367
12326
|
|
@@ -11377,18 +12336,26 @@ label=\"%d [%" PRId64 ", %" PRId64 "] | <x>%s",
|
|
11377
12336
|
|
11378
12337
|
snprintf(color, sizeof(color), "pink");
|
11379
12338
|
|
12339
|
+
fprintf(fp, " \"%p\" [ "
|
12340
|
+
"style = filled; fillcolor = %s; shape = record; "
|
12341
|
+
"label=\"<x>",
|
12342
|
+
(void *) node, color);
|
12343
|
+
|
12344
|
+
if (strlen(node->name) > 0) {
|
12345
|
+
fprintf(fp, "%s | ", node->name);
|
12346
|
+
}
|
11380
12347
|
if (ggml_nelements(node) == 1) {
|
11381
|
-
|
11382
|
-
|
11383
|
-
|
11384
|
-
|
11385
|
-
|
11386
|
-
|
11387
|
-
style = filled; fillcolor = %s; shape = record; \
|
11388
|
-
label=\"<x>CONST %d [%" PRId64 ", %" PRId64 "]\"; ]\n",
|
11389
|
-
(void *) node, color,
|
11390
|
-
i, node->ne[0], node->ne[1]);
|
12348
|
+
if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
|
12349
|
+
fprintf(fp, "%d", ggml_get_i32_1d(node, 0));
|
12350
|
+
}
|
12351
|
+
else {
|
12352
|
+
fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, 0));
|
12353
|
+
}
|
11391
12354
|
}
|
12355
|
+
else {
|
12356
|
+
fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]);
|
12357
|
+
}
|
12358
|
+
fprintf(fp, "\"; ]\n");
|
11392
12359
|
}
|
11393
12360
|
|
11394
12361
|
for (int i = 0; i < gb->n_nodes; i++) {
|
@@ -12129,7 +13096,7 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
|
|
12129
13096
|
|
12130
13097
|
for (int i = 0; i < nb; i++) {
|
12131
13098
|
for (int l = 0; l < QK4_0; l += 2) {
|
12132
|
-
const uint8_t vi0 = y[i].qs[l/2] &
|
13099
|
+
const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
|
12133
13100
|
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
12134
13101
|
|
12135
13102
|
hist[vi0]++;
|
@@ -12152,7 +13119,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
|
|
12152
13119
|
|
12153
13120
|
for (int i = 0; i < nb; i++) {
|
12154
13121
|
for (int l = 0; l < QK4_1; l += 2) {
|
12155
|
-
const uint8_t vi0 = y[i].qs[l/2] &
|
13122
|
+
const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
|
12156
13123
|
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
12157
13124
|
|
12158
13125
|
hist[vi0]++;
|
@@ -12171,12 +13138,11 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
|
|
12171
13138
|
for (int j = 0; j < n; j += k) {
|
12172
13139
|
block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
|
12173
13140
|
|
12174
|
-
|
12175
|
-
quantize_row_q4_2_rmse(src + j, y, k);
|
13141
|
+
quantize_row_q4_2_reference(src + j, y, k);
|
12176
13142
|
|
12177
13143
|
for (int i = 0; i < nb; i++) {
|
12178
13144
|
for (int l = 0; l < QK4_2; l += 2) {
|
12179
|
-
const uint8_t vi0 = y[i].qs[l/2] &
|
13145
|
+
const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
|
12180
13146
|
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
12181
13147
|
|
12182
13148
|
hist[vi0]++;
|
@@ -12188,19 +13154,56 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
|
|
12188
13154
|
return (n/QK4_2*sizeof(block_q4_2));
|
12189
13155
|
}
|
12190
13156
|
|
12191
|
-
size_t
|
12192
|
-
assert(k %
|
12193
|
-
const int nb = k /
|
13157
|
+
size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) {
|
13158
|
+
assert(k % QK5_0 == 0);
|
13159
|
+
const int nb = k / QK5_0;
|
12194
13160
|
|
12195
13161
|
for (int j = 0; j < n; j += k) {
|
12196
|
-
|
13162
|
+
block_q5_0 * restrict y = (block_q5_0 *)dst + j/QK5_0;
|
12197
13163
|
|
12198
|
-
|
13164
|
+
quantize_row_q5_0_reference(src + j, y, k);
|
12199
13165
|
|
12200
13166
|
for (int i = 0; i < nb; i++) {
|
12201
|
-
|
12202
|
-
|
12203
|
-
|
13167
|
+
uint32_t qh;
|
13168
|
+
memcpy(&qh, &y[i].qh, sizeof(qh));
|
13169
|
+
|
13170
|
+
for (int l = 0; l < QK5_0; l += 2) {
|
13171
|
+
const uint8_t vh0 = ((qh & (1u << (l + 0))) >> (l + 0)) << 4;
|
13172
|
+
const uint8_t vh1 = ((qh & (1u << (l + 1))) >> (l + 1)) << 4;
|
13173
|
+
|
13174
|
+
// cast to 16 bins
|
13175
|
+
const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
|
13176
|
+
const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2;
|
13177
|
+
|
13178
|
+
hist[vi0]++;
|
13179
|
+
hist[vi1]++;
|
13180
|
+
}
|
13181
|
+
}
|
13182
|
+
}
|
13183
|
+
|
13184
|
+
return (n/QK5_0*sizeof(block_q5_0));
|
13185
|
+
}
|
13186
|
+
|
13187
|
+
size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) {
|
13188
|
+
assert(k % QK5_1 == 0);
|
13189
|
+
const int nb = k / QK5_1;
|
13190
|
+
|
13191
|
+
for (int j = 0; j < n; j += k) {
|
13192
|
+
block_q5_1 * restrict y = (block_q5_1 *)dst + j/QK5_1;
|
13193
|
+
|
13194
|
+
quantize_row_q5_1_reference(src + j, y, k);
|
13195
|
+
|
13196
|
+
for (int i = 0; i < nb; i++) {
|
13197
|
+
uint32_t qh;
|
13198
|
+
memcpy(&qh, &y[i].qh, sizeof(qh));
|
13199
|
+
|
13200
|
+
for (int l = 0; l < QK5_1; l += 2) {
|
13201
|
+
const uint8_t vh0 = ((qh & (1u << (l + 0))) >> (l + 0)) << 4;
|
13202
|
+
const uint8_t vh1 = ((qh & (1u << (l + 1))) >> (l + 1)) << 4;
|
13203
|
+
|
13204
|
+
// cast to 16 bins
|
13205
|
+
const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
|
13206
|
+
const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2;
|
12204
13207
|
|
12205
13208
|
hist[vi0]++;
|
12206
13209
|
hist[vi1]++;
|
@@ -12208,7 +13211,28 @@ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t *
|
|
12208
13211
|
}
|
12209
13212
|
}
|
12210
13213
|
|
12211
|
-
return (n/
|
13214
|
+
return (n/QK5_1*sizeof(block_q5_1));
|
13215
|
+
}
|
13216
|
+
|
13217
|
+
size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) {
|
13218
|
+
assert(k % QK8_0 == 0);
|
13219
|
+
const int nb = k / QK8_0;
|
13220
|
+
|
13221
|
+
for (int j = 0; j < n; j += k) {
|
13222
|
+
block_q8_0 * restrict y = (block_q8_0 *)dst + j/QK8_0;
|
13223
|
+
|
13224
|
+
quantize_row_q8_0_reference(src + j, y, k);
|
13225
|
+
|
13226
|
+
for (int i = 0; i < nb; i++) {
|
13227
|
+
for (int l = 0; l < QK8_0; ++l) {
|
13228
|
+
const int8_t vi = y[i].qs[l];
|
13229
|
+
|
13230
|
+
hist[vi/16 + 8]++;
|
13231
|
+
}
|
13232
|
+
}
|
13233
|
+
}
|
13234
|
+
|
13235
|
+
return (n/QK8_0*sizeof(block_q8_0));
|
12212
13236
|
}
|
12213
13237
|
|
12214
13238
|
size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
|
@@ -12232,11 +13256,23 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
|
|
12232
13256
|
block_q4_2 * block = (block_q4_2*)dst + start / QK4_2;
|
12233
13257
|
result = ggml_quantize_q4_2(src + start, block, n, n, hist);
|
12234
13258
|
} break;
|
12235
|
-
case
|
13259
|
+
case GGML_TYPE_Q5_0:
|
13260
|
+
{
|
13261
|
+
GGML_ASSERT(start % QK5_0 == 0);
|
13262
|
+
block_q5_0 * block = (block_q5_0*)dst + start / QK5_0;
|
13263
|
+
result = ggml_quantize_q5_0(src + start, block, n, n, hist);
|
13264
|
+
} break;
|
13265
|
+
case GGML_TYPE_Q5_1:
|
13266
|
+
{
|
13267
|
+
GGML_ASSERT(start % QK5_1 == 0);
|
13268
|
+
block_q5_1 * block = (block_q5_1*)dst + start / QK5_1;
|
13269
|
+
result = ggml_quantize_q5_1(src + start, block, n, n, hist);
|
13270
|
+
} break;
|
13271
|
+
case GGML_TYPE_Q8_0:
|
12236
13272
|
{
|
12237
|
-
GGML_ASSERT(start %
|
12238
|
-
|
12239
|
-
result =
|
13273
|
+
GGML_ASSERT(start % QK8_0 == 0);
|
13274
|
+
block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
|
13275
|
+
result = ggml_quantize_q8_0(src + start, block, n, n, hist);
|
12240
13276
|
} break;
|
12241
13277
|
default:
|
12242
13278
|
assert(false);
|
@@ -12335,7 +13371,7 @@ int ggml_cpu_has_wasm_simd(void) {
|
|
12335
13371
|
}
|
12336
13372
|
|
12337
13373
|
int ggml_cpu_has_blas(void) {
|
12338
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
13374
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
12339
13375
|
return 1;
|
12340
13376
|
#else
|
12341
13377
|
return 0;
|
@@ -12350,6 +13386,18 @@ int ggml_cpu_has_cublas(void) {
|
|
12350
13386
|
#endif
|
12351
13387
|
}
|
12352
13388
|
|
13389
|
+
int ggml_cpu_has_clblast(void) {
|
13390
|
+
#if defined(GGML_USE_CLBLAST)
|
13391
|
+
return 1;
|
13392
|
+
#else
|
13393
|
+
return 0;
|
13394
|
+
#endif
|
13395
|
+
}
|
13396
|
+
|
13397
|
+
int ggml_cpu_has_gpublas(void) {
|
13398
|
+
return ggml_cpu_has_cublas() || ggml_cpu_has_clblast();
|
13399
|
+
}
|
13400
|
+
|
12353
13401
|
int ggml_cpu_has_sse3(void) {
|
12354
13402
|
#if defined(__SSE3__)
|
12355
13403
|
return 1;
|