llama_cpp 0.0.4 → 0.0.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +3 -2
- data/ext/llama_cpp/extconf.rb +26 -0
- data/ext/llama_cpp/llama_cpp.cpp +106 -0
- data/ext/llama_cpp/src/ggml-cuda.h +12 -0
- data/ext/llama_cpp/src/ggml.c +2038 -895
- data/ext/llama_cpp/src/ggml.h +21 -1
- data/ext/llama_cpp/src/llama.cpp +376 -62
- data/ext/llama_cpp/src/llama.h +17 -1
- data/ext/llama_cpp/src/llama_util.h +22 -16
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +3 -3
- data/sig/llama_cpp.rbs +13 -1
- metadata +3 -2
data/ext/llama_cpp/src/ggml.c
CHANGED
@@ -19,6 +19,7 @@
|
|
19
19
|
#include <inttypes.h>
|
20
20
|
#include <stdio.h>
|
21
21
|
#include <float.h>
|
22
|
+
#include <limits.h>
|
22
23
|
|
23
24
|
// if C99 - static_assert is noop
|
24
25
|
// ref: https://stackoverflow.com/a/53923785/4039976
|
@@ -118,7 +119,16 @@ typedef void* thread_ret_t;
|
|
118
119
|
#define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
|
119
120
|
#define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
|
120
121
|
#else
|
121
|
-
|
122
|
+
inline static void* ggml_aligned_malloc(size_t size) {
|
123
|
+
void* aligned_memory = NULL;
|
124
|
+
int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
|
125
|
+
if (result != 0) {
|
126
|
+
// Handle allocation failure
|
127
|
+
return NULL;
|
128
|
+
}
|
129
|
+
return aligned_memory;
|
130
|
+
}
|
131
|
+
#define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size)
|
122
132
|
#define GGML_ALIGNED_FREE(ptr) free(ptr)
|
123
133
|
#endif
|
124
134
|
|
@@ -133,10 +143,49 @@ typedef void* thread_ret_t;
|
|
133
143
|
} \
|
134
144
|
} while (0)
|
135
145
|
|
136
|
-
#
|
146
|
+
#if defined(GGML_USE_ACCELERATE)
|
137
147
|
#include <Accelerate/Accelerate.h>
|
138
|
-
#elif GGML_USE_OPENBLAS
|
148
|
+
#elif defined(GGML_USE_OPENBLAS)
|
139
149
|
#include <cblas.h>
|
150
|
+
#elif defined(GGML_USE_CUBLAS)
|
151
|
+
#include <cublas_v2.h>
|
152
|
+
#include <cuda_runtime.h>
|
153
|
+
#include "ggml-cuda.h"
|
154
|
+
|
155
|
+
#define CUDA_CHECK(err) \
|
156
|
+
do { \
|
157
|
+
cudaError_t err_ = (err); \
|
158
|
+
if (err_ != cudaSuccess) { \
|
159
|
+
printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
|
160
|
+
cudaGetErrorString(err_)); \
|
161
|
+
exit(1); \
|
162
|
+
} \
|
163
|
+
} while (0)
|
164
|
+
|
165
|
+
#define CUBLAS_CHECK(err) \
|
166
|
+
do { \
|
167
|
+
cublasStatus_t err_ = (err); \
|
168
|
+
if (err_ != CUBLAS_STATUS_SUCCESS) { \
|
169
|
+
printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
|
170
|
+
exit(1); \
|
171
|
+
} \
|
172
|
+
} while (0)
|
173
|
+
|
174
|
+
static cublasHandle_t cublasH = NULL;
|
175
|
+
static cudaStream_t cudaStream = NULL;
|
176
|
+
static void init_cublas(void) {
|
177
|
+
if (cublasH == NULL) {
|
178
|
+
// create cublas handle, bind a stream
|
179
|
+
CUBLAS_CHECK(cublasCreate(&cublasH));
|
180
|
+
|
181
|
+
CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
|
182
|
+
|
183
|
+
CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
|
184
|
+
|
185
|
+
// configure logging to stdout
|
186
|
+
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
|
187
|
+
}
|
188
|
+
}
|
140
189
|
#endif
|
141
190
|
|
142
191
|
#undef MIN
|
@@ -418,14 +467,30 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
|
|
418
467
|
// quantization
|
419
468
|
//
|
420
469
|
|
421
|
-
#
|
470
|
+
#if __AVX__ || __AVX2__ || __AVX512F__
|
471
|
+
// Unpack 16 4-bit fields into 16 bytes
|
472
|
+
// The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
|
473
|
+
static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
|
474
|
+
{
|
475
|
+
// Load 8 bytes from memory
|
476
|
+
__m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
|
477
|
+
|
478
|
+
// Expand bytes into uint16_t values
|
479
|
+
__m128i bytes = _mm_cvtepu8_epi16( tmp );
|
480
|
+
|
481
|
+
// Unpack values into individual bytes
|
482
|
+
const __m128i lowMask = _mm_set1_epi8( 0xF );
|
483
|
+
__m128i high = _mm_andnot_si128( lowMask, bytes );
|
484
|
+
__m128i low = _mm_and_si128( lowMask, bytes );
|
485
|
+
high = _mm_slli_epi16( high, 4 );
|
486
|
+
bytes = _mm_or_si128( low, high );
|
487
|
+
return bytes;
|
488
|
+
}
|
422
489
|
|
423
|
-
// AVX routines provided by GH user Const-me
|
424
|
-
// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
|
425
490
|
#if __AVX2__ || __AVX512F__
|
426
491
|
// Unpack 32 4-bit fields into 32 bytes
|
427
492
|
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
428
|
-
static inline __m256i
|
493
|
+
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
|
429
494
|
{
|
430
495
|
// Load 16 bytes from memory
|
431
496
|
__m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
|
@@ -456,24 +521,7 @@ static inline __m128i packNibbles( __m256i bytes )
|
|
456
521
|
__m128i r1 = _mm256_extracti128_si256( bytes, 1 );
|
457
522
|
return _mm_packus_epi16( r0, r1 );
|
458
523
|
}
|
459
|
-
#
|
460
|
-
static inline __m128i bytesFromNibbles( const uint8_t* rsi )
|
461
|
-
{
|
462
|
-
// Load 8 bytes from memory
|
463
|
-
__m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
|
464
|
-
|
465
|
-
// Expand bytes into uint16_t values
|
466
|
-
__m128i bytes = _mm_cvtepu8_epi16( tmp );
|
467
|
-
|
468
|
-
// Unpack values into individual bytes
|
469
|
-
const __m128i lowMask = _mm_set1_epi8( 0xF );
|
470
|
-
__m128i high = _mm_andnot_si128( lowMask, bytes );
|
471
|
-
__m128i low = _mm_and_si128( lowMask, bytes );
|
472
|
-
high = _mm_slli_epi16( high, 4 );
|
473
|
-
bytes = _mm_or_si128( low, high );
|
474
|
-
return bytes;
|
475
|
-
}
|
476
|
-
|
524
|
+
#else
|
477
525
|
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
478
526
|
{
|
479
527
|
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
|
@@ -490,6 +538,7 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
|
490
538
|
return _mm_packus_epi16( bytes1, bytes2);
|
491
539
|
}
|
492
540
|
#endif
|
541
|
+
#endif // __AVX__ || __AVX2__ || __AVX512F__
|
493
542
|
|
494
543
|
#if __ARM_NEON
|
495
544
|
|
@@ -507,6 +556,18 @@ inline static uint16_t vaddvq_u8(uint8x16_t v) {
|
|
507
556
|
(uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
|
508
557
|
}
|
509
558
|
|
559
|
+
inline static int16_t vaddvq_s8(int8x16_t v) {
|
560
|
+
return
|
561
|
+
(int16_t)vgetq_lane_s8(v, 0) + (int16_t)vgetq_lane_s8(v, 1) +
|
562
|
+
(int16_t)vgetq_lane_s8(v, 2) + (int16_t)vgetq_lane_s8(v, 3) +
|
563
|
+
(int16_t)vgetq_lane_s8(v, 4) + (int16_t)vgetq_lane_s8(v, 5) +
|
564
|
+
(int16_t)vgetq_lane_s8(v, 6) + (int16_t)vgetq_lane_s8(v, 7) +
|
565
|
+
(int16_t)vgetq_lane_s8(v, 8) + (int16_t)vgetq_lane_s8(v, 9) +
|
566
|
+
(int16_t)vgetq_lane_s8(v, 10) + (int16_t)vgetq_lane_s8(v, 11) +
|
567
|
+
(int16_t)vgetq_lane_s8(v, 12) + (int16_t)vgetq_lane_s8(v, 13) +
|
568
|
+
(int16_t)vgetq_lane_s8(v, 14) + (int16_t)vgetq_lane_s8(v, 15);
|
569
|
+
}
|
570
|
+
|
510
571
|
inline static int32_t vaddvq_s16(int16x8_t v) {
|
511
572
|
return
|
512
573
|
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
|
@@ -531,68 +592,88 @@ inline static float vaddvq_f32(float32x4_t v) {
|
|
531
592
|
return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
|
532
593
|
}
|
533
594
|
|
534
|
-
|
595
|
+
float vminvq_f32(float32x4_t v) {
|
535
596
|
return
|
536
597
|
MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
|
537
598
|
MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
|
538
599
|
}
|
539
600
|
|
540
|
-
|
601
|
+
float vmaxvq_f32(float32x4_t v) {
|
541
602
|
return
|
542
603
|
MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
|
543
604
|
MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
|
544
605
|
}
|
545
606
|
|
546
|
-
|
607
|
+
int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
|
547
608
|
return vget_low_s8(vcombine_s8(a, b));
|
548
609
|
}
|
549
610
|
|
550
|
-
|
611
|
+
int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
|
551
612
|
return vget_high_s8(vcombine_s8(a, b));
|
552
613
|
}
|
553
614
|
|
554
|
-
|
615
|
+
uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
|
555
616
|
return vget_low_u8(vcombine_u8(a, b));
|
556
617
|
}
|
557
618
|
|
558
|
-
|
619
|
+
uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
|
559
620
|
return vget_high_u8(vcombine_u8(a, b));
|
560
621
|
}
|
561
622
|
|
562
623
|
#endif
|
563
624
|
#endif
|
564
625
|
|
565
|
-
|
566
|
-
|
567
|
-
// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
|
626
|
+
|
627
|
+
#define QK4_0 32
|
568
628
|
typedef struct {
|
569
|
-
float d;
|
570
|
-
uint8_t qs[
|
629
|
+
float d; // delta
|
630
|
+
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
571
631
|
} block_q4_0;
|
572
|
-
static_assert(sizeof(block_q4_0) == sizeof(float) +
|
632
|
+
static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
|
573
633
|
|
574
|
-
|
575
|
-
// blocks of QK elements
|
576
|
-
// represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
|
634
|
+
#define QK4_1 32
|
577
635
|
typedef struct {
|
578
|
-
float d;
|
579
|
-
float m;
|
580
|
-
uint8_t qs[
|
636
|
+
float d; // delta
|
637
|
+
float m; // min
|
638
|
+
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
581
639
|
} block_q4_1;
|
582
|
-
static_assert(sizeof(block_q4_1) == sizeof(float)
|
640
|
+
static_assert(sizeof(block_q4_1) == 2 * sizeof(float) + QK4_1 / 2, "wrong q4_1 block size/padding");
|
641
|
+
|
642
|
+
#define QK4_2 16
|
643
|
+
typedef struct {
|
644
|
+
ggml_fp16_t d; // delta
|
645
|
+
uint8_t qs[QK4_2 / 2]; // nibbles / quants
|
646
|
+
} block_q4_2;
|
647
|
+
static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
|
648
|
+
|
649
|
+
#define QK4_3 16
|
650
|
+
typedef struct {
|
651
|
+
ggml_fp16_t d; // delta
|
652
|
+
ggml_fp16_t m; // min
|
653
|
+
uint8_t qs[QK4_3 / 2]; // nibbles / quants
|
654
|
+
} block_q4_3;
|
655
|
+
static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
|
656
|
+
|
657
|
+
#define QK8_0 32
|
658
|
+
typedef struct {
|
659
|
+
float d; // delta
|
660
|
+
int8_t qs[QK8_0]; // quants
|
661
|
+
} block_q8_0;
|
662
|
+
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
|
663
|
+
|
583
664
|
|
584
665
|
// reference implementation for deterministic creation of model files
|
585
666
|
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
|
586
|
-
assert(k %
|
587
|
-
const int nb = k /
|
667
|
+
assert(k % QK4_0 == 0);
|
668
|
+
const int nb = k / QK4_0;
|
588
669
|
|
589
|
-
uint8_t pp[
|
670
|
+
uint8_t pp[QK4_0/2];
|
590
671
|
|
591
672
|
for (int i = 0; i < nb; i++) {
|
592
673
|
float amax = 0.0f; // absolute max
|
593
674
|
|
594
|
-
for (int l = 0; l <
|
595
|
-
const float v = x[i*
|
675
|
+
for (int l = 0; l < QK4_0; l++) {
|
676
|
+
const float v = x[i*QK4_0 + l];
|
596
677
|
amax = MAX(amax, fabsf(v));
|
597
678
|
}
|
598
679
|
|
@@ -601,9 +682,9 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
|
|
601
682
|
|
602
683
|
y[i].d = d;
|
603
684
|
|
604
|
-
for (int l = 0; l <
|
605
|
-
const float v0 = x[i*
|
606
|
-
const float v1 = x[i*
|
685
|
+
for (int l = 0; l < QK4_0; l += 2) {
|
686
|
+
const float v0 = x[i*QK4_0 + l + 0]*id;
|
687
|
+
const float v1 = x[i*QK4_0 + l + 1]*id;
|
607
688
|
|
608
689
|
const uint8_t vi0 = (int8_t)roundf(v0) + 8;
|
609
690
|
const uint8_t vi1 = (int8_t)roundf(v1) + 8;
|
@@ -619,8 +700,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
|
|
619
700
|
}
|
620
701
|
|
621
702
|
static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int k) {
|
622
|
-
assert(k %
|
623
|
-
const int nb = k /
|
703
|
+
assert(k % QK4_0 == 0);
|
704
|
+
const int nb = k / QK4_0;
|
624
705
|
|
625
706
|
block_q4_0 * restrict y = vy;
|
626
707
|
|
@@ -870,19 +951,19 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
870
951
|
}
|
871
952
|
|
872
953
|
static void quantize_row_q4_1_reference(const float * restrict x, void * restrict vy, int k) {
|
873
|
-
assert(k %
|
874
|
-
const int nb = k /
|
954
|
+
assert(k % QK4_1 == 0);
|
955
|
+
const int nb = k / QK4_1;
|
875
956
|
|
876
957
|
block_q4_1 * restrict y = vy;
|
877
958
|
|
878
|
-
uint8_t pp[
|
959
|
+
uint8_t pp[QK4_1/2];
|
879
960
|
|
880
961
|
for (int i = 0; i < nb; i++) {
|
881
962
|
float min = FLT_MAX;
|
882
963
|
float max = -FLT_MAX;
|
883
964
|
|
884
|
-
for (int l = 0; l <
|
885
|
-
const float v = x[i*
|
965
|
+
for (int l = 0; l < QK4_1; l++) {
|
966
|
+
const float v = x[i*QK4_1 + l];
|
886
967
|
if (v < min) min = v;
|
887
968
|
if (v > max) max = v;
|
888
969
|
}
|
@@ -893,9 +974,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
|
|
893
974
|
y[i].d = d;
|
894
975
|
y[i].m = min;
|
895
976
|
|
896
|
-
for (int l = 0; l <
|
897
|
-
const float v0 = (x[i*
|
898
|
-
const float v1 = (x[i*
|
977
|
+
for (int l = 0; l < QK4_1; l += 2) {
|
978
|
+
const float v0 = (x[i*QK4_1 + l + 0] - min)*id;
|
979
|
+
const float v1 = (x[i*QK4_1 + l + 1] - min)*id;
|
899
980
|
|
900
981
|
const uint8_t vi0 = roundf(v0);
|
901
982
|
const uint8_t vi1 = roundf(v1);
|
@@ -911,9 +992,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
|
|
911
992
|
}
|
912
993
|
|
913
994
|
static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
|
914
|
-
assert(k %
|
995
|
+
assert(k % QK4_1 == 0);
|
915
996
|
|
916
|
-
const int nb = k /
|
997
|
+
const int nb = k / QK4_1;
|
917
998
|
|
918
999
|
block_q4_1 * restrict y = vy;
|
919
1000
|
|
@@ -997,7 +1078,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
|
|
997
1078
|
float32x4_t minv[8];
|
998
1079
|
float32x4_t maxv[8];
|
999
1080
|
|
1000
|
-
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*
|
1081
|
+
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK4_1 + 4*l);
|
1001
1082
|
|
1002
1083
|
for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
|
1003
1084
|
for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
|
@@ -1033,9 +1114,327 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
|
|
1033
1114
|
#endif
|
1034
1115
|
}
|
1035
1116
|
|
1117
|
+
// reference implementation for deterministic creation of model files
|
1118
|
+
static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * restrict y, int k) {
|
1119
|
+
assert(k % QK4_2 == 0);
|
1120
|
+
|
1121
|
+
const int nb = k / QK4_2;
|
1122
|
+
|
1123
|
+
for (int i = 0; i < nb; i++) {
|
1124
|
+
float amax = 0.0f; // absolute max
|
1125
|
+
|
1126
|
+
for (int l = 0; l < QK4_2; l++) {
|
1127
|
+
const float v = x[i*QK4_2 + l];
|
1128
|
+
amax = MAX(amax, fabsf(v));
|
1129
|
+
}
|
1130
|
+
|
1131
|
+
const float d = amax / ((1 << 3) - 1);
|
1132
|
+
|
1133
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1134
|
+
|
1135
|
+
y[i].d = GGML_FP32_TO_FP16(d);
|
1136
|
+
|
1137
|
+
for (int l = 0; l < QK4_2; l += 2) {
|
1138
|
+
const float v0 = x[i*QK4_2 + l + 0]*id;
|
1139
|
+
const float v1 = x[i*QK4_2 + l + 1]*id;
|
1140
|
+
|
1141
|
+
const uint8_t vi0 = (uint8_t)(v0 + 8.5f);
|
1142
|
+
const uint8_t vi1 = (uint8_t)(v1 + 8.5f);
|
1143
|
+
|
1144
|
+
assert(vi0 < 16);
|
1145
|
+
assert(vi1 < 16);
|
1146
|
+
|
1147
|
+
y[i].qs[l/2] = vi0 | (vi1 << 4);
|
1148
|
+
}
|
1149
|
+
}
|
1150
|
+
}
|
1151
|
+
|
1152
|
+
static inline int nearest_int(float fval) {
|
1153
|
+
assert(fval <= 4194303.f);
|
1154
|
+
float val = fval + 12582912.f;
|
1155
|
+
int i; memcpy(&i, &val, sizeof(int));
|
1156
|
+
return (i & 0x007fffff) - 0x00400000;
|
1157
|
+
}
|
1158
|
+
|
1159
|
+
static float kquantize_q4_with_bounds(int n, int nmin, int nmax, const float * restrict X, int nCandidates,
|
1160
|
+
const float * restrict candidates, int8_t * restrict L) {
|
1161
|
+
assert (nmin >= INT8_MIN);
|
1162
|
+
assert (nmax <= INT8_MAX);
|
1163
|
+
float amax = 0;
|
1164
|
+
for (int i=0; i<n; ++i) amax = MAX(amax, fabsf(X[i]));
|
1165
|
+
if (!amax) { // all zero
|
1166
|
+
for (int i=0; i<n; ++i) L[i] = 0;
|
1167
|
+
return 1.f;
|
1168
|
+
}
|
1169
|
+
float best = 0, bestScale = 0;
|
1170
|
+
for (int si=0; si<nCandidates; ++si) {
|
1171
|
+
float iscale = candidates[si]/amax;
|
1172
|
+
float sumlxP = 0; int suml2P = 0;
|
1173
|
+
float sumlxM = 0; int suml2M = 0;
|
1174
|
+
for (int i=0; i<n; ++i) {
|
1175
|
+
int l = nearest_int(iscale*X[i]);
|
1176
|
+
int lp = MAX(nmin, MIN(nmax, +l));
|
1177
|
+
int lm = MAX(nmin, MIN(nmax, -l));
|
1178
|
+
sumlxP += X[i]*lp; suml2P += lp*lp;
|
1179
|
+
sumlxM += X[i]*lm; suml2M += lm*lm;
|
1180
|
+
}
|
1181
|
+
float sumlxP2 = sumlxP*sumlxP;
|
1182
|
+
float sumlxM2 = sumlxM*sumlxM;
|
1183
|
+
if (sumlxP2*suml2M > sumlxM2*suml2P) {
|
1184
|
+
if (sumlxP2 > best*suml2P) {
|
1185
|
+
best = sumlxP2/suml2P; bestScale = iscale;
|
1186
|
+
}
|
1187
|
+
} else {
|
1188
|
+
if (sumlxM2 > best*suml2M) {
|
1189
|
+
best = sumlxM2/suml2M; bestScale = -iscale;
|
1190
|
+
}
|
1191
|
+
}
|
1192
|
+
}
|
1193
|
+
float sumlx = 0; int suml2 = 0;
|
1194
|
+
for (int i=0; i<n; ++i) {
|
1195
|
+
int l = nearest_int(bestScale*X[i]);
|
1196
|
+
l = MAX(nmin, MIN(nmax, l));
|
1197
|
+
sumlx += X[i]*l; suml2 += l*l;
|
1198
|
+
L[i] = l;
|
1199
|
+
}
|
1200
|
+
float scale = sumlx/suml2;
|
1201
|
+
return scale;
|
1202
|
+
}
|
1203
|
+
|
1204
|
+
static void quantize_row_q4_2_rmse(const float * restrict x, block_q4_2 * restrict y, int k) {
|
1205
|
+
#define CANDIDATE_COUNT 8
|
1206
|
+
static const float candidates[CANDIDATE_COUNT] = { +8.7f, +8.3f, +8.1f, +7.8f, +7.3f, +7.0f, +6.3f, +5.7f };
|
1207
|
+
assert(k % QK4_2 == 0);
|
1208
|
+
|
1209
|
+
int8_t L[QK4_2];
|
1210
|
+
|
1211
|
+
const int nb = k / QK4_2;
|
1212
|
+
|
1213
|
+
for (int i = 0; i < nb; i++) {
|
1214
|
+
float scale = kquantize_q4_with_bounds(QK4_2, -8, 7, x, CANDIDATE_COUNT, candidates, L);
|
1215
|
+
y[i].d = GGML_FP32_TO_FP16(scale);
|
1216
|
+
|
1217
|
+
for (int l = 0; l < QK4_2; l += 2) {
|
1218
|
+
const uint8_t vi0 = (uint8_t)(L[l+0] + 8);
|
1219
|
+
const uint8_t vi1 = (uint8_t)(L[l+1] + 8);
|
1220
|
+
|
1221
|
+
assert(vi0 < 16);
|
1222
|
+
assert(vi1 < 16);
|
1223
|
+
|
1224
|
+
y[i].qs[l/2] = vi0 | (vi1 << 4);
|
1225
|
+
}
|
1226
|
+
|
1227
|
+
x += QK4_2;
|
1228
|
+
}
|
1229
|
+
}
|
1230
|
+
|
1231
|
+
static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
|
1232
|
+
assert(k % QK4_2 == 0);
|
1233
|
+
|
1234
|
+
block_q4_2 * restrict y = vy;
|
1235
|
+
|
1236
|
+
//quantize_row_q4_2_reference(x, y, k);
|
1237
|
+
// This produces the exact same format, just better match to the input floats ("better" as measured by RMSE)
|
1238
|
+
quantize_row_q4_2_rmse(x, y, k);
|
1239
|
+
}
|
1240
|
+
|
1241
|
+
static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) {
|
1242
|
+
assert(k % QK4_3 == 0);
|
1243
|
+
const int nb = k / QK4_3;
|
1244
|
+
|
1245
|
+
for (int i = 0; i < nb; i++) {
|
1246
|
+
float min = FLT_MAX;
|
1247
|
+
float max = -FLT_MAX;
|
1248
|
+
|
1249
|
+
for (int l = 0; l < QK4_3; l++) {
|
1250
|
+
const float v = x[i*QK4_3 + l];
|
1251
|
+
if (v < min) min = v;
|
1252
|
+
if (v > max) max = v;
|
1253
|
+
}
|
1254
|
+
|
1255
|
+
const float d = (max - min) / ((1 << 4) - 1);
|
1256
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1257
|
+
|
1258
|
+
y[i].d = GGML_FP32_TO_FP16(d);
|
1259
|
+
y[i].m = GGML_FP32_TO_FP16(min);
|
1260
|
+
|
1261
|
+
for (int l = 0; l < QK4_3; l += 2) {
|
1262
|
+
const float v0 = (x[i*QK4_3 + l + 0] - min)*id;
|
1263
|
+
const float v1 = (x[i*QK4_3 + l + 1] - min)*id;
|
1264
|
+
|
1265
|
+
const uint8_t vi0 = (int) (v0 + 0.5f);
|
1266
|
+
const uint8_t vi1 = (int) (v1 + 0.5f);
|
1267
|
+
|
1268
|
+
assert(vi0 < 16);
|
1269
|
+
assert(vi1 < 16);
|
1270
|
+
|
1271
|
+
y[i].qs[l/2] = vi0 | (vi1 << 4);
|
1272
|
+
}
|
1273
|
+
}
|
1274
|
+
}
|
1275
|
+
|
1276
|
+
static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int k) {
|
1277
|
+
assert(k % QK4_3 == 0);
|
1278
|
+
|
1279
|
+
block_q4_3 * restrict y = vy;
|
1280
|
+
|
1281
|
+
quantize_row_q4_3_reference(x, y, k);
|
1282
|
+
}
|
1283
|
+
|
1284
|
+
// reference implementation for deterministic creation of model files
|
1285
|
+
static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
|
1286
|
+
assert(k % QK8_0 == 0);
|
1287
|
+
const int nb = k / QK8_0;
|
1288
|
+
|
1289
|
+
for (int i = 0; i < nb; i++) {
|
1290
|
+
float amax = 0.0f; // absolute max
|
1291
|
+
|
1292
|
+
for (int l = 0; l < QK8_0; l++) {
|
1293
|
+
const float v = x[i*QK8_0 + l];
|
1294
|
+
amax = MAX(amax, fabsf(v));
|
1295
|
+
}
|
1296
|
+
|
1297
|
+
const float d = amax / ((1 << 7) - 1);
|
1298
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1299
|
+
|
1300
|
+
y[i].d = d;
|
1301
|
+
|
1302
|
+
for (int l = 0; l < QK8_0; ++l) {
|
1303
|
+
const float v = x[i*QK8_0 + l]*id;
|
1304
|
+
y[i].qs[l] = roundf(v);
|
1305
|
+
}
|
1306
|
+
}
|
1307
|
+
}
|
1308
|
+
|
1309
|
+
static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
|
1310
|
+
assert(k % QK8_0 == 0);
|
1311
|
+
const int nb = k / QK8_0;
|
1312
|
+
|
1313
|
+
block_q8_0 * restrict y = vy;
|
1314
|
+
|
1315
|
+
#if defined(__ARM_NEON)
|
1316
|
+
for (int i = 0; i < nb; i++) {
|
1317
|
+
float32x4_t srcv [8];
|
1318
|
+
float32x4_t asrcv[8];
|
1319
|
+
float32x4_t amaxv[8];
|
1320
|
+
|
1321
|
+
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
|
1322
|
+
for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
|
1323
|
+
|
1324
|
+
for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
|
1325
|
+
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
|
1326
|
+
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
|
1327
|
+
|
1328
|
+
const float amax = vmaxvq_f32(amaxv[0]);
|
1329
|
+
|
1330
|
+
const float d = amax / ((1 << 7) - 1);
|
1331
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1332
|
+
|
1333
|
+
y[i].d = d;
|
1334
|
+
|
1335
|
+
for (int l = 0; l < 8; l++) {
|
1336
|
+
const float32x4_t v = vmulq_n_f32(srcv[l], id);
|
1337
|
+
const int32x4_t vi = vcvtnq_s32_f32(v);
|
1338
|
+
|
1339
|
+
y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
|
1340
|
+
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
|
1341
|
+
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
|
1342
|
+
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
|
1343
|
+
}
|
1344
|
+
}
|
1345
|
+
#elif defined(__AVX2__) || defined(__AVX__)
|
1346
|
+
for (int i = 0; i < nb; i++) {
|
1347
|
+
// Load elements into 4 AVX vectors
|
1348
|
+
__m256 v0 = _mm256_loadu_ps( x );
|
1349
|
+
__m256 v1 = _mm256_loadu_ps( x + 8 );
|
1350
|
+
__m256 v2 = _mm256_loadu_ps( x + 16 );
|
1351
|
+
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
1352
|
+
x += 32;
|
1353
|
+
|
1354
|
+
// Compute max(abs(e)) for the block
|
1355
|
+
const __m256 signBit = _mm256_set1_ps( -0.0f );
|
1356
|
+
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
|
1357
|
+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
|
1358
|
+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
|
1359
|
+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
|
1360
|
+
|
1361
|
+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
|
1362
|
+
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
1363
|
+
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
1364
|
+
const float maxScalar = _mm_cvtss_f32( max4 );
|
1365
|
+
|
1366
|
+
// Quantize these floats
|
1367
|
+
const float d = maxScalar / 127.f;
|
1368
|
+
y[i].d = d;
|
1369
|
+
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
|
1370
|
+
const __m256 mul = _mm256_set1_ps( id );
|
1371
|
+
|
1372
|
+
// Apply the multiplier
|
1373
|
+
v0 = _mm256_mul_ps( v0, mul );
|
1374
|
+
v1 = _mm256_mul_ps( v1, mul );
|
1375
|
+
v2 = _mm256_mul_ps( v2, mul );
|
1376
|
+
v3 = _mm256_mul_ps( v3, mul );
|
1377
|
+
|
1378
|
+
// Round to nearest integer
|
1379
|
+
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
|
1380
|
+
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
|
1381
|
+
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
|
1382
|
+
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
|
1383
|
+
|
1384
|
+
// Convert floats to integers
|
1385
|
+
__m256i i0 = _mm256_cvtps_epi32( v0 );
|
1386
|
+
__m256i i1 = _mm256_cvtps_epi32( v1 );
|
1387
|
+
__m256i i2 = _mm256_cvtps_epi32( v2 );
|
1388
|
+
__m256i i3 = _mm256_cvtps_epi32( v3 );
|
1389
|
+
|
1390
|
+
#if defined(__AVX2__)
|
1391
|
+
// Convert int32 to int16
|
1392
|
+
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
|
1393
|
+
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
|
1394
|
+
// Convert int16 to int8
|
1395
|
+
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
|
1396
|
+
|
1397
|
+
// We got our precious signed bytes, but the order is now wrong
|
1398
|
+
// These AVX2 pack instructions process 16-byte pieces independently
|
1399
|
+
// The following instruction is fixing the order
|
1400
|
+
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
|
1401
|
+
i0 = _mm256_permutevar8x32_epi32( i0, perm );
|
1402
|
+
|
1403
|
+
_mm256_storeu_si256((__m256i *)y[i].qs, i0);
|
1404
|
+
#else
|
1405
|
+
// Since we don't have in AVX some necessary functions,
|
1406
|
+
// we split the registers in half and call AVX2 analogs from SSE
|
1407
|
+
__m128i ni0 = _mm256_castsi256_si128( i0 );
|
1408
|
+
__m128i ni1 = _mm256_extractf128_si256( i0, 1);
|
1409
|
+
__m128i ni2 = _mm256_castsi256_si128( i1 );
|
1410
|
+
__m128i ni3 = _mm256_extractf128_si256( i1, 1);
|
1411
|
+
__m128i ni4 = _mm256_castsi256_si128( i2 );
|
1412
|
+
__m128i ni5 = _mm256_extractf128_si256( i2, 1);
|
1413
|
+
__m128i ni6 = _mm256_castsi256_si128( i3 );
|
1414
|
+
__m128i ni7 = _mm256_extractf128_si256( i3, 1);
|
1415
|
+
|
1416
|
+
// Convert int32 to int16
|
1417
|
+
ni0 = _mm_packs_epi32( ni0, ni1 );
|
1418
|
+
ni2 = _mm_packs_epi32( ni2, ni3 );
|
1419
|
+
ni4 = _mm_packs_epi32( ni4, ni5 );
|
1420
|
+
ni6 = _mm_packs_epi32( ni6, ni7 );
|
1421
|
+
// Convert int16 to int8
|
1422
|
+
ni0 = _mm_packs_epi16( ni0, ni2 );
|
1423
|
+
ni4 = _mm_packs_epi16( ni4, ni6 );
|
1424
|
+
|
1425
|
+
_mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
|
1426
|
+
_mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
|
1427
|
+
#endif
|
1428
|
+
}
|
1429
|
+
#else
|
1430
|
+
// scalar
|
1431
|
+
quantize_row_q8_0_reference(x, y, k);
|
1432
|
+
#endif
|
1433
|
+
}
|
1434
|
+
|
1036
1435
|
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
|
1037
|
-
assert(k %
|
1038
|
-
const int nb = k /
|
1436
|
+
assert(k % QK4_0 == 0);
|
1437
|
+
const int nb = k / QK4_0;
|
1039
1438
|
|
1040
1439
|
const block_q4_0 * restrict x = vx;
|
1041
1440
|
|
@@ -1046,9 +1445,9 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1046
1445
|
|
1047
1446
|
const uint8_t * restrict pp = x[i].qs;
|
1048
1447
|
|
1049
|
-
for (int l = 0; l <
|
1448
|
+
for (int l = 0; l < QK4_0; l += 32) {
|
1050
1449
|
// Load 32x4-bit integers into 32x8-bit integers
|
1051
|
-
__m256i vx8 =
|
1450
|
+
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
|
1052
1451
|
|
1053
1452
|
// Subtract 8 from the integers
|
1054
1453
|
vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
|
@@ -1068,7 +1467,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1068
1467
|
// Scale and store
|
1069
1468
|
for (int j = 0; j < 4; j++) {
|
1070
1469
|
const __m256 result = _mm256_mul_ps(vf[j], d_v);
|
1071
|
-
_mm256_storeu_ps(y + i *
|
1470
|
+
_mm256_storeu_ps(y + i * QK4_0 + l + j*8, result);
|
1072
1471
|
}
|
1073
1472
|
}
|
1074
1473
|
}
|
@@ -1078,7 +1477,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1078
1477
|
|
1079
1478
|
const uint8_t * restrict pp = x[i].qs;
|
1080
1479
|
|
1081
|
-
for (int l = 0; l <
|
1480
|
+
for (int l = 0; l < QK4_0; l += 16) {
|
1082
1481
|
// Load 16x4-bit integers into 8x8-bit integers
|
1083
1482
|
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
1084
1483
|
|
@@ -1117,10 +1516,10 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1117
1516
|
const float32x4_t r3 = vmulq_f32(vf_3, vd);
|
1118
1517
|
|
1119
1518
|
// Store
|
1120
|
-
vst1q_f32(y + i*
|
1121
|
-
vst1q_f32(y + i*
|
1122
|
-
vst1q_f32(y + i*
|
1123
|
-
vst1q_f32(y + i*
|
1519
|
+
vst1q_f32(y + i*QK4_0 + l + 0, r0);
|
1520
|
+
vst1q_f32(y + i*QK4_0 + l + 4, r1);
|
1521
|
+
vst1q_f32(y + i*QK4_0 + l + 8, r2);
|
1522
|
+
vst1q_f32(y + i*QK4_0 + l + 12, r3);
|
1124
1523
|
}
|
1125
1524
|
}
|
1126
1525
|
#else
|
@@ -1130,7 +1529,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1130
1529
|
|
1131
1530
|
const uint8_t * restrict pp = x[i].qs;
|
1132
1531
|
|
1133
|
-
for (int l = 0; l <
|
1532
|
+
for (int l = 0; l < QK4_0; l += 2) {
|
1134
1533
|
const uint8_t vi = pp[l/2];
|
1135
1534
|
|
1136
1535
|
const int8_t vi0 = vi & 0xf;
|
@@ -1141,19 +1540,19 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1141
1540
|
|
1142
1541
|
//printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1);
|
1143
1542
|
|
1144
|
-
y[i*
|
1145
|
-
y[i*
|
1543
|
+
y[i*QK4_0 + l + 0] = v0;
|
1544
|
+
y[i*QK4_0 + l + 1] = v1;
|
1146
1545
|
|
1147
|
-
assert(!isnan(y[i*
|
1148
|
-
assert(!isnan(y[i*
|
1546
|
+
assert(!isnan(y[i*QK4_0 + l + 0]));
|
1547
|
+
assert(!isnan(y[i*QK4_0 + l + 1]));
|
1149
1548
|
}
|
1150
1549
|
}
|
1151
1550
|
#endif
|
1152
1551
|
}
|
1153
1552
|
|
1154
1553
|
static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, int k) {
|
1155
|
-
assert(k %
|
1156
|
-
const int nb = k /
|
1554
|
+
assert(k % QK4_1 == 0);
|
1555
|
+
const int nb = k / QK4_1;
|
1157
1556
|
|
1158
1557
|
const block_q4_1 * restrict x = vx;
|
1159
1558
|
|
@@ -1164,9 +1563,9 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1164
1563
|
|
1165
1564
|
const uint8_t * restrict pp = x[i].qs;
|
1166
1565
|
|
1167
|
-
for (int l = 0; l <
|
1566
|
+
for (int l = 0; l < QK4_1; l += 32) {
|
1168
1567
|
// Load 32x4-bit integers into 32x8-bit integers
|
1169
|
-
__m256i vx8 =
|
1568
|
+
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
|
1170
1569
|
|
1171
1570
|
// Convert to 16-bit int
|
1172
1571
|
const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
|
@@ -1183,7 +1582,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1183
1582
|
// Scale, add m and store
|
1184
1583
|
for (int j = 0; j < 4; j++) {
|
1185
1584
|
const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
|
1186
|
-
_mm256_storeu_ps(y + i *
|
1585
|
+
_mm256_storeu_ps(y + i * QK4_1 + l + j*8, result);
|
1187
1586
|
}
|
1188
1587
|
}
|
1189
1588
|
}
|
@@ -1194,7 +1593,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1194
1593
|
|
1195
1594
|
const uint8_t * restrict pp = x[i].qs;
|
1196
1595
|
|
1197
|
-
for (int l = 0; l <
|
1596
|
+
for (int l = 0; l < QK4_1; l += 16) {
|
1198
1597
|
// Load 16x4-bit integers into 8x8-bit integers
|
1199
1598
|
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
1200
1599
|
|
@@ -1225,10 +1624,10 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1225
1624
|
const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd);
|
1226
1625
|
|
1227
1626
|
// Store
|
1228
|
-
vst1q_f32(y + i*
|
1229
|
-
vst1q_f32(y + i*
|
1230
|
-
vst1q_f32(y + i*
|
1231
|
-
vst1q_f32(y + i*
|
1627
|
+
vst1q_f32(y + i*QK4_1 + l + 0, r0);
|
1628
|
+
vst1q_f32(y + i*QK4_1 + l + 4, r1);
|
1629
|
+
vst1q_f32(y + i*QK4_1 + l + 8, r2);
|
1630
|
+
vst1q_f32(y + i*QK4_1 + l + 12, r3);
|
1232
1631
|
}
|
1233
1632
|
}
|
1234
1633
|
#else
|
@@ -1238,7 +1637,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1238
1637
|
|
1239
1638
|
const uint8_t * restrict pp = x[i].qs;
|
1240
1639
|
|
1241
|
-
for (int l = 0; l <
|
1640
|
+
for (int l = 0; l < QK4_1; l += 2) {
|
1242
1641
|
const uint8_t vi = pp[l/2];
|
1243
1642
|
|
1244
1643
|
const int8_t vi0 = vi & 0xf;
|
@@ -1247,21 +1646,130 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1247
1646
|
const float v0 = vi0*d + m;
|
1248
1647
|
const float v1 = vi1*d + m;
|
1249
1648
|
|
1250
|
-
y[i*
|
1251
|
-
y[i*
|
1649
|
+
y[i*QK4_1 + l + 0] = v0;
|
1650
|
+
y[i*QK4_1 + l + 1] = v1;
|
1252
1651
|
|
1253
|
-
assert(!isnan(y[i*
|
1254
|
-
assert(!isnan(y[i*
|
1652
|
+
assert(!isnan(y[i*QK4_1 + l + 0]));
|
1653
|
+
assert(!isnan(y[i*QK4_1 + l + 1]));
|
1255
1654
|
}
|
1256
1655
|
}
|
1257
1656
|
#endif
|
1258
1657
|
}
|
1259
1658
|
|
1260
|
-
|
1261
|
-
|
1262
|
-
|
1659
|
+
static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, int k) {
|
1660
|
+
assert(k % QK4_2 == 0);
|
1661
|
+
const int nb = k / QK4_2;
|
1263
1662
|
|
1264
|
-
|
1663
|
+
const block_q4_2 * restrict x = vx;
|
1664
|
+
|
1665
|
+
for (int i = 0; i < nb; i++) {
|
1666
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
1667
|
+
|
1668
|
+
const uint8_t * restrict pp = x[i].qs;
|
1669
|
+
|
1670
|
+
for (int l = 0; l < QK4_2; l += 2) {
|
1671
|
+
const uint8_t vi = pp[l/2];
|
1672
|
+
|
1673
|
+
const int8_t vi0 = vi & 0xf;
|
1674
|
+
const int8_t vi1 = vi >> 4;
|
1675
|
+
|
1676
|
+
const float v0 = (vi0 - 8)*d;
|
1677
|
+
const float v1 = (vi1 - 8)*d;
|
1678
|
+
|
1679
|
+
y[i*QK4_2 + l + 0] = v0;
|
1680
|
+
y[i*QK4_2 + l + 1] = v1;
|
1681
|
+
|
1682
|
+
assert(!isnan(y[i*QK4_2 + l + 0]));
|
1683
|
+
assert(!isnan(y[i*QK4_2 + l + 1]));
|
1684
|
+
}
|
1685
|
+
}
|
1686
|
+
}
|
1687
|
+
|
1688
|
+
static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, int k) {
|
1689
|
+
assert(k % QK4_3 == 0);
|
1690
|
+
const int nb = k / QK4_3;
|
1691
|
+
|
1692
|
+
const block_q4_3 * restrict x = vx;
|
1693
|
+
|
1694
|
+
for (int i = 0; i < nb; i++) {
|
1695
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
1696
|
+
const float m = GGML_FP16_TO_FP32(x[i].m);
|
1697
|
+
|
1698
|
+
const uint8_t * restrict pp = x[i].qs;
|
1699
|
+
|
1700
|
+
for (int l = 0; l < QK4_3; l += 2) {
|
1701
|
+
const uint8_t vi = pp[l/2];
|
1702
|
+
|
1703
|
+
const int8_t vi0 = vi & 0xf;
|
1704
|
+
const int8_t vi1 = vi >> 4;
|
1705
|
+
|
1706
|
+
const float v0 = vi0*d + m;
|
1707
|
+
const float v1 = vi1*d + m;
|
1708
|
+
|
1709
|
+
y[i*QK4_3 + l + 0] = v0;
|
1710
|
+
y[i*QK4_3 + l + 1] = v1;
|
1711
|
+
|
1712
|
+
assert(!isnan(y[i*QK4_3 + l + 0]));
|
1713
|
+
assert(!isnan(y[i*QK4_3 + l + 1]));
|
1714
|
+
}
|
1715
|
+
}
|
1716
|
+
}
|
1717
|
+
|
1718
|
+
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1719
|
+
static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1720
|
+
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1721
|
+
static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1722
|
+
|
1723
|
+
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
1724
|
+
[GGML_TYPE_Q4_0] = {
|
1725
|
+
.dequantize_row_q = dequantize_row_q4_0,
|
1726
|
+
.quantize_row_q = quantize_row_q4_0,
|
1727
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
|
1728
|
+
.quantize_row_q_dot = quantize_row_q8_0,
|
1729
|
+
.vec_dot_q = ggml_vec_dot_q4_0_q8_0,
|
1730
|
+
},
|
1731
|
+
[GGML_TYPE_Q4_1] = {
|
1732
|
+
.dequantize_row_q = dequantize_row_q4_1,
|
1733
|
+
.quantize_row_q = quantize_row_q4_1,
|
1734
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
|
1735
|
+
.quantize_row_q_dot = quantize_row_q8_0,
|
1736
|
+
.vec_dot_q = ggml_vec_dot_q4_1_q8_0,
|
1737
|
+
},
|
1738
|
+
[GGML_TYPE_Q4_2] = {
|
1739
|
+
.dequantize_row_q = dequantize_row_q4_2,
|
1740
|
+
.quantize_row_q = quantize_row_q4_2,
|
1741
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_rmse, //quantize_row_q4_2_reference,
|
1742
|
+
.quantize_row_q_dot = quantize_row_q8_0,
|
1743
|
+
.vec_dot_q = ggml_vec_dot_q4_2_q8_0,
|
1744
|
+
},
|
1745
|
+
[GGML_TYPE_Q4_3] = {
|
1746
|
+
.dequantize_row_q = dequantize_row_q4_3,
|
1747
|
+
.quantize_row_q = quantize_row_q4_3,
|
1748
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference, // TODO: RMSE optimization
|
1749
|
+
.quantize_row_q_dot = quantize_row_q8_0,
|
1750
|
+
.vec_dot_q = ggml_vec_dot_q4_3_q8_0,
|
1751
|
+
},
|
1752
|
+
[GGML_TYPE_Q8_0] = {
|
1753
|
+
.dequantize_row_q = NULL, // TODO
|
1754
|
+
.quantize_row_q = quantize_row_q8_0,
|
1755
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
|
1756
|
+
.quantize_row_q_dot = quantize_row_q8_0,
|
1757
|
+
.vec_dot_q = NULL, // TODO
|
1758
|
+
},
|
1759
|
+
};
|
1760
|
+
|
1761
|
+
// For internal test use
|
1762
|
+
quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
|
1763
|
+
GGML_ASSERT(i < GGML_TYPE_COUNT);
|
1764
|
+
return quantize_fns[i];
|
1765
|
+
}
|
1766
|
+
|
1767
|
+
|
1768
|
+
//
|
1769
|
+
// simd mappings
|
1770
|
+
//
|
1771
|
+
|
1772
|
+
// we define a common set of C macros which map to specific intrinsics based on the current architecture
|
1265
1773
|
// we then implement the fundamental computation operations below using only these macros
|
1266
1774
|
// adding support for new architectures requires to define the corresponding SIMD macros
|
1267
1775
|
//
|
@@ -1813,37 +2321,6 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
|
|
1813
2321
|
*s = sumf;
|
1814
2322
|
}
|
1815
2323
|
|
1816
|
-
#if __AVX512F__ && QK == 32
|
1817
|
-
static inline __m512 dot_q4_0_oneblock_avx512(
|
1818
|
-
__m512 acc,
|
1819
|
-
const block_q4_0 * restrict x,
|
1820
|
-
const block_q4_0 * restrict y,
|
1821
|
-
int i
|
1822
|
-
) {
|
1823
|
-
// Compute combined scale for the block
|
1824
|
-
__m512 d = _mm512_set1_ps( x[i].d * y[i].d );
|
1825
|
-
|
1826
|
-
__m256i bx = bytesFromNibbles( x[i].qs );
|
1827
|
-
__m256i by = bytesFromNibbles( y[i].qs );
|
1828
|
-
|
1829
|
-
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
1830
|
-
const __m256i off = _mm256_set1_epi8( 8 );
|
1831
|
-
bx = _mm256_sub_epi8( bx, off );
|
1832
|
-
by = _mm256_sub_epi8( by, off );
|
1833
|
-
|
1834
|
-
// Sign-extend 16 signed bytes into int16_t
|
1835
|
-
__m512i x32 = _mm512_cvtepi8_epi16( bx );
|
1836
|
-
__m512i y32 = _mm512_cvtepi8_epi16( by );
|
1837
|
-
// Compute products of int16_t integers, add pairwise
|
1838
|
-
__m512i i64 = _mm512_madd_epi16( x32, y32 );
|
1839
|
-
|
1840
|
-
// Convert int32_t to float
|
1841
|
-
__m512 p = _mm512_cvtepi32_ps( i64 );
|
1842
|
-
// Apply the scale, and accumulate
|
1843
|
-
return _mm512_fmadd_ps( d, p, acc );
|
1844
|
-
}
|
1845
|
-
#endif
|
1846
|
-
|
1847
2324
|
inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
|
1848
2325
|
ggml_float sumf = 0.0;
|
1849
2326
|
|
@@ -1880,67 +2357,64 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
|
|
1880
2357
|
*s = sumf;
|
1881
2358
|
}
|
1882
2359
|
|
1883
|
-
static void
|
1884
|
-
const int nb = n /
|
2360
|
+
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2361
|
+
const int nb = n / QK8_0;
|
1885
2362
|
|
1886
|
-
assert(n %
|
2363
|
+
assert(n % QK8_0 == 0);
|
1887
2364
|
assert(nb % 2 == 0);
|
1888
2365
|
|
1889
2366
|
const block_q4_0 * restrict x = vx;
|
1890
|
-
const
|
2367
|
+
const block_q8_0 * restrict y = vy;
|
1891
2368
|
|
1892
2369
|
float sumf = 0.0;
|
1893
2370
|
|
1894
2371
|
#if defined(__ARM_NEON)
|
1895
|
-
|
1896
|
-
|
2372
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
2373
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
1897
2374
|
|
1898
2375
|
for (int i = 0; i < nb; i += 2) {
|
1899
2376
|
const block_q4_0 * restrict x0 = &x[i + 0];
|
1900
|
-
const block_q4_0 * restrict y0 = &y[i + 0];
|
1901
2377
|
const block_q4_0 * restrict x1 = &x[i + 1];
|
1902
|
-
const
|
2378
|
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
2379
|
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
1903
2380
|
|
1904
|
-
const uint8x16_t m4b
|
1905
|
-
const int8x16_t s8b
|
2381
|
+
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
2382
|
+
const int8x16_t s8b = vdupq_n_s8(0x8);
|
1906
2383
|
|
1907
2384
|
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
1908
|
-
const uint8x16_t v1_0 = vld1q_u8(y0->qs);
|
1909
2385
|
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
1910
|
-
const uint8x16_t v1_1 = vld1q_u8(y1->qs);
|
1911
2386
|
|
1912
2387
|
// 4-bit -> 8-bit
|
1913
|
-
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
|
1914
|
-
const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
|
2388
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
1915
2389
|
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
1916
|
-
const int8x16_t
|
1917
|
-
|
1918
|
-
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
|
1919
|
-
const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
|
2390
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
1920
2391
|
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
1921
|
-
const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
|
1922
2392
|
|
1923
2393
|
// sub 8
|
1924
2394
|
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
1925
|
-
const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
|
1926
2395
|
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
1927
|
-
const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
|
1928
|
-
|
1929
2396
|
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
1930
|
-
const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
|
1931
2397
|
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
1932
|
-
|
2398
|
+
|
2399
|
+
// load y
|
2400
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2401
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2402
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2403
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2404
|
+
|
2405
|
+
// interleave
|
2406
|
+
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
|
2407
|
+
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
|
2408
|
+
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
|
2409
|
+
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
|
1933
2410
|
|
1934
2411
|
#if defined(__ARM_FEATURE_DOTPROD)
|
1935
2412
|
// dot product into int32x4_t
|
1936
|
-
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
|
1937
|
-
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
|
2413
|
+
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
|
2414
|
+
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
|
1938
2415
|
|
1939
|
-
|
1940
|
-
|
1941
|
-
|
1942
|
-
sum0 += x0->d*y0->d*vaddvq_s32(p_0);
|
1943
|
-
sum1 += x1->d*y1->d*vaddvq_s32(p_1);
|
2416
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
|
2417
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
|
1944
2418
|
#else
|
1945
2419
|
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
|
1946
2420
|
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
|
@@ -1952,115 +2426,51 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
1952
2426
|
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
|
1953
2427
|
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
|
1954
2428
|
|
1955
|
-
const
|
1956
|
-
const
|
1957
|
-
|
1958
|
-
const
|
1959
|
-
const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
|
2429
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2430
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2431
|
+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
2432
|
+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
1960
2433
|
|
1961
|
-
|
1962
|
-
|
1963
|
-
|
1964
|
-
sum0 += x0->d*y0->d*vaddvq_s16(p_0);
|
1965
|
-
sum1 += x1->d*y1->d*vaddvq_s16(p_1);
|
2434
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
|
2435
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
|
1966
2436
|
#endif
|
1967
2437
|
}
|
1968
2438
|
|
1969
|
-
sumf =
|
1970
|
-
#elif defined(
|
2439
|
+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
2440
|
+
#elif defined(__AVX2__)
|
1971
2441
|
// Initialize accumulator with zeros
|
1972
|
-
|
1973
|
-
__m512 acc1 = _mm512_setzero_ps();
|
2442
|
+
__m256 acc = _mm256_setzero_ps();
|
1974
2443
|
|
1975
|
-
|
1976
|
-
|
2444
|
+
// Main loop
|
2445
|
+
for (int i = 0; i < nb; ++i) {
|
2446
|
+
/* Compute combined scale for the block */
|
2447
|
+
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
|
1977
2448
|
|
1978
|
-
|
1979
|
-
int i = superblock_ix * superblock_size;
|
2449
|
+
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
1980
2450
|
|
1981
|
-
|
1982
|
-
|
1983
|
-
|
1984
|
-
acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+3 );
|
1985
|
-
acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+4 );
|
1986
|
-
acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+5 );
|
1987
|
-
acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+6 );
|
1988
|
-
acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+7 );
|
1989
|
-
}
|
2451
|
+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2452
|
+
const __m256i off = _mm256_set1_epi8( 8 );
|
2453
|
+
bx = _mm256_sub_epi8( bx, off );
|
1990
2454
|
|
1991
|
-
|
1992
|
-
for (int i = superblock_count * superblock_size; i < nb; ++i) {
|
1993
|
-
acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i );
|
1994
|
-
}
|
2455
|
+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
1995
2456
|
|
1996
|
-
|
1997
|
-
|
1998
|
-
#elif defined(__AVX2__)
|
1999
|
-
// Initialize accumulator with zeros
|
2000
|
-
__m256 acc = _mm256_setzero_ps();
|
2457
|
+
// Get absolute values of x vectors
|
2458
|
+
const __m256i ax = _mm256_sign_epi8(bx, bx);
|
2001
2459
|
|
2002
|
-
|
2003
|
-
|
2004
|
-
const __m256i offset_8 = _mm256_set1_epi16( 8 );
|
2460
|
+
// Sign the values of the y vectors
|
2461
|
+
const __m256i sy = _mm256_sign_epi8(by, bx);
|
2005
2462
|
|
2006
|
-
|
2007
|
-
|
2008
|
-
assert(nb % UNROLL_COUNT == 0);
|
2463
|
+
// Perform multiplication and create 16-bit values
|
2464
|
+
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
2009
2465
|
|
2010
|
-
|
2011
|
-
|
2012
|
-
|
2013
|
-
|
2014
|
-
|
2015
|
-
|
2016
|
-
|
2017
|
-
|
2018
|
-
|
2019
|
-
/* get input from x
|
2020
|
-
Input: 32 Nibbles (16 bytes) at *x[i+u]
|
2021
|
-
Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
|
2022
|
-
|
2023
|
-
/* Load 16 bytes from memory */
|
2024
|
-
const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
|
2025
|
-
/* Expand bytes into uint16_t values */
|
2026
|
-
const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
|
2027
|
-
/* Unpack values into individual bytes */
|
2028
|
-
__m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
|
2029
|
-
const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
|
2030
|
-
__m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
|
2031
|
-
/* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
|
2032
|
-
x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
|
2033
|
-
x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
|
2034
|
-
|
2035
|
-
/* get input from y
|
2036
|
-
Input: 32 Nibbles (16 bytes) at *y[i+u]
|
2037
|
-
Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
|
2038
|
-
|
2039
|
-
/* Load 16 bytes from memory */
|
2040
|
-
const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
|
2041
|
-
/* Expand bytes into uint16_t values */
|
2042
|
-
const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
|
2043
|
-
/* Unpack values into individual bytes */
|
2044
|
-
const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
|
2045
|
-
__m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
|
2046
|
-
__m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
|
2047
|
-
/* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
|
2048
|
-
y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
|
2049
|
-
y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
|
2050
|
-
|
2051
|
-
/* Compute products of int16_t integers, add pairwise, store as int32_t */
|
2052
|
-
__m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
|
2053
|
-
__m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
|
2054
|
-
|
2055
|
-
/* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
|
2056
|
-
__m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
|
2057
|
-
|
2058
|
-
/* Convert to vectore of 8 int32_t to 8 floats */
|
2059
|
-
__m256 q = _mm256_cvtepi32_ps( xy_q );
|
2060
|
-
|
2061
|
-
/* Multiply q with scale and accumulate */
|
2062
|
-
acc = _mm256_fmadd_ps( scale, q, acc );
|
2063
|
-
}
|
2466
|
+
const __m256i ones = _mm256_set1_epi16(1);
|
2467
|
+
__m256i xy_q = _mm256_madd_epi16(ones, dot);
|
2468
|
+
|
2469
|
+
/* Convert to vectore of 8 int32_t to 8 floats */
|
2470
|
+
__m256 q = _mm256_cvtepi32_ps( xy_q );
|
2471
|
+
|
2472
|
+
/* Multiply q with scale and accumulate */
|
2473
|
+
acc = _mm256_fmadd_ps( d, q, acc );
|
2064
2474
|
}
|
2065
2475
|
|
2066
2476
|
// Return horizontal sum of the acc vector
|
@@ -2082,13 +2492,12 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
2082
2492
|
__m128i i32[2];
|
2083
2493
|
for (int j = 0; j < 2; ++j) {
|
2084
2494
|
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
|
2085
|
-
__m128i bx =
|
2086
|
-
__m128i by =
|
2495
|
+
__m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j);
|
2496
|
+
__m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
|
2087
2497
|
|
2088
2498
|
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2089
2499
|
const __m128i off = _mm_set1_epi8( 8 );
|
2090
2500
|
bx = _mm_sub_epi8( bx, off );
|
2091
|
-
by = _mm_sub_epi8( by, off );
|
2092
2501
|
|
2093
2502
|
// Get absolute values of x vectors
|
2094
2503
|
const __m128i ax = _mm_sign_epi8(bx, bx);
|
@@ -2116,86 +2525,6 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
2116
2525
|
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2117
2526
|
|
2118
2527
|
sumf = _mm_cvtss_f32( res );
|
2119
|
-
#elif defined(__wasm_simd128__)
|
2120
|
-
// wasm simd
|
2121
|
-
float sum0 = 0.0f;
|
2122
|
-
float sum1 = 0.0f;
|
2123
|
-
|
2124
|
-
for (int i = 0; i < nb; i += 2) {
|
2125
|
-
const block_q4_0 * restrict x0 = &x[i + 0];
|
2126
|
-
const block_q4_0 * restrict y0 = &y[i + 0];
|
2127
|
-
const block_q4_0 * restrict x1 = &x[i + 1];
|
2128
|
-
const block_q4_0 * restrict y1 = &y[i + 1];
|
2129
|
-
|
2130
|
-
const v128_t m4b = wasm_u8x16_splat(0xf);
|
2131
|
-
const v128_t s8b = wasm_i8x16_splat(0x8);
|
2132
|
-
|
2133
|
-
const v128_t v0_0 = wasm_v128_load(x0->qs);
|
2134
|
-
const v128_t v0_1 = wasm_v128_load(y0->qs);
|
2135
|
-
const v128_t v1_0 = wasm_v128_load(x1->qs);
|
2136
|
-
const v128_t v1_1 = wasm_v128_load(y1->qs);
|
2137
|
-
|
2138
|
-
// 4-bit -> 8-bit
|
2139
|
-
const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
|
2140
|
-
const v128_t v1_0l = wasm_v128_and(v1_0, m4b);
|
2141
|
-
|
2142
|
-
const v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
|
2143
|
-
const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4);
|
2144
|
-
|
2145
|
-
const v128_t v0_1l = wasm_v128_and(v0_1, m4b);
|
2146
|
-
const v128_t v1_1l = wasm_v128_and(v1_1, m4b);
|
2147
|
-
|
2148
|
-
const v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
|
2149
|
-
const v128_t v1_1h = wasm_u8x16_shr(v1_1, 4);
|
2150
|
-
|
2151
|
-
// sub 8
|
2152
|
-
const v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
|
2153
|
-
const v128_t v1_0ls = wasm_i8x16_sub(v1_0l, s8b);
|
2154
|
-
|
2155
|
-
const v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
|
2156
|
-
const v128_t v1_0hs = wasm_i8x16_sub(v1_0h, s8b);
|
2157
|
-
|
2158
|
-
const v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
|
2159
|
-
const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b);
|
2160
|
-
|
2161
|
-
const v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
|
2162
|
-
const v128_t v1_1hs = wasm_i8x16_sub(v1_1h, s8b);
|
2163
|
-
|
2164
|
-
// dot product into int16x8_t
|
2165
|
-
const v128_t pl0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0ls), wasm_i16x8_extend_low_i8x16(v1_0ls));
|
2166
|
-
const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls));
|
2167
|
-
|
2168
|
-
const v128_t ph0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0hs), wasm_i16x8_extend_low_i8x16(v1_0hs));
|
2169
|
-
const v128_t ph0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0hs), wasm_i16x8_extend_high_i8x16(v1_0hs));
|
2170
|
-
|
2171
|
-
const v128_t pl1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1ls), wasm_i16x8_extend_low_i8x16(v1_1ls));
|
2172
|
-
const v128_t pl1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1ls), wasm_i16x8_extend_high_i8x16(v1_1ls));
|
2173
|
-
|
2174
|
-
const v128_t ph1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1hs), wasm_i16x8_extend_low_i8x16(v1_1hs));
|
2175
|
-
const v128_t ph1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1hs), wasm_i16x8_extend_high_i8x16(v1_1hs));
|
2176
|
-
|
2177
|
-
const v128_t pl_0 = wasm_i16x8_add(pl0l, pl0h);
|
2178
|
-
const v128_t ph_0 = wasm_i16x8_add(ph0l, ph0h);
|
2179
|
-
|
2180
|
-
const v128_t pl_1 = wasm_i16x8_add(pl1l, pl1h);
|
2181
|
-
const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h);
|
2182
|
-
|
2183
|
-
const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0);
|
2184
|
-
const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1);
|
2185
|
-
|
2186
|
-
sum0 += x0->d * y0->d * (
|
2187
|
-
wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) +
|
2188
|
-
wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) +
|
2189
|
-
wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) +
|
2190
|
-
wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7));
|
2191
|
-
sum1 += x1->d * y1->d * (
|
2192
|
-
wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) +
|
2193
|
-
wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) +
|
2194
|
-
wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) +
|
2195
|
-
wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7));
|
2196
|
-
}
|
2197
|
-
|
2198
|
-
sumf = sum0 + sum1;
|
2199
2528
|
#else
|
2200
2529
|
// scalar
|
2201
2530
|
for (int i = 0; i < nb; i++) {
|
@@ -2203,98 +2532,159 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
2203
2532
|
const float d1 = y[i].d;
|
2204
2533
|
|
2205
2534
|
const uint8_t * restrict p0 = x[i].qs;
|
2206
|
-
const
|
2535
|
+
const int8_t * restrict p1 = y[i].qs;
|
2207
2536
|
|
2208
2537
|
int sumi = 0;
|
2209
|
-
for (int j = 0; j <
|
2538
|
+
for (int j = 0; j < QK8_0/2; j++) {
|
2210
2539
|
const uint8_t v0 = p0[j];
|
2211
|
-
const uint8_t v1 = p1[j];
|
2212
2540
|
|
2213
|
-
const
|
2214
|
-
const
|
2541
|
+
const int i0 = (int8_t) (v0 & 0xf) - 8;
|
2542
|
+
const int i1 = (int8_t) (v0 >> 4) - 8;
|
2215
2543
|
|
2216
|
-
const
|
2217
|
-
const
|
2544
|
+
const int i2 = p1[2*j + 0];
|
2545
|
+
const int i3 = p1[2*j + 1];
|
2218
2546
|
|
2219
2547
|
sumi += i0*i2 + i1*i3;
|
2220
2548
|
}
|
2221
|
-
sumf += d0
|
2549
|
+
sumf += d0*d1*sumi;
|
2222
2550
|
}
|
2223
2551
|
#endif
|
2224
2552
|
|
2225
2553
|
*s = sumf;
|
2226
2554
|
}
|
2227
2555
|
|
2228
|
-
static void
|
2229
|
-
const int nb = n /
|
2556
|
+
static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2557
|
+
const int nb = n / QK8_0;
|
2558
|
+
|
2559
|
+
assert(n % QK8_0 == 0);
|
2560
|
+
assert(nb % 2 == 0);
|
2230
2561
|
|
2231
2562
|
const block_q4_1 * restrict x = vx;
|
2232
|
-
const
|
2563
|
+
const block_q8_0 * restrict y = vy;
|
2233
2564
|
|
2234
2565
|
float sumf = 0.0;
|
2235
2566
|
|
2236
|
-
|
2567
|
+
// TODO: add AVX / WASM SIMD / etc
|
2568
|
+
#if defined(__ARM_NEON)
|
2569
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
2570
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
2571
|
+
|
2572
|
+
for (int i = 0; i < nb; i += 2) {
|
2573
|
+
const block_q4_1 * restrict x0 = &x[i + 0];
|
2574
|
+
const block_q4_1 * restrict x1 = &x[i + 1];
|
2575
|
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
2576
|
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
2577
|
+
|
2578
|
+
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
2579
|
+
|
2580
|
+
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
2581
|
+
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
2582
|
+
|
2583
|
+
// 4-bit -> 8-bit
|
2584
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
2585
|
+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
2586
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
2587
|
+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
2588
|
+
|
2589
|
+
// load y
|
2590
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2591
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2592
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2593
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2594
|
+
|
2595
|
+
// interleave
|
2596
|
+
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
|
2597
|
+
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
|
2598
|
+
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
|
2599
|
+
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
|
2600
|
+
|
2601
|
+
const int16x8_t s0i = vaddq_s16(
|
2602
|
+
vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
|
2603
|
+
vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
|
2604
|
+
|
2605
|
+
const int16x8_t s1i = vaddq_s16(
|
2606
|
+
vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
|
2607
|
+
vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
|
2608
|
+
|
2609
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
|
2610
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
|
2611
|
+
|
2612
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
2613
|
+
// dot product into int32x4_t
|
2614
|
+
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
|
2615
|
+
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
|
2616
|
+
|
2617
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
|
2618
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
|
2619
|
+
#else
|
2620
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
|
2621
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
|
2622
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
|
2623
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
|
2624
|
+
|
2625
|
+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
|
2626
|
+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
|
2627
|
+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
|
2628
|
+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
|
2629
|
+
|
2630
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2631
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2632
|
+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
2633
|
+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
2634
|
+
|
2635
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
|
2636
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
|
2637
|
+
#endif
|
2638
|
+
}
|
2639
|
+
|
2640
|
+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
2641
|
+
#elif defined(__AVX2__)
|
2237
2642
|
// Initialize accumulator with zeros
|
2238
2643
|
__m256 acc = _mm256_setzero_ps();
|
2239
|
-
// Accumulator for constant offsets
|
2240
|
-
float acc_offset = 0.0f;
|
2241
2644
|
|
2242
2645
|
// Main loop
|
2243
2646
|
for (int i = 0; i < nb; ++i) {
|
2244
2647
|
const float * d0 = &x[i].d;
|
2245
2648
|
const float * d1 = &y[i].d;
|
2246
|
-
|
2247
2649
|
const float * m0 = &x[i].m;
|
2248
|
-
const float * m1 = &y[i].m;
|
2249
2650
|
|
2250
2651
|
const __m256 d0v = _mm256_broadcast_ss( d0 );
|
2251
2652
|
const __m256 d1v = _mm256_broadcast_ss( d1 );
|
2252
2653
|
const __m256 m0v = _mm256_broadcast_ss( m0 );
|
2253
|
-
const __m256 m1v = _mm256_broadcast_ss( m1 );
|
2254
2654
|
|
2255
|
-
// Compute combined
|
2256
|
-
const __m256
|
2257
|
-
|
2258
|
-
// Compute cross scales for the block
|
2259
|
-
const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
|
2260
|
-
const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
|
2261
|
-
const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ );
|
2655
|
+
// Compute combined scales
|
2656
|
+
const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
|
2657
|
+
const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
|
2262
2658
|
|
2263
2659
|
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
|
2264
|
-
__m256i bx =
|
2265
|
-
__m256i by =
|
2266
|
-
|
2267
|
-
// Now we have a vector with bytes in [ 0 .. 15 ] interval.
|
2268
|
-
|
2269
|
-
// Sign-extend first 16 signed bytes into int16_t
|
2270
|
-
__m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
|
2271
|
-
__m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
|
2272
|
-
// Compute products of int16_t integers, add pairwise
|
2273
|
-
__m256i i32 = _mm256_madd_epi16( x16, y16 );
|
2274
|
-
|
2275
|
-
// Sign-extend last 16 signed bytes into int16_t vectors
|
2276
|
-
__m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
|
2277
|
-
__m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
|
2278
|
-
// Accumulate products of int16_t integers
|
2279
|
-
i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) );
|
2280
|
-
|
2281
|
-
// compute sums of unsigned bytes in bx, by in blocks of 8.
|
2282
|
-
// This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
|
2283
|
-
// which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
|
2284
|
-
// so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
|
2285
|
-
__m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() );
|
2286
|
-
__m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() );
|
2287
|
-
__m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
|
2288
|
-
__m256 sums = _mm256_cvtepi32_ps( sumsi );
|
2660
|
+
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
|
2661
|
+
const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
|
2289
2662
|
|
2290
|
-
//
|
2291
|
-
|
2292
|
-
|
2293
|
-
//
|
2294
|
-
|
2295
|
-
|
2296
|
-
//
|
2297
|
-
|
2663
|
+
// Get absolute values of x vectors
|
2664
|
+
const __m256i ax = _mm256_sign_epi8( bx, bx );
|
2665
|
+
|
2666
|
+
// Sign the values of the y vectors
|
2667
|
+
const __m256i sy = _mm256_sign_epi8( by, bx );
|
2668
|
+
|
2669
|
+
// Perform multiplication and create 16-bit values
|
2670
|
+
const __m256i dot = _mm256_maddubs_epi16( ax, sy );
|
2671
|
+
const __m256i ones = _mm256_set1_epi16( 1 );
|
2672
|
+
const __m256i xy_q = _mm256_madd_epi16( ones, dot );
|
2673
|
+
|
2674
|
+
// Convert to vector of 8 int32_t to 8 floats
|
2675
|
+
const __m256 xy = _mm256_cvtepi32_ps( xy_q );
|
2676
|
+
|
2677
|
+
// Accumulate d0*d1*x*y
|
2678
|
+
acc = _mm256_fmadd_ps( d0d1, xy, acc );
|
2679
|
+
|
2680
|
+
// Compute sum of y values
|
2681
|
+
const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
|
2682
|
+
const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
|
2683
|
+
const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
|
2684
|
+
const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
|
2685
|
+
|
2686
|
+
// Accumulate d1*m0*y
|
2687
|
+
acc = _mm256_fmadd_ps( d1m0, ysum, acc );
|
2298
2688
|
}
|
2299
2689
|
|
2300
2690
|
// Return horizontal sum of the acc vector
|
@@ -2303,131 +2693,379 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
|
|
2303
2693
|
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2304
2694
|
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2305
2695
|
|
2306
|
-
sumf = _mm_cvtss_f32( res )
|
2307
|
-
#
|
2308
|
-
|
2309
|
-
|
2310
|
-
|
2311
|
-
|
2696
|
+
sumf = _mm_cvtss_f32( res );
|
2697
|
+
#else
|
2698
|
+
// scalar
|
2699
|
+
for (int i = 0; i < nb; i++) {
|
2700
|
+
const float d0 = x[i].d;
|
2701
|
+
const float m0 = x[i].m;
|
2702
|
+
const float d1 = y[i].d;
|
2703
|
+
|
2704
|
+
const uint8_t * restrict p0 = x[i].qs;
|
2705
|
+
const int8_t * restrict p1 = y[i].qs;
|
2706
|
+
|
2707
|
+
// TODO: this is very slow ..
|
2708
|
+
for (int j = 0; j < QK8_0/2; j++) {
|
2709
|
+
const uint8_t v0 = p0[j];
|
2710
|
+
|
2711
|
+
const float f0 = d0*(v0 & 0xf) + m0;
|
2712
|
+
const float f1 = d0*(v0 >> 4) + m0;
|
2713
|
+
|
2714
|
+
const float f2 = d1*p1[2*j + 0];
|
2715
|
+
const float f3 = d1*p1[2*j + 1];
|
2716
|
+
|
2717
|
+
sumf += f0*f2 + f1*f3;
|
2718
|
+
}
|
2719
|
+
}
|
2720
|
+
#endif
|
2721
|
+
|
2722
|
+
*s = sumf;
|
2723
|
+
}
|
2724
|
+
|
2725
|
+
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2726
|
+
const int nb = n / QK8_0;
|
2727
|
+
|
2728
|
+
assert(n % QK8_0 == 0);
|
2729
|
+
assert(nb % 2 == 0);
|
2730
|
+
assert(QK8_0 == 2*QK4_2);
|
2731
|
+
|
2732
|
+
const block_q4_2 * restrict x = vx;
|
2733
|
+
const block_q8_0 * restrict y = vy;
|
2734
|
+
|
2735
|
+
float sumf = 0.0;
|
2736
|
+
|
2737
|
+
#if defined(__ARM_NEON)
|
2738
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
2739
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
2312
2740
|
|
2313
2741
|
for (int i = 0; i < nb; i += 2) {
|
2314
|
-
const
|
2315
|
-
const
|
2316
|
-
const
|
2317
|
-
const
|
2742
|
+
const block_q4_2 * restrict x0_0 = &x[2*(i + 0) + 0];
|
2743
|
+
const block_q4_2 * restrict x0_1 = &x[2*(i + 0) + 1];
|
2744
|
+
const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0];
|
2745
|
+
const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1];
|
2318
2746
|
|
2319
|
-
const
|
2747
|
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
2748
|
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
2320
2749
|
|
2321
|
-
const uint8x16_t
|
2322
|
-
const
|
2323
|
-
|
2324
|
-
const uint8x16_t
|
2750
|
+
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
2751
|
+
const int8x16_t s8b = vdupq_n_s8(0x8);
|
2752
|
+
|
2753
|
+
const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
|
2754
|
+
const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
|
2325
2755
|
|
2326
2756
|
// 4-bit -> 8-bit
|
2327
|
-
const
|
2328
|
-
const
|
2329
|
-
const
|
2330
|
-
const
|
2757
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
2758
|
+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
2759
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
2760
|
+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
2331
2761
|
|
2332
|
-
|
2333
|
-
const
|
2334
|
-
const
|
2335
|
-
const
|
2762
|
+
// sub 8
|
2763
|
+
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
2764
|
+
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
2765
|
+
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
2766
|
+
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
2336
2767
|
|
2337
|
-
|
2338
|
-
|
2339
|
-
|
2768
|
+
// interleave
|
2769
|
+
const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
|
2770
|
+
const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
|
2771
|
+
const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
|
2772
|
+
const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
|
2340
2773
|
|
2341
|
-
|
2342
|
-
|
2343
|
-
|
2774
|
+
// load y
|
2775
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2776
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2777
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2778
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2344
2779
|
|
2345
2780
|
#if defined(__ARM_FEATURE_DOTPROD)
|
2346
|
-
|
2347
|
-
|
2348
|
-
|
2781
|
+
sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
|
2782
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
|
2783
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
|
2349
2784
|
|
2350
|
-
|
2351
|
-
|
2352
|
-
|
2353
|
-
sum11 += x0->d*y0->d*vaddvq_u32(p_0);
|
2354
|
-
sum11 += x1->d*y1->d*vaddvq_u32(p_1);
|
2785
|
+
sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
|
2786
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
|
2787
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
|
2355
2788
|
#else
|
2356
|
-
const
|
2357
|
-
const
|
2358
|
-
const
|
2359
|
-
const
|
2789
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
|
2790
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
|
2791
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
|
2792
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
|
2793
|
+
|
2794
|
+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
|
2795
|
+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
|
2796
|
+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
|
2797
|
+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
|
2798
|
+
|
2799
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2800
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2801
|
+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
2802
|
+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
2803
|
+
|
2804
|
+
sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
|
2805
|
+
vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
|
2806
|
+
vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
|
2807
|
+
|
2808
|
+
sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
|
2809
|
+
vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
|
2810
|
+
vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
|
2811
|
+
#endif
|
2812
|
+
}
|
2813
|
+
|
2814
|
+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
2815
|
+
#elif defined(__AVX2__)
|
2816
|
+
// Initialize accumulator with zeros
|
2817
|
+
__m256 acc = _mm256_setzero_ps();
|
2818
|
+
|
2819
|
+
// Main loop
|
2820
|
+
for (int i = 0; i < nb; i++) {
|
2821
|
+
/* Compute combined scale for the block */
|
2822
|
+
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
|
2823
|
+
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
|
2824
|
+
const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
|
2360
2825
|
|
2361
|
-
|
2362
|
-
|
2363
|
-
|
2364
|
-
const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
|
2826
|
+
__m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
|
2827
|
+
__m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
|
2828
|
+
__m256i bx = _mm256_set_m128i(bx1, bx0);
|
2365
2829
|
|
2366
|
-
|
2367
|
-
const
|
2830
|
+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2831
|
+
const __m256i off = _mm256_set1_epi8(8);
|
2832
|
+
bx = _mm256_sub_epi8(bx, off);
|
2368
2833
|
|
2369
|
-
|
2370
|
-
const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
|
2834
|
+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
2371
2835
|
|
2372
|
-
|
2373
|
-
const
|
2836
|
+
// Get absolute values of x vectors
|
2837
|
+
const __m256i ax = _mm256_sign_epi8(bx, bx);
|
2838
|
+
// Sign the values of the y vectors
|
2839
|
+
const __m256i sy = _mm256_sign_epi8(by, bx);
|
2840
|
+
// Perform multiplication and create 16-bit values
|
2841
|
+
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
2374
2842
|
|
2375
|
-
|
2376
|
-
|
2377
|
-
|
2843
|
+
const __m256i ones = _mm256_set1_epi16(1);
|
2844
|
+
__m256i xy_q = _mm256_madd_epi16(ones, dot);
|
2845
|
+
|
2846
|
+
/* Convert to vectore of 8 int32_t to 8 floats */
|
2847
|
+
__m256 q = _mm256_cvtepi32_ps(xy_q);
|
2848
|
+
|
2849
|
+
/* Multiply q with scale and accumulate */
|
2850
|
+
acc = _mm256_fmadd_ps(d, q, acc);
|
2378
2851
|
}
|
2379
2852
|
|
2380
|
-
|
2853
|
+
// Return horizontal sum of the acc vector
|
2854
|
+
__m128 res = _mm256_extractf128_ps(acc, 1);
|
2855
|
+
res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
|
2856
|
+
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
|
2857
|
+
res = _mm_add_ss(res, _mm_movehdup_ps(res));
|
2858
|
+
|
2859
|
+
sumf = _mm_cvtss_f32(res);
|
2381
2860
|
#else
|
2382
2861
|
// scalar
|
2383
2862
|
for (int i = 0; i < nb; i++) {
|
2384
|
-
const
|
2385
|
-
const
|
2863
|
+
const uint8_t * restrict x0 = x[2*i + 0].qs;
|
2864
|
+
const uint8_t * restrict x1 = x[2*i + 1].qs;
|
2865
|
+
const int8_t * restrict y0 = y[i].qs;
|
2386
2866
|
|
2387
|
-
const float
|
2388
|
-
const float
|
2867
|
+
const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
|
2868
|
+
const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
|
2389
2869
|
|
2390
|
-
|
2391
|
-
|
2870
|
+
int sumi_0 = 0;
|
2871
|
+
int sumi_1 = 0;
|
2392
2872
|
|
2393
|
-
for (int j = 0; j <
|
2394
|
-
const uint8_t v0 =
|
2395
|
-
const uint8_t v1 =
|
2873
|
+
for (int j = 0; j < QK8_0/4; j++) {
|
2874
|
+
const uint8_t v0 = x0[j];
|
2875
|
+
const uint8_t v1 = x1[j];
|
2396
2876
|
|
2397
|
-
const
|
2398
|
-
const
|
2877
|
+
const int i0_0 = (int8_t) (v0 & 0xf) - 8;
|
2878
|
+
const int i1_0 = (int8_t) (v0 >> 4) - 8;
|
2399
2879
|
|
2400
|
-
const
|
2401
|
-
const
|
2880
|
+
const int i0_1 = (int8_t) (v1 & 0xf) - 8;
|
2881
|
+
const int i1_1 = (int8_t) (v1 >> 4) - 8;
|
2402
2882
|
|
2403
|
-
|
2883
|
+
const int i2_0 = y0[2*j + 0];
|
2884
|
+
const int i3_0 = y0[2*j + 1];
|
2885
|
+
|
2886
|
+
const int i2_1 = y0[2*(j + QK8_0/4) + 0];
|
2887
|
+
const int i3_1 = y0[2*(j + QK8_0/4) + 1];
|
2888
|
+
|
2889
|
+
sumi_0 += i0_0*i2_0 + i1_0*i3_0;
|
2890
|
+
sumi_1 += i0_1*i2_1 + i1_1*i3_1;
|
2404
2891
|
}
|
2892
|
+
|
2893
|
+
sumf += (d0 * y[i].d) * sumi_0;
|
2894
|
+
sumf += (d1 * y[i].d) * sumi_1;
|
2405
2895
|
}
|
2406
2896
|
#endif
|
2407
2897
|
|
2408
2898
|
*s = sumf;
|
2409
2899
|
}
|
2410
2900
|
|
2411
|
-
|
2412
|
-
|
2413
|
-
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
|
2414
|
-
ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
|
2901
|
+
static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2902
|
+
const int nb = n / QK8_0;
|
2415
2903
|
|
2416
|
-
|
2904
|
+
assert(n % QK8_0 == 0);
|
2905
|
+
assert(nb % 2 == 0);
|
2906
|
+
assert(QK8_0 == 2*QK4_2);
|
2417
2907
|
|
2418
|
-
|
2419
|
-
|
2420
|
-
}
|
2908
|
+
const block_q4_3 * restrict x = vx;
|
2909
|
+
const block_q8_0 * restrict y = vy;
|
2421
2910
|
|
2422
|
-
|
2423
|
-
const int np = (n & ~(GGML_F16_STEP - 1));
|
2911
|
+
float sumf = 0.0;
|
2424
2912
|
|
2425
|
-
|
2913
|
+
#if defined(__ARM_NEON)
|
2914
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
2915
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
2426
2916
|
|
2427
|
-
|
2428
|
-
|
2917
|
+
for (int i = 0; i < nb; i += 2) {
|
2918
|
+
const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
|
2919
|
+
const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
|
2920
|
+
const block_q4_3 * restrict x1_0 = &x[2*(i + 1) + 0];
|
2921
|
+
const block_q4_3 * restrict x1_1 = &x[2*(i + 1) + 1];
|
2429
2922
|
|
2430
|
-
|
2923
|
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
2924
|
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
2925
|
+
|
2926
|
+
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
2927
|
+
|
2928
|
+
const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
|
2929
|
+
const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
|
2930
|
+
const float x1_0d = GGML_FP16_TO_FP32(x1_0->d);
|
2931
|
+
const float x1_1d = GGML_FP16_TO_FP32(x1_1->d);
|
2932
|
+
|
2933
|
+
const float x0_0m = GGML_FP16_TO_FP32(x0_0->m);
|
2934
|
+
const float x0_1m = GGML_FP16_TO_FP32(x0_1->m);
|
2935
|
+
const float x1_0m = GGML_FP16_TO_FP32(x1_0->m);
|
2936
|
+
const float x1_1m = GGML_FP16_TO_FP32(x1_1->m);
|
2937
|
+
|
2938
|
+
const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
|
2939
|
+
const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
|
2940
|
+
|
2941
|
+
// 4-bit -> 8-bit
|
2942
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
2943
|
+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
2944
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
2945
|
+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
2946
|
+
|
2947
|
+
// interleave
|
2948
|
+
const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
|
2949
|
+
const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
|
2950
|
+
const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
|
2951
|
+
const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
|
2952
|
+
|
2953
|
+
// load y
|
2954
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2955
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2956
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2957
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2958
|
+
|
2959
|
+
const int16x8_t sy0_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0l)), vmovl_s8(vget_high_s8(v1_0l)));
|
2960
|
+
const int16x8_t sy0_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0h)), vmovl_s8(vget_high_s8(v1_0h)));
|
2961
|
+
|
2962
|
+
const int16x8_t sy1_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1l)), vmovl_s8(vget_high_s8(v1_1l)));
|
2963
|
+
const int16x8_t sy1_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1h)), vmovl_s8(vget_high_s8(v1_1h)));
|
2964
|
+
|
2965
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_0), vget_high_s16(sy0_0))), x0_0m*y0->d);
|
2966
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_1), vget_high_s16(sy0_1))), x0_1m*y0->d);
|
2967
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_0), vget_high_s16(sy1_0))), x1_0m*y1->d);
|
2968
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_1), vget_high_s16(sy1_1))), x1_1m*y1->d);
|
2969
|
+
|
2970
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
2971
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
|
2972
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
|
2973
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), x1_0d*y1->d);
|
2974
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), x1_1d*y1->d);
|
2975
|
+
#else
|
2976
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
|
2977
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
|
2978
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
|
2979
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
|
2980
|
+
|
2981
|
+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
|
2982
|
+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
|
2983
|
+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
|
2984
|
+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
|
2985
|
+
|
2986
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2987
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2988
|
+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
2989
|
+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
2990
|
+
|
2991
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
|
2992
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(ph0), x0_1d*y0->d);
|
2993
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(pl1), x1_0d*y1->d);
|
2994
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph1), x1_1d*y1->d);
|
2995
|
+
#endif
|
2996
|
+
}
|
2997
|
+
|
2998
|
+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
2999
|
+
#else
|
3000
|
+
// scalar
|
3001
|
+
for (int i = 0; i < nb; i++) {
|
3002
|
+
const uint8_t * restrict x0 = x[2*i + 0].qs;
|
3003
|
+
const uint8_t * restrict x1 = x[2*i + 1].qs;
|
3004
|
+
const int8_t * restrict y0 = y[i].qs;
|
3005
|
+
|
3006
|
+
const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
|
3007
|
+
const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
|
3008
|
+
const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
|
3009
|
+
const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
|
3010
|
+
|
3011
|
+
int sy_0 = 0;
|
3012
|
+
int sy_1 = 0;
|
3013
|
+
|
3014
|
+
int sxy_0 = 0;
|
3015
|
+
int sxy_1 = 0;
|
3016
|
+
|
3017
|
+
for (int j = 0; j < QK8_0/4; j++) {
|
3018
|
+
const uint8_t v0 = x0[j];
|
3019
|
+
const uint8_t v1 = x1[j];
|
3020
|
+
|
3021
|
+
const int x0_0 = v0 & 0xf;
|
3022
|
+
const int x1_0 = v0 >> 4;
|
3023
|
+
|
3024
|
+
const int x0_1 = v1 & 0xf;
|
3025
|
+
const int x1_1 = v1 >> 4;
|
3026
|
+
|
3027
|
+
const int y0_0 = y0[2*j + 0];
|
3028
|
+
const int y1_0 = y0[2*j + 1];
|
3029
|
+
|
3030
|
+
const int y0_1 = y0[2*(j + QK8_0/4) + 0];
|
3031
|
+
const int y1_1 = y0[2*(j + QK8_0/4) + 1];
|
3032
|
+
|
3033
|
+
sy_0 += y0_0 + y1_0;
|
3034
|
+
sy_1 += y0_1 + y1_1;
|
3035
|
+
|
3036
|
+
sxy_0 += x0_0*y0_0 + x1_0*y1_0;
|
3037
|
+
sxy_1 += x0_1*y0_1 + x1_1*y1_1;
|
3038
|
+
}
|
3039
|
+
|
3040
|
+
sumf += (d0*sxy_0 + m0*sy_0)*y[i].d;
|
3041
|
+
sumf += (d1*sxy_1 + m1*sy_1)*y[i].d;
|
3042
|
+
}
|
3043
|
+
#endif
|
3044
|
+
|
3045
|
+
*s = sumf;
|
3046
|
+
}
|
3047
|
+
|
3048
|
+
|
3049
|
+
// compute GGML_VEC_DOT_UNROLL dot products at once
|
3050
|
+
// xs - x row stride in bytes
|
3051
|
+
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
|
3052
|
+
ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
|
3053
|
+
|
3054
|
+
ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
|
3055
|
+
|
3056
|
+
for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
|
3057
|
+
x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
|
3058
|
+
}
|
3059
|
+
|
3060
|
+
#if defined(GGML_SIMD)
|
3061
|
+
const int np = (n & ~(GGML_F16_STEP - 1));
|
3062
|
+
|
3063
|
+
GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
|
3064
|
+
|
3065
|
+
GGML_F16_VEC ax[GGML_F16_ARR];
|
3066
|
+
GGML_F16_VEC ay[GGML_F16_ARR];
|
3067
|
+
|
3068
|
+
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
2431
3069
|
for (int j = 0; j < GGML_F16_ARR; j++) {
|
2432
3070
|
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
2433
3071
|
|
@@ -2652,24 +3290,30 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
|
|
2652
3290
|
static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
|
2653
3291
|
[GGML_TYPE_F32] = 1,
|
2654
3292
|
[GGML_TYPE_F16] = 1,
|
2655
|
-
[GGML_TYPE_Q4_0] =
|
2656
|
-
[GGML_TYPE_Q4_1] =
|
3293
|
+
[GGML_TYPE_Q4_0] = QK4_0,
|
3294
|
+
[GGML_TYPE_Q4_1] = QK4_1,
|
3295
|
+
[GGML_TYPE_Q4_2] = QK4_2,
|
3296
|
+
[GGML_TYPE_Q4_3] = QK4_3,
|
3297
|
+
[GGML_TYPE_Q8_0] = QK8_0,
|
2657
3298
|
[GGML_TYPE_I8] = 1,
|
2658
3299
|
[GGML_TYPE_I16] = 1,
|
2659
3300
|
[GGML_TYPE_I32] = 1,
|
2660
3301
|
};
|
2661
|
-
static_assert(GGML_TYPE_COUNT ==
|
3302
|
+
static_assert(GGML_TYPE_COUNT == 10, "GGML_BLCK_SIZE is outdated");
|
2662
3303
|
|
2663
3304
|
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
|
2664
3305
|
[GGML_TYPE_F32] = sizeof(float),
|
2665
3306
|
[GGML_TYPE_F16] = sizeof(ggml_fp16_t),
|
2666
3307
|
[GGML_TYPE_Q4_0] = sizeof(block_q4_0),
|
2667
3308
|
[GGML_TYPE_Q4_1] = sizeof(block_q4_1),
|
3309
|
+
[GGML_TYPE_Q4_2] = sizeof(block_q4_2),
|
3310
|
+
[GGML_TYPE_Q4_3] = sizeof(block_q4_3),
|
3311
|
+
[GGML_TYPE_Q8_0] = sizeof(block_q8_0),
|
2668
3312
|
[GGML_TYPE_I8] = sizeof(int8_t),
|
2669
3313
|
[GGML_TYPE_I16] = sizeof(int16_t),
|
2670
3314
|
[GGML_TYPE_I32] = sizeof(int32_t),
|
2671
3315
|
};
|
2672
|
-
static_assert(GGML_TYPE_COUNT ==
|
3316
|
+
static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_SIZE is outdated");
|
2673
3317
|
|
2674
3318
|
|
2675
3319
|
static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
|
@@ -2677,11 +3321,28 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
|
|
2677
3321
|
[GGML_TYPE_F16] = "f16",
|
2678
3322
|
[GGML_TYPE_Q4_0] = "q4_0",
|
2679
3323
|
[GGML_TYPE_Q4_1] = "q4_1",
|
3324
|
+
[GGML_TYPE_Q4_2] = "q4_2",
|
3325
|
+
[GGML_TYPE_Q4_3] = "q4_3",
|
3326
|
+
[GGML_TYPE_Q8_0] = "q8_0",
|
2680
3327
|
[GGML_TYPE_I8] = "i8",
|
2681
3328
|
[GGML_TYPE_I16] = "i16",
|
2682
3329
|
[GGML_TYPE_I32] = "i32",
|
2683
3330
|
};
|
2684
|
-
static_assert(GGML_TYPE_COUNT ==
|
3331
|
+
static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_NAME is outdated");
|
3332
|
+
|
3333
|
+
static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
|
3334
|
+
[GGML_TYPE_F32] = false,
|
3335
|
+
[GGML_TYPE_F16] = false,
|
3336
|
+
[GGML_TYPE_Q4_0] = true,
|
3337
|
+
[GGML_TYPE_Q4_1] = true,
|
3338
|
+
[GGML_TYPE_Q4_2] = true,
|
3339
|
+
[GGML_TYPE_Q4_3] = true,
|
3340
|
+
[GGML_TYPE_Q8_0] = true,
|
3341
|
+
[GGML_TYPE_I8] = false,
|
3342
|
+
[GGML_TYPE_I16] = false,
|
3343
|
+
[GGML_TYPE_I32] = false,
|
3344
|
+
};
|
3345
|
+
static_assert(GGML_TYPE_COUNT == 10, "GGML_IS_QUANTIZED is outdated");
|
2685
3346
|
|
2686
3347
|
static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
2687
3348
|
"NONE",
|
@@ -2943,6 +3604,10 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
|
|
2943
3604
|
(t0->ne[3] == t1->ne[3]);
|
2944
3605
|
}
|
2945
3606
|
|
3607
|
+
bool ggml_is_quantized(enum ggml_type type) {
|
3608
|
+
return GGML_IS_QUANTIZED[type];
|
3609
|
+
}
|
3610
|
+
|
2946
3611
|
static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
|
2947
3612
|
return tensor->nb[0] > tensor->nb[1];
|
2948
3613
|
}
|
@@ -3053,6 +3718,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|
3053
3718
|
GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
|
3054
3719
|
}
|
3055
3720
|
|
3721
|
+
// initialize cuBLAS
|
3722
|
+
#if defined(GGML_USE_CUBLAS)
|
3723
|
+
init_cublas();
|
3724
|
+
#endif
|
3725
|
+
|
3056
3726
|
is_first_call = false;
|
3057
3727
|
}
|
3058
3728
|
|
@@ -3354,14 +4024,6 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
|
|
3354
4024
|
char * const data = tensor->data;
|
3355
4025
|
|
3356
4026
|
switch (tensor->type) {
|
3357
|
-
case GGML_TYPE_Q4_0:
|
3358
|
-
{
|
3359
|
-
GGML_ASSERT(false);
|
3360
|
-
} break;
|
3361
|
-
case GGML_TYPE_Q4_1:
|
3362
|
-
{
|
3363
|
-
GGML_ASSERT(false);
|
3364
|
-
} break;
|
3365
4027
|
case GGML_TYPE_I8:
|
3366
4028
|
{
|
3367
4029
|
assert(tensor->nb[0] == sizeof(int8_t));
|
@@ -3397,7 +4059,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
|
|
3397
4059
|
ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
|
3398
4060
|
}
|
3399
4061
|
} break;
|
3400
|
-
|
4062
|
+
default:
|
3401
4063
|
{
|
3402
4064
|
GGML_ASSERT(false);
|
3403
4065
|
} break;
|
@@ -3414,14 +4076,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
|
|
3414
4076
|
char * const data = tensor->data;
|
3415
4077
|
|
3416
4078
|
switch (tensor->type) {
|
3417
|
-
case GGML_TYPE_Q4_0:
|
3418
|
-
{
|
3419
|
-
GGML_ASSERT(false);
|
3420
|
-
} break;
|
3421
|
-
case GGML_TYPE_Q4_1:
|
3422
|
-
{
|
3423
|
-
GGML_ASSERT(false);
|
3424
|
-
} break;
|
3425
4079
|
case GGML_TYPE_I8:
|
3426
4080
|
{
|
3427
4081
|
assert(tensor->nb[0] == sizeof(int8_t));
|
@@ -3457,7 +4111,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
|
|
3457
4111
|
ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
|
3458
4112
|
}
|
3459
4113
|
} break;
|
3460
|
-
|
4114
|
+
default:
|
3461
4115
|
{
|
3462
4116
|
GGML_ASSERT(false);
|
3463
4117
|
} break;
|
@@ -3468,14 +4122,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
|
|
3468
4122
|
|
3469
4123
|
int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
|
3470
4124
|
switch (tensor->type) {
|
3471
|
-
case GGML_TYPE_Q4_0:
|
3472
|
-
{
|
3473
|
-
GGML_ASSERT(false);
|
3474
|
-
} break;
|
3475
|
-
case GGML_TYPE_Q4_1:
|
3476
|
-
{
|
3477
|
-
GGML_ASSERT(false);
|
3478
|
-
} break;
|
3479
4125
|
case GGML_TYPE_I8:
|
3480
4126
|
{
|
3481
4127
|
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
@@ -3501,7 +4147,7 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
|
|
3501
4147
|
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
3502
4148
|
return ((float *)(tensor->data))[i];
|
3503
4149
|
} break;
|
3504
|
-
|
4150
|
+
default:
|
3505
4151
|
{
|
3506
4152
|
GGML_ASSERT(false);
|
3507
4153
|
} break;
|
@@ -3512,14 +4158,6 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
|
|
3512
4158
|
|
3513
4159
|
void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
|
3514
4160
|
switch (tensor->type) {
|
3515
|
-
case GGML_TYPE_Q4_0:
|
3516
|
-
{
|
3517
|
-
GGML_ASSERT(false);
|
3518
|
-
} break;
|
3519
|
-
case GGML_TYPE_Q4_1:
|
3520
|
-
{
|
3521
|
-
GGML_ASSERT(false);
|
3522
|
-
} break;
|
3523
4161
|
case GGML_TYPE_I8:
|
3524
4162
|
{
|
3525
4163
|
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
@@ -3545,7 +4183,7 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
|
|
3545
4183
|
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
3546
4184
|
((float *)(tensor->data))[i] = value;
|
3547
4185
|
} break;
|
3548
|
-
|
4186
|
+
default:
|
3549
4187
|
{
|
3550
4188
|
GGML_ASSERT(false);
|
3551
4189
|
} break;
|
@@ -3554,14 +4192,6 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
|
|
3554
4192
|
|
3555
4193
|
float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
|
3556
4194
|
switch (tensor->type) {
|
3557
|
-
case GGML_TYPE_Q4_0:
|
3558
|
-
{
|
3559
|
-
GGML_ASSERT(false);
|
3560
|
-
} break;
|
3561
|
-
case GGML_TYPE_Q4_1:
|
3562
|
-
{
|
3563
|
-
GGML_ASSERT(false);
|
3564
|
-
} break;
|
3565
4195
|
case GGML_TYPE_I8:
|
3566
4196
|
{
|
3567
4197
|
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
@@ -3587,7 +4217,7 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
|
|
3587
4217
|
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
3588
4218
|
return ((float *)(tensor->data))[i];
|
3589
4219
|
} break;
|
3590
|
-
|
4220
|
+
default:
|
3591
4221
|
{
|
3592
4222
|
GGML_ASSERT(false);
|
3593
4223
|
} break;
|
@@ -3598,14 +4228,6 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
|
|
3598
4228
|
|
3599
4229
|
void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
|
3600
4230
|
switch (tensor->type) {
|
3601
|
-
case GGML_TYPE_Q4_0:
|
3602
|
-
{
|
3603
|
-
GGML_ASSERT(false);
|
3604
|
-
} break;
|
3605
|
-
case GGML_TYPE_Q4_1:
|
3606
|
-
{
|
3607
|
-
GGML_ASSERT(false);
|
3608
|
-
} break;
|
3609
4231
|
case GGML_TYPE_I8:
|
3610
4232
|
{
|
3611
4233
|
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
@@ -3631,7 +4253,7 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
|
|
3631
4253
|
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
3632
4254
|
((float *)(tensor->data))[i] = value;
|
3633
4255
|
} break;
|
3634
|
-
|
4256
|
+
default:
|
3635
4257
|
{
|
3636
4258
|
GGML_ASSERT(false);
|
3637
4259
|
} break;
|
@@ -5031,7 +5653,6 @@ static void ggml_compute_forward_dup_f16(
|
|
5031
5653
|
const struct ggml_compute_params * params,
|
5032
5654
|
const struct ggml_tensor * src0,
|
5033
5655
|
struct ggml_tensor * dst) {
|
5034
|
-
GGML_ASSERT(params->ith == 0);
|
5035
5656
|
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
5036
5657
|
|
5037
5658
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
@@ -5043,6 +5664,11 @@ static void ggml_compute_forward_dup_f16(
|
|
5043
5664
|
const int64_t ne02 = src0->ne[2];
|
5044
5665
|
const int64_t ne03 = src0->ne[3];
|
5045
5666
|
|
5667
|
+
const int64_t ne0 = dst->ne[0];
|
5668
|
+
const int64_t ne1 = dst->ne[1];
|
5669
|
+
const int64_t ne2 = dst->ne[2];
|
5670
|
+
const int64_t ne3 = dst->ne[3];
|
5671
|
+
|
5046
5672
|
const size_t nb00 = src0->nb[0];
|
5047
5673
|
const size_t nb01 = src0->nb[1];
|
5048
5674
|
const size_t nb02 = src0->nb[2];
|
@@ -5053,19 +5679,40 @@ static void ggml_compute_forward_dup_f16(
|
|
5053
5679
|
const size_t nb2 = dst->nb[2];
|
5054
5680
|
const size_t nb3 = dst->nb[3];
|
5055
5681
|
|
5682
|
+
const int ith = params->ith; // thread index
|
5683
|
+
const int nth = params->nth; // number of threads
|
5684
|
+
|
5056
5685
|
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
|
5057
|
-
|
5686
|
+
// parallelize by elements
|
5687
|
+
const int ne = ggml_nelements(dst);
|
5688
|
+
const int dr = (ne + nth - 1) / nth;
|
5689
|
+
const int ie0 = dr * ith;
|
5690
|
+
const int ie1 = MIN(ie0 + dr, ne);
|
5691
|
+
|
5692
|
+
memcpy(
|
5693
|
+
((char *) dst->data + ie0*nb0),
|
5694
|
+
((char *) src0->data + ie0*nb00),
|
5695
|
+
(ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
|
5696
|
+
|
5058
5697
|
return;
|
5059
5698
|
}
|
5060
5699
|
|
5700
|
+
// parallelize by rows
|
5701
|
+
const int nr = ne01;
|
5702
|
+
// number of rows per thread
|
5703
|
+
const int dr = (nr + nth - 1) / nth;
|
5704
|
+
// row range for this thread
|
5705
|
+
const int ir0 = dr * ith;
|
5706
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
5707
|
+
|
5061
5708
|
if (src0->type == dst->type &&
|
5062
|
-
|
5063
|
-
|
5709
|
+
ne00 == ne0 &&
|
5710
|
+
nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
|
5064
5711
|
// copy by rows
|
5065
5712
|
const size_t rs = ne00*nb00;
|
5066
5713
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5067
5714
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5068
|
-
for (int64_t i01 =
|
5715
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5069
5716
|
memcpy(
|
5070
5717
|
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
5071
5718
|
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
@@ -5079,21 +5726,21 @@ static void ggml_compute_forward_dup_f16(
|
|
5079
5726
|
// TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
|
5080
5727
|
|
5081
5728
|
if (ggml_is_contiguous(dst)) {
|
5082
|
-
if (
|
5729
|
+
if (nb00 == sizeof(ggml_fp16_t)) {
|
5083
5730
|
if (dst->type == GGML_TYPE_F16) {
|
5084
5731
|
size_t id = 0;
|
5085
|
-
const size_t rs = ne00*nb00;
|
5732
|
+
const size_t rs = ne00 * nb00;
|
5733
|
+
char * dst_ptr = (char *) dst->data;
|
5086
5734
|
|
5087
5735
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5088
5736
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5089
|
-
|
5737
|
+
id += rs * ir0;
|
5738
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5090
5739
|
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
5091
|
-
|
5092
|
-
|
5093
|
-
memcpy(dst_ptr, src0_ptr, rs);
|
5094
|
-
|
5095
|
-
id++;
|
5740
|
+
memcpy(dst_ptr + id, src0_ptr, rs);
|
5741
|
+
id += rs;
|
5096
5742
|
}
|
5743
|
+
id += rs * (ne01 - ir1);
|
5097
5744
|
}
|
5098
5745
|
}
|
5099
5746
|
} else if (dst->type == GGML_TYPE_F32) {
|
@@ -5102,14 +5749,39 @@ static void ggml_compute_forward_dup_f16(
|
|
5102
5749
|
|
5103
5750
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5104
5751
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5105
|
-
|
5752
|
+
id += ne00 * ir0;
|
5753
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5754
|
+
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
5106
5755
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5107
|
-
|
5108
|
-
|
5109
|
-
dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
|
5756
|
+
dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
|
5110
5757
|
id++;
|
5111
5758
|
}
|
5112
5759
|
}
|
5760
|
+
id += ne00 * (ne01 - ir1);
|
5761
|
+
}
|
5762
|
+
}
|
5763
|
+
} else if (ggml_is_quantized(dst->type)) {
|
5764
|
+
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
|
5765
|
+
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
5766
|
+
|
5767
|
+
size_t id = 0;
|
5768
|
+
size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
|
5769
|
+
char * dst_ptr = (char *) dst->data;
|
5770
|
+
|
5771
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
5772
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
5773
|
+
id += rs * ir0;
|
5774
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5775
|
+
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
5776
|
+
|
5777
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
5778
|
+
src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
|
5779
|
+
}
|
5780
|
+
|
5781
|
+
quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
5782
|
+
id += rs;
|
5783
|
+
}
|
5784
|
+
id += rs * (ne01 - ir1);
|
5113
5785
|
}
|
5114
5786
|
}
|
5115
5787
|
} else {
|
@@ -5124,7 +5796,8 @@ static void ggml_compute_forward_dup_f16(
|
|
5124
5796
|
|
5125
5797
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5126
5798
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5127
|
-
|
5799
|
+
id += ne00 * ir0;
|
5800
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5128
5801
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5129
5802
|
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5130
5803
|
|
@@ -5132,6 +5805,7 @@ static void ggml_compute_forward_dup_f16(
|
|
5132
5805
|
id++;
|
5133
5806
|
}
|
5134
5807
|
}
|
5808
|
+
id += ne00 * (ne01 - ir1);
|
5135
5809
|
}
|
5136
5810
|
}
|
5137
5811
|
} else if (dst->type == GGML_TYPE_F16) {
|
@@ -5140,7 +5814,8 @@ static void ggml_compute_forward_dup_f16(
|
|
5140
5814
|
|
5141
5815
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5142
5816
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5143
|
-
|
5817
|
+
id += ne00 * ir0;
|
5818
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5144
5819
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5145
5820
|
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5146
5821
|
|
@@ -5148,6 +5823,7 @@ static void ggml_compute_forward_dup_f16(
|
|
5148
5823
|
id++;
|
5149
5824
|
}
|
5150
5825
|
}
|
5826
|
+
id += ne00 * (ne01 - ir1);
|
5151
5827
|
}
|
5152
5828
|
}
|
5153
5829
|
} else {
|
@@ -5166,7 +5842,20 @@ static void ggml_compute_forward_dup_f16(
|
|
5166
5842
|
if (dst->type == GGML_TYPE_F16) {
|
5167
5843
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5168
5844
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5169
|
-
|
5845
|
+
i10 += ne00 * ir0;
|
5846
|
+
while (i10 >= ne0) {
|
5847
|
+
i10 -= ne0;
|
5848
|
+
if (++i11 == ne1) {
|
5849
|
+
i11 = 0;
|
5850
|
+
if (++i12 == ne2) {
|
5851
|
+
i12 = 0;
|
5852
|
+
if (++i13 == ne3) {
|
5853
|
+
i13 = 0;
|
5854
|
+
}
|
5855
|
+
}
|
5856
|
+
}
|
5857
|
+
}
|
5858
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5170
5859
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
5171
5860
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5172
5861
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
@@ -5187,25 +5876,51 @@ static void ggml_compute_forward_dup_f16(
|
|
5187
5876
|
}
|
5188
5877
|
}
|
5189
5878
|
}
|
5879
|
+
i10 += ne00 * (ne01 - ir1);
|
5880
|
+
while (i10 >= ne0) {
|
5881
|
+
i10 -= ne0;
|
5882
|
+
if (++i11 == ne1) {
|
5883
|
+
i11 = 0;
|
5884
|
+
if (++i12 == ne2) {
|
5885
|
+
i12 = 0;
|
5886
|
+
if (++i13 == ne3) {
|
5887
|
+
i13 = 0;
|
5888
|
+
}
|
5889
|
+
}
|
5890
|
+
}
|
5891
|
+
}
|
5190
5892
|
}
|
5191
5893
|
}
|
5192
5894
|
} else if (dst->type == GGML_TYPE_F32) {
|
5193
5895
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5194
5896
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5195
|
-
|
5897
|
+
i10 += ne00 * ir0;
|
5898
|
+
while (i10 >= ne0) {
|
5899
|
+
i10 -= ne0;
|
5900
|
+
if (++i11 == ne1) {
|
5901
|
+
i11 = 0;
|
5902
|
+
if (++i12 == ne2) {
|
5903
|
+
i12 = 0;
|
5904
|
+
if (++i13 == ne3) {
|
5905
|
+
i13 = 0;
|
5906
|
+
}
|
5907
|
+
}
|
5908
|
+
}
|
5909
|
+
}
|
5910
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5196
5911
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
5197
5912
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5198
5913
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
5199
5914
|
|
5200
5915
|
*(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
|
5201
5916
|
|
5202
|
-
if (++i10 ==
|
5917
|
+
if (++i10 == ne0) {
|
5203
5918
|
i10 = 0;
|
5204
|
-
if (++i11 ==
|
5919
|
+
if (++i11 == ne1) {
|
5205
5920
|
i11 = 0;
|
5206
|
-
if (++i12 ==
|
5921
|
+
if (++i12 == ne2) {
|
5207
5922
|
i12 = 0;
|
5208
|
-
if (++i13 ==
|
5923
|
+
if (++i13 == ne3) {
|
5209
5924
|
i13 = 0;
|
5210
5925
|
}
|
5211
5926
|
}
|
@@ -5213,6 +5928,19 @@ static void ggml_compute_forward_dup_f16(
|
|
5213
5928
|
}
|
5214
5929
|
}
|
5215
5930
|
}
|
5931
|
+
i10 += ne00 * (ne01 - ir1);
|
5932
|
+
while (i10 >= ne0) {
|
5933
|
+
i10 -= ne0;
|
5934
|
+
if (++i11 == ne1) {
|
5935
|
+
i11 = 0;
|
5936
|
+
if (++i12 == ne2) {
|
5937
|
+
i12 = 0;
|
5938
|
+
if (++i13 == ne3) {
|
5939
|
+
i13 = 0;
|
5940
|
+
}
|
5941
|
+
}
|
5942
|
+
}
|
5943
|
+
}
|
5216
5944
|
}
|
5217
5945
|
}
|
5218
5946
|
} else {
|
@@ -5224,7 +5952,6 @@ static void ggml_compute_forward_dup_f32(
|
|
5224
5952
|
const struct ggml_compute_params * params,
|
5225
5953
|
const struct ggml_tensor * src0,
|
5226
5954
|
struct ggml_tensor * dst) {
|
5227
|
-
GGML_ASSERT(params->ith == 0);
|
5228
5955
|
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
5229
5956
|
|
5230
5957
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
@@ -5236,6 +5963,11 @@ static void ggml_compute_forward_dup_f32(
|
|
5236
5963
|
const int64_t ne02 = src0->ne[2];
|
5237
5964
|
const int64_t ne03 = src0->ne[3];
|
5238
5965
|
|
5966
|
+
const int64_t ne0 = dst->ne[0];
|
5967
|
+
const int64_t ne1 = dst->ne[1];
|
5968
|
+
const int64_t ne2 = dst->ne[2];
|
5969
|
+
const int64_t ne3 = dst->ne[3];
|
5970
|
+
|
5239
5971
|
const size_t nb00 = src0->nb[0];
|
5240
5972
|
const size_t nb01 = src0->nb[1];
|
5241
5973
|
const size_t nb02 = src0->nb[2];
|
@@ -5246,19 +5978,40 @@ static void ggml_compute_forward_dup_f32(
|
|
5246
5978
|
const size_t nb2 = dst->nb[2];
|
5247
5979
|
const size_t nb3 = dst->nb[3];
|
5248
5980
|
|
5981
|
+
const int ith = params->ith; // thread index
|
5982
|
+
const int nth = params->nth; // number of threads
|
5983
|
+
|
5249
5984
|
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
|
5250
|
-
|
5985
|
+
// parallelize by elements
|
5986
|
+
const int ne = ggml_nelements(dst);
|
5987
|
+
const int dr = (ne + nth - 1) / nth;
|
5988
|
+
const int ie0 = dr * ith;
|
5989
|
+
const int ie1 = MIN(ie0 + dr, ne);
|
5990
|
+
|
5991
|
+
memcpy(
|
5992
|
+
((char *) dst->data + ie0*nb0),
|
5993
|
+
((char *) src0->data + ie0*nb00),
|
5994
|
+
(ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
|
5995
|
+
|
5251
5996
|
return;
|
5252
5997
|
}
|
5253
5998
|
|
5999
|
+
// parallelize by rows
|
6000
|
+
const int nr = ne01;
|
6001
|
+
// number of rows per thread
|
6002
|
+
const int dr = (nr + nth - 1) / nth;
|
6003
|
+
// row range for this thread
|
6004
|
+
const int ir0 = dr * ith;
|
6005
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
6006
|
+
|
5254
6007
|
if (src0->type == dst->type &&
|
5255
|
-
|
5256
|
-
|
6008
|
+
ne00 == ne0 &&
|
6009
|
+
nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
|
5257
6010
|
// copy by rows
|
5258
6011
|
const size_t rs = ne00*nb00;
|
5259
6012
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5260
6013
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5261
|
-
for (int64_t i01 =
|
6014
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5262
6015
|
memcpy(
|
5263
6016
|
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
5264
6017
|
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
@@ -5271,21 +6024,21 @@ static void ggml_compute_forward_dup_f32(
|
|
5271
6024
|
|
5272
6025
|
if (ggml_is_contiguous(dst)) {
|
5273
6026
|
// TODO: simplify
|
5274
|
-
if (
|
6027
|
+
if (nb00 == sizeof(float)) {
|
5275
6028
|
if (dst->type == GGML_TYPE_F32) {
|
5276
6029
|
size_t id = 0;
|
5277
|
-
const size_t rs = ne00*nb00;
|
6030
|
+
const size_t rs = ne00 * nb00;
|
6031
|
+
char * dst_ptr = (char *) dst->data;
|
5278
6032
|
|
5279
6033
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5280
6034
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5281
|
-
|
6035
|
+
id += rs * ir0;
|
6036
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5282
6037
|
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
5283
|
-
|
5284
|
-
|
5285
|
-
memcpy(dst_ptr, src0_ptr, rs);
|
5286
|
-
|
5287
|
-
id++;
|
6038
|
+
memcpy(dst_ptr + id, src0_ptr, rs);
|
6039
|
+
id += rs;
|
5288
6040
|
}
|
6041
|
+
id += rs * (ne01 - ir1);
|
5289
6042
|
}
|
5290
6043
|
}
|
5291
6044
|
} else if (dst->type == GGML_TYPE_F16) {
|
@@ -5294,7 +6047,8 @@ static void ggml_compute_forward_dup_f32(
|
|
5294
6047
|
|
5295
6048
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5296
6049
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5297
|
-
|
6050
|
+
id += ne00 * ir0;
|
6051
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5298
6052
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5299
6053
|
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5300
6054
|
|
@@ -5302,6 +6056,25 @@ static void ggml_compute_forward_dup_f32(
|
|
5302
6056
|
id++;
|
5303
6057
|
}
|
5304
6058
|
}
|
6059
|
+
id += ne00 * (ne01 - ir1);
|
6060
|
+
}
|
6061
|
+
}
|
6062
|
+
} else if (ggml_is_quantized(dst->type)) {
|
6063
|
+
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
|
6064
|
+
|
6065
|
+
size_t id = 0;
|
6066
|
+
size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
|
6067
|
+
char * dst_ptr = (char *) dst->data;
|
6068
|
+
|
6069
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
6070
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
6071
|
+
id += rs * ir0;
|
6072
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
6073
|
+
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
6074
|
+
quantize_row_q(src0_ptr, dst_ptr + id, ne00);
|
6075
|
+
id += rs;
|
6076
|
+
}
|
6077
|
+
id += rs * (ne01 - ir1);
|
5305
6078
|
}
|
5306
6079
|
}
|
5307
6080
|
} else {
|
@@ -5316,7 +6089,8 @@ static void ggml_compute_forward_dup_f32(
|
|
5316
6089
|
|
5317
6090
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5318
6091
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5319
|
-
|
6092
|
+
id += ne00 * ir0;
|
6093
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5320
6094
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5321
6095
|
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5322
6096
|
|
@@ -5324,6 +6098,7 @@ static void ggml_compute_forward_dup_f32(
|
|
5324
6098
|
id++;
|
5325
6099
|
}
|
5326
6100
|
}
|
6101
|
+
id += ne00 * (ne01 - ir1);
|
5327
6102
|
}
|
5328
6103
|
}
|
5329
6104
|
} else if (dst->type == GGML_TYPE_F16) {
|
@@ -5332,7 +6107,8 @@ static void ggml_compute_forward_dup_f32(
|
|
5332
6107
|
|
5333
6108
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5334
6109
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5335
|
-
|
6110
|
+
id += ne00 * ir0;
|
6111
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5336
6112
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5337
6113
|
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5338
6114
|
|
@@ -5340,6 +6116,7 @@ static void ggml_compute_forward_dup_f32(
|
|
5340
6116
|
id++;
|
5341
6117
|
}
|
5342
6118
|
}
|
6119
|
+
id += ne00 * (ne01 - ir1);
|
5343
6120
|
}
|
5344
6121
|
}
|
5345
6122
|
} else {
|
@@ -5351,6 +6128,7 @@ static void ggml_compute_forward_dup_f32(
|
|
5351
6128
|
}
|
5352
6129
|
|
5353
6130
|
// dst counters
|
6131
|
+
|
5354
6132
|
int64_t i10 = 0;
|
5355
6133
|
int64_t i11 = 0;
|
5356
6134
|
int64_t i12 = 0;
|
@@ -5359,20 +6137,33 @@ static void ggml_compute_forward_dup_f32(
|
|
5359
6137
|
if (dst->type == GGML_TYPE_F32) {
|
5360
6138
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5361
6139
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5362
|
-
|
6140
|
+
i10 += ne00 * ir0;
|
6141
|
+
while (i10 >= ne0) {
|
6142
|
+
i10 -= ne0;
|
6143
|
+
if (++i11 == ne1) {
|
6144
|
+
i11 = 0;
|
6145
|
+
if (++i12 == ne2) {
|
6146
|
+
i12 = 0;
|
6147
|
+
if (++i13 == ne3) {
|
6148
|
+
i13 = 0;
|
6149
|
+
}
|
6150
|
+
}
|
6151
|
+
}
|
6152
|
+
}
|
6153
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5363
6154
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
5364
6155
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5365
6156
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
5366
6157
|
|
5367
6158
|
memcpy(dst_ptr, src0_ptr, sizeof(float));
|
5368
6159
|
|
5369
|
-
if (++i10 ==
|
6160
|
+
if (++i10 == ne0) {
|
5370
6161
|
i10 = 0;
|
5371
|
-
if (++i11 ==
|
6162
|
+
if (++i11 == ne1) {
|
5372
6163
|
i11 = 0;
|
5373
|
-
if (++i12 ==
|
6164
|
+
if (++i12 == ne2) {
|
5374
6165
|
i12 = 0;
|
5375
|
-
if (++i13 ==
|
6166
|
+
if (++i13 == ne3) {
|
5376
6167
|
i13 = 0;
|
5377
6168
|
}
|
5378
6169
|
}
|
@@ -5380,25 +6171,51 @@ static void ggml_compute_forward_dup_f32(
|
|
5380
6171
|
}
|
5381
6172
|
}
|
5382
6173
|
}
|
6174
|
+
i10 += ne00 * (ne01 - ir1);
|
6175
|
+
while (i10 >= ne0) {
|
6176
|
+
i10 -= ne0;
|
6177
|
+
if (++i11 == ne1) {
|
6178
|
+
i11 = 0;
|
6179
|
+
if (++i12 == ne2) {
|
6180
|
+
i12 = 0;
|
6181
|
+
if (++i13 == ne3) {
|
6182
|
+
i13 = 0;
|
6183
|
+
}
|
6184
|
+
}
|
6185
|
+
}
|
6186
|
+
}
|
5383
6187
|
}
|
5384
6188
|
}
|
5385
6189
|
} else if (dst->type == GGML_TYPE_F16) {
|
5386
6190
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5387
6191
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5388
|
-
|
6192
|
+
i10 += ne00 * ir0;
|
6193
|
+
while (i10 >= ne0) {
|
6194
|
+
i10 -= ne0;
|
6195
|
+
if (++i11 == ne1) {
|
6196
|
+
i11 = 0;
|
6197
|
+
if (++i12 == ne2) {
|
6198
|
+
i12 = 0;
|
6199
|
+
if (++i13 == ne3) {
|
6200
|
+
i13 = 0;
|
6201
|
+
}
|
6202
|
+
}
|
6203
|
+
}
|
6204
|
+
}
|
6205
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5389
6206
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
5390
6207
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5391
6208
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
5392
6209
|
|
5393
6210
|
*(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
|
5394
6211
|
|
5395
|
-
if (++i10 ==
|
6212
|
+
if (++i10 == ne0) {
|
5396
6213
|
i10 = 0;
|
5397
|
-
if (++i11 ==
|
6214
|
+
if (++i11 == ne1) {
|
5398
6215
|
i11 = 0;
|
5399
|
-
if (++i12 ==
|
6216
|
+
if (++i12 == ne2) {
|
5400
6217
|
i12 = 0;
|
5401
|
-
if (++i13 ==
|
6218
|
+
if (++i13 == ne3) {
|
5402
6219
|
i13 = 0;
|
5403
6220
|
}
|
5404
6221
|
}
|
@@ -5406,6 +6223,19 @@ static void ggml_compute_forward_dup_f32(
|
|
5406
6223
|
}
|
5407
6224
|
}
|
5408
6225
|
}
|
6226
|
+
i10 += ne00 * (ne01 - ir1);
|
6227
|
+
while (i10 >= ne0) {
|
6228
|
+
i10 -= ne0;
|
6229
|
+
if (++i11 == ne1) {
|
6230
|
+
i11 = 0;
|
6231
|
+
if (++i12 == ne2) {
|
6232
|
+
i12 = 0;
|
6233
|
+
if (++i13 == ne3) {
|
6234
|
+
i13 = 0;
|
6235
|
+
}
|
6236
|
+
}
|
6237
|
+
}
|
6238
|
+
}
|
5409
6239
|
}
|
5410
6240
|
}
|
5411
6241
|
} else {
|
@@ -5426,12 +6256,7 @@ static void ggml_compute_forward_dup(
|
|
5426
6256
|
{
|
5427
6257
|
ggml_compute_forward_dup_f32(params, src0, dst);
|
5428
6258
|
} break;
|
5429
|
-
|
5430
|
-
case GGML_TYPE_Q4_1:
|
5431
|
-
case GGML_TYPE_I8:
|
5432
|
-
case GGML_TYPE_I16:
|
5433
|
-
case GGML_TYPE_I32:
|
5434
|
-
case GGML_TYPE_COUNT:
|
6259
|
+
default:
|
5435
6260
|
{
|
5436
6261
|
GGML_ASSERT(false);
|
5437
6262
|
} break;
|
@@ -5497,6 +6322,212 @@ static void ggml_compute_forward_add_f32(
|
|
5497
6322
|
}
|
5498
6323
|
}
|
5499
6324
|
|
6325
|
+
static void ggml_compute_forward_add_f16_f32(
|
6326
|
+
const struct ggml_compute_params * params,
|
6327
|
+
const struct ggml_tensor * src0,
|
6328
|
+
const struct ggml_tensor * src1,
|
6329
|
+
struct ggml_tensor * dst) {
|
6330
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
6331
|
+
|
6332
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
6333
|
+
return;
|
6334
|
+
}
|
6335
|
+
|
6336
|
+
const int ith = params->ith;
|
6337
|
+
const int nth = params->nth;
|
6338
|
+
|
6339
|
+
const int n = ggml_nrows(src0);
|
6340
|
+
const int nc = src0->ne[0];
|
6341
|
+
|
6342
|
+
const size_t nb00 = src0->nb[0];
|
6343
|
+
const size_t nb01 = src0->nb[1];
|
6344
|
+
|
6345
|
+
const size_t nb10 = src1->nb[0];
|
6346
|
+
const size_t nb11 = src1->nb[1];
|
6347
|
+
|
6348
|
+
const size_t nb0 = dst->nb[0];
|
6349
|
+
const size_t nb1 = dst->nb[1];
|
6350
|
+
|
6351
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
6352
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
6353
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F16);
|
6354
|
+
|
6355
|
+
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
|
6356
|
+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
6357
|
+
|
6358
|
+
if (nb10 == sizeof(float)) {
|
6359
|
+
for (int j = ith; j < n; j += nth) {
|
6360
|
+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
|
6361
|
+
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
|
6362
|
+
for (int i = 0; i < nc; i++) {
|
6363
|
+
float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
|
6364
|
+
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
|
6365
|
+
}
|
6366
|
+
}
|
6367
|
+
}
|
6368
|
+
else {
|
6369
|
+
// src1 is not contiguous
|
6370
|
+
GGML_ASSERT(false);
|
6371
|
+
}
|
6372
|
+
}
|
6373
|
+
|
6374
|
+
static void ggml_compute_forward_add_f16_f16(
|
6375
|
+
const struct ggml_compute_params * params,
|
6376
|
+
const struct ggml_tensor * src0,
|
6377
|
+
const struct ggml_tensor * src1,
|
6378
|
+
struct ggml_tensor * dst) {
|
6379
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
6380
|
+
|
6381
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
6382
|
+
return;
|
6383
|
+
}
|
6384
|
+
|
6385
|
+
const int ith = params->ith;
|
6386
|
+
const int nth = params->nth;
|
6387
|
+
|
6388
|
+
const int n = ggml_nrows(src0);
|
6389
|
+
const int nc = src0->ne[0];
|
6390
|
+
|
6391
|
+
const size_t nb00 = src0->nb[0];
|
6392
|
+
const size_t nb01 = src0->nb[1];
|
6393
|
+
|
6394
|
+
const size_t nb10 = src1->nb[0];
|
6395
|
+
const size_t nb11 = src1->nb[1];
|
6396
|
+
|
6397
|
+
const size_t nb0 = dst->nb[0];
|
6398
|
+
const size_t nb1 = dst->nb[1];
|
6399
|
+
|
6400
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
6401
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F16);
|
6402
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F16);
|
6403
|
+
|
6404
|
+
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
|
6405
|
+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
6406
|
+
|
6407
|
+
if (nb10 == sizeof(ggml_fp16_t)) {
|
6408
|
+
for (int j = ith; j < n; j += nth) {
|
6409
|
+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
|
6410
|
+
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
|
6411
|
+
for (int i = 0; i < nc; i++) {
|
6412
|
+
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
|
6413
|
+
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
|
6414
|
+
}
|
6415
|
+
}
|
6416
|
+
}
|
6417
|
+
else {
|
6418
|
+
// src1 is not contiguous
|
6419
|
+
GGML_ASSERT(false);
|
6420
|
+
}
|
6421
|
+
}
|
6422
|
+
|
6423
|
+
static void ggml_compute_forward_add_q_f32(
|
6424
|
+
const struct ggml_compute_params * params,
|
6425
|
+
const struct ggml_tensor * src0,
|
6426
|
+
const struct ggml_tensor * src1,
|
6427
|
+
struct ggml_tensor * dst) {
|
6428
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
6429
|
+
|
6430
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
6431
|
+
return;
|
6432
|
+
}
|
6433
|
+
|
6434
|
+
const int64_t ne00 = src0->ne[0];
|
6435
|
+
const int64_t ne01 = src0->ne[1];
|
6436
|
+
const int64_t ne02 = src0->ne[2];
|
6437
|
+
const int64_t ne03 = src0->ne[3];
|
6438
|
+
|
6439
|
+
//const int64_t ne10 = src1->ne[0];
|
6440
|
+
//const int64_t ne11 = src1->ne[1];
|
6441
|
+
const int64_t ne12 = src1->ne[2];
|
6442
|
+
const int64_t ne13 = src1->ne[3];
|
6443
|
+
|
6444
|
+
//const int64_t ne0 = dst->ne[0];
|
6445
|
+
//const int64_t ne1 = dst->ne[1];
|
6446
|
+
const int64_t ne2 = dst->ne[2];
|
6447
|
+
const int64_t ne3 = dst->ne[3];
|
6448
|
+
|
6449
|
+
const int nb00 = src0->nb[0];
|
6450
|
+
const int nb01 = src0->nb[1];
|
6451
|
+
const int nb02 = src0->nb[2];
|
6452
|
+
const int nb03 = src0->nb[3];
|
6453
|
+
|
6454
|
+
const int nb10 = src1->nb[0];
|
6455
|
+
const int nb11 = src1->nb[1];
|
6456
|
+
const int nb12 = src1->nb[2];
|
6457
|
+
const int nb13 = src1->nb[3];
|
6458
|
+
|
6459
|
+
const int nb0 = dst->nb[0];
|
6460
|
+
const int nb1 = dst->nb[1];
|
6461
|
+
const int nb2 = dst->nb[2];
|
6462
|
+
const int nb3 = dst->nb[3];
|
6463
|
+
|
6464
|
+
const int ith = params->ith;
|
6465
|
+
const int nth = params->nth;
|
6466
|
+
|
6467
|
+
GGML_ASSERT(ne02 == ne12);
|
6468
|
+
GGML_ASSERT(ne03 == ne13);
|
6469
|
+
GGML_ASSERT(ne2 == ne12);
|
6470
|
+
GGML_ASSERT(ne3 == ne13);
|
6471
|
+
|
6472
|
+
const enum ggml_type type = src0->type;
|
6473
|
+
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
|
6474
|
+
quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
|
6475
|
+
|
6476
|
+
// we don't support permuted src0 or src1
|
6477
|
+
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
|
6478
|
+
GGML_ASSERT(nb10 == sizeof(float));
|
6479
|
+
|
6480
|
+
// dst cannot be transposed or permuted
|
6481
|
+
GGML_ASSERT(nb0 <= nb1);
|
6482
|
+
GGML_ASSERT(nb1 <= nb2);
|
6483
|
+
GGML_ASSERT(nb2 <= nb3);
|
6484
|
+
|
6485
|
+
GGML_ASSERT(ggml_is_quantized(src0->type));
|
6486
|
+
GGML_ASSERT(dst->type == src0->type);
|
6487
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
6488
|
+
|
6489
|
+
// total rows in src0
|
6490
|
+
const int nr = ne01*ne02*ne03;
|
6491
|
+
|
6492
|
+
// rows per thread
|
6493
|
+
const int dr = (nr + nth - 1)/nth;
|
6494
|
+
|
6495
|
+
// row range for this thread
|
6496
|
+
const int ir0 = dr*ith;
|
6497
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
6498
|
+
|
6499
|
+
float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
6500
|
+
|
6501
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
6502
|
+
// src0 indices
|
6503
|
+
const int i03 = ir/(ne02*ne01);
|
6504
|
+
const int i02 = (ir - i03*ne02*ne01)/ne01;
|
6505
|
+
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
6506
|
+
|
6507
|
+
// src1 and dst are same shape as src0 => same indices
|
6508
|
+
const int i13 = i03;
|
6509
|
+
const int i12 = i02;
|
6510
|
+
const int i11 = i01;
|
6511
|
+
|
6512
|
+
const int i3 = i03;
|
6513
|
+
const int i2 = i02;
|
6514
|
+
const int i1 = i01;
|
6515
|
+
|
6516
|
+
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
|
6517
|
+
float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
|
6518
|
+
void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb0));
|
6519
|
+
|
6520
|
+
assert(ne00 % 32 == 0);
|
6521
|
+
|
6522
|
+
// unquantize row from src0 to temp buffer
|
6523
|
+
dequantize_row_q(src0_row, wdata, ne00);
|
6524
|
+
// add src1
|
6525
|
+
ggml_vec_acc_f32(ne00, wdata, src1_row);
|
6526
|
+
// quantize row to dst
|
6527
|
+
quantize_row_q(wdata, dst_row, ne00);
|
6528
|
+
}
|
6529
|
+
}
|
6530
|
+
|
5500
6531
|
static void ggml_compute_forward_add(
|
5501
6532
|
const struct ggml_compute_params * params,
|
5502
6533
|
const struct ggml_tensor * src0,
|
@@ -5507,13 +6538,26 @@ static void ggml_compute_forward_add(
|
|
5507
6538
|
{
|
5508
6539
|
ggml_compute_forward_add_f32(params, src0, src1, dst);
|
5509
6540
|
} break;
|
6541
|
+
case GGML_TYPE_F16:
|
6542
|
+
{
|
6543
|
+
if (src1->type == GGML_TYPE_F16) {
|
6544
|
+
ggml_compute_forward_add_f16_f16(params, src0, src1, dst);
|
6545
|
+
}
|
6546
|
+
else if (src1->type == GGML_TYPE_F32) {
|
6547
|
+
ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
|
6548
|
+
}
|
6549
|
+
else {
|
6550
|
+
GGML_ASSERT(false);
|
6551
|
+
}
|
6552
|
+
} break;
|
5510
6553
|
case GGML_TYPE_Q4_0:
|
5511
6554
|
case GGML_TYPE_Q4_1:
|
5512
|
-
case
|
5513
|
-
case
|
5514
|
-
|
5515
|
-
|
5516
|
-
|
6555
|
+
case GGML_TYPE_Q4_2:
|
6556
|
+
case GGML_TYPE_Q4_3:
|
6557
|
+
{
|
6558
|
+
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
6559
|
+
} break;
|
6560
|
+
default:
|
5517
6561
|
{
|
5518
6562
|
GGML_ASSERT(false);
|
5519
6563
|
} break;
|
@@ -5559,13 +6603,7 @@ static void ggml_compute_forward_sub(
|
|
5559
6603
|
{
|
5560
6604
|
ggml_compute_forward_sub_f32(params, src0, src1, dst);
|
5561
6605
|
} break;
|
5562
|
-
|
5563
|
-
case GGML_TYPE_Q4_1:
|
5564
|
-
case GGML_TYPE_I8:
|
5565
|
-
case GGML_TYPE_I16:
|
5566
|
-
case GGML_TYPE_I32:
|
5567
|
-
case GGML_TYPE_F16:
|
5568
|
-
case GGML_TYPE_COUNT:
|
6606
|
+
default:
|
5569
6607
|
{
|
5570
6608
|
GGML_ASSERT(false);
|
5571
6609
|
} break;
|
@@ -5611,13 +6649,7 @@ static void ggml_compute_forward_mul(
|
|
5611
6649
|
{
|
5612
6650
|
ggml_compute_forward_mul_f32(params, src0, src1, dst);
|
5613
6651
|
} break;
|
5614
|
-
|
5615
|
-
case GGML_TYPE_Q4_1:
|
5616
|
-
case GGML_TYPE_I8:
|
5617
|
-
case GGML_TYPE_I16:
|
5618
|
-
case GGML_TYPE_I32:
|
5619
|
-
case GGML_TYPE_F16:
|
5620
|
-
case GGML_TYPE_COUNT:
|
6652
|
+
default:
|
5621
6653
|
{
|
5622
6654
|
GGML_ASSERT(false);
|
5623
6655
|
} break;
|
@@ -5663,13 +6695,7 @@ static void ggml_compute_forward_div(
|
|
5663
6695
|
{
|
5664
6696
|
ggml_compute_forward_div_f32(params, src0, src1, dst);
|
5665
6697
|
} break;
|
5666
|
-
|
5667
|
-
case GGML_TYPE_Q4_1:
|
5668
|
-
case GGML_TYPE_I8:
|
5669
|
-
case GGML_TYPE_I16:
|
5670
|
-
case GGML_TYPE_I32:
|
5671
|
-
case GGML_TYPE_F16:
|
5672
|
-
case GGML_TYPE_COUNT:
|
6698
|
+
default:
|
5673
6699
|
{
|
5674
6700
|
GGML_ASSERT(false);
|
5675
6701
|
} break;
|
@@ -5711,13 +6737,7 @@ static void ggml_compute_forward_sqr(
|
|
5711
6737
|
{
|
5712
6738
|
ggml_compute_forward_sqr_f32(params, src0, dst);
|
5713
6739
|
} break;
|
5714
|
-
|
5715
|
-
case GGML_TYPE_Q4_1:
|
5716
|
-
case GGML_TYPE_I8:
|
5717
|
-
case GGML_TYPE_I16:
|
5718
|
-
case GGML_TYPE_I32:
|
5719
|
-
case GGML_TYPE_F16:
|
5720
|
-
case GGML_TYPE_COUNT:
|
6740
|
+
default:
|
5721
6741
|
{
|
5722
6742
|
GGML_ASSERT(false);
|
5723
6743
|
} break;
|
@@ -5759,13 +6779,7 @@ static void ggml_compute_forward_sqrt(
|
|
5759
6779
|
{
|
5760
6780
|
ggml_compute_forward_sqrt_f32(params, src0, dst);
|
5761
6781
|
} break;
|
5762
|
-
|
5763
|
-
case GGML_TYPE_Q4_1:
|
5764
|
-
case GGML_TYPE_I8:
|
5765
|
-
case GGML_TYPE_I16:
|
5766
|
-
case GGML_TYPE_I32:
|
5767
|
-
case GGML_TYPE_F16:
|
5768
|
-
case GGML_TYPE_COUNT:
|
6782
|
+
default:
|
5769
6783
|
{
|
5770
6784
|
GGML_ASSERT(false);
|
5771
6785
|
} break;
|
@@ -5817,13 +6831,7 @@ static void ggml_compute_forward_sum(
|
|
5817
6831
|
{
|
5818
6832
|
ggml_compute_forward_sum_f32(params, src0, dst);
|
5819
6833
|
} break;
|
5820
|
-
|
5821
|
-
case GGML_TYPE_Q4_1:
|
5822
|
-
case GGML_TYPE_I8:
|
5823
|
-
case GGML_TYPE_I16:
|
5824
|
-
case GGML_TYPE_I32:
|
5825
|
-
case GGML_TYPE_F16:
|
5826
|
-
case GGML_TYPE_COUNT:
|
6834
|
+
default:
|
5827
6835
|
{
|
5828
6836
|
GGML_ASSERT(false);
|
5829
6837
|
} break;
|
@@ -5894,13 +6902,7 @@ static void ggml_compute_forward_mean(
|
|
5894
6902
|
{
|
5895
6903
|
ggml_compute_forward_mean_f32(params, src0, dst);
|
5896
6904
|
} break;
|
5897
|
-
|
5898
|
-
case GGML_TYPE_Q4_1:
|
5899
|
-
case GGML_TYPE_I8:
|
5900
|
-
case GGML_TYPE_I16:
|
5901
|
-
case GGML_TYPE_I32:
|
5902
|
-
case GGML_TYPE_F16:
|
5903
|
-
case GGML_TYPE_COUNT:
|
6905
|
+
default:
|
5904
6906
|
{
|
5905
6907
|
GGML_ASSERT(false);
|
5906
6908
|
} break;
|
@@ -5958,13 +6960,7 @@ static void ggml_compute_forward_repeat(
|
|
5958
6960
|
{
|
5959
6961
|
ggml_compute_forward_repeat_f32(params, src0, dst);
|
5960
6962
|
} break;
|
5961
|
-
|
5962
|
-
case GGML_TYPE_Q4_1:
|
5963
|
-
case GGML_TYPE_I8:
|
5964
|
-
case GGML_TYPE_I16:
|
5965
|
-
case GGML_TYPE_I32:
|
5966
|
-
case GGML_TYPE_F16:
|
5967
|
-
case GGML_TYPE_COUNT:
|
6963
|
+
default:
|
5968
6964
|
{
|
5969
6965
|
GGML_ASSERT(false);
|
5970
6966
|
} break;
|
@@ -6006,13 +7002,7 @@ static void ggml_compute_forward_abs(
|
|
6006
7002
|
{
|
6007
7003
|
ggml_compute_forward_abs_f32(params, src0, dst);
|
6008
7004
|
} break;
|
6009
|
-
|
6010
|
-
case GGML_TYPE_Q4_1:
|
6011
|
-
case GGML_TYPE_I8:
|
6012
|
-
case GGML_TYPE_I16:
|
6013
|
-
case GGML_TYPE_I32:
|
6014
|
-
case GGML_TYPE_F16:
|
6015
|
-
case GGML_TYPE_COUNT:
|
7005
|
+
default:
|
6016
7006
|
{
|
6017
7007
|
GGML_ASSERT(false);
|
6018
7008
|
} break;
|
@@ -6054,13 +7044,7 @@ static void ggml_compute_forward_sgn(
|
|
6054
7044
|
{
|
6055
7045
|
ggml_compute_forward_sgn_f32(params, src0, dst);
|
6056
7046
|
} break;
|
6057
|
-
|
6058
|
-
case GGML_TYPE_Q4_1:
|
6059
|
-
case GGML_TYPE_I8:
|
6060
|
-
case GGML_TYPE_I16:
|
6061
|
-
case GGML_TYPE_I32:
|
6062
|
-
case GGML_TYPE_F16:
|
6063
|
-
case GGML_TYPE_COUNT:
|
7047
|
+
default:
|
6064
7048
|
{
|
6065
7049
|
GGML_ASSERT(false);
|
6066
7050
|
} break;
|
@@ -6102,13 +7086,7 @@ static void ggml_compute_forward_neg(
|
|
6102
7086
|
{
|
6103
7087
|
ggml_compute_forward_neg_f32(params, src0, dst);
|
6104
7088
|
} break;
|
6105
|
-
|
6106
|
-
case GGML_TYPE_Q4_1:
|
6107
|
-
case GGML_TYPE_I8:
|
6108
|
-
case GGML_TYPE_I16:
|
6109
|
-
case GGML_TYPE_I32:
|
6110
|
-
case GGML_TYPE_F16:
|
6111
|
-
case GGML_TYPE_COUNT:
|
7089
|
+
default:
|
6112
7090
|
{
|
6113
7091
|
GGML_ASSERT(false);
|
6114
7092
|
} break;
|
@@ -6150,13 +7128,7 @@ static void ggml_compute_forward_step(
|
|
6150
7128
|
{
|
6151
7129
|
ggml_compute_forward_step_f32(params, src0, dst);
|
6152
7130
|
} break;
|
6153
|
-
|
6154
|
-
case GGML_TYPE_Q4_1:
|
6155
|
-
case GGML_TYPE_I8:
|
6156
|
-
case GGML_TYPE_I16:
|
6157
|
-
case GGML_TYPE_I32:
|
6158
|
-
case GGML_TYPE_F16:
|
6159
|
-
case GGML_TYPE_COUNT:
|
7131
|
+
default:
|
6160
7132
|
{
|
6161
7133
|
GGML_ASSERT(false);
|
6162
7134
|
} break;
|
@@ -6193,18 +7165,12 @@ static void ggml_compute_forward_relu(
|
|
6193
7165
|
const struct ggml_compute_params * params,
|
6194
7166
|
const struct ggml_tensor * src0,
|
6195
7167
|
struct ggml_tensor * dst) {
|
6196
|
-
switch (src0->type) {
|
6197
|
-
case GGML_TYPE_F32:
|
6198
|
-
{
|
6199
|
-
ggml_compute_forward_relu_f32(params, src0, dst);
|
6200
|
-
} break;
|
6201
|
-
|
6202
|
-
case GGML_TYPE_Q4_1:
|
6203
|
-
case GGML_TYPE_I8:
|
6204
|
-
case GGML_TYPE_I16:
|
6205
|
-
case GGML_TYPE_I32:
|
6206
|
-
case GGML_TYPE_F16:
|
6207
|
-
case GGML_TYPE_COUNT:
|
7168
|
+
switch (src0->type) {
|
7169
|
+
case GGML_TYPE_F32:
|
7170
|
+
{
|
7171
|
+
ggml_compute_forward_relu_f32(params, src0, dst);
|
7172
|
+
} break;
|
7173
|
+
default:
|
6208
7174
|
{
|
6209
7175
|
GGML_ASSERT(false);
|
6210
7176
|
} break;
|
@@ -6263,13 +7229,7 @@ static void ggml_compute_forward_gelu(
|
|
6263
7229
|
{
|
6264
7230
|
ggml_compute_forward_gelu_f32(params, src0, dst);
|
6265
7231
|
} break;
|
6266
|
-
|
6267
|
-
case GGML_TYPE_Q4_1:
|
6268
|
-
case GGML_TYPE_I8:
|
6269
|
-
case GGML_TYPE_I16:
|
6270
|
-
case GGML_TYPE_I32:
|
6271
|
-
case GGML_TYPE_F16:
|
6272
|
-
case GGML_TYPE_COUNT:
|
7232
|
+
default:
|
6273
7233
|
{
|
6274
7234
|
GGML_ASSERT(false);
|
6275
7235
|
} break;
|
@@ -6330,13 +7290,7 @@ static void ggml_compute_forward_silu(
|
|
6330
7290
|
{
|
6331
7291
|
ggml_compute_forward_silu_f32(params, src0, dst);
|
6332
7292
|
} break;
|
6333
|
-
|
6334
|
-
case GGML_TYPE_Q4_1:
|
6335
|
-
case GGML_TYPE_I8:
|
6336
|
-
case GGML_TYPE_I16:
|
6337
|
-
case GGML_TYPE_I32:
|
6338
|
-
case GGML_TYPE_F16:
|
6339
|
-
case GGML_TYPE_COUNT:
|
7293
|
+
default:
|
6340
7294
|
{
|
6341
7295
|
GGML_ASSERT(false);
|
6342
7296
|
} break;
|
@@ -6416,13 +7370,7 @@ static void ggml_compute_forward_norm(
|
|
6416
7370
|
{
|
6417
7371
|
ggml_compute_forward_norm_f32(params, src0, dst);
|
6418
7372
|
} break;
|
6419
|
-
|
6420
|
-
case GGML_TYPE_Q4_1:
|
6421
|
-
case GGML_TYPE_I8:
|
6422
|
-
case GGML_TYPE_I16:
|
6423
|
-
case GGML_TYPE_I32:
|
6424
|
-
case GGML_TYPE_F16:
|
6425
|
-
case GGML_TYPE_COUNT:
|
7373
|
+
default:
|
6426
7374
|
{
|
6427
7375
|
GGML_ASSERT(false);
|
6428
7376
|
} break;
|
@@ -6496,13 +7444,7 @@ static void ggml_compute_forward_rms_norm(
|
|
6496
7444
|
{
|
6497
7445
|
ggml_compute_forward_rms_norm_f32(params, src0, dst);
|
6498
7446
|
} break;
|
6499
|
-
|
6500
|
-
case GGML_TYPE_Q4_1:
|
6501
|
-
case GGML_TYPE_I8:
|
6502
|
-
case GGML_TYPE_I16:
|
6503
|
-
case GGML_TYPE_I32:
|
6504
|
-
case GGML_TYPE_F16:
|
6505
|
-
case GGML_TYPE_COUNT:
|
7447
|
+
default:
|
6506
7448
|
{
|
6507
7449
|
GGML_ASSERT(false);
|
6508
7450
|
} break;
|
@@ -6512,7 +7454,7 @@ static void ggml_compute_forward_rms_norm(
|
|
6512
7454
|
|
6513
7455
|
// ggml_compute_forward_mul_mat
|
6514
7456
|
|
6515
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
7457
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
6516
7458
|
// helper function to determine if it is better to use BLAS or not
|
6517
7459
|
// for large matrices, BLAS is faster
|
6518
7460
|
static bool ggml_compute_forward_mul_mat_use_blas(
|
@@ -6552,7 +7494,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
6552
7494
|
const int64_t ne02 = src0->ne[2];
|
6553
7495
|
const int64_t ne03 = src0->ne[3];
|
6554
7496
|
|
6555
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
7497
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
6556
7498
|
const int64_t ne10 = src1->ne[0];
|
6557
7499
|
#endif
|
6558
7500
|
const int64_t ne11 = src1->ne[1];
|
@@ -6609,7 +7551,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
6609
7551
|
// nb01 >= nb00 - src0 is not transposed
|
6610
7552
|
// compute by src0 rows
|
6611
7553
|
|
6612
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
7554
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
6613
7555
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
6614
7556
|
if (params->ith != 0) {
|
6615
7557
|
return;
|
@@ -6623,6 +7565,21 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
6623
7565
|
return;
|
6624
7566
|
}
|
6625
7567
|
|
7568
|
+
#if defined(GGML_USE_CUBLAS)
|
7569
|
+
float *d_X = NULL;
|
7570
|
+
float *d_Y = NULL;
|
7571
|
+
float *d_D = NULL;
|
7572
|
+
const float alpha = 1.0f;
|
7573
|
+
const float beta = 0.0f;
|
7574
|
+
const int x_ne = ne01 * ne10;
|
7575
|
+
const int y_ne = ne11 * ne10;
|
7576
|
+
const int d_ne = ne11 * ne01;
|
7577
|
+
|
7578
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
|
7579
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
|
7580
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
7581
|
+
#endif
|
7582
|
+
|
6626
7583
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
6627
7584
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
6628
7585
|
const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
|
@@ -6630,15 +7587,37 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
6630
7587
|
|
6631
7588
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
6632
7589
|
|
7590
|
+
#if defined(GGML_USE_CUBLAS)
|
7591
|
+
// copy data to device
|
7592
|
+
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
|
7593
|
+
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
7594
|
+
|
7595
|
+
// compute
|
7596
|
+
CUBLAS_CHECK(
|
7597
|
+
cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
7598
|
+
ne01, ne11, ne10,
|
7599
|
+
&alpha, d_X, ne00,
|
7600
|
+
d_Y, ne10,
|
7601
|
+
&beta, d_D, ne01));
|
7602
|
+
|
7603
|
+
// copy data to host
|
7604
|
+
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
7605
|
+
#else
|
6633
7606
|
// zT = y * xT
|
6634
7607
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
6635
7608
|
ne11, ne01, ne10,
|
6636
7609
|
1.0f, y, ne10,
|
6637
7610
|
x, ne00,
|
6638
7611
|
0.0f, d, ne01);
|
7612
|
+
#endif
|
6639
7613
|
}
|
6640
7614
|
}
|
6641
|
-
|
7615
|
+
#if defined(GGML_USE_CUBLAS)
|
7616
|
+
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
7617
|
+
CUDA_CHECK(cudaFree(d_X));
|
7618
|
+
CUDA_CHECK(cudaFree(d_Y));
|
7619
|
+
CUDA_CHECK(cudaFree(d_D));
|
7620
|
+
#endif
|
6642
7621
|
//printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
6643
7622
|
|
6644
7623
|
return;
|
@@ -6768,7 +7747,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
6768
7747
|
// nb01 >= nb00 - src0 is not transposed
|
6769
7748
|
// compute by src0 rows
|
6770
7749
|
|
6771
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
7750
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
6772
7751
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
6773
7752
|
GGML_ASSERT(nb10 == sizeof(float));
|
6774
7753
|
|
@@ -6784,10 +7763,37 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
6784
7763
|
return;
|
6785
7764
|
}
|
6786
7765
|
|
6787
|
-
|
7766
|
+
#if defined(GGML_USE_CUBLAS)
|
7767
|
+
ggml_fp16_t * const wdata = params->wdata;
|
6788
7768
|
|
7769
|
+
float *d_X = NULL;
|
7770
|
+
float *d_Y = NULL;
|
7771
|
+
float *d_D = NULL;
|
7772
|
+
const float alpha = 1.0f;
|
7773
|
+
const float beta = 0.0f;
|
7774
|
+
const int x_ne = ne01 * ne10;
|
7775
|
+
const int y_ne = ne11 * ne10;
|
7776
|
+
const int d_ne = ne11 * ne01;
|
7777
|
+
|
7778
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne));
|
7779
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
|
7780
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
7781
|
+
#else
|
7782
|
+
float * const wdata = params->wdata;
|
7783
|
+
#endif
|
6789
7784
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
6790
7785
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
7786
|
+
#if defined(GGML_USE_CUBLAS)
|
7787
|
+
// with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
|
7788
|
+
{
|
7789
|
+
size_t id = 0;
|
7790
|
+
for (int64_t i01 = 0; i01 < ne11; ++i01) {
|
7791
|
+
for (int64_t i00 = 0; i00 < ne10; ++i00) {
|
7792
|
+
wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
|
7793
|
+
}
|
7794
|
+
}
|
7795
|
+
}
|
7796
|
+
#else
|
6791
7797
|
{
|
6792
7798
|
size_t id = 0;
|
6793
7799
|
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
@@ -6796,7 +7802,31 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
6796
7802
|
}
|
6797
7803
|
}
|
6798
7804
|
}
|
7805
|
+
#endif
|
6799
7806
|
|
7807
|
+
#if defined(GGML_USE_CUBLAS)
|
7808
|
+
const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03);
|
7809
|
+
const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
|
7810
|
+
|
7811
|
+
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
7812
|
+
|
7813
|
+
// copy data to device
|
7814
|
+
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream));
|
7815
|
+
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
7816
|
+
|
7817
|
+
// compute
|
7818
|
+
CUBLAS_CHECK(
|
7819
|
+
cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
7820
|
+
ne01, ne11, ne10,
|
7821
|
+
&alpha, d_X, CUDA_R_16F, ne00,
|
7822
|
+
d_Y, CUDA_R_16F, ne10,
|
7823
|
+
&beta, d_D, CUDA_R_32F, ne01,
|
7824
|
+
CUBLAS_COMPUTE_32F,
|
7825
|
+
CUBLAS_GEMM_DEFAULT));
|
7826
|
+
|
7827
|
+
// copy data to host
|
7828
|
+
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
7829
|
+
#else
|
6800
7830
|
const float * x = wdata;
|
6801
7831
|
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
6802
7832
|
|
@@ -6808,9 +7838,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
6808
7838
|
1.0f, y, ne10,
|
6809
7839
|
x, ne00,
|
6810
7840
|
0.0f, d, ne01);
|
7841
|
+
#endif
|
6811
7842
|
}
|
6812
7843
|
}
|
6813
7844
|
|
7845
|
+
#if defined(GGML_USE_CUBLAS)
|
7846
|
+
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
7847
|
+
CUDA_CHECK(cudaFree(d_X));
|
7848
|
+
CUDA_CHECK(cudaFree(d_Y));
|
7849
|
+
CUDA_CHECK(cudaFree(d_D));
|
7850
|
+
#endif
|
6814
7851
|
/*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
|
6815
7852
|
|
6816
7853
|
return;
|
@@ -6894,27 +7931,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
6894
7931
|
//}
|
6895
7932
|
}
|
6896
7933
|
|
6897
|
-
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
6898
|
-
[GGML_TYPE_Q4_0] = {
|
6899
|
-
.dequantize_row_q = dequantize_row_q4_0,
|
6900
|
-
.quantize_row_q = quantize_row_q4_0,
|
6901
|
-
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
|
6902
|
-
.vec_dot_q = ggml_vec_dot_q4_0,
|
6903
|
-
},
|
6904
|
-
[GGML_TYPE_Q4_1] = {
|
6905
|
-
.dequantize_row_q = dequantize_row_q4_1,
|
6906
|
-
.quantize_row_q = quantize_row_q4_1,
|
6907
|
-
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
|
6908
|
-
.vec_dot_q = ggml_vec_dot_q4_1,
|
6909
|
-
},
|
6910
|
-
};
|
6911
|
-
|
6912
|
-
// For internal test use
|
6913
|
-
quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
|
6914
|
-
GGML_ASSERT(i < GGML_TYPE_COUNT);
|
6915
|
-
return quantize_fns[i];
|
6916
|
-
}
|
6917
|
-
|
6918
7934
|
static void ggml_compute_forward_mul_mat_q_f32(
|
6919
7935
|
const struct ggml_compute_params * params,
|
6920
7936
|
const struct ggml_tensor * src0,
|
@@ -6962,8 +7978,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
6962
7978
|
GGML_ASSERT(ne3 == ne13);
|
6963
7979
|
|
6964
7980
|
const enum ggml_type type = src0->type;
|
6965
|
-
quantize_row_q_t const
|
6966
|
-
vec_dot_q_t const vec_dot_q
|
7981
|
+
quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
|
7982
|
+
vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
|
6967
7983
|
|
6968
7984
|
// we don't support permuted src0 or src1
|
6969
7985
|
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
|
@@ -6983,7 +7999,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
6983
7999
|
// nb01 >= nb00 - src0 is not transposed
|
6984
8000
|
// compute by src0 rows
|
6985
8001
|
|
6986
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
8002
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
6987
8003
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
6988
8004
|
if (params->ith != 0) {
|
6989
8005
|
return;
|
@@ -6997,11 +8013,55 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
6997
8013
|
return;
|
6998
8014
|
}
|
6999
8015
|
|
8016
|
+
#if defined(GGML_USE_CUBLAS)
|
8017
|
+
float *d_X = NULL;
|
8018
|
+
float *d_Y = NULL;
|
8019
|
+
float *d_D = NULL;
|
8020
|
+
float *d_Q = NULL;
|
8021
|
+
const float alpha = 1.0f;
|
8022
|
+
const float beta = 0.0f;
|
8023
|
+
const int x_ne = ne01 * ne10;
|
8024
|
+
const int y_ne = ne11 * ne10;
|
8025
|
+
const int d_ne = ne11 * ne01;
|
8026
|
+
|
8027
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
|
8028
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
|
8029
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
8030
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
|
8031
|
+
|
8032
|
+
void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
|
8033
|
+
if (type == GGML_TYPE_Q4_0) {
|
8034
|
+
dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
|
8035
|
+
}
|
8036
|
+
else if (type == GGML_TYPE_Q4_1) {
|
8037
|
+
dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
|
8038
|
+
}
|
8039
|
+
else if (type == GGML_TYPE_Q4_2) {
|
8040
|
+
dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
|
8041
|
+
}
|
8042
|
+
else {
|
8043
|
+
GGML_ASSERT(false);
|
8044
|
+
}
|
8045
|
+
#else
|
7000
8046
|
float * const wdata = params->wdata;
|
7001
8047
|
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
|
8048
|
+
#endif
|
7002
8049
|
|
7003
8050
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
7004
8051
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
8052
|
+
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
8053
|
+
|
8054
|
+
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
8055
|
+
|
8056
|
+
#if defined(GGML_USE_CUBLAS)
|
8057
|
+
// copy and dequantize on device
|
8058
|
+
CUDA_CHECK(
|
8059
|
+
cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
|
8060
|
+
GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
|
8061
|
+
|
8062
|
+
dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream);
|
8063
|
+
CUDA_CHECK(cudaGetLastError());
|
8064
|
+
#else
|
7005
8065
|
{
|
7006
8066
|
size_t id = 0;
|
7007
8067
|
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
@@ -7009,21 +8069,42 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7009
8069
|
id += ne00;
|
7010
8070
|
}
|
7011
8071
|
}
|
7012
|
-
|
7013
8072
|
const float * x = wdata;
|
7014
|
-
|
8073
|
+
#endif
|
7015
8074
|
|
7016
|
-
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
7017
8075
|
|
8076
|
+
#if defined(GGML_USE_CUBLAS)
|
8077
|
+
// copy data to device
|
8078
|
+
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
8079
|
+
|
8080
|
+
// compute
|
8081
|
+
CUBLAS_CHECK(
|
8082
|
+
cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
8083
|
+
ne01, ne11, ne10,
|
8084
|
+
&alpha, d_X, ne00,
|
8085
|
+
d_Y, ne10,
|
8086
|
+
&beta, d_D, ne01));
|
8087
|
+
|
8088
|
+
// copy data to host
|
8089
|
+
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
8090
|
+
#else
|
7018
8091
|
// zT = y * xT
|
7019
8092
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
7020
8093
|
ne11, ne01, ne10,
|
7021
8094
|
1.0f, y, ne10,
|
7022
8095
|
x, ne00,
|
7023
8096
|
0.0f, d, ne01);
|
8097
|
+
#endif
|
7024
8098
|
}
|
7025
8099
|
}
|
7026
8100
|
|
8101
|
+
#if defined(GGML_USE_CUBLAS)
|
8102
|
+
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
8103
|
+
CUDA_CHECK(cudaFree(d_X));
|
8104
|
+
CUDA_CHECK(cudaFree(d_Y));
|
8105
|
+
CUDA_CHECK(cudaFree(d_D));
|
8106
|
+
CUDA_CHECK(cudaFree(d_Q));
|
8107
|
+
#endif
|
7027
8108
|
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
7028
8109
|
|
7029
8110
|
return;
|
@@ -7032,12 +8113,12 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7032
8113
|
|
7033
8114
|
if (params->type == GGML_TASK_INIT) {
|
7034
8115
|
char * wdata = params->wdata;
|
7035
|
-
const size_t row_size = ne10*GGML_TYPE_SIZE[
|
8116
|
+
const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
|
7036
8117
|
|
7037
8118
|
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
7038
8119
|
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
7039
8120
|
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
7040
|
-
|
8121
|
+
quantize_row_q_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
|
7041
8122
|
wdata += row_size;
|
7042
8123
|
}
|
7043
8124
|
}
|
@@ -7063,7 +8144,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7063
8144
|
const int ir1 = MIN(ir0 + dr, nr);
|
7064
8145
|
|
7065
8146
|
void * wdata = params->wdata;
|
7066
|
-
const size_t row_size = ne00*GGML_TYPE_SIZE[
|
8147
|
+
const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
|
7067
8148
|
|
7068
8149
|
for (int ir = ir0; ir < ir1; ++ir) {
|
7069
8150
|
// src0 indices
|
@@ -7111,6 +8192,9 @@ static void ggml_compute_forward_mul_mat(
|
|
7111
8192
|
switch (src0->type) {
|
7112
8193
|
case GGML_TYPE_Q4_0:
|
7113
8194
|
case GGML_TYPE_Q4_1:
|
8195
|
+
case GGML_TYPE_Q4_2:
|
8196
|
+
case GGML_TYPE_Q4_3:
|
8197
|
+
case GGML_TYPE_Q8_0:
|
7114
8198
|
{
|
7115
8199
|
ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
|
7116
8200
|
} break;
|
@@ -7122,42 +8206,11 @@ static void ggml_compute_forward_mul_mat(
|
|
7122
8206
|
{
|
7123
8207
|
ggml_compute_forward_mul_mat_f32(params, src0, src1, dst);
|
7124
8208
|
} break;
|
7125
|
-
|
7126
|
-
case GGML_TYPE_I16:
|
7127
|
-
case GGML_TYPE_I32:
|
7128
|
-
case GGML_TYPE_COUNT:
|
8209
|
+
default:
|
7129
8210
|
{
|
7130
8211
|
GGML_ASSERT(false);
|
7131
8212
|
} break;
|
7132
8213
|
}
|
7133
|
-
|
7134
|
-
#if 0
|
7135
|
-
if (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_Q4_1) {
|
7136
|
-
static int first = 8;
|
7137
|
-
printf("src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
|
7138
|
-
printf("src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
|
7139
|
-
printf("dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
|
7140
|
-
if (first) {
|
7141
|
-
--first;
|
7142
|
-
} else {
|
7143
|
-
for (int k = 0; k < dst->ne[1]; ++k) {
|
7144
|
-
for (int j = 0; j < dst->ne[0]/16; ++j) {
|
7145
|
-
for (int i = 0; i < 16; ++i) {
|
7146
|
-
printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
|
7147
|
-
}
|
7148
|
-
printf("\n");
|
7149
|
-
}
|
7150
|
-
printf("\n");
|
7151
|
-
}
|
7152
|
-
printf("\n");
|
7153
|
-
exit(0);
|
7154
|
-
}
|
7155
|
-
} else {
|
7156
|
-
printf("aaaa src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
|
7157
|
-
printf("aaaa src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
|
7158
|
-
printf("aaaa dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
|
7159
|
-
}
|
7160
|
-
#endif
|
7161
8214
|
}
|
7162
8215
|
|
7163
8216
|
// ggml_compute_forward_scale
|
@@ -7207,13 +8260,7 @@ static void ggml_compute_forward_scale(
|
|
7207
8260
|
{
|
7208
8261
|
ggml_compute_forward_scale_f32(params, src0, src1, dst);
|
7209
8262
|
} break;
|
7210
|
-
|
7211
|
-
case GGML_TYPE_Q4_1:
|
7212
|
-
case GGML_TYPE_I8:
|
7213
|
-
case GGML_TYPE_I16:
|
7214
|
-
case GGML_TYPE_I32:
|
7215
|
-
case GGML_TYPE_F16:
|
7216
|
-
case GGML_TYPE_COUNT:
|
8263
|
+
default:
|
7217
8264
|
{
|
7218
8265
|
GGML_ASSERT(false);
|
7219
8266
|
} break;
|
@@ -7374,6 +8421,9 @@ static void ggml_compute_forward_get_rows(
|
|
7374
8421
|
switch (src0->type) {
|
7375
8422
|
case GGML_TYPE_Q4_0:
|
7376
8423
|
case GGML_TYPE_Q4_1:
|
8424
|
+
case GGML_TYPE_Q4_2:
|
8425
|
+
case GGML_TYPE_Q4_3:
|
8426
|
+
case GGML_TYPE_Q8_0:
|
7377
8427
|
{
|
7378
8428
|
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
7379
8429
|
} break;
|
@@ -7385,10 +8435,7 @@ static void ggml_compute_forward_get_rows(
|
|
7385
8435
|
{
|
7386
8436
|
ggml_compute_forward_get_rows_f32(params, src0, src1, dst);
|
7387
8437
|
} break;
|
7388
|
-
|
7389
|
-
case GGML_TYPE_I16:
|
7390
|
-
case GGML_TYPE_I32:
|
7391
|
-
case GGML_TYPE_COUNT:
|
8438
|
+
default:
|
7392
8439
|
{
|
7393
8440
|
GGML_ASSERT(false);
|
7394
8441
|
} break;
|
@@ -7461,13 +8508,7 @@ static void ggml_compute_forward_diag_mask_inf(
|
|
7461
8508
|
{
|
7462
8509
|
ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst);
|
7463
8510
|
} break;
|
7464
|
-
|
7465
|
-
case GGML_TYPE_Q4_1:
|
7466
|
-
case GGML_TYPE_I8:
|
7467
|
-
case GGML_TYPE_I16:
|
7468
|
-
case GGML_TYPE_I32:
|
7469
|
-
case GGML_TYPE_F16:
|
7470
|
-
case GGML_TYPE_COUNT:
|
8511
|
+
default:
|
7471
8512
|
{
|
7472
8513
|
GGML_ASSERT(false);
|
7473
8514
|
} break;
|
@@ -7555,13 +8596,7 @@ static void ggml_compute_forward_soft_max(
|
|
7555
8596
|
{
|
7556
8597
|
ggml_compute_forward_soft_max_f32(params, src0, dst);
|
7557
8598
|
} break;
|
7558
|
-
|
7559
|
-
case GGML_TYPE_Q4_1:
|
7560
|
-
case GGML_TYPE_I8:
|
7561
|
-
case GGML_TYPE_I16:
|
7562
|
-
case GGML_TYPE_I32:
|
7563
|
-
case GGML_TYPE_F16:
|
7564
|
-
case GGML_TYPE_COUNT:
|
8599
|
+
default:
|
7565
8600
|
{
|
7566
8601
|
GGML_ASSERT(false);
|
7567
8602
|
} break;
|
@@ -7618,9 +8653,11 @@ static void ggml_compute_forward_rope_f32(
|
|
7618
8653
|
|
7619
8654
|
const float theta_scale = powf(10000.0, -2.0f/n_dims);
|
7620
8655
|
|
8656
|
+
const bool is_neox = mode & 2;
|
8657
|
+
|
7621
8658
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
7622
|
-
for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
7623
|
-
const int p = (mode == 0 ? n_past + i2 : i2);
|
8659
|
+
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
8660
|
+
const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
7624
8661
|
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
7625
8662
|
if (ir++ < ir0) continue;
|
7626
8663
|
if (ir > ir1) break;
|
@@ -7633,14 +8670,25 @@ static void ggml_compute_forward_rope_f32(
|
|
7633
8670
|
|
7634
8671
|
theta *= theta_scale;
|
7635
8672
|
|
7636
|
-
|
7637
|
-
|
8673
|
+
if (!is_neox) {
|
8674
|
+
const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
8675
|
+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
8676
|
+
|
8677
|
+
const float x0 = src[0];
|
8678
|
+
const float x1 = src[1];
|
7638
8679
|
|
7639
|
-
|
7640
|
-
|
8680
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
8681
|
+
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
8682
|
+
} else {
|
8683
|
+
const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
8684
|
+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
7641
8685
|
|
7642
|
-
|
7643
|
-
|
8686
|
+
const float x0 = src[0];
|
8687
|
+
const float x1 = src[n_dims/2];
|
8688
|
+
|
8689
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
8690
|
+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
8691
|
+
}
|
7644
8692
|
}
|
7645
8693
|
}
|
7646
8694
|
}
|
@@ -7695,9 +8743,11 @@ static void ggml_compute_forward_rope_f16(
|
|
7695
8743
|
|
7696
8744
|
const float theta_scale = powf(10000.0, -2.0f/n_dims);
|
7697
8745
|
|
8746
|
+
const bool is_neox = mode & 2;
|
8747
|
+
|
7698
8748
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
7699
|
-
for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
7700
|
-
const int p = (mode == 0 ? n_past + i2 : i2);
|
8749
|
+
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
8750
|
+
const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
7701
8751
|
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
7702
8752
|
if (ir++ < ir0) continue;
|
7703
8753
|
if (ir > ir1) break;
|
@@ -7710,14 +8760,25 @@ static void ggml_compute_forward_rope_f16(
|
|
7710
8760
|
|
7711
8761
|
theta *= theta_scale;
|
7712
8762
|
|
7713
|
-
|
7714
|
-
|
8763
|
+
if (!is_neox) {
|
8764
|
+
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
8765
|
+
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
8766
|
+
|
8767
|
+
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
8768
|
+
const float x1 = GGML_FP16_TO_FP32(src[1]);
|
7715
8769
|
|
7716
|
-
|
7717
|
-
|
8770
|
+
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
8771
|
+
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
8772
|
+
} else {
|
8773
|
+
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
8774
|
+
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
7718
8775
|
|
7719
|
-
|
7720
|
-
|
8776
|
+
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
8777
|
+
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
|
8778
|
+
|
8779
|
+
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
8780
|
+
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
8781
|
+
}
|
7721
8782
|
}
|
7722
8783
|
}
|
7723
8784
|
}
|
@@ -7738,12 +8799,7 @@ static void ggml_compute_forward_rope(
|
|
7738
8799
|
{
|
7739
8800
|
ggml_compute_forward_rope_f32(params, src0, src1, dst);
|
7740
8801
|
} break;
|
7741
|
-
|
7742
|
-
case GGML_TYPE_Q4_1:
|
7743
|
-
case GGML_TYPE_I8:
|
7744
|
-
case GGML_TYPE_I16:
|
7745
|
-
case GGML_TYPE_I32:
|
7746
|
-
case GGML_TYPE_COUNT:
|
8802
|
+
default:
|
7747
8803
|
{
|
7748
8804
|
GGML_ASSERT(false);
|
7749
8805
|
} break;
|
@@ -8006,12 +9062,7 @@ static void ggml_compute_forward_conv_1d_1s(
|
|
8006
9062
|
{
|
8007
9063
|
ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst);
|
8008
9064
|
} break;
|
8009
|
-
|
8010
|
-
case GGML_TYPE_Q4_1:
|
8011
|
-
case GGML_TYPE_I8:
|
8012
|
-
case GGML_TYPE_I16:
|
8013
|
-
case GGML_TYPE_I32:
|
8014
|
-
case GGML_TYPE_COUNT:
|
9065
|
+
default:
|
8015
9066
|
{
|
8016
9067
|
GGML_ASSERT(false);
|
8017
9068
|
} break;
|
@@ -8274,12 +9325,7 @@ static void ggml_compute_forward_conv_1d_2s(
|
|
8274
9325
|
{
|
8275
9326
|
ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst);
|
8276
9327
|
} break;
|
8277
|
-
|
8278
|
-
case GGML_TYPE_Q4_1:
|
8279
|
-
case GGML_TYPE_I8:
|
8280
|
-
case GGML_TYPE_I16:
|
8281
|
-
case GGML_TYPE_I32:
|
8282
|
-
case GGML_TYPE_COUNT:
|
9328
|
+
default:
|
8283
9329
|
{
|
8284
9330
|
GGML_ASSERT(false);
|
8285
9331
|
} break;
|
@@ -8759,12 +9805,7 @@ static void ggml_compute_forward_flash_attn(
|
|
8759
9805
|
{
|
8760
9806
|
ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst);
|
8761
9807
|
} break;
|
8762
|
-
|
8763
|
-
case GGML_TYPE_Q4_1:
|
8764
|
-
case GGML_TYPE_I8:
|
8765
|
-
case GGML_TYPE_I16:
|
8766
|
-
case GGML_TYPE_I32:
|
8767
|
-
case GGML_TYPE_COUNT:
|
9808
|
+
default:
|
8768
9809
|
{
|
8769
9810
|
GGML_ASSERT(false);
|
8770
9811
|
} break;
|
@@ -8970,12 +10011,7 @@ static void ggml_compute_forward_flash_ff(
|
|
8970
10011
|
{
|
8971
10012
|
GGML_ASSERT(false); // TODO
|
8972
10013
|
} break;
|
8973
|
-
|
8974
|
-
case GGML_TYPE_Q4_1:
|
8975
|
-
case GGML_TYPE_I8:
|
8976
|
-
case GGML_TYPE_I16:
|
8977
|
-
case GGML_TYPE_I32:
|
8978
|
-
case GGML_TYPE_COUNT:
|
10014
|
+
default:
|
8979
10015
|
{
|
8980
10016
|
GGML_ASSERT(false);
|
8981
10017
|
} break;
|
@@ -9019,13 +10055,7 @@ static void ggml_compute_forward_map_unary(
|
|
9019
10055
|
{
|
9020
10056
|
ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
|
9021
10057
|
} break;
|
9022
|
-
|
9023
|
-
case GGML_TYPE_Q4_1:
|
9024
|
-
case GGML_TYPE_I8:
|
9025
|
-
case GGML_TYPE_I16:
|
9026
|
-
case GGML_TYPE_I32:
|
9027
|
-
case GGML_TYPE_F16:
|
9028
|
-
case GGML_TYPE_COUNT:
|
10058
|
+
default:
|
9029
10059
|
{
|
9030
10060
|
GGML_ASSERT(false);
|
9031
10061
|
} break;
|
@@ -9074,13 +10104,7 @@ static void ggml_compute_forward_map_binary(
|
|
9074
10104
|
{
|
9075
10105
|
ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
|
9076
10106
|
} break;
|
9077
|
-
|
9078
|
-
case GGML_TYPE_Q4_1:
|
9079
|
-
case GGML_TYPE_I8:
|
9080
|
-
case GGML_TYPE_I16:
|
9081
|
-
case GGML_TYPE_I32:
|
9082
|
-
case GGML_TYPE_F16:
|
9083
|
-
case GGML_TYPE_COUNT:
|
10107
|
+
default:
|
9084
10108
|
{
|
9085
10109
|
GGML_ASSERT(false);
|
9086
10110
|
} break;
|
@@ -9830,13 +10854,29 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
9830
10854
|
struct ggml_tensor * node = cgraph->nodes[i];
|
9831
10855
|
|
9832
10856
|
switch (node->op) {
|
10857
|
+
case GGML_OP_CPY:
|
9833
10858
|
case GGML_OP_DUP:
|
9834
10859
|
{
|
9835
|
-
node->n_tasks =
|
10860
|
+
node->n_tasks = n_threads;
|
10861
|
+
|
10862
|
+
size_t cur = 0;
|
10863
|
+
if (ggml_is_quantized(node->type)) {
|
10864
|
+
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_threads;
|
10865
|
+
}
|
10866
|
+
|
10867
|
+
work_size = MAX(work_size, cur);
|
9836
10868
|
} break;
|
9837
10869
|
case GGML_OP_ADD:
|
9838
10870
|
{
|
9839
10871
|
node->n_tasks = n_threads;
|
10872
|
+
|
10873
|
+
size_t cur = 0;
|
10874
|
+
|
10875
|
+
if (ggml_is_quantized(node->src0->type)) {
|
10876
|
+
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
|
10877
|
+
}
|
10878
|
+
|
10879
|
+
work_size = MAX(work_size, cur);
|
9840
10880
|
} break;
|
9841
10881
|
case GGML_OP_SUB:
|
9842
10882
|
case GGML_OP_MUL:
|
@@ -9881,7 +10921,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
9881
10921
|
size_t cur = 0;
|
9882
10922
|
|
9883
10923
|
if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
|
9884
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
10924
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
9885
10925
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
9886
10926
|
node->n_tasks = 1; // TODO: this actually is doing nothing
|
9887
10927
|
// the threads are still spinning
|
@@ -9897,15 +10937,15 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
9897
10937
|
#endif
|
9898
10938
|
} else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
|
9899
10939
|
cur = 0;
|
9900
|
-
} else if (
|
9901
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
10940
|
+
} else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
|
10941
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
9902
10942
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
9903
10943
|
node->n_tasks = 1;
|
9904
10944
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
|
9905
10945
|
} else
|
9906
10946
|
#endif
|
9907
10947
|
{
|
9908
|
-
cur = GGML_TYPE_SIZE[
|
10948
|
+
cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
|
9909
10949
|
}
|
9910
10950
|
} else {
|
9911
10951
|
GGML_ASSERT(false);
|
@@ -9917,7 +10957,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
9917
10957
|
{
|
9918
10958
|
node->n_tasks = n_threads;
|
9919
10959
|
} break;
|
9920
|
-
case GGML_OP_CPY:
|
9921
10960
|
case GGML_OP_CONT:
|
9922
10961
|
case GGML_OP_RESHAPE:
|
9923
10962
|
case GGML_OP_VIEW:
|
@@ -11080,16 +12119,16 @@ enum ggml_opt_result ggml_opt(
|
|
11080
12119
|
////////////////////////////////////////////////////////////////////////////////
|
11081
12120
|
|
11082
12121
|
size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
|
11083
|
-
assert(k %
|
11084
|
-
const int nb = k /
|
12122
|
+
assert(k % QK4_0 == 0);
|
12123
|
+
const int nb = k / QK4_0;
|
11085
12124
|
|
11086
12125
|
for (int j = 0; j < n; j += k) {
|
11087
|
-
block_q4_0 * restrict y = (block_q4_0 *)dst + j/
|
12126
|
+
block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK4_0;
|
11088
12127
|
|
11089
12128
|
quantize_row_q4_0_reference(src + j, y, k);
|
11090
12129
|
|
11091
12130
|
for (int i = 0; i < nb; i++) {
|
11092
|
-
for (int l = 0; l <
|
12131
|
+
for (int l = 0; l < QK4_0; l += 2) {
|
11093
12132
|
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
|
11094
12133
|
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
11095
12134
|
|
@@ -11099,20 +12138,67 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
|
|
11099
12138
|
}
|
11100
12139
|
}
|
11101
12140
|
|
11102
|
-
return (n/
|
12141
|
+
return (n/QK4_0*sizeof(block_q4_0));
|
11103
12142
|
}
|
11104
12143
|
|
11105
12144
|
size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
|
11106
|
-
assert(k %
|
11107
|
-
const int nb = k /
|
12145
|
+
assert(k % QK4_1 == 0);
|
12146
|
+
const int nb = k / QK4_1;
|
11108
12147
|
|
11109
12148
|
for (int j = 0; j < n; j += k) {
|
11110
|
-
block_q4_1 * restrict y = (block_q4_1 *)dst + j/
|
12149
|
+
block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK4_1;
|
11111
12150
|
|
11112
12151
|
quantize_row_q4_1_reference(src + j, y, k);
|
11113
12152
|
|
11114
12153
|
for (int i = 0; i < nb; i++) {
|
11115
|
-
for (int l = 0; l <
|
12154
|
+
for (int l = 0; l < QK4_1; l += 2) {
|
12155
|
+
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
|
12156
|
+
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
12157
|
+
|
12158
|
+
hist[vi0]++;
|
12159
|
+
hist[vi1]++;
|
12160
|
+
}
|
12161
|
+
}
|
12162
|
+
}
|
12163
|
+
|
12164
|
+
return (n/QK4_1*sizeof(block_q4_1));
|
12165
|
+
}
|
12166
|
+
|
12167
|
+
size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist) {
|
12168
|
+
assert(k % QK4_2 == 0);
|
12169
|
+
const int nb = k / QK4_2;
|
12170
|
+
|
12171
|
+
for (int j = 0; j < n; j += k) {
|
12172
|
+
block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
|
12173
|
+
|
12174
|
+
//quantize_row_q4_2_reference(src + j, y, k);
|
12175
|
+
quantize_row_q4_2_rmse(src + j, y, k);
|
12176
|
+
|
12177
|
+
for (int i = 0; i < nb; i++) {
|
12178
|
+
for (int l = 0; l < QK4_2; l += 2) {
|
12179
|
+
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
|
12180
|
+
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
12181
|
+
|
12182
|
+
hist[vi0]++;
|
12183
|
+
hist[vi1]++;
|
12184
|
+
}
|
12185
|
+
}
|
12186
|
+
}
|
12187
|
+
|
12188
|
+
return (n/QK4_2*sizeof(block_q4_2));
|
12189
|
+
}
|
12190
|
+
|
12191
|
+
size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist) {
|
12192
|
+
assert(k % QK4_3 == 0);
|
12193
|
+
const int nb = k / QK4_3;
|
12194
|
+
|
12195
|
+
for (int j = 0; j < n; j += k) {
|
12196
|
+
block_q4_3 * restrict y = (block_q4_3 *)dst + j/QK4_3;
|
12197
|
+
|
12198
|
+
quantize_row_q4_3_reference(src + j, y, k);
|
12199
|
+
|
12200
|
+
for (int i = 0; i < nb; i++) {
|
12201
|
+
for (int l = 0; l < QK4_3; l += 2) {
|
11116
12202
|
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
|
11117
12203
|
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
11118
12204
|
|
@@ -11122,7 +12208,40 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
|
|
11122
12208
|
}
|
11123
12209
|
}
|
11124
12210
|
|
11125
|
-
return (n/
|
12211
|
+
return (n/QK4_3*sizeof(block_q4_3));
|
12212
|
+
}
|
12213
|
+
|
12214
|
+
size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
|
12215
|
+
size_t result = 0;
|
12216
|
+
switch (type) {
|
12217
|
+
case GGML_TYPE_Q4_0:
|
12218
|
+
{
|
12219
|
+
GGML_ASSERT(start % QK4_0 == 0);
|
12220
|
+
block_q4_0 * block = (block_q4_0*)dst + start / QK4_0;
|
12221
|
+
result = ggml_quantize_q4_0(src + start, block, n, n, hist);
|
12222
|
+
} break;
|
12223
|
+
case GGML_TYPE_Q4_1:
|
12224
|
+
{
|
12225
|
+
GGML_ASSERT(start % QK4_1 == 0);
|
12226
|
+
block_q4_1 * block = (block_q4_1*)dst + start / QK4_1;
|
12227
|
+
result = ggml_quantize_q4_1(src + start, block, n, n, hist);
|
12228
|
+
} break;
|
12229
|
+
case GGML_TYPE_Q4_2:
|
12230
|
+
{
|
12231
|
+
GGML_ASSERT(start % QK4_2 == 0);
|
12232
|
+
block_q4_2 * block = (block_q4_2*)dst + start / QK4_2;
|
12233
|
+
result = ggml_quantize_q4_2(src + start, block, n, n, hist);
|
12234
|
+
} break;
|
12235
|
+
case GGML_TYPE_Q4_3:
|
12236
|
+
{
|
12237
|
+
GGML_ASSERT(start % QK4_3 == 0);
|
12238
|
+
block_q4_3 * block = (block_q4_3*)dst + start / QK4_3;
|
12239
|
+
result = ggml_quantize_q4_3(src + start, block, n, n, hist);
|
12240
|
+
} break;
|
12241
|
+
default:
|
12242
|
+
assert(false);
|
12243
|
+
}
|
12244
|
+
return result;
|
11126
12245
|
}
|
11127
12246
|
|
11128
12247
|
////////////////////////////////////////////////////////////////////////////////
|
@@ -11151,6 +12270,22 @@ int ggml_cpu_has_avx512(void) {
|
|
11151
12270
|
#endif
|
11152
12271
|
}
|
11153
12272
|
|
12273
|
+
int ggml_cpu_has_avx512_vbmi(void) {
|
12274
|
+
#if defined(__AVX512VBMI__)
|
12275
|
+
return 1;
|
12276
|
+
#else
|
12277
|
+
return 0;
|
12278
|
+
#endif
|
12279
|
+
}
|
12280
|
+
|
12281
|
+
int ggml_cpu_has_avx512_vnni(void) {
|
12282
|
+
#if defined(__AVX512VNNI__)
|
12283
|
+
return 1;
|
12284
|
+
#else
|
12285
|
+
return 0;
|
12286
|
+
#endif
|
12287
|
+
}
|
12288
|
+
|
11154
12289
|
int ggml_cpu_has_fma(void) {
|
11155
12290
|
#if defined(__FMA__)
|
11156
12291
|
return 1;
|
@@ -11200,7 +12335,15 @@ int ggml_cpu_has_wasm_simd(void) {
|
|
11200
12335
|
}
|
11201
12336
|
|
11202
12337
|
int ggml_cpu_has_blas(void) {
|
11203
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
12338
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
12339
|
+
return 1;
|
12340
|
+
#else
|
12341
|
+
return 0;
|
12342
|
+
#endif
|
12343
|
+
}
|
12344
|
+
|
12345
|
+
int ggml_cpu_has_cublas(void) {
|
12346
|
+
#if defined(GGML_USE_CUBLAS)
|
11204
12347
|
return 1;
|
11205
12348
|
#else
|
11206
12349
|
return 0;
|