llama_cpp 0.0.4 → 0.0.6
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +3 -2
- data/ext/llama_cpp/extconf.rb +26 -0
- data/ext/llama_cpp/llama_cpp.cpp +106 -0
- data/ext/llama_cpp/src/ggml-cuda.h +12 -0
- data/ext/llama_cpp/src/ggml.c +2038 -895
- data/ext/llama_cpp/src/ggml.h +21 -1
- data/ext/llama_cpp/src/llama.cpp +376 -62
- data/ext/llama_cpp/src/llama.h +17 -1
- data/ext/llama_cpp/src/llama_util.h +22 -16
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +3 -3
- data/sig/llama_cpp.rbs +13 -1
- metadata +3 -2
data/ext/llama_cpp/src/ggml.c
CHANGED
@@ -19,6 +19,7 @@
|
|
19
19
|
#include <inttypes.h>
|
20
20
|
#include <stdio.h>
|
21
21
|
#include <float.h>
|
22
|
+
#include <limits.h>
|
22
23
|
|
23
24
|
// if C99 - static_assert is noop
|
24
25
|
// ref: https://stackoverflow.com/a/53923785/4039976
|
@@ -118,7 +119,16 @@ typedef void* thread_ret_t;
|
|
118
119
|
#define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
|
119
120
|
#define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
|
120
121
|
#else
|
121
|
-
|
122
|
+
inline static void* ggml_aligned_malloc(size_t size) {
|
123
|
+
void* aligned_memory = NULL;
|
124
|
+
int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
|
125
|
+
if (result != 0) {
|
126
|
+
// Handle allocation failure
|
127
|
+
return NULL;
|
128
|
+
}
|
129
|
+
return aligned_memory;
|
130
|
+
}
|
131
|
+
#define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size)
|
122
132
|
#define GGML_ALIGNED_FREE(ptr) free(ptr)
|
123
133
|
#endif
|
124
134
|
|
@@ -133,10 +143,49 @@ typedef void* thread_ret_t;
|
|
133
143
|
} \
|
134
144
|
} while (0)
|
135
145
|
|
136
|
-
#
|
146
|
+
#if defined(GGML_USE_ACCELERATE)
|
137
147
|
#include <Accelerate/Accelerate.h>
|
138
|
-
#elif GGML_USE_OPENBLAS
|
148
|
+
#elif defined(GGML_USE_OPENBLAS)
|
139
149
|
#include <cblas.h>
|
150
|
+
#elif defined(GGML_USE_CUBLAS)
|
151
|
+
#include <cublas_v2.h>
|
152
|
+
#include <cuda_runtime.h>
|
153
|
+
#include "ggml-cuda.h"
|
154
|
+
|
155
|
+
#define CUDA_CHECK(err) \
|
156
|
+
do { \
|
157
|
+
cudaError_t err_ = (err); \
|
158
|
+
if (err_ != cudaSuccess) { \
|
159
|
+
printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
|
160
|
+
cudaGetErrorString(err_)); \
|
161
|
+
exit(1); \
|
162
|
+
} \
|
163
|
+
} while (0)
|
164
|
+
|
165
|
+
#define CUBLAS_CHECK(err) \
|
166
|
+
do { \
|
167
|
+
cublasStatus_t err_ = (err); \
|
168
|
+
if (err_ != CUBLAS_STATUS_SUCCESS) { \
|
169
|
+
printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
|
170
|
+
exit(1); \
|
171
|
+
} \
|
172
|
+
} while (0)
|
173
|
+
|
174
|
+
static cublasHandle_t cublasH = NULL;
|
175
|
+
static cudaStream_t cudaStream = NULL;
|
176
|
+
static void init_cublas(void) {
|
177
|
+
if (cublasH == NULL) {
|
178
|
+
// create cublas handle, bind a stream
|
179
|
+
CUBLAS_CHECK(cublasCreate(&cublasH));
|
180
|
+
|
181
|
+
CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
|
182
|
+
|
183
|
+
CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
|
184
|
+
|
185
|
+
// configure logging to stdout
|
186
|
+
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
|
187
|
+
}
|
188
|
+
}
|
140
189
|
#endif
|
141
190
|
|
142
191
|
#undef MIN
|
@@ -418,14 +467,30 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
|
|
418
467
|
// quantization
|
419
468
|
//
|
420
469
|
|
421
|
-
#
|
470
|
+
#if __AVX__ || __AVX2__ || __AVX512F__
|
471
|
+
// Unpack 16 4-bit fields into 16 bytes
|
472
|
+
// The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
|
473
|
+
static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
|
474
|
+
{
|
475
|
+
// Load 8 bytes from memory
|
476
|
+
__m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
|
477
|
+
|
478
|
+
// Expand bytes into uint16_t values
|
479
|
+
__m128i bytes = _mm_cvtepu8_epi16( tmp );
|
480
|
+
|
481
|
+
// Unpack values into individual bytes
|
482
|
+
const __m128i lowMask = _mm_set1_epi8( 0xF );
|
483
|
+
__m128i high = _mm_andnot_si128( lowMask, bytes );
|
484
|
+
__m128i low = _mm_and_si128( lowMask, bytes );
|
485
|
+
high = _mm_slli_epi16( high, 4 );
|
486
|
+
bytes = _mm_or_si128( low, high );
|
487
|
+
return bytes;
|
488
|
+
}
|
422
489
|
|
423
|
-
// AVX routines provided by GH user Const-me
|
424
|
-
// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
|
425
490
|
#if __AVX2__ || __AVX512F__
|
426
491
|
// Unpack 32 4-bit fields into 32 bytes
|
427
492
|
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
428
|
-
static inline __m256i
|
493
|
+
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
|
429
494
|
{
|
430
495
|
// Load 16 bytes from memory
|
431
496
|
__m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
|
@@ -456,24 +521,7 @@ static inline __m128i packNibbles( __m256i bytes )
|
|
456
521
|
__m128i r1 = _mm256_extracti128_si256( bytes, 1 );
|
457
522
|
return _mm_packus_epi16( r0, r1 );
|
458
523
|
}
|
459
|
-
#
|
460
|
-
static inline __m128i bytesFromNibbles( const uint8_t* rsi )
|
461
|
-
{
|
462
|
-
// Load 8 bytes from memory
|
463
|
-
__m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
|
464
|
-
|
465
|
-
// Expand bytes into uint16_t values
|
466
|
-
__m128i bytes = _mm_cvtepu8_epi16( tmp );
|
467
|
-
|
468
|
-
// Unpack values into individual bytes
|
469
|
-
const __m128i lowMask = _mm_set1_epi8( 0xF );
|
470
|
-
__m128i high = _mm_andnot_si128( lowMask, bytes );
|
471
|
-
__m128i low = _mm_and_si128( lowMask, bytes );
|
472
|
-
high = _mm_slli_epi16( high, 4 );
|
473
|
-
bytes = _mm_or_si128( low, high );
|
474
|
-
return bytes;
|
475
|
-
}
|
476
|
-
|
524
|
+
#else
|
477
525
|
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
478
526
|
{
|
479
527
|
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
|
@@ -490,6 +538,7 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
|
490
538
|
return _mm_packus_epi16( bytes1, bytes2);
|
491
539
|
}
|
492
540
|
#endif
|
541
|
+
#endif // __AVX__ || __AVX2__ || __AVX512F__
|
493
542
|
|
494
543
|
#if __ARM_NEON
|
495
544
|
|
@@ -507,6 +556,18 @@ inline static uint16_t vaddvq_u8(uint8x16_t v) {
|
|
507
556
|
(uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
|
508
557
|
}
|
509
558
|
|
559
|
+
inline static int16_t vaddvq_s8(int8x16_t v) {
|
560
|
+
return
|
561
|
+
(int16_t)vgetq_lane_s8(v, 0) + (int16_t)vgetq_lane_s8(v, 1) +
|
562
|
+
(int16_t)vgetq_lane_s8(v, 2) + (int16_t)vgetq_lane_s8(v, 3) +
|
563
|
+
(int16_t)vgetq_lane_s8(v, 4) + (int16_t)vgetq_lane_s8(v, 5) +
|
564
|
+
(int16_t)vgetq_lane_s8(v, 6) + (int16_t)vgetq_lane_s8(v, 7) +
|
565
|
+
(int16_t)vgetq_lane_s8(v, 8) + (int16_t)vgetq_lane_s8(v, 9) +
|
566
|
+
(int16_t)vgetq_lane_s8(v, 10) + (int16_t)vgetq_lane_s8(v, 11) +
|
567
|
+
(int16_t)vgetq_lane_s8(v, 12) + (int16_t)vgetq_lane_s8(v, 13) +
|
568
|
+
(int16_t)vgetq_lane_s8(v, 14) + (int16_t)vgetq_lane_s8(v, 15);
|
569
|
+
}
|
570
|
+
|
510
571
|
inline static int32_t vaddvq_s16(int16x8_t v) {
|
511
572
|
return
|
512
573
|
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
|
@@ -531,68 +592,88 @@ inline static float vaddvq_f32(float32x4_t v) {
|
|
531
592
|
return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
|
532
593
|
}
|
533
594
|
|
534
|
-
|
595
|
+
float vminvq_f32(float32x4_t v) {
|
535
596
|
return
|
536
597
|
MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
|
537
598
|
MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
|
538
599
|
}
|
539
600
|
|
540
|
-
|
601
|
+
float vmaxvq_f32(float32x4_t v) {
|
541
602
|
return
|
542
603
|
MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
|
543
604
|
MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
|
544
605
|
}
|
545
606
|
|
546
|
-
|
607
|
+
int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
|
547
608
|
return vget_low_s8(vcombine_s8(a, b));
|
548
609
|
}
|
549
610
|
|
550
|
-
|
611
|
+
int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
|
551
612
|
return vget_high_s8(vcombine_s8(a, b));
|
552
613
|
}
|
553
614
|
|
554
|
-
|
615
|
+
uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
|
555
616
|
return vget_low_u8(vcombine_u8(a, b));
|
556
617
|
}
|
557
618
|
|
558
|
-
|
619
|
+
uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
|
559
620
|
return vget_high_u8(vcombine_u8(a, b));
|
560
621
|
}
|
561
622
|
|
562
623
|
#endif
|
563
624
|
#endif
|
564
625
|
|
565
|
-
|
566
|
-
|
567
|
-
// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
|
626
|
+
|
627
|
+
#define QK4_0 32
|
568
628
|
typedef struct {
|
569
|
-
float d;
|
570
|
-
uint8_t qs[
|
629
|
+
float d; // delta
|
630
|
+
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
571
631
|
} block_q4_0;
|
572
|
-
static_assert(sizeof(block_q4_0) == sizeof(float) +
|
632
|
+
static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
|
573
633
|
|
574
|
-
|
575
|
-
// blocks of QK elements
|
576
|
-
// represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
|
634
|
+
#define QK4_1 32
|
577
635
|
typedef struct {
|
578
|
-
float d;
|
579
|
-
float m;
|
580
|
-
uint8_t qs[
|
636
|
+
float d; // delta
|
637
|
+
float m; // min
|
638
|
+
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
581
639
|
} block_q4_1;
|
582
|
-
static_assert(sizeof(block_q4_1) == sizeof(float)
|
640
|
+
static_assert(sizeof(block_q4_1) == 2 * sizeof(float) + QK4_1 / 2, "wrong q4_1 block size/padding");
|
641
|
+
|
642
|
+
#define QK4_2 16
|
643
|
+
typedef struct {
|
644
|
+
ggml_fp16_t d; // delta
|
645
|
+
uint8_t qs[QK4_2 / 2]; // nibbles / quants
|
646
|
+
} block_q4_2;
|
647
|
+
static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
|
648
|
+
|
649
|
+
#define QK4_3 16
|
650
|
+
typedef struct {
|
651
|
+
ggml_fp16_t d; // delta
|
652
|
+
ggml_fp16_t m; // min
|
653
|
+
uint8_t qs[QK4_3 / 2]; // nibbles / quants
|
654
|
+
} block_q4_3;
|
655
|
+
static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
|
656
|
+
|
657
|
+
#define QK8_0 32
|
658
|
+
typedef struct {
|
659
|
+
float d; // delta
|
660
|
+
int8_t qs[QK8_0]; // quants
|
661
|
+
} block_q8_0;
|
662
|
+
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
|
663
|
+
|
583
664
|
|
584
665
|
// reference implementation for deterministic creation of model files
|
585
666
|
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
|
586
|
-
assert(k %
|
587
|
-
const int nb = k /
|
667
|
+
assert(k % QK4_0 == 0);
|
668
|
+
const int nb = k / QK4_0;
|
588
669
|
|
589
|
-
uint8_t pp[
|
670
|
+
uint8_t pp[QK4_0/2];
|
590
671
|
|
591
672
|
for (int i = 0; i < nb; i++) {
|
592
673
|
float amax = 0.0f; // absolute max
|
593
674
|
|
594
|
-
for (int l = 0; l <
|
595
|
-
const float v = x[i*
|
675
|
+
for (int l = 0; l < QK4_0; l++) {
|
676
|
+
const float v = x[i*QK4_0 + l];
|
596
677
|
amax = MAX(amax, fabsf(v));
|
597
678
|
}
|
598
679
|
|
@@ -601,9 +682,9 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
|
|
601
682
|
|
602
683
|
y[i].d = d;
|
603
684
|
|
604
|
-
for (int l = 0; l <
|
605
|
-
const float v0 = x[i*
|
606
|
-
const float v1 = x[i*
|
685
|
+
for (int l = 0; l < QK4_0; l += 2) {
|
686
|
+
const float v0 = x[i*QK4_0 + l + 0]*id;
|
687
|
+
const float v1 = x[i*QK4_0 + l + 1]*id;
|
607
688
|
|
608
689
|
const uint8_t vi0 = (int8_t)roundf(v0) + 8;
|
609
690
|
const uint8_t vi1 = (int8_t)roundf(v1) + 8;
|
@@ -619,8 +700,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
|
|
619
700
|
}
|
620
701
|
|
621
702
|
static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int k) {
|
622
|
-
assert(k %
|
623
|
-
const int nb = k /
|
703
|
+
assert(k % QK4_0 == 0);
|
704
|
+
const int nb = k / QK4_0;
|
624
705
|
|
625
706
|
block_q4_0 * restrict y = vy;
|
626
707
|
|
@@ -870,19 +951,19 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
870
951
|
}
|
871
952
|
|
872
953
|
static void quantize_row_q4_1_reference(const float * restrict x, void * restrict vy, int k) {
|
873
|
-
assert(k %
|
874
|
-
const int nb = k /
|
954
|
+
assert(k % QK4_1 == 0);
|
955
|
+
const int nb = k / QK4_1;
|
875
956
|
|
876
957
|
block_q4_1 * restrict y = vy;
|
877
958
|
|
878
|
-
uint8_t pp[
|
959
|
+
uint8_t pp[QK4_1/2];
|
879
960
|
|
880
961
|
for (int i = 0; i < nb; i++) {
|
881
962
|
float min = FLT_MAX;
|
882
963
|
float max = -FLT_MAX;
|
883
964
|
|
884
|
-
for (int l = 0; l <
|
885
|
-
const float v = x[i*
|
965
|
+
for (int l = 0; l < QK4_1; l++) {
|
966
|
+
const float v = x[i*QK4_1 + l];
|
886
967
|
if (v < min) min = v;
|
887
968
|
if (v > max) max = v;
|
888
969
|
}
|
@@ -893,9 +974,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
|
|
893
974
|
y[i].d = d;
|
894
975
|
y[i].m = min;
|
895
976
|
|
896
|
-
for (int l = 0; l <
|
897
|
-
const float v0 = (x[i*
|
898
|
-
const float v1 = (x[i*
|
977
|
+
for (int l = 0; l < QK4_1; l += 2) {
|
978
|
+
const float v0 = (x[i*QK4_1 + l + 0] - min)*id;
|
979
|
+
const float v1 = (x[i*QK4_1 + l + 1] - min)*id;
|
899
980
|
|
900
981
|
const uint8_t vi0 = roundf(v0);
|
901
982
|
const uint8_t vi1 = roundf(v1);
|
@@ -911,9 +992,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
|
|
911
992
|
}
|
912
993
|
|
913
994
|
static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
|
914
|
-
assert(k %
|
995
|
+
assert(k % QK4_1 == 0);
|
915
996
|
|
916
|
-
const int nb = k /
|
997
|
+
const int nb = k / QK4_1;
|
917
998
|
|
918
999
|
block_q4_1 * restrict y = vy;
|
919
1000
|
|
@@ -997,7 +1078,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
|
|
997
1078
|
float32x4_t minv[8];
|
998
1079
|
float32x4_t maxv[8];
|
999
1080
|
|
1000
|
-
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*
|
1081
|
+
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK4_1 + 4*l);
|
1001
1082
|
|
1002
1083
|
for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
|
1003
1084
|
for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
|
@@ -1033,9 +1114,327 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
|
|
1033
1114
|
#endif
|
1034
1115
|
}
|
1035
1116
|
|
1117
|
+
// reference implementation for deterministic creation of model files
|
1118
|
+
static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * restrict y, int k) {
|
1119
|
+
assert(k % QK4_2 == 0);
|
1120
|
+
|
1121
|
+
const int nb = k / QK4_2;
|
1122
|
+
|
1123
|
+
for (int i = 0; i < nb; i++) {
|
1124
|
+
float amax = 0.0f; // absolute max
|
1125
|
+
|
1126
|
+
for (int l = 0; l < QK4_2; l++) {
|
1127
|
+
const float v = x[i*QK4_2 + l];
|
1128
|
+
amax = MAX(amax, fabsf(v));
|
1129
|
+
}
|
1130
|
+
|
1131
|
+
const float d = amax / ((1 << 3) - 1);
|
1132
|
+
|
1133
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1134
|
+
|
1135
|
+
y[i].d = GGML_FP32_TO_FP16(d);
|
1136
|
+
|
1137
|
+
for (int l = 0; l < QK4_2; l += 2) {
|
1138
|
+
const float v0 = x[i*QK4_2 + l + 0]*id;
|
1139
|
+
const float v1 = x[i*QK4_2 + l + 1]*id;
|
1140
|
+
|
1141
|
+
const uint8_t vi0 = (uint8_t)(v0 + 8.5f);
|
1142
|
+
const uint8_t vi1 = (uint8_t)(v1 + 8.5f);
|
1143
|
+
|
1144
|
+
assert(vi0 < 16);
|
1145
|
+
assert(vi1 < 16);
|
1146
|
+
|
1147
|
+
y[i].qs[l/2] = vi0 | (vi1 << 4);
|
1148
|
+
}
|
1149
|
+
}
|
1150
|
+
}
|
1151
|
+
|
1152
|
+
static inline int nearest_int(float fval) {
|
1153
|
+
assert(fval <= 4194303.f);
|
1154
|
+
float val = fval + 12582912.f;
|
1155
|
+
int i; memcpy(&i, &val, sizeof(int));
|
1156
|
+
return (i & 0x007fffff) - 0x00400000;
|
1157
|
+
}
|
1158
|
+
|
1159
|
+
static float kquantize_q4_with_bounds(int n, int nmin, int nmax, const float * restrict X, int nCandidates,
|
1160
|
+
const float * restrict candidates, int8_t * restrict L) {
|
1161
|
+
assert (nmin >= INT8_MIN);
|
1162
|
+
assert (nmax <= INT8_MAX);
|
1163
|
+
float amax = 0;
|
1164
|
+
for (int i=0; i<n; ++i) amax = MAX(amax, fabsf(X[i]));
|
1165
|
+
if (!amax) { // all zero
|
1166
|
+
for (int i=0; i<n; ++i) L[i] = 0;
|
1167
|
+
return 1.f;
|
1168
|
+
}
|
1169
|
+
float best = 0, bestScale = 0;
|
1170
|
+
for (int si=0; si<nCandidates; ++si) {
|
1171
|
+
float iscale = candidates[si]/amax;
|
1172
|
+
float sumlxP = 0; int suml2P = 0;
|
1173
|
+
float sumlxM = 0; int suml2M = 0;
|
1174
|
+
for (int i=0; i<n; ++i) {
|
1175
|
+
int l = nearest_int(iscale*X[i]);
|
1176
|
+
int lp = MAX(nmin, MIN(nmax, +l));
|
1177
|
+
int lm = MAX(nmin, MIN(nmax, -l));
|
1178
|
+
sumlxP += X[i]*lp; suml2P += lp*lp;
|
1179
|
+
sumlxM += X[i]*lm; suml2M += lm*lm;
|
1180
|
+
}
|
1181
|
+
float sumlxP2 = sumlxP*sumlxP;
|
1182
|
+
float sumlxM2 = sumlxM*sumlxM;
|
1183
|
+
if (sumlxP2*suml2M > sumlxM2*suml2P) {
|
1184
|
+
if (sumlxP2 > best*suml2P) {
|
1185
|
+
best = sumlxP2/suml2P; bestScale = iscale;
|
1186
|
+
}
|
1187
|
+
} else {
|
1188
|
+
if (sumlxM2 > best*suml2M) {
|
1189
|
+
best = sumlxM2/suml2M; bestScale = -iscale;
|
1190
|
+
}
|
1191
|
+
}
|
1192
|
+
}
|
1193
|
+
float sumlx = 0; int suml2 = 0;
|
1194
|
+
for (int i=0; i<n; ++i) {
|
1195
|
+
int l = nearest_int(bestScale*X[i]);
|
1196
|
+
l = MAX(nmin, MIN(nmax, l));
|
1197
|
+
sumlx += X[i]*l; suml2 += l*l;
|
1198
|
+
L[i] = l;
|
1199
|
+
}
|
1200
|
+
float scale = sumlx/suml2;
|
1201
|
+
return scale;
|
1202
|
+
}
|
1203
|
+
|
1204
|
+
static void quantize_row_q4_2_rmse(const float * restrict x, block_q4_2 * restrict y, int k) {
|
1205
|
+
#define CANDIDATE_COUNT 8
|
1206
|
+
static const float candidates[CANDIDATE_COUNT] = { +8.7f, +8.3f, +8.1f, +7.8f, +7.3f, +7.0f, +6.3f, +5.7f };
|
1207
|
+
assert(k % QK4_2 == 0);
|
1208
|
+
|
1209
|
+
int8_t L[QK4_2];
|
1210
|
+
|
1211
|
+
const int nb = k / QK4_2;
|
1212
|
+
|
1213
|
+
for (int i = 0; i < nb; i++) {
|
1214
|
+
float scale = kquantize_q4_with_bounds(QK4_2, -8, 7, x, CANDIDATE_COUNT, candidates, L);
|
1215
|
+
y[i].d = GGML_FP32_TO_FP16(scale);
|
1216
|
+
|
1217
|
+
for (int l = 0; l < QK4_2; l += 2) {
|
1218
|
+
const uint8_t vi0 = (uint8_t)(L[l+0] + 8);
|
1219
|
+
const uint8_t vi1 = (uint8_t)(L[l+1] + 8);
|
1220
|
+
|
1221
|
+
assert(vi0 < 16);
|
1222
|
+
assert(vi1 < 16);
|
1223
|
+
|
1224
|
+
y[i].qs[l/2] = vi0 | (vi1 << 4);
|
1225
|
+
}
|
1226
|
+
|
1227
|
+
x += QK4_2;
|
1228
|
+
}
|
1229
|
+
}
|
1230
|
+
|
1231
|
+
static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
|
1232
|
+
assert(k % QK4_2 == 0);
|
1233
|
+
|
1234
|
+
block_q4_2 * restrict y = vy;
|
1235
|
+
|
1236
|
+
//quantize_row_q4_2_reference(x, y, k);
|
1237
|
+
// This produces the exact same format, just better match to the input floats ("better" as measured by RMSE)
|
1238
|
+
quantize_row_q4_2_rmse(x, y, k);
|
1239
|
+
}
|
1240
|
+
|
1241
|
+
static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) {
|
1242
|
+
assert(k % QK4_3 == 0);
|
1243
|
+
const int nb = k / QK4_3;
|
1244
|
+
|
1245
|
+
for (int i = 0; i < nb; i++) {
|
1246
|
+
float min = FLT_MAX;
|
1247
|
+
float max = -FLT_MAX;
|
1248
|
+
|
1249
|
+
for (int l = 0; l < QK4_3; l++) {
|
1250
|
+
const float v = x[i*QK4_3 + l];
|
1251
|
+
if (v < min) min = v;
|
1252
|
+
if (v > max) max = v;
|
1253
|
+
}
|
1254
|
+
|
1255
|
+
const float d = (max - min) / ((1 << 4) - 1);
|
1256
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1257
|
+
|
1258
|
+
y[i].d = GGML_FP32_TO_FP16(d);
|
1259
|
+
y[i].m = GGML_FP32_TO_FP16(min);
|
1260
|
+
|
1261
|
+
for (int l = 0; l < QK4_3; l += 2) {
|
1262
|
+
const float v0 = (x[i*QK4_3 + l + 0] - min)*id;
|
1263
|
+
const float v1 = (x[i*QK4_3 + l + 1] - min)*id;
|
1264
|
+
|
1265
|
+
const uint8_t vi0 = (int) (v0 + 0.5f);
|
1266
|
+
const uint8_t vi1 = (int) (v1 + 0.5f);
|
1267
|
+
|
1268
|
+
assert(vi0 < 16);
|
1269
|
+
assert(vi1 < 16);
|
1270
|
+
|
1271
|
+
y[i].qs[l/2] = vi0 | (vi1 << 4);
|
1272
|
+
}
|
1273
|
+
}
|
1274
|
+
}
|
1275
|
+
|
1276
|
+
static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int k) {
|
1277
|
+
assert(k % QK4_3 == 0);
|
1278
|
+
|
1279
|
+
block_q4_3 * restrict y = vy;
|
1280
|
+
|
1281
|
+
quantize_row_q4_3_reference(x, y, k);
|
1282
|
+
}
|
1283
|
+
|
1284
|
+
// reference implementation for deterministic creation of model files
|
1285
|
+
static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
|
1286
|
+
assert(k % QK8_0 == 0);
|
1287
|
+
const int nb = k / QK8_0;
|
1288
|
+
|
1289
|
+
for (int i = 0; i < nb; i++) {
|
1290
|
+
float amax = 0.0f; // absolute max
|
1291
|
+
|
1292
|
+
for (int l = 0; l < QK8_0; l++) {
|
1293
|
+
const float v = x[i*QK8_0 + l];
|
1294
|
+
amax = MAX(amax, fabsf(v));
|
1295
|
+
}
|
1296
|
+
|
1297
|
+
const float d = amax / ((1 << 7) - 1);
|
1298
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1299
|
+
|
1300
|
+
y[i].d = d;
|
1301
|
+
|
1302
|
+
for (int l = 0; l < QK8_0; ++l) {
|
1303
|
+
const float v = x[i*QK8_0 + l]*id;
|
1304
|
+
y[i].qs[l] = roundf(v);
|
1305
|
+
}
|
1306
|
+
}
|
1307
|
+
}
|
1308
|
+
|
1309
|
+
static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
|
1310
|
+
assert(k % QK8_0 == 0);
|
1311
|
+
const int nb = k / QK8_0;
|
1312
|
+
|
1313
|
+
block_q8_0 * restrict y = vy;
|
1314
|
+
|
1315
|
+
#if defined(__ARM_NEON)
|
1316
|
+
for (int i = 0; i < nb; i++) {
|
1317
|
+
float32x4_t srcv [8];
|
1318
|
+
float32x4_t asrcv[8];
|
1319
|
+
float32x4_t amaxv[8];
|
1320
|
+
|
1321
|
+
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
|
1322
|
+
for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
|
1323
|
+
|
1324
|
+
for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
|
1325
|
+
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
|
1326
|
+
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
|
1327
|
+
|
1328
|
+
const float amax = vmaxvq_f32(amaxv[0]);
|
1329
|
+
|
1330
|
+
const float d = amax / ((1 << 7) - 1);
|
1331
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1332
|
+
|
1333
|
+
y[i].d = d;
|
1334
|
+
|
1335
|
+
for (int l = 0; l < 8; l++) {
|
1336
|
+
const float32x4_t v = vmulq_n_f32(srcv[l], id);
|
1337
|
+
const int32x4_t vi = vcvtnq_s32_f32(v);
|
1338
|
+
|
1339
|
+
y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
|
1340
|
+
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
|
1341
|
+
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
|
1342
|
+
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
|
1343
|
+
}
|
1344
|
+
}
|
1345
|
+
#elif defined(__AVX2__) || defined(__AVX__)
|
1346
|
+
for (int i = 0; i < nb; i++) {
|
1347
|
+
// Load elements into 4 AVX vectors
|
1348
|
+
__m256 v0 = _mm256_loadu_ps( x );
|
1349
|
+
__m256 v1 = _mm256_loadu_ps( x + 8 );
|
1350
|
+
__m256 v2 = _mm256_loadu_ps( x + 16 );
|
1351
|
+
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
1352
|
+
x += 32;
|
1353
|
+
|
1354
|
+
// Compute max(abs(e)) for the block
|
1355
|
+
const __m256 signBit = _mm256_set1_ps( -0.0f );
|
1356
|
+
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
|
1357
|
+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
|
1358
|
+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
|
1359
|
+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
|
1360
|
+
|
1361
|
+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
|
1362
|
+
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
1363
|
+
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
1364
|
+
const float maxScalar = _mm_cvtss_f32( max4 );
|
1365
|
+
|
1366
|
+
// Quantize these floats
|
1367
|
+
const float d = maxScalar / 127.f;
|
1368
|
+
y[i].d = d;
|
1369
|
+
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
|
1370
|
+
const __m256 mul = _mm256_set1_ps( id );
|
1371
|
+
|
1372
|
+
// Apply the multiplier
|
1373
|
+
v0 = _mm256_mul_ps( v0, mul );
|
1374
|
+
v1 = _mm256_mul_ps( v1, mul );
|
1375
|
+
v2 = _mm256_mul_ps( v2, mul );
|
1376
|
+
v3 = _mm256_mul_ps( v3, mul );
|
1377
|
+
|
1378
|
+
// Round to nearest integer
|
1379
|
+
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
|
1380
|
+
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
|
1381
|
+
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
|
1382
|
+
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
|
1383
|
+
|
1384
|
+
// Convert floats to integers
|
1385
|
+
__m256i i0 = _mm256_cvtps_epi32( v0 );
|
1386
|
+
__m256i i1 = _mm256_cvtps_epi32( v1 );
|
1387
|
+
__m256i i2 = _mm256_cvtps_epi32( v2 );
|
1388
|
+
__m256i i3 = _mm256_cvtps_epi32( v3 );
|
1389
|
+
|
1390
|
+
#if defined(__AVX2__)
|
1391
|
+
// Convert int32 to int16
|
1392
|
+
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
|
1393
|
+
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
|
1394
|
+
// Convert int16 to int8
|
1395
|
+
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
|
1396
|
+
|
1397
|
+
// We got our precious signed bytes, but the order is now wrong
|
1398
|
+
// These AVX2 pack instructions process 16-byte pieces independently
|
1399
|
+
// The following instruction is fixing the order
|
1400
|
+
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
|
1401
|
+
i0 = _mm256_permutevar8x32_epi32( i0, perm );
|
1402
|
+
|
1403
|
+
_mm256_storeu_si256((__m256i *)y[i].qs, i0);
|
1404
|
+
#else
|
1405
|
+
// Since we don't have in AVX some necessary functions,
|
1406
|
+
// we split the registers in half and call AVX2 analogs from SSE
|
1407
|
+
__m128i ni0 = _mm256_castsi256_si128( i0 );
|
1408
|
+
__m128i ni1 = _mm256_extractf128_si256( i0, 1);
|
1409
|
+
__m128i ni2 = _mm256_castsi256_si128( i1 );
|
1410
|
+
__m128i ni3 = _mm256_extractf128_si256( i1, 1);
|
1411
|
+
__m128i ni4 = _mm256_castsi256_si128( i2 );
|
1412
|
+
__m128i ni5 = _mm256_extractf128_si256( i2, 1);
|
1413
|
+
__m128i ni6 = _mm256_castsi256_si128( i3 );
|
1414
|
+
__m128i ni7 = _mm256_extractf128_si256( i3, 1);
|
1415
|
+
|
1416
|
+
// Convert int32 to int16
|
1417
|
+
ni0 = _mm_packs_epi32( ni0, ni1 );
|
1418
|
+
ni2 = _mm_packs_epi32( ni2, ni3 );
|
1419
|
+
ni4 = _mm_packs_epi32( ni4, ni5 );
|
1420
|
+
ni6 = _mm_packs_epi32( ni6, ni7 );
|
1421
|
+
// Convert int16 to int8
|
1422
|
+
ni0 = _mm_packs_epi16( ni0, ni2 );
|
1423
|
+
ni4 = _mm_packs_epi16( ni4, ni6 );
|
1424
|
+
|
1425
|
+
_mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
|
1426
|
+
_mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
|
1427
|
+
#endif
|
1428
|
+
}
|
1429
|
+
#else
|
1430
|
+
// scalar
|
1431
|
+
quantize_row_q8_0_reference(x, y, k);
|
1432
|
+
#endif
|
1433
|
+
}
|
1434
|
+
|
1036
1435
|
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
|
1037
|
-
assert(k %
|
1038
|
-
const int nb = k /
|
1436
|
+
assert(k % QK4_0 == 0);
|
1437
|
+
const int nb = k / QK4_0;
|
1039
1438
|
|
1040
1439
|
const block_q4_0 * restrict x = vx;
|
1041
1440
|
|
@@ -1046,9 +1445,9 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1046
1445
|
|
1047
1446
|
const uint8_t * restrict pp = x[i].qs;
|
1048
1447
|
|
1049
|
-
for (int l = 0; l <
|
1448
|
+
for (int l = 0; l < QK4_0; l += 32) {
|
1050
1449
|
// Load 32x4-bit integers into 32x8-bit integers
|
1051
|
-
__m256i vx8 =
|
1450
|
+
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
|
1052
1451
|
|
1053
1452
|
// Subtract 8 from the integers
|
1054
1453
|
vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
|
@@ -1068,7 +1467,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1068
1467
|
// Scale and store
|
1069
1468
|
for (int j = 0; j < 4; j++) {
|
1070
1469
|
const __m256 result = _mm256_mul_ps(vf[j], d_v);
|
1071
|
-
_mm256_storeu_ps(y + i *
|
1470
|
+
_mm256_storeu_ps(y + i * QK4_0 + l + j*8, result);
|
1072
1471
|
}
|
1073
1472
|
}
|
1074
1473
|
}
|
@@ -1078,7 +1477,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1078
1477
|
|
1079
1478
|
const uint8_t * restrict pp = x[i].qs;
|
1080
1479
|
|
1081
|
-
for (int l = 0; l <
|
1480
|
+
for (int l = 0; l < QK4_0; l += 16) {
|
1082
1481
|
// Load 16x4-bit integers into 8x8-bit integers
|
1083
1482
|
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
1084
1483
|
|
@@ -1117,10 +1516,10 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1117
1516
|
const float32x4_t r3 = vmulq_f32(vf_3, vd);
|
1118
1517
|
|
1119
1518
|
// Store
|
1120
|
-
vst1q_f32(y + i*
|
1121
|
-
vst1q_f32(y + i*
|
1122
|
-
vst1q_f32(y + i*
|
1123
|
-
vst1q_f32(y + i*
|
1519
|
+
vst1q_f32(y + i*QK4_0 + l + 0, r0);
|
1520
|
+
vst1q_f32(y + i*QK4_0 + l + 4, r1);
|
1521
|
+
vst1q_f32(y + i*QK4_0 + l + 8, r2);
|
1522
|
+
vst1q_f32(y + i*QK4_0 + l + 12, r3);
|
1124
1523
|
}
|
1125
1524
|
}
|
1126
1525
|
#else
|
@@ -1130,7 +1529,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1130
1529
|
|
1131
1530
|
const uint8_t * restrict pp = x[i].qs;
|
1132
1531
|
|
1133
|
-
for (int l = 0; l <
|
1532
|
+
for (int l = 0; l < QK4_0; l += 2) {
|
1134
1533
|
const uint8_t vi = pp[l/2];
|
1135
1534
|
|
1136
1535
|
const int8_t vi0 = vi & 0xf;
|
@@ -1141,19 +1540,19 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1141
1540
|
|
1142
1541
|
//printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1);
|
1143
1542
|
|
1144
|
-
y[i*
|
1145
|
-
y[i*
|
1543
|
+
y[i*QK4_0 + l + 0] = v0;
|
1544
|
+
y[i*QK4_0 + l + 1] = v1;
|
1146
1545
|
|
1147
|
-
assert(!isnan(y[i*
|
1148
|
-
assert(!isnan(y[i*
|
1546
|
+
assert(!isnan(y[i*QK4_0 + l + 0]));
|
1547
|
+
assert(!isnan(y[i*QK4_0 + l + 1]));
|
1149
1548
|
}
|
1150
1549
|
}
|
1151
1550
|
#endif
|
1152
1551
|
}
|
1153
1552
|
|
1154
1553
|
static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, int k) {
|
1155
|
-
assert(k %
|
1156
|
-
const int nb = k /
|
1554
|
+
assert(k % QK4_1 == 0);
|
1555
|
+
const int nb = k / QK4_1;
|
1157
1556
|
|
1158
1557
|
const block_q4_1 * restrict x = vx;
|
1159
1558
|
|
@@ -1164,9 +1563,9 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1164
1563
|
|
1165
1564
|
const uint8_t * restrict pp = x[i].qs;
|
1166
1565
|
|
1167
|
-
for (int l = 0; l <
|
1566
|
+
for (int l = 0; l < QK4_1; l += 32) {
|
1168
1567
|
// Load 32x4-bit integers into 32x8-bit integers
|
1169
|
-
__m256i vx8 =
|
1568
|
+
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
|
1170
1569
|
|
1171
1570
|
// Convert to 16-bit int
|
1172
1571
|
const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
|
@@ -1183,7 +1582,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1183
1582
|
// Scale, add m and store
|
1184
1583
|
for (int j = 0; j < 4; j++) {
|
1185
1584
|
const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
|
1186
|
-
_mm256_storeu_ps(y + i *
|
1585
|
+
_mm256_storeu_ps(y + i * QK4_1 + l + j*8, result);
|
1187
1586
|
}
|
1188
1587
|
}
|
1189
1588
|
}
|
@@ -1194,7 +1593,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1194
1593
|
|
1195
1594
|
const uint8_t * restrict pp = x[i].qs;
|
1196
1595
|
|
1197
|
-
for (int l = 0; l <
|
1596
|
+
for (int l = 0; l < QK4_1; l += 16) {
|
1198
1597
|
// Load 16x4-bit integers into 8x8-bit integers
|
1199
1598
|
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
1200
1599
|
|
@@ -1225,10 +1624,10 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1225
1624
|
const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd);
|
1226
1625
|
|
1227
1626
|
// Store
|
1228
|
-
vst1q_f32(y + i*
|
1229
|
-
vst1q_f32(y + i*
|
1230
|
-
vst1q_f32(y + i*
|
1231
|
-
vst1q_f32(y + i*
|
1627
|
+
vst1q_f32(y + i*QK4_1 + l + 0, r0);
|
1628
|
+
vst1q_f32(y + i*QK4_1 + l + 4, r1);
|
1629
|
+
vst1q_f32(y + i*QK4_1 + l + 8, r2);
|
1630
|
+
vst1q_f32(y + i*QK4_1 + l + 12, r3);
|
1232
1631
|
}
|
1233
1632
|
}
|
1234
1633
|
#else
|
@@ -1238,7 +1637,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1238
1637
|
|
1239
1638
|
const uint8_t * restrict pp = x[i].qs;
|
1240
1639
|
|
1241
|
-
for (int l = 0; l <
|
1640
|
+
for (int l = 0; l < QK4_1; l += 2) {
|
1242
1641
|
const uint8_t vi = pp[l/2];
|
1243
1642
|
|
1244
1643
|
const int8_t vi0 = vi & 0xf;
|
@@ -1247,21 +1646,130 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1247
1646
|
const float v0 = vi0*d + m;
|
1248
1647
|
const float v1 = vi1*d + m;
|
1249
1648
|
|
1250
|
-
y[i*
|
1251
|
-
y[i*
|
1649
|
+
y[i*QK4_1 + l + 0] = v0;
|
1650
|
+
y[i*QK4_1 + l + 1] = v1;
|
1252
1651
|
|
1253
|
-
assert(!isnan(y[i*
|
1254
|
-
assert(!isnan(y[i*
|
1652
|
+
assert(!isnan(y[i*QK4_1 + l + 0]));
|
1653
|
+
assert(!isnan(y[i*QK4_1 + l + 1]));
|
1255
1654
|
}
|
1256
1655
|
}
|
1257
1656
|
#endif
|
1258
1657
|
}
|
1259
1658
|
|
1260
|
-
|
1261
|
-
|
1262
|
-
|
1659
|
+
static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, int k) {
|
1660
|
+
assert(k % QK4_2 == 0);
|
1661
|
+
const int nb = k / QK4_2;
|
1263
1662
|
|
1264
|
-
|
1663
|
+
const block_q4_2 * restrict x = vx;
|
1664
|
+
|
1665
|
+
for (int i = 0; i < nb; i++) {
|
1666
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
1667
|
+
|
1668
|
+
const uint8_t * restrict pp = x[i].qs;
|
1669
|
+
|
1670
|
+
for (int l = 0; l < QK4_2; l += 2) {
|
1671
|
+
const uint8_t vi = pp[l/2];
|
1672
|
+
|
1673
|
+
const int8_t vi0 = vi & 0xf;
|
1674
|
+
const int8_t vi1 = vi >> 4;
|
1675
|
+
|
1676
|
+
const float v0 = (vi0 - 8)*d;
|
1677
|
+
const float v1 = (vi1 - 8)*d;
|
1678
|
+
|
1679
|
+
y[i*QK4_2 + l + 0] = v0;
|
1680
|
+
y[i*QK4_2 + l + 1] = v1;
|
1681
|
+
|
1682
|
+
assert(!isnan(y[i*QK4_2 + l + 0]));
|
1683
|
+
assert(!isnan(y[i*QK4_2 + l + 1]));
|
1684
|
+
}
|
1685
|
+
}
|
1686
|
+
}
|
1687
|
+
|
1688
|
+
static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, int k) {
|
1689
|
+
assert(k % QK4_3 == 0);
|
1690
|
+
const int nb = k / QK4_3;
|
1691
|
+
|
1692
|
+
const block_q4_3 * restrict x = vx;
|
1693
|
+
|
1694
|
+
for (int i = 0; i < nb; i++) {
|
1695
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
1696
|
+
const float m = GGML_FP16_TO_FP32(x[i].m);
|
1697
|
+
|
1698
|
+
const uint8_t * restrict pp = x[i].qs;
|
1699
|
+
|
1700
|
+
for (int l = 0; l < QK4_3; l += 2) {
|
1701
|
+
const uint8_t vi = pp[l/2];
|
1702
|
+
|
1703
|
+
const int8_t vi0 = vi & 0xf;
|
1704
|
+
const int8_t vi1 = vi >> 4;
|
1705
|
+
|
1706
|
+
const float v0 = vi0*d + m;
|
1707
|
+
const float v1 = vi1*d + m;
|
1708
|
+
|
1709
|
+
y[i*QK4_3 + l + 0] = v0;
|
1710
|
+
y[i*QK4_3 + l + 1] = v1;
|
1711
|
+
|
1712
|
+
assert(!isnan(y[i*QK4_3 + l + 0]));
|
1713
|
+
assert(!isnan(y[i*QK4_3 + l + 1]));
|
1714
|
+
}
|
1715
|
+
}
|
1716
|
+
}
|
1717
|
+
|
1718
|
+
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1719
|
+
static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1720
|
+
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1721
|
+
static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1722
|
+
|
1723
|
+
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
1724
|
+
[GGML_TYPE_Q4_0] = {
|
1725
|
+
.dequantize_row_q = dequantize_row_q4_0,
|
1726
|
+
.quantize_row_q = quantize_row_q4_0,
|
1727
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
|
1728
|
+
.quantize_row_q_dot = quantize_row_q8_0,
|
1729
|
+
.vec_dot_q = ggml_vec_dot_q4_0_q8_0,
|
1730
|
+
},
|
1731
|
+
[GGML_TYPE_Q4_1] = {
|
1732
|
+
.dequantize_row_q = dequantize_row_q4_1,
|
1733
|
+
.quantize_row_q = quantize_row_q4_1,
|
1734
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
|
1735
|
+
.quantize_row_q_dot = quantize_row_q8_0,
|
1736
|
+
.vec_dot_q = ggml_vec_dot_q4_1_q8_0,
|
1737
|
+
},
|
1738
|
+
[GGML_TYPE_Q4_2] = {
|
1739
|
+
.dequantize_row_q = dequantize_row_q4_2,
|
1740
|
+
.quantize_row_q = quantize_row_q4_2,
|
1741
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_rmse, //quantize_row_q4_2_reference,
|
1742
|
+
.quantize_row_q_dot = quantize_row_q8_0,
|
1743
|
+
.vec_dot_q = ggml_vec_dot_q4_2_q8_0,
|
1744
|
+
},
|
1745
|
+
[GGML_TYPE_Q4_3] = {
|
1746
|
+
.dequantize_row_q = dequantize_row_q4_3,
|
1747
|
+
.quantize_row_q = quantize_row_q4_3,
|
1748
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference, // TODO: RMSE optimization
|
1749
|
+
.quantize_row_q_dot = quantize_row_q8_0,
|
1750
|
+
.vec_dot_q = ggml_vec_dot_q4_3_q8_0,
|
1751
|
+
},
|
1752
|
+
[GGML_TYPE_Q8_0] = {
|
1753
|
+
.dequantize_row_q = NULL, // TODO
|
1754
|
+
.quantize_row_q = quantize_row_q8_0,
|
1755
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
|
1756
|
+
.quantize_row_q_dot = quantize_row_q8_0,
|
1757
|
+
.vec_dot_q = NULL, // TODO
|
1758
|
+
},
|
1759
|
+
};
|
1760
|
+
|
1761
|
+
// For internal test use
|
1762
|
+
quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
|
1763
|
+
GGML_ASSERT(i < GGML_TYPE_COUNT);
|
1764
|
+
return quantize_fns[i];
|
1765
|
+
}
|
1766
|
+
|
1767
|
+
|
1768
|
+
//
|
1769
|
+
// simd mappings
|
1770
|
+
//
|
1771
|
+
|
1772
|
+
// we define a common set of C macros which map to specific intrinsics based on the current architecture
|
1265
1773
|
// we then implement the fundamental computation operations below using only these macros
|
1266
1774
|
// adding support for new architectures requires to define the corresponding SIMD macros
|
1267
1775
|
//
|
@@ -1813,37 +2321,6 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
|
|
1813
2321
|
*s = sumf;
|
1814
2322
|
}
|
1815
2323
|
|
1816
|
-
#if __AVX512F__ && QK == 32
|
1817
|
-
static inline __m512 dot_q4_0_oneblock_avx512(
|
1818
|
-
__m512 acc,
|
1819
|
-
const block_q4_0 * restrict x,
|
1820
|
-
const block_q4_0 * restrict y,
|
1821
|
-
int i
|
1822
|
-
) {
|
1823
|
-
// Compute combined scale for the block
|
1824
|
-
__m512 d = _mm512_set1_ps( x[i].d * y[i].d );
|
1825
|
-
|
1826
|
-
__m256i bx = bytesFromNibbles( x[i].qs );
|
1827
|
-
__m256i by = bytesFromNibbles( y[i].qs );
|
1828
|
-
|
1829
|
-
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
1830
|
-
const __m256i off = _mm256_set1_epi8( 8 );
|
1831
|
-
bx = _mm256_sub_epi8( bx, off );
|
1832
|
-
by = _mm256_sub_epi8( by, off );
|
1833
|
-
|
1834
|
-
// Sign-extend 16 signed bytes into int16_t
|
1835
|
-
__m512i x32 = _mm512_cvtepi8_epi16( bx );
|
1836
|
-
__m512i y32 = _mm512_cvtepi8_epi16( by );
|
1837
|
-
// Compute products of int16_t integers, add pairwise
|
1838
|
-
__m512i i64 = _mm512_madd_epi16( x32, y32 );
|
1839
|
-
|
1840
|
-
// Convert int32_t to float
|
1841
|
-
__m512 p = _mm512_cvtepi32_ps( i64 );
|
1842
|
-
// Apply the scale, and accumulate
|
1843
|
-
return _mm512_fmadd_ps( d, p, acc );
|
1844
|
-
}
|
1845
|
-
#endif
|
1846
|
-
|
1847
2324
|
inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
|
1848
2325
|
ggml_float sumf = 0.0;
|
1849
2326
|
|
@@ -1880,67 +2357,64 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
|
|
1880
2357
|
*s = sumf;
|
1881
2358
|
}
|
1882
2359
|
|
1883
|
-
static void
|
1884
|
-
const int nb = n /
|
2360
|
+
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2361
|
+
const int nb = n / QK8_0;
|
1885
2362
|
|
1886
|
-
assert(n %
|
2363
|
+
assert(n % QK8_0 == 0);
|
1887
2364
|
assert(nb % 2 == 0);
|
1888
2365
|
|
1889
2366
|
const block_q4_0 * restrict x = vx;
|
1890
|
-
const
|
2367
|
+
const block_q8_0 * restrict y = vy;
|
1891
2368
|
|
1892
2369
|
float sumf = 0.0;
|
1893
2370
|
|
1894
2371
|
#if defined(__ARM_NEON)
|
1895
|
-
|
1896
|
-
|
2372
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
2373
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
1897
2374
|
|
1898
2375
|
for (int i = 0; i < nb; i += 2) {
|
1899
2376
|
const block_q4_0 * restrict x0 = &x[i + 0];
|
1900
|
-
const block_q4_0 * restrict y0 = &y[i + 0];
|
1901
2377
|
const block_q4_0 * restrict x1 = &x[i + 1];
|
1902
|
-
const
|
2378
|
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
2379
|
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
1903
2380
|
|
1904
|
-
const uint8x16_t m4b
|
1905
|
-
const int8x16_t s8b
|
2381
|
+
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
2382
|
+
const int8x16_t s8b = vdupq_n_s8(0x8);
|
1906
2383
|
|
1907
2384
|
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
1908
|
-
const uint8x16_t v1_0 = vld1q_u8(y0->qs);
|
1909
2385
|
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
1910
|
-
const uint8x16_t v1_1 = vld1q_u8(y1->qs);
|
1911
2386
|
|
1912
2387
|
// 4-bit -> 8-bit
|
1913
|
-
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
|
1914
|
-
const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
|
2388
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
1915
2389
|
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
1916
|
-
const int8x16_t
|
1917
|
-
|
1918
|
-
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
|
1919
|
-
const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
|
2390
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
1920
2391
|
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
1921
|
-
const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
|
1922
2392
|
|
1923
2393
|
// sub 8
|
1924
2394
|
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
1925
|
-
const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
|
1926
2395
|
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
1927
|
-
const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
|
1928
|
-
|
1929
2396
|
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
1930
|
-
const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
|
1931
2397
|
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
1932
|
-
|
2398
|
+
|
2399
|
+
// load y
|
2400
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2401
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2402
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2403
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2404
|
+
|
2405
|
+
// interleave
|
2406
|
+
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
|
2407
|
+
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
|
2408
|
+
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
|
2409
|
+
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
|
1933
2410
|
|
1934
2411
|
#if defined(__ARM_FEATURE_DOTPROD)
|
1935
2412
|
// dot product into int32x4_t
|
1936
|
-
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
|
1937
|
-
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
|
2413
|
+
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
|
2414
|
+
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
|
1938
2415
|
|
1939
|
-
|
1940
|
-
|
1941
|
-
|
1942
|
-
sum0 += x0->d*y0->d*vaddvq_s32(p_0);
|
1943
|
-
sum1 += x1->d*y1->d*vaddvq_s32(p_1);
|
2416
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
|
2417
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
|
1944
2418
|
#else
|
1945
2419
|
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
|
1946
2420
|
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
|
@@ -1952,115 +2426,51 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
1952
2426
|
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
|
1953
2427
|
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
|
1954
2428
|
|
1955
|
-
const
|
1956
|
-
const
|
1957
|
-
|
1958
|
-
const
|
1959
|
-
const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
|
2429
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2430
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2431
|
+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
2432
|
+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
1960
2433
|
|
1961
|
-
|
1962
|
-
|
1963
|
-
|
1964
|
-
sum0 += x0->d*y0->d*vaddvq_s16(p_0);
|
1965
|
-
sum1 += x1->d*y1->d*vaddvq_s16(p_1);
|
2434
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
|
2435
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
|
1966
2436
|
#endif
|
1967
2437
|
}
|
1968
2438
|
|
1969
|
-
sumf =
|
1970
|
-
#elif defined(
|
2439
|
+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
2440
|
+
#elif defined(__AVX2__)
|
1971
2441
|
// Initialize accumulator with zeros
|
1972
|
-
|
1973
|
-
__m512 acc1 = _mm512_setzero_ps();
|
2442
|
+
__m256 acc = _mm256_setzero_ps();
|
1974
2443
|
|
1975
|
-
|
1976
|
-
|
2444
|
+
// Main loop
|
2445
|
+
for (int i = 0; i < nb; ++i) {
|
2446
|
+
/* Compute combined scale for the block */
|
2447
|
+
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
|
1977
2448
|
|
1978
|
-
|
1979
|
-
int i = superblock_ix * superblock_size;
|
2449
|
+
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
1980
2450
|
|
1981
|
-
|
1982
|
-
|
1983
|
-
|
1984
|
-
acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+3 );
|
1985
|
-
acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+4 );
|
1986
|
-
acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+5 );
|
1987
|
-
acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+6 );
|
1988
|
-
acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+7 );
|
1989
|
-
}
|
2451
|
+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2452
|
+
const __m256i off = _mm256_set1_epi8( 8 );
|
2453
|
+
bx = _mm256_sub_epi8( bx, off );
|
1990
2454
|
|
1991
|
-
|
1992
|
-
for (int i = superblock_count * superblock_size; i < nb; ++i) {
|
1993
|
-
acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i );
|
1994
|
-
}
|
2455
|
+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
1995
2456
|
|
1996
|
-
|
1997
|
-
|
1998
|
-
#elif defined(__AVX2__)
|
1999
|
-
// Initialize accumulator with zeros
|
2000
|
-
__m256 acc = _mm256_setzero_ps();
|
2457
|
+
// Get absolute values of x vectors
|
2458
|
+
const __m256i ax = _mm256_sign_epi8(bx, bx);
|
2001
2459
|
|
2002
|
-
|
2003
|
-
|
2004
|
-
const __m256i offset_8 = _mm256_set1_epi16( 8 );
|
2460
|
+
// Sign the values of the y vectors
|
2461
|
+
const __m256i sy = _mm256_sign_epi8(by, bx);
|
2005
2462
|
|
2006
|
-
|
2007
|
-
|
2008
|
-
assert(nb % UNROLL_COUNT == 0);
|
2463
|
+
// Perform multiplication and create 16-bit values
|
2464
|
+
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
2009
2465
|
|
2010
|
-
|
2011
|
-
|
2012
|
-
|
2013
|
-
|
2014
|
-
|
2015
|
-
|
2016
|
-
|
2017
|
-
|
2018
|
-
|
2019
|
-
/* get input from x
|
2020
|
-
Input: 32 Nibbles (16 bytes) at *x[i+u]
|
2021
|
-
Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
|
2022
|
-
|
2023
|
-
/* Load 16 bytes from memory */
|
2024
|
-
const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
|
2025
|
-
/* Expand bytes into uint16_t values */
|
2026
|
-
const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
|
2027
|
-
/* Unpack values into individual bytes */
|
2028
|
-
__m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
|
2029
|
-
const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
|
2030
|
-
__m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
|
2031
|
-
/* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
|
2032
|
-
x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
|
2033
|
-
x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
|
2034
|
-
|
2035
|
-
/* get input from y
|
2036
|
-
Input: 32 Nibbles (16 bytes) at *y[i+u]
|
2037
|
-
Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
|
2038
|
-
|
2039
|
-
/* Load 16 bytes from memory */
|
2040
|
-
const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
|
2041
|
-
/* Expand bytes into uint16_t values */
|
2042
|
-
const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
|
2043
|
-
/* Unpack values into individual bytes */
|
2044
|
-
const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
|
2045
|
-
__m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
|
2046
|
-
__m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
|
2047
|
-
/* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
|
2048
|
-
y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
|
2049
|
-
y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
|
2050
|
-
|
2051
|
-
/* Compute products of int16_t integers, add pairwise, store as int32_t */
|
2052
|
-
__m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
|
2053
|
-
__m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
|
2054
|
-
|
2055
|
-
/* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
|
2056
|
-
__m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
|
2057
|
-
|
2058
|
-
/* Convert to vectore of 8 int32_t to 8 floats */
|
2059
|
-
__m256 q = _mm256_cvtepi32_ps( xy_q );
|
2060
|
-
|
2061
|
-
/* Multiply q with scale and accumulate */
|
2062
|
-
acc = _mm256_fmadd_ps( scale, q, acc );
|
2063
|
-
}
|
2466
|
+
const __m256i ones = _mm256_set1_epi16(1);
|
2467
|
+
__m256i xy_q = _mm256_madd_epi16(ones, dot);
|
2468
|
+
|
2469
|
+
/* Convert to vectore of 8 int32_t to 8 floats */
|
2470
|
+
__m256 q = _mm256_cvtepi32_ps( xy_q );
|
2471
|
+
|
2472
|
+
/* Multiply q with scale and accumulate */
|
2473
|
+
acc = _mm256_fmadd_ps( d, q, acc );
|
2064
2474
|
}
|
2065
2475
|
|
2066
2476
|
// Return horizontal sum of the acc vector
|
@@ -2082,13 +2492,12 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
2082
2492
|
__m128i i32[2];
|
2083
2493
|
for (int j = 0; j < 2; ++j) {
|
2084
2494
|
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
|
2085
|
-
__m128i bx =
|
2086
|
-
__m128i by =
|
2495
|
+
__m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j);
|
2496
|
+
__m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
|
2087
2497
|
|
2088
2498
|
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2089
2499
|
const __m128i off = _mm_set1_epi8( 8 );
|
2090
2500
|
bx = _mm_sub_epi8( bx, off );
|
2091
|
-
by = _mm_sub_epi8( by, off );
|
2092
2501
|
|
2093
2502
|
// Get absolute values of x vectors
|
2094
2503
|
const __m128i ax = _mm_sign_epi8(bx, bx);
|
@@ -2116,86 +2525,6 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
2116
2525
|
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2117
2526
|
|
2118
2527
|
sumf = _mm_cvtss_f32( res );
|
2119
|
-
#elif defined(__wasm_simd128__)
|
2120
|
-
// wasm simd
|
2121
|
-
float sum0 = 0.0f;
|
2122
|
-
float sum1 = 0.0f;
|
2123
|
-
|
2124
|
-
for (int i = 0; i < nb; i += 2) {
|
2125
|
-
const block_q4_0 * restrict x0 = &x[i + 0];
|
2126
|
-
const block_q4_0 * restrict y0 = &y[i + 0];
|
2127
|
-
const block_q4_0 * restrict x1 = &x[i + 1];
|
2128
|
-
const block_q4_0 * restrict y1 = &y[i + 1];
|
2129
|
-
|
2130
|
-
const v128_t m4b = wasm_u8x16_splat(0xf);
|
2131
|
-
const v128_t s8b = wasm_i8x16_splat(0x8);
|
2132
|
-
|
2133
|
-
const v128_t v0_0 = wasm_v128_load(x0->qs);
|
2134
|
-
const v128_t v0_1 = wasm_v128_load(y0->qs);
|
2135
|
-
const v128_t v1_0 = wasm_v128_load(x1->qs);
|
2136
|
-
const v128_t v1_1 = wasm_v128_load(y1->qs);
|
2137
|
-
|
2138
|
-
// 4-bit -> 8-bit
|
2139
|
-
const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
|
2140
|
-
const v128_t v1_0l = wasm_v128_and(v1_0, m4b);
|
2141
|
-
|
2142
|
-
const v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
|
2143
|
-
const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4);
|
2144
|
-
|
2145
|
-
const v128_t v0_1l = wasm_v128_and(v0_1, m4b);
|
2146
|
-
const v128_t v1_1l = wasm_v128_and(v1_1, m4b);
|
2147
|
-
|
2148
|
-
const v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
|
2149
|
-
const v128_t v1_1h = wasm_u8x16_shr(v1_1, 4);
|
2150
|
-
|
2151
|
-
// sub 8
|
2152
|
-
const v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
|
2153
|
-
const v128_t v1_0ls = wasm_i8x16_sub(v1_0l, s8b);
|
2154
|
-
|
2155
|
-
const v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
|
2156
|
-
const v128_t v1_0hs = wasm_i8x16_sub(v1_0h, s8b);
|
2157
|
-
|
2158
|
-
const v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
|
2159
|
-
const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b);
|
2160
|
-
|
2161
|
-
const v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
|
2162
|
-
const v128_t v1_1hs = wasm_i8x16_sub(v1_1h, s8b);
|
2163
|
-
|
2164
|
-
// dot product into int16x8_t
|
2165
|
-
const v128_t pl0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0ls), wasm_i16x8_extend_low_i8x16(v1_0ls));
|
2166
|
-
const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls));
|
2167
|
-
|
2168
|
-
const v128_t ph0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0hs), wasm_i16x8_extend_low_i8x16(v1_0hs));
|
2169
|
-
const v128_t ph0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0hs), wasm_i16x8_extend_high_i8x16(v1_0hs));
|
2170
|
-
|
2171
|
-
const v128_t pl1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1ls), wasm_i16x8_extend_low_i8x16(v1_1ls));
|
2172
|
-
const v128_t pl1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1ls), wasm_i16x8_extend_high_i8x16(v1_1ls));
|
2173
|
-
|
2174
|
-
const v128_t ph1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1hs), wasm_i16x8_extend_low_i8x16(v1_1hs));
|
2175
|
-
const v128_t ph1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1hs), wasm_i16x8_extend_high_i8x16(v1_1hs));
|
2176
|
-
|
2177
|
-
const v128_t pl_0 = wasm_i16x8_add(pl0l, pl0h);
|
2178
|
-
const v128_t ph_0 = wasm_i16x8_add(ph0l, ph0h);
|
2179
|
-
|
2180
|
-
const v128_t pl_1 = wasm_i16x8_add(pl1l, pl1h);
|
2181
|
-
const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h);
|
2182
|
-
|
2183
|
-
const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0);
|
2184
|
-
const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1);
|
2185
|
-
|
2186
|
-
sum0 += x0->d * y0->d * (
|
2187
|
-
wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) +
|
2188
|
-
wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) +
|
2189
|
-
wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) +
|
2190
|
-
wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7));
|
2191
|
-
sum1 += x1->d * y1->d * (
|
2192
|
-
wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) +
|
2193
|
-
wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) +
|
2194
|
-
wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) +
|
2195
|
-
wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7));
|
2196
|
-
}
|
2197
|
-
|
2198
|
-
sumf = sum0 + sum1;
|
2199
2528
|
#else
|
2200
2529
|
// scalar
|
2201
2530
|
for (int i = 0; i < nb; i++) {
|
@@ -2203,98 +2532,159 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
2203
2532
|
const float d1 = y[i].d;
|
2204
2533
|
|
2205
2534
|
const uint8_t * restrict p0 = x[i].qs;
|
2206
|
-
const
|
2535
|
+
const int8_t * restrict p1 = y[i].qs;
|
2207
2536
|
|
2208
2537
|
int sumi = 0;
|
2209
|
-
for (int j = 0; j <
|
2538
|
+
for (int j = 0; j < QK8_0/2; j++) {
|
2210
2539
|
const uint8_t v0 = p0[j];
|
2211
|
-
const uint8_t v1 = p1[j];
|
2212
2540
|
|
2213
|
-
const
|
2214
|
-
const
|
2541
|
+
const int i0 = (int8_t) (v0 & 0xf) - 8;
|
2542
|
+
const int i1 = (int8_t) (v0 >> 4) - 8;
|
2215
2543
|
|
2216
|
-
const
|
2217
|
-
const
|
2544
|
+
const int i2 = p1[2*j + 0];
|
2545
|
+
const int i3 = p1[2*j + 1];
|
2218
2546
|
|
2219
2547
|
sumi += i0*i2 + i1*i3;
|
2220
2548
|
}
|
2221
|
-
sumf += d0
|
2549
|
+
sumf += d0*d1*sumi;
|
2222
2550
|
}
|
2223
2551
|
#endif
|
2224
2552
|
|
2225
2553
|
*s = sumf;
|
2226
2554
|
}
|
2227
2555
|
|
2228
|
-
static void
|
2229
|
-
const int nb = n /
|
2556
|
+
static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2557
|
+
const int nb = n / QK8_0;
|
2558
|
+
|
2559
|
+
assert(n % QK8_0 == 0);
|
2560
|
+
assert(nb % 2 == 0);
|
2230
2561
|
|
2231
2562
|
const block_q4_1 * restrict x = vx;
|
2232
|
-
const
|
2563
|
+
const block_q8_0 * restrict y = vy;
|
2233
2564
|
|
2234
2565
|
float sumf = 0.0;
|
2235
2566
|
|
2236
|
-
|
2567
|
+
// TODO: add AVX / WASM SIMD / etc
|
2568
|
+
#if defined(__ARM_NEON)
|
2569
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
2570
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
2571
|
+
|
2572
|
+
for (int i = 0; i < nb; i += 2) {
|
2573
|
+
const block_q4_1 * restrict x0 = &x[i + 0];
|
2574
|
+
const block_q4_1 * restrict x1 = &x[i + 1];
|
2575
|
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
2576
|
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
2577
|
+
|
2578
|
+
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
2579
|
+
|
2580
|
+
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
2581
|
+
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
2582
|
+
|
2583
|
+
// 4-bit -> 8-bit
|
2584
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
2585
|
+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
2586
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
2587
|
+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
2588
|
+
|
2589
|
+
// load y
|
2590
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2591
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2592
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2593
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2594
|
+
|
2595
|
+
// interleave
|
2596
|
+
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
|
2597
|
+
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
|
2598
|
+
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
|
2599
|
+
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
|
2600
|
+
|
2601
|
+
const int16x8_t s0i = vaddq_s16(
|
2602
|
+
vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
|
2603
|
+
vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
|
2604
|
+
|
2605
|
+
const int16x8_t s1i = vaddq_s16(
|
2606
|
+
vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
|
2607
|
+
vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
|
2608
|
+
|
2609
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
|
2610
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
|
2611
|
+
|
2612
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
2613
|
+
// dot product into int32x4_t
|
2614
|
+
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
|
2615
|
+
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
|
2616
|
+
|
2617
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
|
2618
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
|
2619
|
+
#else
|
2620
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
|
2621
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
|
2622
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
|
2623
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
|
2624
|
+
|
2625
|
+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
|
2626
|
+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
|
2627
|
+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
|
2628
|
+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
|
2629
|
+
|
2630
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2631
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2632
|
+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
2633
|
+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
2634
|
+
|
2635
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
|
2636
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
|
2637
|
+
#endif
|
2638
|
+
}
|
2639
|
+
|
2640
|
+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
2641
|
+
#elif defined(__AVX2__)
|
2237
2642
|
// Initialize accumulator with zeros
|
2238
2643
|
__m256 acc = _mm256_setzero_ps();
|
2239
|
-
// Accumulator for constant offsets
|
2240
|
-
float acc_offset = 0.0f;
|
2241
2644
|
|
2242
2645
|
// Main loop
|
2243
2646
|
for (int i = 0; i < nb; ++i) {
|
2244
2647
|
const float * d0 = &x[i].d;
|
2245
2648
|
const float * d1 = &y[i].d;
|
2246
|
-
|
2247
2649
|
const float * m0 = &x[i].m;
|
2248
|
-
const float * m1 = &y[i].m;
|
2249
2650
|
|
2250
2651
|
const __m256 d0v = _mm256_broadcast_ss( d0 );
|
2251
2652
|
const __m256 d1v = _mm256_broadcast_ss( d1 );
|
2252
2653
|
const __m256 m0v = _mm256_broadcast_ss( m0 );
|
2253
|
-
const __m256 m1v = _mm256_broadcast_ss( m1 );
|
2254
2654
|
|
2255
|
-
// Compute combined
|
2256
|
-
const __m256
|
2257
|
-
|
2258
|
-
// Compute cross scales for the block
|
2259
|
-
const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
|
2260
|
-
const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
|
2261
|
-
const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ );
|
2655
|
+
// Compute combined scales
|
2656
|
+
const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
|
2657
|
+
const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
|
2262
2658
|
|
2263
2659
|
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
|
2264
|
-
__m256i bx =
|
2265
|
-
__m256i by =
|
2266
|
-
|
2267
|
-
// Now we have a vector with bytes in [ 0 .. 15 ] interval.
|
2268
|
-
|
2269
|
-
// Sign-extend first 16 signed bytes into int16_t
|
2270
|
-
__m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
|
2271
|
-
__m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
|
2272
|
-
// Compute products of int16_t integers, add pairwise
|
2273
|
-
__m256i i32 = _mm256_madd_epi16( x16, y16 );
|
2274
|
-
|
2275
|
-
// Sign-extend last 16 signed bytes into int16_t vectors
|
2276
|
-
__m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
|
2277
|
-
__m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
|
2278
|
-
// Accumulate products of int16_t integers
|
2279
|
-
i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) );
|
2280
|
-
|
2281
|
-
// compute sums of unsigned bytes in bx, by in blocks of 8.
|
2282
|
-
// This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
|
2283
|
-
// which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
|
2284
|
-
// so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
|
2285
|
-
__m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() );
|
2286
|
-
__m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() );
|
2287
|
-
__m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
|
2288
|
-
__m256 sums = _mm256_cvtepi32_ps( sumsi );
|
2660
|
+
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
|
2661
|
+
const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
|
2289
2662
|
|
2290
|
-
//
|
2291
|
-
|
2292
|
-
|
2293
|
-
//
|
2294
|
-
|
2295
|
-
|
2296
|
-
//
|
2297
|
-
|
2663
|
+
// Get absolute values of x vectors
|
2664
|
+
const __m256i ax = _mm256_sign_epi8( bx, bx );
|
2665
|
+
|
2666
|
+
// Sign the values of the y vectors
|
2667
|
+
const __m256i sy = _mm256_sign_epi8( by, bx );
|
2668
|
+
|
2669
|
+
// Perform multiplication and create 16-bit values
|
2670
|
+
const __m256i dot = _mm256_maddubs_epi16( ax, sy );
|
2671
|
+
const __m256i ones = _mm256_set1_epi16( 1 );
|
2672
|
+
const __m256i xy_q = _mm256_madd_epi16( ones, dot );
|
2673
|
+
|
2674
|
+
// Convert to vector of 8 int32_t to 8 floats
|
2675
|
+
const __m256 xy = _mm256_cvtepi32_ps( xy_q );
|
2676
|
+
|
2677
|
+
// Accumulate d0*d1*x*y
|
2678
|
+
acc = _mm256_fmadd_ps( d0d1, xy, acc );
|
2679
|
+
|
2680
|
+
// Compute sum of y values
|
2681
|
+
const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
|
2682
|
+
const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
|
2683
|
+
const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
|
2684
|
+
const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
|
2685
|
+
|
2686
|
+
// Accumulate d1*m0*y
|
2687
|
+
acc = _mm256_fmadd_ps( d1m0, ysum, acc );
|
2298
2688
|
}
|
2299
2689
|
|
2300
2690
|
// Return horizontal sum of the acc vector
|
@@ -2303,131 +2693,379 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
|
|
2303
2693
|
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2304
2694
|
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2305
2695
|
|
2306
|
-
sumf = _mm_cvtss_f32( res )
|
2307
|
-
#
|
2308
|
-
|
2309
|
-
|
2310
|
-
|
2311
|
-
|
2696
|
+
sumf = _mm_cvtss_f32( res );
|
2697
|
+
#else
|
2698
|
+
// scalar
|
2699
|
+
for (int i = 0; i < nb; i++) {
|
2700
|
+
const float d0 = x[i].d;
|
2701
|
+
const float m0 = x[i].m;
|
2702
|
+
const float d1 = y[i].d;
|
2703
|
+
|
2704
|
+
const uint8_t * restrict p0 = x[i].qs;
|
2705
|
+
const int8_t * restrict p1 = y[i].qs;
|
2706
|
+
|
2707
|
+
// TODO: this is very slow ..
|
2708
|
+
for (int j = 0; j < QK8_0/2; j++) {
|
2709
|
+
const uint8_t v0 = p0[j];
|
2710
|
+
|
2711
|
+
const float f0 = d0*(v0 & 0xf) + m0;
|
2712
|
+
const float f1 = d0*(v0 >> 4) + m0;
|
2713
|
+
|
2714
|
+
const float f2 = d1*p1[2*j + 0];
|
2715
|
+
const float f3 = d1*p1[2*j + 1];
|
2716
|
+
|
2717
|
+
sumf += f0*f2 + f1*f3;
|
2718
|
+
}
|
2719
|
+
}
|
2720
|
+
#endif
|
2721
|
+
|
2722
|
+
*s = sumf;
|
2723
|
+
}
|
2724
|
+
|
2725
|
+
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2726
|
+
const int nb = n / QK8_0;
|
2727
|
+
|
2728
|
+
assert(n % QK8_0 == 0);
|
2729
|
+
assert(nb % 2 == 0);
|
2730
|
+
assert(QK8_0 == 2*QK4_2);
|
2731
|
+
|
2732
|
+
const block_q4_2 * restrict x = vx;
|
2733
|
+
const block_q8_0 * restrict y = vy;
|
2734
|
+
|
2735
|
+
float sumf = 0.0;
|
2736
|
+
|
2737
|
+
#if defined(__ARM_NEON)
|
2738
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
2739
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
2312
2740
|
|
2313
2741
|
for (int i = 0; i < nb; i += 2) {
|
2314
|
-
const
|
2315
|
-
const
|
2316
|
-
const
|
2317
|
-
const
|
2742
|
+
const block_q4_2 * restrict x0_0 = &x[2*(i + 0) + 0];
|
2743
|
+
const block_q4_2 * restrict x0_1 = &x[2*(i + 0) + 1];
|
2744
|
+
const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0];
|
2745
|
+
const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1];
|
2318
2746
|
|
2319
|
-
const
|
2747
|
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
2748
|
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
2320
2749
|
|
2321
|
-
const uint8x16_t
|
2322
|
-
const
|
2323
|
-
|
2324
|
-
const uint8x16_t
|
2750
|
+
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
2751
|
+
const int8x16_t s8b = vdupq_n_s8(0x8);
|
2752
|
+
|
2753
|
+
const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
|
2754
|
+
const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
|
2325
2755
|
|
2326
2756
|
// 4-bit -> 8-bit
|
2327
|
-
const
|
2328
|
-
const
|
2329
|
-
const
|
2330
|
-
const
|
2757
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
2758
|
+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
2759
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
2760
|
+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
2331
2761
|
|
2332
|
-
|
2333
|
-
const
|
2334
|
-
const
|
2335
|
-
const
|
2762
|
+
// sub 8
|
2763
|
+
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
2764
|
+
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
2765
|
+
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
2766
|
+
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
2336
2767
|
|
2337
|
-
|
2338
|
-
|
2339
|
-
|
2768
|
+
// interleave
|
2769
|
+
const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
|
2770
|
+
const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
|
2771
|
+
const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
|
2772
|
+
const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
|
2340
2773
|
|
2341
|
-
|
2342
|
-
|
2343
|
-
|
2774
|
+
// load y
|
2775
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2776
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2777
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2778
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2344
2779
|
|
2345
2780
|
#if defined(__ARM_FEATURE_DOTPROD)
|
2346
|
-
|
2347
|
-
|
2348
|
-
|
2781
|
+
sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
|
2782
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
|
2783
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
|
2349
2784
|
|
2350
|
-
|
2351
|
-
|
2352
|
-
|
2353
|
-
sum11 += x0->d*y0->d*vaddvq_u32(p_0);
|
2354
|
-
sum11 += x1->d*y1->d*vaddvq_u32(p_1);
|
2785
|
+
sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
|
2786
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
|
2787
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
|
2355
2788
|
#else
|
2356
|
-
const
|
2357
|
-
const
|
2358
|
-
const
|
2359
|
-
const
|
2789
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
|
2790
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
|
2791
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
|
2792
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
|
2793
|
+
|
2794
|
+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
|
2795
|
+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
|
2796
|
+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
|
2797
|
+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
|
2798
|
+
|
2799
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2800
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2801
|
+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
2802
|
+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
2803
|
+
|
2804
|
+
sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
|
2805
|
+
vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
|
2806
|
+
vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
|
2807
|
+
|
2808
|
+
sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
|
2809
|
+
vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
|
2810
|
+
vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
|
2811
|
+
#endif
|
2812
|
+
}
|
2813
|
+
|
2814
|
+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
2815
|
+
#elif defined(__AVX2__)
|
2816
|
+
// Initialize accumulator with zeros
|
2817
|
+
__m256 acc = _mm256_setzero_ps();
|
2818
|
+
|
2819
|
+
// Main loop
|
2820
|
+
for (int i = 0; i < nb; i++) {
|
2821
|
+
/* Compute combined scale for the block */
|
2822
|
+
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
|
2823
|
+
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
|
2824
|
+
const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
|
2360
2825
|
|
2361
|
-
|
2362
|
-
|
2363
|
-
|
2364
|
-
const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
|
2826
|
+
__m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
|
2827
|
+
__m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
|
2828
|
+
__m256i bx = _mm256_set_m128i(bx1, bx0);
|
2365
2829
|
|
2366
|
-
|
2367
|
-
const
|
2830
|
+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2831
|
+
const __m256i off = _mm256_set1_epi8(8);
|
2832
|
+
bx = _mm256_sub_epi8(bx, off);
|
2368
2833
|
|
2369
|
-
|
2370
|
-
const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
|
2834
|
+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
2371
2835
|
|
2372
|
-
|
2373
|
-
const
|
2836
|
+
// Get absolute values of x vectors
|
2837
|
+
const __m256i ax = _mm256_sign_epi8(bx, bx);
|
2838
|
+
// Sign the values of the y vectors
|
2839
|
+
const __m256i sy = _mm256_sign_epi8(by, bx);
|
2840
|
+
// Perform multiplication and create 16-bit values
|
2841
|
+
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
2374
2842
|
|
2375
|
-
|
2376
|
-
|
2377
|
-
|
2843
|
+
const __m256i ones = _mm256_set1_epi16(1);
|
2844
|
+
__m256i xy_q = _mm256_madd_epi16(ones, dot);
|
2845
|
+
|
2846
|
+
/* Convert to vectore of 8 int32_t to 8 floats */
|
2847
|
+
__m256 q = _mm256_cvtepi32_ps(xy_q);
|
2848
|
+
|
2849
|
+
/* Multiply q with scale and accumulate */
|
2850
|
+
acc = _mm256_fmadd_ps(d, q, acc);
|
2378
2851
|
}
|
2379
2852
|
|
2380
|
-
|
2853
|
+
// Return horizontal sum of the acc vector
|
2854
|
+
__m128 res = _mm256_extractf128_ps(acc, 1);
|
2855
|
+
res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
|
2856
|
+
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
|
2857
|
+
res = _mm_add_ss(res, _mm_movehdup_ps(res));
|
2858
|
+
|
2859
|
+
sumf = _mm_cvtss_f32(res);
|
2381
2860
|
#else
|
2382
2861
|
// scalar
|
2383
2862
|
for (int i = 0; i < nb; i++) {
|
2384
|
-
const
|
2385
|
-
const
|
2863
|
+
const uint8_t * restrict x0 = x[2*i + 0].qs;
|
2864
|
+
const uint8_t * restrict x1 = x[2*i + 1].qs;
|
2865
|
+
const int8_t * restrict y0 = y[i].qs;
|
2386
2866
|
|
2387
|
-
const float
|
2388
|
-
const float
|
2867
|
+
const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
|
2868
|
+
const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
|
2389
2869
|
|
2390
|
-
|
2391
|
-
|
2870
|
+
int sumi_0 = 0;
|
2871
|
+
int sumi_1 = 0;
|
2392
2872
|
|
2393
|
-
for (int j = 0; j <
|
2394
|
-
const uint8_t v0 =
|
2395
|
-
const uint8_t v1 =
|
2873
|
+
for (int j = 0; j < QK8_0/4; j++) {
|
2874
|
+
const uint8_t v0 = x0[j];
|
2875
|
+
const uint8_t v1 = x1[j];
|
2396
2876
|
|
2397
|
-
const
|
2398
|
-
const
|
2877
|
+
const int i0_0 = (int8_t) (v0 & 0xf) - 8;
|
2878
|
+
const int i1_0 = (int8_t) (v0 >> 4) - 8;
|
2399
2879
|
|
2400
|
-
const
|
2401
|
-
const
|
2880
|
+
const int i0_1 = (int8_t) (v1 & 0xf) - 8;
|
2881
|
+
const int i1_1 = (int8_t) (v1 >> 4) - 8;
|
2402
2882
|
|
2403
|
-
|
2883
|
+
const int i2_0 = y0[2*j + 0];
|
2884
|
+
const int i3_0 = y0[2*j + 1];
|
2885
|
+
|
2886
|
+
const int i2_1 = y0[2*(j + QK8_0/4) + 0];
|
2887
|
+
const int i3_1 = y0[2*(j + QK8_0/4) + 1];
|
2888
|
+
|
2889
|
+
sumi_0 += i0_0*i2_0 + i1_0*i3_0;
|
2890
|
+
sumi_1 += i0_1*i2_1 + i1_1*i3_1;
|
2404
2891
|
}
|
2892
|
+
|
2893
|
+
sumf += (d0 * y[i].d) * sumi_0;
|
2894
|
+
sumf += (d1 * y[i].d) * sumi_1;
|
2405
2895
|
}
|
2406
2896
|
#endif
|
2407
2897
|
|
2408
2898
|
*s = sumf;
|
2409
2899
|
}
|
2410
2900
|
|
2411
|
-
|
2412
|
-
|
2413
|
-
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
|
2414
|
-
ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
|
2901
|
+
static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2902
|
+
const int nb = n / QK8_0;
|
2415
2903
|
|
2416
|
-
|
2904
|
+
assert(n % QK8_0 == 0);
|
2905
|
+
assert(nb % 2 == 0);
|
2906
|
+
assert(QK8_0 == 2*QK4_2);
|
2417
2907
|
|
2418
|
-
|
2419
|
-
|
2420
|
-
}
|
2908
|
+
const block_q4_3 * restrict x = vx;
|
2909
|
+
const block_q8_0 * restrict y = vy;
|
2421
2910
|
|
2422
|
-
|
2423
|
-
const int np = (n & ~(GGML_F16_STEP - 1));
|
2911
|
+
float sumf = 0.0;
|
2424
2912
|
|
2425
|
-
|
2913
|
+
#if defined(__ARM_NEON)
|
2914
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
2915
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
2426
2916
|
|
2427
|
-
|
2428
|
-
|
2917
|
+
for (int i = 0; i < nb; i += 2) {
|
2918
|
+
const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
|
2919
|
+
const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
|
2920
|
+
const block_q4_3 * restrict x1_0 = &x[2*(i + 1) + 0];
|
2921
|
+
const block_q4_3 * restrict x1_1 = &x[2*(i + 1) + 1];
|
2429
2922
|
|
2430
|
-
|
2923
|
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
2924
|
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
2925
|
+
|
2926
|
+
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
2927
|
+
|
2928
|
+
const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
|
2929
|
+
const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
|
2930
|
+
const float x1_0d = GGML_FP16_TO_FP32(x1_0->d);
|
2931
|
+
const float x1_1d = GGML_FP16_TO_FP32(x1_1->d);
|
2932
|
+
|
2933
|
+
const float x0_0m = GGML_FP16_TO_FP32(x0_0->m);
|
2934
|
+
const float x0_1m = GGML_FP16_TO_FP32(x0_1->m);
|
2935
|
+
const float x1_0m = GGML_FP16_TO_FP32(x1_0->m);
|
2936
|
+
const float x1_1m = GGML_FP16_TO_FP32(x1_1->m);
|
2937
|
+
|
2938
|
+
const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
|
2939
|
+
const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
|
2940
|
+
|
2941
|
+
// 4-bit -> 8-bit
|
2942
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
2943
|
+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
2944
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
2945
|
+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
2946
|
+
|
2947
|
+
// interleave
|
2948
|
+
const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
|
2949
|
+
const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
|
2950
|
+
const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
|
2951
|
+
const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
|
2952
|
+
|
2953
|
+
// load y
|
2954
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2955
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2956
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2957
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2958
|
+
|
2959
|
+
const int16x8_t sy0_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0l)), vmovl_s8(vget_high_s8(v1_0l)));
|
2960
|
+
const int16x8_t sy0_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0h)), vmovl_s8(vget_high_s8(v1_0h)));
|
2961
|
+
|
2962
|
+
const int16x8_t sy1_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1l)), vmovl_s8(vget_high_s8(v1_1l)));
|
2963
|
+
const int16x8_t sy1_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1h)), vmovl_s8(vget_high_s8(v1_1h)));
|
2964
|
+
|
2965
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_0), vget_high_s16(sy0_0))), x0_0m*y0->d);
|
2966
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_1), vget_high_s16(sy0_1))), x0_1m*y0->d);
|
2967
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_0), vget_high_s16(sy1_0))), x1_0m*y1->d);
|
2968
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_1), vget_high_s16(sy1_1))), x1_1m*y1->d);
|
2969
|
+
|
2970
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
2971
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
|
2972
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
|
2973
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), x1_0d*y1->d);
|
2974
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), x1_1d*y1->d);
|
2975
|
+
#else
|
2976
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
|
2977
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
|
2978
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
|
2979
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
|
2980
|
+
|
2981
|
+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
|
2982
|
+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
|
2983
|
+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
|
2984
|
+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
|
2985
|
+
|
2986
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2987
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2988
|
+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
2989
|
+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
2990
|
+
|
2991
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
|
2992
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(ph0), x0_1d*y0->d);
|
2993
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(pl1), x1_0d*y1->d);
|
2994
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph1), x1_1d*y1->d);
|
2995
|
+
#endif
|
2996
|
+
}
|
2997
|
+
|
2998
|
+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
2999
|
+
#else
|
3000
|
+
// scalar
|
3001
|
+
for (int i = 0; i < nb; i++) {
|
3002
|
+
const uint8_t * restrict x0 = x[2*i + 0].qs;
|
3003
|
+
const uint8_t * restrict x1 = x[2*i + 1].qs;
|
3004
|
+
const int8_t * restrict y0 = y[i].qs;
|
3005
|
+
|
3006
|
+
const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
|
3007
|
+
const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
|
3008
|
+
const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
|
3009
|
+
const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
|
3010
|
+
|
3011
|
+
int sy_0 = 0;
|
3012
|
+
int sy_1 = 0;
|
3013
|
+
|
3014
|
+
int sxy_0 = 0;
|
3015
|
+
int sxy_1 = 0;
|
3016
|
+
|
3017
|
+
for (int j = 0; j < QK8_0/4; j++) {
|
3018
|
+
const uint8_t v0 = x0[j];
|
3019
|
+
const uint8_t v1 = x1[j];
|
3020
|
+
|
3021
|
+
const int x0_0 = v0 & 0xf;
|
3022
|
+
const int x1_0 = v0 >> 4;
|
3023
|
+
|
3024
|
+
const int x0_1 = v1 & 0xf;
|
3025
|
+
const int x1_1 = v1 >> 4;
|
3026
|
+
|
3027
|
+
const int y0_0 = y0[2*j + 0];
|
3028
|
+
const int y1_0 = y0[2*j + 1];
|
3029
|
+
|
3030
|
+
const int y0_1 = y0[2*(j + QK8_0/4) + 0];
|
3031
|
+
const int y1_1 = y0[2*(j + QK8_0/4) + 1];
|
3032
|
+
|
3033
|
+
sy_0 += y0_0 + y1_0;
|
3034
|
+
sy_1 += y0_1 + y1_1;
|
3035
|
+
|
3036
|
+
sxy_0 += x0_0*y0_0 + x1_0*y1_0;
|
3037
|
+
sxy_1 += x0_1*y0_1 + x1_1*y1_1;
|
3038
|
+
}
|
3039
|
+
|
3040
|
+
sumf += (d0*sxy_0 + m0*sy_0)*y[i].d;
|
3041
|
+
sumf += (d1*sxy_1 + m1*sy_1)*y[i].d;
|
3042
|
+
}
|
3043
|
+
#endif
|
3044
|
+
|
3045
|
+
*s = sumf;
|
3046
|
+
}
|
3047
|
+
|
3048
|
+
|
3049
|
+
// compute GGML_VEC_DOT_UNROLL dot products at once
|
3050
|
+
// xs - x row stride in bytes
|
3051
|
+
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
|
3052
|
+
ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
|
3053
|
+
|
3054
|
+
ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
|
3055
|
+
|
3056
|
+
for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
|
3057
|
+
x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
|
3058
|
+
}
|
3059
|
+
|
3060
|
+
#if defined(GGML_SIMD)
|
3061
|
+
const int np = (n & ~(GGML_F16_STEP - 1));
|
3062
|
+
|
3063
|
+
GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
|
3064
|
+
|
3065
|
+
GGML_F16_VEC ax[GGML_F16_ARR];
|
3066
|
+
GGML_F16_VEC ay[GGML_F16_ARR];
|
3067
|
+
|
3068
|
+
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
2431
3069
|
for (int j = 0; j < GGML_F16_ARR; j++) {
|
2432
3070
|
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
2433
3071
|
|
@@ -2652,24 +3290,30 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
|
|
2652
3290
|
static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
|
2653
3291
|
[GGML_TYPE_F32] = 1,
|
2654
3292
|
[GGML_TYPE_F16] = 1,
|
2655
|
-
[GGML_TYPE_Q4_0] =
|
2656
|
-
[GGML_TYPE_Q4_1] =
|
3293
|
+
[GGML_TYPE_Q4_0] = QK4_0,
|
3294
|
+
[GGML_TYPE_Q4_1] = QK4_1,
|
3295
|
+
[GGML_TYPE_Q4_2] = QK4_2,
|
3296
|
+
[GGML_TYPE_Q4_3] = QK4_3,
|
3297
|
+
[GGML_TYPE_Q8_0] = QK8_0,
|
2657
3298
|
[GGML_TYPE_I8] = 1,
|
2658
3299
|
[GGML_TYPE_I16] = 1,
|
2659
3300
|
[GGML_TYPE_I32] = 1,
|
2660
3301
|
};
|
2661
|
-
static_assert(GGML_TYPE_COUNT ==
|
3302
|
+
static_assert(GGML_TYPE_COUNT == 10, "GGML_BLCK_SIZE is outdated");
|
2662
3303
|
|
2663
3304
|
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
|
2664
3305
|
[GGML_TYPE_F32] = sizeof(float),
|
2665
3306
|
[GGML_TYPE_F16] = sizeof(ggml_fp16_t),
|
2666
3307
|
[GGML_TYPE_Q4_0] = sizeof(block_q4_0),
|
2667
3308
|
[GGML_TYPE_Q4_1] = sizeof(block_q4_1),
|
3309
|
+
[GGML_TYPE_Q4_2] = sizeof(block_q4_2),
|
3310
|
+
[GGML_TYPE_Q4_3] = sizeof(block_q4_3),
|
3311
|
+
[GGML_TYPE_Q8_0] = sizeof(block_q8_0),
|
2668
3312
|
[GGML_TYPE_I8] = sizeof(int8_t),
|
2669
3313
|
[GGML_TYPE_I16] = sizeof(int16_t),
|
2670
3314
|
[GGML_TYPE_I32] = sizeof(int32_t),
|
2671
3315
|
};
|
2672
|
-
static_assert(GGML_TYPE_COUNT ==
|
3316
|
+
static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_SIZE is outdated");
|
2673
3317
|
|
2674
3318
|
|
2675
3319
|
static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
|
@@ -2677,11 +3321,28 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
|
|
2677
3321
|
[GGML_TYPE_F16] = "f16",
|
2678
3322
|
[GGML_TYPE_Q4_0] = "q4_0",
|
2679
3323
|
[GGML_TYPE_Q4_1] = "q4_1",
|
3324
|
+
[GGML_TYPE_Q4_2] = "q4_2",
|
3325
|
+
[GGML_TYPE_Q4_3] = "q4_3",
|
3326
|
+
[GGML_TYPE_Q8_0] = "q8_0",
|
2680
3327
|
[GGML_TYPE_I8] = "i8",
|
2681
3328
|
[GGML_TYPE_I16] = "i16",
|
2682
3329
|
[GGML_TYPE_I32] = "i32",
|
2683
3330
|
};
|
2684
|
-
static_assert(GGML_TYPE_COUNT ==
|
3331
|
+
static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_NAME is outdated");
|
3332
|
+
|
3333
|
+
static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
|
3334
|
+
[GGML_TYPE_F32] = false,
|
3335
|
+
[GGML_TYPE_F16] = false,
|
3336
|
+
[GGML_TYPE_Q4_0] = true,
|
3337
|
+
[GGML_TYPE_Q4_1] = true,
|
3338
|
+
[GGML_TYPE_Q4_2] = true,
|
3339
|
+
[GGML_TYPE_Q4_3] = true,
|
3340
|
+
[GGML_TYPE_Q8_0] = true,
|
3341
|
+
[GGML_TYPE_I8] = false,
|
3342
|
+
[GGML_TYPE_I16] = false,
|
3343
|
+
[GGML_TYPE_I32] = false,
|
3344
|
+
};
|
3345
|
+
static_assert(GGML_TYPE_COUNT == 10, "GGML_IS_QUANTIZED is outdated");
|
2685
3346
|
|
2686
3347
|
static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
2687
3348
|
"NONE",
|
@@ -2943,6 +3604,10 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
|
|
2943
3604
|
(t0->ne[3] == t1->ne[3]);
|
2944
3605
|
}
|
2945
3606
|
|
3607
|
+
bool ggml_is_quantized(enum ggml_type type) {
|
3608
|
+
return GGML_IS_QUANTIZED[type];
|
3609
|
+
}
|
3610
|
+
|
2946
3611
|
static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
|
2947
3612
|
return tensor->nb[0] > tensor->nb[1];
|
2948
3613
|
}
|
@@ -3053,6 +3718,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|
3053
3718
|
GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
|
3054
3719
|
}
|
3055
3720
|
|
3721
|
+
// initialize cuBLAS
|
3722
|
+
#if defined(GGML_USE_CUBLAS)
|
3723
|
+
init_cublas();
|
3724
|
+
#endif
|
3725
|
+
|
3056
3726
|
is_first_call = false;
|
3057
3727
|
}
|
3058
3728
|
|
@@ -3354,14 +4024,6 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
|
|
3354
4024
|
char * const data = tensor->data;
|
3355
4025
|
|
3356
4026
|
switch (tensor->type) {
|
3357
|
-
case GGML_TYPE_Q4_0:
|
3358
|
-
{
|
3359
|
-
GGML_ASSERT(false);
|
3360
|
-
} break;
|
3361
|
-
case GGML_TYPE_Q4_1:
|
3362
|
-
{
|
3363
|
-
GGML_ASSERT(false);
|
3364
|
-
} break;
|
3365
4027
|
case GGML_TYPE_I8:
|
3366
4028
|
{
|
3367
4029
|
assert(tensor->nb[0] == sizeof(int8_t));
|
@@ -3397,7 +4059,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
|
|
3397
4059
|
ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
|
3398
4060
|
}
|
3399
4061
|
} break;
|
3400
|
-
|
4062
|
+
default:
|
3401
4063
|
{
|
3402
4064
|
GGML_ASSERT(false);
|
3403
4065
|
} break;
|
@@ -3414,14 +4076,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
|
|
3414
4076
|
char * const data = tensor->data;
|
3415
4077
|
|
3416
4078
|
switch (tensor->type) {
|
3417
|
-
case GGML_TYPE_Q4_0:
|
3418
|
-
{
|
3419
|
-
GGML_ASSERT(false);
|
3420
|
-
} break;
|
3421
|
-
case GGML_TYPE_Q4_1:
|
3422
|
-
{
|
3423
|
-
GGML_ASSERT(false);
|
3424
|
-
} break;
|
3425
4079
|
case GGML_TYPE_I8:
|
3426
4080
|
{
|
3427
4081
|
assert(tensor->nb[0] == sizeof(int8_t));
|
@@ -3457,7 +4111,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
|
|
3457
4111
|
ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
|
3458
4112
|
}
|
3459
4113
|
} break;
|
3460
|
-
|
4114
|
+
default:
|
3461
4115
|
{
|
3462
4116
|
GGML_ASSERT(false);
|
3463
4117
|
} break;
|
@@ -3468,14 +4122,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
|
|
3468
4122
|
|
3469
4123
|
int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
|
3470
4124
|
switch (tensor->type) {
|
3471
|
-
case GGML_TYPE_Q4_0:
|
3472
|
-
{
|
3473
|
-
GGML_ASSERT(false);
|
3474
|
-
} break;
|
3475
|
-
case GGML_TYPE_Q4_1:
|
3476
|
-
{
|
3477
|
-
GGML_ASSERT(false);
|
3478
|
-
} break;
|
3479
4125
|
case GGML_TYPE_I8:
|
3480
4126
|
{
|
3481
4127
|
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
@@ -3501,7 +4147,7 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
|
|
3501
4147
|
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
3502
4148
|
return ((float *)(tensor->data))[i];
|
3503
4149
|
} break;
|
3504
|
-
|
4150
|
+
default:
|
3505
4151
|
{
|
3506
4152
|
GGML_ASSERT(false);
|
3507
4153
|
} break;
|
@@ -3512,14 +4158,6 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
|
|
3512
4158
|
|
3513
4159
|
void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
|
3514
4160
|
switch (tensor->type) {
|
3515
|
-
case GGML_TYPE_Q4_0:
|
3516
|
-
{
|
3517
|
-
GGML_ASSERT(false);
|
3518
|
-
} break;
|
3519
|
-
case GGML_TYPE_Q4_1:
|
3520
|
-
{
|
3521
|
-
GGML_ASSERT(false);
|
3522
|
-
} break;
|
3523
4161
|
case GGML_TYPE_I8:
|
3524
4162
|
{
|
3525
4163
|
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
@@ -3545,7 +4183,7 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
|
|
3545
4183
|
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
3546
4184
|
((float *)(tensor->data))[i] = value;
|
3547
4185
|
} break;
|
3548
|
-
|
4186
|
+
default:
|
3549
4187
|
{
|
3550
4188
|
GGML_ASSERT(false);
|
3551
4189
|
} break;
|
@@ -3554,14 +4192,6 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
|
|
3554
4192
|
|
3555
4193
|
float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
|
3556
4194
|
switch (tensor->type) {
|
3557
|
-
case GGML_TYPE_Q4_0:
|
3558
|
-
{
|
3559
|
-
GGML_ASSERT(false);
|
3560
|
-
} break;
|
3561
|
-
case GGML_TYPE_Q4_1:
|
3562
|
-
{
|
3563
|
-
GGML_ASSERT(false);
|
3564
|
-
} break;
|
3565
4195
|
case GGML_TYPE_I8:
|
3566
4196
|
{
|
3567
4197
|
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
@@ -3587,7 +4217,7 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
|
|
3587
4217
|
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
3588
4218
|
return ((float *)(tensor->data))[i];
|
3589
4219
|
} break;
|
3590
|
-
|
4220
|
+
default:
|
3591
4221
|
{
|
3592
4222
|
GGML_ASSERT(false);
|
3593
4223
|
} break;
|
@@ -3598,14 +4228,6 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
|
|
3598
4228
|
|
3599
4229
|
void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
|
3600
4230
|
switch (tensor->type) {
|
3601
|
-
case GGML_TYPE_Q4_0:
|
3602
|
-
{
|
3603
|
-
GGML_ASSERT(false);
|
3604
|
-
} break;
|
3605
|
-
case GGML_TYPE_Q4_1:
|
3606
|
-
{
|
3607
|
-
GGML_ASSERT(false);
|
3608
|
-
} break;
|
3609
4231
|
case GGML_TYPE_I8:
|
3610
4232
|
{
|
3611
4233
|
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
@@ -3631,7 +4253,7 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
|
|
3631
4253
|
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
3632
4254
|
((float *)(tensor->data))[i] = value;
|
3633
4255
|
} break;
|
3634
|
-
|
4256
|
+
default:
|
3635
4257
|
{
|
3636
4258
|
GGML_ASSERT(false);
|
3637
4259
|
} break;
|
@@ -5031,7 +5653,6 @@ static void ggml_compute_forward_dup_f16(
|
|
5031
5653
|
const struct ggml_compute_params * params,
|
5032
5654
|
const struct ggml_tensor * src0,
|
5033
5655
|
struct ggml_tensor * dst) {
|
5034
|
-
GGML_ASSERT(params->ith == 0);
|
5035
5656
|
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
5036
5657
|
|
5037
5658
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
@@ -5043,6 +5664,11 @@ static void ggml_compute_forward_dup_f16(
|
|
5043
5664
|
const int64_t ne02 = src0->ne[2];
|
5044
5665
|
const int64_t ne03 = src0->ne[3];
|
5045
5666
|
|
5667
|
+
const int64_t ne0 = dst->ne[0];
|
5668
|
+
const int64_t ne1 = dst->ne[1];
|
5669
|
+
const int64_t ne2 = dst->ne[2];
|
5670
|
+
const int64_t ne3 = dst->ne[3];
|
5671
|
+
|
5046
5672
|
const size_t nb00 = src0->nb[0];
|
5047
5673
|
const size_t nb01 = src0->nb[1];
|
5048
5674
|
const size_t nb02 = src0->nb[2];
|
@@ -5053,19 +5679,40 @@ static void ggml_compute_forward_dup_f16(
|
|
5053
5679
|
const size_t nb2 = dst->nb[2];
|
5054
5680
|
const size_t nb3 = dst->nb[3];
|
5055
5681
|
|
5682
|
+
const int ith = params->ith; // thread index
|
5683
|
+
const int nth = params->nth; // number of threads
|
5684
|
+
|
5056
5685
|
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
|
5057
|
-
|
5686
|
+
// parallelize by elements
|
5687
|
+
const int ne = ggml_nelements(dst);
|
5688
|
+
const int dr = (ne + nth - 1) / nth;
|
5689
|
+
const int ie0 = dr * ith;
|
5690
|
+
const int ie1 = MIN(ie0 + dr, ne);
|
5691
|
+
|
5692
|
+
memcpy(
|
5693
|
+
((char *) dst->data + ie0*nb0),
|
5694
|
+
((char *) src0->data + ie0*nb00),
|
5695
|
+
(ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
|
5696
|
+
|
5058
5697
|
return;
|
5059
5698
|
}
|
5060
5699
|
|
5700
|
+
// parallelize by rows
|
5701
|
+
const int nr = ne01;
|
5702
|
+
// number of rows per thread
|
5703
|
+
const int dr = (nr + nth - 1) / nth;
|
5704
|
+
// row range for this thread
|
5705
|
+
const int ir0 = dr * ith;
|
5706
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
5707
|
+
|
5061
5708
|
if (src0->type == dst->type &&
|
5062
|
-
|
5063
|
-
|
5709
|
+
ne00 == ne0 &&
|
5710
|
+
nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
|
5064
5711
|
// copy by rows
|
5065
5712
|
const size_t rs = ne00*nb00;
|
5066
5713
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5067
5714
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5068
|
-
for (int64_t i01 =
|
5715
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5069
5716
|
memcpy(
|
5070
5717
|
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
5071
5718
|
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
@@ -5079,21 +5726,21 @@ static void ggml_compute_forward_dup_f16(
|
|
5079
5726
|
// TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
|
5080
5727
|
|
5081
5728
|
if (ggml_is_contiguous(dst)) {
|
5082
|
-
if (
|
5729
|
+
if (nb00 == sizeof(ggml_fp16_t)) {
|
5083
5730
|
if (dst->type == GGML_TYPE_F16) {
|
5084
5731
|
size_t id = 0;
|
5085
|
-
const size_t rs = ne00*nb00;
|
5732
|
+
const size_t rs = ne00 * nb00;
|
5733
|
+
char * dst_ptr = (char *) dst->data;
|
5086
5734
|
|
5087
5735
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5088
5736
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5089
|
-
|
5737
|
+
id += rs * ir0;
|
5738
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5090
5739
|
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
5091
|
-
|
5092
|
-
|
5093
|
-
memcpy(dst_ptr, src0_ptr, rs);
|
5094
|
-
|
5095
|
-
id++;
|
5740
|
+
memcpy(dst_ptr + id, src0_ptr, rs);
|
5741
|
+
id += rs;
|
5096
5742
|
}
|
5743
|
+
id += rs * (ne01 - ir1);
|
5097
5744
|
}
|
5098
5745
|
}
|
5099
5746
|
} else if (dst->type == GGML_TYPE_F32) {
|
@@ -5102,14 +5749,39 @@ static void ggml_compute_forward_dup_f16(
|
|
5102
5749
|
|
5103
5750
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5104
5751
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5105
|
-
|
5752
|
+
id += ne00 * ir0;
|
5753
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5754
|
+
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
5106
5755
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5107
|
-
|
5108
|
-
|
5109
|
-
dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
|
5756
|
+
dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
|
5110
5757
|
id++;
|
5111
5758
|
}
|
5112
5759
|
}
|
5760
|
+
id += ne00 * (ne01 - ir1);
|
5761
|
+
}
|
5762
|
+
}
|
5763
|
+
} else if (ggml_is_quantized(dst->type)) {
|
5764
|
+
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
|
5765
|
+
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
5766
|
+
|
5767
|
+
size_t id = 0;
|
5768
|
+
size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
|
5769
|
+
char * dst_ptr = (char *) dst->data;
|
5770
|
+
|
5771
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
5772
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
5773
|
+
id += rs * ir0;
|
5774
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5775
|
+
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
5776
|
+
|
5777
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
5778
|
+
src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
|
5779
|
+
}
|
5780
|
+
|
5781
|
+
quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
5782
|
+
id += rs;
|
5783
|
+
}
|
5784
|
+
id += rs * (ne01 - ir1);
|
5113
5785
|
}
|
5114
5786
|
}
|
5115
5787
|
} else {
|
@@ -5124,7 +5796,8 @@ static void ggml_compute_forward_dup_f16(
|
|
5124
5796
|
|
5125
5797
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5126
5798
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5127
|
-
|
5799
|
+
id += ne00 * ir0;
|
5800
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5128
5801
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5129
5802
|
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5130
5803
|
|
@@ -5132,6 +5805,7 @@ static void ggml_compute_forward_dup_f16(
|
|
5132
5805
|
id++;
|
5133
5806
|
}
|
5134
5807
|
}
|
5808
|
+
id += ne00 * (ne01 - ir1);
|
5135
5809
|
}
|
5136
5810
|
}
|
5137
5811
|
} else if (dst->type == GGML_TYPE_F16) {
|
@@ -5140,7 +5814,8 @@ static void ggml_compute_forward_dup_f16(
|
|
5140
5814
|
|
5141
5815
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5142
5816
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5143
|
-
|
5817
|
+
id += ne00 * ir0;
|
5818
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5144
5819
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5145
5820
|
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5146
5821
|
|
@@ -5148,6 +5823,7 @@ static void ggml_compute_forward_dup_f16(
|
|
5148
5823
|
id++;
|
5149
5824
|
}
|
5150
5825
|
}
|
5826
|
+
id += ne00 * (ne01 - ir1);
|
5151
5827
|
}
|
5152
5828
|
}
|
5153
5829
|
} else {
|
@@ -5166,7 +5842,20 @@ static void ggml_compute_forward_dup_f16(
|
|
5166
5842
|
if (dst->type == GGML_TYPE_F16) {
|
5167
5843
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5168
5844
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5169
|
-
|
5845
|
+
i10 += ne00 * ir0;
|
5846
|
+
while (i10 >= ne0) {
|
5847
|
+
i10 -= ne0;
|
5848
|
+
if (++i11 == ne1) {
|
5849
|
+
i11 = 0;
|
5850
|
+
if (++i12 == ne2) {
|
5851
|
+
i12 = 0;
|
5852
|
+
if (++i13 == ne3) {
|
5853
|
+
i13 = 0;
|
5854
|
+
}
|
5855
|
+
}
|
5856
|
+
}
|
5857
|
+
}
|
5858
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5170
5859
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
5171
5860
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5172
5861
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
@@ -5187,25 +5876,51 @@ static void ggml_compute_forward_dup_f16(
|
|
5187
5876
|
}
|
5188
5877
|
}
|
5189
5878
|
}
|
5879
|
+
i10 += ne00 * (ne01 - ir1);
|
5880
|
+
while (i10 >= ne0) {
|
5881
|
+
i10 -= ne0;
|
5882
|
+
if (++i11 == ne1) {
|
5883
|
+
i11 = 0;
|
5884
|
+
if (++i12 == ne2) {
|
5885
|
+
i12 = 0;
|
5886
|
+
if (++i13 == ne3) {
|
5887
|
+
i13 = 0;
|
5888
|
+
}
|
5889
|
+
}
|
5890
|
+
}
|
5891
|
+
}
|
5190
5892
|
}
|
5191
5893
|
}
|
5192
5894
|
} else if (dst->type == GGML_TYPE_F32) {
|
5193
5895
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5194
5896
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5195
|
-
|
5897
|
+
i10 += ne00 * ir0;
|
5898
|
+
while (i10 >= ne0) {
|
5899
|
+
i10 -= ne0;
|
5900
|
+
if (++i11 == ne1) {
|
5901
|
+
i11 = 0;
|
5902
|
+
if (++i12 == ne2) {
|
5903
|
+
i12 = 0;
|
5904
|
+
if (++i13 == ne3) {
|
5905
|
+
i13 = 0;
|
5906
|
+
}
|
5907
|
+
}
|
5908
|
+
}
|
5909
|
+
}
|
5910
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5196
5911
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
5197
5912
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5198
5913
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
5199
5914
|
|
5200
5915
|
*(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
|
5201
5916
|
|
5202
|
-
if (++i10 ==
|
5917
|
+
if (++i10 == ne0) {
|
5203
5918
|
i10 = 0;
|
5204
|
-
if (++i11 ==
|
5919
|
+
if (++i11 == ne1) {
|
5205
5920
|
i11 = 0;
|
5206
|
-
if (++i12 ==
|
5921
|
+
if (++i12 == ne2) {
|
5207
5922
|
i12 = 0;
|
5208
|
-
if (++i13 ==
|
5923
|
+
if (++i13 == ne3) {
|
5209
5924
|
i13 = 0;
|
5210
5925
|
}
|
5211
5926
|
}
|
@@ -5213,6 +5928,19 @@ static void ggml_compute_forward_dup_f16(
|
|
5213
5928
|
}
|
5214
5929
|
}
|
5215
5930
|
}
|
5931
|
+
i10 += ne00 * (ne01 - ir1);
|
5932
|
+
while (i10 >= ne0) {
|
5933
|
+
i10 -= ne0;
|
5934
|
+
if (++i11 == ne1) {
|
5935
|
+
i11 = 0;
|
5936
|
+
if (++i12 == ne2) {
|
5937
|
+
i12 = 0;
|
5938
|
+
if (++i13 == ne3) {
|
5939
|
+
i13 = 0;
|
5940
|
+
}
|
5941
|
+
}
|
5942
|
+
}
|
5943
|
+
}
|
5216
5944
|
}
|
5217
5945
|
}
|
5218
5946
|
} else {
|
@@ -5224,7 +5952,6 @@ static void ggml_compute_forward_dup_f32(
|
|
5224
5952
|
const struct ggml_compute_params * params,
|
5225
5953
|
const struct ggml_tensor * src0,
|
5226
5954
|
struct ggml_tensor * dst) {
|
5227
|
-
GGML_ASSERT(params->ith == 0);
|
5228
5955
|
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
5229
5956
|
|
5230
5957
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
@@ -5236,6 +5963,11 @@ static void ggml_compute_forward_dup_f32(
|
|
5236
5963
|
const int64_t ne02 = src0->ne[2];
|
5237
5964
|
const int64_t ne03 = src0->ne[3];
|
5238
5965
|
|
5966
|
+
const int64_t ne0 = dst->ne[0];
|
5967
|
+
const int64_t ne1 = dst->ne[1];
|
5968
|
+
const int64_t ne2 = dst->ne[2];
|
5969
|
+
const int64_t ne3 = dst->ne[3];
|
5970
|
+
|
5239
5971
|
const size_t nb00 = src0->nb[0];
|
5240
5972
|
const size_t nb01 = src0->nb[1];
|
5241
5973
|
const size_t nb02 = src0->nb[2];
|
@@ -5246,19 +5978,40 @@ static void ggml_compute_forward_dup_f32(
|
|
5246
5978
|
const size_t nb2 = dst->nb[2];
|
5247
5979
|
const size_t nb3 = dst->nb[3];
|
5248
5980
|
|
5981
|
+
const int ith = params->ith; // thread index
|
5982
|
+
const int nth = params->nth; // number of threads
|
5983
|
+
|
5249
5984
|
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
|
5250
|
-
|
5985
|
+
// parallelize by elements
|
5986
|
+
const int ne = ggml_nelements(dst);
|
5987
|
+
const int dr = (ne + nth - 1) / nth;
|
5988
|
+
const int ie0 = dr * ith;
|
5989
|
+
const int ie1 = MIN(ie0 + dr, ne);
|
5990
|
+
|
5991
|
+
memcpy(
|
5992
|
+
((char *) dst->data + ie0*nb0),
|
5993
|
+
((char *) src0->data + ie0*nb00),
|
5994
|
+
(ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
|
5995
|
+
|
5251
5996
|
return;
|
5252
5997
|
}
|
5253
5998
|
|
5999
|
+
// parallelize by rows
|
6000
|
+
const int nr = ne01;
|
6001
|
+
// number of rows per thread
|
6002
|
+
const int dr = (nr + nth - 1) / nth;
|
6003
|
+
// row range for this thread
|
6004
|
+
const int ir0 = dr * ith;
|
6005
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
6006
|
+
|
5254
6007
|
if (src0->type == dst->type &&
|
5255
|
-
|
5256
|
-
|
6008
|
+
ne00 == ne0 &&
|
6009
|
+
nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
|
5257
6010
|
// copy by rows
|
5258
6011
|
const size_t rs = ne00*nb00;
|
5259
6012
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5260
6013
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5261
|
-
for (int64_t i01 =
|
6014
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5262
6015
|
memcpy(
|
5263
6016
|
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
5264
6017
|
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
@@ -5271,21 +6024,21 @@ static void ggml_compute_forward_dup_f32(
|
|
5271
6024
|
|
5272
6025
|
if (ggml_is_contiguous(dst)) {
|
5273
6026
|
// TODO: simplify
|
5274
|
-
if (
|
6027
|
+
if (nb00 == sizeof(float)) {
|
5275
6028
|
if (dst->type == GGML_TYPE_F32) {
|
5276
6029
|
size_t id = 0;
|
5277
|
-
const size_t rs = ne00*nb00;
|
6030
|
+
const size_t rs = ne00 * nb00;
|
6031
|
+
char * dst_ptr = (char *) dst->data;
|
5278
6032
|
|
5279
6033
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5280
6034
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5281
|
-
|
6035
|
+
id += rs * ir0;
|
6036
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5282
6037
|
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
5283
|
-
|
5284
|
-
|
5285
|
-
memcpy(dst_ptr, src0_ptr, rs);
|
5286
|
-
|
5287
|
-
id++;
|
6038
|
+
memcpy(dst_ptr + id, src0_ptr, rs);
|
6039
|
+
id += rs;
|
5288
6040
|
}
|
6041
|
+
id += rs * (ne01 - ir1);
|
5289
6042
|
}
|
5290
6043
|
}
|
5291
6044
|
} else if (dst->type == GGML_TYPE_F16) {
|
@@ -5294,7 +6047,8 @@ static void ggml_compute_forward_dup_f32(
|
|
5294
6047
|
|
5295
6048
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5296
6049
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5297
|
-
|
6050
|
+
id += ne00 * ir0;
|
6051
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5298
6052
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5299
6053
|
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5300
6054
|
|
@@ -5302,6 +6056,25 @@ static void ggml_compute_forward_dup_f32(
|
|
5302
6056
|
id++;
|
5303
6057
|
}
|
5304
6058
|
}
|
6059
|
+
id += ne00 * (ne01 - ir1);
|
6060
|
+
}
|
6061
|
+
}
|
6062
|
+
} else if (ggml_is_quantized(dst->type)) {
|
6063
|
+
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
|
6064
|
+
|
6065
|
+
size_t id = 0;
|
6066
|
+
size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
|
6067
|
+
char * dst_ptr = (char *) dst->data;
|
6068
|
+
|
6069
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
6070
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
6071
|
+
id += rs * ir0;
|
6072
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
6073
|
+
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
6074
|
+
quantize_row_q(src0_ptr, dst_ptr + id, ne00);
|
6075
|
+
id += rs;
|
6076
|
+
}
|
6077
|
+
id += rs * (ne01 - ir1);
|
5305
6078
|
}
|
5306
6079
|
}
|
5307
6080
|
} else {
|
@@ -5316,7 +6089,8 @@ static void ggml_compute_forward_dup_f32(
|
|
5316
6089
|
|
5317
6090
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5318
6091
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5319
|
-
|
6092
|
+
id += ne00 * ir0;
|
6093
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5320
6094
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5321
6095
|
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5322
6096
|
|
@@ -5324,6 +6098,7 @@ static void ggml_compute_forward_dup_f32(
|
|
5324
6098
|
id++;
|
5325
6099
|
}
|
5326
6100
|
}
|
6101
|
+
id += ne00 * (ne01 - ir1);
|
5327
6102
|
}
|
5328
6103
|
}
|
5329
6104
|
} else if (dst->type == GGML_TYPE_F16) {
|
@@ -5332,7 +6107,8 @@ static void ggml_compute_forward_dup_f32(
|
|
5332
6107
|
|
5333
6108
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5334
6109
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5335
|
-
|
6110
|
+
id += ne00 * ir0;
|
6111
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5336
6112
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5337
6113
|
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5338
6114
|
|
@@ -5340,6 +6116,7 @@ static void ggml_compute_forward_dup_f32(
|
|
5340
6116
|
id++;
|
5341
6117
|
}
|
5342
6118
|
}
|
6119
|
+
id += ne00 * (ne01 - ir1);
|
5343
6120
|
}
|
5344
6121
|
}
|
5345
6122
|
} else {
|
@@ -5351,6 +6128,7 @@ static void ggml_compute_forward_dup_f32(
|
|
5351
6128
|
}
|
5352
6129
|
|
5353
6130
|
// dst counters
|
6131
|
+
|
5354
6132
|
int64_t i10 = 0;
|
5355
6133
|
int64_t i11 = 0;
|
5356
6134
|
int64_t i12 = 0;
|
@@ -5359,20 +6137,33 @@ static void ggml_compute_forward_dup_f32(
|
|
5359
6137
|
if (dst->type == GGML_TYPE_F32) {
|
5360
6138
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5361
6139
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5362
|
-
|
6140
|
+
i10 += ne00 * ir0;
|
6141
|
+
while (i10 >= ne0) {
|
6142
|
+
i10 -= ne0;
|
6143
|
+
if (++i11 == ne1) {
|
6144
|
+
i11 = 0;
|
6145
|
+
if (++i12 == ne2) {
|
6146
|
+
i12 = 0;
|
6147
|
+
if (++i13 == ne3) {
|
6148
|
+
i13 = 0;
|
6149
|
+
}
|
6150
|
+
}
|
6151
|
+
}
|
6152
|
+
}
|
6153
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5363
6154
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
5364
6155
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5365
6156
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
5366
6157
|
|
5367
6158
|
memcpy(dst_ptr, src0_ptr, sizeof(float));
|
5368
6159
|
|
5369
|
-
if (++i10 ==
|
6160
|
+
if (++i10 == ne0) {
|
5370
6161
|
i10 = 0;
|
5371
|
-
if (++i11 ==
|
6162
|
+
if (++i11 == ne1) {
|
5372
6163
|
i11 = 0;
|
5373
|
-
if (++i12 ==
|
6164
|
+
if (++i12 == ne2) {
|
5374
6165
|
i12 = 0;
|
5375
|
-
if (++i13 ==
|
6166
|
+
if (++i13 == ne3) {
|
5376
6167
|
i13 = 0;
|
5377
6168
|
}
|
5378
6169
|
}
|
@@ -5380,25 +6171,51 @@ static void ggml_compute_forward_dup_f32(
|
|
5380
6171
|
}
|
5381
6172
|
}
|
5382
6173
|
}
|
6174
|
+
i10 += ne00 * (ne01 - ir1);
|
6175
|
+
while (i10 >= ne0) {
|
6176
|
+
i10 -= ne0;
|
6177
|
+
if (++i11 == ne1) {
|
6178
|
+
i11 = 0;
|
6179
|
+
if (++i12 == ne2) {
|
6180
|
+
i12 = 0;
|
6181
|
+
if (++i13 == ne3) {
|
6182
|
+
i13 = 0;
|
6183
|
+
}
|
6184
|
+
}
|
6185
|
+
}
|
6186
|
+
}
|
5383
6187
|
}
|
5384
6188
|
}
|
5385
6189
|
} else if (dst->type == GGML_TYPE_F16) {
|
5386
6190
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5387
6191
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5388
|
-
|
6192
|
+
i10 += ne00 * ir0;
|
6193
|
+
while (i10 >= ne0) {
|
6194
|
+
i10 -= ne0;
|
6195
|
+
if (++i11 == ne1) {
|
6196
|
+
i11 = 0;
|
6197
|
+
if (++i12 == ne2) {
|
6198
|
+
i12 = 0;
|
6199
|
+
if (++i13 == ne3) {
|
6200
|
+
i13 = 0;
|
6201
|
+
}
|
6202
|
+
}
|
6203
|
+
}
|
6204
|
+
}
|
6205
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5389
6206
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
5390
6207
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5391
6208
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
5392
6209
|
|
5393
6210
|
*(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
|
5394
6211
|
|
5395
|
-
if (++i10 ==
|
6212
|
+
if (++i10 == ne0) {
|
5396
6213
|
i10 = 0;
|
5397
|
-
if (++i11 ==
|
6214
|
+
if (++i11 == ne1) {
|
5398
6215
|
i11 = 0;
|
5399
|
-
if (++i12 ==
|
6216
|
+
if (++i12 == ne2) {
|
5400
6217
|
i12 = 0;
|
5401
|
-
if (++i13 ==
|
6218
|
+
if (++i13 == ne3) {
|
5402
6219
|
i13 = 0;
|
5403
6220
|
}
|
5404
6221
|
}
|
@@ -5406,6 +6223,19 @@ static void ggml_compute_forward_dup_f32(
|
|
5406
6223
|
}
|
5407
6224
|
}
|
5408
6225
|
}
|
6226
|
+
i10 += ne00 * (ne01 - ir1);
|
6227
|
+
while (i10 >= ne0) {
|
6228
|
+
i10 -= ne0;
|
6229
|
+
if (++i11 == ne1) {
|
6230
|
+
i11 = 0;
|
6231
|
+
if (++i12 == ne2) {
|
6232
|
+
i12 = 0;
|
6233
|
+
if (++i13 == ne3) {
|
6234
|
+
i13 = 0;
|
6235
|
+
}
|
6236
|
+
}
|
6237
|
+
}
|
6238
|
+
}
|
5409
6239
|
}
|
5410
6240
|
}
|
5411
6241
|
} else {
|
@@ -5426,12 +6256,7 @@ static void ggml_compute_forward_dup(
|
|
5426
6256
|
{
|
5427
6257
|
ggml_compute_forward_dup_f32(params, src0, dst);
|
5428
6258
|
} break;
|
5429
|
-
|
5430
|
-
case GGML_TYPE_Q4_1:
|
5431
|
-
case GGML_TYPE_I8:
|
5432
|
-
case GGML_TYPE_I16:
|
5433
|
-
case GGML_TYPE_I32:
|
5434
|
-
case GGML_TYPE_COUNT:
|
6259
|
+
default:
|
5435
6260
|
{
|
5436
6261
|
GGML_ASSERT(false);
|
5437
6262
|
} break;
|
@@ -5497,6 +6322,212 @@ static void ggml_compute_forward_add_f32(
|
|
5497
6322
|
}
|
5498
6323
|
}
|
5499
6324
|
|
6325
|
+
static void ggml_compute_forward_add_f16_f32(
|
6326
|
+
const struct ggml_compute_params * params,
|
6327
|
+
const struct ggml_tensor * src0,
|
6328
|
+
const struct ggml_tensor * src1,
|
6329
|
+
struct ggml_tensor * dst) {
|
6330
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
6331
|
+
|
6332
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
6333
|
+
return;
|
6334
|
+
}
|
6335
|
+
|
6336
|
+
const int ith = params->ith;
|
6337
|
+
const int nth = params->nth;
|
6338
|
+
|
6339
|
+
const int n = ggml_nrows(src0);
|
6340
|
+
const int nc = src0->ne[0];
|
6341
|
+
|
6342
|
+
const size_t nb00 = src0->nb[0];
|
6343
|
+
const size_t nb01 = src0->nb[1];
|
6344
|
+
|
6345
|
+
const size_t nb10 = src1->nb[0];
|
6346
|
+
const size_t nb11 = src1->nb[1];
|
6347
|
+
|
6348
|
+
const size_t nb0 = dst->nb[0];
|
6349
|
+
const size_t nb1 = dst->nb[1];
|
6350
|
+
|
6351
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
6352
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
6353
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F16);
|
6354
|
+
|
6355
|
+
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
|
6356
|
+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
6357
|
+
|
6358
|
+
if (nb10 == sizeof(float)) {
|
6359
|
+
for (int j = ith; j < n; j += nth) {
|
6360
|
+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
|
6361
|
+
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
|
6362
|
+
for (int i = 0; i < nc; i++) {
|
6363
|
+
float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
|
6364
|
+
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
|
6365
|
+
}
|
6366
|
+
}
|
6367
|
+
}
|
6368
|
+
else {
|
6369
|
+
// src1 is not contiguous
|
6370
|
+
GGML_ASSERT(false);
|
6371
|
+
}
|
6372
|
+
}
|
6373
|
+
|
6374
|
+
static void ggml_compute_forward_add_f16_f16(
|
6375
|
+
const struct ggml_compute_params * params,
|
6376
|
+
const struct ggml_tensor * src0,
|
6377
|
+
const struct ggml_tensor * src1,
|
6378
|
+
struct ggml_tensor * dst) {
|
6379
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
6380
|
+
|
6381
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
6382
|
+
return;
|
6383
|
+
}
|
6384
|
+
|
6385
|
+
const int ith = params->ith;
|
6386
|
+
const int nth = params->nth;
|
6387
|
+
|
6388
|
+
const int n = ggml_nrows(src0);
|
6389
|
+
const int nc = src0->ne[0];
|
6390
|
+
|
6391
|
+
const size_t nb00 = src0->nb[0];
|
6392
|
+
const size_t nb01 = src0->nb[1];
|
6393
|
+
|
6394
|
+
const size_t nb10 = src1->nb[0];
|
6395
|
+
const size_t nb11 = src1->nb[1];
|
6396
|
+
|
6397
|
+
const size_t nb0 = dst->nb[0];
|
6398
|
+
const size_t nb1 = dst->nb[1];
|
6399
|
+
|
6400
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
6401
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F16);
|
6402
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F16);
|
6403
|
+
|
6404
|
+
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
|
6405
|
+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
6406
|
+
|
6407
|
+
if (nb10 == sizeof(ggml_fp16_t)) {
|
6408
|
+
for (int j = ith; j < n; j += nth) {
|
6409
|
+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
|
6410
|
+
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
|
6411
|
+
for (int i = 0; i < nc; i++) {
|
6412
|
+
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
|
6413
|
+
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
|
6414
|
+
}
|
6415
|
+
}
|
6416
|
+
}
|
6417
|
+
else {
|
6418
|
+
// src1 is not contiguous
|
6419
|
+
GGML_ASSERT(false);
|
6420
|
+
}
|
6421
|
+
}
|
6422
|
+
|
6423
|
+
static void ggml_compute_forward_add_q_f32(
|
6424
|
+
const struct ggml_compute_params * params,
|
6425
|
+
const struct ggml_tensor * src0,
|
6426
|
+
const struct ggml_tensor * src1,
|
6427
|
+
struct ggml_tensor * dst) {
|
6428
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
6429
|
+
|
6430
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
6431
|
+
return;
|
6432
|
+
}
|
6433
|
+
|
6434
|
+
const int64_t ne00 = src0->ne[0];
|
6435
|
+
const int64_t ne01 = src0->ne[1];
|
6436
|
+
const int64_t ne02 = src0->ne[2];
|
6437
|
+
const int64_t ne03 = src0->ne[3];
|
6438
|
+
|
6439
|
+
//const int64_t ne10 = src1->ne[0];
|
6440
|
+
//const int64_t ne11 = src1->ne[1];
|
6441
|
+
const int64_t ne12 = src1->ne[2];
|
6442
|
+
const int64_t ne13 = src1->ne[3];
|
6443
|
+
|
6444
|
+
//const int64_t ne0 = dst->ne[0];
|
6445
|
+
//const int64_t ne1 = dst->ne[1];
|
6446
|
+
const int64_t ne2 = dst->ne[2];
|
6447
|
+
const int64_t ne3 = dst->ne[3];
|
6448
|
+
|
6449
|
+
const int nb00 = src0->nb[0];
|
6450
|
+
const int nb01 = src0->nb[1];
|
6451
|
+
const int nb02 = src0->nb[2];
|
6452
|
+
const int nb03 = src0->nb[3];
|
6453
|
+
|
6454
|
+
const int nb10 = src1->nb[0];
|
6455
|
+
const int nb11 = src1->nb[1];
|
6456
|
+
const int nb12 = src1->nb[2];
|
6457
|
+
const int nb13 = src1->nb[3];
|
6458
|
+
|
6459
|
+
const int nb0 = dst->nb[0];
|
6460
|
+
const int nb1 = dst->nb[1];
|
6461
|
+
const int nb2 = dst->nb[2];
|
6462
|
+
const int nb3 = dst->nb[3];
|
6463
|
+
|
6464
|
+
const int ith = params->ith;
|
6465
|
+
const int nth = params->nth;
|
6466
|
+
|
6467
|
+
GGML_ASSERT(ne02 == ne12);
|
6468
|
+
GGML_ASSERT(ne03 == ne13);
|
6469
|
+
GGML_ASSERT(ne2 == ne12);
|
6470
|
+
GGML_ASSERT(ne3 == ne13);
|
6471
|
+
|
6472
|
+
const enum ggml_type type = src0->type;
|
6473
|
+
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
|
6474
|
+
quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
|
6475
|
+
|
6476
|
+
// we don't support permuted src0 or src1
|
6477
|
+
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
|
6478
|
+
GGML_ASSERT(nb10 == sizeof(float));
|
6479
|
+
|
6480
|
+
// dst cannot be transposed or permuted
|
6481
|
+
GGML_ASSERT(nb0 <= nb1);
|
6482
|
+
GGML_ASSERT(nb1 <= nb2);
|
6483
|
+
GGML_ASSERT(nb2 <= nb3);
|
6484
|
+
|
6485
|
+
GGML_ASSERT(ggml_is_quantized(src0->type));
|
6486
|
+
GGML_ASSERT(dst->type == src0->type);
|
6487
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
6488
|
+
|
6489
|
+
// total rows in src0
|
6490
|
+
const int nr = ne01*ne02*ne03;
|
6491
|
+
|
6492
|
+
// rows per thread
|
6493
|
+
const int dr = (nr + nth - 1)/nth;
|
6494
|
+
|
6495
|
+
// row range for this thread
|
6496
|
+
const int ir0 = dr*ith;
|
6497
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
6498
|
+
|
6499
|
+
float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
6500
|
+
|
6501
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
6502
|
+
// src0 indices
|
6503
|
+
const int i03 = ir/(ne02*ne01);
|
6504
|
+
const int i02 = (ir - i03*ne02*ne01)/ne01;
|
6505
|
+
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
6506
|
+
|
6507
|
+
// src1 and dst are same shape as src0 => same indices
|
6508
|
+
const int i13 = i03;
|
6509
|
+
const int i12 = i02;
|
6510
|
+
const int i11 = i01;
|
6511
|
+
|
6512
|
+
const int i3 = i03;
|
6513
|
+
const int i2 = i02;
|
6514
|
+
const int i1 = i01;
|
6515
|
+
|
6516
|
+
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
|
6517
|
+
float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
|
6518
|
+
void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb0));
|
6519
|
+
|
6520
|
+
assert(ne00 % 32 == 0);
|
6521
|
+
|
6522
|
+
// unquantize row from src0 to temp buffer
|
6523
|
+
dequantize_row_q(src0_row, wdata, ne00);
|
6524
|
+
// add src1
|
6525
|
+
ggml_vec_acc_f32(ne00, wdata, src1_row);
|
6526
|
+
// quantize row to dst
|
6527
|
+
quantize_row_q(wdata, dst_row, ne00);
|
6528
|
+
}
|
6529
|
+
}
|
6530
|
+
|
5500
6531
|
static void ggml_compute_forward_add(
|
5501
6532
|
const struct ggml_compute_params * params,
|
5502
6533
|
const struct ggml_tensor * src0,
|
@@ -5507,13 +6538,26 @@ static void ggml_compute_forward_add(
|
|
5507
6538
|
{
|
5508
6539
|
ggml_compute_forward_add_f32(params, src0, src1, dst);
|
5509
6540
|
} break;
|
6541
|
+
case GGML_TYPE_F16:
|
6542
|
+
{
|
6543
|
+
if (src1->type == GGML_TYPE_F16) {
|
6544
|
+
ggml_compute_forward_add_f16_f16(params, src0, src1, dst);
|
6545
|
+
}
|
6546
|
+
else if (src1->type == GGML_TYPE_F32) {
|
6547
|
+
ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
|
6548
|
+
}
|
6549
|
+
else {
|
6550
|
+
GGML_ASSERT(false);
|
6551
|
+
}
|
6552
|
+
} break;
|
5510
6553
|
case GGML_TYPE_Q4_0:
|
5511
6554
|
case GGML_TYPE_Q4_1:
|
5512
|
-
case
|
5513
|
-
case
|
5514
|
-
|
5515
|
-
|
5516
|
-
|
6555
|
+
case GGML_TYPE_Q4_2:
|
6556
|
+
case GGML_TYPE_Q4_3:
|
6557
|
+
{
|
6558
|
+
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
6559
|
+
} break;
|
6560
|
+
default:
|
5517
6561
|
{
|
5518
6562
|
GGML_ASSERT(false);
|
5519
6563
|
} break;
|
@@ -5559,13 +6603,7 @@ static void ggml_compute_forward_sub(
|
|
5559
6603
|
{
|
5560
6604
|
ggml_compute_forward_sub_f32(params, src0, src1, dst);
|
5561
6605
|
} break;
|
5562
|
-
|
5563
|
-
case GGML_TYPE_Q4_1:
|
5564
|
-
case GGML_TYPE_I8:
|
5565
|
-
case GGML_TYPE_I16:
|
5566
|
-
case GGML_TYPE_I32:
|
5567
|
-
case GGML_TYPE_F16:
|
5568
|
-
case GGML_TYPE_COUNT:
|
6606
|
+
default:
|
5569
6607
|
{
|
5570
6608
|
GGML_ASSERT(false);
|
5571
6609
|
} break;
|
@@ -5611,13 +6649,7 @@ static void ggml_compute_forward_mul(
|
|
5611
6649
|
{
|
5612
6650
|
ggml_compute_forward_mul_f32(params, src0, src1, dst);
|
5613
6651
|
} break;
|
5614
|
-
|
5615
|
-
case GGML_TYPE_Q4_1:
|
5616
|
-
case GGML_TYPE_I8:
|
5617
|
-
case GGML_TYPE_I16:
|
5618
|
-
case GGML_TYPE_I32:
|
5619
|
-
case GGML_TYPE_F16:
|
5620
|
-
case GGML_TYPE_COUNT:
|
6652
|
+
default:
|
5621
6653
|
{
|
5622
6654
|
GGML_ASSERT(false);
|
5623
6655
|
} break;
|
@@ -5663,13 +6695,7 @@ static void ggml_compute_forward_div(
|
|
5663
6695
|
{
|
5664
6696
|
ggml_compute_forward_div_f32(params, src0, src1, dst);
|
5665
6697
|
} break;
|
5666
|
-
|
5667
|
-
case GGML_TYPE_Q4_1:
|
5668
|
-
case GGML_TYPE_I8:
|
5669
|
-
case GGML_TYPE_I16:
|
5670
|
-
case GGML_TYPE_I32:
|
5671
|
-
case GGML_TYPE_F16:
|
5672
|
-
case GGML_TYPE_COUNT:
|
6698
|
+
default:
|
5673
6699
|
{
|
5674
6700
|
GGML_ASSERT(false);
|
5675
6701
|
} break;
|
@@ -5711,13 +6737,7 @@ static void ggml_compute_forward_sqr(
|
|
5711
6737
|
{
|
5712
6738
|
ggml_compute_forward_sqr_f32(params, src0, dst);
|
5713
6739
|
} break;
|
5714
|
-
|
5715
|
-
case GGML_TYPE_Q4_1:
|
5716
|
-
case GGML_TYPE_I8:
|
5717
|
-
case GGML_TYPE_I16:
|
5718
|
-
case GGML_TYPE_I32:
|
5719
|
-
case GGML_TYPE_F16:
|
5720
|
-
case GGML_TYPE_COUNT:
|
6740
|
+
default:
|
5721
6741
|
{
|
5722
6742
|
GGML_ASSERT(false);
|
5723
6743
|
} break;
|
@@ -5759,13 +6779,7 @@ static void ggml_compute_forward_sqrt(
|
|
5759
6779
|
{
|
5760
6780
|
ggml_compute_forward_sqrt_f32(params, src0, dst);
|
5761
6781
|
} break;
|
5762
|
-
|
5763
|
-
case GGML_TYPE_Q4_1:
|
5764
|
-
case GGML_TYPE_I8:
|
5765
|
-
case GGML_TYPE_I16:
|
5766
|
-
case GGML_TYPE_I32:
|
5767
|
-
case GGML_TYPE_F16:
|
5768
|
-
case GGML_TYPE_COUNT:
|
6782
|
+
default:
|
5769
6783
|
{
|
5770
6784
|
GGML_ASSERT(false);
|
5771
6785
|
} break;
|
@@ -5817,13 +6831,7 @@ static void ggml_compute_forward_sum(
|
|
5817
6831
|
{
|
5818
6832
|
ggml_compute_forward_sum_f32(params, src0, dst);
|
5819
6833
|
} break;
|
5820
|
-
|
5821
|
-
case GGML_TYPE_Q4_1:
|
5822
|
-
case GGML_TYPE_I8:
|
5823
|
-
case GGML_TYPE_I16:
|
5824
|
-
case GGML_TYPE_I32:
|
5825
|
-
case GGML_TYPE_F16:
|
5826
|
-
case GGML_TYPE_COUNT:
|
6834
|
+
default:
|
5827
6835
|
{
|
5828
6836
|
GGML_ASSERT(false);
|
5829
6837
|
} break;
|
@@ -5894,13 +6902,7 @@ static void ggml_compute_forward_mean(
|
|
5894
6902
|
{
|
5895
6903
|
ggml_compute_forward_mean_f32(params, src0, dst);
|
5896
6904
|
} break;
|
5897
|
-
|
5898
|
-
case GGML_TYPE_Q4_1:
|
5899
|
-
case GGML_TYPE_I8:
|
5900
|
-
case GGML_TYPE_I16:
|
5901
|
-
case GGML_TYPE_I32:
|
5902
|
-
case GGML_TYPE_F16:
|
5903
|
-
case GGML_TYPE_COUNT:
|
6905
|
+
default:
|
5904
6906
|
{
|
5905
6907
|
GGML_ASSERT(false);
|
5906
6908
|
} break;
|
@@ -5958,13 +6960,7 @@ static void ggml_compute_forward_repeat(
|
|
5958
6960
|
{
|
5959
6961
|
ggml_compute_forward_repeat_f32(params, src0, dst);
|
5960
6962
|
} break;
|
5961
|
-
|
5962
|
-
case GGML_TYPE_Q4_1:
|
5963
|
-
case GGML_TYPE_I8:
|
5964
|
-
case GGML_TYPE_I16:
|
5965
|
-
case GGML_TYPE_I32:
|
5966
|
-
case GGML_TYPE_F16:
|
5967
|
-
case GGML_TYPE_COUNT:
|
6963
|
+
default:
|
5968
6964
|
{
|
5969
6965
|
GGML_ASSERT(false);
|
5970
6966
|
} break;
|
@@ -6006,13 +7002,7 @@ static void ggml_compute_forward_abs(
|
|
6006
7002
|
{
|
6007
7003
|
ggml_compute_forward_abs_f32(params, src0, dst);
|
6008
7004
|
} break;
|
6009
|
-
|
6010
|
-
case GGML_TYPE_Q4_1:
|
6011
|
-
case GGML_TYPE_I8:
|
6012
|
-
case GGML_TYPE_I16:
|
6013
|
-
case GGML_TYPE_I32:
|
6014
|
-
case GGML_TYPE_F16:
|
6015
|
-
case GGML_TYPE_COUNT:
|
7005
|
+
default:
|
6016
7006
|
{
|
6017
7007
|
GGML_ASSERT(false);
|
6018
7008
|
} break;
|
@@ -6054,13 +7044,7 @@ static void ggml_compute_forward_sgn(
|
|
6054
7044
|
{
|
6055
7045
|
ggml_compute_forward_sgn_f32(params, src0, dst);
|
6056
7046
|
} break;
|
6057
|
-
|
6058
|
-
case GGML_TYPE_Q4_1:
|
6059
|
-
case GGML_TYPE_I8:
|
6060
|
-
case GGML_TYPE_I16:
|
6061
|
-
case GGML_TYPE_I32:
|
6062
|
-
case GGML_TYPE_F16:
|
6063
|
-
case GGML_TYPE_COUNT:
|
7047
|
+
default:
|
6064
7048
|
{
|
6065
7049
|
GGML_ASSERT(false);
|
6066
7050
|
} break;
|
@@ -6102,13 +7086,7 @@ static void ggml_compute_forward_neg(
|
|
6102
7086
|
{
|
6103
7087
|
ggml_compute_forward_neg_f32(params, src0, dst);
|
6104
7088
|
} break;
|
6105
|
-
|
6106
|
-
case GGML_TYPE_Q4_1:
|
6107
|
-
case GGML_TYPE_I8:
|
6108
|
-
case GGML_TYPE_I16:
|
6109
|
-
case GGML_TYPE_I32:
|
6110
|
-
case GGML_TYPE_F16:
|
6111
|
-
case GGML_TYPE_COUNT:
|
7089
|
+
default:
|
6112
7090
|
{
|
6113
7091
|
GGML_ASSERT(false);
|
6114
7092
|
} break;
|
@@ -6150,13 +7128,7 @@ static void ggml_compute_forward_step(
|
|
6150
7128
|
{
|
6151
7129
|
ggml_compute_forward_step_f32(params, src0, dst);
|
6152
7130
|
} break;
|
6153
|
-
|
6154
|
-
case GGML_TYPE_Q4_1:
|
6155
|
-
case GGML_TYPE_I8:
|
6156
|
-
case GGML_TYPE_I16:
|
6157
|
-
case GGML_TYPE_I32:
|
6158
|
-
case GGML_TYPE_F16:
|
6159
|
-
case GGML_TYPE_COUNT:
|
7131
|
+
default:
|
6160
7132
|
{
|
6161
7133
|
GGML_ASSERT(false);
|
6162
7134
|
} break;
|
@@ -6193,18 +7165,12 @@ static void ggml_compute_forward_relu(
|
|
6193
7165
|
const struct ggml_compute_params * params,
|
6194
7166
|
const struct ggml_tensor * src0,
|
6195
7167
|
struct ggml_tensor * dst) {
|
6196
|
-
switch (src0->type) {
|
6197
|
-
case GGML_TYPE_F32:
|
6198
|
-
{
|
6199
|
-
ggml_compute_forward_relu_f32(params, src0, dst);
|
6200
|
-
} break;
|
6201
|
-
|
6202
|
-
case GGML_TYPE_Q4_1:
|
6203
|
-
case GGML_TYPE_I8:
|
6204
|
-
case GGML_TYPE_I16:
|
6205
|
-
case GGML_TYPE_I32:
|
6206
|
-
case GGML_TYPE_F16:
|
6207
|
-
case GGML_TYPE_COUNT:
|
7168
|
+
switch (src0->type) {
|
7169
|
+
case GGML_TYPE_F32:
|
7170
|
+
{
|
7171
|
+
ggml_compute_forward_relu_f32(params, src0, dst);
|
7172
|
+
} break;
|
7173
|
+
default:
|
6208
7174
|
{
|
6209
7175
|
GGML_ASSERT(false);
|
6210
7176
|
} break;
|
@@ -6263,13 +7229,7 @@ static void ggml_compute_forward_gelu(
|
|
6263
7229
|
{
|
6264
7230
|
ggml_compute_forward_gelu_f32(params, src0, dst);
|
6265
7231
|
} break;
|
6266
|
-
|
6267
|
-
case GGML_TYPE_Q4_1:
|
6268
|
-
case GGML_TYPE_I8:
|
6269
|
-
case GGML_TYPE_I16:
|
6270
|
-
case GGML_TYPE_I32:
|
6271
|
-
case GGML_TYPE_F16:
|
6272
|
-
case GGML_TYPE_COUNT:
|
7232
|
+
default:
|
6273
7233
|
{
|
6274
7234
|
GGML_ASSERT(false);
|
6275
7235
|
} break;
|
@@ -6330,13 +7290,7 @@ static void ggml_compute_forward_silu(
|
|
6330
7290
|
{
|
6331
7291
|
ggml_compute_forward_silu_f32(params, src0, dst);
|
6332
7292
|
} break;
|
6333
|
-
|
6334
|
-
case GGML_TYPE_Q4_1:
|
6335
|
-
case GGML_TYPE_I8:
|
6336
|
-
case GGML_TYPE_I16:
|
6337
|
-
case GGML_TYPE_I32:
|
6338
|
-
case GGML_TYPE_F16:
|
6339
|
-
case GGML_TYPE_COUNT:
|
7293
|
+
default:
|
6340
7294
|
{
|
6341
7295
|
GGML_ASSERT(false);
|
6342
7296
|
} break;
|
@@ -6416,13 +7370,7 @@ static void ggml_compute_forward_norm(
|
|
6416
7370
|
{
|
6417
7371
|
ggml_compute_forward_norm_f32(params, src0, dst);
|
6418
7372
|
} break;
|
6419
|
-
|
6420
|
-
case GGML_TYPE_Q4_1:
|
6421
|
-
case GGML_TYPE_I8:
|
6422
|
-
case GGML_TYPE_I16:
|
6423
|
-
case GGML_TYPE_I32:
|
6424
|
-
case GGML_TYPE_F16:
|
6425
|
-
case GGML_TYPE_COUNT:
|
7373
|
+
default:
|
6426
7374
|
{
|
6427
7375
|
GGML_ASSERT(false);
|
6428
7376
|
} break;
|
@@ -6496,13 +7444,7 @@ static void ggml_compute_forward_rms_norm(
|
|
6496
7444
|
{
|
6497
7445
|
ggml_compute_forward_rms_norm_f32(params, src0, dst);
|
6498
7446
|
} break;
|
6499
|
-
|
6500
|
-
case GGML_TYPE_Q4_1:
|
6501
|
-
case GGML_TYPE_I8:
|
6502
|
-
case GGML_TYPE_I16:
|
6503
|
-
case GGML_TYPE_I32:
|
6504
|
-
case GGML_TYPE_F16:
|
6505
|
-
case GGML_TYPE_COUNT:
|
7447
|
+
default:
|
6506
7448
|
{
|
6507
7449
|
GGML_ASSERT(false);
|
6508
7450
|
} break;
|
@@ -6512,7 +7454,7 @@ static void ggml_compute_forward_rms_norm(
|
|
6512
7454
|
|
6513
7455
|
// ggml_compute_forward_mul_mat
|
6514
7456
|
|
6515
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
7457
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
6516
7458
|
// helper function to determine if it is better to use BLAS or not
|
6517
7459
|
// for large matrices, BLAS is faster
|
6518
7460
|
static bool ggml_compute_forward_mul_mat_use_blas(
|
@@ -6552,7 +7494,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
6552
7494
|
const int64_t ne02 = src0->ne[2];
|
6553
7495
|
const int64_t ne03 = src0->ne[3];
|
6554
7496
|
|
6555
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
7497
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
6556
7498
|
const int64_t ne10 = src1->ne[0];
|
6557
7499
|
#endif
|
6558
7500
|
const int64_t ne11 = src1->ne[1];
|
@@ -6609,7 +7551,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
6609
7551
|
// nb01 >= nb00 - src0 is not transposed
|
6610
7552
|
// compute by src0 rows
|
6611
7553
|
|
6612
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
7554
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
6613
7555
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
6614
7556
|
if (params->ith != 0) {
|
6615
7557
|
return;
|
@@ -6623,6 +7565,21 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
6623
7565
|
return;
|
6624
7566
|
}
|
6625
7567
|
|
7568
|
+
#if defined(GGML_USE_CUBLAS)
|
7569
|
+
float *d_X = NULL;
|
7570
|
+
float *d_Y = NULL;
|
7571
|
+
float *d_D = NULL;
|
7572
|
+
const float alpha = 1.0f;
|
7573
|
+
const float beta = 0.0f;
|
7574
|
+
const int x_ne = ne01 * ne10;
|
7575
|
+
const int y_ne = ne11 * ne10;
|
7576
|
+
const int d_ne = ne11 * ne01;
|
7577
|
+
|
7578
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
|
7579
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
|
7580
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
7581
|
+
#endif
|
7582
|
+
|
6626
7583
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
6627
7584
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
6628
7585
|
const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
|
@@ -6630,15 +7587,37 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
6630
7587
|
|
6631
7588
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
6632
7589
|
|
7590
|
+
#if defined(GGML_USE_CUBLAS)
|
7591
|
+
// copy data to device
|
7592
|
+
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
|
7593
|
+
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
7594
|
+
|
7595
|
+
// compute
|
7596
|
+
CUBLAS_CHECK(
|
7597
|
+
cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
7598
|
+
ne01, ne11, ne10,
|
7599
|
+
&alpha, d_X, ne00,
|
7600
|
+
d_Y, ne10,
|
7601
|
+
&beta, d_D, ne01));
|
7602
|
+
|
7603
|
+
// copy data to host
|
7604
|
+
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
7605
|
+
#else
|
6633
7606
|
// zT = y * xT
|
6634
7607
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
6635
7608
|
ne11, ne01, ne10,
|
6636
7609
|
1.0f, y, ne10,
|
6637
7610
|
x, ne00,
|
6638
7611
|
0.0f, d, ne01);
|
7612
|
+
#endif
|
6639
7613
|
}
|
6640
7614
|
}
|
6641
|
-
|
7615
|
+
#if defined(GGML_USE_CUBLAS)
|
7616
|
+
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
7617
|
+
CUDA_CHECK(cudaFree(d_X));
|
7618
|
+
CUDA_CHECK(cudaFree(d_Y));
|
7619
|
+
CUDA_CHECK(cudaFree(d_D));
|
7620
|
+
#endif
|
6642
7621
|
//printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
6643
7622
|
|
6644
7623
|
return;
|
@@ -6768,7 +7747,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
6768
7747
|
// nb01 >= nb00 - src0 is not transposed
|
6769
7748
|
// compute by src0 rows
|
6770
7749
|
|
6771
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
7750
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
6772
7751
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
6773
7752
|
GGML_ASSERT(nb10 == sizeof(float));
|
6774
7753
|
|
@@ -6784,10 +7763,37 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
6784
7763
|
return;
|
6785
7764
|
}
|
6786
7765
|
|
6787
|
-
|
7766
|
+
#if defined(GGML_USE_CUBLAS)
|
7767
|
+
ggml_fp16_t * const wdata = params->wdata;
|
6788
7768
|
|
7769
|
+
float *d_X = NULL;
|
7770
|
+
float *d_Y = NULL;
|
7771
|
+
float *d_D = NULL;
|
7772
|
+
const float alpha = 1.0f;
|
7773
|
+
const float beta = 0.0f;
|
7774
|
+
const int x_ne = ne01 * ne10;
|
7775
|
+
const int y_ne = ne11 * ne10;
|
7776
|
+
const int d_ne = ne11 * ne01;
|
7777
|
+
|
7778
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne));
|
7779
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
|
7780
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
7781
|
+
#else
|
7782
|
+
float * const wdata = params->wdata;
|
7783
|
+
#endif
|
6789
7784
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
6790
7785
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
7786
|
+
#if defined(GGML_USE_CUBLAS)
|
7787
|
+
// with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
|
7788
|
+
{
|
7789
|
+
size_t id = 0;
|
7790
|
+
for (int64_t i01 = 0; i01 < ne11; ++i01) {
|
7791
|
+
for (int64_t i00 = 0; i00 < ne10; ++i00) {
|
7792
|
+
wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
|
7793
|
+
}
|
7794
|
+
}
|
7795
|
+
}
|
7796
|
+
#else
|
6791
7797
|
{
|
6792
7798
|
size_t id = 0;
|
6793
7799
|
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
@@ -6796,7 +7802,31 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
6796
7802
|
}
|
6797
7803
|
}
|
6798
7804
|
}
|
7805
|
+
#endif
|
6799
7806
|
|
7807
|
+
#if defined(GGML_USE_CUBLAS)
|
7808
|
+
const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03);
|
7809
|
+
const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
|
7810
|
+
|
7811
|
+
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
7812
|
+
|
7813
|
+
// copy data to device
|
7814
|
+
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream));
|
7815
|
+
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
7816
|
+
|
7817
|
+
// compute
|
7818
|
+
CUBLAS_CHECK(
|
7819
|
+
cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
7820
|
+
ne01, ne11, ne10,
|
7821
|
+
&alpha, d_X, CUDA_R_16F, ne00,
|
7822
|
+
d_Y, CUDA_R_16F, ne10,
|
7823
|
+
&beta, d_D, CUDA_R_32F, ne01,
|
7824
|
+
CUBLAS_COMPUTE_32F,
|
7825
|
+
CUBLAS_GEMM_DEFAULT));
|
7826
|
+
|
7827
|
+
// copy data to host
|
7828
|
+
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
7829
|
+
#else
|
6800
7830
|
const float * x = wdata;
|
6801
7831
|
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
6802
7832
|
|
@@ -6808,9 +7838,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
6808
7838
|
1.0f, y, ne10,
|
6809
7839
|
x, ne00,
|
6810
7840
|
0.0f, d, ne01);
|
7841
|
+
#endif
|
6811
7842
|
}
|
6812
7843
|
}
|
6813
7844
|
|
7845
|
+
#if defined(GGML_USE_CUBLAS)
|
7846
|
+
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
7847
|
+
CUDA_CHECK(cudaFree(d_X));
|
7848
|
+
CUDA_CHECK(cudaFree(d_Y));
|
7849
|
+
CUDA_CHECK(cudaFree(d_D));
|
7850
|
+
#endif
|
6814
7851
|
/*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
|
6815
7852
|
|
6816
7853
|
return;
|
@@ -6894,27 +7931,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
6894
7931
|
//}
|
6895
7932
|
}
|
6896
7933
|
|
6897
|
-
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
6898
|
-
[GGML_TYPE_Q4_0] = {
|
6899
|
-
.dequantize_row_q = dequantize_row_q4_0,
|
6900
|
-
.quantize_row_q = quantize_row_q4_0,
|
6901
|
-
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
|
6902
|
-
.vec_dot_q = ggml_vec_dot_q4_0,
|
6903
|
-
},
|
6904
|
-
[GGML_TYPE_Q4_1] = {
|
6905
|
-
.dequantize_row_q = dequantize_row_q4_1,
|
6906
|
-
.quantize_row_q = quantize_row_q4_1,
|
6907
|
-
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
|
6908
|
-
.vec_dot_q = ggml_vec_dot_q4_1,
|
6909
|
-
},
|
6910
|
-
};
|
6911
|
-
|
6912
|
-
// For internal test use
|
6913
|
-
quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
|
6914
|
-
GGML_ASSERT(i < GGML_TYPE_COUNT);
|
6915
|
-
return quantize_fns[i];
|
6916
|
-
}
|
6917
|
-
|
6918
7934
|
static void ggml_compute_forward_mul_mat_q_f32(
|
6919
7935
|
const struct ggml_compute_params * params,
|
6920
7936
|
const struct ggml_tensor * src0,
|
@@ -6962,8 +7978,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
6962
7978
|
GGML_ASSERT(ne3 == ne13);
|
6963
7979
|
|
6964
7980
|
const enum ggml_type type = src0->type;
|
6965
|
-
quantize_row_q_t const
|
6966
|
-
vec_dot_q_t const vec_dot_q
|
7981
|
+
quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
|
7982
|
+
vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
|
6967
7983
|
|
6968
7984
|
// we don't support permuted src0 or src1
|
6969
7985
|
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
|
@@ -6983,7 +7999,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
6983
7999
|
// nb01 >= nb00 - src0 is not transposed
|
6984
8000
|
// compute by src0 rows
|
6985
8001
|
|
6986
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
8002
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
6987
8003
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
6988
8004
|
if (params->ith != 0) {
|
6989
8005
|
return;
|
@@ -6997,11 +8013,55 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
6997
8013
|
return;
|
6998
8014
|
}
|
6999
8015
|
|
8016
|
+
#if defined(GGML_USE_CUBLAS)
|
8017
|
+
float *d_X = NULL;
|
8018
|
+
float *d_Y = NULL;
|
8019
|
+
float *d_D = NULL;
|
8020
|
+
float *d_Q = NULL;
|
8021
|
+
const float alpha = 1.0f;
|
8022
|
+
const float beta = 0.0f;
|
8023
|
+
const int x_ne = ne01 * ne10;
|
8024
|
+
const int y_ne = ne11 * ne10;
|
8025
|
+
const int d_ne = ne11 * ne01;
|
8026
|
+
|
8027
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
|
8028
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
|
8029
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
8030
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
|
8031
|
+
|
8032
|
+
void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
|
8033
|
+
if (type == GGML_TYPE_Q4_0) {
|
8034
|
+
dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
|
8035
|
+
}
|
8036
|
+
else if (type == GGML_TYPE_Q4_1) {
|
8037
|
+
dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
|
8038
|
+
}
|
8039
|
+
else if (type == GGML_TYPE_Q4_2) {
|
8040
|
+
dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
|
8041
|
+
}
|
8042
|
+
else {
|
8043
|
+
GGML_ASSERT(false);
|
8044
|
+
}
|
8045
|
+
#else
|
7000
8046
|
float * const wdata = params->wdata;
|
7001
8047
|
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
|
8048
|
+
#endif
|
7002
8049
|
|
7003
8050
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
7004
8051
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
8052
|
+
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
8053
|
+
|
8054
|
+
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
8055
|
+
|
8056
|
+
#if defined(GGML_USE_CUBLAS)
|
8057
|
+
// copy and dequantize on device
|
8058
|
+
CUDA_CHECK(
|
8059
|
+
cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
|
8060
|
+
GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
|
8061
|
+
|
8062
|
+
dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream);
|
8063
|
+
CUDA_CHECK(cudaGetLastError());
|
8064
|
+
#else
|
7005
8065
|
{
|
7006
8066
|
size_t id = 0;
|
7007
8067
|
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
@@ -7009,21 +8069,42 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7009
8069
|
id += ne00;
|
7010
8070
|
}
|
7011
8071
|
}
|
7012
|
-
|
7013
8072
|
const float * x = wdata;
|
7014
|
-
|
8073
|
+
#endif
|
7015
8074
|
|
7016
|
-
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
7017
8075
|
|
8076
|
+
#if defined(GGML_USE_CUBLAS)
|
8077
|
+
// copy data to device
|
8078
|
+
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
8079
|
+
|
8080
|
+
// compute
|
8081
|
+
CUBLAS_CHECK(
|
8082
|
+
cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
8083
|
+
ne01, ne11, ne10,
|
8084
|
+
&alpha, d_X, ne00,
|
8085
|
+
d_Y, ne10,
|
8086
|
+
&beta, d_D, ne01));
|
8087
|
+
|
8088
|
+
// copy data to host
|
8089
|
+
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
8090
|
+
#else
|
7018
8091
|
// zT = y * xT
|
7019
8092
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
7020
8093
|
ne11, ne01, ne10,
|
7021
8094
|
1.0f, y, ne10,
|
7022
8095
|
x, ne00,
|
7023
8096
|
0.0f, d, ne01);
|
8097
|
+
#endif
|
7024
8098
|
}
|
7025
8099
|
}
|
7026
8100
|
|
8101
|
+
#if defined(GGML_USE_CUBLAS)
|
8102
|
+
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
8103
|
+
CUDA_CHECK(cudaFree(d_X));
|
8104
|
+
CUDA_CHECK(cudaFree(d_Y));
|
8105
|
+
CUDA_CHECK(cudaFree(d_D));
|
8106
|
+
CUDA_CHECK(cudaFree(d_Q));
|
8107
|
+
#endif
|
7027
8108
|
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
7028
8109
|
|
7029
8110
|
return;
|
@@ -7032,12 +8113,12 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7032
8113
|
|
7033
8114
|
if (params->type == GGML_TASK_INIT) {
|
7034
8115
|
char * wdata = params->wdata;
|
7035
|
-
const size_t row_size = ne10*GGML_TYPE_SIZE[
|
8116
|
+
const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
|
7036
8117
|
|
7037
8118
|
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
7038
8119
|
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
7039
8120
|
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
7040
|
-
|
8121
|
+
quantize_row_q_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
|
7041
8122
|
wdata += row_size;
|
7042
8123
|
}
|
7043
8124
|
}
|
@@ -7063,7 +8144,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7063
8144
|
const int ir1 = MIN(ir0 + dr, nr);
|
7064
8145
|
|
7065
8146
|
void * wdata = params->wdata;
|
7066
|
-
const size_t row_size = ne00*GGML_TYPE_SIZE[
|
8147
|
+
const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
|
7067
8148
|
|
7068
8149
|
for (int ir = ir0; ir < ir1; ++ir) {
|
7069
8150
|
// src0 indices
|
@@ -7111,6 +8192,9 @@ static void ggml_compute_forward_mul_mat(
|
|
7111
8192
|
switch (src0->type) {
|
7112
8193
|
case GGML_TYPE_Q4_0:
|
7113
8194
|
case GGML_TYPE_Q4_1:
|
8195
|
+
case GGML_TYPE_Q4_2:
|
8196
|
+
case GGML_TYPE_Q4_3:
|
8197
|
+
case GGML_TYPE_Q8_0:
|
7114
8198
|
{
|
7115
8199
|
ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
|
7116
8200
|
} break;
|
@@ -7122,42 +8206,11 @@ static void ggml_compute_forward_mul_mat(
|
|
7122
8206
|
{
|
7123
8207
|
ggml_compute_forward_mul_mat_f32(params, src0, src1, dst);
|
7124
8208
|
} break;
|
7125
|
-
|
7126
|
-
case GGML_TYPE_I16:
|
7127
|
-
case GGML_TYPE_I32:
|
7128
|
-
case GGML_TYPE_COUNT:
|
8209
|
+
default:
|
7129
8210
|
{
|
7130
8211
|
GGML_ASSERT(false);
|
7131
8212
|
} break;
|
7132
8213
|
}
|
7133
|
-
|
7134
|
-
#if 0
|
7135
|
-
if (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_Q4_1) {
|
7136
|
-
static int first = 8;
|
7137
|
-
printf("src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
|
7138
|
-
printf("src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
|
7139
|
-
printf("dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
|
7140
|
-
if (first) {
|
7141
|
-
--first;
|
7142
|
-
} else {
|
7143
|
-
for (int k = 0; k < dst->ne[1]; ++k) {
|
7144
|
-
for (int j = 0; j < dst->ne[0]/16; ++j) {
|
7145
|
-
for (int i = 0; i < 16; ++i) {
|
7146
|
-
printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
|
7147
|
-
}
|
7148
|
-
printf("\n");
|
7149
|
-
}
|
7150
|
-
printf("\n");
|
7151
|
-
}
|
7152
|
-
printf("\n");
|
7153
|
-
exit(0);
|
7154
|
-
}
|
7155
|
-
} else {
|
7156
|
-
printf("aaaa src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
|
7157
|
-
printf("aaaa src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
|
7158
|
-
printf("aaaa dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
|
7159
|
-
}
|
7160
|
-
#endif
|
7161
8214
|
}
|
7162
8215
|
|
7163
8216
|
// ggml_compute_forward_scale
|
@@ -7207,13 +8260,7 @@ static void ggml_compute_forward_scale(
|
|
7207
8260
|
{
|
7208
8261
|
ggml_compute_forward_scale_f32(params, src0, src1, dst);
|
7209
8262
|
} break;
|
7210
|
-
|
7211
|
-
case GGML_TYPE_Q4_1:
|
7212
|
-
case GGML_TYPE_I8:
|
7213
|
-
case GGML_TYPE_I16:
|
7214
|
-
case GGML_TYPE_I32:
|
7215
|
-
case GGML_TYPE_F16:
|
7216
|
-
case GGML_TYPE_COUNT:
|
8263
|
+
default:
|
7217
8264
|
{
|
7218
8265
|
GGML_ASSERT(false);
|
7219
8266
|
} break;
|
@@ -7374,6 +8421,9 @@ static void ggml_compute_forward_get_rows(
|
|
7374
8421
|
switch (src0->type) {
|
7375
8422
|
case GGML_TYPE_Q4_0:
|
7376
8423
|
case GGML_TYPE_Q4_1:
|
8424
|
+
case GGML_TYPE_Q4_2:
|
8425
|
+
case GGML_TYPE_Q4_3:
|
8426
|
+
case GGML_TYPE_Q8_0:
|
7377
8427
|
{
|
7378
8428
|
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
7379
8429
|
} break;
|
@@ -7385,10 +8435,7 @@ static void ggml_compute_forward_get_rows(
|
|
7385
8435
|
{
|
7386
8436
|
ggml_compute_forward_get_rows_f32(params, src0, src1, dst);
|
7387
8437
|
} break;
|
7388
|
-
|
7389
|
-
case GGML_TYPE_I16:
|
7390
|
-
case GGML_TYPE_I32:
|
7391
|
-
case GGML_TYPE_COUNT:
|
8438
|
+
default:
|
7392
8439
|
{
|
7393
8440
|
GGML_ASSERT(false);
|
7394
8441
|
} break;
|
@@ -7461,13 +8508,7 @@ static void ggml_compute_forward_diag_mask_inf(
|
|
7461
8508
|
{
|
7462
8509
|
ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst);
|
7463
8510
|
} break;
|
7464
|
-
|
7465
|
-
case GGML_TYPE_Q4_1:
|
7466
|
-
case GGML_TYPE_I8:
|
7467
|
-
case GGML_TYPE_I16:
|
7468
|
-
case GGML_TYPE_I32:
|
7469
|
-
case GGML_TYPE_F16:
|
7470
|
-
case GGML_TYPE_COUNT:
|
8511
|
+
default:
|
7471
8512
|
{
|
7472
8513
|
GGML_ASSERT(false);
|
7473
8514
|
} break;
|
@@ -7555,13 +8596,7 @@ static void ggml_compute_forward_soft_max(
|
|
7555
8596
|
{
|
7556
8597
|
ggml_compute_forward_soft_max_f32(params, src0, dst);
|
7557
8598
|
} break;
|
7558
|
-
|
7559
|
-
case GGML_TYPE_Q4_1:
|
7560
|
-
case GGML_TYPE_I8:
|
7561
|
-
case GGML_TYPE_I16:
|
7562
|
-
case GGML_TYPE_I32:
|
7563
|
-
case GGML_TYPE_F16:
|
7564
|
-
case GGML_TYPE_COUNT:
|
8599
|
+
default:
|
7565
8600
|
{
|
7566
8601
|
GGML_ASSERT(false);
|
7567
8602
|
} break;
|
@@ -7618,9 +8653,11 @@ static void ggml_compute_forward_rope_f32(
|
|
7618
8653
|
|
7619
8654
|
const float theta_scale = powf(10000.0, -2.0f/n_dims);
|
7620
8655
|
|
8656
|
+
const bool is_neox = mode & 2;
|
8657
|
+
|
7621
8658
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
7622
|
-
for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
7623
|
-
const int p = (mode == 0 ? n_past + i2 : i2);
|
8659
|
+
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
8660
|
+
const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
7624
8661
|
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
7625
8662
|
if (ir++ < ir0) continue;
|
7626
8663
|
if (ir > ir1) break;
|
@@ -7633,14 +8670,25 @@ static void ggml_compute_forward_rope_f32(
|
|
7633
8670
|
|
7634
8671
|
theta *= theta_scale;
|
7635
8672
|
|
7636
|
-
|
7637
|
-
|
8673
|
+
if (!is_neox) {
|
8674
|
+
const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
8675
|
+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
8676
|
+
|
8677
|
+
const float x0 = src[0];
|
8678
|
+
const float x1 = src[1];
|
7638
8679
|
|
7639
|
-
|
7640
|
-
|
8680
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
8681
|
+
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
8682
|
+
} else {
|
8683
|
+
const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
8684
|
+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
7641
8685
|
|
7642
|
-
|
7643
|
-
|
8686
|
+
const float x0 = src[0];
|
8687
|
+
const float x1 = src[n_dims/2];
|
8688
|
+
|
8689
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
8690
|
+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
8691
|
+
}
|
7644
8692
|
}
|
7645
8693
|
}
|
7646
8694
|
}
|
@@ -7695,9 +8743,11 @@ static void ggml_compute_forward_rope_f16(
|
|
7695
8743
|
|
7696
8744
|
const float theta_scale = powf(10000.0, -2.0f/n_dims);
|
7697
8745
|
|
8746
|
+
const bool is_neox = mode & 2;
|
8747
|
+
|
7698
8748
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
7699
|
-
for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
7700
|
-
const int p = (mode == 0 ? n_past + i2 : i2);
|
8749
|
+
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
8750
|
+
const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
7701
8751
|
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
7702
8752
|
if (ir++ < ir0) continue;
|
7703
8753
|
if (ir > ir1) break;
|
@@ -7710,14 +8760,25 @@ static void ggml_compute_forward_rope_f16(
|
|
7710
8760
|
|
7711
8761
|
theta *= theta_scale;
|
7712
8762
|
|
7713
|
-
|
7714
|
-
|
8763
|
+
if (!is_neox) {
|
8764
|
+
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
8765
|
+
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
8766
|
+
|
8767
|
+
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
8768
|
+
const float x1 = GGML_FP16_TO_FP32(src[1]);
|
7715
8769
|
|
7716
|
-
|
7717
|
-
|
8770
|
+
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
8771
|
+
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
8772
|
+
} else {
|
8773
|
+
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
8774
|
+
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
7718
8775
|
|
7719
|
-
|
7720
|
-
|
8776
|
+
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
8777
|
+
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
|
8778
|
+
|
8779
|
+
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
8780
|
+
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
8781
|
+
}
|
7721
8782
|
}
|
7722
8783
|
}
|
7723
8784
|
}
|
@@ -7738,12 +8799,7 @@ static void ggml_compute_forward_rope(
|
|
7738
8799
|
{
|
7739
8800
|
ggml_compute_forward_rope_f32(params, src0, src1, dst);
|
7740
8801
|
} break;
|
7741
|
-
|
7742
|
-
case GGML_TYPE_Q4_1:
|
7743
|
-
case GGML_TYPE_I8:
|
7744
|
-
case GGML_TYPE_I16:
|
7745
|
-
case GGML_TYPE_I32:
|
7746
|
-
case GGML_TYPE_COUNT:
|
8802
|
+
default:
|
7747
8803
|
{
|
7748
8804
|
GGML_ASSERT(false);
|
7749
8805
|
} break;
|
@@ -8006,12 +9062,7 @@ static void ggml_compute_forward_conv_1d_1s(
|
|
8006
9062
|
{
|
8007
9063
|
ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst);
|
8008
9064
|
} break;
|
8009
|
-
|
8010
|
-
case GGML_TYPE_Q4_1:
|
8011
|
-
case GGML_TYPE_I8:
|
8012
|
-
case GGML_TYPE_I16:
|
8013
|
-
case GGML_TYPE_I32:
|
8014
|
-
case GGML_TYPE_COUNT:
|
9065
|
+
default:
|
8015
9066
|
{
|
8016
9067
|
GGML_ASSERT(false);
|
8017
9068
|
} break;
|
@@ -8274,12 +9325,7 @@ static void ggml_compute_forward_conv_1d_2s(
|
|
8274
9325
|
{
|
8275
9326
|
ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst);
|
8276
9327
|
} break;
|
8277
|
-
|
8278
|
-
case GGML_TYPE_Q4_1:
|
8279
|
-
case GGML_TYPE_I8:
|
8280
|
-
case GGML_TYPE_I16:
|
8281
|
-
case GGML_TYPE_I32:
|
8282
|
-
case GGML_TYPE_COUNT:
|
9328
|
+
default:
|
8283
9329
|
{
|
8284
9330
|
GGML_ASSERT(false);
|
8285
9331
|
} break;
|
@@ -8759,12 +9805,7 @@ static void ggml_compute_forward_flash_attn(
|
|
8759
9805
|
{
|
8760
9806
|
ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst);
|
8761
9807
|
} break;
|
8762
|
-
|
8763
|
-
case GGML_TYPE_Q4_1:
|
8764
|
-
case GGML_TYPE_I8:
|
8765
|
-
case GGML_TYPE_I16:
|
8766
|
-
case GGML_TYPE_I32:
|
8767
|
-
case GGML_TYPE_COUNT:
|
9808
|
+
default:
|
8768
9809
|
{
|
8769
9810
|
GGML_ASSERT(false);
|
8770
9811
|
} break;
|
@@ -8970,12 +10011,7 @@ static void ggml_compute_forward_flash_ff(
|
|
8970
10011
|
{
|
8971
10012
|
GGML_ASSERT(false); // TODO
|
8972
10013
|
} break;
|
8973
|
-
|
8974
|
-
case GGML_TYPE_Q4_1:
|
8975
|
-
case GGML_TYPE_I8:
|
8976
|
-
case GGML_TYPE_I16:
|
8977
|
-
case GGML_TYPE_I32:
|
8978
|
-
case GGML_TYPE_COUNT:
|
10014
|
+
default:
|
8979
10015
|
{
|
8980
10016
|
GGML_ASSERT(false);
|
8981
10017
|
} break;
|
@@ -9019,13 +10055,7 @@ static void ggml_compute_forward_map_unary(
|
|
9019
10055
|
{
|
9020
10056
|
ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
|
9021
10057
|
} break;
|
9022
|
-
|
9023
|
-
case GGML_TYPE_Q4_1:
|
9024
|
-
case GGML_TYPE_I8:
|
9025
|
-
case GGML_TYPE_I16:
|
9026
|
-
case GGML_TYPE_I32:
|
9027
|
-
case GGML_TYPE_F16:
|
9028
|
-
case GGML_TYPE_COUNT:
|
10058
|
+
default:
|
9029
10059
|
{
|
9030
10060
|
GGML_ASSERT(false);
|
9031
10061
|
} break;
|
@@ -9074,13 +10104,7 @@ static void ggml_compute_forward_map_binary(
|
|
9074
10104
|
{
|
9075
10105
|
ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
|
9076
10106
|
} break;
|
9077
|
-
|
9078
|
-
case GGML_TYPE_Q4_1:
|
9079
|
-
case GGML_TYPE_I8:
|
9080
|
-
case GGML_TYPE_I16:
|
9081
|
-
case GGML_TYPE_I32:
|
9082
|
-
case GGML_TYPE_F16:
|
9083
|
-
case GGML_TYPE_COUNT:
|
10107
|
+
default:
|
9084
10108
|
{
|
9085
10109
|
GGML_ASSERT(false);
|
9086
10110
|
} break;
|
@@ -9830,13 +10854,29 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
9830
10854
|
struct ggml_tensor * node = cgraph->nodes[i];
|
9831
10855
|
|
9832
10856
|
switch (node->op) {
|
10857
|
+
case GGML_OP_CPY:
|
9833
10858
|
case GGML_OP_DUP:
|
9834
10859
|
{
|
9835
|
-
node->n_tasks =
|
10860
|
+
node->n_tasks = n_threads;
|
10861
|
+
|
10862
|
+
size_t cur = 0;
|
10863
|
+
if (ggml_is_quantized(node->type)) {
|
10864
|
+
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_threads;
|
10865
|
+
}
|
10866
|
+
|
10867
|
+
work_size = MAX(work_size, cur);
|
9836
10868
|
} break;
|
9837
10869
|
case GGML_OP_ADD:
|
9838
10870
|
{
|
9839
10871
|
node->n_tasks = n_threads;
|
10872
|
+
|
10873
|
+
size_t cur = 0;
|
10874
|
+
|
10875
|
+
if (ggml_is_quantized(node->src0->type)) {
|
10876
|
+
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
|
10877
|
+
}
|
10878
|
+
|
10879
|
+
work_size = MAX(work_size, cur);
|
9840
10880
|
} break;
|
9841
10881
|
case GGML_OP_SUB:
|
9842
10882
|
case GGML_OP_MUL:
|
@@ -9881,7 +10921,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
9881
10921
|
size_t cur = 0;
|
9882
10922
|
|
9883
10923
|
if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
|
9884
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
10924
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
9885
10925
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
9886
10926
|
node->n_tasks = 1; // TODO: this actually is doing nothing
|
9887
10927
|
// the threads are still spinning
|
@@ -9897,15 +10937,15 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
9897
10937
|
#endif
|
9898
10938
|
} else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
|
9899
10939
|
cur = 0;
|
9900
|
-
} else if (
|
9901
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
10940
|
+
} else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
|
10941
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
9902
10942
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
9903
10943
|
node->n_tasks = 1;
|
9904
10944
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
|
9905
10945
|
} else
|
9906
10946
|
#endif
|
9907
10947
|
{
|
9908
|
-
cur = GGML_TYPE_SIZE[
|
10948
|
+
cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
|
9909
10949
|
}
|
9910
10950
|
} else {
|
9911
10951
|
GGML_ASSERT(false);
|
@@ -9917,7 +10957,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
9917
10957
|
{
|
9918
10958
|
node->n_tasks = n_threads;
|
9919
10959
|
} break;
|
9920
|
-
case GGML_OP_CPY:
|
9921
10960
|
case GGML_OP_CONT:
|
9922
10961
|
case GGML_OP_RESHAPE:
|
9923
10962
|
case GGML_OP_VIEW:
|
@@ -11080,16 +12119,16 @@ enum ggml_opt_result ggml_opt(
|
|
11080
12119
|
////////////////////////////////////////////////////////////////////////////////
|
11081
12120
|
|
11082
12121
|
size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
|
11083
|
-
assert(k %
|
11084
|
-
const int nb = k /
|
12122
|
+
assert(k % QK4_0 == 0);
|
12123
|
+
const int nb = k / QK4_0;
|
11085
12124
|
|
11086
12125
|
for (int j = 0; j < n; j += k) {
|
11087
|
-
block_q4_0 * restrict y = (block_q4_0 *)dst + j/
|
12126
|
+
block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK4_0;
|
11088
12127
|
|
11089
12128
|
quantize_row_q4_0_reference(src + j, y, k);
|
11090
12129
|
|
11091
12130
|
for (int i = 0; i < nb; i++) {
|
11092
|
-
for (int l = 0; l <
|
12131
|
+
for (int l = 0; l < QK4_0; l += 2) {
|
11093
12132
|
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
|
11094
12133
|
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
11095
12134
|
|
@@ -11099,20 +12138,67 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
|
|
11099
12138
|
}
|
11100
12139
|
}
|
11101
12140
|
|
11102
|
-
return (n/
|
12141
|
+
return (n/QK4_0*sizeof(block_q4_0));
|
11103
12142
|
}
|
11104
12143
|
|
11105
12144
|
size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
|
11106
|
-
assert(k %
|
11107
|
-
const int nb = k /
|
12145
|
+
assert(k % QK4_1 == 0);
|
12146
|
+
const int nb = k / QK4_1;
|
11108
12147
|
|
11109
12148
|
for (int j = 0; j < n; j += k) {
|
11110
|
-
block_q4_1 * restrict y = (block_q4_1 *)dst + j/
|
12149
|
+
block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK4_1;
|
11111
12150
|
|
11112
12151
|
quantize_row_q4_1_reference(src + j, y, k);
|
11113
12152
|
|
11114
12153
|
for (int i = 0; i < nb; i++) {
|
11115
|
-
for (int l = 0; l <
|
12154
|
+
for (int l = 0; l < QK4_1; l += 2) {
|
12155
|
+
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
|
12156
|
+
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
12157
|
+
|
12158
|
+
hist[vi0]++;
|
12159
|
+
hist[vi1]++;
|
12160
|
+
}
|
12161
|
+
}
|
12162
|
+
}
|
12163
|
+
|
12164
|
+
return (n/QK4_1*sizeof(block_q4_1));
|
12165
|
+
}
|
12166
|
+
|
12167
|
+
size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist) {
|
12168
|
+
assert(k % QK4_2 == 0);
|
12169
|
+
const int nb = k / QK4_2;
|
12170
|
+
|
12171
|
+
for (int j = 0; j < n; j += k) {
|
12172
|
+
block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
|
12173
|
+
|
12174
|
+
//quantize_row_q4_2_reference(src + j, y, k);
|
12175
|
+
quantize_row_q4_2_rmse(src + j, y, k);
|
12176
|
+
|
12177
|
+
for (int i = 0; i < nb; i++) {
|
12178
|
+
for (int l = 0; l < QK4_2; l += 2) {
|
12179
|
+
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
|
12180
|
+
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
12181
|
+
|
12182
|
+
hist[vi0]++;
|
12183
|
+
hist[vi1]++;
|
12184
|
+
}
|
12185
|
+
}
|
12186
|
+
}
|
12187
|
+
|
12188
|
+
return (n/QK4_2*sizeof(block_q4_2));
|
12189
|
+
}
|
12190
|
+
|
12191
|
+
size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist) {
|
12192
|
+
assert(k % QK4_3 == 0);
|
12193
|
+
const int nb = k / QK4_3;
|
12194
|
+
|
12195
|
+
for (int j = 0; j < n; j += k) {
|
12196
|
+
block_q4_3 * restrict y = (block_q4_3 *)dst + j/QK4_3;
|
12197
|
+
|
12198
|
+
quantize_row_q4_3_reference(src + j, y, k);
|
12199
|
+
|
12200
|
+
for (int i = 0; i < nb; i++) {
|
12201
|
+
for (int l = 0; l < QK4_3; l += 2) {
|
11116
12202
|
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
|
11117
12203
|
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
11118
12204
|
|
@@ -11122,7 +12208,40 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
|
|
11122
12208
|
}
|
11123
12209
|
}
|
11124
12210
|
|
11125
|
-
return (n/
|
12211
|
+
return (n/QK4_3*sizeof(block_q4_3));
|
12212
|
+
}
|
12213
|
+
|
12214
|
+
size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
|
12215
|
+
size_t result = 0;
|
12216
|
+
switch (type) {
|
12217
|
+
case GGML_TYPE_Q4_0:
|
12218
|
+
{
|
12219
|
+
GGML_ASSERT(start % QK4_0 == 0);
|
12220
|
+
block_q4_0 * block = (block_q4_0*)dst + start / QK4_0;
|
12221
|
+
result = ggml_quantize_q4_0(src + start, block, n, n, hist);
|
12222
|
+
} break;
|
12223
|
+
case GGML_TYPE_Q4_1:
|
12224
|
+
{
|
12225
|
+
GGML_ASSERT(start % QK4_1 == 0);
|
12226
|
+
block_q4_1 * block = (block_q4_1*)dst + start / QK4_1;
|
12227
|
+
result = ggml_quantize_q4_1(src + start, block, n, n, hist);
|
12228
|
+
} break;
|
12229
|
+
case GGML_TYPE_Q4_2:
|
12230
|
+
{
|
12231
|
+
GGML_ASSERT(start % QK4_2 == 0);
|
12232
|
+
block_q4_2 * block = (block_q4_2*)dst + start / QK4_2;
|
12233
|
+
result = ggml_quantize_q4_2(src + start, block, n, n, hist);
|
12234
|
+
} break;
|
12235
|
+
case GGML_TYPE_Q4_3:
|
12236
|
+
{
|
12237
|
+
GGML_ASSERT(start % QK4_3 == 0);
|
12238
|
+
block_q4_3 * block = (block_q4_3*)dst + start / QK4_3;
|
12239
|
+
result = ggml_quantize_q4_3(src + start, block, n, n, hist);
|
12240
|
+
} break;
|
12241
|
+
default:
|
12242
|
+
assert(false);
|
12243
|
+
}
|
12244
|
+
return result;
|
11126
12245
|
}
|
11127
12246
|
|
11128
12247
|
////////////////////////////////////////////////////////////////////////////////
|
@@ -11151,6 +12270,22 @@ int ggml_cpu_has_avx512(void) {
|
|
11151
12270
|
#endif
|
11152
12271
|
}
|
11153
12272
|
|
12273
|
+
int ggml_cpu_has_avx512_vbmi(void) {
|
12274
|
+
#if defined(__AVX512VBMI__)
|
12275
|
+
return 1;
|
12276
|
+
#else
|
12277
|
+
return 0;
|
12278
|
+
#endif
|
12279
|
+
}
|
12280
|
+
|
12281
|
+
int ggml_cpu_has_avx512_vnni(void) {
|
12282
|
+
#if defined(__AVX512VNNI__)
|
12283
|
+
return 1;
|
12284
|
+
#else
|
12285
|
+
return 0;
|
12286
|
+
#endif
|
12287
|
+
}
|
12288
|
+
|
11154
12289
|
int ggml_cpu_has_fma(void) {
|
11155
12290
|
#if defined(__FMA__)
|
11156
12291
|
return 1;
|
@@ -11200,7 +12335,15 @@ int ggml_cpu_has_wasm_simd(void) {
|
|
11200
12335
|
}
|
11201
12336
|
|
11202
12337
|
int ggml_cpu_has_blas(void) {
|
11203
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
12338
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
12339
|
+
return 1;
|
12340
|
+
#else
|
12341
|
+
return 0;
|
12342
|
+
#endif
|
12343
|
+
}
|
12344
|
+
|
12345
|
+
int ggml_cpu_has_cublas(void) {
|
12346
|
+
#if defined(GGML_USE_CUBLAS)
|
11204
12347
|
return 1;
|
11205
12348
|
#else
|
11206
12349
|
return 0;
|